diff --git a/python/pyspark/ml/image.pyi b/.github/workflows/ansi_sql_mode_test.yml similarity index 56% rename from python/pyspark/ml/image.pyi rename to .github/workflows/ansi_sql_mode_test.yml index 206490aaa82d5..e68b04b5420f0 100644 --- a/python/pyspark/ml/image.pyi +++ b/.github/workflows/ansi_sql_mode_test.yml @@ -15,26 +15,20 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# -from typing import Dict, List +name: ANSI SQL mode test -from pyspark.sql.types import Row, StructType +on: + push: + branches: + - master -from numpy import ndarray +jobs: + ansi_sql_test: + uses: ./.github/workflows/build_and_test.yml + if: github.repository == 'apache/spark' + with: + ansi_enabled: true -class _ImageSchema: - def __init__(self) -> None: ... - @property - def imageSchema(self) -> StructType: ... - @property - def ocvTypes(self) -> Dict[str, int]: ... - @property - def columnSchema(self) -> StructType: ... - @property - def imageFields(self) -> List[str]: ... - @property - def undefinedImageType(self) -> str: ... - def toNDArray(self, image: Row) -> ndarray: ... - def toImage(self, array: ndarray, origin: str = ...) -> Row: ... -ImageSchema: _ImageSchema diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 32f46d35c5b3e..ebe17b5963f20 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -37,6 +37,12 @@ on: - cron: '0 13 * * *' # Java 17 - cron: '0 16 * * *' + workflow_call: + inputs: + ansi_enabled: + required: false + type: boolean + default: false jobs: configure-jobs: @@ -92,7 +98,7 @@ jobs: echo '::set-output name=java::8' echo '::set-output name=branch::master' # Default branch to run on. CHANGE here when a branch is cut out. echo '::set-output name=type::regular' - echo '::set-output name=envs::{}' + echo '::set-output name=envs::{"SPARK_ANSI_SQL_MODE": "${{ inputs.ansi_enabled }}"}' echo '::set-output name=hadoop::hadoop3' fi @@ -252,7 +258,7 @@ jobs: - name: Install Python packages (Python 3.8) if: (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-')) run: | - python3.8 -m pip install 'numpy>=1.20.0' 'pyarrow<5.0.0' pandas scipy xmlrunner + python3.8 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy xmlrunner python3.8 -m pip list # Run the tests. - name: Run tests @@ -287,7 +293,7 @@ jobs: name: "Build modules (${{ format('{0}, {1} job', needs.configure-jobs.outputs.branch, needs.configure-jobs.outputs.type) }}): ${{ matrix.modules }}" runs-on: ubuntu-20.04 container: - image: dongjoon/apache-spark-github-action-image:20211228 + image: dongjoon/apache-spark-github-action-image:20220207 strategy: fail-fast: false matrix: @@ -311,6 +317,7 @@ jobs: SKIP_UNIDOC: true SKIP_MIMA: true METASPACE_SIZE: 1g + SPARK_ANSI_SQL_MODE: ${{ inputs.ansi_enabled }} steps: - name: Checkout Spark repository uses: actions/checkout@v2 @@ -391,13 +398,14 @@ jobs: name: "Build modules: sparkr" runs-on: ubuntu-20.04 container: - image: dongjoon/apache-spark-github-action-image:20211228 + image: dongjoon/apache-spark-github-action-image:20220207 env: HADOOP_PROFILE: ${{ needs.configure-jobs.outputs.hadoop }} HIVE_PROFILE: hive2.3 GITHUB_PREV_SHA: ${{ github.event.before }} SPARK_LOCAL_IP: localhost SKIP_MIMA: true + SPARK_ANSI_SQL_MODE: ${{ inputs.ansi_enabled }} steps: - name: Checkout Spark repository uses: actions/checkout@v2 @@ -462,7 +470,7 @@ jobs: PYSPARK_DRIVER_PYTHON: python3.9 PYSPARK_PYTHON: python3.9 container: - image: dongjoon/apache-spark-github-action-image:20211228 + image: dongjoon/apache-spark-github-action-image:20220207 steps: - name: Checkout Spark repository uses: actions/checkout@v2 @@ -529,11 +537,14 @@ jobs: # See also https://github.com/sphinx-doc/sphinx/issues/7551. # Jinja2 3.0.0+ causes error when building with Sphinx. # See also https://issues.apache.org/jira/browse/SPARK-35375. - python3.9 -m pip install 'sphinx<3.1.0' mkdocs pydata_sphinx_theme ipython nbsphinx numpydoc 'jinja2<3.0.0' - python3.9 -m pip install sphinx_plotly_directive 'numpy>=1.20.0' 'pyarrow<5.0.0' pandas 'plotly>=4.8' + # Pin the MarkupSafe to 2.0.1 to resolve the CI error. + # See also https://issues.apache.org/jira/browse/SPARK-38279. + python3.9 -m pip install 'sphinx<3.1.0' mkdocs pydata_sphinx_theme ipython nbsphinx numpydoc 'jinja2<3.0.0' 'markupsafe==2.0.1' + python3.9 -m pip install ipython_genutils # See SPARK-38517 + python3.9 -m pip install sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8' apt-get update -y apt-get install -y ruby ruby-dev - Rscript -e "install.packages(c('devtools', 'testthat', 'knitr', 'rmarkdown', 'roxygen2'), repos='https://cloud.r-project.org/')" + Rscript -e "install.packages(c('devtools', 'testthat', 'knitr', 'rmarkdown', 'markdown', 'e1071', 'roxygen2'), repos='https://cloud.r-project.org/')" Rscript -e "devtools::install_version('pkgdown', version='2.0.1', repos='https://cloud.r-project.org')" Rscript -e "devtools::install_version('preferably', version='0.4', repos='https://cloud.r-project.org')" gem install bundler @@ -614,7 +625,7 @@ jobs: export MAVEN_CLI_OPTS="--no-transfer-progress" export JAVA_VERSION=${{ matrix.java }} # It uses Maven's 'install' intentionally, see https://github.com/apache/spark/pull/26414. - ./build/mvn $MAVEN_CLI_OPTS -DskipTests -Pyarn -Pmesos -Pkubernetes -Phive -Phive-thriftserver -Phadoop-cloud -Djava.version=${JAVA_VERSION/-ea} install + ./build/mvn $MAVEN_CLI_OPTS -DskipTests -Pyarn -Pmesos -Pkubernetes -Pvolcano -Phive -Phive-thriftserver -Phadoop-cloud -Djava.version=${JAVA_VERSION/-ea} install rm -rf ~/.m2/repository/org/apache/spark scala-213: @@ -660,7 +671,7 @@ jobs: - name: Build with SBT run: | ./dev/change-scala-version.sh 2.13 - ./build/sbt -Pyarn -Pmesos -Pkubernetes -Phive -Phive-thriftserver -Phadoop-cloud -Pkinesis-asl -Pdocker-integration-tests -Pkubernetes-integration-tests -Pspark-ganglia-lgpl -Pscala-2.13 compile test:compile + ./build/sbt -Pyarn -Pmesos -Pkubernetes -Pvolcano -Phive -Phive-thriftserver -Phadoop-cloud -Pkinesis-asl -Pdocker-integration-tests -Pkubernetes-integration-tests -Pspark-ganglia-lgpl -Pscala-2.13 compile test:compile tpcds-1g: needs: [configure-jobs, precondition] @@ -669,6 +680,7 @@ jobs: runs-on: ubuntu-20.04 env: SPARK_LOCAL_IP: localhost + SPARK_ANSI_SQL_MODE: ${{ inputs.ansi_enabled }} steps: - name: Checkout Spark repository uses: actions/checkout@v2 diff --git a/.github/workflows/notify_test_workflow.yml b/.github/workflows/notify_test_workflow.yml index bd9147abe1f75..04e7ab8309025 100644 --- a/.github/workflows/notify_test_workflow.yml +++ b/.github/workflows/notify_test_workflow.yml @@ -38,12 +38,19 @@ jobs: with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | - const endpoint = 'GET /repos/:owner/:repo/commits/:ref/check-runs' + const endpoint = 'GET /repos/:owner/:repo/actions/workflows/:id/runs?&branch=:branch' + const check_run_endpoint = 'GET /repos/:owner/:repo/commits/:ref/check-runs' // TODO: Should use pull_request.user and pull_request.user.repos_url? // If a different person creates a commit to another forked repo, // it wouldn't be able to detect. const params = { + owner: context.payload.pull_request.head.repo.owner.login, + repo: context.payload.pull_request.head.repo.name, + id: 'build_and_test.yml', + branch: context.payload.pull_request.head.ref, + } + const check_run_params = { owner: context.payload.pull_request.head.repo.owner.login, repo: context.payload.pull_request.head.repo.name, ref: context.payload.pull_request.head.ref, @@ -67,7 +74,7 @@ jobs: const head_sha = context.payload.pull_request.head.sha let status = 'queued' - if (!runs || runs.data.check_runs.filter(r => r.name === "Configure jobs").length === 0) { + if (!runs || runs.data.workflow_runs.length === 0) { status = 'completed' const conclusion = 'action_required' @@ -99,16 +106,29 @@ jobs: } }) } else { - const runID = runs.data.check_runs.filter(r => r.name === "Configure jobs")[0].id + const run_id = runs.data.workflow_runs[0].id - if (runs.data.check_runs[0].head_sha != context.payload.pull_request.head.sha) { + if (runs.data.workflow_runs[0].head_sha != context.payload.pull_request.head.sha) { throw new Error('There was a new unsynced commit pushed. Please retrigger the workflow.'); } - const runUrl = 'https://github.com/' + // Here we get check run ID to provide Check run view instead of Actions view, see also SPARK-37879. + const check_runs = await github.request(check_run_endpoint, check_run_params) + const check_run_head = check_runs.data.check_runs.filter(r => r.name === "Configure jobs")[0] + + if (check_run_head.head_sha != context.payload.pull_request.head.sha) { + throw new Error('There was a new unsynced commit pushed. Please retrigger the workflow.'); + } + + const check_run_url = 'https://github.com/' + context.payload.pull_request.head.repo.full_name + '/runs/' - + runID + + check_run_head.id + + const actions_url = 'https://github.com/' + + context.payload.pull_request.head.repo.full_name + + '/actions/runs/' + + run_id github.checks.create({ owner: context.repo.owner, @@ -118,13 +138,13 @@ jobs: status: status, output: { title: 'Test results', - summary: '[See test results](' + runUrl + ')', + summary: '[See test results](' + check_run_url + ')', text: JSON.stringify({ owner: context.payload.pull_request.head.repo.owner.login, repo: context.payload.pull_request.head.repo.name, - run_id: runID + run_id: run_id }) }, - details_url: runUrl, + details_url: actions_url, }) } diff --git a/.gitignore b/.gitignore index 560265e4f4cf1..b75878189a975 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,8 @@ *~ .java-version .DS_Store +.ammonite +.bloop .bsp/ .cache .classpath @@ -21,10 +23,12 @@ # SPARK-35223: Add IssueNavigationLink to make IDEA support hyperlink on JIRA Ticket and GitHub PR on Git plugin. !.idea/vcs.xml .idea_modules/ +.metals .project .pydevproject .scala_dependencies .settings +.vscode /lib/ R-unit-tests.log R/unit-tests.out @@ -59,6 +63,7 @@ lint-r-report.log lint-js-report.log log/ logs/ +metals.sbt out/ project/boot/ project/build/target/ diff --git a/LICENSE-binary b/LICENSE-binary index 7e29e613f8057..8bbc913262c89 100644 --- a/LICENSE-binary +++ b/LICENSE-binary @@ -456,6 +456,7 @@ net.sf.py4j:py4j org.jpmml:pmml-model org.jpmml:pmml-schema org.threeten:threeten-extra +org.jdom:jdom2 python/lib/py4j-*-src.zip python/pyspark/cloudpickle.py @@ -504,6 +505,7 @@ Common Development and Distribution License (CDDL) 1.0 javax.activation:activation http://www.oracle.com/technetwork/java/javase/tech/index-jsp-138795.html javax.xml.stream:stax-api https://jcp.org/en/jsr/detail?id=173 javax.transaction:javax.transaction-api +javax.xml.bind:jaxb-api Common Development and Distribution License (CDDL) 1.1 diff --git a/NOTICE-binary b/NOTICE-binary index 4ce8bf2f86b2a..95653c6f49a07 100644 --- a/NOTICE-binary +++ b/NOTICE-binary @@ -917,6 +917,9 @@ This product includes code (JaspellTernarySearchTrie) from Java Spelling Checkin g Package (jaspell): http://jaspell.sourceforge.net/ License: The BSD License (http://www.opensource.org/licenses/bsd-license.php) +This product includes software developed by the JDOM Project (http://www.jdom.org/) +License: https://raw.githubusercontent.com/hunterhacker/jdom/master/LICENSE.txt + The snowball stemmers in analysis/common/src/java/net/sf/snowball were developed by Martin Porter and Richard Boulton. diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 6b85bb758a081..d147ff2b34cfd 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -60,7 +60,7 @@ Collate: 'types.R' 'utils.R' 'window.R' -RoxygenNote: 7.1.1 +RoxygenNote: 7.1.2 VignetteBuilder: knitr NeedsCompilation: no Encoding: UTF-8 diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 0e46324ed5c47..df1094bacef64 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1690,9 +1690,9 @@ test_that("column functions", { df <- as.DataFrame(list(list("col" = "1"))) c <- collect(select(df, schema_of_csv("Amsterdam,2018"))) - expect_equal(c[[1]], "STRUCT<`_c0`: STRING, `_c1`: INT>") + expect_equal(c[[1]], "STRUCT<_c0: STRING, _c1: INT>") c <- collect(select(df, schema_of_csv(lit("Amsterdam,2018")))) - expect_equal(c[[1]], "STRUCT<`_c0`: STRING, `_c1`: INT>") + expect_equal(c[[1]], "STRUCT<_c0: STRING, _c1: INT>") # Test to_json(), from_json(), schema_of_json() df <- sql("SELECT array(named_struct('name', 'Bob'), named_struct('name', 'Alice')) as people") @@ -1725,9 +1725,9 @@ test_that("column functions", { df <- as.DataFrame(list(list("col" = "1"))) c <- collect(select(df, schema_of_json('{"name":"Bob"}'))) - expect_equal(c[[1]], "STRUCT<`name`: STRING>") + expect_equal(c[[1]], "STRUCT") c <- collect(select(df, schema_of_json(lit('{"name":"Bob"}')))) - expect_equal(c[[1]], "STRUCT<`name`: STRING>") + expect_equal(c[[1]], "STRUCT") # Test to_json() supports arrays of primitive types and arrays df <- sql("SELECT array(19, 42, 70) as age") @@ -2051,13 +2051,19 @@ test_that("date functions on a DataFrame", { }) test_that("SPARK-37108: expose make_date expression in R", { + ansiEnabled <- sparkR.conf("spark.sql.ansi.enabled")[[1]] == "true" df <- createDataFrame( - list(list(2021, 10, 22), list(2021, 13, 1), - list(2021, 2, 29), list(2020, 2, 29)), + c( + list(list(2021, 10, 22), list(2020, 2, 29)), + if (ansiEnabled) list() else list(list(2021, 13, 1), list(2021, 2, 29)) + ), list("year", "month", "day") ) expect <- createDataFrame( - list(list(as.Date("2021-10-22")), NA, NA, list(as.Date("2020-02-29"))), + c( + list(list(as.Date("2021-10-22")), list(as.Date("2020-02-29"))), + if (ansiEnabled) list() else list(NA, NA) + ), list("make_date(year, month, day)") ) actual <- select(df, make_date(df$year, df$month, df$day)) diff --git a/bin/pyspark b/bin/pyspark index 4840589ffb7bd..1e16c56bc9724 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -50,7 +50,7 @@ export PYSPARK_DRIVER_PYTHON_OPTS # Add the PySpark classes to the Python path: export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH" -export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.9.3-src.zip:$PYTHONPATH" +export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.9.4-src.zip:$PYTHONPATH" # Load the PySpark shell.py script when ./pyspark is used interactively: export OLD_PYTHONSTARTUP="$PYTHONSTARTUP" diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index a19627a3b220a..f20c320494757 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -30,7 +30,7 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( ) set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH% -set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.9.3-src.zip;%PYTHONPATH% +set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.9.4-src.zip;%PYTHONPATH% set OLD_PYTHONSTARTUP=%PYTHONSTARTUP% set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/ArrayWrappers.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/ArrayWrappers.java index 825355ed5d587..6f9487322bb57 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/ArrayWrappers.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/ArrayWrappers.java @@ -200,7 +200,7 @@ public int hashCode() { public int compareTo(ComparableObjectArray other) { int len = Math.min(array.length, other.array.length); for (int i = 0; i < len; i++) { - int diff = ((Comparable) array[i]).compareTo((Comparable) other.array[i]); + int diff = ((Comparable) array[i]).compareTo(other.array[i]); if (diff != 0) { return diff; } diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java index ef92a6cbba31a..c43c9b171f5a4 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java @@ -329,13 +329,14 @@ private int countKeys(Class type) throws Exception { byte[] prefix = db.getTypeInfo(type).keyPrefix(); int count = 0; - DBIterator it = db.db().iterator(); - it.seek(prefix); - - while (it.hasNext()) { - byte[] key = it.next().getKey(); - if (LevelDBIterator.startsWith(key, prefix)) { - count++; + try (DBIterator it = db.db().iterator()) { + it.seek(prefix); + + while (it.hasNext()) { + byte[] key = it.next().getKey(); + if (LevelDBIterator.startsWith(key, prefix)) { + count++; + } } } diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBSuite.java index 1bae764ae96ad..cd18d227cba72 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBSuite.java @@ -330,15 +330,16 @@ private int countKeys(Class type) throws Exception { byte[] prefix = db.getTypeInfo(type).keyPrefix(); int count = 0; - RocksIterator it = db.db().newIterator(); - it.seek(prefix); - - while (it.isValid()) { - byte[] key = it.key(); - if (RocksDBIterator.startsWith(key, prefix)) { - count++; + try (RocksIterator it = db.db().newIterator()) { + it.seek(prefix); + + while (it.isValid()) { + byte[] key = it.key(); + if (RocksDBIterator.startsWith(key, prefix)) { + count++; + } + it.next(); } - it.next(); } return count; diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 576c08858d6c3..261f20540a297 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -140,7 +140,7 @@ public void channelActive() { @Override public void channelInactive() { - if (numOutstandingRequests() > 0) { + if (hasOutstandingRequests()) { String remoteAddress = getRemoteAddress(channel); logger.error("Still have {} requests outstanding when connection from {} is closed", numOutstandingRequests(), remoteAddress); @@ -150,7 +150,7 @@ public void channelInactive() { @Override public void exceptionCaught(Throwable cause) { - if (numOutstandingRequests() > 0) { + if (hasOutstandingRequests()) { String remoteAddress = getRemoteAddress(channel); logger.error("Still have {} requests outstanding when connection from {} is closed", numOutstandingRequests(), remoteAddress); @@ -275,6 +275,12 @@ public int numOutstandingRequests() { (streamActive ? 1 : 0); } + /** Check if there are any outstanding requests (fetch requests + rpcs) */ + public Boolean hasOutstandingRequests() { + return streamActive || !outstandingFetches.isEmpty() || !outstandingRpcs.isEmpty() || + !streamCallbacks.isEmpty(); + } + /** Returns the time in nanoseconds of when the last request was sent out. */ public long getTimeOfLastRequestNs() { return timeOfLastRequestNs.get(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index 275e64ee50f26..d197032003e6e 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -161,8 +161,7 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc boolean isActuallyOverdue = System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs; if (e.state() == IdleState.ALL_IDLE && isActuallyOverdue) { - boolean hasInFlightRequests = responseHandler.numOutstandingRequests() > 0; - if (hasInFlightRequests) { + if (responseHandler.hasOutstandingRequests()) { String address = getRemoteAddress(ctx.channel()); logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + "requests. Assuming connection is dead; please adjust" + diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java index e62b8cb24e0ed..cff115d12b5fe 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java @@ -28,8 +28,6 @@ import org.apache.commons.crypto.stream.CryptoOutputStream; import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; -import org.hamcrest.CoreMatchers; -import org.hamcrest.MatcherAssert; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -81,7 +79,7 @@ CryptoInputStream createInputStream(ReadableByteChannel ch) throws IOException { channel.writeInbound(buffer2); fail("Should have raised an exception"); } catch (Throwable expected) { - MatcherAssert.assertThat(expected, CoreMatchers.instanceOf(IOException.class)); + assertEquals(expected.getClass(), IOException.class); assertEquals(0, buffer2.refCnt()); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ErrorHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ErrorHandler.java index 9136ff6af4e7e..519b02d12421a 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ErrorHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ErrorHandler.java @@ -82,8 +82,8 @@ public boolean shouldRetryError(Throwable t) { // If it is a FileNotFoundException originating from the client while pushing the shuffle // blocks to the server, even then there is no need to retry. We will still log this // exception once which helps with debugging. - if (t.getCause() != null && (t.getCause() instanceof ConnectException || - t.getCause() instanceof FileNotFoundException)) { + if (t.getCause() instanceof ConnectException || + t.getCause() instanceof FileNotFoundException) { return false; } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorDiskUtils.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorDiskUtils.java index e5e61aae92d2f..2ed0718628380 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorDiskUtils.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorDiskUtils.java @@ -27,7 +27,7 @@ public class ExecutorDiskUtils { * Hashes a filename into the corresponding local directory, in a manner consistent with * Spark's DiskBlockManager.getFile(). */ - public static File getFile(String[] localDirs, int subDirsPerLocalDir, String filename) { + public static String getFilePath(String[] localDirs, int subDirsPerLocalDir, String filename) { int hash = JavaUtils.nonNegativeHash(filename); String localDir = localDirs[hash % localDirs.length]; int subDirId = (hash / localDirs.length) % subDirsPerLocalDir; @@ -38,9 +38,8 @@ public static File getFile(String[] localDirs, int subDirsPerLocalDir, String fi // Unfortunately, we cannot just call the normalization code that java.io.File // uses, since it is in the package-private class java.io.FileSystem. // So we are creating a File just to get the normalized path back to intern it. - // Finally a new File is built and returned with this interned normalized path. - final String normalizedInternedPath = new File(notNormalizedPath).getPath().intern(); - return new File(normalizedInternedPath); + // We return this interned normalized path. + return new File(notNormalizedPath).getPath().intern(); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java index 1e413f6b2f375..52bc0f9c2226d 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java @@ -512,14 +512,14 @@ private class ShuffleManagedBufferIterator implements Iterator { mapIds = msg.mapIds; reduceIds = msg.reduceIds; batchFetchEnabled = msg.batchFetchEnabled; - } - - @Override - public boolean hasNext() { // mapIds.length must equal to reduceIds.length, and the passed in FetchShuffleBlocks // must have non-empty mapIds and reduceIds, see the checking logic in // OneForOneBlockFetcher. assert(mapIds.length != 0 && mapIds.length == reduceIds.length); + } + + @Override + public boolean hasNext() { return mapIdx < mapIds.length && reduceIdx < reduceIds[mapIdx].length; } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index bf8c6ae0ab31a..4b8a5e82d7445 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -80,7 +80,7 @@ public class ExternalShuffleBlockResolver { * Caches index file information so that we can avoid open/close the index files * for each block fetch. */ - private final LoadingCache shuffleIndexCache; + private final LoadingCache shuffleIndexCache; // Single-threaded Java executor used to perform expensive recursive directory deletion. private final Executor directoryCleaner; @@ -112,15 +112,16 @@ public ExternalShuffleBlockResolver(TransportConf conf, File registeredExecutorF Boolean.parseBoolean(conf.get(Constants.SHUFFLE_SERVICE_FETCH_RDD_ENABLED, "false")); this.registeredExecutorFile = registeredExecutorFile; String indexCacheSize = conf.get("spark.shuffle.service.index.cache.size", "100m"); - CacheLoader indexCacheLoader = - new CacheLoader() { - public ShuffleIndexInformation load(File file) throws IOException { - return new ShuffleIndexInformation(file); + CacheLoader indexCacheLoader = + new CacheLoader() { + public ShuffleIndexInformation load(String filePath) throws IOException { + return new ShuffleIndexInformation(filePath); } }; shuffleIndexCache = CacheBuilder.newBuilder() .maximumWeight(JavaUtils.byteStringAsBytes(indexCacheSize)) - .weigher((Weigher) (file, indexInfo) -> indexInfo.getSize()) + .weigher((Weigher) + (filePath, indexInfo) -> indexInfo.getRetainedMemorySize()) .build(indexCacheLoader); db = LevelDBProvider.initLevelDB(this.registeredExecutorFile, CURRENT_VERSION, mapper); if (db != null) { @@ -300,28 +301,35 @@ private void deleteNonShuffleServiceServedFiles(String[] dirs) { */ private ManagedBuffer getSortBasedShuffleBlockData( ExecutorShuffleInfo executor, int shuffleId, long mapId, int startReduceId, int endReduceId) { - File indexFile = ExecutorDiskUtils.getFile(executor.localDirs, executor.subDirsPerLocalDir, - "shuffle_" + shuffleId + "_" + mapId + "_0.index"); + String indexFilePath = + ExecutorDiskUtils.getFilePath( + executor.localDirs, + executor.subDirsPerLocalDir, + "shuffle_" + shuffleId + "_" + mapId + "_0.index"); try { - ShuffleIndexInformation shuffleIndexInformation = shuffleIndexCache.get(indexFile); + ShuffleIndexInformation shuffleIndexInformation = shuffleIndexCache.get(indexFilePath); ShuffleIndexRecord shuffleIndexRecord = shuffleIndexInformation.getIndex( startReduceId, endReduceId); return new FileSegmentManagedBuffer( conf, - ExecutorDiskUtils.getFile(executor.localDirs, executor.subDirsPerLocalDir, - "shuffle_" + shuffleId + "_" + mapId + "_0.data"), + new File( + ExecutorDiskUtils.getFilePath( + executor.localDirs, + executor.subDirsPerLocalDir, + "shuffle_" + shuffleId + "_" + mapId + "_0.data")), shuffleIndexRecord.getOffset(), shuffleIndexRecord.getLength()); } catch (ExecutionException e) { - throw new RuntimeException("Failed to open file: " + indexFile, e); + throw new RuntimeException("Failed to open file: " + indexFilePath, e); } } public ManagedBuffer getDiskPersistedRddBlockData( ExecutorShuffleInfo executor, int rddId, int splitIndex) { - File file = ExecutorDiskUtils.getFile(executor.localDirs, executor.subDirsPerLocalDir, - "rdd_" + rddId + "_" + splitIndex); + File file = new File( + ExecutorDiskUtils.getFilePath( + executor.localDirs, executor.subDirsPerLocalDir, "rdd_" + rddId + "_" + splitIndex)); long fileLength = file.length(); ManagedBuffer res = null; if (file.exists()) { @@ -348,8 +356,8 @@ public int removeBlocks(String appId, String execId, String[] blockIds) { } int numRemovedBlocks = 0; for (String blockId : blockIds) { - File file = - ExecutorDiskUtils.getFile(executor.localDirs, executor.subDirsPerLocalDir, blockId); + File file = new File( + ExecutorDiskUtils.getFilePath(executor.localDirs, executor.subDirsPerLocalDir, blockId)); if (file.delete()) { numRemovedBlocks++; } else { @@ -386,10 +394,8 @@ public Cause diagnoseShuffleBlockCorruption( ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); // This should be in sync with IndexShuffleBlockResolver.getChecksumFile String fileName = "shuffle_" + shuffleId + "_" + mapId + "_0.checksum." + algorithm; - File checksumFile = ExecutorDiskUtils.getFile( - executor.localDirs, - executor.subDirsPerLocalDir, - fileName); + File checksumFile = new File( + ExecutorDiskUtils.getFilePath(executor.localDirs, executor.subDirsPerLocalDir, fileName)); ManagedBuffer data = getBlockData(appId, execId, shuffleId, mapId, reduceId); return ShuffleChecksumHelper.diagnoseCorruption( algorithm, checksumFile, reduceId, data, checksumByReader); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java index d626cc3efaf07..62ab34028963e 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java @@ -83,20 +83,11 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { public static final String MERGE_DIR_KEY = "mergeDir"; public static final String ATTEMPT_ID_KEY = "attemptId"; private static final int UNDEFINED_ATTEMPT_ID = -1; - // Shuffles of determinate stages will have shuffleMergeId set to 0 - private static final int DETERMINATE_SHUFFLE_MERGE_ID = 0; private static final ErrorHandler.BlockPushErrorHandler ERROR_HANDLER = createErrorHandler(); // ByteBuffer to respond to client upon a successful merge of a pushed block private static final ByteBuffer SUCCESS_RESPONSE = new BlockPushReturnCode(ReturnCode.SUCCESS.id(), "").toByteBuffer().asReadOnlyBuffer(); - // ConcurrentHashMap doesn't allow null for keys or values which is why this is required. - // Marker to identify finalized indeterminate shuffle partitions in the case of indeterminate - // stage retries. - @VisibleForTesting - public static final Map INDETERMINATE_SHUFFLE_FINALIZED = - Collections.emptyMap(); - /** * A concurrent hashmap where the key is the applicationId, and the value includes * all the merged shuffle information for this application. AppShuffleInfo stores @@ -111,7 +102,7 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { private final int ioExceptionsThresholdDuringMerge; @SuppressWarnings("UnstableApiUsage") - private final LoadingCache indexCache; + private final LoadingCache indexCache; @SuppressWarnings("UnstableApiUsage") public RemoteBlockPushResolver(TransportConf conf) { @@ -122,15 +113,16 @@ public RemoteBlockPushResolver(TransportConf conf) { NettyUtils.createThreadFactory("spark-shuffle-merged-shuffle-directory-cleaner")); this.minChunkSize = conf.minChunkSizeInMergedShuffleFile(); this.ioExceptionsThresholdDuringMerge = conf.ioExceptionsThresholdDuringMerge(); - CacheLoader indexCacheLoader = - new CacheLoader() { - public ShuffleIndexInformation load(File file) throws IOException { - return new ShuffleIndexInformation(file); + CacheLoader indexCacheLoader = + new CacheLoader() { + public ShuffleIndexInformation load(String filePath) throws IOException { + return new ShuffleIndexInformation(filePath); } }; indexCache = CacheBuilder.newBuilder() .maximumWeight(conf.mergedIndexCacheSize()) - .weigher((Weigher)(file, indexInfo) -> indexInfo.getSize()) + .weigher((Weigher) + (filePath, indexInfo) -> indexInfo.getRetainedMemorySize()) .build(indexCacheLoader); } @@ -169,75 +161,61 @@ private AppShufflePartitionInfo getOrCreateAppShufflePartitionInfo( String blockId) throws BlockPushNonFatalFailure { ConcurrentMap shuffles = appShuffleInfo.shuffles; AppShuffleMergePartitionsInfo shufflePartitionsWithMergeId = - shuffles.compute(shuffleId, (id, appShuffleMergePartitionsInfo) -> { - if (appShuffleMergePartitionsInfo == null) { - File dataFile = - appShuffleInfo.getMergedShuffleDataFile(shuffleId, shuffleMergeId, reduceId); - // If this partition is already finalized then the partitions map will not contain the - // shuffleId for determinate stages but the data file would exist. - // In that case the block is considered late. In the case of indeterminate stages, most - // recent shuffleMergeId finalized would be pointing to INDETERMINATE_SHUFFLE_FINALIZED - if (dataFile.exists()) { - throw new BlockPushNonFatalFailure(new BlockPushReturnCode( - ReturnCode.TOO_LATE_BLOCK_PUSH.id(), blockId).toByteBuffer(), - BlockPushNonFatalFailure.getErrorMsg(blockId, ReturnCode.TOO_LATE_BLOCK_PUSH)); - } else { - logger.info("Creating a new attempt for shuffle blocks push request for shuffle {}" - + " with shuffleMergeId {} for application {}_{}", shuffleId, shuffleMergeId, - appShuffleInfo.appId, appShuffleInfo.attemptId); - return new AppShuffleMergePartitionsInfo(shuffleMergeId, false); - } + shuffles.compute(shuffleId, (id, mergePartitionsInfo) -> { + if (mergePartitionsInfo == null) { + logger.info("{} attempt {} shuffle {} shuffleMerge {}: creating a new shuffle " + + "merge metadata", appShuffleInfo.appId, appShuffleInfo.attemptId, shuffleId, + shuffleMergeId); + return new AppShuffleMergePartitionsInfo(shuffleMergeId, false); } else { - // Reject the request as we have already seen a higher shuffleMergeId than the - // current incoming one - int latestShuffleMergeId = appShuffleMergePartitionsInfo.shuffleMergeId; + int latestShuffleMergeId = mergePartitionsInfo.shuffleMergeId; if (latestShuffleMergeId > shuffleMergeId) { + // Reject the request as we have already seen a higher shuffleMergeId than the one + // in the current request. throw new BlockPushNonFatalFailure( new BlockPushReturnCode(ReturnCode.STALE_BLOCK_PUSH.id(), blockId).toByteBuffer(), BlockPushNonFatalFailure.getErrorMsg(blockId, ReturnCode.STALE_BLOCK_PUSH)); - } else if (latestShuffleMergeId == shuffleMergeId) { - return appShuffleMergePartitionsInfo; - } else { + } else if (latestShuffleMergeId < shuffleMergeId){ // Higher shuffleMergeId seen for the shuffle ID meaning new stage attempt is being // run for the shuffle ID. Close and clean up old shuffleMergeId files, // happens in the indeterminate stage retries - logger.info("Creating a new attempt for shuffle blocks push request for shuffle {}" - + " with shuffleMergeId {} for application {}_{} since it is higher than the" - + " latest shuffleMergeId {} already seen", shuffleId, shuffleMergeId, - appShuffleInfo.appId, appShuffleInfo.attemptId, latestShuffleMergeId); + logger.info("{} attempt {} shuffle {} shuffleMerge {}: creating a new shuffle " + + "merge metadata since received shuffleMergeId is higher than latest " + + "shuffleMergeId {}", appShuffleInfo.appId, appShuffleInfo.attemptId, shuffleId, + shuffleMergeId, latestShuffleMergeId); mergedShuffleCleaner.execute(() -> - closeAndDeletePartitionFiles(appShuffleMergePartitionsInfo.shuffleMergePartitions)); + closeAndDeletePartitionFiles(mergePartitionsInfo.shuffleMergePartitions)); return new AppShuffleMergePartitionsInfo(shuffleMergeId, false); + } else { + // The request is for block with same shuffleMergeId as the latest shuffleMergeId + if (mergePartitionsInfo.isFinalized()) { + throw new BlockPushNonFatalFailure( + new BlockPushReturnCode( + ReturnCode.TOO_LATE_BLOCK_PUSH.id(), blockId).toByteBuffer(), + BlockPushNonFatalFailure.getErrorMsg(blockId, ReturnCode.TOO_LATE_BLOCK_PUSH)); + } + return mergePartitionsInfo; } } }); - - // It only gets here when the shuffle is already finalized. - if (null == shufflePartitionsWithMergeId || - INDETERMINATE_SHUFFLE_FINALIZED == shufflePartitionsWithMergeId.shuffleMergePartitions) { - throw new BlockPushNonFatalFailure( - new BlockPushReturnCode(ReturnCode.TOO_LATE_BLOCK_PUSH.id(), blockId).toByteBuffer(), - BlockPushNonFatalFailure.getErrorMsg(blockId, ReturnCode.TOO_LATE_BLOCK_PUSH)); - } - Map shuffleMergePartitions = - shufflePartitionsWithMergeId.shuffleMergePartitions; + shufflePartitionsWithMergeId.shuffleMergePartitions; return shuffleMergePartitions.computeIfAbsent(reduceId, key -> { // It only gets here when the key is not present in the map. The first time the merge // manager receives a pushed block for a given application shuffle partition. File dataFile = appShuffleInfo.getMergedShuffleDataFile(shuffleId, shuffleMergeId, reduceId); - File indexFile = - appShuffleInfo.getMergedShuffleIndexFile(shuffleId, shuffleMergeId, reduceId); + File indexFile = new File( + appShuffleInfo.getMergedShuffleIndexFilePath(shuffleId, shuffleMergeId, reduceId)); File metaFile = appShuffleInfo.getMergedShuffleMetaFile(shuffleId, shuffleMergeId, reduceId); try { return newAppShufflePartitionInfo(appShuffleInfo.appId, shuffleId, shuffleMergeId, reduceId, dataFile, indexFile, metaFile); } catch (IOException e) { - logger.error( - "Cannot create merged shuffle partition with data file {}, index file {}, and " - + "meta file {}", dataFile.getAbsolutePath(), + logger.error("{} attempt {} shuffle {} shuffleMerge {}: cannot create merged shuffle " + + "partition with data file {}, index file {}, and meta file {}", appShuffleInfo.appId, + appShuffleInfo.attemptId, shuffleId, shuffleMergeId, dataFile.getAbsolutePath(), indexFile.getAbsolutePath(), metaFile.getAbsolutePath()); throw new RuntimeException( String.format("Cannot initialize merged shuffle partition for appId %s shuffleId %s " @@ -274,8 +252,8 @@ public MergedBlockMeta getMergedBlockMeta( shuffleId, shuffleMergeId, reduceId, ErrorHandler.BlockFetchErrorHandler.STALE_SHUFFLE_BLOCK_FETCH)); } - File indexFile = - appShuffleInfo.getMergedShuffleIndexFile(shuffleId, shuffleMergeId, reduceId); + File indexFile = new File( + appShuffleInfo.getMergedShuffleIndexFilePath(shuffleId, shuffleMergeId, reduceId)); if (!indexFile.exists()) { throw new RuntimeException(String.format( "Merged shuffle index file %s not found", indexFile.getPath())); @@ -313,18 +291,18 @@ public ManagedBuffer getMergedBlockData( throw new RuntimeException(String.format("Merged shuffle data file %s not found", dataFile.getPath())); } - File indexFile = - appShuffleInfo.getMergedShuffleIndexFile(shuffleId, shuffleMergeId, reduceId); + String indexFilePath = + appShuffleInfo.getMergedShuffleIndexFilePath(shuffleId, shuffleMergeId, reduceId); try { // If we get here, the merged shuffle file should have been properly finalized. Thus we can // use the file length to determine the size of the merged shuffle block. - ShuffleIndexInformation shuffleIndexInformation = indexCache.get(indexFile); + ShuffleIndexInformation shuffleIndexInformation = indexCache.get(indexFilePath); ShuffleIndexRecord shuffleIndexRecord = shuffleIndexInformation.getIndex(chunkId); return new FileSegmentManagedBuffer( conf, dataFile, shuffleIndexRecord.getOffset(), shuffleIndexRecord.getLength()); } catch (ExecutionException e) { throw new RuntimeException(String.format( - "Failed to open merged shuffle index file %s", indexFile.getPath()), e); + "Failed to open merged shuffle index file %s", indexFilePath), e); } } @@ -350,6 +328,7 @@ public void applicationRemoved(String appId, boolean cleanupLocalDirs) { * If cleanupLocalDirs is true, the merged shuffle files will also be deleted. * The cleanup will be executed in a separate thread. */ + @SuppressWarnings("SynchronizationOnLocalVariableOrMethodParameter") @VisibleForTesting void closeAndDeletePartitionFilesIfNeeded( AppShuffleInfo appShuffleInfo, @@ -512,10 +491,11 @@ public ByteBuffer getCompletionResponse() { } } + @SuppressWarnings("SynchronizationOnLocalVariableOrMethodParameter") @Override public MergeStatuses finalizeShuffleMerge(FinalizeShuffleMerge msg) { - logger.info("Finalizing shuffle {} with shuffleMergeId {} from Application {}_{}.", - msg.shuffleId, msg.shuffleMergeId, msg.appId, msg.appAttemptId); + logger.info("{} attempt {} shuffle {} shuffleMerge {}: finalize shuffle merge", + msg.appId, msg.appAttemptId, msg.shuffleId, msg.shuffleMergeId); AppShuffleInfo appShuffleInfo = validateAndGetAppShuffleInfo(msg.appId); if (appShuffleInfo.attemptId != msg.appAttemptId) { // If finalizeShuffleMerge from a former application attempt, it is considered late, @@ -534,35 +514,33 @@ public MergeStatuses finalizeShuffleMerge(FinalizeShuffleMerge msg) { } AtomicReference> shuffleMergePartitionsRef = new AtomicReference<>(null); - // Metadata of the determinate stage shuffle can be safely removed as part of finalizing - // shuffle merge. Currently once the shuffle is finalized for a determinate stages, retry - // stages of the same shuffle will have shuffle push disabled. - if (msg.shuffleMergeId == DETERMINATE_SHUFFLE_MERGE_ID) { - AppShuffleMergePartitionsInfo appShuffleMergePartitionsInfo = - appShuffleInfo.shuffles.remove(msg.shuffleId); - if (appShuffleMergePartitionsInfo != null) { - shuffleMergePartitionsRef.set(appShuffleMergePartitionsInfo.shuffleMergePartitions); - } - } else { - appShuffleInfo.shuffles.compute(msg.shuffleId, (id, value) -> { - if (null == value || msg.shuffleMergeId < value.shuffleMergeId || - INDETERMINATE_SHUFFLE_FINALIZED == value.shuffleMergePartitions) { + appShuffleInfo.shuffles.compute(msg.shuffleId, (shuffleId, mergePartitionsInfo) -> { + if (null != mergePartitionsInfo) { + if (msg.shuffleMergeId < mergePartitionsInfo.shuffleMergeId || + mergePartitionsInfo.isFinalized()) { throw new RuntimeException(String.format( - "Shuffle merge finalize request for shuffle %s with" + " shuffleMergeId %s is %s", - msg.shuffleId, msg.shuffleMergeId, - ErrorHandler.BlockPushErrorHandler.STALE_SHUFFLE_FINALIZE_SUFFIX)); - } else if (msg.shuffleMergeId > value.shuffleMergeId) { + "Shuffle merge finalize request for shuffle %s with" + " shuffleMergeId %s is %s", + msg.shuffleId, msg.shuffleMergeId, + ErrorHandler.BlockPushErrorHandler.STALE_SHUFFLE_FINALIZE_SUFFIX)); + } else if (msg.shuffleMergeId > mergePartitionsInfo.shuffleMergeId) { // If no blocks pushed for the finalizeShuffleMerge shuffleMergeId then return // empty MergeStatuses but cleanup the older shuffleMergeId files. mergedShuffleCleaner.execute(() -> - closeAndDeletePartitionFiles(value.shuffleMergePartitions)); - return new AppShuffleMergePartitionsInfo(msg.shuffleMergeId, true); + closeAndDeletePartitionFiles(mergePartitionsInfo.shuffleMergePartitions)); } else { - shuffleMergePartitionsRef.set(value.shuffleMergePartitions); - return new AppShuffleMergePartitionsInfo(msg.shuffleMergeId, true); + // This block covers: + // 1. finalization of determinate stage + // 2. finalization of indeterminate stage if the shuffleMergeId related to it is the one + // for which the message is received. + shuffleMergePartitionsRef.set(mergePartitionsInfo.shuffleMergePartitions); } - }); - } + } + // Even when the mergePartitionsInfo is null, we mark the shuffle as finalized but the results + // sent to the driver will be empty. This cam happen when the service didn't receive any + // blocks for the shuffle yet and the driver didn't wait for enough time to finalize the + // shuffle. + return new AppShuffleMergePartitionsInfo(msg.shuffleMergeId, true); + }); Map shuffleMergePartitions = shuffleMergePartitionsRef.get(); MergeStatuses mergeStatuses; if (null == shuffleMergePartitions || shuffleMergePartitions.isEmpty()) { @@ -576,14 +554,25 @@ public MergeStatuses finalizeShuffleMerge(FinalizeShuffleMerge msg) { for (AppShufflePartitionInfo partition: shuffleMergePartitions.values()) { synchronized (partition) { try { + logger.debug("{} attempt {} shuffle {} shuffleMerge {}: finalizing shuffle " + + "partition {} ", msg.appId, msg.appAttemptId, msg.shuffleId, + msg.shuffleMergeId, partition.reduceId); // This can throw IOException which will marks this shuffle partition as not merged. partition.finalizePartition(); - bitmaps.add(partition.mapTracker); - reduceIds.add(partition.reduceId); - sizes.add(partition.getLastChunkOffset()); + if (partition.mapTracker.getCardinality() > 0) { + bitmaps.add(partition.mapTracker); + reduceIds.add(partition.reduceId); + sizes.add(partition.getLastChunkOffset()); + logger.debug("{} attempt {} shuffle {} shuffleMerge {}: finalization results " + + "added for partition {} data size {} index size {} meta size {}", + msg.appId, msg.appAttemptId, msg.shuffleId, + msg.shuffleMergeId, partition.reduceId, partition.getLastChunkOffset(), + partition.indexFile.getPos(), partition.metaFile.getPos()); + } } catch (IOException ioe) { - logger.warn("Exception while finalizing shuffle partition {}_{} {} {}", msg.appId, - msg.appAttemptId, msg.shuffleId, partition.reduceId, ioe); + logger.warn("{} attempt {} shuffle {} shuffleMerge {}: exception while " + + "finalizing shuffle partition {}", msg.appId, msg.appAttemptId, msg.shuffleId, + msg.shuffleMergeId, partition.reduceId); } finally { partition.closeAllFilesAndDeleteIfNeeded(false); } @@ -593,8 +582,8 @@ public MergeStatuses finalizeShuffleMerge(FinalizeShuffleMerge msg) { bitmaps.toArray(new RoaringBitmap[bitmaps.size()]), Ints.toArray(reduceIds), Longs.toArray(sizes)); } - logger.info("Finalized shuffle {} with shuffleMergeId {} from Application {}_{}.", - msg.shuffleId, msg.shuffleMergeId, msg.appId, msg.appAttemptId); + logger.info("{} attempt {} shuffle {} shuffleMerge {}: finalization of shuffle merge completed", + msg.appId, msg.appAttemptId, msg.shuffleId, msg.shuffleMergeId); return mergeStatuses; } @@ -808,7 +797,7 @@ private boolean isTooLate( AppShuffleMergePartitionsInfo appShuffleMergePartitionsInfo, int reduceId) { return null == appShuffleMergePartitionsInfo || - INDETERMINATE_SHUFFLE_FINALIZED == appShuffleMergePartitionsInfo.shuffleMergePartitions || + appShuffleMergePartitionsInfo.isFinalized() || !appShuffleMergePartitionsInfo.shuffleMergePartitions.containsKey(reduceId); } @@ -1008,20 +997,27 @@ AppShufflePartitionInfo getPartitionInfo() { * required for the shuffles of indeterminate stages. */ public static class AppShuffleMergePartitionsInfo { + // ConcurrentHashMap doesn't allow null for keys or values which is why this is required. + // Marker to identify finalized shuffle partitions. + private static final Map SHUFFLE_FINALIZED_MARKER = + Collections.emptyMap(); private final int shuffleMergeId; private final Map shuffleMergePartitions; - public AppShuffleMergePartitionsInfo( - int shuffleMergeId, boolean shuffleFinalized) { + public AppShuffleMergePartitionsInfo(int shuffleMergeId, boolean shuffleFinalized) { this.shuffleMergeId = shuffleMergeId; - this.shuffleMergePartitions = shuffleFinalized ? - INDETERMINATE_SHUFFLE_FINALIZED : new ConcurrentHashMap<>(); + this.shuffleMergePartitions = shuffleFinalized ? SHUFFLE_FINALIZED_MARKER : + new ConcurrentHashMap<>(); } @VisibleForTesting public Map getShuffleMergePartitions() { return shuffleMergePartitions; } + + public boolean isFinalized() { + return shuffleMergePartitions == SHUFFLE_FINALIZED_MARKER; + } } /** Metadata tracked for an actively merged shuffle partition */ @@ -1308,11 +1304,14 @@ public ConcurrentMap getShuffles() { * @see [[org.apache.spark.storage.DiskBlockManager#getMergedShuffleFile( * org.apache.spark.storage.BlockId, scala.Option)]] */ - private File getFile(String filename) { + private String getFilePath(String filename) { // TODO: [SPARK-33236] Change the message when this service is able to handle NM restart - File targetFile = ExecutorDiskUtils.getFile(appPathsInfo.activeLocalDirs, - appPathsInfo.subDirsPerLocalDir, filename); - logger.debug("Get merged file {}", targetFile.getAbsolutePath()); + String targetFile = + ExecutorDiskUtils.getFilePath( + appPathsInfo.activeLocalDirs, + appPathsInfo.subDirsPerLocalDir, + filename); + logger.debug("Get merged file {}", targetFile); return targetFile; } @@ -1332,16 +1331,16 @@ public File getMergedShuffleDataFile( int reduceId) { String fileName = String.format("%s.data", generateFileName(appId, shuffleId, shuffleMergeId, reduceId)); - return getFile(fileName); + return new File(getFilePath(fileName)); } - public File getMergedShuffleIndexFile( + public String getMergedShuffleIndexFilePath( int shuffleId, int shuffleMergeId, int reduceId) { String indexName = String.format("%s.index", generateFileName(appId, shuffleId, shuffleMergeId, reduceId)); - return getFile(indexName); + return getFilePath(indexName); } public File getMergedShuffleMetaFile( @@ -1350,7 +1349,7 @@ public File getMergedShuffleMetaFile( int reduceId) { String metaName = String.format("%s.meta", generateFileName(appId, shuffleId, shuffleMergeId, reduceId)); - return getFile(metaName); + return new File(getFilePath(metaName)); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java index 512e4a52c8628..463edc770d28e 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java @@ -191,7 +191,7 @@ private synchronized void initiateRetry() { */ private synchronized boolean shouldRetry(Throwable e) { boolean isIOException = e instanceof IOException - || (e.getCause() != null && e.getCause() instanceof IOException); + || e.getCause() instanceof IOException; boolean hasRemainingRetries = retryCount < maxRetries; return isIOException && hasRemainingRetries && errorHandler.shouldRetryError(e); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java index b65aacfcc4b9e..6669255f30299 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java @@ -29,25 +29,28 @@ * as an in-memory LongBuffer. */ public class ShuffleIndexInformation { + + // The estimate of `ShuffleIndexInformation` memory footprint which is relevant in case of small + // index files (i.e. storing only 2 offsets = 16 bytes). + static final int INSTANCE_MEMORY_FOOTPRINT = 176; + /** offsets as long buffer */ private final LongBuffer offsets; - private int size; - public ShuffleIndexInformation(File indexFile) throws IOException { - size = (int)indexFile.length(); - ByteBuffer buffer = ByteBuffer.allocate(size); + public ShuffleIndexInformation(String indexFilePath) throws IOException { + File indexFile = new File(indexFilePath); + ByteBuffer buffer = ByteBuffer.allocate((int)indexFile.length()); offsets = buffer.asLongBuffer(); try (DataInputStream dis = new DataInputStream(Files.newInputStream(indexFile.toPath()))) { dis.readFully(buffer.array()); } } - /** - * Size of the index file - * @return size - */ - public int getSize() { - return size; + public int getRetainedMemorySize() { + // SPARK-33206: here the offsets' capacity is multiplied by 8 as offsets stores long values. + // Integer overflow won't be an issue here as long as the number of reducers is under + // (Integer.MAX_VALUE - INSTANCE_MEMORY_FOOTPRINT) / 8 - 1 = 268435432. + return (offsets.capacity() << 3) + INSTANCE_MEMORY_FOOTPRINT; } /** diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockPushReturnCode.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockPushReturnCode.java index 0455d67c5ace2..d3f170f91507f 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockPushReturnCode.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockPushReturnCode.java @@ -68,7 +68,7 @@ public String toString() { @Override public boolean equals(Object other) { - if (other != null && other instanceof BlockPushReturnCode) { + if (other instanceof BlockPushReturnCode) { BlockPushReturnCode o = (BlockPushReturnCode) other; return returnCode == o.returnCode && Objects.equals(failureBlockId, o.failureBlockId); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlocksRemoved.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlocksRemoved.java index a4d6035df807c..452f70c6cd221 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlocksRemoved.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlocksRemoved.java @@ -51,7 +51,7 @@ public String toString() { @Override public boolean equals(Object other) { - if (other != null && other instanceof BlocksRemoved) { + if (other instanceof BlocksRemoved) { BlocksRemoved o = (BlocksRemoved) other; return numRemovedBlocks == o.numRemovedBlocks; } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java index f123ccb663377..ead13f5b14f1a 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java @@ -69,7 +69,7 @@ public String toString() { @Override public boolean equals(Object other) { - if (other != null && other instanceof ExecutorShuffleInfo) { + if (other instanceof ExecutorShuffleInfo) { ExecutorShuffleInfo o = (ExecutorShuffleInfo) other; return Arrays.equals(localDirs, o.localDirs) && subDirsPerLocalDir == o.subDirsPerLocalDir diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FinalizeShuffleMerge.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FinalizeShuffleMerge.java index 675739a41e817..e99fe1707092b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FinalizeShuffleMerge.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FinalizeShuffleMerge.java @@ -69,7 +69,7 @@ public String toString() { @Override public boolean equals(Object other) { - if (other != null && other instanceof FinalizeShuffleMerge) { + if (other instanceof FinalizeShuffleMerge) { FinalizeShuffleMerge o = (FinalizeShuffleMerge) other; return Objects.equal(appId, o.appId) && appAttemptId == o.appAttemptId diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/MergeStatuses.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/MergeStatuses.java index b2658d62b445b..b6bfc302d218b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/MergeStatuses.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/MergeStatuses.java @@ -95,7 +95,7 @@ public String toString() { @Override public boolean equals(Object other) { - if (other != null && other instanceof MergeStatuses) { + if (other instanceof MergeStatuses) { MergeStatuses o = (MergeStatuses) other; return Objects.equal(shuffleId, o.shuffleId) && Objects.equal(shuffleMergeId, o.shuffleMergeId) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java index 771e17b3233ec..91f203764ecd8 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java @@ -60,7 +60,7 @@ public String toString() { @Override public boolean equals(Object other) { - if (other != null && other instanceof OpenBlocks) { + if (other instanceof OpenBlocks) { OpenBlocks o = (OpenBlocks) other; return Objects.equals(appId, o.appId) && Objects.equals(execId, o.execId) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/PushBlockStream.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/PushBlockStream.java index b868d7ccff568..fc9900bae1e8a 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/PushBlockStream.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/PushBlockStream.java @@ -87,7 +87,7 @@ public String toString() { @Override public boolean equals(Object other) { - if (other != null && other instanceof PushBlockStream) { + if (other instanceof PushBlockStream) { PushBlockStream o = (PushBlockStream) other; return Objects.equal(appId, o.appId) && appAttemptId == o.appAttemptId diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java index f6af755cd9cd5..6189820726205 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java @@ -65,7 +65,7 @@ public String toString() { @Override public boolean equals(Object other) { - if (other != null && other instanceof RegisterExecutor) { + if (other instanceof RegisterExecutor) { RegisterExecutor o = (RegisterExecutor) other; return Objects.equals(appId, o.appId) && Objects.equals(execId, o.execId) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RemoveBlocks.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RemoveBlocks.java index ade838bd4286c..6c194d1a14cf2 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RemoveBlocks.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RemoveBlocks.java @@ -60,7 +60,7 @@ public String toString() { @Override public boolean equals(Object other) { - if (other != null && other instanceof RemoveBlocks) { + if (other instanceof RemoveBlocks) { RemoveBlocks o = (RemoveBlocks) other; return Objects.equals(appId, o.appId) && Objects.equals(execId, o.execId) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java index dd7715a4e82d4..20954914a7ced 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java @@ -57,7 +57,7 @@ public String toString() { @Override public boolean equals(Object other) { - if (other != null && other instanceof StreamHandle) { + if (other instanceof StreamHandle) { StreamHandle o = (StreamHandle) other; return Objects.equals(streamId, o.streamId) && Objects.equals(numChunks, o.numChunks); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java index a5bc3f7009b46..c5e07d0d991b7 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java @@ -79,7 +79,7 @@ public String toString() { @Override public boolean equals(Object other) { - if (other != null && other instanceof UploadBlock) { + if (other instanceof UploadBlock) { UploadBlock o = (UploadBlock) other; return Objects.equals(appId, o.appId) && Objects.equals(execId, o.execId) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlockStream.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlockStream.java index 958a84e516c81..a1ac9da0956da 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlockStream.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlockStream.java @@ -63,7 +63,7 @@ public String toString() { @Override public boolean equals(Object other) { - if (other != null && other instanceof UploadBlockStream) { + if (other instanceof UploadBlockStream) { UploadBlockStream o = (UploadBlockStream) other; return Objects.equals(blockId, o.blockId) && Arrays.equals(metadata, o.metadata); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java index f4a29aaac19f7..603b20c7dbacf 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java @@ -1161,8 +1161,8 @@ public void testFinalizeOfDeterminateShuffle() throws IOException { RemoteBlockPushResolver.AppShuffleInfo appShuffleInfo = pushResolver.validateAndGetAppShuffleInfo(TEST_APP); - assertTrue("Metadata of determinate shuffle should be removed after finalize shuffle" - + " merge", appShuffleInfo.getShuffles().get(0) == null); + assertTrue("Determinate shuffle should be marked finalized", + appShuffleInfo.getShuffles().get(0).isFinalized()); validateMergeStatuses(statuses, new int[] {0}, new long[] {9}); MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0); validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[]{4, 5}, new int[][]{{0}, {1}}); @@ -1247,7 +1247,7 @@ void closeAndDeletePartitionFiles(Map partitio assertFalse("Meta files on the disk should be cleaned up", appShuffleInfo.getMergedShuffleMetaFile(0, 1, 0).exists()); assertFalse("Index files on the disk should be cleaned up", - appShuffleInfo.getMergedShuffleIndexFile(0, 1, 0).exists()); + new File(appShuffleInfo.getMergedShuffleIndexFilePath(0, 1, 0)).exists()); stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2])); stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2])); // stream 2 now completes @@ -1282,11 +1282,84 @@ void closeAndDeletePartitionFiles(Map partitio assertFalse("MergedBlock meta file for shuffle 0 and shuffleMergeId 4 should be cleaned" + " up", appShuffleInfo.getMergedShuffleMetaFile(0, 4, 0).exists()); assertFalse("MergedBlock index file for shuffle 0 and shuffleMergeId 4 should be cleaned" - + " up", appShuffleInfo.getMergedShuffleIndexFile(0, 4, 0).exists()); + + " up", new File(appShuffleInfo.getMergedShuffleIndexFilePath(0, 4, 0)).exists()); assertFalse("MergedBlock data file for shuffle 0 and shuffleMergeId 4 should be cleaned" + " up", appShuffleInfo.getMergedShuffleDataFile(0, 4, 0).exists()); } + @Test + public void testFinalizationResultIsEmptyWhenTheServerDidNotReceiveAnyBlocks() { + //shuffle 1 0 is finalized even though the server didn't receive any blocks for it. + MergeStatuses statuses = pushResolver.finalizeShuffleMerge( + new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 1, 0)); + assertEquals("no partitions were merged", 0, statuses.reduceIds.length); + RemoteBlockPushResolver.AppShuffleInfo appShuffleInfo = + pushResolver.validateAndGetAppShuffleInfo(TEST_APP); + assertTrue("shuffle 1 should be marked finalized", + appShuffleInfo.getShuffles().get(1).isFinalized()); + removeApplication(TEST_APP); + } + + // Test for SPARK-37675 and SPARK-37793 + @Test + public void testEmptyMergePartitionsAreNotReported() throws IOException { + //shufflePush_1_0_0_100 is received by the server + StreamCallbackWithID stream1 = pushResolver.receiveBlockDataAsStream( + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 1, 0, 0, 100, 0)); + stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[4])); + //shuffle 1 0 is finalized + MergeStatuses statuses = pushResolver.finalizeShuffleMerge( + new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 1, 0)); + assertEquals("no partitions were merged", 0, statuses.reduceIds.length); + removeApplication(TEST_APP); + } + + // Test for SPARK-37675 and SPARK-37793 + @Test + public void testAllBlocksAreRejectedWhenReceivedAfterFinalization() throws IOException { + //shufflePush_1_0_0_100 is received by the server + StreamCallbackWithID stream1 = pushResolver.receiveBlockDataAsStream( + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 1, 0, 0, 100, 0)); + stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[4])); + stream1.onComplete(stream1.getID()); + //shuffle 1 0 is finalized + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 1, 0)); + BlockPushNonFatalFailure errorToValidate = null; + try { + //shufflePush_1_0_0_200 is received by the server after finalization of shuffle 1 0 which + //should be rejected + StreamCallbackWithID failureCallback = pushResolver.receiveBlockDataAsStream( + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 1, 0, 0, 200, 0)); + failureCallback.onComplete(failureCallback.getID()); + } catch (BlockPushNonFatalFailure e) { + BlockPushReturnCode errorCode = + (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse()); + assertEquals(BlockPushNonFatalFailure.ReturnCode.TOO_LATE_BLOCK_PUSH.id(), + errorCode.returnCode); + errorToValidate = e; + assertEquals(errorCode.failureBlockId, "shufflePush_1_0_0_200"); + } + assertNotNull("shufflePush_1_0_0_200 should be rejected", errorToValidate); + try { + //shufflePush_1_0_1_100 is received by the server after finalization of shuffle 1 0 which + //should also be rejected + StreamCallbackWithID failureCallback = pushResolver.receiveBlockDataAsStream( + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 1, 0, 1, 100, 0)); + failureCallback.onComplete(failureCallback.getID()); + } catch (BlockPushNonFatalFailure e) { + BlockPushReturnCode errorCode = + (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse()); + assertEquals(BlockPushNonFatalFailure.ReturnCode.TOO_LATE_BLOCK_PUSH.id(), + errorCode.returnCode); + errorToValidate = e; + assertEquals(errorCode.failureBlockId, "shufflePush_1_0_1_100"); + } + assertNotNull("shufflePush_1_0_1_100 should be rejected", errorToValidate); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 1, 0, 100); + validateChunks(TEST_APP, 1, 0, 100, blockMeta, new int[]{4}, new int[][]{{0}}); + removeApplication(TEST_APP); + } + private void useTestFiles(boolean useTestIndexFile, boolean useTestMetaFile) throws IOException { pushResolver = new RemoteBlockPushResolver(conf) { @Override diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java index 1b44b061f3d81..985a7a364282e 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java @@ -240,7 +240,6 @@ public void testRetryAndUnrecoverable() throws IOException, InterruptedException * retries -- the first interaction may include an IOException, which causes a retry of some * subset of the original blocks in a second interaction. */ - @SuppressWarnings("unchecked") private static void performInteractions(List> interactions, BlockFetchingListener listener) throws IOException, InterruptedException { diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleIndexInformationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleIndexInformationSuite.java new file mode 100644 index 0000000000000..c4ff8935e2d64 --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleIndexInformationSuite.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.io.IOException; + +import java.nio.charset.StandardCharsets; + +import static org.junit.Assert.*; + +public class ShuffleIndexInformationSuite { + private static final String sortBlock0 = "tiny block"; + private static final String sortBlock1 = "a bit longer block"; + + private static TestShuffleDataContext dataContext; + private static String blockId; + + @BeforeClass + public static void before() throws IOException { + dataContext = new TestShuffleDataContext(2, 5); + + dataContext.create(); + // Write some sort data. + blockId = dataContext.insertSortShuffleData(0, 0, new byte[][] { + sortBlock0.getBytes(StandardCharsets.UTF_8), + sortBlock1.getBytes(StandardCharsets.UTF_8)}); + } + + @AfterClass + public static void afterAll() { + dataContext.cleanup(); + } + + @Test + public void test() throws IOException { + String path = ExecutorDiskUtils.getFilePath( + dataContext.localDirs, + dataContext.subDirsPerLocalDir, + blockId + ".index"); + ShuffleIndexInformation s = new ShuffleIndexInformation(path); + // the index file contains 3 offsets: + // 0, sortBlock0.length, sortBlock0.length + sortBlock1.length + assertEquals(0L, s.getIndex(0).getOffset()); + assertEquals(sortBlock0.length(), s.getIndex(0).getLength()); + + assertEquals(sortBlock0.length(), s.getIndex(1).getOffset()); + assertEquals(sortBlock1.length(), s.getIndex(1).getLength()); + + assertEquals((3 * 8) + ShuffleIndexInformation.INSTANCE_MEMORY_FOOTPRINT, + s.getRetainedMemorySize()); + } +} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java index fb67d7220a0b4..bcf57ea621979 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java @@ -68,7 +68,8 @@ public void cleanup() { } /** Creates reducer blocks in a sort-based data format within our local dirs. */ - public void insertSortShuffleData(int shuffleId, int mapId, byte[][] blocks) throws IOException { + public String insertSortShuffleData(int shuffleId, int mapId, byte[][] blocks) + throws IOException { String blockId = "shuffle_" + shuffleId + "_" + mapId + "_0"; OutputStream dataStream = null; @@ -76,10 +77,10 @@ public void insertSortShuffleData(int shuffleId, int mapId, byte[][] blocks) thr boolean suppressExceptionsDuringClose = true; try { - dataStream = new FileOutputStream( - ExecutorDiskUtils.getFile(localDirs, subDirsPerLocalDir, blockId + ".data")); - indexStream = new DataOutputStream(new FileOutputStream( - ExecutorDiskUtils.getFile(localDirs, subDirsPerLocalDir, blockId + ".index"))); + dataStream = new FileOutputStream(new File( + ExecutorDiskUtils.getFilePath(localDirs, subDirsPerLocalDir, blockId + ".data"))); + indexStream = new DataOutputStream(new FileOutputStream(new File( + ExecutorDiskUtils.getFilePath(localDirs, subDirsPerLocalDir, blockId + ".index")))); long offset = 0; indexStream.writeLong(offset); @@ -93,6 +94,7 @@ public void insertSortShuffleData(int shuffleId, int mapId, byte[][] blocks) thr Closeables.close(dataStream, suppressExceptionsDuringClose); Closeables.close(indexStream, suppressExceptionsDuringClose); } + return blockId; } /** Creates spill file(s) within the local dirs. */ @@ -122,11 +124,11 @@ private void insertFile(String filename) throws IOException { private void insertFile(String filename, byte[] block) throws IOException { OutputStream dataStream = null; - File file = ExecutorDiskUtils.getFile(localDirs, subDirsPerLocalDir, filename); + File file = new File(ExecutorDiskUtils.getFilePath(localDirs, subDirsPerLocalDir, filename)); Assert.assertFalse("this test file has been already generated", file.exists()); try { dataStream = new FileOutputStream( - ExecutorDiskUtils.getFile(localDirs, subDirsPerLocalDir, filename)); + new File(ExecutorDiskUtils.getFilePath(localDirs, subDirsPerLocalDir, filename))); dataStream.write(block); } finally { Closeables.close(dataStream, false); diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java index c01c0470fa8c5..31857443e8c68 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java @@ -115,7 +115,7 @@ static BitArray readFrom(DataInputStream in) throws IOException { @Override public boolean equals(Object other) { if (this == other) return true; - if (other == null || !(other instanceof BitArray)) return false; + if (!(other instanceof BitArray)) return false; BitArray that = (BitArray) other; return Arrays.equals(data, that.data); } diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java index 5afe5fe45b18d..e7766ee903480 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java @@ -42,7 +42,7 @@ public boolean equals(Object other) { return true; } - if (other == null || !(other instanceof BloomFilterImpl)) { + if (!(other instanceof BloomFilterImpl)) { return false; } diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index f6c1c39bbfd0a..80e71738198b2 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -70,7 +70,7 @@ public boolean equals(Object other) { return true; } - if (other == null || !(other instanceof CountMinSketchImpl)) { + if (!(other instanceof CountMinSketchImpl)) { return false; } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index 5a7e32b0d9d3b..deb7d2bf1b0f8 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -19,6 +19,8 @@ import org.apache.spark.unsafe.Platform; +import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET; + public class ByteArrayMethods { private ByteArrayMethods() { @@ -91,4 +93,39 @@ public static boolean arrayEquals( } return true; } + + public static boolean contains(byte[] arr, byte[] sub) { + if (sub.length == 0) { + return true; + } + byte first = sub[0]; + for (int i = 0; i <= arr.length - sub.length; i++) { + if (arr[i] == first && matchAt(arr, sub, i)) { + return true; + } + } + return false; + } + + public static boolean startsWith(byte[] array, byte[] target) { + if (target.length > array.length) { + return false; + } + return arrayEquals(array, BYTE_ARRAY_OFFSET, target, BYTE_ARRAY_OFFSET, target.length); + } + + public static boolean endsWith(byte[] array, byte[] target) { + if (target.length > array.length) { + return false; + } + return arrayEquals(array, BYTE_ARRAY_OFFSET + array.length - target.length, + target, BYTE_ARRAY_OFFSET, target.length); + } + + public static boolean matchAt(byte[] arr, byte[] sub, int pos) { + if (sub.length + pos > arr.length || pos < 0) { + return false; + } + return arrayEquals(arr, BYTE_ARRAY_OFFSET + pos, sub, BYTE_ARRAY_OFFSET, sub.length); + } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java index 4126cf5150fa9..aae47aa963201 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -178,7 +178,7 @@ private static void fillWithPattern(byte[] result, int firstPos, int beyondPos, for (int pos = firstPos; pos < beyondPos; pos += pad.length) { final int jMax = Math.min(pad.length, beyondPos - pos); for (int j = 0; j < jMax; ++j) { - result[pos + j] = (byte) pad[j]; + result[pos + j] = pad[j]; } } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index c47b90d4be6af..98c61cfd9bb9b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -20,7 +20,6 @@ import javax.annotation.Nonnull; import java.io.*; import java.nio.ByteBuffer; -import java.nio.ByteOrder; import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Map; @@ -96,9 +95,6 @@ public final class UTF8String implements Comparable, Externalizable, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 // 0xF5..0xFF - disallowed in UTF-8 }; - private static final boolean IS_LITTLE_ENDIAN = - ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN; - private static final UTF8String COMMA_UTF8 = UTF8String.fromString(","); public static final UTF8String EMPTY_UTF8 = UTF8String.fromString(""); @@ -373,7 +369,7 @@ public UTF8String toUpperCase() { // fallback return toUpperCaseSlow(); } - int upper = Character.toUpperCase((int) b); + int upper = Character.toUpperCase(b); if (upper > 127) { // fallback return toUpperCaseSlow(); @@ -403,7 +399,7 @@ public UTF8String toLowerCase() { // fallback return toLowerCaseSlow(); } - int lower = Character.toLowerCase((int) b); + int lower = Character.toLowerCase(b); if (lower > 127) { // fallback return toLowerCaseSlow(); diff --git a/conf/log4j2.properties.template b/conf/log4j2.properties.template index 85b4f679a93e2..99f68a8a9e98c 100644 --- a/conf/log4j2.properties.template +++ b/conf/log4j2.properties.template @@ -57,7 +57,7 @@ logger.FunctionRegistry.level = error # For deploying Spark ThriftServer # SPARK-34128: Suppress undesirable TTransportException warnings involved in THRIFT-4805 -appender.console.filter.1.type = MarkerFilter -appender.console.filter.1.marker = Thrift error occurred during processing of message +appender.console.filter.1.type = RegexFilter +appender.console.filter.1.regex = .*Thrift error occurred during processing of message.* appender.console.filter.1.onMatch = deny appender.console.filter.1.onMismatch = neutral diff --git a/core/benchmarks/ZStandardBenchmark-jdk11-results.txt b/core/benchmarks/ZStandardBenchmark-jdk11-results.txt index d975b1d05fc98..53c9299e84366 100644 --- a/core/benchmarks/ZStandardBenchmark-jdk11-results.txt +++ b/core/benchmarks/ZStandardBenchmark-jdk11-results.txt @@ -2,26 +2,26 @@ Benchmark ZStandardCompressionCodec ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 11.0.14+9-LTS on Linux 5.11.0-1027-azure +Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------------- -Compression 10000 times at level 1 without buffer pool 517 518 1 0.0 51679.1 1.0X -Compression 10000 times at level 2 without buffer pool 828 829 1 0.0 82770.5 0.6X -Compression 10000 times at level 3 without buffer pool 1031 1035 6 0.0 103117.5 0.5X -Compression 10000 times at level 1 with buffer pool 474 475 1 0.0 47377.9 1.1X -Compression 10000 times at level 2 with buffer pool 544 545 1 0.0 54382.9 1.0X -Compression 10000 times at level 3 with buffer pool 728 732 5 0.0 72791.2 0.7X +Compression 10000 times at level 1 without buffer pool 584 604 15 0.0 58407.5 1.0X +Compression 10000 times at level 2 without buffer pool 654 665 11 0.0 65444.9 0.9X +Compression 10000 times at level 3 without buffer pool 907 916 8 0.0 90677.0 0.6X +Compression 10000 times at level 1 with buffer pool 674 686 11 0.0 67437.9 0.9X +Compression 10000 times at level 2 with buffer pool 759 769 10 0.0 75916.2 0.8X +Compression 10000 times at level 3 with buffer pool 1006 1017 16 0.0 100600.2 0.6X -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 11.0.14+9-LTS on Linux 5.11.0-1027-azure +Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------ -Decompression 10000 times from level 1 without buffer pool 1097 1097 1 0.0 109654.6 1.0X -Decompression 10000 times from level 2 without buffer pool 1097 1097 0 0.0 109695.5 1.0X -Decompression 10000 times from level 3 without buffer pool 1093 1093 1 0.0 109309.2 1.0X -Decompression 10000 times from level 1 with buffer pool 854 855 1 0.0 85422.5 1.3X -Decompression 10000 times from level 2 with buffer pool 853 853 0 0.0 85287.9 1.3X -Decompression 10000 times from level 3 with buffer pool 854 854 0 0.0 85417.9 1.3X +Decompression 10000 times from level 1 without buffer pool 693 698 9 0.0 69257.4 1.0X +Decompression 10000 times from level 2 without buffer pool 699 707 7 0.0 69857.8 1.0X +Decompression 10000 times from level 3 without buffer pool 689 697 7 0.0 68858.9 1.0X +Decompression 10000 times from level 1 with buffer pool 450 476 37 0.0 45005.9 1.5X +Decompression 10000 times from level 2 with buffer pool 527 550 26 0.0 52653.2 1.3X +Decompression 10000 times from level 3 with buffer pool 452 513 43 0.0 45201.4 1.5X diff --git a/core/benchmarks/ZStandardBenchmark-jdk17-results.txt b/core/benchmarks/ZStandardBenchmark-jdk17-results.txt index ecb7e6c6bcfd6..c6d84b79cb29c 100644 --- a/core/benchmarks/ZStandardBenchmark-jdk17-results.txt +++ b/core/benchmarks/ZStandardBenchmark-jdk17-results.txt @@ -2,26 +2,26 @@ Benchmark ZStandardCompressionCodec ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 17.0.2+8-LTS on Linux 5.11.0-1027-azure Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------------- -Compression 10000 times at level 1 without buffer pool 2930 2953 33 0.0 293038.2 1.0X -Compression 10000 times at level 2 without buffer pool 1846 2728 1248 0.0 184565.8 1.6X -Compression 10000 times at level 3 without buffer pool 2109 2110 2 0.0 210881.8 1.4X -Compression 10000 times at level 1 with buffer pool 1466 1479 19 0.0 146569.0 2.0X -Compression 10000 times at level 2 with buffer pool 1570 1584 20 0.0 156976.5 1.9X -Compression 10000 times at level 3 with buffer pool 1845 1852 10 0.0 184465.3 1.6X +Compression 10000 times at level 1 without buffer pool 2380 2426 65 0.0 238014.5 1.0X +Compression 10000 times at level 2 without buffer pool 1532 2271 1045 0.0 153222.7 1.6X +Compression 10000 times at level 3 without buffer pool 1746 1757 15 0.0 174619.0 1.4X +Compression 10000 times at level 1 with buffer pool 1177 1178 2 0.0 117681.3 2.0X +Compression 10000 times at level 2 with buffer pool 1267 1273 8 0.0 126719.0 1.9X +Compression 10000 times at level 3 with buffer pool 1517 1603 122 0.0 151729.8 1.6X -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 17.0.2+8-LTS on Linux 5.11.0-1027-azure Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------ -Decompression 10000 times from level 1 without buffer pool 2852 2887 49 0.0 285224.2 1.0X -Decompression 10000 times from level 2 without buffer pool 2903 2908 7 0.0 290287.1 1.0X -Decompression 10000 times from level 3 without buffer pool 2846 2858 18 0.0 284558.0 1.0X -Decompression 10000 times from level 1 with buffer pool 2637 2647 14 0.0 263714.3 1.1X -Decompression 10000 times from level 2 with buffer pool 2619 2629 14 0.0 261915.2 1.1X -Decompression 10000 times from level 3 with buffer pool 2640 2652 17 0.0 263976.7 1.1X +Decompression 10000 times from level 1 without buffer pool 2241 2271 42 0.0 224123.2 1.0X +Decompression 10000 times from level 2 without buffer pool 2210 2253 62 0.0 220980.7 1.0X +Decompression 10000 times from level 3 without buffer pool 2220 2228 12 0.0 221964.2 1.0X +Decompression 10000 times from level 1 with buffer pool 1987 1995 12 0.0 198705.4 1.1X +Decompression 10000 times from level 2 with buffer pool 1966 1968 4 0.0 196572.3 1.1X +Decompression 10000 times from level 3 with buffer pool 1983 1991 11 0.0 198277.7 1.1X diff --git a/core/benchmarks/ZStandardBenchmark-results.txt b/core/benchmarks/ZStandardBenchmark-results.txt index 24b982ad63a5e..5de6d182fa6de 100644 --- a/core/benchmarks/ZStandardBenchmark-results.txt +++ b/core/benchmarks/ZStandardBenchmark-results.txt @@ -2,26 +2,26 @@ Benchmark ZStandardCompressionCodec ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 1.8.0_322-b06 on Linux 5.11.0-1027-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------------- -Compression 10000 times at level 1 without buffer pool 398 523 144 0.0 39785.2 1.0X -Compression 10000 times at level 2 without buffer pool 452 457 5 0.0 45210.8 0.9X -Compression 10000 times at level 3 without buffer pool 634 650 15 0.0 63405.8 0.6X -Compression 10000 times at level 1 with buffer pool 329 334 4 0.0 32851.3 1.2X -Compression 10000 times at level 2 with buffer pool 384 393 7 0.0 38421.9 1.0X -Compression 10000 times at level 3 with buffer pool 561 570 7 0.0 56070.4 0.7X +Compression 10000 times at level 1 without buffer pool 633 774 122 0.0 63315.3 1.0X +Compression 10000 times at level 2 without buffer pool 748 749 2 0.0 74771.7 0.8X +Compression 10000 times at level 3 without buffer pool 945 949 7 0.0 94461.5 0.7X +Compression 10000 times at level 1 with buffer pool 287 289 2 0.0 28703.6 2.2X +Compression 10000 times at level 2 with buffer pool 336 342 3 0.0 33641.3 1.9X +Compression 10000 times at level 3 with buffer pool 517 528 8 0.0 51747.9 1.2X -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 1.8.0_322-b06 on Linux 5.11.0-1027-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------ -Decompression 10000 times from level 1 without buffer pool 686 686 0 0.0 68582.6 1.0X -Decompression 10000 times from level 2 without buffer pool 683 686 3 0.0 68270.5 1.0X -Decompression 10000 times from level 3 without buffer pool 687 690 4 0.0 68653.8 1.0X -Decompression 10000 times from level 1 with buffer pool 495 497 3 0.0 49467.7 1.4X -Decompression 10000 times from level 2 with buffer pool 438 467 26 0.0 43839.3 1.6X -Decompression 10000 times from level 3 with buffer pool 495 496 1 0.0 49474.0 1.4X +Decompression 10000 times from level 1 without buffer pool 683 689 9 0.0 68294.8 1.0X +Decompression 10000 times from level 2 without buffer pool 684 685 1 0.0 68441.8 1.0X +Decompression 10000 times from level 3 without buffer pool 684 685 1 0.0 68446.7 1.0X +Decompression 10000 times from level 1 with buffer pool 494 495 2 0.0 49362.5 1.4X +Decompression 10000 times from level 2 with buffer pool 493 495 2 0.0 49330.7 1.4X +Decompression 10000 times from level 3 with buffer pool 494 497 5 0.0 49359.8 1.4X diff --git a/core/pom.xml b/core/pom.xml index 8c9bbd3b8d277..953c76b73469f 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -193,8 +193,8 @@ commons-io - commons-collections - commons-collections + org.apache.commons + commons-collections4 com.google.code.findbugs @@ -250,10 +250,6 @@ org.roaringbitmap RoaringBitmap - - commons-net - commons-net - org.scala-lang.modules scala-xml_${scala.binary.version} @@ -399,16 +395,6 @@ xml-apis test - - org.hamcrest - hamcrest-core - test - - - org.hamcrest - hamcrest-library - test - org.mockito mockito-core @@ -437,7 +423,7 @@ net.sf.py4j py4j - 0.10.9.3 + 0.10.9.4 org.apache.spark diff --git a/core/src/main/java/org/apache/hadoop/shaded/net/jpountz/lz4/LZ4Factory.java b/core/src/main/java/org/apache/hadoop/shaded/net/jpountz/lz4/LZ4Factory.java deleted file mode 100644 index 61829b2728bce..0000000000000 --- a/core/src/main/java/org/apache/hadoop/shaded/net/jpountz/lz4/LZ4Factory.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.hadoop.shaded.net.jpountz.lz4; - -/** - * TODO(SPARK-36679): A temporary workaround for SPARK-36669. We should remove this after - * Hadoop 3.3.2 release which fixes the LZ4 relocation in shaded Hadoop client libraries. - * This does not need implement all net.jpountz.lz4.LZ4Factory API, just the ones used by - * Hadoop Lz4Compressor. - */ -public final class LZ4Factory { - - private net.jpountz.lz4.LZ4Factory lz4Factory; - - public LZ4Factory(net.jpountz.lz4.LZ4Factory lz4Factory) { - this.lz4Factory = lz4Factory; - } - - public static LZ4Factory fastestInstance() { - return new LZ4Factory(net.jpountz.lz4.LZ4Factory.fastestInstance()); - } - - public LZ4Compressor highCompressor() { - return new LZ4Compressor(lz4Factory.highCompressor()); - } - - public LZ4Compressor fastCompressor() { - return new LZ4Compressor(lz4Factory.fastCompressor()); - } - - public LZ4SafeDecompressor safeDecompressor() { - return new LZ4SafeDecompressor(lz4Factory.safeDecompressor()); - } -} diff --git a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java index 7ca5ade7b9a74..91910b99ac999 100644 --- a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java +++ b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java @@ -49,7 +49,7 @@ public NioBufferedFileInputStream(File file) throws IOException { } /** - * Checks weather data is left to be read from the input stream. + * Checks whether data is left to be read from the input stream. * @return true if data is left, false otherwise */ private boolean refill() throws IOException { diff --git a/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java index 2e18715b600e0..011fecb315639 100644 --- a/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java +++ b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java @@ -302,7 +302,7 @@ public int available() throws IOException { stateChangeLock.lock(); // Make sure we have no integer overflow. try { - return (int) Math.min((long) Integer.MAX_VALUE, + return (int) Math.min(Integer.MAX_VALUE, (long) activeBuffer.remaining() + readAheadBuffer.remaining()); } finally { stateChangeLock.unlock(); diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index f474c30b8b3d8..f4f4052b4faf4 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -941,7 +941,7 @@ public long getPeakMemoryUsedBytes() { /** * Returns the average number of probes per key lookup. */ - public double getAvgHashProbeBucketListIterations() { + public double getAvgHashProbesPerKey() { return (1.0 * numProbes) / numKeyLookups; } diff --git a/core/src/main/resources/META-INF/services/org.apache.spark.deploy.history.EventFilterBuilder b/core/src/main/resources/META-INF/services/org.apache.spark.deploy.history.EventFilterBuilder index 784e58270ab42..e349eac3d0d07 100644 --- a/core/src/main/resources/META-INF/services/org.apache.spark.deploy.history.EventFilterBuilder +++ b/core/src/main/resources/META-INF/services/org.apache.spark.deploy.history.EventFilterBuilder @@ -1 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + org.apache.spark.deploy.history.BasicEventFilterBuilder \ No newline at end of file diff --git a/core/src/main/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider b/core/src/main/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider index c1f2060cabcff..3c2e241793d8e 100644 --- a/core/src/main/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider +++ b/core/src/main/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider @@ -1,2 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + org.apache.spark.deploy.security.HadoopFSDelegationTokenProvider org.apache.spark.deploy.security.HBaseDelegationTokenProvider diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 34ccc63ff072c..c7a9c854cb486 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -1,7 +1,4 @@ { - "AES_CRYPTO_ERROR" : { - "message" : [ "AES crypto operation failed with: %s" ] - }, "AMBIGUOUS_FIELD_NAME" : { "message" : [ "Field name %s is ambiguous and has %s matching fields in the struct." ], "sqlState" : "42000" @@ -18,6 +15,12 @@ "message" : [ "Cannot parse decimal" ], "sqlState" : "42000" }, + "CANNOT_UP_CAST_DATATYPE" : { + "message" : [ "Cannot up cast %s from %s to %s.\n%s" ] + }, + "CANNOT_USE_MIXTURE" : { + "message" : [ "Cannot use a mixture of aggregate function and group aggregate pandas UDF" ] + }, "CAST_CAUSES_OVERFLOW" : { "message" : [ "Casting %s to %s causes overflow. To return NULL instead, use 'try_cast'. If necessary set %s to false to bypass this error." ], "sqlState" : "22005" @@ -25,6 +28,10 @@ "CONCURRENT_QUERY" : { "message" : [ "Another instance of this query was just started by a concurrent session." ] }, + "DATETIME_OVERFLOW" : { + "message" : [ "Datetime operation overflow: %s." ], + "sqlState" : "22008" + }, "DIVIDE_BY_ZERO" : { "message" : [ "divide by zero. To return NULL instead, use 'try_divide'. If necessary set %s to false (except for ANSI interval type) to bypass this error." ], "sqlState" : "22012" @@ -43,6 +50,12 @@ "FAILED_SET_ORIGINAL_PERMISSION_BACK" : { "message" : [ "Failed to set original permission %s back to the created path: %s. Exception: %s" ] }, + "GRAPHITE_SINK_INVALID_PROTOCOL" : { + "message" : [ "Invalid Graphite protocol: %s" ] + }, + "GRAPHITE_SINK_PROPERTY_MISSING" : { + "message" : [ "Graphite sink requires '%s' property." ] + }, "GROUPING_COLUMN_MISMATCH" : { "message" : [ "Column of grouping (%s) can't be found in grouping columns %s" ], "sqlState" : "42000" @@ -54,9 +67,6 @@ "GROUPING_SIZE_LIMIT_EXCEEDED" : { "message" : [ "Grouping sets size cannot be greater than %s" ] }, - "IF_PARTITION_NOT_EXISTS_UNSUPPORTED" : { - "message" : [ "Cannot write, IF NOT EXISTS is not supported for table: %s" ] - }, "ILLEGAL_SUBSTRING" : { "message" : [ "%s cannot contain %s." ] }, @@ -67,6 +77,9 @@ "INCOMPATIBLE_DATASOURCE_REGISTER" : { "message" : [ "Detected an incompatible DataSourceRegister. Please remove the incompatible library from classpath or upgrade it. Error: %s" ] }, + "INCONSISTENT_BEHAVIOR_CROSS_VERSION" : { + "message" : [ "You may get a different result due to the upgrading to Spark >= %s: %s" ] + }, "INDEX_OUT_OF_BOUNDS" : { "message" : [ "Index %s must be between 0 and the length of the ArrayData." ], "sqlState" : "22023" @@ -74,10 +87,6 @@ "INTERNAL_ERROR" : { "message" : [ "%s" ] }, - "INVALID_AES_KEY_LENGTH" : { - "message" : [ "The key length of aes_encrypt/aes_decrypt should be one of 16, 24 or 32 bytes, but got: %s" ], - "sqlState" : "42000" - }, "INVALID_ARRAY_INDEX" : { "message" : [ "Invalid index: %s, numElements: %s. If necessary set %s to false to bypass this error." ] }, @@ -99,6 +108,14 @@ "INVALID_JSON_SCHEMA_MAPTYPE" : { "message" : [ "Input schema %s can only contain StringType as a key type for a MapType." ] }, + "INVALID_PARAMETER_VALUE" : { + "message" : [ "The value of parameter(s) '%s' in %s is invalid: %s" ], + "sqlState" : "22023" + }, + "INVALID_SQL_SYNTAX" : { + "message" : [ "Invalid SQL syntax: %s" ], + "sqlState" : "42000" + }, "MAP_KEY_DOES_NOT_EXIST" : { "message" : [ "Key %s does not exist. If necessary set %s to false to bypass this error." ] }, @@ -121,6 +138,14 @@ "message" : [ "PARTITION clause cannot contain a non-partition column name: %s" ], "sqlState" : "42000" }, + "PARSE_EMPTY_STATEMENT" : { + "message" : [ "Syntax error, unexpected empty statement" ], + "sqlState" : "42000" + }, + "PARSE_INPUT_MISMATCHED" : { + "message" : [ "Syntax error at or near %s" ], + "sqlState" : "42000" + }, "PIVOT_VALUE_DATA_TYPE_MISMATCH" : { "message" : [ "Invalid pivot value '%s': value data type %s does not match pivot column data type %s" ], "sqlState" : "42000" @@ -129,10 +154,6 @@ "message" : [ "Failed to rename as %s was not found" ], "sqlState" : "22023" }, - "ROW_FROM_CSV_PARSER_NOT_EXPECTED" : { - "message" : [ "Expected one row from CSV parser." ], - "sqlState" : "42000" - }, "SECOND_FUNCTION_ARGUMENT_NOT_INTEGER" : { "message" : [ "The second argument of '%s' function needs to be an integer." ], "sqlState" : "22023" @@ -144,28 +165,19 @@ "message" : [ "Unrecognized SQL type %s" ], "sqlState" : "42000" }, - "UNSUPPORTED_AES_MODE" : { - "message" : [ "The AES mode %s with the padding %s is not supported" ], - "sqlState" : "0A000" - }, - "UNSUPPORTED_CHANGE_COLUMN" : { - "message" : [ "Please add an implementation for a column change here" ], - "sqlState" : "0A000" - }, "UNSUPPORTED_DATATYPE" : { "message" : [ "Unsupported data type %s" ], "sqlState" : "0A000" }, - "UNSUPPORTED_LITERAL_TYPE" : { - "message" : [ "Unsupported literal type %s %s" ], + "UNSUPPORTED_FEATURE" : { + "message" : [ "The feature is not supported: %s" ], "sqlState" : "0A000" }, - "UNSUPPORTED_SIMPLE_STRING_WITH_NODE_ID" : { - "message" : [ "%s does not implement simpleStringWithNodeId" ] + "UNSUPPORTED_GROUPING_EXPRESSION" : { + "message" : [ "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup" ] }, - "UNSUPPORTED_TRANSACTION_BY_JDBC_SERVER" : { - "message" : [ "The target JDBC server does not support transaction and can only support ALTER TABLE with a single action." ], - "sqlState" : "0A000" + "UNSUPPORTED_OPERATION" : { + "message" : [ "The operation is not supported: %s" ] }, "WRITING_JOB_ABORTED" : { "message" : [ "Writing job aborted" ], diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 8e348eefef6c2..fbb92b4b4e293 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -104,15 +104,17 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( private[this] val numPartitions = rdd.partitions.length - // By default, shuffle merge is enabled for ShuffleDependency if push based shuffle + // By default, shuffle merge is allowed for ShuffleDependency if push based shuffle // is enabled - private[this] var _shuffleMergeEnabled = canShuffleMergeBeEnabled() + private[this] var _shuffleMergeAllowed = canShuffleMergeBeEnabled() - private[spark] def setShuffleMergeEnabled(shuffleMergeEnabled: Boolean): Unit = { - _shuffleMergeEnabled = shuffleMergeEnabled + private[spark] def setShuffleMergeAllowed(shuffleMergeAllowed: Boolean): Unit = { + _shuffleMergeAllowed = shuffleMergeAllowed } - def shuffleMergeEnabled : Boolean = _shuffleMergeEnabled + def shuffleMergeEnabled : Boolean = shuffleMergeAllowed && mergerLocs.nonEmpty + + def shuffleMergeAllowed : Boolean = _shuffleMergeAllowed /** * Stores the location of the list of chosen external shuffle services for handling the @@ -124,7 +126,7 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( * Stores the information about whether the shuffle merge is finalized for the shuffle map stage * associated with this shuffle dependency */ - private[this] var _shuffleMergedFinalized: Boolean = false + private[this] var _shuffleMergeFinalized: Boolean = false /** * shuffleMergeId is used to uniquely identify merging process of shuffle @@ -135,31 +137,34 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( def shuffleMergeId: Int = _shuffleMergeId def setMergerLocs(mergerLocs: Seq[BlockManagerId]): Unit = { + assert(shuffleMergeAllowed) this.mergerLocs = mergerLocs } def getMergerLocs: Seq[BlockManagerId] = mergerLocs private[spark] def markShuffleMergeFinalized(): Unit = { - _shuffleMergedFinalized = true + _shuffleMergeFinalized = true + } + + private[spark] def isShuffleMergeFinalizedMarked: Boolean = { + _shuffleMergeFinalized } /** - * Returns true if push-based shuffle is disabled for this stage or empty RDD, - * or if the shuffle merge for this stage is finalized, i.e. the shuffle merge - * results for all partitions are available. + * Returns true if push-based shuffle is disabled or if the shuffle merge for + * this shuffle is finalized. */ def shuffleMergeFinalized: Boolean = { - // Empty RDD won't be computed therefore shuffle merge finalized should be true by default. - if (shuffleMergeEnabled && numPartitions > 0) { - _shuffleMergedFinalized + if (shuffleMergeEnabled) { + isShuffleMergeFinalizedMarked } else { true } } def newShuffleMergeState(): Unit = { - _shuffleMergedFinalized = false + _shuffleMergeFinalized = false mergerLocs = Nil _shuffleMergeId += 1 finalizeTask = None @@ -187,7 +192,7 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( * @param mapIndex Map task index * @return number of map tasks with block push completed */ - def incPushCompleted(mapIndex: Int): Int = { + private[spark] def incPushCompleted(mapIndex: Int): Int = { shufflePushCompleted.add(mapIndex) shufflePushCompleted.getCardinality } @@ -195,9 +200,9 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( // Only used by DAGScheduler to coordinate shuffle merge finalization @transient private[this] var finalizeTask: Option[ScheduledFuture[_]] = None - def getFinalizeTask: Option[ScheduledFuture[_]] = finalizeTask + private[spark] def getFinalizeTask: Option[ScheduledFuture[_]] = finalizeTask - def setFinalizeTask(task: ScheduledFuture[_]): Unit = { + private[spark] def setFinalizeTask(task: ScheduledFuture[_]): Unit = { finalizeTask = Option(task) } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index d71fb09682924..e6ed469250b47 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -144,6 +144,8 @@ private class ShuffleStatus( */ private[this] var _numAvailableMergeResults: Int = 0 + private[this] var shufflePushMergerLocations: Seq[BlockManagerId] = Seq.empty + /** * Register a map output. If there is already a registered location for the map output then it * will be replaced by the new location. @@ -213,6 +215,16 @@ private class ShuffleStatus( mergeStatuses(reduceId) = status } + def registerShuffleMergerLocations(shuffleMergers: Seq[BlockManagerId]): Unit = withWriteLock { + if (shufflePushMergerLocations.isEmpty) { + shufflePushMergerLocations = shuffleMergers + } + } + + def removeShuffleMergerLocations(): Unit = withWriteLock { + shufflePushMergerLocations = Nil + } + // TODO support updateMergeResult for similar use cases as updateMapOutput /** @@ -392,6 +404,10 @@ private class ShuffleStatus( f(mergeStatuses) } + def getShufflePushMergerLocations: Seq[BlockManagerId] = withReadLock { + shufflePushMergerLocations + } + /** * Clears the cached serialized map output statuses. */ @@ -429,6 +445,8 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage private[spark] case class GetMapAndMergeResultStatuses(shuffleId: Int) extends MapOutputTrackerMessage +private[spark] case class GetShufflePushMergerLocations(shuffleId: Int) + extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage private[spark] sealed trait MapOutputTrackerMasterMessage @@ -436,6 +454,8 @@ private[spark] case class GetMapOutputMessage(shuffleId: Int, context: RpcCallContext) extends MapOutputTrackerMasterMessage private[spark] case class GetMapAndMergeOutputMessage(shuffleId: Int, context: RpcCallContext) extends MapOutputTrackerMasterMessage +private[spark] case class GetShufflePushMergersMessage(shuffleId: Int, + context: RpcCallContext) extends MapOutputTrackerMasterMessage private[spark] case class MapSizesByExecutorId( iter: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], enableBatchFetch: Boolean) @@ -457,6 +477,11 @@ private[spark] class MapOutputTrackerMasterEndpoint( logInfo(s"Asked to send map/merge result locations for shuffle $shuffleId to $hostPort") tracker.post(GetMapAndMergeOutputMessage(shuffleId, context)) + case GetShufflePushMergerLocations(shuffleId: Int) => + logInfo(s"Asked to send shuffle push merger locations for shuffle" + + s" $shuffleId to ${context.senderAddress.hostPort}") + tracker.post(GetShufflePushMergersMessage(shuffleId, context)) + case StopMapOutputTracker => logInfo("MapOutputTrackerMasterEndpoint stopped!") context.reply(true) @@ -596,6 +621,16 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging partitionId: Int, chunkBitmap: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] + /** + * Called from executors whenever a task with push based shuffle is enabled doesn't have shuffle + * mergers available. This typically happens when the initial stages doesn't have enough shuffle + * mergers available since very few executors got registered. This is on a best effort basis, + * if there is not enough shuffle mergers available for this stage then an empty sequence would + * be returned indicating the task to avoid shuffle push. + * @param shuffleId + */ + def getShufflePushMergerLocations(shuffleId: Int): Seq[BlockManagerId] + /** * Deletes map output status information for the specified shuffle stage. */ @@ -711,6 +746,11 @@ private[spark] class MapOutputTrackerMaster( handleStatusMessage(shuffleId, context, false) case GetMapAndMergeOutputMessage(shuffleId, context) => handleStatusMessage(shuffleId, context, true) + case GetShufflePushMergersMessage(shuffleId, context) => + logDebug(s"Handling request to send shuffle push merger locations for shuffle" + + s" $shuffleId to ${context.senderAddress.hostPort}") + context.reply(shuffleStatuses.get(shuffleId).map(_.getShufflePushMergerLocations) + .getOrElse(Seq.empty[BlockManagerId])) } } catch { case NonFatal(e) => logError(e.getMessage, e) @@ -772,6 +812,7 @@ private[spark] class MapOutputTrackerMaster( case Some(shuffleStatus) => shuffleStatus.removeOutputsByFilter(x => true) shuffleStatus.removeMergeResultsByFilter(x => true) + shuffleStatus.removeShuffleMergerLocations() incrementEpoch() case None => throw new SparkException( @@ -789,6 +830,12 @@ private[spark] class MapOutputTrackerMaster( } } + def registerShufflePushMergerLocations( + shuffleId: Int, + shuffleMergers: Seq[BlockManagerId]): Unit = { + shuffleStatuses(shuffleId).registerShuffleMergerLocations(shuffleMergers) + } + /** * Unregisters a merge result corresponding to the reduceId if present. If the optional mapIndex * is specified, it will only unregister the merge result if the mapIndex is part of that merge @@ -1130,7 +1177,7 @@ private[spark] class MapOutputTrackerMaster( override def getMapSizesForMergeResult( shuffleId: Int, partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { - Seq.empty.toIterator + Seq.empty.iterator } // This method is only called in local-mode. Since push based shuffle won't be @@ -1139,7 +1186,12 @@ private[spark] class MapOutputTrackerMaster( shuffleId: Int, partitionId: Int, chunkTracker: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { - Seq.empty.toIterator + Seq.empty.iterator + } + + // This method is only called in local-mode. + override def getShufflePushMergerLocations(shuffleId: Int): Seq[BlockManagerId] = { + shuffleStatuses(shuffleId).getShufflePushMergerLocations } override def stop(): Unit = { @@ -1176,6 +1228,14 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr // instantiate a serializer. See the followup to SPARK-36705 for more details. private lazy val fetchMergeResult = Utils.isPushBasedShuffleEnabled(conf, isDriver = false) + /** + * [[shufflePushMergerLocations]] tracks shuffle push merger locations for the latest + * shuffle execution + * + * Exposed for testing + */ + val shufflePushMergerLocations = new ConcurrentHashMap[Int, Seq[BlockManagerId]]().asScala + /** * A [[KeyLock]] whose key is a shuffle id to ensure there is only one thread fetching * the same shuffle block. @@ -1213,10 +1273,10 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr useMergeResult: Boolean): MapSizesByExecutorId = { logDebug(s"Fetching outputs for shuffle $shuffleId") val (mapOutputStatuses, mergedOutputStatuses) = getStatuses(shuffleId, conf, - // EnableBatchFetch can be set to false during stage retry when the - // shuffleDependency.shuffleMergeEnabled is set to false, and Driver + // enableBatchFetch can be set to false during stage retry when the + // shuffleDependency.isShuffleMergeFinalizedMarked is set to false, and Driver // has already collected the mergedStatus for its shuffle dependency. - // In this case, boolean check helps to insure that the unnecessary + // In this case, boolean check helps to ensure that the unnecessary // mergeStatus won't be fetched, thus mergedOutputStatuses won't be // passed to convertMapStatuses. See details in [SPARK-37023]. if (useMergeResult) fetchMergeResult else false) @@ -1281,6 +1341,26 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr } } + override def getShufflePushMergerLocations(shuffleId: Int): Seq[BlockManagerId] = { + shufflePushMergerLocations.getOrElse(shuffleId, getMergerLocations(shuffleId)) + } + + private def getMergerLocations(shuffleId: Int): Seq[BlockManagerId] = { + fetchingLock.withLock(shuffleId) { + var fetchedMergers = shufflePushMergerLocations.get(shuffleId).orNull + if (null == fetchedMergers) { + fetchedMergers = + askTracker[Seq[BlockManagerId]](GetShufflePushMergerLocations(shuffleId)) + if (fetchedMergers.nonEmpty) { + shufflePushMergerLocations(shuffleId) = fetchedMergers + } else { + fetchedMergers = Seq.empty[BlockManagerId] + } + } + fetchedMergers + } + } + /** * Get or fetch the array of MapStatuses and MergeStatuses if push based shuffle enabled * for a given shuffle ID. NOTE: clients MUST synchronize @@ -1364,6 +1444,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr def unregisterShuffle(shuffleId: Int): Unit = { mapStatuses.remove(shuffleId) mergeStatuses.remove(shuffleId) + shufflePushMergerLocations.remove(shuffleId) } /** @@ -1378,6 +1459,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr epoch = newEpoch mapStatuses.clear() mergeStatuses.clear() + shufflePushMergerLocations.clear() } } } diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index d061627bea69c..f11176cc23310 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -324,7 +324,7 @@ private[spark] class SecurityManager( case "yarn" | "local" | LOCAL_N_REGEX(_) | LOCAL_N_FAILURES_REGEX(_, _) => true - case k8sRegex() => + case KUBERNETES_REGEX(_) => // Don't propagate the secret through the user's credentials in kubernetes. That conflicts // with the way k8s handles propagation of delegation tokens. false @@ -354,7 +354,7 @@ private[spark] class SecurityManager( private def secretKeyFromFile(): Option[String] = { sparkConf.get(authSecretFileConf).flatMap { secretFilePath => sparkConf.getOption(SparkLauncher.SPARK_MASTER).map { - case k8sRegex() => + case SparkMasterRegex.KUBERNETES_REGEX(_) => val secretFile = new File(secretFilePath) require(secretFile.isFile, s"No file found containing the secret key at $secretFilePath.") val base64Key = Base64.getEncoder.encodeToString(Files.readAllBytes(secretFile.toPath)) @@ -391,7 +391,6 @@ private[spark] class SecurityManager( private[spark] object SecurityManager { - val k8sRegex = "k8s.*".r val SPARK_AUTH_CONF = NETWORK_AUTH_ENABLED.key val SPARK_AUTH_SECRET_CONF = AUTH_SECRET.key // This is used to set auth secret to an executor's env variable. It should have the same diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 5f37a1abb1909..cf121749b7348 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -636,7 +636,9 @@ private[spark] object SparkConf extends Logging { DeprecatedConfig("spark.blacklist.killBlacklistedExecutors", "3.1.0", "Please use spark.excludeOnFailure.killExcludedExecutors"), DeprecatedConfig("spark.yarn.blacklist.executor.launch.blacklisting.enabled", "3.1.0", - "Please use spark.yarn.executor.launch.excludeOnFailure.enabled") + "Please use spark.yarn.executor.launch.excludeOnFailure.enabled"), + DeprecatedConfig("spark.kubernetes.memoryOverheadFactor", "3.3.0", + "Please use spark.driver.memoryOverheadFactor and spark.executor.memoryOverheadFactor") ) Map(configs.map { cfg => (cfg.key -> cfg) } : _*) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 86bf7255ee1e0..02c58d2a9b4f2 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -560,7 +560,7 @@ class SparkContext(config: SparkConf) extends Logging { _plugins = PluginContainer(this, _resources.asJava) // Create and start the scheduler - val (sched, ts) = SparkContext.createTaskScheduler(this, master, deployMode) + val (sched, ts) = SparkContext.createTaskScheduler(this, master) _schedulerBackend = sched _taskScheduler = ts _dagScheduler = new DAGScheduler(this) @@ -2890,8 +2890,7 @@ object SparkContext extends Logging { */ private def createTaskScheduler( sc: SparkContext, - master: String, - deployMode: String): (SchedulerBackend, TaskScheduler) = { + master: String): (SchedulerBackend, TaskScheduler) = { import SparkMasterRegex._ // When running locally, don't try to re-execute tasks on failure. diff --git a/core/src/main/scala/org/apache/spark/SparkException.scala b/core/src/main/scala/org/apache/spark/SparkException.scala index aea09e36ade74..8442c8eb8d35d 100644 --- a/core/src/main/scala/org/apache/spark/SparkException.scala +++ b/core/src/main/scala/org/apache/spark/SparkException.scala @@ -71,9 +71,22 @@ private[spark] case class ExecutorDeadException(message: String) /** * Exception thrown when Spark returns different result after upgrading to a new version. */ -private[spark] class SparkUpgradeException(version: String, message: String, cause: Throwable) - extends RuntimeException("You may get a different result due to the upgrading of Spark" + - s" $version: $message", cause) +private[spark] class SparkUpgradeException( + errorClass: String, + messageParameters: Array[String], + cause: Throwable) + extends RuntimeException(SparkThrowableHelper.getMessage(errorClass, messageParameters), cause) + with SparkThrowable { + + def this(version: String, message: String, cause: Throwable) = + this ( + errorClass = "INCONSISTENT_BEHAVIOR_CROSS_VERSION", + messageParameters = Array(version, message), + cause = cause + ) + + override def getErrorClass: String = errorClass +} /** * Arithmetic exception thrown from Spark with an error class. diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 9bc6ccbd0df65..104e98b8ae0a4 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -24,7 +24,7 @@ import java.nio.file.{Files => JavaFiles, Paths} import java.nio.file.attribute.PosixFilePermission.{OWNER_EXECUTE, OWNER_READ, OWNER_WRITE} import java.security.SecureRandom import java.security.cert.X509Certificate -import java.util.{Arrays, EnumSet, Locale, Properties} +import java.util.{Arrays, EnumSet, Locale} import java.util.concurrent.{TimeoutException, TimeUnit} import java.util.jar.{JarEntry, JarOutputStream, Manifest} import java.util.regex.Pattern @@ -41,9 +41,10 @@ import scala.util.Try import com.google.common.io.{ByteStreams, Files} import org.apache.commons.lang3.StringUtils -// scalastyle:off -import org.apache.log4j.PropertyConfigurator -// scalastyle:on +import org.apache.logging.log4j.LogManager +import org.apache.logging.log4j.core.LoggerContext +import org.apache.logging.log4j.core.appender.ConsoleAppender +import org.apache.logging.log4j.core.config.builder.api.ConfigurationBuilderFactory import org.eclipse.jetty.server.Handler import org.eclipse.jetty.server.Server import org.eclipse.jetty.server.handler.DefaultHandler @@ -262,7 +263,8 @@ private[spark] object TestUtils { contains = contain(e, msg) } assert(contains, - s"Exception tree doesn't contain the expected exception ${typeMsg}with message: $msg") + s"Exception tree doesn't contain the expected exception ${typeMsg}with message: $msg\n" + + Utils.exceptionString(e)) } /** @@ -336,22 +338,26 @@ private[spark] object TestUtils { connection.setRequestMethod(method) headers.foreach { case (k, v) => connection.setRequestProperty(k, v) } - // Disable cert and host name validation for HTTPS tests. - if (connection.isInstanceOf[HttpsURLConnection]) { - val sslCtx = SSLContext.getInstance("SSL") - val trustManager = new X509TrustManager { - override def getAcceptedIssuers(): Array[X509Certificate] = null - override def checkClientTrusted(x509Certificates: Array[X509Certificate], - s: String): Unit = {} - override def checkServerTrusted(x509Certificates: Array[X509Certificate], - s: String): Unit = {} - } - val verifier = new HostnameVerifier() { - override def verify(hostname: String, session: SSLSession): Boolean = true - } - sslCtx.init(null, Array(trustManager), new SecureRandom()) - connection.asInstanceOf[HttpsURLConnection].setSSLSocketFactory(sslCtx.getSocketFactory()) - connection.asInstanceOf[HttpsURLConnection].setHostnameVerifier(verifier) + connection match { + // Disable cert and host name validation for HTTPS tests. + case httpConnection: HttpsURLConnection => + val sslCtx = SSLContext.getInstance("SSL") + val trustManager = new X509TrustManager { + override def getAcceptedIssuers: Array[X509Certificate] = null + + override def checkClientTrusted(x509Certificates: Array[X509Certificate], + s: String): Unit = {} + + override def checkServerTrusted(x509Certificates: Array[X509Certificate], + s: String): Unit = {} + } + val verifier = new HostnameVerifier() { + override def verify(hostname: String, session: SSLSession): Boolean = true + } + sslCtx.init(null, Array(trustManager), new SecureRandom()) + httpConnection.setSSLSocketFactory(sslCtx.getSocketFactory) + httpConnection.setHostnameVerifier(verifier) + case _ => // do nothing } try { @@ -418,17 +424,18 @@ private[spark] object TestUtils { } /** - * config a log4j properties used for testsuite + * config a log4j2 properties used for testsuite */ - def configTestLog4j(level: String): Unit = { - val pro = new Properties() - pro.put("log4j.rootLogger", s"$level, console") - pro.put("log4j.appender.console", "org.apache.log4j.ConsoleAppender") - pro.put("log4j.appender.console.target", "System.err") - pro.put("log4j.appender.console.layout", "org.apache.log4j.PatternLayout") - pro.put("log4j.appender.console.layout.ConversionPattern", - "%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n") - PropertyConfigurator.configure(pro) + def configTestLog4j2(level: String): Unit = { + val builder = ConfigurationBuilderFactory.newConfigurationBuilder() + val appenderBuilder = builder.newAppender("console", "CONSOLE") + .addAttribute("target", ConsoleAppender.Target.SYSTEM_ERR) + appenderBuilder.add(builder.newLayout("PatternLayout") + .addAttribute("pattern", "%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n")) + builder.add(appenderBuilder) + builder.add(builder.newRootLogger(level).add(builder.newAppenderRef("console"))) + val configuration = builder.build() + LogManager.getContext(false).asInstanceOf[LoggerContext].reconfigure(configuration) } /** @@ -440,6 +447,21 @@ private[spark] object TestUtils { current ++ current.filter(_.isDirectory).flatMap(recursiveList) } + /** + * Returns the list of files at 'path' recursively. This skips files that are ignored normally + * by MapReduce. + */ + def listDirectory(path: File): Array[String] = { + val result = ArrayBuffer.empty[String] + if (path.isDirectory) { + path.listFiles.foreach(f => result.appendAll(listDirectory(f))) + } else { + val c = path.getName.charAt(0) + if (c != '.' && c != '_') result.append(path.getAbsolutePath) + } + result.toArray + } + /** Creates a temp JSON file that contains the input JSON record. */ def createTempJsonFile(dir: File, prefix: String, jsonValue: JValue): String = { val file = File.createTempFile(prefix, ".json", dir) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 6d4dc3d3dfe92..6dc9e71a00848 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -245,7 +245,7 @@ private[spark] object PythonRDD extends Logging { out.writeInt(1) // Write the next object and signal end of data for this iteration - writeIteratorToStream(partitionArray.toIterator, out) + writeIteratorToStream(partitionArray.iterator, out) out.writeInt(SpecialLengths.END_OF_DATA_SECTION) out.flush() } else { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 8daba86758412..a9c353691b466 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -27,7 +27,7 @@ import org.apache.spark.SparkContext import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} private[spark] object PythonUtils { - val PY4J_ZIP_NAME = "py4j-0.10.9.3-src.zip" + val PY4J_ZIP_NAME = "py4j-0.10.9.4-src.zip" /** Get the PYTHONPATH for PySpark, either from SPARK_HOME, if it is set, or from our JAR */ def sparkPythonPath: String = { diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 917203831404f..f9f8c56eb86c4 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -22,7 +22,7 @@ import java.nio.charset.StandardCharsets import java.sql.{Date, Time, Timestamp} import scala.collection.JavaConverters._ -import scala.collection.mutable.WrappedArray +import scala.collection.mutable /** * Utility functions to serialize, deserialize objects to / from R @@ -303,12 +303,10 @@ private[spark] object SerDe { // Convert ArrayType collected from DataFrame to Java array // Collected data of ArrayType from a DataFrame is observed to be of // type "scala.collection.mutable.WrappedArray" - val value = - if (obj.isInstanceOf[WrappedArray[_]]) { - obj.asInstanceOf[WrappedArray[_]].toArray - } else { - obj - } + val value = obj match { + case wa: mutable.WrappedArray[_] => wa.array + case other => other + } value match { case v: java.lang.Character => diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index 989a1941d1791..b6f59c36081f5 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -22,7 +22,8 @@ import java.util.concurrent.atomic.AtomicLong import scala.reflect.ClassTag -import org.apache.commons.collections.map.{AbstractReferenceMap, ReferenceMap} +import org.apache.commons.collections4.map.AbstractReferenceMap.ReferenceStrength +import org.apache.commons.collections4.map.ReferenceMap import org.apache.spark.SparkConf import org.apache.spark.api.python.PythonBroadcast @@ -55,7 +56,7 @@ private[spark] class BroadcastManager( private[broadcast] val cachedValues = Collections.synchronizedMap( - new ReferenceMap(AbstractReferenceMap.HARD, AbstractReferenceMap.WEAK) + new ReferenceMap(ReferenceStrength.HARD, ReferenceStrength.WEAK) .asInstanceOf[java.util.Map[Any, Any]] ) diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index 5915fb8cc7c84..7209e2c373ab1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -24,7 +24,9 @@ import java.util.concurrent.TimeoutException import scala.collection.mutable.ListBuffer import scala.concurrent.{Future, Promise} +// scalastyle:off executioncontextglobal import scala.concurrent.ExecutionContext.Implicits.global +// scalastyle:on executioncontextglobal import scala.concurrent.duration._ import scala.sys.process._ diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index d3c5f0eaf0341..dab1474725d9e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -908,7 +908,7 @@ private[spark] class SparkSubmit extends Logging { logInfo(s"Main class:\n$childMainClass") logInfo(s"Arguments:\n${childArgs.mkString("\n")}") // sysProps may contain sensitive information, so redact before printing - logInfo(s"Spark config:\n${Utils.redact(sparkConf.getAll.toMap).mkString("\n")}") + logInfo(s"Spark config:\n${Utils.redact(sparkConf.getAll.toMap).sorted.mkString("\n")}") logInfo(s"Classpath elements:\n${childClasspath.mkString("\n")}") logInfo("\n") } @@ -1456,9 +1456,9 @@ private[spark] object SparkSubmitUtils extends Logging { throw new RuntimeException(rr.getAllProblemMessages.toString) } // retrieve all resolved dependencies + retrieveOptions.setDestArtifactPattern(packagesDirectory.getAbsolutePath + File.separator + + "[organization]_[artifact]-[revision](-[classifier]).[ext]") ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId, - packagesDirectory.getAbsolutePath + File.separator + - "[organization]_[artifact]-[revision](-[classifier]).[ext]", retrieveOptions.setConfs(Array(ivyConfName))) resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) } finally { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 47fbab52659a4..9a5123f218a63 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -327,7 +327,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | |Spark properties used, including those specified through | --conf and those from the properties file $propertiesFile: - |${Utils.redact(sparkProperties).mkString(" ", "\n ", "\n")} + |${Utils.redact(sparkProperties).sorted.mkString(" ", "\n ", "\n")} """.stripMargin } diff --git a/core/src/main/scala/org/apache/spark/deploy/StandaloneResourceUtils.scala b/core/src/main/scala/org/apache/spark/deploy/StandaloneResourceUtils.scala index c7c31a85b0636..641c5416cbb33 100644 --- a/core/src/main/scala/org/apache/spark/deploy/StandaloneResourceUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/StandaloneResourceUtils.scala @@ -99,7 +99,7 @@ private[spark] object StandaloneResourceUtils extends Logging { ResourceAllocation(new ResourceID(componentName, rName), rInfo.addresses) }.toSeq try { - writeResourceAllocationJson(componentName, allocations, tmpFile) + writeResourceAllocationJson(allocations, tmpFile) } catch { case NonFatal(e) => val errMsg = s"Exception threw while preparing resource file for $compShortName" @@ -112,7 +112,6 @@ private[spark] object StandaloneResourceUtils extends Logging { } private def writeResourceAllocationJson[T]( - componentName: String, allocations: Seq[T], jsonFile: File): Unit = { implicit val formats = DefaultFormats diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 55f648a4a05c8..a9adaed374af1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -26,7 +26,7 @@ import java.util.zip.ZipOutputStream import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.io.Source +import scala.io.{Codec, Source} import scala.util.control.NonFatal import scala.xml.Node @@ -144,11 +144,11 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val dbPath = Files.createDirectories(new File(path, dir).toPath()).toFile() Utils.chmod700(dbPath) - val metadata = new FsHistoryProviderMetadata(CURRENT_LISTING_VERSION, - AppStatusStore.CURRENT_VERSION, logDir.toString()) + val metadata = FsHistoryProviderMetadata(CURRENT_LISTING_VERSION, + AppStatusStore.CURRENT_VERSION, logDir) try { - open(dbPath, metadata) + open(dbPath, metadata, conf) } catch { // If there's an error, remove the listing database and any existing UI database // from the store directory, since it's extremely likely that they'll all contain @@ -156,12 +156,12 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) case _: UnsupportedStoreVersionException | _: MetadataMismatchException => logInfo("Detected incompatible DB versions, deleting...") path.listFiles().foreach(Utils.deleteRecursively) - open(dbPath, metadata) + open(dbPath, metadata, conf) case dbExc @ (_: NativeDB.DBException | _: RocksDBException) => // Get rid of the corrupted data and re-create it. logWarning(s"Failed to load disk store $dbPath :", dbExc) Utils.deleteRecursively(dbPath) - open(dbPath, metadata) + open(dbPath, metadata, conf) } }.getOrElse(new InMemoryStore()) @@ -414,7 +414,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } else { Map() } - Map("Event log directory" -> logDir.toString) ++ safeMode + Map("Event log directory" -> logDir) ++ safeMode } override def start(): Unit = { @@ -819,7 +819,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - val source = Source.fromInputStream(in).getLines() + val source = Source.fromInputStream(in)(Codec.UTF8).getLines() // Because skipping may leave the stream in the middle of a line, read the next line // before replaying. @@ -1218,7 +1218,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // the existing data. dm.openStore(appId, attempt.info.attemptId).foreach { path => try { - return KVUtils.open(path, metadata) + return KVUtils.open(path, metadata, conf) } catch { case e: Exception => logInfo(s"Failed to open existing store for $appId/${attempt.info.attemptId}.", e) @@ -1284,14 +1284,14 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) try { logInfo(s"Leasing disk manager space for app $appId / ${attempt.info.attemptId}...") lease = dm.lease(reader.totalSize, reader.compressionCodec.isDefined) - val diskStore = KVUtils.open(lease.tmpPath, metadata) + val diskStore = KVUtils.open(lease.tmpPath, metadata, conf) hybridStore.setDiskStore(diskStore) hybridStore.switchToDiskStore(new HybridStore.SwitchToDiskStoreListener { override def onSwitchToDiskStoreSuccess: Unit = { logInfo(s"Completely switched to diskStore for app $appId / ${attempt.info.attemptId}.") diskStore.close() val newStorePath = lease.commit(appId, attempt.info.attemptId) - hybridStore.setDiskStore(KVUtils.open(newStorePath, metadata)) + hybridStore.setDiskStore(KVUtils.open(newStorePath, metadata, conf)) memoryManager.release(appId, attempt.info.attemptId) } override def onSwitchToDiskStoreFail(e: Exception): Unit = { @@ -1327,7 +1327,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) logInfo(s"Leasing disk manager space for app $appId / ${attempt.info.attemptId}...") val lease = dm.lease(reader.totalSize, isCompressed) try { - Utils.tryWithResource(KVUtils.open(lease.tmpPath, metadata)) { store => + Utils.tryWithResource(KVUtils.open(lease.tmpPath, metadata, conf)) { store => rebuildAppStore(store, reader, attempt.info.lastUpdated.getTime()) } newStorePath = lease.commit(appId, attempt.info.attemptId) @@ -1345,7 +1345,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - KVUtils.open(newStorePath, metadata) + KVUtils.open(newStorePath, metadata, conf) } private def createInMemoryStore(attempt: AttemptInfoWrapper): KVStore = { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryAppStatusStore.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryAppStatusStore.scala index ac0f102d81a6a..d86243df7163f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryAppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryAppStatusStore.scala @@ -44,21 +44,23 @@ private[spark] class HistoryAppStatusStore( override def executorList(activeOnly: Boolean): Seq[v1.ExecutorSummary] = { val execList = super.executorList(activeOnly) - logUrlPattern match { - case Some(pattern) => execList.map(replaceLogUrls(_, pattern)) - case None => execList + if (logUrlPattern.nonEmpty) { + execList.map(replaceLogUrls) + } else { + execList } } override def executorSummary(executorId: String): v1.ExecutorSummary = { val execSummary = super.executorSummary(executorId) - logUrlPattern match { - case Some(pattern) => replaceLogUrls(execSummary, pattern) - case None => execSummary + if (logUrlPattern.nonEmpty) { + replaceLogUrls(execSummary) + } else { + execSummary } } - private def replaceLogUrls(exec: v1.ExecutorSummary, urlPattern: String): v1.ExecutorSummary = { + private def replaceLogUrls(exec: v1.ExecutorSummary): v1.ExecutorSummary = { val newLogUrlMap = logUrlHandler.applyPattern(exec.executorLogs, exec.attributes) replaceExecutorLogs(exec, newLogUrlMap) } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala index 40e337a725430..72d407d8643cf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala @@ -28,6 +28,7 @@ import org.apache.commons.io.FileUtils import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.internal.config.History._ +import org.apache.spark.internal.config.History.HybridStoreDiskBackend.LEVELDB import org.apache.spark.status.KVUtils._ import org.apache.spark.util.{Clock, Utils} import org.apache.spark.util.kvstore.KVStore @@ -55,6 +56,8 @@ private class HistoryServerDiskManager( if (!appStoreDir.isDirectory() && !appStoreDir.mkdir()) { throw new IllegalArgumentException(s"Failed to create app directory ($appStoreDir).") } + private val extension = + if (conf.get(HYBRID_STORE_DISK_BACKEND) == LEVELDB.toString) ".ldb" else ".rdb" private val tmpStoreDir = new File(path, "temp") if (!tmpStoreDir.isDirectory() && !tmpStoreDir.mkdir()) { @@ -251,7 +254,7 @@ private class HistoryServerDiskManager( } private[history] def appStorePath(appId: String, attemptId: Option[String]): File = { - val fileName = appId + attemptId.map("_" + _).getOrElse("") + ".ldb" + val fileName = appId + attemptId.map("_" + _).getOrElse("") + extension new File(appStoreDir, fileName) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 7dbf6b92b4088..775b27bcbf279 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -25,7 +25,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.util.Random import org.apache.spark.{SecurityManager, SparkConf, SparkException} -import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState, SparkHadoopUtil} +import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.MasterMessages._ @@ -53,8 +53,6 @@ private[deploy] class Master( private val forwardMessageThread = ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread") - private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - // For application IDs private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) @@ -95,11 +93,6 @@ private[deploy] class Master( // After onStart, webUi will be set private var webUi: MasterWebUI = null - private val masterPublicAddress = { - val envVar = conf.getenv("SPARK_PUBLIC_DNS") - if (envVar != null) envVar else address.host - } - private val masterUrl = address.toSparkURL private var masterWebUiUrl: String = _ diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 6143321427d4c..a71eb33a2fe1d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -76,13 +76,13 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { private def formatMasterResourcesInUse(aliveWorkers: Array[WorkerInfo]): String = { val totalInfo = aliveWorkers.map(_.resourcesInfo) - .flatMap(_.toIterator) + .flatMap(_.iterator) .groupBy(_._1) // group by resource name .map { case (rName, rInfoArr) => rName -> rInfoArr.map(_._2.addresses.size).sum } val usedInfo = aliveWorkers.map(_.resourcesInfoUsed) - .flatMap(_.toIterator) + .flatMap(_.iterator) .groupBy(_._1) // group by resource name .map { case (rName, rInfoArr) => rName -> rInfoArr.map(_._2.addresses.size).sum @@ -277,7 +277,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { s"if (window.confirm('Are you sure you want to kill application ${app.id} ?')) " + "{ this.parentNode.submit(); return true; } else { return false; }"
- + (kill)
@@ -328,7 +328,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { s"if (window.confirm('Are you sure you want to kill driver ${driver.id} ?')) " + "{ this.parentNode.submit(); return true; } else { return false; }"
- + (kill)
@@ -339,10 +339,10 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {driver.worker.map(w => if (w.isAlive()) { - {w.id.toString} + {w.id} } else { - w.id.toString + w.id }).getOrElse("None")} {driver.state} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index cc1d60a097b2e..8a0fc886e60ca 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -229,7 +229,9 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { * Exposed for testing. */ private[rest] def readResponse(connection: HttpURLConnection): SubmitRestProtocolResponse = { + // scalastyle:off executioncontextglobal import scala.concurrent.ExecutionContext.Implicits.global + // scalastyle:on executioncontextglobal val responseFuture = Future { val responseCode = connection.getResponseCode diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala index 5c98762d4181d..3120d482f11e1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala @@ -57,7 +57,7 @@ private[deploy] class HadoopFSDelegationTokenProvider // Get the token renewal interval if it is not set. It will only be called once. if (tokenRenewalInterval == null) { - tokenRenewalInterval = getTokenRenewalInterval(hadoopConf, sparkConf, fileSystems) + tokenRenewalInterval = getTokenRenewalInterval(hadoopConf, fileSystems) } // Get the time of next renewal. @@ -123,7 +123,6 @@ private[deploy] class HadoopFSDelegationTokenProvider private def getTokenRenewalInterval( hadoopConf: Configuration, - sparkConf: SparkConf, filesystems: Set[FileSystem]): Option[Long] = { // We cannot use the tokens generated with renewer yarn. Trying to renew // those will fail with an access control issue. So create new tokens with the logged in diff --git a/core/src/main/scala/org/apache/spark/errors/SparkCoreErrors.scala b/core/src/main/scala/org/apache/spark/errors/SparkCoreErrors.scala index 95925deca0c30..aecef8ed2d63d 100644 --- a/core/src/main/scala/org/apache/spark/errors/SparkCoreErrors.scala +++ b/core/src/main/scala/org/apache/spark/errors/SparkCoreErrors.scala @@ -315,4 +315,14 @@ object SparkCoreErrors { def failToGetNonShuffleBlockError(blockId: BlockId, e: Throwable): Throwable = { new SparkException(s"Failed to get block $blockId, which is not a shuffle block", e) } + + def graphiteSinkInvalidProtocolError(invalidProtocol: String): Throwable = { + new SparkException(errorClass = "GRAPHITE_SINK_INVALID_PROTOCOL", + messageParameters = Array(invalidProtocol), cause = null) + } + + def graphiteSinkPropertyMissingError(missingProperty: String): Throwable = { + new SparkException(errorClass = "GRAPHITE_SINK_PROPERTY_MISSING", + messageParameters = Array(missingProperty), cause = null) + } } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index fb7b4e62150db..a94e63656e1a1 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -42,7 +42,6 @@ import org.apache.spark.resource.ResourceUtils._ import org.apache.spark.rpc._ import org.apache.spark.scheduler.{ExecutorLossMessage, ExecutorLossReason, TaskDescription} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, SignalUtils, ThreadUtils, Utils} private[spark] class CoarseGrainedExecutorBackend( @@ -65,10 +64,6 @@ private[spark] class CoarseGrainedExecutorBackend( var executor: Executor = null @volatile var driver: Option[RpcEndpointRef] = None - // If this CoarseGrainedExecutorBackend is changed to support multiple threads, then this may need - // to be changed so that we don't share the serializer instance across threads - private[this] val ser: SerializerInstance = env.closureSerializer.newInstance() - private var _resources = Map.empty[String, ResourceInformation] /** diff --git a/core/src/main/scala/org/apache/spark/internal/Logging.scala b/core/src/main/scala/org/apache/spark/internal/Logging.scala index bdc5139fd918e..d483a93464c06 100644 --- a/core/src/main/scala/org/apache/spark/internal/Logging.scala +++ b/core/src/main/scala/org/apache/spark/internal/Logging.scala @@ -154,7 +154,11 @@ trait Logging { // Use the repl's main class to define the default log level when running the shell, // overriding the root logger's config if they're different. val replLogger = LogManager.getLogger(logName).asInstanceOf[Log4jLogger] - val replLevel = Option(replLogger.getLevel()).getOrElse(Level.WARN) + val replLevel = if (Logging.loggerWithCustomConfig(replLogger)) { + replLogger.getLevel() + } else { + Level.WARN + } // Update the consoleAppender threshold to replLevel if (replLevel != rootLogger.getLevel()) { if (!silent) { @@ -229,6 +233,17 @@ private[spark] object Logging { "org.apache.logging.slf4j.Log4jLoggerFactory".equals(binderClass) } + // Return true if the logger has custom configuration. It depends on: + // 1. If the logger isn't attached with root logger config (i.e., with custom configuration), or + // 2. the logger level is different to root config level (i.e., it is changed programmatically). + // + // Note that if a logger is programmatically changed log level but set to same level + // as root config level, we cannot tell if it is with custom configuration. + private def loggerWithCustomConfig(logger: Log4jLogger): Boolean = { + val rootConfig = LogManager.getRootLogger.asInstanceOf[Log4jLogger].get() + (logger.get() ne rootConfig) || (logger.getLevel != rootConfig.getLevel()) + } + /** * Return true if log4j2 is initialized by default configuration which has one * appender with error level. See `org.apache.logging.log4j.core.config.DefaultConfiguration`. @@ -267,17 +282,6 @@ private[spark] object Logging { } } - // Return true if the logger has custom configuration. It depends on: - // 1. If the logger isn't attached with root logger config (i.e., with custom configuration), or - // 2. the logger level is different to root config level (i.e., it is changed programmatically). - // - // Note that if a logger is programmatically changed log level but set to same level - // as root config level, we cannot tell if it is with custom configuration. - private def loggerWithCustomConfig(logger: Log4jLogger): Boolean = { - val rootConfig = LogManager.getRootLogger.asInstanceOf[Log4jLogger].get() - (logger.get() ne rootConfig) || (logger.getLevel != rootConfig.getLevel()) - } - override def getState: LifeCycle.State = status override def initialize(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala index 38e057b16dcc5..e3190269a5349 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala @@ -140,15 +140,15 @@ private[spark] class TypedConfigBuilder[T]( def createWithDefault(default: T): ConfigEntry[T] = { // Treat "String" as a special case, so that both createWithDefault and createWithDefaultString // behave the same w.r.t. variable expansion of default values. - if (default.isInstanceOf[String]) { - createWithDefaultString(default.asInstanceOf[String]) - } else { - val transformedDefault = converter(stringConverter(default)) - val entry = new ConfigEntryWithDefault[T](parent.key, parent._prependedKey, - parent._prependSeparator, parent._alternatives, transformedDefault, converter, - stringConverter, parent._doc, parent._public, parent._version) - parent._onCreate.foreach(_(entry)) - entry + default match { + case str: String => createWithDefaultString(str) + case _ => + val transformedDefault = converter(stringConverter(default)) + val entry = new ConfigEntryWithDefault[T](parent.key, parent._prependedKey, + parent._prependSeparator, parent._alternatives, transformedDefault, converter, + stringConverter, parent._doc, parent._public, parent._version) + parent._onCreate.foreach(_ (entry)) + entry } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index a942ba5401ab2..ffe4501248f43 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -105,6 +105,22 @@ package object config { .bytesConf(ByteUnit.MiB) .createOptional + private[spark] val DRIVER_MEMORY_OVERHEAD_FACTOR = + ConfigBuilder("spark.driver.memoryOverheadFactor") + .doc("Fraction of driver memory to be allocated as additional non-heap memory per driver " + + "process in cluster mode. This is memory that accounts for things like VM overheads, " + + "interned strings, other native overheads, etc. This tends to grow with the container " + + "size. This value defaults to 0.10 except for Kubernetes non-JVM jobs, which defaults to " + + "0.40. This is done as non-JVM tasks need more non-JVM heap space and such tasks " + + "commonly fail with \"Memory Overhead Exceeded\" errors. This preempts this error " + + "with a higher default. This value is ignored if spark.driver.memoryOverhead is set " + + "directly.") + .version("3.3.0") + .doubleConf + .checkValue(factor => factor > 0, + "Ensure that memory overhead is a double greater than 0") + .createWithDefault(0.1) + private[spark] val DRIVER_LOG_DFS_DIR = ConfigBuilder("spark.driver.log.dfsDir").version("3.0.0").stringConf.createOptional @@ -315,6 +331,18 @@ package object config { .bytesConf(ByteUnit.MiB) .createOptional + private[spark] val EXECUTOR_MEMORY_OVERHEAD_FACTOR = + ConfigBuilder("spark.executor.memoryOverheadFactor") + .doc("Fraction of executor memory to be allocated as additional non-heap memory per " + + "executor process. This is memory that accounts for things like VM overheads, " + + "interned strings, other native overheads, etc. This tends to grow with the container " + + "size. This value is ignored if spark.executor.memoryOverhead is set directly.") + .version("3.3.0") + .doubleConf + .checkValue(factor => factor > 0, + "Ensure that memory overhead is a double greater than 0") + .createWithDefault(0.1) + private[spark] val CORES_MAX = ConfigBuilder("spark.cores.max") .doc("When running on a standalone deploy cluster or a Mesos cluster in coarse-grained " + "sharing mode, the maximum amount of CPU cores to request for the application from across " + @@ -364,11 +392,6 @@ package object config { .doubleConf .createWithDefault(0.6) - private[spark] val STORAGE_SAFETY_FRACTION = ConfigBuilder("spark.storage.safetyFraction") - .version("1.1.0") - .doubleConf - .createWithDefault(0.9) - private[spark] val STORAGE_UNROLL_MEMORY_THRESHOLD = ConfigBuilder("spark.storage.unrollMemoryThreshold") .doc("Initial memory to request before unrolling any block") @@ -1178,6 +1201,30 @@ package object config { .bytesConf(ByteUnit.BYTE) .createWithDefault(100 * 1024 * 1024) + private[spark] val SHUFFLE_ACCURATE_BLOCK_SKEWED_FACTOR = + ConfigBuilder("spark.shuffle.accurateBlockSkewedFactor") + .internal() + .doc("A shuffle block is considered as skewed and will be accurately recorded in " + + "HighlyCompressedMapStatus if its size is larger than this factor multiplying " + + "the median shuffle block size or SHUFFLE_ACCURATE_BLOCK_THRESHOLD. It is " + + "recommended to set this parameter to be the same as SKEW_JOIN_SKEWED_PARTITION_FACTOR." + + "Set to -1.0 to disable this feature by default.") + .version("3.3.0") + .doubleConf + .createWithDefault(-1.0) + + private[spark] val SHUFFLE_MAX_ACCURATE_SKEWED_BLOCK_NUMBER = + ConfigBuilder("spark.shuffle.maxAccurateSkewedBlockNumber") + .internal() + .doc("Max skewed shuffle blocks allowed to be accurately recorded in " + + "HighlyCompressedMapStatus if its size is larger than " + + "SHUFFLE_ACCURATE_BLOCK_SKEWED_FACTOR multiplying the median shuffle block size or " + + "SHUFFLE_ACCURATE_BLOCK_THRESHOLD.") + .version("3.3.0") + .intConf + .checkValue(_ > 0, "Allowed max accurate skewed block number must be positive.") + .createWithDefault(100) + private[spark] val SHUFFLE_REGISTRATION_TIMEOUT = ConfigBuilder("spark.shuffle.registration.timeout") .doc("Timeout in milliseconds for registration to the external shuffle service.") diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala index 5cd7397ea358f..e2a96267082b8 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -95,6 +95,7 @@ abstract class FileCommitProtocol extends Logging { * if a task is going to write out multiple files to the same dir. The file commit protocol only * guarantees that files written by different tasks will not conflict. */ + @deprecated("use newTaskTempFile(..., spec: FileNameSpec) instead", "3.3.0") def newTaskTempFile(taskContext: TaskAttemptContext, dir: Option[String], ext: String): String /** @@ -132,6 +133,7 @@ abstract class FileCommitProtocol extends Logging { * if a task is going to write out multiple files to the same dir. The file commit protocol only * guarantees that files written by different tasks will not conflict. */ + @deprecated("use newTaskTempFileAbsPath(..., spec: FileNameSpec) instead", "3.3.0") def newTaskTempFileAbsPath( taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index a39e9abd9bdc4..3a24da98ecc24 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -104,7 +104,7 @@ class HadoopMapReduceCommitProtocol( * The staging directory of this write job. Spark uses it to deal with files with absolute output * path, or writing data into partitioned directory with dynamicPartitionOverwrite=true. */ - protected def stagingDir = getStagingDir(path, jobId) + @transient protected lazy val stagingDir = getStagingDir(path, jobId) protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { val format = context.getOutputFormatClass.getConstructor().newInstance() diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala index c08b47f99dda3..596974f338fd8 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -17,8 +17,11 @@ package org.apache.spark.memory +import java.lang.management.{ManagementFactory, PlatformManagedObject} import javax.annotation.concurrent.GuardedBy +import scala.util.Try + import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ @@ -27,6 +30,7 @@ import org.apache.spark.storage.memory.MemoryStore import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.memory.MemoryAllocator +import org.apache.spark.util.Utils /** * An abstract memory manager that enforces how memory is shared between execution and storage. @@ -242,8 +246,12 @@ private[spark] abstract class MemoryManager( * If user didn't explicitly set "spark.buffer.pageSize", we figure out the default value * by looking at the number of cores available to the process, and the total amount of memory, * and then divide it by a factor of safety. + * + * SPARK-37593 If we are using G1GC, it's better to take the LONG_ARRAY_OFFSET + * into consideration so that the requested memory size is power of 2 + * and can be divided by G1 heap region size to reduce memory waste within one G1 region. */ - val pageSizeBytes: Long = { + private lazy val defaultPageSizeBytes = { val minPageSize = 1L * 1024 * 1024 // 1MB val maxPageSize = 64L * minPageSize // 64MB val cores = if (numCores > 0) numCores else Runtime.getRuntime.availableProcessors() @@ -254,10 +262,16 @@ private[spark] abstract class MemoryManager( case MemoryMode.OFF_HEAP => offHeapExecutionMemoryPool.poolSize } val size = ByteArrayMethods.nextPowerOf2(maxTungstenMemory / cores / safetyFactor) - val default = math.min(maxPageSize, math.max(minPageSize, size)) - conf.get(BUFFER_PAGESIZE).getOrElse(default) + val chosenPageSize = math.min(maxPageSize, math.max(minPageSize, size)) + if (isG1GC && tungstenMemoryMode == MemoryMode.ON_HEAP) { + chosenPageSize - Platform.LONG_ARRAY_OFFSET + } else { + chosenPageSize + } } + val pageSizeBytes: Long = conf.get(BUFFER_PAGESIZE).getOrElse(defaultPageSizeBytes) + /** * Allocates memory for use by Unsafe/Tungsten code. */ @@ -267,4 +281,22 @@ private[spark] abstract class MemoryManager( case MemoryMode.OFF_HEAP => MemoryAllocator.UNSAFE } } + + /** + * Return whether we are using G1GC or not + */ + private lazy val isG1GC: Boolean = { + Try { + val clazz = Utils.classForName("com.sun.management.HotSpotDiagnosticMXBean") + .asInstanceOf[Class[_ <: PlatformManagedObject]] + val vmOptionClazz = Utils.classForName("com.sun.management.VMOption") + val hotSpotDiagnosticMXBean = ManagementFactory.getPlatformMXBean(clazz) + val vmOptionMethod = clazz.getMethod("getVMOption", classOf[String]) + val valueMethod = vmOptionClazz.getMethod("getValue") + + val useG1GCObject = vmOptionMethod.invoke(hotSpotDiagnosticMXBean, "UseG1GC") + val useG1GC = valueMethod.invoke(useG1GCObject).asInstanceOf[String] + "true".equals(useG1GC) + }.getOrElse(false) + } } diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala index bddd18adc683e..4b53aad6fc48b 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala @@ -107,9 +107,9 @@ private[spark] class MetricsConfig(conf: SparkConf) extends Logging { def subProperties(prop: Properties, regex: Regex): mutable.HashMap[String, Properties] = { val subProperties = new mutable.HashMap[String, Properties] prop.asScala.foreach { kv => - if (regex.findPrefixOf(kv._1.toString).isDefined) { - val regex(prefix, suffix) = kv._1.toString - subProperties.getOrElseUpdate(prefix, new Properties).setProperty(suffix, kv._2.toString) + if (regex.findPrefixOf(kv._1).isDefined) { + val regex(prefix, suffix) = kv._1 + subProperties.getOrElseUpdate(prefix, new Properties).setProperty(suffix, kv._2) } } subProperties diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala index 1c59e191db531..13460954c061c 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -23,6 +23,7 @@ import java.util.concurrent.TimeUnit import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry} import com.codahale.metrics.graphite.{Graphite, GraphiteReporter, GraphiteUDP} +import org.apache.spark.errors.SparkCoreErrors import org.apache.spark.metrics.MetricsSystem private[spark] class GraphiteSink( @@ -42,11 +43,11 @@ private[spark] class GraphiteSink( def propertyToOption(prop: String): Option[String] = Option(property.getProperty(prop)) if (!propertyToOption(GRAPHITE_KEY_HOST).isDefined) { - throw new Exception("Graphite sink requires 'host' property.") + throw SparkCoreErrors.graphiteSinkPropertyMissingError("host") } if (!propertyToOption(GRAPHITE_KEY_PORT).isDefined) { - throw new Exception("Graphite sink requires 'port' property.") + throw SparkCoreErrors.graphiteSinkPropertyMissingError("port") } val host = propertyToOption(GRAPHITE_KEY_HOST).get @@ -69,7 +70,7 @@ private[spark] class GraphiteSink( val graphite = propertyToOption(GRAPHITE_KEY_PROTOCOL).map(_.toLowerCase(Locale.ROOT)) match { case Some("udp") => new GraphiteUDP(host, port) case Some("tcp") | None => new Graphite(host, port) - case Some(p) => throw new Exception(s"Invalid Graphite protocol: $p") + case Some(p) => throw SparkCoreErrors.graphiteSinkInvalidProtocolError(p) } val filter = propertyToOption(GRAPHITE_KEY_REGEX) match { diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 701145107482e..fcc2275585e83 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -61,14 +61,14 @@ private[spark] class HadoopPartition(rddId: Int, override val index: Int, s: Inp * @return a Map with the environment variables and corresponding values, it could be empty */ def getPipeEnvVars(): Map[String, String] = { - val envVars: Map[String, String] = if (inputSplit.value.isInstanceOf[FileSplit]) { - val is: FileSplit = inputSplit.value.asInstanceOf[FileSplit] - // map_input_file is deprecated in favor of mapreduce_map_input_file but set both - // since it's not removed yet - Map("map_input_file" -> is.getPath().toString(), - "mapreduce_map_input_file" -> is.getPath().toString()) - } else { - Map() + val envVars: Map[String, String] = inputSplit.value match { + case is: FileSplit => + // map_input_file is deprecated in favor of mapreduce_map_input_file but set both + // since it's not removed yet + Map("map_input_file" -> is.getPath().toString(), + "mapreduce_map_input_file" -> is.getPath().toString()) + case _ => + Map() } envVars } @@ -161,29 +161,31 @@ class HadoopRDD[K, V]( newJobConf } } else { - if (conf.isInstanceOf[JobConf]) { - logDebug("Re-using user-broadcasted JobConf") - conf.asInstanceOf[JobConf] - } else { - Option(HadoopRDD.getCachedMetadata(jobConfCacheKey)) - .map { conf => - logDebug("Re-using cached JobConf") - conf.asInstanceOf[JobConf] - } - .getOrElse { - // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in - // the local process. The local cache is accessed through HadoopRDD.putCachedMetadata(). - // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary - // objects. Synchronize to prevent ConcurrentModificationException (SPARK-1097, - // HADOOP-10456). - HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { - logDebug("Creating new JobConf and caching it for later re-use") - val newJobConf = new JobConf(conf) - initLocalJobConfFuncOpt.foreach(f => f(newJobConf)) - HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) - newJobConf - } - } + conf match { + case jobConf: JobConf => + logDebug("Re-using user-broadcasted JobConf") + jobConf + case _ => + Option(HadoopRDD.getCachedMetadata(jobConfCacheKey)) + .map { conf => + logDebug("Re-using cached JobConf") + conf.asInstanceOf[JobConf] + } + .getOrElse { + // Create a JobConf that will be cached and used across this RDD's getJobConf() + // calls in the local process. The local cache is accessed through + // HadoopRDD.putCachedMetadata(). + // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary + // objects. Synchronize to prevent ConcurrentModificationException (SPARK-1097, + // HADOOP-10456). + HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { + logDebug("Creating new JobConf and caching it for later re-use") + val newJobConf = new JobConf(conf) + initLocalJobConfFuncOpt.foreach(f => f(newJobConf)) + HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) + newJobConf + } + } } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index c6959a5a4dafa..596298b222e05 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -244,7 +244,6 @@ class NewHadoopRDD[K, V]( } private var havePair = false - private var recordsSinceMetricsUpdate = 0 override def hasNext: Boolean = { if (!finished && !havePair) { diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index 285da043c0b9a..7e121e9a7ef2c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -72,9 +72,10 @@ private[spark] class PipedRDD[T: ClassTag]( // for compatibility with Hadoop which sets these env variables // so the user code can access the input filename - if (split.isInstanceOf[HadoopPartition]) { - val hadoopSplit = split.asInstanceOf[HadoopPartition] - currentEnvVars.putAll(hadoopSplit.getPipeEnvVars().asJava) + split match { + case hadoopSplit: HadoopPartition => + currentEnvVars.putAll(hadoopSplit.getPipeEnvVars().asJava) + case _ => // do nothing } // When spark.worker.separated.working.directory option is turned on, each diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 4c39d178d38e8..c76b0d95d103d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1746,7 +1746,6 @@ abstract class RDD[T: ClassTag]( } /** - * :: Experimental :: * Removes an RDD's shuffles and it's non-persisted ancestors. * When running without a shuffle service, cleaning up shuffle files enables downscaling. * If you use the RDD after this call, you should checkpoint and materialize it first. @@ -1755,7 +1754,6 @@ abstract class RDD[T: ClassTag]( * * Tuning the driver GC to be more aggressive, so the regular context cleaner is triggered * * Setting an appropriate TTL for shuffle files to be auto cleaned */ - @Experimental @DeveloperApi @Since("3.1.0") def cleanShuffleDependencies(blocking: Boolean = false): Unit = { @@ -1764,9 +1762,11 @@ abstract class RDD[T: ClassTag]( * Clean the shuffles & all of its parents. */ def cleanEagerly(dep: Dependency[_]): Unit = { - if (dep.isInstanceOf[ShuffleDependency[_, _, _]]) { - val shuffleId = dep.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId - cleaner.doCleanupShuffle(shuffleId, blocking) + dep match { + case dependency: ShuffleDependency[_, _, _] => + val shuffleId = dependency.shuffleId + cleaner.doCleanupShuffle(shuffleId, blocking) + case _ => // do nothing } val rdd = dep.rdd val rddDepsOpt = rdd.internalDependencies diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala index 0a26b7b0678eb..0d1bc1425161e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala @@ -46,7 +46,7 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v */ def getCheckpointDir: Option[String] = RDDCheckpointData.synchronized { if (isCheckpointed) { - Some(cpDir.toString) + Some(cpDir) } else { None } diff --git a/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala b/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala index 339870195044c..087897ff73097 100644 --- a/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala +++ b/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala @@ -87,7 +87,7 @@ class ResourceProfile( } private[spark] def getPySparkMemory: Option[Long] = { - executorResources.get(ResourceProfile.PYSPARK_MEM).map(_.amount.toLong) + executorResources.get(ResourceProfile.PYSPARK_MEM).map(_.amount) } /* diff --git a/core/src/main/scala/org/apache/spark/resource/ResourceUtils.scala b/core/src/main/scala/org/apache/spark/resource/ResourceUtils.scala index 837b2d80aace6..3f0a0d36dff6e 100644 --- a/core/src/main/scala/org/apache/spark/resource/ResourceUtils.scala +++ b/core/src/main/scala/org/apache/spark/resource/ResourceUtils.scala @@ -386,7 +386,6 @@ private[spark] object ResourceUtils extends Logging { val resourcePlugins = Utils.loadExtensions(classOf[ResourceDiscoveryPlugin], pluginClasses, sparkConf) // apply each plugin until one of them returns the information for this resource - var riOption: Optional[ResourceInformation] = Optional.empty() resourcePlugins.foreach { plugin => val riOption = plugin.discoverResource(resourceRequest, sparkConf) if (riOption.isPresent()) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index eed71038b3e33..ffaabba71e8cc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1369,24 +1369,37 @@ private[spark] class DAGScheduler( * locations for block push/merge by getting the historical locations of past executors. */ private def prepareShuffleServicesForShuffleMapStage(stage: ShuffleMapStage): Unit = { - assert(stage.shuffleDep.shuffleMergeEnabled && !stage.shuffleDep.shuffleMergeFinalized) + assert(stage.shuffleDep.shuffleMergeAllowed && !stage.shuffleDep.isShuffleMergeFinalizedMarked) if (stage.shuffleDep.getMergerLocs.isEmpty) { - val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations( - stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId) - if (mergerLocs.nonEmpty) { - stage.shuffleDep.setMergerLocs(mergerLocs) - logInfo(s"Push-based shuffle enabled for $stage (${stage.name}) with" + - s" ${stage.shuffleDep.getMergerLocs.size} merger locations") - - logDebug("List of shuffle push merger locations " + - s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}") - } else { - stage.shuffleDep.setShuffleMergeEnabled(false) - logInfo(s"Push-based shuffle disabled for $stage (${stage.name})") - } + getAndSetShufflePushMergerLocations(stage) + } + + val shuffleId = stage.shuffleDep.shuffleId + val shuffleMergeId = stage.shuffleDep.shuffleMergeId + if (stage.shuffleDep.shuffleMergeEnabled) { + logInfo(s"Shuffle merge enabled before starting the stage for $stage with shuffle" + + s" $shuffleId and shuffle merge $shuffleMergeId with" + + s" ${stage.shuffleDep.getMergerLocs.size} merger locations") + } else { + logInfo(s"Shuffle merge disabled for $stage with shuffle $shuffleId" + + s" and shuffle merge $shuffleMergeId, but can get enabled later adaptively" + + s" once enough mergers are available") } } + private def getAndSetShufflePushMergerLocations(stage: ShuffleMapStage): Seq[BlockManagerId] = { + val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations( + stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId) + if (mergerLocs.nonEmpty) { + stage.shuffleDep.setMergerLocs(mergerLocs) + } + + logDebug(s"Shuffle merge locations for shuffle ${stage.shuffleDep.shuffleId} with" + + s" shuffle merge ${stage.shuffleDep.shuffleMergeId} is" + + s" ${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}") + mergerLocs + } + /** Called when stage's parents are available and we can now do its task. */ private def submitMissingTasks(stage: Stage, jobId: Int): Unit = { logDebug("submitMissingTasks(" + stage + ")") @@ -1418,15 +1431,15 @@ private[spark] class DAGScheduler( case s: ShuffleMapStage => outputCommitCoordinator.stageStart(stage = s.id, maxPartitionId = s.numPartitions - 1) // Only generate merger location for a given shuffle dependency once. - if (s.shuffleDep.shuffleMergeEnabled) { - if (!s.shuffleDep.shuffleMergeFinalized) { + if (s.shuffleDep.shuffleMergeAllowed) { + if (!s.shuffleDep.isShuffleMergeFinalizedMarked) { prepareShuffleServicesForShuffleMapStage(s) } else { // Disable Shuffle merge for the retry/reuse of the same shuffle dependency if it has // already been merge finalized. If the shuffle dependency was previously assigned // merger locations but the corresponding shuffle map stage did not complete // successfully, we would still enable push for its retry. - s.shuffleDep.setShuffleMergeEnabled(false) + s.shuffleDep.setShuffleMergeAllowed(false) logInfo(s"Push-based shuffle disabled for $stage (${stage.name}) since it" + " is already shuffle merge finalized") } @@ -1821,7 +1834,7 @@ private[spark] class DAGScheduler( } if (runningStages.contains(shuffleStage) && shuffleStage.pendingPartitions.isEmpty) { - if (!shuffleStage.shuffleDep.shuffleMergeFinalized && + if (!shuffleStage.shuffleDep.isShuffleMergeFinalizedMarked && shuffleStage.shuffleDep.getMergerLocs.nonEmpty) { checkAndScheduleShuffleMergeFinalize(shuffleStage) } else { @@ -2313,7 +2326,7 @@ private[spark] class DAGScheduler( // Register merge statuses if the stage is still running and shuffle merge is not finalized yet. // TODO: SPARK-35549: Currently merge statuses results which come after shuffle merge // TODO: is finalized is not registered. - if (runningStages.contains(stage) && !stage.shuffleDep.shuffleMergeFinalized) { + if (runningStages.contains(stage) && !stage.shuffleDep.isShuffleMergeFinalizedMarked) { mapOutputTracker.registerMergeResults(stage.shuffleDep.shuffleId, mergeStatuses) } } @@ -2350,7 +2363,7 @@ private[spark] class DAGScheduler( // This is required to prevent shuffle merge finalization by dangling tasks of a // previous attempt in the case of indeterminate stage. if (shuffleDep.shuffleMergeId == shuffleMergeId) { - if (!shuffleDep.shuffleMergeFinalized && + if (!shuffleDep.isShuffleMergeFinalizedMarked && shuffleDep.incPushCompleted(mapIndex).toDouble / shuffleDep.rdd.partitions.length >= shufflePushMinRatio) { scheduleShuffleMergeFinalize(mapStage, delay = 0) @@ -2487,6 +2500,23 @@ private[spark] class DAGScheduler( executorFailureEpoch -= execId } shuffleFileLostEpoch -= execId + + if (pushBasedShuffleEnabled) { + // Only set merger locations for stages that are not yet finished and have empty mergers + shuffleIdToMapStage.filter { case (_, stage) => + stage.shuffleDep.shuffleMergeAllowed && stage.shuffleDep.getMergerLocs.isEmpty && + runningStages.contains(stage) + }.foreach { case(_, stage: ShuffleMapStage) => + if (getAndSetShufflePushMergerLocations(stage).nonEmpty) { + logInfo(s"Shuffle merge enabled adaptively for $stage with shuffle" + + s" ${stage.shuffleDep.shuffleId} and shuffle merge" + + s" ${stage.shuffleDep.shuffleMergeId} with ${stage.shuffleDep.getMergerLocs.size}" + + s" merger locations") + mapOutputTracker.registerShufflePushMergerLocations(stage.shuffleDep.shuffleId, + stage.shuffleDep.getMergerLocs) + } + } + } } private[scheduler] def handleStageCancellation(stageId: Int, reason: Option[String]): Unit = { @@ -2540,7 +2570,7 @@ private[spark] class DAGScheduler( stage.latestInfo.stageFailed(errorMessage.get) logInfo(s"$stage (${stage.name}) failed in $serviceTime s due to ${errorMessage.get}") } - + updateStageInfoForPushBasedShuffle(stage) if (!willRetry) { outputCommitCoordinator.stageEnd(stage.id) } @@ -2563,6 +2593,7 @@ private[spark] class DAGScheduler( val dependentJobs: Seq[ActiveJob] = activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq failedStage.latestInfo.completionTime = Some(clock.getTimeMillis()) + updateStageInfoForPushBasedShuffle(failedStage) for (job <- dependentJobs) { failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason", exception) } @@ -2571,6 +2602,19 @@ private[spark] class DAGScheduler( } } + private def updateStageInfoForPushBasedShuffle(stage: Stage): Unit = { + // With adaptive shuffle mergers, StageInfo's + // isPushBasedShuffleEnabled and shuffleMergers need to be updated at the end. + stage match { + case s: ShuffleMapStage => + stage.latestInfo.setPushBasedShuffleEnabled(s.shuffleDep.shuffleMergeEnabled) + if (s.shuffleDep.shuffleMergeEnabled) { + stage.latestInfo.setShuffleMergerCount(s.shuffleDep.getMergerLocs.size) + } + case _ => + } + } + /** Cancel all independent, running stages that are only used by this job. */ private def cancelRunningIndependentStages(job: ActiveJob, reason: String): Boolean = { var ableToCancelStages = true diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 07eed76805dd2..d10cf55ed0d10 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -255,9 +255,35 @@ private[spark] object HighlyCompressedMapStatus { // we expect that there will be far fewer of them, so we will perform fewer bitmap insertions. val emptyBlocks = new RoaringBitmap() val totalNumBlocks = uncompressedSizes.length - val threshold = Option(SparkEnv.get) - .map(_.conf.get(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD)) - .getOrElse(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD.defaultValue.get) + val accurateBlockSkewedFactor = Option(SparkEnv.get) + .map(_.conf.get(config.SHUFFLE_ACCURATE_BLOCK_SKEWED_FACTOR)) + .getOrElse(config.SHUFFLE_ACCURATE_BLOCK_SKEWED_FACTOR.defaultValue.get) + val shuffleAccurateBlockThreshold = + Option(SparkEnv.get) + .map(_.conf.get(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD)) + .getOrElse(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD.defaultValue.get) + val threshold = + if (accurateBlockSkewedFactor > 0) { + val sortedSizes = uncompressedSizes.sorted + val medianSize: Long = Utils.median(sortedSizes, true) + val maxAccurateSkewedBlockNumber = + Math.min( + Option(SparkEnv.get) + .map(_.conf.get(config.SHUFFLE_MAX_ACCURATE_SKEWED_BLOCK_NUMBER)) + .getOrElse(config.SHUFFLE_MAX_ACCURATE_SKEWED_BLOCK_NUMBER.defaultValue.get), + totalNumBlocks + ) + val skewSizeThreshold = + Math.max( + medianSize * accurateBlockSkewedFactor, + sortedSizes(totalNumBlocks - maxAccurateSkewedBlockNumber) + ) + Math.min(shuffleAccurateBlockThreshold, skewSizeThreshold) + } else { + // Disable skew detection if accurateBlockSkewedFactor <= 0 + shuffleAccurateBlockThreshold + } + val hugeBlockSizes = mutable.Map.empty[Int, Byte] while (i < totalNumBlocks) { val size = uncompressedSizes(i) diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index 7b681bf0abfe8..29835c482dfa1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -39,7 +39,9 @@ class StageInfo( val taskMetrics: TaskMetrics = null, private[spark] val taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty, private[spark] val shuffleDepId: Option[Int] = None, - val resourceProfileId: Int) { + val resourceProfileId: Int, + private[spark] var isPushBasedShuffleEnabled: Boolean = false, + private[spark] var shuffleMergerCount: Int = 0) { /** When this stage was submitted from the DAGScheduler to a TaskScheduler. */ var submissionTime: Option[Long] = None /** Time when the stage completed or when the stage was cancelled. */ @@ -73,6 +75,14 @@ class StageInfo( "running" } } + + private[spark] def setShuffleMergerCount(mergers: Int): Unit = { + shuffleMergerCount = mergers + } + + private[spark] def setPushBasedShuffleEnabled(pushBasedShuffleEnabled: Boolean): Unit = { + isPushBasedShuffleEnabled = pushBasedShuffleEnabled + } } private[spark] object StageInfo { @@ -108,6 +118,8 @@ private[spark] object StageInfo { taskMetrics, taskLocalityPreferences, shuffleDepId, - resourceProfileId) + resourceProfileId, + false, + 0) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetExcludeList.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetExcludeList.scala index d20f3ed65472e..f479e5e32bc2f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetExcludeList.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetExcludeList.scala @@ -117,7 +117,7 @@ private[scheduler] class TaskSetExcludelist( // over the limit, exclude this task from the entire host. val execsWithFailuresOnNode = nodeToExecsWithFailures.getOrElseUpdate(host, new HashSet()) execsWithFailuresOnNode += exec - val failuresOnHost = execsWithFailuresOnNode.toIterator.flatMap { exec => + val failuresOnHost = execsWithFailuresOnNode.iterator.flatMap { exec => execToFailures.get(exec).map { failures => // We count task attempts here, not the number of unique executors with failures. This is // because jobs are aborted based on the number task attempts; if we counted unique diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 68b7065002f4e..b7fae2a533f0e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -820,6 +820,7 @@ private[spark] class TaskSetManager( s"on ${info.host} (executor ${info.executorId}) ($tasksSuccessful/$numTasks)") // Mark successful and stop if all the tasks have succeeded. successful(index) = true + numFailures(index) = 0 if (tasksSuccessful == numTasks) { isZombie = true } @@ -843,6 +844,7 @@ private[spark] class TaskSetManager( if (!successful(index)) { tasksSuccessful += 1 successful(index) = true + numFailures(index) = 0 if (tasksSuccessful == numTasks) { isZombie = true } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 13a7183a29dd6..08f0e5f9c5f1d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -21,7 +21,7 @@ import java.util.concurrent.{ScheduledExecutorService, TimeUnit} import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} import javax.annotation.concurrent.GuardedBy -import scala.collection.mutable.{HashMap, HashSet} +import scala.collection.mutable.{HashMap, HashSet, Queue} import scala.concurrent.Future import org.apache.hadoop.security.UserGroupInformation @@ -82,6 +82,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp @GuardedBy("CoarseGrainedSchedulerBackend.this") private val requestedTotalExecutorsPerResourceProfile = new HashMap[ResourceProfile, Int] + // Profile IDs to the times that executors were requested for. + // The operations we do on queue are all amortized constant cost + // see https://www.scala-lang.org/api/2.13.x/scala/collection/mutable/ArrayDeque.html + @GuardedBy("CoarseGrainedSchedulerBackend.this") + private val execRequestTimes = new HashMap[Int, Queue[(Int, Long)]] + private val listenerBus = scheduler.sc.listenerBus // Executors we have requested the cluster manager to kill that have not died yet; maps @@ -260,9 +266,27 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp .resourceProfileFromId(resourceProfileId).getNumSlotsPerAddress(rName, conf) (info.name, new ExecutorResourceInfo(info.name, info.addresses, numParts)) } + // If we've requested the executor figure out when we did. + val reqTs: Option[Long] = CoarseGrainedSchedulerBackend.this.synchronized { + execRequestTimes.get(resourceProfileId).flatMap { + times => + times.headOption.map { + h => + // Take off the top element + times.dequeue() + // If we requested more than one exec reduce the req count by 1 and prepend it back + if (h._1 > 1) { + ((h._1 - 1, h._2)) +=: times + } + h._2 + } + } + } + val data = new ExecutorData(executorRef, executorAddress, hostname, 0, cores, logUrlHandler.applyPattern(logUrls, attributes), attributes, - resourcesInfo, resourceProfileId, registrationTs = System.currentTimeMillis()) + resourcesInfo, resourceProfileId, registrationTs = System.currentTimeMillis(), + requestTs = reqTs) // This must be synchronized because variables mutated // in this block are read when requesting executors CoarseGrainedSchedulerBackend.this.synchronized { @@ -742,6 +766,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp val numExisting = requestedTotalExecutorsPerResourceProfile.getOrElse(defaultProf, 0) requestedTotalExecutorsPerResourceProfile(defaultProf) = numExisting + numAdditionalExecutors // Account for executors pending to be added or removed + updateExecRequestTime(defaultProf.id, numAdditionalExecutors) doRequestTotalExecutors(requestedTotalExecutorsPerResourceProfile.toMap) } @@ -780,15 +805,53 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp (scheduler.sc.resourceProfileManager.resourceProfileFromId(rpid), num) } val response = synchronized { + val oldResourceProfileToNumExecutors = requestedTotalExecutorsPerResourceProfile.map { + case (rp, num) => + (rp.id, num) + }.toMap this.requestedTotalExecutorsPerResourceProfile.clear() this.requestedTotalExecutorsPerResourceProfile ++= resourceProfileToNumExecutors this.numLocalityAwareTasksPerResourceProfileId = numLocalityAwareTasksPerResourceProfileId this.rpHostToLocalTaskCount = hostToLocalTaskCount + updateExecRequestTimes(oldResourceProfileToNumExecutors, resourceProfileIdToNumExecutors) doRequestTotalExecutors(requestedTotalExecutorsPerResourceProfile.toMap) } defaultAskTimeout.awaitResult(response) } + private def updateExecRequestTimes(oldProfile: Map[Int, Int], newProfile: Map[Int, Int]): Unit = { + newProfile.map { + case (k, v) => + val delta = v - oldProfile.getOrElse(k, 0) + if (delta != 0) { + updateExecRequestTime(k, delta) + } + } + } + + private def updateExecRequestTime(profileId: Int, delta: Int) = { + val times = execRequestTimes.getOrElseUpdate(profileId, Queue[(Int, Long)]()) + if (delta > 0) { + // Add the request to the end, constant time op + times += ((delta, System.currentTimeMillis())) + } else if (delta < 0) { + // Consume as if |delta| had been allocated + var c = -delta + // Note: it's possible that something else allocated an executor and we have + // a negative delta, we can just avoid mutating the queue. + while (c > 0 && !times.isEmpty) { + val h = times.dequeue + if (h._1 > c) { + // Prepend updated first req to times, constant time op + ((h._1 - c, h._2)) +=: times + c = 0 + } else { + c = c - h._1 + } + } + } + } + /** * Request executors from the cluster manager by specifying the total number desired, * including existing pending and running executors. diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala index 86b44e835368c..07236d4007faa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -31,6 +31,7 @@ import org.apache.spark.scheduler.ExecutorResourceInfo * @param resourcesInfo The information of the currently available resources on the executor * @param resourceProfileId The id of the ResourceProfile being used by this executor * @param registrationTs The registration timestamp of this executor + * @param requestTs What time this executor was most likely requested at */ private[cluster] class ExecutorData( val executorEndpoint: RpcEndpointRef, @@ -42,6 +43,7 @@ private[cluster] class ExecutorData( override val attributes: Map[String, String], override val resourcesInfo: Map[String, ExecutorResourceInfo], override val resourceProfileId: Int, - val registrationTs: Long + val registrationTs: Long, + val requestTs: Option[Long] ) extends ExecutorInfo(executorHost, totalCores, logUrlMap, attributes, - resourcesInfo, resourceProfileId) + resourcesInfo, resourceProfileId, Some(registrationTs), requestTs) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala index a97b08941ba78..5be8950192c8c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala @@ -31,10 +31,19 @@ class ExecutorInfo( val logUrlMap: Map[String, String], val attributes: Map[String, String], val resourcesInfo: Map[String, ResourceInformation], - val resourceProfileId: Int) { + val resourceProfileId: Int, + val registrationTime: Option[Long], + val requestTime: Option[Long]) { + def this(executorHost: String, totalCores: Int, logUrlMap: Map[String, String], + attributes: Map[String, String], resourcesInfo: Map[String, ResourceInformation], + resourceProfileId: Int) = { + this(executorHost, totalCores, logUrlMap, attributes, resourcesInfo, resourceProfileId, + None, None) + } def this(executorHost: String, totalCores: Int, logUrlMap: Map[String, String]) = { - this(executorHost, totalCores, logUrlMap, Map.empty, Map.empty, DEFAULT_RESOURCE_PROFILE_ID) + this(executorHost, totalCores, logUrlMap, Map.empty, Map.empty, DEFAULT_RESOURCE_PROFILE_ID, + None, None) } def this( @@ -42,7 +51,8 @@ class ExecutorInfo( totalCores: Int, logUrlMap: Map[String, String], attributes: Map[String, String]) = { - this(executorHost, totalCores, logUrlMap, attributes, Map.empty, DEFAULT_RESOURCE_PROFILE_ID) + this(executorHost, totalCores, logUrlMap, attributes, Map.empty, DEFAULT_RESOURCE_PROFILE_ID, + None, None) } def this( @@ -52,7 +62,7 @@ class ExecutorInfo( attributes: Map[String, String], resourcesInfo: Map[String, ResourceInformation]) = { this(executorHost, totalCores, logUrlMap, attributes, resourcesInfo, - DEFAULT_RESOURCE_PROFILE_ID) + DEFAULT_RESOURCE_PROFILE_ID, None, None) } def canEqual(other: Any): Boolean = other.isInstanceOf[ExecutorInfo] @@ -72,6 +82,6 @@ class ExecutorInfo( override def hashCode(): Int = { val state = Seq(executorHost, totalCores, logUrlMap, attributes, resourcesInfo, resourceProfileId) - state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b) + state.filter(_ != null).map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitor.scala b/core/src/main/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitor.scala index 4939dab5702a7..defef5bfcf23b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitor.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitor.scala @@ -134,7 +134,7 @@ private[spark] class ExecutorMonitor( .toSeq updateNextTimeout(newNextTimeout) } - timedOutExecs + timedOutExecs.sortBy(_._1) } /** @@ -356,7 +356,8 @@ private[spark] class ExecutorMonitor( if (removed != null) { decrementExecResourceProfileCount(removed.resourceProfileId) if (removed.decommissioning) { - if (event.reason == ExecutorLossMessage.decommissionFinished) { + if (event.reason == ExecutorLossMessage.decommissionFinished || + event.reason == ExecutorDecommission().message) { metrics.gracefullyDecommissioned.inc() } else { metrics.decommissionUnfinished.inc() @@ -378,6 +379,10 @@ private[spark] class ExecutorMonitor( } override def onBlockUpdated(event: SparkListenerBlockUpdated): Unit = { + if (!client.isExecutorActive(event.blockUpdatedInfo.blockManagerId.executorId)) { + return + } + val exec = ensureExecutorIsTracked(event.blockUpdatedInfo.blockManagerId.executorId, UNKNOWN_RESOURCE_PROFILE_ID) diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 7454a74094541..f1485ec99789d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -93,7 +93,8 @@ private[spark] class IndexShuffleBlockResolver( def getDataFile(shuffleId: Int, mapId: Long, dirs: Option[Array[String]]): File = { val blockId = ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID) dirs - .map(ExecutorDiskUtils.getFile(_, blockManager.subDirsPerLocalDir, blockId.name)) + .map(d => + new File(ExecutorDiskUtils.getFilePath(d, blockManager.subDirsPerLocalDir, blockId.name))) .getOrElse(blockManager.diskBlockManager.getFile(blockId)) } @@ -109,7 +110,8 @@ private[spark] class IndexShuffleBlockResolver( dirs: Option[Array[String]] = None): File = { val blockId = ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID) dirs - .map(ExecutorDiskUtils.getFile(_, blockManager.subDirsPerLocalDir, blockId.name)) + .map(d => + new File(ExecutorDiskUtils.getFilePath(d, blockManager.subDirsPerLocalDir, blockId.name))) .getOrElse(blockManager.diskBlockManager.getFile(blockId)) } @@ -546,7 +548,8 @@ private[spark] class IndexShuffleBlockResolver( val blockId = ShuffleChecksumBlockId(shuffleId, mapId, NOOP_REDUCE_ID) val fileName = ShuffleChecksumHelper.getChecksumFileName(blockId.name, algorithm) dirs - .map(ExecutorDiskUtils.getFile(_, blockManager.subDirsPerLocalDir, fileName)) + .map(d => + new File(ExecutorDiskUtils.getFilePath(d, blockManager.subDirsPerLocalDir, fileName))) .getOrElse(blockManager.diskBlockManager.getFile(fileName)) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala index d6972cd470c9e..230ec7efdb14f 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala @@ -118,11 +118,11 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging { pushRequests ++= Utils.randomize(requests) if (pushRequests.isEmpty) { notifyDriverAboutPushCompletion() + } else { + submitTask(() => { + tryPushUpToMax() + }) } - - submitTask(() => { - tryPushUpToMax() - }) } private[shuffle] def tryPushUpToMax(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala index 270d23efc1b2d..be5b8385f5e7e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala @@ -59,13 +59,25 @@ private[spark] class ShuffleWriteProcessor extends Serializable with Logging { rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) val mapStatus = writer.stop(success = true) if (mapStatus.isDefined) { + // Check if sufficient shuffle mergers are available now for the ShuffleMapTask to push + if (dep.shuffleMergeAllowed && dep.getMergerLocs.isEmpty) { + val mapOutputTracker = SparkEnv.get.mapOutputTracker + val mergerLocs = + mapOutputTracker.getShufflePushMergerLocations(dep.shuffleId) + if (mergerLocs.nonEmpty) { + dep.setMergerLocs(mergerLocs) + } + } // Initiate shuffle push process if push based shuffle is enabled // The map task only takes care of converting the shuffle data file into multiple // block push requests. It delegates pushing the blocks to a different thread-pool - // ShuffleBlockPusher.BLOCK_PUSHER_POOL. - if (dep.shuffleMergeEnabled && dep.getMergerLocs.nonEmpty && !dep.shuffleMergeFinalized) { + if (!dep.shuffleMergeFinalized) { manager.shuffleBlockResolver match { case resolver: IndexShuffleBlockResolver => + logInfo(s"Shuffle merge enabled with ${dep.getMergerLocs.size} merger locations " + + s" for stage ${context.stageId()} with shuffle ID ${dep.shuffleId}") + logDebug(s"Starting pushing blocks for the task ${context.taskAttemptId()}") val dataFile = resolver.getDataFile(dep.shuffleId, mapId) new ShuffleBlockPusher(SparkEnv.get.conf) .initiateBlockPush(dataFile, writer.getPartitionLengths(), dep, partition.index) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index e8c7f1f4d91c3..46aca07ce43f6 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -131,7 +131,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, _, C]] val (blocksByAddress, canEnableBatchFetch) = - if (baseShuffleHandle.dependency.shuffleMergeEnabled) { + if (baseShuffleHandle.dependency.isShuffleMergeFinalizedMarked) { val res = SparkEnv.get.mapOutputTracker.getPushBasedShuffleMapSizesByExecutorId( handle.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition) (res.iter, res.enableBatchFetch) diff --git a/core/src/main/scala/org/apache/spark/status/KVUtils.scala b/core/src/main/scala/org/apache/spark/status/KVUtils.scala index ddee539eb9eb4..7a4b613ac0696 100644 --- a/core/src/main/scala/org/apache/spark/status/KVUtils.scala +++ b/core/src/main/scala/org/apache/spark/status/KVUtils.scala @@ -38,8 +38,8 @@ private[spark] object KVUtils extends Logging { /** Use this to annotate constructor params to be used as KVStore indices. */ type KVIndexParam = KVIndex @getter - private lazy val backend = - HybridStoreDiskBackend.withName(new SparkConf().get(HYBRID_STORE_DISK_BACKEND)) + private def backend(conf: SparkConf) = + HybridStoreDiskBackend.withName(conf.get(HYBRID_STORE_DISK_BACKEND)) /** * A KVStoreSerializer that provides Scala types serialization too, and uses the same options as @@ -59,11 +59,12 @@ private[spark] object KVUtils extends Logging { * @param metadata Metadata value to compare to the data in the store. If the store does not * contain any metadata (e.g. it's a new store), this value is written as * the store's metadata. + * @param conf SparkConf use to get `HYBRID_STORE_DISK_BACKEND` */ - def open[M: ClassTag](path: File, metadata: M): KVStore = { + def open[M: ClassTag](path: File, metadata: M, conf: SparkConf): KVStore = { require(metadata != null, "Metadata is required.") - val db = backend match { + val db = backend(conf) match { case LEVELDB => new LevelDB(path, new KVStoreScalaSerializer()) case ROCKSDB => new RocksDB(path, new KVStoreScalaSerializer()) } diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index 103e4bab411e5..39bf593274904 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -344,7 +344,7 @@ private[spark] class TaskDataWrapper( @JsonIgnore @KVIndex(value = TaskIndexNames.SHUFFLE_TOTAL_READS, parent = TaskIndexNames.STAGE) private def shuffleTotalReads: Long = { if (hasMetrics) { - getMetricValue(shuffleLocalBytesRead) + getMetricValue(shuffleRemoteBytesRead) + shuffleLocalBytesRead + shuffleRemoteBytesRead } else { -1L } @@ -353,7 +353,7 @@ private[spark] class TaskDataWrapper( @JsonIgnore @KVIndex(value = TaskIndexNames.SHUFFLE_TOTAL_BLOCKS, parent = TaskIndexNames.STAGE) private def shuffleTotalBlocks: Long = { if (hasMetrics) { - getMetricValue(shuffleLocalBlocksFetched) + getMetricValue(shuffleRemoteBlocksFetched) + shuffleLocalBlocksFetched + shuffleRemoteBlocksFetched } else { -1L } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index ec4dc7722e681..d5901888d1abf 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -1143,7 +1143,8 @@ private[spark] class BlockManager( val buf = blockTransferService.fetchBlockSync(loc.host, loc.port, loc.executorId, blockId.toString, tempFileManager) if (blockSize > 0 && buf.size() == 0) { - throw new IllegalStateException("Empty buffer received for non empty block") + throw new IllegalStateException("Empty buffer received for non empty block " + + s"when fetching remote block $blockId from $loc") } buf } catch { @@ -1155,7 +1156,8 @@ private[spark] class BlockManager( // Give up trying anymore locations. Either we've tried all of the original locations, // or we've refreshed the list of locations from the master, and have still // hit failures after trying locations from the refreshed list. - logWarning(s"Failed to fetch block after $totalFailureCount fetch failures. " + + logWarning(s"Failed to fetch remote block $blockId " + + s"from [${locations.mkString(", ")}] after $totalFailureCount fetch failures. " + s"Most recent failure cause:", e) return None } @@ -1200,7 +1202,7 @@ private[spark] class BlockManager( blockId: BlockId, localDirs: Array[String], blockSize: Long): Option[ManagedBuffer] = { - val file = ExecutorDiskUtils.getFile(localDirs, subDirsPerLocalDir, blockId.name) + val file = new File(ExecutorDiskUtils.getFilePath(localDirs, subDirsPerLocalDir, blockId.name)) if (file.exists()) { val managedBuffer = securityManager.getIOEncryptionKey() match { case Some(key) => @@ -1870,7 +1872,7 @@ private[spark] class BlockManager( serializerManager.dataSerializeStream( blockId, out, - elements.toIterator)(info.classTag.asInstanceOf[ClassTag[T]]) + elements.iterator)(info.classTag.asInstanceOf[ClassTag[T]]) } case Right(bytes) => diskStore.putBytes(blockId, bytes) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerDecommissioner.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerDecommissioner.scala index aef5cbf07d681..cb01faf7d401d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerDecommissioner.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerDecommissioner.scala @@ -109,17 +109,21 @@ private[storage] class BlockManagerDecommissioner( s"to $peer ($retryCount / $maxReplicationFailuresForDecommission)") // Migrate the components of the blocks. try { - blocks.foreach { case (blockId, buffer) => - logDebug(s"Migrating sub-block ${blockId}") - bm.blockTransferService.uploadBlockSync( - peer.host, - peer.port, - peer.executorId, - blockId, - buffer, - StorageLevel.DISK_ONLY, - null) // class tag, we don't need for shuffle - logDebug(s"Migrated sub-block $blockId") + if (fallbackStorage.isDefined && peer == FallbackStorage.FALLBACK_BLOCK_MANAGER_ID) { + fallbackStorage.foreach(_.copy(shuffleBlockInfo, bm)) + } else { + blocks.foreach { case (blockId, buffer) => + logDebug(s"Migrating sub-block ${blockId}") + bm.blockTransferService.uploadBlockSync( + peer.host, + peer.port, + peer.executorId, + blockId, + buffer, + StorageLevel.DISK_ONLY, + null) // class tag, we don't need for shuffle + logDebug(s"Migrated sub-block $blockId") + } } logInfo(s"Migrated $shuffleBlockInfo to $peer") } catch { @@ -131,7 +135,10 @@ private[storage] class BlockManagerDecommissioner( // driver a no longer referenced RDD with shuffle files. if (bm.migratableResolver.getMigrationBlocks(shuffleBlockInfo).size < blocks.size) { logWarning(s"Skipping block $shuffleBlockInfo, block deleted.") - } else if (fallbackStorage.isDefined) { + } else if (fallbackStorage.isDefined + // Confirm peer is not the fallback BM ID because fallbackStorage would already + // have been used in the try-block above so there's no point trying again + && peer != FallbackStorage.FALLBACK_BLOCK_MANAGER_ID) { fallbackStorage.foreach(_.copy(shuffleBlockInfo, bm)) } else { logError(s"Error occurred during migrating $shuffleBlockInfo", e) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index bebe32b95203c..c6a22972d2a0f 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -79,7 +79,7 @@ private[spark] class DiskBlockManager( /** Looks up a file by hashing it into one of our local subdirectories. */ // This method should be kept in sync with - // org.apache.spark.network.shuffle.ExecutorDiskUtils#getFile(). + // org.apache.spark.network.shuffle.ExecutorDiskUtils#getFilePath(). def getFile(filename: String): File = { // Figure out which local directory it hashes to, and which subdirectory in that val hash = Utils.nonNegativeHash(filename) @@ -130,7 +130,7 @@ private[spark] class DiskBlockManager( throw new IllegalArgumentException( s"Cannot read $filename because merged shuffle dirs is empty") } - ExecutorDiskUtils.getFile(dirs.get, subDirsPerLocalDir, filename) + new File(ExecutorDiskUtils.getFilePath(dirs.get, subDirsPerLocalDir, filename)) } /** Check if disk block manager has a block. */ diff --git a/core/src/main/scala/org/apache/spark/storage/FallbackStorage.scala b/core/src/main/scala/org/apache/spark/storage/FallbackStorage.scala index d137099e73437..0c1206cb9010b 100644 --- a/core/src/main/scala/org/apache/spark/storage/FallbackStorage.scala +++ b/core/src/main/scala/org/apache/spark/storage/FallbackStorage.scala @@ -95,7 +95,9 @@ private[storage] class FallbackStorage(conf: SparkConf) extends Logging { } private[storage] class NoopRpcEndpointRef(conf: SparkConf) extends RpcEndpointRef(conf) { + // scalastyle:off executioncontextglobal import scala.concurrent.ExecutionContext.Implicits.global + // scalastyle:on executioncontextglobal override def address: RpcAddress = null override def name: String = "fallback" override def send(message: Any): Unit = {} diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 1d3543ed8b23c..144d8cff7d4fa 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -305,7 +305,7 @@ private[spark] class MemoryStore( val unrolledIterator = if (valuesHolder.vector != null) { valuesHolder.vector.iterator } else { - valuesHolder.arrayValues.toIterator + valuesHolder.arrayValues.iterator } Left(new PartiallyUnrolledIterator( diff --git a/core/src/main/scala/org/apache/spark/ui/GraphUIData.scala b/core/src/main/scala/org/apache/spark/ui/GraphUIData.scala index 87ff677514461..ab8757ff9d1f2 100644 --- a/core/src/main/scala/org/apache/spark/ui/GraphUIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/GraphUIData.scala @@ -102,7 +102,7 @@ private[spark] class GraphUIData( val jsForLabels = operationLabels.toSeq.sorted.mkString("[\"", "\",\"", "\"]") val (maxX, minX, maxY, minY) = if (values != null && values.length > 0) { - val xValues = values.map(_._1.toLong) + val xValues = values.map(_._1) val yValues = values.map(_._2.asScala.toSeq.map(_._2.toLong).sum) (xValues.max, xValues.min, yValues.max, yValues.min) } else { diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index fb43af357f7b8..c1708c320c5d4 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -110,7 +110,7 @@ private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends // Don't show the tables if there is no stream block Nil } else { - val sorted = blocks.groupBy(_.name).toSeq.sortBy(_._1.toString) + val sorted = blocks.groupBy(_.name).toSeq.sortBy(_._1)

Receiver Blocks

diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 4e68ee0ed83cd..3287a786597c3 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -463,7 +463,7 @@ private[spark] object JsonProtocol { case ExecutorLostFailure(executorId, exitCausedByApp, reason) => ("Executor ID" -> executorId) ~ ("Exit Caused By App" -> exitCausedByApp) ~ - ("Loss Reason" -> reason.map(_.toString)) + ("Loss Reason" -> reason) case taskKilled: TaskKilled => val accumUpdates = JArray(taskKilled.accumUpdates.map(accumulableInfoToJson).toList) ("Kill Reason" -> taskKilled.reason) ~ @@ -526,7 +526,9 @@ private[spark] object JsonProtocol { ("Log Urls" -> mapToJson(executorInfo.logUrlMap)) ~ ("Attributes" -> mapToJson(executorInfo.attributes)) ~ ("Resources" -> resourcesMapToJson(executorInfo.resourcesInfo)) ~ - ("Resource Profile Id" -> executorInfo.resourceProfileId) + ("Resource Profile Id" -> executorInfo.resourceProfileId) ~ + ("Registration Time" -> executorInfo.registrationTime) ~ + ("Request Time" -> executorInfo.requestTime) } def resourcesMapToJson(m: Map[String, ResourceInformation]): JValue = { @@ -1220,8 +1222,15 @@ private[spark] object JsonProtocol { case Some(id) => id.extract[Int] case None => ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID } + val registrationTs = jsonOption(json \ "Registration Time") map { ts => + ts.extract[Long] + } + val requestTs = jsonOption(json \ "Request Time") map { ts => + ts.extract[Long] + } + new ExecutorInfo(executorHost, totalCores, logUrls, attributes.toMap, resources.toMap, - resourceProfileId) + resourceProfileId, registrationTs, requestTs) } def blockUpdatedInfoFromJson(json: JValue): BlockUpdatedInfo = { diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 9ec93077d0a4c..55d13801d4abc 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -262,7 +262,7 @@ object SizeEstimator extends Logging { val s2 = sampleArray(array, state, rand, drawn, length) val size = math.min(s1, s2) state.size += math.max(s1, s2) + - (size * ((length - ARRAY_SAMPLE_SIZE) / (ARRAY_SAMPLE_SIZE))).toLong + (size * ((length - ARRAY_SAMPLE_SIZE) / ARRAY_SAMPLE_SIZE)) } } } @@ -282,7 +282,7 @@ object SizeEstimator extends Logging { drawn.add(index) val obj = ScalaRunTime.array_apply(array, index).asInstanceOf[AnyRef] if (obj != null) { - size += SizeEstimator.estimate(obj, state.visited).toLong + size += SizeEstimator.estimate(obj, state.visited) } } size diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 4410fe7fa8657..c8c7ea627b864 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -355,26 +355,26 @@ private[spark] object Utils extends Logging { closeStreams: Boolean = false, transferToEnabled: Boolean = false): Long = { tryWithSafeFinally { - if (in.isInstanceOf[FileInputStream] && out.isInstanceOf[FileOutputStream] - && transferToEnabled) { - // When both streams are File stream, use transferTo to improve copy performance. - val inChannel = in.asInstanceOf[FileInputStream].getChannel() - val outChannel = out.asInstanceOf[FileOutputStream].getChannel() - val size = inChannel.size() - copyFileStreamNIO(inChannel, outChannel, 0, size) - size - } else { - var count = 0L - val buf = new Array[Byte](8192) - var n = 0 - while (n != -1) { - n = in.read(buf) - if (n != -1) { - out.write(buf, 0, n) - count += n + (in, out) match { + case (input: FileInputStream, output: FileOutputStream) if transferToEnabled => + // When both streams are File stream, use transferTo to improve copy performance. + val inChannel = input.getChannel + val outChannel = output.getChannel + val size = inChannel.size() + copyFileStreamNIO(inChannel, outChannel, 0, size) + size + case (input, output) => + var count = 0L + val buf = new Array[Byte](8192) + var n = 0 + while (n != -1) { + n = input.read(buf) + if (n != -1) { + output.write(buf, 0, n) + count += n + } } - } - count + count } } { if (closeStreams) { @@ -593,6 +593,9 @@ private[spark] object Utils extends Logging { * basically copied from `org.apache.hadoop.yarn.util.FSDownload.unpack`. */ def unpack(source: File, dest: File): Unit = { + if (!source.exists()) { + throw new FileNotFoundException(source.getAbsolutePath) + } val lowerSrc = StringUtils.toLowerCase(source.getName) if (lowerSrc.endsWith(".jar")) { RunJar.unJar(source, dest, RunJar.MATCH_ANY) @@ -3216,6 +3219,23 @@ private[spark] object Utils extends Logging { } files.toSeq } + + /** + * Return the median number of a long array + * + * @param sizes + * @param alreadySorted + * @return + */ + def median(sizes: Array[Long], alreadySorted: Boolean): Long = { + val len = sizes.length + val sortedSize = if (alreadySorted) sizes else sizes.sorted + len match { + case _ if (len % 2 == 0) => + math.max((sortedSize(len / 2) + sortedSize(len / 2 - 1)) / 2, 1) + case _ => math.max(sortedSize(len / 2), 1) + } + } } private[util] object CallerContext extends Logging { diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 87f9ab32eb585..f4e09b7a0a38a 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -54,9 +54,6 @@ import org.apache.spark.storage.*; import org.apache.spark.util.Utils; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.greaterThan; -import static org.hamcrest.Matchers.lessThan; import static org.junit.Assert.*; import static org.mockito.Answers.RETURNS_SMART_NULLS; import static org.mockito.Mockito.*; @@ -92,7 +89,6 @@ public void tearDown() { } @Before - @SuppressWarnings("unchecked") public void setUp() throws Exception { MockitoAnnotations.openMocks(this).close(); tempDir = Utils.createTempDir(null, "test"); @@ -418,9 +414,9 @@ private void testMergingSpills( assertSpillFilesWereCleanedUp(); ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics(); assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten()); - assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L)); - assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length())); - assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L)); + assertTrue(taskMetrics.diskBytesSpilled() > 0L); + assertTrue(taskMetrics.diskBytesSpilled() < mergedOutputFile.length()); + assertTrue(taskMetrics.memoryBytesSpilled() > 0L); assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.bytesWritten()); } @@ -510,9 +506,9 @@ public void writeEnoughDataToTriggerSpill() throws Exception { assertSpillFilesWereCleanedUp(); ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics(); assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten()); - assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L)); - assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length())); - assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L)); + assertTrue(taskMetrics.diskBytesSpilled() > 0L); + assertTrue(taskMetrics.diskBytesSpilled() < mergedOutputFile.length()); + assertTrue(taskMetrics.memoryBytesSpilled()> 0L); assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.bytesWritten()); } @@ -543,9 +539,9 @@ private void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exc assertSpillFilesWereCleanedUp(); ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics(); assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten()); - assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L)); - assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length())); - assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L)); + assertTrue(taskMetrics.diskBytesSpilled() > 0L); + assertTrue(taskMetrics.diskBytesSpilled() < mergedOutputFile.length()); + assertTrue(taskMetrics.memoryBytesSpilled()> 0L); assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.bytesWritten()); } diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 3685f6826752d..277c8ffa99a8f 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -24,7 +24,6 @@ import scala.Tuple2$; -import org.hamcrest.MatcherAssert; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -49,9 +48,7 @@ import org.apache.spark.util.Utils; import org.apache.spark.internal.config.package$; -import static org.hamcrest.Matchers.greaterThan; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; +import static org.junit.Assert.*; import static org.mockito.Answers.RETURNS_SMART_NULLS; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; @@ -527,7 +524,7 @@ public void failureToGrow() { break; } } - MatcherAssert.assertThat(i, greaterThan(0)); + assertTrue(i > 0); Assert.assertFalse(success); } finally { map.free(); diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 025bb47cbff78..04316a62f4f8c 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -25,7 +25,6 @@ import scala.Tuple2$; -import org.hamcrest.MatcherAssert; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -46,8 +45,6 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.util.Utils; -import static org.hamcrest.Matchers.greaterThan; -import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.junit.Assert.*; import static org.mockito.Answers.RETURNS_SMART_NULLS; import static org.mockito.Mockito.*; @@ -225,7 +222,7 @@ public void testSortTimeMetric() throws Exception { sorter.insertRecord(null, 0, 0, 0, false); sorter.spill(); - MatcherAssert.assertThat(sorter.getSortTimeNanos(), greaterThan(prevSortTime)); + assertTrue(sorter.getSortTimeNanos() > prevSortTime); prevSortTime = sorter.getSortTimeNanos(); sorter.spill(); // no sort needed @@ -233,7 +230,7 @@ public void testSortTimeMetric() throws Exception { sorter.insertRecord(null, 0, 0, 0, false); UnsafeSorterIterator iter = sorter.getSortedIterator(); - MatcherAssert.assertThat(sorter.getSortTimeNanos(), greaterThan(prevSortTime)); + assertTrue(sorter.getSortTimeNanos() > prevSortTime); sorter.cleanupResources(); assertSpillFilesWereCleanedUp(); @@ -252,7 +249,7 @@ public void spillingOccursInResponseToMemoryPressure() throws Exception { // The insertion of this record should trigger a spill: insertNumber(sorter, 0); // Ensure that spill files were created - MatcherAssert.assertThat(tempDir.listFiles().length, greaterThanOrEqualTo(1)); + assertTrue(tempDir.listFiles().length >= 1); // Read back the sorted data: UnsafeSorterIterator iter = sorter.getSortedIterator(); diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 9d4909ddce792..ea1dc9957f466 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -33,10 +33,8 @@ import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.internal.config.package$; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.greaterThanOrEqualTo; -import static org.hamcrest.Matchers.isIn; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; public class UnsafeInMemorySorterSuite { @@ -137,8 +135,8 @@ public int compare( final String str = getStringFromDataPage(iter.getBaseObject(), iter.getBaseOffset(), iter.getRecordLength()); final long keyPrefix = iter.getKeyPrefix(); - assertThat(str, isIn(Arrays.asList(dataToSort))); - assertThat(keyPrefix, greaterThanOrEqualTo(prevPrefix)); + assertTrue(Arrays.asList(dataToSort).contains(str)); + assertTrue(keyPrefix >= prevPrefix); prevPrefix = keyPrefix; iterLength++; } diff --git a/core/src/test/java/test/org/apache/spark/JavaAPISuite.java b/core/src/test/java/test/org/apache/spark/JavaAPISuite.java index 3796d3ba88ed6..fd91237a999a3 100644 --- a/core/src/test/java/test/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/test/org/apache/spark/JavaAPISuite.java @@ -130,7 +130,6 @@ public void sparkContextUnion() { assertEquals(4, pUnion.count()); } - @SuppressWarnings("unchecked") @Test public void intersection() { List ints1 = Arrays.asList(1, 10, 2, 3, 4, 5); @@ -216,7 +215,6 @@ public void sortByKey() { assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2)); } - @SuppressWarnings("unchecked") @Test public void repartitionAndSortWithinPartitions() { List> pairs = new ArrayList<>(); @@ -356,7 +354,6 @@ public void zipWithIndex() { assertEquals(correctIndexes, indexes.collect()); } - @SuppressWarnings("unchecked") @Test public void lookup() { JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( @@ -401,7 +398,6 @@ public void groupByOnPairRDD() { assertEquals(5, Iterables.size(oddsAndEvens.lookup(false).get(0))); // Odds } - @SuppressWarnings("unchecked") @Test public void keyByOnPairRDD() { // Regression test for SPARK-4459 @@ -413,7 +409,6 @@ public void keyByOnPairRDD() { assertEquals(1, (long) keyed.lookup("2").get(0)._1()); } - @SuppressWarnings("unchecked") @Test public void cogroup() { JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( @@ -433,7 +428,6 @@ public void cogroup() { cogrouped.collect(); } - @SuppressWarnings("unchecked") @Test public void cogroup3() { JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( @@ -460,7 +454,6 @@ public void cogroup3() { cogrouped.collect(); } - @SuppressWarnings("unchecked") @Test public void cogroup4() { JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( @@ -491,7 +484,6 @@ public void cogroup4() { cogrouped.collect(); } - @SuppressWarnings("unchecked") @Test public void leftOuterJoin() { JavaPairRDD rdd1 = sc.parallelizePairs(Arrays.asList( @@ -557,7 +549,6 @@ public void treeAggregateWithFinalAggregateOnExecutor() { } } - @SuppressWarnings("unchecked") @Test public void aggregateByKey() { JavaPairRDD pairs = sc.parallelizePairs( @@ -583,7 +574,6 @@ public void aggregateByKey() { assertEquals(new HashSet<>(Arrays.asList(1, 3)), sets.get(5)); } - @SuppressWarnings("unchecked") @Test public void foldByKey() { List> pairs = Arrays.asList( @@ -600,7 +590,6 @@ public void foldByKey() { assertEquals(3, sums.lookup(3).get(0).intValue()); } - @SuppressWarnings("unchecked") @Test public void reduceByKey() { List> pairs = Arrays.asList( @@ -836,7 +825,6 @@ public void flatMap() { assertEquals(11, pairsRDD.count()); } - @SuppressWarnings("unchecked") @Test public void mapsFromPairsToPairs() { List> pairs = Arrays.asList( @@ -919,7 +907,6 @@ public void repartition() { } } - @SuppressWarnings("unchecked") @Test public void persist() { JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); @@ -1018,7 +1005,6 @@ public void textFilesCompressed() { assertEquals(expected, readRDD.collect()); } - @SuppressWarnings("unchecked") @Test public void sequenceFile() { String outputDir = new File(tempDir, "output").getAbsolutePath(); @@ -1108,7 +1094,6 @@ public void binaryRecords() throws Exception { } } - @SuppressWarnings("unchecked") @Test public void writeWithNewAPIHadoopFile() { String outputDir = new File(tempDir, "output").getAbsolutePath(); @@ -1159,7 +1144,6 @@ public void objectFilesOfInts() { assertEquals(expected, readRDD.collect()); } - @SuppressWarnings("unchecked") @Test public void objectFilesOfComplexTypes() { String outputDir = new File(tempDir, "output").getAbsolutePath(); @@ -1297,7 +1281,6 @@ public void combineByKey() { assertEquals(expected, results); } - @SuppressWarnings("unchecked") @Test public void mapOnPairRDD() { JavaRDD rdd1 = sc.parallelize(Arrays.asList(1,2,3,4)); @@ -1310,7 +1293,6 @@ public void mapOnPairRDD() { new Tuple2<>(0, 4)), rdd3.collect()); } - @SuppressWarnings("unchecked") @Test public void collectPartitions() { JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7), 3); @@ -1391,7 +1373,6 @@ public void collectAsMapAndSerialize() throws Exception { } @Test - @SuppressWarnings("unchecked") public void sampleByKey() { JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 3); JavaPairRDD rdd2 = rdd1.mapToPair(i -> new Tuple2<>(i % 2, 1)); @@ -1411,7 +1392,6 @@ public void sampleByKey() { } @Test - @SuppressWarnings("unchecked") public void sampleByKeyExact() { JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 3); JavaPairRDD rdd2 = rdd1.mapToPair(i -> new Tuple2<>(i % 2, 1)); diff --git a/core/src/test/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager b/core/src/test/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager index 33b162eb274c1..3ff68027f915d 100644 --- a/core/src/test/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager +++ b/core/src/test/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + org.apache.spark.scheduler.DummyExternalClusterManager org.apache.spark.scheduler.MockExternalClusterManager org.apache.spark.scheduler.CSMockExternalClusterManager diff --git a/core/src/test/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider b/core/src/test/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider index f4107befc825b..ed3908e95e4cb 100644 --- a/core/src/test/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider +++ b/core/src/test/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider @@ -1 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + org.apache.spark.deploy.security.ExceptionThrowingDelegationTokenProvider diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index f58777584d0ae..124a138ccf10f 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -51,7 +51,7 @@ class DriverSuite extends SparkFunSuite with TimeLimits { */ object DriverWithoutCleanup { def main(args: Array[String]): Unit = { - TestUtils.configTestLog4j("INFO") + TestUtils.configTestLog4j2("INFO") val conf = new SparkConf val sc = new SparkContext(args(0), "DriverWithoutCleanup", conf) sc.parallelize(1 to 100, 4).count() diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index f1f2b4fc70cdb..ac7670014eb9d 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -28,7 +28,7 @@ import com.google.common.io.Files import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io._ -import org.apache.hadoop.io.compress.{BZip2Codec, CompressionCodec, DefaultCodec, Lz4Codec} +import org.apache.hadoop.io.compress.{BZip2Codec, CompressionCodec, DefaultCodec, Lz4Codec, SnappyCodec} import org.apache.hadoop.mapred.{FileAlreadyExistsException, FileSplit, JobConf, TextInputFormat, TextOutputFormat} import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} @@ -136,8 +136,8 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { } // Hadoop "gzip" and "zstd" codecs require native library installed for sequence files - // "snappy" codec does not work due to SPARK-36681. - val codecs = Seq((new DefaultCodec(), "default"), (new BZip2Codec(), "bzip2")) ++ { + val codecs = Seq((new DefaultCodec(), "default"), (new BZip2Codec(), "bzip2"), + (new SnappyCodec(), "snappy")) ++ { if (VersionUtils.isHadoop3) Seq((new Lz4Codec(), "lz4")) else Seq() } codecs.foreach { case (codec, codecName) => diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 082a92ef41d3b..77bdb882c507d 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -20,7 +20,9 @@ package org.apache.spark import java.util.concurrent.{Semaphore, TimeUnit} import java.util.concurrent.atomic.AtomicInteger +// scalastyle:off executioncontextglobal import scala.concurrent.ExecutionContext.Implicits.global +// scalastyle:on executioncontextglobal import scala.concurrent.Future import scala.concurrent.duration._ diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 0ee2c77997973..5e502eb568759 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -855,7 +855,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { rpcEnv.shutdown() } - test("SPARK-37023: Avoid fetching merge status when shuffleMergeEnabled is false") { + test("SPARK-37023: Avoid fetching merge status when useMergeResult is false") { val newConf = new SparkConf newConf.set(PUSH_BASED_SHUFFLE_ENABLED, true) newConf.set(IS_TESTING, true) @@ -910,4 +910,32 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { rpcEnv.shutdown() slaveRpcEnv.shutdown() } + + test("SPARK-34826: Adaptive shuffle mergers") { + val newConf = new SparkConf + newConf.set("spark.shuffle.push.based.enabled", "true") + newConf.set("spark.shuffle.service.enabled", "true") + + // needs TorrentBroadcast so need a SparkContext + withSpark(new SparkContext("local", "MapOutputTrackerSuite", newConf)) { sc => + val masterTracker = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + val rpcEnv = sc.env.rpcEnv + val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) + rpcEnv.stop(masterTracker.trackerEndpoint) + rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) + + val worker = new MapOutputTrackerWorker(newConf) + worker.trackerEndpoint = + rpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + + masterTracker.registerShuffle(20, 100, 100) + worker.updateEpoch(masterTracker.getEpoch) + val mergerLocs = (1 to 10).map(x => BlockManagerId(s"exec-$x", s"host-$x", 7337)) + masterTracker.registerShufflePushMergerLocations(20, mergerLocs) + + assert(worker.getShufflePushMergerLocations(20).size == 10) + worker.unregisterShuffle(20) + assert(worker.shufflePushMergerLocations.isEmpty) + } + } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index 0c72f770a787c..3a615d0ea6cf1 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -32,15 +32,10 @@ class SparkContextSchedulerCreationSuite def noOp(taskSchedulerImpl: TaskSchedulerImpl): Unit = {} def createTaskScheduler(master: String)(body: TaskSchedulerImpl => Unit = noOp): Unit = - createTaskScheduler(master, "client")(body) - - def createTaskScheduler(master: String, deployMode: String)( - body: TaskSchedulerImpl => Unit): Unit = - createTaskScheduler(master, deployMode, new SparkConf())(body) + createTaskScheduler(master, new SparkConf())(body) def createTaskScheduler( master: String, - deployMode: String, conf: SparkConf)(body: TaskSchedulerImpl => Unit): Unit = { // Create local SparkContext to setup a SparkEnv. We don't actually want to start() the // real schedulers, so we don't want to create a full SparkContext with the desired scheduler. @@ -48,7 +43,7 @@ class SparkContextSchedulerCreationSuite val createTaskSchedulerMethod = PrivateMethod[Tuple2[SchedulerBackend, TaskScheduler]](Symbol("createTaskScheduler")) val (_, sched) = - SparkContext invokePrivate createTaskSchedulerMethod(sc, master, deployMode) + SparkContext invokePrivate createTaskSchedulerMethod(sc, master) try { body(sched.asInstanceOf[TaskSchedulerImpl]) } finally { @@ -132,7 +127,7 @@ class SparkContextSchedulerCreationSuite test("local-default-parallelism") { val conf = new SparkConf().set("spark.default.parallelism", "16") - val sched = createTaskScheduler("local", "client", conf) { sched => + val sched = createTaskScheduler("local", conf) { sched => sched.backend match { case s: LocalSchedulerBackend => assert(s.defaultParallelism() === 16) case _ => fail() diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 81b40a324d0de..02e67c0af1258 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -272,12 +272,14 @@ abstract class SparkFunSuite override def append(loggingEvent: LogEvent): Unit = loggingEvent.synchronized { val copyEvent = loggingEvent.toImmutable if (copyEvent.getLevel.isMoreSpecificThan(_threshold)) { - if (_loggingEvents.size >= maxEvents) { - val loggingInfo = if (msg == "") "." else s" while logging $msg." - throw new IllegalStateException( - s"Number of events reached the limit of $maxEvents$loggingInfo") + _loggingEvents.synchronized { + if (_loggingEvents.size >= maxEvents) { + val loggingInfo = if (msg == "") "." else s" while logging $msg." + throw new IllegalStateException( + s"Number of events reached the limit of $maxEvents$loggingInfo") + } + _loggingEvents.append(copyEvent) } - _loggingEvents.append(copyEvent) } } @@ -285,6 +287,8 @@ abstract class SparkFunSuite _threshold = threshold } - def loggingEvents: ArrayBuffer[LogEvent] = _loggingEvents.filterNot(_ == null) + def loggingEvents: ArrayBuffer[LogEvent] = _loggingEvents.synchronized { + _loggingEvents.filterNot(_ == null) + } } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 19e4875512a65..c5a72efcb786b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -23,7 +23,7 @@ import java.nio.charset.StandardCharsets import java.nio.file.{Files, Paths} import scala.collection.mutable.ArrayBuffer -import scala.io.Source +import scala.io.{Codec, Source} import com.google.common.io.ByteStreams import org.apache.commons.io.FileUtils @@ -647,7 +647,7 @@ class SparkSubmitSuite runSparkSubmit(args) val listStatus = fileSystem.listStatus(testDirPath) val logData = EventLogFileReader.openEventLog(listStatus.last.getPath, fileSystem) - Source.fromInputStream(logData).getLines().foreach { line => + Source.fromInputStream(logData)(Codec.UTF8).getLines().foreach { line => assert(!line.contains("secret_password")) } } @@ -1520,7 +1520,7 @@ class SparkSubmitSuite object JarCreationTest extends Logging { def main(args: Array[String]): Unit = { - TestUtils.configTestLog4j("INFO") + TestUtils.configTestLog4j2("INFO") val conf = new SparkConf() val sc = new SparkContext(conf) val result = sc.makeRDD(1 to 100, 10).mapPartitions { x => @@ -1544,7 +1544,7 @@ object JarCreationTest extends Logging { object SimpleApplicationTest { def main(args: Array[String]): Unit = { - TestUtils.configTestLog4j("INFO") + TestUtils.configTestLog4j2("INFO") val conf = new SparkConf() val sc = new SparkContext(conf) val configs = Seq("spark.master", "spark.app.name") diff --git a/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileWritersSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileWritersSuite.scala index e6dd9ae4224d9..455e2e18b11e1 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileWritersSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileWritersSuite.scala @@ -21,7 +21,7 @@ import java.io.{File, FileOutputStream, IOException} import java.net.URI import scala.collection.mutable -import scala.io.Source +import scala.io.{Codec, Source} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} @@ -114,7 +114,7 @@ abstract class EventLogFileWritersSuite extends SparkFunSuite with LocalSparkCon protected def readLinesFromEventLogFile(log: Path, fs: FileSystem): List[String] = { val logDataStream = EventLogFileReader.openEventLog(log, fs) try { - Source.fromInputStream(logDataStream).getLines().toList + Source.fromInputStream(logDataStream)(Codec.UTF8).getLines().toList } finally { logDataStream.close() } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerDiskManagerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerDiskManagerSuite.scala index a8f372932672c..c534d66c1571c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerDiskManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerDiskManagerSuite.scala @@ -26,13 +26,20 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.internal.config.History._ +import org.apache.spark.internal.config.History.HybridStoreDiskBackend import org.apache.spark.status.KVUtils -import org.apache.spark.tags.ExtendedLevelDBTest +import org.apache.spark.tags.{ExtendedLevelDBTest, ExtendedRocksDBTest} import org.apache.spark.util.{ManualClock, Utils} import org.apache.spark.util.kvstore.KVStore -@ExtendedLevelDBTest -class HistoryServerDiskManagerSuite extends SparkFunSuite with BeforeAndAfter { +abstract class HistoryServerDiskManagerSuite extends SparkFunSuite with BeforeAndAfter { + + protected def backend: HybridStoreDiskBackend.Value + + protected def extension: String + + protected def conf: SparkConf = new SparkConf() + .set(HYBRID_STORE_DISK_BACKEND, backend.toString) private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value, Seq.empty: _*) @@ -43,7 +50,7 @@ class HistoryServerDiskManagerSuite extends SparkFunSuite with BeforeAndAfter { before { testDir = Utils.createTempDir() - store = KVUtils.open(new File(testDir, "listing"), "test") + store = KVUtils.open(new File(testDir, "listing"), "test", conf) } after { @@ -212,4 +219,21 @@ class HistoryServerDiskManagerSuite extends SparkFunSuite with BeforeAndAfter { assert(store.read(classOf[ApplicationStoreInfo], dstC.getAbsolutePath).size === 2) } + test("SPARK-38095: appStorePath should use backend extensions") { + val conf = new SparkConf().set(HYBRID_STORE_DISK_BACKEND, backend.toString) + val manager = new HistoryServerDiskManager(conf, testDir, store, new ManualClock()) + assert(manager.appStorePath("appId", None).getName.endsWith(extension)) + } +} + +@ExtendedLevelDBTest +class HistoryServerDiskManagerUseLevelDBSuite extends HistoryServerDiskManagerSuite { + override protected def backend: HybridStoreDiskBackend.Value = HybridStoreDiskBackend.LEVELDB + override protected def extension: String = ".ldb" +} + +@ExtendedRocksDBTest +class HistoryServerDiskManagerUseRocksDBSuite extends HistoryServerDiskManagerSuite { + override protected def backend: HybridStoreDiskBackend.Value = HybridStoreDiskBackend.ROCKSDB + override protected def extension: String = ".rdb" } diff --git a/core/src/test/scala/org/apache/spark/metrics/sink/GraphiteSinkSuite.scala b/core/src/test/scala/org/apache/spark/metrics/sink/GraphiteSinkSuite.scala index cf34121fe73dc..3a6e7f4c12472 100644 --- a/core/src/test/scala/org/apache/spark/metrics/sink/GraphiteSinkSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/sink/GraphiteSinkSuite.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import com.codahale.metrics._ -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} class GraphiteSinkSuite extends SparkFunSuite { @@ -79,4 +79,42 @@ class GraphiteSinkSuite extends SparkFunSuite { assert(metricKeys.equals(filteredMetricKeys), "Should contain only metrics matches regex filter") } + + test("GraphiteSink without host") { + val props = new Properties + props.put("port", "54321") + val registry = new MetricRegistry + + val e = intercept[SparkException] { + new GraphiteSink(props, registry) + } + assert(e.getErrorClass === "GRAPHITE_SINK_PROPERTY_MISSING") + assert(e.getMessage === "Graphite sink requires 'host' property.") + } + + test("GraphiteSink without port") { + val props = new Properties + props.put("host", "127.0.0.1") + val registry = new MetricRegistry + + val e = intercept[SparkException] { + new GraphiteSink(props, registry) + } + assert(e.getErrorClass === "GRAPHITE_SINK_PROPERTY_MISSING") + assert(e.getMessage === "Graphite sink requires 'port' property.") + } + + test("GraphiteSink with invalid protocol") { + val props = new Properties + props.put("host", "127.0.0.1") + props.put("port", "54321") + props.put("protocol", "http") + val registry = new MetricRegistry + + val e = intercept[SparkException] { + new GraphiteSink(props, registry) + } + assert(e.getErrorClass === "GRAPHITE_SINK_INVALID_PROTOCOL") + assert(e.getMessage === "Invalid Graphite protocol: http") + } } diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index a5bc557eef5ad..93daf9032323d 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -20,7 +20,9 @@ package org.apache.spark.rdd import java.util.concurrent.Semaphore import scala.concurrent._ +// scalastyle:off executioncontextglobal import scala.concurrent.ExecutionContext.Implicits.global +// scalastyle:on executioncontextglobal import scala.concurrent.duration.Duration import org.scalatest.BeforeAndAfterAll diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala index 16a92f54f9368..7875cbcc0dfae 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala @@ -138,5 +138,5 @@ class RDDOperationScopeSuite extends SparkFunSuite with BeforeAndAfter { private class MyCoolRDD(sc: SparkContext) extends RDD[Int](sc, Nil) { override def getPartitions: Array[Partition] = Array.empty - override def compute(p: Partition, context: TaskContext): Iterator[Int] = { Nil.toIterator } + override def compute(p: Partition, context: TaskContext): Iterator[Int] = { Nil.iterator } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala index 4663717dc86be..e77ade60a61be 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala @@ -40,6 +40,7 @@ import org.apache.spark.resource.TestResourceIDs._ import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.{RpcUtils, SerializableBuffer, Utils} class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext @@ -189,6 +190,8 @@ class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkCo test("extra resources from executor") { + val testStartTime = System.currentTimeMillis() + val execCores = 3 val conf = new SparkConf() .set(EXECUTOR_CORES, execCores) @@ -207,6 +210,10 @@ class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkCo sc.resourceProfileManager.addResourceProfile(rp) assert(rp.id > ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) val backend = sc.schedulerBackend.asInstanceOf[TestCoarseGrainedSchedulerBackend] + // Note we get two in default profile and one in the new rp + // we need to put a req time in for all of them. + backend.requestTotalExecutors(Map((rp.id, 1)), Map(), Map()) + backend.requestExecutors(3) val mockEndpointRef = mock[RpcEndpointRef] val mockAddress = mock[RpcAddress] when(mockEndpointRef.send(LaunchTask)).thenAnswer((_: InvocationOnMock) => {}) @@ -214,8 +221,12 @@ class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkCo val resources = Map(GPU -> new ResourceInformation(GPU, Array("0", "1", "3"))) var executorAddedCount: Int = 0 + val infos = scala.collection.mutable.ArrayBuffer[ExecutorInfo]() val listener = new SparkListener() { override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = { + // Lets check that the exec allocation times "make sense" + val info = executorAdded.executorInfo + infos += info executorAddedCount += 1 } } @@ -271,8 +282,128 @@ class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkCo } sc.listenerBus.waitUntilEmpty(executorUpTimeout.toMillis) assert(executorAddedCount === 3) + infos.foreach { info => + assert(info.requestTime.get > 0, + "Exec allocation and request times don't make sense") + assert(info.requestTime.get > testStartTime, + "Exec allocation and request times don't make sense") + assert(info.registrationTime.get > info.requestTime.get, + "Exec allocation and request times don't make sense") + } } + test("exec alloc decrease.") { + + val testStartTime = System.currentTimeMillis() + + val execCores = 3 + val conf = new SparkConf() + .set(EXECUTOR_CORES, execCores) + .set(SCHEDULER_REVIVE_INTERVAL.key, "1m") // don't let it auto revive during test + .set(EXECUTOR_INSTANCES, 0) // avoid errors about duplicate executor registrations + .setMaster( + "coarseclustermanager[org.apache.spark.scheduler.TestCoarseGrainedSchedulerBackend]") + .setAppName("test") + conf.set(TASK_GPU_ID.amountConf, "1") + conf.set(EXECUTOR_GPU_ID.amountConf, "1") + + sc = new SparkContext(conf) + val execGpu = new ExecutorResourceRequests().cores(1).resource(GPU, 3) + val taskGpu = new TaskResourceRequests().cpus(1).resource(GPU, 1) + val rp = new ResourceProfile(execGpu.requests, taskGpu.requests) + sc.resourceProfileManager.addResourceProfile(rp) + assert(rp.id > ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) + val backend = sc.schedulerBackend.asInstanceOf[TestCoarseGrainedSchedulerBackend] + // Note we get two in default profile and one in the new rp + // we need to put a req time in for all of them. + backend.requestTotalExecutors(Map((rp.id, 1)), Map(), Map()) + // Decrease the number of execs requested in the new rp. + backend.requestTotalExecutors(Map((rp.id, 0)), Map(), Map()) + // Request execs in the default profile. + backend.requestExecutors(3) + val mockEndpointRef = mock[RpcEndpointRef] + val mockAddress = mock[RpcAddress] + when(mockEndpointRef.send(LaunchTask)).thenAnswer((_: InvocationOnMock) => {}) + + val resources = Map(GPU -> new ResourceInformation(GPU, Array("0", "1", "3"))) + + var executorAddedCount: Int = 0 + val infos = scala.collection.mutable.ArrayBuffer[ExecutorInfo]() + val listener = new SparkListener() { + override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = { + // Lets check that the exec allocation times "make sense" + val info = executorAdded.executorInfo + infos += info + executorAddedCount += 1 + } + } + + sc.addSparkListener(listener) + + backend.driverEndpoint.askSync[Boolean]( + RegisterExecutor("1", mockEndpointRef, mockAddress.host, 1, Map.empty, Map.empty, resources, + ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID)) + backend.driverEndpoint.askSync[Boolean]( + RegisterExecutor("2", mockEndpointRef, mockAddress.host, 1, Map.empty, Map.empty, resources, + ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID)) + backend.driverEndpoint.askSync[Boolean]( + RegisterExecutor("3", mockEndpointRef, mockAddress.host, 1, Map.empty, Map.empty, resources, + rp.id)) + + val frameSize = RpcUtils.maxMessageSizeBytes(sc.conf) + val bytebuffer = java.nio.ByteBuffer.allocate(frameSize - 100) + val buffer = new SerializableBuffer(bytebuffer) + + var execResources = backend.getExecutorAvailableResources("1") + assert(execResources(GPU).availableAddrs.sorted === Array("0", "1", "3")) + + val exec3ResourceProfileId = backend.getExecutorResourceProfileId("3") + assert(exec3ResourceProfileId === rp.id) + + val taskResources = Map(GPU -> new ResourceInformation(GPU, Array("0"))) + val taskDescs: Seq[Seq[TaskDescription]] = Seq(Seq(new TaskDescription(1, 0, "1", + "t1", 0, 1, mutable.Map.empty[String, Long], + mutable.Map.empty[String, Long], mutable.Map.empty[String, Long], + new Properties(), 1, taskResources, bytebuffer))) + val ts = backend.getTaskSchedulerImpl() + when(ts.resourceOffers(any[IndexedSeq[WorkerOffer]], any[Boolean])).thenReturn(taskDescs) + + backend.driverEndpoint.send(ReviveOffers) + + eventually(timeout(5 seconds)) { + execResources = backend.getExecutorAvailableResources("1") + assert(execResources(GPU).availableAddrs.sorted === Array("1", "3")) + assert(execResources(GPU).assignedAddrs === Array("0")) + } + + // To avoid allocating any resources immediately after releasing the resource from the task to + // make sure that `availableAddrs` below won't change + when(ts.resourceOffers(any[IndexedSeq[WorkerOffer]], any[Boolean])).thenReturn(Seq.empty) + backend.driverEndpoint.send( + StatusUpdate("1", 1, TaskState.FINISHED, buffer, taskResources)) + + eventually(timeout(5 seconds)) { + execResources = backend.getExecutorAvailableResources("1") + assert(execResources(GPU).availableAddrs.sorted === Array("0", "1", "3")) + assert(execResources(GPU).assignedAddrs.isEmpty) + } + sc.listenerBus.waitUntilEmpty(executorUpTimeout.toMillis) + assert(executorAddedCount === 3) + infos.foreach { info => + info.requestTime.map { t => + assert(t > 0, + "Exec request times don't make sense") + assert(t >= testStartTime, + "Exec allocation and request times don't make sense") + assert(t >= info.requestTime.get, + "Exec allocation and request times don't make sense") + } + } + assert(infos.filter(_.requestTime.isEmpty).length === 1, + "Our unexpected executor does not have a request time.") + } + + private def testSubmitJob(sc: SparkContext, rdd: RDD[Int]): Unit = { sc.submitJob( rdd, diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 76612cb605835..023e352ba1b02 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -3613,8 +3613,8 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti val shuffleStage2 = scheduler.stageIdToStage(1).asInstanceOf[ShuffleMapStage] assert(shuffleStage2.shuffleDep.getMergerLocs.nonEmpty) - assert(shuffleStage2.shuffleDep.shuffleMergeFinalized) - assert(shuffleStage1.shuffleDep.shuffleMergeFinalized) + assert(shuffleStage2.shuffleDep.isShuffleMergeFinalizedMarked) + assert(shuffleStage1.shuffleDep.isShuffleMergeFinalizedMarked) assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep1.shuffleId) == parts) assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep2.shuffleId) == parts) @@ -3671,7 +3671,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti completeShuffleMapStageSuccessfully(0, 0, parts) val shuffleStage = scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage] - assert(!shuffleStage.shuffleDep.shuffleMergeEnabled) + assert(shuffleStage.shuffleDep.mergerLocs.isEmpty) completeNextResultStageWithSuccess(1, 0) @@ -3686,14 +3686,13 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti completeNextStageWithFetchFailure(3, 0, shuffleDep) scheduler.resubmitFailedStages() - // Make sure shuffle merge is disabled for the retry val stage2 = scheduler.stageIdToStage(2).asInstanceOf[ShuffleMapStage] - assert(!stage2.shuffleDep.shuffleMergeEnabled) + assert(stage2.shuffleDep.shuffleMergeEnabled) // the scheduler now creates a new task set to regenerate the missing map output, but this time // using a different stage, the "skipped" one assert(scheduler.stageIdToStage(2).latestInfo.taskMetrics != null) - completeShuffleMapStageSuccessfully(2, 1, 2) + completeShuffleMapStageSuccessfully(2, 1, parts) completeNextResultStageWithSuccess(3, 1, idx => idx + 1234) val expected = (0 until parts).map(idx => (idx, idx + 1234)) @@ -3798,7 +3797,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti submit(reduceRdd, (0 until parts).toArray) completeShuffleMapStageSuccessfully(0, 0, reduceRdd.partitions.length) val shuffleMapStage = scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage] - assert(!shuffleMapStage.shuffleDep.shuffleMergeEnabled) + assert(!shuffleMapStage.shuffleDep.shuffleMergeAllowed) } test("SPARK-32920: metadata fetch failure should not unregister map status") { @@ -3926,7 +3925,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti val finalizeTask1 = shuffleStage1.shuffleDep.getFinalizeTask.get .asInstanceOf[DummyScheduledFuture] assert(finalizeTask1.delay == 10 && finalizeTask1.registerMergeResults) - assert(shuffleStage1.shuffleDep.shuffleMergeFinalized) + assert(shuffleStage1.shuffleDep.isShuffleMergeFinalizedMarked) complete(taskSets(1), taskSets(1).tasks.zipWithIndex.map { case (_, idx) => @@ -4051,8 +4050,8 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti runEvent(StageCancelled(0, Option("Explicit cancel check"))) scheduler.handleShuffleMergeFinalized(shuffleStage1, shuffleStage1.shuffleDep.shuffleMergeId) - assert(shuffleStage1.shuffleDep.shuffleMergeEnabled) - assert(!shuffleStage1.shuffleDep.shuffleMergeFinalized) + assert(shuffleStage1.shuffleDep.mergerLocs.nonEmpty) + assert(!shuffleStage1.shuffleDep.isShuffleMergeFinalizedMarked) assert(mapOutputTracker. getNumAvailableMergeResults(shuffleStage1.shuffleDep.shuffleId) == 0) @@ -4082,7 +4081,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti assert(shuffleIndeterminateStage.isIndeterminate) scheduler.handleShuffleMergeFinalized(shuffleIndeterminateStage, 2) assert(shuffleIndeterminateStage.shuffleDep.shuffleMergeEnabled) - assert(!shuffleIndeterminateStage.shuffleDep.shuffleMergeFinalized) + assert(!shuffleIndeterminateStage.shuffleDep.isShuffleMergeFinalizedMarked) } // With Adaptive shuffle merge finalization, once minimum shuffle pushes complete after stage @@ -4130,7 +4129,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti } val shuffleStage1 = scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage] assert(shuffleStage1.shuffleDep.shuffleMergeEnabled) - assert(!shuffleStage1.shuffleDep.shuffleMergeFinalized) + assert(!shuffleStage1.shuffleDep.isShuffleMergeFinalizedMarked) val finalizeTask1 = shuffleStage1.shuffleDep.getFinalizeTask.get. asInstanceOf[DummyScheduledFuture] assert(finalizeTask1.delay == 10 && finalizeTask1.registerMergeResults) @@ -4147,7 +4146,206 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti assert(finalizeTask2.delay == 0 && finalizeTask2.registerMergeResults) } - /** + test("SPARK-34826: Adaptively fetch shuffle mergers") { + initPushBasedShuffleConfs(conf) + conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 2) + DAGSchedulerSuite.clearMergerLocs() + DAGSchedulerSuite.addMergerLocs(Seq("host1")) + val parts = 2 + + val shuffleMapRdd = new MyRDD(sc, parts, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(parts)) + val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = mapOutputTracker) + + // Submit a reduce job that depends which will create a map stage + submit(reduceRdd, (0 until parts).toArray) + + runEvent(makeCompletionEvent( + taskSets(0).tasks(0), Success, makeMapStatus("hostA", parts), + Seq.empty, Array.empty, createFakeTaskInfoWithId(0))) + + val shuffleStage1 = scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage] + assert(!shuffleStage1.shuffleDep.shuffleMergeEnabled) + assert(mapOutputTracker.getShufflePushMergerLocations(0).isEmpty) + + DAGSchedulerSuite.addMergerLocs(Seq("host2", "host3")) + + // host2 executor added event to trigger registering of shuffle merger locations + // as shuffle mergers are tracked separately for test + runEvent(ExecutorAdded("exec2", "host2")) + + // Check if new shuffle merger locations are available for push or not + assert(mapOutputTracker.getShufflePushMergerLocations(0).size == 2) + assert(shuffleStage1.shuffleDep.getMergerLocs.size == 2) + + // Complete remaining tasks in ShuffleMapStage 0 + runEvent(makeCompletionEvent(taskSets(0).tasks(1), Success, + makeMapStatus("host1", parts), Seq.empty, Array.empty, createFakeTaskInfoWithId(1))) + + completeNextResultStageWithSuccess(1, 0) + assert(results === Map(0 -> 42, 1 -> 42)) + + results.clear() + assertDataStructuresEmpty() + } + + test("SPARK-34826: Adaptively fetch shuffle mergers with stage retry") { + initPushBasedShuffleConfs(conf) + conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 2) + DAGSchedulerSuite.clearMergerLocs() + DAGSchedulerSuite.addMergerLocs(Seq("host1")) + val parts = 2 + + val shuffleMapRdd1 = new MyRDD(sc, parts, Nil) + val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(parts)) + val shuffleMapRdd2 = new MyRDD(sc, parts, Nil) + val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new HashPartitioner(parts)) + val reduceRdd = new MyRDD(sc, parts, List(shuffleDep1, shuffleDep2), + tracker = mapOutputTracker) + + // Submit a reduce job that depends which will create a map stage + submit(reduceRdd, (0 until parts).toArray) + + val taskResults = taskSets(0).tasks.zipWithIndex.map { + case (_, idx) => + (Success, makeMapStatus("host" + idx, parts)) + }.toSeq + + val shuffleStage1 = scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage] + DAGSchedulerSuite.addMergerLocs(Seq("host2", "host3")) + // host2 executor added event to trigger registering of shuffle merger locations + // as shuffle mergers are tracked separately for test + runEvent(ExecutorAdded("exec2", "host2")) + // Check if new shuffle merger locations are available for push or not + assert(mapOutputTracker.getShufflePushMergerLocations(0).size == 2) + assert(shuffleStage1.shuffleDep.getMergerLocs.size == 2) + val mergerLocsBeforeRetry = shuffleStage1.shuffleDep.getMergerLocs + + // Clear merger locations to check if new mergers are not getting set for the + // retry of determinate stage + DAGSchedulerSuite.clearMergerLocs() + + // Remove MapStatus on one of the host before the stage ends to trigger + // a scenario where stage 0 needs to be resubmitted upon finishing all tasks. + // Merge finalization should be scheduled in this case. + for ((result, i) <- taskResults.zipWithIndex) { + if (i == taskSets(0).tasks.size - 1) { + mapOutputTracker.removeOutputsOnHost("host0") + } + runEvent(makeCompletionEvent(taskSets(0).tasks(i), result._1, result._2)) + } + assert(shuffleStage1.shuffleDep.isShuffleMergeFinalizedMarked) + + DAGSchedulerSuite.addMergerLocs(Seq("host4", "host5")) + // host4 executor added event shouldn't reset merger locations given merger locations + // are already set + runEvent(ExecutorAdded("exec4", "host4")) + + // Successfully completing the retry of stage 0. + complete(taskSets(2), taskSets(2).tasks.zipWithIndex.map { + case (_, idx) => + (Success, makeMapStatus("host" + idx, parts)) + }.toSeq) + + assert(shuffleStage1.shuffleDep.shuffleMergeId == 0) + assert(shuffleStage1.shuffleDep.getMergerLocs.size == 2) + assert(shuffleStage1.shuffleDep.isShuffleMergeFinalizedMarked) + val newMergerLocs = + scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage].shuffleDep.getMergerLocs + assert(mergerLocsBeforeRetry.sortBy(_.host) === newMergerLocs.sortBy(_.host)) + val shuffleStage2 = scheduler.stageIdToStage(1).asInstanceOf[ShuffleMapStage] + complete(taskSets(1), taskSets(1).tasks.zipWithIndex.map { + case (_, idx) => + (Success, makeMapStatus("host" + idx, parts, 10)) + }.toSeq) + assert(shuffleStage2.shuffleDep.getMergerLocs.size == 2) + completeNextResultStageWithSuccess(2, 0) + assert(results === Map(0 -> 42, 1 -> 42)) + + results.clear() + assertDataStructuresEmpty() + } + + test("SPARK-34826: Adaptively fetch shuffle mergers with stage retry for indeterminate stage") { + initPushBasedShuffleConfs(conf) + conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 2) + DAGSchedulerSuite.clearMergerLocs() + DAGSchedulerSuite.addMergerLocs(Seq("host1")) + val parts = 2 + + val shuffleMapRdd1 = new MyRDD(sc, parts, Nil, indeterminate = true) + val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(parts)) + val shuffleMapRdd2 = new MyRDD(sc, parts, Nil, indeterminate = true) + val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new HashPartitioner(parts)) + val reduceRdd = new MyRDD(sc, parts, List(shuffleDep1, shuffleDep2), + tracker = mapOutputTracker) + + // Submit a reduce job that depends which will create a map stage + submit(reduceRdd, (0 until parts).toArray) + + val taskResults = taskSets(0).tasks.zipWithIndex.map { + case (_, idx) => + (Success, makeMapStatus("host" + idx, parts)) + }.toSeq + + val shuffleStage1 = scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage] + DAGSchedulerSuite.addMergerLocs(Seq("host2", "host3")) + // host2 executor added event to trigger registering of shuffle merger locations + // as shuffle mergers are tracked separately for test + runEvent(ExecutorAdded("exec2", "host2")) + // Check if new shuffle merger locations are available for push or not + assert(mapOutputTracker.getShufflePushMergerLocations(0).size == 2) + assert(shuffleStage1.shuffleDep.getMergerLocs.size == 2) + val mergerLocsBeforeRetry = shuffleStage1.shuffleDep.getMergerLocs + + // Clear merger locations to check if new mergers are getting set for the + // retry of indeterminate stage + DAGSchedulerSuite.clearMergerLocs() + + // Remove MapStatus on one of the host before the stage ends to trigger + // a scenario where stage 0 needs to be resubmitted upon finishing all tasks. + // Merge finalization should be scheduled in this case. + for ((result, i) <- taskResults.zipWithIndex) { + if (i == taskSets(0).tasks.size - 1) { + mapOutputTracker.removeOutputsOnHost("host0") + } + runEvent(makeCompletionEvent(taskSets(0).tasks(i), result._1, result._2)) + } + + // Indeterminate stage should recompute all partitions, hence + // isShuffleMergeFinalizedMarked should be false here + assert(!shuffleStage1.shuffleDep.isShuffleMergeFinalizedMarked) + + DAGSchedulerSuite.addMergerLocs(Seq("host4", "host5")) + // host4 executor added event should reset merger locations given merger locations + // are already reset + runEvent(ExecutorAdded("exec4", "host4")) + assert(shuffleStage1.shuffleDep.getMergerLocs.size == 2) + // Successfully completing the retry of stage 0. + complete(taskSets(2), taskSets(2).tasks.zipWithIndex.map { + case (_, idx) => + (Success, makeMapStatus("host" + idx, parts)) + }.toSeq) + + assert(shuffleStage1.shuffleDep.shuffleMergeId == 2) + assert(shuffleStage1.shuffleDep.isShuffleMergeFinalizedMarked) + val newMergerLocs = + scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage].shuffleDep.getMergerLocs + assert(mergerLocsBeforeRetry.sortBy(_.host) !== newMergerLocs.sortBy(_.host)) + val shuffleStage2 = scheduler.stageIdToStage(1).asInstanceOf[ShuffleMapStage] + complete(taskSets(1), taskSets(1).tasks.zipWithIndex.map { + case (_, idx) => + (Success, makeMapStatus("host" + idx, parts, 10)) + }.toSeq) + assert(shuffleStage2.shuffleDep.getMergerLocs.size == 2) + completeNextResultStageWithSuccess(2, 0) + assert(results === Map(0 -> 42, 1 -> 42)) + + results.clear() + assertDataStructuresEmpty() + } + + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. */ diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index b06e83e291c0a..edb2095004f71 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -23,7 +23,7 @@ import java.util.{Arrays, Properties} import scala.collection.immutable.Map import scala.collection.mutable import scala.collection.mutable.Set -import scala.io.Source +import scala.io.{Codec, Source} import org.apache.hadoop.fs.Path import org.json4s.jackson.JsonMethods._ @@ -661,7 +661,7 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit } private def readLines(in: InputStream): Seq[String] = { - Source.fromInputStream(in).getLines().toSeq + Source.fromInputStream(in)(Codec.UTF8).getLines().toSeq } /** diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index 23cc416f8572f..fe76b1bc322cd 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.LocalSparkContext._ import org.apache.spark.internal.config import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.Utils class MapStatusSuite extends SparkFunSuite { private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value, Seq.empty: _*) @@ -191,4 +192,100 @@ class MapStatusSuite extends SparkFunSuite { assert(count === 3000) } } + + def compressAndDecompressSize(size: Long): Long = { + MapStatus.decompressSize(MapStatus.compressSize(size)) + } + + test("SPARK-36967: HighlyCompressedMapStatus should record accurately the size " + + "of skewed shuffle blocks") { + val emptyBlocksLength = 3 + val smallAndUntrackedBlocksLength = 2889 + val trackedSkewedBlocksLength = 20 + + val conf = new SparkConf().set(config.SHUFFLE_ACCURATE_BLOCK_SKEWED_FACTOR.key, "5") + val env = mock(classOf[SparkEnv]) + doReturn(conf).when(env).conf + SparkEnv.set(env) + + val emptyBlocks = Array.fill[Long](emptyBlocksLength)(0L) + val smallAndUntrackedBlocks = Array.tabulate[Long](smallAndUntrackedBlocksLength)(i => i) + val trackedSkewedBlocks = + Array.tabulate[Long](trackedSkewedBlocksLength)(i => i + 350 * 1024) + val allBlocks = emptyBlocks ++: smallAndUntrackedBlocks ++: trackedSkewedBlocks + val avg = smallAndUntrackedBlocks.sum / smallAndUntrackedBlocks.length + val loc = BlockManagerId("a", "b", 10) + val mapTaskAttemptId = 5 + val status = MapStatus(loc, allBlocks, mapTaskAttemptId) + val status1 = compressAndDecompressMapStatus(status) + assert(status1.isInstanceOf[HighlyCompressedMapStatus]) + assert(status1.location == loc) + assert(status1.mapId == mapTaskAttemptId) + assert(status1.getSizeForBlock(0) == 0) + for (i <- 1 until emptyBlocksLength) { + assert(status1.getSizeForBlock(i) === 0L) + } + for (i <- 1 until smallAndUntrackedBlocksLength) { + assert(status1.getSizeForBlock(emptyBlocksLength + i) === avg) + } + for (i <- 0 until trackedSkewedBlocksLength) { + assert(status1.getSizeForBlock(emptyBlocksLength + smallAndUntrackedBlocksLength + i) === + compressAndDecompressSize(trackedSkewedBlocks(i)), + "Only tracked skewed block size is accurate") + } + } + + test("SPARK-36967: Limit accurate skewed block number if too many blocks are skewed") { + val accurateBlockSkewedFactor = 5 + val emptyBlocksLength = 3 + val smallBlocksLength = 2500 + val untrackedSkewedBlocksLength = 500 + val trackedSkewedBlocksLength = 20 + + val conf = + new SparkConf() + .set(config.SHUFFLE_ACCURATE_BLOCK_SKEWED_FACTOR.key, accurateBlockSkewedFactor.toString) + .set( + config.SHUFFLE_MAX_ACCURATE_SKEWED_BLOCK_NUMBER.key, + trackedSkewedBlocksLength.toString) + val env = mock(classOf[SparkEnv]) + doReturn(conf).when(env).conf + SparkEnv.set(env) + + val emptyBlocks = Array.fill[Long](emptyBlocksLength)(0L) + val smallBlockSizes = Array.tabulate[Long](smallBlocksLength)(i => i + 1) + val untrackedSkewedBlocksSizes = + Array.tabulate[Long](untrackedSkewedBlocksLength)(i => i + 3500 * 1024) + val trackedSkewedBlocksSizes = + Array.tabulate[Long](trackedSkewedBlocksLength)(i => i + 4500 * 1024) + val nonEmptyBlocks = + smallBlockSizes ++: untrackedSkewedBlocksSizes ++: trackedSkewedBlocksSizes + val allBlocks = emptyBlocks ++: nonEmptyBlocks + + val skewThreshold = Utils.median(allBlocks, false) * accurateBlockSkewedFactor + assert(nonEmptyBlocks.filter(_ > skewThreshold).size == + untrackedSkewedBlocksLength + trackedSkewedBlocksLength, + "number of skewed block sizes") + + val smallAndUntrackedBlocks = + nonEmptyBlocks.slice(0, nonEmptyBlocks.size - trackedSkewedBlocksLength) + val avg = smallAndUntrackedBlocks.sum / smallAndUntrackedBlocks.length + + val loc = BlockManagerId("a", "b", 10) + val mapTaskAttemptId = 5 + val status = MapStatus(loc, allBlocks, mapTaskAttemptId) + val status1 = compressAndDecompressMapStatus(status) + assert(status1.isInstanceOf[HighlyCompressedMapStatus]) + assert(status1.location == loc) + assert(status1.mapId == mapTaskAttemptId) + assert(status1.getSizeForBlock(0) == 0) + for (i <- emptyBlocksLength until allBlocks.length - trackedSkewedBlocksLength) { + assert(status1.getSizeForBlock(i) === avg) + } + for (i <- 0 until trackedSkewedBlocksLength) { + assert(status1.getSizeForBlock(allBlocks.length - trackedSkewedBlocksLength + i) === + compressAndDecompressSize(trackedSkewedBlocksSizes(i)), + "Only tracked skewed block size is accurate") + } + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala index ac4ed13b25488..9ed26e712563e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala @@ -136,7 +136,9 @@ abstract class SchedulerIntegrationSuite[T <: MockBackend: ClassTag] extends Spa func: (TaskContext, Iterator[_]) => _ = jobComputeFunc): Future[Any] = { val waiter: JobWaiter[Any] = scheduler.submitJob(rdd, func, partitions.toSeq, CallSite("", ""), (index, res) => results(index) = res, new Properties()) + // scalastyle:off executioncontextglobal import scala.concurrent.ExecutionContext.Implicits.global + // scalastyle:on executioncontextglobal waiter.completionFuture.recover { case ex => failure = ex } @@ -697,7 +699,9 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor withBackend(runBackend _) { // Submit a job containing an RDD which will hang in getPartitions() until we release // the countdown latch: + // scalastyle:off executioncontextglobal import scala.concurrent.ExecutionContext.Implicits.global + // scalastyle:on executioncontextglobal val slowJobFuture = Future { submit(rddWithSlowGetPartitions, Array(0)) }.flatten // Block the current thread until the other thread has started the getPartitions() call: diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala index c84735c9665a7..8b81468406bbb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala @@ -52,6 +52,7 @@ class SparkListenerWithClusterSuite extends SparkFunSuite with LocalSparkContext assert(listener.addedExecutorInfo.size == 2) assert(listener.addedExecutorInfo("0").totalCores == 1) assert(listener.addedExecutorInfo("1").totalCores == 1) + assert(listener.addedExecutorInfo("0").registrationTime.get > 0 ) } private class SaveExecutorInfo extends SparkListener { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 3d80a69246cc0..360a14b031139 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -2244,6 +2244,81 @@ class TaskSetManagerSuite // After 3s have elapsed now the task is marked as speculative task assert(sched.speculativeTasks.size == 1) } + + test("SPARK-37580: Reset numFailures when one of task attempts succeeds") { + sc = new SparkContext("local", "test") + // Set the speculation multiplier to be 0 so speculative tasks are launched immediately + sc.conf.set(config.SPECULATION_MULTIPLIER, 0.0) + sc.conf.set(config.SPECULATION_QUANTILE, 0.6) + sc.conf.set(config.SPECULATION_ENABLED, true) + + sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3")) + sched.backend = mock(classOf[SchedulerBackend]) + val taskSet = FakeTask.createTaskSet(3) + val clock = new ManualClock() + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) + + // Offer resources for 3 task to start + val tasks = new ArrayBuffer[TaskDescription]() + for ((k, v) <- List("exec1" -> "host1", "exec2" -> "host2", "exec3" -> "host3")) { + val taskOption = manager.resourceOffer(k, v, NO_PREF)._1 + assert(taskOption.isDefined) + val task = taskOption.get + assert(task.executorId === k) + tasks += task + } + assert(sched.startedTasks.toSet === (0 until 3).toSet) + + def runningTaskForIndex(index: Int): TaskDescription = { + tasks.find { task => + task.index == index && !sched.endedTasks.contains(task.taskId) + }.getOrElse { + throw new RuntimeException(s"couldn't find index $index in " + + s"tasks: ${tasks.map { t => t.index -> t.taskId }} with endedTasks:" + + s" ${sched.endedTasks.keys}") + } + } + clock.advance(1) + + // running task with index 1 fail 3 times (not enough to abort the stage) + (0 until 3).foreach { attempt => + val task = runningTaskForIndex(1) + val endReason = ExceptionFailure("a", "b", Array(), "c", None) + manager.handleFailedTask(task.taskId, TaskState.FAILED, endReason) + sched.endedTasks(task.taskId) = endReason + assert(!manager.isZombie) + val nextTask = manager.resourceOffer(s"exec2", s"host2", NO_PREF)._1 + assert(nextTask.isDefined, s"no offer for attempt $attempt of 1") + tasks += nextTask.get + } + + val numFailuresField = classOf[TaskSetManager].getDeclaredField("numFailures") + numFailuresField.setAccessible(true) + val numFailures = numFailuresField.get(manager).asInstanceOf[Array[Int]] + // numFailures(1) should be 3 + assert(numFailures(1) == 3) + + // make task(TID 2) success to speculative other tasks + manager.handleSuccessfulTask(2, createTaskResult(2)) + + val originalTask = runningTaskForIndex(1) + clock.advance(1) + assert(manager.checkSpeculatableTasks(0)) + assert(sched.speculativeTasks.toSet === Set(0, 1)) + + // make the speculative task(index 1) success + val speculativeTask = manager.resourceOffer("exec1", "host1", NO_PREF)._1 + assert(speculativeTask.isDefined) + manager.handleSuccessfulTask(speculativeTask.get.taskId, createTaskResult(1)) + // if task success, numFailures will be reset to 0 + assert(numFailures(1) == 0) + + // failed the originalTask(index 1) and check if the task manager is zombie + val failedReason = ExceptionFailure("a", "b", Array(), "c", None) + manager.handleFailedTask(originalTask.taskId, TaskState.FAILED, failedReason) + assert(!manager.isZombie) + } + } class FakeLongTasks(stageId: Int, partitionId: Int) extends FakeTask(stageId, partitionId) { diff --git a/core/src/test/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitorSuite.scala index 69afdb57ef404..c8916dcd6eb4f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitorSuite.scala @@ -169,6 +169,8 @@ class ExecutorMonitorSuite extends SparkFunSuite { } test("keeps track of stored blocks for each rdd and split") { + knownExecs ++= Set("1", "2") + monitor.onExecutorAdded(SparkListenerExecutorAdded(clock.getTimeMillis(), "1", execInfo)) monitor.onBlockUpdated(rddUpdate(1, 0, "1")) @@ -233,7 +235,23 @@ class ExecutorMonitorSuite extends SparkFunSuite { assert(monitor.timedOutExecutors(clock.nanoTime()).toSet === Set("1", "2", "3")) } + test("SPARK-38019: timedOutExecutors should be deterministic") { + knownExecs ++= Set("1", "2", "3") + + // start exec 1, 2, 3 at 0s (should idle time out at 60s) + monitor.onExecutorAdded(SparkListenerExecutorAdded(clock.getTimeMillis(), "1", execInfo)) + assert(monitor.isExecutorIdle("1")) + monitor.onExecutorAdded(SparkListenerExecutorAdded(clock.getTimeMillis(), "2", execInfo)) + assert(monitor.isExecutorIdle("2")) + monitor.onExecutorAdded(SparkListenerExecutorAdded(clock.getTimeMillis(), "3", execInfo)) + assert(monitor.isExecutorIdle("3")) + + clock.setTime(TimeUnit.SECONDS.toMillis(150)) + assert(monitor.timedOutExecutors().map(_._1) === Seq("1", "2", "3")) + } + test("SPARK-27677: don't track blocks stored on disk when using shuffle service") { + knownExecs += "1" // First make sure that blocks on disk are counted when no shuffle service is available. monitor.onExecutorAdded(SparkListenerExecutorAdded(clock.getTimeMillis(), "1", execInfo)) monitor.onBlockUpdated(rddUpdate(1, 0, "1", level = StorageLevel.DISK_ONLY)) @@ -267,7 +285,7 @@ class ExecutorMonitorSuite extends SparkFunSuite { knownExecs ++= Set("1", "2", "3") val execInfoRp1 = new ExecutorInfo("host1", 1, Map.empty, - Map.empty, Map.empty, 1) + Map.empty, Map.empty, 1, None, None) monitor.onExecutorAdded(SparkListenerExecutorAdded(clock.getTimeMillis(), "1", execInfo)) monitor.onExecutorAdded(SparkListenerExecutorAdded(clock.getTimeMillis(), "2", execInfo)) @@ -443,6 +461,22 @@ class ExecutorMonitorSuite extends SparkFunSuite { assert(monitor.timedOutExecutors(idleDeadline).isEmpty) } + test("SPARK-37688: ignore SparkListenerBlockUpdated event if executor was not active") { + conf + .set(DYN_ALLOCATION_SHUFFLE_TRACKING_TIMEOUT, Long.MaxValue) + .set(DYN_ALLOCATION_SHUFFLE_TRACKING_ENABLED, true) + .set(SHUFFLE_SERVICE_ENABLED, false) + monitor = new ExecutorMonitor(conf, client, null, clock, allocationManagerSource()) + + monitor.onExecutorAdded(SparkListenerExecutorAdded(clock.getTimeMillis(), "1", execInfo)) + monitor.onExecutorRemoved(SparkListenerExecutorRemoved(clock.getTimeMillis(), "1", + "heartbeats timeout")) + monitor.onBlockUpdated(rddUpdate(1, 1, "1", level = StorageLevel.MEMORY_AND_DISK)) + + assert(monitor.executorCount == 0 ) + } + + private def idleDeadline: Long = clock.nanoTime() + idleTimeoutNs + 1 private def storageDeadline: Long = clock.nanoTime() + storageTimeoutNs + 1 private def shuffleDeadline: Long = clock.nanoTime() + shuffleTimeoutNs + 1 diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerBenchmark.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerBenchmark.scala index 3814d2b6fb475..28e0e79a6fd7e 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerBenchmark.scala @@ -18,7 +18,9 @@ package org.apache.spark.serializer import scala.concurrent._ +// scalastyle:off executioncontextglobal import scala.concurrent.ExecutionContext.Implicits.global +// scalastyle:on executioncontextglobal import scala.concurrent.duration._ import org.apache.spark.{SparkConf, SparkContext} diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index d964b28df2983..56b8e0b6df3fd 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -111,7 +111,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) (shuffleBlockId, byteOutputStream.size().toLong, mapId) } - Seq((localBlockManagerId, shuffleBlockIdsAndSizes)).toIterator + Seq((localBlockManagerId, shuffleBlockIdsAndSizes)).iterator } // Create a mocked shuffle handle to pass into HashShuffleReader. diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala index 6c13c7c8c3c61..9e52b5e15143b 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala @@ -103,7 +103,7 @@ class SortShuffleWriterSuite mapId = 2, context, shuffleExecutorComponents) - writer.write(records.toIterator) + writer.write(records.iterator) writer.stop(success = true) val dataFile = shuffleBlockResolver.getDataFile(shuffleId, 2) val writeMetrics = context.taskMetrics().shuffleWriteMetrics @@ -160,7 +160,7 @@ class SortShuffleWriterSuite context, new LocalDiskShuffleExecutorComponents( conf, shuffleBlockResolver._blockManager, shuffleBlockResolver)) - writer.write(records.toIterator) + writer.write(records.iterator) val sorterMethod = PrivateMethod[ExternalSorter[_, _, _]](Symbol("sorter")) val sorter = writer.invokePrivate(sorterMethod()) val expectSpillSize = if (doSpill) records.size else 0 diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index c6db626121fa2..5e2e931c37689 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -28,6 +28,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark._ import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} +import org.apache.spark.internal.config.History.{HYBRID_STORE_DISK_BACKEND, HybridStoreDiskBackend} import org.apache.spark.internal.config.Status._ import org.apache.spark.metrics.ExecutorMetricType import org.apache.spark.resource.ResourceProfile @@ -36,15 +37,11 @@ import org.apache.spark.scheduler.cluster._ import org.apache.spark.status.ListenerEventsTestHelper._ import org.apache.spark.status.api.v1 import org.apache.spark.storage._ -import org.apache.spark.tags.ExtendedLevelDBTest +import org.apache.spark.tags.{ExtendedLevelDBTest, ExtendedRocksDBTest} import org.apache.spark.util.Utils import org.apache.spark.util.kvstore.{InMemoryStore, KVStore} -@ExtendedLevelDBTest -class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { - private val conf = new SparkConf() - .set(LIVE_ENTITY_UPDATE_PERIOD, 0L) - .set(ASYNC_TRACKING_ENABLED, false) +abstract class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { private val twoReplicaMemAndDiskLevel = StorageLevel(true, true, false, true, 2) @@ -53,7 +50,11 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { private var store: ElementTrackingStore = _ private var taskIdTracker = -1L - protected def createKVStore: KVStore = KVUtils.open(testDir, getClass().getName()) + protected def conf: SparkConf = new SparkConf() + .set(LIVE_ENTITY_UPDATE_PERIOD, 0L) + .set(ASYNC_TRACKING_ENABLED, false) + + protected def createKVStore: KVStore = KVUtils.open(testDir, getClass().getName(), conf) before { time = 0L @@ -1891,3 +1892,15 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { class AppStatusListenerWithInMemoryStoreSuite extends AppStatusListenerSuite { override def createKVStore: KVStore = new InMemoryStore() } + +@ExtendedLevelDBTest +class AppStatusListenerWithLevelDBSuite extends AppStatusListenerSuite { + override def conf: SparkConf = super.conf + .set(HYBRID_STORE_DISK_BACKEND, HybridStoreDiskBackend.LEVELDB.toString) +} + +@ExtendedRocksDBTest +class AppStatusListenerWithRocksDBSuite extends AppStatusListenerSuite { + override def conf: SparkConf = super.conf + .set(HYBRID_STORE_DISK_BACKEND, HybridStoreDiskBackend.ROCKSDB.toString) +} diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala index 422d80976867d..53b01313d5d4c 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala @@ -17,8 +17,11 @@ package org.apache.spark.status +import scala.util.Random + import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.executor.TaskMetrics +import org.apache.spark.internal.config.History.{HYBRID_STORE_DISK_BACKEND, HybridStoreDiskBackend} import org.apache.spark.internal.config.Status.LIVE_ENTITY_UPDATE_PERIOD import org.apache.spark.resource.ResourceProfile import org.apache.spark.scheduler.{SparkListenerStageSubmitted, SparkListenerTaskStart, StageInfo, TaskInfo, TaskLocality} @@ -81,7 +84,8 @@ class AppStatusStoreSuite extends SparkFunSuite { assert(store.count(classOf[CachedQuantile]) === 2) } - private def createAppStore(disk: Boolean, live: Boolean): AppStatusStore = { + private def createAppStore(disk: Boolean, diskStoreType: HybridStoreDiskBackend.Value = null, + live: Boolean): AppStatusStore = { val conf = new SparkConf() if (live) { return AppStatusStore.createLiveStore(conf) @@ -92,8 +96,9 @@ class AppStatusStoreSuite extends SparkFunSuite { } val store: KVStore = if (disk) { + conf.set(HYBRID_STORE_DISK_BACKEND, diskStoreType.toString) val testDir = Utils.createTempDir() - val diskStore = KVUtils.open(testDir, getClass.getName) + val diskStore = KVUtils.open(testDir, getClass.getName, conf) new ElementTrackingStore(diskStore, conf) } else { new ElementTrackingStore(new InMemoryStore, conf) @@ -102,7 +107,8 @@ class AppStatusStoreSuite extends SparkFunSuite { } Seq( - "disk" -> createAppStore(disk = true, live = false), + "disk leveldb" -> createAppStore(disk = true, HybridStoreDiskBackend.LEVELDB, live = false), + "disk rocksdb" -> createAppStore(disk = true, HybridStoreDiskBackend.ROCKSDB, live = false), "in memory" -> createAppStore(disk = false, live = false), "in memory live" -> createAppStore(disk = false, live = true) ).foreach { case (hint, appStore) => @@ -133,13 +139,52 @@ class AppStatusStoreSuite extends SparkFunSuite { * Task summary will consider (1, 3, 5) only */ val summary = appStore.taskSummary(stageId, attemptId, uiQuantiles).get + val successfulTasks = Array(getTaskMetrics(1), getTaskMetrics(3), getTaskMetrics(5)) - val values = Array(1.0, 3.0, 5.0) + def assertQuantiles(metricGetter: TaskMetrics => Double, + actualQuantiles: Seq[Double]): Unit = { + val values = successfulTasks.map(metricGetter) + val expectedQuantiles = new Distribution(values, 0, values.length) + .getQuantiles(uiQuantiles.sorted) - val dist = new Distribution(values, 0, values.length).getQuantiles(uiQuantiles.sorted) - dist.zip(summary.executorRunTime).foreach { case (expected, actual) => - assert(expected === actual) + assert(actualQuantiles === expectedQuantiles) } + + assertQuantiles(_.executorDeserializeTime, summary.executorDeserializeTime) + assertQuantiles(_.executorDeserializeCpuTime, summary.executorDeserializeCpuTime) + assertQuantiles(_.executorRunTime, summary.executorRunTime) + assertQuantiles(_.executorRunTime, summary.executorRunTime) + assertQuantiles(_.executorCpuTime, summary.executorCpuTime) + assertQuantiles(_.resultSize, summary.resultSize) + assertQuantiles(_.jvmGCTime, summary.jvmGcTime) + assertQuantiles(_.resultSerializationTime, summary.resultSerializationTime) + assertQuantiles(_.memoryBytesSpilled, summary.memoryBytesSpilled) + assertQuantiles(_.diskBytesSpilled, summary.diskBytesSpilled) + assertQuantiles(_.peakExecutionMemory, summary.peakExecutionMemory) + assertQuantiles(_.inputMetrics.bytesRead, summary.inputMetrics.bytesRead) + assertQuantiles(_.inputMetrics.recordsRead, summary.inputMetrics.recordsRead) + assertQuantiles(_.outputMetrics.bytesWritten, summary.outputMetrics.bytesWritten) + assertQuantiles(_.outputMetrics.recordsWritten, summary.outputMetrics.recordsWritten) + assertQuantiles(_.shuffleReadMetrics.remoteBlocksFetched, + summary.shuffleReadMetrics.remoteBlocksFetched) + assertQuantiles(_.shuffleReadMetrics.localBlocksFetched, + summary.shuffleReadMetrics.localBlocksFetched) + assertQuantiles(_.shuffleReadMetrics.fetchWaitTime, summary.shuffleReadMetrics.fetchWaitTime) + assertQuantiles(_.shuffleReadMetrics.remoteBytesRead, + summary.shuffleReadMetrics.remoteBytesRead) + assertQuantiles(_.shuffleReadMetrics.remoteBytesReadToDisk, + summary.shuffleReadMetrics.remoteBytesReadToDisk) + assertQuantiles( + t => t.shuffleReadMetrics.localBytesRead + t.shuffleReadMetrics.remoteBytesRead, + summary.shuffleReadMetrics.readBytes) + assertQuantiles( + t => t.shuffleReadMetrics.localBlocksFetched + t.shuffleReadMetrics.remoteBlocksFetched, + summary.shuffleReadMetrics.totalBlocksFetched) + assertQuantiles(_.shuffleWriteMetrics.bytesWritten, summary.shuffleWriteMetrics.writeBytes) + assertQuantiles(_.shuffleWriteMetrics.writeTime, summary.shuffleWriteMetrics.writeTime) + assertQuantiles(_.shuffleWriteMetrics.recordsWritten, + summary.shuffleWriteMetrics.writeRecords) + appStore.close() } } @@ -223,32 +268,41 @@ class AppStatusStoreSuite extends SparkFunSuite { liveTask.write(store.asInstanceOf[ElementTrackingStore], 1L) } - private def getTaskMetrics(i: Int): TaskMetrics = { + /** + * Creates fake task metrics + * @param seed The random seed. The output will be reproducible for a given seed. + * @return The test metrics object with fake data + */ + private def getTaskMetrics(seed: Int): TaskMetrics = { + val random = new Random(seed) + val randomMax = 1000 + def nextInt(): Int = random.nextInt(randomMax) + val taskMetrics = new TaskMetrics() - taskMetrics.setExecutorDeserializeTime(i) - taskMetrics.setExecutorDeserializeCpuTime(i) - taskMetrics.setExecutorRunTime(i) - taskMetrics.setExecutorCpuTime(i) - taskMetrics.setResultSize(i) - taskMetrics.setJvmGCTime(i) - taskMetrics.setResultSerializationTime(i) - taskMetrics.incMemoryBytesSpilled(i) - taskMetrics.incDiskBytesSpilled(i) - taskMetrics.incPeakExecutionMemory(i) - taskMetrics.inputMetrics.incBytesRead(i) - taskMetrics.inputMetrics.incRecordsRead(i) - taskMetrics.outputMetrics.setBytesWritten(i) - taskMetrics.outputMetrics.setRecordsWritten(i) - taskMetrics.shuffleReadMetrics.incRemoteBlocksFetched(i) - taskMetrics.shuffleReadMetrics.incLocalBlocksFetched(i) - taskMetrics.shuffleReadMetrics.incFetchWaitTime(i) - taskMetrics.shuffleReadMetrics.incRemoteBytesRead(i) - taskMetrics.shuffleReadMetrics.incRemoteBytesReadToDisk(i) - taskMetrics.shuffleReadMetrics.incLocalBytesRead(i) - taskMetrics.shuffleReadMetrics.incRecordsRead(i) - taskMetrics.shuffleWriteMetrics.incBytesWritten(i) - taskMetrics.shuffleWriteMetrics.incWriteTime(i) - taskMetrics.shuffleWriteMetrics.incRecordsWritten(i) + taskMetrics.setExecutorDeserializeTime(nextInt()) + taskMetrics.setExecutorDeserializeCpuTime(nextInt()) + taskMetrics.setExecutorRunTime(nextInt()) + taskMetrics.setExecutorCpuTime(nextInt()) + taskMetrics.setResultSize(nextInt()) + taskMetrics.setJvmGCTime(nextInt()) + taskMetrics.setResultSerializationTime(nextInt()) + taskMetrics.incMemoryBytesSpilled(nextInt()) + taskMetrics.incDiskBytesSpilled(nextInt()) + taskMetrics.incPeakExecutionMemory(nextInt()) + taskMetrics.inputMetrics.incBytesRead(nextInt()) + taskMetrics.inputMetrics.incRecordsRead(nextInt()) + taskMetrics.outputMetrics.setBytesWritten(nextInt()) + taskMetrics.outputMetrics.setRecordsWritten(nextInt()) + taskMetrics.shuffleReadMetrics.incRemoteBlocksFetched(nextInt()) + taskMetrics.shuffleReadMetrics.incLocalBlocksFetched(nextInt()) + taskMetrics.shuffleReadMetrics.incFetchWaitTime(nextInt()) + taskMetrics.shuffleReadMetrics.incRemoteBytesRead(nextInt()) + taskMetrics.shuffleReadMetrics.incRemoteBytesReadToDisk(nextInt()) + taskMetrics.shuffleReadMetrics.incLocalBytesRead(nextInt()) + taskMetrics.shuffleReadMetrics.incRecordsRead(nextInt()) + taskMetrics.shuffleWriteMetrics.incBytesWritten(nextInt()) + taskMetrics.shuffleWriteMetrics.incWriteTime(nextInt()) + taskMetrics.shuffleWriteMetrics.incRecordsWritten(nextInt()) taskMetrics } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala index 8999a121bcd15..e004c334dee73 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala @@ -165,9 +165,10 @@ class BlockManagerDecommissionIntegrationSuite extends SparkFunSuite with LocalS } x.map(y => (y, y)) } - val testRdd = shuffle match { - case true => baseRdd.reduceByKey(_ + _) - case false => baseRdd + val testRdd = if (shuffle) { + baseRdd.reduceByKey(_ + _) + } else { + baseRdd } // Listen for the job & block updates diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index d1dc083868baf..0f99ea819f67f 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -886,7 +886,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val blockSize = inv.getArguments()(2).asInstanceOf[Long] val res = store1.readDiskBlockFromSameHostExecutor(blockId, localDirs, blockSize) assert(res.isDefined) - val file = ExecutorDiskUtils.getFile(localDirs, store1.subDirsPerLocalDir, blockId.name) + val file = new File( + ExecutorDiskUtils.getFilePath(localDirs, store1.subDirsPerLocalDir, blockId.name)) // delete the file behind the blockId assert(file.delete()) sameHostExecutorTried = true @@ -2229,7 +2230,9 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE blockData: ManagedBuffer, level: StorageLevel, classTag: ClassTag[_]): Future[Unit] = { + // scalastyle:off executioncontextglobal import scala.concurrent.ExecutionContext.Implicits.global + // scalastyle:on executioncontextglobal Future {} } diff --git a/core/src/test/scala/org/apache/spark/storage/FallbackStorageSuite.scala b/core/src/test/scala/org/apache/spark/storage/FallbackStorageSuite.scala index 7d648c979cd60..3828e9d8297a6 100644 --- a/core/src/test/scala/org/apache/spark/storage/FallbackStorageSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/FallbackStorageSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.storage import java.io.{DataOutputStream, File, FileOutputStream, IOException} -import java.net.{InetAddress, UnknownHostException} import java.nio.file.Files import scala.concurrent.duration._ import org.apache.hadoop.conf.Configuration import org.mockito.{ArgumentMatchers => mc} -import org.mockito.Mockito.{mock, times, verify, when} +import org.mockito.Mockito.{mock, never, verify, when} import org.scalatest.concurrent.Eventually.{eventually, interval, timeout} import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite, TestUtils} @@ -42,13 +41,6 @@ import org.apache.spark.util.Utils.tryWithResource class FallbackStorageSuite extends SparkFunSuite with LocalSparkContext { def getSparkConf(initialExecutor: Int = 1, minExecutor: Int = 1): SparkConf = { - // Some DNS always replies for all hostnames including unknown host names - try { - InetAddress.getByName(FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.host) - assume(false) - } catch { - case _: UnknownHostException => - } new SparkConf(false) .setAppName(getClass.getName) .set(SPARK_MASTER, s"local-cluster[$initialExecutor,1,1024]") @@ -179,8 +171,8 @@ class FallbackStorageSuite extends SparkFunSuite with LocalSparkContext { decommissioner.start() val fallbackStorage = new FallbackStorage(conf) eventually(timeout(10.second), interval(1.seconds)) { - // uploadBlockSync is not used - verify(blockTransferService, times(1)) + // uploadBlockSync should not be used, verify that it is not called + verify(blockTransferService, never()) .uploadBlockSync(mc.any(), mc.any(), mc.any(), mc.any(), mc.any(), mc.any(), mc.any()) Seq("shuffle_1_1_0.index", "shuffle_1_1_0.data").foreach { filename => diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index afb9a862b113c..e6f052510462d 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -24,7 +24,9 @@ import java.util.concurrent.{CompletableFuture, Semaphore} import java.util.zip.CheckedInputStream import scala.collection.mutable +// scalastyle:off executioncontextglobal import scala.concurrent.ExecutionContext.Implicits.global +// scalastyle:on executioncontextglobal import scala.concurrent.Future import com.google.common.io.ByteStreams @@ -160,10 +162,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT verify(buffer, times(0)).release() val delegateAccess = PrivateMethod[InputStream](Symbol("delegate")) var in = wrappedInputStream.invokePrivate(delegateAccess()) - if (in.isInstanceOf[CheckedInputStream]) { - val underlyingInputFiled = classOf[CheckedInputStream].getSuperclass.getDeclaredField("in") - underlyingInputFiled.setAccessible(true) - in = underlyingInputFiled.get(in.asInstanceOf[CheckedInputStream]).asInstanceOf[InputStream] + in match { + case stream: CheckedInputStream => + val underlyingInputFiled = classOf[CheckedInputStream].getSuperclass.getDeclaredField("in") + underlyingInputFiled.setAccessible(true) + in = underlyingInputFiled.get(stream).asInstanceOf[InputStream] + case _ => // do nothing } verify(in, times(0)).close() wrappedInputStream.close() @@ -197,7 +201,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager.getOrElse(createMockBlockManager()), mapOutputTracker, - blocksByAddress.toIterator, + blocksByAddress.iterator, (_, in) => streamWrapperLimitSize.map(new LimitedInputStream(in, _)).getOrElse(in), maxBytesInFlight, maxReqsInFlight, diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index 1a2eb6950c403..8ca4bc9a1527b 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -222,14 +222,15 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { // assert(appender.getClass === classTag[ExpectedAppender].getClass) assert(appender.getClass.getSimpleName === classTag[ExpectedAppender].runtimeClass.getSimpleName) - if (appender.isInstanceOf[RollingFileAppender]) { - val rollingPolicy = appender.asInstanceOf[RollingFileAppender].rollingPolicy - val policyParam = if (rollingPolicy.isInstanceOf[TimeBasedRollingPolicy]) { - rollingPolicy.asInstanceOf[TimeBasedRollingPolicy].rolloverIntervalMillis - } else { - rollingPolicy.asInstanceOf[SizeBasedRollingPolicy].rolloverSizeBytes - } - assert(policyParam === expectedRollingPolicyParam) + appender match { + case rfa: RollingFileAppender => + val rollingPolicy = rfa.rollingPolicy + val policyParam = rollingPolicy match { + case timeBased: TimeBasedRollingPolicy => timeBased.rolloverIntervalMillis + case sizeBased: SizeBasedRollingPolicy => sizeBased.rolloverSizeBytes + } + assert(policyParam === expectedRollingPolicyParam) + case _ => // do nothing } testOutputStream.close() appender.awaitTermination() diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 4eea2256553f5..a3dc2d8fa735e 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -96,6 +96,9 @@ class JsonProtocolSuite extends SparkFunSuite { val applicationEnd = SparkListenerApplicationEnd(42L) val executorAdded = SparkListenerExecutorAdded(executorAddedTime, "exec1", new ExecutorInfo("Hostee.awesome.com", 11, logUrlMap, attributes, resources.toMap, 4)) + val executorAddedWithTime = SparkListenerExecutorAdded(executorAddedTime, "exec1", + new ExecutorInfo("Hostee.awesome.com", 11, logUrlMap, attributes, resources.toMap, 4, + Some(0), Some(1))) val executorRemoved = SparkListenerExecutorRemoved(executorRemovedTime, "exec2", "test reason") val executorBlacklisted = SparkListenerExecutorBlacklisted(executorExcludedTime, "exec1", 22) val executorUnblacklisted = @@ -155,6 +158,7 @@ class JsonProtocolSuite extends SparkFunSuite { testEvent(applicationStartWithLogs, applicationStartJsonWithLogUrlsString) testEvent(applicationEnd, applicationEndJsonString) testEvent(executorAdded, executorAddedJsonString) + testEvent(executorAddedWithTime, executorAddedWithTimeJsonString) testEvent(executorRemoved, executorRemovedJsonString) testEvent(executorBlacklisted, executorBlacklistedJsonString) testEvent(executorUnblacklisted, executorUnblacklistedJsonString) @@ -173,6 +177,7 @@ class JsonProtocolSuite extends SparkFunSuite { test("Dependent Classes") { val logUrlMap = Map("stderr" -> "mystderr", "stdout" -> "mystdout").toMap val attributes = Map("ContainerId" -> "ct1", "User" -> "spark").toMap + val rinfo = Map[String, ResourceInformation]().toMap testRDDInfo(makeRddInfo(2, 3, 4, 5L, 6L, DeterministicLevel.DETERMINATE)) testStageInfo(makeStageInfo(10, 20, 30, 40L, 50L)) testTaskInfo(makeTaskInfo(999L, 888, 55, 777L, false)) @@ -180,6 +185,8 @@ class JsonProtocolSuite extends SparkFunSuite { 33333L, 44444L, 55555L, 66666L, 7, 8, hasHadoopInput = false, hasOutput = false)) testBlockManagerId(BlockManagerId("Hong", "Kong", 500)) testExecutorInfo(new ExecutorInfo("host", 43, logUrlMap, attributes)) + testExecutorInfo(new ExecutorInfo("host", 43, logUrlMap, attributes, + rinfo, 1, Some(0), Some(1))) // StorageLevel testStorageLevel(StorageLevel.NONE) @@ -2141,6 +2148,37 @@ private[spark] object JsonProtocolSuite extends Assertions { |} """.stripMargin + private val executorAddedWithTimeJsonString = + s""" + |{ + | "Event": "SparkListenerExecutorAdded", + | "Timestamp": ${executorAddedTime}, + | "Executor ID": "exec1", + | "Executor Info": { + | "Host": "Hostee.awesome.com", + | "Total Cores": 11, + | "Log Urls" : { + | "stderr" : "mystderr", + | "stdout" : "mystdout" + | }, + | "Attributes" : { + | "ContainerId" : "ct1", + | "User" : "spark" + | }, + | "Resources" : { + | "gpu" : { + | "name" : "gpu", + | "addresses" : [ "0", "1" ] + | } + | }, + | "Resource Profile Id": 4, + | "Registration Time" : 0, + | "Request Time" : 1 + | } + | + |} + """.stripMargin + private val executorRemovedJsonString = s""" |{ diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 6117decbf47eb..973c09884a6d4 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -227,15 +227,16 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { try { // Get a handle on the buffered data, to make sure memory gets freed once we read past the // end of it. Need to use reflection to get handle on inner structures for this check - val byteBufferInputStream = if (mergedStream.isInstanceOf[ChunkedByteBufferInputStream]) { - assert(inputLength < limit) - mergedStream.asInstanceOf[ChunkedByteBufferInputStream] - } else { - assert(inputLength >= limit) - val sequenceStream = mergedStream.asInstanceOf[SequenceInputStream] - val fieldValue = getFieldValue(sequenceStream, "in") - assert(fieldValue.isInstanceOf[ChunkedByteBufferInputStream]) - fieldValue.asInstanceOf[ChunkedByteBufferInputStream] + val byteBufferInputStream = mergedStream match { + case stream: ChunkedByteBufferInputStream => + assert(inputLength < limit) + stream + case _ => + assert(inputLength >= limit) + val sequenceStream = mergedStream.asInstanceOf[SequenceInputStream] + val fieldValue = getFieldValue(sequenceStream, "in") + assert(fieldValue.isInstanceOf[ChunkedByteBufferInputStream]) + fieldValue.asInstanceOf[ChunkedByteBufferInputStream] } (0 until inputLength).foreach { idx => assert(bytes(idx) === mergedStream.read().asInstanceOf[Byte]) @@ -463,7 +464,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { test("get iterator size") { val empty = Seq[Int]() - assert(Utils.getIteratorSize(empty.toIterator) === 0L) + assert(Utils.getIteratorSize(empty.iterator) === 0L) val iterator = Iterator.range(0, 5) assert(Utils.getIteratorSize(iterator) === 5L) } diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 770b0d2651bd5..7fdc3839d8a4f 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -112,7 +112,6 @@ spark-deps-.* .*\.tsv .*\.sql .Rbuildignore -META-INF/* spark-warehouse structured-streaming/* kafka-source-initial-offset-version-2.1.0.bin diff --git a/dev/.scalafmt.conf b/dev/.scalafmt.conf index 9598540752ebd..d2196e601aa2d 100644 --- a/dev/.scalafmt.conf +++ b/dev/.scalafmt.conf @@ -25,4 +25,3 @@ optIn = { danglingParentheses = false docstrings = JavaDoc maxColumn = 98 -newlines.topLevelStatements = [before,after] diff --git a/dev/appveyor-install-dependencies.ps1 b/dev/appveyor-install-dependencies.ps1 index dd9acef2451ee..d469c98fdb3a2 100644 --- a/dev/appveyor-install-dependencies.ps1 +++ b/dev/appveyor-install-dependencies.ps1 @@ -97,7 +97,7 @@ if (!(Test-Path $tools)) { # ========================== SBT Push-Location $tools -$sbtVer = "1.6.1" +$sbtVer = "1.6.2" Start-FileDownload "https://github.com/sbt/sbt/releases/download/v$sbtVer/sbt-$sbtVer.zip" "sbt.zip" # extract diff --git a/dev/checkstyle.xml b/dev/checkstyle.xml index b6abfb57c2019..6c93ff94fd9f2 100644 --- a/dev/checkstyle.xml +++ b/dev/checkstyle.xml @@ -189,6 +189,10 @@ + + + + diff --git a/dev/deps/spark-deps-hadoop-2-hive-2.3 b/dev/deps/spark-deps-hadoop-2-hive-2.3 index c35751c50622b..f2db663550407 100644 --- a/dev/deps/spark-deps-hadoop-2-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-2-hive-2.3 @@ -1,7 +1,7 @@ HikariCP/2.5.1//HikariCP-2.5.1.jar JLargeArrays/1.5//JLargeArrays-1.5.jar JTransforms/3.1//JTransforms-3.1.jar -RoaringBitmap/0.9.23//RoaringBitmap-0.9.23.jar +RoaringBitmap/0.9.25//RoaringBitmap-0.9.25.jar ST4/4.0.4//ST4-4.0.4.jar activation/1.1.1//activation-1.1.1.jar aircompressor/0.21//aircompressor-0.21.jar @@ -17,10 +17,10 @@ api-asn1-api/1.0.0-M20//api-asn1-api-1.0.0-M20.jar api-util/1.0.0-M20//api-util-1.0.0-M20.jar arpack/2.2.1//arpack-2.2.1.jar arpack_combined_all/0.1//arpack_combined_all-0.1.jar -arrow-format/6.0.1//arrow-format-6.0.1.jar -arrow-memory-core/6.0.1//arrow-memory-core-6.0.1.jar -arrow-memory-netty/6.0.1//arrow-memory-netty-6.0.1.jar -arrow-vector/6.0.1//arrow-vector-6.0.1.jar +arrow-format/7.0.0//arrow-format-7.0.0.jar +arrow-memory-core/7.0.0//arrow-memory-core-7.0.0.jar +arrow-memory-netty/7.0.0//arrow-memory-netty-7.0.0.jar +arrow-vector/7.0.0//arrow-vector-7.0.0.jar audience-annotations/0.5.0//audience-annotations-0.5.0.jar automaton/1.11-8//automaton-1.11-8.jar avro-ipc/1.11.0//avro-ipc-1.11.0.jar @@ -38,6 +38,7 @@ commons-beanutils/1.9.4//commons-beanutils-1.9.4.jar commons-cli/1.5.0//commons-cli-1.5.0.jar commons-codec/1.15//commons-codec-1.15.jar commons-collections/3.2.2//commons-collections-3.2.2.jar +commons-collections4/4.4//commons-collections4-4.4.jar commons-compiler/3.0.16//commons-compiler-3.0.16.jar commons-compress/1.21//commons-compress-1.21.jar commons-configuration/1.6//commons-configuration-1.6.jar @@ -49,11 +50,11 @@ commons-io/2.4//commons-io-2.4.jar commons-lang/2.6//commons-lang-2.6.jar commons-lang3/3.12.0//commons-lang3-3.12.0.jar commons-logging/1.1.3//commons-logging-1.1.3.jar -commons-math3/3.4.1//commons-math3-3.4.1.jar +commons-math3/3.6.1//commons-math3-3.6.1.jar commons-net/3.1//commons-net-3.1.jar commons-pool/1.5.4//commons-pool-1.5.4.jar -commons-text/1.6//commons-text-1.6.jar -compress-lzf/1.0.3//compress-lzf-1.0.3.jar +commons-text/1.9//commons-text-1.9.jar +compress-lzf/1.1//compress-lzf-1.1.jar core/1.1.2//core-1.1.2.jar curator-client/2.7.1//curator-client-2.7.1.jar curator-framework/2.7.1//curator-framework-2.7.1.jar @@ -111,16 +112,16 @@ httpclient/4.5.13//httpclient-4.5.13.jar httpcore/4.4.14//httpcore-4.4.14.jar istack-commons-runtime/3.0.8//istack-commons-runtime-3.0.8.jar ivy/2.5.0//ivy-2.5.0.jar -jackson-annotations/2.13.1//jackson-annotations-2.13.1.jar +jackson-annotations/2.13.2//jackson-annotations-2.13.2.jar jackson-core-asl/1.9.13//jackson-core-asl-1.9.13.jar -jackson-core/2.13.1//jackson-core-2.13.1.jar -jackson-databind/2.13.1//jackson-databind-2.13.1.jar -jackson-dataformat-cbor/2.13.1//jackson-dataformat-cbor-2.13.1.jar -jackson-dataformat-yaml/2.13.1//jackson-dataformat-yaml-2.13.1.jar -jackson-datatype-jsr310/2.13.0//jackson-datatype-jsr310-2.13.0.jar +jackson-core/2.13.2//jackson-core-2.13.2.jar +jackson-databind/2.13.2//jackson-databind-2.13.2.jar +jackson-dataformat-cbor/2.13.2//jackson-dataformat-cbor-2.13.2.jar +jackson-dataformat-yaml/2.13.2//jackson-dataformat-yaml-2.13.2.jar +jackson-datatype-jsr310/2.13.1//jackson-datatype-jsr310-2.13.1.jar jackson-jaxrs/1.9.13//jackson-jaxrs-1.9.13.jar jackson-mapper-asl/1.9.13//jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.12/2.13.1//jackson-module-scala_2.12-2.13.1.jar +jackson-module-scala_2.12/2.13.2//jackson-module-scala_2.12-2.13.2.jar jackson-xc/1.9.13//jackson-xc-1.9.13.jar jakarta.annotation-api/1.3.5//jakarta.annotation-api-1.3.5.jar jakarta.inject/2.6.1//jakarta.inject-2.6.1.jar @@ -145,10 +146,10 @@ jersey-hk2/2.34//jersey-hk2-2.34.jar jersey-server/2.34//jersey-server-2.34.jar jetty-sslengine/6.1.26//jetty-sslengine-6.1.26.jar jetty-util/6.1.26//jetty-util-6.1.26.jar -jetty-util/9.4.43.v20210629//jetty-util-9.4.43.v20210629.jar +jetty-util/9.4.44.v20210927//jetty-util-9.4.44.v20210927.jar jetty/6.1.26//jetty-6.1.26.jar jline/2.14.6//jline-2.14.6.jar -joda-time/2.10.12//joda-time-2.10.12.jar +joda-time/2.10.13//joda-time-2.10.13.jar jodd-core/3.5.2//jodd-core-3.5.2.jar jpam/1.1//jpam-1.1.jar json/1.8//json-1.8.jar @@ -161,27 +162,27 @@ jsr305/3.0.0//jsr305-3.0.0.jar jta/1.1//jta-1.1.jar jul-to-slf4j/1.7.32//jul-to-slf4j-1.7.32.jar kryo-shaded/4.0.2//kryo-shaded-4.0.2.jar -kubernetes-client/5.10.2//kubernetes-client-5.10.2.jar -kubernetes-model-admissionregistration/5.10.2//kubernetes-model-admissionregistration-5.10.2.jar -kubernetes-model-apiextensions/5.10.2//kubernetes-model-apiextensions-5.10.2.jar -kubernetes-model-apps/5.10.2//kubernetes-model-apps-5.10.2.jar -kubernetes-model-autoscaling/5.10.2//kubernetes-model-autoscaling-5.10.2.jar -kubernetes-model-batch/5.10.2//kubernetes-model-batch-5.10.2.jar -kubernetes-model-certificates/5.10.2//kubernetes-model-certificates-5.10.2.jar -kubernetes-model-common/5.10.2//kubernetes-model-common-5.10.2.jar -kubernetes-model-coordination/5.10.2//kubernetes-model-coordination-5.10.2.jar -kubernetes-model-core/5.10.2//kubernetes-model-core-5.10.2.jar -kubernetes-model-discovery/5.10.2//kubernetes-model-discovery-5.10.2.jar -kubernetes-model-events/5.10.2//kubernetes-model-events-5.10.2.jar -kubernetes-model-extensions/5.10.2//kubernetes-model-extensions-5.10.2.jar -kubernetes-model-flowcontrol/5.10.2//kubernetes-model-flowcontrol-5.10.2.jar -kubernetes-model-metrics/5.10.2//kubernetes-model-metrics-5.10.2.jar -kubernetes-model-networking/5.10.2//kubernetes-model-networking-5.10.2.jar -kubernetes-model-node/5.10.2//kubernetes-model-node-5.10.2.jar -kubernetes-model-policy/5.10.2//kubernetes-model-policy-5.10.2.jar -kubernetes-model-rbac/5.10.2//kubernetes-model-rbac-5.10.2.jar -kubernetes-model-scheduling/5.10.2//kubernetes-model-scheduling-5.10.2.jar -kubernetes-model-storageclass/5.10.2//kubernetes-model-storageclass-5.10.2.jar +kubernetes-client/5.12.1//kubernetes-client-5.12.1.jar +kubernetes-model-admissionregistration/5.12.1//kubernetes-model-admissionregistration-5.12.1.jar +kubernetes-model-apiextensions/5.12.1//kubernetes-model-apiextensions-5.12.1.jar +kubernetes-model-apps/5.12.1//kubernetes-model-apps-5.12.1.jar +kubernetes-model-autoscaling/5.12.1//kubernetes-model-autoscaling-5.12.1.jar +kubernetes-model-batch/5.12.1//kubernetes-model-batch-5.12.1.jar +kubernetes-model-certificates/5.12.1//kubernetes-model-certificates-5.12.1.jar +kubernetes-model-common/5.12.1//kubernetes-model-common-5.12.1.jar +kubernetes-model-coordination/5.12.1//kubernetes-model-coordination-5.12.1.jar +kubernetes-model-core/5.12.1//kubernetes-model-core-5.12.1.jar +kubernetes-model-discovery/5.12.1//kubernetes-model-discovery-5.12.1.jar +kubernetes-model-events/5.12.1//kubernetes-model-events-5.12.1.jar +kubernetes-model-extensions/5.12.1//kubernetes-model-extensions-5.12.1.jar +kubernetes-model-flowcontrol/5.12.1//kubernetes-model-flowcontrol-5.12.1.jar +kubernetes-model-metrics/5.12.1//kubernetes-model-metrics-5.12.1.jar +kubernetes-model-networking/5.12.1//kubernetes-model-networking-5.12.1.jar +kubernetes-model-node/5.12.1//kubernetes-model-node-5.12.1.jar +kubernetes-model-policy/5.12.1//kubernetes-model-policy-5.12.1.jar +kubernetes-model-rbac/5.12.1//kubernetes-model-rbac-5.12.1.jar +kubernetes-model-scheduling/5.12.1//kubernetes-model-scheduling-5.12.1.jar +kubernetes-model-storageclass/5.12.1//kubernetes-model-storageclass-5.12.1.jar lapack/2.2.1//lapack-2.2.1.jar leveldbjni-all/1.8//leveldbjni-all-1.8.jar libfb303/0.9.3//libfb303-0.9.3.jar @@ -192,36 +193,35 @@ log4j-core/2.17.1//log4j-core-2.17.1.jar log4j-slf4j-impl/2.17.1//log4j-slf4j-impl-2.17.1.jar logging-interceptor/3.12.12//logging-interceptor-3.12.12.jar lz4-java/1.8.0//lz4-java-1.8.0.jar -macro-compat_2.12/1.1.1//macro-compat_2.12-1.1.1.jar mesos/1.4.3/shaded-protobuf/mesos-1.4.3-shaded-protobuf.jar -metrics-core/4.2.2//metrics-core-4.2.2.jar -metrics-graphite/4.2.2//metrics-graphite-4.2.2.jar -metrics-jmx/4.2.2//metrics-jmx-4.2.2.jar -metrics-json/4.2.2//metrics-json-4.2.2.jar -metrics-jvm/4.2.2//metrics-jvm-4.2.2.jar +metrics-core/4.2.7//metrics-core-4.2.7.jar +metrics-graphite/4.2.7//metrics-graphite-4.2.7.jar +metrics-jmx/4.2.7//metrics-jmx-4.2.7.jar +metrics-json/4.2.7//metrics-json-4.2.7.jar +metrics-jvm/4.2.7//metrics-jvm-4.2.7.jar minlog/1.3.0//minlog-1.3.0.jar -netty-all/4.1.72.Final//netty-all-4.1.72.Final.jar -netty-buffer/4.1.72.Final//netty-buffer-4.1.72.Final.jar -netty-codec/4.1.72.Final//netty-codec-4.1.72.Final.jar -netty-common/4.1.72.Final//netty-common-4.1.72.Final.jar -netty-handler/4.1.72.Final//netty-handler-4.1.72.Final.jar -netty-resolver/4.1.72.Final//netty-resolver-4.1.72.Final.jar -netty-tcnative-classes/2.0.46.Final//netty-tcnative-classes-2.0.46.Final.jar -netty-transport-classes-epoll/4.1.72.Final//netty-transport-classes-epoll-4.1.72.Final.jar -netty-transport-classes-kqueue/4.1.72.Final//netty-transport-classes-kqueue-4.1.72.Final.jar -netty-transport-native-epoll/4.1.72.Final/linux-aarch_64/netty-transport-native-epoll-4.1.72.Final-linux-aarch_64.jar -netty-transport-native-epoll/4.1.72.Final/linux-x86_64/netty-transport-native-epoll-4.1.72.Final-linux-x86_64.jar -netty-transport-native-kqueue/4.1.72.Final/osx-aarch_64/netty-transport-native-kqueue-4.1.72.Final-osx-aarch_64.jar -netty-transport-native-kqueue/4.1.72.Final/osx-x86_64/netty-transport-native-kqueue-4.1.72.Final-osx-x86_64.jar -netty-transport-native-unix-common/4.1.72.Final//netty-transport-native-unix-common-4.1.72.Final.jar -netty-transport/4.1.72.Final//netty-transport-4.1.72.Final.jar +netty-all/4.1.74.Final//netty-all-4.1.74.Final.jar +netty-buffer/4.1.74.Final//netty-buffer-4.1.74.Final.jar +netty-codec/4.1.74.Final//netty-codec-4.1.74.Final.jar +netty-common/4.1.74.Final//netty-common-4.1.74.Final.jar +netty-handler/4.1.74.Final//netty-handler-4.1.74.Final.jar +netty-resolver/4.1.74.Final//netty-resolver-4.1.74.Final.jar +netty-tcnative-classes/2.0.48.Final//netty-tcnative-classes-2.0.48.Final.jar +netty-transport-classes-epoll/4.1.74.Final//netty-transport-classes-epoll-4.1.74.Final.jar +netty-transport-classes-kqueue/4.1.74.Final//netty-transport-classes-kqueue-4.1.74.Final.jar +netty-transport-native-epoll/4.1.74.Final/linux-aarch_64/netty-transport-native-epoll-4.1.74.Final-linux-aarch_64.jar +netty-transport-native-epoll/4.1.74.Final/linux-x86_64/netty-transport-native-epoll-4.1.74.Final-linux-x86_64.jar +netty-transport-native-kqueue/4.1.74.Final/osx-aarch_64/netty-transport-native-kqueue-4.1.74.Final-osx-aarch_64.jar +netty-transport-native-kqueue/4.1.74.Final/osx-x86_64/netty-transport-native-kqueue-4.1.74.Final-osx-x86_64.jar +netty-transport-native-unix-common/4.1.74.Final//netty-transport-native-unix-common-4.1.74.Final.jar +netty-transport/4.1.74.Final//netty-transport-4.1.74.Final.jar objenesis/3.2//objenesis-3.2.jar okhttp/3.12.12//okhttp-3.12.12.jar okio/1.14.0//okio-1.14.0.jar opencsv/2.3//opencsv-2.3.jar -orc-core/1.7.2//orc-core-1.7.2.jar -orc-mapreduce/1.7.2//orc-mapreduce-1.7.2.jar -orc-shims/1.7.2//orc-shims-1.7.2.jar +orc-core/1.7.3//orc-core-1.7.3.jar +orc-mapreduce/1.7.3//orc-mapreduce-1.7.3.jar +orc-shims/1.7.3//orc-shims-1.7.3.jar oro/2.0.8//oro-2.0.8.jar osgi-resource-locator/1.0.3//osgi-resource-locator-1.0.3.jar paranamer/2.8//paranamer-2.8.jar @@ -233,7 +233,7 @@ parquet-hadoop/1.12.2//parquet-hadoop-1.12.2.jar parquet-jackson/1.12.2//parquet-jackson-1.12.2.jar pickle/1.2//pickle-1.2.jar protobuf-java/2.5.0//protobuf-java-2.5.0.jar -py4j/0.10.9.3//py4j-0.10.9.3.jar +py4j/0.10.9.4//py4j-0.10.9.4.jar remotetea-oncrpc/1.1.2//remotetea-oncrpc-1.1.2.jar rocksdbjni/6.20.3//rocksdbjni-6.20.3.jar scala-collection-compat_2.12/2.1.1//scala-collection-compat_2.12-2.1.1.jar @@ -242,10 +242,10 @@ scala-library/2.12.15//scala-library-2.12.15.jar scala-parser-combinators_2.12/1.1.2//scala-parser-combinators_2.12-1.1.2.jar scala-reflect/2.12.15//scala-reflect-2.12.15.jar scala-xml_2.12/1.2.0//scala-xml_2.12-1.2.0.jar -shapeless_2.12/2.3.3//shapeless_2.12-2.3.3.jar -shims/0.9.23//shims-0.9.23.jar +shapeless_2.12/2.3.7//shapeless_2.12-2.3.7.jar +shims/0.9.25//shims-0.9.25.jar slf4j-api/1.7.32//slf4j-api-1.7.32.jar -snakeyaml/1.28//snakeyaml-1.28.jar +snakeyaml/1.30//snakeyaml-1.30.jar snappy-java/1.1.8.4//snappy-java-1.1.8.4.jar spire-macros_2.12/0.17.0//spire-macros_2.12-0.17.0.jar spire-platform_2.12/0.17.0//spire-platform_2.12-0.17.0.jar @@ -255,7 +255,7 @@ stax-api/1.0.1//stax-api-1.0.1.jar stream/2.9.6//stream-2.9.6.jar super-csv/2.2.0//super-csv-2.2.0.jar threeten-extra/1.5.0//threeten-extra-1.5.0.jar -tink/1.6.0//tink-1.6.0.jar +tink/1.6.1//tink-1.6.1.jar transaction-api/1.1//transaction-api-1.1.jar univocity-parsers/2.9.1//univocity-parsers-2.9.1.jar velocity/1.5//velocity-1.5.jar @@ -267,4 +267,4 @@ xz/1.8//xz-1.8.jar zjsonpatch/0.3.0//zjsonpatch-0.3.0.jar zookeeper-jute/3.6.2//zookeeper-jute-3.6.2.jar zookeeper/3.6.2//zookeeper-3.6.2.jar -zstd-jni/1.5.1-1//zstd-jni-1.5.1-1.jar +zstd-jni/1.5.2-1//zstd-jni-1.5.2-1.jar diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 51aaba30cf4e5..c56b4c9bb6826 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -1,32 +1,31 @@ HikariCP/2.5.1//HikariCP-2.5.1.jar JLargeArrays/1.5//JLargeArrays-1.5.jar JTransforms/3.1//JTransforms-3.1.jar -RoaringBitmap/0.9.23//RoaringBitmap-0.9.23.jar +RoaringBitmap/0.9.25//RoaringBitmap-0.9.25.jar ST4/4.0.4//ST4-4.0.4.jar activation/1.1.1//activation-1.1.1.jar aircompressor/0.21//aircompressor-0.21.jar algebra_2.12/2.0.1//algebra_2.12-2.0.1.jar -aliyun-java-sdk-core/3.4.0//aliyun-java-sdk-core-3.4.0.jar -aliyun-java-sdk-ecs/4.2.0//aliyun-java-sdk-ecs-4.2.0.jar -aliyun-java-sdk-ram/3.0.0//aliyun-java-sdk-ram-3.0.0.jar -aliyun-java-sdk-sts/3.0.0//aliyun-java-sdk-sts-3.0.0.jar -aliyun-sdk-oss/3.4.1//aliyun-sdk-oss-3.4.1.jar +aliyun-java-sdk-core/4.5.10//aliyun-java-sdk-core-4.5.10.jar +aliyun-java-sdk-kms/2.11.0//aliyun-java-sdk-kms-2.11.0.jar +aliyun-java-sdk-ram/3.1.0//aliyun-java-sdk-ram-3.1.0.jar +aliyun-sdk-oss/3.13.0//aliyun-sdk-oss-3.13.0.jar annotations/17.0.0//annotations-17.0.0.jar antlr-runtime/3.5.2//antlr-runtime-3.5.2.jar antlr4-runtime/4.8//antlr4-runtime-4.8.jar aopalliance-repackaged/2.6.1//aopalliance-repackaged-2.6.1.jar arpack/2.2.1//arpack-2.2.1.jar arpack_combined_all/0.1//arpack_combined_all-0.1.jar -arrow-format/6.0.1//arrow-format-6.0.1.jar -arrow-memory-core/6.0.1//arrow-memory-core-6.0.1.jar -arrow-memory-netty/6.0.1//arrow-memory-netty-6.0.1.jar -arrow-vector/6.0.1//arrow-vector-6.0.1.jar +arrow-format/7.0.0//arrow-format-7.0.0.jar +arrow-memory-core/7.0.0//arrow-memory-core-7.0.0.jar +arrow-memory-netty/7.0.0//arrow-memory-netty-7.0.0.jar +arrow-vector/7.0.0//arrow-vector-7.0.0.jar audience-annotations/0.5.0//audience-annotations-0.5.0.jar automaton/1.11-8//automaton-1.11-8.jar avro-ipc/1.11.0//avro-ipc-1.11.0.jar avro-mapred/1.11.0//avro-mapred-1.11.0.jar avro/1.11.0//avro-1.11.0.jar -aws-java-sdk-bundle/1.11.901//aws-java-sdk-bundle-1.11.901.jar +aws-java-sdk-bundle/1.11.1026//aws-java-sdk-bundle-1.11.1026.jar azure-data-lake-store-sdk/2.3.9//azure-data-lake-store-sdk-2.3.9.jar azure-keyvault-core/1.0.0//azure-keyvault-core-1.0.0.jar azure-storage/7.0.1//azure-storage-7.0.1.jar @@ -39,7 +38,7 @@ chill-java/0.10.0//chill-java-0.10.0.jar chill_2.12/0.10.0//chill_2.12-0.10.0.jar commons-cli/1.5.0//commons-cli-1.5.0.jar commons-codec/1.15//commons-codec-1.15.jar -commons-collections/3.2.2//commons-collections-3.2.2.jar +commons-collections4/4.4//commons-collections4-4.4.jar commons-compiler/3.0.16//commons-compiler-3.0.16.jar commons-compress/1.21//commons-compress-1.21.jar commons-crypto/1.1.0//commons-crypto-1.1.0.jar @@ -48,11 +47,10 @@ commons-io/2.11.0//commons-io-2.11.0.jar commons-lang/2.6//commons-lang-2.6.jar commons-lang3/3.12.0//commons-lang3-3.12.0.jar commons-logging/1.1.3//commons-logging-1.1.3.jar -commons-math3/3.4.1//commons-math3-3.4.1.jar -commons-net/3.1//commons-net-3.1.jar +commons-math3/3.6.1//commons-math3-3.6.1.jar commons-pool/1.5.4//commons-pool-1.5.4.jar -commons-text/1.6//commons-text-1.6.jar -compress-lzf/1.0.3//compress-lzf-1.0.3.jar +commons-text/1.9//commons-text-1.9.jar +compress-lzf/1.1//compress-lzf-1.1.jar core/1.1.2//core-1.1.2.jar cos_api-bundle/5.6.19//cos_api-bundle-5.6.19.jar curator-client/2.13.0//curator-client-2.13.0.jar @@ -68,18 +66,18 @@ generex/1.0.2//generex-1.0.2.jar gmetric4j/1.0.10//gmetric4j-1.0.10.jar gson/2.2.4//gson-2.2.4.jar guava/14.0.1//guava-14.0.1.jar -hadoop-aliyun/3.3.1//hadoop-aliyun-3.3.1.jar -hadoop-annotations/3.3.1//hadoop-annotations-3.3.1.jar -hadoop-aws/3.3.1//hadoop-aws-3.3.1.jar -hadoop-azure-datalake/3.3.1//hadoop-azure-datalake-3.3.1.jar -hadoop-azure/3.3.1//hadoop-azure-3.3.1.jar -hadoop-client-api/3.3.1//hadoop-client-api-3.3.1.jar -hadoop-client-runtime/3.3.1//hadoop-client-runtime-3.3.1.jar -hadoop-cloud-storage/3.3.1//hadoop-cloud-storage-3.3.1.jar -hadoop-cos/3.3.1//hadoop-cos-3.3.1.jar -hadoop-openstack/3.3.1//hadoop-openstack-3.3.1.jar +hadoop-aliyun/3.3.2//hadoop-aliyun-3.3.2.jar +hadoop-annotations/3.3.2//hadoop-annotations-3.3.2.jar +hadoop-aws/3.3.2//hadoop-aws-3.3.2.jar +hadoop-azure-datalake/3.3.2//hadoop-azure-datalake-3.3.2.jar +hadoop-azure/3.3.2//hadoop-azure-3.3.2.jar +hadoop-client-api/3.3.2//hadoop-client-api-3.3.2.jar +hadoop-client-runtime/3.3.2//hadoop-client-runtime-3.3.2.jar +hadoop-cloud-storage/3.3.2//hadoop-cloud-storage-3.3.2.jar +hadoop-cos/3.3.2//hadoop-cos-3.3.2.jar +hadoop-openstack/3.3.2//hadoop-openstack-3.3.2.jar hadoop-shaded-guava/1.1.1//hadoop-shaded-guava-1.1.1.jar -hadoop-yarn-server-web-proxy/3.3.1//hadoop-yarn-server-web-proxy-3.3.1.jar +hadoop-yarn-server-web-proxy/3.3.2//hadoop-yarn-server-web-proxy-3.3.2.jar hive-beeline/2.3.9//hive-beeline-2.3.9.jar hive-cli/2.3.9//hive-cli-2.3.9.jar hive-common/2.3.9//hive-common-2.3.9.jar @@ -98,20 +96,20 @@ hive-vector-code-gen/2.3.9//hive-vector-code-gen-2.3.9.jar hk2-api/2.6.1//hk2-api-2.6.1.jar hk2-locator/2.6.1//hk2-locator-2.6.1.jar hk2-utils/2.6.1//hk2-utils-2.6.1.jar -htrace-core4/4.1.0-incubating//htrace-core4-4.1.0-incubating.jar httpclient/4.5.13//httpclient-4.5.13.jar httpcore/4.4.14//httpcore-4.4.14.jar +ini4j/0.5.4//ini4j-0.5.4.jar istack-commons-runtime/3.0.8//istack-commons-runtime-3.0.8.jar ivy/2.5.0//ivy-2.5.0.jar -jackson-annotations/2.13.1//jackson-annotations-2.13.1.jar +jackson-annotations/2.13.2//jackson-annotations-2.13.2.jar jackson-core-asl/1.9.13//jackson-core-asl-1.9.13.jar -jackson-core/2.13.1//jackson-core-2.13.1.jar -jackson-databind/2.13.1//jackson-databind-2.13.1.jar -jackson-dataformat-cbor/2.13.1//jackson-dataformat-cbor-2.13.1.jar -jackson-dataformat-yaml/2.13.1//jackson-dataformat-yaml-2.13.1.jar -jackson-datatype-jsr310/2.13.0//jackson-datatype-jsr310-2.13.0.jar +jackson-core/2.13.2//jackson-core-2.13.2.jar +jackson-databind/2.13.2//jackson-databind-2.13.2.jar +jackson-dataformat-cbor/2.13.2//jackson-dataformat-cbor-2.13.2.jar +jackson-dataformat-yaml/2.13.2//jackson-dataformat-yaml-2.13.2.jar +jackson-datatype-jsr310/2.13.1//jackson-datatype-jsr310-2.13.1.jar jackson-mapper-asl/1.9.13//jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.12/2.13.1//jackson-module-scala_2.12-2.13.1.jar +jackson-module-scala_2.12/2.13.2//jackson-module-scala_2.12-2.13.2.jar jakarta.annotation-api/1.3.5//jakarta.annotation-api-1.3.5.jar jakarta.inject/2.6.1//jakarta.inject-2.6.1.jar jakarta.servlet-api/4.0.3//jakarta.servlet-api-4.0.3.jar @@ -122,10 +120,11 @@ janino/3.0.16//janino-3.0.16.jar javassist/3.25.0-GA//javassist-3.25.0-GA.jar javax.jdo/3.2.0-m3//javax.jdo-3.2.0-m3.jar javolution/5.5.1//javolution-5.5.1.jar +jaxb-api/2.2.11//jaxb-api-2.2.11.jar jaxb-runtime/2.3.2//jaxb-runtime-2.3.2.jar jcl-over-slf4j/1.7.32//jcl-over-slf4j-1.7.32.jar jdo-api/3.0.1//jdo-api-3.0.1.jar -jdom/1.1//jdom-1.1.jar +jdom2/2.0.6//jdom2-2.0.6.jar jersey-client/2.34//jersey-client-2.34.jar jersey-common/2.34//jersey-common-2.34.jar jersey-container-servlet-core/2.34//jersey-container-servlet-core-2.34.jar @@ -133,10 +132,10 @@ jersey-container-servlet/2.34//jersey-container-servlet-2.34.jar jersey-hk2/2.34//jersey-hk2-2.34.jar jersey-server/2.34//jersey-server-2.34.jar jettison/1.1//jettison-1.1.jar -jetty-util-ajax/9.4.43.v20210629//jetty-util-ajax-9.4.43.v20210629.jar -jetty-util/9.4.43.v20210629//jetty-util-9.4.43.v20210629.jar +jetty-util-ajax/9.4.44.v20210927//jetty-util-ajax-9.4.44.v20210927.jar +jetty-util/9.4.44.v20210927//jetty-util-9.4.44.v20210927.jar jline/2.14.6//jline-2.14.6.jar -joda-time/2.10.12//joda-time-2.10.12.jar +joda-time/2.10.13//joda-time-2.10.13.jar jodd-core/3.5.2//jodd-core-3.5.2.jar jpam/1.1//jpam-1.1.jar json/1.8//json-1.8.jar @@ -148,27 +147,27 @@ jsr305/3.0.0//jsr305-3.0.0.jar jta/1.1//jta-1.1.jar jul-to-slf4j/1.7.32//jul-to-slf4j-1.7.32.jar kryo-shaded/4.0.2//kryo-shaded-4.0.2.jar -kubernetes-client/5.10.2//kubernetes-client-5.10.2.jar -kubernetes-model-admissionregistration/5.10.2//kubernetes-model-admissionregistration-5.10.2.jar -kubernetes-model-apiextensions/5.10.2//kubernetes-model-apiextensions-5.10.2.jar -kubernetes-model-apps/5.10.2//kubernetes-model-apps-5.10.2.jar -kubernetes-model-autoscaling/5.10.2//kubernetes-model-autoscaling-5.10.2.jar -kubernetes-model-batch/5.10.2//kubernetes-model-batch-5.10.2.jar -kubernetes-model-certificates/5.10.2//kubernetes-model-certificates-5.10.2.jar -kubernetes-model-common/5.10.2//kubernetes-model-common-5.10.2.jar -kubernetes-model-coordination/5.10.2//kubernetes-model-coordination-5.10.2.jar -kubernetes-model-core/5.10.2//kubernetes-model-core-5.10.2.jar -kubernetes-model-discovery/5.10.2//kubernetes-model-discovery-5.10.2.jar -kubernetes-model-events/5.10.2//kubernetes-model-events-5.10.2.jar -kubernetes-model-extensions/5.10.2//kubernetes-model-extensions-5.10.2.jar -kubernetes-model-flowcontrol/5.10.2//kubernetes-model-flowcontrol-5.10.2.jar -kubernetes-model-metrics/5.10.2//kubernetes-model-metrics-5.10.2.jar -kubernetes-model-networking/5.10.2//kubernetes-model-networking-5.10.2.jar -kubernetes-model-node/5.10.2//kubernetes-model-node-5.10.2.jar -kubernetes-model-policy/5.10.2//kubernetes-model-policy-5.10.2.jar -kubernetes-model-rbac/5.10.2//kubernetes-model-rbac-5.10.2.jar -kubernetes-model-scheduling/5.10.2//kubernetes-model-scheduling-5.10.2.jar -kubernetes-model-storageclass/5.10.2//kubernetes-model-storageclass-5.10.2.jar +kubernetes-client/5.12.1//kubernetes-client-5.12.1.jar +kubernetes-model-admissionregistration/5.12.1//kubernetes-model-admissionregistration-5.12.1.jar +kubernetes-model-apiextensions/5.12.1//kubernetes-model-apiextensions-5.12.1.jar +kubernetes-model-apps/5.12.1//kubernetes-model-apps-5.12.1.jar +kubernetes-model-autoscaling/5.12.1//kubernetes-model-autoscaling-5.12.1.jar +kubernetes-model-batch/5.12.1//kubernetes-model-batch-5.12.1.jar +kubernetes-model-certificates/5.12.1//kubernetes-model-certificates-5.12.1.jar +kubernetes-model-common/5.12.1//kubernetes-model-common-5.12.1.jar +kubernetes-model-coordination/5.12.1//kubernetes-model-coordination-5.12.1.jar +kubernetes-model-core/5.12.1//kubernetes-model-core-5.12.1.jar +kubernetes-model-discovery/5.12.1//kubernetes-model-discovery-5.12.1.jar +kubernetes-model-events/5.12.1//kubernetes-model-events-5.12.1.jar +kubernetes-model-extensions/5.12.1//kubernetes-model-extensions-5.12.1.jar +kubernetes-model-flowcontrol/5.12.1//kubernetes-model-flowcontrol-5.12.1.jar +kubernetes-model-metrics/5.12.1//kubernetes-model-metrics-5.12.1.jar +kubernetes-model-networking/5.12.1//kubernetes-model-networking-5.12.1.jar +kubernetes-model-node/5.12.1//kubernetes-model-node-5.12.1.jar +kubernetes-model-policy/5.12.1//kubernetes-model-policy-5.12.1.jar +kubernetes-model-rbac/5.12.1//kubernetes-model-rbac-5.12.1.jar +kubernetes-model-scheduling/5.12.1//kubernetes-model-scheduling-5.12.1.jar +kubernetes-model-storageclass/5.12.1//kubernetes-model-storageclass-5.12.1.jar lapack/2.2.1//lapack-2.2.1.jar leveldbjni-all/1.8//leveldbjni-all-1.8.jar libfb303/0.9.3//libfb303-0.9.3.jar @@ -179,36 +178,38 @@ log4j-core/2.17.1//log4j-core-2.17.1.jar log4j-slf4j-impl/2.17.1//log4j-slf4j-impl-2.17.1.jar logging-interceptor/3.12.12//logging-interceptor-3.12.12.jar lz4-java/1.8.0//lz4-java-1.8.0.jar -macro-compat_2.12/1.1.1//macro-compat_2.12-1.1.1.jar mesos/1.4.3/shaded-protobuf/mesos-1.4.3-shaded-protobuf.jar -metrics-core/4.2.2//metrics-core-4.2.2.jar -metrics-graphite/4.2.2//metrics-graphite-4.2.2.jar -metrics-jmx/4.2.2//metrics-jmx-4.2.2.jar -metrics-json/4.2.2//metrics-json-4.2.2.jar -metrics-jvm/4.2.2//metrics-jvm-4.2.2.jar +metrics-core/4.2.7//metrics-core-4.2.7.jar +metrics-graphite/4.2.7//metrics-graphite-4.2.7.jar +metrics-jmx/4.2.7//metrics-jmx-4.2.7.jar +metrics-json/4.2.7//metrics-json-4.2.7.jar +metrics-jvm/4.2.7//metrics-jvm-4.2.7.jar minlog/1.3.0//minlog-1.3.0.jar -netty-all/4.1.72.Final//netty-all-4.1.72.Final.jar -netty-buffer/4.1.72.Final//netty-buffer-4.1.72.Final.jar -netty-codec/4.1.72.Final//netty-codec-4.1.72.Final.jar -netty-common/4.1.72.Final//netty-common-4.1.72.Final.jar -netty-handler/4.1.72.Final//netty-handler-4.1.72.Final.jar -netty-resolver/4.1.72.Final//netty-resolver-4.1.72.Final.jar -netty-tcnative-classes/2.0.46.Final//netty-tcnative-classes-2.0.46.Final.jar -netty-transport-classes-epoll/4.1.72.Final//netty-transport-classes-epoll-4.1.72.Final.jar -netty-transport-classes-kqueue/4.1.72.Final//netty-transport-classes-kqueue-4.1.72.Final.jar -netty-transport-native-epoll/4.1.72.Final/linux-aarch_64/netty-transport-native-epoll-4.1.72.Final-linux-aarch_64.jar -netty-transport-native-epoll/4.1.72.Final/linux-x86_64/netty-transport-native-epoll-4.1.72.Final-linux-x86_64.jar -netty-transport-native-kqueue/4.1.72.Final/osx-aarch_64/netty-transport-native-kqueue-4.1.72.Final-osx-aarch_64.jar -netty-transport-native-kqueue/4.1.72.Final/osx-x86_64/netty-transport-native-kqueue-4.1.72.Final-osx-x86_64.jar -netty-transport-native-unix-common/4.1.72.Final//netty-transport-native-unix-common-4.1.72.Final.jar -netty-transport/4.1.72.Final//netty-transport-4.1.72.Final.jar +netty-all/4.1.74.Final//netty-all-4.1.74.Final.jar +netty-buffer/4.1.74.Final//netty-buffer-4.1.74.Final.jar +netty-codec/4.1.74.Final//netty-codec-4.1.74.Final.jar +netty-common/4.1.74.Final//netty-common-4.1.74.Final.jar +netty-handler/4.1.74.Final//netty-handler-4.1.74.Final.jar +netty-resolver/4.1.74.Final//netty-resolver-4.1.74.Final.jar +netty-tcnative-classes/2.0.48.Final//netty-tcnative-classes-2.0.48.Final.jar +netty-transport-classes-epoll/4.1.74.Final//netty-transport-classes-epoll-4.1.74.Final.jar +netty-transport-classes-kqueue/4.1.74.Final//netty-transport-classes-kqueue-4.1.74.Final.jar +netty-transport-native-epoll/4.1.74.Final/linux-aarch_64/netty-transport-native-epoll-4.1.74.Final-linux-aarch_64.jar +netty-transport-native-epoll/4.1.74.Final/linux-x86_64/netty-transport-native-epoll-4.1.74.Final-linux-x86_64.jar +netty-transport-native-kqueue/4.1.74.Final/osx-aarch_64/netty-transport-native-kqueue-4.1.74.Final-osx-aarch_64.jar +netty-transport-native-kqueue/4.1.74.Final/osx-x86_64/netty-transport-native-kqueue-4.1.74.Final-osx-x86_64.jar +netty-transport-native-unix-common/4.1.74.Final//netty-transport-native-unix-common-4.1.74.Final.jar +netty-transport/4.1.74.Final//netty-transport-4.1.74.Final.jar objenesis/3.2//objenesis-3.2.jar okhttp/3.12.12//okhttp-3.12.12.jar okio/1.14.0//okio-1.14.0.jar opencsv/2.3//opencsv-2.3.jar -orc-core/1.7.2//orc-core-1.7.2.jar -orc-mapreduce/1.7.2//orc-mapreduce-1.7.2.jar -orc-shims/1.7.2//orc-shims-1.7.2.jar +opentracing-api/0.33.0//opentracing-api-0.33.0.jar +opentracing-noop/0.33.0//opentracing-noop-0.33.0.jar +opentracing-util/0.33.0//opentracing-util-0.33.0.jar +orc-core/1.7.3//orc-core-1.7.3.jar +orc-mapreduce/1.7.3//orc-mapreduce-1.7.3.jar +orc-shims/1.7.3//orc-shims-1.7.3.jar oro/2.0.8//oro-2.0.8.jar osgi-resource-locator/1.0.3//osgi-resource-locator-1.0.3.jar paranamer/2.8//paranamer-2.8.jar @@ -220,7 +221,7 @@ parquet-hadoop/1.12.2//parquet-hadoop-1.12.2.jar parquet-jackson/1.12.2//parquet-jackson-1.12.2.jar pickle/1.2//pickle-1.2.jar protobuf-java/2.5.0//protobuf-java-2.5.0.jar -py4j/0.10.9.3//py4j-0.10.9.3.jar +py4j/0.10.9.4//py4j-0.10.9.4.jar remotetea-oncrpc/1.1.2//remotetea-oncrpc-1.1.2.jar rocksdbjni/6.20.3//rocksdbjni-6.20.3.jar scala-collection-compat_2.12/2.1.1//scala-collection-compat_2.12-2.1.1.jar @@ -229,10 +230,10 @@ scala-library/2.12.15//scala-library-2.12.15.jar scala-parser-combinators_2.12/1.1.2//scala-parser-combinators_2.12-1.1.2.jar scala-reflect/2.12.15//scala-reflect-2.12.15.jar scala-xml_2.12/1.2.0//scala-xml_2.12-1.2.0.jar -shapeless_2.12/2.3.3//shapeless_2.12-2.3.3.jar -shims/0.9.23//shims-0.9.23.jar +shapeless_2.12/2.3.7//shapeless_2.12-2.3.7.jar +shims/0.9.25//shims-0.9.25.jar slf4j-api/1.7.32//slf4j-api-1.7.32.jar -snakeyaml/1.28//snakeyaml-1.28.jar +snakeyaml/1.30//snakeyaml-1.30.jar snappy-java/1.1.8.4//snappy-java-1.1.8.4.jar spire-macros_2.12/0.17.0//spire-macros_2.12-0.17.0.jar spire-platform_2.12/0.17.0//spire-platform_2.12-0.17.0.jar @@ -242,7 +243,7 @@ stax-api/1.0.1//stax-api-1.0.1.jar stream/2.9.6//stream-2.9.6.jar super-csv/2.2.0//super-csv-2.2.0.jar threeten-extra/1.5.0//threeten-extra-1.5.0.jar -tink/1.6.0//tink-1.6.0.jar +tink/1.6.1//tink-1.6.1.jar transaction-api/1.1//transaction-api-1.1.jar univocity-parsers/2.9.1//univocity-parsers-2.9.1.jar velocity/1.5//velocity-1.5.jar @@ -252,4 +253,4 @@ xz/1.8//xz-1.8.jar zjsonpatch/0.3.0//zjsonpatch-0.3.0.jar zookeeper-jute/3.6.2//zookeeper-jute-3.6.2.jar zookeeper/3.6.2//zookeeper-3.6.2.jar -zstd-jni/1.5.1-1//zstd-jni-1.5.1-1.jar +zstd-jni/1.5.2-1//zstd-jni-1.5.2-1.jar diff --git a/dev/lint-python b/dev/lint-python index c40198e87c2d6..f0ca8832be057 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -18,7 +18,7 @@ # define test binaries + versions FLAKE8_BUILD="flake8" MINIMUM_FLAKE8="3.9.0" -MINIMUM_MYPY="0.910" +MINIMUM_MYPY="0.920" MYPY_BUILD="mypy" PYTEST_BUILD="pytest" @@ -127,7 +127,6 @@ function mypy_examples_test { echo "starting mypy examples test..." MYPY_REPORT=$( (MYPYPATH=python $MYPY_BUILD \ - --allow-untyped-defs \ --config-file python/mypy.ini \ --exclude "mllib/*" \ examples/src/main/python/) 2>&1) @@ -152,6 +151,15 @@ function mypy_test { return fi + _MYPY_VERSION=($($MYPY_BUILD --version)) + MYPY_VERSION="${_MYPY_VERSION[1]}" + EXPECTED_MYPY="$(satisfies_min_version $MYPY_VERSION $MINIMUM_MYPY)" + + if [[ "$EXPECTED_MYPY" == "False" ]]; then + echo "The minimum mypy version needs to be $MINIMUM_MYPY. Your current version is $MYPY_VERSION. Skipping for now." + return + fi + mypy_annotation_test mypy_examples_test mypy_data_test diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 8d09c530dfb7f..e21a39a688170 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -135,11 +135,12 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): continue_maybe(msg) had_conflicts = True + # First commit author should be considered as the primary author when the rank is the same commit_authors = run_cmd( - ["git", "log", "HEAD..%s" % pr_branch_name, "--pretty=format:%an <%ae>"] + ["git", "log", "HEAD..%s" % pr_branch_name, "--pretty=format:%an <%ae>", "--reverse"] ).split("\n") distinct_authors = sorted( - set(commit_authors), key=lambda x: commit_authors.count(x), reverse=True + list(dict.fromkeys(commit_authors)), key=lambda x: commit_authors.count(x), reverse=True ) primary_author = input( 'Enter primary author in the format of "name " [%s]: ' % distinct_authors[0] diff --git a/dev/package-lock.json b/dev/package-lock.json index a57f45bcf7184..c2a61b389ac53 100644 --- a/dev/package-lock.json +++ b/dev/package-lock.json @@ -1,979 +1,2244 @@ { - "requires": true, - "lockfileVersion": 1, - "dependencies": { - "@babel/code-frame": { - "version": "7.12.11", - "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.12.11.tgz", - "integrity": "sha512-Zt1yodBx1UcyiePMSkWnU4hPqhwq7hGi2nFL1LeA3EUl+q2LQx16MISgJ0+z7dnmgvP9QtIleuETGOiOH1RcIw==", - "dev": true, - "requires": { - "@babel/highlight": "^7.10.4" - } - }, - "@babel/helper-validator-identifier": { - "version": "7.14.0", - "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.14.0.tgz", - "integrity": "sha512-V3ts7zMSu5lfiwWDVWzRDGIN+lnCEUdaXgtVHJgLb1rGaA6jMrtB9EmE7L18foXJIE8Un/A/h6NJfGQp/e1J4A==", - "dev": true - }, - "@babel/highlight": { - "version": "7.14.0", - "resolved": "https://registry.npmjs.org/@babel/highlight/-/highlight-7.14.0.tgz", - "integrity": "sha512-YSCOwxvTYEIMSGaBQb5kDDsCopDdiUGsqpatp3fOlI4+2HQSkTmEVWnVuySdAC5EWCqSWWTv0ib63RjR7dTBdg==", - "dev": true, - "requires": { - "@babel/helper-validator-identifier": "^7.14.0", - "chalk": "^2.0.0", - "js-tokens": "^4.0.0" - }, - "dependencies": { - "chalk": { - "version": "2.4.2", - "resolved": "https://registry.npmjs.org/chalk/-/chalk-2.4.2.tgz", - "integrity": "sha512-Mti+f9lpJNcwF4tWV8/OrTTtF1gZi+f8FqlyAdouralcFWFQWF2+NgCHShjkCb+IFBLq9buZwE1xckQU4peSuQ==", - "dev": true, - "requires": { - "ansi-styles": "^3.2.1", - "escape-string-regexp": "^1.0.5", - "supports-color": "^5.3.0" - } - } - } - }, - "@eslint/eslintrc": { - "version": "0.4.0", - "resolved": "https://registry.npmjs.org/@eslint/eslintrc/-/eslintrc-0.4.0.tgz", - "integrity": "sha512-2ZPCc+uNbjV5ERJr+aKSPRwZgKd2z11x0EgLvb1PURmUrn9QNRXFqje0Ldq454PfAVyaJYyrDvvIKSFP4NnBog==", - "dev": true, - "requires": { - "ajv": "^6.12.4", - "debug": "^4.1.1", - "espree": "^7.3.0", - "globals": "^12.1.0", - "ignore": "^4.0.6", - "import-fresh": "^3.2.1", - "js-yaml": "^3.13.1", - "minimatch": "^3.0.4", - "strip-json-comments": "^3.1.1" - }, - "dependencies": { - "globals": { - "version": "12.4.0", - "resolved": "https://registry.npmjs.org/globals/-/globals-12.4.0.tgz", - "integrity": "sha512-BWICuzzDvDoH54NHKCseDanAhE3CeDorgDL5MT6LMXXj2WCnd9UC2szdk4AWLfjdgNBCXLUanXYcpBBKOSWGwg==", - "dev": true, - "requires": { - "type-fest": "^0.8.1" - } - } - } - }, - "acorn": { - "version": "7.4.1", - "resolved": "https://registry.npmjs.org/acorn/-/acorn-7.4.1.tgz", - "integrity": "sha512-nQyp0o1/mNdbTO1PO6kHkwSrmgZ0MT/jCCpNiwbUjGoRN4dlBhqJtoQuCnEOKzgTVwg0ZWiCoQy6SxMebQVh8A==", - "dev": true - }, - "acorn-jsx": { - "version": "5.3.1", - "resolved": "https://registry.npmjs.org/acorn-jsx/-/acorn-jsx-5.3.1.tgz", - "integrity": "sha512-K0Ptm/47OKfQRpNQ2J/oIN/3QYiK6FwW+eJbILhsdxh2WTLdl+30o8aGdTbm5JbffpFFAg/g+zi1E+jvJha5ng==", - "dev": true - }, - "ajv": { - "version": "6.12.6", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", - "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", - "dev": true, - "requires": { - "fast-deep-equal": "^3.1.1", - "fast-json-stable-stringify": "^2.0.0", - "json-schema-traverse": "^0.4.1", - "uri-js": "^4.2.2" - } - }, - "ansi-colors": { - "version": "4.1.1", - "resolved": "https://registry.npmjs.org/ansi-colors/-/ansi-colors-4.1.1.tgz", - "integrity": "sha512-JoX0apGbHaUJBNl6yF+p6JAFYZ666/hhCGKN5t9QFjbJQKUU/g8MNbFDbvfrgKXvI1QpZplPOnwIo99lX/AAmA==", - "dev": true - }, - "ansi-regex": { - "version": "5.0.0", - "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.0.tgz", - "integrity": "sha512-bY6fj56OUQ0hU1KjFNDQuJFezqKdrAyFdIevADiqrWHwSlbmBNMHp5ak2f40Pm8JTFyM2mqxkG6ngkHO11f/lg==", - "dev": true - }, - "ansi-styles": { - "version": "3.2.1", - "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-3.2.1.tgz", - "integrity": "sha512-VT0ZI6kZRdTh8YyJw3SMbYm/u+NqfsAxEpWO0Pf9sq8/e94WxxOpPKx9FR1FlyCtOVDNOQ+8ntlqFxiRc+r5qA==", - "dev": true, - "requires": { - "color-convert": "^1.9.0" - } - }, - "argparse": { - "version": "1.0.10", - "resolved": "https://registry.npmjs.org/argparse/-/argparse-1.0.10.tgz", - "integrity": "sha512-o5Roy6tNG4SL/FOkCAN6RzjiakZS25RLYFrcMttJqbdd8BWrnA+fGz57iN5Pb06pvBGvl5gQ0B48dJlslXvoTg==", - "dev": true, - "requires": { - "sprintf-js": "~1.0.2" - } - }, - "astral-regex": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/astral-regex/-/astral-regex-2.0.0.tgz", - "integrity": "sha512-Z7tMw1ytTXt5jqMcOP+OQteU1VuNK9Y02uuJtKQ1Sv69jXQKKg5cibLwGJow8yzZP+eAc18EmLGPal0bp36rvQ==", - "dev": true - }, - "balanced-match": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", - "integrity": "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==", - "dev": true - }, - "brace-expansion": { - "version": "1.1.11", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.11.tgz", - "integrity": "sha512-iCuPHDFgrHX7H2vEI/5xpz07zSHB00TpugqhmYtVmMO6518mCuRMoOYFldEBl0g187ufozdaHgWKcYFb61qGiA==", - "dev": true, - "requires": { - "balanced-match": "^1.0.0", - "concat-map": "0.0.1" - } - }, - "callsites": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/callsites/-/callsites-3.1.0.tgz", - "integrity": "sha512-P8BjAsXvZS+VIDUI11hHCQEv74YT67YUi5JJFNWIqL235sBmjX4+qx9Muvls5ivyNENctx46xQLQ3aTuE7ssaQ==", - "dev": true - }, + "name": "dev", + "lockfileVersion": 2, + "requires": true, + "packages": { + "": { + "devDependencies": { + "ansi-regex": "^5.0.1", + "eslint": "^7.25.0" + } + }, + "node_modules/@babel/code-frame": { + "version": "7.12.11", + "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.12.11.tgz", + "integrity": "sha512-Zt1yodBx1UcyiePMSkWnU4hPqhwq7hGi2nFL1LeA3EUl+q2LQx16MISgJ0+z7dnmgvP9QtIleuETGOiOH1RcIw==", + "dev": true, + "dependencies": { + "@babel/highlight": "^7.10.4" + } + }, + "node_modules/@babel/helper-validator-identifier": { + "version": "7.14.0", + "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.14.0.tgz", + "integrity": "sha512-V3ts7zMSu5lfiwWDVWzRDGIN+lnCEUdaXgtVHJgLb1rGaA6jMrtB9EmE7L18foXJIE8Un/A/h6NJfGQp/e1J4A==", + "dev": true + }, + "node_modules/@babel/highlight": { + "version": "7.14.0", + "resolved": "https://registry.npmjs.org/@babel/highlight/-/highlight-7.14.0.tgz", + "integrity": "sha512-YSCOwxvTYEIMSGaBQb5kDDsCopDdiUGsqpatp3fOlI4+2HQSkTmEVWnVuySdAC5EWCqSWWTv0ib63RjR7dTBdg==", + "dev": true, + "dependencies": { + "@babel/helper-validator-identifier": "^7.14.0", + "chalk": "^2.0.0", + "js-tokens": "^4.0.0" + } + }, + "node_modules/@babel/highlight/node_modules/chalk": { + "version": "2.4.2", + "resolved": "https://registry.npmjs.org/chalk/-/chalk-2.4.2.tgz", + "integrity": "sha512-Mti+f9lpJNcwF4tWV8/OrTTtF1gZi+f8FqlyAdouralcFWFQWF2+NgCHShjkCb+IFBLq9buZwE1xckQU4peSuQ==", + "dev": true, + "dependencies": { + "ansi-styles": "^3.2.1", + "escape-string-regexp": "^1.0.5", + "supports-color": "^5.3.0" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/@eslint/eslintrc": { + "version": "0.4.0", + "resolved": "https://registry.npmjs.org/@eslint/eslintrc/-/eslintrc-0.4.0.tgz", + "integrity": "sha512-2ZPCc+uNbjV5ERJr+aKSPRwZgKd2z11x0EgLvb1PURmUrn9QNRXFqje0Ldq454PfAVyaJYyrDvvIKSFP4NnBog==", + "dev": true, + "dependencies": { + "ajv": "^6.12.4", + "debug": "^4.1.1", + "espree": "^7.3.0", + "globals": "^12.1.0", + "ignore": "^4.0.6", + "import-fresh": "^3.2.1", + "js-yaml": "^3.13.1", + "minimatch": "^3.0.4", + "strip-json-comments": "^3.1.1" + }, + "engines": { + "node": "^10.12.0 || >=12.0.0" + } + }, + "node_modules/@eslint/eslintrc/node_modules/globals": { + "version": "12.4.0", + "resolved": "https://registry.npmjs.org/globals/-/globals-12.4.0.tgz", + "integrity": "sha512-BWICuzzDvDoH54NHKCseDanAhE3CeDorgDL5MT6LMXXj2WCnd9UC2szdk4AWLfjdgNBCXLUanXYcpBBKOSWGwg==", + "dev": true, + "dependencies": { + "type-fest": "^0.8.1" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/acorn": { + "version": "7.4.1", + "resolved": "https://registry.npmjs.org/acorn/-/acorn-7.4.1.tgz", + "integrity": "sha512-nQyp0o1/mNdbTO1PO6kHkwSrmgZ0MT/jCCpNiwbUjGoRN4dlBhqJtoQuCnEOKzgTVwg0ZWiCoQy6SxMebQVh8A==", + "dev": true, + "bin": { + "acorn": "bin/acorn" + }, + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/acorn-jsx": { + "version": "5.3.1", + "resolved": "https://registry.npmjs.org/acorn-jsx/-/acorn-jsx-5.3.1.tgz", + "integrity": "sha512-K0Ptm/47OKfQRpNQ2J/oIN/3QYiK6FwW+eJbILhsdxh2WTLdl+30o8aGdTbm5JbffpFFAg/g+zi1E+jvJha5ng==", + "dev": true, + "peerDependencies": { + "acorn": "^6.0.0 || ^7.0.0 || ^8.0.0" + } + }, + "node_modules/ajv": { + "version": "6.12.6", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", + "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", + "dev": true, + "dependencies": { + "fast-deep-equal": "^3.1.1", + "fast-json-stable-stringify": "^2.0.0", + "json-schema-traverse": "^0.4.1", + "uri-js": "^4.2.2" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/epoberezkin" + } + }, + "node_modules/ansi-colors": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/ansi-colors/-/ansi-colors-4.1.1.tgz", + "integrity": "sha512-JoX0apGbHaUJBNl6yF+p6JAFYZ666/hhCGKN5t9QFjbJQKUU/g8MNbFDbvfrgKXvI1QpZplPOnwIo99lX/AAmA==", + "dev": true, + "engines": { + "node": ">=6" + } + }, + "node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/ansi-styles": { + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-3.2.1.tgz", + "integrity": "sha512-VT0ZI6kZRdTh8YyJw3SMbYm/u+NqfsAxEpWO0Pf9sq8/e94WxxOpPKx9FR1FlyCtOVDNOQ+8ntlqFxiRc+r5qA==", + "dev": true, + "dependencies": { + "color-convert": "^1.9.0" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/argparse": { + "version": "1.0.10", + "resolved": "https://registry.npmjs.org/argparse/-/argparse-1.0.10.tgz", + "integrity": "sha512-o5Roy6tNG4SL/FOkCAN6RzjiakZS25RLYFrcMttJqbdd8BWrnA+fGz57iN5Pb06pvBGvl5gQ0B48dJlslXvoTg==", + "dev": true, + "dependencies": { + "sprintf-js": "~1.0.2" + } + }, + "node_modules/astral-regex": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/astral-regex/-/astral-regex-2.0.0.tgz", + "integrity": "sha512-Z7tMw1ytTXt5jqMcOP+OQteU1VuNK9Y02uuJtKQ1Sv69jXQKKg5cibLwGJow8yzZP+eAc18EmLGPal0bp36rvQ==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/balanced-match": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", + "integrity": "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==", + "dev": true + }, + "node_modules/brace-expansion": { + "version": "1.1.11", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.11.tgz", + "integrity": "sha512-iCuPHDFgrHX7H2vEI/5xpz07zSHB00TpugqhmYtVmMO6518mCuRMoOYFldEBl0g187ufozdaHgWKcYFb61qGiA==", + "dev": true, + "dependencies": { + "balanced-match": "^1.0.0", + "concat-map": "0.0.1" + } + }, + "node_modules/callsites": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/callsites/-/callsites-3.1.0.tgz", + "integrity": "sha512-P8BjAsXvZS+VIDUI11hHCQEv74YT67YUi5JJFNWIqL235sBmjX4+qx9Muvls5ivyNENctx46xQLQ3aTuE7ssaQ==", + "dev": true, + "engines": { + "node": ">=6" + } + }, + "node_modules/chalk": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.1.tgz", + "integrity": "sha512-diHzdDKxcU+bAsUboHLPEDQiw0qEe0qd7SYUn3HgcFlWgbDcfLGswOHYeGrHKzG9z6UYf01d9VFMfZxPM1xZSg==", + "dev": true, + "dependencies": { + "ansi-styles": "^4.1.0", + "supports-color": "^7.1.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/chalk?sponsor=1" + } + }, + "node_modules/chalk/node_modules/ansi-styles": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", + "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "dev": true, + "dependencies": { + "color-convert": "^2.0.1" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/chalk/node_modules/color-convert": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", + "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "dev": true, + "dependencies": { + "color-name": "~1.1.4" + }, + "engines": { + "node": ">=7.0.0" + } + }, + "node_modules/chalk/node_modules/color-name": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", + "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", + "dev": true + }, + "node_modules/chalk/node_modules/has-flag": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", + "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/chalk/node_modules/supports-color": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", + "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", + "dev": true, + "dependencies": { + "has-flag": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/color-convert": { + "version": "1.9.3", + "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-1.9.3.tgz", + "integrity": "sha512-QfAUtd+vFdAtFQcC8CCyYt1fYWxSqAiK2cSD6zDB8N3cpsEBAvRxp9zOGg6G/SHHJYAT88/az/IuDGALsNVbGg==", + "dev": true, + "dependencies": { + "color-name": "1.1.3" + } + }, + "node_modules/color-name": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.3.tgz", + "integrity": "sha1-p9BVi9icQveV3UIyj3QIMcpTvCU=", + "dev": true + }, + "node_modules/concat-map": { + "version": "0.0.1", + "resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz", + "integrity": "sha1-2Klr13/Wjfd5OnMDajug1UBdR3s=", + "dev": true + }, + "node_modules/cross-spawn": { + "version": "7.0.3", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", + "integrity": "sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==", + "dev": true, + "dependencies": { + "path-key": "^3.1.0", + "shebang-command": "^2.0.0", + "which": "^2.0.1" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/debug": { + "version": "4.3.1", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.1.tgz", + "integrity": "sha512-doEwdvm4PCeK4K3RQN2ZC2BYUBaxwLARCqZmMjtF8a51J2Rb0xpVloFRnCODwqjpwnAoao4pelN8l3RJdv3gRQ==", + "dev": true, + "dependencies": { + "ms": "2.1.2" + }, + "engines": { + "node": ">=6.0" + }, + "peerDependenciesMeta": { + "supports-color": { + "optional": true + } + } + }, + "node_modules/deep-is": { + "version": "0.1.3", + "resolved": "https://registry.npmjs.org/deep-is/-/deep-is-0.1.3.tgz", + "integrity": "sha1-s2nW+128E+7PUk+RsHD+7cNXzzQ=", + "dev": true + }, + "node_modules/doctrine": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/doctrine/-/doctrine-3.0.0.tgz", + "integrity": "sha512-yS+Q5i3hBf7GBkd4KG8a7eBNNWNGLTaEwwYWUijIYM7zrlYDM0BFXHjjPWlWZ1Rg7UaddZeIDmi9jF3HmqiQ2w==", + "dev": true, + "dependencies": { + "esutils": "^2.0.2" + }, + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/emoji-regex": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", + "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", + "dev": true + }, + "node_modules/enquirer": { + "version": "2.3.6", + "resolved": "https://registry.npmjs.org/enquirer/-/enquirer-2.3.6.tgz", + "integrity": "sha512-yjNnPr315/FjS4zIsUxYguYUPP2e1NK4d7E7ZOLiyYCcbFBiTMyID+2wvm2w6+pZ/odMA7cRkjhsPbltwBOrLg==", + "dev": true, + "dependencies": { + "ansi-colors": "^4.1.1" + }, + "engines": { + "node": ">=8.6" + } + }, + "node_modules/escape-string-regexp": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-1.0.5.tgz", + "integrity": "sha1-G2HAViGQqN/2rjuyzwIAyhMLhtQ=", + "dev": true, + "engines": { + "node": ">=0.8.0" + } + }, + "node_modules/eslint": { + "version": "7.25.0", + "resolved": "https://registry.npmjs.org/eslint/-/eslint-7.25.0.tgz", + "integrity": "sha512-TVpSovpvCNpLURIScDRB6g5CYu/ZFq9GfX2hLNIV4dSBKxIWojeDODvYl3t0k0VtMxYeR8OXPCFE5+oHMlGfhw==", + "dev": true, + "dependencies": { + "@babel/code-frame": "7.12.11", + "@eslint/eslintrc": "^0.4.0", + "ajv": "^6.10.0", + "chalk": "^4.0.0", + "cross-spawn": "^7.0.2", + "debug": "^4.0.1", + "doctrine": "^3.0.0", + "enquirer": "^2.3.5", + "eslint-scope": "^5.1.1", + "eslint-utils": "^2.1.0", + "eslint-visitor-keys": "^2.0.0", + "espree": "^7.3.1", + "esquery": "^1.4.0", + "esutils": "^2.0.2", + "file-entry-cache": "^6.0.1", + "functional-red-black-tree": "^1.0.1", + "glob-parent": "^5.0.0", + "globals": "^13.6.0", + "ignore": "^4.0.6", + "import-fresh": "^3.0.0", + "imurmurhash": "^0.1.4", + "is-glob": "^4.0.0", + "js-yaml": "^3.13.1", + "json-stable-stringify-without-jsonify": "^1.0.1", + "levn": "^0.4.1", + "lodash": "^4.17.21", + "minimatch": "^3.0.4", + "natural-compare": "^1.4.0", + "optionator": "^0.9.1", + "progress": "^2.0.0", + "regexpp": "^3.1.0", + "semver": "^7.2.1", + "strip-ansi": "^6.0.0", + "strip-json-comments": "^3.1.0", + "table": "^6.0.4", + "text-table": "^0.2.0", + "v8-compile-cache": "^2.0.3" + }, + "bin": { + "eslint": "bin/eslint.js" + }, + "engines": { + "node": "^10.12.0 || >=12.0.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/eslint-scope": { + "version": "5.1.1", + "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-5.1.1.tgz", + "integrity": "sha512-2NxwbF/hZ0KpepYN0cNbo+FN6XoK7GaHlQhgx/hIZl6Va0bF45RQOOwhLIy8lQDbuCiadSLCBnH2CFYquit5bw==", + "dev": true, + "dependencies": { + "esrecurse": "^4.3.0", + "estraverse": "^4.1.1" + }, + "engines": { + "node": ">=8.0.0" + } + }, + "node_modules/eslint-utils": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/eslint-utils/-/eslint-utils-2.1.0.tgz", + "integrity": "sha512-w94dQYoauyvlDc43XnGB8lU3Zt713vNChgt4EWwhXAP2XkBvndfxF0AgIqKOOasjPIPzj9JqgwkwbCYD0/V3Zg==", + "dev": true, + "dependencies": { + "eslint-visitor-keys": "^1.1.0" + }, + "engines": { + "node": ">=6" + }, + "funding": { + "url": "https://github.com/sponsors/mysticatea" + } + }, + "node_modules/eslint-utils/node_modules/eslint-visitor-keys": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-1.3.0.tgz", + "integrity": "sha512-6J72N8UNa462wa/KFODt/PJ3IU60SDpC3QXC1Hjc1BXXpfL2C9R5+AU7jhe0F6GREqVMh4Juu+NY7xn+6dipUQ==", + "dev": true, + "engines": { + "node": ">=4" + } + }, + "node_modules/eslint-visitor-keys": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-2.1.0.tgz", + "integrity": "sha512-0rSmRBzXgDzIsD6mGdJgevzgezI534Cer5L/vyMX0kHzT/jiB43jRhd9YUlMGYLQy2zprNmoT8qasCGtY+QaKw==", + "dev": true, + "engines": { + "node": ">=10" + } + }, + "node_modules/espree": { + "version": "7.3.1", + "resolved": "https://registry.npmjs.org/espree/-/espree-7.3.1.tgz", + "integrity": "sha512-v3JCNCE64umkFpmkFGqzVKsOT0tN1Zr+ueqLZfpV1Ob8e+CEgPWa+OxCoGH3tnhimMKIaBm4m/vaRpJ/krRz2g==", + "dev": true, + "dependencies": { + "acorn": "^7.4.0", + "acorn-jsx": "^5.3.1", + "eslint-visitor-keys": "^1.3.0" + }, + "engines": { + "node": "^10.12.0 || >=12.0.0" + } + }, + "node_modules/espree/node_modules/eslint-visitor-keys": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-1.3.0.tgz", + "integrity": "sha512-6J72N8UNa462wa/KFODt/PJ3IU60SDpC3QXC1Hjc1BXXpfL2C9R5+AU7jhe0F6GREqVMh4Juu+NY7xn+6dipUQ==", + "dev": true, + "engines": { + "node": ">=4" + } + }, + "node_modules/esprima": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/esprima/-/esprima-4.0.1.tgz", + "integrity": "sha512-eGuFFw7Upda+g4p+QHvnW0RyTX/SVeJBDM/gCtMARO0cLuT2HcEKnTPvhjV6aGeqrCB/sbNop0Kszm0jsaWU4A==", + "dev": true, + "bin": { + "esparse": "bin/esparse.js", + "esvalidate": "bin/esvalidate.js" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/esquery": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/esquery/-/esquery-1.4.0.tgz", + "integrity": "sha512-cCDispWt5vHHtwMY2YrAQ4ibFkAL8RbH5YGBnZBc90MolvvfkkQcJro/aZiAQUlQ3qgrYS6D6v8Gc5G5CQsc9w==", + "dev": true, + "dependencies": { + "estraverse": "^5.1.0" + }, + "engines": { + "node": ">=0.10" + } + }, + "node_modules/esquery/node_modules/estraverse": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-5.2.0.tgz", + "integrity": "sha512-BxbNGGNm0RyRYvUdHpIwv9IWzeM9XClbOxwoATuFdOE7ZE6wHL+HQ5T8hoPM+zHvmKzzsEqhgy0GrQ5X13afiQ==", + "dev": true, + "engines": { + "node": ">=4.0" + } + }, + "node_modules/esrecurse": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/esrecurse/-/esrecurse-4.3.0.tgz", + "integrity": "sha512-KmfKL3b6G+RXvP8N1vr3Tq1kL/oCFgn2NYXEtqP8/L3pKapUA4G8cFVaoF3SU323CD4XypR/ffioHmkti6/Tag==", + "dev": true, + "dependencies": { + "estraverse": "^5.2.0" + }, + "engines": { + "node": ">=4.0" + } + }, + "node_modules/esrecurse/node_modules/estraverse": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-5.2.0.tgz", + "integrity": "sha512-BxbNGGNm0RyRYvUdHpIwv9IWzeM9XClbOxwoATuFdOE7ZE6wHL+HQ5T8hoPM+zHvmKzzsEqhgy0GrQ5X13afiQ==", + "dev": true, + "engines": { + "node": ">=4.0" + } + }, + "node_modules/estraverse": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-4.3.0.tgz", + "integrity": "sha512-39nnKffWz8xN1BU/2c79n9nB9HDzo0niYUqx6xyqUnyoAnQyyWpOTdZEeiCch8BBu515t4wp9ZmgVfVhn9EBpw==", + "dev": true, + "engines": { + "node": ">=4.0" + } + }, + "node_modules/esutils": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/esutils/-/esutils-2.0.3.tgz", + "integrity": "sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/fast-deep-equal": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz", + "integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==", + "dev": true + }, + "node_modules/fast-json-stable-stringify": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/fast-json-stable-stringify/-/fast-json-stable-stringify-2.1.0.tgz", + "integrity": "sha512-lhd/wF+Lk98HZoTCtlVraHtfh5XYijIjalXck7saUtuanSDyLMxnHhSXEDJqHxD7msR8D0uCmqlkwjCV8xvwHw==", + "dev": true + }, + "node_modules/fast-levenshtein": { + "version": "2.0.6", + "resolved": "https://registry.npmjs.org/fast-levenshtein/-/fast-levenshtein-2.0.6.tgz", + "integrity": "sha1-PYpcZog6FqMMqGQ+hR8Zuqd5eRc=", + "dev": true + }, + "node_modules/file-entry-cache": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/file-entry-cache/-/file-entry-cache-6.0.1.tgz", + "integrity": "sha512-7Gps/XWymbLk2QLYK4NzpMOrYjMhdIxXuIvy2QBsLE6ljuodKvdkWs/cpyJJ3CVIVpH0Oi1Hvg1ovbMzLdFBBg==", + "dev": true, + "dependencies": { + "flat-cache": "^3.0.4" + }, + "engines": { + "node": "^10.12.0 || >=12.0.0" + } + }, + "node_modules/flat-cache": { + "version": "3.0.4", + "resolved": "https://registry.npmjs.org/flat-cache/-/flat-cache-3.0.4.tgz", + "integrity": "sha512-dm9s5Pw7Jc0GvMYbshN6zchCA9RgQlzzEZX3vylR9IqFfS8XciblUXOKfW6SiuJ0e13eDYZoZV5wdrev7P3Nwg==", + "dev": true, + "dependencies": { + "flatted": "^3.1.0", + "rimraf": "^3.0.2" + }, + "engines": { + "node": "^10.12.0 || >=12.0.0" + } + }, + "node_modules/flatted": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.1.1.tgz", + "integrity": "sha512-zAoAQiudy+r5SvnSw3KJy5os/oRJYHzrzja/tBDqrZtNhUw8bt6y8OBzMWcjWr+8liV8Eb6yOhw8WZ7VFZ5ZzA==", + "dev": true + }, + "node_modules/fs.realpath": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/fs.realpath/-/fs.realpath-1.0.0.tgz", + "integrity": "sha1-FQStJSMVjKpA20onh8sBQRmU6k8=", + "dev": true + }, + "node_modules/functional-red-black-tree": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/functional-red-black-tree/-/functional-red-black-tree-1.0.1.tgz", + "integrity": "sha1-GwqzvVU7Kg1jmdKcDj6gslIHgyc=", + "dev": true + }, + "node_modules/glob": { + "version": "7.1.6", + "resolved": "https://registry.npmjs.org/glob/-/glob-7.1.6.tgz", + "integrity": "sha512-LwaxwyZ72Lk7vZINtNNrywX0ZuLyStrdDtabefZKAY5ZGJhVtgdznluResxNmPitE0SAO+O26sWTHeKSI2wMBA==", + "dev": true, + "dependencies": { + "fs.realpath": "^1.0.0", + "inflight": "^1.0.4", + "inherits": "2", + "minimatch": "^3.0.4", + "once": "^1.3.0", + "path-is-absolute": "^1.0.0" + }, + "engines": { + "node": "*" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/glob-parent": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", + "integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==", + "dev": true, + "dependencies": { + "is-glob": "^4.0.1" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/globals": { + "version": "13.8.0", + "resolved": "https://registry.npmjs.org/globals/-/globals-13.8.0.tgz", + "integrity": "sha512-rHtdA6+PDBIjeEvA91rpqzEvk/k3/i7EeNQiryiWuJH0Hw9cpyJMAt2jtbAwUaRdhD+573X4vWw6IcjKPasi9Q==", + "dev": true, + "dependencies": { + "type-fest": "^0.20.2" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/globals/node_modules/type-fest": { + "version": "0.20.2", + "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-0.20.2.tgz", + "integrity": "sha512-Ne+eE4r0/iWnpAxD852z3A+N0Bt5RN//NjJwRd2VFHEmrywxf5vsZlh4R6lixl6B+wz/8d+maTSAkN1FIkI3LQ==", + "dev": true, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/has-flag": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-3.0.0.tgz", + "integrity": "sha1-tdRU3CGZriJWmfNGfloH87lVuv0=", + "dev": true, + "engines": { + "node": ">=4" + } + }, + "node_modules/ignore": { + "version": "4.0.6", + "resolved": "https://registry.npmjs.org/ignore/-/ignore-4.0.6.tgz", + "integrity": "sha512-cyFDKrqc/YdcWFniJhzI42+AzS+gNwmUzOSFcRCQYwySuBBBy/KjuxWLZ/FHEH6Moq1NizMOBWyTcv8O4OZIMg==", + "dev": true, + "engines": { + "node": ">= 4" + } + }, + "node_modules/import-fresh": { + "version": "3.3.0", + "resolved": "https://registry.npmjs.org/import-fresh/-/import-fresh-3.3.0.tgz", + "integrity": "sha512-veYYhQa+D1QBKznvhUHxb8faxlrwUnxseDAbAp457E0wLNio2bOSKnjYDhMj+YiAq61xrMGhQk9iXVk5FzgQMw==", + "dev": true, + "dependencies": { + "parent-module": "^1.0.0", + "resolve-from": "^4.0.0" + }, + "engines": { + "node": ">=6" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/imurmurhash": { + "version": "0.1.4", + "resolved": "https://registry.npmjs.org/imurmurhash/-/imurmurhash-0.1.4.tgz", + "integrity": "sha1-khi5srkoojixPcT7a21XbyMUU+o=", + "dev": true, + "engines": { + "node": ">=0.8.19" + } + }, + "node_modules/inflight": { + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/inflight/-/inflight-1.0.6.tgz", + "integrity": "sha1-Sb1jMdfQLQwJvJEKEHW6gWW1bfk=", + "dev": true, + "dependencies": { + "once": "^1.3.0", + "wrappy": "1" + } + }, + "node_modules/inherits": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz", + "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==", + "dev": true + }, + "node_modules/is-extglob": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/is-extglob/-/is-extglob-2.1.1.tgz", + "integrity": "sha1-qIwCU1eR8C7TfHahueqXc8gz+MI=", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-fullwidth-code-point": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz", + "integrity": "sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/is-glob": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.1.tgz", + "integrity": "sha512-5G0tKtBTFImOqDnLB2hG6Bp2qcKEFduo4tZu9MT/H6NQv/ghhy30o55ufafxJ/LdH79LLs2Kfrn85TLKyA7BUg==", + "dev": true, + "dependencies": { + "is-extglob": "^2.1.1" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/isexe": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", + "integrity": "sha1-6PvzdNxVb/iUehDcsFctYz8s+hA=", + "dev": true + }, + "node_modules/js-tokens": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz", + "integrity": "sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==", + "dev": true + }, + "node_modules/js-yaml": { + "version": "3.14.1", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-3.14.1.tgz", + "integrity": "sha512-okMH7OXXJ7YrN9Ok3/SXrnu4iX9yOk+25nqX4imS2npuvTYDmo/QEZoqwZkYaIDk3jVvBOTOIEgEhaLOynBS9g==", + "dev": true, + "dependencies": { + "argparse": "^1.0.7", + "esprima": "^4.0.0" + }, + "bin": { + "js-yaml": "bin/js-yaml.js" + } + }, + "node_modules/json-schema-traverse": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz", + "integrity": "sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==", + "dev": true + }, + "node_modules/json-stable-stringify-without-jsonify": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/json-stable-stringify-without-jsonify/-/json-stable-stringify-without-jsonify-1.0.1.tgz", + "integrity": "sha1-nbe1lJatPzz+8wp1FC0tkwrXJlE=", + "dev": true + }, + "node_modules/levn": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/levn/-/levn-0.4.1.tgz", + "integrity": "sha512-+bT2uH4E5LGE7h/n3evcS/sQlJXCpIp6ym8OWJ5eV6+67Dsql/LaaT7qJBAt2rzfoa/5QBGBhxDix1dMt2kQKQ==", + "dev": true, + "dependencies": { + "prelude-ls": "^1.2.1", + "type-check": "~0.4.0" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/lodash": { + "version": "4.17.21", + "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz", + "integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==", + "dev": true + }, + "node_modules/lodash.clonedeep": { + "version": "4.5.0", + "resolved": "https://registry.npmjs.org/lodash.clonedeep/-/lodash.clonedeep-4.5.0.tgz", + "integrity": "sha1-4j8/nE+Pvd6HJSnBBxhXoIblzO8=", + "dev": true + }, + "node_modules/lodash.flatten": { + "version": "4.4.0", + "resolved": "https://registry.npmjs.org/lodash.flatten/-/lodash.flatten-4.4.0.tgz", + "integrity": "sha1-8xwiIlqWMtK7+OSt2+8kCqdlph8=", + "dev": true + }, + "node_modules/lodash.truncate": { + "version": "4.4.2", + "resolved": "https://registry.npmjs.org/lodash.truncate/-/lodash.truncate-4.4.2.tgz", + "integrity": "sha1-WjUNoLERO4N+z//VgSy+WNbq4ZM=", + "dev": true + }, + "node_modules/lru-cache": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-6.0.0.tgz", + "integrity": "sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==", + "dev": true, + "dependencies": { + "yallist": "^4.0.0" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/minimatch": { + "version": "3.0.4", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.0.4.tgz", + "integrity": "sha512-yJHVQEhyqPLUTgt9B83PXu6W3rx4MvvHvSUvToogpwoGDOUQ+yDrR0HRot+yOCdCO7u4hX3pWft6kWBBcqh0UA==", + "dev": true, + "dependencies": { + "brace-expansion": "^1.1.7" + }, + "engines": { + "node": "*" + } + }, + "node_modules/ms": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", + "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==", + "dev": true + }, + "node_modules/natural-compare": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/natural-compare/-/natural-compare-1.4.0.tgz", + "integrity": "sha1-Sr6/7tdUHywnrPspvbvRXI1bpPc=", + "dev": true + }, + "node_modules/once": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/once/-/once-1.4.0.tgz", + "integrity": "sha1-WDsap3WWHUsROsF9nFC6753Xa9E=", + "dev": true, + "dependencies": { + "wrappy": "1" + } + }, + "node_modules/optionator": { + "version": "0.9.1", + "resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.1.tgz", + "integrity": "sha512-74RlY5FCnhq4jRxVUPKDaRwrVNXMqsGsiW6AJw4XK8hmtm10wC0ypZBLw5IIp85NZMr91+qd1RvvENwg7jjRFw==", + "dev": true, + "dependencies": { + "deep-is": "^0.1.3", + "fast-levenshtein": "^2.0.6", + "levn": "^0.4.1", + "prelude-ls": "^1.2.1", + "type-check": "^0.4.0", + "word-wrap": "^1.2.3" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/parent-module": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/parent-module/-/parent-module-1.0.1.tgz", + "integrity": "sha512-GQ2EWRpQV8/o+Aw8YqtfZZPfNRWZYkbidE9k5rpl/hC3vtHHBfGm2Ifi6qWV+coDGkrUKZAxE3Lot5kcsRlh+g==", + "dev": true, + "dependencies": { + "callsites": "^3.0.0" + }, + "engines": { + "node": ">=6" + } + }, + "node_modules/path-is-absolute": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/path-is-absolute/-/path-is-absolute-1.0.1.tgz", + "integrity": "sha1-F0uSaHNVNP+8es5r9TpanhtcX18=", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/path-key": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", + "integrity": "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/prelude-ls": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/prelude-ls/-/prelude-ls-1.2.1.tgz", + "integrity": "sha512-vkcDPrRZo1QZLbn5RLGPpg/WmIQ65qoWWhcGKf/b5eplkkarX0m9z8ppCat4mlOqUsWpyNuYgO3VRyrYHSzX5g==", + "dev": true, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/progress": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/progress/-/progress-2.0.3.tgz", + "integrity": "sha512-7PiHtLll5LdnKIMw100I+8xJXR5gW2QwWYkT6iJva0bXitZKa/XMrSbdmg3r2Xnaidz9Qumd0VPaMrZlF9V9sA==", + "dev": true, + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/punycode": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.1.1.tgz", + "integrity": "sha512-XRsRjdf+j5ml+y/6GKHPZbrF/8p2Yga0JPtdqTIY2Xe5ohJPD9saDJJLPvp9+NSBprVvevdXZybnj2cv8OEd0A==", + "dev": true, + "engines": { + "node": ">=6" + } + }, + "node_modules/regexpp": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/regexpp/-/regexpp-3.1.0.tgz", + "integrity": "sha512-ZOIzd8yVsQQA7j8GCSlPGXwg5PfmA1mrq0JP4nGhh54LaKN3xdai/vHUDu74pKwV8OxseMS65u2NImosQcSD0Q==", + "dev": true, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/mysticatea" + } + }, + "node_modules/require-from-string": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/require-from-string/-/require-from-string-2.0.2.tgz", + "integrity": "sha512-Xf0nWe6RseziFMu+Ap9biiUbmplq6S9/p+7w7YXP/JBHhrUDDUhwa+vANyubuqfZWTveU//DYVGsDG7RKL/vEw==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/resolve-from": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/resolve-from/-/resolve-from-4.0.0.tgz", + "integrity": "sha512-pb/MYmXstAkysRFx8piNI1tGFNQIFA3vkE3Gq4EuA1dF6gHp/+vgZqsCGJapvy8N3Q+4o7FwvquPJcnZ7RYy4g==", + "dev": true, + "engines": { + "node": ">=4" + } + }, + "node_modules/rimraf": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-3.0.2.tgz", + "integrity": "sha512-JZkJMZkAGFFPP2YqXZXPbMlMBgsxzE8ILs4lMIX/2o0L9UBw9O/Y3o6wFw/i9YLapcUJWwqbi3kdxIPdC62TIA==", + "dev": true, + "dependencies": { + "glob": "^7.1.3" + }, + "bin": { + "rimraf": "bin.js" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/semver": { + "version": "7.3.5", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.3.5.tgz", + "integrity": "sha512-PoeGJYh8HK4BTO/a9Tf6ZG3veo/A7ZVsYrSA6J8ny9nb3B1VrpkuN+z9OE5wfE5p6H4LchYZsegiQgbJD94ZFQ==", + "dev": true, + "dependencies": { + "lru-cache": "^6.0.0" + }, + "bin": { + "semver": "bin/semver.js" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/shebang-command": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", + "integrity": "sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==", + "dev": true, + "dependencies": { + "shebang-regex": "^3.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/shebang-regex": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-3.0.0.tgz", + "integrity": "sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/slice-ansi": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/slice-ansi/-/slice-ansi-4.0.0.tgz", + "integrity": "sha512-qMCMfhY040cVHT43K9BFygqYbUPFZKHOg7K73mtTWJRb8pyP3fzf4Ixd5SzdEJQ6MRUg/WBnOLxghZtKKurENQ==", + "dev": true, + "dependencies": { + "ansi-styles": "^4.0.0", + "astral-regex": "^2.0.0", + "is-fullwidth-code-point": "^3.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/slice-ansi?sponsor=1" + } + }, + "node_modules/slice-ansi/node_modules/ansi-styles": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", + "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "dev": true, + "dependencies": { + "color-convert": "^2.0.1" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/slice-ansi/node_modules/color-convert": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", + "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "dev": true, + "dependencies": { + "color-name": "~1.1.4" + }, + "engines": { + "node": ">=7.0.0" + } + }, + "node_modules/slice-ansi/node_modules/color-name": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", + "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", + "dev": true + }, + "node_modules/sprintf-js": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/sprintf-js/-/sprintf-js-1.0.3.tgz", + "integrity": "sha1-BOaSb2YolTVPPdAVIDYzuFcpfiw=", + "dev": true + }, + "node_modules/string-width": { + "version": "4.2.2", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.2.tgz", + "integrity": "sha512-XBJbT3N4JhVumXE0eoLU9DCjcaF92KLNqTmFCnG1pf8duUxFGwtP6AD6nkjw9a3IdiRtL3E2w3JDiE/xi3vOeA==", + "dev": true, + "dependencies": { + "emoji-regex": "^8.0.0", + "is-fullwidth-code-point": "^3.0.0", + "strip-ansi": "^6.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/strip-ansi": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.0.tgz", + "integrity": "sha512-AuvKTrTfQNYNIctbR1K/YGTR1756GycPsg7b9bdV9Duqur4gv6aKqHXah67Z8ImS7WEz5QVcOtlfW2rZEugt6w==", + "dev": true, + "dependencies": { + "ansi-regex": "^5.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/strip-json-comments": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/strip-json-comments/-/strip-json-comments-3.1.1.tgz", + "integrity": "sha512-6fPc+R4ihwqP6N/aIv2f1gMH8lOVtWQHoqC4yK6oSDVVocumAsfCqjkXnqiYMhmMwS/mEHLp7Vehlt3ql6lEig==", + "dev": true, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/supports-color": { + "version": "5.5.0", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-5.5.0.tgz", + "integrity": "sha512-QjVjwdXIt408MIiAqCX4oUKsgU2EqAGzs2Ppkm4aQYbjm+ZEWEcW4SfFNTr4uMNZma0ey4f5lgLrkB0aX0QMow==", + "dev": true, + "dependencies": { + "has-flag": "^3.0.0" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/table": { + "version": "6.6.0", + "resolved": "https://registry.npmjs.org/table/-/table-6.6.0.tgz", + "integrity": "sha512-iZMtp5tUvcnAdtHpZTWLPF0M7AgiQsURR2DwmxnJwSy8I3+cY+ozzVvYha3BOLG2TB+L0CqjIz+91htuj6yCXg==", + "dev": true, + "dependencies": { + "ajv": "^8.0.1", + "lodash.clonedeep": "^4.5.0", + "lodash.flatten": "^4.4.0", + "lodash.truncate": "^4.4.2", + "slice-ansi": "^4.0.0", + "string-width": "^4.2.0", + "strip-ansi": "^6.0.0" + }, + "engines": { + "node": ">=10.0.0" + } + }, + "node_modules/table/node_modules/ajv": { + "version": "8.2.0", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.2.0.tgz", + "integrity": "sha512-WSNGFuyWd//XO8n/m/EaOlNLtO0yL8EXT/74LqT4khdhpZjP7lkj/kT5uwRmGitKEVp/Oj7ZUHeGfPtgHhQ5CA==", + "dev": true, + "dependencies": { + "fast-deep-equal": "^3.1.1", + "json-schema-traverse": "^1.0.0", + "require-from-string": "^2.0.2", + "uri-js": "^4.2.2" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/epoberezkin" + } + }, + "node_modules/table/node_modules/json-schema-traverse": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", + "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", + "dev": true + }, + "node_modules/text-table": { + "version": "0.2.0", + "resolved": "https://registry.npmjs.org/text-table/-/text-table-0.2.0.tgz", + "integrity": "sha1-f17oI66AUgfACvLfSoTsP8+lcLQ=", + "dev": true + }, + "node_modules/type-check": { + "version": "0.4.0", + "resolved": "https://registry.npmjs.org/type-check/-/type-check-0.4.0.tgz", + "integrity": "sha512-XleUoc9uwGXqjWwXaUTZAmzMcFZ5858QA2vvx1Ur5xIcixXIP+8LnFDgRplU30us6teqdlskFfu+ae4K79Ooew==", + "dev": true, + "dependencies": { + "prelude-ls": "^1.2.1" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/type-fest": { + "version": "0.8.1", + "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-0.8.1.tgz", + "integrity": "sha512-4dbzIzqvjtgiM5rw1k5rEHtBANKmdudhGyBEajN01fEyhaAIhsoKNy6y7+IN93IfpFtwY9iqi7kD+xwKhQsNJA==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/uri-js": { + "version": "4.4.1", + "resolved": "https://registry.npmjs.org/uri-js/-/uri-js-4.4.1.tgz", + "integrity": "sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg==", + "dev": true, + "dependencies": { + "punycode": "^2.1.0" + } + }, + "node_modules/v8-compile-cache": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/v8-compile-cache/-/v8-compile-cache-2.3.0.tgz", + "integrity": "sha512-l8lCEmLcLYZh4nbunNZvQCJc5pv7+RCwa8q/LdUx8u7lsWvPDKmpodJAJNwkAhJC//dFY48KuIEmjtd4RViDrA==", + "dev": true + }, + "node_modules/which": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", + "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", + "dev": true, + "dependencies": { + "isexe": "^2.0.0" + }, + "bin": { + "node-which": "bin/node-which" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/word-wrap": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/word-wrap/-/word-wrap-1.2.3.tgz", + "integrity": "sha512-Hz/mrNwitNRh/HUAtM/VT/5VH+ygD6DV7mYKZAtHOrbs8U7lvPS6xf7EJKMF0uW1KJCl0H701g3ZGus+muE5vQ==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/wrappy": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/wrappy/-/wrappy-1.0.2.tgz", + "integrity": "sha1-tSQ9jz7BqjXxNkYFvA0QNuMKtp8=", + "dev": true + }, + "node_modules/yallist": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz", + "integrity": "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==", + "dev": true + } + }, + "dependencies": { + "@babel/code-frame": { + "version": "7.12.11", + "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.12.11.tgz", + "integrity": "sha512-Zt1yodBx1UcyiePMSkWnU4hPqhwq7hGi2nFL1LeA3EUl+q2LQx16MISgJ0+z7dnmgvP9QtIleuETGOiOH1RcIw==", + "dev": true, + "requires": { + "@babel/highlight": "^7.10.4" + } + }, + "@babel/helper-validator-identifier": { + "version": "7.14.0", + "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.14.0.tgz", + "integrity": "sha512-V3ts7zMSu5lfiwWDVWzRDGIN+lnCEUdaXgtVHJgLb1rGaA6jMrtB9EmE7L18foXJIE8Un/A/h6NJfGQp/e1J4A==", + "dev": true + }, + "@babel/highlight": { + "version": "7.14.0", + "resolved": "https://registry.npmjs.org/@babel/highlight/-/highlight-7.14.0.tgz", + "integrity": "sha512-YSCOwxvTYEIMSGaBQb5kDDsCopDdiUGsqpatp3fOlI4+2HQSkTmEVWnVuySdAC5EWCqSWWTv0ib63RjR7dTBdg==", + "dev": true, + "requires": { + "@babel/helper-validator-identifier": "^7.14.0", + "chalk": "^2.0.0", + "js-tokens": "^4.0.0" + }, + "dependencies": { "chalk": { - "version": "4.1.1", - "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.1.tgz", - "integrity": "sha512-diHzdDKxcU+bAsUboHLPEDQiw0qEe0qd7SYUn3HgcFlWgbDcfLGswOHYeGrHKzG9z6UYf01d9VFMfZxPM1xZSg==", - "dev": true, - "requires": { - "ansi-styles": "^4.1.0", - "supports-color": "^7.1.0" - }, - "dependencies": { - "ansi-styles": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", - "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", - "dev": true, - "requires": { - "color-convert": "^2.0.1" - } - }, - "color-convert": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", - "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", - "dev": true, - "requires": { - "color-name": "~1.1.4" - } - }, - "color-name": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", - "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", - "dev": true - }, - "has-flag": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", - "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", - "dev": true - }, - "supports-color": { - "version": "7.2.0", - "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", - "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", - "dev": true, - "requires": { - "has-flag": "^4.0.0" - } - } - } + "version": "2.4.2", + "resolved": "https://registry.npmjs.org/chalk/-/chalk-2.4.2.tgz", + "integrity": "sha512-Mti+f9lpJNcwF4tWV8/OrTTtF1gZi+f8FqlyAdouralcFWFQWF2+NgCHShjkCb+IFBLq9buZwE1xckQU4peSuQ==", + "dev": true, + "requires": { + "ansi-styles": "^3.2.1", + "escape-string-regexp": "^1.0.5", + "supports-color": "^5.3.0" + } + } + } + }, + "@eslint/eslintrc": { + "version": "0.4.0", + "resolved": "https://registry.npmjs.org/@eslint/eslintrc/-/eslintrc-0.4.0.tgz", + "integrity": "sha512-2ZPCc+uNbjV5ERJr+aKSPRwZgKd2z11x0EgLvb1PURmUrn9QNRXFqje0Ldq454PfAVyaJYyrDvvIKSFP4NnBog==", + "dev": true, + "requires": { + "ajv": "^6.12.4", + "debug": "^4.1.1", + "espree": "^7.3.0", + "globals": "^12.1.0", + "ignore": "^4.0.6", + "import-fresh": "^3.2.1", + "js-yaml": "^3.13.1", + "minimatch": "^3.0.4", + "strip-json-comments": "^3.1.1" + }, + "dependencies": { + "globals": { + "version": "12.4.0", + "resolved": "https://registry.npmjs.org/globals/-/globals-12.4.0.tgz", + "integrity": "sha512-BWICuzzDvDoH54NHKCseDanAhE3CeDorgDL5MT6LMXXj2WCnd9UC2szdk4AWLfjdgNBCXLUanXYcpBBKOSWGwg==", + "dev": true, + "requires": { + "type-fest": "^0.8.1" + } + } + } + }, + "acorn": { + "version": "7.4.1", + "resolved": "https://registry.npmjs.org/acorn/-/acorn-7.4.1.tgz", + "integrity": "sha512-nQyp0o1/mNdbTO1PO6kHkwSrmgZ0MT/jCCpNiwbUjGoRN4dlBhqJtoQuCnEOKzgTVwg0ZWiCoQy6SxMebQVh8A==", + "dev": true + }, + "acorn-jsx": { + "version": "5.3.1", + "resolved": "https://registry.npmjs.org/acorn-jsx/-/acorn-jsx-5.3.1.tgz", + "integrity": "sha512-K0Ptm/47OKfQRpNQ2J/oIN/3QYiK6FwW+eJbILhsdxh2WTLdl+30o8aGdTbm5JbffpFFAg/g+zi1E+jvJha5ng==", + "dev": true, + "requires": {} + }, + "ajv": { + "version": "6.12.6", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", + "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", + "dev": true, + "requires": { + "fast-deep-equal": "^3.1.1", + "fast-json-stable-stringify": "^2.0.0", + "json-schema-traverse": "^0.4.1", + "uri-js": "^4.2.2" + } + }, + "ansi-colors": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/ansi-colors/-/ansi-colors-4.1.1.tgz", + "integrity": "sha512-JoX0apGbHaUJBNl6yF+p6JAFYZ666/hhCGKN5t9QFjbJQKUU/g8MNbFDbvfrgKXvI1QpZplPOnwIo99lX/AAmA==", + "dev": true + }, + "ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "dev": true + }, + "ansi-styles": { + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-3.2.1.tgz", + "integrity": "sha512-VT0ZI6kZRdTh8YyJw3SMbYm/u+NqfsAxEpWO0Pf9sq8/e94WxxOpPKx9FR1FlyCtOVDNOQ+8ntlqFxiRc+r5qA==", + "dev": true, + "requires": { + "color-convert": "^1.9.0" + } + }, + "argparse": { + "version": "1.0.10", + "resolved": "https://registry.npmjs.org/argparse/-/argparse-1.0.10.tgz", + "integrity": "sha512-o5Roy6tNG4SL/FOkCAN6RzjiakZS25RLYFrcMttJqbdd8BWrnA+fGz57iN5Pb06pvBGvl5gQ0B48dJlslXvoTg==", + "dev": true, + "requires": { + "sprintf-js": "~1.0.2" + } + }, + "astral-regex": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/astral-regex/-/astral-regex-2.0.0.tgz", + "integrity": "sha512-Z7tMw1ytTXt5jqMcOP+OQteU1VuNK9Y02uuJtKQ1Sv69jXQKKg5cibLwGJow8yzZP+eAc18EmLGPal0bp36rvQ==", + "dev": true + }, + "balanced-match": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", + "integrity": "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==", + "dev": true + }, + "brace-expansion": { + "version": "1.1.11", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.11.tgz", + "integrity": "sha512-iCuPHDFgrHX7H2vEI/5xpz07zSHB00TpugqhmYtVmMO6518mCuRMoOYFldEBl0g187ufozdaHgWKcYFb61qGiA==", + "dev": true, + "requires": { + "balanced-match": "^1.0.0", + "concat-map": "0.0.1" + } + }, + "callsites": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/callsites/-/callsites-3.1.0.tgz", + "integrity": "sha512-P8BjAsXvZS+VIDUI11hHCQEv74YT67YUi5JJFNWIqL235sBmjX4+qx9Muvls5ivyNENctx46xQLQ3aTuE7ssaQ==", + "dev": true + }, + "chalk": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.1.tgz", + "integrity": "sha512-diHzdDKxcU+bAsUboHLPEDQiw0qEe0qd7SYUn3HgcFlWgbDcfLGswOHYeGrHKzG9z6UYf01d9VFMfZxPM1xZSg==", + "dev": true, + "requires": { + "ansi-styles": "^4.1.0", + "supports-color": "^7.1.0" + }, + "dependencies": { + "ansi-styles": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", + "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "dev": true, + "requires": { + "color-convert": "^2.0.1" + } }, "color-convert": { - "version": "1.9.3", - "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-1.9.3.tgz", - "integrity": "sha512-QfAUtd+vFdAtFQcC8CCyYt1fYWxSqAiK2cSD6zDB8N3cpsEBAvRxp9zOGg6G/SHHJYAT88/az/IuDGALsNVbGg==", - "dev": true, - "requires": { - "color-name": "1.1.3" - } + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", + "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "dev": true, + "requires": { + "color-name": "~1.1.4" + } }, "color-name": { - "version": "1.1.3", - "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.3.tgz", - "integrity": "sha1-p9BVi9icQveV3UIyj3QIMcpTvCU=", - "dev": true - }, - "concat-map": { - "version": "0.0.1", - "resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz", - "integrity": "sha1-2Klr13/Wjfd5OnMDajug1UBdR3s=", - "dev": true - }, - "cross-spawn": { - "version": "7.0.3", - "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", - "integrity": "sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==", - "dev": true, - "requires": { - "path-key": "^3.1.0", - "shebang-command": "^2.0.0", - "which": "^2.0.1" - } - }, - "debug": { - "version": "4.3.1", - "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.1.tgz", - "integrity": "sha512-doEwdvm4PCeK4K3RQN2ZC2BYUBaxwLARCqZmMjtF8a51J2Rb0xpVloFRnCODwqjpwnAoao4pelN8l3RJdv3gRQ==", - "dev": true, - "requires": { - "ms": "2.1.2" - } - }, - "deep-is": { - "version": "0.1.3", - "resolved": "https://registry.npmjs.org/deep-is/-/deep-is-0.1.3.tgz", - "integrity": "sha1-s2nW+128E+7PUk+RsHD+7cNXzzQ=", - "dev": true - }, - "doctrine": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/doctrine/-/doctrine-3.0.0.tgz", - "integrity": "sha512-yS+Q5i3hBf7GBkd4KG8a7eBNNWNGLTaEwwYWUijIYM7zrlYDM0BFXHjjPWlWZ1Rg7UaddZeIDmi9jF3HmqiQ2w==", - "dev": true, - "requires": { - "esutils": "^2.0.2" - } - }, - "emoji-regex": { - "version": "8.0.0", - "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", - "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", - "dev": true - }, - "enquirer": { - "version": "2.3.6", - "resolved": "https://registry.npmjs.org/enquirer/-/enquirer-2.3.6.tgz", - "integrity": "sha512-yjNnPr315/FjS4zIsUxYguYUPP2e1NK4d7E7ZOLiyYCcbFBiTMyID+2wvm2w6+pZ/odMA7cRkjhsPbltwBOrLg==", - "dev": true, - "requires": { - "ansi-colors": "^4.1.1" - } - }, - "escape-string-regexp": { - "version": "1.0.5", - "resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-1.0.5.tgz", - "integrity": "sha1-G2HAViGQqN/2rjuyzwIAyhMLhtQ=", - "dev": true - }, - "eslint": { - "version": "7.25.0", - "resolved": "https://registry.npmjs.org/eslint/-/eslint-7.25.0.tgz", - "integrity": "sha512-TVpSovpvCNpLURIScDRB6g5CYu/ZFq9GfX2hLNIV4dSBKxIWojeDODvYl3t0k0VtMxYeR8OXPCFE5+oHMlGfhw==", - "dev": true, - "requires": { - "@babel/code-frame": "7.12.11", - "@eslint/eslintrc": "^0.4.0", - "ajv": "^6.10.0", - "chalk": "^4.0.0", - "cross-spawn": "^7.0.2", - "debug": "^4.0.1", - "doctrine": "^3.0.0", - "enquirer": "^2.3.5", - "eslint-scope": "^5.1.1", - "eslint-utils": "^2.1.0", - "eslint-visitor-keys": "^2.0.0", - "espree": "^7.3.1", - "esquery": "^1.4.0", - "esutils": "^2.0.2", - "file-entry-cache": "^6.0.1", - "functional-red-black-tree": "^1.0.1", - "glob-parent": "^5.0.0", - "globals": "^13.6.0", - "ignore": "^4.0.6", - "import-fresh": "^3.0.0", - "imurmurhash": "^0.1.4", - "is-glob": "^4.0.0", - "js-yaml": "^3.13.1", - "json-stable-stringify-without-jsonify": "^1.0.1", - "levn": "^0.4.1", - "lodash": "^4.17.21", - "minimatch": "^3.0.4", - "natural-compare": "^1.4.0", - "optionator": "^0.9.1", - "progress": "^2.0.0", - "regexpp": "^3.1.0", - "semver": "^7.2.1", - "strip-ansi": "^6.0.0", - "strip-json-comments": "^3.1.0", - "table": "^6.0.4", - "text-table": "^0.2.0", - "v8-compile-cache": "^2.0.3" - } - }, - "eslint-scope": { - "version": "5.1.1", - "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-5.1.1.tgz", - "integrity": "sha512-2NxwbF/hZ0KpepYN0cNbo+FN6XoK7GaHlQhgx/hIZl6Va0bF45RQOOwhLIy8lQDbuCiadSLCBnH2CFYquit5bw==", - "dev": true, - "requires": { - "esrecurse": "^4.3.0", - "estraverse": "^4.1.1" - } - }, - "eslint-utils": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/eslint-utils/-/eslint-utils-2.1.0.tgz", - "integrity": "sha512-w94dQYoauyvlDc43XnGB8lU3Zt713vNChgt4EWwhXAP2XkBvndfxF0AgIqKOOasjPIPzj9JqgwkwbCYD0/V3Zg==", - "dev": true, - "requires": { - "eslint-visitor-keys": "^1.1.0" - }, - "dependencies": { - "eslint-visitor-keys": { - "version": "1.3.0", - "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-1.3.0.tgz", - "integrity": "sha512-6J72N8UNa462wa/KFODt/PJ3IU60SDpC3QXC1Hjc1BXXpfL2C9R5+AU7jhe0F6GREqVMh4Juu+NY7xn+6dipUQ==", - "dev": true - } - } - }, - "eslint-visitor-keys": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-2.1.0.tgz", - "integrity": "sha512-0rSmRBzXgDzIsD6mGdJgevzgezI534Cer5L/vyMX0kHzT/jiB43jRhd9YUlMGYLQy2zprNmoT8qasCGtY+QaKw==", - "dev": true - }, - "espree": { - "version": "7.3.1", - "resolved": "https://registry.npmjs.org/espree/-/espree-7.3.1.tgz", - "integrity": "sha512-v3JCNCE64umkFpmkFGqzVKsOT0tN1Zr+ueqLZfpV1Ob8e+CEgPWa+OxCoGH3tnhimMKIaBm4m/vaRpJ/krRz2g==", - "dev": true, - "requires": { - "acorn": "^7.4.0", - "acorn-jsx": "^5.3.1", - "eslint-visitor-keys": "^1.3.0" - }, - "dependencies": { - "eslint-visitor-keys": { - "version": "1.3.0", - "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-1.3.0.tgz", - "integrity": "sha512-6J72N8UNa462wa/KFODt/PJ3IU60SDpC3QXC1Hjc1BXXpfL2C9R5+AU7jhe0F6GREqVMh4Juu+NY7xn+6dipUQ==", - "dev": true - } - } - }, - "esprima": { - "version": "4.0.1", - "resolved": "https://registry.npmjs.org/esprima/-/esprima-4.0.1.tgz", - "integrity": "sha512-eGuFFw7Upda+g4p+QHvnW0RyTX/SVeJBDM/gCtMARO0cLuT2HcEKnTPvhjV6aGeqrCB/sbNop0Kszm0jsaWU4A==", - "dev": true - }, - "esquery": { - "version": "1.4.0", - "resolved": "https://registry.npmjs.org/esquery/-/esquery-1.4.0.tgz", - "integrity": "sha512-cCDispWt5vHHtwMY2YrAQ4ibFkAL8RbH5YGBnZBc90MolvvfkkQcJro/aZiAQUlQ3qgrYS6D6v8Gc5G5CQsc9w==", - "dev": true, - "requires": { - "estraverse": "^5.1.0" - }, - "dependencies": { - "estraverse": { - "version": "5.2.0", - "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-5.2.0.tgz", - "integrity": "sha512-BxbNGGNm0RyRYvUdHpIwv9IWzeM9XClbOxwoATuFdOE7ZE6wHL+HQ5T8hoPM+zHvmKzzsEqhgy0GrQ5X13afiQ==", - "dev": true - } - } - }, - "esrecurse": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/esrecurse/-/esrecurse-4.3.0.tgz", - "integrity": "sha512-KmfKL3b6G+RXvP8N1vr3Tq1kL/oCFgn2NYXEtqP8/L3pKapUA4G8cFVaoF3SU323CD4XypR/ffioHmkti6/Tag==", - "dev": true, - "requires": { - "estraverse": "^5.2.0" - }, - "dependencies": { - "estraverse": { - "version": "5.2.0", - "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-5.2.0.tgz", - "integrity": "sha512-BxbNGGNm0RyRYvUdHpIwv9IWzeM9XClbOxwoATuFdOE7ZE6wHL+HQ5T8hoPM+zHvmKzzsEqhgy0GrQ5X13afiQ==", - "dev": true - } - } - }, - "estraverse": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-4.3.0.tgz", - "integrity": "sha512-39nnKffWz8xN1BU/2c79n9nB9HDzo0niYUqx6xyqUnyoAnQyyWpOTdZEeiCch8BBu515t4wp9ZmgVfVhn9EBpw==", - "dev": true - }, - "esutils": { - "version": "2.0.3", - "resolved": "https://registry.npmjs.org/esutils/-/esutils-2.0.3.tgz", - "integrity": "sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g==", - "dev": true - }, - "fast-deep-equal": { - "version": "3.1.3", - "resolved": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz", - "integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==", - "dev": true - }, - "fast-json-stable-stringify": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/fast-json-stable-stringify/-/fast-json-stable-stringify-2.1.0.tgz", - "integrity": "sha512-lhd/wF+Lk98HZoTCtlVraHtfh5XYijIjalXck7saUtuanSDyLMxnHhSXEDJqHxD7msR8D0uCmqlkwjCV8xvwHw==", - "dev": true - }, - "fast-levenshtein": { - "version": "2.0.6", - "resolved": "https://registry.npmjs.org/fast-levenshtein/-/fast-levenshtein-2.0.6.tgz", - "integrity": "sha1-PYpcZog6FqMMqGQ+hR8Zuqd5eRc=", - "dev": true - }, - "file-entry-cache": { - "version": "6.0.1", - "resolved": "https://registry.npmjs.org/file-entry-cache/-/file-entry-cache-6.0.1.tgz", - "integrity": "sha512-7Gps/XWymbLk2QLYK4NzpMOrYjMhdIxXuIvy2QBsLE6ljuodKvdkWs/cpyJJ3CVIVpH0Oi1Hvg1ovbMzLdFBBg==", - "dev": true, - "requires": { - "flat-cache": "^3.0.4" - } - }, - "flat-cache": { - "version": "3.0.4", - "resolved": "https://registry.npmjs.org/flat-cache/-/flat-cache-3.0.4.tgz", - "integrity": "sha512-dm9s5Pw7Jc0GvMYbshN6zchCA9RgQlzzEZX3vylR9IqFfS8XciblUXOKfW6SiuJ0e13eDYZoZV5wdrev7P3Nwg==", - "dev": true, - "requires": { - "flatted": "^3.1.0", - "rimraf": "^3.0.2" - } - }, - "flatted": { - "version": "3.1.1", - "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.1.1.tgz", - "integrity": "sha512-zAoAQiudy+r5SvnSw3KJy5os/oRJYHzrzja/tBDqrZtNhUw8bt6y8OBzMWcjWr+8liV8Eb6yOhw8WZ7VFZ5ZzA==", - "dev": true - }, - "fs.realpath": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/fs.realpath/-/fs.realpath-1.0.0.tgz", - "integrity": "sha1-FQStJSMVjKpA20onh8sBQRmU6k8=", - "dev": true - }, - "functional-red-black-tree": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/functional-red-black-tree/-/functional-red-black-tree-1.0.1.tgz", - "integrity": "sha1-GwqzvVU7Kg1jmdKcDj6gslIHgyc=", - "dev": true - }, - "glob": { - "version": "7.1.6", - "resolved": "https://registry.npmjs.org/glob/-/glob-7.1.6.tgz", - "integrity": "sha512-LwaxwyZ72Lk7vZINtNNrywX0ZuLyStrdDtabefZKAY5ZGJhVtgdznluResxNmPitE0SAO+O26sWTHeKSI2wMBA==", - "dev": true, - "requires": { - "fs.realpath": "^1.0.0", - "inflight": "^1.0.4", - "inherits": "2", - "minimatch": "^3.0.4", - "once": "^1.3.0", - "path-is-absolute": "^1.0.0" - } - }, - "glob-parent": { - "version": "5.1.2", - "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", - "integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==", - "dev": true, - "requires": { - "is-glob": "^4.0.1" - } - }, - "globals": { - "version": "13.8.0", - "resolved": "https://registry.npmjs.org/globals/-/globals-13.8.0.tgz", - "integrity": "sha512-rHtdA6+PDBIjeEvA91rpqzEvk/k3/i7EeNQiryiWuJH0Hw9cpyJMAt2jtbAwUaRdhD+573X4vWw6IcjKPasi9Q==", - "dev": true, - "requires": { - "type-fest": "^0.20.2" - }, - "dependencies": { - "type-fest": { - "version": "0.20.2", - "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-0.20.2.tgz", - "integrity": "sha512-Ne+eE4r0/iWnpAxD852z3A+N0Bt5RN//NjJwRd2VFHEmrywxf5vsZlh4R6lixl6B+wz/8d+maTSAkN1FIkI3LQ==", - "dev": true - } - } + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", + "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", + "dev": true }, "has-flag": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-3.0.0.tgz", - "integrity": "sha1-tdRU3CGZriJWmfNGfloH87lVuv0=", - "dev": true - }, - "ignore": { - "version": "4.0.6", - "resolved": "https://registry.npmjs.org/ignore/-/ignore-4.0.6.tgz", - "integrity": "sha512-cyFDKrqc/YdcWFniJhzI42+AzS+gNwmUzOSFcRCQYwySuBBBy/KjuxWLZ/FHEH6Moq1NizMOBWyTcv8O4OZIMg==", - "dev": true - }, - "import-fresh": { - "version": "3.3.0", - "resolved": "https://registry.npmjs.org/import-fresh/-/import-fresh-3.3.0.tgz", - "integrity": "sha512-veYYhQa+D1QBKznvhUHxb8faxlrwUnxseDAbAp457E0wLNio2bOSKnjYDhMj+YiAq61xrMGhQk9iXVk5FzgQMw==", - "dev": true, - "requires": { - "parent-module": "^1.0.0", - "resolve-from": "^4.0.0" - } - }, - "imurmurhash": { - "version": "0.1.4", - "resolved": "https://registry.npmjs.org/imurmurhash/-/imurmurhash-0.1.4.tgz", - "integrity": "sha1-khi5srkoojixPcT7a21XbyMUU+o=", - "dev": true - }, - "inflight": { - "version": "1.0.6", - "resolved": "https://registry.npmjs.org/inflight/-/inflight-1.0.6.tgz", - "integrity": "sha1-Sb1jMdfQLQwJvJEKEHW6gWW1bfk=", - "dev": true, - "requires": { - "once": "^1.3.0", - "wrappy": "1" - } - }, - "inherits": { - "version": "2.0.4", - "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz", - "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==", - "dev": true - }, - "is-extglob": { - "version": "2.1.1", - "resolved": "https://registry.npmjs.org/is-extglob/-/is-extglob-2.1.1.tgz", - "integrity": "sha1-qIwCU1eR8C7TfHahueqXc8gz+MI=", - "dev": true - }, - "is-fullwidth-code-point": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz", - "integrity": "sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg==", - "dev": true - }, - "is-glob": { - "version": "4.0.1", - "resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.1.tgz", - "integrity": "sha512-5G0tKtBTFImOqDnLB2hG6Bp2qcKEFduo4tZu9MT/H6NQv/ghhy30o55ufafxJ/LdH79LLs2Kfrn85TLKyA7BUg==", - "dev": true, - "requires": { - "is-extglob": "^2.1.1" - } - }, - "isexe": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", - "integrity": "sha1-6PvzdNxVb/iUehDcsFctYz8s+hA=", - "dev": true - }, - "js-tokens": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz", - "integrity": "sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==", - "dev": true - }, - "js-yaml": { - "version": "3.14.1", - "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-3.14.1.tgz", - "integrity": "sha512-okMH7OXXJ7YrN9Ok3/SXrnu4iX9yOk+25nqX4imS2npuvTYDmo/QEZoqwZkYaIDk3jVvBOTOIEgEhaLOynBS9g==", - "dev": true, - "requires": { - "argparse": "^1.0.7", - "esprima": "^4.0.0" - } - }, - "json-schema-traverse": { - "version": "0.4.1", - "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz", - "integrity": "sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==", - "dev": true - }, - "json-stable-stringify-without-jsonify": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/json-stable-stringify-without-jsonify/-/json-stable-stringify-without-jsonify-1.0.1.tgz", - "integrity": "sha1-nbe1lJatPzz+8wp1FC0tkwrXJlE=", - "dev": true - }, - "levn": { - "version": "0.4.1", - "resolved": "https://registry.npmjs.org/levn/-/levn-0.4.1.tgz", - "integrity": "sha512-+bT2uH4E5LGE7h/n3evcS/sQlJXCpIp6ym8OWJ5eV6+67Dsql/LaaT7qJBAt2rzfoa/5QBGBhxDix1dMt2kQKQ==", - "dev": true, - "requires": { - "prelude-ls": "^1.2.1", - "type-check": "~0.4.0" - } - }, - "lodash": { - "version": "4.17.21", - "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz", - "integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==", - "dev": true - }, - "lodash.clonedeep": { - "version": "4.5.0", - "resolved": "https://registry.npmjs.org/lodash.clonedeep/-/lodash.clonedeep-4.5.0.tgz", - "integrity": "sha1-4j8/nE+Pvd6HJSnBBxhXoIblzO8=", - "dev": true - }, - "lodash.flatten": { - "version": "4.4.0", - "resolved": "https://registry.npmjs.org/lodash.flatten/-/lodash.flatten-4.4.0.tgz", - "integrity": "sha1-8xwiIlqWMtK7+OSt2+8kCqdlph8=", - "dev": true - }, - "lodash.truncate": { - "version": "4.4.2", - "resolved": "https://registry.npmjs.org/lodash.truncate/-/lodash.truncate-4.4.2.tgz", - "integrity": "sha1-WjUNoLERO4N+z//VgSy+WNbq4ZM=", - "dev": true - }, - "lru-cache": { - "version": "6.0.0", - "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-6.0.0.tgz", - "integrity": "sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==", - "dev": true, - "requires": { - "yallist": "^4.0.0" - } - }, - "minimatch": { - "version": "3.0.4", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.0.4.tgz", - "integrity": "sha512-yJHVQEhyqPLUTgt9B83PXu6W3rx4MvvHvSUvToogpwoGDOUQ+yDrR0HRot+yOCdCO7u4hX3pWft6kWBBcqh0UA==", - "dev": true, - "requires": { - "brace-expansion": "^1.1.7" - } - }, - "ms": { - "version": "2.1.2", - "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", - "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==", - "dev": true - }, - "natural-compare": { - "version": "1.4.0", - "resolved": "https://registry.npmjs.org/natural-compare/-/natural-compare-1.4.0.tgz", - "integrity": "sha1-Sr6/7tdUHywnrPspvbvRXI1bpPc=", - "dev": true - }, - "once": { - "version": "1.4.0", - "resolved": "https://registry.npmjs.org/once/-/once-1.4.0.tgz", - "integrity": "sha1-WDsap3WWHUsROsF9nFC6753Xa9E=", - "dev": true, - "requires": { - "wrappy": "1" - } - }, - "optionator": { - "version": "0.9.1", - "resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.1.tgz", - "integrity": "sha512-74RlY5FCnhq4jRxVUPKDaRwrVNXMqsGsiW6AJw4XK8hmtm10wC0ypZBLw5IIp85NZMr91+qd1RvvENwg7jjRFw==", - "dev": true, - "requires": { - "deep-is": "^0.1.3", - "fast-levenshtein": "^2.0.6", - "levn": "^0.4.1", - "prelude-ls": "^1.2.1", - "type-check": "^0.4.0", - "word-wrap": "^1.2.3" - } - }, - "parent-module": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/parent-module/-/parent-module-1.0.1.tgz", - "integrity": "sha512-GQ2EWRpQV8/o+Aw8YqtfZZPfNRWZYkbidE9k5rpl/hC3vtHHBfGm2Ifi6qWV+coDGkrUKZAxE3Lot5kcsRlh+g==", - "dev": true, - "requires": { - "callsites": "^3.0.0" - } - }, - "path-is-absolute": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/path-is-absolute/-/path-is-absolute-1.0.1.tgz", - "integrity": "sha1-F0uSaHNVNP+8es5r9TpanhtcX18=", - "dev": true - }, - "path-key": { - "version": "3.1.1", - "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", - "integrity": "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==", - "dev": true - }, - "prelude-ls": { - "version": "1.2.1", - "resolved": "https://registry.npmjs.org/prelude-ls/-/prelude-ls-1.2.1.tgz", - "integrity": "sha512-vkcDPrRZo1QZLbn5RLGPpg/WmIQ65qoWWhcGKf/b5eplkkarX0m9z8ppCat4mlOqUsWpyNuYgO3VRyrYHSzX5g==", - "dev": true - }, - "progress": { - "version": "2.0.3", - "resolved": "https://registry.npmjs.org/progress/-/progress-2.0.3.tgz", - "integrity": "sha512-7PiHtLll5LdnKIMw100I+8xJXR5gW2QwWYkT6iJva0bXitZKa/XMrSbdmg3r2Xnaidz9Qumd0VPaMrZlF9V9sA==", - "dev": true - }, - "punycode": { - "version": "2.1.1", - "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.1.1.tgz", - "integrity": "sha512-XRsRjdf+j5ml+y/6GKHPZbrF/8p2Yga0JPtdqTIY2Xe5ohJPD9saDJJLPvp9+NSBprVvevdXZybnj2cv8OEd0A==", - "dev": true - }, - "regexpp": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/regexpp/-/regexpp-3.1.0.tgz", - "integrity": "sha512-ZOIzd8yVsQQA7j8GCSlPGXwg5PfmA1mrq0JP4nGhh54LaKN3xdai/vHUDu74pKwV8OxseMS65u2NImosQcSD0Q==", - "dev": true - }, - "require-from-string": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/require-from-string/-/require-from-string-2.0.2.tgz", - "integrity": "sha512-Xf0nWe6RseziFMu+Ap9biiUbmplq6S9/p+7w7YXP/JBHhrUDDUhwa+vANyubuqfZWTveU//DYVGsDG7RKL/vEw==", - "dev": true - }, - "resolve-from": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/resolve-from/-/resolve-from-4.0.0.tgz", - "integrity": "sha512-pb/MYmXstAkysRFx8piNI1tGFNQIFA3vkE3Gq4EuA1dF6gHp/+vgZqsCGJapvy8N3Q+4o7FwvquPJcnZ7RYy4g==", - "dev": true - }, - "rimraf": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-3.0.2.tgz", - "integrity": "sha512-JZkJMZkAGFFPP2YqXZXPbMlMBgsxzE8ILs4lMIX/2o0L9UBw9O/Y3o6wFw/i9YLapcUJWwqbi3kdxIPdC62TIA==", - "dev": true, - "requires": { - "glob": "^7.1.3" - } - }, - "semver": { - "version": "7.3.5", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.3.5.tgz", - "integrity": "sha512-PoeGJYh8HK4BTO/a9Tf6ZG3veo/A7ZVsYrSA6J8ny9nb3B1VrpkuN+z9OE5wfE5p6H4LchYZsegiQgbJD94ZFQ==", - "dev": true, - "requires": { - "lru-cache": "^6.0.0" - } - }, - "shebang-command": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", - "integrity": "sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==", - "dev": true, - "requires": { - "shebang-regex": "^3.0.0" - } - }, - "shebang-regex": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-3.0.0.tgz", - "integrity": "sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==", - "dev": true - }, - "slice-ansi": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/slice-ansi/-/slice-ansi-4.0.0.tgz", - "integrity": "sha512-qMCMfhY040cVHT43K9BFygqYbUPFZKHOg7K73mtTWJRb8pyP3fzf4Ixd5SzdEJQ6MRUg/WBnOLxghZtKKurENQ==", - "dev": true, - "requires": { - "ansi-styles": "^4.0.0", - "astral-regex": "^2.0.0", - "is-fullwidth-code-point": "^3.0.0" - }, - "dependencies": { - "ansi-styles": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", - "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", - "dev": true, - "requires": { - "color-convert": "^2.0.1" - } - }, - "color-convert": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", - "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", - "dev": true, - "requires": { - "color-name": "~1.1.4" - } - }, - "color-name": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", - "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", - "dev": true - } - } - }, - "sprintf-js": { - "version": "1.0.3", - "resolved": "https://registry.npmjs.org/sprintf-js/-/sprintf-js-1.0.3.tgz", - "integrity": "sha1-BOaSb2YolTVPPdAVIDYzuFcpfiw=", - "dev": true - }, - "string-width": { - "version": "4.2.2", - "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.2.tgz", - "integrity": "sha512-XBJbT3N4JhVumXE0eoLU9DCjcaF92KLNqTmFCnG1pf8duUxFGwtP6AD6nkjw9a3IdiRtL3E2w3JDiE/xi3vOeA==", - "dev": true, - "requires": { - "emoji-regex": "^8.0.0", - "is-fullwidth-code-point": "^3.0.0", - "strip-ansi": "^6.0.0" - } - }, - "strip-ansi": { - "version": "6.0.0", - "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.0.tgz", - "integrity": "sha512-AuvKTrTfQNYNIctbR1K/YGTR1756GycPsg7b9bdV9Duqur4gv6aKqHXah67Z8ImS7WEz5QVcOtlfW2rZEugt6w==", - "dev": true, - "requires": { - "ansi-regex": "^5.0.0" - } - }, - "strip-json-comments": { - "version": "3.1.1", - "resolved": "https://registry.npmjs.org/strip-json-comments/-/strip-json-comments-3.1.1.tgz", - "integrity": "sha512-6fPc+R4ihwqP6N/aIv2f1gMH8lOVtWQHoqC4yK6oSDVVocumAsfCqjkXnqiYMhmMwS/mEHLp7Vehlt3ql6lEig==", - "dev": true + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", + "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "dev": true }, "supports-color": { - "version": "5.5.0", - "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-5.5.0.tgz", - "integrity": "sha512-QjVjwdXIt408MIiAqCX4oUKsgU2EqAGzs2Ppkm4aQYbjm+ZEWEcW4SfFNTr4uMNZma0ey4f5lgLrkB0aX0QMow==", - "dev": true, - "requires": { - "has-flag": "^3.0.0" - } - }, - "table": { - "version": "6.6.0", - "resolved": "https://registry.npmjs.org/table/-/table-6.6.0.tgz", - "integrity": "sha512-iZMtp5tUvcnAdtHpZTWLPF0M7AgiQsURR2DwmxnJwSy8I3+cY+ozzVvYha3BOLG2TB+L0CqjIz+91htuj6yCXg==", - "dev": true, - "requires": { - "ajv": "^8.0.1", - "lodash.clonedeep": "^4.5.0", - "lodash.flatten": "^4.4.0", - "lodash.truncate": "^4.4.2", - "slice-ansi": "^4.0.0", - "string-width": "^4.2.0", - "strip-ansi": "^6.0.0" - }, - "dependencies": { - "ajv": { - "version": "8.2.0", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.2.0.tgz", - "integrity": "sha512-WSNGFuyWd//XO8n/m/EaOlNLtO0yL8EXT/74LqT4khdhpZjP7lkj/kT5uwRmGitKEVp/Oj7ZUHeGfPtgHhQ5CA==", - "dev": true, - "requires": { - "fast-deep-equal": "^3.1.1", - "json-schema-traverse": "^1.0.0", - "require-from-string": "^2.0.2", - "uri-js": "^4.2.2" - } - }, - "json-schema-traverse": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", - "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", - "dev": true - } - } - }, - "text-table": { - "version": "0.2.0", - "resolved": "https://registry.npmjs.org/text-table/-/text-table-0.2.0.tgz", - "integrity": "sha1-f17oI66AUgfACvLfSoTsP8+lcLQ=", - "dev": true - }, - "type-check": { - "version": "0.4.0", - "resolved": "https://registry.npmjs.org/type-check/-/type-check-0.4.0.tgz", - "integrity": "sha512-XleUoc9uwGXqjWwXaUTZAmzMcFZ5858QA2vvx1Ur5xIcixXIP+8LnFDgRplU30us6teqdlskFfu+ae4K79Ooew==", - "dev": true, - "requires": { - "prelude-ls": "^1.2.1" - } - }, + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", + "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", + "dev": true, + "requires": { + "has-flag": "^4.0.0" + } + } + } + }, + "color-convert": { + "version": "1.9.3", + "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-1.9.3.tgz", + "integrity": "sha512-QfAUtd+vFdAtFQcC8CCyYt1fYWxSqAiK2cSD6zDB8N3cpsEBAvRxp9zOGg6G/SHHJYAT88/az/IuDGALsNVbGg==", + "dev": true, + "requires": { + "color-name": "1.1.3" + } + }, + "color-name": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.3.tgz", + "integrity": "sha1-p9BVi9icQveV3UIyj3QIMcpTvCU=", + "dev": true + }, + "concat-map": { + "version": "0.0.1", + "resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz", + "integrity": "sha1-2Klr13/Wjfd5OnMDajug1UBdR3s=", + "dev": true + }, + "cross-spawn": { + "version": "7.0.3", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", + "integrity": "sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==", + "dev": true, + "requires": { + "path-key": "^3.1.0", + "shebang-command": "^2.0.0", + "which": "^2.0.1" + } + }, + "debug": { + "version": "4.3.1", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.1.tgz", + "integrity": "sha512-doEwdvm4PCeK4K3RQN2ZC2BYUBaxwLARCqZmMjtF8a51J2Rb0xpVloFRnCODwqjpwnAoao4pelN8l3RJdv3gRQ==", + "dev": true, + "requires": { + "ms": "2.1.2" + } + }, + "deep-is": { + "version": "0.1.3", + "resolved": "https://registry.npmjs.org/deep-is/-/deep-is-0.1.3.tgz", + "integrity": "sha1-s2nW+128E+7PUk+RsHD+7cNXzzQ=", + "dev": true + }, + "doctrine": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/doctrine/-/doctrine-3.0.0.tgz", + "integrity": "sha512-yS+Q5i3hBf7GBkd4KG8a7eBNNWNGLTaEwwYWUijIYM7zrlYDM0BFXHjjPWlWZ1Rg7UaddZeIDmi9jF3HmqiQ2w==", + "dev": true, + "requires": { + "esutils": "^2.0.2" + } + }, + "emoji-regex": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", + "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", + "dev": true + }, + "enquirer": { + "version": "2.3.6", + "resolved": "https://registry.npmjs.org/enquirer/-/enquirer-2.3.6.tgz", + "integrity": "sha512-yjNnPr315/FjS4zIsUxYguYUPP2e1NK4d7E7ZOLiyYCcbFBiTMyID+2wvm2w6+pZ/odMA7cRkjhsPbltwBOrLg==", + "dev": true, + "requires": { + "ansi-colors": "^4.1.1" + } + }, + "escape-string-regexp": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-1.0.5.tgz", + "integrity": "sha1-G2HAViGQqN/2rjuyzwIAyhMLhtQ=", + "dev": true + }, + "eslint": { + "version": "7.25.0", + "resolved": "https://registry.npmjs.org/eslint/-/eslint-7.25.0.tgz", + "integrity": "sha512-TVpSovpvCNpLURIScDRB6g5CYu/ZFq9GfX2hLNIV4dSBKxIWojeDODvYl3t0k0VtMxYeR8OXPCFE5+oHMlGfhw==", + "dev": true, + "requires": { + "@babel/code-frame": "7.12.11", + "@eslint/eslintrc": "^0.4.0", + "ajv": "^6.10.0", + "chalk": "^4.0.0", + "cross-spawn": "^7.0.2", + "debug": "^4.0.1", + "doctrine": "^3.0.0", + "enquirer": "^2.3.5", + "eslint-scope": "^5.1.1", + "eslint-utils": "^2.1.0", + "eslint-visitor-keys": "^2.0.0", + "espree": "^7.3.1", + "esquery": "^1.4.0", + "esutils": "^2.0.2", + "file-entry-cache": "^6.0.1", + "functional-red-black-tree": "^1.0.1", + "glob-parent": "^5.0.0", + "globals": "^13.6.0", + "ignore": "^4.0.6", + "import-fresh": "^3.0.0", + "imurmurhash": "^0.1.4", + "is-glob": "^4.0.0", + "js-yaml": "^3.13.1", + "json-stable-stringify-without-jsonify": "^1.0.1", + "levn": "^0.4.1", + "lodash": "^4.17.21", + "minimatch": "^3.0.4", + "natural-compare": "^1.4.0", + "optionator": "^0.9.1", + "progress": "^2.0.0", + "regexpp": "^3.1.0", + "semver": "^7.2.1", + "strip-ansi": "^6.0.0", + "strip-json-comments": "^3.1.0", + "table": "^6.0.4", + "text-table": "^0.2.0", + "v8-compile-cache": "^2.0.3" + } + }, + "eslint-scope": { + "version": "5.1.1", + "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-5.1.1.tgz", + "integrity": "sha512-2NxwbF/hZ0KpepYN0cNbo+FN6XoK7GaHlQhgx/hIZl6Va0bF45RQOOwhLIy8lQDbuCiadSLCBnH2CFYquit5bw==", + "dev": true, + "requires": { + "esrecurse": "^4.3.0", + "estraverse": "^4.1.1" + } + }, + "eslint-utils": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/eslint-utils/-/eslint-utils-2.1.0.tgz", + "integrity": "sha512-w94dQYoauyvlDc43XnGB8lU3Zt713vNChgt4EWwhXAP2XkBvndfxF0AgIqKOOasjPIPzj9JqgwkwbCYD0/V3Zg==", + "dev": true, + "requires": { + "eslint-visitor-keys": "^1.1.0" + }, + "dependencies": { + "eslint-visitor-keys": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-1.3.0.tgz", + "integrity": "sha512-6J72N8UNa462wa/KFODt/PJ3IU60SDpC3QXC1Hjc1BXXpfL2C9R5+AU7jhe0F6GREqVMh4Juu+NY7xn+6dipUQ==", + "dev": true + } + } + }, + "eslint-visitor-keys": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-2.1.0.tgz", + "integrity": "sha512-0rSmRBzXgDzIsD6mGdJgevzgezI534Cer5L/vyMX0kHzT/jiB43jRhd9YUlMGYLQy2zprNmoT8qasCGtY+QaKw==", + "dev": true + }, + "espree": { + "version": "7.3.1", + "resolved": "https://registry.npmjs.org/espree/-/espree-7.3.1.tgz", + "integrity": "sha512-v3JCNCE64umkFpmkFGqzVKsOT0tN1Zr+ueqLZfpV1Ob8e+CEgPWa+OxCoGH3tnhimMKIaBm4m/vaRpJ/krRz2g==", + "dev": true, + "requires": { + "acorn": "^7.4.0", + "acorn-jsx": "^5.3.1", + "eslint-visitor-keys": "^1.3.0" + }, + "dependencies": { + "eslint-visitor-keys": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-1.3.0.tgz", + "integrity": "sha512-6J72N8UNa462wa/KFODt/PJ3IU60SDpC3QXC1Hjc1BXXpfL2C9R5+AU7jhe0F6GREqVMh4Juu+NY7xn+6dipUQ==", + "dev": true + } + } + }, + "esprima": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/esprima/-/esprima-4.0.1.tgz", + "integrity": "sha512-eGuFFw7Upda+g4p+QHvnW0RyTX/SVeJBDM/gCtMARO0cLuT2HcEKnTPvhjV6aGeqrCB/sbNop0Kszm0jsaWU4A==", + "dev": true + }, + "esquery": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/esquery/-/esquery-1.4.0.tgz", + "integrity": "sha512-cCDispWt5vHHtwMY2YrAQ4ibFkAL8RbH5YGBnZBc90MolvvfkkQcJro/aZiAQUlQ3qgrYS6D6v8Gc5G5CQsc9w==", + "dev": true, + "requires": { + "estraverse": "^5.1.0" + }, + "dependencies": { + "estraverse": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-5.2.0.tgz", + "integrity": "sha512-BxbNGGNm0RyRYvUdHpIwv9IWzeM9XClbOxwoATuFdOE7ZE6wHL+HQ5T8hoPM+zHvmKzzsEqhgy0GrQ5X13afiQ==", + "dev": true + } + } + }, + "esrecurse": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/esrecurse/-/esrecurse-4.3.0.tgz", + "integrity": "sha512-KmfKL3b6G+RXvP8N1vr3Tq1kL/oCFgn2NYXEtqP8/L3pKapUA4G8cFVaoF3SU323CD4XypR/ffioHmkti6/Tag==", + "dev": true, + "requires": { + "estraverse": "^5.2.0" + }, + "dependencies": { + "estraverse": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-5.2.0.tgz", + "integrity": "sha512-BxbNGGNm0RyRYvUdHpIwv9IWzeM9XClbOxwoATuFdOE7ZE6wHL+HQ5T8hoPM+zHvmKzzsEqhgy0GrQ5X13afiQ==", + "dev": true + } + } + }, + "estraverse": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-4.3.0.tgz", + "integrity": "sha512-39nnKffWz8xN1BU/2c79n9nB9HDzo0niYUqx6xyqUnyoAnQyyWpOTdZEeiCch8BBu515t4wp9ZmgVfVhn9EBpw==", + "dev": true + }, + "esutils": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/esutils/-/esutils-2.0.3.tgz", + "integrity": "sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g==", + "dev": true + }, + "fast-deep-equal": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz", + "integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==", + "dev": true + }, + "fast-json-stable-stringify": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/fast-json-stable-stringify/-/fast-json-stable-stringify-2.1.0.tgz", + "integrity": "sha512-lhd/wF+Lk98HZoTCtlVraHtfh5XYijIjalXck7saUtuanSDyLMxnHhSXEDJqHxD7msR8D0uCmqlkwjCV8xvwHw==", + "dev": true + }, + "fast-levenshtein": { + "version": "2.0.6", + "resolved": "https://registry.npmjs.org/fast-levenshtein/-/fast-levenshtein-2.0.6.tgz", + "integrity": "sha1-PYpcZog6FqMMqGQ+hR8Zuqd5eRc=", + "dev": true + }, + "file-entry-cache": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/file-entry-cache/-/file-entry-cache-6.0.1.tgz", + "integrity": "sha512-7Gps/XWymbLk2QLYK4NzpMOrYjMhdIxXuIvy2QBsLE6ljuodKvdkWs/cpyJJ3CVIVpH0Oi1Hvg1ovbMzLdFBBg==", + "dev": true, + "requires": { + "flat-cache": "^3.0.4" + } + }, + "flat-cache": { + "version": "3.0.4", + "resolved": "https://registry.npmjs.org/flat-cache/-/flat-cache-3.0.4.tgz", + "integrity": "sha512-dm9s5Pw7Jc0GvMYbshN6zchCA9RgQlzzEZX3vylR9IqFfS8XciblUXOKfW6SiuJ0e13eDYZoZV5wdrev7P3Nwg==", + "dev": true, + "requires": { + "flatted": "^3.1.0", + "rimraf": "^3.0.2" + } + }, + "flatted": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.1.1.tgz", + "integrity": "sha512-zAoAQiudy+r5SvnSw3KJy5os/oRJYHzrzja/tBDqrZtNhUw8bt6y8OBzMWcjWr+8liV8Eb6yOhw8WZ7VFZ5ZzA==", + "dev": true + }, + "fs.realpath": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/fs.realpath/-/fs.realpath-1.0.0.tgz", + "integrity": "sha1-FQStJSMVjKpA20onh8sBQRmU6k8=", + "dev": true + }, + "functional-red-black-tree": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/functional-red-black-tree/-/functional-red-black-tree-1.0.1.tgz", + "integrity": "sha1-GwqzvVU7Kg1jmdKcDj6gslIHgyc=", + "dev": true + }, + "glob": { + "version": "7.1.6", + "resolved": "https://registry.npmjs.org/glob/-/glob-7.1.6.tgz", + "integrity": "sha512-LwaxwyZ72Lk7vZINtNNrywX0ZuLyStrdDtabefZKAY5ZGJhVtgdznluResxNmPitE0SAO+O26sWTHeKSI2wMBA==", + "dev": true, + "requires": { + "fs.realpath": "^1.0.0", + "inflight": "^1.0.4", + "inherits": "2", + "minimatch": "^3.0.4", + "once": "^1.3.0", + "path-is-absolute": "^1.0.0" + } + }, + "glob-parent": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", + "integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==", + "dev": true, + "requires": { + "is-glob": "^4.0.1" + } + }, + "globals": { + "version": "13.8.0", + "resolved": "https://registry.npmjs.org/globals/-/globals-13.8.0.tgz", + "integrity": "sha512-rHtdA6+PDBIjeEvA91rpqzEvk/k3/i7EeNQiryiWuJH0Hw9cpyJMAt2jtbAwUaRdhD+573X4vWw6IcjKPasi9Q==", + "dev": true, + "requires": { + "type-fest": "^0.20.2" + }, + "dependencies": { "type-fest": { - "version": "0.8.1", - "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-0.8.1.tgz", - "integrity": "sha512-4dbzIzqvjtgiM5rw1k5rEHtBANKmdudhGyBEajN01fEyhaAIhsoKNy6y7+IN93IfpFtwY9iqi7kD+xwKhQsNJA==", - "dev": true - }, - "uri-js": { - "version": "4.4.1", - "resolved": "https://registry.npmjs.org/uri-js/-/uri-js-4.4.1.tgz", - "integrity": "sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg==", - "dev": true, - "requires": { - "punycode": "^2.1.0" - } - }, - "v8-compile-cache": { - "version": "2.3.0", - "resolved": "https://registry.npmjs.org/v8-compile-cache/-/v8-compile-cache-2.3.0.tgz", - "integrity": "sha512-l8lCEmLcLYZh4nbunNZvQCJc5pv7+RCwa8q/LdUx8u7lsWvPDKmpodJAJNwkAhJC//dFY48KuIEmjtd4RViDrA==", - "dev": true - }, - "which": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", - "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", - "dev": true, - "requires": { - "isexe": "^2.0.0" - } + "version": "0.20.2", + "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-0.20.2.tgz", + "integrity": "sha512-Ne+eE4r0/iWnpAxD852z3A+N0Bt5RN//NjJwRd2VFHEmrywxf5vsZlh4R6lixl6B+wz/8d+maTSAkN1FIkI3LQ==", + "dev": true + } + } + }, + "has-flag": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-3.0.0.tgz", + "integrity": "sha1-tdRU3CGZriJWmfNGfloH87lVuv0=", + "dev": true + }, + "ignore": { + "version": "4.0.6", + "resolved": "https://registry.npmjs.org/ignore/-/ignore-4.0.6.tgz", + "integrity": "sha512-cyFDKrqc/YdcWFniJhzI42+AzS+gNwmUzOSFcRCQYwySuBBBy/KjuxWLZ/FHEH6Moq1NizMOBWyTcv8O4OZIMg==", + "dev": true + }, + "import-fresh": { + "version": "3.3.0", + "resolved": "https://registry.npmjs.org/import-fresh/-/import-fresh-3.3.0.tgz", + "integrity": "sha512-veYYhQa+D1QBKznvhUHxb8faxlrwUnxseDAbAp457E0wLNio2bOSKnjYDhMj+YiAq61xrMGhQk9iXVk5FzgQMw==", + "dev": true, + "requires": { + "parent-module": "^1.0.0", + "resolve-from": "^4.0.0" + } + }, + "imurmurhash": { + "version": "0.1.4", + "resolved": "https://registry.npmjs.org/imurmurhash/-/imurmurhash-0.1.4.tgz", + "integrity": "sha1-khi5srkoojixPcT7a21XbyMUU+o=", + "dev": true + }, + "inflight": { + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/inflight/-/inflight-1.0.6.tgz", + "integrity": "sha1-Sb1jMdfQLQwJvJEKEHW6gWW1bfk=", + "dev": true, + "requires": { + "once": "^1.3.0", + "wrappy": "1" + } + }, + "inherits": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz", + "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==", + "dev": true + }, + "is-extglob": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/is-extglob/-/is-extglob-2.1.1.tgz", + "integrity": "sha1-qIwCU1eR8C7TfHahueqXc8gz+MI=", + "dev": true + }, + "is-fullwidth-code-point": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz", + "integrity": "sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg==", + "dev": true + }, + "is-glob": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.1.tgz", + "integrity": "sha512-5G0tKtBTFImOqDnLB2hG6Bp2qcKEFduo4tZu9MT/H6NQv/ghhy30o55ufafxJ/LdH79LLs2Kfrn85TLKyA7BUg==", + "dev": true, + "requires": { + "is-extglob": "^2.1.1" + } + }, + "isexe": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", + "integrity": "sha1-6PvzdNxVb/iUehDcsFctYz8s+hA=", + "dev": true + }, + "js-tokens": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz", + "integrity": "sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==", + "dev": true + }, + "js-yaml": { + "version": "3.14.1", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-3.14.1.tgz", + "integrity": "sha512-okMH7OXXJ7YrN9Ok3/SXrnu4iX9yOk+25nqX4imS2npuvTYDmo/QEZoqwZkYaIDk3jVvBOTOIEgEhaLOynBS9g==", + "dev": true, + "requires": { + "argparse": "^1.0.7", + "esprima": "^4.0.0" + } + }, + "json-schema-traverse": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz", + "integrity": "sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==", + "dev": true + }, + "json-stable-stringify-without-jsonify": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/json-stable-stringify-without-jsonify/-/json-stable-stringify-without-jsonify-1.0.1.tgz", + "integrity": "sha1-nbe1lJatPzz+8wp1FC0tkwrXJlE=", + "dev": true + }, + "levn": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/levn/-/levn-0.4.1.tgz", + "integrity": "sha512-+bT2uH4E5LGE7h/n3evcS/sQlJXCpIp6ym8OWJ5eV6+67Dsql/LaaT7qJBAt2rzfoa/5QBGBhxDix1dMt2kQKQ==", + "dev": true, + "requires": { + "prelude-ls": "^1.2.1", + "type-check": "~0.4.0" + } + }, + "lodash": { + "version": "4.17.21", + "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz", + "integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==", + "dev": true + }, + "lodash.clonedeep": { + "version": "4.5.0", + "resolved": "https://registry.npmjs.org/lodash.clonedeep/-/lodash.clonedeep-4.5.0.tgz", + "integrity": "sha1-4j8/nE+Pvd6HJSnBBxhXoIblzO8=", + "dev": true + }, + "lodash.flatten": { + "version": "4.4.0", + "resolved": "https://registry.npmjs.org/lodash.flatten/-/lodash.flatten-4.4.0.tgz", + "integrity": "sha1-8xwiIlqWMtK7+OSt2+8kCqdlph8=", + "dev": true + }, + "lodash.truncate": { + "version": "4.4.2", + "resolved": "https://registry.npmjs.org/lodash.truncate/-/lodash.truncate-4.4.2.tgz", + "integrity": "sha1-WjUNoLERO4N+z//VgSy+WNbq4ZM=", + "dev": true + }, + "lru-cache": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-6.0.0.tgz", + "integrity": "sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==", + "dev": true, + "requires": { + "yallist": "^4.0.0" + } + }, + "minimatch": { + "version": "3.0.4", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.0.4.tgz", + "integrity": "sha512-yJHVQEhyqPLUTgt9B83PXu6W3rx4MvvHvSUvToogpwoGDOUQ+yDrR0HRot+yOCdCO7u4hX3pWft6kWBBcqh0UA==", + "dev": true, + "requires": { + "brace-expansion": "^1.1.7" + } + }, + "ms": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", + "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==", + "dev": true + }, + "natural-compare": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/natural-compare/-/natural-compare-1.4.0.tgz", + "integrity": "sha1-Sr6/7tdUHywnrPspvbvRXI1bpPc=", + "dev": true + }, + "once": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/once/-/once-1.4.0.tgz", + "integrity": "sha1-WDsap3WWHUsROsF9nFC6753Xa9E=", + "dev": true, + "requires": { + "wrappy": "1" + } + }, + "optionator": { + "version": "0.9.1", + "resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.1.tgz", + "integrity": "sha512-74RlY5FCnhq4jRxVUPKDaRwrVNXMqsGsiW6AJw4XK8hmtm10wC0ypZBLw5IIp85NZMr91+qd1RvvENwg7jjRFw==", + "dev": true, + "requires": { + "deep-is": "^0.1.3", + "fast-levenshtein": "^2.0.6", + "levn": "^0.4.1", + "prelude-ls": "^1.2.1", + "type-check": "^0.4.0", + "word-wrap": "^1.2.3" + } + }, + "parent-module": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/parent-module/-/parent-module-1.0.1.tgz", + "integrity": "sha512-GQ2EWRpQV8/o+Aw8YqtfZZPfNRWZYkbidE9k5rpl/hC3vtHHBfGm2Ifi6qWV+coDGkrUKZAxE3Lot5kcsRlh+g==", + "dev": true, + "requires": { + "callsites": "^3.0.0" + } + }, + "path-is-absolute": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/path-is-absolute/-/path-is-absolute-1.0.1.tgz", + "integrity": "sha1-F0uSaHNVNP+8es5r9TpanhtcX18=", + "dev": true + }, + "path-key": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", + "integrity": "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==", + "dev": true + }, + "prelude-ls": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/prelude-ls/-/prelude-ls-1.2.1.tgz", + "integrity": "sha512-vkcDPrRZo1QZLbn5RLGPpg/WmIQ65qoWWhcGKf/b5eplkkarX0m9z8ppCat4mlOqUsWpyNuYgO3VRyrYHSzX5g==", + "dev": true + }, + "progress": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/progress/-/progress-2.0.3.tgz", + "integrity": "sha512-7PiHtLll5LdnKIMw100I+8xJXR5gW2QwWYkT6iJva0bXitZKa/XMrSbdmg3r2Xnaidz9Qumd0VPaMrZlF9V9sA==", + "dev": true + }, + "punycode": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.1.1.tgz", + "integrity": "sha512-XRsRjdf+j5ml+y/6GKHPZbrF/8p2Yga0JPtdqTIY2Xe5ohJPD9saDJJLPvp9+NSBprVvevdXZybnj2cv8OEd0A==", + "dev": true + }, + "regexpp": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/regexpp/-/regexpp-3.1.0.tgz", + "integrity": "sha512-ZOIzd8yVsQQA7j8GCSlPGXwg5PfmA1mrq0JP4nGhh54LaKN3xdai/vHUDu74pKwV8OxseMS65u2NImosQcSD0Q==", + "dev": true + }, + "require-from-string": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/require-from-string/-/require-from-string-2.0.2.tgz", + "integrity": "sha512-Xf0nWe6RseziFMu+Ap9biiUbmplq6S9/p+7w7YXP/JBHhrUDDUhwa+vANyubuqfZWTveU//DYVGsDG7RKL/vEw==", + "dev": true + }, + "resolve-from": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/resolve-from/-/resolve-from-4.0.0.tgz", + "integrity": "sha512-pb/MYmXstAkysRFx8piNI1tGFNQIFA3vkE3Gq4EuA1dF6gHp/+vgZqsCGJapvy8N3Q+4o7FwvquPJcnZ7RYy4g==", + "dev": true + }, + "rimraf": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-3.0.2.tgz", + "integrity": "sha512-JZkJMZkAGFFPP2YqXZXPbMlMBgsxzE8ILs4lMIX/2o0L9UBw9O/Y3o6wFw/i9YLapcUJWwqbi3kdxIPdC62TIA==", + "dev": true, + "requires": { + "glob": "^7.1.3" + } + }, + "semver": { + "version": "7.3.5", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.3.5.tgz", + "integrity": "sha512-PoeGJYh8HK4BTO/a9Tf6ZG3veo/A7ZVsYrSA6J8ny9nb3B1VrpkuN+z9OE5wfE5p6H4LchYZsegiQgbJD94ZFQ==", + "dev": true, + "requires": { + "lru-cache": "^6.0.0" + } + }, + "shebang-command": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", + "integrity": "sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==", + "dev": true, + "requires": { + "shebang-regex": "^3.0.0" + } + }, + "shebang-regex": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-3.0.0.tgz", + "integrity": "sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==", + "dev": true + }, + "slice-ansi": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/slice-ansi/-/slice-ansi-4.0.0.tgz", + "integrity": "sha512-qMCMfhY040cVHT43K9BFygqYbUPFZKHOg7K73mtTWJRb8pyP3fzf4Ixd5SzdEJQ6MRUg/WBnOLxghZtKKurENQ==", + "dev": true, + "requires": { + "ansi-styles": "^4.0.0", + "astral-regex": "^2.0.0", + "is-fullwidth-code-point": "^3.0.0" + }, + "dependencies": { + "ansi-styles": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", + "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "dev": true, + "requires": { + "color-convert": "^2.0.1" + } }, - "word-wrap": { - "version": "1.2.3", - "resolved": "https://registry.npmjs.org/word-wrap/-/word-wrap-1.2.3.tgz", - "integrity": "sha512-Hz/mrNwitNRh/HUAtM/VT/5VH+ygD6DV7mYKZAtHOrbs8U7lvPS6xf7EJKMF0uW1KJCl0H701g3ZGus+muE5vQ==", - "dev": true + "color-convert": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", + "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "dev": true, + "requires": { + "color-name": "~1.1.4" + } }, - "wrappy": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/wrappy/-/wrappy-1.0.2.tgz", - "integrity": "sha1-tSQ9jz7BqjXxNkYFvA0QNuMKtp8=", - "dev": true + "color-name": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", + "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", + "dev": true + } + } + }, + "sprintf-js": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/sprintf-js/-/sprintf-js-1.0.3.tgz", + "integrity": "sha1-BOaSb2YolTVPPdAVIDYzuFcpfiw=", + "dev": true + }, + "string-width": { + "version": "4.2.2", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.2.tgz", + "integrity": "sha512-XBJbT3N4JhVumXE0eoLU9DCjcaF92KLNqTmFCnG1pf8duUxFGwtP6AD6nkjw9a3IdiRtL3E2w3JDiE/xi3vOeA==", + "dev": true, + "requires": { + "emoji-regex": "^8.0.0", + "is-fullwidth-code-point": "^3.0.0", + "strip-ansi": "^6.0.0" + } + }, + "strip-ansi": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.0.tgz", + "integrity": "sha512-AuvKTrTfQNYNIctbR1K/YGTR1756GycPsg7b9bdV9Duqur4gv6aKqHXah67Z8ImS7WEz5QVcOtlfW2rZEugt6w==", + "dev": true, + "requires": { + "ansi-regex": "^5.0.0" + } + }, + "strip-json-comments": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/strip-json-comments/-/strip-json-comments-3.1.1.tgz", + "integrity": "sha512-6fPc+R4ihwqP6N/aIv2f1gMH8lOVtWQHoqC4yK6oSDVVocumAsfCqjkXnqiYMhmMwS/mEHLp7Vehlt3ql6lEig==", + "dev": true + }, + "supports-color": { + "version": "5.5.0", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-5.5.0.tgz", + "integrity": "sha512-QjVjwdXIt408MIiAqCX4oUKsgU2EqAGzs2Ppkm4aQYbjm+ZEWEcW4SfFNTr4uMNZma0ey4f5lgLrkB0aX0QMow==", + "dev": true, + "requires": { + "has-flag": "^3.0.0" + } + }, + "table": { + "version": "6.6.0", + "resolved": "https://registry.npmjs.org/table/-/table-6.6.0.tgz", + "integrity": "sha512-iZMtp5tUvcnAdtHpZTWLPF0M7AgiQsURR2DwmxnJwSy8I3+cY+ozzVvYha3BOLG2TB+L0CqjIz+91htuj6yCXg==", + "dev": true, + "requires": { + "ajv": "^8.0.1", + "lodash.clonedeep": "^4.5.0", + "lodash.flatten": "^4.4.0", + "lodash.truncate": "^4.4.2", + "slice-ansi": "^4.0.0", + "string-width": "^4.2.0", + "strip-ansi": "^6.0.0" + }, + "dependencies": { + "ajv": { + "version": "8.2.0", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.2.0.tgz", + "integrity": "sha512-WSNGFuyWd//XO8n/m/EaOlNLtO0yL8EXT/74LqT4khdhpZjP7lkj/kT5uwRmGitKEVp/Oj7ZUHeGfPtgHhQ5CA==", + "dev": true, + "requires": { + "fast-deep-equal": "^3.1.1", + "json-schema-traverse": "^1.0.0", + "require-from-string": "^2.0.2", + "uri-js": "^4.2.2" + } }, - "yallist": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz", - "integrity": "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==", - "dev": true + "json-schema-traverse": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", + "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", + "dev": true } + } + }, + "text-table": { + "version": "0.2.0", + "resolved": "https://registry.npmjs.org/text-table/-/text-table-0.2.0.tgz", + "integrity": "sha1-f17oI66AUgfACvLfSoTsP8+lcLQ=", + "dev": true + }, + "type-check": { + "version": "0.4.0", + "resolved": "https://registry.npmjs.org/type-check/-/type-check-0.4.0.tgz", + "integrity": "sha512-XleUoc9uwGXqjWwXaUTZAmzMcFZ5858QA2vvx1Ur5xIcixXIP+8LnFDgRplU30us6teqdlskFfu+ae4K79Ooew==", + "dev": true, + "requires": { + "prelude-ls": "^1.2.1" + } + }, + "type-fest": { + "version": "0.8.1", + "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-0.8.1.tgz", + "integrity": "sha512-4dbzIzqvjtgiM5rw1k5rEHtBANKmdudhGyBEajN01fEyhaAIhsoKNy6y7+IN93IfpFtwY9iqi7kD+xwKhQsNJA==", + "dev": true + }, + "uri-js": { + "version": "4.4.1", + "resolved": "https://registry.npmjs.org/uri-js/-/uri-js-4.4.1.tgz", + "integrity": "sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg==", + "dev": true, + "requires": { + "punycode": "^2.1.0" + } + }, + "v8-compile-cache": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/v8-compile-cache/-/v8-compile-cache-2.3.0.tgz", + "integrity": "sha512-l8lCEmLcLYZh4nbunNZvQCJc5pv7+RCwa8q/LdUx8u7lsWvPDKmpodJAJNwkAhJC//dFY48KuIEmjtd4RViDrA==", + "dev": true + }, + "which": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", + "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", + "dev": true, + "requires": { + "isexe": "^2.0.0" + } + }, + "word-wrap": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/word-wrap/-/word-wrap-1.2.3.tgz", + "integrity": "sha512-Hz/mrNwitNRh/HUAtM/VT/5VH+ygD6DV7mYKZAtHOrbs8U7lvPS6xf7EJKMF0uW1KJCl0H701g3ZGus+muE5vQ==", + "dev": true + }, + "wrappy": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/wrappy/-/wrappy-1.0.2.tgz", + "integrity": "sha1-tSQ9jz7BqjXxNkYFvA0QNuMKtp8=", + "dev": true + }, + "yallist": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz", + "integrity": "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==", + "dev": true } + } } diff --git a/dev/package.json b/dev/package.json index 0391a3983f78f..f975bdde8319a 100644 --- a/dev/package.json +++ b/dev/package.json @@ -1,5 +1,6 @@ { "devDependencies": { - "eslint": "^7.25.0" + "eslint": "^7.25.0", + "ansi-regex": "^5.0.1" } } diff --git a/dev/run-tests.py b/dev/run-tests.py index d943277e1d516..570ee4c8169cf 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -653,14 +653,14 @@ def main(): run_scala_tests(build_tool, extra_profiles, test_modules, excluded_tags, included_tags) modules_with_python_tests = [m for m in test_modules if m.python_test_goals] - if modules_with_python_tests: + if modules_with_python_tests and not os.environ.get("SKIP_PYTHON"): run_python_tests( modules_with_python_tests, opts.parallelism, with_coverage=os.environ.get("PYSPARK_CODECOV", "false") == "true", ) run_python_packaging_tests() - if any(m.should_run_r_tests for m in test_modules): + if any(m.should_run_r_tests for m in test_modules) and not os.environ.get("SKIP_R"): run_sparkr_tests() diff --git a/dev/scalastyle b/dev/scalastyle index 212ef900eb9b4..5f958b8fb0a7b 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -17,7 +17,7 @@ # limitations under the License. # -SPARK_PROFILES=${1:-"-Pmesos -Pkubernetes -Pyarn -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive"} +SPARK_PROFILES=${1:-"-Pmesos -Pkubernetes -Pyarn -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive -Pvolcano"} # NOTE: echo "q" is needed because SBT prompts the user for input on encountering a build file # with failure (either resolution or compilation); the "q" makes SBT quit. diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 7cd5bd15752ae..6e668bba8c803 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -453,6 +453,7 @@ def __hash__(self): "pyspark.sql.tests.test_pandas_udf_grouped_agg", "pyspark.sql.tests.test_pandas_udf_scalar", "pyspark.sql.tests.test_pandas_udf_typehints", + "pyspark.sql.tests.test_pandas_udf_typehints_with_future_annotations", "pyspark.sql.tests.test_pandas_udf_window", "pyspark.sql.tests.test_readwriter", "pyspark.sql.tests.test_serde", diff --git a/dev/tox.ini b/dev/tox.ini index df4dfce5dcaa0..464b9b959fa14 100644 --- a/dev/tox.ini +++ b/dev/tox.ini @@ -18,21 +18,25 @@ ignore = E203, # Skip as black formatter adds a whitespace around ':'. E402, # Module top level import is disabled for optional import check, etc. - F403, # Using wildcard discouraged but F401 can detect. Disabled to reduce the usage of noqa. # 1. Type hints with def are treated as redefinition (e.g., functions.log). # 2. Some are used for testing. F811, # There are too many instances to fix. Ignored for now. W503, W504, - - # Below rules should be enabled in the future. - E731, per-file-ignores = - # F405 and E501 are ignored as shared.py is auto-generated. - python/pyspark/ml/param/shared.py: F405 E501, + # E501 is ignored as shared.py is auto-generated. + python/pyspark/ml/param/shared.py: E501, # Examples contain some unused variables. examples/src/main/python/sql/datasource.py: F841, + # Exclude * imports in test files + python/pyspark/ml/tests/*.py: F403, + python/pyspark/mllib/tests/*.py: F403, + python/pyspark/pandas/tests/*.py: F401 F403, + python/pyspark/resource/tests/*.py: F403, + python/pyspark/sql/tests/*.py: F403, + python/pyspark/streaming/tests/*.py: F403, + python/pyspark/tests/*.py: F403 exclude = */target/*, docs/.local_ruby_bundle/, diff --git a/docs/README.md b/docs/README.md index 5e9a187ea3ab6..6bb83d8953057 100644 --- a/docs/README.md +++ b/docs/README.md @@ -48,23 +48,9 @@ $ bundle install Note: If you are on a system with both Ruby 1.9 and Ruby 2.0 you may need to replace gem with gem2.0. -### R Documentation +### SQL and Python API Documentation (Optional) -If you'd like to generate R documentation, you'll need to [install Pandoc](https://pandoc.org/installing.html) -and install these libraries: - -```sh -$ sudo Rscript -e 'install.packages(c("knitr", "devtools", "testthat", "rmarkdown"), repos="https://cloud.r-project.org/")' -$ sudo Rscript -e 'devtools::install_version("roxygen2", version = "7.1.1", repos="https://cloud.r-project.org/")' -$ sudo Rscript -e "devtools::install_version('pkgdown', version='2.0.1', repos='https://cloud.r-project.org')" -$ sudo Rscript -e "devtools::install_version('preferably', version='0.4', repos='https://cloud.r-project.org')" -``` - -Note: Other versions of roxygen2 might work in SparkR documentation generation but `RoxygenNote` field in `$SPARK_HOME/R/pkg/DESCRIPTION` is 7.1.1, which is updated if the version is mismatched. - -### API Documentation - -To generate API docs for any language, you'll need to install these libraries: +To generate SQL and Python API docs, you'll need to install these libraries: prob=%s, prediction=%f" % ( rid, text, str(prob), prediction # type: ignore diff --git a/examples/src/main/python/pagerank.py b/examples/src/main/python/pagerank.py index 0ab7249a82185..c0233461d119f 100755 --- a/examples/src/main/python/pagerank.py +++ b/examples/src/main/python/pagerank.py @@ -25,18 +25,20 @@ import re import sys from operator import add +from typing import Iterable, Tuple +from pyspark.resultiterable import ResultIterable from pyspark.sql import SparkSession -def computeContribs(urls, rank): +def computeContribs(urls: ResultIterable[str], rank: float) -> Iterable[Tuple[str, float]]: """Calculates URL contributions to the rank of other URLs.""" num_urls = len(urls) for url in urls: yield (url, rank / num_urls) -def parseNeighbors(urls): +def parseNeighbors(urls: str) -> Tuple[str, str]: """Parses a urls pair string into urls pair.""" parts = re.split(r'\s+', urls) return parts[0], parts[1] @@ -73,8 +75,9 @@ def parseNeighbors(urls): # Calculates and updates URL ranks continuously using PageRank algorithm. for iteration in range(int(sys.argv[2])): # Calculates URL contributions to the rank of other URLs. - contribs = links.join(ranks).flatMap( - lambda url_urls_rank: computeContribs(url_urls_rank[1][0], url_urls_rank[1][1])) + contribs = links.join(ranks).flatMap(lambda url_urls_rank: computeContribs( + url_urls_rank[1][0], url_urls_rank[1][1] # type: ignore[arg-type] + )) # Re-calculates URL ranks based on neighbor contributions. ranks = contribs.reduceByKey(add).mapValues(lambda rank: rank * 0.85 + 0.15) diff --git a/examples/src/main/python/pi.py b/examples/src/main/python/pi.py index e646722533f68..e61740ad58832 100755 --- a/examples/src/main/python/pi.py +++ b/examples/src/main/python/pi.py @@ -34,7 +34,7 @@ partitions = int(sys.argv[1]) if len(sys.argv) > 1 else 2 n = 100000 * partitions - def f(_): + def f(_: int) -> float: x = random() * 2 - 1 y = random() * 2 - 1 return 1 if x ** 2 + y ** 2 <= 1 else 0 diff --git a/examples/src/main/python/sql/arrow.py b/examples/src/main/python/sql/arrow.py index 298830ca26751..76dfbc4d73a0c 100644 --- a/examples/src/main/python/sql/arrow.py +++ b/examples/src/main/python/sql/arrow.py @@ -23,6 +23,8 @@ # NOTE that this file is imported in user guide in PySpark documentation. # The codes are referred via line numbers. See also `literalinclude` directive in Sphinx. +import pandas as pd +from typing import Iterable from pyspark.sql import SparkSession from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version @@ -31,7 +33,7 @@ require_minimum_pyarrow_version() -def dataframe_with_arrow_example(spark): +def dataframe_with_arrow_example(spark: SparkSession) -> None: import numpy as np import pandas as pd @@ -50,12 +52,12 @@ def dataframe_with_arrow_example(spark): print("Pandas DataFrame result statistics:\n%s\n" % str(result_pdf.describe())) -def ser_to_frame_pandas_udf_example(spark): +def ser_to_frame_pandas_udf_example(spark: SparkSession) -> None: import pandas as pd from pyspark.sql.functions import pandas_udf - @pandas_udf("col1 string, col2 long") + @pandas_udf("col1 string, col2 long") # type: ignore[call-overload] def func(s1: pd.Series, s2: pd.Series, s3: pd.DataFrame) -> pd.DataFrame: s3['col2'] = s1 + s2.str.len() return s3 @@ -78,7 +80,7 @@ def func(s1: pd.Series, s2: pd.Series, s3: pd.DataFrame) -> pd.DataFrame: # | |-- col2: long (nullable = true) -def ser_to_ser_pandas_udf_example(spark): +def ser_to_ser_pandas_udf_example(spark: SparkSession) -> None: import pandas as pd from pyspark.sql.functions import col, pandas_udf @@ -88,7 +90,7 @@ def ser_to_ser_pandas_udf_example(spark): def multiply_func(a: pd.Series, b: pd.Series) -> pd.Series: return a * b - multiply = pandas_udf(multiply_func, returnType=LongType()) + multiply = pandas_udf(multiply_func, returnType=LongType()) # type: ignore[call-overload] # The function for a pandas_udf should be able to execute with local Pandas data x = pd.Series([1, 2, 3]) @@ -112,7 +114,7 @@ def multiply_func(a: pd.Series, b: pd.Series) -> pd.Series: # +-------------------+ -def iter_ser_to_iter_ser_pandas_udf_example(spark): +def iter_ser_to_iter_ser_pandas_udf_example(spark: SparkSession) -> None: from typing import Iterator import pandas as pd @@ -123,7 +125,7 @@ def iter_ser_to_iter_ser_pandas_udf_example(spark): df = spark.createDataFrame(pdf) # Declare the function and create the UDF - @pandas_udf("long") + @pandas_udf("long") # type: ignore[call-overload] def plus_one(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: for x in iterator: yield x + 1 @@ -138,7 +140,7 @@ def plus_one(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: # +-----------+ -def iter_sers_to_iter_ser_pandas_udf_example(spark): +def iter_sers_to_iter_ser_pandas_udf_example(spark: SparkSession) -> None: from typing import Iterator, Tuple import pandas as pd @@ -149,7 +151,7 @@ def iter_sers_to_iter_ser_pandas_udf_example(spark): df = spark.createDataFrame(pdf) # Declare the function and create the UDF - @pandas_udf("long") + @pandas_udf("long") # type: ignore[call-overload] def multiply_two_cols( iterator: Iterator[Tuple[pd.Series, pd.Series]]) -> Iterator[pd.Series]: for a, b in iterator: @@ -165,7 +167,7 @@ def multiply_two_cols( # +-----------------------+ -def ser_to_scalar_pandas_udf_example(spark): +def ser_to_scalar_pandas_udf_example(spark: SparkSession) -> None: import pandas as pd from pyspark.sql.functions import pandas_udf @@ -176,7 +178,7 @@ def ser_to_scalar_pandas_udf_example(spark): ("id", "v")) # Declare the function and create the UDF - @pandas_udf("double") + @pandas_udf("double") # type: ignore[call-overload] def mean_udf(v: pd.Series) -> float: return v.mean() @@ -210,12 +212,12 @@ def mean_udf(v: pd.Series) -> float: # +---+----+------+ -def grouped_apply_in_pandas_example(spark): +def grouped_apply_in_pandas_example(spark: SparkSession) -> None: df = spark.createDataFrame( [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")) - def subtract_mean(pdf): + def subtract_mean(pdf: pd.DataFrame) -> pd.DataFrame: # pdf is a pandas.DataFrame v = pdf.v return pdf.assign(v=v - v.mean()) @@ -232,10 +234,10 @@ def subtract_mean(pdf): # +---+----+ -def map_in_pandas_example(spark): +def map_in_pandas_example(spark: SparkSession) -> None: df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age")) - def filter_func(iterator): + def filter_func(iterator: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]: for pdf in iterator: yield pdf[pdf.id == 1] @@ -247,7 +249,7 @@ def filter_func(iterator): # +---+---+ -def cogrouped_apply_in_pandas_example(spark): +def cogrouped_apply_in_pandas_example(spark: SparkSession) -> None: import pandas as pd df1 = spark.createDataFrame( @@ -258,7 +260,7 @@ def cogrouped_apply_in_pandas_example(spark): [(20000101, 1, "x"), (20000101, 2, "y")], ("time", "id", "v2")) - def asof_join(left, right): + def asof_join(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame: return pd.merge_asof(left, right, on="time", by="id") df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas( diff --git a/examples/src/main/python/sql/basic.py b/examples/src/main/python/sql/basic.py index cc63e9d71a254..4f7ec7ba267df 100644 --- a/examples/src/main/python/sql/basic.py +++ b/examples/src/main/python/sql/basic.py @@ -34,7 +34,7 @@ # $example off:programmatic_schema$ -def basic_df_example(spark): +def basic_df_example(spark: SparkSession) -> None: # $example on:create_df$ # spark is an existing SparkSession df = spark.read.json("examples/src/main/resources/people.json") @@ -137,7 +137,7 @@ def basic_df_example(spark): # $example off:global_temp_view$ -def schema_inference_example(spark): +def schema_inference_example(spark: SparkSession) -> None: # $example on:schema_inferring$ sc = spark.sparkContext @@ -162,7 +162,7 @@ def schema_inference_example(spark): # $example off:schema_inferring$ -def programmatic_schema_example(spark): +def programmatic_schema_example(spark: SparkSession) -> None: # $example on:programmatic_schema$ sc = spark.sparkContext diff --git a/examples/src/main/python/sql/datasource.py b/examples/src/main/python/sql/datasource.py index 4d7aa045b4b87..fd312dbf16476 100644 --- a/examples/src/main/python/sql/datasource.py +++ b/examples/src/main/python/sql/datasource.py @@ -26,7 +26,7 @@ # $example off:schema_merging$ -def generic_file_source_options_example(spark): +def generic_file_source_options_example(spark: SparkSession) -> None: # $example on:ignore_corrupt_files$ # enable ignore corrupt files spark.sql("set spark.sql.files.ignoreCorruptFiles=true") @@ -88,7 +88,7 @@ def generic_file_source_options_example(spark): # $example off:load_with_modified_time_filter$ -def basic_datasource_example(spark): +def basic_datasource_example(spark: SparkSession) -> None: # $example on:generic_load_save_functions$ df = spark.read.load("examples/src/main/resources/users.parquet") df.select("name", "favorite_color").write.save("namesAndFavColors.parquet") @@ -148,7 +148,7 @@ def basic_datasource_example(spark): spark.sql("DROP TABLE IF EXISTS users_partitioned_bucketed") -def parquet_example(spark): +def parquet_example(spark: SparkSession) -> None: # $example on:basic_parquet_example$ peopleDF = spark.read.json("examples/src/main/resources/people.json") @@ -172,7 +172,7 @@ def parquet_example(spark): # $example off:basic_parquet_example$ -def parquet_schema_merging_example(spark): +def parquet_schema_merging_example(spark: SparkSession) -> None: # $example on:schema_merging$ # spark is from the previous example. # Create a simple DataFrame, stored into a partition directory @@ -202,7 +202,7 @@ def parquet_schema_merging_example(spark): # $example off:schema_merging$ -def json_dataset_example(spark): +def json_dataset_example(spark: SparkSession) -> None: # $example on:json_dataset$ # spark is from the previous example. sc = spark.sparkContext @@ -244,7 +244,7 @@ def json_dataset_example(spark): # $example off:json_dataset$ -def csv_dataset_example(spark): +def csv_dataset_example(spark: SparkSession) -> None: # $example on:csv_dataset$ # spark is from the previous example sc = spark.sparkContext @@ -264,7 +264,7 @@ def csv_dataset_example(spark): # +------------------+ # Read a csv with delimiter, the default delimiter is "," - df2 = spark.read.option(delimiter=';').csv(path) + df2 = spark.read.option("delimiter", ";").csv(path) df2.show() # +-----+---+---------+ # | _c0|_c1| _c2| @@ -308,7 +308,7 @@ def csv_dataset_example(spark): # $example off:csv_dataset$ -def text_dataset_example(spark): +def text_dataset_example(spark: SparkSession) -> None: # $example on:text_dataset$ # spark is from the previous example sc = spark.sparkContext @@ -358,7 +358,7 @@ def text_dataset_example(spark): # $example off:text_dataset$ -def jdbc_dataset_example(spark): +def jdbc_dataset_example(spark: SparkSession) -> None: # $example on:jdbc_dataset$ # Note: JDBC loading and saving can be achieved via either the load/save or jdbc methods # Loading data from a JDBC source diff --git a/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py b/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py index 4aa44955d9d30..cc39d8afa6be9 100644 --- a/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py +++ b/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py @@ -87,7 +87,7 @@ windowedCounts = words.groupBy( window(words.timestamp, windowDuration, slideDuration), words.word - ).count().orderBy('window') # type: ignore[arg-type] + ).count().orderBy('window') # Start running the query that prints the windowed word counts to the console query = windowedCounts\ diff --git a/examples/src/main/python/status_api_demo.py b/examples/src/main/python/status_api_demo.py index fca733034b93e..3bf96ca4466fa 100644 --- a/examples/src/main/python/status_api_demo.py +++ b/examples/src/main/python/status_api_demo.py @@ -18,30 +18,31 @@ import time import threading import queue as Queue +from typing import Any, Callable, List, Tuple from pyspark import SparkConf, SparkContext -def delayed(seconds): - def f(x): +def delayed(seconds: int) -> Callable[[Any], Any]: + def f(x: int) -> int: time.sleep(seconds) return x return f -def call_in_background(f, *args): - result = Queue.Queue(1) +def call_in_background(f: Callable[..., Any], *args: Any) -> Queue.Queue: + result: Queue.Queue = Queue.Queue(1) t = threading.Thread(target=lambda: result.put(f(*args))) t.daemon = True t.start() return result -def main(): +def main() -> None: conf = SparkConf().set("spark.ui.showConsoleProgress", "false") sc = SparkContext(appName="PythonStatusAPIDemo", conf=conf) - def run(): + def run() -> List[Tuple[int, int]]: rdd = sc.parallelize(range(10), 10).map(delayed(2)) reduced = rdd.map(lambda x: (x, 1)).reduceByKey(lambda x, y: x + y) return reduced.map(delayed(2)).collect() @@ -52,6 +53,8 @@ def run(): ids = status.getJobIdsForGroup() for id in ids: job = status.getJobInfo(id) + assert job is not None + print("Job", id, "status: ", job.status) for sid in job.stageIds: info = status.getStageInfo(sid) diff --git a/examples/src/main/python/streaming/network_wordjoinsentiments.py b/examples/src/main/python/streaming/network_wordjoinsentiments.py index 15f727b0f28cd..b3f2114b9e8c2 100644 --- a/examples/src/main/python/streaming/network_wordjoinsentiments.py +++ b/examples/src/main/python/streaming/network_wordjoinsentiments.py @@ -34,10 +34,11 @@ from typing import Tuple from pyspark import SparkContext +from pyspark.rdd import RDD from pyspark.streaming import DStream, StreamingContext -def print_happiest_words(rdd): +def print_happiest_words(rdd: RDD[Tuple[float, str]]) -> None: top_list = rdd.take(5) print("Happiest topics in the last 5 seconds (%d total):" % rdd.count()) for tuple in top_list: diff --git a/examples/src/main/python/streaming/recoverable_network_wordcount.py b/examples/src/main/python/streaming/recoverable_network_wordcount.py index 212b6f605780d..9d3fe4c30ec61 100644 --- a/examples/src/main/python/streaming/recoverable_network_wordcount.py +++ b/examples/src/main/python/streaming/recoverable_network_wordcount.py @@ -35,28 +35,33 @@ checkpoint data exists in ~/checkpoint/, then it will create StreamingContext from the checkpoint data. """ +import datetime import os import sys +from typing import List, Tuple from pyspark import SparkContext +from pyspark.accumulators import Accumulator +from pyspark.broadcast import Broadcast +from pyspark.rdd import RDD from pyspark.streaming import StreamingContext # Get or register a Broadcast variable -def getWordExcludeList(sparkContext): +def getWordExcludeList(sparkContext: SparkContext) -> Broadcast[List[str]]: if ('wordExcludeList' not in globals()): globals()['wordExcludeList'] = sparkContext.broadcast(["a", "b", "c"]) return globals()['wordExcludeList'] # Get or register an Accumulator -def getDroppedWordsCounter(sparkContext): +def getDroppedWordsCounter(sparkContext: SparkContext) -> Accumulator[int]: if ('droppedWordsCounter' not in globals()): globals()['droppedWordsCounter'] = sparkContext.accumulator(0) return globals()['droppedWordsCounter'] -def createContext(host, port, outputPath): +def createContext(host: str, port: int, outputPath: str) -> StreamingContext: # If you do not see this printed, that means the StreamingContext has been loaded # from the new checkpoint print("Creating new context") @@ -71,14 +76,14 @@ def createContext(host, port, outputPath): words = lines.flatMap(lambda line: line.split(" ")) wordCounts = words.map(lambda x: (x, 1)).reduceByKey(lambda x, y: x + y) - def echo(time, rdd): + def echo(time: datetime.datetime, rdd: RDD[Tuple[str, int]]) -> None: # Get or register the excludeList Broadcast excludeList = getWordExcludeList(rdd.context) # Get or register the droppedWordsCounter Accumulator droppedWordsCounter = getDroppedWordsCounter(rdd.context) # Use excludeList to drop words and use droppedWordsCounter to count them - def filterFunc(wordCount): + def filterFunc(wordCount: Tuple[str, int]) -> bool: if wordCount[0] in excludeList.value: droppedWordsCounter.add(wordCount[1]) return False diff --git a/examples/src/main/python/streaming/sql_network_wordcount.py b/examples/src/main/python/streaming/sql_network_wordcount.py index 10bacbe0e6c6d..9518cb70ba784 100644 --- a/examples/src/main/python/streaming/sql_network_wordcount.py +++ b/examples/src/main/python/streaming/sql_network_wordcount.py @@ -28,13 +28,15 @@ `$ bin/spark-submit examples/src/main/python/streaming/sql_network_wordcount.py localhost 9999` """ import sys +import datetime -from pyspark import SparkContext +from pyspark import SparkConf, SparkContext +from pyspark.rdd import RDD from pyspark.streaming import StreamingContext from pyspark.sql import Row, SparkSession -def getSparkSessionInstance(sparkConf): +def getSparkSessionInstance(sparkConf: SparkConf) -> SparkSession: if ('sparkSessionSingletonInstance' not in globals()): globals()['sparkSessionSingletonInstance'] = SparkSession\ .builder\ @@ -57,7 +59,7 @@ def getSparkSessionInstance(sparkConf): words = lines.flatMap(lambda line: line.split(" ")) # Convert RDDs of the words DStream to DataFrame and run SQL query - def process(time, rdd): + def process(time: datetime.datetime, rdd: RDD[str]) -> None: print("========= %s =========" % str(time)) try: diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py index 7a45be663a765..553af5f48eda4 100644 --- a/examples/src/main/python/streaming/stateful_network_wordcount.py +++ b/examples/src/main/python/streaming/stateful_network_wordcount.py @@ -30,6 +30,7 @@ localhost 9999` """ import sys +from typing import Iterable, Optional from pyspark import SparkContext from pyspark.streaming import StreamingContext @@ -45,7 +46,7 @@ # RDD with initial state (key, value) pairs initialStateRDD = sc.parallelize([(u'hello', 1), (u'world', 1)]) - def updateFunc(new_values, last_sum): + def updateFunc(new_values: Iterable[int], last_sum: Optional[int]) -> int: return sum(new_values) + (last_sum or 0) lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2])) diff --git a/examples/src/main/python/transitive_closure.py b/examples/src/main/python/transitive_closure.py index 9f543daecd3dd..e1f3b66af82d9 100755 --- a/examples/src/main/python/transitive_closure.py +++ b/examples/src/main/python/transitive_closure.py @@ -17,6 +17,7 @@ import sys from random import Random +from typing import Set, Tuple from pyspark.sql import SparkSession @@ -25,8 +26,8 @@ rand = Random(42) -def generateGraph(): - edges = set() +def generateGraph() -> Set[Tuple[int, int]]: + edges: Set[Tuple[int, int]] = set() while len(edges) < numEdges: src = rand.randrange(0, numVertices) dst = rand.randrange(0, numVertices) diff --git a/examples/src/main/resources/META-INF/services/org.apache.spark.sql.SparkSessionExtensionsProvider b/examples/src/main/resources/META-INF/services/org.apache.spark.sql.SparkSessionExtensionsProvider index c239843a3b502..7a65f53236933 100644 --- a/examples/src/main/resources/META-INF/services/org.apache.spark.sql.SparkSessionExtensionsProvider +++ b/examples/src/main/resources/META-INF/services/org.apache.spark.sql.SparkSessionExtensionsProvider @@ -1 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + org.apache.spark.examples.extensions.SessionExtensionsWithLoader diff --git a/examples/src/main/resources/META-INF/services/org.apache.spark.sql.jdbc.JdbcConnectionProvider b/examples/src/main/resources/META-INF/services/org.apache.spark.sql.jdbc.JdbcConnectionProvider index 776948cc04de7..ccfca6eafce86 100644 --- a/examples/src/main/resources/META-INF/services/org.apache.spark.sql.jdbc.JdbcConnectionProvider +++ b/examples/src/main/resources/META-INF/services/org.apache.spark.sql.jdbc.JdbcConnectionProvider @@ -1 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + org.apache.spark.examples.sql.jdbc.ExampleJdbcConnectionProvider \ No newline at end of file diff --git a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala index ed56108f4b624..94fc755e0ca0f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala @@ -41,7 +41,7 @@ object DriverSubmissionTest { env.asScala.filter { case (k, _) => k.contains("SPARK_TEST")}.foreach(println) println("System properties containing spark.test:") - properties.filter { case (k, _) => k.toString.contains("spark.test") }.foreach(println) + properties.filter { case (k, _) => k.contains("spark.test") }.foreach(println) for (i <- 1 until numSecondsToSleep) { println(s"Alive for $i out of $numSecondsToSleep seconds") diff --git a/examples/src/main/scala/org/apache/spark/examples/extensions/AgeExample.scala b/examples/src/main/scala/org/apache/spark/examples/extensions/AgeExample.scala index d25f2204994c7..e4840241006db 100644 --- a/examples/src/main/scala/org/apache/spark/examples/extensions/AgeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/extensions/AgeExample.scala @@ -18,14 +18,15 @@ package org.apache.spark.examples.extensions import org.apache.spark.sql.catalyst.expressions.{CurrentDate, Expression, RuntimeReplaceable, SubtractDates} +import org.apache.spark.sql.catalyst.trees.UnaryLike /** * How old are you in days? */ -case class AgeExample(birthday: Expression, child: Expression) extends RuntimeReplaceable { - - def this(birthday: Expression) = this(birthday, SubtractDates(CurrentDate(), birthday)) - override def exprsReplaced: Seq[Expression] = Seq(birthday) - - override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) +case class AgeExample(birthday: Expression) extends RuntimeReplaceable with UnaryLike[Expression] { + override lazy val replacement: Expression = SubtractDates(CurrentDate(), birthday) + override def child: Expression = birthday + override protected def withNewChildInternal(newChild: Expression): Expression = { + copy(birthday = newChild) + } } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index a3006a1fa2be0..d80f54d18476f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -158,17 +158,18 @@ object LDAExample { println(s"Finished training LDA model. Summary:") println(s"\t Training time: $elapsed sec") - if (ldaModel.isInstanceOf[DistributedLDAModel]) { - val distLDAModel = ldaModel.asInstanceOf[DistributedLDAModel] - val avgLogLikelihood = distLDAModel.logLikelihood / actualCorpusSize.toDouble - println(s"\t Training data average log likelihood: $avgLogLikelihood") - println() + ldaModel match { + case distLDAModel: DistributedLDAModel => + val avgLogLikelihood = distLDAModel.logLikelihood / actualCorpusSize.toDouble + println(s"\t Training data average log likelihood: $avgLogLikelihood") + println() + case _ => // do nothing } // Print the topics, showing the top-weighted terms for each topic. val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10) val topics = topicIndices.map { case (terms, termWeights) => - terms.zip(termWeights).map { case (term, weight) => (vocabArray(term.toInt), weight) } + terms.zip(termWeights).map { case (term, weight) => (vocabArray(term), weight) } } println(s"${params.k} topics:") topics.zipWithIndex.foreach { case (topic, i) => diff --git a/external/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/external/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index d89f963059642..c61270406c3f3 100644 --- a/external/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/external/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + org.apache.spark.sql.v2.avro.AvroDataSourceV2 diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala index f2f754aabd3ed..4a82df6ba0dce 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala @@ -262,7 +262,7 @@ private[sql] class AvroSerializer( avroStruct, catalystStruct, avroPath, catalystPath, positionalFieldMatch) avroSchemaHelper.validateNoExtraCatalystFields(ignoreNullable = false) - avroSchemaHelper.validateNoExtraAvroFields() + avroSchemaHelper.validateNoExtraRequiredAvroFields() val (avroIndices, fieldConverters) = avroSchemaHelper.matchedFields.map { case AvroMatchedField(catalystField, _, avroField) => diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala index 149d0b6e73de6..ef9d22f35d048 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala @@ -270,10 +270,12 @@ private[sql] object AvroUtils extends Logging { /** * Validate that there are no Avro fields which don't have a matching Catalyst field, throwing - * [[IncompatibleSchemaException]] if such extra fields are found. + * [[IncompatibleSchemaException]] if such extra fields are found. Only required (non-nullable) + * fields are checked; nullable fields are ignored. */ - def validateNoExtraAvroFields(): Unit = { - (avroFieldArray.toSet -- matchedFields.map(_.avroField)).foreach { extraField => + def validateNoExtraRequiredAvroFields(): Unit = { + val extraFields = avroFieldArray.toSet -- matchedFields.map(_.avroField) + extraFields.filterNot(isNullable).foreach { extraField => if (positionalFieldMatch) { throw new IncompatibleSchemaException(s"Found field '${extraField.name()}' at position " + s"${extraField.pos()} of ${toFieldStr(avroPath)} from Avro schema but there is no " + @@ -328,4 +330,9 @@ private[sql] object AvroUtils extends Logging { case Seq() => "top-level record" case n => s"field '${n.mkString(".")}'" } + + /** Return true iff `avroField` is nullable, i.e. `UNION` type and has `NULL` as an option. */ + private[avro] def isNullable(avroField: Schema.Field): Boolean = + avroField.schema().getType == Schema.Type.UNION && + avroField.schema().getTypes.asScala.exists(_.getType == Schema.Type.NULL) } diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSchemaHelperSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSchemaHelperSuite.scala index 604b4e80d89e3..8ad06492fa5d9 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSchemaHelperSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSchemaHelperSuite.scala @@ -104,7 +104,7 @@ class AvroSchemaHelperSuite extends SQLTestUtils with SharedSparkSession { AvroMatchedField(catalystSchema("shared2"), 3, avroSchema.getField("shared2")) )) assertThrows[IncompatibleSchemaException] { - helper.validateNoExtraAvroFields() + helper.validateNoExtraRequiredAvroFields() } helper.validateNoExtraCatalystFields(ignoreNullable = true) assertThrows[IncompatibleSchemaException] { @@ -133,4 +133,27 @@ class AvroSchemaHelperSuite extends SQLTestUtils with SharedSparkSession { helperNullable.validateNoExtraCatalystFields(ignoreNullable = false) } } + + test("SPARK-34378: validateNoExtraRequiredAvroFields detects required and ignores nullable") { + val avroSchema = SchemaBuilder.record("record").fields() + .requiredInt("foo") + .nullableInt("bar", 1) + .optionalInt("baz") + .endRecord() + + val catalystFull = + new StructType().add("foo", IntegerType).add("bar", IntegerType).add("baz", IntegerType) + + def testValidation(catalystFieldToRemove: String): Unit = { + val filteredSchema = StructType(catalystFull.filterNot(_.name == catalystFieldToRemove)) + new AvroUtils.AvroSchemaHelper(avroSchema, filteredSchema, Seq(""), Seq(""), false) + .validateNoExtraRequiredAvroFields() + } + + assertThrows[IncompatibleSchemaException] { + testValidation("foo") + } + testValidation("bar") + testValidation("baz") + } } diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala index 6d0a734f381ee..bfd56613fd64c 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala @@ -121,30 +121,41 @@ class AvroSerdeSuite extends SparkFunSuite { } } - test("Fail to convert for serialization with field count mismatch") { - // Note that this is allowed for deserialization, but not serialization - val tooManyFields = - createAvroSchemaWithTopLevelFields(_.optionalInt("foo").optionalLong("bar")) - assertFailedConversionMessage(tooManyFields, Serializer, BY_NAME, + test("Fail to convert with missing Catalyst fields") { + val nestedFooField = SchemaBuilder.record("foo").fields().optionalInt("bar").endRecord() + val avroExtraOptional = createAvroSchemaWithTopLevelFields( + _.name("foo").`type`(nestedFooField).noDefault().optionalLong("bar")) + val avroExtraRequired = createAvroSchemaWithTopLevelFields( + _.name("foo").`type`(nestedFooField).noDefault().requiredLong("bar")) + + // serializing with extra _nullable_ Avro field is okay, but fails if extra field is required + withFieldMatchType(Serializer.create(CATALYST_STRUCT, avroExtraOptional, _)) + assertFailedConversionMessage(avroExtraRequired, Serializer, BY_NAME, "Found field 'bar' in Avro schema but there is no match in the SQL schema") - assertFailedConversionMessage(tooManyFields, Serializer, BY_POSITION, + assertFailedConversionMessage(avroExtraRequired, Serializer, BY_POSITION, "Found field 'bar' at position 1 of top-level record from Avro schema but there is no " + "match in the SQL schema at top-level record (using positional matching)") - val tooManyFieldsNested = + // deserializing should work regardless of whether the extra field is required or not + withFieldMatchType(Deserializer.create(CATALYST_STRUCT, avroExtraOptional, _)) + withFieldMatchType(Deserializer.create(CATALYST_STRUCT, avroExtraRequired, _)) + + val avroExtraNestedOptional = createNestedAvroSchemaWithFields("foo", _.optionalInt("bar").optionalInt("baz")) - assertFailedConversionMessage(tooManyFieldsNested, Serializer, BY_NAME, + val avroExtraNestedRequired = + createNestedAvroSchemaWithFields("foo", _.optionalInt("bar").requiredInt("baz")) + + // serializing with extra _nullable_ Avro field is okay, but fails if extra field is required + withFieldMatchType(Serializer.create(CATALYST_STRUCT, avroExtraNestedOptional, _)) + assertFailedConversionMessage(avroExtraNestedRequired, Serializer, BY_NAME, "Found field 'foo.baz' in Avro schema but there is no match in the SQL schema") - assertFailedConversionMessage(tooManyFieldsNested, Serializer, BY_POSITION, + assertFailedConversionMessage(avroExtraNestedRequired, Serializer, BY_POSITION, s"Found field 'baz' at position 1 of field 'foo' from Avro schema but there is no match " + s"in the SQL schema at field 'foo' (using positional matching)") - val tooFewFields = createAvroSchemaWithTopLevelFields(f => f) - assertFailedConversionMessage(tooFewFields, Serializer, BY_NAME, - "Cannot find field 'foo' in Avro schema") - assertFailedConversionMessage(tooFewFields, Serializer, BY_POSITION, - "Cannot find field at position 0 of top-level record from Avro schema " + - "(using positional matching)") + // deserializing should work regardless of whether the extra field is required or not + withFieldMatchType(Deserializer.create(CATALYST_STRUCT, avroExtraNestedOptional, _)) + withFieldMatchType(Deserializer.create(CATALYST_STRUCT, avroExtraNestedRequired, _)) } /** diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index d85baeb9386f2..a70fbc0d833e8 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -68,6 +68,8 @@ abstract class AvroSuite override protected def beforeAll(): Unit = { super.beforeAll() + // initialize SessionCatalog here so it has a clean hadoopConf + spark.sessionState.catalog spark.conf.set(SQLConf.FILES_MAX_PARTITION_BYTES.key, 1024) } @@ -1359,6 +1361,32 @@ abstract class AvroSuite } } + test("SPARK-34378: support writing user provided avro schema with missing optional fields") { + withTempDir { tempDir => + val avroSchema = SchemaBuilder.builder().record("test").fields() + .requiredString("f1").optionalString("f2").endRecord().toString() + + val data = Seq("foo", "bar") + + // Fail if required field f1 is missing + val e = intercept[SparkException] { + data.toDF("f2").write.option("avroSchema", avroSchema).format("avro").save(s"$tempDir/fail") + } + assertExceptionMsg[IncompatibleSchemaException](e, + "Found field 'f1' in Avro schema but there is no match in the SQL schema") + + val tempSaveDir = s"$tempDir/save/" + // Succeed if optional field f2 is missing + data.toDF("f1").write.option("avroSchema", avroSchema).format("avro").save(tempSaveDir) + + val newDf = spark.read.format("avro").load(tempSaveDir) + assert(newDf.schema === new StructType().add("f1", StringType).add("f2", StringType)) + val rows = newDf.collect() + assert(rows.map(_.getAs[String]("f1")).sorted === data.sorted) + rows.foreach(row => assert(row.isNullAt(1))) + } + } + test("SPARK-34133: Reading user provided schema respects case sensitivity for field matching") { val wrongCaseSchema = new StructType() .add("STRING", StringType, nullable = false) @@ -2301,9 +2329,10 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper { } assert(filterCondition.isDefined) // The partitions filters should be pushed down and no need to be reevaluated. - assert(filterCondition.get.collectFirst { - case a: AttributeReference if a.name == "p1" || a.name == "p2" => a - }.isEmpty) + assert(!filterCondition.get.exists { + case a: AttributeReference => a.name == "p1" || a.name == "p2" + case _ => false + }) val fileScan = df.queryExecution.executedPlan collectFirst { case BatchScanExec(_, f: AvroScan, _) => f diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index bb39e5fde6d08..e3070f462c1ff 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -162,5 +162,10 @@ mssql-jdbc test + + mysql + mysql-connector-java + test + diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala index 59eb49dc303df..6cee6622e1c1f 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala @@ -23,7 +23,7 @@ import java.util.Properties import org.scalatest.time.SpanSugar._ -import org.apache.spark.sql.Row +import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ import org.apache.spark.sql.types.{BooleanType, ByteType, ShortType, StructType} import org.apache.spark.tags.DockerTest @@ -198,4 +198,23 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite { """.stripMargin.replaceAll("\n", " ")) assert(sql("select x, y from queryOption").collect.toSet == expectedResult) } + + test("SPARK-30062") { + val expectedResult = Set( + (42, "fred"), + (17, "dave") + ).map { case (x, y) => + Row(Integer.valueOf(x), String.valueOf(y)) + } + val df = sqlContext.read.jdbc(jdbcUrl, "tbl", new Properties) + for (_ <- 0 to 2) { + df.write.mode(SaveMode.Append).jdbc(jdbcUrl, "tblcopy", new Properties) + } + assert(sqlContext.read.jdbc(jdbcUrl, "tblcopy", new Properties).count === 6) + df.write.mode(SaveMode.Overwrite).option("truncate", true) + .jdbc(jdbcUrl, "tblcopy", new Properties) + val actual = sqlContext.read.jdbc(jdbcUrl, "tblcopy", new Properties).collect + assert(actual.length === 2) + assert(actual.toSet === expectedResult) + } } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala index 5ac9a5191b010..4b2bbbdd8494c 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala @@ -18,13 +18,14 @@ package org.apache.spark.sql.jdbc.v2 import java.sql.Connection +import java.util.Locale import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.jdbc.DatabaseOnDocker import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest @@ -36,8 +37,9 @@ import org.apache.spark.tags.DockerTest * }}} */ @DockerTest -class DB2IntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { +class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override val catalogName: String = "db2" + override val namespaceOpt: Option[String] = Some("DB2INST1") override val db = new DatabaseOnDocker { override val imageName = sys.env.getOrElse("DB2_DOCKER_IMAGE_NAME", "ibmcom/db2:11.5.6.0a") override val env = Map( @@ -59,8 +61,13 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.db2", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.db2.url", db.getJdbcUrl(dockerIp, externalPort)) + .set("spark.sql.catalog.db2.pushDownAggregate", "true") - override def dataPreparation(conn: Connection): Unit = {} + override def tablePreparation(connection: Connection): Unit = { + connection.prepareStatement( + "CREATE TABLE employee (dept INTEGER, name VARCHAR(10), salary DECIMAL(20, 2), bonus DOUBLE)") + .executeUpdate() + } override def testUpdateColumnType(tbl: String): Unit = { sql(s"CREATE TABLE $tbl (ID INTEGER)") @@ -86,4 +93,17 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { val expectedSchema = new StructType().add("ID", IntegerType, true, defaultMetadata) assert(t.schema === expectedSchema) } + + override def caseConvert(tableName: String): String = tableName.toUpperCase(Locale.ROOT) + + testVarPop() + testVarPop(true) + testVarSamp() + testVarSamp(true) + testStddevPop() + testStddevPop(true) + testStddevSamp() + testStddevSamp(true) + testCovarPop() + testCovarSamp() } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2NamespaceSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2NamespaceSuite.scala new file mode 100644 index 0000000000000..f0e98fc2722b0 --- /dev/null +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2NamespaceSuite.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc.v2 + +import java.sql.Connection + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.tags.DockerTest + +/** + * To run this test suite for a specific version (e.g., ibmcom/db2:11.5.6.0a): + * {{{ + * ENABLE_DOCKER_INTEGRATION_TESTS=1 DB2_DOCKER_IMAGE_NAME=ibmcom/db2:11.5.6.0a + * ./build/sbt -Pdocker-integration-tests "testOnly *v2.DB2NamespaceSuite" + * }}} + */ +@DockerTest +class DB2NamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespaceTest { + override val db = new DatabaseOnDocker { + override val imageName = sys.env.getOrElse("DB2_DOCKER_IMAGE_NAME", "ibmcom/db2:11.5.6.0a") + override val env = Map( + "DB2INST1_PASSWORD" -> "rootpass", + "LICENSE" -> "accept", + "DBNAME" -> "db2foo", + "ARCHIVE_LOGS" -> "false", + "AUTOCONFIG" -> "false" + ) + override val usesIpc = false + override val jdbcPort: Int = 50000 + override val privileged = true + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:db2://$ip:$port/db2foo:user=db2inst1;password=rootpass;retrieveMessagesFromServerOnGetMessage=true;" //scalastyle:ignore + } + + val map = new CaseInsensitiveStringMap( + Map("url" -> db.getJdbcUrl(dockerIp, externalPort), + "driver" -> "com.ibm.db2.jcc.DB2Driver").asJava) + + catalog.initialize("db2", map) + + override def dataPreparation(conn: Connection): Unit = {} + + override def builtinNamespaces: Array[Array[String]] = + Array(Array("NULLID"), Array("SQLJ"), Array("SYSCAT"), Array("SYSFUN"), + Array("SYSIBM"), Array("SYSIBMADM"), Array("SYSIBMINTERNAL"), Array("SYSIBMTS"), + Array("SYSPROC"), Array("SYSPUBLIC"), Array("SYSSTAT"), Array("SYSTOOLS")) + + override def listNamespaces(namespace: Array[String]): Array[Array[String]] = { + builtinNamespaces ++ Array(namespace) + } + + override val supportsDropSchemaCascade: Boolean = false + + testListNamespaces() + testDropNamespaces() +} diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala new file mode 100644 index 0000000000000..72edfc9f1bf1c --- /dev/null +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc.v2 + +import java.sql.Connection + +import org.apache.spark.sql.jdbc.DockerJDBCIntegrationSuite + +abstract class DockerJDBCIntegrationV2Suite extends DockerJDBCIntegrationSuite { + + /** + * Prepare databases and tables for testing. + */ + override def dataPreparation(connection: Connection): Unit = { + tablePreparation(connection) + connection.prepareStatement("INSERT INTO employee VALUES (1, 'amy', 10000, 1000)") + .executeUpdate() + connection.prepareStatement("INSERT INTO employee VALUES (2, 'alex', 12000, 1200)") + .executeUpdate() + connection.prepareStatement("INSERT INTO employee VALUES (1, 'cathy', 9000, 1200)") + .executeUpdate() + connection.prepareStatement("INSERT INTO employee VALUES (2, 'david', 10000, 1300)") + .executeUpdate() + connection.prepareStatement("INSERT INTO employee VALUES (6, 'jen', 12000, 1200)") + .executeUpdate() + } + + def tablePreparation(connection: Connection): Unit +} diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index 75446fb50e45b..a527c6f8cb5b6 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.jdbc.DatabaseOnDocker import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest @@ -37,7 +37,7 @@ import org.apache.spark.tags.DockerTest * }}} */ @DockerTest -class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { +class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override val catalogName: String = "mssql" @@ -58,10 +58,15 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBC override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.mssql", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.mssql.url", db.getJdbcUrl(dockerIp, externalPort)) + .set("spark.sql.catalog.mssql.pushDownAggregate", "true") override val connectionTimeout = timeout(7.minutes) - override def dataPreparation(conn: Connection): Unit = {} + override def tablePreparation(connection: Connection): Unit = { + connection.prepareStatement( + "CREATE TABLE employee (dept INT, name VARCHAR(32), salary NUMERIC(20, 2), bonus FLOAT)") + .executeUpdate() + } override def notSupportsTableComment: Boolean = true @@ -91,4 +96,13 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBC assert(msg.contains("UpdateColumnNullability is not supported")) } + + testVarPop() + testVarPop(true) + testVarSamp() + testVarSamp(true) + testStddevPop() + testStddevPop(true) + testStddevSamp() + testStddevSamp(true) } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala new file mode 100644 index 0000000000000..aa8dac266380a --- /dev/null +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc.v2 + +import java.sql.Connection + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.tags.DockerTest + +/** + * To run this test suite for a specific version (e.g., 2019-CU13-ubuntu-20.04): + * {{{ + * ENABLE_DOCKER_INTEGRATION_TESTS=1 + * MSSQLSERVER_DOCKER_IMAGE_NAME=mcr.microsoft.com/mssql/server:2019-CU13-ubuntu-20.04 + * ./build/sbt -Pdocker-integration-tests "testOnly *v2.MsSqlServerNamespaceSuite" + * }}} + */ +@DockerTest +class MsSqlServerNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespaceTest { + override val db = new DatabaseOnDocker { + override val imageName = sys.env.getOrElse("MSSQLSERVER_DOCKER_IMAGE_NAME", + "mcr.microsoft.com/mssql/server:2019-CU13-ubuntu-20.04") + override val env = Map( + "SA_PASSWORD" -> "Sapass123", + "ACCEPT_EULA" -> "Y" + ) + override val usesIpc = false + override val jdbcPort: Int = 1433 + + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:sqlserver://$ip:$port;user=sa;password=Sapass123;" + } + + val map = new CaseInsensitiveStringMap( + Map("url" -> db.getJdbcUrl(dockerIp, externalPort), + "driver" -> "com.microsoft.sqlserver.jdbc.SQLServerDriver").asJava) + + catalog.initialize("mssql", map) + + override def dataPreparation(conn: Connection): Unit = {} + + override def builtinNamespaces: Array[Array[String]] = + Array(Array("db_accessadmin"), Array("db_backupoperator"), Array("db_datareader"), + Array("db_datawriter"), Array("db_ddladmin"), Array("db_denydatareader"), + Array("db_denydatawriter"), Array("db_owner"), Array("db_securityadmin"), Array("dbo"), + Array("guest"), Array("INFORMATION_SCHEMA"), Array("sys")) + + override def listNamespaces(namespace: Array[String]): Array[Array[String]] = { + builtinNamespaces ++ Array(namespace) + } + + override val supportsSchemaComment: Boolean = false + + override val supportsDropSchemaCascade: Boolean = false + + testListNamespaces() + testDropNamespaces() +} diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala index 71adc51b87441..97f521a378eb7 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala @@ -24,22 +24,19 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.jdbc.DatabaseOnDocker import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest /** - * * To run this test suite for a specific version (e.g., mysql:5.7.36): * {{{ * ENABLE_DOCKER_INTEGRATION_TESTS=1 MYSQL_DOCKER_IMAGE_NAME=mysql:5.7.36 * ./build/sbt -Pdocker-integration-tests "testOnly *v2*MySQLIntegrationSuite" - * * }}} - * */ @DockerTest -class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { +class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override val catalogName: String = "mysql" override val db = new DatabaseOnDocker { override val imageName = sys.env.getOrElse("MYSQL_DOCKER_IMAGE_NAME", "mysql:5.7.36") @@ -57,13 +54,17 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.mysql", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.mysql.url", db.getJdbcUrl(dockerIp, externalPort)) + .set("spark.sql.catalog.mysql.pushDownAggregate", "true") override val connectionTimeout = timeout(7.minutes) private var mySQLVersion = -1 - override def dataPreparation(conn: Connection): Unit = { - mySQLVersion = conn.getMetaData.getDatabaseMajorVersion + override def tablePreparation(connection: Connection): Unit = { + mySQLVersion = connection.getMetaData.getDatabaseMajorVersion + connection.prepareStatement( + "CREATE TABLE employee (dept INT, name VARCHAR(32), salary DECIMAL(20, 2)," + + " bonus DOUBLE)").executeUpdate() } override def testUpdateColumnType(tbl: String): Unit = { @@ -119,4 +120,9 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { override def supportsIndex: Boolean = true override def indexOptions: String = "KEY_BLOCK_SIZE=10" + + testVarPop() + testVarSamp() + testStddevPop() + testStddevSamp() } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala new file mode 100644 index 0000000000000..d8dee61d70ea6 --- /dev/null +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc.v2 + +import java.sql.{Connection, SQLFeatureNotSupportedException} + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.connector.catalog.NamespaceChange +import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.tags.DockerTest + +/** + * To run this test suite for a specific version (e.g., mysql:5.7.36): + * {{{ + * ENABLE_DOCKER_INTEGRATION_TESTS=1 MYSQL_DOCKER_IMAGE_NAME=mysql:5.7.36 + * ./build/sbt -Pdocker-integration-tests "testOnly *v2*MySQLNamespaceSuite" + * }}} + */ +@DockerTest +class MySQLNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespaceTest { + override val db = new DatabaseOnDocker { + override val imageName = sys.env.getOrElse("MYSQL_DOCKER_IMAGE_NAME", "mysql:5.7.36") + override val env = Map( + "MYSQL_ROOT_PASSWORD" -> "rootpass" + ) + override val usesIpc = false + override val jdbcPort: Int = 3306 + + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:mysql://$ip:$port/" + + s"mysql?user=root&password=rootpass&allowPublicKeyRetrieval=true&useSSL=false" + } + + val map = new CaseInsensitiveStringMap( + Map("url" -> db.getJdbcUrl(dockerIp, externalPort), + "driver" -> "com.mysql.jdbc.Driver").asJava) + + catalog.initialize("mysql", map) + + override def dataPreparation(conn: Connection): Unit = {} + + override def builtinNamespaces: Array[Array[String]] = + Array(Array("information_schema"), Array("mysql"), Array("performance_schema"), Array("sys")) + + override def listNamespaces(namespace: Array[String]): Array[Array[String]] = { + Array(builtinNamespaces.head, namespace) ++ builtinNamespaces.tail + } + + override val supportsSchemaComment: Boolean = false + + override val supportsDropSchemaRestrict: Boolean = false + + testListNamespaces() + testDropNamespaces() + + test("Create or remove comment of namespace unsupported") { + val e1 = intercept[AnalysisException] { + catalog.createNamespace(Array("foo"), Map("comment" -> "test comment").asJava) + } + assert(e1.getMessage.contains("Failed create name space: foo")) + assert(e1.getCause.isInstanceOf[SQLFeatureNotSupportedException]) + assert(e1.getCause.asInstanceOf[SQLFeatureNotSupportedException].getMessage + .contains("Create namespace comment is not supported")) + assert(catalog.namespaceExists(Array("foo")) === false) + catalog.createNamespace(Array("foo"), Map.empty[String, String].asJava) + assert(catalog.namespaceExists(Array("foo")) === true) + val e2 = intercept[AnalysisException] { + catalog.alterNamespace(Array("foo"), NamespaceChange + .setProperty("comment", "comment for foo")) + } + assert(e2.getMessage.contains("Failed create comment on name space: foo")) + assert(e2.getCause.isInstanceOf[SQLFeatureNotSupportedException]) + assert(e2.getCause.asInstanceOf[SQLFeatureNotSupportedException].getMessage + .contains("Create namespace comment is not supported")) + val e3 = intercept[AnalysisException] { + catalog.alterNamespace(Array("foo"), NamespaceChange.removeProperty("comment")) + } + assert(e3.getMessage.contains("Failed remove comment on name space: foo")) + assert(e3.getCause.isInstanceOf[SQLFeatureNotSupportedException]) + assert(e3.getCause.asInstanceOf[SQLFeatureNotSupportedException].getMessage + .contains("Remove namespace comment is not supported")) + catalog.dropNamespace(Array("foo"), cascade = true) + assert(catalog.namespaceExists(Array("foo")) === false) + } +} diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala index ef8fe5354c540..2669924dc28c0 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala @@ -18,13 +18,14 @@ package org.apache.spark.sql.jdbc.v2 import java.sql.Connection +import java.util.Locale import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.jdbc.DatabaseOnDocker import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest @@ -54,8 +55,9 @@ import org.apache.spark.tags.DockerTest * This procedure has been validated with Oracle 18.4.0 Express Edition. */ @DockerTest -class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { +class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override val catalogName: String = "oracle" + override val namespaceOpt: Option[String] = Some("SYSTEM") override val db = new DatabaseOnDocker { lazy override val imageName = sys.env.getOrElse("ORACLE_DOCKER_IMAGE_NAME", "gvenzl/oracle-xe:18.4.0") @@ -73,9 +75,15 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.oracle", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.oracle.url", db.getJdbcUrl(dockerIp, externalPort)) + .set("spark.sql.catalog.oracle.pushDownAggregate", "true") override val connectionTimeout = timeout(7.minutes) - override def dataPreparation(conn: Connection): Unit = {} + + override def tablePreparation(connection: Connection): Unit = { + connection.prepareStatement( + "CREATE TABLE employee (dept NUMBER(32), name VARCHAR2(32), salary NUMBER(20, 2)," + + " bonus BINARY_DOUBLE)").executeUpdate() + } override def testUpdateColumnType(tbl: String): Unit = { sql(s"CREATE TABLE $tbl (ID INTEGER)") @@ -93,4 +101,14 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest assert(msg1.contains( s"Cannot update $catalogName.alt_table field ID: string cannot be cast to int")) } + + override def caseConvert(tableName: String): String = tableName.toUpperCase(Locale.ROOT) + + testVarPop() + testVarSamp() + testStddevPop() + testStddevSamp() + testCovarPop() + testCovarSamp() + testCorr() } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleNamespaceSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleNamespaceSuite.scala new file mode 100644 index 0000000000000..31f26d2990666 --- /dev/null +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleNamespaceSuite.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc.v2 + +import java.sql.Connection + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.tags.DockerTest + +/** + * The following are the steps to test this: + * + * 1. Choose to use a prebuilt image or build Oracle database in a container + * - The documentation on how to build Oracle RDBMS in a container is at + * https://github.com/oracle/docker-images/blob/master/OracleDatabase/SingleInstance/README.md + * - Official Oracle container images can be found at https://container-registry.oracle.com + * - A trustable and streamlined Oracle XE database image can be found on Docker Hub at + * https://hub.docker.com/r/gvenzl/oracle-xe see also https://github.com/gvenzl/oci-oracle-xe + * 2. Run: export ORACLE_DOCKER_IMAGE_NAME=image_you_want_to_use_for_testing + * - Example: export ORACLE_DOCKER_IMAGE_NAME=gvenzl/oracle-xe:latest + * 3. Run: export ENABLE_DOCKER_INTEGRATION_TESTS=1 + * 4. Start docker: sudo service docker start + * - Optionally, docker pull $ORACLE_DOCKER_IMAGE_NAME + * 5. Run Spark integration tests for Oracle with: ./build/sbt -Pdocker-integration-tests + * "testOnly org.apache.spark.sql.jdbc.v2.OracleNamespaceSuite" + * + * A sequence of commands to build the Oracle XE database container image: + * $ git clone https://github.com/oracle/docker-images.git + * $ cd docker-images/OracleDatabase/SingleInstance/dockerfiles + * $ ./buildContainerImage.sh -v 18.4.0 -x + * $ export ORACLE_DOCKER_IMAGE_NAME=oracle/database:18.4.0-xe + * + * This procedure has been validated with Oracle 18.4.0 Express Edition. + */ +@DockerTest +class OracleNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespaceTest { + override val db = new DatabaseOnDocker { + lazy override val imageName = + sys.env.getOrElse("ORACLE_DOCKER_IMAGE_NAME", "gvenzl/oracle-xe:18.4.0") + val oracle_password = "Th1s1sThe0racle#Pass" + override val env = Map( + "ORACLE_PWD" -> oracle_password, // oracle images uses this + "ORACLE_PASSWORD" -> oracle_password // gvenzl/oracle-xe uses this + ) + override val usesIpc = false + override val jdbcPort: Int = 1521 + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:oracle:thin:system/$oracle_password@//$ip:$port/xe" + } + + val map = new CaseInsensitiveStringMap( + Map("url" -> db.getJdbcUrl(dockerIp, externalPort), + "driver" -> "oracle.jdbc.OracleDriver").asJava) + + catalog.initialize("system", map) + + override def dataPreparation(conn: Connection): Unit = {} + + override def builtinNamespaces: Array[Array[String]] = + Array(Array("ANONYMOUS"), Array("APEX_030200"), Array("APEX_PUBLIC_USER"), Array("APPQOSSYS"), + Array("BI"), Array("DIP"), Array("FLOWS_FILES"), Array("HR"), Array("OE"), Array("PM"), + Array("SCOTT"), Array("SH"), Array("SPATIAL_CSW_ADMIN_USR"), Array("SPATIAL_WFS_ADMIN_USR"), + Array("XS$NULL")) + + // Cannot create schema dynamically + // TODO testListNamespaces() + // TODO testDropNamespaces() +} diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala index 7fba6671ffe71..77ace3f3f4ea7 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala @@ -22,7 +22,7 @@ import java.sql.Connection import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.jdbc.DatabaseOnDocker import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest @@ -34,7 +34,7 @@ import org.apache.spark.tags.DockerTest * }}} */ @DockerTest -class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { +class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override val catalogName: String = "postgresql" override val db = new DatabaseOnDocker { override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:14.0-alpine") @@ -51,8 +51,13 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTes .set("spark.sql.catalog.postgresql.url", db.getJdbcUrl(dockerIp, externalPort)) .set("spark.sql.catalog.postgresql.pushDownTableSample", "true") .set("spark.sql.catalog.postgresql.pushDownLimit", "true") + .set("spark.sql.catalog.postgresql.pushDownAggregate", "true") - override def dataPreparation(conn: Connection): Unit = {} + override def tablePreparation(connection: Connection): Unit = { + connection.prepareStatement( + "CREATE TABLE employee (dept INTEGER, name VARCHAR(32), salary NUMERIC(20, 2)," + + " bonus double precision)").executeUpdate() + } override def testUpdateColumnType(tbl: String): Unit = { sql(s"CREATE TABLE $tbl (ID INTEGER)") @@ -84,4 +89,19 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTes override def supportsIndex: Boolean = true override def indexOptions: String = "FILLFACTOR=70" + + testVarPop() + testVarPop(true) + testVarSamp() + testVarSamp(true) + testStddevPop() + testStddevPop(true) + testStddevSamp() + testStddevSamp(true) + testCovarPop() + testCovarPop(true) + testCovarSamp() + testCovarSamp(true) + testCorr() + testCorr(true) } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala index a7744d18433f1..33190103d6a9a 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala @@ -53,7 +53,9 @@ class PostgresNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNames override def dataPreparation(conn: Connection): Unit = {} - override def builtinNamespaces: Array[Array[String]] = { + override def builtinNamespaces: Array[Array[String]] = Array(Array("information_schema"), Array("pg_catalog"), Array("public")) - } + + testListNamespaces() + testDropNamespaces() } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala index 284b05c1cc120..bae0d7c361635 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala @@ -17,47 +17,117 @@ package org.apache.spark.sql.jdbc.v2 +import java.util +import java.util.Collections + import scala.collection.JavaConverters._ import org.apache.logging.log4j.Level import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.connector.catalog.NamespaceChange +import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException +import org.apache.spark.sql.connector.catalog.{Identifier, NamespaceChange} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.jdbc.DockerIntegrationFunSuite import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} import org.apache.spark.tags.DockerTest @DockerTest private[v2] trait V2JDBCNamespaceTest extends SharedSparkSession with DockerIntegrationFunSuite { val catalog = new JDBCTableCatalog() + private val emptyProps: util.Map[String, String] = Collections.emptyMap[String, String] + private val schema: StructType = new StructType() + .add("id", IntegerType) + .add("data", StringType) + def builtinNamespaces: Array[Array[String]] - test("listNamespaces: basic behavior") { - catalog.createNamespace(Array("foo"), Map("comment" -> "test comment").asJava) - assert(catalog.listNamespaces() === Array(Array("foo")) ++ builtinNamespaces) - assert(catalog.listNamespaces(Array("foo")) === Array()) - assert(catalog.namespaceExists(Array("foo")) === true) - - val logAppender = new LogAppender("catalog comment") - withLogAppender(logAppender) { - catalog.alterNamespace(Array("foo"), NamespaceChange - .setProperty("comment", "comment for foo")) - catalog.alterNamespace(Array("foo"), NamespaceChange.removeProperty("comment")) + def listNamespaces(namespace: Array[String]): Array[Array[String]] = { + Array(namespace) ++ builtinNamespaces + } + + def supportsSchemaComment: Boolean = true + + def supportsDropSchemaCascade: Boolean = true + + def supportsDropSchemaRestrict: Boolean = true + + def testListNamespaces(): Unit = { + test("listNamespaces: basic behavior") { + val commentMap = if (supportsSchemaComment) { + Map("comment" -> "test comment") + } else { + Map.empty[String, String] + } + catalog.createNamespace(Array("foo"), commentMap.asJava) + assert(catalog.listNamespaces() === listNamespaces(Array("foo"))) + assert(catalog.listNamespaces(Array("foo")) === Array()) + assert(catalog.namespaceExists(Array("foo")) === true) + + if (supportsSchemaComment) { + val logAppender = new LogAppender("catalog comment") + withLogAppender(logAppender) { + catalog.alterNamespace(Array("foo"), NamespaceChange + .setProperty("comment", "comment for foo")) + catalog.alterNamespace(Array("foo"), NamespaceChange.removeProperty("comment")) + } + val createCommentWarning = logAppender.loggingEvents + .filter(_.getLevel == Level.WARN) + .map(_.getMessage.getFormattedMessage) + .exists(_.contains("catalog comment")) + assert(createCommentWarning === false) + } + + if (supportsDropSchemaRestrict) { + catalog.dropNamespace(Array("foo"), cascade = false) + } else { + catalog.dropNamespace(Array("foo"), cascade = true) + } + assert(catalog.namespaceExists(Array("foo")) === false) + assert(catalog.listNamespaces() === builtinNamespaces) + val msg = intercept[AnalysisException] { + catalog.listNamespaces(Array("foo")) + }.getMessage + assert(msg.contains("Namespace 'foo' not found")) + } + } + + def testDropNamespaces(): Unit = { + test("Drop namespace") { + val ident1 = Identifier.of(Array("foo"), "tab") + // Drop empty namespace without cascade + val commentMap = if (supportsSchemaComment) { + Map("comment" -> "test comment") + } else { + Map.empty[String, String] + } + catalog.createNamespace(Array("foo"), commentMap.asJava) + assert(catalog.namespaceExists(Array("foo")) === true) + if (supportsDropSchemaRestrict) { + catalog.dropNamespace(Array("foo"), cascade = false) + } else { + catalog.dropNamespace(Array("foo"), cascade = true) + } + assert(catalog.namespaceExists(Array("foo")) === false) + + // Drop non empty namespace without cascade + catalog.createNamespace(Array("foo"), commentMap.asJava) + assert(catalog.namespaceExists(Array("foo")) === true) + catalog.createTable(ident1, schema, Array.empty, emptyProps) + if (supportsDropSchemaRestrict) { + intercept[NonEmptyNamespaceException] { + catalog.dropNamespace(Array("foo"), cascade = false) + } + } + + // Drop non empty namespace with cascade + if (supportsDropSchemaCascade) { + assert(catalog.namespaceExists(Array("foo")) === true) + catalog.dropNamespace(Array("foo"), cascade = true) + assert(catalog.namespaceExists(Array("foo")) === false) + } } - val createCommentWarning = logAppender.loggingEvents - .filter(_.getLevel == Level.WARN) - .map(_.getMessage.getFormattedMessage) - .exists(_.contains("catalog comment")) - assert(createCommentWarning === false) - - catalog.dropNamespace(Array("foo")) - assert(catalog.namespaceExists(Array("foo")) === false) - assert(catalog.listNamespaces() === builtinNamespaces) - val msg = intercept[AnalysisException] { - catalog.listNamespaces(Array("foo")) - }.getMessage - assert(msg.contains("Namespace 'foo' not found")) } } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala index 49aa20387e38e..ebd5b844cbc9b 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.jdbc.v2 import org.apache.logging.log4j.Level -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{AnalysisException, DataFrame} import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException} -import org.apache.spark.sql.catalyst.plans.logical.{Filter, Sample} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sample} import org.apache.spark.sql.connector.catalog.{Catalogs, Identifier, TableCatalog} import org.apache.spark.sql.connector.catalog.index.SupportsIndex +import org.apache.spark.sql.connector.expressions.aggregate.GeneralAggregateFunc import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.jdbc.DockerIntegrationFunSuite import org.apache.spark.sql.test.SharedSparkSession @@ -36,6 +36,12 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu import testImplicits._ val catalogName: String + + val namespaceOpt: Option[String] = None + + private def catalogAndNamespace = + namespaceOpt.map(namespace => s"$catalogName.$namespace").getOrElse(catalogName) + // dialect specific update column type test def testUpdateColumnType(tbl: String): Unit @@ -246,22 +252,30 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu def supportsTableSample: Boolean = false - private def samplePushed(df: DataFrame): Boolean = { + private def checkSamplePushed(df: DataFrame, pushed: Boolean = true): Unit = { val sample = df.queryExecution.optimizedPlan.collect { case s: Sample => s } - sample.isEmpty + if (pushed) { + assert(sample.isEmpty) + } else { + assert(sample.nonEmpty) + } } - private def filterPushed(df: DataFrame): Boolean = { + private def checkFilterPushed(df: DataFrame, pushed: Boolean = true): Unit = { val filter = df.queryExecution.optimizedPlan.collect { case f: Filter => f } - filter.isEmpty + if (pushed) { + assert(filter.isEmpty) + } else { + assert(filter.nonEmpty) + } } private def limitPushed(df: DataFrame, limit: Int): Boolean = { - val filter = df.queryExecution.optimizedPlan.collect { + df.queryExecution.optimizedPlan.collect { case relation: DataSourceV2ScanRelation => relation.scan match { case v1: V1ScanWrapper => return v1.pushedDownOperators.limit == Some(limit) @@ -270,11 +284,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu false } - private def columnPruned(df: DataFrame, col: String): Boolean = { + private def checkColumnPruned(df: DataFrame, col: String): Unit = { val scan = df.queryExecution.optimizedPlan.collectFirst { case s: DataSourceV2ScanRelation => s }.get - scan.schema.names.sameElements(Seq(col)) + assert(scan.schema.names.sameElements(Seq(col))) } test("SPARK-37038: Test TABLESAMPLE") { @@ -286,37 +300,37 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu // sample push down + column pruning val df1 = sql(s"SELECT col1 FROM $catalogName.new_table TABLESAMPLE (BUCKET 6 OUT OF 10)" + " REPEATABLE (12345)") - assert(samplePushed(df1)) - assert(columnPruned(df1, "col1")) + checkSamplePushed(df1) + checkColumnPruned(df1, "col1") assert(df1.collect().length < 10) // sample push down only val df2 = sql(s"SELECT * FROM $catalogName.new_table TABLESAMPLE (50 PERCENT)" + " REPEATABLE (12345)") - assert(samplePushed(df2)) + checkSamplePushed(df2) assert(df2.collect().length < 10) // sample(BUCKET ... OUT OF) push down + limit push down + column pruning val df3 = sql(s"SELECT col1 FROM $catalogName.new_table TABLESAMPLE (BUCKET 6 OUT OF 10)" + " LIMIT 2") - assert(samplePushed(df3)) + checkSamplePushed(df3) assert(limitPushed(df3, 2)) - assert(columnPruned(df3, "col1")) + checkColumnPruned(df3, "col1") assert(df3.collect().length <= 2) // sample(... PERCENT) push down + limit push down + column pruning val df4 = sql(s"SELECT col1 FROM $catalogName.new_table" + " TABLESAMPLE (50 PERCENT) REPEATABLE (12345) LIMIT 2") - assert(samplePushed(df4)) + checkSamplePushed(df4) assert(limitPushed(df4, 2)) - assert(columnPruned(df4, "col1")) + checkColumnPruned(df4, "col1") assert(df4.collect().length <= 2) // sample push down + filter push down + limit push down val df5 = sql(s"SELECT * FROM $catalogName.new_table" + " TABLESAMPLE (BUCKET 6 OUT OF 10) WHERE col1 > 0 LIMIT 2") - assert(samplePushed(df5)) - assert(filterPushed(df5)) + checkSamplePushed(df5) + checkFilterPushed(df5) assert(limitPushed(df5, 2)) assert(df5.collect().length <= 2) @@ -325,27 +339,168 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu // Todo: push down filter/limit val df6 = sql(s"SELECT col1 FROM $catalogName.new_table" + " TABLESAMPLE (BUCKET 6 OUT OF 10) WHERE col1 > 0 LIMIT 2") - assert(samplePushed(df6)) - assert(!filterPushed(df6)) + checkSamplePushed(df6) + checkFilterPushed(df6, false) assert(!limitPushed(df6, 2)) - assert(columnPruned(df6, "col1")) + checkColumnPruned(df6, "col1") assert(df6.collect().length <= 2) // sample + limit // Push down order is sample -> filter -> limit // only limit is pushed down because in this test sample is after limit val df7 = spark.read.table(s"$catalogName.new_table").limit(2).sample(0.5) - assert(!samplePushed(df7)) + checkSamplePushed(df7, false) assert(limitPushed(df7, 2)) // sample + filter // Push down order is sample -> filter -> limit // only filter is pushed down because in this test sample is after filter val df8 = spark.read.table(s"$catalogName.new_table").where($"col1" > 1).sample(0.5) - assert(!samplePushed(df8)) - assert(filterPushed(df8)) + checkSamplePushed(df8, false) + checkFilterPushed(df8) assert(df8.collect().length < 10) } } } + + protected def checkAggregateRemoved(df: DataFrame): Unit = { + val aggregates = df.queryExecution.optimizedPlan.collect { + case agg: Aggregate => agg + } + assert(aggregates.isEmpty) + } + + private def checkAggregatePushed(df: DataFrame, funcName: String): Unit = { + df.queryExecution.optimizedPlan.collect { + case DataSourceV2ScanRelation(_, scan, _) => + assert(scan.isInstanceOf[V1ScanWrapper]) + val wrapper = scan.asInstanceOf[V1ScanWrapper] + assert(wrapper.pushedDownOperators.aggregation.isDefined) + val aggregationExpressions = + wrapper.pushedDownOperators.aggregation.get.aggregateExpressions() + assert(aggregationExpressions.length == 1) + assert(aggregationExpressions(0).isInstanceOf[GeneralAggregateFunc]) + assert(aggregationExpressions(0).asInstanceOf[GeneralAggregateFunc].name() == funcName) + } + } + + protected def caseConvert(tableName: String): String = tableName + + protected def testVarPop(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: VAR_POP with distinct: $isDistinct") { + val df = sql(s"SELECT VAR_POP(${distinct}bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "VAR_POP") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 10000d) + assert(row(1).getDouble(0) === 2500d) + assert(row(2).getDouble(0) === 0d) + } + } + + protected def testVarSamp(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: VAR_SAMP with distinct: $isDistinct") { + val df = sql( + s"SELECT VAR_SAMP(${distinct}bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "VAR_SAMP") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 20000d) + assert(row(1).getDouble(0) === 5000d) + assert(row(2).isNullAt(0)) + } + } + + protected def testStddevPop(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: STDDEV_POP with distinct: $isDistinct") { + val df = sql( + s"SELECT STDDEV_POP(${distinct}bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "STDDEV_POP") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 100d) + assert(row(1).getDouble(0) === 50d) + assert(row(2).getDouble(0) === 0d) + } + } + + protected def testStddevSamp(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: STDDEV_SAMP with distinct: $isDistinct") { + val df = sql( + s"SELECT STDDEV_SAMP(${distinct}bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "STDDEV_SAMP") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 141.4213562373095d) + assert(row(1).getDouble(0) === 70.71067811865476d) + assert(row(2).isNullAt(0)) + } + } + + protected def testCovarPop(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: COVAR_POP with distinct: $isDistinct") { + val df = sql( + s"SELECT COVAR_POP(${distinct}bonus, bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "COVAR_POP") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 10000d) + assert(row(1).getDouble(0) === 2500d) + assert(row(2).getDouble(0) === 0d) + } + } + + protected def testCovarSamp(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: COVAR_SAMP with distinct: $isDistinct") { + val df = sql( + s"SELECT COVAR_SAMP(${distinct}bonus, bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "COVAR_SAMP") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 20000d) + assert(row(1).getDouble(0) === 5000d) + assert(row(2).isNullAt(0)) + } + } + + protected def testCorr(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: CORR with distinct: $isDistinct") { + val df = sql( + s"SELECT CORR(${distinct}bonus, bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "CORR") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 1d) + assert(row(1).getDouble(0) === 1d) + assert(row(2).isNullAt(0)) + } + } } diff --git a/external/kafka-0-10-sql/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/external/kafka-0-10-sql/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 2f9e9fc0396d5..e096f120b8926 100644 --- a/external/kafka-0-10-sql/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/external/kafka-0-10-sql/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + org.apache.spark.sql.kafka010.KafkaSourceProvider diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala index 3b73896d631c6..77bc658a1ef20 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala @@ -31,8 +31,9 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory} import org.apache.spark.sql.connector.read.streaming._ import org.apache.spark.sql.kafka010.KafkaSourceProvider._ +import org.apache.spark.sql.kafka010.MockedSystemClock.currentMockSystemTime import org.apache.spark.sql.util.CaseInsensitiveStringMap -import org.apache.spark.util.{UninterruptibleThread, Utils} +import org.apache.spark.util.{Clock, ManualClock, SystemClock, UninterruptibleThread, Utils} /** * A [[MicroBatchStream]] that reads data from Kafka. @@ -57,7 +58,7 @@ private[kafka010] class KafkaMicroBatchStream( metadataPath: String, startingOffsets: KafkaOffsetRangeLimit, failOnDataLoss: Boolean) - extends SupportsAdmissionControl with ReportsSourceMetrics with MicroBatchStream with Logging { + extends SupportsTriggerAvailableNow with ReportsSourceMetrics with MicroBatchStream with Logging { private[kafka010] val pollTimeoutMs = options.getLong( KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, @@ -73,6 +74,13 @@ private[kafka010] class KafkaMicroBatchStream( Utils.timeStringAsMs(Option(options.get( KafkaSourceProvider.MAX_TRIGGER_DELAY)).getOrElse(DEFAULT_MAX_TRIGGER_DELAY)) + // this allows us to mock system clock for testing purposes + private[kafka010] val clock: Clock = if (options.containsKey(MOCK_SYSTEM_TIME)) { + new MockedSystemClock + } else { + new SystemClock + } + private var lastTriggerMillis = 0L private val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false) @@ -81,6 +89,8 @@ private[kafka010] class KafkaMicroBatchStream( private var latestPartitionOffsets: PartitionOffsetMap = _ + private var allDataForTriggerAvailableNow: PartitionOffsetMap = _ + /** * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only * called in StreamExecutionThread. Otherwise, interrupting a thread while running @@ -98,7 +108,8 @@ private[kafka010] class KafkaMicroBatchStream( } else if (minOffsetPerTrigger.isDefined) { ReadLimit.minRows(minOffsetPerTrigger.get, maxTriggerDelayMs) } else { - maxOffsetsPerTrigger.map(ReadLimit.maxRows).getOrElse(super.getDefaultReadLimit) + // TODO (SPARK-37973) Directly call super.getDefaultReadLimit when scala issue 12523 is fixed + maxOffsetsPerTrigger.map(ReadLimit.maxRows).getOrElse(ReadLimit.allAvailable()) } } @@ -113,7 +124,13 @@ private[kafka010] class KafkaMicroBatchStream( override def latestOffset(start: Offset, readLimit: ReadLimit): Offset = { val startPartitionOffsets = start.asInstanceOf[KafkaSourceOffset].partitionToOffsets - latestPartitionOffsets = kafkaOffsetReader.fetchLatestOffsets(Some(startPartitionOffsets)) + + // Use the pre-fetched list of partition offsets when Trigger.AvailableNow is enabled. + latestPartitionOffsets = if (allDataForTriggerAvailableNow != null) { + allDataForTriggerAvailableNow + } else { + kafkaOffsetReader.fetchLatestOffsets(Some(startPartitionOffsets)) + } val limits: Seq[ReadLimit] = readLimit match { case rows: CompositeReadLimit => rows.getReadLimits @@ -157,9 +174,9 @@ private[kafka010] class KafkaMicroBatchStream( currentOffsets: Map[TopicPartition, Long], maxTriggerDelayMs: Long): Boolean = { // Checking first if the maxbatchDelay time has passed - if ((System.currentTimeMillis() - lastTriggerMillis) >= maxTriggerDelayMs) { + if ((clock.getTimeMillis() - lastTriggerMillis) >= maxTriggerDelayMs) { logDebug("Maximum wait time is passed, triggering batch") - lastTriggerMillis = System.currentTimeMillis() + lastTriggerMillis = clock.getTimeMillis() false } else { val newRecords = latestOffsets.flatMap { @@ -167,7 +184,7 @@ private[kafka010] class KafkaMicroBatchStream( Some(topic -> (offset - currentOffsets.getOrElse(topic, 0L))) }.values.sum.toDouble if (newRecords < minLimit) true else { - lastTriggerMillis = System.currentTimeMillis() + lastTriggerMillis = clock.getTimeMillis() false } } @@ -298,6 +315,11 @@ private[kafka010] class KafkaMicroBatchStream( logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE") } } + + override def prepareForTriggerAvailableNow(): Unit = { + allDataForTriggerAvailableNow = kafkaOffsetReader.fetchLatestOffsets( + Some(getOrCreateInitialPartitionOffsets())) + } } object KafkaMicroBatchStream extends Logging { @@ -333,3 +355,24 @@ object KafkaMicroBatchStream extends Logging { ju.Collections.emptyMap() } } + +/** + * To return a mocked system clock for testing purposes + */ +private[kafka010] class MockedSystemClock extends ManualClock { + override def getTimeMillis(): Long = { + currentMockSystemTime + } +} + +private[kafka010] object MockedSystemClock { + var currentMockSystemTime = 0L + + def advanceCurrentSystemTime(advanceByMillis: Long): Unit = { + currentMockSystemTime += advanceByMillis + } + + def reset(): Unit = { + currentMockSystemTime = 0L + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 87cef02d0d8f2..c82fda85eb4e8 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.connector.read.streaming.{Offset => _, _} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{Clock, SystemClock, Utils} /** * A [[Source]] that reads data from Kafka using the following design. @@ -77,7 +77,7 @@ private[kafka010] class KafkaSource( metadataPath: String, startingOffsets: KafkaOffsetRangeLimit, failOnDataLoss: Boolean) - extends SupportsAdmissionControl with Source with Logging { + extends SupportsTriggerAvailableNow with Source with Logging { private val sc = sqlContext.sparkContext @@ -94,11 +94,20 @@ private[kafka010] class KafkaSource( private[kafka010] val maxTriggerDelayMs = Utils.timeStringAsMs(sourceOptions.get(MAX_TRIGGER_DELAY).getOrElse(DEFAULT_MAX_TRIGGER_DELAY)) + // this allows us to mock system clock for testing purposes + private[kafka010] val clock: Clock = if (sourceOptions.contains(MOCK_SYSTEM_TIME)) { + new MockedSystemClock + } else { + new SystemClock + } + private val includeHeaders = sourceOptions.getOrElse(INCLUDE_HEADERS, "false").toBoolean private var lastTriggerMillis = 0L + private var allDataForTriggerAvailableNow: PartitionOffsetMap = _ + /** * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only * called in StreamExecutionThread. Otherwise, interrupting a thread while running @@ -130,7 +139,8 @@ private[kafka010] class KafkaSource( } else if (minOffsetPerTrigger.isDefined) { ReadLimit.minRows(minOffsetPerTrigger.get, maxTriggerDelayMs) } else { - maxOffsetsPerTrigger.map(ReadLimit.maxRows).getOrElse(super.getDefaultReadLimit) + // TODO (SPARK-37973) Directly call super.getDefaultReadLimit when scala issue 12523 is fixed + maxOffsetsPerTrigger.map(ReadLimit.maxRows).getOrElse(ReadLimit.allAvailable()) } } @@ -159,7 +169,14 @@ private[kafka010] class KafkaSource( // Make sure initialPartitionOffsets is initialized initialPartitionOffsets val currentOffsets = currentPartitionOffsets.orElse(Some(initialPartitionOffsets)) - val latest = kafkaReader.fetchLatestOffsets(currentOffsets) + + // Use the pre-fetched list of partition offsets when Trigger.AvailableNow is enabled. + val latest = if (allDataForTriggerAvailableNow != null) { + allDataForTriggerAvailableNow + } else { + kafkaReader.fetchLatestOffsets(currentOffsets) + } + latestPartitionOffsets = Some(latest) val limits: Seq[ReadLimit] = limit match { @@ -206,9 +223,9 @@ private[kafka010] class KafkaSource( currentOffsets: Map[TopicPartition, Long], maxTriggerDelayMs: Long): Boolean = { // Checking first if the maxbatchDelay time has passed - if ((System.currentTimeMillis() - lastTriggerMillis) >= maxTriggerDelayMs) { + if ((clock.getTimeMillis() - lastTriggerMillis) >= maxTriggerDelayMs) { logDebug("Maximum wait time is passed, triggering batch") - lastTriggerMillis = System.currentTimeMillis() + lastTriggerMillis = clock.getTimeMillis() false } else { val newRecords = latestOffsets.flatMap { @@ -216,7 +233,7 @@ private[kafka010] class KafkaSource( Some(topic -> (offset - currentOffsets.getOrElse(topic, 0L))) }.values.sum.toDouble if (newRecords < minLimit) true else { - lastTriggerMillis = System.currentTimeMillis() + lastTriggerMillis = clock.getTimeMillis() false } } @@ -331,6 +348,10 @@ private[kafka010] class KafkaSource( logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE") } } + + override def prepareForTriggerAvailableNow(): Unit = { + allDataForTriggerAvailableNow = kafkaReader.fetchLatestOffsets(Some(initialPartitionOffsets)) + } } /** Companion object for the [[KafkaSource]]. */ diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 640996da67bca..de78992533b22 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -562,6 +562,8 @@ private[kafka010] object KafkaSourceProvider extends Logging { "startingoffsetsbytimestampstrategy" private val GROUP_ID_PREFIX = "groupidprefix" private[kafka010] val INCLUDE_HEADERS = "includeheaders" + // This is only for internal testing and should not be used otherwise. + private[kafka010] val MOCK_SYSTEM_TIME = "_mockSystemTime" private[kafka010] object StrategyOnNoMatchStartingOffset extends Enumeration { val ERROR, LATEST = Value @@ -726,7 +728,7 @@ private[kafka010] object KafkaSourceProvider extends Logging { parameters .keySet .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) - .map { k => k.drop(6).toString -> parameters(k) } + .map { k => k.drop(6) -> parameters(k) } .toMap } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index f61696f6485e6..db71f0fd9184a 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.kafka010.KafkaSourceProvider._ -import org.apache.spark.sql.streaming.{StreamTest, Trigger} +import org.apache.spark.sql.streaming.{StreamingQuery, StreamTest, Trigger} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -195,6 +195,45 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { true } + test("Trigger.AvailableNow") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + + testUtils.sendMessages(topic, (0 until 15).map { case x => + s"foo-$x" + }.toArray, Some(0)) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("maxOffsetsPerTrigger", 5) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + .load() + + var index: Int = 0 + def startTriggerAvailableNowQuery(): StreamingQuery = { + reader.writeStream + .foreachBatch((_: Dataset[Row], _: Long) => { + index += 1 + }) + .trigger(Trigger.AvailableNow) + .start() + } + + val query = startTriggerAvailableNowQuery() + try { + assert(query.awaitTermination(streamingTimeout.toMillis)) + } finally { + query.stop() + } + + // should have 3 batches now i.e. 15 / 5 = 3 + assert(index == 3) + } + test("(de)serialization of initial offsets") { val topic = newTopic() testUtils.createTopic(topic, partitions = 5) @@ -419,6 +458,8 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { } test("compositeReadLimit") { + MockedSystemClock.reset() + val topic = newTopic() testUtils.createTopic(topic, partitions = 3) testUtils.sendMessages(topic, (100 to 120).map(_.toString).toArray, Some(0)) @@ -435,6 +476,9 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { .option("maxOffsetsPerTrigger", 20) .option("subscribe", topic) .option("startingOffsets", "earliest") + // mock system time to ensure deterministic behavior + // in determining if maxOffsetsPerTrigger is satisfied + .option("_mockSystemTime", "") val kafka = reader.load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] @@ -442,6 +486,10 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { val clock = new StreamManualClock + def advanceSystemClock(mills: Long): ExternalAction = () => { + MockedSystemClock.advanceCurrentSystemTime(mills) + } + testStream(mapped)( StartStream(Trigger.ProcessingTime(100), clock), waitUntilBatchProcessed(clock), @@ -453,6 +501,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { // No data is processed for next batch as data is less than minOffsetsPerTrigger // and maxTriggerDelay is not expired AdvanceManualClock(100), + advanceSystemClock(100), waitUntilBatchProcessed(clock), CheckNewAnswer(), Assert { @@ -462,6 +511,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { true }, AdvanceManualClock(100), + advanceSystemClock(100), waitUntilBatchProcessed(clock), // Running batch now as number of new records is greater than minOffsetsPerTrigger // but reading limited data as per maxOffsetsPerTrigger @@ -473,14 +523,11 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { // Testing maxTriggerDelay // No data is processed for next batch till maxTriggerDelay is expired AdvanceManualClock(100), + advanceSystemClock(100), waitUntilBatchProcessed(clock), CheckNewAnswer(), - // Sleeping for 5s to let maxTriggerDelay expire - Assert { - Thread.sleep(5 * 1000) - true - }, AdvanceManualClock(100), + advanceSystemClock(5000), // Running batch as maxTriggerDelay is expired waitUntilBatchProcessed(clock), CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, @@ -1369,10 +1416,10 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { testStream(kafka)( makeSureGetOffsetCalled, AssertOnQuery { query => - query.logicalPlan.find { + query.logicalPlan.exists { case r: StreamingDataSourceV2Relation => r.stream.isInstanceOf[KafkaMicroBatchStream] case _ => false - }.isDefined + } } ) } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala index 058563dfa167d..c5d2a99d156f8 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -44,6 +44,7 @@ import org.apache.kafka.common.network.ListenerName import org.apache.kafka.common.security.auth.SecurityProtocol.{PLAINTEXT, SASL_PLAINTEXT} import org.apache.kafka.common.serialization.StringSerializer import org.apache.kafka.common.utils.SystemTime +import org.apache.zookeeper.client.ZKClientConfig import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} import org.apache.zookeeper.server.auth.SASLAuthenticationProvider import org.scalatest.Assertions._ @@ -266,7 +267,7 @@ class KafkaTestUtils( // Get the actual zookeeper binding port zkPort = zookeeper.actualPort zkClient = KafkaZkClient(s"$zkHost:$zkPort", isSecure = false, zkSessionTimeout, - zkConnectionTimeout, 1, new SystemTime()) + zkConnectionTimeout, 1, new SystemTime(), "test", new ZKClientConfig) zkReady = true } @@ -488,9 +489,7 @@ class KafkaTestUtils( protected def brokerConfiguration: Properties = { val props = new Properties() props.put("broker.id", "0") - props.put("host.name", "127.0.0.1") - props.put("advertised.host.name", "127.0.0.1") - props.put("port", brokerPort.toString) + props.put("listeners", s"PLAINTEXT://127.0.0.1:$brokerPort") props.put("log.dir", Utils.createTempDir().getAbsolutePath) props.put("zookeeper.connect", zkAddress) props.put("zookeeper.connection.timeout.ms", "60000") diff --git a/external/kafka-0-10-token-provider/src/main/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider b/external/kafka-0-10-token-provider/src/main/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider index 34014016584de..ff1987503183f 100644 --- a/external/kafka-0-10-token-provider/src/main/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider +++ b/external/kafka-0-10-token-provider/src/main/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider @@ -1 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + org.apache.spark.kafka010.KafkaDelegationTokenProvider diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala index b9ef16fb58cb9..9c57663b3d8ef 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala @@ -21,15 +21,17 @@ import java.{ util => ju } import java.io.File import scala.collection.JavaConverters._ +import scala.concurrent.duration._ import scala.util.Random -import kafka.log.{CleanerConfig, Log, LogCleaner, LogConfig, ProducerStateManager} +import kafka.log.{CleanerConfig, LogCleaner, LogConfig, UnifiedLog} import kafka.server.{BrokerTopicStats, LogDirFailureChannel} import kafka.utils.Pool import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.record.{CompressionType, MemoryRecords, SimpleRecord} import org.apache.kafka.common.serialization.StringDeserializer import org.scalatest.BeforeAndAfterAll +import org.scalatest.concurrent.Eventually.{eventually, interval, timeout} import org.apache.spark._ import org.apache.spark.scheduler.ExecutorCacheTaskLocation @@ -84,7 +86,7 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { private def compactLogs(topic: String, partition: Int, messages: Array[(String, String)]): Unit = { val mockTime = new MockTime() - val logs = new Pool[TopicPartition, Log]() + val logs = new Pool[TopicPartition, UnifiedLog]() val logDir = kafkaTestUtils.brokerLogDir val dir = new File(logDir, topic + "-" + partition) dir.mkdirs() @@ -93,7 +95,7 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { logProps.put(LogConfig.MinCleanableDirtyRatioProp, java.lang.Float.valueOf(0.1f)) val logDirFailureChannel = new LogDirFailureChannel(1) val topicPartition = new TopicPartition(topic, partition) - val log = new Log( + val log = UnifiedLog( dir, LogConfig(logProps), 0L, @@ -103,9 +105,10 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { mockTime, Int.MaxValue, Int.MaxValue, - topicPartition, - new ProducerStateManager(topicPartition, dir), - logDirFailureChannel + logDirFailureChannel, + lastShutdownClean = false, + topicId = None, + keepPartitionMetadataFile = false ) messages.foreach { case (k, v) => val record = new SimpleRecord(k.getBytes, v.getBytes) @@ -201,6 +204,11 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { sc, kafkaParams, offsetRanges, preferredHosts ).map(m => m.key -> m.value) + // To make it sure that the compaction happens + eventually(timeout(20.second), interval(1.seconds)) { + val dir = new File(kafkaTestUtils.brokerLogDir, topic + "-0") + assert(dir.listFiles().exists(_.getName.endsWith(".deleted"))) + } val received = rdd.collect.toSet assert(received === compactedMessages.toSet) diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala index 0783e591def51..dd8d66f1fc08f 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala @@ -35,6 +35,7 @@ import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.network.ListenerName import org.apache.kafka.common.serialization.StringSerializer import org.apache.kafka.common.utils.{Time => KTime} +import org.apache.zookeeper.client.ZKClientConfig import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} import org.apache.spark.{SparkConf, SparkException} @@ -106,7 +107,7 @@ private[kafka010] class KafkaTestUtils extends Logging { // Get the actual zookeeper binding port zkPort = zookeeper.actualPort zkClient = KafkaZkClient(s"$zkHost:$zkPort", isSecure = false, zkSessionTimeout, - zkConnectionTimeout, 1, KTime.SYSTEM) + zkConnectionTimeout, 1, KTime.SYSTEM, "test", new ZKClientConfig) admClient = new AdminZkClient(zkClient) zkReady = true } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 8564597f4f135..4a790878cf9dc 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -21,7 +21,6 @@ import scala.reflect.{classTag, ClassTag} import org.apache.spark.HashPartitioner import org.apache.spark.graphx._ -import org.apache.spark.graphx.util.BytecodeUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -265,14 +264,6 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( } } - /** Test whether the closure accesses the attribute with name `attrName`. */ - private def accessesVertexAttr(closure: AnyRef, attrName: String): Boolean = { - try { - BytecodeUtils.invokedMethod(closure, classOf[EdgeTriplet[VD, ED]], attrName) - } catch { - case _: ClassNotFoundException => true // if we don't know, be conservative - } - } } // end of class GraphImpl diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml index 976ce02e9ea8d..3ba96055ae05f 100644 --- a/hadoop-cloud/pom.xml +++ b/hadoop-cloud/pom.xml @@ -267,6 +267,13 @@ com.google.guava guava + + + org.jacoco + org.jacoco.agent + + 3.3.2 2.5.0 ${hadoop.version} 3.6.2 @@ -133,12 +134,12 @@ 2.3 - 2.8.1 + 3.1.0 10.14.2.0 1.12.2 - 1.7.2 - 9.4.43.v20210629 + 1.7.3 + 9.4.44.v20210927 4.0.3 0.10.0 2.5.0 @@ -147,7 +148,7 @@ If you changes codahale.metrics.version, you also need to change the link to metrics.dropwizard.io in docs/monitoring.md. --> - 4.2.2 + 4.2.7 1.11.0 1.12.0 @@ -157,17 +158,24 @@ 4.5.13 4.4.14 - 3.4.1 + 3.6.1 - 3.2.2 + 4.4 2.12.15 2.12 2.0.2 + + + 4.4.0 --test true 1.9.13 - 2.13.1 + 2.13.2 1.1.8.4 1.1.2 2.2.1 @@ -184,7 +192,7 @@ 14.0.1 3.0.16 2.34 - 2.10.12 + 2.10.13 3.5.2 3.0.0 0.12.0 @@ -196,15 +204,15 @@ 1.1.0 1.5.0 1.60 - 1.6.0 + 1.6.1 - 6.0.1 + 7.0.0 org.fusesource.leveldbjni - 5.10.2 + 5.12.1 ${java.home} @@ -393,7 +401,7 @@ org.scalatestplus - mockito-3-12_${scala.binary.version} + mockito-4-2_${scala.binary.version} test @@ -407,7 +415,7 @@ test - com.novocode + com.github.sbt junit-interface test @@ -584,7 +592,7 @@ org.apache.commons commons-text - 1.6 + 1.9 commons-lang @@ -612,8 +620,8 @@ ${commons.math3.version} - commons-collections - commons-collections + org.apache.commons + commons-collections4 ${commons.collections.version} @@ -737,13 +745,11 @@ org.apache.logging.log4j log4j-api ${log4j.version} - ${hadoop.deps.scope} org.apache.logging.log4j log4j-core ${log4j.version} - ${hadoop.deps.scope} @@ -758,7 +764,7 @@ com.ning compress-lzf - 1.0.3 + 1.1 org.xerial.snappy @@ -773,7 +779,7 @@ com.github.luben zstd-jni - 1.5.1-1 + 1.5.2-1 com.clearspring.analytics @@ -801,17 +807,12 @@ org.roaringbitmap RoaringBitmap - 0.9.23 - - - commons-net - commons-net - 3.1 + 0.9.25 io.netty netty-all - 4.1.72.Final + 4.1.74.Final io.netty @@ -1026,6 +1027,11 @@ + + com.chuusai + shapeless_${scala.binary.version} + 2.3.7 + org.json4s json4s-jackson_${scala.binary.version} @@ -1057,11 +1063,6 @@ scala-library ${scala.version} - - org.scala-lang - scala-actors - ${scala.version} - org.scala-lang.modules scala-parser-combinators_${scala.binary.version} @@ -1086,8 +1087,8 @@ org.scalatestplus - mockito-3-12_${scala.binary.version} - 3.2.10.0 + mockito-4-2_${scala.binary.version} + 3.2.11.0 test @@ -1099,13 +1100,13 @@ org.mockito mockito-core - 3.12.4 + 4.2.0 test org.mockito mockito-inline - 3.12.4 + 4.2.0 test @@ -1123,25 +1124,13 @@ junit junit - 4.13.1 - test - - - org.hamcrest - hamcrest-core - 1.3 - test - - - org.hamcrest - hamcrest-library - 1.3 + 4.13.2 test - com.novocode + com.github.sbt junit-interface - 0.11 + 0.13.3 test @@ -1176,7 +1165,7 @@ org.postgresql postgresql - 42.3.0 + 42.3.3 test @@ -2756,44 +2745,43 @@ - - org.codehaus.mojo - build-helper-maven-plugin - 3.2.0 - - - module-timestamp-property - validate - - timestamp-property - - - module.build.timestamp - ${maven.build.timestamp.format} - current - America/Los_Angeles - - - - local-timestamp-property - validate - - timestamp-property - - - local.build.timestamp - ${maven.build.timestamp.format} - build - America/Los_Angeles - - - - + + org.codehaus.mojo + build-helper-maven-plugin + 3.2.0 + + + module-timestamp-property + validate + + timestamp-property + + + module.build.timestamp + ${maven.build.timestamp.format} + current + America/Los_Angeles + + + + local-timestamp-property + validate + + timestamp-property + + + local.build.timestamp + ${maven.build.timestamp.format} + build + America/Los_Angeles + + + + net.alchim31.maven scala-maven-plugin - - 4.3.0 + ${scala-maven-plugin.version} eclipse-add-source @@ -3133,58 +3121,6 @@ - - - - org.eclipse.m2e - lifecycle-mapping - 1.0.0 - - - - - - org.apache.maven.plugins - maven-dependency-plugin - [2.8,) - - build-classpath - - - - - - - - - org.apache.maven.plugins - maven-jar-plugin - 3.1.2 - - test-jar - - - - - - - - - org.apache.maven.plugins - maven-antrun-plugin - [${maven-antrun.version},) - - run - - - - - - - - - - @@ -3490,6 +3426,7 @@ hadoop-2 + 2.7.4 2.7.1 2.4 @@ -3499,6 +3436,8 @@ hadoop-client hadoop-yarn-api hadoop-client + + 4.3.0 @@ -3564,9 +3503,9 @@ scala-2.12 - 2.12.15 @@ -3581,7 +3520,7 @@ scala-2.13 - 2.13.7 + 2.13.8 2.13 @@ -3758,5 +3697,72 @@ + + only-eclipse + + + + m2e.version + + + + + + + + + org.eclipse.m2e + lifecycle-mapping + 1.0.0 + + + + + + org.apache.maven.plugins + maven-dependency-plugin + [2.8,) + + build-classpath + + + + + + + + + org.apache.maven.plugins + maven-jar-plugin + 3.1.2 + + test-jar + + + + + + + + + org.apache.maven.plugins + maven-antrun-plugin + [${maven-antrun.version},) + + run + + + + + + + + + + + + + + diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b985f95b85c6d..b045d4615d3c4 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -48,7 +48,12 @@ object MimaExcludes { // [SPARK-37780][SQL] QueryExecutionListener support SQLConf as constructor parameter ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.util.ExecutionListenerManager.this"), // [SPARK-37786][SQL] StreamingQueryListener support use SQLConf.get to get corresponding SessionState's SQLConf - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.this") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.this"), + + // [SPARK-37600][BUILD] Upgrade to Hadoop 3.3.2 + ProblemFilters.exclude[MissingClassProblem]("org.apache.hadoop.shaded.net.jpountz.lz4.LZ4Compressor"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.hadoop.shaded.net.jpountz.lz4.LZ4Factory"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.hadoop.shaded.net.jpountz.lz4.LZ4SafeDecompressor") ) // Exclude rules for 3.2.x from 3.1.1 @@ -66,6 +71,12 @@ object MimaExcludes { ProblemFilters.exclude[Problem]("org.apache.spark.sql.catalyst.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.execution.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.internal.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.errors.*"), + // DSv2 catalog and expression APIs are unstable yet. We should enable this back. + ProblemFilters.exclude[Problem]("org.apache.spark.sql.connector.catalog.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.connector.expressions.*"), + // Avro source implementation is internal. + ProblemFilters.exclude[Problem]("org.apache.spark.sql.v2.avro.*"), // [SPARK-34848][CORE] Add duration to TaskMetricDistributions ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.this"), diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 4130c6a1c73d6..b536b50532a05 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -376,8 +376,8 @@ object SparkBuild extends PomBuild { val mimaProjects = allProjects.filterNot { x => Seq( - spark, hive, hiveThriftServer, catalyst, repl, networkCommon, networkShuffle, networkYarn, - unsafe, tags, tokenProviderKafka010, sqlKafka010, kvstore, avro + spark, hive, hiveThriftServer, repl, networkCommon, networkShuffle, networkYarn, + unsafe, tags, tokenProviderKafka010, sqlKafka010 ).contains(x) } @@ -421,6 +421,11 @@ object SparkBuild extends PomBuild { // SPARK-14738 - Remove docker tests from main Spark build // enable(DockerIntegrationTests.settings)(dockerIntegrationTests) + if (!profiles.contains("volcano")) { + enable(Volcano.settings)(kubernetes) + enable(Volcano.settings)(kubernetesIntegrationTests) + } + enable(KubernetesIntegrationTests.settings)(kubernetesIntegrationTests) enable(YARN.settings)(yarn) @@ -604,8 +609,8 @@ object DockerIntegrationTests { } /** - * These settings run a hardcoded configuration of the Kubernetes integration tests using - * minikube. Docker images will have the "dev" tag, and will be overwritten every time the + * These settings run the Kubernetes integration tests. + * Docker images will have the "dev" tag, and will be overwritten every time the * integration tests are run. The integration tests are actually bound to the "test" phase, * so running "test" on this module will run the integration tests. * @@ -623,8 +628,10 @@ object KubernetesIntegrationTests { val dockerBuild = TaskKey[Unit]("docker-imgs", "Build the docker images for ITs.") val runITs = TaskKey[Unit]("run-its", "Only run ITs, skip image build.") - val imageTag = settingKey[String]("Tag to use for images built during the test.") - val namespace = settingKey[String]("Namespace where to run pods.") + val imageRepo = sys.props.getOrElse("spark.kubernetes.test.imageRepo", "docker.io/kubespark") + val imageTag = sys.props.get("spark.kubernetes.test.imageTag") + val namespace = sys.props.get("spark.kubernetes.test.namespace") + val deployMode = sys.props.get("spark.kubernetes.test.deployMode") // Hack: this variable is used to control whether to build docker images. It's updated by // the tasks below in a non-obvious way, so that you get the functionality described in @@ -632,29 +639,41 @@ object KubernetesIntegrationTests { private var shouldBuildImage = true lazy val settings = Seq( - imageTag := "dev", - namespace := "default", dockerBuild := { if (shouldBuildImage) { val dockerTool = s"$sparkHome/bin/docker-image-tool.sh" val bindingsDir = s"$sparkHome/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings" - val dockerFile = sys.props.get("spark.kubernetes.test.dockerFile") - val javaImageTag = sys.props.getOrElse("spark.kubernetes.test.javaImageTag", "8-jre-slim") - val extraOptions = if (dockerFile.isDefined) { - Seq("-f", s"${dockerFile.get}") - } else { + val javaImageTag = sys.props.get("spark.kubernetes.test.javaImageTag") + val dockerFile = sys.props.getOrElse("spark.kubernetes.test.dockerFile", + s"$sparkHome/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile.java17") + val pyDockerFile = sys.props.getOrElse("spark.kubernetes.test.pyDockerFile", + s"$bindingsDir/python/Dockerfile") + val rDockerFile = sys.props.getOrElse("spark.kubernetes.test.rDockerFile", + s"$bindingsDir/R/Dockerfile") + val extraOptions = if (javaImageTag.isDefined) { Seq("-b", s"java_image_tag=$javaImageTag") + } else { + Seq("-f", s"$dockerFile") } - val cmd = Seq(dockerTool, "-m", - "-t", imageTag.value, - "-p", s"$bindingsDir/python/Dockerfile", - "-R", s"$bindingsDir/R/Dockerfile") ++ + val cmd = Seq(dockerTool, + "-r", imageRepo, + "-t", imageTag.getOrElse("dev"), + "-p", pyDockerFile, + "-R", rDockerFile) ++ + (if (deployMode != Some("minikube")) Seq.empty else Seq("-m")) ++ extraOptions :+ "build" val ec = Process(cmd).! if (ec != 0) { throw new IllegalStateException(s"Process '${cmd.mkString(" ")}' exited with $ec.") } + if (deployMode == Some("cloud")) { + val cmd = Seq(dockerTool, "-r", imageRepo, "-t", imageTag.getOrElse("dev"), "push") + val ret = Process(cmd).! + if (ret != 0) { + throw new IllegalStateException(s"Process '${cmd.mkString(" ")}' exited with $ret.") + } + } } shouldBuildImage = true }, @@ -666,11 +685,12 @@ object KubernetesIntegrationTests { }.value, (Test / test) := (Test / test).dependsOn(dockerBuild).value, (Test / javaOptions) ++= Seq( - "-Dspark.kubernetes.test.deployMode=minikube", - s"-Dspark.kubernetes.test.imageTag=${imageTag.value}", - s"-Dspark.kubernetes.test.namespace=${namespace.value}", + s"-Dspark.kubernetes.test.deployMode=${deployMode.getOrElse("minikube")}", + s"-Dspark.kubernetes.test.imageRepo=${imageRepo}", + s"-Dspark.kubernetes.test.imageTag=${imageTag.getOrElse("dev")}", s"-Dspark.kubernetes.test.unpackSparkDir=$sparkHome" ), + (Test / javaOptions) ++= namespace.map("-Dspark.kubernetes.test.namespace=" + _), // Force packaging before building images, so that the latest code is tested. dockerBuild := dockerBuild .dependsOn(assembly / Compile / packageBin) @@ -945,6 +965,13 @@ object SparkR { ) } +object Volcano { + // Exclude all volcano file for Compile and Test + lazy val settings = Seq( + unmanagedSources / excludeFilter := HiddenFileFilter || "*Volcano*.scala" + ) +} + object Unidoc { import BuildCommons._ @@ -1204,7 +1231,7 @@ object TestSettings { (Test / testOptions) += Tests.Argument(TestFrameworks.ScalaTest, "-W", "120", "300"), (Test / testOptions) += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"), // Enable Junit testing. - libraryDependencies += "com.novocode" % "junit-interface" % "0.11" % "test", + libraryDependencies += "com.github.sbt" % "junit-interface" % "0.13.3" % "test", // `parallelExecutionInTest` controls whether test suites belonging to the same SBT project // can run in parallel with one another. It does NOT control whether tests execute in parallel // within the same JVM (which is controlled by `testForkedParallel`) or whether test cases diff --git a/project/build.properties b/project/build.properties index d434f8eead721..8599f07ab2b6f 100644 --- a/project/build.properties +++ b/project/build.properties @@ -15,4 +15,4 @@ # limitations under the License. # # Please update the version in appveyor-install-dependencies.ps1 together. -sbt.version=1.6.1 +sbt.version=1.6.2 diff --git a/python/docs/Makefile b/python/docs/Makefile index 9cb1a17ef584f..2628530cb20b3 100644 --- a/python/docs/Makefile +++ b/python/docs/Makefile @@ -21,7 +21,7 @@ SPHINXBUILD ?= sphinx-build SOURCEDIR ?= source BUILDDIR ?= build -export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.9.3-src.zip) +export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.9.4-src.zip) # Put it first so that "make" without argument is like "make help". help: diff --git a/python/docs/make2.bat b/python/docs/make2.bat index 2e4e2b543ab24..26ef220309c48 100644 --- a/python/docs/make2.bat +++ b/python/docs/make2.bat @@ -25,7 +25,7 @@ if "%SPHINXBUILD%" == "" ( set SOURCEDIR=source set BUILDDIR=build -set PYTHONPATH=..;..\lib\py4j-0.10.9.3-src.zip +set PYTHONPATH=..;..\lib\py4j-0.10.9.4-src.zip if "%1" == "" goto help diff --git a/python/docs/source/getting_started/install.rst b/python/docs/source/getting_started/install.rst index 15a12403128d9..3503be03339fe 100644 --- a/python/docs/source/getting_started/install.rst +++ b/python/docs/source/getting_started/install.rst @@ -157,7 +157,7 @@ Package Minimum supported version Note `pandas` 1.0.5 Optional for Spark SQL `NumPy` 1.7 Required for MLlib DataFrame-based API `pyarrow` 1.0.0 Optional for Spark SQL -`Py4J` 0.10.9.3 Required +`Py4J` 0.10.9.4 Required `pandas` 1.0.5 Required for pandas API on Spark `pyarrow` 1.0.0 Required for pandas API on Spark `Numpy` 1.14 Required for pandas API on Spark diff --git a/python/docs/source/reference/pyspark.rst b/python/docs/source/reference/pyspark.rst index 6d4d0b55477c1..f0997255bb911 100644 --- a/python/docs/source/reference/pyspark.rst +++ b/python/docs/source/reference/pyspark.rst @@ -53,6 +53,7 @@ Spark Context APIs SparkContext.PACKAGE_EXTENSIONS SparkContext.accumulator + SparkContext.addArchive SparkContext.addFile SparkContext.addPyFile SparkContext.applicationId @@ -111,6 +112,7 @@ RDD APIs RDD.cache RDD.cartesian RDD.checkpoint + RDD.cleanShuffleDependencies RDD.coalesce RDD.cogroup RDD.collect diff --git a/python/docs/source/reference/pyspark.sql.rst b/python/docs/source/reference/pyspark.sql.rst index 818814ca0a147..1d34961a91a61 100644 --- a/python/docs/source/reference/pyspark.sql.rst +++ b/python/docs/source/reference/pyspark.sql.rst @@ -201,6 +201,7 @@ DataFrame APIs DataFrame.show DataFrame.sort DataFrame.sortWithinPartitions + DataFrame.sparkSession DataFrame.stat DataFrame.storageLevel DataFrame.subtract diff --git a/python/lib/py4j-0.10.9.3-src.zip b/python/lib/py4j-0.10.9.3-src.zip deleted file mode 100644 index 428f3acd62b3c..0000000000000 Binary files a/python/lib/py4j-0.10.9.3-src.zip and /dev/null differ diff --git a/python/lib/py4j-0.10.9.4-src.zip b/python/lib/py4j-0.10.9.4-src.zip new file mode 100644 index 0000000000000..51b3404d5ab3e Binary files /dev/null and b/python/lib/py4j-0.10.9.4-src.zip differ diff --git a/python/mypy.ini b/python/mypy.ini index 8a4c92eaebcf5..efaa3dc97d3c4 100644 --- a/python/mypy.ini +++ b/python/mypy.ini @@ -20,6 +20,8 @@ strict_optional = True no_implicit_optional = True disallow_untyped_defs = True show_error_codes = True +warn_unused_ignores = True +warn_redundant_casts = True ; Allow untyped def in internal modules and tests diff --git a/python/pyspark/_typing.pyi b/python/pyspark/_typing.pyi index 9a36c8945bf96..6cc09263684d5 100644 --- a/python/pyspark/_typing.pyi +++ b/python/pyspark/_typing.pyi @@ -17,17 +17,27 @@ # under the License. from typing import Callable, Iterable, Sized, TypeVar, Union -from typing_extensions import Protocol +from typing_extensions import Literal, Protocol + +from numpy import int32, int64, float32, float64, ndarray F = TypeVar("F", bound=Callable) T_co = TypeVar("T_co", covariant=True) PrimitiveType = Union[bool, float, int, str] +NonUDFType = Literal[0] + class SupportsIAdd(Protocol): def __iadd__(self, other: SupportsIAdd) -> SupportsIAdd: ... class SupportsOrdering(Protocol): - def __le__(self, other: SupportsOrdering) -> bool: ... + def __lt__(self, other: SupportsOrdering) -> bool: ... class SizedIterable(Protocol, Sized, Iterable[T_co]): ... + +S = TypeVar("S", bound=SupportsOrdering) + +NumberOrArray = TypeVar( + "NumberOrArray", float, int, complex, int32, int64, float32, float64, ndarray +) diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index d3dc2e91c4fad..fe775a37ed8e9 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -20,20 +20,30 @@ import struct import socketserver as SocketServer import threading +from typing import Callable, Dict, Generic, Tuple, Type, TYPE_CHECKING, TypeVar, Union + from pyspark.serializers import read_int, CPickleSerializer +if TYPE_CHECKING: + from pyspark._typing import SupportsIAdd # noqa: F401 + import socketserver.BaseRequestHandler # type: ignore[import] + __all__ = ["Accumulator", "AccumulatorParam"] +T = TypeVar("T") +U = TypeVar("U", bound="SupportsIAdd") pickleSer = CPickleSerializer() # Holds accumulators registered on the current machine, keyed by ID. This is then used to send # the local accumulator updates back to the driver program at the end of a task. -_accumulatorRegistry = {} +_accumulatorRegistry: Dict[int, "Accumulator"] = {} -def _deserialize_accumulator(aid, zero_value, accum_param): +def _deserialize_accumulator( + aid: int, zero_value: T, accum_param: "AccumulatorParam[T]" +) -> "Accumulator[T]": from pyspark.accumulators import _accumulatorRegistry # If this certain accumulator was deserialized, don't overwrite it. @@ -46,7 +56,7 @@ def _deserialize_accumulator(aid, zero_value, accum_param): return accum -class Accumulator: +class Accumulator(Generic[T]): """ A shared variable that can be accumulated, i.e., has a commutative and associative "add" @@ -106,7 +116,7 @@ class Accumulator: TypeError: ... """ - def __init__(self, aid, value, accum_param): + def __init__(self, aid: int, value: T, accum_param: "AccumulatorParam[T]"): """Create a new Accumulator with a given initial value and AccumulatorParam object""" from pyspark.accumulators import _accumulatorRegistry @@ -116,42 +126,47 @@ def __init__(self, aid, value, accum_param): self._deserialized = False _accumulatorRegistry[aid] = self - def __reduce__(self): + def __reduce__( + self, + ) -> Tuple[ + Callable[[int, T, "AccumulatorParam[T]"], "Accumulator[T]"], + Tuple[int, T, "AccumulatorParam[T]"], + ]: """Custom serialization; saves the zero value from our AccumulatorParam""" param = self.accum_param return (_deserialize_accumulator, (self.aid, param.zero(self._value), param)) @property - def value(self): + def value(self) -> T: """Get the accumulator's value; only usable in driver program""" if self._deserialized: raise RuntimeError("Accumulator.value cannot be accessed inside tasks") return self._value @value.setter - def value(self, value): + def value(self, value: T) -> None: """Sets the accumulator's value; only usable in driver program""" if self._deserialized: raise RuntimeError("Accumulator.value cannot be accessed inside tasks") self._value = value - def add(self, term): + def add(self, term: T) -> None: """Adds a term to this accumulator's value""" self._value = self.accum_param.addInPlace(self._value, term) - def __iadd__(self, term): + def __iadd__(self, term: T) -> "Accumulator[T]": """The += operator; adds a term to this accumulator's value""" self.add(term) return self - def __str__(self): + def __str__(self) -> str: return str(self._value) - def __repr__(self): + def __repr__(self) -> str: return "Accumulator" % (self.aid, self._value) -class AccumulatorParam: +class AccumulatorParam(Generic[T]): """ Helper object that defines how to accumulate values of a given type. @@ -178,14 +193,14 @@ class AccumulatorParam: [7.0, 8.0, 9.0] """ - def zero(self, value): + def zero(self, value: T) -> T: """ Provide a "zero value" for the type, compatible in dimensions with the provided `value` (e.g., a zero vector) """ raise NotImplementedError - def addInPlace(self, value1, value2): + def addInPlace(self, value1: T, value2: T) -> T: """ Add two values of the accumulator's data type, returning a new value; for efficiency, can also update `value1` in place and return it. @@ -193,7 +208,7 @@ def addInPlace(self, value1, value2): raise NotImplementedError -class AddingAccumulatorParam(AccumulatorParam): +class AddingAccumulatorParam(AccumulatorParam[U]): """ An AccumulatorParam that uses the + operators to add values. Designed for simple types @@ -201,21 +216,21 @@ class AddingAccumulatorParam(AccumulatorParam): as a parameter. """ - def __init__(self, zero_value): + def __init__(self, zero_value: U): self.zero_value = zero_value - def zero(self, value): + def zero(self, value: U) -> U: return self.zero_value - def addInPlace(self, value1, value2): - value1 += value2 + def addInPlace(self, value1: U, value2: U) -> U: + value1 += value2 # type: ignore[operator] return value1 # Singleton accumulator params for some standard types -INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0) -FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0) -COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) +INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0) # type: ignore[type-var] +FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0) # type: ignore[type-var] +COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) # type: ignore[type-var] class _UpdateRequestHandler(SocketServer.StreamRequestHandler): @@ -225,20 +240,20 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler): server is shutdown. """ - def handle(self): + def handle(self) -> None: from pyspark.accumulators import _accumulatorRegistry - auth_token = self.server.auth_token + auth_token = self.server.auth_token # type: ignore[attr-defined] - def poll(func): - while not self.server.server_shutdown: + def poll(func: Callable[[], bool]) -> None: + while not self.server.server_shutdown: # type: ignore[attr-defined] # Poll every 1 second for new data -- don't block in case of shutdown. r, _, _ = select.select([self.rfile], [], [], 1) if self.rfile in r: if func(): break - def accum_updates(): + def accum_updates() -> bool: num_updates = read_int(self.rfile) for _ in range(num_updates): (aid, update) = pickleSer._read_with_length(self.rfile) @@ -247,8 +262,8 @@ def accum_updates(): self.wfile.write(struct.pack("!b", 1)) return False - def authenticate_and_accum_updates(): - received_token = self.rfile.read(len(auth_token)) + def authenticate_and_accum_updates() -> bool: + received_token: Union[bytes, str] = self.rfile.read(len(auth_token)) if isinstance(received_token, bytes): received_token = received_token.decode("utf-8") if received_token == auth_token: @@ -267,7 +282,12 @@ def authenticate_and_accum_updates(): class AccumulatorServer(SocketServer.TCPServer): - def __init__(self, server_address, RequestHandlerClass, auth_token): + def __init__( + self, + server_address: Tuple[str, int], + RequestHandlerClass: Type["socketserver.BaseRequestHandler"], + auth_token: str, + ): SocketServer.TCPServer.__init__(self, server_address, RequestHandlerClass) self.auth_token = auth_token @@ -277,13 +297,13 @@ def __init__(self, server_address, RequestHandlerClass, auth_token): """ server_shutdown = False - def shutdown(self): + def shutdown(self) -> None: self.server_shutdown = True SocketServer.TCPServer.shutdown(self) self.server_close() -def _start_update_server(auth_token): +def _start_update_server(auth_token: str) -> AccumulatorServer: """Start a TCP server to receive accumulator updates in a daemon thread, and returns it""" server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler, auth_token) thread = threading.Thread(target=server.serve_forever) diff --git a/python/pyspark/accumulators.pyi b/python/pyspark/accumulators.pyi deleted file mode 100644 index 315979218cee6..0000000000000 --- a/python/pyspark/accumulators.pyi +++ /dev/null @@ -1,71 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import Callable, Dict, Generic, Tuple, Type, TypeVar - -import socketserver.BaseRequestHandler # type: ignore - -from pyspark._typing import SupportsIAdd - -T = TypeVar("T") -U = TypeVar("U", bound=SupportsIAdd) - -import socketserver as SocketServer - -_accumulatorRegistry: Dict[int, Accumulator] - -class Accumulator(Generic[T]): - aid: int - accum_param: AccumulatorParam[T] - def __init__(self, aid: int, value: T, accum_param: AccumulatorParam[T]) -> None: ... - def __reduce__( - self, - ) -> Tuple[ - Callable[[int, int, AccumulatorParam[T]], Accumulator[T]], - Tuple[int, int, AccumulatorParam[T]], - ]: ... - @property - def value(self) -> T: ... - @value.setter - def value(self, value: T) -> None: ... - def add(self, term: T) -> None: ... - def __iadd__(self, term: T) -> Accumulator[T]: ... - -class AccumulatorParam(Generic[T]): - def zero(self, value: T) -> T: ... - def addInPlace(self, value1: T, value2: T) -> T: ... - -class AddingAccumulatorParam(AccumulatorParam[U]): - zero_value: U - def __init__(self, zero_value: U) -> None: ... - def zero(self, value: U) -> U: ... - def addInPlace(self, value1: U, value2: U) -> U: ... - -class _UpdateRequestHandler(SocketServer.StreamRequestHandler): - def handle(self) -> None: ... - -class AccumulatorServer(SocketServer.TCPServer): - auth_token: str - def __init__( - self, - server_address: Tuple[str, int], - RequestHandlerClass: Type[socketserver.BaseRequestHandler], - auth_token: str, - ) -> None: ... - server_shutdown: bool - def shutdown(self) -> None: ... diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 903e4ea4b0851..edd282de92f64 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -21,20 +21,40 @@ from tempfile import NamedTemporaryFile import threading import pickle +from typing import ( + overload, + Any, + Callable, + Dict, + Generic, + IO, + Iterator, + Optional, + Tuple, + TypeVar, + TYPE_CHECKING, + Union, +) +from typing.io import BinaryIO # type: ignore[import] from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import ChunkedStream, pickle_protocol from pyspark.util import print_exec +if TYPE_CHECKING: + from pyspark import SparkContext + __all__ = ["Broadcast"] +T = TypeVar("T") + # Holds broadcasted data received from Java, keyed by its id. -_broadcastRegistry = {} +_broadcastRegistry: Dict[int, "Broadcast[Any]"] = {} -def _from_id(bid): +def _from_id(bid: int) -> "Broadcast[Any]": from pyspark.broadcast import _broadcastRegistry if bid not in _broadcastRegistry: @@ -42,7 +62,7 @@ def _from_id(bid): return _broadcastRegistry[bid] -class Broadcast: +class Broadcast(Generic[T]): """ A broadcast variable created with :meth:`SparkContext.broadcast`. @@ -62,7 +82,31 @@ class Broadcast: >>> large_broadcast = sc.broadcast(range(10000)) """ - def __init__(self, sc=None, value=None, pickle_registry=None, path=None, sock_file=None): + @overload # On driver + def __init__( + self: "Broadcast[T]", + sc: "SparkContext", + value: T, + pickle_registry: "BroadcastPickleRegistry", + ): + ... + + @overload # On worker without decryption server + def __init__(self: "Broadcast[Any]", *, path: str): + ... + + @overload # On worker with decryption server + def __init__(self: "Broadcast[Any]", *, sock_file: str): + ... + + def __init__( + self, + sc: Optional["SparkContext"] = None, + value: Optional[T] = None, + pickle_registry: Optional["BroadcastPickleRegistry"] = None, + path: Optional[str] = None, + sock_file: Optional[BinaryIO] = None, + ): """ Should not be called directly by users -- use :meth:`SparkContext.broadcast` instead. @@ -71,8 +115,10 @@ def __init__(self, sc=None, value=None, pickle_registry=None, path=None, sock_fi # we're on the driver. We want the pickled data to end up in a file (maybe encrypted) f = NamedTemporaryFile(delete=False, dir=sc._temp_dir) self._path = f.name - self._sc = sc + self._sc: Optional["SparkContext"] = sc + assert sc._jvm is not None self._python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path) + broadcast_out: Union[ChunkedStream, IO[bytes]] if sc._encryption_enabled: # with encryption, we ask the jvm to do the encryption for us, we send it data # over a socket @@ -82,7 +128,7 @@ def __init__(self, sc=None, value=None, pickle_registry=None, path=None, sock_fi else: # no encryption, we can just write pickled data directly to the file from python broadcast_out = f - self.dump(value, broadcast_out) + self.dump(value, broadcast_out) # type: ignore[arg-type] if sc._encryption_enabled: self._python_broadcast.waitTillDataReceived() self._jbroadcast = sc._jsc.broadcast(self._python_broadcast) @@ -102,7 +148,7 @@ def __init__(self, sc=None, value=None, pickle_registry=None, path=None, sock_fi assert path is not None self._path = path - def dump(self, value, f): + def dump(self, value: T, f: BinaryIO) -> None: try: pickle.dump(value, f, pickle_protocol) except pickle.PickleError: @@ -113,11 +159,11 @@ def dump(self, value, f): raise pickle.PicklingError(msg) f.close() - def load_from_path(self, path): + def load_from_path(self, path: str) -> T: with open(path, "rb", 1 << 20) as f: return self.load(f) - def load(self, file): + def load(self, file: BinaryIO) -> T: # "file" could also be a socket gc.disable() try: @@ -126,7 +172,7 @@ def load(self, file): gc.enable() @property - def value(self): + def value(self) -> T: """Return the broadcasted value""" if not hasattr(self, "_value") and self._path is not None: # we only need to decrypt it here when encryption is enabled and @@ -140,7 +186,7 @@ def value(self): self._value = self.load_from_path(self._path) return self._value - def unpersist(self, blocking=False): + def unpersist(self, blocking: bool = False) -> None: """ Delete cached copies of this broadcast on the executors. If the broadcast is used after this is called, it will need to be @@ -155,7 +201,7 @@ def unpersist(self, blocking=False): raise RuntimeError("Broadcast can only be unpersisted in driver") self._jbroadcast.unpersist(blocking) - def destroy(self, blocking=False): + def destroy(self, blocking: bool = False) -> None: """ Destroy all data and metadata related to this broadcast variable. Use this with caution; once a broadcast variable has been destroyed, @@ -175,9 +221,10 @@ def destroy(self, blocking=False): self._jbroadcast.destroy(blocking) os.unlink(self._path) - def __reduce__(self): + def __reduce__(self) -> Tuple[Callable[[int], "Broadcast[T]"], Tuple[int]]: if self._jbroadcast is None: raise RuntimeError("Broadcast can only be serialized in driver") + assert self._pickle_registry is not None self._pickle_registry.add(self) return _from_id, (self._jbroadcast.id(),) @@ -185,17 +232,17 @@ def __reduce__(self): class BroadcastPickleRegistry(threading.local): """Thread-local registry for broadcast variables that have been pickled""" - def __init__(self): + def __init__(self) -> None: self.__dict__.setdefault("_registry", set()) - def __iter__(self): + def __iter__(self) -> Iterator[Broadcast[Any]]: for bcast in self._registry: yield bcast - def add(self, bcast): + def add(self, bcast: Broadcast[Any]) -> None: self._registry.add(bcast) - def clear(self): + def clear(self) -> None: self._registry.clear() diff --git a/python/pyspark/broadcast.pyi b/python/pyspark/broadcast.pyi deleted file mode 100644 index 944cb06d4178c..0000000000000 --- a/python/pyspark/broadcast.pyi +++ /dev/null @@ -1,48 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import threading -from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeVar - -T = TypeVar("T") - -_broadcastRegistry: Dict[int, Broadcast] - -class Broadcast(Generic[T]): - def __init__( - self, - sc: Optional[Any] = ..., - value: Optional[T] = ..., - pickle_registry: Optional[Any] = ..., - path: Optional[Any] = ..., - sock_file: Optional[Any] = ..., - ) -> None: ... - def dump(self, value: T, f: Any) -> None: ... - def load_from_path(self, path: Any) -> T: ... - def load(self, file: Any) -> T: ... - @property - def value(self) -> T: ... - def unpersist(self, blocking: bool = ...) -> None: ... - def destroy(self, blocking: bool = ...) -> None: ... - def __reduce__(self) -> Tuple[Callable[[int], T], Tuple[int]]: ... - -class BroadcastPickleRegistry(threading.local): - def __init__(self) -> None: ... - def __iter__(self) -> None: ... - def add(self, bcast: Any) -> None: ... - def clear(self) -> None: ... diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index 536e1f89cff3f..1ddc8f5ddaa92 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -20,7 +20,7 @@ import sys from typing import Dict, List, Optional, Tuple, cast, overload -from py4j.java_gateway import JVMView, JavaObject # type: ignore[import] +from py4j.java_gateway import JVMView, JavaObject class SparkConf: @@ -124,7 +124,7 @@ def __init__( else: from pyspark.context import SparkContext - _jvm = _jvm or SparkContext._jvm # type: ignore[attr-defined] + _jvm = _jvm or SparkContext._jvm if _jvm is not None: # JVM is created, so create self._jconf directly through JVM @@ -203,6 +203,18 @@ def setAll(self, pairs: List[Tuple[str, str]]) -> "SparkConf": self.set(k, v) return self + @overload + def get(self, key: str) -> Optional[str]: + ... + + @overload + def get(self, key: str, defaultValue: None) -> Optional[str]: + ... + + @overload + def get(self, key: str, defaultValue: str) -> str: + ... + def get(self, key: str, defaultValue: Optional[str] = None) -> Optional[str]: """Get the configured value for some key, or return a default otherwise.""" if defaultValue is None: # Py4J doesn't call the right get() if we pass None diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 1002716ae2453..59b5fa7f3a434 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -35,6 +35,7 @@ List, NoReturn, Optional, + Sequence, Tuple, Type, TYPE_CHECKING, @@ -62,7 +63,7 @@ ) from pyspark.storagelevel import StorageLevel from pyspark.resource.information import ResourceInformation -from pyspark.rdd import RDD, _load_from_socket # type: ignore[attr-defined] +from pyspark.rdd import RDD, _load_from_socket from pyspark.taskcontext import TaskContext from pyspark.traceback_utils import CallSite, first_spark_call from pyspark.status import StatusTracker @@ -180,10 +181,7 @@ def __init__( udf_profiler_cls: Type[UDFBasicProfiler] = UDFBasicProfiler, ): - if ( - conf is None - or cast(str, conf.get("spark.executor.allowSparkContext", "false")).lower() != "true" - ): + if conf is None or conf.get("spark.executor.allowSparkContext", "false").lower() != "true": # In order to prevent SparkContext from being created in executors. SparkContext._assert_on_driver() @@ -289,7 +287,7 @@ def _do_init( # they will be passed back to us through a TCP server assert self._gateway is not None auth_token = self._gateway.gateway_parameters.auth_token - start_update_server = accumulators._start_update_server # type: ignore[attr-defined] + start_update_server = accumulators._start_update_server self._accumulatorServer = start_update_server(auth_token) (host, port) = self._accumulatorServer.server_address assert self._jvm is not None @@ -325,7 +323,7 @@ def _do_init( # Deploy code dependencies set by spark-submit; these will already have been added # with SparkContext.addFile, so we just need to add them to the PYTHONPATH - for path in cast(str, self._conf.get("spark.submit.pyFiles", "")).split(","): + for path in self._conf.get("spark.submit.pyFiles", "").split(","): if path != "": (dirname, filename) = os.path.split(path) try: @@ -570,11 +568,11 @@ def stop(self) -> None: self._jsc = None if getattr(self, "_accumulatorServer", None): self._accumulatorServer.shutdown() - self._accumulatorServer = None + self._accumulatorServer = None # type: ignore[assignment] with SparkContext._lock: - SparkContext._active_spark_context = None # type: ignore[assignment] + SparkContext._active_spark_context = None - def emptyRDD(self) -> "RDD[Any]": + def emptyRDD(self) -> RDD[Any]: """ Create an RDD that has no partitions or elements. """ @@ -582,7 +580,7 @@ def emptyRDD(self) -> "RDD[Any]": def range( self, start: int, end: Optional[int] = None, step: int = 1, numSlices: Optional[int] = None - ) -> "RDD[int]": + ) -> RDD[int]: """ Create a new RDD of int containing elements from `start` to `end` (exclusive), increased by `step` every element. Can be called the same @@ -620,7 +618,7 @@ def range( return self.parallelize(range(start, end, step), numSlices) - def parallelize(self, c: Iterable[T], numSlices: Optional[int] = None) -> "RDD[T]": + def parallelize(self, c: Iterable[T], numSlices: Optional[int] = None) -> RDD[T]: """ Distribute a local Python collection to form an RDD. Using range is recommended if the input represents a range for performance. @@ -724,7 +722,7 @@ def _serialize_to_jvm( # we eagerly reads the file so we can delete right after. os.unlink(tempFile.name) - def pickleFile(self, name: str, minPartitions: Optional[int] = None) -> "RDD[Any]": + def pickleFile(self, name: str, minPartitions: Optional[int] = None) -> RDD[Any]: """ Load an RDD previously saved using :meth:`RDD.saveAsPickleFile` method. @@ -741,7 +739,7 @@ def pickleFile(self, name: str, minPartitions: Optional[int] = None) -> "RDD[Any def textFile( self, name: str, minPartitions: Optional[int] = None, use_unicode: bool = True - ) -> "RDD[str]": + ) -> RDD[str]: """ Read a text file from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI, and return it as an @@ -766,7 +764,7 @@ def textFile( def wholeTextFiles( self, path: str, minPartitions: Optional[int] = None, use_unicode: bool = True - ) -> "RDD[Tuple[str, str]]": + ) -> RDD[Tuple[str, str]]: """ Read a directory of text files from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system @@ -821,9 +819,7 @@ def wholeTextFiles( PairDeserializer(UTF8Deserializer(use_unicode), UTF8Deserializer(use_unicode)), ) - def binaryFiles( - self, path: str, minPartitions: Optional[int] = None - ) -> "RDD[Tuple[str, bytes]]": + def binaryFiles(self, path: str, minPartitions: Optional[int] = None) -> RDD[Tuple[str, bytes]]: """ Read a directory of binary files from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI @@ -842,7 +838,7 @@ def binaryFiles( PairDeserializer(UTF8Deserializer(), NoOpSerializer()), ) - def binaryRecords(self, path: str, recordLength: int) -> "RDD[bytes]": + def binaryRecords(self, path: str, recordLength: int) -> RDD[bytes]: """ Load data from a flat binary file, assuming each record is a set of numbers with the specified numerical format (see ByteBuffer), and the number of @@ -875,7 +871,7 @@ def sequenceFile( valueConverter: Optional[str] = None, minSplits: Optional[int] = None, batchSize: int = 0, - ) -> "RDD[Tuple[T, U]]": + ) -> RDD[Tuple[T, U]]: """ Read a Hadoop SequenceFile with arbitrary key and value Writable class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. @@ -930,7 +926,7 @@ def newAPIHadoopFile( valueConverter: Optional[str] = None, conf: Optional[Dict[str, str]] = None, batchSize: int = 0, - ) -> "RDD[Tuple[T, U]]": + ) -> RDD[Tuple[T, U]]: """ Read a 'new API' Hadoop InputFormat with arbitrary key and value class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. @@ -989,7 +985,7 @@ def newAPIHadoopRDD( valueConverter: Optional[str] = None, conf: Optional[Dict[str, str]] = None, batchSize: int = 0, - ) -> "RDD[Tuple[T, U]]": + ) -> RDD[Tuple[T, U]]: """ Read a 'new API' Hadoop InputFormat with arbitrary key and value class, from an arbitrary Hadoop configuration, which is passed in as a Python dict. @@ -1042,7 +1038,7 @@ def hadoopFile( valueConverter: Optional[str] = None, conf: Optional[Dict[str, str]] = None, batchSize: int = 0, - ) -> "RDD[Tuple[T, U]]": + ) -> RDD[Tuple[T, U]]: """ Read an 'old' Hadoop InputFormat with arbitrary key and value class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. @@ -1097,7 +1093,7 @@ def hadoopRDD( valueConverter: Optional[str] = None, conf: Optional[Dict[str, str]] = None, batchSize: int = 0, - ) -> "RDD[Tuple[T, U]]": + ) -> RDD[Tuple[T, U]]: """ Read an 'old' Hadoop InputFormat with arbitrary key and value class, from an arbitrary Hadoop configuration, which is passed in as a Python dict. @@ -1144,7 +1140,7 @@ def _checkpointFile(self, name: str, input_deserializer: PairDeserializer) -> RD jrdd = self._jsc.checkpointFile(name) return RDD(jrdd, self, input_deserializer) - def union(self, rdds: List["RDD[T]"]) -> "RDD[T]": + def union(self, rdds: List[RDD[T]]) -> RDD[T]: """ Build the union of a list of RDDs. @@ -1164,12 +1160,9 @@ def union(self, rdds: List["RDD[T]"]) -> "RDD[T]": >>> sorted(sc.union([textFile, parallelized]).collect()) ['Hello', 'World!'] """ - first_jrdd_deserializer = rdds[0]._jrdd_deserializer # type: ignore[attr-defined] - if any( - x._jrdd_deserializer != first_jrdd_deserializer # type: ignore[attr-defined] - for x in rdds - ): - rdds = [x._reserialize() for x in rdds] # type: ignore[attr-defined] + first_jrdd_deserializer = rdds[0]._jrdd_deserializer + if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds): + rdds = [x._reserialize() for x in rdds] gw = SparkContext._gateway assert gw is not None jvm = SparkContext._jvm @@ -1177,21 +1170,19 @@ def union(self, rdds: List["RDD[T]"]) -> "RDD[T]": jrdd_cls = jvm.org.apache.spark.api.java.JavaRDD jpair_rdd_cls = jvm.org.apache.spark.api.java.JavaPairRDD jdouble_rdd_cls = jvm.org.apache.spark.api.java.JavaDoubleRDD - if is_instance_of(gw, rdds[0]._jrdd, jrdd_cls): # type: ignore[attr-defined] + if is_instance_of(gw, rdds[0]._jrdd, jrdd_cls): cls = jrdd_cls - elif is_instance_of(gw, rdds[0]._jrdd, jpair_rdd_cls): # type: ignore[attr-defined] + elif is_instance_of(gw, rdds[0]._jrdd, jpair_rdd_cls): cls = jpair_rdd_cls - elif is_instance_of(gw, rdds[0]._jrdd, jdouble_rdd_cls): # type: ignore[attr-defined] + elif is_instance_of(gw, rdds[0]._jrdd, jdouble_rdd_cls): cls = jdouble_rdd_cls else: - cls_name = rdds[0]._jrdd.getClass().getCanonicalName() # type: ignore[attr-defined] + cls_name = rdds[0]._jrdd.getClass().getCanonicalName() raise TypeError("Unsupported Java RDD class %s" % cls_name) jrdds = gw.new_array(cls, len(rdds)) for i in range(0, len(rdds)): - jrdds[i] = rdds[i]._jrdd # type: ignore[attr-defined] - return RDD( - self._jsc.union(jrdds), self, rdds[0]._jrdd_deserializer # type: ignore[attr-defined] - ) + jrdds[i] = rdds[i]._jrdd + return RDD(self._jsc.union(jrdds), self, rdds[0]._jrdd_deserializer) def broadcast(self, value: T) -> "Broadcast[T]": """ @@ -1213,11 +1204,11 @@ def accumulator( """ if accum_param is None: if isinstance(value, int): - accum_param = accumulators.INT_ACCUMULATOR_PARAM # type: ignore[attr-defined] + accum_param = cast("AccumulatorParam[T]", accumulators.INT_ACCUMULATOR_PARAM) elif isinstance(value, float): - accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM # type: ignore[attr-defined] + accum_param = cast("AccumulatorParam[T]", accumulators.FLOAT_ACCUMULATOR_PARAM) elif isinstance(value, complex): - accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM # type: ignore[attr-defined] + accum_param = cast("AccumulatorParam[T]", accumulators.COMPLEX_ACCUMULATOR_PARAM) else: raise TypeError("No default accumulator param for type %s" % type(value)) SparkContext._next_accum_id += 1 @@ -1277,6 +1268,50 @@ def addPyFile(self, path: str) -> None: importlib.invalidate_caches() + def addArchive(self, path: str) -> None: + """ + Add an archive to be downloaded with this Spark job on every node. + The `path` passed can be either a local file, a file in HDFS + (or other Hadoop-supported filesystems), or an HTTP, HTTPS or + FTP URI. + + To access the file in Spark jobs, use :meth:`SparkFiles.get` with the + filename to find its download/unpacked location. The given path should + be one of .zip, .tar, .tar.gz, .tgz and .jar. + + .. versionadded:: 3.3.0 + + Notes + ----- + A path can be added only once. Subsequent additions of the same path are ignored. + This API is experimental. + + Examples + -------- + Creates a zipped file that contains a text file written '100'. + + >>> import zipfile + >>> from pyspark import SparkFiles + >>> path = os.path.join(tempdir, "test.txt") + >>> zip_path = os.path.join(tempdir, "test.zip") + >>> with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipped: + ... with open(path, "w") as f: + ... _ = f.write("100") + ... zipped.write(path, os.path.basename(path)) + >>> sc.addArchive(zip_path) + + Reads the '100' as an integer in the zipped file, and processes + it with the data in the RDD. + + >>> def func(iterator): + ... with open("%s/test.txt" % SparkFiles.get("test.zip")) as f: + ... v = int(f.readline()) + ... return [x * int(v) for x in iterator] + >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect() + [100, 200, 300, 400] + """ + self._jsc.sc().addArchive(path) + def setCheckpointDir(self, dirName: str) -> None: """ Set the directory under which RDDs are going to be checkpointed. The @@ -1330,7 +1365,7 @@ def setJobGroup(self, groupId: str, description: str, interruptOnCancel: bool = to HDFS-1208, where HDFS may respond to Thread.interrupt() by marking nodes as dead. If you run jobs in parallel, use :class:`pyspark.InheritableThread` for thread - local inheritance, and preventing resource leak. + local inheritance. Examples -------- @@ -1370,7 +1405,7 @@ def setLocalProperty(self, key: str, value: str) -> None: Notes ----- If you run jobs in parallel, use :class:`pyspark.InheritableThread` for thread - local inheritance, and preventing resource leak. + local inheritance. """ self._jsc.setLocalProperty(key, value) @@ -1388,7 +1423,7 @@ def setJobDescription(self, value: str) -> None: Notes ----- If you run jobs in parallel, use :class:`pyspark.InheritableThread` for thread - local inheritance, and preventing resource leak. + local inheritance. """ self._jsc.setJobDescription(value) @@ -1419,9 +1454,9 @@ def statusTracker(self) -> StatusTracker: def runJob( self, - rdd: "RDD[T]", + rdd: RDD[T], partitionFunc: Callable[[Iterable[T]], Iterable[U]], - partitions: Optional[List[int]] = None, + partitions: Optional[Sequence[int]] = None, allowLocal: bool = False, ) -> List[U]: """ @@ -1441,19 +1476,15 @@ def runJob( [0, 1, 16, 25] """ if partitions is None: - partitions = list(range(rdd._jrdd.partitions().size())) # type: ignore[attr-defined] + partitions = list(range(rdd._jrdd.partitions().size())) # Implementation note: This is implemented as a mapPartitions followed # by runJob() in order to avoid having to pass a Python lambda into # SparkContext#runJob. mappedRDD = rdd.mapPartitions(partitionFunc) assert self._jvm is not None - sock_info = self._jvm.PythonRDD.runJob( - self._jsc.sc(), mappedRDD._jrdd, partitions # type: ignore[attr-defined] - ) - return list( - _load_from_socket(sock_info, mappedRDD._jrdd_deserializer) # type: ignore[attr-defined] - ) + sock_info = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions) + return list(_load_from_socket(sock_info, mappedRDD._jrdd_deserializer)) def show_profiles(self) -> None: """Print the profile stats to stdout""" diff --git a/python/pyspark/instrumentation_utils.py b/python/pyspark/instrumentation_utils.py new file mode 100644 index 0000000000000..908f5cbb3d473 --- /dev/null +++ b/python/pyspark/instrumentation_utils.py @@ -0,0 +1,183 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import functools +import inspect +import threading +import importlib +import time +from types import ModuleType +from typing import Tuple, Union, List, Callable, Any, Type + + +__all__: List[str] = [] + +_local = threading.local() + + +def _wrap_function(class_name: str, function_name: str, func: Callable, logger: Any) -> Callable: + + signature = inspect.signature(func) + + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + if hasattr(_local, "logging") and _local.logging: + # no need to log since this should be internal call. + return func(*args, **kwargs) + _local.logging = True + try: + start = time.perf_counter() + try: + res = func(*args, **kwargs) + logger.log_success( + class_name, function_name, time.perf_counter() - start, signature + ) + return res + except Exception as ex: + logger.log_failure( + class_name, function_name, ex, time.perf_counter() - start, signature + ) + raise + finally: + _local.logging = False + + return wrapper + + +def _wrap_property(class_name: str, property_name: str, prop: Any, logger: Any) -> Any: + @property # type: ignore[misc] + def wrapper(self: Any) -> Any: + if hasattr(_local, "logging") and _local.logging: + # no need to log since this should be internal call. + return prop.fget(self) + _local.logging = True + try: + start = time.perf_counter() + try: + res = prop.fget(self) + logger.log_success(class_name, property_name, time.perf_counter() - start) + return res + except Exception as ex: + logger.log_failure(class_name, property_name, ex, time.perf_counter() - start) + raise + finally: + _local.logging = False + + wrapper.__doc__ = prop.__doc__ + + if prop.fset is not None: + wrapper = wrapper.setter( # type: ignore[attr-defined] + _wrap_function(class_name, prop.fset.__name__, prop.fset, logger) + ) + + return wrapper + + +def _wrap_missing_function( + class_name: str, function_name: str, func: Callable, original: Any, logger: Any +) -> Any: + + if not hasattr(original, function_name): + return func + + signature = inspect.signature(getattr(original, function_name)) + + is_deprecated = func.__name__ == "deprecated_function" + + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + try: + return func(*args, **kwargs) + finally: + logger.log_missing(class_name, function_name, is_deprecated, signature) + + return wrapper + + +def _wrap_missing_property(class_name: str, property_name: str, prop: Any, logger: Any) -> Any: + + is_deprecated = prop.fget.__name__ == "deprecated_property" + + @property # type: ignore[misc] + def wrapper(self: Any) -> Any: + try: + return prop.fget(self) + finally: + logger.log_missing(class_name, property_name, is_deprecated) + + return wrapper + + +def _attach( + logger_module: Union[str, ModuleType], + modules: List[ModuleType], + classes: List[Type[Any]], + missings: List[Tuple[Type[Any], Type[Any]]], +) -> None: + if isinstance(logger_module, str): + logger_module = importlib.import_module(logger_module) + + logger = getattr(logger_module, "get_logger")() + + special_functions = set( + [ + "__init__", + "__repr__", + "__str__", + "_repr_html_", + "__len__", + "__getitem__", + "__setitem__", + "__getattr__", + "__enter__", + "__exit__", + ] + ) + + # Modules + for target_module in modules: + target_name = target_module.__name__.split(".")[-1] + for name in getattr(target_module, "__all__"): + func = getattr(target_module, name) + if not inspect.isfunction(func): + continue + setattr(target_module, name, _wrap_function(target_name, name, func, logger)) + + # Classes + for target_class in classes: + for name, func in inspect.getmembers(target_class, inspect.isfunction): + if name.startswith("_") and name not in special_functions: + continue + setattr(target_class, name, _wrap_function(target_class.__name__, name, func, logger)) + + for name, prop in inspect.getmembers(target_class, lambda o: isinstance(o, property)): + if name.startswith("_"): + continue + setattr(target_class, name, _wrap_property(target_class.__name__, name, prop, logger)) + + # Missings + for original, missing in missings: + for name, func in inspect.getmembers(missing, inspect.isfunction): + setattr( + missing, + name, + _wrap_missing_function(original.__name__, name, func, original, logger), + ) + + for name, prop in inspect.getmembers(missing, lambda o: isinstance(o, property)): + setattr(missing, name, _wrap_missing_property(original.__name__, name, prop, logger)) diff --git a/python/pyspark/ml/_typing.pyi b/python/pyspark/ml/_typing.pyi index b51aa9634fe77..12d831f1e8c7e 100644 --- a/python/pyspark/ml/_typing.pyi +++ b/python/pyspark/ml/_typing.pyi @@ -16,12 +16,15 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, TypeVar, Union +from typing import Any, Dict, List, TypeVar, Tuple, Union from typing_extensions import Literal +from numpy import ndarray + import pyspark.ml.base import pyspark.ml.param import pyspark.ml.util +from pyspark.ml.linalg import Vector import pyspark.ml.wrapper from py4j.java_gateway import JavaObject @@ -68,6 +71,8 @@ MultilabelClassificationEvaluatorMetricType = Union[ Literal["microF1Measure"], ] ClusteringEvaluatorMetricType = Literal["silhouette"] +ClusteringEvaluatorDistanceMeasureType = Union[Literal["squaredEuclidean"], Literal["cosine"]] + RankingEvaluatorMetricType = Union[ Literal["meanAveragePrecision"], Literal["meanAveragePrecisionAtK"], @@ -75,3 +80,5 @@ RankingEvaluatorMetricType = Union[ Literal["ndcgAtK"], Literal["recallAtK"], ] + +VectorLike = Union[ndarray, Vector, List[float], Tuple[float, ...]] diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py index 9d2a1917d9f0e..20540ebbef65a 100644 --- a/python/pyspark/ml/base.py +++ b/python/pyspark/ml/base.py @@ -15,12 +15,29 @@ # limitations under the License. # -from abc import ABCMeta, abstractmethod, abstractproperty +from abc import ABCMeta, abstractmethod import copy import threading +from typing import ( + Any, + Callable, + Generic, + Iterator, + List, + Optional, + Sequence, + Tuple, + TypeVar, + Union, + cast, + overload, + TYPE_CHECKING, +) + from pyspark import since +from pyspark.ml.param import P from pyspark.ml.common import inherit_doc from pyspark.ml.param.shared import ( HasInputCol, @@ -30,11 +47,18 @@ HasPredictionCol, Params, ) +from pyspark.sql.dataframe import DataFrame from pyspark.sql.functions import udf -from pyspark.sql.types import StructField, StructType +from pyspark.sql.types import DataType, StructField, StructType + +if TYPE_CHECKING: + from pyspark.ml._typing import ParamMap + +T = TypeVar("T") +M = TypeVar("M", bound="Transformer") -class _FitMultipleIterator: +class _FitMultipleIterator(Generic[M]): """ Used by default implementation of Estimator.fitMultiple to produce models in a thread safe iterator. This class handles the simple case of fitMultiple where each param map should be @@ -55,17 +79,17 @@ class _FitMultipleIterator: See :py:meth:`Estimator.fitMultiple` for more info. """ - def __init__(self, fitSingleModel, numModels): + def __init__(self, fitSingleModel: Callable[[int], M], numModels: int): """ """ self.fitSingleModel = fitSingleModel self.numModel = numModels self.counter = 0 self.lock = threading.Lock() - def __iter__(self): + def __iter__(self) -> Iterator[Tuple[int, M]]: return self - def __next__(self): + def __next__(self) -> Tuple[int, M]: with self.lock: index = self.counter if index >= self.numModel: @@ -73,13 +97,13 @@ def __next__(self): self.counter += 1 return index, self.fitSingleModel(index) - def next(self): + def next(self) -> Tuple[int, M]: """For python2 compatibility.""" return self.__next__() @inherit_doc -class Estimator(Params, metaclass=ABCMeta): +class Estimator(Params, Generic[M], metaclass=ABCMeta): """ Abstract class for estimators that fit models to data. @@ -89,7 +113,7 @@ class Estimator(Params, metaclass=ABCMeta): pass @abstractmethod - def _fit(self, dataset): + def _fit(self, dataset: DataFrame) -> M: """ Fits a model to the input dataset. This is called by the default implementation of fit. @@ -106,7 +130,9 @@ def _fit(self, dataset): """ raise NotImplementedError() - def fitMultiple(self, dataset, paramMaps): + def fitMultiple( + self, dataset: DataFrame, paramMaps: Sequence["ParamMap"] + ) -> Iterator[Tuple[int, M]]: """ Fits a model to the input dataset for each param map in `paramMaps`. @@ -128,12 +154,26 @@ def fitMultiple(self, dataset, paramMaps): """ estimator = self.copy() - def fitSingleModel(index): + def fitSingleModel(index: int) -> M: return estimator.fit(dataset, paramMaps[index]) return _FitMultipleIterator(fitSingleModel, len(paramMaps)) - def fit(self, dataset, params=None): + @overload + def fit(self, dataset: DataFrame, params: Optional["ParamMap"] = ...) -> M: + ... + + @overload + def fit( + self, dataset: DataFrame, params: Union[List["ParamMap"], Tuple["ParamMap"]] + ) -> List[M]: + ... + + def fit( + self, + dataset: DataFrame, + params: Optional[Union["ParamMap", List["ParamMap"], Tuple["ParamMap"]]] = None, + ) -> Union[M, List[M]]: """ Fits a model to the input dataset with optional parameters. @@ -156,10 +196,10 @@ def fit(self, dataset, params=None): if params is None: params = dict() if isinstance(params, (list, tuple)): - models = [None] * len(params) + models: List[Optional[M]] = [None] * len(params) for index, model in self.fitMultiple(dataset, params): models[index] = model - return models + return cast(List[M], models) elif isinstance(params, dict): if params: return self.copy(params)._fit(dataset) @@ -183,7 +223,7 @@ class Transformer(Params, metaclass=ABCMeta): pass @abstractmethod - def _transform(self, dataset): + def _transform(self, dataset: DataFrame) -> DataFrame: """ Transforms the input dataset. @@ -199,7 +239,7 @@ def _transform(self, dataset): """ raise NotImplementedError() - def transform(self, dataset, params=None): + def transform(self, dataset: DataFrame, params: Optional["ParamMap"] = None) -> DataFrame: """ Transforms the input dataset with optional parameters. @@ -248,20 +288,20 @@ class UnaryTransformer(HasInputCol, HasOutputCol, Transformer): .. versionadded:: 2.3.0 """ - def setInputCol(self, value): + def setInputCol(self: P, value: str) -> P: """ Sets the value of :py:attr:`inputCol`. """ return self._set(inputCol=value) - def setOutputCol(self, value): + def setOutputCol(self: P, value: str) -> P: """ Sets the value of :py:attr:`outputCol`. """ return self._set(outputCol=value) @abstractmethod - def createTransformFunc(self): + def createTransformFunc(self) -> Callable[..., Any]: """ Creates the transform function using the given param map. The input param map already takes account of the embedded param map. So the param values should be determined @@ -270,20 +310,20 @@ def createTransformFunc(self): raise NotImplementedError() @abstractmethod - def outputDataType(self): + def outputDataType(self) -> DataType: """ Returns the data type of the output column. """ raise NotImplementedError() @abstractmethod - def validateInputType(self, inputType): + def validateInputType(self, inputType: DataType) -> None: """ Validates the input type. Throw an exception if it is invalid. """ raise NotImplementedError() - def transformSchema(self, schema): + def transformSchema(self, schema: StructType) -> StructType: inputType = schema[self.getInputCol()].dataType self.validateInputType(inputType) if self.getOutputCol() in schema.names: @@ -292,7 +332,7 @@ def transformSchema(self, schema): outputFields.append(StructField(self.getOutputCol(), self.outputDataType(), nullable=False)) return StructType(outputFields) - def _transform(self, dataset): + def _transform(self, dataset: DataFrame) -> DataFrame: self.transformSchema(dataset.schema) transformUDF = udf(self.createTransformFunc(), self.outputDataType()) transformedDataset = dataset.withColumn( @@ -313,27 +353,27 @@ class _PredictorParams(HasLabelCol, HasFeaturesCol, HasPredictionCol): @inherit_doc -class Predictor(Estimator, _PredictorParams, metaclass=ABCMeta): +class Predictor(Estimator[M], _PredictorParams, metaclass=ABCMeta): """ Estimator for prediction tasks (regression and classification). """ @since("3.0.0") - def setLabelCol(self, value): + def setLabelCol(self: P, value: str) -> P: """ Sets the value of :py:attr:`labelCol`. """ return self._set(labelCol=value) @since("3.0.0") - def setFeaturesCol(self, value): + def setFeaturesCol(self: P, value: str) -> P: """ Sets the value of :py:attr:`featuresCol`. """ return self._set(featuresCol=value) @since("3.0.0") - def setPredictionCol(self, value): + def setPredictionCol(self: P, value: str) -> P: """ Sets the value of :py:attr:`predictionCol`. """ @@ -341,28 +381,29 @@ def setPredictionCol(self, value): @inherit_doc -class PredictionModel(Model, _PredictorParams, metaclass=ABCMeta): +class PredictionModel(Model, _PredictorParams, Generic[T], metaclass=ABCMeta): """ Model for prediction tasks (regression and classification). """ @since("3.0.0") - def setFeaturesCol(self, value): + def setFeaturesCol(self: P, value: str) -> P: """ Sets the value of :py:attr:`featuresCol`. """ return self._set(featuresCol=value) @since("3.0.0") - def setPredictionCol(self, value): + def setPredictionCol(self: P, value: str) -> P: """ Sets the value of :py:attr:`predictionCol`. """ return self._set(predictionCol=value) - @abstractproperty + @property # type: ignore[misc] + @abstractmethod @since("2.1.0") - def numFeatures(self): + def numFeatures(self) -> int: """ Returns the number of features the model was trained on. If unknown, returns -1 """ @@ -370,7 +411,7 @@ def numFeatures(self): @abstractmethod @since("3.0.0") - def predict(self, value): + def predict(self, value: T) -> float: """ Predict label for the given features. """ diff --git a/python/pyspark/ml/base.pyi b/python/pyspark/ml/base.pyi deleted file mode 100644 index 37ae6de7ed9a5..0000000000000 --- a/python/pyspark/ml/base.pyi +++ /dev/null @@ -1,103 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import overload -from typing import ( - Callable, - Generic, - Iterable, - List, - Optional, - Sequence, - Tuple, - Union, -) -from pyspark.ml._typing import M, P, T, ParamMap - -import _thread - -import abc -from abc import abstractmethod -from pyspark import since as since # noqa: F401 -from pyspark.ml.common import inherit_doc as inherit_doc # noqa: F401 -from pyspark.ml.param.shared import ( - HasFeaturesCol as HasFeaturesCol, - HasInputCol as HasInputCol, - HasLabelCol as HasLabelCol, - HasOutputCol as HasOutputCol, - HasPredictionCol as HasPredictionCol, - Params as Params, -) -from pyspark.sql.functions import udf as udf # noqa: F401 -from pyspark.sql.types import ( # noqa: F401 - DataType, - StructField as StructField, - StructType as StructType, -) - -from pyspark.sql.dataframe import DataFrame - -class _FitMultipleIterator: - fitSingleModel: Callable[[int], Transformer] - numModel: int - counter: int = ... - lock: _thread.LockType - def __init__(self, fitSingleModel: Callable[[int], Transformer], numModels: int) -> None: ... - def __iter__(self) -> _FitMultipleIterator: ... - def __next__(self) -> Tuple[int, Transformer]: ... - def next(self) -> Tuple[int, Transformer]: ... - -class Estimator(Generic[M], Params, metaclass=abc.ABCMeta): - @overload - def fit(self, dataset: DataFrame, params: Optional[ParamMap] = ...) -> M: ... - @overload - def fit( - self, dataset: DataFrame, params: Union[List[ParamMap], Tuple[ParamMap]] - ) -> List[M]: ... - def fitMultiple( - self, dataset: DataFrame, params: Sequence[ParamMap] - ) -> Iterable[Tuple[int, M]]: ... - -class Transformer(Params, metaclass=abc.ABCMeta): - def transform(self, dataset: DataFrame, params: Optional[ParamMap] = ...) -> DataFrame: ... - -class Model(Transformer, metaclass=abc.ABCMeta): ... - -class UnaryTransformer(HasInputCol, HasOutputCol, Transformer, metaclass=abc.ABCMeta): - def createTransformFunc(self) -> Callable: ... - def outputDataType(self) -> DataType: ... - def validateInputType(self, inputType: DataType) -> None: ... - def transformSchema(self, schema: StructType) -> StructType: ... - def setInputCol(self: M, value: str) -> M: ... - def setOutputCol(self: M, value: str) -> M: ... - -class _PredictorParams(HasLabelCol, HasFeaturesCol, HasPredictionCol): ... - -class Predictor(Estimator[M], _PredictorParams, metaclass=abc.ABCMeta): - def setLabelCol(self: P, value: str) -> P: ... - def setFeaturesCol(self: P, value: str) -> P: ... - def setPredictionCol(self: P, value: str) -> P: ... - -class PredictionModel(Generic[T], Model, _PredictorParams, metaclass=abc.ABCMeta): - def setFeaturesCol(self: M, value: str) -> M: ... - def setPredictionCol(self: M, value: str) -> M: ... - @property - @abc.abstractmethod - def numFeatures(self) -> int: ... - @abstractmethod - def predict(self, value: T) -> float: ... diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index e6ce3e0b9ae89..b791e6f169d44 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -20,7 +20,7 @@ import sys import uuid import warnings -from abc import ABCMeta, abstractmethod, abstractproperty +from abc import ABCMeta, abstractmethod from multiprocessing.pool import ThreadPool from pyspark import keyword_only, since, SparkContext, inheritable_thread_target @@ -155,7 +155,8 @@ def setRawPredictionCol(self, value): """ return self._set(rawPredictionCol=value) - @abstractproperty + @property + @abstractmethod @since("2.1.0") def numClasses(self): """ @@ -998,8 +999,7 @@ def getThreshold(self): raise ValueError( "Logistic Regression getThreshold only applies to" + " binary classification, but thresholds has length != 2." - + " thresholds: " - + ",".join(ts) + + " thresholds: {ts}".format(ts=ts) ) return 1.0 / (1.0 + ts[0] / ts[1]) else: diff --git a/python/pyspark/ml/classification.pyi b/python/pyspark/ml/classification.pyi index bb4fb056a95d0..16c31924defde 100644 --- a/python/pyspark/ml/classification.pyi +++ b/python/pyspark/ml/classification.pyi @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, List, Optional, Type +from typing import Any, Generic, List, Optional, Type from pyspark.ml._typing import JM, M, P, T, ParamMap import abc @@ -69,6 +69,8 @@ from pyspark.ml.param import Param from pyspark.ml.regression import DecisionTreeRegressionModel from pyspark.sql.dataframe import DataFrame +from py4j.java_gateway import JavaObject + class _ClassifierParams(HasRawPredictionCol, _PredictorParams): ... class Classifier(Predictor, _ClassifierParams, metaclass=abc.ABCMeta): @@ -96,7 +98,7 @@ class ProbabilisticClassificationModel( @abstractmethod def predictProbability(self, value: Vector) -> Vector: ... -class _JavaClassifier(Classifier, JavaPredictor[JM], metaclass=abc.ABCMeta): +class _JavaClassifier(Classifier, JavaPredictor[JM], Generic[JM], metaclass=abc.ABCMeta): def setRawPredictionCol(self: P, value: str) -> P: ... class _JavaClassificationModel(ClassificationModel, JavaPredictionModel[T]): @@ -105,7 +107,7 @@ class _JavaClassificationModel(ClassificationModel, JavaPredictionModel[T]): def predictRaw(self, value: Vector) -> Vector: ... class _JavaProbabilisticClassifier( - ProbabilisticClassifier, _JavaClassifier[JM], metaclass=abc.ABCMeta + ProbabilisticClassifier, _JavaClassifier[JM], Generic[JM], metaclass=abc.ABCMeta ): ... class _JavaProbabilisticClassificationModel( @@ -231,6 +233,7 @@ class LinearSVC( def setWeightCol(self, value: str) -> LinearSVC: ... def setAggregationDepth(self, value: int) -> LinearSVC: ... def setMaxBlockSizeInMB(self, value: float) -> LinearSVC: ... + def _create_model(self, java_model: JavaObject) -> LinearSVCModel: ... class LinearSVCModel( _JavaClassificationModel[Vector], @@ -350,6 +353,7 @@ class LogisticRegression( def setWeightCol(self, value: str) -> LogisticRegression: ... def setAggregationDepth(self, value: int) -> LogisticRegression: ... def setMaxBlockSizeInMB(self, value: float) -> LogisticRegression: ... + def _create_model(self, java_model: JavaObject) -> LogisticRegressionModel: ... class LogisticRegressionModel( _JavaProbabilisticClassificationModel[Vector], @@ -444,6 +448,7 @@ class DecisionTreeClassifier( def setCheckpointInterval(self, value: int) -> DecisionTreeClassifier: ... def setSeed(self, value: int) -> DecisionTreeClassifier: ... def setWeightCol(self, value: str) -> DecisionTreeClassifier: ... + def _create_model(self, java_model: JavaObject) -> DecisionTreeClassificationModel: ... class DecisionTreeClassificationModel( _DecisionTreeModel, @@ -529,6 +534,7 @@ class RandomForestClassifier( def setCheckpointInterval(self, value: int) -> RandomForestClassifier: ... def setWeightCol(self, value: str) -> RandomForestClassifier: ... def setMinWeightFractionPerNode(self, value: float) -> RandomForestClassifier: ... + def _create_model(self, java_model: JavaObject) -> RandomForestClassificationModel: ... class RandomForestClassificationModel( _TreeEnsembleModel, @@ -633,6 +639,7 @@ class GBTClassifier( def setStepSize(self, value: float) -> GBTClassifier: ... def setWeightCol(self, value: str) -> GBTClassifier: ... def setMinWeightFractionPerNode(self, value: float) -> GBTClassifier: ... + def _create_model(self, java_model: JavaObject) -> GBTClassificationModel: ... class GBTClassificationModel( _TreeEnsembleModel, @@ -691,6 +698,7 @@ class NaiveBayes( def setSmoothing(self, value: float) -> NaiveBayes: ... def setModelType(self, value: str) -> NaiveBayes: ... def setWeightCol(self, value: str) -> NaiveBayes: ... + def _create_model(self, java_model: JavaObject) -> NaiveBayesModel: ... class NaiveBayesModel( _JavaProbabilisticClassificationModel[Vector], @@ -769,6 +777,7 @@ class MultilayerPerceptronClassifier( def setTol(self, value: float) -> MultilayerPerceptronClassifier: ... def setStepSize(self, value: float) -> MultilayerPerceptronClassifier: ... def setSolver(self, value: str) -> MultilayerPerceptronClassifier: ... + def _create_model(self, java_model: JavaObject) -> MultilayerPerceptronClassificationModel: ... class MultilayerPerceptronClassificationModel( _JavaProbabilisticClassificationModel[Vector], @@ -820,6 +829,7 @@ class OneVsRest( weightCol: Optional[str] = ..., parallelism: int = ..., ) -> OneVsRest: ... + def _fit(self, dataset: DataFrame) -> OneVsRestModel: ... def setClassifier(self, value: Estimator[M]) -> OneVsRest: ... def setLabelCol(self, value: str) -> OneVsRest: ... def setFeaturesCol(self, value: str) -> OneVsRest: ... @@ -832,6 +842,7 @@ class OneVsRest( class OneVsRestModel(Model, _OneVsRestParams, MLReadable[OneVsRestModel], MLWritable): models: List[Transformer] def __init__(self, models: List[Transformer]) -> None: ... + def _transform(self, dataset: DataFrame) -> DataFrame: ... def setFeaturesCol(self, value: str) -> OneVsRestModel: ... def setPredictionCol(self, value: str) -> OneVsRestModel: ... def setRawPredictionCol(self, value: str) -> OneVsRestModel: ... @@ -919,6 +930,7 @@ class FMClassifier( def setSeed(self, value: int) -> FMClassifier: ... def setFitIntercept(self, value: bool) -> FMClassifier: ... def setRegParam(self, value: float) -> FMClassifier: ... + def _create_model(self, java_model: JavaObject) -> FMClassificationModel: ... class FMClassificationModel( _JavaProbabilisticClassificationModel[Vector], diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 11fbdf5cf9246..9d2384ffe35b4 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -18,6 +18,10 @@ import sys import warnings +from typing import Any, Dict, List, Optional, TYPE_CHECKING + +import numpy as np + from pyspark import since, keyword_only from pyspark.ml.param.shared import ( HasMaxIter, @@ -45,6 +49,12 @@ from pyspark.ml.common import inherit_doc, _java2py from pyspark.ml.stat import MultivariateGaussian from pyspark.sql import DataFrame +from pyspark.ml.linalg import Vector, Matrix + +if TYPE_CHECKING: + from pyspark.ml._typing import M + from py4j.java_gateway import JavaObject + __all__ = [ "BisectingKMeans", @@ -71,57 +81,57 @@ class ClusteringSummary(JavaWrapper): .. versionadded:: 2.1.0 """ - @property + @property # type: ignore[misc] @since("2.1.0") - def predictionCol(self): + def predictionCol(self) -> str: """ Name for column of predicted clusters in `predictions`. """ return self._call_java("predictionCol") - @property + @property # type: ignore[misc] @since("2.1.0") - def predictions(self): + def predictions(self) -> DataFrame: """ DataFrame produced by the model's `transform` method. """ return self._call_java("predictions") - @property + @property # type: ignore[misc] @since("2.1.0") - def featuresCol(self): + def featuresCol(self) -> str: """ Name for column of features in `predictions`. """ return self._call_java("featuresCol") - @property + @property # type: ignore[misc] @since("2.1.0") - def k(self): + def k(self) -> int: """ The number of clusters the model was trained with. """ return self._call_java("k") - @property + @property # type: ignore[misc] @since("2.1.0") - def cluster(self): + def cluster(self) -> DataFrame: """ DataFrame of predicted cluster centers for each training data point. """ return self._call_java("cluster") - @property + @property # type: ignore[misc] @since("2.1.0") - def clusterSizes(self): + def clusterSizes(self) -> List[int]: """ Size of (number of data points in) each cluster. """ return self._call_java("clusterSizes") - @property + @property # type: ignore[misc] @since("2.4.0") - def numIter(self): + def numIter(self) -> int: """ Number of iterations. """ @@ -145,19 +155,19 @@ class _GaussianMixtureParams( .. versionadded:: 3.0.0 """ - k = Param( + k: Param[int] = Param( Params._dummy(), "k", "Number of independent Gaussians in the mixture model. " + "Must be > 1.", typeConverter=TypeConverters.toInt, ) - def __init__(self, *args): + def __init__(self, *args: Any): super(_GaussianMixtureParams, self).__init__(*args) self._setDefault(k=2, tol=0.01, maxIter=100, aggregationDepth=2) @since("2.0.0") - def getK(self): + def getK(self) -> int: """ Gets the value of `k` """ @@ -165,7 +175,11 @@ def getK(self): class GaussianMixtureModel( - JavaModel, _GaussianMixtureParams, JavaMLWritable, JavaMLReadable, HasTrainingSummary + JavaModel, + _GaussianMixtureParams, + JavaMLWritable, + JavaMLReadable["GaussianMixtureModel"], + HasTrainingSummary["GaussianMixtureSummary"], ): """ Model fitted by GaussianMixture. @@ -174,29 +188,29 @@ class GaussianMixtureModel( """ @since("3.0.0") - def setFeaturesCol(self, value): + def setFeaturesCol(self, value: str) -> "GaussianMixtureModel": """ Sets the value of :py:attr:`featuresCol`. """ return self._set(featuresCol=value) @since("3.0.0") - def setPredictionCol(self, value): + def setPredictionCol(self, value: str) -> "GaussianMixtureModel": """ Sets the value of :py:attr:`predictionCol`. """ return self._set(predictionCol=value) @since("3.0.0") - def setProbabilityCol(self, value): + def setProbabilityCol(self, value: str) -> "GaussianMixtureModel": """ Sets the value of :py:attr:`probabilityCol`. """ return self._set(probabilityCol=value) - @property + @property # type: ignore[misc] @since("2.0.0") - def weights(self): + def weights(self) -> List[float]: """ Weight for each Gaussian distribution in the mixture. This is a multinomial probability distribution over the k Gaussians, @@ -204,23 +218,25 @@ def weights(self): """ return self._call_java("weights") - @property + @property # type: ignore[misc] @since("3.0.0") - def gaussians(self): + def gaussians(self) -> List[MultivariateGaussian]: """ Array of :py:class:`MultivariateGaussian` where gaussians[i] represents the Multivariate Gaussian (Normal) Distribution for Gaussian i """ sc = SparkContext._active_spark_context + assert sc is not None and self._java_obj is not None + jgaussians = self._java_obj.gaussians() return [ MultivariateGaussian(_java2py(sc, jgaussian.mean()), _java2py(sc, jgaussian.cov())) for jgaussian in jgaussians ] - @property + @property # type: ignore[misc] @since("2.0.0") - def gaussiansDF(self): + def gaussiansDF(self) -> DataFrame: """ Retrieve Gaussian distributions as a DataFrame. Each row represents a Gaussian Distribution. @@ -228,9 +244,9 @@ def gaussiansDF(self): """ return self._call_java("gaussiansDF") - @property + @property # type: ignore[misc] @since("2.1.0") - def summary(self): + def summary(self) -> "GaussianMixtureSummary": """ Gets summary (cluster assignments, cluster sizes) of the model trained on the training set. An exception is thrown if no summary exists. @@ -243,14 +259,14 @@ def summary(self): ) @since("3.0.0") - def predict(self, value): + def predict(self, value: Vector) -> int: """ Predict label for the given features. """ return self._call_java("predict", value) @since("3.0.0") - def predictProbability(self, value): + def predictProbability(self, value: Vector) -> Vector: """ Predict probability for the given features. """ @@ -258,7 +274,12 @@ def predictProbability(self, value): @inherit_doc -class GaussianMixture(JavaEstimator, _GaussianMixtureParams, JavaMLWritable, JavaMLReadable): +class GaussianMixture( + JavaEstimator[GaussianMixtureModel], + _GaussianMixtureParams, + JavaMLWritable, + JavaMLReadable["GaussianMixture"], +): """ GaussianMixture clustering. This class performs expectation maximization for multivariate Gaussian @@ -379,19 +400,21 @@ class GaussianMixture(JavaEstimator, _GaussianMixtureParams, JavaMLWritable, Jav GaussianMixture... """ + _input_kwargs: Dict[str, Any] + @keyword_only def __init__( self, *, - featuresCol="features", - predictionCol="prediction", - k=2, - probabilityCol="probability", - tol=0.01, - maxIter=100, - seed=None, - aggregationDepth=2, - weightCol=None, + featuresCol: str = "features", + predictionCol: str = "prediction", + k: int = 2, + probabilityCol: str = "probability", + tol: float = 0.01, + maxIter: int = 100, + seed: Optional[int] = None, + aggregationDepth: int = 2, + weightCol: Optional[str] = None, ): """ __init__(self, \\*, featuresCol="features", predictionCol="prediction", k=2, \ @@ -405,7 +428,7 @@ def __init__( kwargs = self._input_kwargs self.setParams(**kwargs) - def _create_model(self, java_model): + def _create_model(self, java_model: "JavaObject") -> "GaussianMixtureModel": return GaussianMixtureModel(java_model) @keyword_only @@ -413,16 +436,16 @@ def _create_model(self, java_model): def setParams( self, *, - featuresCol="features", - predictionCol="prediction", - k=2, - probabilityCol="probability", - tol=0.01, - maxIter=100, - seed=None, - aggregationDepth=2, - weightCol=None, - ): + featuresCol: str = "features", + predictionCol: str = "prediction", + k: int = 2, + probabilityCol: str = "probability", + tol: float = 0.01, + maxIter: int = 100, + seed: Optional[int] = None, + aggregationDepth: int = 2, + weightCol: Optional[str] = None, + ) -> "GaussianMixture": """ setParams(self, \\*, featuresCol="features", predictionCol="prediction", k=2, \ probabilityCol="probability", tol=0.01, maxIter=100, seed=None, \ @@ -434,63 +457,63 @@ def setParams( return self._set(**kwargs) @since("2.0.0") - def setK(self, value): + def setK(self, value: int) -> "GaussianMixture": """ Sets the value of :py:attr:`k`. """ return self._set(k=value) @since("2.0.0") - def setMaxIter(self, value): + def setMaxIter(self, value: int) -> "GaussianMixture": """ Sets the value of :py:attr:`maxIter`. """ return self._set(maxIter=value) @since("2.0.0") - def setFeaturesCol(self, value): + def setFeaturesCol(self, value: str) -> "GaussianMixture": """ Sets the value of :py:attr:`featuresCol`. """ return self._set(featuresCol=value) @since("2.0.0") - def setPredictionCol(self, value): + def setPredictionCol(self, value: str) -> "GaussianMixture": """ Sets the value of :py:attr:`predictionCol`. """ return self._set(predictionCol=value) @since("2.0.0") - def setProbabilityCol(self, value): + def setProbabilityCol(self, value: str) -> "GaussianMixture": """ Sets the value of :py:attr:`probabilityCol`. """ return self._set(probabilityCol=value) @since("3.0.0") - def setWeightCol(self, value): + def setWeightCol(self, value: str) -> "GaussianMixture": """ Sets the value of :py:attr:`weightCol`. """ return self._set(weightCol=value) @since("2.0.0") - def setSeed(self, value): + def setSeed(self, value: int) -> "GaussianMixture": """ Sets the value of :py:attr:`seed`. """ return self._set(seed=value) @since("2.0.0") - def setTol(self, value): + def setTol(self, value: float) -> "GaussianMixture": """ Sets the value of :py:attr:`tol`. """ return self._set(tol=value) @since("3.0.0") - def setAggregationDepth(self, value): + def setAggregationDepth(self, value: int) -> "GaussianMixture": """ Sets the value of :py:attr:`aggregationDepth`. """ @@ -504,25 +527,25 @@ class GaussianMixtureSummary(ClusteringSummary): .. versionadded:: 2.1.0 """ - @property + @property # type: ignore[misc] @since("2.1.0") - def probabilityCol(self): + def probabilityCol(self) -> str: """ Name for column of predicted probability of each cluster in `predictions`. """ return self._call_java("probabilityCol") - @property + @property # type: ignore[misc] @since("2.1.0") - def probability(self): + def probability(self) -> DataFrame: """ DataFrame of probabilities of each cluster for each training data point. """ return self._call_java("probability") - @property + @property # type: ignore[misc] @since("2.2.0") - def logLikelihood(self): + def logLikelihood(self) -> float: """ Total log-likelihood for this model on the given data. """ @@ -536,9 +559,9 @@ class KMeansSummary(ClusteringSummary): .. versionadded:: 2.1.0 """ - @property + @property # type: ignore[misc] @since("2.4.0") - def trainingCost(self): + def trainingCost(self) -> float: """ K-means cost (sum of squared distances to the nearest centroid for all points in the training dataset). This is equivalent to sklearn's inertia. @@ -556,13 +579,13 @@ class _KMeansParams( .. versionadded:: 3.0.0 """ - k = Param( + k: Param[int] = Param( Params._dummy(), "k", "The number of clusters to create. Must be > 1.", typeConverter=TypeConverters.toInt, ) - initMode = Param( + initMode: Param[str] = Param( Params._dummy(), "initMode", 'The initialization algorithm. This can be either "random" to ' @@ -570,14 +593,14 @@ class _KMeansParams( + "to use a parallel variant of k-means++", typeConverter=TypeConverters.toString, ) - initSteps = Param( + initSteps: Param[int] = Param( Params._dummy(), "initSteps", "The number of steps for k-means|| " + "initialization mode. Must be > 0.", typeConverter=TypeConverters.toInt, ) - def __init__(self, *args): + def __init__(self, *args: Any): super(_KMeansParams, self).__init__(*args) self._setDefault( k=2, @@ -589,21 +612,21 @@ def __init__(self, *args): ) @since("1.5.0") - def getK(self): + def getK(self) -> int: """ Gets the value of `k` """ return self.getOrDefault(self.k) @since("1.5.0") - def getInitMode(self): + def getInitMode(self) -> str: """ Gets the value of `initMode` """ return self.getOrDefault(self.initMode) @since("1.5.0") - def getInitSteps(self): + def getInitSteps(self) -> int: """ Gets the value of `initSteps` """ @@ -611,7 +634,11 @@ def getInitSteps(self): class KMeansModel( - JavaModel, _KMeansParams, GeneralJavaMLWritable, JavaMLReadable, HasTrainingSummary + JavaModel, + _KMeansParams, + GeneralJavaMLWritable, + JavaMLReadable["KMeansModel"], + HasTrainingSummary["KMeansSummary"], ): """ Model fitted by KMeans. @@ -620,27 +647,27 @@ class KMeansModel( """ @since("3.0.0") - def setFeaturesCol(self, value): + def setFeaturesCol(self, value: str) -> "KMeansModel": """ Sets the value of :py:attr:`featuresCol`. """ return self._set(featuresCol=value) @since("3.0.0") - def setPredictionCol(self, value): + def setPredictionCol(self, value: str) -> "KMeansModel": """ Sets the value of :py:attr:`predictionCol`. """ return self._set(predictionCol=value) @since("1.5.0") - def clusterCenters(self): + def clusterCenters(self) -> List[np.ndarray]: """Get the cluster centers, represented as a list of NumPy arrays.""" return [c.toArray() for c in self._call_java("clusterCenters")] - @property + @property # type: ignore[misc] @since("2.1.0") - def summary(self): + def summary(self) -> KMeansSummary: """ Gets summary (cluster assignments, cluster sizes) of the model trained on the training set. An exception is thrown if no summary exists. @@ -653,7 +680,7 @@ def summary(self): ) @since("3.0.0") - def predict(self, value): + def predict(self, value: Vector) -> int: """ Predict label for the given features. """ @@ -661,7 +688,7 @@ def predict(self, value): @inherit_doc -class KMeans(JavaEstimator, _KMeansParams, JavaMLWritable, JavaMLReadable): +class KMeans(JavaEstimator[KMeansModel], _KMeansParams, JavaMLWritable, JavaMLReadable["KMeans"]): """ K-means clustering with a k-means++ like initialization mode (the k-means|| algorithm by Bahmani et al). @@ -727,20 +754,22 @@ class KMeans(JavaEstimator, _KMeansParams, JavaMLWritable, JavaMLReadable): True """ + _input_kwargs: Dict[str, Any] + @keyword_only def __init__( self, *, - featuresCol="features", - predictionCol="prediction", - k=2, - initMode="k-means||", - initSteps=2, - tol=1e-4, - maxIter=20, - seed=None, - distanceMeasure="euclidean", - weightCol=None, + featuresCol: str = "features", + predictionCol: str = "prediction", + k: int = 2, + initMode: str = "k-means||", + initSteps: int = 2, + tol: float = 1e-4, + maxIter: int = 20, + seed: Optional[int] = None, + distanceMeasure: str = "euclidean", + weightCol: Optional[str] = None, ): """ __init__(self, \\*, featuresCol="features", predictionCol="prediction", k=2, \ @@ -752,7 +781,7 @@ def __init__( kwargs = self._input_kwargs self.setParams(**kwargs) - def _create_model(self, java_model): + def _create_model(self, java_model: "JavaObject") -> KMeansModel: return KMeansModel(java_model) @keyword_only @@ -760,17 +789,17 @@ def _create_model(self, java_model): def setParams( self, *, - featuresCol="features", - predictionCol="prediction", - k=2, - initMode="k-means||", - initSteps=2, - tol=1e-4, - maxIter=20, - seed=None, - distanceMeasure="euclidean", - weightCol=None, - ): + featuresCol: str = "features", + predictionCol: str = "prediction", + k: int = 2, + initMode: str = "k-means||", + initSteps: int = 2, + tol: float = 1e-4, + maxIter: int = 20, + seed: Optional[int] = None, + distanceMeasure: str = "euclidean", + weightCol: Optional[str] = None, + ) -> "KMeans": """ setParams(self, \\*, featuresCol="features", predictionCol="prediction", k=2, \ initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, \ @@ -782,70 +811,70 @@ def setParams( return self._set(**kwargs) @since("1.5.0") - def setK(self, value): + def setK(self, value: int) -> "KMeans": """ Sets the value of :py:attr:`k`. """ return self._set(k=value) @since("1.5.0") - def setInitMode(self, value): + def setInitMode(self, value: str) -> "KMeans": """ Sets the value of :py:attr:`initMode`. """ return self._set(initMode=value) @since("1.5.0") - def setInitSteps(self, value): + def setInitSteps(self, value: int) -> "KMeans": """ Sets the value of :py:attr:`initSteps`. """ return self._set(initSteps=value) @since("2.4.0") - def setDistanceMeasure(self, value): + def setDistanceMeasure(self, value: str) -> "KMeans": """ Sets the value of :py:attr:`distanceMeasure`. """ return self._set(distanceMeasure=value) @since("1.5.0") - def setMaxIter(self, value): + def setMaxIter(self, value: int) -> "KMeans": """ Sets the value of :py:attr:`maxIter`. """ return self._set(maxIter=value) @since("1.5.0") - def setFeaturesCol(self, value): + def setFeaturesCol(self, value: str) -> "KMeans": """ Sets the value of :py:attr:`featuresCol`. """ return self._set(featuresCol=value) @since("1.5.0") - def setPredictionCol(self, value): + def setPredictionCol(self, value: str) -> "KMeans": """ Sets the value of :py:attr:`predictionCol`. """ return self._set(predictionCol=value) @since("1.5.0") - def setSeed(self, value): + def setSeed(self, value: int) -> "KMeans": """ Sets the value of :py:attr:`seed`. """ return self._set(seed=value) @since("1.5.0") - def setTol(self, value): + def setTol(self, value: float) -> "KMeans": """ Sets the value of :py:attr:`tol`. """ return self._set(tol=value) @since("3.0.0") - def setWeightCol(self, value): + def setWeightCol(self, value: str) -> "KMeans": """ Sets the value of :py:attr:`weightCol`. """ @@ -854,7 +883,12 @@ def setWeightCol(self, value): @inherit_doc class _BisectingKMeansParams( - HasMaxIter, HasFeaturesCol, HasSeed, HasPredictionCol, HasDistanceMeasure, HasWeightCol + HasMaxIter, + HasFeaturesCol, + HasSeed, + HasPredictionCol, + HasDistanceMeasure, + HasWeightCol, ): """ Params for :py:class:`BisectingKMeans` and :py:class:`BisectingKMeansModel`. @@ -862,13 +896,13 @@ class _BisectingKMeansParams( .. versionadded:: 3.0.0 """ - k = Param( + k: Param[int] = Param( Params._dummy(), "k", "The desired number of leaf clusters. Must be > 1.", typeConverter=TypeConverters.toInt, ) - minDivisibleClusterSize = Param( + minDivisibleClusterSize: Param[float] = Param( Params._dummy(), "minDivisibleClusterSize", "The minimum number of points (if >= 1.0) or the minimum " @@ -876,19 +910,19 @@ class _BisectingKMeansParams( typeConverter=TypeConverters.toFloat, ) - def __init__(self, *args): + def __init__(self, *args: Any): super(_BisectingKMeansParams, self).__init__(*args) self._setDefault(maxIter=20, k=4, minDivisibleClusterSize=1.0) @since("2.0.0") - def getK(self): + def getK(self) -> int: """ Gets the value of `k` or its default value. """ return self.getOrDefault(self.k) @since("2.0.0") - def getMinDivisibleClusterSize(self): + def getMinDivisibleClusterSize(self) -> float: """ Gets the value of `minDivisibleClusterSize` or its default value. """ @@ -896,7 +930,11 @@ def getMinDivisibleClusterSize(self): class BisectingKMeansModel( - JavaModel, _BisectingKMeansParams, JavaMLWritable, JavaMLReadable, HasTrainingSummary + JavaModel, + _BisectingKMeansParams, + JavaMLWritable, + JavaMLReadable["BisectingKMeansModel"], + HasTrainingSummary["BisectingKMeansSummary"], ): """ Model fitted by BisectingKMeans. @@ -905,26 +943,26 @@ class BisectingKMeansModel( """ @since("3.0.0") - def setFeaturesCol(self, value): + def setFeaturesCol(self, value: str) -> "BisectingKMeansModel": """ Sets the value of :py:attr:`featuresCol`. """ return self._set(featuresCol=value) @since("3.0.0") - def setPredictionCol(self, value): + def setPredictionCol(self, value: str) -> "BisectingKMeansModel": """ Sets the value of :py:attr:`predictionCol`. """ return self._set(predictionCol=value) @since("2.0.0") - def clusterCenters(self): + def clusterCenters(self) -> List[np.ndarray]: """Get the cluster centers, represented as a list of NumPy arrays.""" return [c.toArray() for c in self._call_java("clusterCenters")] @since("2.0.0") - def computeCost(self, dataset): + def computeCost(self, dataset: DataFrame) -> float: """ Computes the sum of squared distances between the input points and their corresponding cluster centers. @@ -941,9 +979,9 @@ def computeCost(self, dataset): ) return self._call_java("computeCost", dataset) - @property + @property # type: ignore[misc] @since("2.1.0") - def summary(self): + def summary(self) -> "BisectingKMeansSummary": """ Gets summary (cluster assignments, cluster sizes) of the model trained on the training set. An exception is thrown if no summary exists. @@ -956,7 +994,7 @@ def summary(self): ) @since("3.0.0") - def predict(self, value): + def predict(self, value: Vector) -> int: """ Predict label for the given features. """ @@ -964,7 +1002,12 @@ def predict(self, value): @inherit_doc -class BisectingKMeans(JavaEstimator, _BisectingKMeansParams, JavaMLWritable, JavaMLReadable): +class BisectingKMeans( + JavaEstimator[BisectingKMeansModel], + _BisectingKMeansParams, + JavaMLWritable, + JavaMLReadable["BisectingKMeans"], +): """ A bisecting k-means algorithm based on the paper "A comparison of document clustering techniques" by Steinbach, Karypis, and Kumar, with modification to fit Spark. @@ -1043,18 +1086,20 @@ class BisectingKMeans(JavaEstimator, _BisectingKMeansParams, JavaMLWritable, Jav True """ + _input_kwargs: Dict[str, Any] + @keyword_only def __init__( self, *, - featuresCol="features", - predictionCol="prediction", - maxIter=20, - seed=None, - k=4, - minDivisibleClusterSize=1.0, - distanceMeasure="euclidean", - weightCol=None, + featuresCol: str = "features", + predictionCol: str = "prediction", + maxIter: int = 20, + seed: Optional[int] = None, + k: int = 4, + minDivisibleClusterSize: float = 1.0, + distanceMeasure: str = "euclidean", + weightCol: Optional[str] = None, ): """ __init__(self, \\*, featuresCol="features", predictionCol="prediction", maxIter=20, \ @@ -1073,15 +1118,15 @@ def __init__( def setParams( self, *, - featuresCol="features", - predictionCol="prediction", - maxIter=20, - seed=None, - k=4, - minDivisibleClusterSize=1.0, - distanceMeasure="euclidean", - weightCol=None, - ): + featuresCol: str = "features", + predictionCol: str = "prediction", + maxIter: int = 20, + seed: Optional[int] = None, + k: int = 4, + minDivisibleClusterSize: float = 1.0, + distanceMeasure: str = "euclidean", + weightCol: Optional[str] = None, + ) -> "BisectingKMeans": """ setParams(self, \\*, featuresCol="features", predictionCol="prediction", maxIter=20, \ seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean", \ @@ -1092,62 +1137,62 @@ def setParams( return self._set(**kwargs) @since("2.0.0") - def setK(self, value): + def setK(self, value: int) -> "BisectingKMeans": """ Sets the value of :py:attr:`k`. """ return self._set(k=value) @since("2.0.0") - def setMinDivisibleClusterSize(self, value): + def setMinDivisibleClusterSize(self, value: float) -> "BisectingKMeans": """ Sets the value of :py:attr:`minDivisibleClusterSize`. """ return self._set(minDivisibleClusterSize=value) @since("2.4.0") - def setDistanceMeasure(self, value): + def setDistanceMeasure(self, value: str) -> "BisectingKMeans": """ Sets the value of :py:attr:`distanceMeasure`. """ return self._set(distanceMeasure=value) @since("2.0.0") - def setMaxIter(self, value): + def setMaxIter(self, value: int) -> "BisectingKMeans": """ Sets the value of :py:attr:`maxIter`. """ return self._set(maxIter=value) @since("2.0.0") - def setFeaturesCol(self, value): + def setFeaturesCol(self, value: str) -> "BisectingKMeans": """ Sets the value of :py:attr:`featuresCol`. """ return self._set(featuresCol=value) @since("2.0.0") - def setPredictionCol(self, value): + def setPredictionCol(self, value: str) -> "BisectingKMeans": """ Sets the value of :py:attr:`predictionCol`. """ return self._set(predictionCol=value) @since("2.0.0") - def setSeed(self, value): + def setSeed(self, value: int) -> "BisectingKMeans": """ Sets the value of :py:attr:`seed`. """ return self._set(seed=value) @since("3.0.0") - def setWeightCol(self, value): + def setWeightCol(self, value: str) -> "BisectingKMeans": """ Sets the value of :py:attr:`weightCol`. """ return self._set(weightCol=value) - def _create_model(self, java_model): + def _create_model(self, java_model: "JavaObject") -> BisectingKMeansModel: return BisectingKMeansModel(java_model) @@ -1158,9 +1203,9 @@ class BisectingKMeansSummary(ClusteringSummary): .. versionadded:: 2.1.0 """ - @property + @property # type: ignore[misc] @since("3.0.0") - def trainingCost(self): + def trainingCost(self) -> float: """ Sum of squared distances to the nearest centroid for all points in the training dataset. This is equivalent to sklearn's inertia. @@ -1176,27 +1221,27 @@ class _LDAParams(HasMaxIter, HasFeaturesCol, HasSeed, HasCheckpointInterval): .. versionadded:: 3.0.0 """ - k = Param( + k: Param[int] = Param( Params._dummy(), "k", "The number of topics (clusters) to infer. Must be > 1.", typeConverter=TypeConverters.toInt, ) - optimizer = Param( + optimizer: Param[str] = Param( Params._dummy(), "optimizer", "Optimizer or inference algorithm used to estimate the LDA model. " "Supported: online, em", typeConverter=TypeConverters.toString, ) - learningOffset = Param( + learningOffset: Param[float] = Param( Params._dummy(), "learningOffset", "A (positive) learning parameter that downweights early iterations." " Larger values make early iterations count less", typeConverter=TypeConverters.toFloat, ) - learningDecay = Param( + learningDecay: Param[float] = Param( Params._dummy(), "learningDecay", "Learning rate, set as an" @@ -1204,14 +1249,14 @@ class _LDAParams(HasMaxIter, HasFeaturesCol, HasSeed, HasCheckpointInterval): "guarantee asymptotic convergence.", typeConverter=TypeConverters.toFloat, ) - subsamplingRate = Param( + subsamplingRate: Param[float] = Param( Params._dummy(), "subsamplingRate", "Fraction of the corpus to be sampled and used in each iteration " "of mini-batch gradient descent, in range (0, 1].", typeConverter=TypeConverters.toFloat, ) - optimizeDocConcentration = Param( + optimizeDocConcentration: Param[bool] = Param( Params._dummy(), "optimizeDocConcentration", "Indicates whether the docConcentration (Dirichlet parameter " @@ -1219,21 +1264,21 @@ class _LDAParams(HasMaxIter, HasFeaturesCol, HasSeed, HasCheckpointInterval): "training.", typeConverter=TypeConverters.toBoolean, ) - docConcentration = Param( + docConcentration: Param[List[float]] = Param( Params._dummy(), "docConcentration", 'Concentration parameter (commonly named "alpha") for the ' 'prior placed on documents\' distributions over topics ("theta").', typeConverter=TypeConverters.toListFloat, ) - topicConcentration = Param( + topicConcentration: Param[float] = Param( Params._dummy(), "topicConcentration", 'Concentration parameter (commonly named "beta" or "eta") for ' "the prior placed on topic' distributions over terms.", typeConverter=TypeConverters.toFloat, ) - topicDistributionCol = Param( + topicDistributionCol: Param[str] = Param( Params._dummy(), "topicDistributionCol", "Output column with estimates of the topic mixture distribution " @@ -1241,7 +1286,7 @@ class _LDAParams(HasMaxIter, HasFeaturesCol, HasSeed, HasCheckpointInterval): "Returns a vector of zeros for an empty document.", typeConverter=TypeConverters.toString, ) - keepLastCheckpoint = Param( + keepLastCheckpoint: Param[bool] = Param( Params._dummy(), "keepLastCheckpoint", "(For EM optimizer) If using checkpointing, this indicates whether" @@ -1251,7 +1296,7 @@ class _LDAParams(HasMaxIter, HasFeaturesCol, HasSeed, HasCheckpointInterval): TypeConverters.toBoolean, ) - def __init__(self, *args): + def __init__(self, *args: Any): super(_LDAParams, self).__init__(*args) self._setDefault( maxIter=20, @@ -1267,70 +1312,70 @@ def __init__(self, *args): ) @since("2.0.0") - def getK(self): + def getK(self) -> int: """ Gets the value of :py:attr:`k` or its default value. """ return self.getOrDefault(self.k) @since("2.0.0") - def getOptimizer(self): + def getOptimizer(self) -> str: """ Gets the value of :py:attr:`optimizer` or its default value. """ return self.getOrDefault(self.optimizer) @since("2.0.0") - def getLearningOffset(self): + def getLearningOffset(self) -> float: """ Gets the value of :py:attr:`learningOffset` or its default value. """ return self.getOrDefault(self.learningOffset) @since("2.0.0") - def getLearningDecay(self): + def getLearningDecay(self) -> float: """ Gets the value of :py:attr:`learningDecay` or its default value. """ return self.getOrDefault(self.learningDecay) @since("2.0.0") - def getSubsamplingRate(self): + def getSubsamplingRate(self) -> float: """ Gets the value of :py:attr:`subsamplingRate` or its default value. """ return self.getOrDefault(self.subsamplingRate) @since("2.0.0") - def getOptimizeDocConcentration(self): + def getOptimizeDocConcentration(self) -> bool: """ Gets the value of :py:attr:`optimizeDocConcentration` or its default value. """ return self.getOrDefault(self.optimizeDocConcentration) @since("2.0.0") - def getDocConcentration(self): + def getDocConcentration(self) -> List[float]: """ Gets the value of :py:attr:`docConcentration` or its default value. """ return self.getOrDefault(self.docConcentration) @since("2.0.0") - def getTopicConcentration(self): + def getTopicConcentration(self) -> float: """ Gets the value of :py:attr:`topicConcentration` or its default value. """ return self.getOrDefault(self.topicConcentration) @since("2.0.0") - def getTopicDistributionCol(self): + def getTopicDistributionCol(self) -> str: """ Gets the value of :py:attr:`topicDistributionCol` or its default value. """ return self.getOrDefault(self.topicDistributionCol) @since("2.0.0") - def getKeepLastCheckpoint(self): + def getKeepLastCheckpoint(self) -> bool: """ Gets the value of :py:attr:`keepLastCheckpoint` or its default value. """ @@ -1348,40 +1393,40 @@ class LDAModel(JavaModel, _LDAParams): """ @since("3.0.0") - def setFeaturesCol(self, value): + def setFeaturesCol(self: "M", value: str) -> "M": """ Sets the value of :py:attr:`featuresCol`. """ return self._set(featuresCol=value) @since("3.0.0") - def setSeed(self, value): + def setSeed(self: "M", value: int) -> "M": """ Sets the value of :py:attr:`seed`. """ return self._set(seed=value) @since("3.0.0") - def setTopicDistributionCol(self, value): + def setTopicDistributionCol(self: "M", value: str) -> "M": """ Sets the value of :py:attr:`topicDistributionCol`. """ return self._set(topicDistributionCol=value) @since("2.0.0") - def isDistributed(self): + def isDistributed(self) -> bool: """ Indicates whether this instance is of type DistributedLDAModel """ return self._call_java("isDistributed") @since("2.0.0") - def vocabSize(self): + def vocabSize(self) -> int: """Vocabulary size (number of terms or words in the vocabulary)""" return self._call_java("vocabSize") @since("2.0.0") - def topicsMatrix(self): + def topicsMatrix(self) -> Matrix: """ Inferred topics, where each topic is represented by a distribution over terms. This is a matrix of size vocabSize x k, where each column is a topic. @@ -1395,7 +1440,7 @@ def topicsMatrix(self): return self._call_java("topicsMatrix") @since("2.0.0") - def logLikelihood(self, dataset): + def logLikelihood(self, dataset: DataFrame) -> float: """ Calculates a lower bound on the log likelihood of the entire corpus. See Equation (16) in the Online LDA paper (Hoffman et al., 2010). @@ -1407,7 +1452,7 @@ def logLikelihood(self, dataset): return self._call_java("logLikelihood", dataset) @since("2.0.0") - def logPerplexity(self, dataset): + def logPerplexity(self, dataset: DataFrame) -> float: """ Calculate an upper bound on perplexity. (Lower is better.) See Equation (16) in the Online LDA paper (Hoffman et al., 2010). @@ -1419,14 +1464,14 @@ def logPerplexity(self, dataset): return self._call_java("logPerplexity", dataset) @since("2.0.0") - def describeTopics(self, maxTermsPerTopic=10): + def describeTopics(self, maxTermsPerTopic: int = 10) -> DataFrame: """ Return the topics described by their top-weighted terms. """ return self._call_java("describeTopics", maxTermsPerTopic) @since("2.0.0") - def estimatedDocConcentration(self): + def estimatedDocConcentration(self) -> Vector: """ Value for :py:attr:`LDA.docConcentration` estimated from data. If Online LDA was used and :py:attr:`LDA.optimizeDocConcentration` was set to false, @@ -1436,7 +1481,7 @@ def estimatedDocConcentration(self): @inherit_doc -class DistributedLDAModel(LDAModel, JavaMLReadable, JavaMLWritable): +class DistributedLDAModel(LDAModel, JavaMLReadable["DistributedLDAModel"], JavaMLWritable): """ Distributed model fitted by :py:class:`LDA`. This type of model is currently only produced by Expectation-Maximization (EM). @@ -1448,7 +1493,7 @@ class DistributedLDAModel(LDAModel, JavaMLReadable, JavaMLWritable): """ @since("2.0.0") - def toLocal(self): + def toLocal(self) -> "LocalLDAModel": """ Convert this distributed model to a local representation. This discards info about the training dataset. @@ -1464,7 +1509,7 @@ def toLocal(self): return model @since("2.0.0") - def trainingLogLikelihood(self): + def trainingLogLikelihood(self) -> float: """ Log likelihood of the observed tokens in the training set, given the current parameter estimates: @@ -1482,14 +1527,14 @@ def trainingLogLikelihood(self): return self._call_java("trainingLogLikelihood") @since("2.0.0") - def logPrior(self): + def logPrior(self) -> float: """ Log probability of the current parameter estimate: log P(topics, topic distributions for docs | alpha, eta) """ return self._call_java("logPrior") - def getCheckpointFiles(self): + def getCheckpointFiles(self) -> List[str]: """ If using checkpointing and :py:attr:`LDA.keepLastCheckpoint` is set to true, then there may be saved checkpoint files. This method is provided so that users can manage those files. @@ -1511,7 +1556,7 @@ def getCheckpointFiles(self): @inherit_doc -class LocalLDAModel(LDAModel, JavaMLReadable, JavaMLWritable): +class LocalLDAModel(LDAModel, JavaMLReadable["LocalLDAModel"], JavaMLWritable): """ Local (non-distributed) model fitted by :py:class:`LDA`. This model stores the inferred topics only; it does not store info about the training dataset. @@ -1523,7 +1568,7 @@ class LocalLDAModel(LDAModel, JavaMLReadable, JavaMLWritable): @inherit_doc -class LDA(JavaEstimator, _LDAParams, JavaMLReadable, JavaMLWritable): +class LDA(JavaEstimator[LDAModel], _LDAParams, JavaMLReadable["LDA"], JavaMLWritable): """ Latent Dirichlet Allocation (LDA), a topic model designed for text documents. @@ -1593,24 +1638,26 @@ class LDA(JavaEstimator, _LDAParams, JavaMLReadable, JavaMLWritable): True """ + _input_kwargs: Dict[str, Any] + @keyword_only def __init__( self, *, - featuresCol="features", - maxIter=20, - seed=None, - checkpointInterval=10, - k=10, - optimizer="online", - learningOffset=1024.0, - learningDecay=0.51, - subsamplingRate=0.05, - optimizeDocConcentration=True, - docConcentration=None, - topicConcentration=None, - topicDistributionCol="topicDistribution", - keepLastCheckpoint=True, + featuresCol: str = "features", + maxIter: int = 20, + seed: Optional[int] = None, + checkpointInterval: int = 10, + k: int = 10, + optimizer: str = "online", + learningOffset: float = 1024.0, + learningDecay: float = 0.51, + subsamplingRate: float = 0.05, + optimizeDocConcentration: bool = True, + docConcentration: Optional[List[float]] = None, + topicConcentration: Optional[float] = None, + topicDistributionCol: str = "topicDistribution", + keepLastCheckpoint: bool = True, ): """ __init__(self, \\*, featuresCol="features", maxIter=20, seed=None, checkpointInterval=10,\ @@ -1624,7 +1671,7 @@ def __init__( kwargs = self._input_kwargs self.setParams(**kwargs) - def _create_model(self, java_model): + def _create_model(self, java_model: "JavaObject") -> LDAModel: if self.getOptimizer() == "em": return DistributedLDAModel(java_model) else: @@ -1635,21 +1682,21 @@ def _create_model(self, java_model): def setParams( self, *, - featuresCol="features", - maxIter=20, - seed=None, - checkpointInterval=10, - k=10, - optimizer="online", - learningOffset=1024.0, - learningDecay=0.51, - subsamplingRate=0.05, - optimizeDocConcentration=True, - docConcentration=None, - topicConcentration=None, - topicDistributionCol="topicDistribution", - keepLastCheckpoint=True, - ): + featuresCol: str = "features", + maxIter: int = 20, + seed: Optional[int] = None, + checkpointInterval: int = 10, + k: int = 10, + optimizer: str = "online", + learningOffset: float = 1024.0, + learningDecay: float = 0.51, + subsamplingRate: float = 0.05, + optimizeDocConcentration: bool = True, + docConcentration: Optional[List[float]] = None, + topicConcentration: Optional[float] = None, + topicDistributionCol: str = "topicDistribution", + keepLastCheckpoint: bool = True, + ) -> "LDA": """ setParams(self, \\*, featuresCol="features", maxIter=20, seed=None, checkpointInterval=10,\ k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,\ @@ -1663,21 +1710,21 @@ def setParams( return self._set(**kwargs) @since("2.0.0") - def setCheckpointInterval(self, value): + def setCheckpointInterval(self, value: int) -> "LDA": """ Sets the value of :py:attr:`checkpointInterval`. """ return self._set(checkpointInterval=value) @since("2.0.0") - def setSeed(self, value): + def setSeed(self, value: int) -> "LDA": """ Sets the value of :py:attr:`seed`. """ return self._set(seed=value) @since("2.0.0") - def setK(self, value): + def setK(self, value: int) -> "LDA": """ Sets the value of :py:attr:`k`. @@ -1688,7 +1735,7 @@ def setK(self, value): return self._set(k=value) @since("2.0.0") - def setOptimizer(self, value): + def setOptimizer(self, value: str) -> "LDA": """ Sets the value of :py:attr:`optimizer`. Currently only support 'em' and 'online'. @@ -1702,7 +1749,7 @@ def setOptimizer(self, value): return self._set(optimizer=value) @since("2.0.0") - def setLearningOffset(self, value): + def setLearningOffset(self, value: float) -> "LDA": """ Sets the value of :py:attr:`learningOffset`. @@ -1715,7 +1762,7 @@ def setLearningOffset(self, value): return self._set(learningOffset=value) @since("2.0.0") - def setLearningDecay(self, value): + def setLearningDecay(self, value: float) -> "LDA": """ Sets the value of :py:attr:`learningDecay`. @@ -1728,7 +1775,7 @@ def setLearningDecay(self, value): return self._set(learningDecay=value) @since("2.0.0") - def setSubsamplingRate(self, value): + def setSubsamplingRate(self, value: float) -> "LDA": """ Sets the value of :py:attr:`subsamplingRate`. @@ -1741,7 +1788,7 @@ def setSubsamplingRate(self, value): return self._set(subsamplingRate=value) @since("2.0.0") - def setOptimizeDocConcentration(self, value): + def setOptimizeDocConcentration(self, value: bool) -> "LDA": """ Sets the value of :py:attr:`optimizeDocConcentration`. @@ -1754,7 +1801,7 @@ def setOptimizeDocConcentration(self, value): return self._set(optimizeDocConcentration=value) @since("2.0.0") - def setDocConcentration(self, value): + def setDocConcentration(self, value: List[float]) -> "LDA": """ Sets the value of :py:attr:`docConcentration`. @@ -1767,7 +1814,7 @@ def setDocConcentration(self, value): return self._set(docConcentration=value) @since("2.0.0") - def setTopicConcentration(self, value): + def setTopicConcentration(self, value: float) -> "LDA": """ Sets the value of :py:attr:`topicConcentration`. @@ -1780,7 +1827,7 @@ def setTopicConcentration(self, value): return self._set(topicConcentration=value) @since("2.0.0") - def setTopicDistributionCol(self, value): + def setTopicDistributionCol(self, value: str) -> "LDA": """ Sets the value of :py:attr:`topicDistributionCol`. @@ -1793,7 +1840,7 @@ def setTopicDistributionCol(self, value): return self._set(topicDistributionCol=value) @since("2.0.0") - def setKeepLastCheckpoint(self, value): + def setKeepLastCheckpoint(self, value: bool) -> "LDA": """ Sets the value of :py:attr:`keepLastCheckpoint`. @@ -1806,14 +1853,14 @@ def setKeepLastCheckpoint(self, value): return self._set(keepLastCheckpoint=value) @since("2.0.0") - def setMaxIter(self, value): + def setMaxIter(self, value: int) -> "LDA": """ Sets the value of :py:attr:`maxIter`. """ return self._set(maxIter=value) @since("2.0.0") - def setFeaturesCol(self, value): + def setFeaturesCol(self, value: str) -> "LDA": """ Sets the value of :py:attr:`featuresCol`. """ @@ -1828,13 +1875,13 @@ class _PowerIterationClusteringParams(HasMaxIter, HasWeightCol): .. versionadded:: 3.0.0 """ - k = Param( + k: Param[int] = Param( Params._dummy(), "k", "The number of clusters to create. Must be > 1.", typeConverter=TypeConverters.toInt, ) - initMode = Param( + initMode: Param[str] = Param( Params._dummy(), "initMode", "The initialization algorithm. This can be either " @@ -1843,46 +1890,46 @@ class _PowerIterationClusteringParams(HasMaxIter, HasWeightCol): + "'random' and 'degree'.", typeConverter=TypeConverters.toString, ) - srcCol = Param( + srcCol: Param[str] = Param( Params._dummy(), "srcCol", "Name of the input column for source vertex IDs.", typeConverter=TypeConverters.toString, ) - dstCol = Param( + dstCol: Param[str] = Param( Params._dummy(), "dstCol", "Name of the input column for destination vertex IDs.", typeConverter=TypeConverters.toString, ) - def __init__(self, *args): + def __init__(self, *args: Any): super(_PowerIterationClusteringParams, self).__init__(*args) self._setDefault(k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst") @since("2.4.0") - def getK(self): + def getK(self) -> int: """ Gets the value of :py:attr:`k` or its default value. """ return self.getOrDefault(self.k) @since("2.4.0") - def getInitMode(self): + def getInitMode(self) -> str: """ Gets the value of :py:attr:`initMode` or its default value. """ return self.getOrDefault(self.initMode) @since("2.4.0") - def getSrcCol(self): + def getSrcCol(self) -> str: """ Gets the value of :py:attr:`srcCol` or its default value. """ return self.getOrDefault(self.srcCol) @since("2.4.0") - def getDstCol(self): + def getDstCol(self) -> str: """ Gets the value of :py:attr:`dstCol` or its default value. """ @@ -1891,7 +1938,10 @@ def getDstCol(self): @inherit_doc class PowerIterationClustering( - _PowerIterationClusteringParams, JavaParams, JavaMLReadable, JavaMLWritable + _PowerIterationClusteringParams, + JavaParams, + JavaMLReadable["PowerIterationClustering"], + JavaMLWritable, ): """ Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by @@ -1943,9 +1993,18 @@ class PowerIterationClustering( True """ + _input_kwargs: Dict[str, Any] + @keyword_only def __init__( - self, *, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst", weightCol=None + self, + *, + k: int = 2, + maxIter: int = 20, + initMode: str = "random", + srcCol: str = "src", + dstCol: str = "dst", + weightCol: Optional[str] = None, ): """ __init__(self, \\*, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst",\ @@ -1961,8 +2020,15 @@ def __init__( @keyword_only @since("2.4.0") def setParams( - self, *, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst", weightCol=None - ): + self, + *, + k: int = 2, + maxIter: int = 20, + initMode: str = "random", + srcCol: str = "src", + dstCol: str = "dst", + weightCol: Optional[str] = None, + ) -> "PowerIterationClustering": """ setParams(self, \\*, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst",\ weightCol=None) @@ -1972,49 +2038,49 @@ def setParams( return self._set(**kwargs) @since("2.4.0") - def setK(self, value): + def setK(self, value: int) -> "PowerIterationClustering": """ Sets the value of :py:attr:`k`. """ return self._set(k=value) @since("2.4.0") - def setInitMode(self, value): + def setInitMode(self, value: str) -> "PowerIterationClustering": """ Sets the value of :py:attr:`initMode`. """ return self._set(initMode=value) @since("2.4.0") - def setSrcCol(self, value): + def setSrcCol(self, value: str) -> "PowerIterationClustering": """ Sets the value of :py:attr:`srcCol`. """ return self._set(srcCol=value) @since("2.4.0") - def setDstCol(self, value): + def setDstCol(self, value: str) -> "PowerIterationClustering": """ Sets the value of :py:attr:`dstCol`. """ return self._set(dstCol=value) @since("2.4.0") - def setMaxIter(self, value): + def setMaxIter(self, value: int) -> "PowerIterationClustering": """ Sets the value of :py:attr:`maxIter`. """ return self._set(maxIter=value) @since("2.4.0") - def setWeightCol(self, value): + def setWeightCol(self, value: str) -> "PowerIterationClustering": """ Sets the value of :py:attr:`weightCol`. """ return self._set(weightCol=value) @since("2.4.0") - def assignClusters(self, dataset): + def assignClusters(self, dataset: DataFrame) -> DataFrame: """ Run the PIC algorithm and returns a cluster assignment for each input vertex. @@ -2038,8 +2104,10 @@ def assignClusters(self, dataset): - cluster: Int """ self._transfer_params_to_java() + assert self._java_obj is not None + jdf = self._java_obj.assignClusters(dataset._jdf) - return DataFrame(jdf, dataset.sql_ctx) + return DataFrame(jdf, dataset.sparkSession) if __name__ == "__main__": diff --git a/python/pyspark/ml/clustering.pyi b/python/pyspark/ml/clustering.pyi deleted file mode 100644 index 81074fc285273..0000000000000 --- a/python/pyspark/ml/clustering.pyi +++ /dev/null @@ -1,433 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import Any, List, Optional - -from pyspark.ml.linalg import Matrix, Vector -from pyspark.ml.util import ( - GeneralJavaMLWritable, - HasTrainingSummary, - JavaMLReadable, - JavaMLWritable, -) -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaWrapper -from pyspark.ml.param.shared import ( - HasAggregationDepth, - HasCheckpointInterval, - HasDistanceMeasure, - HasFeaturesCol, - HasMaxIter, - HasPredictionCol, - HasProbabilityCol, - HasSeed, - HasTol, - HasWeightCol, -) - -from pyspark.ml.param import Param -from pyspark.ml.stat import MultivariateGaussian -from pyspark.sql.dataframe import DataFrame - -from numpy import ndarray - -class ClusteringSummary(JavaWrapper): - @property - def predictionCol(self) -> str: ... - @property - def predictions(self) -> DataFrame: ... - @property - def featuresCol(self) -> str: ... - @property - def k(self) -> int: ... - @property - def cluster(self) -> DataFrame: ... - @property - def clusterSizes(self) -> List[int]: ... - @property - def numIter(self) -> int: ... - -class _GaussianMixtureParams( - HasMaxIter, - HasFeaturesCol, - HasSeed, - HasPredictionCol, - HasProbabilityCol, - HasTol, - HasAggregationDepth, - HasWeightCol, -): - k: Param[int] - def __init__(self, *args: Any): ... - def getK(self) -> int: ... - -class GaussianMixtureModel( - JavaModel, - _GaussianMixtureParams, - JavaMLWritable, - JavaMLReadable[GaussianMixtureModel], - HasTrainingSummary[GaussianMixtureSummary], -): - def setFeaturesCol(self, value: str) -> GaussianMixtureModel: ... - def setPredictionCol(self, value: str) -> GaussianMixtureModel: ... - def setProbabilityCol(self, value: str) -> GaussianMixtureModel: ... - @property - def weights(self) -> List[float]: ... - @property - def gaussians(self) -> List[MultivariateGaussian]: ... - @property - def gaussiansDF(self) -> DataFrame: ... - @property - def summary(self) -> GaussianMixtureSummary: ... - def predict(self, value: Vector) -> int: ... - def predictProbability(self, value: Vector) -> Vector: ... - -class GaussianMixture( - JavaEstimator[GaussianMixtureModel], - _GaussianMixtureParams, - JavaMLWritable, - JavaMLReadable[GaussianMixture], -): - def __init__( - self, - *, - featuresCol: str = ..., - predictionCol: str = ..., - k: int = ..., - probabilityCol: str = ..., - tol: float = ..., - maxIter: int = ..., - seed: Optional[int] = ..., - aggregationDepth: int = ..., - weightCol: Optional[str] = ..., - ) -> None: ... - def setParams( - self, - *, - featuresCol: str = ..., - predictionCol: str = ..., - k: int = ..., - probabilityCol: str = ..., - tol: float = ..., - maxIter: int = ..., - seed: Optional[int] = ..., - aggregationDepth: int = ..., - weightCol: Optional[str] = ..., - ) -> GaussianMixture: ... - def setK(self, value: int) -> GaussianMixture: ... - def setMaxIter(self, value: int) -> GaussianMixture: ... - def setFeaturesCol(self, value: str) -> GaussianMixture: ... - def setPredictionCol(self, value: str) -> GaussianMixture: ... - def setProbabilityCol(self, value: str) -> GaussianMixture: ... - def setWeightCol(self, value: str) -> GaussianMixture: ... - def setSeed(self, value: int) -> GaussianMixture: ... - def setTol(self, value: float) -> GaussianMixture: ... - def setAggregationDepth(self, value: int) -> GaussianMixture: ... - -class GaussianMixtureSummary(ClusteringSummary): - @property - def probabilityCol(self) -> str: ... - @property - def probability(self) -> DataFrame: ... - @property - def logLikelihood(self) -> float: ... - -class KMeansSummary(ClusteringSummary): - def trainingCost(self) -> float: ... - -class _KMeansParams( - HasMaxIter, - HasFeaturesCol, - HasSeed, - HasPredictionCol, - HasTol, - HasDistanceMeasure, - HasWeightCol, -): - k: Param[int] - initMode: Param[str] - initSteps: Param[int] - def __init__(self, *args: Any): ... - def getK(self) -> int: ... - def getInitMode(self) -> str: ... - def getInitSteps(self) -> int: ... - -class KMeansModel( - JavaModel, - _KMeansParams, - GeneralJavaMLWritable, - JavaMLReadable[KMeansModel], - HasTrainingSummary[KMeansSummary], -): - def setFeaturesCol(self, value: str) -> KMeansModel: ... - def setPredictionCol(self, value: str) -> KMeansModel: ... - def clusterCenters(self) -> List[ndarray]: ... - @property - def summary(self) -> KMeansSummary: ... - def predict(self, value: Vector) -> int: ... - -class KMeans(JavaEstimator[KMeansModel], _KMeansParams, JavaMLWritable, JavaMLReadable[KMeans]): - def __init__( - self, - *, - featuresCol: str = ..., - predictionCol: str = ..., - k: int = ..., - initMode: str = ..., - initSteps: int = ..., - tol: float = ..., - maxIter: int = ..., - seed: Optional[int] = ..., - distanceMeasure: str = ..., - weightCol: Optional[str] = ..., - ) -> None: ... - def setParams( - self, - *, - featuresCol: str = ..., - predictionCol: str = ..., - k: int = ..., - initMode: str = ..., - initSteps: int = ..., - tol: float = ..., - maxIter: int = ..., - seed: Optional[int] = ..., - distanceMeasure: str = ..., - weightCol: Optional[str] = ..., - ) -> KMeans: ... - def setK(self, value: int) -> KMeans: ... - def setInitMode(self, value: str) -> KMeans: ... - def setInitSteps(self, value: int) -> KMeans: ... - def setDistanceMeasure(self, value: str) -> KMeans: ... - def setMaxIter(self, value: int) -> KMeans: ... - def setFeaturesCol(self, value: str) -> KMeans: ... - def setPredictionCol(self, value: str) -> KMeans: ... - def setSeed(self, value: int) -> KMeans: ... - def setTol(self, value: float) -> KMeans: ... - def setWeightCol(self, value: str) -> KMeans: ... - -class _BisectingKMeansParams( - HasMaxIter, - HasFeaturesCol, - HasSeed, - HasPredictionCol, - HasDistanceMeasure, - HasWeightCol, -): - k: Param[int] - minDivisibleClusterSize: Param[float] - def __init__(self, *args: Any): ... - def getK(self) -> int: ... - def getMinDivisibleClusterSize(self) -> float: ... - -class BisectingKMeansModel( - JavaModel, - _BisectingKMeansParams, - JavaMLWritable, - JavaMLReadable[BisectingKMeansModel], - HasTrainingSummary[BisectingKMeansSummary], -): - def setFeaturesCol(self, value: str) -> BisectingKMeansModel: ... - def setPredictionCol(self, value: str) -> BisectingKMeansModel: ... - def clusterCenters(self) -> List[ndarray]: ... - def computeCost(self, dataset: DataFrame) -> float: ... - @property - def summary(self) -> BisectingKMeansSummary: ... - def predict(self, value: Vector) -> int: ... - -class BisectingKMeans( - JavaEstimator[BisectingKMeansModel], - _BisectingKMeansParams, - JavaMLWritable, - JavaMLReadable[BisectingKMeans], -): - def __init__( - self, - *, - featuresCol: str = ..., - predictionCol: str = ..., - maxIter: int = ..., - seed: Optional[int] = ..., - k: int = ..., - minDivisibleClusterSize: float = ..., - distanceMeasure: str = ..., - weightCol: Optional[str] = ..., - ) -> None: ... - def setParams( - self, - *, - featuresCol: str = ..., - predictionCol: str = ..., - maxIter: int = ..., - seed: Optional[int] = ..., - k: int = ..., - minDivisibleClusterSize: float = ..., - distanceMeasure: str = ..., - weightCol: Optional[str] = ..., - ) -> BisectingKMeans: ... - def setK(self, value: int) -> BisectingKMeans: ... - def setMinDivisibleClusterSize(self, value: float) -> BisectingKMeans: ... - def setDistanceMeasure(self, value: str) -> BisectingKMeans: ... - def setMaxIter(self, value: int) -> BisectingKMeans: ... - def setFeaturesCol(self, value: str) -> BisectingKMeans: ... - def setPredictionCol(self, value: str) -> BisectingKMeans: ... - def setSeed(self, value: int) -> BisectingKMeans: ... - def setWeightCol(self, value: str) -> BisectingKMeans: ... - -class BisectingKMeansSummary(ClusteringSummary): - @property - def trainingCost(self) -> float: ... - -class _LDAParams(HasMaxIter, HasFeaturesCol, HasSeed, HasCheckpointInterval): - k: Param[int] - optimizer: Param[str] - learningOffset: Param[float] - learningDecay: Param[float] - subsamplingRate: Param[float] - optimizeDocConcentration: Param[bool] - docConcentration: Param[List[float]] - topicConcentration: Param[float] - topicDistributionCol: Param[str] - keepLastCheckpoint: Param[bool] - def __init__(self, *args: Any): ... - def setK(self, value: int) -> LDA: ... - def getOptimizer(self) -> str: ... - def getLearningOffset(self) -> float: ... - def getLearningDecay(self) -> float: ... - def getSubsamplingRate(self) -> float: ... - def getOptimizeDocConcentration(self) -> bool: ... - def getDocConcentration(self) -> List[float]: ... - def getTopicConcentration(self) -> float: ... - def getTopicDistributionCol(self) -> str: ... - def getKeepLastCheckpoint(self) -> bool: ... - -class LDAModel(JavaModel, _LDAParams): - def setFeaturesCol(self, value: str) -> LDAModel: ... - def setSeed(self, value: int) -> LDAModel: ... - def setTopicDistributionCol(self, value: str) -> LDAModel: ... - def isDistributed(self) -> bool: ... - def vocabSize(self) -> int: ... - def topicsMatrix(self) -> Matrix: ... - def logLikelihood(self, dataset: DataFrame) -> float: ... - def logPerplexity(self, dataset: DataFrame) -> float: ... - def describeTopics(self, maxTermsPerTopic: int = ...) -> DataFrame: ... - def estimatedDocConcentration(self) -> Vector: ... - -class DistributedLDAModel(LDAModel, JavaMLReadable[DistributedLDAModel], JavaMLWritable): - def toLocal(self) -> LDAModel: ... - def trainingLogLikelihood(self) -> float: ... - def logPrior(self) -> float: ... - def getCheckpointFiles(self) -> List[str]: ... - -class LocalLDAModel(LDAModel, JavaMLReadable[LocalLDAModel], JavaMLWritable): ... - -class LDA(JavaEstimator[LDAModel], _LDAParams, JavaMLReadable[LDA], JavaMLWritable): - def __init__( - self, - *, - featuresCol: str = ..., - maxIter: int = ..., - seed: Optional[int] = ..., - checkpointInterval: int = ..., - k: int = ..., - optimizer: str = ..., - learningOffset: float = ..., - learningDecay: float = ..., - subsamplingRate: float = ..., - optimizeDocConcentration: bool = ..., - docConcentration: Optional[List[float]] = ..., - topicConcentration: Optional[float] = ..., - topicDistributionCol: str = ..., - keepLastCheckpoint: bool = ..., - ) -> None: ... - def setParams( - self, - *, - featuresCol: str = ..., - maxIter: int = ..., - seed: Optional[int] = ..., - checkpointInterval: int = ..., - k: int = ..., - optimizer: str = ..., - learningOffset: float = ..., - learningDecay: float = ..., - subsamplingRate: float = ..., - optimizeDocConcentration: bool = ..., - docConcentration: Optional[List[float]] = ..., - topicConcentration: Optional[float] = ..., - topicDistributionCol: str = ..., - keepLastCheckpoint: bool = ..., - ) -> LDA: ... - def setCheckpointInterval(self, value: int) -> LDA: ... - def setSeed(self, value: int) -> LDA: ... - def setK(self, value: int) -> LDA: ... - def setOptimizer(self, value: str) -> LDA: ... - def setLearningOffset(self, value: float) -> LDA: ... - def setLearningDecay(self, value: float) -> LDA: ... - def setSubsamplingRate(self, value: float) -> LDA: ... - def setOptimizeDocConcentration(self, value: bool) -> LDA: ... - def setDocConcentration(self, value: List[float]) -> LDA: ... - def setTopicConcentration(self, value: float) -> LDA: ... - def setTopicDistributionCol(self, value: str) -> LDA: ... - def setKeepLastCheckpoint(self, value: bool) -> LDA: ... - def setMaxIter(self, value: int) -> LDA: ... - def setFeaturesCol(self, value: str) -> LDA: ... - -class _PowerIterationClusteringParams(HasMaxIter, HasWeightCol): - k: Param[int] - initMode: Param[str] - srcCol: Param[str] - dstCol: Param[str] - def __init__(self, *args: Any): ... - def getK(self) -> int: ... - def getInitMode(self) -> str: ... - def getSrcCol(self) -> str: ... - def getDstCol(self) -> str: ... - -class PowerIterationClustering( - _PowerIterationClusteringParams, - JavaParams, - JavaMLReadable[PowerIterationClustering], - JavaMLWritable, -): - def __init__( - self, - *, - k: int = ..., - maxIter: int = ..., - initMode: str = ..., - srcCol: str = ..., - dstCol: str = ..., - weightCol: Optional[str] = ..., - ) -> None: ... - def setParams( - self, - *, - k: int = ..., - maxIter: int = ..., - initMode: str = ..., - srcCol: str = ..., - dstCol: str = ..., - weightCol: Optional[str] = ..., - ) -> PowerIterationClustering: ... - def setK(self, value: int) -> PowerIterationClustering: ... - def setInitMode(self, value: str) -> PowerIterationClustering: ... - def setSrcCol(self, value: str) -> str: ... - def setDstCol(self, value: str) -> PowerIterationClustering: ... - def setMaxIter(self, value: int) -> PowerIterationClustering: ... - def setWeightCol(self, value: str) -> PowerIterationClustering: ... - def assignClusters(self, dataset: DataFrame) -> DataFrame: ... diff --git a/python/pyspark/ml/common.py b/python/pyspark/ml/common.py index 2329421b9ee09..dd6fee467e699 100644 --- a/python/pyspark/ml/common.py +++ b/python/pyspark/ml/common.py @@ -65,11 +65,9 @@ def _to_java_object_rdd(rdd: RDD) -> JavaObject: It will convert each Python object into Java object by Pickle, whenever the RDD is serialized in batch or not. """ - rdd = rdd._reserialize(AutoBatchedSerializer(CPickleSerializer())) # type: ignore[attr-defined] + rdd = rdd._reserialize(AutoBatchedSerializer(CPickleSerializer())) assert rdd.ctx._jvm is not None - return rdd.ctx._jvm.org.apache.spark.ml.python.MLSerDe.pythonToJava( - rdd._jrdd, True # type: ignore[attr-defined] - ) + return rdd.ctx._jvm.org.apache.spark.ml.python.MLSerDe.pythonToJava(rdd._jrdd, True) def _py2java(sc: SparkContext, obj: Any) -> JavaObject: @@ -79,7 +77,7 @@ def _py2java(sc: SparkContext, obj: Any) -> JavaObject: elif isinstance(obj, DataFrame): obj = obj._jdf elif isinstance(obj, SparkContext): - obj = obj._jsc # type: ignore[attr-defined] + obj = obj._jsc elif isinstance(obj, list): obj = [_py2java(sc, x) for x in obj] elif isinstance(obj, JavaObject): @@ -108,7 +106,7 @@ def _java2py(sc: SparkContext, r: "JavaObjectOrPickleDump", encoding: str = "byt return RDD(jrdd, sc) if clsName == "Dataset": - return DataFrame(r, SparkSession(sc)._wrapped) + return DataFrame(r, SparkSession._getActiveSessionOrCreate()) if clsName in _picklable_classes: r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index be63a8f8ce972..ff0e5b91e424d 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -18,6 +18,8 @@ import sys from abc import abstractmethod, ABCMeta +from typing import Any, Dict, Optional, TYPE_CHECKING + from pyspark import since, keyword_only from pyspark.ml.wrapper import JavaParams from pyspark.ml.param import Param, Params, TypeConverters @@ -31,6 +33,20 @@ ) from pyspark.ml.common import inherit_doc from pyspark.ml.util import JavaMLReadable, JavaMLWritable +from pyspark.sql.dataframe import DataFrame + +if TYPE_CHECKING: + from pyspark.ml._typing import ( + ParamMap, + BinaryClassificationEvaluatorMetricType, + ClusteringEvaluatorDistanceMeasureType, + ClusteringEvaluatorMetricType, + MulticlassClassificationEvaluatorMetricType, + MultilabelClassificationEvaluatorMetricType, + RankingEvaluatorMetricType, + RegressionEvaluatorMetricType, + ) + __all__ = [ "Evaluator", @@ -54,7 +70,7 @@ class Evaluator(Params, metaclass=ABCMeta): pass @abstractmethod - def _evaluate(self, dataset): + def _evaluate(self, dataset: DataFrame) -> float: """ Evaluates the output. @@ -70,7 +86,7 @@ def _evaluate(self, dataset): """ raise NotImplementedError() - def evaluate(self, dataset, params=None): + def evaluate(self, dataset: DataFrame, params: Optional["ParamMap"] = None) -> float: """ Evaluates the output with optional parameters. @@ -99,7 +115,7 @@ def evaluate(self, dataset, params=None): raise TypeError("Params must be a param map but got %s." % type(params)) @since("1.5.0") - def isLargerBetter(self): + def isLargerBetter(self) -> bool: """ Indicates whether the metric returned by :py:meth:`evaluate` should be maximized (True, default) or minimized (False). @@ -115,7 +131,7 @@ class JavaEvaluator(JavaParams, Evaluator, metaclass=ABCMeta): implementations. """ - def _evaluate(self, dataset): + def _evaluate(self, dataset: DataFrame) -> float: """ Evaluates the output. @@ -130,16 +146,23 @@ def _evaluate(self, dataset): evaluation metric """ self._transfer_params_to_java() + assert self._java_obj is not None return self._java_obj.evaluate(dataset._jdf) - def isLargerBetter(self): + def isLargerBetter(self) -> bool: self._transfer_params_to_java() + assert self._java_obj is not None return self._java_obj.isLargerBetter() @inherit_doc class BinaryClassificationEvaluator( - JavaEvaluator, HasLabelCol, HasRawPredictionCol, HasWeightCol, JavaMLReadable, JavaMLWritable + JavaEvaluator, + HasLabelCol, + HasRawPredictionCol, + HasWeightCol, + JavaMLReadable["BinaryClassificationEvaluator"], + JavaMLWritable, ): """ Evaluator for binary classification, which expects input columns rawPrediction, label @@ -182,14 +205,14 @@ class BinaryClassificationEvaluator( 1000 """ - metricName = Param( + metricName: Param["BinaryClassificationEvaluatorMetricType"] = Param( Params._dummy(), "metricName", "metric name in evaluation (areaUnderROC|areaUnderPR)", - typeConverter=TypeConverters.toString, + typeConverter=TypeConverters.toString, # type: ignore[arg-type] ) - numBins = Param( + numBins: Param[int] = Param( Params._dummy(), "numBins", "Number of bins to down-sample the curves " @@ -198,15 +221,17 @@ class BinaryClassificationEvaluator( typeConverter=TypeConverters.toInt, ) + _input_kwargs: Dict[str, Any] + @keyword_only def __init__( self, *, - rawPredictionCol="rawPrediction", - labelCol="label", - metricName="areaUnderROC", - weightCol=None, - numBins=1000, + rawPredictionCol: str = "rawPrediction", + labelCol: str = "label", + metricName: "BinaryClassificationEvaluatorMetricType" = "areaUnderROC", + weightCol: Optional[str] = None, + numBins: int = 1000, ): """ __init__(self, \\*, rawPredictionCol="rawPrediction", labelCol="label", \ @@ -221,47 +246,49 @@ def __init__( self._set(**kwargs) @since("1.4.0") - def setMetricName(self, value): + def setMetricName( + self, value: "BinaryClassificationEvaluatorMetricType" + ) -> "BinaryClassificationEvaluator": """ Sets the value of :py:attr:`metricName`. """ return self._set(metricName=value) @since("1.4.0") - def getMetricName(self): + def getMetricName(self) -> str: """ Gets the value of metricName or its default value. """ return self.getOrDefault(self.metricName) @since("3.0.0") - def setNumBins(self, value): + def setNumBins(self, value: int) -> "BinaryClassificationEvaluator": """ Sets the value of :py:attr:`numBins`. """ return self._set(numBins=value) @since("3.0.0") - def getNumBins(self): + def getNumBins(self) -> int: """ Gets the value of numBins or its default value. """ return self.getOrDefault(self.numBins) - def setLabelCol(self, value): + def setLabelCol(self, value: str) -> "BinaryClassificationEvaluator": """ Sets the value of :py:attr:`labelCol`. """ return self._set(labelCol=value) - def setRawPredictionCol(self, value): + def setRawPredictionCol(self, value: str) -> "BinaryClassificationEvaluator": """ Sets the value of :py:attr:`rawPredictionCol`. """ return self._set(rawPredictionCol=value) @since("3.0.0") - def setWeightCol(self, value): + def setWeightCol(self, value: str) -> "BinaryClassificationEvaluator": """ Sets the value of :py:attr:`weightCol`. """ @@ -272,12 +299,12 @@ def setWeightCol(self, value): def setParams( self, *, - rawPredictionCol="rawPrediction", - labelCol="label", - metricName="areaUnderROC", - weightCol=None, - numBins=1000, - ): + rawPredictionCol: str = "rawPrediction", + labelCol: str = "label", + metricName: "BinaryClassificationEvaluatorMetricType" = "areaUnderROC", + weightCol: Optional[str] = None, + numBins: int = 1000, + ) -> "BinaryClassificationEvaluator": """ setParams(self, \\*, rawPredictionCol="rawPrediction", labelCol="label", \ metricName="areaUnderROC", weightCol=None, numBins=1000) @@ -289,7 +316,12 @@ def setParams( @inherit_doc class RegressionEvaluator( - JavaEvaluator, HasLabelCol, HasPredictionCol, HasWeightCol, JavaMLReadable, JavaMLWritable + JavaEvaluator, + HasLabelCol, + HasPredictionCol, + HasWeightCol, + JavaMLReadable["RegressionEvaluator"], + JavaMLWritable, ): """ Evaluator for Regression, which expects input columns prediction, label @@ -328,7 +360,7 @@ class RegressionEvaluator( False """ - metricName = Param( + metricName: Param["RegressionEvaluatorMetricType"] = Param( Params._dummy(), "metricName", """metric name in evaluation - one of: @@ -337,25 +369,27 @@ class RegressionEvaluator( r2 - r^2 metric mae - mean absolute error var - explained variance.""", - typeConverter=TypeConverters.toString, + typeConverter=TypeConverters.toString, # type: ignore[arg-type] ) - throughOrigin = Param( + throughOrigin: Param[bool] = Param( Params._dummy(), "throughOrigin", "whether the regression is through the origin.", typeConverter=TypeConverters.toBoolean, ) + _input_kwargs: Dict[str, Any] + @keyword_only def __init__( self, *, - predictionCol="prediction", - labelCol="label", - metricName="rmse", - weightCol=None, - throughOrigin=False, + predictionCol: str = "prediction", + labelCol: str = "label", + metricName: "RegressionEvaluatorMetricType" = "rmse", + weightCol: Optional[str] = None, + throughOrigin: bool = False, ): """ __init__(self, \\*, predictionCol="prediction", labelCol="label", \ @@ -370,47 +404,47 @@ def __init__( self._set(**kwargs) @since("1.4.0") - def setMetricName(self, value): + def setMetricName(self, value: "RegressionEvaluatorMetricType") -> "RegressionEvaluator": """ Sets the value of :py:attr:`metricName`. """ return self._set(metricName=value) @since("1.4.0") - def getMetricName(self): + def getMetricName(self) -> "RegressionEvaluatorMetricType": """ Gets the value of metricName or its default value. """ return self.getOrDefault(self.metricName) @since("3.0.0") - def setThroughOrigin(self, value): + def setThroughOrigin(self, value: bool) -> "RegressionEvaluator": """ Sets the value of :py:attr:`throughOrigin`. """ return self._set(throughOrigin=value) @since("3.0.0") - def getThroughOrigin(self): + def getThroughOrigin(self) -> bool: """ Gets the value of throughOrigin or its default value. """ return self.getOrDefault(self.throughOrigin) - def setLabelCol(self, value): + def setLabelCol(self, value: str) -> "RegressionEvaluator": """ Sets the value of :py:attr:`labelCol`. """ return self._set(labelCol=value) - def setPredictionCol(self, value): + def setPredictionCol(self, value: str) -> "RegressionEvaluator": """ Sets the value of :py:attr:`predictionCol`. """ return self._set(predictionCol=value) @since("3.0.0") - def setWeightCol(self, value): + def setWeightCol(self, value: str) -> "RegressionEvaluator": """ Sets the value of :py:attr:`weightCol`. """ @@ -421,12 +455,12 @@ def setWeightCol(self, value): def setParams( self, *, - predictionCol="prediction", - labelCol="label", - metricName="rmse", - weightCol=None, - throughOrigin=False, - ): + predictionCol: str = "prediction", + labelCol: str = "label", + metricName: "RegressionEvaluatorMetricType" = "rmse", + weightCol: Optional[str] = None, + throughOrigin: bool = False, + ) -> "RegressionEvaluator": """ setParams(self, \\*, predictionCol="prediction", labelCol="label", \ metricName="rmse", weightCol=None, throughOrigin=False) @@ -443,7 +477,7 @@ class MulticlassClassificationEvaluator( HasPredictionCol, HasWeightCol, HasProbabilityCol, - JavaMLReadable, + JavaMLReadable["MulticlassClassificationEvaluator"], JavaMLWritable, ): """ @@ -499,7 +533,7 @@ class MulticlassClassificationEvaluator( 0.9682... """ - metricName = Param( + metricName: Param["MulticlassClassificationEvaluatorMetricType"] = Param( Params._dummy(), "metricName", "metric name in evaluation " @@ -507,9 +541,9 @@ class MulticlassClassificationEvaluator( "weightedFalsePositiveRate|weightedFMeasure|truePositiveRateByLabel| " "falsePositiveRateByLabel|precisionByLabel|recallByLabel|fMeasureByLabel| " "logLoss|hammingLoss)", - typeConverter=TypeConverters.toString, + typeConverter=TypeConverters.toString, # type: ignore[arg-type] ) - metricLabel = Param( + metricLabel: Param[float] = Param( Params._dummy(), "metricLabel", "The class whose metric will be computed in truePositiveRateByLabel|" @@ -517,14 +551,14 @@ class MulticlassClassificationEvaluator( " Must be >= 0. The default value is 0.", typeConverter=TypeConverters.toFloat, ) - beta = Param( + beta: Param[float] = Param( Params._dummy(), "beta", "The beta value used in weightedFMeasure|fMeasureByLabel." " Must be > 0. The default value is 1.", typeConverter=TypeConverters.toFloat, ) - eps = Param( + eps: Param[float] = Param( Params._dummy(), "eps", "log-loss is undefined for p=0 or p=1, so probabilities are clipped to " @@ -533,18 +567,20 @@ class MulticlassClassificationEvaluator( typeConverter=TypeConverters.toFloat, ) + _input_kwargs: Dict[str, Any] + @keyword_only def __init__( self, *, - predictionCol="prediction", - labelCol="label", - metricName="f1", - weightCol=None, - metricLabel=0.0, - beta=1.0, - probabilityCol="probability", - eps=1e-15, + predictionCol: str = "prediction", + labelCol: str = "label", + metricName: "MulticlassClassificationEvaluatorMetricType" = "f1", + weightCol: Optional[str] = None, + metricLabel: float = 0.0, + beta: float = 1.0, + probabilityCol: str = "probability", + eps: float = 1e-15, ): """ __init__(self, \\*, predictionCol="prediction", labelCol="label", \ @@ -560,82 +596,84 @@ def __init__( self._set(**kwargs) @since("1.5.0") - def setMetricName(self, value): + def setMetricName( + self, value: "MulticlassClassificationEvaluatorMetricType" + ) -> "MulticlassClassificationEvaluator": """ Sets the value of :py:attr:`metricName`. """ return self._set(metricName=value) @since("1.5.0") - def getMetricName(self): + def getMetricName(self) -> "MulticlassClassificationEvaluatorMetricType": """ Gets the value of metricName or its default value. """ return self.getOrDefault(self.metricName) @since("3.0.0") - def setMetricLabel(self, value): + def setMetricLabel(self, value: float) -> "MulticlassClassificationEvaluator": """ Sets the value of :py:attr:`metricLabel`. """ return self._set(metricLabel=value) @since("3.0.0") - def getMetricLabel(self): + def getMetricLabel(self) -> float: """ Gets the value of metricLabel or its default value. """ return self.getOrDefault(self.metricLabel) @since("3.0.0") - def setBeta(self, value): + def setBeta(self, value: float) -> "MulticlassClassificationEvaluator": """ Sets the value of :py:attr:`beta`. """ return self._set(beta=value) @since("3.0.0") - def getBeta(self): + def getBeta(self) -> float: """ Gets the value of beta or its default value. """ return self.getOrDefault(self.beta) @since("3.0.0") - def setEps(self, value): + def setEps(self, value: float) -> "MulticlassClassificationEvaluator": """ Sets the value of :py:attr:`eps`. """ return self._set(eps=value) @since("3.0.0") - def getEps(self): + def getEps(self) -> float: """ Gets the value of eps or its default value. """ return self.getOrDefault(self.eps) - def setLabelCol(self, value): + def setLabelCol(self, value: str) -> "MulticlassClassificationEvaluator": """ Sets the value of :py:attr:`labelCol`. """ return self._set(labelCol=value) - def setPredictionCol(self, value): + def setPredictionCol(self, value: str) -> "MulticlassClassificationEvaluator": """ Sets the value of :py:attr:`predictionCol`. """ return self._set(predictionCol=value) @since("3.0.0") - def setProbabilityCol(self, value): + def setProbabilityCol(self, value: str) -> "MulticlassClassificationEvaluator": """ Sets the value of :py:attr:`probabilityCol`. """ return self._set(probabilityCol=value) @since("3.0.0") - def setWeightCol(self, value): + def setWeightCol(self, value: str) -> "MulticlassClassificationEvaluator": """ Sets the value of :py:attr:`weightCol`. """ @@ -646,15 +684,15 @@ def setWeightCol(self, value): def setParams( self, *, - predictionCol="prediction", - labelCol="label", - metricName="f1", - weightCol=None, - metricLabel=0.0, - beta=1.0, - probabilityCol="probability", - eps=1e-15, - ): + predictionCol: str = "prediction", + labelCol: str = "label", + metricName: "MulticlassClassificationEvaluatorMetricType" = "f1", + weightCol: Optional[str] = None, + metricLabel: float = 0.0, + beta: float = 1.0, + probabilityCol: str = "probability", + eps: float = 1e-15, + ) -> "MulticlassClassificationEvaluator": """ setParams(self, \\*, predictionCol="prediction", labelCol="label", \ metricName="f1", weightCol=None, metricLabel=0.0, beta=1.0, \ @@ -667,7 +705,11 @@ def setParams( @inherit_doc class MultilabelClassificationEvaluator( - JavaEvaluator, HasLabelCol, HasPredictionCol, JavaMLReadable, JavaMLWritable + JavaEvaluator, + HasLabelCol, + HasPredictionCol, + JavaMLReadable["MultilabelClassificationEvaluator"], + JavaMLWritable, ): """ Evaluator for Multilabel Classification, which expects two input @@ -700,16 +742,16 @@ class MultilabelClassificationEvaluator( 'prediction' """ - metricName = Param( + metricName: Param["MultilabelClassificationEvaluatorMetricType"] = Param( Params._dummy(), "metricName", "metric name in evaluation " "(subsetAccuracy|accuracy|hammingLoss|precision|recall|f1Measure|" "precisionByLabel|recallByLabel|f1MeasureByLabel|microPrecision|" "microRecall|microF1Measure)", - typeConverter=TypeConverters.toString, + typeConverter=TypeConverters.toString, # type: ignore[arg-type] ) - metricLabel = Param( + metricLabel: Param[float] = Param( Params._dummy(), "metricLabel", "The class whose metric will be computed in precisionByLabel|" @@ -718,15 +760,17 @@ class MultilabelClassificationEvaluator( typeConverter=TypeConverters.toFloat, ) + _input_kwargs: Dict[str, Any] + @keyword_only def __init__( self, *, - predictionCol="prediction", - labelCol="label", - metricName="f1Measure", - metricLabel=0.0, - ): + predictionCol: str = "prediction", + labelCol: str = "label", + metricName: "MultilabelClassificationEvaluatorMetricType" = "f1Measure", + metricLabel: float = 0.0, + ) -> None: """ __init__(self, \\*, predictionCol="prediction", labelCol="label", \ metricName="f1Measure", metricLabel=0.0) @@ -740,42 +784,44 @@ def __init__( self._set(**kwargs) @since("3.0.0") - def setMetricName(self, value): + def setMetricName( + self, value: "MultilabelClassificationEvaluatorMetricType" + ) -> "MultilabelClassificationEvaluator": """ Sets the value of :py:attr:`metricName`. """ return self._set(metricName=value) @since("3.0.0") - def getMetricName(self): + def getMetricName(self) -> "MultilabelClassificationEvaluatorMetricType": """ Gets the value of metricName or its default value. """ return self.getOrDefault(self.metricName) @since("3.0.0") - def setMetricLabel(self, value): + def setMetricLabel(self, value: float) -> "MultilabelClassificationEvaluator": """ Sets the value of :py:attr:`metricLabel`. """ return self._set(metricLabel=value) @since("3.0.0") - def getMetricLabel(self): + def getMetricLabel(self) -> float: """ Gets the value of metricLabel or its default value. """ return self.getOrDefault(self.metricLabel) @since("3.0.0") - def setLabelCol(self, value): + def setLabelCol(self, value: str) -> "MultilabelClassificationEvaluator": """ Sets the value of :py:attr:`labelCol`. """ return self._set(labelCol=value) @since("3.0.0") - def setPredictionCol(self, value): + def setPredictionCol(self, value: str) -> "MultilabelClassificationEvaluator": """ Sets the value of :py:attr:`predictionCol`. """ @@ -786,11 +832,11 @@ def setPredictionCol(self, value): def setParams( self, *, - predictionCol="prediction", - labelCol="label", - metricName="f1Measure", - metricLabel=0.0, - ): + predictionCol: str = "prediction", + labelCol: str = "label", + metricName: "MultilabelClassificationEvaluatorMetricType" = "f1Measure", + metricLabel: float = 0.0, + ) -> "MultilabelClassificationEvaluator": """ setParams(self, \\*, predictionCol="prediction", labelCol="label", \ metricName="f1Measure", metricLabel=0.0) @@ -802,7 +848,12 @@ def setParams( @inherit_doc class ClusteringEvaluator( - JavaEvaluator, HasPredictionCol, HasFeaturesCol, HasWeightCol, JavaMLReadable, JavaMLWritable + JavaEvaluator, + HasPredictionCol, + HasFeaturesCol, + HasWeightCol, + JavaMLReadable["ClusteringEvaluator"], + JavaMLWritable, ): """ Evaluator for Clustering results, which expects two input @@ -848,28 +899,30 @@ class ClusteringEvaluator( 'prediction' """ - metricName = Param( + metricName: Param["ClusteringEvaluatorMetricType"] = Param( Params._dummy(), "metricName", "metric name in evaluation (silhouette)", - typeConverter=TypeConverters.toString, + typeConverter=TypeConverters.toString, # type: ignore[arg-type] ) - distanceMeasure = Param( + distanceMeasure: Param["ClusteringEvaluatorDistanceMeasureType"] = Param( Params._dummy(), "distanceMeasure", "The distance measure. " + "Supported options: 'squaredEuclidean' and 'cosine'.", - typeConverter=TypeConverters.toString, + typeConverter=TypeConverters.toString, # type: ignore[arg-type] ) + _input_kwargs: Dict[str, Any] + @keyword_only def __init__( self, *, - predictionCol="prediction", - featuresCol="features", - metricName="silhouette", - distanceMeasure="squaredEuclidean", - weightCol=None, + predictionCol: str = "prediction", + featuresCol: str = "features", + metricName: "ClusteringEvaluatorMetricType" = "silhouette", + distanceMeasure: str = "squaredEuclidean", + weightCol: Optional[str] = None, ): """ __init__(self, \\*, predictionCol="prediction", featuresCol="features", \ @@ -888,12 +941,12 @@ def __init__( def setParams( self, *, - predictionCol="prediction", - featuresCol="features", - metricName="silhouette", - distanceMeasure="squaredEuclidean", - weightCol=None, - ): + predictionCol: str = "prediction", + featuresCol: str = "features", + metricName: "ClusteringEvaluatorMetricType" = "silhouette", + distanceMeasure: str = "squaredEuclidean", + weightCol: Optional[str] = None, + ) -> "ClusteringEvaluator": """ setParams(self, \\*, predictionCol="prediction", featuresCol="features", \ metricName="silhouette", distanceMeasure="squaredEuclidean", weightCol=None) @@ -903,47 +956,49 @@ def setParams( return self._set(**kwargs) @since("2.3.0") - def setMetricName(self, value): + def setMetricName(self, value: "ClusteringEvaluatorMetricType") -> "ClusteringEvaluator": """ Sets the value of :py:attr:`metricName`. """ return self._set(metricName=value) @since("2.3.0") - def getMetricName(self): + def getMetricName(self) -> "ClusteringEvaluatorMetricType": """ Gets the value of metricName or its default value. """ return self.getOrDefault(self.metricName) @since("2.4.0") - def setDistanceMeasure(self, value): + def setDistanceMeasure( + self, value: "ClusteringEvaluatorDistanceMeasureType" + ) -> "ClusteringEvaluator": """ Sets the value of :py:attr:`distanceMeasure`. """ return self._set(distanceMeasure=value) @since("2.4.0") - def getDistanceMeasure(self): + def getDistanceMeasure(self) -> "ClusteringEvaluatorDistanceMeasureType": """ Gets the value of `distanceMeasure` """ return self.getOrDefault(self.distanceMeasure) - def setFeaturesCol(self, value): + def setFeaturesCol(self, value: "str") -> "ClusteringEvaluator": """ Sets the value of :py:attr:`featuresCol`. """ return self._set(featuresCol=value) - def setPredictionCol(self, value): + def setPredictionCol(self, value: str) -> "ClusteringEvaluator": """ Sets the value of :py:attr:`predictionCol`. """ return self._set(predictionCol=value) @since("3.1.0") - def setWeightCol(self, value): + def setWeightCol(self, value: str) -> "ClusteringEvaluator": """ Sets the value of :py:attr:`weightCol`. """ @@ -952,7 +1007,7 @@ def setWeightCol(self, value): @inherit_doc class RankingEvaluator( - JavaEvaluator, HasLabelCol, HasPredictionCol, JavaMLReadable, JavaMLWritable + JavaEvaluator, HasLabelCol, HasPredictionCol, JavaMLReadable["RankingEvaluator"], JavaMLWritable ): """ Evaluator for Ranking, which expects two input @@ -986,15 +1041,15 @@ class RankingEvaluator( 'prediction' """ - metricName = Param( + metricName: Param["RankingEvaluatorMetricType"] = Param( Params._dummy(), "metricName", "metric name in evaluation " "(meanAveragePrecision|meanAveragePrecisionAtK|" "precisionAtK|ndcgAtK|recallAtK)", - typeConverter=TypeConverters.toString, + typeConverter=TypeConverters.toString, # type: ignore[arg-type] ) - k = Param( + k: Param[int] = Param( Params._dummy(), "k", "The ranking position value used in meanAveragePrecisionAtK|precisionAtK|" @@ -1002,14 +1057,16 @@ class RankingEvaluator( typeConverter=TypeConverters.toInt, ) + _input_kwargs: Dict[str, Any] + @keyword_only def __init__( self, *, - predictionCol="prediction", - labelCol="label", - metricName="meanAveragePrecision", - k=10, + predictionCol: str = "prediction", + labelCol: str = "label", + metricName: "RankingEvaluatorMetricType" = "meanAveragePrecision", + k: int = 10, ): """ __init__(self, \\*, predictionCol="prediction", labelCol="label", \ @@ -1024,42 +1081,42 @@ def __init__( self._set(**kwargs) @since("3.0.0") - def setMetricName(self, value): + def setMetricName(self, value: "RankingEvaluatorMetricType") -> "RankingEvaluator": """ Sets the value of :py:attr:`metricName`. """ return self._set(metricName=value) @since("3.0.0") - def getMetricName(self): + def getMetricName(self) -> "RankingEvaluatorMetricType": """ Gets the value of metricName or its default value. """ return self.getOrDefault(self.metricName) @since("3.0.0") - def setK(self, value): + def setK(self, value: int) -> "RankingEvaluator": """ Sets the value of :py:attr:`k`. """ return self._set(k=value) @since("3.0.0") - def getK(self): + def getK(self) -> int: """ Gets the value of k or its default value. """ return self.getOrDefault(self.k) @since("3.0.0") - def setLabelCol(self, value): + def setLabelCol(self, value: str) -> "RankingEvaluator": """ Sets the value of :py:attr:`labelCol`. """ return self._set(labelCol=value) @since("3.0.0") - def setPredictionCol(self, value): + def setPredictionCol(self, value: str) -> "RankingEvaluator": """ Sets the value of :py:attr:`predictionCol`. """ @@ -1070,11 +1127,11 @@ def setPredictionCol(self, value): def setParams( self, *, - predictionCol="prediction", - labelCol="label", - metricName="meanAveragePrecision", - k=10, - ): + predictionCol: str = "prediction", + labelCol: str = "label", + metricName: "RankingEvaluatorMetricType" = "meanAveragePrecision", + k: int = 10, + ) -> "RankingEvaluator": """ setParams(self, \\*, predictionCol="prediction", labelCol="label", \ metricName="meanAveragePrecision", k=10) diff --git a/python/pyspark/ml/evaluation.pyi b/python/pyspark/ml/evaluation.pyi deleted file mode 100644 index d7883f4e1b1aa..0000000000000 --- a/python/pyspark/ml/evaluation.pyi +++ /dev/null @@ -1,277 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import abc -from typing import Optional -from pyspark.ml._typing import ( - ParamMap, - BinaryClassificationEvaluatorMetricType, - ClusteringEvaluatorMetricType, - MulticlassClassificationEvaluatorMetricType, - MultilabelClassificationEvaluatorMetricType, - RankingEvaluatorMetricType, - RegressionEvaluatorMetricType, -) - -from pyspark.ml.wrapper import JavaParams -from pyspark.ml.param import Param, Params -from pyspark.ml.param.shared import ( - HasFeaturesCol, - HasLabelCol, - HasPredictionCol, - HasProbabilityCol, - HasRawPredictionCol, - HasWeightCol, -) -from pyspark.ml.util import JavaMLReadable, JavaMLWritable -from pyspark.sql.dataframe import DataFrame - -class Evaluator(Params, metaclass=abc.ABCMeta): - def evaluate(self, dataset: DataFrame, params: Optional[ParamMap] = ...) -> float: ... - def isLargerBetter(self) -> bool: ... - -class JavaEvaluator(JavaParams, Evaluator, metaclass=abc.ABCMeta): - def isLargerBetter(self) -> bool: ... - -class BinaryClassificationEvaluator( - JavaEvaluator, - HasLabelCol, - HasRawPredictionCol, - HasWeightCol, - JavaMLReadable[BinaryClassificationEvaluator], - JavaMLWritable, -): - metricName: Param[BinaryClassificationEvaluatorMetricType] - numBins: Param[int] - def __init__( - self, - *, - rawPredictionCol: str = ..., - labelCol: str = ..., - metricName: BinaryClassificationEvaluatorMetricType = ..., - weightCol: Optional[str] = ..., - numBins: int = ..., - ) -> None: ... - def setMetricName( - self, value: BinaryClassificationEvaluatorMetricType - ) -> BinaryClassificationEvaluator: ... - def getMetricName(self) -> BinaryClassificationEvaluatorMetricType: ... - def setNumBins(self, value: int) -> BinaryClassificationEvaluator: ... - def getNumBins(self) -> int: ... - def setLabelCol(self, value: str) -> BinaryClassificationEvaluator: ... - def setRawPredictionCol(self, value: str) -> BinaryClassificationEvaluator: ... - def setWeightCol(self, value: str) -> BinaryClassificationEvaluator: ... - def setParams( - self, - *, - rawPredictionCol: str = ..., - labelCol: str = ..., - metricName: BinaryClassificationEvaluatorMetricType = ..., - weightCol: Optional[str] = ..., - numBins: int = ..., - ) -> BinaryClassificationEvaluator: ... - -class RegressionEvaluator( - JavaEvaluator, - HasLabelCol, - HasPredictionCol, - HasWeightCol, - JavaMLReadable[RegressionEvaluator], - JavaMLWritable, -): - metricName: Param[RegressionEvaluatorMetricType] - throughOrigin: Param[bool] - def __init__( - self, - *, - predictionCol: str = ..., - labelCol: str = ..., - metricName: RegressionEvaluatorMetricType = ..., - weightCol: Optional[str] = ..., - throughOrigin: bool = ..., - ) -> None: ... - def setMetricName(self, value: RegressionEvaluatorMetricType) -> RegressionEvaluator: ... - def getMetricName(self) -> RegressionEvaluatorMetricType: ... - def setThroughOrigin(self, value: bool) -> RegressionEvaluator: ... - def getThroughOrigin(self) -> bool: ... - def setLabelCol(self, value: str) -> RegressionEvaluator: ... - def setPredictionCol(self, value: str) -> RegressionEvaluator: ... - def setWeightCol(self, value: str) -> RegressionEvaluator: ... - def setParams( - self, - *, - predictionCol: str = ..., - labelCol: str = ..., - metricName: RegressionEvaluatorMetricType = ..., - weightCol: Optional[str] = ..., - throughOrigin: bool = ..., - ) -> RegressionEvaluator: ... - -class MulticlassClassificationEvaluator( - JavaEvaluator, - HasLabelCol, - HasPredictionCol, - HasWeightCol, - HasProbabilityCol, - JavaMLReadable[MulticlassClassificationEvaluator], - JavaMLWritable, -): - metricName: Param[MulticlassClassificationEvaluatorMetricType] - metricLabel: Param[float] - beta: Param[float] - eps: Param[float] - def __init__( - self, - *, - predictionCol: str = ..., - labelCol: str = ..., - metricName: MulticlassClassificationEvaluatorMetricType = ..., - weightCol: Optional[str] = ..., - metricLabel: float = ..., - beta: float = ..., - probabilityCol: str = ..., - eps: float = ..., - ) -> None: ... - def setMetricName( - self, value: MulticlassClassificationEvaluatorMetricType - ) -> MulticlassClassificationEvaluator: ... - def getMetricName(self) -> MulticlassClassificationEvaluatorMetricType: ... - def setMetricLabel(self, value: float) -> MulticlassClassificationEvaluator: ... - def getMetricLabel(self) -> float: ... - def setBeta(self, value: float) -> MulticlassClassificationEvaluator: ... - def getBeta(self) -> float: ... - def setEps(self, value: float) -> MulticlassClassificationEvaluator: ... - def getEps(self) -> float: ... - def setLabelCol(self, value: str) -> MulticlassClassificationEvaluator: ... - def setPredictionCol(self, value: str) -> MulticlassClassificationEvaluator: ... - def setProbabilityCol(self, value: str) -> MulticlassClassificationEvaluator: ... - def setWeightCol(self, value: str) -> MulticlassClassificationEvaluator: ... - def setParams( - self, - *, - predictionCol: str = ..., - labelCol: str = ..., - metricName: MulticlassClassificationEvaluatorMetricType = ..., - weightCol: Optional[str] = ..., - metricLabel: float = ..., - beta: float = ..., - probabilityCol: str = ..., - eps: float = ..., - ) -> MulticlassClassificationEvaluator: ... - -class MultilabelClassificationEvaluator( - JavaEvaluator, - HasLabelCol, - HasPredictionCol, - JavaMLReadable[MultilabelClassificationEvaluator], - JavaMLWritable, -): - metricName: Param[MultilabelClassificationEvaluatorMetricType] - metricLabel: Param[float] - def __init__( - self, - *, - predictionCol: str = ..., - labelCol: str = ..., - metricName: MultilabelClassificationEvaluatorMetricType = ..., - metricLabel: float = ..., - ) -> None: ... - def setMetricName( - self, value: MultilabelClassificationEvaluatorMetricType - ) -> MultilabelClassificationEvaluator: ... - def getMetricName(self) -> MultilabelClassificationEvaluatorMetricType: ... - def setMetricLabel(self, value: float) -> MultilabelClassificationEvaluator: ... - def getMetricLabel(self) -> float: ... - def setLabelCol(self, value: str) -> MultilabelClassificationEvaluator: ... - def setPredictionCol(self, value: str) -> MultilabelClassificationEvaluator: ... - def setParams( - self, - *, - predictionCol: str = ..., - labelCol: str = ..., - metricName: MultilabelClassificationEvaluatorMetricType = ..., - metricLabel: float = ..., - ) -> MultilabelClassificationEvaluator: ... - -class ClusteringEvaluator( - JavaEvaluator, - HasPredictionCol, - HasFeaturesCol, - HasWeightCol, - JavaMLReadable[ClusteringEvaluator], - JavaMLWritable, -): - metricName: Param[ClusteringEvaluatorMetricType] - distanceMeasure: Param[str] - def __init__( - self, - *, - predictionCol: str = ..., - featuresCol: str = ..., - metricName: ClusteringEvaluatorMetricType = ..., - distanceMeasure: str = ..., - weightCol: Optional[str] = ..., - ) -> None: ... - def setParams( - self, - *, - predictionCol: str = ..., - featuresCol: str = ..., - metricName: ClusteringEvaluatorMetricType = ..., - distanceMeasure: str = ..., - weightCol: Optional[str] = ..., - ) -> ClusteringEvaluator: ... - def setMetricName(self, value: ClusteringEvaluatorMetricType) -> ClusteringEvaluator: ... - def getMetricName(self) -> ClusteringEvaluatorMetricType: ... - def setDistanceMeasure(self, value: str) -> ClusteringEvaluator: ... - def getDistanceMeasure(self) -> str: ... - def setFeaturesCol(self, value: str) -> ClusteringEvaluator: ... - def setPredictionCol(self, value: str) -> ClusteringEvaluator: ... - def setWeightCol(self, value: str) -> ClusteringEvaluator: ... - -class RankingEvaluator( - JavaEvaluator, - HasLabelCol, - HasPredictionCol, - JavaMLReadable[RankingEvaluator], - JavaMLWritable, -): - metricName: Param[RankingEvaluatorMetricType] - k: Param[int] - def __init__( - self, - *, - predictionCol: str = ..., - labelCol: str = ..., - metricName: RankingEvaluatorMetricType = ..., - k: int = ..., - ) -> None: ... - def setMetricName(self, value: RankingEvaluatorMetricType) -> RankingEvaluator: ... - def getMetricName(self) -> RankingEvaluatorMetricType: ... - def setK(self, value: int) -> RankingEvaluator: ... - def getK(self) -> int: ... - def setLabelCol(self, value: str) -> RankingEvaluator: ... - def setPredictionCol(self, value: str) -> RankingEvaluator: ... - def setParams( - self, - *, - predictionCol: str = ..., - labelCol: str = ..., - metricName: RankingEvaluatorMetricType = ..., - k: int = ..., - ) -> RankingEvaluator: ... diff --git a/python/pyspark/ml/feature.pyi b/python/pyspark/ml/feature.pyi index 6efc304b897f4..6545bcd1c516a 100644 --- a/python/pyspark/ml/feature.pyi +++ b/python/pyspark/ml/feature.pyi @@ -42,6 +42,8 @@ from pyspark.ml.linalg import Vector, DenseVector, DenseMatrix from pyspark.sql.dataframe import DataFrame from pyspark.ml.param import Param +from py4j.java_gateway import JavaObject + class Binarizer( JavaTransformer, HasThreshold, @@ -103,6 +105,7 @@ class _LSH(Generic[JM], JavaEstimator[JM], _LSHParams, JavaMLReadable, JavaMLWri def setNumHashTables(self: P, value: int) -> P: ... def setInputCol(self: P, value: str) -> P: ... def setOutputCol(self: P, value: str) -> P: ... + def _create_model(self, java_model: JavaObject) -> JM: ... class _LSHModel(JavaModel, _LSHParams): def setInputCol(self: P, value: str) -> P: ... @@ -268,6 +271,7 @@ class CountVectorizer( def setBinary(self, value: bool) -> CountVectorizer: ... def setInputCol(self, value: str) -> CountVectorizer: ... def setOutputCol(self, value: str) -> CountVectorizer: ... + def _create_model(self, java_model: JavaObject) -> CountVectorizerModel: ... class CountVectorizerModel(JavaModel, JavaMLReadable[CountVectorizerModel], JavaMLWritable): def setInputCol(self, value: str) -> CountVectorizerModel: ... @@ -412,6 +416,7 @@ class IDF(JavaEstimator[IDFModel], _IDFParams, JavaMLReadable[IDF], JavaMLWritab def setMinDocFreq(self, value: int) -> IDF: ... def setInputCol(self, value: str) -> IDF: ... def setOutputCol(self, value: str) -> IDF: ... + def _create_model(self, java_model: JavaObject) -> IDFModel: ... class IDFModel(JavaModel, _IDFParams, JavaMLReadable[IDFModel], JavaMLWritable): def setInputCol(self, value: str) -> IDFModel: ... @@ -477,6 +482,7 @@ class Imputer(JavaEstimator[ImputerModel], _ImputerParams, JavaMLReadable[Impute def setInputCol(self, value: str) -> Imputer: ... def setOutputCol(self, value: str) -> Imputer: ... def setRelativeError(self, value: float) -> Imputer: ... + def _create_model(self, java_model: JavaObject) -> ImputerModel: ... class ImputerModel(JavaModel, _ImputerParams, JavaMLReadable[ImputerModel], JavaMLWritable): def setInputCols(self, value: List[str]) -> ImputerModel: ... @@ -518,6 +524,7 @@ class MaxAbsScaler( ) -> MaxAbsScaler: ... def setInputCol(self, value: str) -> MaxAbsScaler: ... def setOutputCol(self, value: str) -> MaxAbsScaler: ... + def _create_model(self, java_model: JavaObject) -> MaxAbsScalerModel: ... class MaxAbsScalerModel( JavaModel, _MaxAbsScalerParams, JavaMLReadable[MaxAbsScalerModel], JavaMLWritable @@ -588,6 +595,7 @@ class MinMaxScaler( def setMax(self, value: float) -> MinMaxScaler: ... def setInputCol(self, value: str) -> MinMaxScaler: ... def setOutputCol(self, value: str) -> MinMaxScaler: ... + def _create_model(self, java_model: JavaObject) -> MinMaxScalerModel: ... class MinMaxScalerModel( JavaModel, _MinMaxScalerParams, JavaMLReadable[MinMaxScalerModel], JavaMLWritable @@ -687,6 +695,7 @@ class OneHotEncoder( def setHandleInvalid(self, value: str) -> OneHotEncoder: ... def setInputCol(self, value: str) -> OneHotEncoder: ... def setOutputCol(self, value: str) -> OneHotEncoder: ... + def _create_model(self, java_model: JavaObject) -> OneHotEncoderModel: ... class OneHotEncoderModel( JavaModel, _OneHotEncoderParams, JavaMLReadable[OneHotEncoderModel], JavaMLWritable @@ -783,6 +792,7 @@ class QuantileDiscretizer( def setOutputCol(self, value: str) -> QuantileDiscretizer: ... def setOutputCols(self, value: List[str]) -> QuantileDiscretizer: ... def setHandleInvalid(self, value: str) -> QuantileDiscretizer: ... + def _create_model(self, java_model: JavaObject) -> Bucketizer: ... class _RobustScalerParams(HasInputCol, HasOutputCol, HasRelativeError): lower: Param[float] @@ -827,6 +837,7 @@ class RobustScaler( def setInputCol(self, value: str) -> RobustScaler: ... def setOutputCol(self, value: str) -> RobustScaler: ... def setRelativeError(self, value: float) -> RobustScaler: ... + def _create_model(self, java_model: JavaObject) -> RobustScalerModel: ... class RobustScalerModel( JavaModel, _RobustScalerParams, JavaMLReadable[RobustScalerModel], JavaMLWritable @@ -920,6 +931,7 @@ class StandardScaler( def setWithStd(self, value: bool) -> StandardScaler: ... def setInputCol(self, value: str) -> StandardScaler: ... def setOutputCol(self, value: str) -> StandardScaler: ... + def _create_model(self, java_model: JavaObject) -> StandardScalerModel: ... class StandardScalerModel( JavaModel, @@ -990,6 +1002,7 @@ class StringIndexer( def setOutputCol(self, value: str) -> StringIndexer: ... def setOutputCols(self, value: List[str]) -> StringIndexer: ... def setHandleInvalid(self, value: str) -> StringIndexer: ... + def _create_model(self, java_model: JavaObject) -> StringIndexerModel: ... class StringIndexerModel( JavaModel, _StringIndexerParams, JavaMLReadable[StringIndexerModel], JavaMLWritable @@ -1186,6 +1199,7 @@ class VectorIndexer( def setInputCol(self, value: str) -> VectorIndexer: ... def setOutputCol(self, value: str) -> VectorIndexer: ... def setHandleInvalid(self, value: str) -> VectorIndexer: ... + def _create_model(self, java_model: JavaObject) -> VectorIndexerModel: ... class VectorIndexerModel( JavaModel, _VectorIndexerParams, JavaMLReadable[VectorIndexerModel], JavaMLWritable @@ -1286,6 +1300,7 @@ class Word2Vec( def setOutputCol(self, value: str) -> Word2Vec: ... def setSeed(self, value: int) -> Word2Vec: ... def setStepSize(self, value: float) -> Word2Vec: ... + def _create_model(self, java_model: JavaObject) -> Word2VecModel: ... class Word2VecModel(JavaModel, _Word2VecParams, JavaMLReadable[Word2VecModel], JavaMLWritable): def getVectors(self) -> DataFrame: ... @@ -1322,6 +1337,7 @@ class PCA(JavaEstimator[PCAModel], _PCAParams, JavaMLReadable[PCA], JavaMLWritab def setK(self, value: int) -> PCA: ... def setInputCol(self, value: str) -> PCA: ... def setOutputCol(self, value: str) -> PCA: ... + def _create_model(self, java_model: JavaObject) -> PCAModel: ... class PCAModel(JavaModel, _PCAParams, JavaMLReadable[PCAModel], JavaMLWritable): def setInputCol(self, value: str) -> PCAModel: ... @@ -1373,6 +1389,7 @@ class RFormula( def setFeaturesCol(self, value: str) -> RFormula: ... def setLabelCol(self, value: str) -> RFormula: ... def setHandleInvalid(self, value: str) -> RFormula: ... + def _create_model(self, java_model: JavaObject) -> RFormulaModel: ... class RFormulaModel(JavaModel, _RFormulaParams, JavaMLReadable[RFormulaModel], JavaMLWritable): ... @@ -1391,7 +1408,7 @@ class _SelectorParams(HasFeaturesCol, HasOutputCol, HasLabelCol): def getFdr(self) -> float: ... def getFwe(self) -> float: ... -class _Selector(JavaEstimator[JM], _SelectorParams, JavaMLReadable, JavaMLWritable): +class _Selector(JavaEstimator[JM], _SelectorParams, JavaMLReadable, JavaMLWritable, Generic[JM]): def setSelectorType(self: P, value: str) -> P: ... def setNumTopFeatures(self: P, value: int) -> P: ... def setPercentile(self: P, value: float) -> P: ... @@ -1401,6 +1418,7 @@ class _Selector(JavaEstimator[JM], _SelectorParams, JavaMLReadable, JavaMLWritab def setFeaturesCol(self: P, value: str) -> P: ... def setOutputCol(self: P, value: str) -> P: ... def setLabelCol(self: P, value: str) -> P: ... + def _create_model(self, java_model: JavaObject) -> JM: ... class _SelectorModel(JavaModel, _SelectorParams): def setFeaturesCol(self: P, value: str) -> P: ... @@ -1448,6 +1466,7 @@ class ChiSqSelector( def setFeaturesCol(self, value: str) -> ChiSqSelector: ... def setOutputCol(self, value: str) -> ChiSqSelector: ... def setLabelCol(self, value: str) -> ChiSqSelector: ... + def _create_model(self, java_model: JavaObject) -> ChiSqSelectorModel: ... class ChiSqSelectorModel(_SelectorModel, JavaMLReadable[ChiSqSelectorModel], JavaMLWritable): def setFeaturesCol(self, value: str) -> ChiSqSelectorModel: ... @@ -1500,6 +1519,7 @@ class VarianceThresholdSelector( def setVarianceThreshold(self, value: float) -> VarianceThresholdSelector: ... def setFeaturesCol(self, value: str) -> VarianceThresholdSelector: ... def setOutputCol(self, value: str) -> VarianceThresholdSelector: ... + def _create_model(self, java_model: JavaObject) -> VarianceThresholdSelectorModel: ... class VarianceThresholdSelectorModel( JavaModel, @@ -1552,6 +1572,7 @@ class UnivariateFeatureSelector( def setFeaturesCol(self, value: str) -> UnivariateFeatureSelector: ... def setOutputCol(self, value: str) -> UnivariateFeatureSelector: ... def setLabelCol(self, value: str) -> UnivariateFeatureSelector: ... + def _create_model(self, java_model: JavaObject) -> UnivariateFeatureSelectorModel: ... class UnivariateFeatureSelectorModel( JavaModel, diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py index 9cfd3afb386d1..5848b5baca3e5 100644 --- a/python/pyspark/ml/fpm.py +++ b/python/pyspark/ml/fpm.py @@ -16,6 +16,7 @@ # import sys +from typing import Any, Dict, Optional, TYPE_CHECKING from pyspark import keyword_only, since from pyspark.sql import DataFrame @@ -23,6 +24,9 @@ from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams from pyspark.ml.param.shared import HasPredictionCol, Param, TypeConverters, Params +if TYPE_CHECKING: + from py4j.java_gateway import JavaObject + __all__ = ["FPGrowth", "FPGrowthModel", "PrefixSpan"] @@ -33,10 +37,10 @@ class _FPGrowthParams(HasPredictionCol): .. versionadded:: 3.0.0 """ - itemsCol = Param( + itemsCol: Param[str] = Param( Params._dummy(), "itemsCol", "items column name", typeConverter=TypeConverters.toString ) - minSupport = Param( + minSupport: Param[float] = Param( Params._dummy(), "minSupport", "Minimal support level of the frequent pattern. [0.0, 1.0]. " @@ -44,7 +48,7 @@ class _FPGrowthParams(HasPredictionCol): + "times will be output in the frequent itemsets.", typeConverter=TypeConverters.toFloat, ) - numPartitions = Param( + numPartitions: Param[int] = Param( Params._dummy(), "numPartitions", "Number of partitions (at least 1) used by parallel FP-growth. " @@ -52,7 +56,7 @@ class _FPGrowthParams(HasPredictionCol): + "and partition number of the input dataset is used.", typeConverter=TypeConverters.toInt, ) - minConfidence = Param( + minConfidence: Param[float] = Param( Params._dummy(), "minConfidence", "Minimal confidence for generating Association Rule. [0.0, 1.0]. " @@ -61,38 +65,38 @@ class _FPGrowthParams(HasPredictionCol): typeConverter=TypeConverters.toFloat, ) - def __init__(self, *args): + def __init__(self, *args: Any): super(_FPGrowthParams, self).__init__(*args) self._setDefault( minSupport=0.3, minConfidence=0.8, itemsCol="items", predictionCol="prediction" ) - def getItemsCol(self): + def getItemsCol(self) -> str: """ Gets the value of itemsCol or its default value. """ return self.getOrDefault(self.itemsCol) - def getMinSupport(self): + def getMinSupport(self) -> float: """ Gets the value of minSupport or its default value. """ return self.getOrDefault(self.minSupport) - def getNumPartitions(self): + def getNumPartitions(self) -> int: """ Gets the value of :py:attr:`numPartitions` or its default value. """ return self.getOrDefault(self.numPartitions) - def getMinConfidence(self): + def getMinConfidence(self) -> float: """ Gets the value of minConfidence or its default value. """ return self.getOrDefault(self.minConfidence) -class FPGrowthModel(JavaModel, _FPGrowthParams, JavaMLWritable, JavaMLReadable): +class FPGrowthModel(JavaModel, _FPGrowthParams, JavaMLWritable, JavaMLReadable["FPGrowthModel"]): """ Model fitted by FPGrowth. @@ -100,29 +104,29 @@ class FPGrowthModel(JavaModel, _FPGrowthParams, JavaMLWritable, JavaMLReadable): """ @since("3.0.0") - def setItemsCol(self, value): + def setItemsCol(self, value: str) -> "FPGrowthModel": """ Sets the value of :py:attr:`itemsCol`. """ return self._set(itemsCol=value) @since("3.0.0") - def setMinConfidence(self, value): + def setMinConfidence(self, value: float) -> "FPGrowthModel": """ Sets the value of :py:attr:`minConfidence`. """ return self._set(minConfidence=value) @since("3.0.0") - def setPredictionCol(self, value): + def setPredictionCol(self, value: str) -> "FPGrowthModel": """ Sets the value of :py:attr:`predictionCol`. """ return self._set(predictionCol=value) - @property + @property # type: ignore[misc] @since("2.2.0") - def freqItemsets(self): + def freqItemsets(self) -> DataFrame: """ DataFrame with two columns: * `items` - Itemset of the same type as the input column. @@ -130,9 +134,9 @@ def freqItemsets(self): """ return self._call_java("freqItemsets") - @property + @property # type: ignore[misc] @since("2.2.0") - def associationRules(self): + def associationRules(self) -> DataFrame: """ DataFrame with four columns: * `antecedent` - Array of the same type as the input column. @@ -143,7 +147,9 @@ def associationRules(self): return self._call_java("associationRules") -class FPGrowth(JavaEstimator, _FPGrowthParams, JavaMLWritable, JavaMLReadable): +class FPGrowth( + JavaEstimator[FPGrowthModel], _FPGrowthParams, JavaMLWritable, JavaMLReadable["FPGrowth"] +): r""" A parallel FP-growth algorithm to mine frequent itemsets. @@ -229,16 +235,17 @@ class FPGrowth(JavaEstimator, _FPGrowthParams, JavaMLWritable, JavaMLReadable): >>> fpm.transform(data).take(1) == model2.transform(data).take(1) True """ + _input_kwargs: Dict[str, Any] @keyword_only def __init__( self, *, - minSupport=0.3, - minConfidence=0.8, - itemsCol="items", - predictionCol="prediction", - numPartitions=None, + minSupport: float = 0.3, + minConfidence: float = 0.8, + itemsCol: str = "items", + predictionCol: str = "prediction", + numPartitions: Optional[int] = None, ): """ __init__(self, \\*, minSupport=0.3, minConfidence=0.8, itemsCol="items", \ @@ -254,12 +261,12 @@ def __init__( def setParams( self, *, - minSupport=0.3, - minConfidence=0.8, - itemsCol="items", - predictionCol="prediction", - numPartitions=None, - ): + minSupport: float = 0.3, + minConfidence: float = 0.8, + itemsCol: str = "items", + predictionCol: str = "prediction", + numPartitions: Optional[int] = None, + ) -> "FPGrowth": """ setParams(self, \\*, minSupport=0.3, minConfidence=0.8, itemsCol="items", \ predictionCol="prediction", numPartitions=None) @@ -267,37 +274,37 @@ def setParams( kwargs = self._input_kwargs return self._set(**kwargs) - def setItemsCol(self, value): + def setItemsCol(self, value: str) -> "FPGrowth": """ Sets the value of :py:attr:`itemsCol`. """ return self._set(itemsCol=value) - def setMinSupport(self, value): + def setMinSupport(self, value: float) -> "FPGrowth": """ Sets the value of :py:attr:`minSupport`. """ return self._set(minSupport=value) - def setNumPartitions(self, value): + def setNumPartitions(self, value: int) -> "FPGrowth": """ Sets the value of :py:attr:`numPartitions`. """ return self._set(numPartitions=value) - def setMinConfidence(self, value): + def setMinConfidence(self, value: float) -> "FPGrowth": """ Sets the value of :py:attr:`minConfidence`. """ return self._set(minConfidence=value) - def setPredictionCol(self, value): + def setPredictionCol(self, value: str) -> "FPGrowth": """ Sets the value of :py:attr:`predictionCol`. """ return self._set(predictionCol=value) - def _create_model(self, java_model): + def _create_model(self, java_model: "JavaObject") -> FPGrowthModel: return FPGrowthModel(java_model) @@ -347,7 +354,9 @@ class PrefixSpan(JavaParams): ... """ - minSupport = Param( + _input_kwargs: Dict[str, Any] + + minSupport: Param[float] = Param( Params._dummy(), "minSupport", "The minimal support level of the " @@ -356,14 +365,14 @@ class PrefixSpan(JavaParams): typeConverter=TypeConverters.toFloat, ) - maxPatternLength = Param( + maxPatternLength: Param[int] = Param( Params._dummy(), "maxPatternLength", "The maximal length of the sequential pattern. Must be > 0.", typeConverter=TypeConverters.toInt, ) - maxLocalProjDBSize = Param( + maxLocalProjDBSize: Param[int] = Param( Params._dummy(), "maxLocalProjDBSize", "The maximum number of items (including delimiters used in the " @@ -374,7 +383,7 @@ class PrefixSpan(JavaParams): typeConverter=TypeConverters.toInt, ) - sequenceCol = Param( + sequenceCol: Param[str] = Param( Params._dummy(), "sequenceCol", "The name of the sequence column in " @@ -386,10 +395,10 @@ class PrefixSpan(JavaParams): def __init__( self, *, - minSupport=0.1, - maxPatternLength=10, - maxLocalProjDBSize=32000000, - sequenceCol="sequence", + minSupport: float = 0.1, + maxPatternLength: int = 10, + maxLocalProjDBSize: int = 32000000, + sequenceCol: str = "sequence", ): """ __init__(self, \\*, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, \ @@ -408,11 +417,11 @@ def __init__( def setParams( self, *, - minSupport=0.1, - maxPatternLength=10, - maxLocalProjDBSize=32000000, - sequenceCol="sequence", - ): + minSupport: float = 0.1, + maxPatternLength: int = 10, + maxLocalProjDBSize: int = 32000000, + sequenceCol: str = "sequence", + ) -> "PrefixSpan": """ setParams(self, \\*, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, \ sequenceCol="sequence") @@ -421,62 +430,62 @@ def setParams( return self._set(**kwargs) @since("3.0.0") - def setMinSupport(self, value): + def setMinSupport(self, value: float) -> "PrefixSpan": """ Sets the value of :py:attr:`minSupport`. """ return self._set(minSupport=value) @since("3.0.0") - def getMinSupport(self): + def getMinSupport(self) -> float: """ Gets the value of minSupport or its default value. """ return self.getOrDefault(self.minSupport) @since("3.0.0") - def setMaxPatternLength(self, value): + def setMaxPatternLength(self, value: int) -> "PrefixSpan": """ Sets the value of :py:attr:`maxPatternLength`. """ return self._set(maxPatternLength=value) @since("3.0.0") - def getMaxPatternLength(self): + def getMaxPatternLength(self) -> int: """ Gets the value of maxPatternLength or its default value. """ return self.getOrDefault(self.maxPatternLength) @since("3.0.0") - def setMaxLocalProjDBSize(self, value): + def setMaxLocalProjDBSize(self, value: int) -> "PrefixSpan": """ Sets the value of :py:attr:`maxLocalProjDBSize`. """ return self._set(maxLocalProjDBSize=value) @since("3.0.0") - def getMaxLocalProjDBSize(self): + def getMaxLocalProjDBSize(self) -> int: """ Gets the value of maxLocalProjDBSize or its default value. """ return self.getOrDefault(self.maxLocalProjDBSize) @since("3.0.0") - def setSequenceCol(self, value): + def setSequenceCol(self, value: str) -> "PrefixSpan": """ Sets the value of :py:attr:`sequenceCol`. """ return self._set(sequenceCol=value) @since("3.0.0") - def getSequenceCol(self): + def getSequenceCol(self) -> str: """ Gets the value of sequenceCol or its default value. """ return self.getOrDefault(self.sequenceCol) - def findFrequentSequentialPatterns(self, dataset): + def findFrequentSequentialPatterns(self, dataset: DataFrame) -> DataFrame: """ Finds the complete set of frequent sequential patterns in the input sequences of itemsets. @@ -499,8 +508,9 @@ def findFrequentSequentialPatterns(self, dataset): """ self._transfer_params_to_java() + assert self._java_obj is not None jdf = self._java_obj.findFrequentSequentialPatterns(dataset._jdf) - return DataFrame(jdf, dataset.sql_ctx) + return DataFrame(jdf, dataset.sparkSession) if __name__ == "__main__": diff --git a/python/pyspark/ml/fpm.pyi b/python/pyspark/ml/fpm.pyi deleted file mode 100644 index 609bc447735b7..0000000000000 --- a/python/pyspark/ml/fpm.pyi +++ /dev/null @@ -1,107 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import Any, Optional - -from pyspark.ml.util import JavaMLReadable, JavaMLWritable -from pyspark.ml.wrapper import JavaEstimator, JavaParams, JavaModel -from pyspark.ml.param.shared import HasPredictionCol -from pyspark.sql.dataframe import DataFrame - -from pyspark.ml.param import Param - -class _FPGrowthParams(HasPredictionCol): - itemsCol: Param[str] - minSupport: Param[float] - numPartitions: Param[int] - minConfidence: Param[float] - def __init__(self, *args: Any): ... - def getItemsCol(self) -> str: ... - def getMinSupport(self) -> float: ... - def getNumPartitions(self) -> int: ... - def getMinConfidence(self) -> float: ... - -class FPGrowthModel(JavaModel, _FPGrowthParams, JavaMLWritable, JavaMLReadable[FPGrowthModel]): - def setItemsCol(self, value: str) -> FPGrowthModel: ... - def setMinConfidence(self, value: float) -> FPGrowthModel: ... - def setPredictionCol(self, value: str) -> FPGrowthModel: ... - @property - def freqItemsets(self) -> DataFrame: ... - @property - def associationRules(self) -> DataFrame: ... - -class FPGrowth( - JavaEstimator[FPGrowthModel], - _FPGrowthParams, - JavaMLWritable, - JavaMLReadable[FPGrowth], -): - def __init__( - self, - *, - minSupport: float = ..., - minConfidence: float = ..., - itemsCol: str = ..., - predictionCol: str = ..., - numPartitions: Optional[int] = ..., - ) -> None: ... - def setParams( - self, - *, - minSupport: float = ..., - minConfidence: float = ..., - itemsCol: str = ..., - predictionCol: str = ..., - numPartitions: Optional[int] = ..., - ) -> FPGrowth: ... - def setItemsCol(self, value: str) -> FPGrowth: ... - def setMinSupport(self, value: float) -> FPGrowth: ... - def setNumPartitions(self, value: int) -> FPGrowth: ... - def setMinConfidence(self, value: float) -> FPGrowth: ... - def setPredictionCol(self, value: str) -> FPGrowth: ... - -class PrefixSpan(JavaParams): - minSupport: Param[float] - maxPatternLength: Param[int] - maxLocalProjDBSize: Param[int] - sequenceCol: Param[str] - def __init__( - self, - *, - minSupport: float = ..., - maxPatternLength: int = ..., - maxLocalProjDBSize: int = ..., - sequenceCol: str = ..., - ) -> None: ... - def setParams( - self, - *, - minSupport: float = ..., - maxPatternLength: int = ..., - maxLocalProjDBSize: int = ..., - sequenceCol: str = ..., - ) -> PrefixSpan: ... - def setMinSupport(self, value: float) -> PrefixSpan: ... - def getMinSupport(self) -> float: ... - def setMaxPatternLength(self, value: int) -> PrefixSpan: ... - def getMaxPatternLength(self) -> int: ... - def setMaxLocalProjDBSize(self, value: int) -> PrefixSpan: ... - def getMaxLocalProjDBSize(self) -> int: ... - def setSequenceCol(self, value: str) -> PrefixSpan: ... - def getSequenceCol(self) -> str: ... - def findFrequentSequentialPatterns(self, dataset: DataFrame) -> DataFrame: ... diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index 7188ef3d10963..6dc97ac246ab3 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -25,12 +25,13 @@ """ import sys +from typing import Any, Dict, List, NoReturn, Optional, cast import numpy as np from distutils.version import LooseVersion from pyspark import SparkContext -from pyspark.sql.types import Row, _create_row, _parse_datatype_json_string +from pyspark.sql.types import Row, StructType, _create_row, _parse_datatype_json_string from pyspark.sql import SparkSession __all__ = ["ImageSchema"] @@ -43,15 +44,15 @@ class _ImageSchema: APIs of this class. """ - def __init__(self): - self._imageSchema = None - self._ocvTypes = None - self._columnSchema = None - self._imageFields = None - self._undefinedImageType = None + def __init__(self) -> None: + self._imageSchema: Optional[StructType] = None + self._ocvTypes: Optional[Dict[str, int]] = None + self._columnSchema: Optional[StructType] = None + self._imageFields: Optional[List[str]] = None + self._undefinedImageType: Optional[str] = None @property - def imageSchema(self): + def imageSchema(self) -> StructType: """ Returns the image schema. @@ -66,12 +67,13 @@ def imageSchema(self): if self._imageSchema is None: ctx = SparkContext._active_spark_context + assert ctx is not None and ctx._jvm is not None jschema = ctx._jvm.org.apache.spark.ml.image.ImageSchema.imageSchema() - self._imageSchema = _parse_datatype_json_string(jschema.json()) + self._imageSchema = cast(StructType, _parse_datatype_json_string(jschema.json())) return self._imageSchema @property - def ocvTypes(self): + def ocvTypes(self) -> Dict[str, int]: """ Returns the OpenCV type mapping supported. @@ -85,11 +87,12 @@ def ocvTypes(self): if self._ocvTypes is None: ctx = SparkContext._active_spark_context + assert ctx is not None and ctx._jvm is not None self._ocvTypes = dict(ctx._jvm.org.apache.spark.ml.image.ImageSchema.javaOcvTypes()) return self._ocvTypes @property - def columnSchema(self): + def columnSchema(self) -> StructType: """ Returns the schema for the image column. @@ -104,12 +107,13 @@ def columnSchema(self): if self._columnSchema is None: ctx = SparkContext._active_spark_context + assert ctx is not None and ctx._jvm is not None jschema = ctx._jvm.org.apache.spark.ml.image.ImageSchema.columnSchema() - self._columnSchema = _parse_datatype_json_string(jschema.json()) + self._columnSchema = cast(StructType, _parse_datatype_json_string(jschema.json())) return self._columnSchema @property - def imageFields(self): + def imageFields(self) -> List[str]: """ Returns field names of image columns. @@ -123,11 +127,12 @@ def imageFields(self): if self._imageFields is None: ctx = SparkContext._active_spark_context + assert ctx is not None and ctx._jvm is not None self._imageFields = list(ctx._jvm.org.apache.spark.ml.image.ImageSchema.imageFields()) return self._imageFields @property - def undefinedImageType(self): + def undefinedImageType(self) -> str: """ Returns the name of undefined image type for the invalid image. @@ -136,12 +141,13 @@ def undefinedImageType(self): if self._undefinedImageType is None: ctx = SparkContext._active_spark_context + assert ctx is not None and ctx._jvm is not None self._undefinedImageType = ( ctx._jvm.org.apache.spark.ml.image.ImageSchema.undefinedImageType() ) return self._undefinedImageType - def toNDArray(self, image): + def toNDArray(self, image: Row) -> np.ndarray: """ Converts an image to an array with metadata. @@ -181,7 +187,7 @@ def toNDArray(self, image): strides=(width * nChannels, nChannels, 1), ) - def toImage(self, array, origin=""): + def toImage(self, array: np.ndarray, origin: str = "") -> Row: """ Converts an array with metadata to a two-dimensional image. @@ -238,14 +244,14 @@ def toImage(self, array, origin=""): # Monkey patch to disallow instantiation of this class. -def _disallow_instance(_): +def _disallow_instance(_: Any) -> NoReturn: raise RuntimeError("Creating instance of _ImageSchema class is disallowed.") -_ImageSchema.__init__ = _disallow_instance +_ImageSchema.__init__ = _disallow_instance # type: ignore[assignment] -def _test(): +def _test() -> None: import doctest import pyspark.ml.image diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py index b361925712818..d3d2cbdaa0a01 100644 --- a/python/pyspark/ml/linalg/__init__.py +++ b/python/pyspark/ml/linalg/__init__.py @@ -40,6 +40,22 @@ BooleanType, ) +from typing import ( + Any, + Callable, + cast, + Dict, + Iterable, + List, + Optional, + overload, + Sequence, + Tuple, + Type, + TYPE_CHECKING, + Union, +) + __all__ = [ "Vector", @@ -52,6 +68,11 @@ "Matrices", ] +if TYPE_CHECKING: + from pyspark.mllib._typing import NormType + from pyspark.ml._typing import VectorLike + from scipy.sparse import spmatrix + # Check whether we have SciPy. MLlib works without it too, but if we have it, some methods, # such as _dot and _serialize_double_vector, start to support scipy.sparse matrices. @@ -65,23 +86,23 @@ _have_scipy = False -def _convert_to_vector(d): +def _convert_to_vector(d: Union["VectorLike", "spmatrix", range]) -> "Vector": if isinstance(d, Vector): return d elif type(d) in (array.array, np.array, np.ndarray, list, tuple, range): return DenseVector(d) elif _have_scipy and scipy.sparse.issparse(d): - assert d.shape[1] == 1, "Expected column vector" + assert cast("spmatrix", d).shape[1] == 1, "Expected column vector" # Make sure the converted csc_matrix has sorted indices. - csc = d.tocsc() + csc = cast("spmatrix", d).tocsc() if not csc.has_sorted_indices: csc.sort_indices() - return SparseVector(d.shape[0], csc.indices, csc.data) + return SparseVector(cast("spmatrix", d).shape[0], csc.indices, csc.data) else: raise TypeError("Cannot convert type %s into Vector" % type(d)) -def _vector_size(v): +def _vector_size(v: Union["VectorLike", "spmatrix", range]) -> int: """ Returns the size of the vector. @@ -112,24 +133,24 @@ def _vector_size(v): else: raise ValueError("Cannot treat an ndarray of shape %s as a vector" % str(v.shape)) elif _have_scipy and scipy.sparse.issparse(v): - assert v.shape[1] == 1, "Expected column vector" - return v.shape[0] + assert cast("spmatrix", v).shape[1] == 1, "Expected column vector" + return cast("spmatrix", v).shape[0] else: raise TypeError("Cannot treat type %s as a vector" % type(v)) -def _format_float(f, digits=4): +def _format_float(f: float, digits: int = 4) -> str: s = str(round(f, digits)) if "." in s: s = s[: s.index(".") + 1 + digits] return s -def _format_float_list(xs): +def _format_float_list(xs: Iterable[float]) -> List[str]: return [_format_float(x) for x in xs] -def _double_to_long_bits(value): +def _double_to_long_bits(value: float) -> int: if np.isnan(value): value = float("nan") # pack double into 64 bits, then unpack as long int @@ -142,7 +163,7 @@ class VectorUDT(UserDefinedType): """ @classmethod - def sqlType(cls): + def sqlType(cls) -> StructType: return StructType( [ StructField("type", ByteType(), False), @@ -153,37 +174,41 @@ def sqlType(cls): ) @classmethod - def module(cls): + def module(cls) -> str: return "pyspark.ml.linalg" @classmethod - def scalaUDT(cls): + def scalaUDT(cls) -> str: return "org.apache.spark.ml.linalg.VectorUDT" - def serialize(self, obj): + def serialize( + self, obj: "Vector" + ) -> Tuple[int, Optional[int], Optional[List[int]], List[float]]: if isinstance(obj, SparseVector): indices = [int(i) for i in obj.indices] values = [float(v) for v in obj.values] return (0, obj.size, indices, values) elif isinstance(obj, DenseVector): - values = [float(v) for v in obj] + values = [float(v) for v in obj] # type: ignore[attr-defined] return (1, None, None, values) else: raise TypeError("cannot serialize %r of type %r" % (obj, type(obj))) - def deserialize(self, datum): + def deserialize( + self, datum: Tuple[int, Optional[int], Optional[List[int]], List[float]] + ) -> "Vector": assert ( len(datum) == 4 ), "VectorUDT.deserialize given row with length %d but requires 4" % len(datum) tpe = datum[0] if tpe == 0: - return SparseVector(datum[1], datum[2], datum[3]) + return SparseVector(cast(int, datum[1]), cast(List[int], datum[2]), datum[3]) elif tpe == 1: return DenseVector(datum[3]) else: raise ValueError("do not recognize type %r" % tpe) - def simpleString(self): + def simpleString(self) -> str: return "vector" @@ -193,7 +218,7 @@ class MatrixUDT(UserDefinedType): """ @classmethod - def sqlType(cls): + def sqlType(cls) -> StructType: return StructType( [ StructField("type", ByteType(), False), @@ -207,14 +232,16 @@ def sqlType(cls): ) @classmethod - def module(cls): + def module(cls) -> str: return "pyspark.ml.linalg" @classmethod - def scalaUDT(cls): + def scalaUDT(cls) -> str: return "org.apache.spark.ml.linalg.MatrixUDT" - def serialize(self, obj): + def serialize( + self, obj: "Matrix" + ) -> Tuple[int, int, int, Optional[List[int]], Optional[List[int]], List[float], bool]: if isinstance(obj, SparseMatrix): colPtrs = [int(i) for i in obj.colPtrs] rowIndices = [int(i) for i in obj.rowIndices] @@ -234,19 +261,22 @@ def serialize(self, obj): else: raise TypeError("cannot serialize type %r" % (type(obj))) - def deserialize(self, datum): + def deserialize( + self, + datum: Tuple[int, int, int, Optional[List[int]], Optional[List[int]], List[float], bool], + ) -> "Matrix": assert ( len(datum) == 7 ), "MatrixUDT.deserialize given row with length %d but requires 7" % len(datum) tpe = datum[0] if tpe == 0: - return SparseMatrix(*datum[1:]) + return SparseMatrix(*datum[1:]) # type: ignore[arg-type] elif tpe == 1: return DenseMatrix(datum[1], datum[2], datum[5], datum[6]) else: raise ValueError("do not recognize type %r" % tpe) - def simpleString(self): + def simpleString(self) -> str: return "matrix" @@ -258,7 +288,7 @@ class Vector: Abstract class for DenseVector and SparseVector """ - def toArray(self): + def toArray(self) -> np.ndarray: """ Convert the vector into an numpy.ndarray @@ -266,6 +296,9 @@ def toArray(self): """ raise NotImplementedError + def __len__(self) -> int: + raise NotImplementedError + class DenseVector(Vector): """ @@ -293,25 +326,26 @@ class DenseVector(Vector): DenseVector([-1.0, -2.0]) """ - def __init__(self, ar): + def __init__(self, ar: Union[bytes, np.ndarray, Iterable[float]]): + ar_: np.ndarray if isinstance(ar, bytes): - ar = np.frombuffer(ar, dtype=np.float64) + ar_ = np.frombuffer(ar, dtype=np.float64) elif not isinstance(ar, np.ndarray): - ar = np.array(ar, dtype=np.float64) - if ar.dtype != np.float64: - ar = ar.astype(np.float64) - self.array = ar + ar_ = np.array(ar, dtype=np.float64) + else: + ar_ = ar.astype(np.float64) if ar.dtype != np.float64 else ar + self.array = ar_ - def __reduce__(self): - return DenseVector, (self.array.tostring(),) + def __reduce__(self) -> Tuple[Type["DenseVector"], Tuple[bytes]]: + return DenseVector, (self.array.tobytes(),) - def numNonzeros(self): + def numNonzeros(self) -> int: """ Number of nonzero elements. This scans all active values and count non zeros """ return np.count_nonzero(self.array) - def norm(self, p): + def norm(self, p: "NormType") -> np.float64: """ Calculates the norm of a DenseVector. @@ -325,7 +359,7 @@ def norm(self, p): """ return np.linalg.norm(self.array, p) - def dot(self, other): + def dot(self, other: Iterable[float]) -> np.float64: """ Compute the dot product of two Vectors. We support (Numpy array, list, SparseVector, or SciPy sparse) @@ -359,8 +393,8 @@ def dot(self, other): assert len(self) == other.shape[0], "dimension mismatch" return np.dot(self.array, other) elif _have_scipy and scipy.sparse.issparse(other): - assert len(self) == other.shape[0], "dimension mismatch" - return other.transpose().dot(self.toArray()) + assert len(self) == cast("spmatrix", other).shape[0], "dimension mismatch" + return cast("spmatrix", other).transpose().dot(self.toArray()) else: assert len(self) == _vector_size(other), "dimension mismatch" if isinstance(other, SparseVector): @@ -368,9 +402,9 @@ def dot(self, other): elif isinstance(other, Vector): return np.dot(self.toArray(), other.toArray()) else: - return np.dot(self.toArray(), other) + return np.dot(self.toArray(), other) # type: ignore[call-overload] - def squared_distance(self, other): + def squared_distance(self, other: Iterable[float]) -> np.float64: """ Squared distance of two Vectors. @@ -401,41 +435,49 @@ def squared_distance(self, other): if isinstance(other, SparseVector): return other.squared_distance(self) elif _have_scipy and scipy.sparse.issparse(other): - return _convert_to_vector(other).squared_distance(self) + return _convert_to_vector(other).squared_distance(self) # type: ignore[attr-defined] if isinstance(other, Vector): other = other.toArray() elif not isinstance(other, np.ndarray): other = np.array(other) - diff = self.toArray() - other + diff: np.ndarray = self.toArray() - other return np.dot(diff, diff) - def toArray(self): + def toArray(self) -> np.ndarray: """ Returns the underlying numpy.ndarray """ return self.array @property - def values(self): + def values(self) -> np.ndarray: """ Returns the underlying numpy.ndarray """ return self.array - def __getitem__(self, item): + @overload + def __getitem__(self, item: int) -> np.float64: + ... + + @overload + def __getitem__(self, item: slice) -> np.ndarray: + ... + + def __getitem__(self, item: Union[int, slice]) -> Union[np.float64, np.ndarray]: return self.array[item] - def __len__(self): + def __len__(self) -> int: return len(self.array) - def __str__(self): + def __str__(self) -> str: return "[" + ",".join([str(v) for v in self.array]) + "]" - def __repr__(self): + def __repr__(self) -> str: return "DenseVector([%s])" % (", ".join(_format_float(i) for i in self.array)) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, DenseVector): return np.array_equal(self.array, other.array) elif isinstance(other, SparseVector): @@ -444,10 +486,10 @@ def __eq__(self, other): return Vectors._equals(list(range(len(self))), self.array, other.indices, other.values) return False - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other - def __hash__(self): + def __hash__(self) -> int: size = len(self) result = 31 + size nnz = 0 @@ -461,14 +503,14 @@ def __hash__(self): i += 1 return result - def __getattr__(self, item): + def __getattr__(self, item: str) -> Any: return getattr(self.array, item) - def __neg__(self): + def __neg__(self) -> "DenseVector": return DenseVector(-self.array) - def _delegate(op): - def func(self, other): + def _delegate(op: str) -> Callable[["DenseVector", Any], "DenseVector"]: # type: ignore[misc] + def func(self: "DenseVector", other: Any) -> "DenseVector": if isinstance(other, DenseVector): other = other.array return DenseVector(getattr(self.array, op)(other)) @@ -495,7 +537,33 @@ class SparseVector(Vector): alternatively pass SciPy's {scipy.sparse} data types. """ - def __init__(self, size, *args): + @overload + def __init__(self, size: int, __indices: bytes, __values: bytes): + ... + + @overload + def __init__(self, size: int, *args: Tuple[int, float]): + ... + + @overload + def __init__(self, size: int, __indices: Iterable[int], __values: Iterable[float]): + ... + + @overload + def __init__(self, size: int, __pairs: Iterable[Tuple[int, float]]): + ... + + @overload + def __init__(self, size: int, __map: Dict[int, float]): + ... + + def __init__( + self, + size: int, + *args: Union[ + bytes, Tuple[int, float], Iterable[float], Iterable[Tuple[int, float]], Dict[int, float] + ], + ): """ Create a sparse vector, using either a dictionary, a list of (index, value) pairs, or two separate arrays of indices and @@ -535,7 +603,7 @@ def __init__(self, size, *args): pairs = args[0] if type(pairs) == dict: pairs = pairs.items() - pairs = sorted(pairs) + pairs = cast(Iterable[Tuple[int, float]], sorted(pairs)) self.indices = np.array([p[0] for p in pairs], dtype=np.int32) """ A list of indices corresponding to active entries. """ self.values = np.array([p[1] for p in pairs], dtype=np.float64) @@ -570,13 +638,13 @@ def __init__(self, size, *args): ) assert np.min(self.indices) >= 0, "Contains negative index %d" % (np.min(self.indices)) - def numNonzeros(self): + def numNonzeros(self) -> int: """ Number of nonzero elements. This scans all active values and count non zeros. """ return np.count_nonzero(self.values) - def norm(self, p): + def norm(self, p: "NormType") -> np.float64: """ Calculates the norm of a SparseVector. @@ -590,10 +658,10 @@ def norm(self, p): """ return np.linalg.norm(self.values, p) - def __reduce__(self): - return (SparseVector, (self.size, self.indices.tostring(), self.values.tostring())) + def __reduce__(self) -> Tuple[Type["SparseVector"], Tuple[int, bytes, bytes]]: + return (SparseVector, (self.size, self.indices.tobytes(), self.values.tobytes())) - def dot(self, other): + def dot(self, other: Iterable[float]) -> np.float64: """ Dot product with a SparseVector or 1- or 2-dimensional Numpy array. @@ -643,15 +711,15 @@ def dot(self, other): self_cmind = np.in1d(self.indices, other.indices, assume_unique=True) self_values = self.values[self_cmind] if self_values.size == 0: - return 0.0 + return np.float64(0.0) else: other_cmind = np.in1d(other.indices, self.indices, assume_unique=True) return np.dot(self_values, other.values[other_cmind]) else: - return self.dot(_convert_to_vector(other)) + return self.dot(_convert_to_vector(other)) # type: ignore[arg-type] - def squared_distance(self, other): + def squared_distance(self, other: Iterable[float]) -> np.float64: """ Squared distance from a SparseVector or 1-dimensional NumPy array. @@ -719,9 +787,9 @@ def squared_distance(self, other): j += 1 return result else: - return self.squared_distance(_convert_to_vector(other)) + return self.squared_distance(_convert_to_vector(other)) # type: ignore[arg-type] - def toArray(self): + def toArray(self) -> np.ndarray: """ Returns a copy of this SparseVector as a 1-dimensional numpy.ndarray. """ @@ -729,15 +797,15 @@ def toArray(self): arr[self.indices] = self.values return arr - def __len__(self): + def __len__(self) -> int: return self.size - def __str__(self): + def __str__(self) -> str: inds = "[" + ",".join([str(i) for i in self.indices]) + "]" vals = "[" + ",".join([str(v) for v in self.values]) + "]" return "(" + ",".join((str(self.size), inds, vals)) + ")" - def __repr__(self): + def __repr__(self) -> str: inds = self.indices vals = self.values entries = ", ".join( @@ -745,7 +813,7 @@ def __repr__(self): ) return "SparseVector({0}, {{{1}}})".format(self.size, entries) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, SparseVector): return ( other.size == self.size @@ -758,7 +826,7 @@ def __eq__(self, other): return Vectors._equals(self.indices, self.values, list(range(len(other))), other.array) return False - def __getitem__(self, index): + def __getitem__(self, index: int) -> np.float64: inds = self.indices vals = self.values if not isinstance(index, int): @@ -770,18 +838,18 @@ def __getitem__(self, index): index += self.size if (inds.size == 0) or (index > inds.item(-1)): - return 0.0 + return np.float64(0.0) insert_index = np.searchsorted(inds, index) row_ind = inds[insert_index] if row_ind == index: return vals[insert_index] - return 0.0 + return np.float64(0.0) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self.__eq__(other) - def __hash__(self): + def __hash__(self) -> int: result = 31 + self.size nnz = 0 i = 0 @@ -809,7 +877,37 @@ class Vectors: """ @staticmethod - def sparse(size, *args): + @overload + def sparse(size: int, __indices: bytes, __values: bytes) -> SparseVector: + ... + + @staticmethod + @overload + def sparse(size: int, *args: Tuple[int, float]) -> SparseVector: + ... + + @staticmethod + @overload + def sparse(size: int, __indices: Iterable[int], __values: Iterable[float]) -> SparseVector: + ... + + @staticmethod + @overload + def sparse(size: int, __pairs: Iterable[Tuple[int, float]]) -> SparseVector: + ... + + @staticmethod + @overload + def sparse(size: int, __map: Dict[int, float]) -> SparseVector: + ... + + @staticmethod + def sparse( + size: int, + *args: Union[ + bytes, Tuple[int, float], Iterable[float], Iterable[Tuple[int, float]], Dict[int, float] + ], + ) -> SparseVector: """ Create a sparse vector, using either a dictionary, a list of (index, value) pairs, or two separate arrays of indices and @@ -832,10 +930,25 @@ def sparse(size, *args): >>> Vectors.sparse(4, [1, 3], [1.0, 5.5]) SparseVector(4, {1: 1.0, 3: 5.5}) """ - return SparseVector(size, *args) + return SparseVector(size, *args) # type: ignore[arg-type] + + @overload + @staticmethod + def dense(*elements: float) -> DenseVector: + ... + + @overload + @staticmethod + def dense(__arr: bytes) -> DenseVector: + ... + + @overload + @staticmethod + def dense(__arr: Iterable[float]) -> DenseVector: + ... @staticmethod - def dense(*elements): + def dense(*elements: Union[float, bytes, np.ndarray, Iterable[float]]) -> DenseVector: """ Create a dense vector of 64-bit floats from a Python list or numbers. @@ -848,11 +961,11 @@ def dense(*elements): """ if len(elements) == 1 and not isinstance(elements[0], (float, int)): # it's list, numpy.array or other iterable object. - elements = elements[0] - return DenseVector(elements) + elements = elements[0] # type: ignore[assignment] + return DenseVector(cast(Iterable[float], elements)) @staticmethod - def squared_distance(v1, v2): + def squared_distance(v1: Vector, v2: Vector) -> np.float64: """ Squared distance between two vectors. a and b can be of type SparseVector, DenseVector, np.ndarray @@ -866,21 +979,26 @@ def squared_distance(v1, v2): 51.0 """ v1, v2 = _convert_to_vector(v1), _convert_to_vector(v2) - return v1.squared_distance(v2) + return v1.squared_distance(v2) # type: ignore[attr-defined] @staticmethod - def norm(vector, p): + def norm(vector: Vector, p: "NormType") -> np.float64: """ Find norm of the given vector. """ - return _convert_to_vector(vector).norm(p) + return _convert_to_vector(vector).norm(p) # type: ignore[attr-defined] @staticmethod - def zeros(size): + def zeros(size: int) -> DenseVector: return DenseVector(np.zeros(size)) @staticmethod - def _equals(v1_indices, v1_values, v2_indices, v2_values): + def _equals( + v1_indices: Union[Sequence[int], np.ndarray], + v1_values: Union[Sequence[float], np.ndarray], + v2_indices: Union[Sequence[int], np.ndarray], + v2_values: Union[Sequence[float], np.ndarray], + ) -> bool: """ Check equality between sparse/dense vectors, v1_indices and v2_indices assume to be strictly increasing. @@ -913,19 +1031,19 @@ class Matrix: Represents a local matrix. """ - def __init__(self, numRows, numCols, isTransposed=False): + def __init__(self, numRows: int, numCols: int, isTransposed: bool = False): self.numRows = numRows self.numCols = numCols self.isTransposed = isTransposed - def toArray(self): + def toArray(self) -> np.ndarray: """ Returns its elements in a numpy.ndarray. """ raise NotImplementedError @staticmethod - def _convert_to_array(array_like, dtype): + def _convert_to_array(array_like: Union[bytes, Iterable[float]], dtype: Any) -> np.ndarray: """ Convert Matrix attributes which are array-like or buffer to array. """ @@ -939,21 +1057,27 @@ class DenseMatrix(Matrix): Column-major dense matrix. """ - def __init__(self, numRows, numCols, values, isTransposed=False): + def __init__( + self, + numRows: int, + numCols: int, + values: Union[bytes, Iterable[float]], + isTransposed: bool = False, + ): Matrix.__init__(self, numRows, numCols, isTransposed) values = self._convert_to_array(values, np.float64) assert len(values) == numRows * numCols self.values = values - def __reduce__(self): + def __reduce__(self) -> Tuple[Type["DenseMatrix"], Tuple[int, int, bytes, int]]: return DenseMatrix, ( self.numRows, self.numCols, - self.values.tostring(), + self.values.tobytes(), int(self.isTransposed), ) - def __str__(self): + def __str__(self) -> str: """ Pretty printing of a DenseMatrix @@ -976,7 +1100,7 @@ def __str__(self): x = "\n".join([(" " * 6 + line) for line in array_lines[1:]]) return array_lines[0].replace("array", "DenseMatrix") + "\n" + x - def __repr__(self): + def __repr__(self) -> str: """ Representation of a DenseMatrix @@ -995,12 +1119,12 @@ def __repr__(self): _format_float_list(self.values[:8]) + ["..."] + _format_float_list(self.values[-8:]) ) - entries = ", ".join(entries) + entries = ", ".join(entries) # type: ignore[assignment] return "DenseMatrix({0}, {1}, [{2}], {3})".format( self.numRows, self.numCols, entries, self.isTransposed ) - def toArray(self): + def toArray(self) -> np.ndarray: """ Return a :py:class:`numpy.ndarray` @@ -1016,7 +1140,7 @@ def toArray(self): else: return self.values.reshape((self.numRows, self.numCols), order="F") - def toSparse(self): + def toSparse(self) -> "SparseMatrix": """Convert to SparseMatrix""" if self.isTransposed: values = np.ravel(self.toArray(), order="F") @@ -1030,7 +1154,7 @@ def toSparse(self): return SparseMatrix(self.numRows, self.numCols, colPtrs, rowIndices, values) - def __getitem__(self, indices): + def __getitem__(self, indices: Tuple[int, int]) -> np.float64: i, j = indices if i < 0 or i >= self.numRows: raise IndexError("Row index %d is out of range [0, %d)" % (i, self.numRows)) @@ -1042,21 +1166,29 @@ def __getitem__(self, indices): else: return self.values[i + j * self.numRows] - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if self.numRows != other.numRows or self.numCols != other.numCols: return False if isinstance(other, SparseMatrix): - return np.all(self.toArray() == other.toArray()) + return np.all(self.toArray() == other.toArray()).tolist() self_values = np.ravel(self.toArray(), order="F") other_values = np.ravel(other.toArray(), order="F") - return np.all(self_values == other_values) + return np.all(self_values == other_values).tolist() class SparseMatrix(Matrix): """Sparse Matrix stored in CSC format.""" - def __init__(self, numRows, numCols, colPtrs, rowIndices, values, isTransposed=False): + def __init__( + self, + numRows: int, + numCols: int, + colPtrs: Union[bytes, Iterable[int]], + rowIndices: Union[bytes, Iterable[int]], + values: Union[bytes, Iterable[float]], + isTransposed: bool = False, + ): Matrix.__init__(self, numRows, numCols, isTransposed) self.colPtrs = self._convert_to_array(colPtrs, np.int32) self.rowIndices = self._convert_to_array(rowIndices, np.int32) @@ -1078,7 +1210,7 @@ def __init__(self, numRows, numCols, colPtrs, rowIndices, values, isTransposed=F % (self.rowIndices.size, self.values.size) ) - def __str__(self): + def __str__(self) -> str: """ Pretty printing of a SparseMatrix @@ -1124,7 +1256,7 @@ def __str__(self): spstr += "\n.." * 2 return spstr - def __repr__(self): + def __repr__(self) -> str: """ Representation of a SparseMatrix @@ -1149,24 +1281,24 @@ def __repr__(self): if len(self.colPtrs) > 16: colPtrs = colPtrs[:8] + ["..."] + colPtrs[-8:] - values = ", ".join(values) - rowIndices = ", ".join([str(ind) for ind in rowIndices]) - colPtrs = ", ".join([str(ptr) for ptr in colPtrs]) + values = ", ".join(values) # type: ignore[assignment] + rowIndices = ", ".join([str(ind) for ind in rowIndices]) # type: ignore[assignment] + colPtrs = ", ".join([str(ptr) for ptr in colPtrs]) # type: ignore[assignment] return "SparseMatrix({0}, {1}, [{2}], [{3}], [{4}], {5})".format( self.numRows, self.numCols, colPtrs, rowIndices, values, self.isTransposed ) - def __reduce__(self): + def __reduce__(self) -> Tuple[Type["SparseMatrix"], Tuple[int, int, bytes, bytes, bytes, int]]: return SparseMatrix, ( self.numRows, self.numCols, - self.colPtrs.tostring(), - self.rowIndices.tostring(), - self.values.tostring(), + self.colPtrs.tobytes(), + self.rowIndices.tobytes(), + self.values.tobytes(), int(self.isTransposed), ) - def __getitem__(self, indices): + def __getitem__(self, indices: Tuple[int, int]) -> np.float64: i, j = indices if i < 0 or i >= self.numRows: raise IndexError("Row index %d is out of range [0, %d)" % (i, self.numRows)) @@ -1186,9 +1318,9 @@ def __getitem__(self, indices): if ind < colEnd and self.rowIndices[ind] == i: return self.values[ind] else: - return 0.0 + return np.float64(0.0) - def toArray(self): + def toArray(self) -> np.ndarray: """ Return a numpy.ndarray """ @@ -1202,32 +1334,38 @@ def toArray(self): A[self.rowIndices[startptr:endptr], k] = self.values[startptr:endptr] return A - def toDense(self): + def toDense(self) -> "DenseMatrix": densevals = np.ravel(self.toArray(), order="F") return DenseMatrix(self.numRows, self.numCols, densevals) # TODO: More efficient implementation: - def __eq__(self, other): - return np.all(self.toArray() == other.toArray()) + def __eq__(self, other: Any) -> bool: + return np.all(self.toArray() == other.toArray()).tolist() class Matrices: @staticmethod - def dense(numRows, numCols, values): + def dense(numRows: int, numCols: int, values: Union[bytes, Iterable[float]]) -> DenseMatrix: """ Create a DenseMatrix """ return DenseMatrix(numRows, numCols, values) @staticmethod - def sparse(numRows, numCols, colPtrs, rowIndices, values): + def sparse( + numRows: int, + numCols: int, + colPtrs: Union[bytes, Iterable[int]], + rowIndices: Union[bytes, Iterable[int]], + values: Union[bytes, Iterable[float]], + ) -> SparseMatrix: """ Create a SparseMatrix """ return SparseMatrix(numRows, numCols, colPtrs, rowIndices, values) -def _test(): +def _test() -> None: import doctest try: diff --git a/python/pyspark/ml/linalg/__init__.pyi b/python/pyspark/ml/linalg/__init__.pyi deleted file mode 100644 index bb0939771b1be..0000000000000 --- a/python/pyspark/ml/linalg/__init__.pyi +++ /dev/null @@ -1,243 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import overload -from typing import Any, Dict, Iterable, List, NoReturn, Optional, Tuple, Type, Union - -from pyspark.ml import linalg as newlinalg # noqa: F401 -from pyspark.sql.types import StructType, UserDefinedType - -from numpy import float64, ndarray - -class VectorUDT(UserDefinedType): - @classmethod - def sqlType(cls) -> StructType: ... - @classmethod - def module(cls) -> str: ... - @classmethod - def scalaUDT(cls) -> str: ... - def serialize( - self, obj: Vector - ) -> Tuple[int, Optional[int], Optional[List[int]], List[float]]: ... - def deserialize(self, datum: Any) -> Vector: ... - def simpleString(self) -> str: ... - -class MatrixUDT(UserDefinedType): - @classmethod - def sqlType(cls) -> StructType: ... - @classmethod - def module(cls) -> str: ... - @classmethod - def scalaUDT(cls) -> str: ... - def serialize( - self, obj: Matrix - ) -> Tuple[int, int, int, Optional[List[int]], Optional[List[int]], List[float], bool]: ... - def deserialize(self, datum: Any) -> Matrix: ... - def simpleString(self) -> str: ... - -class Vector: - __UDT__: VectorUDT - def toArray(self) -> ndarray: ... - -class DenseVector(Vector): - array: ndarray - @overload - def __init__(self, *elements: float) -> None: ... - @overload - def __init__(self, __arr: bytes) -> None: ... - @overload - def __init__(self, __arr: Iterable[float]) -> None: ... - def __reduce__(self) -> Tuple[Type[DenseVector], bytes]: ... - def numNonzeros(self) -> int: ... - def norm(self, p: Union[float, str]) -> float64: ... - def dot(self, other: Iterable[float]) -> float64: ... - def squared_distance(self, other: Iterable[float]) -> float64: ... - def toArray(self) -> ndarray: ... - @property - def values(self) -> ndarray: ... - def __getitem__(self, item: int) -> float64: ... - def __len__(self) -> int: ... - def __eq__(self, other: Any) -> bool: ... - def __ne__(self, other: Any) -> bool: ... - def __hash__(self) -> int: ... - def __getattr__(self, item: str) -> Any: ... - def __neg__(self) -> DenseVector: ... - def __add__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... - def __sub__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... - def __mul__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... - def __div__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... - def __truediv__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... - def __mod__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... - def __radd__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... - def __rsub__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... - def __rmul__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... - def __rdiv__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... - def __rtruediv__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... - def __rmod__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... - -class SparseVector(Vector): - size: int - indices: ndarray - values: ndarray - @overload - def __init__(self, size: int, *args: Tuple[int, float]) -> None: ... - @overload - def __init__(self, size: int, __indices: bytes, __values: bytes) -> None: ... - @overload - def __init__(self, size: int, __indices: Iterable[int], __values: Iterable[float]) -> None: ... - @overload - def __init__(self, size: int, __pairs: Iterable[Tuple[int, float]]) -> None: ... - @overload - def __init__(self, size: int, __map: Dict[int, float]) -> None: ... - def numNonzeros(self) -> int: ... - def norm(self, p: Union[float, str]) -> float64: ... - def __reduce__(self) -> Tuple[Type[SparseVector], Tuple[int, bytes, bytes]]: ... - def dot(self, other: Iterable[float]) -> float64: ... - def squared_distance(self, other: Iterable[float]) -> float64: ... - def toArray(self) -> ndarray: ... - def __len__(self) -> int: ... - def __eq__(self, other: Any) -> bool: ... - def __getitem__(self, index: int) -> float64: ... - def __ne__(self, other: Any) -> bool: ... - def __hash__(self) -> int: ... - -class Vectors: - @overload - @staticmethod - def sparse(size: int, *args: Tuple[int, float]) -> SparseVector: ... - @overload - @staticmethod - def sparse(size: int, __indices: bytes, __values: bytes) -> SparseVector: ... - @overload - @staticmethod - def sparse(size: int, __indices: Iterable[int], __values: Iterable[float]) -> SparseVector: ... - @overload - @staticmethod - def sparse(size: int, __pairs: Iterable[Tuple[int, float]]) -> SparseVector: ... - @overload - @staticmethod - def sparse(size: int, __map: Dict[int, float]) -> SparseVector: ... - @overload - @staticmethod - def dense(*elements: float) -> DenseVector: ... - @overload - @staticmethod - def dense(__arr: bytes) -> DenseVector: ... - @overload - @staticmethod - def dense(__arr: Iterable[float]) -> DenseVector: ... - @staticmethod - def stringify(vector: Vector) -> str: ... - @staticmethod - def squared_distance(v1: Vector, v2: Vector) -> float64: ... - @staticmethod - def norm(vector: Vector, p: Union[float, str]) -> float64: ... - @staticmethod - def zeros(size: int) -> DenseVector: ... - -class Matrix: - __UDT__: MatrixUDT - numRows: int - numCols: int - isTransposed: bool - def __init__(self, numRows: int, numCols: int, isTransposed: bool = ...) -> None: ... - def toArray(self) -> ndarray: ... - -class DenseMatrix(Matrix): - values: Any - @overload - def __init__( - self, numRows: int, numCols: int, values: bytes, isTransposed: bool = ... - ) -> None: ... - @overload - def __init__( - self, - numRows: int, - numCols: int, - values: Iterable[float], - isTransposed: bool = ..., - ) -> None: ... - def __reduce__(self) -> Tuple[Type[DenseMatrix], Tuple[int, int, bytes, int]]: ... - def toArray(self) -> ndarray: ... - def toSparse(self) -> SparseMatrix: ... - def __getitem__(self, indices: Tuple[int, int]) -> float64: ... - def __eq__(self, other: Any) -> bool: ... - -class SparseMatrix(Matrix): - colPtrs: ndarray - rowIndices: ndarray - values: ndarray - @overload - def __init__( - self, - numRows: int, - numCols: int, - colPtrs: bytes, - rowIndices: bytes, - values: bytes, - isTransposed: bool = ..., - ) -> None: ... - @overload - def __init__( - self, - numRows: int, - numCols: int, - colPtrs: Iterable[int], - rowIndices: Iterable[int], - values: Iterable[float], - isTransposed: bool = ..., - ) -> None: ... - def __reduce__( - self, - ) -> Tuple[Type[SparseMatrix], Tuple[int, int, bytes, bytes, bytes, int]]: ... - def __getitem__(self, indices: Tuple[int, int]) -> float64: ... - def toArray(self) -> ndarray: ... - def toDense(self) -> DenseMatrix: ... - def __eq__(self, other: Any) -> bool: ... - -class Matrices: - @overload - @staticmethod - def dense( - numRows: int, numCols: int, values: bytes, isTransposed: bool = ... - ) -> DenseMatrix: ... - @overload - @staticmethod - def dense( - numRows: int, numCols: int, values: Iterable[float], isTransposed: bool = ... - ) -> DenseMatrix: ... - @overload - @staticmethod - def sparse( - numRows: int, - numCols: int, - colPtrs: bytes, - rowIndices: bytes, - values: bytes, - isTransposed: bool = ..., - ) -> SparseMatrix: ... - @overload - @staticmethod - def sparse( - numRows: int, - numCols: int, - colPtrs: Iterable[int], - rowIndices: Iterable[int], - values: Iterable[float], - isTransposed: bool = ..., - ) -> SparseMatrix: ... diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 092f79f50f4d2..74f7b0bc3c7a8 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -20,7 +20,6 @@ from typing import ( Any, Callable, - cast, Generic, List, Optional, @@ -43,6 +42,7 @@ __all__ = ["Param", "Params", "TypeConverters"] T = TypeVar("T") +P = TypeVar("P", bound="Params") class Param(Generic[T]): @@ -303,7 +303,7 @@ def explainParam(self, param: Union[str, Param]) -> str: Explains a single param and returns its name, doc, and optional default value and user-supplied value in a string. """ - param = cast(Param, self._resolveParam(param)) + param = self._resolveParam(param) values = [] if self.isDefined(param): if param in self._defaultParamMap: @@ -409,7 +409,7 @@ def extractParamMap(self, extra: Optional["ParamMap"] = None) -> "ParamMap": paramMap.update(extra) return paramMap - def copy(self, extra: Optional["ParamMap"] = None) -> "Params": + def copy(self: P, extra: Optional["ParamMap"] = None) -> P: """ Creates a copy of this instance with the same uid and some extra params. The default implementation creates a @@ -492,7 +492,7 @@ def _dummy() -> "Params": dummy.uid = "undefined" return dummy - def _set(self, **kwargs: Any) -> "Params": + def _set(self: P, **kwargs: Any) -> P: """ Sets user-supplied params. """ @@ -513,7 +513,7 @@ def clear(self, param: Param) -> None: if self.isSet(param): del self._paramMap[param] - def _setDefault(self, **kwargs: Any) -> "Params": + def _setDefault(self: P, **kwargs: Any) -> P: """ Sets default params. """ @@ -529,7 +529,7 @@ def _setDefault(self, **kwargs: Any) -> "Params": self._defaultParamMap[p] = value return self - def _copyValues(self, to: "Params", extra: Optional["ParamMap"] = None) -> "Params": + def _copyValues(self, to: P, extra: Optional["ParamMap"] = None) -> P: """ Copies param values from this instance to another instance for params shared by them. @@ -568,7 +568,7 @@ def _copyValues(self, to: "Params", extra: Optional["ParamMap"] = None) -> "Para to._set(**{param.name: paramMap[param]}) return to - def _resetUid(self, newUid: Any) -> "Params": + def _resetUid(self: P, newUid: Any) -> P: """ Changes the uid of this instance. This updates both the stored uid and the parent uid of params and param maps. diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 82c752187f9ca..5df1782084a81 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -102,7 +102,7 @@ def get{name[0].upper()}{name[1:]}(self) -> {paramType}: print(header) print("\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n") print("from typing import List\n") - print("from pyspark.ml.param import *\n\n") + print("from pyspark.ml.param import Param, Params, TypeConverters\n\n") shared = [ ( "maxIter", diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index f24d1796a72db..fcfced2e566df 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -19,7 +19,7 @@ from typing import List -from pyspark.ml.param import * +from pyspark.ml.param import Param, Params, TypeConverters class HasMaxIter(Params): diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 57ca47ec3bca8..24653d1d919ee 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -16,6 +16,8 @@ # import os +from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast, TYPE_CHECKING + from pyspark import keyword_only, since, SparkContext from pyspark.ml.base import Estimator, Model, Transformer from pyspark.ml.param import Param, Params @@ -28,14 +30,20 @@ DefaultParamsWriter, MLWriter, MLReader, + JavaMLReadable, JavaMLWritable, ) from pyspark.ml.wrapper import JavaParams from pyspark.ml.common import inherit_doc +from pyspark.sql.dataframe import DataFrame + +if TYPE_CHECKING: + from pyspark.ml._typing import ParamMap, PipelineStage + from py4j.java_gateway import JavaObject @inherit_doc -class Pipeline(Estimator, MLReadable, MLWritable): +class Pipeline(Estimator["PipelineModel"], MLReadable["Pipeline"], MLWritable): """ A simple pipeline, which acts as an estimator. A Pipeline consists of a sequence of stages, each of which is either an @@ -56,10 +64,14 @@ class Pipeline(Estimator, MLReadable, MLWritable): .. versionadded:: 1.3.0 """ - stages = Param(Params._dummy(), "stages", "a list of pipeline stages") + stages: Param[List["PipelineStage"]] = Param( + Params._dummy(), "stages", "a list of pipeline stages" + ) + + _input_kwargs: Dict[str, Any] @keyword_only - def __init__(self, *, stages=None): + def __init__(self, *, stages: Optional[List["PipelineStage"]] = None): """ __init__(self, \\*, stages=None) """ @@ -67,7 +79,7 @@ def __init__(self, *, stages=None): kwargs = self._input_kwargs self.setParams(**kwargs) - def setStages(self, value): + def setStages(self, value: List["PipelineStage"]) -> "Pipeline": """ Set pipeline stages. @@ -87,7 +99,7 @@ def setStages(self, value): return self._set(stages=value) @since("1.3.0") - def getStages(self): + def getStages(self) -> List["PipelineStage"]: """ Get pipeline stages. """ @@ -95,7 +107,7 @@ def getStages(self): @keyword_only @since("1.3.0") - def setParams(self, *, stages=None): + def setParams(self, *, stages: Optional[List["PipelineStage"]] = None) -> "Pipeline": """ setParams(self, \\*, stages=None) Sets params for Pipeline. @@ -103,7 +115,7 @@ def setParams(self, *, stages=None): kwargs = self._input_kwargs return self._set(**kwargs) - def _fit(self, dataset): + def _fit(self, dataset: DataFrame) -> "PipelineModel": stages = self.getStages() for stage in stages: if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)): @@ -112,7 +124,7 @@ def _fit(self, dataset): for i, stage in enumerate(stages): if isinstance(stage, Estimator): indexOfLastEstimator = i - transformers = [] + transformers: List[Transformer] = [] for i, stage in enumerate(stages): if i <= indexOfLastEstimator: if isinstance(stage, Transformer): @@ -124,10 +136,10 @@ def _fit(self, dataset): if i < indexOfLastEstimator: dataset = model.transform(dataset) else: - transformers.append(stage) + transformers.append(cast(Transformer, stage)) return PipelineModel(transformers) - def copy(self, extra=None): + def copy(self, extra: Optional["ParamMap"] = None) -> "Pipeline": """ Creates a copy of this instance. @@ -150,21 +162,21 @@ def copy(self, extra=None): return that.setStages(stages) @since("2.0.0") - def write(self): + def write(self) -> MLWriter: """Returns an MLWriter instance for this ML instance.""" allStagesAreJava = PipelineSharedReadWrite.checkStagesForJava(self.getStages()) if allStagesAreJava: - return JavaMLWriter(self) + return JavaMLWriter(self) # type: ignore[arg-type] return PipelineWriter(self) @classmethod @since("2.0.0") - def read(cls): + def read(cls) -> "PipelineReader": """Returns an MLReader instance for this class.""" return PipelineReader(cls) @classmethod - def _from_java(cls, java_stage): + def _from_java(cls, java_stage: "JavaObject") -> "Pipeline": """ Given a Java Pipeline, create and return a Python wrapper of it. Used for ML persistence. @@ -172,12 +184,14 @@ def _from_java(cls, java_stage): # Create a new instance of this stage. py_stage = cls() # Load information from java_stage to the instance. - py_stages = [JavaParams._from_java(s) for s in java_stage.getStages()] + py_stages: List["PipelineStage"] = [ + JavaParams._from_java(s) for s in java_stage.getStages() + ] py_stage.setStages(py_stages) py_stage._resetUid(java_stage.uid()) return py_stage - def _to_java(self): + def _to_java(self) -> "JavaObject": """ Transfer this instance to a Java Pipeline. Used for ML persistence. @@ -188,10 +202,12 @@ def _to_java(self): """ gateway = SparkContext._gateway + assert gateway is not None and SparkContext._jvm is not None + cls = SparkContext._jvm.org.apache.spark.ml.PipelineStage java_stages = gateway.new_array(cls, len(self.getStages())) for idx, stage in enumerate(self.getStages()): - java_stages[idx] = stage._to_java() + java_stages[idx] = cast(JavaParams, stage)._to_java() _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.Pipeline", self.uid) _java_obj.setStages(java_stages) @@ -205,30 +221,30 @@ class PipelineWriter(MLWriter): (Private) Specialization of :py:class:`MLWriter` for :py:class:`Pipeline` types """ - def __init__(self, instance): + def __init__(self, instance: Pipeline): super(PipelineWriter, self).__init__() self.instance = instance - def saveImpl(self, path): + def saveImpl(self, path: str) -> None: stages = self.instance.getStages() PipelineSharedReadWrite.validateStages(stages) PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path) @inherit_doc -class PipelineReader(MLReader): +class PipelineReader(MLReader[Pipeline]): """ (Private) Specialization of :py:class:`MLReader` for :py:class:`Pipeline` types """ - def __init__(self, cls): + def __init__(self, cls: Type[Pipeline]): super(PipelineReader, self).__init__() self.cls = cls - def load(self, path): + def load(self, path: str) -> Pipeline: metadata = DefaultParamsReader.loadMetadata(path, self.sc) if "language" not in metadata["paramMap"] or metadata["paramMap"]["language"] != "Python": - return JavaMLReader(self.cls).load(path) + return JavaMLReader(cast(Type["JavaMLReadable[Pipeline]"], self.cls)).load(path) else: uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path) return Pipeline(stages=stages)._resetUid(uid) @@ -240,53 +256,55 @@ class PipelineModelWriter(MLWriter): (Private) Specialization of :py:class:`MLWriter` for :py:class:`PipelineModel` types """ - def __init__(self, instance): + def __init__(self, instance: "PipelineModel"): super(PipelineModelWriter, self).__init__() self.instance = instance - def saveImpl(self, path): + def saveImpl(self, path: str) -> None: stages = self.instance.stages - PipelineSharedReadWrite.validateStages(stages) - PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path) + PipelineSharedReadWrite.validateStages(cast(List["PipelineStage"], stages)) + PipelineSharedReadWrite.saveImpl( + self.instance, cast(List["PipelineStage"], stages), self.sc, path + ) @inherit_doc -class PipelineModelReader(MLReader): +class PipelineModelReader(MLReader["PipelineModel"]): """ (Private) Specialization of :py:class:`MLReader` for :py:class:`PipelineModel` types """ - def __init__(self, cls): + def __init__(self, cls: Type["PipelineModel"]): super(PipelineModelReader, self).__init__() self.cls = cls - def load(self, path): + def load(self, path: str) -> "PipelineModel": metadata = DefaultParamsReader.loadMetadata(path, self.sc) if "language" not in metadata["paramMap"] or metadata["paramMap"]["language"] != "Python": - return JavaMLReader(self.cls).load(path) + return JavaMLReader(cast(Type["JavaMLReadable[PipelineModel]"], self.cls)).load(path) else: uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path) - return PipelineModel(stages=stages)._resetUid(uid) + return PipelineModel(stages=cast(List[Transformer], stages))._resetUid(uid) @inherit_doc -class PipelineModel(Model, MLReadable, MLWritable): +class PipelineModel(Model, MLReadable["PipelineModel"], MLWritable): """ Represents a compiled pipeline with transformers and fitted models. .. versionadded:: 1.3.0 """ - def __init__(self, stages): + def __init__(self, stages: List[Transformer]): super(PipelineModel, self).__init__() self.stages = stages - def _transform(self, dataset): + def _transform(self, dataset: DataFrame) -> DataFrame: for t in self.stages: dataset = t.transform(dataset) return dataset - def copy(self, extra=None): + def copy(self, extra: Optional["ParamMap"] = None) -> "PipelineModel": """ Creates a copy of this instance. @@ -301,33 +319,35 @@ def copy(self, extra=None): return PipelineModel(stages) @since("2.0.0") - def write(self): + def write(self) -> MLWriter: """Returns an MLWriter instance for this ML instance.""" - allStagesAreJava = PipelineSharedReadWrite.checkStagesForJava(self.stages) + allStagesAreJava = PipelineSharedReadWrite.checkStagesForJava( + cast(List["PipelineStage"], self.stages) + ) if allStagesAreJava: - return JavaMLWriter(self) + return JavaMLWriter(self) # type: ignore[arg-type] return PipelineModelWriter(self) @classmethod @since("2.0.0") - def read(cls): + def read(cls) -> PipelineModelReader: """Returns an MLReader instance for this class.""" return PipelineModelReader(cls) @classmethod - def _from_java(cls, java_stage): + def _from_java(cls, java_stage: "JavaObject") -> "PipelineModel": """ Given a Java PipelineModel, create and return a Python wrapper of it. Used for ML persistence. """ # Load information from java_stage to the instance. - py_stages = [JavaParams._from_java(s) for s in java_stage.stages()] + py_stages: List[Transformer] = [JavaParams._from_java(s) for s in java_stage.stages()] # Create a new instance of this stage. py_stage = cls(py_stages) py_stage._resetUid(java_stage.uid()) return py_stage - def _to_java(self): + def _to_java(self) -> "JavaObject": """ Transfer this instance to a Java PipelineModel. Used for ML persistence. @@ -335,10 +355,12 @@ def _to_java(self): """ gateway = SparkContext._gateway + assert gateway is not None and SparkContext._jvm is not None + cls = SparkContext._jvm.org.apache.spark.ml.Transformer java_stages = gateway.new_array(cls, len(self.stages)) for idx, stage in enumerate(self.stages): - java_stages[idx] = stage._to_java() + java_stages[idx] = cast(JavaParams, stage)._to_java() _java_obj = JavaParams._new_java_obj( "org.apache.spark.ml.PipelineModel", self.uid, java_stages @@ -357,11 +379,11 @@ class PipelineSharedReadWrite: """ @staticmethod - def checkStagesForJava(stages): + def checkStagesForJava(stages: List["PipelineStage"]) -> bool: return all(isinstance(stage, JavaMLWritable) for stage in stages) @staticmethod - def validateStages(stages): + def validateStages(stages: List["PipelineStage"]) -> None: """ Check that all stages are Writable """ @@ -375,7 +397,12 @@ def validateStages(stages): ) @staticmethod - def saveImpl(instance, stages, sc, path): + def saveImpl( + instance: Union[Pipeline, PipelineModel], + stages: List["PipelineStage"], + sc: SparkContext, + path: str, + ) -> None: """ Save metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel` - save metadata to path/metadata @@ -386,12 +413,14 @@ def saveImpl(instance, stages, sc, path): DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap=jsonParams) stagesDir = os.path.join(path, "stages") for index, stage in enumerate(stages): - stage.write().save( + cast(MLWritable, stage).write().save( PipelineSharedReadWrite.getStagePath(stage.uid, index, len(stages), stagesDir) ) @staticmethod - def load(metadata, sc, path): + def load( + metadata: Dict[str, Any], sc: SparkContext, path: str + ) -> Tuple[str, List["PipelineStage"]]: """ Load metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel` @@ -407,12 +436,12 @@ def load(metadata, sc, path): stagePath = PipelineSharedReadWrite.getStagePath( stageUid, index, len(stageUids), stagesDir ) - stage = DefaultParamsReader.loadParamsInstance(stagePath, sc) + stage: "PipelineStage" = DefaultParamsReader.loadParamsInstance(stagePath, sc) stages.append(stage) return (metadata["uid"], stages) @staticmethod - def getStagePath(stageUid, stageIdx, numStages, stagesDir): + def getStagePath(stageUid: str, stageIdx: int, numStages: int, stagesDir: str) -> str: """ Get path for saving the given stage. """ diff --git a/python/pyspark/ml/pipeline.pyi b/python/pyspark/ml/pipeline.pyi deleted file mode 100644 index f55b1e3e1ea47..0000000000000 --- a/python/pyspark/ml/pipeline.pyi +++ /dev/null @@ -1,95 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import Any, Dict, List, Optional, Tuple, Type, Union - -from pyspark.ml._typing import PipelineStage -from pyspark.context import SparkContext -from pyspark.ml.base import Estimator, Model, Transformer -from pyspark.ml.param import Param -from pyspark.ml.util import ( # noqa: F401 - DefaultParamsReader as DefaultParamsReader, - DefaultParamsWriter as DefaultParamsWriter, - JavaMLReader as JavaMLReader, - JavaMLWritable as JavaMLWritable, - JavaMLWriter as JavaMLWriter, - MLReadable as MLReadable, - MLReader as MLReader, - MLWritable as MLWritable, - MLWriter as MLWriter, -) - -class Pipeline(Estimator[PipelineModel], MLReadable[Pipeline], MLWritable): - stages: List[PipelineStage] - def __init__(self, *, stages: Optional[List[PipelineStage]] = ...) -> None: ... - def setStages(self, stages: List[PipelineStage]) -> Pipeline: ... - def getStages(self) -> List[PipelineStage]: ... - def setParams(self, *, stages: Optional[List[PipelineStage]] = ...) -> Pipeline: ... - def copy(self, extra: Optional[Dict[Param, str]] = ...) -> Pipeline: ... - def write(self) -> JavaMLWriter: ... - def save(self, path: str) -> None: ... - @classmethod - def read(cls) -> PipelineReader: ... - -class PipelineWriter(MLWriter): - instance: Pipeline - def __init__(self, instance: Pipeline) -> None: ... - def saveImpl(self, path: str) -> None: ... - -class PipelineReader(MLReader[Pipeline]): - cls: Type[Pipeline] - def __init__(self, cls: Type[Pipeline]) -> None: ... - def load(self, path: str) -> Pipeline: ... - -class PipelineModelWriter(MLWriter): - instance: PipelineModel - def __init__(self, instance: PipelineModel) -> None: ... - def saveImpl(self, path: str) -> None: ... - -class PipelineModelReader(MLReader[PipelineModel]): - cls: Type[PipelineModel] - def __init__(self, cls: Type[PipelineModel]) -> None: ... - def load(self, path: str) -> PipelineModel: ... - -class PipelineModel(Model, MLReadable[PipelineModel], MLWritable): - stages: List[PipelineStage] - def __init__(self, stages: List[Transformer]) -> None: ... - def copy(self, extra: Optional[Dict[Param, Any]] = ...) -> PipelineModel: ... - def write(self) -> JavaMLWriter: ... - def save(self, path: str) -> None: ... - @classmethod - def read(cls) -> PipelineModelReader: ... - -class PipelineSharedReadWrite: - @staticmethod - def checkStagesForJava(stages: List[PipelineStage]) -> bool: ... - @staticmethod - def validateStages(stages: List[PipelineStage]) -> None: ... - @staticmethod - def saveImpl( - instance: Union[Pipeline, PipelineModel], - stages: List[PipelineStage], - sc: SparkContext, - path: str, - ) -> None: ... - @staticmethod - def load( - metadata: Dict[str, Any], sc: SparkContext, path: str - ) -> Tuple[str, List[PipelineStage]]: ... - @staticmethod - def getStagePath(stageUid: str, stageIdx: int, numStages: int, stagesDir: str) -> str: ... diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index f0628fb9221cf..f13fb721b9a88 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -16,6 +16,7 @@ # import sys +from typing import Any, Dict, Optional, TYPE_CHECKING from pyspark import since, keyword_only from pyspark.ml.param.shared import ( @@ -30,6 +31,10 @@ from pyspark.ml.common import inherit_doc from pyspark.ml.param import Params, TypeConverters, Param from pyspark.ml.util import JavaMLWritable, JavaMLReadable +from pyspark.sql import DataFrame + +if TYPE_CHECKING: + from py4j.java_gateway import JavaObject __all__ = ["ALS", "ALSModel"] @@ -43,19 +48,19 @@ class _ALSModelParams(HasPredictionCol, HasBlockSize): .. versionadded:: 3.0.0 """ - userCol = Param( + userCol: Param[str] = Param( Params._dummy(), "userCol", "column name for user ids. Ids must be within " + "the integer value range.", typeConverter=TypeConverters.toString, ) - itemCol = Param( + itemCol: Param[str] = Param( Params._dummy(), "itemCol", "column name for item ids. Ids must be within " + "the integer value range.", typeConverter=TypeConverters.toString, ) - coldStartStrategy = Param( + coldStartStrategy: Param[str] = Param( Params._dummy(), "coldStartStrategy", "strategy for dealing with " @@ -66,26 +71,26 @@ class _ALSModelParams(HasPredictionCol, HasBlockSize): typeConverter=TypeConverters.toString, ) - def __init__(self, *args): + def __init__(self, *args: Any): super(_ALSModelParams, self).__init__(*args) self._setDefault(blockSize=4096) @since("1.4.0") - def getUserCol(self): + def getUserCol(self) -> str: """ Gets the value of userCol or its default value. """ return self.getOrDefault(self.userCol) @since("1.4.0") - def getItemCol(self): + def getItemCol(self) -> str: """ Gets the value of itemCol or its default value. """ return self.getOrDefault(self.itemCol) @since("2.2.0") - def getColdStartStrategy(self): + def getColdStartStrategy(self) -> str: """ Gets the value of coldStartStrategy or its default value. """ @@ -100,60 +105,60 @@ class _ALSParams(_ALSModelParams, HasMaxIter, HasRegParam, HasCheckpointInterval .. versionadded:: 3.0.0 """ - rank = Param( + rank: Param[int] = Param( Params._dummy(), "rank", "rank of the factorization", typeConverter=TypeConverters.toInt ) - numUserBlocks = Param( + numUserBlocks: Param[int] = Param( Params._dummy(), "numUserBlocks", "number of user blocks", typeConverter=TypeConverters.toInt, ) - numItemBlocks = Param( + numItemBlocks: Param[int] = Param( Params._dummy(), "numItemBlocks", "number of item blocks", typeConverter=TypeConverters.toInt, ) - implicitPrefs = Param( + implicitPrefs: Param[bool] = Param( Params._dummy(), "implicitPrefs", "whether to use implicit preference", typeConverter=TypeConverters.toBoolean, ) - alpha = Param( + alpha: Param[float] = Param( Params._dummy(), "alpha", "alpha for implicit preference", typeConverter=TypeConverters.toFloat, ) - ratingCol = Param( + ratingCol: Param[str] = Param( Params._dummy(), "ratingCol", "column name for ratings", typeConverter=TypeConverters.toString, ) - nonnegative = Param( + nonnegative: Param[bool] = Param( Params._dummy(), "nonnegative", "whether to use nonnegative constraint for least squares", typeConverter=TypeConverters.toBoolean, ) - intermediateStorageLevel = Param( + intermediateStorageLevel: Param[str] = Param( Params._dummy(), "intermediateStorageLevel", "StorageLevel for intermediate datasets. Cannot be 'NONE'.", typeConverter=TypeConverters.toString, ) - finalStorageLevel = Param( + finalStorageLevel: Param[str] = Param( Params._dummy(), "finalStorageLevel", "StorageLevel for ALS model factors.", typeConverter=TypeConverters.toString, ) - def __init__(self, *args): + def __init__(self, *args: Any): super(_ALSParams, self).__init__(*args) self._setDefault( rank=10, @@ -174,63 +179,63 @@ def __init__(self, *args): ) @since("1.4.0") - def getRank(self): + def getRank(self) -> int: """ Gets the value of rank or its default value. """ return self.getOrDefault(self.rank) @since("1.4.0") - def getNumUserBlocks(self): + def getNumUserBlocks(self) -> int: """ Gets the value of numUserBlocks or its default value. """ return self.getOrDefault(self.numUserBlocks) @since("1.4.0") - def getNumItemBlocks(self): + def getNumItemBlocks(self) -> int: """ Gets the value of numItemBlocks or its default value. """ return self.getOrDefault(self.numItemBlocks) @since("1.4.0") - def getImplicitPrefs(self): + def getImplicitPrefs(self) -> bool: """ Gets the value of implicitPrefs or its default value. """ return self.getOrDefault(self.implicitPrefs) @since("1.4.0") - def getAlpha(self): + def getAlpha(self) -> float: """ Gets the value of alpha or its default value. """ return self.getOrDefault(self.alpha) @since("1.4.0") - def getRatingCol(self): + def getRatingCol(self) -> str: """ Gets the value of ratingCol or its default value. """ return self.getOrDefault(self.ratingCol) @since("1.4.0") - def getNonnegative(self): + def getNonnegative(self) -> bool: """ Gets the value of nonnegative or its default value. """ return self.getOrDefault(self.nonnegative) @since("2.0.0") - def getIntermediateStorageLevel(self): + def getIntermediateStorageLevel(self) -> str: """ Gets the value of intermediateStorageLevel or its default value. """ return self.getOrDefault(self.intermediateStorageLevel) @since("2.0.0") - def getFinalStorageLevel(self): + def getFinalStorageLevel(self) -> str: """ Gets the value of finalStorageLevel or its default value. """ @@ -238,7 +243,7 @@ def getFinalStorageLevel(self): @inherit_doc -class ALS(JavaEstimator, _ALSParams, JavaMLWritable, JavaMLReadable): +class ALS(JavaEstimator["ALSModel"], _ALSParams, JavaMLWritable, JavaMLReadable["ALS"]): """ Alternating Least Squares (ALS) matrix factorization. @@ -320,7 +325,7 @@ class ALS(JavaEstimator, _ALSParams, JavaMLWritable, JavaMLReadable): >>> test = spark.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"]) >>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0]) >>> predictions[0] - Row(user=0, item=2, newPrediction=0.69291...) + Row(user=0, item=2, newPrediction=0.6929...) >>> predictions[1] Row(user=1, item=0, newPrediction=3.47356...) >>> predictions[2] @@ -359,27 +364,29 @@ class ALS(JavaEstimator, _ALSParams, JavaMLWritable, JavaMLReadable): True """ + _input_kwargs: Dict[str, Any] + @keyword_only def __init__( self, *, - rank=10, - maxIter=10, - regParam=0.1, - numUserBlocks=10, - numItemBlocks=10, - implicitPrefs=False, - alpha=1.0, - userCol="user", - itemCol="item", - seed=None, - ratingCol="rating", - nonnegative=False, - checkpointInterval=10, - intermediateStorageLevel="MEMORY_AND_DISK", - finalStorageLevel="MEMORY_AND_DISK", - coldStartStrategy="nan", - blockSize=4096, + rank: int = 10, + maxIter: int = 10, + regParam: float = 0.1, + numUserBlocks: int = 10, + numItemBlocks: int = 10, + implicitPrefs: bool = False, + alpha: float = 1.0, + userCol: str = "user", + itemCol: str = "item", + seed: Optional[int] = None, + ratingCol: str = "rating", + nonnegative: bool = False, + checkpointInterval: int = 10, + intermediateStorageLevel: str = "MEMORY_AND_DISK", + finalStorageLevel: str = "MEMORY_AND_DISK", + coldStartStrategy: str = "nan", + blockSize: int = 4096, ): """ __init__(self, \\*, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, @@ -398,24 +405,24 @@ def __init__( def setParams( self, *, - rank=10, - maxIter=10, - regParam=0.1, - numUserBlocks=10, - numItemBlocks=10, - implicitPrefs=False, - alpha=1.0, - userCol="user", - itemCol="item", - seed=None, - ratingCol="rating", - nonnegative=False, - checkpointInterval=10, - intermediateStorageLevel="MEMORY_AND_DISK", - finalStorageLevel="MEMORY_AND_DISK", - coldStartStrategy="nan", - blockSize=4096, - ): + rank: int = 10, + maxIter: int = 10, + regParam: float = 0.1, + numUserBlocks: int = 10, + numItemBlocks: int = 10, + implicitPrefs: bool = False, + alpha: float = 1.0, + userCol: str = "user", + itemCol: str = "item", + seed: Optional[int] = None, + ratingCol: str = "rating", + nonnegative: bool = False, + checkpointInterval: int = 10, + intermediateStorageLevel: str = "MEMORY_AND_DISK", + finalStorageLevel: str = "MEMORY_AND_DISK", + coldStartStrategy: str = "nan", + blockSize: int = 4096, + ) -> "ALS": """ setParams(self, \\*, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, \ numItemBlocks=10, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", \ @@ -427,32 +434,32 @@ def setParams( kwargs = self._input_kwargs return self._set(**kwargs) - def _create_model(self, java_model): + def _create_model(self, java_model: "JavaObject") -> "ALSModel": return ALSModel(java_model) @since("1.4.0") - def setRank(self, value): + def setRank(self, value: int) -> "ALS": """ Sets the value of :py:attr:`rank`. """ return self._set(rank=value) @since("1.4.0") - def setNumUserBlocks(self, value): + def setNumUserBlocks(self, value: int) -> "ALS": """ Sets the value of :py:attr:`numUserBlocks`. """ return self._set(numUserBlocks=value) @since("1.4.0") - def setNumItemBlocks(self, value): + def setNumItemBlocks(self, value: int) -> "ALS": """ Sets the value of :py:attr:`numItemBlocks`. """ return self._set(numItemBlocks=value) @since("1.4.0") - def setNumBlocks(self, value): + def setNumBlocks(self, value: int) -> "ALS": """ Sets both :py:attr:`numUserBlocks` and :py:attr:`numItemBlocks` to the specific value. """ @@ -460,107 +467,107 @@ def setNumBlocks(self, value): return self._set(numItemBlocks=value) @since("1.4.0") - def setImplicitPrefs(self, value): + def setImplicitPrefs(self, value: bool) -> "ALS": """ Sets the value of :py:attr:`implicitPrefs`. """ return self._set(implicitPrefs=value) @since("1.4.0") - def setAlpha(self, value): + def setAlpha(self, value: float) -> "ALS": """ Sets the value of :py:attr:`alpha`. """ return self._set(alpha=value) @since("1.4.0") - def setUserCol(self, value): + def setUserCol(self, value: str) -> "ALS": """ Sets the value of :py:attr:`userCol`. """ return self._set(userCol=value) @since("1.4.0") - def setItemCol(self, value): + def setItemCol(self, value: str) -> "ALS": """ Sets the value of :py:attr:`itemCol`. """ return self._set(itemCol=value) @since("1.4.0") - def setRatingCol(self, value): + def setRatingCol(self, value: str) -> "ALS": """ Sets the value of :py:attr:`ratingCol`. """ return self._set(ratingCol=value) @since("1.4.0") - def setNonnegative(self, value): + def setNonnegative(self, value: bool) -> "ALS": """ Sets the value of :py:attr:`nonnegative`. """ return self._set(nonnegative=value) @since("2.0.0") - def setIntermediateStorageLevel(self, value): + def setIntermediateStorageLevel(self, value: str) -> "ALS": """ Sets the value of :py:attr:`intermediateStorageLevel`. """ return self._set(intermediateStorageLevel=value) @since("2.0.0") - def setFinalStorageLevel(self, value): + def setFinalStorageLevel(self, value: str) -> "ALS": """ Sets the value of :py:attr:`finalStorageLevel`. """ return self._set(finalStorageLevel=value) @since("2.2.0") - def setColdStartStrategy(self, value): + def setColdStartStrategy(self, value: str) -> "ALS": """ Sets the value of :py:attr:`coldStartStrategy`. """ return self._set(coldStartStrategy=value) - def setMaxIter(self, value): + def setMaxIter(self, value: int) -> "ALS": """ Sets the value of :py:attr:`maxIter`. """ return self._set(maxIter=value) - def setRegParam(self, value): + def setRegParam(self, value: float) -> "ALS": """ Sets the value of :py:attr:`regParam`. """ return self._set(regParam=value) - def setPredictionCol(self, value): + def setPredictionCol(self, value: str) -> "ALS": """ Sets the value of :py:attr:`predictionCol`. """ return self._set(predictionCol=value) - def setCheckpointInterval(self, value): + def setCheckpointInterval(self, value: int) -> "ALS": """ Sets the value of :py:attr:`checkpointInterval`. """ return self._set(checkpointInterval=value) - def setSeed(self, value): + def setSeed(self, value: int) -> "ALS": """ Sets the value of :py:attr:`seed`. """ return self._set(seed=value) @since("3.0.0") - def setBlockSize(self, value): + def setBlockSize(self, value: int) -> "ALS": """ Sets the value of :py:attr:`blockSize`. """ return self._set(blockSize=value) -class ALSModel(JavaModel, _ALSModelParams, JavaMLWritable, JavaMLReadable): +class ALSModel(JavaModel, _ALSModelParams, JavaMLWritable, JavaMLReadable["ALSModel"]): """ Model fitted by ALS. @@ -568,65 +575,65 @@ class ALSModel(JavaModel, _ALSModelParams, JavaMLWritable, JavaMLReadable): """ @since("3.0.0") - def setUserCol(self, value): + def setUserCol(self, value: str) -> "ALSModel": """ Sets the value of :py:attr:`userCol`. """ return self._set(userCol=value) @since("3.0.0") - def setItemCol(self, value): + def setItemCol(self, value: str) -> "ALSModel": """ Sets the value of :py:attr:`itemCol`. """ return self._set(itemCol=value) @since("3.0.0") - def setColdStartStrategy(self, value): + def setColdStartStrategy(self, value: str) -> "ALSModel": """ Sets the value of :py:attr:`coldStartStrategy`. """ return self._set(coldStartStrategy=value) @since("3.0.0") - def setPredictionCol(self, value): + def setPredictionCol(self, value: str) -> "ALSModel": """ Sets the value of :py:attr:`predictionCol`. """ return self._set(predictionCol=value) @since("3.0.0") - def setBlockSize(self, value): + def setBlockSize(self, value: int) -> "ALSModel": """ Sets the value of :py:attr:`blockSize`. """ return self._set(blockSize=value) - @property + @property # type: ignore[misc] @since("1.4.0") - def rank(self): + def rank(self) -> int: """rank of the matrix factorization model""" return self._call_java("rank") - @property + @property # type: ignore[misc] @since("1.4.0") - def userFactors(self): + def userFactors(self) -> DataFrame: """ a DataFrame that stores user factors in two columns: `id` and `features` """ return self._call_java("userFactors") - @property + @property # type: ignore[misc] @since("1.4.0") - def itemFactors(self): + def itemFactors(self) -> DataFrame: """ a DataFrame that stores item factors in two columns: `id` and `features` """ return self._call_java("itemFactors") - def recommendForAllUsers(self, numItems): + def recommendForAllUsers(self, numItems: int) -> DataFrame: """ Returns top `numItems` items recommended for each user, for all users. @@ -645,7 +652,7 @@ def recommendForAllUsers(self, numItems): """ return self._call_java("recommendForAllUsers", numItems) - def recommendForAllItems(self, numUsers): + def recommendForAllItems(self, numUsers: int) -> DataFrame: """ Returns top `numUsers` users recommended for each item, for all items. @@ -664,7 +671,7 @@ def recommendForAllItems(self, numUsers): """ return self._call_java("recommendForAllItems", numUsers) - def recommendForUserSubset(self, dataset, numItems): + def recommendForUserSubset(self, dataset: DataFrame, numItems: int) -> DataFrame: """ Returns top `numItems` items recommended for each user id in the input data set. Note that if there are duplicate ids in the input dataset, only one set of recommendations per unique @@ -687,7 +694,7 @@ def recommendForUserSubset(self, dataset, numItems): """ return self._call_java("recommendForUserSubset", dataset, numItems) - def recommendForItemSubset(self, dataset, numUsers): + def recommendForItemSubset(self, dataset: DataFrame, numUsers: int) -> DataFrame: """ Returns top `numUsers` users recommended for each item id in the input data set. Note that if there are duplicate ids in the input dataset, only one set of recommendations per unique diff --git a/python/pyspark/ml/recommendation.pyi b/python/pyspark/ml/recommendation.pyi deleted file mode 100644 index 6ce178b9d71b1..0000000000000 --- a/python/pyspark/ml/recommendation.pyi +++ /dev/null @@ -1,146 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import Any, Optional - -import sys # noqa: F401 - -from pyspark import since, keyword_only # noqa: F401 -from pyspark.ml.param.shared import ( - HasBlockSize, - HasCheckpointInterval, - HasMaxIter, - HasPredictionCol, - HasRegParam, - HasSeed, -) -from pyspark.ml.wrapper import JavaEstimator, JavaModel -from pyspark.ml.common import inherit_doc # noqa: F401 -from pyspark.ml.param import Param -from pyspark.ml.util import JavaMLWritable, JavaMLReadable - -from pyspark.sql.dataframe import DataFrame - -class _ALSModelParams(HasPredictionCol, HasBlockSize): - userCol: Param[str] - itemCol: Param[str] - coldStartStrategy: Param[str] - def getUserCol(self) -> str: ... - def getItemCol(self) -> str: ... - def getColdStartStrategy(self) -> str: ... - -class _ALSParams(_ALSModelParams, HasMaxIter, HasRegParam, HasCheckpointInterval, HasSeed): - rank: Param[int] - numUserBlocks: Param[int] - numItemBlocks: Param[int] - implicitPrefs: Param[bool] - alpha: Param[float] - ratingCol: Param[str] - nonnegative: Param[bool] - intermediateStorageLevel: Param[str] - finalStorageLevel: Param[str] - def __init__(self, *args: Any): ... - def getRank(self) -> int: ... - def getNumUserBlocks(self) -> int: ... - def getNumItemBlocks(self) -> int: ... - def getImplicitPrefs(self) -> bool: ... - def getAlpha(self) -> float: ... - def getRatingCol(self) -> str: ... - def getNonnegative(self) -> bool: ... - def getIntermediateStorageLevel(self) -> str: ... - def getFinalStorageLevel(self) -> str: ... - -class ALS(JavaEstimator[ALSModel], _ALSParams, JavaMLWritable, JavaMLReadable[ALS]): - def __init__( - self, - *, - rank: int = ..., - maxIter: int = ..., - regParam: float = ..., - numUserBlocks: int = ..., - numItemBlocks: int = ..., - implicitPrefs: bool = ..., - alpha: float = ..., - userCol: str = ..., - itemCol: str = ..., - seed: Optional[int] = ..., - ratingCol: str = ..., - nonnegative: bool = ..., - checkpointInterval: int = ..., - intermediateStorageLevel: str = ..., - finalStorageLevel: str = ..., - coldStartStrategy: str = ..., - blockSize: int = ..., - ) -> None: ... - def setParams( - self, - *, - rank: int = ..., - maxIter: int = ..., - regParam: float = ..., - numUserBlocks: int = ..., - numItemBlocks: int = ..., - implicitPrefs: bool = ..., - alpha: float = ..., - userCol: str = ..., - itemCol: str = ..., - seed: Optional[int] = ..., - ratingCol: str = ..., - nonnegative: bool = ..., - checkpointInterval: int = ..., - intermediateStorageLevel: str = ..., - finalStorageLevel: str = ..., - coldStartStrategy: str = ..., - blockSize: int = ..., - ) -> ALS: ... - def setRank(self, value: int) -> ALS: ... - def setNumUserBlocks(self, value: int) -> ALS: ... - def setNumItemBlocks(self, value: int) -> ALS: ... - def setNumBlocks(self, value: int) -> ALS: ... - def setImplicitPrefs(self, value: bool) -> ALS: ... - def setAlpha(self, value: float) -> ALS: ... - def setUserCol(self, value: str) -> ALS: ... - def setItemCol(self, value: str) -> ALS: ... - def setRatingCol(self, value: str) -> ALS: ... - def setNonnegative(self, value: bool) -> ALS: ... - def setIntermediateStorageLevel(self, value: str) -> ALS: ... - def setFinalStorageLevel(self, value: str) -> ALS: ... - def setColdStartStrategy(self, value: str) -> ALS: ... - def setMaxIter(self, value: int) -> ALS: ... - def setRegParam(self, value: float) -> ALS: ... - def setPredictionCol(self, value: str) -> ALS: ... - def setCheckpointInterval(self, value: int) -> ALS: ... - def setSeed(self, value: int) -> ALS: ... - def setBlockSize(self, value: int) -> ALS: ... - -class ALSModel(JavaModel, _ALSModelParams, JavaMLWritable, JavaMLReadable[ALSModel]): - def setUserCol(self, value: str) -> ALSModel: ... - def setItemCol(self, value: str) -> ALSModel: ... - def setColdStartStrategy(self, value: str) -> ALSModel: ... - def setPredictionCol(self, value: str) -> ALSModel: ... - def setBlockSize(self, value: int) -> ALSModel: ... - @property - def rank(self) -> int: ... - @property - def userFactors(self) -> DataFrame: ... - @property - def itemFactors(self) -> DataFrame: ... - def recommendForAllUsers(self, numItems: int) -> DataFrame: ... - def recommendForAllItems(self, numUsers: int) -> DataFrame: ... - def recommendForUserSubset(self, dataset: DataFrame, numItems: int) -> DataFrame: ... - def recommendForItemSubset(self, dataset: DataFrame, numUsers: int) -> DataFrame: ... diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 0faca85354f58..8678ec3f31e0b 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -17,6 +17,8 @@ import sys +from typing import Any, Dict, Generic, List, Optional, TypeVar, TYPE_CHECKING + from abc import ABCMeta from pyspark import keyword_only, since @@ -52,6 +54,8 @@ _GBTParams, _TreeRegressorParams, ) +from pyspark.ml.base import Transformer +from pyspark.ml.linalg import Vector, Matrix from pyspark.ml.util import ( JavaMLWritable, JavaMLReadable, @@ -63,11 +67,19 @@ JavaModel, JavaPredictor, JavaPredictionModel, + JavaTransformer, JavaWrapper, ) from pyspark.ml.common import inherit_doc from pyspark.sql import DataFrame +if TYPE_CHECKING: + from py4j.java_gateway import JavaObject + +T = TypeVar("T") +M = TypeVar("M", bound=Transformer) +JM = TypeVar("JM", bound=JavaTransformer) + __all__ = [ "AFTSurvivalRegression", @@ -93,7 +105,7 @@ ] -class Regressor(Predictor, _PredictorParams, metaclass=ABCMeta): +class Regressor(Predictor[M], _PredictorParams, Generic[M], metaclass=ABCMeta): """ Regressor for regression tasks. @@ -103,7 +115,7 @@ class Regressor(Predictor, _PredictorParams, metaclass=ABCMeta): pass -class RegressionModel(PredictionModel, _PredictorParams, metaclass=ABCMeta): +class RegressionModel(PredictionModel[T], _PredictorParams, metaclass=ABCMeta): """ Model produced by a ``Regressor``. @@ -113,7 +125,7 @@ class RegressionModel(PredictionModel, _PredictorParams, metaclass=ABCMeta): pass -class _JavaRegressor(Regressor, JavaPredictor, metaclass=ABCMeta): +class _JavaRegressor(Regressor, JavaPredictor[JM], Generic[JM], metaclass=ABCMeta): """ Java Regressor for regression tasks. @@ -123,7 +135,7 @@ class _JavaRegressor(Regressor, JavaPredictor, metaclass=ABCMeta): pass -class _JavaRegressionModel(RegressionModel, JavaPredictionModel, metaclass=ABCMeta): +class _JavaRegressionModel(RegressionModel, JavaPredictionModel[T], metaclass=ABCMeta): """ Java Model produced by a ``_JavaRegressor``. To be mixed in with :class:`pyspark.ml.JavaModel` @@ -154,21 +166,21 @@ class _LinearRegressionParams( .. versionadded:: 3.0.0 """ - solver = Param( + solver: Param[str] = Param( Params._dummy(), "solver", "The solver algorithm for optimization. Supported " + "options: auto, normal, l-bfgs.", typeConverter=TypeConverters.toString, ) - loss = Param( + loss: Param[str] = Param( Params._dummy(), "loss", "The loss function to be optimized. Supported " + "options: squaredError, huber.", typeConverter=TypeConverters.toString, ) - epsilon = Param( + epsilon: Param[float] = Param( Params._dummy(), "epsilon", "The shape parameter to control the amount of " @@ -176,7 +188,7 @@ class _LinearRegressionParams( typeConverter=TypeConverters.toFloat, ) - def __init__(self, *args): + def __init__(self, *args: Any): super(_LinearRegressionParams, self).__init__(*args) self._setDefault( maxIter=100, @@ -188,7 +200,7 @@ def __init__(self, *args): ) @since("2.3.0") - def getEpsilon(self): + def getEpsilon(self) -> float: """ Gets the value of epsilon or its default value. """ @@ -196,7 +208,12 @@ def getEpsilon(self): @inherit_doc -class LinearRegression(_JavaRegressor, _LinearRegressionParams, JavaMLWritable, JavaMLReadable): +class LinearRegression( + _JavaRegressor["LinearRegressionModel"], + _LinearRegressionParams, + JavaMLWritable, + JavaMLReadable["LinearRegression"], +): """ Linear regression. @@ -279,25 +296,27 @@ class LinearRegression(_JavaRegressor, _LinearRegressionParams, JavaMLWritable, >>> model.write().format("pmml").save(model_path + "_2") """ + _input_kwargs: Dict[str, Any] + @keyword_only def __init__( self, *, - featuresCol="features", - labelCol="label", - predictionCol="prediction", - maxIter=100, - regParam=0.0, - elasticNetParam=0.0, - tol=1e-6, - fitIntercept=True, - standardization=True, - solver="auto", - weightCol=None, - aggregationDepth=2, - loss="squaredError", - epsilon=1.35, - maxBlockSizeInMB=0.0, + featuresCol: str = "features", + labelCol: str = "label", + predictionCol: str = "prediction", + maxIter: int = 100, + regParam: float = 0.0, + elasticNetParam: float = 0.0, + tol: float = 1e-6, + fitIntercept: bool = True, + standardization: bool = True, + solver: str = "auto", + weightCol: Optional[str] = None, + aggregationDepth: int = 2, + loss: str = "squaredError", + epsilon: float = 1.35, + maxBlockSizeInMB: float = 0.0, ): """ __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \ @@ -317,22 +336,22 @@ def __init__( def setParams( self, *, - featuresCol="features", - labelCol="label", - predictionCol="prediction", - maxIter=100, - regParam=0.0, - elasticNetParam=0.0, - tol=1e-6, - fitIntercept=True, - standardization=True, - solver="auto", - weightCol=None, - aggregationDepth=2, - loss="squaredError", - epsilon=1.35, - maxBlockSizeInMB=0.0, - ): + featuresCol: str = "features", + labelCol: str = "label", + predictionCol: str = "prediction", + maxIter: int = 100, + regParam: float = 0.0, + elasticNetParam: float = 0.0, + tol: float = 1e-6, + fitIntercept: bool = True, + standardization: bool = True, + solver: str = "auto", + weightCol: Optional[str] = None, + aggregationDepth: int = 2, + loss: str = "squaredError", + epsilon: float = 1.35, + maxBlockSizeInMB: float = 0.0, + ) -> "LinearRegression": """ setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ @@ -343,78 +362,78 @@ def setParams( kwargs = self._input_kwargs return self._set(**kwargs) - def _create_model(self, java_model): + def _create_model(self, java_model: "JavaObject") -> "LinearRegressionModel": return LinearRegressionModel(java_model) @since("2.3.0") - def setEpsilon(self, value): + def setEpsilon(self, value: float) -> "LinearRegression": """ Sets the value of :py:attr:`epsilon`. """ return self._set(epsilon=value) - def setMaxIter(self, value): + def setMaxIter(self, value: int) -> "LinearRegression": """ Sets the value of :py:attr:`maxIter`. """ return self._set(maxIter=value) - def setRegParam(self, value): + def setRegParam(self, value: float) -> "LinearRegression": """ Sets the value of :py:attr:`regParam`. """ return self._set(regParam=value) - def setTol(self, value): + def setTol(self, value: float) -> "LinearRegression": """ Sets the value of :py:attr:`tol`. """ return self._set(tol=value) - def setElasticNetParam(self, value): + def setElasticNetParam(self, value: float) -> "LinearRegression": """ Sets the value of :py:attr:`elasticNetParam`. """ return self._set(elasticNetParam=value) - def setFitIntercept(self, value): + def setFitIntercept(self, value: bool) -> "LinearRegression": """ Sets the value of :py:attr:`fitIntercept`. """ return self._set(fitIntercept=value) - def setStandardization(self, value): + def setStandardization(self, value: bool) -> "LinearRegression": """ Sets the value of :py:attr:`standardization`. """ return self._set(standardization=value) - def setWeightCol(self, value): + def setWeightCol(self, value: str) -> "LinearRegression": """ Sets the value of :py:attr:`weightCol`. """ return self._set(weightCol=value) - def setSolver(self, value): + def setSolver(self, value: str) -> "LinearRegression": """ Sets the value of :py:attr:`solver`. """ return self._set(solver=value) - def setAggregationDepth(self, value): + def setAggregationDepth(self, value: int) -> "LinearRegression": """ Sets the value of :py:attr:`aggregationDepth`. """ return self._set(aggregationDepth=value) - def setLoss(self, value): + def setLoss(self, value: str) -> "LinearRegression": """ Sets the value of :py:attr:`loss`. """ return self._set(lossType=value) @since("3.1.0") - def setMaxBlockSizeInMB(self, value): + def setMaxBlockSizeInMB(self, value: float) -> "LinearRegression": """ Sets the value of :py:attr:`maxBlockSizeInMB`. """ @@ -425,8 +444,8 @@ class LinearRegressionModel( _JavaRegressionModel, _LinearRegressionParams, GeneralJavaMLWritable, - JavaMLReadable, - HasTrainingSummary, + JavaMLReadable["LinearRegressionModel"], + HasTrainingSummary["LinearRegressionSummary"], ): """ Model fitted by :class:`LinearRegression`. @@ -434,33 +453,33 @@ class LinearRegressionModel( .. versionadded:: 1.4.0 """ - @property + @property # type: ignore[misc] @since("2.0.0") - def coefficients(self): + def coefficients(self) -> Vector: """ Model coefficients. """ return self._call_java("coefficients") - @property + @property # type: ignore[misc] @since("1.4.0") - def intercept(self): + def intercept(self) -> float: """ Model intercept. """ return self._call_java("intercept") - @property + @property # type: ignore[misc] @since("2.3.0") - def scale(self): + def scale(self) -> float: r""" The value by which :math:`\|y - X'w\|` is scaled down when loss is "huber", otherwise 1.0. """ return self._call_java("scale") - @property + @property # type: ignore[misc] @since("2.0.0") - def summary(self): + def summary(self) -> "LinearRegressionTrainingSummary": """ Gets summary (residuals, MSE, r-squared ) of model on training set. An exception is thrown if @@ -473,7 +492,7 @@ def summary(self): "No training summary available for this %s" % self.__class__.__name__ ) - def evaluate(self, dataset): + def evaluate(self, dataset: DataFrame) -> "LinearRegressionSummary": """ Evaluates the model on a test dataset. @@ -498,44 +517,44 @@ class LinearRegressionSummary(JavaWrapper): .. versionadded:: 2.0.0 """ - @property + @property # type: ignore[misc] @since("2.0.0") - def predictions(self): + def predictions(self) -> DataFrame: """ Dataframe outputted by the model's `transform` method. """ return self._call_java("predictions") - @property + @property # type: ignore[misc] @since("2.0.0") - def predictionCol(self): + def predictionCol(self) -> str: """ Field in "predictions" which gives the predicted value of the label at each instance. """ return self._call_java("predictionCol") - @property + @property # type: ignore[misc] @since("2.0.0") - def labelCol(self): + def labelCol(self) -> str: """ Field in "predictions" which gives the true label of each instance. """ return self._call_java("labelCol") - @property + @property # type: ignore[misc] @since("2.0.0") - def featuresCol(self): + def featuresCol(self) -> str: """ Field in "predictions" which gives the features of each instance as a vector. """ return self._call_java("featuresCol") - @property + @property # type: ignore[misc] @since("2.0.0") - def explainedVariance(self): + def explainedVariance(self) -> float: r""" Returns the explained variance regression score. explainedVariance = :math:`1 - \frac{variance(y - \hat{y})}{variance(y)}` @@ -552,9 +571,9 @@ def explainedVariance(self): """ return self._call_java("explainedVariance") - @property + @property # type: ignore[misc] @since("2.0.0") - def meanAbsoluteError(self): + def meanAbsoluteError(self) -> float: """ Returns the mean absolute error, which is a risk function corresponding to the expected value of the absolute error @@ -568,9 +587,9 @@ def meanAbsoluteError(self): """ return self._call_java("meanAbsoluteError") - @property + @property # type: ignore[misc] @since("2.0.0") - def meanSquaredError(self): + def meanSquaredError(self) -> float: """ Returns the mean squared error, which is a risk function corresponding to the expected value of the squared error @@ -584,9 +603,9 @@ def meanSquaredError(self): """ return self._call_java("meanSquaredError") - @property + @property # type: ignore[misc] @since("2.0.0") - def rootMeanSquaredError(self): + def rootMeanSquaredError(self) -> float: """ Returns the root mean squared error, which is defined as the square root of the mean squared error. @@ -599,9 +618,9 @@ def rootMeanSquaredError(self): """ return self._call_java("rootMeanSquaredError") - @property + @property # type: ignore[misc] @since("2.0.0") - def r2(self): + def r2(self) -> float: """ Returns R^2, the coefficient of determination. @@ -616,9 +635,9 @@ def r2(self): """ return self._call_java("r2") - @property + @property # type: ignore[misc] @since("2.4.0") - def r2adj(self): + def r2adj(self) -> float: """ Returns Adjusted R^2, the adjusted coefficient of determination. @@ -632,33 +651,33 @@ def r2adj(self): """ return self._call_java("r2adj") - @property + @property # type: ignore[misc] @since("2.0.0") - def residuals(self): + def residuals(self) -> DataFrame: """ Residuals (label - predicted value) """ return self._call_java("residuals") - @property + @property # type: ignore[misc] @since("2.0.0") - def numInstances(self): + def numInstances(self) -> int: """ Number of instances in DataFrame predictions """ return self._call_java("numInstances") - @property + @property # type: ignore[misc] @since("2.2.0") - def degreesOfFreedom(self): + def degreesOfFreedom(self) -> int: """ Degrees of freedom. """ return self._call_java("degreesOfFreedom") - @property + @property # type: ignore[misc] @since("2.0.0") - def devianceResiduals(self): + def devianceResiduals(self) -> List[float]: """ The weighted residuals, the usual residuals rescaled by the square root of the instance weights. @@ -666,7 +685,7 @@ def devianceResiduals(self): return self._call_java("devianceResiduals") @property - def coefficientStandardErrors(self): + def coefficientStandardErrors(self) -> List[float]: """ Standard error of estimated coefficients and intercept. This value is only available when using the "normal" solver. @@ -683,7 +702,7 @@ def coefficientStandardErrors(self): return self._call_java("coefficientStandardErrors") @property - def tValues(self): + def tValues(self) -> List[float]: """ T-statistic of estimated coefficients and intercept. This value is only available when using the "normal" solver. @@ -700,7 +719,7 @@ def tValues(self): return self._call_java("tValues") @property - def pValues(self): + def pValues(self) -> List[float]: """ Two-sided p-value of estimated coefficients and intercept. This value is only available when using the "normal" solver. @@ -727,7 +746,7 @@ class LinearRegressionTrainingSummary(LinearRegressionSummary): """ @property - def objectiveHistory(self): + def objectiveHistory(self) -> List[float]: """ Objective function (scaled loss + regularization) at each iteration. @@ -742,7 +761,7 @@ def objectiveHistory(self): return self._call_java("objectiveHistory") @property - def totalIterations(self): + def totalIterations(self) -> int: """ Number of training iterations until termination. This value is only available when using the "l-bfgs" solver. @@ -763,31 +782,31 @@ class _IsotonicRegressionParams(HasFeaturesCol, HasLabelCol, HasPredictionCol, H .. versionadded:: 3.0.0 """ - isotonic = Param( + isotonic: Param[bool] = Param( Params._dummy(), "isotonic", "whether the output sequence should be isotonic/increasing (true) or" + "antitonic/decreasing (false).", typeConverter=TypeConverters.toBoolean, ) - featureIndex = Param( + featureIndex: Param[int] = Param( Params._dummy(), "featureIndex", "The index of the feature if featuresCol is a vector column, no effect otherwise.", typeConverter=TypeConverters.toInt, ) - def __init__(self, *args): + def __init__(self, *args: Any): super(_IsotonicRegressionParams, self).__init__(*args) self._setDefault(isotonic=True, featureIndex=0) - def getIsotonic(self): + def getIsotonic(self) -> bool: """ Gets the value of isotonic or its default value. """ return self.getOrDefault(self.isotonic) - def getFeatureIndex(self): + def getFeatureIndex(self) -> int: """ Gets the value of featureIndex or its default value. """ @@ -839,16 +858,18 @@ class IsotonicRegression( True """ + _input_kwargs: Dict[str, Any] + @keyword_only def __init__( self, *, - featuresCol="features", - labelCol="label", - predictionCol="prediction", - weightCol=None, - isotonic=True, - featureIndex=0, + featuresCol: str = "features", + labelCol: str = "label", + predictionCol: str = "prediction", + weightCol: Optional[str] = None, + isotonic: bool = True, + featureIndex: int = 0, ): """ __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \ @@ -865,13 +886,13 @@ def __init__( def setParams( self, *, - featuresCol="features", - labelCol="label", - predictionCol="prediction", - weightCol=None, - isotonic=True, - featureIndex=0, - ): + featuresCol: str = "features", + labelCol: str = "label", + predictionCol: str = "prediction", + weightCol: Optional[str] = None, + isotonic: bool = True, + featureIndex: int = 0, + ) -> "IsotonicRegression": """ setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \ weightCol=None, isotonic=True, featureIndex=0): @@ -880,51 +901,56 @@ def setParams( kwargs = self._input_kwargs return self._set(**kwargs) - def _create_model(self, java_model): + def _create_model(self, java_model: "JavaObject") -> "IsotonicRegressionModel": return IsotonicRegressionModel(java_model) - def setIsotonic(self, value): + def setIsotonic(self, value: bool) -> "IsotonicRegression": """ Sets the value of :py:attr:`isotonic`. """ return self._set(isotonic=value) - def setFeatureIndex(self, value): + def setFeatureIndex(self, value: int) -> "IsotonicRegression": """ Sets the value of :py:attr:`featureIndex`. """ return self._set(featureIndex=value) @since("1.6.0") - def setFeaturesCol(self, value): + def setFeaturesCol(self, value: str) -> "IsotonicRegression": """ Sets the value of :py:attr:`featuresCol`. """ return self._set(featuresCol=value) @since("1.6.0") - def setPredictionCol(self, value): + def setPredictionCol(self, value: str) -> "IsotonicRegression": """ Sets the value of :py:attr:`predictionCol`. """ return self._set(predictionCol=value) @since("1.6.0") - def setLabelCol(self, value): + def setLabelCol(self, value: str) -> "IsotonicRegression": """ Sets the value of :py:attr:`labelCol`. """ return self._set(labelCol=value) @since("1.6.0") - def setWeightCol(self, value): + def setWeightCol(self, value: str) -> "IsotonicRegression": """ Sets the value of :py:attr:`weightCol`. """ return self._set(weightCol=value) -class IsotonicRegressionModel(JavaModel, _IsotonicRegressionParams, JavaMLWritable, JavaMLReadable): +class IsotonicRegressionModel( + JavaModel, + _IsotonicRegressionParams, + JavaMLWritable, + JavaMLReadable["IsotonicRegressionModel"], +): """ Model fitted by :class:`IsotonicRegression`. @@ -932,52 +958,52 @@ class IsotonicRegressionModel(JavaModel, _IsotonicRegressionParams, JavaMLWritab """ @since("3.0.0") - def setFeaturesCol(self, value): + def setFeaturesCol(self, value: str) -> "IsotonicRegressionModel": """ Sets the value of :py:attr:`featuresCol`. """ return self._set(featuresCol=value) @since("3.0.0") - def setPredictionCol(self, value): + def setPredictionCol(self, value: str) -> "IsotonicRegressionModel": """ Sets the value of :py:attr:`predictionCol`. """ return self._set(predictionCol=value) - def setFeatureIndex(self, value): + def setFeatureIndex(self, value: int) -> "IsotonicRegressionModel": """ Sets the value of :py:attr:`featureIndex`. """ return self._set(featureIndex=value) - @property + @property # type: ignore[misc] @since("1.6.0") - def boundaries(self): + def boundaries(self) -> Vector: """ Boundaries in increasing order for which predictions are known. """ return self._call_java("boundaries") - @property + @property # type: ignore[misc] @since("1.6.0") - def predictions(self): + def predictions(self) -> Vector: """ Predictions associated with the boundaries at the same index, monotone because of isotonic regression. """ return self._call_java("predictions") - @property + @property # type: ignore[misc] @since("3.0.0") - def numFeatures(self): + def numFeatures(self) -> int: """ Returns the number of features the model was trained on. If unknown, returns -1 """ return self._call_java("numFeatures") @since("3.0.0") - def predict(self, value): + def predict(self, value: float) -> float: """ Predict label for the given features. """ @@ -991,7 +1017,7 @@ class _DecisionTreeRegressorParams(_DecisionTreeParams, _TreeRegressorParams, Ha .. versionadded:: 3.0.0 """ - def __init__(self, *args): + def __init__(self, *args: Any): super(_DecisionTreeRegressorParams, self).__init__(*args) self._setDefault( maxDepth=5, @@ -1009,7 +1035,10 @@ def __init__(self, *args): @inherit_doc class DecisionTreeRegressor( - _JavaRegressor, _DecisionTreeRegressorParams, JavaMLWritable, JavaMLReadable + _JavaRegressor["DecisionTreeRegressionModel"], + _DecisionTreeRegressorParams, + JavaMLWritable, + JavaMLReadable["DecisionTreeRegressor"], ): """ `Decision tree `_ @@ -1079,26 +1108,28 @@ class DecisionTreeRegressor( DecisionTreeRegressionModel...depth=1, numNodes=3... """ + _input_kwargs: Dict[str, Any] + @keyword_only def __init__( self, *, - featuresCol="features", - labelCol="label", - predictionCol="prediction", - maxDepth=5, - maxBins=32, - minInstancesPerNode=1, - minInfoGain=0.0, - maxMemoryInMB=256, - cacheNodeIds=False, - checkpointInterval=10, - impurity="variance", - seed=None, - varianceCol=None, - weightCol=None, - leafCol="", - minWeightFractionPerNode=0.0, + featuresCol: str = "features", + labelCol: str = "label", + predictionCol: str = "prediction", + maxDepth: int = 5, + maxBins: int = 32, + minInstancesPerNode: int = 1, + minInfoGain: float = 0.0, + maxMemoryInMB: int = 256, + cacheNodeIds: bool = False, + checkpointInterval: int = 10, + impurity: str = "variance", + seed: Optional[int] = None, + varianceCol: Optional[str] = None, + weightCol: Optional[str] = None, + leafCol: str = "", + minWeightFractionPerNode: float = 0.0, ): """ __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \ @@ -1119,23 +1150,23 @@ def __init__( def setParams( self, *, - featuresCol="features", - labelCol="label", - predictionCol="prediction", - maxDepth=5, - maxBins=32, - minInstancesPerNode=1, - minInfoGain=0.0, - maxMemoryInMB=256, - cacheNodeIds=False, - checkpointInterval=10, - impurity="variance", - seed=None, - varianceCol=None, - weightCol=None, - leafCol="", - minWeightFractionPerNode=0.0, - ): + featuresCol: str = "features", + labelCol: str = "label", + predictionCol: str = "prediction", + maxDepth: int = 5, + maxBins: int = 32, + minInstancesPerNode: int = 1, + minInfoGain: float = 0.0, + maxMemoryInMB: int = 256, + cacheNodeIds: bool = False, + checkpointInterval: int = 10, + impurity: str = "variance", + seed: Optional[int] = None, + varianceCol: Optional[str] = None, + weightCol: Optional[str] = None, + leafCol: str = "", + minWeightFractionPerNode: float = 0.0, + ) -> "DecisionTreeRegressor": """ setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ @@ -1147,87 +1178,87 @@ def setParams( kwargs = self._input_kwargs return self._set(**kwargs) - def _create_model(self, java_model): + def _create_model(self, java_model: "JavaObject") -> "DecisionTreeRegressionModel": return DecisionTreeRegressionModel(java_model) @since("1.4.0") - def setMaxDepth(self, value): + def setMaxDepth(self, value: int) -> "DecisionTreeRegressor": """ Sets the value of :py:attr:`maxDepth`. """ return self._set(maxDepth=value) @since("1.4.0") - def setMaxBins(self, value): + def setMaxBins(self, value: int) -> "DecisionTreeRegressor": """ Sets the value of :py:attr:`maxBins`. """ return self._set(maxBins=value) @since("1.4.0") - def setMinInstancesPerNode(self, value): + def setMinInstancesPerNode(self, value: int) -> "DecisionTreeRegressor": """ Sets the value of :py:attr:`minInstancesPerNode`. """ return self._set(minInstancesPerNode=value) @since("3.0.0") - def setMinWeightFractionPerNode(self, value): + def setMinWeightFractionPerNode(self, value: float) -> "DecisionTreeRegressor": """ Sets the value of :py:attr:`minWeightFractionPerNode`. """ return self._set(minWeightFractionPerNode=value) @since("1.4.0") - def setMinInfoGain(self, value): + def setMinInfoGain(self, value: float) -> "DecisionTreeRegressor": """ Sets the value of :py:attr:`minInfoGain`. """ return self._set(minInfoGain=value) @since("1.4.0") - def setMaxMemoryInMB(self, value): + def setMaxMemoryInMB(self, value: int) -> "DecisionTreeRegressor": """ Sets the value of :py:attr:`maxMemoryInMB`. """ return self._set(maxMemoryInMB=value) @since("1.4.0") - def setCacheNodeIds(self, value): + def setCacheNodeIds(self, value: bool) -> "DecisionTreeRegressor": """ Sets the value of :py:attr:`cacheNodeIds`. """ return self._set(cacheNodeIds=value) @since("1.4.0") - def setImpurity(self, value): + def setImpurity(self, value: str) -> "DecisionTreeRegressor": """ Sets the value of :py:attr:`impurity`. """ return self._set(impurity=value) @since("1.4.0") - def setCheckpointInterval(self, value): + def setCheckpointInterval(self, value: int) -> "DecisionTreeRegressor": """ Sets the value of :py:attr:`checkpointInterval`. """ return self._set(checkpointInterval=value) - def setSeed(self, value): + def setSeed(self, value: int) -> "DecisionTreeRegressor": """ Sets the value of :py:attr:`seed`. """ return self._set(seed=value) @since("3.0.0") - def setWeightCol(self, value): + def setWeightCol(self, value: str) -> "DecisionTreeRegressor": """ Sets the value of :py:attr:`weightCol`. """ return self._set(weightCol=value) @since("2.0.0") - def setVarianceCol(self, value): + def setVarianceCol(self, value: str) -> "DecisionTreeRegressor": """ Sets the value of :py:attr:`varianceCol`. """ @@ -1240,7 +1271,7 @@ class DecisionTreeRegressionModel( _DecisionTreeModel, _DecisionTreeRegressorParams, JavaMLWritable, - JavaMLReadable, + JavaMLReadable["DecisionTreeRegressionModel"], ): """ Model fitted by :class:`DecisionTreeRegressor`. @@ -1249,14 +1280,14 @@ class DecisionTreeRegressionModel( """ @since("3.0.0") - def setVarianceCol(self, value): + def setVarianceCol(self, value: str) -> "DecisionTreeRegressionModel": """ Sets the value of :py:attr:`varianceCol`. """ return self._set(varianceCol=value) @property - def featureImportances(self): + def featureImportances(self) -> Vector: """ Estimate of the importance of each feature. @@ -1287,7 +1318,7 @@ class _RandomForestRegressorParams(_RandomForestParams, _TreeRegressorParams): .. versionadded:: 3.0.0 """ - def __init__(self, *args): + def __init__(self, *args: Any): super(_RandomForestRegressorParams, self).__init__(*args) self._setDefault( maxDepth=5, @@ -1309,7 +1340,10 @@ def __init__(self, *args): @inherit_doc class RandomForestRegressor( - _JavaRegressor, _RandomForestRegressorParams, JavaMLWritable, JavaMLReadable + _JavaRegressor["RandomForestRegressionModel"], + _RandomForestRegressorParams, + JavaMLWritable, + JavaMLReadable["RandomForestRegressor"], ): """ `Random Forest `_ @@ -1374,29 +1408,31 @@ class RandomForestRegressor( True """ + _input_kwargs: Dict[str, Any] + @keyword_only def __init__( self, *, - featuresCol="features", - labelCol="label", - predictionCol="prediction", - maxDepth=5, - maxBins=32, - minInstancesPerNode=1, - minInfoGain=0.0, - maxMemoryInMB=256, - cacheNodeIds=False, - checkpointInterval=10, - impurity="variance", - subsamplingRate=1.0, - seed=None, - numTrees=20, - featureSubsetStrategy="auto", - leafCol="", - minWeightFractionPerNode=0.0, - weightCol=None, - bootstrap=True, + featuresCol: str = "features", + labelCol: str = "label", + predictionCol: str = "prediction", + maxDepth: int = 5, + maxBins: int = 32, + minInstancesPerNode: int = 1, + minInfoGain: float = 0.0, + maxMemoryInMB: int = 256, + cacheNodeIds: bool = False, + checkpointInterval: int = 10, + impurity: str = "variance", + subsamplingRate: float = 1.0, + seed: Optional[int] = None, + numTrees: int = 20, + featureSubsetStrategy: str = "auto", + leafCol: str = "", + minWeightFractionPerNode: float = 0.0, + weightCol: Optional[str] = None, + bootstrap: Optional[bool] = True, ): """ __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \ @@ -1418,26 +1454,26 @@ def __init__( def setParams( self, *, - featuresCol="features", - labelCol="label", - predictionCol="prediction", - maxDepth=5, - maxBins=32, - minInstancesPerNode=1, - minInfoGain=0.0, - maxMemoryInMB=256, - cacheNodeIds=False, - checkpointInterval=10, - impurity="variance", - subsamplingRate=1.0, - seed=None, - numTrees=20, - featureSubsetStrategy="auto", - leafCol="", - minWeightFractionPerNode=0.0, - weightCol=None, - bootstrap=True, - ): + featuresCol: str = "features", + labelCol: str = "label", + predictionCol: str = "prediction", + maxDepth: int = 5, + maxBins: int = 32, + minInstancesPerNode: int = 1, + minInfoGain: float = 0.0, + maxMemoryInMB: int = 256, + cacheNodeIds: bool = False, + checkpointInterval: int = 10, + impurity: str = "variance", + subsamplingRate: float = 1.0, + seed: Optional[int] = None, + numTrees: int = 20, + featureSubsetStrategy: str = "auto", + leafCol: str = "", + minWeightFractionPerNode: float = 0.0, + weightCol: Optional[str] = None, + bootstrap: Optional[bool] = True, + ) -> "RandomForestRegressor": """ setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ @@ -1450,101 +1486,101 @@ def setParams( kwargs = self._input_kwargs return self._set(**kwargs) - def _create_model(self, java_model): + def _create_model(self, java_model: "JavaObject") -> "RandomForestRegressionModel": return RandomForestRegressionModel(java_model) - def setMaxDepth(self, value): + def setMaxDepth(self, value: int) -> "RandomForestRegressor": """ Sets the value of :py:attr:`maxDepth`. """ return self._set(maxDepth=value) - def setMaxBins(self, value): + def setMaxBins(self, value: int) -> "RandomForestRegressor": """ Sets the value of :py:attr:`maxBins`. """ return self._set(maxBins=value) - def setMinInstancesPerNode(self, value): + def setMinInstancesPerNode(self, value: int) -> "RandomForestRegressor": """ Sets the value of :py:attr:`minInstancesPerNode`. """ return self._set(minInstancesPerNode=value) - def setMinInfoGain(self, value): + def setMinInfoGain(self, value: float) -> "RandomForestRegressor": """ Sets the value of :py:attr:`minInfoGain`. """ return self._set(minInfoGain=value) - def setMaxMemoryInMB(self, value): + def setMaxMemoryInMB(self, value: int) -> "RandomForestRegressor": """ Sets the value of :py:attr:`maxMemoryInMB`. """ return self._set(maxMemoryInMB=value) - def setCacheNodeIds(self, value): + def setCacheNodeIds(self, value: bool) -> "RandomForestRegressor": """ Sets the value of :py:attr:`cacheNodeIds`. """ return self._set(cacheNodeIds=value) @since("1.4.0") - def setImpurity(self, value): + def setImpurity(self, value: str) -> "RandomForestRegressor": """ Sets the value of :py:attr:`impurity`. """ return self._set(impurity=value) @since("1.4.0") - def setNumTrees(self, value): + def setNumTrees(self, value: int) -> "RandomForestRegressor": """ Sets the value of :py:attr:`numTrees`. """ return self._set(numTrees=value) @since("3.0.0") - def setBootstrap(self, value): + def setBootstrap(self, value: bool) -> "RandomForestRegressor": """ Sets the value of :py:attr:`bootstrap`. """ return self._set(bootstrap=value) @since("1.4.0") - def setSubsamplingRate(self, value): + def setSubsamplingRate(self, value: float) -> "RandomForestRegressor": """ Sets the value of :py:attr:`subsamplingRate`. """ return self._set(subsamplingRate=value) @since("2.4.0") - def setFeatureSubsetStrategy(self, value): + def setFeatureSubsetStrategy(self, value: str) -> "RandomForestRegressor": """ Sets the value of :py:attr:`featureSubsetStrategy`. """ return self._set(featureSubsetStrategy=value) - def setCheckpointInterval(self, value): + def setCheckpointInterval(self, value: int) -> "RandomForestRegressor": """ Sets the value of :py:attr:`checkpointInterval`. """ return self._set(checkpointInterval=value) - def setSeed(self, value): + def setSeed(self, value: int) -> "RandomForestRegressor": """ Sets the value of :py:attr:`seed`. """ return self._set(seed=value) @since("3.0.0") - def setWeightCol(self, value): + def setWeightCol(self, value: str) -> "RandomForestRegressor": """ Sets the value of :py:attr:`weightCol`. """ return self._set(weightCol=value) @since("3.0.0") - def setMinWeightFractionPerNode(self, value): + def setMinWeightFractionPerNode(self, value: float) -> "RandomForestRegressor": """ Sets the value of :py:attr:`minWeightFractionPerNode`. """ @@ -1552,11 +1588,11 @@ def setMinWeightFractionPerNode(self, value): class RandomForestRegressionModel( - _JavaRegressionModel, + _JavaRegressionModel[Vector], _TreeEnsembleModel, _RandomForestRegressorParams, JavaMLWritable, - JavaMLReadable, + JavaMLReadable["RandomForestRegressionModel"], ): """ Model fitted by :class:`RandomForestRegressor`. @@ -1564,14 +1600,14 @@ class RandomForestRegressionModel( .. versionadded:: 1.4.0 """ - @property + @property # type: ignore[misc] @since("2.0.0") - def trees(self): + def trees(self) -> List[DecisionTreeRegressionModel]: """Trees in this ensemble. Warning: These have null parent Estimators.""" return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))] @property - def featureImportances(self): + def featureImportances(self) -> Vector: """ Estimate of the importance of each feature. @@ -1596,9 +1632,9 @@ class _GBTRegressorParams(_GBTParams, _TreeRegressorParams): .. versionadded:: 3.0.0 """ - supportedLossTypes = ["squared", "absolute"] + supportedLossTypes: List[str] = ["squared", "absolute"] - lossType = Param( + lossType: Param[str] = Param( Params._dummy(), "lossType", "Loss function which GBT tries to minimize (case-insensitive). " @@ -1607,7 +1643,7 @@ class _GBTRegressorParams(_GBTParams, _TreeRegressorParams): typeConverter=TypeConverters.toString, ) - def __init__(self, *args): + def __init__(self, *args: Any): super(_GBTRegressorParams, self).__init__(*args) self._setDefault( maxDepth=5, @@ -1629,7 +1665,7 @@ def __init__(self, *args): ) @since("1.4.0") - def getLossType(self): + def getLossType(self) -> str: """ Gets the value of lossType or its default value. """ @@ -1637,7 +1673,12 @@ def getLossType(self): @inherit_doc -class GBTRegressor(_JavaRegressor, _GBTRegressorParams, JavaMLWritable, JavaMLReadable): +class GBTRegressor( + _JavaRegressor["GBTRegressionModel"], + _GBTRegressorParams, + JavaMLWritable, + JavaMLReadable["GBTRegressor"], +): """ `Gradient-Boosted Trees (GBTs) `_ learning algorithm for regression. @@ -1710,32 +1751,34 @@ class GBTRegressor(_JavaRegressor, _GBTRegressorParams, JavaMLWritable, JavaMLRe 0.01 """ + _input_kwargs: Dict[str, Any] + @keyword_only def __init__( self, *, - featuresCol="features", - labelCol="label", - predictionCol="prediction", - maxDepth=5, - maxBins=32, - minInstancesPerNode=1, - minInfoGain=0.0, - maxMemoryInMB=256, - cacheNodeIds=False, - subsamplingRate=1.0, - checkpointInterval=10, - lossType="squared", - maxIter=20, - stepSize=0.1, - seed=None, - impurity="variance", - featureSubsetStrategy="all", - validationTol=0.01, - validationIndicatorCol=None, - leafCol="", - minWeightFractionPerNode=0.0, - weightCol=None, + featuresCol: str = "features", + labelCol: str = "label", + predictionCol: str = "prediction", + maxDepth: int = 5, + maxBins: int = 32, + minInstancesPerNode: int = 1, + minInfoGain: float = 0.0, + maxMemoryInMB: int = 256, + cacheNodeIds: bool = False, + subsamplingRate: float = 1.0, + checkpointInterval: int = 10, + lossType: str = "squared", + maxIter: int = 20, + stepSize: float = 0.1, + seed: Optional[int] = None, + impurity: str = "variance", + featureSubsetStrategy: str = "all", + validationTol: float = 0.1, + validationIndicatorCol: Optional[str] = None, + leafCol: str = "", + minWeightFractionPerNode: float = 0.0, + weightCol: Optional[str] = None, ): """ __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \ @@ -1756,29 +1799,29 @@ def __init__( def setParams( self, *, - featuresCol="features", - labelCol="label", - predictionCol="prediction", - maxDepth=5, - maxBins=32, - minInstancesPerNode=1, - minInfoGain=0.0, - maxMemoryInMB=256, - cacheNodeIds=False, - subsamplingRate=1.0, - checkpointInterval=10, - lossType="squared", - maxIter=20, - stepSize=0.1, - seed=None, - impurity="variance", - featureSubsetStrategy="all", - validationTol=0.01, - validationIndicatorCol=None, - leafCol="", - minWeightFractionPerNode=0.0, - weightCol=None, - ): + featuresCol: str = "features", + labelCol: str = "label", + predictionCol: str = "prediction", + maxDepth: int = 5, + maxBins: int = 32, + minInstancesPerNode: int = 1, + minInfoGain: float = 0.0, + maxMemoryInMB: int = 256, + cacheNodeIds: bool = False, + subsamplingRate: float = 1.0, + checkpointInterval: int = 10, + lossType: str = "squared", + maxIter: int = 20, + stepSize: float = 0.1, + seed: Optional[int] = None, + impurity: str = "variance", + featureSubsetStrategy: str = "all", + validationTol: float = 0.1, + validationIndicatorCol: Optional[str] = None, + leafCol: str = "", + minWeightFractionPerNode: float = 0.0, + weightCol: Optional[str] = None, + ) -> "GBTRegressor": """ setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ @@ -1792,123 +1835,123 @@ def setParams( kwargs = self._input_kwargs return self._set(**kwargs) - def _create_model(self, java_model): + def _create_model(self, java_model: "JavaObject") -> "GBTRegressionModel": return GBTRegressionModel(java_model) @since("1.4.0") - def setMaxDepth(self, value): + def setMaxDepth(self, value: int) -> "GBTRegressor": """ Sets the value of :py:attr:`maxDepth`. """ return self._set(maxDepth=value) @since("1.4.0") - def setMaxBins(self, value): + def setMaxBins(self, value: int) -> "GBTRegressor": """ Sets the value of :py:attr:`maxBins`. """ return self._set(maxBins=value) @since("1.4.0") - def setMinInstancesPerNode(self, value): + def setMinInstancesPerNode(self, value: int) -> "GBTRegressor": """ Sets the value of :py:attr:`minInstancesPerNode`. """ return self._set(minInstancesPerNode=value) @since("1.4.0") - def setMinInfoGain(self, value): + def setMinInfoGain(self, value: float) -> "GBTRegressor": """ Sets the value of :py:attr:`minInfoGain`. """ return self._set(minInfoGain=value) @since("1.4.0") - def setMaxMemoryInMB(self, value): + def setMaxMemoryInMB(self, value: int) -> "GBTRegressor": """ Sets the value of :py:attr:`maxMemoryInMB`. """ return self._set(maxMemoryInMB=value) @since("1.4.0") - def setCacheNodeIds(self, value): + def setCacheNodeIds(self, value: bool) -> "GBTRegressor": """ Sets the value of :py:attr:`cacheNodeIds`. """ return self._set(cacheNodeIds=value) @since("1.4.0") - def setImpurity(self, value): + def setImpurity(self, value: str) -> "GBTRegressor": """ Sets the value of :py:attr:`impurity`. """ return self._set(impurity=value) @since("1.4.0") - def setLossType(self, value): + def setLossType(self, value: str) -> "GBTRegressor": """ Sets the value of :py:attr:`lossType`. """ return self._set(lossType=value) @since("1.4.0") - def setSubsamplingRate(self, value): + def setSubsamplingRate(self, value: float) -> "GBTRegressor": """ Sets the value of :py:attr:`subsamplingRate`. """ return self._set(subsamplingRate=value) @since("2.4.0") - def setFeatureSubsetStrategy(self, value): + def setFeatureSubsetStrategy(self, value: str) -> "GBTRegressor": """ Sets the value of :py:attr:`featureSubsetStrategy`. """ return self._set(featureSubsetStrategy=value) @since("3.0.0") - def setValidationIndicatorCol(self, value): + def setValidationIndicatorCol(self, value: str) -> "GBTRegressor": """ Sets the value of :py:attr:`validationIndicatorCol`. """ return self._set(validationIndicatorCol=value) @since("1.4.0") - def setMaxIter(self, value): + def setMaxIter(self, value: int) -> "GBTRegressor": """ Sets the value of :py:attr:`maxIter`. """ return self._set(maxIter=value) @since("1.4.0") - def setCheckpointInterval(self, value): + def setCheckpointInterval(self, value: int) -> "GBTRegressor": """ Sets the value of :py:attr:`checkpointInterval`. """ return self._set(checkpointInterval=value) @since("1.4.0") - def setSeed(self, value): + def setSeed(self, value: int) -> "GBTRegressor": """ Sets the value of :py:attr:`seed`. """ return self._set(seed=value) @since("1.4.0") - def setStepSize(self, value): + def setStepSize(self, value: float) -> "GBTRegressor": """ Sets the value of :py:attr:`stepSize`. """ return self._set(stepSize=value) @since("3.0.0") - def setWeightCol(self, value): + def setWeightCol(self, value: str) -> "GBTRegressor": """ Sets the value of :py:attr:`weightCol`. """ return self._set(weightCol=value) @since("3.0.0") - def setMinWeightFractionPerNode(self, value): + def setMinWeightFractionPerNode(self, value: float) -> "GBTRegressor": """ Sets the value of :py:attr:`minWeightFractionPerNode`. """ @@ -1916,7 +1959,11 @@ def setMinWeightFractionPerNode(self, value): class GBTRegressionModel( - _JavaRegressionModel, _TreeEnsembleModel, _GBTRegressorParams, JavaMLWritable, JavaMLReadable + _JavaRegressionModel[Vector], + _TreeEnsembleModel, + _GBTRegressorParams, + JavaMLWritable, + JavaMLReadable["GBTRegressionModel"], ): """ Model fitted by :class:`GBTRegressor`. @@ -1925,7 +1972,7 @@ class GBTRegressionModel( """ @property - def featureImportances(self): + def featureImportances(self) -> Vector: """ Estimate of the importance of each feature. @@ -1942,13 +1989,13 @@ def featureImportances(self): """ return self._call_java("featureImportances") - @property + @property # type: ignore[misc] @since("2.0.0") - def trees(self): + def trees(self) -> List[DecisionTreeRegressionModel]: """Trees in this ensemble. Warning: These have null parent Estimators.""" return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))] - def evaluateEachIteration(self, dataset, loss): + def evaluateEachIteration(self, dataset: DataFrame, loss: str) -> List[float]: """ Method to compute error or loss for every iteration of gradient boosting. @@ -1975,7 +2022,7 @@ class _AFTSurvivalRegressionParams( .. versionadded:: 3.0.0 """ - censorCol = Param( + censorCol: Param[str] = Param( Params._dummy(), "censorCol", "censor column name. The value of this column could be 0 or 1. " @@ -1983,14 +2030,14 @@ class _AFTSurvivalRegressionParams( + "uncensored; otherwise censored.", typeConverter=TypeConverters.toString, ) - quantileProbabilities = Param( + quantileProbabilities: Param[List[float]] = Param( Params._dummy(), "quantileProbabilities", "quantile probabilities array. Values of the quantile probabilities array " + "should be in the range (0, 1) and the array should be non-empty.", typeConverter=TypeConverters.toListFloat, ) - quantilesCol = Param( + quantilesCol: Param[str] = Param( Params._dummy(), "quantilesCol", "quantiles column name. This column will output quantiles of " @@ -1998,7 +2045,7 @@ class _AFTSurvivalRegressionParams( typeConverter=TypeConverters.toString, ) - def __init__(self, *args): + def __init__(self, *args: Any): super(_AFTSurvivalRegressionParams, self).__init__(*args) self._setDefault( censorCol="censor", @@ -2009,21 +2056,21 @@ def __init__(self, *args): ) @since("1.6.0") - def getCensorCol(self): + def getCensorCol(self) -> str: """ Gets the value of censorCol or its default value. """ return self.getOrDefault(self.censorCol) @since("1.6.0") - def getQuantileProbabilities(self): + def getQuantileProbabilities(self) -> List[float]: """ Gets the value of quantileProbabilities or its default value. """ return self.getOrDefault(self.quantileProbabilities) @since("1.6.0") - def getQuantilesCol(self): + def getQuantilesCol(self) -> str: """ Gets the value of quantilesCol or its default value. """ @@ -2032,7 +2079,10 @@ def getQuantilesCol(self): @inherit_doc class AFTSurvivalRegression( - _JavaRegressor, _AFTSurvivalRegressionParams, JavaMLWritable, JavaMLReadable + _JavaRegressor["AFTSurvivalRegressionModel"], + _AFTSurvivalRegressionParams, + JavaMLWritable, + JavaMLReadable["AFTSurvivalRegression"], ): """ Accelerated Failure Time (AFT) Model Survival Regression @@ -2095,23 +2145,33 @@ class AFTSurvivalRegression( .. versionadded:: 1.6.0 """ + _input_kwargs: Dict[str, Any] + @keyword_only def __init__( self, *, - featuresCol="features", - labelCol="label", - predictionCol="prediction", - fitIntercept=True, - maxIter=100, - tol=1e-6, - censorCol="censor", - quantileProbabilities=list( - [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99] - ), # noqa: B005 - quantilesCol=None, - aggregationDepth=2, - maxBlockSizeInMB=0.0, + featuresCol: str = "features", + labelCol: str = "label", + predictionCol: str = "prediction", + fitIntercept: bool = True, + maxIter: int = 100, + tol: float = 1e-6, + censorCol: str = "censor", + quantileProbabilities: List[float] = [ + 0.01, + 0.05, + 0.1, + 0.25, + 0.5, + 0.75, + 0.9, + 0.95, + 0.99, + ], # noqa: B005 + quantilesCol: Optional[str] = None, + aggregationDepth: int = 2, + maxBlockSizeInMB: float = 0.0, ): """ __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \ @@ -2131,20 +2191,28 @@ def __init__( def setParams( self, *, - featuresCol="features", - labelCol="label", - predictionCol="prediction", - fitIntercept=True, - maxIter=100, - tol=1e-6, - censorCol="censor", - quantileProbabilities=list( - [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99] - ), # noqa: B005 - quantilesCol=None, - aggregationDepth=2, - maxBlockSizeInMB=0.0, - ): + featuresCol: str = "features", + labelCol: str = "label", + predictionCol: str = "prediction", + fitIntercept: bool = True, + maxIter: int = 100, + tol: float = 1e-6, + censorCol: str = "censor", + quantileProbabilities: List[float] = [ + 0.01, + 0.05, + 0.1, + 0.25, + 0.5, + 0.75, + 0.9, + 0.95, + 0.99, + ], # noqa: B005 + quantilesCol: Optional[str] = None, + aggregationDepth: int = 2, + maxBlockSizeInMB: float = 0.0, + ) -> "AFTSurvivalRegression": """ setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \ fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \ @@ -2154,60 +2222,60 @@ def setParams( kwargs = self._input_kwargs return self._set(**kwargs) - def _create_model(self, java_model): + def _create_model(self, java_model: "JavaObject") -> "AFTSurvivalRegressionModel": return AFTSurvivalRegressionModel(java_model) @since("1.6.0") - def setCensorCol(self, value): + def setCensorCol(self, value: str) -> "AFTSurvivalRegression": """ Sets the value of :py:attr:`censorCol`. """ return self._set(censorCol=value) @since("1.6.0") - def setQuantileProbabilities(self, value): + def setQuantileProbabilities(self, value: List[float]) -> "AFTSurvivalRegression": """ Sets the value of :py:attr:`quantileProbabilities`. """ return self._set(quantileProbabilities=value) @since("1.6.0") - def setQuantilesCol(self, value): + def setQuantilesCol(self, value: str) -> "AFTSurvivalRegression": """ Sets the value of :py:attr:`quantilesCol`. """ return self._set(quantilesCol=value) @since("1.6.0") - def setMaxIter(self, value): + def setMaxIter(self, value: int) -> "AFTSurvivalRegression": """ Sets the value of :py:attr:`maxIter`. """ return self._set(maxIter=value) @since("1.6.0") - def setTol(self, value): + def setTol(self, value: float) -> "AFTSurvivalRegression": """ Sets the value of :py:attr:`tol`. """ return self._set(tol=value) @since("1.6.0") - def setFitIntercept(self, value): + def setFitIntercept(self, value: bool) -> "AFTSurvivalRegression": """ Sets the value of :py:attr:`fitIntercept`. """ return self._set(fitIntercept=value) @since("2.1.0") - def setAggregationDepth(self, value): + def setAggregationDepth(self, value: int) -> "AFTSurvivalRegression": """ Sets the value of :py:attr:`aggregationDepth`. """ return self._set(aggregationDepth=value) @since("3.1.0") - def setMaxBlockSizeInMB(self, value): + def setMaxBlockSizeInMB(self, value: int) -> "AFTSurvivalRegression": """ Sets the value of :py:attr:`maxBlockSizeInMB`. """ @@ -2215,7 +2283,10 @@ def setMaxBlockSizeInMB(self, value): class AFTSurvivalRegressionModel( - _JavaRegressionModel, _AFTSurvivalRegressionParams, JavaMLWritable, JavaMLReadable + _JavaRegressionModel[Vector], + _AFTSurvivalRegressionParams, + JavaMLWritable, + JavaMLReadable["AFTSurvivalRegressionModel"], ): """ Model fitted by :class:`AFTSurvivalRegression`. @@ -2224,45 +2295,45 @@ class AFTSurvivalRegressionModel( """ @since("3.0.0") - def setQuantileProbabilities(self, value): + def setQuantileProbabilities(self, value: List[float]) -> "AFTSurvivalRegressionModel": """ Sets the value of :py:attr:`quantileProbabilities`. """ return self._set(quantileProbabilities=value) @since("3.0.0") - def setQuantilesCol(self, value): + def setQuantilesCol(self, value: str) -> "AFTSurvivalRegressionModel": """ Sets the value of :py:attr:`quantilesCol`. """ return self._set(quantilesCol=value) - @property + @property # type: ignore[misc] @since("2.0.0") - def coefficients(self): + def coefficients(self) -> Vector: """ Model coefficients. """ return self._call_java("coefficients") - @property + @property # type: ignore[misc] @since("1.6.0") - def intercept(self): + def intercept(self) -> float: """ Model intercept. """ return self._call_java("intercept") - @property + @property # type: ignore[misc] @since("1.6.0") - def scale(self): + def scale(self) -> float: """ Model scale parameter. """ return self._call_java("scale") @since("2.0.0") - def predictQuantiles(self, features): + def predictQuantiles(self, features: Vector) -> Vector: """ Predicted Quantiles """ @@ -2286,7 +2357,7 @@ class _GeneralizedLinearRegressionParams( .. versionadded:: 3.0.0 """ - family = Param( + family: Param[str] = Param( Params._dummy(), "family", "The name of family which is a description of " @@ -2294,7 +2365,7 @@ class _GeneralizedLinearRegressionParams( + "gaussian (default), binomial, poisson, gamma and tweedie.", typeConverter=TypeConverters.toString, ) - link = Param( + link: Param[str] = Param( Params._dummy(), "link", "The name of link function which provides the " @@ -2303,13 +2374,13 @@ class _GeneralizedLinearRegressionParams( + "and sqrt.", typeConverter=TypeConverters.toString, ) - linkPredictionCol = Param( + linkPredictionCol: Param[str] = Param( Params._dummy(), "linkPredictionCol", "link prediction (linear " + "predictor) column name", typeConverter=TypeConverters.toString, ) - variancePower = Param( + variancePower: Param[float] = Param( Params._dummy(), "variancePower", "The power in the variance function " @@ -2318,19 +2389,19 @@ class _GeneralizedLinearRegressionParams( + "for the Tweedie family. Supported values: 0 and [1, Inf).", typeConverter=TypeConverters.toFloat, ) - linkPower = Param( + linkPower: Param[float] = Param( Params._dummy(), "linkPower", "The index in the power link function. " + "Only applicable to the Tweedie family.", typeConverter=TypeConverters.toFloat, ) - solver = Param( + solver: Param[str] = Param( Params._dummy(), "solver", "The solver algorithm for optimization. Supported " + "options: irls.", typeConverter=TypeConverters.toString, ) - offsetCol = Param( + offsetCol: Param[str] = Param( Params._dummy(), "offsetCol", "The offset column name. If this is not set " @@ -2338,7 +2409,7 @@ class _GeneralizedLinearRegressionParams( typeConverter=TypeConverters.toString, ) - def __init__(self, *args): + def __init__(self, *args: Any): super(_GeneralizedLinearRegressionParams, self).__init__(*args) self._setDefault( family="gaussian", @@ -2351,42 +2422,42 @@ def __init__(self, *args): ) @since("2.0.0") - def getFamily(self): + def getFamily(self) -> str: """ Gets the value of family or its default value. """ return self.getOrDefault(self.family) @since("2.0.0") - def getLinkPredictionCol(self): + def getLinkPredictionCol(self) -> str: """ Gets the value of linkPredictionCol or its default value. """ return self.getOrDefault(self.linkPredictionCol) @since("2.0.0") - def getLink(self): + def getLink(self) -> str: """ Gets the value of link or its default value. """ return self.getOrDefault(self.link) @since("2.2.0") - def getVariancePower(self): + def getVariancePower(self) -> float: """ Gets the value of variancePower or its default value. """ return self.getOrDefault(self.variancePower) @since("2.2.0") - def getLinkPower(self): + def getLinkPower(self) -> float: """ Gets the value of linkPower or its default value. """ return self.getOrDefault(self.linkPower) @since("2.3.0") - def getOffsetCol(self): + def getOffsetCol(self) -> str: """ Gets the value of offsetCol or its default value. """ @@ -2395,7 +2466,10 @@ def getOffsetCol(self): @inherit_doc class GeneralizedLinearRegression( - _JavaRegressor, _GeneralizedLinearRegressionParams, JavaMLWritable, JavaMLReadable + _JavaRegressor["GeneralizedLinearRegressionModel"], + _GeneralizedLinearRegressionParams, + JavaMLWritable, + JavaMLReadable["GeneralizedLinearRegression"], ): """ Generalized Linear Regression. @@ -2476,26 +2550,28 @@ class GeneralizedLinearRegression( True """ + _input_kwargs: Dict[str, Any] + @keyword_only def __init__( self, *, - labelCol="label", - featuresCol="features", - predictionCol="prediction", - family="gaussian", - link=None, - fitIntercept=True, - maxIter=25, - tol=1e-6, - regParam=0.0, - weightCol=None, - solver="irls", - linkPredictionCol=None, - variancePower=0.0, - linkPower=None, - offsetCol=None, - aggregationDepth=2, + labelCol: str = "label", + featuresCol: str = "features", + predictionCol: str = "prediction", + family: str = "gaussian", + link: Optional[str] = None, + fitIntercept: bool = True, + maxIter: int = 25, + tol: float = 1e-6, + regParam: float = 0.0, + weightCol: Optional[str] = None, + solver: str = "irls", + linkPredictionCol: Optional[str] = None, + variancePower: float = 0.0, + linkPower: Optional[float] = None, + offsetCol: Optional[str] = None, + aggregationDepth: int = 2, ): """ __init__(self, \\*, labelCol="label", featuresCol="features", predictionCol="prediction", \ @@ -2516,23 +2592,23 @@ def __init__( def setParams( self, *, - labelCol="label", - featuresCol="features", - predictionCol="prediction", - family="gaussian", - link=None, - fitIntercept=True, - maxIter=25, - tol=1e-6, - regParam=0.0, - weightCol=None, - solver="irls", - linkPredictionCol=None, - variancePower=0.0, - linkPower=None, - offsetCol=None, - aggregationDepth=2, - ): + labelCol: str = "label", + featuresCol: str = "features", + predictionCol: str = "prediction", + family: str = "gaussian", + link: Optional[str] = None, + fitIntercept: bool = True, + maxIter: int = 25, + tol: float = 1e-6, + regParam: float = 0.0, + weightCol: Optional[str] = None, + solver: str = "irls", + linkPredictionCol: Optional[str] = None, + variancePower: float = 0.0, + linkPower: Optional[float] = None, + offsetCol: Optional[str] = None, + aggregationDepth: int = 2, + ) -> "GeneralizedLinearRegression": """ setParams(self, \\*, labelCol="label", featuresCol="features", predictionCol="prediction", \ family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \ @@ -2543,95 +2619,95 @@ def setParams( kwargs = self._input_kwargs return self._set(**kwargs) - def _create_model(self, java_model): + def _create_model(self, java_model: "JavaObject") -> "GeneralizedLinearRegressionModel": return GeneralizedLinearRegressionModel(java_model) @since("2.0.0") - def setFamily(self, value): + def setFamily(self, value: str) -> "GeneralizedLinearRegression": """ Sets the value of :py:attr:`family`. """ return self._set(family=value) @since("2.0.0") - def setLinkPredictionCol(self, value): + def setLinkPredictionCol(self, value: str) -> "GeneralizedLinearRegression": """ Sets the value of :py:attr:`linkPredictionCol`. """ return self._set(linkPredictionCol=value) @since("2.0.0") - def setLink(self, value): + def setLink(self, value: str) -> "GeneralizedLinearRegression": """ Sets the value of :py:attr:`link`. """ return self._set(link=value) @since("2.2.0") - def setVariancePower(self, value): + def setVariancePower(self, value: float) -> "GeneralizedLinearRegression": """ Sets the value of :py:attr:`variancePower`. """ return self._set(variancePower=value) @since("2.2.0") - def setLinkPower(self, value): + def setLinkPower(self, value: float) -> "GeneralizedLinearRegression": """ Sets the value of :py:attr:`linkPower`. """ return self._set(linkPower=value) @since("2.3.0") - def setOffsetCol(self, value): + def setOffsetCol(self, value: str) -> "GeneralizedLinearRegression": """ Sets the value of :py:attr:`offsetCol`. """ return self._set(offsetCol=value) @since("2.0.0") - def setMaxIter(self, value): + def setMaxIter(self, value: int) -> "GeneralizedLinearRegression": """ Sets the value of :py:attr:`maxIter`. """ return self._set(maxIter=value) @since("2.0.0") - def setRegParam(self, value): + def setRegParam(self, value: float) -> "GeneralizedLinearRegression": """ Sets the value of :py:attr:`regParam`. """ return self._set(regParam=value) @since("2.0.0") - def setTol(self, value): + def setTol(self, value: float) -> "GeneralizedLinearRegression": """ Sets the value of :py:attr:`tol`. """ return self._set(tol=value) @since("2.0.0") - def setFitIntercept(self, value): + def setFitIntercept(self, value: bool) -> "GeneralizedLinearRegression": """ Sets the value of :py:attr:`fitIntercept`. """ return self._set(fitIntercept=value) @since("2.0.0") - def setWeightCol(self, value): + def setWeightCol(self, value: str) -> "GeneralizedLinearRegression": """ Sets the value of :py:attr:`weightCol`. """ return self._set(weightCol=value) @since("2.0.0") - def setSolver(self, value): + def setSolver(self, value: str) -> "GeneralizedLinearRegression": """ Sets the value of :py:attr:`solver`. """ return self._set(solver=value) @since("3.0.0") - def setAggregationDepth(self, value): + def setAggregationDepth(self, value: int) -> "GeneralizedLinearRegression": """ Sets the value of :py:attr:`aggregationDepth`. """ @@ -2639,11 +2715,11 @@ def setAggregationDepth(self, value): class GeneralizedLinearRegressionModel( - _JavaRegressionModel, + _JavaRegressionModel[Vector], _GeneralizedLinearRegressionParams, JavaMLWritable, - JavaMLReadable, - HasTrainingSummary, + JavaMLReadable["GeneralizedLinearRegressionModel"], + HasTrainingSummary["GeneralizedLinearRegressionTrainingSummary"], ): """ Model fitted by :class:`GeneralizedLinearRegression`. @@ -2652,31 +2728,31 @@ class GeneralizedLinearRegressionModel( """ @since("3.0.0") - def setLinkPredictionCol(self, value): + def setLinkPredictionCol(self, value: str) -> "GeneralizedLinearRegressionModel": """ Sets the value of :py:attr:`linkPredictionCol`. """ return self._set(linkPredictionCol=value) - @property + @property # type: ignore[misc] @since("2.0.0") - def coefficients(self): + def coefficients(self) -> Vector: """ Model coefficients. """ return self._call_java("coefficients") - @property + @property # type: ignore[misc] @since("2.0.0") - def intercept(self): + def intercept(self) -> float: """ Model intercept. """ return self._call_java("intercept") - @property + @property # type: ignore[misc] @since("2.0.0") - def summary(self): + def summary(self) -> "GeneralizedLinearRegressionTrainingSummary": """ Gets summary (residuals, deviance, p-values) of model on training set. An exception is thrown if @@ -2691,7 +2767,7 @@ def summary(self): "No training summary available for this %s" % self.__class__.__name__ ) - def evaluate(self, dataset): + def evaluate(self, dataset: DataFrame) -> "GeneralizedLinearRegressionSummary": """ Evaluates the model on a test dataset. @@ -2716,64 +2792,64 @@ class GeneralizedLinearRegressionSummary(JavaWrapper): .. versionadded:: 2.0.0 """ - @property + @property # type: ignore[misc] @since("2.0.0") - def predictions(self): + def predictions(self) -> DataFrame: """ Predictions output by the model's `transform` method. """ return self._call_java("predictions") - @property + @property # type: ignore[misc] @since("2.0.0") - def predictionCol(self): + def predictionCol(self) -> str: """ Field in :py:attr:`predictions` which gives the predicted value of each instance. This is set to a new column name if the original model's `predictionCol` is not set. """ return self._call_java("predictionCol") - @property + @property # type: ignore[misc] @since("2.2.0") - def numInstances(self): + def numInstances(self) -> int: """ Number of instances in DataFrame predictions. """ return self._call_java("numInstances") - @property + @property # type: ignore[misc] @since("2.0.0") - def rank(self): + def rank(self) -> int: """ The numeric rank of the fitted linear model. """ return self._call_java("rank") - @property + @property # type: ignore[misc] @since("2.0.0") - def degreesOfFreedom(self): + def degreesOfFreedom(self) -> int: """ Degrees of freedom. """ return self._call_java("degreesOfFreedom") - @property + @property # type: ignore[misc] @since("2.0.0") - def residualDegreeOfFreedom(self): + def residualDegreeOfFreedom(self) -> int: """ The residual degrees of freedom. """ return self._call_java("residualDegreeOfFreedom") - @property + @property # type: ignore[misc] @since("2.0.0") - def residualDegreeOfFreedomNull(self): + def residualDegreeOfFreedomNull(self) -> int: """ The residual degrees of freedom for the null model. """ return self._call_java("residualDegreeOfFreedomNull") - def residuals(self, residualsType="deviance"): + def residuals(self, residualsType: str = "deviance") -> DataFrame: """ Get the residuals of the fitted model by type. @@ -2787,25 +2863,25 @@ def residuals(self, residualsType="deviance"): """ return self._call_java("residuals", residualsType) - @property + @property # type: ignore[misc] @since("2.0.0") - def nullDeviance(self): + def nullDeviance(self) -> float: """ The deviance for the null model. """ return self._call_java("nullDeviance") - @property + @property # type: ignore[misc] @since("2.0.0") - def deviance(self): + def deviance(self) -> float: """ The deviance for the fitted model. """ return self._call_java("deviance") - @property + @property # type: ignore[misc] @since("2.0.0") - def dispersion(self): + def dispersion(self) -> float: """ The dispersion of the fitted model. It is taken as 1.0 for the "binomial" and "poisson" families, and otherwise @@ -2814,9 +2890,9 @@ def dispersion(self): """ return self._call_java("dispersion") - @property + @property # type: ignore[misc] @since("2.0.0") - def aic(self): + def aic(self) -> float: """ Akaike's "An Information Criterion"(AIC) for the fitted model. """ @@ -2831,25 +2907,25 @@ class GeneralizedLinearRegressionTrainingSummary(GeneralizedLinearRegressionSumm .. versionadded:: 2.0.0 """ - @property + @property # type: ignore[misc] @since("2.0.0") - def numIterations(self): + def numIterations(self) -> int: """ Number of training iterations. """ return self._call_java("numIterations") - @property + @property # type: ignore[misc] @since("2.0.0") - def solver(self): + def solver(self) -> str: """ The numeric solver used for training. """ return self._call_java("solver") - @property + @property # type: ignore[misc] @since("2.0.0") - def coefficientStandardErrors(self): + def coefficientStandardErrors(self) -> List[float]: """ Standard error of estimated coefficients and intercept. @@ -2858,9 +2934,9 @@ def coefficientStandardErrors(self): """ return self._call_java("coefficientStandardErrors") - @property + @property # type: ignore[misc] @since("2.0.0") - def tValues(self): + def tValues(self) -> List[float]: """ T-statistic of estimated coefficients and intercept. @@ -2869,9 +2945,9 @@ def tValues(self): """ return self._call_java("tValues") - @property + @property # type: ignore[misc] @since("2.0.0") - def pValues(self): + def pValues(self) -> List[float]: """ Two-sided p-value of estimated coefficients and intercept. @@ -2880,7 +2956,7 @@ def pValues(self): """ return self._call_java("pValues") - def __repr__(self): + def __repr__(self) -> str: return self._call_java("toString") @@ -2902,7 +2978,7 @@ class _FactorizationMachinesParams( .. versionadded:: 3.0.0 """ - factorSize = Param( + factorSize: Param[int] = Param( Params._dummy(), "factorSize", "Dimensionality of the factor vectors, " @@ -2910,14 +2986,14 @@ class _FactorizationMachinesParams( typeConverter=TypeConverters.toInt, ) - fitLinear = Param( + fitLinear: Param[bool] = Param( Params._dummy(), "fitLinear", "whether to fit linear term (aka 1-way term)", typeConverter=TypeConverters.toBoolean, ) - miniBatchFraction = Param( + miniBatchFraction: Param[float] = Param( Params._dummy(), "miniBatchFraction", "fraction of the input data " @@ -2925,7 +3001,7 @@ class _FactorizationMachinesParams( typeConverter=TypeConverters.toFloat, ) - initStd = Param( + initStd: Param[float] = Param( Params._dummy(), "initStd", "standard deviation of initial coefficients", @@ -2939,7 +3015,7 @@ class _FactorizationMachinesParams( typeConverter=TypeConverters.toString, ) - def __init__(self, *args): + def __init__(self, *args: Any): super(_FactorizationMachinesParams, self).__init__(*args) self._setDefault( factorSize=8, @@ -2955,28 +3031,28 @@ def __init__(self, *args): ) @since("3.0.0") - def getFactorSize(self): + def getFactorSize(self) -> int: """ Gets the value of factorSize or its default value. """ return self.getOrDefault(self.factorSize) @since("3.0.0") - def getFitLinear(self): + def getFitLinear(self) -> bool: """ Gets the value of fitLinear or its default value. """ return self.getOrDefault(self.fitLinear) @since("3.0.0") - def getMiniBatchFraction(self): + def getMiniBatchFraction(self) -> float: """ Gets the value of miniBatchFraction or its default value. """ return self.getOrDefault(self.miniBatchFraction) @since("3.0.0") - def getInitStd(self): + def getInitStd(self) -> float: """ Gets the value of initStd or its default value. """ @@ -2984,7 +3060,12 @@ def getInitStd(self): @inherit_doc -class FMRegressor(_JavaRegressor, _FactorizationMachinesParams, JavaMLWritable, JavaMLReadable): +class FMRegressor( + _JavaRegressor["FMRegressionModel"], + _FactorizationMachinesParams, + JavaMLWritable, + JavaMLReadable["FMRegressor"], +): """ Factorization Machines learning algorithm for regression. @@ -3044,24 +3125,26 @@ class FMRegressor(_JavaRegressor, _FactorizationMachinesParams, JavaMLWritable, True """ + _input_kwargs: Dict[str, Any] + @keyword_only def __init__( self, *, - featuresCol="features", - labelCol="label", - predictionCol="prediction", - factorSize=8, - fitIntercept=True, - fitLinear=True, - regParam=0.0, - miniBatchFraction=1.0, - initStd=0.01, - maxIter=100, - stepSize=1.0, - tol=1e-6, - solver="adamW", - seed=None, + featuresCol: str = "features", + labelCol: str = "label", + predictionCol: str = "prediction", + factorSize: int = 8, + fitIntercept: bool = True, + fitLinear: bool = True, + regParam: float = 0.0, + miniBatchFraction: float = 1.0, + initStd: float = 0.01, + maxIter: int = 100, + stepSize: float = 1.0, + tol: float = 1e-6, + solver: str = "adamW", + seed: Optional[int] = None, ): """ __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \ @@ -3079,21 +3162,21 @@ def __init__( def setParams( self, *, - featuresCol="features", - labelCol="label", - predictionCol="prediction", - factorSize=8, - fitIntercept=True, - fitLinear=True, - regParam=0.0, - miniBatchFraction=1.0, - initStd=0.01, - maxIter=100, - stepSize=1.0, - tol=1e-6, - solver="adamW", - seed=None, - ): + featuresCol: str = "features", + labelCol: str = "label", + predictionCol: str = "prediction", + factorSize: int = 8, + fitIntercept: bool = True, + fitLinear: bool = True, + regParam: float = 0.0, + miniBatchFraction: float = 1.0, + initStd: float = 0.01, + maxIter: int = 100, + stepSize: float = 1.0, + tol: float = 1e-6, + solver: str = "adamW", + seed: Optional[int] = None, + ) -> "FMRegressor": """ setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \ factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0, \ @@ -3104,81 +3187,81 @@ def setParams( kwargs = self._input_kwargs return self._set(**kwargs) - def _create_model(self, java_model): + def _create_model(self, java_model: "JavaObject") -> "FMRegressionModel": return FMRegressionModel(java_model) @since("3.0.0") - def setFactorSize(self, value): + def setFactorSize(self, value: int) -> "FMRegressor": """ Sets the value of :py:attr:`factorSize`. """ return self._set(factorSize=value) @since("3.0.0") - def setFitLinear(self, value): + def setFitLinear(self, value: bool) -> "FMRegressor": """ Sets the value of :py:attr:`fitLinear`. """ return self._set(fitLinear=value) @since("3.0.0") - def setMiniBatchFraction(self, value): + def setMiniBatchFraction(self, value: float) -> "FMRegressor": """ Sets the value of :py:attr:`miniBatchFraction`. """ return self._set(miniBatchFraction=value) @since("3.0.0") - def setInitStd(self, value): + def setInitStd(self, value: float) -> "FMRegressor": """ Sets the value of :py:attr:`initStd`. """ return self._set(initStd=value) @since("3.0.0") - def setMaxIter(self, value): + def setMaxIter(self, value: int) -> "FMRegressor": """ Sets the value of :py:attr:`maxIter`. """ return self._set(maxIter=value) @since("3.0.0") - def setStepSize(self, value): + def setStepSize(self, value: float) -> "FMRegressor": """ Sets the value of :py:attr:`stepSize`. """ return self._set(stepSize=value) @since("3.0.0") - def setTol(self, value): + def setTol(self, value: float) -> "FMRegressor": """ Sets the value of :py:attr:`tol`. """ return self._set(tol=value) @since("3.0.0") - def setSolver(self, value): + def setSolver(self, value: str) -> "FMRegressor": """ Sets the value of :py:attr:`solver`. """ return self._set(solver=value) @since("3.0.0") - def setSeed(self, value): + def setSeed(self, value: int) -> "FMRegressor": """ Sets the value of :py:attr:`seed`. """ return self._set(seed=value) @since("3.0.0") - def setFitIntercept(self, value): + def setFitIntercept(self, value: bool) -> "FMRegressor": """ Sets the value of :py:attr:`fitIntercept`. """ return self._set(fitIntercept=value) @since("3.0.0") - def setRegParam(self, value): + def setRegParam(self, value: float) -> "FMRegressor": """ Sets the value of :py:attr:`regParam`. """ @@ -3186,7 +3269,10 @@ def setRegParam(self, value): class FMRegressionModel( - _JavaRegressionModel, _FactorizationMachinesParams, JavaMLWritable, JavaMLReadable + _JavaRegressionModel, + _FactorizationMachinesParams, + JavaMLWritable, + JavaMLReadable["FMRegressionModel"], ): """ Model fitted by :class:`FMRegressor`. @@ -3194,25 +3280,25 @@ class FMRegressionModel( .. versionadded:: 3.0.0 """ - @property + @property # type: ignore[misc] @since("3.0.0") - def intercept(self): + def intercept(self) -> float: """ Model intercept. """ return self._call_java("intercept") - @property + @property # type: ignore[misc] @since("3.0.0") - def linear(self): + def linear(self) -> Vector: """ Model linear term. """ return self._call_java("linear") - @property + @property # type: ignore[misc] @since("3.0.0") - def factors(self): + def factors(self) -> Matrix: """ Model factor term. """ diff --git a/python/pyspark/ml/regression.pyi b/python/pyspark/ml/regression.pyi deleted file mode 100644 index 3b553b1401819..0000000000000 --- a/python/pyspark/ml/regression.pyi +++ /dev/null @@ -1,817 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import Any, List, Optional -from pyspark.ml._typing import JM, M, T - -import abc -from pyspark.ml import PredictionModel, Predictor -from pyspark.ml.base import _PredictorParams -from pyspark.ml.param.shared import ( - HasAggregationDepth, - HasMaxBlockSizeInMB, - HasElasticNetParam, - HasFeaturesCol, - HasFitIntercept, - HasLabelCol, - HasLoss, - HasMaxIter, - HasPredictionCol, - HasRegParam, - HasSeed, - HasSolver, - HasStandardization, - HasStepSize, - HasTol, - HasVarianceCol, - HasWeightCol, -) -from pyspark.ml.tree import ( - _DecisionTreeModel, - _DecisionTreeParams, - _GBTParams, - _RandomForestParams, - _TreeEnsembleModel, - _TreeRegressorParams, -) -from pyspark.ml.util import ( - GeneralJavaMLWritable, - HasTrainingSummary, - JavaMLReadable, - JavaMLWritable, -) -from pyspark.ml.wrapper import ( - JavaEstimator, - JavaModel, - JavaPredictionModel, - JavaPredictor, - JavaWrapper, -) - -from pyspark.ml.linalg import Matrix, Vector -from pyspark.ml.param import Param -from pyspark.sql.dataframe import DataFrame - -class Regressor(Predictor[M], _PredictorParams, metaclass=abc.ABCMeta): ... -class RegressionModel(PredictionModel[T], _PredictorParams, metaclass=abc.ABCMeta): ... -class _JavaRegressor(Regressor, JavaPredictor[JM], metaclass=abc.ABCMeta): ... -class _JavaRegressionModel(RegressionModel, JavaPredictionModel[T], metaclass=abc.ABCMeta): ... - -class _LinearRegressionParams( - _PredictorParams, - HasRegParam, - HasElasticNetParam, - HasMaxIter, - HasTol, - HasFitIntercept, - HasStandardization, - HasWeightCol, - HasSolver, - HasAggregationDepth, - HasLoss, - HasMaxBlockSizeInMB, -): - solver: Param[str] - loss: Param[str] - epsilon: Param[float] - def __init__(self, *args: Any): ... - def getEpsilon(self) -> float: ... - -class LinearRegression( - _JavaRegressor[LinearRegressionModel], - _LinearRegressionParams, - JavaMLWritable, - JavaMLReadable[LinearRegression], -): - def __init__( - self, - *, - featuresCol: str = ..., - labelCol: str = ..., - predictionCol: str = ..., - maxIter: int = ..., - regParam: float = ..., - elasticNetParam: float = ..., - tol: float = ..., - fitIntercept: bool = ..., - standardization: bool = ..., - solver: str = ..., - weightCol: Optional[str] = ..., - aggregationDepth: int = ..., - epsilon: float = ..., - maxBlockSizeInMB: float = ..., - ) -> None: ... - def setParams( - self, - *, - featuresCol: str = ..., - labelCol: str = ..., - predictionCol: str = ..., - maxIter: int = ..., - regParam: float = ..., - elasticNetParam: float = ..., - tol: float = ..., - fitIntercept: bool = ..., - standardization: bool = ..., - solver: str = ..., - weightCol: Optional[str] = ..., - aggregationDepth: int = ..., - epsilon: float = ..., - maxBlockSizeInMB: float = ..., - ) -> LinearRegression: ... - def setEpsilon(self, value: float) -> LinearRegression: ... - def setMaxIter(self, value: int) -> LinearRegression: ... - def setRegParam(self, value: float) -> LinearRegression: ... - def setTol(self, value: float) -> LinearRegression: ... - def setElasticNetParam(self, value: float) -> LinearRegression: ... - def setFitIntercept(self, value: bool) -> LinearRegression: ... - def setStandardization(self, value: bool) -> LinearRegression: ... - def setWeightCol(self, value: str) -> LinearRegression: ... - def setSolver(self, value: str) -> LinearRegression: ... - def setAggregationDepth(self, value: int) -> LinearRegression: ... - def setLoss(self, value: str) -> LinearRegression: ... - def setMaxBlockSizeInMB(self, value: float) -> LinearRegression: ... - -class LinearRegressionModel( - _JavaRegressionModel[Vector], - _LinearRegressionParams, - GeneralJavaMLWritable, - JavaMLReadable[LinearRegressionModel], - HasTrainingSummary[LinearRegressionSummary], -): - @property - def coefficients(self) -> Vector: ... - @property - def intercept(self) -> float: ... - @property - def summary(self) -> LinearRegressionTrainingSummary: ... - def evaluate(self, dataset: DataFrame) -> LinearRegressionSummary: ... - -class LinearRegressionSummary(JavaWrapper): - @property - def predictions(self) -> DataFrame: ... - @property - def predictionCol(self) -> str: ... - @property - def labelCol(self) -> str: ... - @property - def featuresCol(self) -> str: ... - @property - def explainedVariance(self) -> float: ... - @property - def meanAbsoluteError(self) -> float: ... - @property - def meanSquaredError(self) -> float: ... - @property - def rootMeanSquaredError(self) -> float: ... - @property - def r2(self) -> float: ... - @property - def r2adj(self) -> float: ... - @property - def residuals(self) -> DataFrame: ... - @property - def numInstances(self) -> int: ... - @property - def devianceResiduals(self) -> List[float]: ... - @property - def coefficientStandardErrors(self) -> List[float]: ... - @property - def tValues(self) -> List[float]: ... - @property - def pValues(self) -> List[float]: ... - -class LinearRegressionTrainingSummary(LinearRegressionSummary): - @property - def objectiveHistory(self) -> List[float]: ... - @property - def totalIterations(self) -> int: ... - -class _IsotonicRegressionParams(HasFeaturesCol, HasLabelCol, HasPredictionCol, HasWeightCol): - isotonic: Param[bool] - featureIndex: Param[int] - def getIsotonic(self) -> bool: ... - def getFeatureIndex(self) -> int: ... - -class IsotonicRegression( - JavaEstimator[IsotonicRegressionModel], - _IsotonicRegressionParams, - HasWeightCol, - JavaMLWritable, - JavaMLReadable[IsotonicRegression], -): - def __init__( - self, - *, - featuresCol: str = ..., - labelCol: str = ..., - predictionCol: str = ..., - weightCol: Optional[str] = ..., - isotonic: bool = ..., - featureIndex: int = ..., - ) -> None: ... - def setParams( - self, - *, - featuresCol: str = ..., - labelCol: str = ..., - predictionCol: str = ..., - weightCol: Optional[str] = ..., - isotonic: bool = ..., - featureIndex: int = ..., - ) -> IsotonicRegression: ... - def setIsotonic(self, value: bool) -> IsotonicRegression: ... - def setFeatureIndex(self, value: int) -> IsotonicRegression: ... - def setFeaturesCol(self, value: str) -> IsotonicRegression: ... - def setPredictionCol(self, value: str) -> IsotonicRegression: ... - def setLabelCol(self, value: str) -> IsotonicRegression: ... - def setWeightCol(self, value: str) -> IsotonicRegression: ... - -class IsotonicRegressionModel( - JavaModel, - _IsotonicRegressionParams, - JavaMLWritable, - JavaMLReadable[IsotonicRegressionModel], -): - def setFeaturesCol(self, value: str) -> IsotonicRegressionModel: ... - def setPredictionCol(self, value: str) -> IsotonicRegressionModel: ... - def setFeatureIndex(self, value: int) -> IsotonicRegressionModel: ... - @property - def boundaries(self) -> Vector: ... - @property - def predictions(self) -> Vector: ... - @property - def numFeatures(self) -> int: ... - def predict(self, value: float) -> float: ... - -class _DecisionTreeRegressorParams(_DecisionTreeParams, _TreeRegressorParams, HasVarianceCol): - def __init__(self, *args: Any): ... - -class DecisionTreeRegressor( - _JavaRegressor[DecisionTreeRegressionModel], - _DecisionTreeRegressorParams, - JavaMLWritable, - JavaMLReadable[DecisionTreeRegressor], -): - def __init__( - self, - *, - featuresCol: str = ..., - labelCol: str = ..., - predictionCol: str = ..., - maxDepth: int = ..., - maxBins: int = ..., - minInstancesPerNode: int = ..., - minInfoGain: float = ..., - maxMemoryInMB: int = ..., - cacheNodeIds: bool = ..., - checkpointInterval: int = ..., - impurity: str = ..., - seed: Optional[int] = ..., - varianceCol: Optional[str] = ..., - weightCol: Optional[str] = ..., - leafCol: str = ..., - minWeightFractionPerNode: float = ..., - ) -> None: ... - def setParams( - self, - *, - featuresCol: str = ..., - labelCol: str = ..., - predictionCol: str = ..., - maxDepth: int = ..., - maxBins: int = ..., - minInstancesPerNode: int = ..., - minInfoGain: float = ..., - maxMemoryInMB: int = ..., - cacheNodeIds: bool = ..., - checkpointInterval: int = ..., - impurity: str = ..., - seed: Optional[int] = ..., - varianceCol: Optional[str] = ..., - weightCol: Optional[str] = ..., - leafCol: str = ..., - minWeightFractionPerNode: float = ..., - ) -> DecisionTreeRegressor: ... - def setMaxDepth(self, value: int) -> DecisionTreeRegressor: ... - def setMaxBins(self, value: int) -> DecisionTreeRegressor: ... - def setMinInstancesPerNode(self, value: int) -> DecisionTreeRegressor: ... - def setMinWeightFractionPerNode(self, value: float) -> DecisionTreeRegressor: ... - def setMinInfoGain(self, value: float) -> DecisionTreeRegressor: ... - def setMaxMemoryInMB(self, value: int) -> DecisionTreeRegressor: ... - def setCacheNodeIds(self, value: bool) -> DecisionTreeRegressor: ... - def setImpurity(self, value: str) -> DecisionTreeRegressor: ... - def setCheckpointInterval(self, value: int) -> DecisionTreeRegressor: ... - def setSeed(self, value: int) -> DecisionTreeRegressor: ... - def setWeightCol(self, value: str) -> DecisionTreeRegressor: ... - def setVarianceCol(self, value: str) -> DecisionTreeRegressor: ... - -class DecisionTreeRegressionModel( - _JavaRegressionModel[Vector], - _DecisionTreeModel, - _DecisionTreeRegressorParams, - JavaMLWritable, - JavaMLReadable[DecisionTreeRegressionModel], -): - def setVarianceCol(self, value: str) -> DecisionTreeRegressionModel: ... - @property - def featureImportances(self) -> Vector: ... - -class _RandomForestRegressorParams(_RandomForestParams, _TreeRegressorParams): - def __init__(self, *args: Any): ... - -class RandomForestRegressor( - _JavaRegressor[RandomForestRegressionModel], - _RandomForestRegressorParams, - JavaMLWritable, - JavaMLReadable[RandomForestRegressor], -): - def __init__( - self, - *, - featuresCol: str = ..., - labelCol: str = ..., - predictionCol: str = ..., - maxDepth: int = ..., - maxBins: int = ..., - minInstancesPerNode: int = ..., - minInfoGain: float = ..., - maxMemoryInMB: int = ..., - cacheNodeIds: bool = ..., - checkpointInterval: int = ..., - impurity: str = ..., - subsamplingRate: float = ..., - seed: Optional[int] = ..., - numTrees: int = ..., - featureSubsetStrategy: str = ..., - leafCol: str = ..., - minWeightFractionPerNode: float = ..., - weightCol: Optional[str] = ..., - bootstrap: Optional[bool] = ..., - ) -> None: ... - def setParams( - self, - *, - featuresCol: str = ..., - labelCol: str = ..., - predictionCol: str = ..., - maxDepth: int = ..., - maxBins: int = ..., - minInstancesPerNode: int = ..., - minInfoGain: float = ..., - maxMemoryInMB: int = ..., - cacheNodeIds: bool = ..., - checkpointInterval: int = ..., - impurity: str = ..., - subsamplingRate: float = ..., - seed: Optional[int] = ..., - numTrees: int = ..., - featureSubsetStrategy: str = ..., - leafCol: str = ..., - minWeightFractionPerNode: float = ..., - weightCol: Optional[str] = ..., - bootstrap: Optional[bool] = ..., - ) -> RandomForestRegressor: ... - def setMaxDepth(self, value: int) -> RandomForestRegressor: ... - def setMaxBins(self, value: int) -> RandomForestRegressor: ... - def setMinInstancesPerNode(self, value: int) -> RandomForestRegressor: ... - def setMinInfoGain(self, value: float) -> RandomForestRegressor: ... - def setMaxMemoryInMB(self, value: int) -> RandomForestRegressor: ... - def setCacheNodeIds(self, value: bool) -> RandomForestRegressor: ... - def setImpurity(self, value: str) -> RandomForestRegressor: ... - def setNumTrees(self, value: int) -> RandomForestRegressor: ... - def setBootstrap(self, value: bool) -> RandomForestRegressor: ... - def setSubsamplingRate(self, value: float) -> RandomForestRegressor: ... - def setFeatureSubsetStrategy(self, value: str) -> RandomForestRegressor: ... - def setCheckpointInterval(self, value: int) -> RandomForestRegressor: ... - def setSeed(self, value: int) -> RandomForestRegressor: ... - def setWeightCol(self, value: str) -> RandomForestRegressor: ... - def setMinWeightFractionPerNode(self, value: float) -> RandomForestRegressor: ... - -class RandomForestRegressionModel( - _JavaRegressionModel[Vector], - _TreeEnsembleModel, - _RandomForestRegressorParams, - JavaMLWritable, - JavaMLReadable[RandomForestRegressionModel], -): - @property - def trees(self) -> List[DecisionTreeRegressionModel]: ... - @property - def featureImportances(self) -> Vector: ... - -class _GBTRegressorParams(_GBTParams, _TreeRegressorParams): - supportedLossTypes: List[str] - lossType: Param[str] - def __init__(self, *args: Any): ... - def getLossType(self) -> str: ... - -class GBTRegressor( - _JavaRegressor[GBTRegressionModel], - _GBTRegressorParams, - JavaMLWritable, - JavaMLReadable[GBTRegressor], -): - def __init__( - self, - *, - featuresCol: str = ..., - labelCol: str = ..., - predictionCol: str = ..., - maxDepth: int = ..., - maxBins: int = ..., - minInstancesPerNode: int = ..., - minInfoGain: float = ..., - maxMemoryInMB: int = ..., - cacheNodeIds: bool = ..., - subsamplingRate: float = ..., - checkpointInterval: int = ..., - lossType: str = ..., - maxIter: int = ..., - stepSize: float = ..., - seed: Optional[int] = ..., - impurity: str = ..., - featureSubsetStrategy: str = ..., - validationTol: float = ..., - validationIndicatorCol: Optional[str] = ..., - leafCol: str = ..., - minWeightFractionPerNode: float = ..., - weightCol: Optional[str] = ..., - ) -> None: ... - def setParams( - self, - *, - featuresCol: str = ..., - labelCol: str = ..., - predictionCol: str = ..., - maxDepth: int = ..., - maxBins: int = ..., - minInstancesPerNode: int = ..., - minInfoGain: float = ..., - maxMemoryInMB: int = ..., - cacheNodeIds: bool = ..., - subsamplingRate: float = ..., - checkpointInterval: int = ..., - lossType: str = ..., - maxIter: int = ..., - stepSize: float = ..., - seed: Optional[int] = ..., - impurity: str = ..., - featureSubsetStrategy: str = ..., - validationTol: float = ..., - validationIndicatorCol: Optional[str] = ..., - leafCol: str = ..., - minWeightFractionPerNode: float = ..., - weightCol: Optional[str] = ..., - ) -> GBTRegressor: ... - def setMaxDepth(self, value: int) -> GBTRegressor: ... - def setMaxBins(self, value: int) -> GBTRegressor: ... - def setMinInstancesPerNode(self, value: int) -> GBTRegressor: ... - def setMinInfoGain(self, value: float) -> GBTRegressor: ... - def setMaxMemoryInMB(self, value: int) -> GBTRegressor: ... - def setCacheNodeIds(self, value: bool) -> GBTRegressor: ... - def setImpurity(self, value: str) -> GBTRegressor: ... - def setLossType(self, value: str) -> GBTRegressor: ... - def setSubsamplingRate(self, value: float) -> GBTRegressor: ... - def setFeatureSubsetStrategy(self, value: str) -> GBTRegressor: ... - def setValidationIndicatorCol(self, value: str) -> GBTRegressor: ... - def setMaxIter(self, value: int) -> GBTRegressor: ... - def setCheckpointInterval(self, value: int) -> GBTRegressor: ... - def setSeed(self, value: int) -> GBTRegressor: ... - def setStepSize(self, value: float) -> GBTRegressor: ... - def setWeightCol(self, value: str) -> GBTRegressor: ... - def setMinWeightFractionPerNode(self, value: float) -> GBTRegressor: ... - -class GBTRegressionModel( - _JavaRegressionModel[Vector], - _TreeEnsembleModel, - _GBTRegressorParams, - JavaMLWritable, - JavaMLReadable[GBTRegressionModel], -): - @property - def featureImportances(self) -> Vector: ... - @property - def trees(self) -> List[DecisionTreeRegressionModel]: ... - def evaluateEachIteration(self, dataset: DataFrame, loss: str) -> List[float]: ... - -class _AFTSurvivalRegressionParams( - _PredictorParams, - HasMaxIter, - HasTol, - HasFitIntercept, - HasAggregationDepth, - HasMaxBlockSizeInMB, -): - censorCol: Param[str] - quantileProbabilities: Param[List[float]] - quantilesCol: Param[str] - def __init__(self, *args: Any): ... - def getCensorCol(self) -> str: ... - def getQuantileProbabilities(self) -> List[float]: ... - def getQuantilesCol(self) -> str: ... - -class AFTSurvivalRegression( - _JavaRegressor[AFTSurvivalRegressionModel], - _AFTSurvivalRegressionParams, - JavaMLWritable, - JavaMLReadable[AFTSurvivalRegression], -): - def __init__( - self, - *, - featuresCol: str = ..., - labelCol: str = ..., - predictionCol: str = ..., - fitIntercept: bool = ..., - maxIter: int = ..., - tol: float = ..., - censorCol: str = ..., - quantileProbabilities: List[float] = ..., - quantilesCol: Optional[str] = ..., - aggregationDepth: int = ..., - maxBlockSizeInMB: float = ..., - ) -> None: ... - def setParams( - self, - *, - featuresCol: str = ..., - labelCol: str = ..., - predictionCol: str = ..., - fitIntercept: bool = ..., - maxIter: int = ..., - tol: float = ..., - censorCol: str = ..., - quantileProbabilities: List[float] = ..., - quantilesCol: Optional[str] = ..., - aggregationDepth: int = ..., - maxBlockSizeInMB: float = ..., - ) -> AFTSurvivalRegression: ... - def setCensorCol(self, value: str) -> AFTSurvivalRegression: ... - def setQuantileProbabilities(self, value: List[float]) -> AFTSurvivalRegression: ... - def setQuantilesCol(self, value: str) -> AFTSurvivalRegression: ... - def setMaxIter(self, value: int) -> AFTSurvivalRegression: ... - def setTol(self, value: float) -> AFTSurvivalRegression: ... - def setFitIntercept(self, value: bool) -> AFTSurvivalRegression: ... - def setAggregationDepth(self, value: int) -> AFTSurvivalRegression: ... - def setMaxBlockSizeInMB(self, value: float) -> AFTSurvivalRegression: ... - -class AFTSurvivalRegressionModel( - _JavaRegressionModel[Vector], - _AFTSurvivalRegressionParams, - JavaMLWritable, - JavaMLReadable[AFTSurvivalRegressionModel], -): - def setQuantileProbabilities(self, value: List[float]) -> AFTSurvivalRegressionModel: ... - def setQuantilesCol(self, value: str) -> AFTSurvivalRegressionModel: ... - @property - def coefficients(self) -> Vector: ... - @property - def intercept(self) -> float: ... - @property - def scale(self) -> float: ... - def predictQuantiles(self, features: Vector) -> Vector: ... - def predict(self, features: Vector) -> float: ... - -class _GeneralizedLinearRegressionParams( - _PredictorParams, - HasFitIntercept, - HasMaxIter, - HasTol, - HasRegParam, - HasWeightCol, - HasSolver, - HasAggregationDepth, -): - family: Param[str] - link: Param[str] - linkPredictionCol: Param[str] - variancePower: Param[float] - linkPower: Param[float] - solver: Param[str] - offsetCol: Param[str] - def __init__(self, *args: Any): ... - def getFamily(self) -> str: ... - def getLinkPredictionCol(self) -> str: ... - def getLink(self) -> str: ... - def getVariancePower(self) -> float: ... - def getLinkPower(self) -> float: ... - def getOffsetCol(self) -> str: ... - -class GeneralizedLinearRegression( - _JavaRegressor[GeneralizedLinearRegressionModel], - _GeneralizedLinearRegressionParams, - JavaMLWritable, - JavaMLReadable[GeneralizedLinearRegression], -): - def __init__( - self, - *, - labelCol: str = ..., - featuresCol: str = ..., - predictionCol: str = ..., - family: str = ..., - link: Optional[str] = ..., - fitIntercept: bool = ..., - maxIter: int = ..., - tol: float = ..., - regParam: float = ..., - weightCol: Optional[str] = ..., - solver: str = ..., - linkPredictionCol: Optional[str] = ..., - variancePower: float = ..., - linkPower: Optional[float] = ..., - offsetCol: Optional[str] = ..., - aggregationDepth: int = ..., - ) -> None: ... - def setParams( - self, - *, - labelCol: str = ..., - featuresCol: str = ..., - predictionCol: str = ..., - family: str = ..., - link: Optional[str] = ..., - fitIntercept: bool = ..., - maxIter: int = ..., - tol: float = ..., - regParam: float = ..., - weightCol: Optional[str] = ..., - solver: str = ..., - linkPredictionCol: Optional[str] = ..., - variancePower: float = ..., - linkPower: Optional[float] = ..., - offsetCol: Optional[str] = ..., - aggregationDepth: int = ..., - ) -> GeneralizedLinearRegression: ... - def setFamily(self, value: str) -> GeneralizedLinearRegression: ... - def setLinkPredictionCol(self, value: str) -> GeneralizedLinearRegression: ... - def setLink(self, value: str) -> GeneralizedLinearRegression: ... - def setVariancePower(self, value: float) -> GeneralizedLinearRegression: ... - def setLinkPower(self, value: float) -> GeneralizedLinearRegression: ... - def setOffsetCol(self, value: str) -> GeneralizedLinearRegression: ... - def setMaxIter(self, value: int) -> GeneralizedLinearRegression: ... - def setRegParam(self, value: float) -> GeneralizedLinearRegression: ... - def setTol(self, value: float) -> GeneralizedLinearRegression: ... - def setFitIntercept(self, value: bool) -> GeneralizedLinearRegression: ... - def setWeightCol(self, value: str) -> GeneralizedLinearRegression: ... - def setSolver(self, value: str) -> GeneralizedLinearRegression: ... - def setAggregationDepth(self, value: int) -> GeneralizedLinearRegression: ... - -class GeneralizedLinearRegressionModel( - _JavaRegressionModel[Vector], - _GeneralizedLinearRegressionParams, - JavaMLWritable, - JavaMLReadable[GeneralizedLinearRegressionModel], - HasTrainingSummary[GeneralizedLinearRegressionTrainingSummary], -): - def setLinkPredictionCol(self, value: str) -> GeneralizedLinearRegressionModel: ... - @property - def coefficients(self) -> Vector: ... - @property - def intercept(self) -> float: ... - @property - def summary(self) -> GeneralizedLinearRegressionTrainingSummary: ... - def evaluate(self, dataset: DataFrame) -> GeneralizedLinearRegressionSummary: ... - -class GeneralizedLinearRegressionSummary(JavaWrapper): - @property - def predictions(self) -> DataFrame: ... - @property - def predictionCol(self) -> str: ... - @property - def rank(self) -> int: ... - @property - def degreesOfFreedom(self) -> int: ... - @property - def residualDegreeOfFreedom(self) -> int: ... - @property - def residualDegreeOfFreedomNull(self) -> int: ... - def residuals(self, residualsType: str = ...) -> DataFrame: ... - @property - def nullDeviance(self) -> float: ... - @property - def deviance(self) -> float: ... - @property - def dispersion(self) -> float: ... - @property - def aic(self) -> float: ... - -class GeneralizedLinearRegressionTrainingSummary(GeneralizedLinearRegressionSummary): - @property - def numIterations(self) -> int: ... - @property - def solver(self) -> str: ... - @property - def coefficientStandardErrors(self) -> List[float]: ... - @property - def tValues(self) -> List[float]: ... - @property - def pValues(self) -> List[float]: ... - -class _FactorizationMachinesParams( - _PredictorParams, - HasMaxIter, - HasStepSize, - HasTol, - HasSolver, - HasSeed, - HasFitIntercept, - HasRegParam, - HasWeightCol, -): - factorSize: Param[int] - fitLinear: Param[bool] - miniBatchFraction: Param[float] - initStd: Param[float] - solver: Param[str] - def __init__(self, *args: Any): ... - def getFactorSize(self) -> int: ... - def getFitLinear(self) -> bool: ... - def getMiniBatchFraction(self) -> float: ... - def getInitStd(self) -> float: ... - -class FMRegressor( - _JavaRegressor[FMRegressionModel], - _FactorizationMachinesParams, - JavaMLWritable, - JavaMLReadable[FMRegressor], -): - factorSize: Param[int] - fitLinear: Param[bool] - miniBatchFraction: Param[float] - initStd: Param[float] - solver: Param[str] - def __init__( - self, - featuresCol: str = ..., - labelCol: str = ..., - predictionCol: str = ..., - factorSize: int = ..., - fitIntercept: bool = ..., - fitLinear: bool = ..., - regParam: float = ..., - miniBatchFraction: float = ..., - initStd: float = ..., - maxIter: int = ..., - stepSize: float = ..., - tol: float = ..., - solver: str = ..., - seed: Optional[int] = ..., - ) -> None: ... - def setParams( - self, - featuresCol: str = ..., - labelCol: str = ..., - predictionCol: str = ..., - factorSize: int = ..., - fitIntercept: bool = ..., - fitLinear: bool = ..., - regParam: float = ..., - miniBatchFraction: float = ..., - initStd: float = ..., - maxIter: int = ..., - stepSize: float = ..., - tol: float = ..., - solver: str = ..., - seed: Optional[int] = ..., - ) -> FMRegressor: ... - def setFactorSize(self, value: int) -> FMRegressor: ... - def setFitLinear(self, value: bool) -> FMRegressor: ... - def setMiniBatchFraction(self, value: float) -> FMRegressor: ... - def setInitStd(self, value: float) -> FMRegressor: ... - def setMaxIter(self, value: int) -> FMRegressor: ... - def setStepSize(self, value: float) -> FMRegressor: ... - def setTol(self, value: float) -> FMRegressor: ... - def setSolver(self, value: str) -> FMRegressor: ... - def setSeed(self, value: int) -> FMRegressor: ... - def setFitIntercept(self, value: bool) -> FMRegressor: ... - def setRegParam(self, value: float) -> FMRegressor: ... - -class FMRegressionModel( - _JavaRegressionModel, - _FactorizationMachinesParams, - JavaMLWritable, - JavaMLReadable[FMRegressionModel], -): - @property - def intercept(self) -> float: ... - @property - def linear(self) -> Vector: ... - @property - def factors(self) -> Matrix: ... diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py index 15bb6ca93f179..b91ef1b6cb346 100644 --- a/python/pyspark/ml/stat.py +++ b/python/pyspark/ml/stat.py @@ -17,12 +17,20 @@ import sys +from typing import Optional, Tuple, TYPE_CHECKING + + from pyspark import since, SparkContext from pyspark.ml.common import _java2py, _py2java +from pyspark.ml.linalg import Matrix, Vector from pyspark.ml.wrapper import JavaWrapper, _jvm from pyspark.sql.column import Column, _to_seq +from pyspark.sql.dataframe import DataFrame from pyspark.sql.functions import lit +if TYPE_CHECKING: + from py4j.java_gateway import JavaObject + class ChiSquareTest: """ @@ -37,7 +45,9 @@ class ChiSquareTest: """ @staticmethod - def test(dataset, featuresCol, labelCol, flatten=False): + def test( + dataset: DataFrame, featuresCol: str, labelCol: str, flatten: bool = False + ) -> DataFrame: """ Perform a Pearson's independence test using dataset. @@ -95,6 +105,8 @@ def test(dataset, featuresCol, labelCol, flatten=False): 4.0 """ sc = SparkContext._active_spark_context + assert sc is not None + javaTestObj = _jvm().org.apache.spark.ml.stat.ChiSquareTest args = [_py2java(sc, arg) for arg in (dataset, featuresCol, labelCol, flatten)] return _java2py(sc, javaTestObj.test(*args)) @@ -116,7 +128,7 @@ class Correlation: """ @staticmethod - def corr(dataset, column, method="pearson"): + def corr(dataset: DataFrame, column: str, method: str = "pearson") -> DataFrame: """ Compute the correlation matrix with specified method using dataset. @@ -162,6 +174,8 @@ def corr(dataset, column, method="pearson"): [ 0.4 , 0.9486... , NaN, 1. ]]) """ sc = SparkContext._active_spark_context + assert sc is not None + javaCorrObj = _jvm().org.apache.spark.ml.stat.Correlation args = [_py2java(sc, arg) for arg in (dataset, column, method)] return _java2py(sc, javaCorrObj.corr(*args)) @@ -181,7 +195,7 @@ class KolmogorovSmirnovTest: """ @staticmethod - def test(dataset, sampleCol, distName, *params): + def test(dataset: DataFrame, sampleCol: str, distName: str, *params: float) -> DataFrame: """ Conduct a one-sample, two-sided Kolmogorov-Smirnov test for probability distribution equality. Currently supports the normal distribution, taking as parameters the mean and @@ -228,9 +242,11 @@ def test(dataset, sampleCol, distName, *params): 0.175 """ sc = SparkContext._active_spark_context + assert sc is not None + javaTestObj = _jvm().org.apache.spark.ml.stat.KolmogorovSmirnovTest dataset = _py2java(sc, dataset) - params = [float(param) for param in params] + params = [float(param) for param in params] # type: ignore[assignment] return _java2py( sc, javaTestObj.test(dataset, sampleCol, distName, _jvm().PythonUtils.toSeq(params)) ) @@ -284,7 +300,7 @@ class Summarizer: @staticmethod @since("2.4.0") - def mean(col, weightCol=None): + def mean(col: Column, weightCol: Optional[Column] = None) -> Column: """ return a column of mean summary """ @@ -292,7 +308,7 @@ def mean(col, weightCol=None): @staticmethod @since("3.0.0") - def sum(col, weightCol=None): + def sum(col: Column, weightCol: Optional[Column] = None) -> Column: """ return a column of sum summary """ @@ -300,7 +316,7 @@ def sum(col, weightCol=None): @staticmethod @since("2.4.0") - def variance(col, weightCol=None): + def variance(col: Column, weightCol: Optional[Column] = None) -> Column: """ return a column of variance summary """ @@ -308,7 +324,7 @@ def variance(col, weightCol=None): @staticmethod @since("3.0.0") - def std(col, weightCol=None): + def std(col: Column, weightCol: Optional[Column] = None) -> Column: """ return a column of std summary """ @@ -316,7 +332,7 @@ def std(col, weightCol=None): @staticmethod @since("2.4.0") - def count(col, weightCol=None): + def count(col: Column, weightCol: Optional[Column] = None) -> Column: """ return a column of count summary """ @@ -324,7 +340,7 @@ def count(col, weightCol=None): @staticmethod @since("2.4.0") - def numNonZeros(col, weightCol=None): + def numNonZeros(col: Column, weightCol: Optional[Column] = None) -> Column: """ return a column of numNonZero summary """ @@ -332,7 +348,7 @@ def numNonZeros(col, weightCol=None): @staticmethod @since("2.4.0") - def max(col, weightCol=None): + def max(col: Column, weightCol: Optional[Column] = None) -> Column: """ return a column of max summary """ @@ -340,7 +356,7 @@ def max(col, weightCol=None): @staticmethod @since("2.4.0") - def min(col, weightCol=None): + def min(col: Column, weightCol: Optional[Column] = None) -> Column: """ return a column of min summary """ @@ -348,7 +364,7 @@ def min(col, weightCol=None): @staticmethod @since("2.4.0") - def normL1(col, weightCol=None): + def normL1(col: Column, weightCol: Optional[Column] = None) -> Column: """ return a column of normL1 summary """ @@ -356,14 +372,14 @@ def normL1(col, weightCol=None): @staticmethod @since("2.4.0") - def normL2(col, weightCol=None): + def normL2(col: Column, weightCol: Optional[Column] = None) -> Column: """ return a column of normL2 summary """ return Summarizer._get_single_metric(col, weightCol, "normL2") @staticmethod - def _check_param(featuresCol, weightCol): + def _check_param(featuresCol: Column, weightCol: Optional[Column]) -> Tuple[Column, Column]: if weightCol is None: weightCol = lit(1.0) if not isinstance(featuresCol, Column) or not isinstance(weightCol, Column): @@ -371,7 +387,7 @@ def _check_param(featuresCol, weightCol): return featuresCol, weightCol @staticmethod - def _get_single_metric(col, weightCol, metric): + def _get_single_metric(col: Column, weightCol: Optional[Column], metric: str) -> Column: col, weightCol = Summarizer._check_param(col, weightCol) return Column( JavaWrapper._new_java_obj( @@ -380,7 +396,7 @@ def _get_single_metric(col, weightCol, metric): ) @staticmethod - def metrics(*metrics): + def metrics(*metrics: str) -> "SummaryBuilder": """ Given a list of metrics, provides a builder that it turns computes metrics from a column. @@ -415,6 +431,8 @@ def metrics(*metrics): :py:class:`pyspark.ml.stat.SummaryBuilder` """ sc = SparkContext._active_spark_context + assert sc is not None + js = JavaWrapper._new_java_obj( "org.apache.spark.ml.stat.Summarizer.metrics", _to_seq(sc, metrics) ) @@ -432,10 +450,10 @@ class SummaryBuilder(JavaWrapper): """ - def __init__(self, jSummaryBuilder): + def __init__(self, jSummaryBuilder: "JavaObject"): super(SummaryBuilder, self).__init__(jSummaryBuilder) - def summary(self, featuresCol, weightCol=None): + def summary(self, featuresCol: Column, weightCol: Optional[Column] = None) -> Column: """ Returns an aggregate object that contains the summary of the column with the requested metrics. @@ -456,6 +474,8 @@ def summary(self, featuresCol, weightCol=None): structure is determined during the creation of the builder. """ featuresCol, weightCol = Summarizer._check_param(featuresCol, weightCol) + assert self._java_obj is not None + return Column(self._java_obj.summary(featuresCol._jc, weightCol._jc)) @@ -474,7 +494,7 @@ class MultivariateGaussian: [ 3., 2.]])) """ - def __init__(self, mean, cov): + def __init__(self, mean: Vector, cov: Matrix): self.mean = mean self.cov = cov diff --git a/python/pyspark/ml/stat.pyi b/python/pyspark/ml/stat.pyi deleted file mode 100644 index 90b0686b1c746..0000000000000 --- a/python/pyspark/ml/stat.pyi +++ /dev/null @@ -1,73 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import Optional - -from pyspark.ml.linalg import Matrix, Vector -from pyspark.ml.wrapper import JavaWrapper -from pyspark.sql.column import Column -from pyspark.sql.dataframe import DataFrame - -from py4j.java_gateway import JavaObject # type: ignore[import] - -class ChiSquareTest: - @staticmethod - def test( - dataset: DataFrame, featuresCol: str, labelCol: str, flatten: bool = ... - ) -> DataFrame: ... - -class Correlation: - @staticmethod - def corr(dataset: DataFrame, column: str, method: str = ...) -> DataFrame: ... - -class KolmogorovSmirnovTest: - @staticmethod - def test(dataset: DataFrame, sampleCol: str, distName: str, *params: float) -> DataFrame: ... - -class Summarizer: - @staticmethod - def mean(col: Column, weightCol: Optional[Column] = ...) -> Column: ... - @staticmethod - def sum(col: Column, weightCol: Optional[Column] = ...) -> Column: ... - @staticmethod - def variance(col: Column, weightCol: Optional[Column] = ...) -> Column: ... - @staticmethod - def std(col: Column, weightCol: Optional[Column] = ...) -> Column: ... - @staticmethod - def count(col: Column, weightCol: Optional[Column] = ...) -> Column: ... - @staticmethod - def numNonZeros(col: Column, weightCol: Optional[Column] = ...) -> Column: ... - @staticmethod - def max(col: Column, weightCol: Optional[Column] = ...) -> Column: ... - @staticmethod - def min(col: Column, weightCol: Optional[Column] = ...) -> Column: ... - @staticmethod - def normL1(col: Column, weightCol: Optional[Column] = ...) -> Column: ... - @staticmethod - def normL2(col: Column, weightCol: Optional[Column] = ...) -> Column: ... - @staticmethod - def metrics(*metrics: str) -> SummaryBuilder: ... - -class SummaryBuilder(JavaWrapper): - def __init__(self, jSummaryBuilder: JavaObject) -> None: ... - def summary(self, featuresCol: Column, weightCol: Optional[Column] = ...) -> Column: ... - -class MultivariateGaussian: - mean: Vector - cov: Matrix - def __init__(self, mean: Vector, cov: Matrix) -> None: ... diff --git a/python/pyspark/ml/tests/test_algorithms.py b/python/pyspark/ml/tests/test_algorithms.py index bf74988a7c097..08da8592c043d 100644 --- a/python/pyspark/ml/tests/test_algorithms.py +++ b/python/pyspark/ml/tests/test_algorithms.py @@ -101,7 +101,15 @@ def test_raw_and_probability_prediction(self): expected_rawPrediction = [-11.6081922998, -8.15827998691, 22.17757045] self.assertTrue(result.prediction, expected_prediction) self.assertTrue(np.allclose(result.probability, expected_probability, atol=1e-4)) - self.assertTrue(np.allclose(result.rawPrediction, expected_rawPrediction, rtol=0.11)) + # Use `assert_allclose` to show the value of `result.rawPrediction` in the assertion error + # message + np.testing.assert_allclose( + result.rawPrediction, + expected_rawPrediction, + rtol=0.15, + # Use the same default value as `np.allclose` + atol=1e-08, + ) class OneVsRestTests(SparkSessionTestCase): diff --git a/python/pyspark/ml/tests/test_wrapper.py b/python/pyspark/ml/tests/test_wrapper.py index 315a035891c28..02ce6f319241f 100644 --- a/python/pyspark/ml/tests/test_wrapper.py +++ b/python/pyspark/ml/tests/test_wrapper.py @@ -21,7 +21,7 @@ from pyspark.ml.linalg import DenseVector, Vectors from pyspark.ml.regression import LinearRegression -from pyspark.ml.wrapper import ( # type: ignore[attr-defined] +from pyspark.ml.wrapper import ( _java2py, _py2java, JavaParams, diff --git a/python/pyspark/ml/tests/typing/test_clustering.yaml b/python/pyspark/ml/tests/typing/test_clustering.yaml new file mode 100644 index 0000000000000..b208573975d7f --- /dev/null +++ b/python/pyspark/ml/tests/typing/test_clustering.yaml @@ -0,0 +1,33 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +- case: InheritedLDAMethods + main: | + from pyspark.ml.clustering import LDAModel, LocalLDAModel, DistributedLDAModel + + distributed_model = DistributedLDAModel.load("foo") + reveal_type(distributed_model) + reveal_type(distributed_model.setFeaturesCol("foo")) + + local_model = distributed_model.toLocal() + reveal_type(local_model) + reveal_type(local_model.setFeaturesCol("foo")) + out: | + main:4: note: Revealed type is "pyspark.ml.clustering.DistributedLDAModel*" + main:5: note: Revealed type is "pyspark.ml.clustering.DistributedLDAModel*" + main:8: note: Revealed type is "pyspark.ml.clustering.LocalLDAModel" + main:9: note: Revealed type is "pyspark.ml.clustering.LocalLDAModel*" diff --git a/python/pyspark/ml/tests/typing/test_evaluation.yml b/python/pyspark/ml/tests/typing/test_evaluation.yml index e9e8f20570b45..a60166dfb96fd 100644 --- a/python/pyspark/ml/tests/typing/test_evaluation.yml +++ b/python/pyspark/ml/tests/typing/test_evaluation.yml @@ -24,3 +24,5 @@ BinaryClassificationEvaluator().setMetricName("foo") # E: Argument 1 to "setMetricName" of "BinaryClassificationEvaluator" has incompatible type "Literal['foo']"; expected "Union[Literal['areaUnderROC'], Literal['areaUnderPR']]" [arg-type] BinaryClassificationEvaluator(metricName="bar") # E: Argument "metricName" to "BinaryClassificationEvaluator" has incompatible type "Literal['bar']"; expected "Union[Literal['areaUnderROC'], Literal['areaUnderPR']]" [arg-type] + + reveal_type(BinaryClassificationEvaluator.load("foo")) # N: Revealed type is "pyspark.ml.evaluation.BinaryClassificationEvaluator*" diff --git a/python/pyspark/ml/tests/typing/test_feature.yml b/python/pyspark/ml/tests/typing/test_feature.yml index 3d6b09038ab50..0d1034a44df66 100644 --- a/python/pyspark/ml/tests/typing/test_feature.yml +++ b/python/pyspark/ml/tests/typing/test_feature.yml @@ -15,6 +15,17 @@ # limitations under the License. # + +- case: featureMethodChaining + main: | + from pyspark.ml.feature import NGram + + reveal_type(NGram().setInputCol("foo").setOutputCol("bar")) + + out: | + main:3: note: Revealed type is "pyspark.ml.feature.NGram" + + - case: stringIndexerOverloads main: | from pyspark.ml.feature import StringIndexer @@ -41,4 +52,4 @@ main:15: error: No overload variant of "StringIndexer" matches argument types "List[str]", "str" [call-overload] main:15: note: Possible overload variants: main:15: note: def StringIndexer(self, *, inputCol: Optional[str] = ..., outputCol: Optional[str] = ..., handleInvalid: str = ..., stringOrderType: str = ...) -> StringIndexer - main:15: note: def StringIndexer(self, *, inputCols: Optional[List[str]] = ..., outputCols: Optional[List[str]] = ..., handleInvalid: str = ..., stringOrderType: str = ...) -> StringIndexer \ No newline at end of file + main:15: note: def StringIndexer(self, *, inputCols: Optional[List[str]] = ..., outputCols: Optional[List[str]] = ..., handleInvalid: str = ..., stringOrderType: str = ...) -> StringIndexer diff --git a/python/pyspark/ml/tests/typing/test_regression.yml b/python/pyspark/ml/tests/typing/test_regression.yml index b045bec0d9891..4a54a565e626d 100644 --- a/python/pyspark/ml/tests/typing/test_regression.yml +++ b/python/pyspark/ml/tests/typing/test_regression.yml @@ -15,6 +15,21 @@ # limitations under the License. # +- case: linearRegressionMethodChaining + main: | + from pyspark.ml.regression import LinearRegression, LinearRegressionModel + + lr = LinearRegression() + reveal_type(lr.setFeaturesCol("foo").setLabelCol("bar")) + + lrm = LinearRegressionModel.load("/foo") + reveal_type(lrm.setPredictionCol("baz")) + + out: | + main:4: note: Revealed type is "pyspark.ml.regression.LinearRegression" + main:7: note: Revealed type is "pyspark.ml.regression.LinearRegressionModel" + + - case: loadFMRegressor main: | from pyspark.ml.regression import FMRegressor, FMRegressionModel diff --git a/python/pyspark/ml/tree.py b/python/pyspark/ml/tree.py index 3d607e1c943ff..ad405b742bdb1 100644 --- a/python/pyspark/ml/tree.py +++ b/python/pyspark/ml/tree.py @@ -14,8 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from typing import List, Sequence, TypeVar, TYPE_CHECKING from pyspark import since +from pyspark.ml.linalg import Vector from pyspark.ml.param import Params from pyspark.ml.param.shared import ( HasCheckpointInterval, @@ -30,35 +32,40 @@ from pyspark.ml.wrapper import JavaPredictionModel from pyspark.ml.common import inherit_doc +if TYPE_CHECKING: + from pyspark.ml._typing import P + +T = TypeVar("T") + @inherit_doc -class _DecisionTreeModel(JavaPredictionModel): +class _DecisionTreeModel(JavaPredictionModel[T]): """ Abstraction for Decision Tree models. .. versionadded:: 1.5.0 """ - @property + @property # type: ignore[misc] @since("1.5.0") - def numNodes(self): + def numNodes(self) -> int: """Return number of nodes of the decision tree.""" return self._call_java("numNodes") - @property + @property # type: ignore[misc] @since("1.5.0") - def depth(self): + def depth(self) -> int: """Return depth of the decision tree.""" return self._call_java("depth") - @property + @property # type: ignore[misc] @since("2.0.0") - def toDebugString(self): + def toDebugString(self) -> str: """Full description of model.""" return self._call_java("toDebugString") @since("3.0.0") - def predictLeaf(self, value): + def predictLeaf(self, value: Vector) -> float: """ Predict the indices of the leaves corresponding to the feature vector. """ @@ -70,7 +77,7 @@ class _DecisionTreeParams(HasCheckpointInterval, HasSeed, HasWeightCol): Mixin for Decision Tree parameters. """ - leafCol = Param( + leafCol: Param[str] = Param( Params._dummy(), "leafCol", "Leaf indices column name. Predicted leaf " @@ -78,7 +85,7 @@ class _DecisionTreeParams(HasCheckpointInterval, HasSeed, HasWeightCol): typeConverter=TypeConverters.toString, ) - maxDepth = Param( + maxDepth: Param[int] = Param( Params._dummy(), "maxDepth", "Maximum depth of the tree. (>= 0) E.g., " @@ -87,7 +94,7 @@ class _DecisionTreeParams(HasCheckpointInterval, HasSeed, HasWeightCol): typeConverter=TypeConverters.toInt, ) - maxBins = Param( + maxBins: Param[int] = Param( Params._dummy(), "maxBins", "Max number of bins for discretizing continuous " @@ -96,7 +103,7 @@ class _DecisionTreeParams(HasCheckpointInterval, HasSeed, HasWeightCol): typeConverter=TypeConverters.toInt, ) - minInstancesPerNode = Param( + minInstancesPerNode: Param[int] = Param( Params._dummy(), "minInstancesPerNode", "Minimum number of " @@ -107,7 +114,7 @@ class _DecisionTreeParams(HasCheckpointInterval, HasSeed, HasWeightCol): typeConverter=TypeConverters.toInt, ) - minWeightFractionPerNode = Param( + minWeightFractionPerNode: Param[float] = Param( Params._dummy(), "minWeightFractionPerNode", "Minimum " @@ -119,14 +126,14 @@ class _DecisionTreeParams(HasCheckpointInterval, HasSeed, HasWeightCol): typeConverter=TypeConverters.toFloat, ) - minInfoGain = Param( + minInfoGain: Param[float] = Param( Params._dummy(), "minInfoGain", "Minimum information gain for a split " + "to be considered at a tree node.", typeConverter=TypeConverters.toFloat, ) - maxMemoryInMB = Param( + maxMemoryInMB: Param[int] = Param( Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to " @@ -135,7 +142,7 @@ class _DecisionTreeParams(HasCheckpointInterval, HasSeed, HasWeightCol): typeConverter=TypeConverters.toInt, ) - cacheNodeIds = Param( + cacheNodeIds: Param[bool] = Param( Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass " @@ -146,58 +153,58 @@ class _DecisionTreeParams(HasCheckpointInterval, HasSeed, HasWeightCol): typeConverter=TypeConverters.toBoolean, ) - def __init__(self): + def __init__(self) -> None: super(_DecisionTreeParams, self).__init__() - def setLeafCol(self, value): + def setLeafCol(self: "P", value: str) -> "P": """ Sets the value of :py:attr:`leafCol`. """ return self._set(leafCol=value) - def getLeafCol(self): + def getLeafCol(self) -> str: """ Gets the value of leafCol or its default value. """ return self.getOrDefault(self.leafCol) - def getMaxDepth(self): + def getMaxDepth(self) -> int: """ Gets the value of maxDepth or its default value. """ return self.getOrDefault(self.maxDepth) - def getMaxBins(self): + def getMaxBins(self) -> int: """ Gets the value of maxBins or its default value. """ return self.getOrDefault(self.maxBins) - def getMinInstancesPerNode(self): + def getMinInstancesPerNode(self) -> int: """ Gets the value of minInstancesPerNode or its default value. """ return self.getOrDefault(self.minInstancesPerNode) - def getMinWeightFractionPerNode(self): + def getMinWeightFractionPerNode(self) -> float: """ Gets the value of minWeightFractionPerNode or its default value. """ return self.getOrDefault(self.minWeightFractionPerNode) - def getMinInfoGain(self): + def getMinInfoGain(self) -> float: """ Gets the value of minInfoGain or its default value. """ return self.getOrDefault(self.minInfoGain) - def getMaxMemoryInMB(self): + def getMaxMemoryInMB(self) -> int: """ Gets the value of maxMemoryInMB or its default value. """ return self.getOrDefault(self.maxMemoryInMB) - def getCacheNodeIds(self): + def getCacheNodeIds(self) -> bool: """ Gets the value of cacheNodeIds or its default value. """ @@ -205,44 +212,44 @@ def getCacheNodeIds(self): @inherit_doc -class _TreeEnsembleModel(JavaPredictionModel): +class _TreeEnsembleModel(JavaPredictionModel[T]): """ (private abstraction) Represents a tree ensemble model. """ - @property + @property # type: ignore[misc] @since("2.0.0") - def trees(self): + def trees(self) -> Sequence["_DecisionTreeModel"]: """Trees in this ensemble. Warning: These have null parent Estimators.""" return [_DecisionTreeModel(m) for m in list(self._call_java("trees"))] - @property + @property # type: ignore[misc] @since("2.0.0") - def getNumTrees(self): + def getNumTrees(self) -> int: """Number of trees in ensemble.""" return self._call_java("getNumTrees") - @property + @property # type: ignore[misc] @since("1.5.0") - def treeWeights(self): + def treeWeights(self) -> List[float]: """Return the weights for each tree""" return list(self._call_java("javaTreeWeights")) - @property + @property # type: ignore[misc] @since("2.0.0") - def totalNumNodes(self): + def totalNumNodes(self) -> int: """Total number of nodes, summed over all trees in the ensemble.""" return self._call_java("totalNumNodes") - @property + @property # type: ignore[misc] @since("2.0.0") - def toDebugString(self): + def toDebugString(self) -> str: """Full description of model.""" return self._call_java("toDebugString") @since("3.0.0") - def predictLeaf(self, value): + def predictLeaf(self, value: Vector) -> float: """ Predict the indices of the leaves corresponding to the feature vector. """ @@ -254,16 +261,16 @@ class _TreeEnsembleParams(_DecisionTreeParams): Mixin for Decision Tree-based ensemble algorithms parameters. """ - subsamplingRate = Param( + subsamplingRate: Param[float] = Param( Params._dummy(), "subsamplingRate", "Fraction of the training data " + "used for learning each decision tree, in range (0, 1].", typeConverter=TypeConverters.toFloat, ) - supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"] + supportedFeatureSubsetStrategies: List[str] = ["auto", "all", "onethird", "sqrt", "log2"] - featureSubsetStrategy = Param( + featureSubsetStrategy: Param[str] = Param( Params._dummy(), "featureSubsetStrategy", "The number of features to consider for splits at each tree node. Supported " @@ -277,18 +284,18 @@ class _TreeEnsembleParams(_DecisionTreeParams): typeConverter=TypeConverters.toString, ) - def __init__(self): + def __init__(self) -> None: super(_TreeEnsembleParams, self).__init__() @since("1.4.0") - def getSubsamplingRate(self): + def getSubsamplingRate(self) -> float: """ Gets the value of subsamplingRate or its default value. """ return self.getOrDefault(self.subsamplingRate) @since("1.4.0") - def getFeatureSubsetStrategy(self): + def getFeatureSubsetStrategy(self) -> str: """ Gets the value of featureSubsetStrategy or its default value. """ @@ -300,32 +307,32 @@ class _RandomForestParams(_TreeEnsembleParams): Private class to track supported random forest parameters. """ - numTrees = Param( + numTrees: Param[int] = Param( Params._dummy(), "numTrees", "Number of trees to train (>= 1).", typeConverter=TypeConverters.toInt, ) - bootstrap = Param( + bootstrap: Param[bool] = Param( Params._dummy(), "bootstrap", "Whether bootstrap samples are used " "when building trees.", typeConverter=TypeConverters.toBoolean, ) - def __init__(self): + def __init__(self) -> None: super(_RandomForestParams, self).__init__() @since("1.4.0") - def getNumTrees(self): + def getNumTrees(self) -> int: """ Gets the value of numTrees or its default value. """ return self.getOrDefault(self.numTrees) @since("3.0.0") - def getBootstrap(self): + def getBootstrap(self) -> bool: """ Gets the value of bootstrap or its default value. """ @@ -337,7 +344,7 @@ class _GBTParams(_TreeEnsembleParams, HasMaxIter, HasStepSize, HasValidationIndi Private class to track supported GBT params. """ - stepSize = Param( + stepSize: Param[float] = Param( Params._dummy(), "stepSize", "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking " @@ -345,7 +352,7 @@ class _GBTParams(_TreeEnsembleParams, HasMaxIter, HasStepSize, HasValidationIndi typeConverter=TypeConverters.toFloat, ) - validationTol = Param( + validationTol: Param[float] = Param( Params._dummy(), "validationTol", "Threshold for stopping early when fit with validation is used. " @@ -356,7 +363,7 @@ class _GBTParams(_TreeEnsembleParams, HasMaxIter, HasStepSize, HasValidationIndi ) @since("3.0.0") - def getValidationTol(self): + def getValidationTol(self) -> float: """ Gets the value of validationTol or its default value. """ @@ -368,9 +375,9 @@ class _HasVarianceImpurity(Params): Private class to track supported impurity measures. """ - supportedImpurities = ["variance"] + supportedImpurities: List[str] = ["variance"] - impurity = Param( + impurity: Param[str] = Param( Params._dummy(), "impurity", "Criterion used for information gain calculation (case-insensitive). " @@ -379,11 +386,11 @@ class _HasVarianceImpurity(Params): typeConverter=TypeConverters.toString, ) - def __init__(self): + def __init__(self) -> None: super(_HasVarianceImpurity, self).__init__() @since("1.4.0") - def getImpurity(self): + def getImpurity(self) -> str: """ Gets the value of impurity or its default value. """ @@ -397,9 +404,9 @@ class _TreeClassifierParams(Params): .. versionadded:: 1.4.0 """ - supportedImpurities = ["entropy", "gini"] + supportedImpurities: List[str] = ["entropy", "gini"] - impurity = Param( + impurity: Param[str] = Param( Params._dummy(), "impurity", "Criterion used for information gain calculation (case-insensitive). " @@ -408,11 +415,11 @@ class _TreeClassifierParams(Params): typeConverter=TypeConverters.toString, ) - def __init__(self): + def __init__(self) -> None: super(_TreeClassifierParams, self).__init__() @since("1.6.0") - def getImpurity(self): + def getImpurity(self) -> str: """ Gets the value of impurity or its default value. """ diff --git a/python/pyspark/ml/tree.pyi b/python/pyspark/ml/tree.pyi deleted file mode 100644 index 5a9b70ed7e3bc..0000000000000 --- a/python/pyspark/ml/tree.pyi +++ /dev/null @@ -1,110 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import List, Sequence -from pyspark.ml._typing import P, T - -from pyspark.ml.linalg import Vector -from pyspark import since as since # noqa: F401 -from pyspark.ml.common import inherit_doc as inherit_doc # noqa: F401 -from pyspark.ml.param import Param, Params as Params -from pyspark.ml.param.shared import ( # noqa: F401 - HasCheckpointInterval as HasCheckpointInterval, - HasMaxIter as HasMaxIter, - HasSeed as HasSeed, - HasStepSize as HasStepSize, - HasValidationIndicatorCol as HasValidationIndicatorCol, - HasWeightCol as HasWeightCol, - Param as Param, - TypeConverters as TypeConverters, -) -from pyspark.ml.wrapper import JavaPredictionModel as JavaPredictionModel - -class _DecisionTreeModel(JavaPredictionModel[T]): - @property - def numNodes(self) -> int: ... - @property - def depth(self) -> int: ... - @property - def toDebugString(self) -> str: ... - def predictLeaf(self, value: Vector) -> float: ... - -class _DecisionTreeParams(HasCheckpointInterval, HasSeed, HasWeightCol): - leafCol: Param[str] - maxDepth: Param[int] - maxBins: Param[int] - minInstancesPerNode: Param[int] - minWeightFractionPerNode: Param[float] - minInfoGain: Param[float] - maxMemoryInMB: Param[int] - cacheNodeIds: Param[bool] - def __init__(self) -> None: ... - def setLeafCol(self: P, value: str) -> P: ... - def getLeafCol(self) -> str: ... - def getMaxDepth(self) -> int: ... - def getMaxBins(self) -> int: ... - def getMinInstancesPerNode(self) -> int: ... - def getMinInfoGain(self) -> float: ... - def getMaxMemoryInMB(self) -> int: ... - def getCacheNodeIds(self) -> bool: ... - -class _TreeEnsembleModel(JavaPredictionModel[T]): - @property - def trees(self) -> Sequence[_DecisionTreeModel]: ... - @property - def getNumTrees(self) -> int: ... - @property - def treeWeights(self) -> List[float]: ... - @property - def totalNumNodes(self) -> int: ... - @property - def toDebugString(self) -> str: ... - -class _TreeEnsembleParams(_DecisionTreeParams): - subsamplingRate: Param[float] - supportedFeatureSubsetStrategies: List[str] - featureSubsetStrategy: Param[str] - def __init__(self) -> None: ... - def getSubsamplingRate(self) -> float: ... - def getFeatureSubsetStrategy(self) -> str: ... - -class _RandomForestParams(_TreeEnsembleParams): - numTrees: Param[int] - bootstrap: Param[bool] - def __init__(self) -> None: ... - def getNumTrees(self) -> int: ... - def getBootstrap(self) -> bool: ... - -class _GBTParams(_TreeEnsembleParams, HasMaxIter, HasStepSize, HasValidationIndicatorCol): - stepSize: Param[float] - validationTol: Param[float] - def getValidationTol(self) -> float: ... - -class _HasVarianceImpurity(Params): - supportedImpurities: List[str] - impurity: Param[str] - def __init__(self) -> None: ... - def getImpurity(self) -> str: ... - -class _TreeClassifierParams(Params): - supportedImpurities: List[str] - impurity: Param[str] - def __init__(self) -> None: ... - def getImpurity(self) -> str: ... - -class _TreeRegressorParams(_HasVarianceImpurity): ... diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 47805c9c2bee9..44a8b51ef8ec5 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -20,12 +20,28 @@ import itertools from multiprocessing.pool import ThreadPool +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, + overload, + TYPE_CHECKING, +) + import numpy as np from pyspark import keyword_only, since, SparkContext, inheritable_thread_target from pyspark.ml import Estimator, Transformer, Model from pyspark.ml.common import inherit_doc, _py2java, _java2py -from pyspark.ml.evaluation import Evaluator +from pyspark.ml.evaluation import Evaluator, JavaEvaluator from pyspark.ml.param import Params, Param, TypeConverters from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism, HasSeed from pyspark.ml.util import ( @@ -43,6 +59,13 @@ from pyspark.sql.functions import col, lit, rand, UserDefinedFunction from pyspark.sql.types import BooleanType +from pyspark.sql.dataframe import DataFrame + +if TYPE_CHECKING: + from pyspark.ml._typing import ParamMap + from py4j.java_gateway import JavaObject + from py4j.java_collections import JavaArray + __all__ = [ "ParamGridBuilder", "CrossValidator", @@ -52,7 +75,14 @@ ] -def _parallelFitTasks(est, train, eva, validation, epm, collectSubModel): +def _parallelFitTasks( + est: Estimator, + train: DataFrame, + eva: Evaluator, + validation: DataFrame, + epm: Sequence["ParamMap"], + collectSubModel: bool, +) -> List[Callable[[], Tuple[int, float, Transformer]]]: """ Creates a list of callables which can be called from different threads to fit and evaluate an estimator in parallel. Each callable returns an `(index, metric)` pair. @@ -79,7 +109,7 @@ def _parallelFitTasks(est, train, eva, validation, epm, collectSubModel): """ modelIter = est.fitMultiple(train, epm) - def singleTask(): + def singleTask() -> Tuple[int, float, Transformer]: index, model = next(modelIter) # TODO: duplicate evaluator to take extra params from input # Note: Supporting tuning params in evaluator need update method @@ -119,11 +149,11 @@ class ParamGridBuilder: True """ - def __init__(self): - self._param_grid = {} + def __init__(self) -> None: + self._param_grid: "ParamMap" = {} @since("1.4.0") - def addGrid(self, param, values): + def addGrid(self, param: Param[Any], values: List[Any]) -> "ParamGridBuilder": """ Sets the given parameters in this grid to fixed values. @@ -137,8 +167,16 @@ def addGrid(self, param, values): return self + @overload + def baseOn(self, __args: "ParamMap") -> "ParamGridBuilder": + ... + + @overload + def baseOn(self, *args: Tuple[Param, Any]) -> "ParamGridBuilder": + ... + @since("1.4.0") - def baseOn(self, *args): + def baseOn(self, *args: Union["ParamMap", Tuple[Param, Any]]) -> "ParamGridBuilder": """ Sets the given parameters in this grid to fixed values. Accepts either a parameter dictionary or a list of (parameter, value) pairs. @@ -152,7 +190,7 @@ def baseOn(self, *args): return self @since("1.4.0") - def build(self): + def build(self) -> List["ParamMap"]: """ Builds and returns all combinations of parameters specified by the param grid. @@ -160,7 +198,9 @@ def build(self): keys = self._param_grid.keys() grid_values = self._param_grid.values() - def to_key_value_pairs(keys, values): + def to_key_value_pairs( + keys: Iterable[Param], values: Iterable[Any] + ) -> Sequence[Tuple[Param, Any]]: return [(key, key.typeConverter(value)) for key, value in zip(keys, values)] return [dict(to_key_value_pairs(keys, prod)) for prod in itertools.product(*grid_values)] @@ -171,44 +211,50 @@ class _ValidatorParams(HasSeed): Common params for TrainValidationSplit and CrossValidator. """ - estimator = Param(Params._dummy(), "estimator", "estimator to be cross-validated") - estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps") - evaluator = Param( + estimator: Param[Estimator] = Param( + Params._dummy(), "estimator", "estimator to be cross-validated" + ) + estimatorParamMaps: Param[List["ParamMap"]] = Param( + Params._dummy(), "estimatorParamMaps", "estimator param maps" + ) + evaluator: Param[Evaluator] = Param( Params._dummy(), "evaluator", "evaluator used to select hyper-parameters that maximize the validator metric", ) @since("2.0.0") - def getEstimator(self): + def getEstimator(self) -> Estimator: """ Gets the value of estimator or its default value. """ return self.getOrDefault(self.estimator) @since("2.0.0") - def getEstimatorParamMaps(self): + def getEstimatorParamMaps(self) -> List["ParamMap"]: """ Gets the value of estimatorParamMaps or its default value. """ return self.getOrDefault(self.estimatorParamMaps) @since("2.0.0") - def getEvaluator(self): + def getEvaluator(self) -> Evaluator: """ Gets the value of evaluator or its default value. """ return self.getOrDefault(self.evaluator) @classmethod - def _from_java_impl(cls, java_stage): + def _from_java_impl( + cls, java_stage: "JavaObject" + ) -> Tuple[Estimator, List["ParamMap"], Evaluator]: """ Return Python estimator, estimatorParamMaps, and evaluator from a Java ValidatorParams. """ # Load information from java_stage to the instance. - estimator = JavaParams._from_java(java_stage.getEstimator()) - evaluator = JavaParams._from_java(java_stage.getEvaluator()) + estimator: Estimator = JavaParams._from_java(java_stage.getEstimator()) + evaluator: Evaluator = JavaParams._from_java(java_stage.getEvaluator()) if isinstance(estimator, JavaEstimator): epms = [ estimator._transfer_param_map_from_java(epm) @@ -224,19 +270,21 @@ def _from_java_impl(cls, java_stage): return estimator, epms, evaluator - def _to_java_impl(self): + def _to_java_impl(self) -> Tuple["JavaObject", "JavaObject", "JavaObject"]: """ Return Java estimator, estimatorParamMaps, and evaluator from this Python instance. """ gateway = SparkContext._gateway + assert gateway is not None and SparkContext._jvm is not None + cls = SparkContext._jvm.org.apache.spark.ml.param.ParamMap estimator = self.getEstimator() if isinstance(estimator, JavaEstimator): java_epms = gateway.new_array(cls, len(self.getEstimatorParamMaps())) for idx, epm in enumerate(self.getEstimatorParamMaps()): - java_epms[idx] = self.getEstimator()._transfer_param_map_to_java(epm) + java_epms[idx] = estimator._transfer_param_map_to_java(epm) elif MetaAlgorithmReadWrite.isMetaEstimator(estimator): # Meta estimator such as Pipeline, OneVsRest java_epms = _ValidatorSharedReadWrite.meta_estimator_transfer_param_maps_to_java( @@ -245,18 +293,24 @@ def _to_java_impl(self): else: raise ValueError("Unsupported estimator used in tuning: " + str(estimator)) - java_estimator = self.getEstimator()._to_java() - java_evaluator = self.getEvaluator()._to_java() + java_estimator = cast(JavaEstimator, self.getEstimator())._to_java() + java_evaluator = cast(JavaEvaluator, self.getEvaluator())._to_java() return java_estimator, java_epms, java_evaluator class _ValidatorSharedReadWrite: @staticmethod - def meta_estimator_transfer_param_maps_to_java(pyEstimator, pyParamMaps): + def meta_estimator_transfer_param_maps_to_java( + pyEstimator: Estimator, pyParamMaps: Sequence["ParamMap"] + ) -> "JavaArray": pyStages = MetaAlgorithmReadWrite.getAllNestedStages(pyEstimator) - stagePairs = list(map(lambda stage: (stage, stage._to_java()), pyStages)) + stagePairs = list(map(lambda stage: (stage, cast(JavaParams, stage)._to_java()), pyStages)) sc = SparkContext._active_spark_context + assert ( + sc is not None and SparkContext._jvm is not None and SparkContext._gateway is not None + ) + paramMapCls = SparkContext._jvm.org.apache.spark.ml.param.ParamMap javaParamMaps = SparkContext._gateway.new_array(paramMapCls, len(pyParamMaps)) @@ -271,7 +325,7 @@ def meta_estimator_transfer_param_maps_to_java(pyEstimator, pyParamMaps): if javaParam is None: raise ValueError("Resolve param in estimatorParamMaps failed: " + str(pyParam)) if isinstance(pyValue, Params) and hasattr(pyValue, "_to_java"): - javaValue = pyValue._to_java() + javaValue = cast(JavaParams, pyValue)._to_java() else: javaValue = _py2java(sc, pyValue) pair = javaParam.w(javaValue) @@ -280,10 +334,15 @@ def meta_estimator_transfer_param_maps_to_java(pyEstimator, pyParamMaps): return javaParamMaps @staticmethod - def meta_estimator_transfer_param_maps_from_java(pyEstimator, javaParamMaps): + def meta_estimator_transfer_param_maps_from_java( + pyEstimator: Estimator, javaParamMaps: "JavaArray" + ) -> List["ParamMap"]: pyStages = MetaAlgorithmReadWrite.getAllNestedStages(pyEstimator) - stagePairs = list(map(lambda stage: (stage, stage._to_java()), pyStages)) + stagePairs = list(map(lambda stage: (stage, cast(JavaParams, stage)._to_java()), pyStages)) sc = SparkContext._active_spark_context + + assert sc is not None and sc._jvm is not None + pyParamMaps = [] for javaParamMap in javaParamMaps: pyParamMap = dict() @@ -301,6 +360,7 @@ def meta_estimator_transfer_param_maps_from_java(pyEstimator, javaParamMaps): + javaParam.name() ) javaValue = javaPair.value() + pyValue: Any if sc._jvm.Class.forName( "org.apache.spark.ml.util.DefaultParamsWritable" ).isInstance(javaValue): @@ -312,20 +372,25 @@ def meta_estimator_transfer_param_maps_from_java(pyEstimator, javaParamMaps): return pyParamMaps @staticmethod - def is_java_convertible(instance): + def is_java_convertible(instance: _ValidatorParams) -> bool: allNestedStages = MetaAlgorithmReadWrite.getAllNestedStages(instance.getEstimator()) evaluator_convertible = isinstance(instance.getEvaluator(), JavaParams) estimator_convertible = all(map(lambda stage: hasattr(stage, "_to_java"), allNestedStages)) return estimator_convertible and evaluator_convertible @staticmethod - def saveImpl(path, instance, sc, extraMetadata=None): + def saveImpl( + path: str, + instance: _ValidatorParams, + sc: SparkContext, + extraMetadata: Optional[Dict[str, Any]] = None, + ) -> None: numParamsNotJson = 0 jsonEstimatorParamMaps = [] for paramMap in instance.getEstimatorParamMaps(): jsonParamMap = [] for p, v in paramMap.items(): - jsonParam = {"parent": p.parent, "name": p.name} + jsonParam: Dict[str, Any] = {"parent": p.parent, "name": p.name} if ( (isinstance(v, Estimator) and not MetaAlgorithmReadWrite.isMetaEstimator(v)) or isinstance(v, Transformer) @@ -334,7 +399,7 @@ def saveImpl(path, instance, sc, extraMetadata=None): relative_path = f"epm_{p.name}{numParamsNotJson}" param_path = os.path.join(path, relative_path) numParamsNotJson += 1 - v.save(param_path) + cast(MLWritable, v).save(param_path) jsonParam["value"] = relative_path jsonParam["isJson"] = False elif isinstance(v, MLWritable): @@ -355,16 +420,18 @@ def saveImpl(path, instance, sc, extraMetadata=None): DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, jsonParams) evaluatorPath = os.path.join(path, "evaluator") - instance.getEvaluator().save(evaluatorPath) + cast(MLWritable, instance.getEvaluator()).save(evaluatorPath) estimatorPath = os.path.join(path, "estimator") - instance.getEstimator().save(estimatorPath) + cast(MLWritable, instance.getEstimator()).save(estimatorPath) @staticmethod - def load(path, sc, metadata): + def load( + path: str, sc: SparkContext, metadata: Dict[str, Any] + ) -> Tuple[Dict[str, Any], Estimator, Evaluator, List["ParamMap"]]: evaluatorPath = os.path.join(path, "evaluator") - evaluator = DefaultParamsReader.loadParamsInstance(evaluatorPath, sc) + evaluator: Evaluator = DefaultParamsReader.loadParamsInstance(evaluatorPath, sc) estimatorPath = os.path.join(path, "estimator") - estimator = DefaultParamsReader.loadParamsInstance(estimatorPath, sc) + estimator: Estimator = DefaultParamsReader.loadParamsInstance(estimatorPath, sc) uidToParams = MetaAlgorithmReadWrite.getUidMap(estimator) uidToParams[evaluator.uid] = evaluator @@ -389,12 +456,12 @@ def load(path, sc, metadata): return metadata, estimator, evaluator, estimatorParamMaps @staticmethod - def validateParams(instance): + def validateParams(instance: _ValidatorParams) -> None: estiamtor = instance.getEstimator() evaluator = instance.getEvaluator() uidMap = MetaAlgorithmReadWrite.getUidMap(estiamtor) - for elem in [evaluator] + list(uidMap.values()): + for elem in [evaluator] + list(uidMap.values()): # type: ignore[arg-type] if not isinstance(elem, MLWritable): raise ValueError( f"Validator write will fail because it contains {elem.uid} " @@ -412,7 +479,7 @@ def validateParams(instance): raise ValueError(paramErr + repr(param)) @staticmethod - def getValidatorModelWriterPersistSubModelsParam(writer): + def getValidatorModelWriterPersistSubModelsParam(writer: MLWriter) -> bool: if "persistsubmodels" in writer.optionMap: persistSubModelsParam = writer.optionMap["persistsubmodels"].lower() if persistSubModelsParam == "true": @@ -425,10 +492,10 @@ def getValidatorModelWriterPersistSubModelsParam(writer): f"the possible values are True, 'True' or False, 'False'" ) else: - return writer.instance.subModels is not None + return writer.instance.subModels is not None # type: ignore[attr-defined] -_save_with_persist_submodels_no_submodels_found_err = ( +_save_with_persist_submodels_no_submodels_found_err: str = ( "When persisting tuning models, you can only set persistSubModels to true if the tuning " "was done with collectSubModels set to true. To save the sub-models, try rerunning fitting " "with collectSubModels set to true." @@ -436,15 +503,15 @@ def getValidatorModelWriterPersistSubModelsParam(writer): @inherit_doc -class CrossValidatorReader(MLReader): - def __init__(self, cls): +class CrossValidatorReader(MLReader["CrossValidator"]): + def __init__(self, cls: Type["CrossValidator"]): super(CrossValidatorReader, self).__init__() self.cls = cls - def load(self, path): + def load(self, path: str) -> "CrossValidator": metadata = DefaultParamsReader.loadMetadata(path, self.sc) if not DefaultParamsReader.isPythonParamsInstance(metadata): - return JavaMLReader(self.cls).load(path) + return JavaMLReader(self.cls).load(path) # type: ignore[arg-type] else: metadata, estimator, evaluator, estimatorParamMaps = _ValidatorSharedReadWrite.load( path, self.sc, metadata @@ -459,32 +526,32 @@ def load(self, path): @inherit_doc class CrossValidatorWriter(MLWriter): - def __init__(self, instance): + def __init__(self, instance: "CrossValidator"): super(CrossValidatorWriter, self).__init__() self.instance = instance - def saveImpl(self, path): + def saveImpl(self, path: str) -> None: _ValidatorSharedReadWrite.validateParams(self.instance) _ValidatorSharedReadWrite.saveImpl(path, self.instance, self.sc) @inherit_doc -class CrossValidatorModelReader(MLReader): - def __init__(self, cls): +class CrossValidatorModelReader(MLReader["CrossValidatorModel"]): + def __init__(self, cls: Type["CrossValidatorModel"]): super(CrossValidatorModelReader, self).__init__() self.cls = cls - def load(self, path): + def load(self, path: str) -> "CrossValidatorModel": metadata = DefaultParamsReader.loadMetadata(path, self.sc) if not DefaultParamsReader.isPythonParamsInstance(metadata): - return JavaMLReader(self.cls).load(path) + return JavaMLReader(self.cls).load(path) # type: ignore[arg-type] else: metadata, estimator, evaluator, estimatorParamMaps = _ValidatorSharedReadWrite.load( path, self.sc, metadata ) numFolds = metadata["paramMap"]["numFolds"] bestModelPath = os.path.join(path, "bestModel") - bestModel = DefaultParamsReader.loadParamsInstance(bestModelPath, self.sc) + bestModel: Model = DefaultParamsReader.loadParamsInstance(bestModelPath, self.sc) avgMetrics = metadata["avgMetrics"] if "stdMetrics" in metadata: stdMetrics = metadata["stdMetrics"] @@ -506,7 +573,10 @@ def load(self, path): subModels = None cvModel = CrossValidatorModel( - bestModel, avgMetrics=avgMetrics, subModels=subModels, stdMetrics=stdMetrics + bestModel, + avgMetrics=avgMetrics, + subModels=cast(List[List[Model]], subModels), + stdMetrics=stdMetrics, ) cvModel = cvModel._resetUid(metadata["uid"]) cvModel.set(cvModel.estimator, estimator) @@ -520,11 +590,11 @@ def load(self, path): @inherit_doc class CrossValidatorModelWriter(MLWriter): - def __init__(self, instance): + def __init__(self, instance: "CrossValidatorModel"): super(CrossValidatorModelWriter, self).__init__() self.instance = instance - def saveImpl(self, path): + def saveImpl(self, path: str) -> None: _ValidatorSharedReadWrite.validateParams(self.instance) instance = self.instance persistSubModels = _ValidatorSharedReadWrite.getValidatorModelWriterPersistSubModelsParam( @@ -536,7 +606,7 @@ def saveImpl(self, path): _ValidatorSharedReadWrite.saveImpl(path, instance, self.sc, extraMetadata=extraMetadata) bestModelPath = os.path.join(path, "bestModel") - instance.bestModel.save(bestModelPath) + cast(MLWritable, instance.bestModel).save(bestModelPath) if persistSubModels: if instance.subModels is None: raise ValueError(_save_with_persist_submodels_no_submodels_found_err) @@ -545,7 +615,7 @@ def saveImpl(self, path): splitPath = os.path.join(subModelsPath, f"fold{splitIndex}") for paramIndex in range(len(instance.getEstimatorParamMaps())): modelPath = os.path.join(splitPath, f"{paramIndex}") - instance.subModels[splitIndex][paramIndex].save(modelPath) + cast(MLWritable, instance.subModels[splitIndex][paramIndex]).save(modelPath) class _CrossValidatorParams(_ValidatorParams): @@ -555,14 +625,14 @@ class _CrossValidatorParams(_ValidatorParams): .. versionadded:: 3.0.0 """ - numFolds = Param( + numFolds: Param[int] = Param( Params._dummy(), "numFolds", "number of folds for cross validation", typeConverter=TypeConverters.toInt, ) - foldCol = Param( + foldCol: Param[str] = Param( Params._dummy(), "foldCol", "Param for the column name of user " @@ -573,19 +643,19 @@ class _CrossValidatorParams(_ValidatorParams): typeConverter=TypeConverters.toString, ) - def __init__(self, *args): + def __init__(self, *args: Any): super(_CrossValidatorParams, self).__init__(*args) self._setDefault(numFolds=3, foldCol="") @since("1.4.0") - def getNumFolds(self): + def getNumFolds(self) -> int: """ Gets the value of numFolds or its default value. """ return self.getOrDefault(self.numFolds) @since("3.1.0") - def getFoldCol(self): + def getFoldCol(self) -> str: """ Gets the value of foldCol or its default value. """ @@ -593,7 +663,12 @@ def getFoldCol(self): class CrossValidator( - Estimator, _CrossValidatorParams, HasParallelism, HasCollectSubModels, MLReadable, MLWritable + Estimator["CrossValidatorModel"], + _CrossValidatorParams, + HasParallelism, + HasCollectSubModels, + MLReadable["CrossValidator"], + MLWritable, ): """ @@ -641,19 +716,21 @@ class CrossValidator( 0.8333... """ + _input_kwargs: Dict[str, Any] + @keyword_only def __init__( self, *, - estimator=None, - estimatorParamMaps=None, - evaluator=None, - numFolds=3, - seed=None, - parallelism=1, - collectSubModels=False, - foldCol="", - ): + estimator: Optional[Estimator] = None, + estimatorParamMaps: Optional[List["ParamMap"]] = None, + evaluator: Optional[Evaluator] = None, + numFolds: int = 3, + seed: Optional[int] = None, + parallelism: int = 1, + collectSubModels: bool = False, + foldCol: str = "", + ) -> None: """ __init__(self, \\*, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ seed=None, parallelism=1, collectSubModels=False, foldCol="") @@ -668,15 +745,15 @@ def __init__( def setParams( self, *, - estimator=None, - estimatorParamMaps=None, - evaluator=None, - numFolds=3, - seed=None, - parallelism=1, - collectSubModels=False, - foldCol="", - ): + estimator: Optional[Estimator] = None, + estimatorParamMaps: Optional[List["ParamMap"]] = None, + evaluator: Optional[Evaluator] = None, + numFolds: int = 3, + seed: Optional[int] = None, + parallelism: int = 1, + collectSubModels: bool = False, + foldCol: str = "", + ) -> "CrossValidator": """ setParams(self, \\*, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ seed=None, parallelism=1, collectSubModels=False, foldCol=""): @@ -686,65 +763,65 @@ def setParams( return self._set(**kwargs) @since("2.0.0") - def setEstimator(self, value): + def setEstimator(self, value: Estimator) -> "CrossValidator": """ Sets the value of :py:attr:`estimator`. """ return self._set(estimator=value) @since("2.0.0") - def setEstimatorParamMaps(self, value): + def setEstimatorParamMaps(self, value: List["ParamMap"]) -> "CrossValidator": """ Sets the value of :py:attr:`estimatorParamMaps`. """ return self._set(estimatorParamMaps=value) @since("2.0.0") - def setEvaluator(self, value): + def setEvaluator(self, value: Evaluator) -> "CrossValidator": """ Sets the value of :py:attr:`evaluator`. """ return self._set(evaluator=value) @since("1.4.0") - def setNumFolds(self, value): + def setNumFolds(self, value: int) -> "CrossValidator": """ Sets the value of :py:attr:`numFolds`. """ return self._set(numFolds=value) @since("3.1.0") - def setFoldCol(self, value): + def setFoldCol(self, value: str) -> "CrossValidator": """ Sets the value of :py:attr:`foldCol`. """ return self._set(foldCol=value) - def setSeed(self, value): + def setSeed(self, value: int) -> "CrossValidator": """ Sets the value of :py:attr:`seed`. """ return self._set(seed=value) - def setParallelism(self, value): + def setParallelism(self, value: int) -> "CrossValidator": """ Sets the value of :py:attr:`parallelism`. """ return self._set(parallelism=value) - def setCollectSubModels(self, value): + def setCollectSubModels(self, value: bool) -> "CrossValidator": """ Sets the value of :py:attr:`collectSubModels`. """ return self._set(collectSubModels=value) @staticmethod - def _gen_avg_and_std_metrics(metrics_all): + def _gen_avg_and_std_metrics(metrics_all: List[List[float]]) -> Tuple[List[float], List[float]]: avg_metrics = np.mean(metrics_all, axis=0) std_metrics = np.std(metrics_all, axis=0) return list(avg_metrics), list(std_metrics) - def _fit(self, dataset): + def _fit(self, dataset: DataFrame) -> "CrossValidatorModel": est = self.getOrDefault(self.estimator) epm = self.getOrDefault(self.estimatorParamMaps) numModels = len(epm) @@ -770,6 +847,7 @@ def _fit(self, dataset): for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks): metrics_all[i][j] = metric if collectSubModelsParam: + assert subModels is not None subModels[i][j] = subModel validation.unpersist() @@ -782,9 +860,11 @@ def _fit(self, dataset): else: bestIndex = np.argmin(metrics) bestModel = est.fit(dataset, epm[bestIndex]) - return self._copyValues(CrossValidatorModel(bestModel, metrics, subModels, std_metrics)) + return self._copyValues( + CrossValidatorModel(bestModel, metrics, cast(List[List[Model]], subModels), std_metrics) + ) - def _kFold(self, dataset): + def _kFold(self, dataset: DataFrame) -> List[Tuple[DataFrame, DataFrame]]: nFolds = self.getOrDefault(self.numFolds) foldCol = self.getOrDefault(self.foldCol) @@ -804,7 +884,7 @@ def _kFold(self, dataset): datasets.append((train, validation)) else: # Use user-specified fold numbers. - def checker(foldNum): + def checker(foldNum: int) -> bool: if foldNum < 0 or foldNum >= nFolds: raise ValueError( "Fold number must be in range [0, %s), but got %s." % (nFolds, foldNum) @@ -825,7 +905,7 @@ def checker(foldNum): return datasets - def copy(self, extra=None): + def copy(self, extra: Optional["ParamMap"] = None) -> "CrossValidator": """ Creates a copy of this instance with a randomly generated uid and some extra params. This copies creates a deep copy of @@ -855,20 +935,20 @@ def copy(self, extra=None): return newCV @since("2.3.0") - def write(self): + def write(self) -> MLWriter: """Returns an MLWriter instance for this ML instance.""" if _ValidatorSharedReadWrite.is_java_convertible(self): - return JavaMLWriter(self) + return JavaMLWriter(self) # type: ignore[arg-type] return CrossValidatorWriter(self) @classmethod @since("2.3.0") - def read(cls): + def read(cls) -> CrossValidatorReader: """Returns an MLReader instance for this class.""" return CrossValidatorReader(cls) @classmethod - def _from_java(cls, java_stage): + def _from_java(cls, java_stage: "JavaObject") -> "CrossValidator": """ Given a Java CrossValidator, create and return a Python wrapper of it. Used for ML persistence. @@ -894,7 +974,7 @@ def _from_java(cls, java_stage): py_stage._resetUid(java_stage.uid()) return py_stage - def _to_java(self): + def _to_java(self) -> "JavaObject": """ Transfer this instance to a Java CrossValidator. Used for ML persistence. @@ -919,7 +999,9 @@ def _to_java(self): return _java_obj -class CrossValidatorModel(Model, _CrossValidatorParams, MLReadable, MLWritable): +class CrossValidatorModel( + Model, _CrossValidatorParams, MLReadable["CrossValidatorModel"], MLWritable +): """ CrossValidatorModel contains the model with the highest average cross-validation metric across folds and uses this model to transform input data. CrossValidatorModel @@ -934,7 +1016,13 @@ class CrossValidatorModel(Model, _CrossValidatorParams, MLReadable, MLWritable): CrossValidator.estimatorParamMaps. """ - def __init__(self, bestModel, avgMetrics=None, subModels=None, stdMetrics=None): + def __init__( + self, + bestModel: Model, + avgMetrics: Optional[List[float]] = None, + subModels: Optional[List[List[Model]]] = None, + stdMetrics: Optional[List[float]] = None, + ): super(CrossValidatorModel, self).__init__() #: best model from cross validation self.bestModel = bestModel @@ -947,10 +1035,10 @@ def __init__(self, bestModel, avgMetrics=None, subModels=None, stdMetrics=None): #: CrossValidator.estimatorParamMaps, in the corresponding order. self.stdMetrics = stdMetrics or [] - def _transform(self, dataset): + def _transform(self, dataset: DataFrame) -> DataFrame: return self.bestModel.transform(dataset) - def copy(self, extra=None): + def copy(self, extra: Optional["ParamMap"] = None) -> "CrossValidatorModel": """ Creates a copy of this instance with a randomly generated uid and some extra params. This copies the underlying bestModel, @@ -974,6 +1062,7 @@ def copy(self, extra=None): extra = dict() bestModel = self.bestModel.copy(extra) avgMetrics = list(self.avgMetrics) + assert self.subModels is not None subModels = [ [sub_model.copy() for sub_model in fold_sub_models] for fold_sub_models in self.subModels @@ -984,26 +1073,28 @@ def copy(self, extra=None): ) @since("2.3.0") - def write(self): + def write(self) -> MLWriter: """Returns an MLWriter instance for this ML instance.""" if _ValidatorSharedReadWrite.is_java_convertible(self): - return JavaMLWriter(self) + return JavaMLWriter(self) # type: ignore[arg-type] return CrossValidatorModelWriter(self) @classmethod @since("2.3.0") - def read(cls): + def read(cls) -> CrossValidatorModelReader: """Returns an MLReader instance for this class.""" return CrossValidatorModelReader(cls) @classmethod - def _from_java(cls, java_stage): + def _from_java(cls, java_stage: "JavaObject") -> "CrossValidatorModel": """ Given a Java CrossValidatorModel, create and return a Python wrapper of it. Used for ML persistence. """ sc = SparkContext._active_spark_context - bestModel = JavaParams._from_java(java_stage.bestModel()) + assert sc is not None + + bestModel: Model = JavaParams._from_java(java_stage.bestModel()) avgMetrics = _java2py(sc, java_stage.avgMetrics()) estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage) @@ -1028,7 +1119,7 @@ def _from_java(cls, java_stage): py_stage._resetUid(java_stage.uid()) return py_stage - def _to_java(self): + def _to_java(self) -> "JavaObject": """ Transfer this instance to a Java CrossValidatorModel. Used for ML persistence. @@ -1039,10 +1130,12 @@ def _to_java(self): """ sc = SparkContext._active_spark_context + assert sc is not None + _java_obj = JavaParams._new_java_obj( "org.apache.spark.ml.tuning.CrossValidatorModel", self.uid, - self.bestModel._to_java(), + cast(JavaParams, self.bestModel)._to_java(), _py2java(sc, self.avgMetrics), ) estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl() @@ -1062,7 +1155,7 @@ def _to_java(self): if self.subModels is not None: java_sub_models = [ - [sub_model._to_java() for sub_model in fold_sub_models] + [cast(JavaParams, sub_model)._to_java() for sub_model in fold_sub_models] for fold_sub_models in self.subModels ] _java_obj.setSubModels(java_sub_models) @@ -1070,15 +1163,15 @@ def _to_java(self): @inherit_doc -class TrainValidationSplitReader(MLReader): - def __init__(self, cls): +class TrainValidationSplitReader(MLReader["TrainValidationSplit"]): + def __init__(self, cls: Type["TrainValidationSplit"]): super(TrainValidationSplitReader, self).__init__() self.cls = cls - def load(self, path): + def load(self, path: str) -> "TrainValidationSplit": metadata = DefaultParamsReader.loadMetadata(path, self.sc) if not DefaultParamsReader.isPythonParamsInstance(metadata): - return JavaMLReader(self.cls).load(path) + return JavaMLReader(self.cls).load(path) # type: ignore[arg-type] else: metadata, estimator, evaluator, estimatorParamMaps = _ValidatorSharedReadWrite.load( path, self.sc, metadata @@ -1093,31 +1186,31 @@ def load(self, path): @inherit_doc class TrainValidationSplitWriter(MLWriter): - def __init__(self, instance): + def __init__(self, instance: "TrainValidationSplit"): super(TrainValidationSplitWriter, self).__init__() self.instance = instance - def saveImpl(self, path): + def saveImpl(self, path: str) -> None: _ValidatorSharedReadWrite.validateParams(self.instance) _ValidatorSharedReadWrite.saveImpl(path, self.instance, self.sc) @inherit_doc -class TrainValidationSplitModelReader(MLReader): - def __init__(self, cls): +class TrainValidationSplitModelReader(MLReader["TrainValidationSplitModel"]): + def __init__(self, cls: Type["TrainValidationSplitModel"]): super(TrainValidationSplitModelReader, self).__init__() self.cls = cls - def load(self, path): + def load(self, path: str) -> "TrainValidationSplitModel": metadata = DefaultParamsReader.loadMetadata(path, self.sc) if not DefaultParamsReader.isPythonParamsInstance(metadata): - return JavaMLReader(self.cls).load(path) + return JavaMLReader(self.cls).load(path) # type: ignore[arg-type] else: metadata, estimator, evaluator, estimatorParamMaps = _ValidatorSharedReadWrite.load( path, self.sc, metadata ) bestModelPath = os.path.join(path, "bestModel") - bestModel = DefaultParamsReader.loadParamsInstance(bestModelPath, self.sc) + bestModel: Model = DefaultParamsReader.loadParamsInstance(bestModelPath, self.sc) validationMetrics = metadata["validationMetrics"] persistSubModels = ("persistSubModels" in metadata) and metadata["persistSubModels"] @@ -1132,7 +1225,9 @@ def load(self, path): subModels = None tvsModel = TrainValidationSplitModel( - bestModel, validationMetrics=validationMetrics, subModels=subModels + bestModel, + validationMetrics=validationMetrics, + subModels=cast(Optional[List[Model]], subModels), ) tvsModel = tvsModel._resetUid(metadata["uid"]) tvsModel.set(tvsModel.estimator, estimator) @@ -1146,11 +1241,11 @@ def load(self, path): @inherit_doc class TrainValidationSplitModelWriter(MLWriter): - def __init__(self, instance): + def __init__(self, instance: "TrainValidationSplitModel"): super(TrainValidationSplitModelWriter, self).__init__() self.instance = instance - def saveImpl(self, path): + def saveImpl(self, path: str) -> None: _ValidatorSharedReadWrite.validateParams(self.instance) instance = self.instance persistSubModels = _ValidatorSharedReadWrite.getValidatorModelWriterPersistSubModelsParam( @@ -1163,14 +1258,14 @@ def saveImpl(self, path): } _ValidatorSharedReadWrite.saveImpl(path, instance, self.sc, extraMetadata=extraMetadata) bestModelPath = os.path.join(path, "bestModel") - instance.bestModel.save(bestModelPath) + cast(MLWritable, instance.bestModel).save(bestModelPath) if persistSubModels: if instance.subModels is None: raise ValueError(_save_with_persist_submodels_no_submodels_found_err) subModelsPath = os.path.join(path, "subModels") for paramIndex in range(len(instance.getEstimatorParamMaps())): modelPath = os.path.join(subModelsPath, f"{paramIndex}") - instance.subModels[paramIndex].save(modelPath) + cast(MLWritable, instance.subModels[paramIndex]).save(modelPath) class _TrainValidationSplitParams(_ValidatorParams): @@ -1180,7 +1275,7 @@ class _TrainValidationSplitParams(_ValidatorParams): .. versionadded:: 3.0.0 """ - trainRatio = Param( + trainRatio: Param[float] = Param( Params._dummy(), "trainRatio", "Param for ratio between train and\ @@ -1188,12 +1283,12 @@ class _TrainValidationSplitParams(_ValidatorParams): typeConverter=TypeConverters.toFloat, ) - def __init__(self, *args): + def __init__(self, *args: Any): super(_TrainValidationSplitParams, self).__init__(*args) self._setDefault(trainRatio=0.75) @since("2.0.0") - def getTrainRatio(self): + def getTrainRatio(self) -> float: """ Gets the value of trainRatio or its default value. """ @@ -1201,11 +1296,11 @@ def getTrainRatio(self): class TrainValidationSplit( - Estimator, + Estimator["TrainValidationSplitModel"], _TrainValidationSplitParams, HasParallelism, HasCollectSubModels, - MLReadable, + MLReadable["TrainValidationSplit"], MLWritable, ): """ @@ -1252,18 +1347,20 @@ class TrainValidationSplit( 0.833... """ + _input_kwargs: Dict[str, Any] + @keyword_only def __init__( self, *, - estimator=None, - estimatorParamMaps=None, - evaluator=None, - trainRatio=0.75, - parallelism=1, - collectSubModels=False, - seed=None, - ): + estimator: Optional[Estimator] = None, + estimatorParamMaps: Optional[List["ParamMap"]] = None, + evaluator: Optional[Evaluator] = None, + trainRatio: float = 0.75, + parallelism: int = 1, + collectSubModels: bool = False, + seed: Optional[int] = None, + ) -> None: """ __init__(self, \\*, estimator=None, estimatorParamMaps=None, evaluator=None, \ trainRatio=0.75, parallelism=1, collectSubModels=False, seed=None) @@ -1278,14 +1375,14 @@ def __init__( def setParams( self, *, - estimator=None, - estimatorParamMaps=None, - evaluator=None, - trainRatio=0.75, - parallelism=1, - collectSubModels=False, - seed=None, - ): + estimator: Optional[Estimator] = None, + estimatorParamMaps: Optional[List["ParamMap"]] = None, + evaluator: Optional[Evaluator] = None, + trainRatio: float = 0.75, + parallelism: int = 1, + collectSubModels: bool = False, + seed: Optional[int] = None, + ) -> "TrainValidationSplit": """ setParams(self, \\*, estimator=None, estimatorParamMaps=None, evaluator=None, \ trainRatio=0.75, parallelism=1, collectSubModels=False, seed=None): @@ -1295,52 +1392,52 @@ def setParams( return self._set(**kwargs) @since("2.0.0") - def setEstimator(self, value): + def setEstimator(self, value: Estimator) -> "TrainValidationSplit": """ Sets the value of :py:attr:`estimator`. """ return self._set(estimator=value) @since("2.0.0") - def setEstimatorParamMaps(self, value): + def setEstimatorParamMaps(self, value: List["ParamMap"]) -> "TrainValidationSplit": """ Sets the value of :py:attr:`estimatorParamMaps`. """ return self._set(estimatorParamMaps=value) @since("2.0.0") - def setEvaluator(self, value): + def setEvaluator(self, value: Evaluator) -> "TrainValidationSplit": """ Sets the value of :py:attr:`evaluator`. """ return self._set(evaluator=value) @since("2.0.0") - def setTrainRatio(self, value): + def setTrainRatio(self, value: float) -> "TrainValidationSplit": """ Sets the value of :py:attr:`trainRatio`. """ return self._set(trainRatio=value) - def setSeed(self, value): + def setSeed(self, value: int) -> "TrainValidationSplit": """ Sets the value of :py:attr:`seed`. """ return self._set(seed=value) - def setParallelism(self, value): + def setParallelism(self, value: int) -> "TrainValidationSplit": """ Sets the value of :py:attr:`parallelism`. """ return self._set(parallelism=value) - def setCollectSubModels(self, value): + def setCollectSubModels(self, value: bool) -> "TrainValidationSplit": """ Sets the value of :py:attr:`collectSubModels`. """ return self._set(collectSubModels=value) - def _fit(self, dataset): + def _fit(self, dataset: DataFrame) -> "TrainValidationSplitModel": est = self.getOrDefault(self.estimator) epm = self.getOrDefault(self.estimatorParamMaps) numModels = len(epm) @@ -1367,19 +1464,26 @@ def _fit(self, dataset): for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks): metrics[j] = metric if collectSubModelsParam: + assert subModels is not None subModels[j] = subModel train.unpersist() validation.unpersist() if eva.isLargerBetter(): - bestIndex = np.argmax(metrics) + bestIndex = np.argmax(cast(List[float], metrics)) else: - bestIndex = np.argmin(metrics) + bestIndex = np.argmin(cast(List[float], metrics)) bestModel = est.fit(dataset, epm[bestIndex]) - return self._copyValues(TrainValidationSplitModel(bestModel, metrics, subModels)) + return self._copyValues( + TrainValidationSplitModel( + bestModel, + cast(List[float], metrics), + subModels, # type: ignore[arg-type] + ) + ) - def copy(self, extra=None): + def copy(self, extra: Optional["ParamMap"] = None) -> "TrainValidationSplit": """ Creates a copy of this instance with a randomly generated uid and some extra params. This copies creates a deep copy of @@ -1408,20 +1512,20 @@ def copy(self, extra=None): return newTVS @since("2.3.0") - def write(self): + def write(self) -> MLWriter: """Returns an MLWriter instance for this ML instance.""" if _ValidatorSharedReadWrite.is_java_convertible(self): - return JavaMLWriter(self) + return JavaMLWriter(self) # type: ignore[arg-type] return TrainValidationSplitWriter(self) @classmethod @since("2.3.0") - def read(cls): + def read(cls) -> TrainValidationSplitReader: """Returns an MLReader instance for this class.""" return TrainValidationSplitReader(cls) @classmethod - def _from_java(cls, java_stage): + def _from_java(cls, java_stage: "JavaObject") -> "TrainValidationSplit": """ Given a Java TrainValidationSplit, create and return a Python wrapper of it. Used for ML persistence. @@ -1445,7 +1549,7 @@ def _from_java(cls, java_stage): py_stage._resetUid(java_stage.uid()) return py_stage - def _to_java(self): + def _to_java(self) -> "JavaObject": """ Transfer this instance to a Java TrainValidationSplit. Used for ML persistence. @@ -1470,14 +1574,21 @@ def _to_java(self): return _java_obj -class TrainValidationSplitModel(Model, _TrainValidationSplitParams, MLReadable, MLWritable): +class TrainValidationSplitModel( + Model, _TrainValidationSplitParams, MLReadable["TrainValidationSplitModel"], MLWritable +): """ Model from train validation split. .. versionadded:: 2.0.0 """ - def __init__(self, bestModel, validationMetrics=None, subModels=None): + def __init__( + self, + bestModel: Model, + validationMetrics: Optional[List[float]] = None, + subModels: Optional[List[Model]] = None, + ): super(TrainValidationSplitModel, self).__init__() #: best model from train validation split self.bestModel = bestModel @@ -1486,10 +1597,10 @@ def __init__(self, bestModel, validationMetrics=None, subModels=None): #: sub models from train validation split self.subModels = subModels - def _transform(self, dataset): + def _transform(self, dataset: DataFrame) -> DataFrame: return self.bestModel.transform(dataset) - def copy(self, extra=None): + def copy(self, extra: Optional["ParamMap"] = None) -> "TrainValidationSplitModel": """ Creates a copy of this instance with a randomly generated uid and some extra params. This copies the underlying bestModel, @@ -1514,26 +1625,27 @@ def copy(self, extra=None): extra = dict() bestModel = self.bestModel.copy(extra) validationMetrics = list(self.validationMetrics) + assert self.subModels is not None subModels = [model.copy() for model in self.subModels] return self._copyValues( TrainValidationSplitModel(bestModel, validationMetrics, subModels), extra=extra ) @since("2.3.0") - def write(self): + def write(self) -> MLWriter: """Returns an MLWriter instance for this ML instance.""" if _ValidatorSharedReadWrite.is_java_convertible(self): - return JavaMLWriter(self) + return JavaMLWriter(self) # type: ignore[arg-type] return TrainValidationSplitModelWriter(self) @classmethod @since("2.3.0") - def read(cls): + def read(cls) -> TrainValidationSplitModelReader: """Returns an MLReader instance for this class.""" return TrainValidationSplitModelReader(cls) @classmethod - def _from_java(cls, java_stage): + def _from_java(cls, java_stage: "JavaObject") -> "TrainValidationSplitModel": """ Given a Java TrainValidationSplitModel, create and return a Python wrapper of it. Used for ML persistence. @@ -1541,7 +1653,9 @@ def _from_java(cls, java_stage): # Load information from java_stage to the instance. sc = SparkContext._active_spark_context - bestModel = JavaParams._from_java(java_stage.bestModel()) + assert sc is not None + + bestModel: Model = JavaParams._from_java(java_stage.bestModel()) validationMetrics = _java2py(sc, java_stage.validationMetrics()) estimator, epms, evaluator = super(TrainValidationSplitModel, cls)._from_java_impl( java_stage @@ -1566,7 +1680,7 @@ def _from_java(cls, java_stage): py_stage._resetUid(java_stage.uid()) return py_stage - def _to_java(self): + def _to_java(self) -> "JavaObject": """ Transfer this instance to a Java TrainValidationSplitModel. Used for ML persistence. @@ -1577,10 +1691,12 @@ def _to_java(self): """ sc = SparkContext._active_spark_context + assert sc is not None + _java_obj = JavaParams._new_java_obj( "org.apache.spark.ml.tuning.TrainValidationSplitModel", self.uid, - self.bestModel._to_java(), + cast(JavaParams, self.bestModel)._to_java(), _py2java(sc, self.validationMetrics), ) estimator, epms, evaluator = super(TrainValidationSplitModel, self)._to_java_impl() @@ -1598,7 +1714,9 @@ def _to_java(self): _java_obj.set(pair) if self.subModels is not None: - java_sub_models = [sub_model._to_java() for sub_model in self.subModels] + java_sub_models = [ + cast(JavaParams, sub_model)._to_java() for sub_model in self.subModels + ] _java_obj.setSubModels(java_sub_models) return _java_obj diff --git a/python/pyspark/ml/tuning.pyi b/python/pyspark/ml/tuning.pyi deleted file mode 100644 index 75da80bec83c6..0000000000000 --- a/python/pyspark/ml/tuning.pyi +++ /dev/null @@ -1,223 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import overload -from typing import Any, List, Optional, Tuple, Type -from pyspark.ml._typing import ParamMap - -from pyspark.ml import Estimator, Model -from pyspark.ml.evaluation import Evaluator -from pyspark.ml.param import Param -from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism, HasSeed -from pyspark.ml.util import MLReader, MLReadable, MLWriter, MLWritable - -class ParamGridBuilder: - def __init__(self) -> None: ... - def addGrid(self, param: Param, values: List[Any]) -> ParamGridBuilder: ... - @overload - def baseOn(self, __args: ParamMap) -> ParamGridBuilder: ... - @overload - def baseOn(self, *args: Tuple[Param, Any]) -> ParamGridBuilder: ... - def build(self) -> List[ParamMap]: ... - -class _ValidatorParams(HasSeed): - estimator: Param[Estimator] - estimatorParamMaps: Param[List[ParamMap]] - evaluator: Param[Evaluator] - def getEstimator(self) -> Estimator: ... - def getEstimatorParamMaps(self) -> List[ParamMap]: ... - def getEvaluator(self) -> Evaluator: ... - -class _CrossValidatorParams(_ValidatorParams): - numFolds: Param[int] - foldCol: Param[str] - def __init__(self, *args: Any): ... - def getNumFolds(self) -> int: ... - def getFoldCol(self) -> str: ... - -class CrossValidator( - Estimator[CrossValidatorModel], - _CrossValidatorParams, - HasParallelism, - HasCollectSubModels, - MLReadable[CrossValidator], - MLWritable, -): - def __init__( - self, - *, - estimator: Optional[Estimator] = ..., - estimatorParamMaps: Optional[List[ParamMap]] = ..., - evaluator: Optional[Evaluator] = ..., - numFolds: int = ..., - seed: Optional[int] = ..., - parallelism: int = ..., - collectSubModels: bool = ..., - foldCol: str = ..., - ) -> None: ... - def setParams( - self, - *, - estimator: Optional[Estimator] = ..., - estimatorParamMaps: Optional[List[ParamMap]] = ..., - evaluator: Optional[Evaluator] = ..., - numFolds: int = ..., - seed: Optional[int] = ..., - parallelism: int = ..., - collectSubModels: bool = ..., - foldCol: str = ..., - ) -> CrossValidator: ... - def setEstimator(self, value: Estimator) -> CrossValidator: ... - def setEstimatorParamMaps(self, value: List[ParamMap]) -> CrossValidator: ... - def setEvaluator(self, value: Evaluator) -> CrossValidator: ... - def setNumFolds(self, value: int) -> CrossValidator: ... - def setFoldCol(self, value: str) -> CrossValidator: ... - def setSeed(self, value: int) -> CrossValidator: ... - def setParallelism(self, value: int) -> CrossValidator: ... - def setCollectSubModels(self, value: bool) -> CrossValidator: ... - def copy(self, extra: Optional[ParamMap] = ...) -> CrossValidator: ... - def write(self) -> MLWriter: ... - @classmethod - def read(cls: Type[CrossValidator]) -> MLReader: ... - -class CrossValidatorModel( - Model, _CrossValidatorParams, MLReadable[CrossValidatorModel], MLWritable -): - bestModel: Model - avgMetrics: List[float] - subModels: List[List[Model]] - def __init__( - self, - bestModel: Model, - avgMetrics: Optional[List[float]] = ..., - subModels: Optional[List[List[Model]]] = ..., - ) -> None: ... - def copy(self, extra: Optional[ParamMap] = ...) -> CrossValidatorModel: ... - def write(self) -> MLWriter: ... - @classmethod - def read(cls: Type[CrossValidatorModel]) -> MLReader: ... - -class _TrainValidationSplitParams(_ValidatorParams): - trainRatio: Param[float] - def __init__(self, *args: Any): ... - def getTrainRatio(self) -> float: ... - -class TrainValidationSplit( - Estimator[TrainValidationSplitModel], - _TrainValidationSplitParams, - HasParallelism, - HasCollectSubModels, - MLReadable[TrainValidationSplit], - MLWritable, -): - def __init__( - self, - *, - estimator: Optional[Estimator] = ..., - estimatorParamMaps: Optional[List[ParamMap]] = ..., - evaluator: Optional[Evaluator] = ..., - trainRatio: float = ..., - parallelism: int = ..., - collectSubModels: bool = ..., - seed: Optional[int] = ..., - ) -> None: ... - def setParams( - self, - *, - estimator: Optional[Estimator] = ..., - estimatorParamMaps: Optional[List[ParamMap]] = ..., - evaluator: Optional[Evaluator] = ..., - trainRatio: float = ..., - parallelism: int = ..., - collectSubModels: bool = ..., - seed: Optional[int] = ..., - ) -> TrainValidationSplit: ... - def setEstimator(self, value: Estimator) -> TrainValidationSplit: ... - def setEstimatorParamMaps(self, value: List[ParamMap]) -> TrainValidationSplit: ... - def setEvaluator(self, value: Evaluator) -> TrainValidationSplit: ... - def setTrainRatio(self, value: float) -> TrainValidationSplit: ... - def setSeed(self, value: int) -> TrainValidationSplit: ... - def setParallelism(self, value: int) -> TrainValidationSplit: ... - def setCollectSubModels(self, value: bool) -> TrainValidationSplit: ... - def copy(self, extra: Optional[ParamMap] = ...) -> TrainValidationSplit: ... - def write(self) -> MLWriter: ... - @classmethod - def read(cls: Type[TrainValidationSplit]) -> MLReader: ... - -class TrainValidationSplitModel( - Model, - _TrainValidationSplitParams, - MLReadable[TrainValidationSplitModel], - MLWritable, -): - bestModel: Model - validationMetrics: List[float] - subModels: List[Model] - def __init__( - self, - bestModel: Model, - validationMetrics: Optional[List[float]] = ..., - subModels: Optional[List[Model]] = ..., - ) -> None: ... - def setEstimator(self, value: Estimator) -> TrainValidationSplitModel: ... - def setEstimatorParamMaps(self, value: List[ParamMap]) -> TrainValidationSplitModel: ... - def setEvaluator(self, value: Evaluator) -> TrainValidationSplitModel: ... - def copy(self, extra: Optional[ParamMap] = ...) -> TrainValidationSplitModel: ... - def write(self) -> MLWriter: ... - @classmethod - def read(cls: Type[TrainValidationSplitModel]) -> MLReader: ... - -class CrossValidatorWriter(MLWriter): - instance: CrossValidator - def __init__(self, instance: CrossValidator) -> None: ... - def saveImpl(self, path: str) -> None: ... - -class CrossValidatorReader(MLReader[CrossValidator]): - cls: Type[CrossValidator] - def __init__(self, cls: Type[CrossValidator]) -> None: ... - def load(self, path: str) -> CrossValidator: ... - -class CrossValidatorModelWriter(MLWriter): - instance: CrossValidatorModel - def __init__(self, instance: CrossValidatorModel) -> None: ... - def saveImpl(self, path: str) -> None: ... - -class CrossValidatorModelReader(MLReader[CrossValidatorModel]): - cls: Type[CrossValidatorModel] - def __init__(self, cls: Type[CrossValidatorModel]) -> None: ... - def load(self, path: str) -> CrossValidatorModel: ... - -class TrainValidationSplitWriter(MLWriter): - instance: TrainValidationSplit - def __init__(self, instance: TrainValidationSplit) -> None: ... - def saveImpl(self, path: str) -> None: ... - -class TrainValidationSplitReader(MLReader[TrainValidationSplit]): - cls: Type[TrainValidationSplit] - def __init__(self, cls: Type[TrainValidationSplit]) -> None: ... - def load(self, path: str) -> TrainValidationSplit: ... - -class TrainValidationSplitModelWriter(MLWriter): - instance: TrainValidationSplitModel - def __init__(self, instance: TrainValidationSplitModel) -> None: ... - def saveImpl(self, path: str) -> None: ... - -class TrainValidationSplitModelReader(MLReader[TrainValidationSplitModel]): - cls: Type[TrainValidationSplitModel] - def __init__(self, cls: Type[TrainValidationSplitModel]) -> None: ... - def load(self, path: str) -> TrainValidationSplitModel: ... diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index ac60deda53c46..14e62ce6217c8 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -20,13 +20,29 @@ import time import uuid +from typing import Any, Dict, Generic, List, Optional, Sequence, Type, TypeVar, cast, TYPE_CHECKING + + from pyspark import SparkContext, since from pyspark.ml.common import inherit_doc from pyspark.sql import SparkSession from pyspark.util import VersionUtils +if TYPE_CHECKING: + from py4j.java_gateway import JavaGateway, JavaObject + from pyspark.ml._typing import PipelineStage + from pyspark.ml.base import Params + from pyspark.ml.wrapper import JavaWrapper + +T = TypeVar("T") +RW = TypeVar("RW", bound="BaseReadWrite") +W = TypeVar("W", bound="MLWriter") +JW = TypeVar("JW", bound="JavaMLWriter") +RL = TypeVar("RL", bound="MLReadable") +JR = TypeVar("JR", bound="JavaMLReader") + -def _jvm(): +def _jvm() -> "JavaGateway": """ Returns the JVM view associated with SparkContext. Must be called after SparkContext is initialized. @@ -43,15 +59,15 @@ class Identifiable: Object with a unique ID. """ - def __init__(self): + def __init__(self) -> None: #: A unique id for the object. self.uid = self._randomUID() - def __repr__(self): + def __repr__(self) -> str: return self.uid @classmethod - def _randomUID(cls): + def _randomUID(cls) -> str: """ Generate a unique string id for the object. The default implementation concatenates the class name, "_", and 12 random hex chars. @@ -68,10 +84,10 @@ class BaseReadWrite: .. versionadded:: 2.3.0 """ - def __init__(self): - self._sparkSession = None + def __init__(self) -> None: + self._sparkSession: Optional[SparkSession] = None - def session(self, sparkSession): + def session(self: RW, sparkSession: SparkSession) -> RW: """ Sets the Spark Session to use for saving/loading. """ @@ -79,19 +95,21 @@ def session(self, sparkSession): return self @property - def sparkSession(self): + def sparkSession(self) -> SparkSession: """ Returns the user-specified Spark Session or the default. """ if self._sparkSession is None: self._sparkSession = SparkSession._getActiveSessionOrCreate() + assert self._sparkSession is not None return self._sparkSession @property - def sc(self): + def sc(self) -> SparkContext: """ Returns the underlying `SparkContext`. """ + assert self.sparkSession is not None return self.sparkSession.sparkContext @@ -103,37 +121,37 @@ class MLWriter(BaseReadWrite): .. versionadded:: 2.0.0 """ - def __init__(self): + def __init__(self) -> None: super(MLWriter, self).__init__() - self.shouldOverwrite = False - self.optionMap = {} + self.shouldOverwrite: bool = False + self.optionMap: Dict[str, Any] = {} - def _handleOverwrite(self, path): + def _handleOverwrite(self, path: str) -> None: from pyspark.ml.wrapper import JavaWrapper _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.util.FileSystemOverwrite") wrapper = JavaWrapper(_java_obj) wrapper._call_java("handleOverwrite", path, True, self.sparkSession._jsparkSession) - def save(self, path): + def save(self, path: str) -> None: """Save the ML instance to the input path.""" if self.shouldOverwrite: self._handleOverwrite(path) self.saveImpl(path) - def saveImpl(self, path): + def saveImpl(self, path: str) -> None: """ save() handles overwriting and then calls this method. Subclasses should override this method to implement the actual saving of the instance. """ raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) - def overwrite(self): + def overwrite(self) -> "MLWriter": """Overwrites if the output path already exists.""" self.shouldOverwrite = True return self - def option(self, key, value): + def option(self, key: str, value: Any) -> "MLWriter": """ Adds an option to the underlying MLWriter. See the documentation for the specific model's writer for possible options. The option name (key) is case-insensitive. @@ -150,7 +168,7 @@ class GeneralMLWriter(MLWriter): .. versionadded:: 2.4.0 """ - def format(self, source): + def format(self, source: str) -> "GeneralMLWriter": """ Specifies the format of ML export ("pmml", "internal", or the fully qualified class name for export). @@ -165,27 +183,29 @@ class JavaMLWriter(MLWriter): (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaParams` types """ - def __init__(self, instance): + _jwrite: "JavaObject" + + def __init__(self, instance: "JavaMLWritable"): super(JavaMLWriter, self).__init__() - _java_obj = instance._to_java() + _java_obj = instance._to_java() # type: ignore[attr-defined] self._jwrite = _java_obj.write() - def save(self, path): + def save(self, path: str) -> None: """Save the ML instance to the input path.""" if not isinstance(path, str): raise TypeError("path should be a string, got type %s" % type(path)) self._jwrite.save(path) - def overwrite(self): + def overwrite(self) -> "JavaMLWriter": """Overwrites if the output path already exists.""" self._jwrite.overwrite() return self - def option(self, key, value): + def option(self, key: str, value: str) -> "JavaMLWriter": self._jwrite.option(key, value) return self - def session(self, sparkSession): + def session(self, sparkSession: SparkSession) -> "JavaMLWriter": """Sets the Spark Session to use for saving.""" self._jwrite.session(sparkSession._jsparkSession) return self @@ -197,10 +217,10 @@ class GeneralJavaMLWriter(JavaMLWriter): (Private) Specialization of :py:class:`GeneralMLWriter` for :py:class:`JavaParams` types """ - def __init__(self, instance): + def __init__(self, instance: "JavaMLWritable"): super(GeneralJavaMLWriter, self).__init__(instance) - def format(self, source): + def format(self, source: str) -> "GeneralJavaMLWriter": """ Specifies the format of ML export ("pmml", "internal", or the fully qualified class name for export). @@ -217,11 +237,11 @@ class MLWritable: .. versionadded:: 2.0.0 """ - def write(self): + def write(self) -> MLWriter: """Returns an MLWriter instance for this ML instance.""" raise NotImplementedError("MLWritable is not yet implemented for type: %r" % type(self)) - def save(self, path): + def save(self, path: str) -> None: """Save this ML instance to the given path, a shortcut of 'write().save(path)'.""" self.write().save(path) @@ -232,7 +252,7 @@ class JavaMLWritable(MLWritable): (Private) Mixin for ML instances that provide :py:class:`JavaMLWriter`. """ - def write(self): + def write(self) -> JavaMLWriter: """Returns an MLWriter instance for this ML instance.""" return JavaMLWriter(self) @@ -243,39 +263,39 @@ class GeneralJavaMLWritable(JavaMLWritable): (Private) Mixin for ML instances that provide :py:class:`GeneralJavaMLWriter`. """ - def write(self): + def write(self) -> GeneralJavaMLWriter: """Returns an GeneralMLWriter instance for this ML instance.""" return GeneralJavaMLWriter(self) @inherit_doc -class MLReader(BaseReadWrite): +class MLReader(BaseReadWrite, Generic[RL]): """ Utility class that can load ML instances. .. versionadded:: 2.0.0 """ - def __init__(self): + def __init__(self) -> None: super(MLReader, self).__init__() - def load(self, path): + def load(self, path: str) -> RL: """Load the ML instance from the input path.""" raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self)) @inherit_doc -class JavaMLReader(MLReader): +class JavaMLReader(MLReader[RL]): """ (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaParams` types """ - def __init__(self, clazz): + def __init__(self, clazz: Type["JavaMLReadable[RL]"]) -> None: super(JavaMLReader, self).__init__() self._clazz = clazz self._jread = self._load_java_obj(clazz).read() - def load(self, path): + def load(self, path: str) -> RL: """Load the ML instance from the input path.""" if not isinstance(path, str): raise TypeError("path should be a string, got type %s" % type(path)) @@ -284,15 +304,15 @@ def load(self, path): raise NotImplementedError( "This Java ML type cannot be loaded into Python currently: %r" % self._clazz ) - return self._clazz._from_java(java_obj) + return self._clazz._from_java(java_obj) # type: ignore[attr-defined] - def session(self, sparkSession): + def session(self: JR, sparkSession: SparkSession) -> JR: """Sets the Spark Session to use for loading.""" self._jread.session(sparkSession._jsparkSession) return self @classmethod - def _java_loader_class(cls, clazz): + def _java_loader_class(cls, clazz: Type["JavaMLReadable[RL]"]) -> str: """ Returns the full class name of the Java ML instance. The default implementation replaces "pyspark" by "org.apache.spark" in @@ -305,7 +325,7 @@ def _java_loader_class(cls, clazz): return java_package + "." + clazz.__name__ @classmethod - def _load_java_obj(cls, clazz): + def _load_java_obj(cls, clazz: Type["JavaMLReadable[RL]"]) -> "JavaObject": """Load the peer Java object of the ML instance.""" java_class = cls._java_loader_class(clazz) java_obj = _jvm() @@ -315,7 +335,7 @@ def _load_java_obj(cls, clazz): @inherit_doc -class MLReadable: +class MLReadable(Generic[RL]): """ Mixin for instances that provide :py:class:`MLReader`. @@ -323,24 +343,24 @@ class MLReadable: """ @classmethod - def read(cls): + def read(cls) -> MLReader[RL]: """Returns an MLReader instance for this class.""" raise NotImplementedError("MLReadable.read() not implemented for type: %r" % cls) @classmethod - def load(cls, path): + def load(cls, path: str) -> RL: """Reads an ML instance from the input path, a shortcut of `read().load(path)`.""" return cls.read().load(path) @inherit_doc -class JavaMLReadable(MLReadable): +class JavaMLReadable(MLReadable[RL]): """ (Private) Mixin for instances that provide JavaMLReader. """ @classmethod - def read(cls): + def read(cls) -> JavaMLReader[RL]: """Returns an MLReader instance for this class.""" return JavaMLReader(cls) @@ -358,7 +378,7 @@ class stores all data as :py:class:`Param` values, then extending this trait wil .. versionadded:: 2.3.0 """ - def write(self): + def write(self) -> MLWriter: """Returns a DefaultParamsWriter instance for this class.""" from pyspark.ml.param import Params @@ -382,15 +402,15 @@ class DefaultParamsWriter(MLWriter): .. versionadded:: 2.3.0 """ - def __init__(self, instance): + def __init__(self, instance: "Params"): super(DefaultParamsWriter, self).__init__() self.instance = instance - def saveImpl(self, path): + def saveImpl(self, path: str) -> None: DefaultParamsWriter.saveMetadata(self.instance, path, self.sc) @staticmethod - def extractJsonParams(instance, skipParams): + def extractJsonParams(instance: "Params", skipParams: Sequence[str]) -> Dict[str, Any]: paramMap = instance.extractParamMap() jsonParams = { param.name: value for param, value in paramMap.items() if param.name not in skipParams @@ -398,7 +418,13 @@ def extractJsonParams(instance, skipParams): return jsonParams @staticmethod - def saveMetadata(instance, path, sc, extraMetadata=None, paramMap=None): + def saveMetadata( + instance: "Params", + path: str, + sc: SparkContext, + extraMetadata: Optional[Dict[str, Any]] = None, + paramMap: Optional[Dict[str, Any]] = None, + ) -> None: """ Saves metadata + Params to: path + "/metadata" @@ -424,7 +450,12 @@ def saveMetadata(instance, path, sc, extraMetadata=None, paramMap=None): sc.parallelize([metadataJson], 1).saveAsTextFile(metadataPath) @staticmethod - def _get_metadata_to_save(instance, sc, extraMetadata=None, paramMap=None): + def _get_metadata_to_save( + instance: "Params", + sc: SparkContext, + extraMetadata: Optional[Dict[str, Any]] = None, + paramMap: Optional[Dict[str, Any]] = None, + ) -> str: """ Helper for :py:meth:`DefaultParamsWriter.saveMetadata` which extracts the JSON to save. This is useful for ensemble models which need to save metadata for many sub-models. @@ -460,11 +491,11 @@ def _get_metadata_to_save(instance, sc, extraMetadata=None, paramMap=None): } if extraMetadata is not None: basicMetadata.update(extraMetadata) - return json.dumps(basicMetadata, separators=[",", ":"]) + return json.dumps(basicMetadata, separators=(",", ":")) @inherit_doc -class DefaultParamsReadable(MLReadable): +class DefaultParamsReadable(MLReadable[RL]): """ Helper trait for making simple :py:class:`Params` types readable. If a :py:class:`Params` class stores all data as :py:class:`Param` values, @@ -477,13 +508,13 @@ class DefaultParamsReadable(MLReadable): """ @classmethod - def read(cls): + def read(cls) -> "DefaultParamsReader[RL]": """Returns a DefaultParamsReader instance for this class.""" return DefaultParamsReader(cls) @inherit_doc -class DefaultParamsReader(MLReader): +class DefaultParamsReader(MLReader[RL]): """ Specialization of :py:class:`MLReader` for :py:class:`Params` types @@ -494,12 +525,12 @@ class DefaultParamsReader(MLReader): .. versionadded:: 2.3.0 """ - def __init__(self, cls): + def __init__(self, cls: Type[DefaultParamsReadable[RL]]): super(DefaultParamsReader, self).__init__() self.cls = cls @staticmethod - def __get_class(clazz): + def __get_class(clazz: str) -> Type[RL]: """ Loads Python class from its name. """ @@ -510,16 +541,16 @@ def __get_class(clazz): m = getattr(m, comp) return m - def load(self, path): + def load(self, path: str) -> RL: metadata = DefaultParamsReader.loadMetadata(path, self.sc) - py_type = DefaultParamsReader.__get_class(metadata["class"]) + py_type: Type[RL] = DefaultParamsReader.__get_class(metadata["class"]) instance = py_type() - instance._resetUid(metadata["uid"]) + cast("Params", instance)._resetUid(metadata["uid"]) DefaultParamsReader.getAndSetParams(instance, metadata) return instance @staticmethod - def loadMetadata(path, sc, expectedClassName=""): + def loadMetadata(path: str, sc: SparkContext, expectedClassName: str = "") -> Dict[str, Any]: """ Load metadata saved using :py:meth:`DefaultParamsWriter.saveMetadata` @@ -536,7 +567,7 @@ def loadMetadata(path, sc, expectedClassName=""): return loadedVals @staticmethod - def _parseMetaData(metadataStr, expectedClassName=""): + def _parseMetaData(metadataStr: str, expectedClassName: str = "") -> Dict[str, Any]: """ Parse metadata JSON string produced by :py:meth`DefaultParamsWriter._get_metadata_to_save`. This is a helper function for :py:meth:`DefaultParamsReader.loadMetadata`. @@ -558,16 +589,18 @@ def _parseMetaData(metadataStr, expectedClassName=""): return metadata @staticmethod - def getAndSetParams(instance, metadata, skipParams=None): + def getAndSetParams( + instance: RL, metadata: Dict[str, Any], skipParams: Optional[List[str]] = None + ) -> None: """ Extract Params from metadata, and set them in the instance. """ # Set user-supplied param values for paramName in metadata["paramMap"]: - param = instance.getParam(paramName) + param = cast("Params", instance).getParam(paramName) if skipParams is None or paramName not in skipParams: paramValue = metadata["paramMap"][paramName] - instance.set(param, paramValue) + cast("Params", instance).set(param, paramValue) # Set default param values majorAndMinorVersions = VersionUtils.majorMinorVersion(metadata["sparkVersion"]) @@ -582,14 +615,14 @@ def getAndSetParams(instance, metadata, skipParams=None): for paramName in metadata["defaultParamMap"]: paramValue = metadata["defaultParamMap"][paramName] - instance._setDefault(**{paramName: paramValue}) + cast("Params", instance)._setDefault(**{paramName: paramValue}) @staticmethod - def isPythonParamsInstance(metadata): + def isPythonParamsInstance(metadata: Dict[str, Any]) -> bool: return metadata["class"].startswith("pyspark.ml.") @staticmethod - def loadParamsInstance(path, sc): + def loadParamsInstance(path: str, sc: SparkContext) -> RL: """ Load a :py:class:`Params` instance from the given path, and return it. This assumes the instance inherits from :py:class:`MLReadable`. @@ -599,41 +632,41 @@ def loadParamsInstance(path, sc): pythonClassName = metadata["class"] else: pythonClassName = metadata["class"].replace("org.apache.spark", "pyspark") - py_type = DefaultParamsReader.__get_class(pythonClassName) + py_type: Type[RL] = DefaultParamsReader.__get_class(pythonClassName) instance = py_type.load(path) return instance @inherit_doc -class HasTrainingSummary: +class HasTrainingSummary(Generic[T]): """ Base class for models that provides Training summary. .. versionadded:: 3.0.0 """ - @property + @property # type: ignore[misc] @since("2.1.0") - def hasSummary(self): + def hasSummary(self) -> bool: """ Indicates whether a training summary exists for this model instance. """ - return self._call_java("hasSummary") + return cast("JavaWrapper", self)._call_java("hasSummary") - @property + @property # type: ignore[misc] @since("2.1.0") - def summary(self): + def summary(self) -> T: """ Gets summary of the model trained on the training set. An exception is thrown if no summary exists. """ - return self._call_java("summary") + return cast("JavaWrapper", self)._call_java("summary") class MetaAlgorithmReadWrite: @staticmethod - def isMetaEstimator(pyInstance): + def isMetaEstimator(pyInstance: Any) -> bool: from pyspark.ml import Estimator, Pipeline from pyspark.ml.tuning import _ValidatorParams from pyspark.ml.classification import OneVsRest @@ -645,23 +678,27 @@ def isMetaEstimator(pyInstance): ) @staticmethod - def getAllNestedStages(pyInstance): + def getAllNestedStages(pyInstance: Any) -> List["Params"]: from pyspark.ml import Pipeline, PipelineModel from pyspark.ml.tuning import _ValidatorParams from pyspark.ml.classification import OneVsRest, OneVsRestModel # TODO: We need to handle `RFormulaModel.pipelineModel` here after Pyspark RFormulaModel # support pipelineModel property. + pySubStages: Sequence["Params"] + if isinstance(pyInstance, Pipeline): pySubStages = pyInstance.getStages() elif isinstance(pyInstance, PipelineModel): - pySubStages = pyInstance.stages + pySubStages = cast(List["PipelineStage"], pyInstance.stages) elif isinstance(pyInstance, _ValidatorParams): raise ValueError("PySpark does not support nested validator.") elif isinstance(pyInstance, OneVsRest): pySubStages = [pyInstance.getClassifier()] elif isinstance(pyInstance, OneVsRestModel): - pySubStages = [pyInstance.getClassifier()] + pyInstance.models + pySubStages = [ + pyInstance.getClassifier() + ] + pyInstance.models # type: ignore[assignment, operator] else: pySubStages = [] @@ -672,7 +709,7 @@ def getAllNestedStages(pyInstance): return [pyInstance] + nestedStages @staticmethod - def getUidMap(instance): + def getUidMap(instance: Any) -> Dict[str, "Params"]: nestedStages = MetaAlgorithmReadWrite.getAllNestedStages(instance) uidMap = {stage.uid: stage for stage in nestedStages} if len(nestedStages) != len(uidMap): diff --git a/python/pyspark/ml/util.pyi b/python/pyspark/ml/util.pyi deleted file mode 100644 index db28c095a5568..0000000000000 --- a/python/pyspark/ml/util.pyi +++ /dev/null @@ -1,136 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import Any, Dict, Generic, Optional, Type, TypeVar, Union - -from pyspark import SparkContext as SparkContext, since as since # noqa: F401 -from pyspark.ml.common import inherit_doc as inherit_doc # noqa: F401 -from pyspark.sql import SparkSession as SparkSession -from pyspark.util import VersionUtils as VersionUtils # noqa: F401 - -S = TypeVar("S") -R = TypeVar("R", bound=MLReadable) - -class Identifiable: - uid: str - def __init__(self) -> None: ... - -class BaseReadWrite: - def __init__(self) -> None: ... - def session(self, sparkSession: SparkSession) -> Union[MLWriter, MLReader]: ... - @property - def sparkSession(self) -> SparkSession: ... - @property - def sc(self) -> SparkContext: ... - -class MLWriter(BaseReadWrite): - shouldOverwrite: bool = ... - def __init__(self) -> None: ... - def save(self, path: str) -> None: ... - def saveImpl(self, path: str) -> None: ... - def overwrite(self) -> MLWriter: ... - -class GeneralMLWriter(MLWriter): - source: str - def format(self, source: str) -> MLWriter: ... - -class JavaMLWriter(MLWriter): - def __init__(self, instance: JavaMLWritable) -> None: ... - def save(self, path: str) -> None: ... - def overwrite(self) -> JavaMLWriter: ... - def option(self, key: str, value: Any) -> JavaMLWriter: ... - def session(self, sparkSession: SparkSession) -> JavaMLWriter: ... - -class GeneralJavaMLWriter(JavaMLWriter): - def __init__(self, instance: MLWritable) -> None: ... - def format(self, source: str) -> GeneralJavaMLWriter: ... - -class MLWritable: - def write(self) -> MLWriter: ... - def save(self, path: str) -> None: ... - -class JavaMLWritable(MLWritable): - def write(self) -> JavaMLWriter: ... - -class GeneralJavaMLWritable(JavaMLWritable): - def write(self) -> GeneralJavaMLWriter: ... - -class MLReader(BaseReadWrite, Generic[R]): - def load(self, path: str) -> R: ... - -class JavaMLReader(MLReader[R]): - def __init__(self, clazz: Type[JavaMLReadable]) -> None: ... - def load(self, path: str) -> R: ... - def session(self, sparkSession: SparkSession) -> JavaMLReader[R]: ... - -class MLReadable(Generic[R]): - @classmethod - def read(cls: Type[R]) -> MLReader[R]: ... - @classmethod - def load(cls: Type[R], path: str) -> R: ... - -class JavaMLReadable(MLReadable[R]): - @classmethod - def read(cls: Type[R]) -> JavaMLReader[R]: ... - -class DefaultParamsWritable(MLWritable): - def write(self) -> MLWriter: ... - -class DefaultParamsWriter(MLWriter): - instance: DefaultParamsWritable - def __init__(self, instance: DefaultParamsWritable) -> None: ... - def saveImpl(self, path: str) -> None: ... - @staticmethod - def saveMetadata( - instance: DefaultParamsWritable, - path: str, - sc: SparkContext, - extraMetadata: Optional[Dict[str, Any]] = ..., - paramMap: Optional[Dict[str, Any]] = ..., - ) -> None: ... - -class DefaultParamsReadable(MLReadable[R]): - @classmethod - def read(cls: Type[R]) -> MLReader[R]: ... - -class DefaultParamsReader(MLReader[R]): - cls: Type[R] - def __init__(self, cls: Type[MLReadable]) -> None: ... - def load(self, path: str) -> R: ... - @staticmethod - def loadMetadata( - path: str, sc: SparkContext, expectedClassName: str = ... - ) -> Dict[str, Any]: ... - @staticmethod - def getAndSetParams(instance: R, metadata: Dict[str, Any]) -> None: ... - @staticmethod - def loadParamsInstance(path: str, sc: SparkContext) -> R: ... - -class HasTrainingSummary(Generic[S]): - @property - def hasSummary(self) -> bool: ... - @property - def summary(self) -> S: ... - -class MetaAlgorithmReadWrite: - @staticmethod - def isMetaEstimator(pyInstance: Any) -> bool: ... - @staticmethod - def getAllNestedStages(pyInstance: Any) -> list: ... - @staticmethod - def getUidMap(instance: Any) -> dict: ... diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index c35df2e5b6ef1..7853e76624464 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -17,49 +17,68 @@ from abc import ABCMeta, abstractmethod +from typing import Any, Generic, Optional, List, Type, TypeVar, TYPE_CHECKING + from pyspark import since from pyspark import SparkContext from pyspark.sql import DataFrame from pyspark.ml import Estimator, Predictor, PredictionModel, Transformer, Model from pyspark.ml.base import _PredictorParams -from pyspark.ml.param import Params +from pyspark.ml.param import Param, Params from pyspark.ml.util import _jvm from pyspark.ml.common import inherit_doc, _java2py, _py2java +if TYPE_CHECKING: + from pyspark.ml._typing import ParamMap + from py4j.java_gateway import JavaObject, JavaClass + + +T = TypeVar("T") +JW = TypeVar("JW", bound="JavaWrapper") +JM = TypeVar("JM", bound="JavaTransformer") +JP = TypeVar("JP", bound="JavaParams") + + class JavaWrapper: """ Wrapper class for a Java companion object """ - def __init__(self, java_obj=None): + def __init__(self, java_obj: Optional["JavaObject"] = None): super(JavaWrapper, self).__init__() self._java_obj = java_obj - def __del__(self): + def __del__(self) -> None: if SparkContext._active_spark_context and self._java_obj is not None: - SparkContext._active_spark_context._gateway.detach(self._java_obj) + SparkContext._active_spark_context._gateway.detach( # type: ignore[union-attr] + self._java_obj + ) @classmethod - def _create_from_java_class(cls, java_class, *args): + def _create_from_java_class(cls: Type[JW], java_class: str, *args: Any) -> JW: """ Construct this object from given Java classname and arguments """ java_obj = JavaWrapper._new_java_obj(java_class, *args) return cls(java_obj) - def _call_java(self, name, *args): + def _call_java(self, name: str, *args: Any) -> Any: m = getattr(self._java_obj, name) sc = SparkContext._active_spark_context + assert sc is not None + java_args = [_py2java(sc, arg) for arg in args] return _java2py(sc, m(*java_args)) @staticmethod - def _new_java_obj(java_class, *args): + def _new_java_obj(java_class: str, *args: Any) -> "JavaObject": """ Returns a new Java object. """ sc = SparkContext._active_spark_context + assert sc is not None + java_obj = _jvm() for name in java_class.split("."): java_obj = getattr(java_obj, name) @@ -67,7 +86,7 @@ def _new_java_obj(java_class, *args): return java_obj(*java_args) @staticmethod - def _new_java_array(pylist, java_class): + def _new_java_array(pylist: List[Any], java_class: "JavaClass") -> "JavaObject": """ Create a Java array of given java_class type. Useful for calling a method with a Scala Array from Python with Py4J. @@ -97,6 +116,9 @@ def _new_java_array(pylist, java_class): Java Array of converted pylist. """ sc = SparkContext._active_spark_context + assert sc is not None + assert sc._gateway is not None + java_array = None if len(pylist) > 0 and isinstance(pylist[0], list): # If pylist is a 2D array, then a 2D java array will be created. @@ -125,20 +147,24 @@ class JavaParams(JavaWrapper, Params, metaclass=ABCMeta): #: The param values in the Java object should be #: synced with the Python wrapper in fit/transform/evaluate/copy. - def _make_java_param_pair(self, param, value): + def _make_java_param_pair(self, param: Param[T], value: T) -> "JavaObject": """ Makes a Java param pair. """ sc = SparkContext._active_spark_context + assert sc is not None and self._java_obj is not None + param = self._resolveParam(param) java_param = self._java_obj.getParam(param.name) java_value = _py2java(sc, value) return java_param.w(java_value) - def _transfer_params_to_java(self): + def _transfer_params_to_java(self) -> None: """ Transforms the embedded params to the companion Java object. """ + assert self._java_obj is not None + pair_defaults = [] for param in self.params: if self.isSet(param): @@ -149,10 +175,12 @@ def _transfer_params_to_java(self): pair_defaults.append(pair) if len(pair_defaults) > 0: sc = SparkContext._active_spark_context + assert sc is not None and sc._jvm is not None + pair_defaults_seq = sc._jvm.PythonUtils.toSeq(pair_defaults) self._java_obj.setDefault(pair_defaults_seq) - def _transfer_param_map_to_java(self, pyParamMap): + def _transfer_param_map_to_java(self, pyParamMap: "ParamMap") -> "JavaObject": """ Transforms a Python ParamMap into a Java ParamMap. """ @@ -163,26 +191,30 @@ def _transfer_param_map_to_java(self, pyParamMap): paramMap.put([pair]) return paramMap - def _create_params_from_java(self): + def _create_params_from_java(self) -> None: """ SPARK-10931: Temporary fix to create params that are defined in the Java obj but not here """ + assert self._java_obj is not None + java_params = list(self._java_obj.params()) from pyspark.ml.param import Param for java_param in java_params: java_param_name = java_param.name() if not hasattr(self, java_param_name): - param = Param(self, java_param_name, java_param.doc()) + param: Param[Any] = Param(self, java_param_name, java_param.doc()) setattr(param, "created_from_java_param", True) setattr(self, java_param_name, param) self._params = None # need to reset so self.params will discover new params - def _transfer_params_from_java(self): + def _transfer_params_from_java(self) -> None: """ Transforms the embedded params from the companion Java object. """ sc = SparkContext._active_spark_context + assert sc is not None and self._java_obj is not None + for param in self.params: if self._java_obj.hasParam(param.name): java_param = self._java_obj.getParam(param.name) @@ -195,11 +227,13 @@ def _transfer_params_from_java(self): value = _java2py(sc, self._java_obj.getDefault(java_param)).get() self._setDefault(**{param.name: value}) - def _transfer_param_map_from_java(self, javaParamMap): + def _transfer_param_map_from_java(self, javaParamMap: "JavaObject") -> "ParamMap": """ Transforms a Java ParamMap into a Python ParamMap. """ sc = SparkContext._active_spark_context + assert sc is not None + paramMap = dict() for pair in javaParamMap.toList(): param = pair.param() @@ -208,13 +242,13 @@ def _transfer_param_map_from_java(self, javaParamMap): return paramMap @staticmethod - def _empty_java_param_map(): + def _empty_java_param_map() -> "JavaObject": """ Returns an empty Java ParamMap reference. """ return _jvm().org.apache.spark.ml.param.ParamMap() - def _to_java(self): + def _to_java(self) -> "JavaObject": """ Transfer this instance's Params to the wrapped Java object, and return the Java object. Used for ML persistence. @@ -230,7 +264,7 @@ def _to_java(self): return self._java_obj @staticmethod - def _from_java(java_stage): + def _from_java(java_stage: "JavaObject") -> "JP": """ Given a Java object, create and return a Python wrapper of it. Used for ML persistence. @@ -238,7 +272,7 @@ def _from_java(java_stage): Meta-algorithms such as Pipeline should override this method as a classmethod. """ - def __get_class(clazz): + def __get_class(clazz: str) -> Type[JP]: """ Loads Python class from its name. """ @@ -271,7 +305,7 @@ def __get_class(clazz): ) return py_stage - def copy(self, extra=None): + def copy(self: "JP", extra: Optional["ParamMap"] = None) -> "JP": """ Creates a copy of this instance with the same uid and some extra params. This implementation first calls Params.copy and @@ -297,30 +331,32 @@ def copy(self, extra=None): that._transfer_params_to_java() return that - def clear(self, param): + def clear(self, param: Param) -> None: """ Clears a param from the param map if it has been explicitly set. """ + assert self._java_obj is not None + super(JavaParams, self).clear(param) java_param = self._java_obj.getParam(param.name) self._java_obj.clear(java_param) @inherit_doc -class JavaEstimator(JavaParams, Estimator, metaclass=ABCMeta): +class JavaEstimator(JavaParams, Estimator[JM], metaclass=ABCMeta): """ Base class for :py:class:`Estimator`s that wrap Java/Scala implementations. """ @abstractmethod - def _create_model(self, java_model): + def _create_model(self, java_model: "JavaObject") -> JM: """ Creates a model from the input Java model reference. """ raise NotImplementedError() - def _fit_java(self, dataset): + def _fit_java(self, dataset: DataFrame) -> "JavaObject": """ Fits a Java model to the input dataset. @@ -334,10 +370,12 @@ def _fit_java(self, dataset): py4j.java_gateway.JavaObject fitted Java model """ + assert self._java_obj is not None + self._transfer_params_to_java() return self._java_obj.fit(dataset._jdf) - def _fit(self, dataset): + def _fit(self, dataset: DataFrame) -> JM: java_model = self._fit_java(dataset) model = self._create_model(java_model) return self._copyValues(model) @@ -351,9 +389,11 @@ class JavaTransformer(JavaParams, Transformer, metaclass=ABCMeta): available as _java_obj. """ - def _transform(self, dataset): + def _transform(self, dataset: DataFrame) -> DataFrame: + assert self._java_obj is not None + self._transfer_params_to_java() - return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sql_ctx) + return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sparkSession) @inherit_doc @@ -364,7 +404,7 @@ class JavaModel(JavaTransformer, Model, metaclass=ABCMeta): param mix-ins, because this sets the UID from the Java model. """ - def __init__(self, java_model=None): + def __init__(self, java_model: Optional["JavaObject"] = None): """ Initialize this instance with a Java model object. Subclasses should call this constructor, initialize params, @@ -388,12 +428,12 @@ def __init__(self, java_model=None): self._resetUid(java_model.uid()) - def __repr__(self): + def __repr__(self) -> str: return self._call_java("toString") @inherit_doc -class JavaPredictor(Predictor, JavaEstimator, _PredictorParams, metaclass=ABCMeta): +class JavaPredictor(Predictor, JavaEstimator[JM], _PredictorParams, Generic[JM], metaclass=ABCMeta): """ (Private) Java Estimator for prediction tasks (regression and classification). """ @@ -402,21 +442,21 @@ class JavaPredictor(Predictor, JavaEstimator, _PredictorParams, metaclass=ABCMet @inherit_doc -class JavaPredictionModel(PredictionModel, JavaModel, _PredictorParams): +class JavaPredictionModel(PredictionModel[T], JavaModel, _PredictorParams): """ (Private) Java Model for prediction tasks (regression and classification). """ - @property + @property # type: ignore[misc] @since("2.1.0") - def numFeatures(self): + def numFeatures(self) -> int: """ Returns the number of features the model was trained on. If unknown, returns -1 """ return self._call_java("numFeatures") @since("3.0.0") - def predict(self, value): + def predict(self, value: T) -> float: """ Predict label for the given features. """ diff --git a/python/pyspark/ml/wrapper.pyi b/python/pyspark/ml/wrapper.pyi deleted file mode 100644 index 7c3406a6d3438..0000000000000 --- a/python/pyspark/ml/wrapper.pyi +++ /dev/null @@ -1,46 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import abc -from typing import Any, Optional -from pyspark.ml._typing import P, T, JM, ParamMap - -from pyspark.ml import Estimator, Predictor, PredictionModel, Transformer, Model -from pyspark.ml.base import _PredictorParams -from pyspark.ml.param import Param, Params - -class JavaWrapper: - def __init__(self, java_obj: Optional[Any] = ...) -> None: ... - def __del__(self) -> None: ... - -class JavaParams(JavaWrapper, Params, metaclass=abc.ABCMeta): - def copy(self: P, extra: Optional[ParamMap] = ...) -> P: ... - def clear(self, param: Param) -> None: ... - -class JavaEstimator(JavaParams, Estimator[JM], metaclass=abc.ABCMeta): ... -class JavaTransformer(JavaParams, Transformer, metaclass=abc.ABCMeta): ... - -class JavaModel(JavaTransformer, Model, metaclass=abc.ABCMeta): - def __init__(self, java_model: Optional[Any] = ...) -> None: ... - -class JavaPredictor(Predictor[JM], JavaEstimator, _PredictorParams, metaclass=abc.ABCMeta): ... - -class JavaPredictionModel(PredictionModel[T], JavaModel, _PredictorParams): - @property - def numFeatures(self) -> int: ... - def predict(self, value: T) -> float: ... diff --git a/python/pyspark/mllib/_typing.pyi b/python/pyspark/mllib/_typing.pyi index 51a98cb0b016b..6a1a0f53a5950 100644 --- a/python/pyspark/mllib/_typing.pyi +++ b/python/pyspark/mllib/_typing.pyi @@ -17,6 +17,7 @@ # under the License. from typing import List, Tuple, TypeVar, Union +from typing_extensions import Literal from pyspark.mllib.linalg import Vector from numpy import ndarray # noqa: F401 from py4j.java_gateway import JavaObject @@ -24,3 +25,4 @@ from py4j.java_gateway import JavaObject VectorLike = Union[ndarray, Vector, List[float], Tuple[float, ...]] C = TypeVar("C", bound=type) JavaObjectOrPickleDump = Union[JavaObject, bytearray, bytes] +NormType = Union[None, float, Literal["fro"], Literal["nuc"]] diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index f302634882ef5..1a3b3581e969f 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -18,10 +18,12 @@ from math import exp import sys import warnings +from typing import Any, Iterable, Optional, Union, overload, TYPE_CHECKING import numpy -from pyspark import RDD, since +from pyspark import RDD, SparkContext, since +from pyspark.streaming.dstream import DStream from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py from pyspark.mllib.linalg import _convert_to_vector from pyspark.mllib.regression import ( @@ -31,6 +33,11 @@ StreamingLinearAlgorithm, ) from pyspark.mllib.util import Saveable, Loader, inherit_doc +from pyspark.mllib.linalg import Vector +from pyspark.mllib.regression import LabeledPoint + +if TYPE_CHECKING: + from pyspark.mllib._typing import VectorLike __all__ = [ @@ -51,12 +58,12 @@ class LinearClassificationModel(LinearModel): model. The categories are represented by int values: 0, 1, 2, etc. """ - def __init__(self, weights, intercept): + def __init__(self, weights: Vector, intercept: float) -> None: super(LinearClassificationModel, self).__init__(weights, intercept) - self._threshold = None + self._threshold: Optional[float] = None @since("1.4.0") - def setThreshold(self, value): + def setThreshold(self, value: float) -> None: """ Sets the threshold that separates positive predictions from negative predictions. An example with prediction score greater @@ -66,9 +73,9 @@ def setThreshold(self, value): """ self._threshold = value - @property + @property # type: ignore[misc] @since("1.4.0") - def threshold(self): + def threshold(self) -> Optional[float]: """ Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. It is used for @@ -77,18 +84,29 @@ def threshold(self): return self._threshold @since("1.4.0") - def clearThreshold(self): + def clearThreshold(self) -> None: """ Clears the threshold so that `predict` will output raw prediction scores. It is used for binary classification only. """ self._threshold = None - @since("1.4.0") - def predict(self, test): + @overload + def predict(self, test: "VectorLike") -> Union[int, float]: + ... + + @overload + def predict(self, test: RDD["VectorLike"]) -> RDD[Union[int, float]]: + ... + + def predict( + self, test: Union["VectorLike", RDD["VectorLike"]] + ) -> Union[RDD[Union[int, float]], Union[int, float]]: """ Predict values for a single data point or an RDD of points using the model trained. + + .. versionadded:: 1.4.0 """ raise NotImplementedError @@ -178,7 +196,9 @@ class LogisticRegressionModel(LinearClassificationModel): 2 """ - def __init__(self, weights, intercept, numFeatures, numClasses): + def __init__( + self, weights: Vector, intercept: float, numFeatures: int, numClasses: int + ) -> None: super(LogisticRegressionModel, self).__init__(weights, intercept) self._numFeatures = int(numFeatures) self._numClasses = int(numClasses) @@ -187,40 +207,53 @@ def __init__(self, weights, intercept, numFeatures, numClasses): self._dataWithBiasSize = None self._weightsMatrix = None else: - self._dataWithBiasSize = self._coeff.size // (self._numClasses - 1) + self._dataWithBiasSize = self._coeff.size // ( # type: ignore[attr-defined] + self._numClasses - 1 + ) self._weightsMatrix = self._coeff.toArray().reshape( self._numClasses - 1, self._dataWithBiasSize ) - @property + @property # type: ignore[misc] @since("1.4.0") - def numFeatures(self): + def numFeatures(self) -> int: """ Dimension of the features. """ return self._numFeatures - @property + @property # type: ignore[misc] @since("1.4.0") - def numClasses(self): + def numClasses(self) -> int: """ Number of possible outcomes for k classes classification problem in Multinomial Logistic Regression. """ return self._numClasses - @since("0.9.0") - def predict(self, x): + @overload + def predict(self, x: "VectorLike") -> Union[int, float]: + ... + + @overload + def predict(self, x: RDD["VectorLike"]) -> RDD[Union[int, float]]: + ... + + def predict( + self, x: Union["VectorLike", RDD["VectorLike"]] + ) -> Union[RDD[Union[int, float]], Union[int, float]]: """ Predict values for a single data point or an RDD of points using the model trained. + + .. versionadded:: 0.9.0 """ if isinstance(x, RDD): return x.map(lambda v: self.predict(v)) x = _convert_to_vector(x) if self.numClasses == 2: - margin = self.weights.dot(x) + self._intercept + margin = self.weights.dot(x) + self._intercept # type: ignore[attr-defined] if margin > 0: prob = 1 / (1 + exp(-margin)) else: @@ -231,29 +264,34 @@ def predict(self, x): else: return 1 if prob > self._threshold else 0 else: + assert self._weightsMatrix is not None + best_class = 0 max_margin = 0.0 - if x.size + 1 == self._dataWithBiasSize: + if x.size + 1 == self._dataWithBiasSize: # type: ignore[attr-defined] for i in range(0, self._numClasses - 1): margin = ( - x.dot(self._weightsMatrix[i][0 : x.size]) + self._weightsMatrix[i][x.size] + x.dot(self._weightsMatrix[i][0 : x.size]) # type: ignore[attr-defined] + + self._weightsMatrix[i][x.size] # type: ignore[attr-defined] ) if margin > max_margin: max_margin = margin best_class = i + 1 else: for i in range(0, self._numClasses - 1): - margin = x.dot(self._weightsMatrix[i]) + margin = x.dot(self._weightsMatrix[i]) # type: ignore[attr-defined] if margin > max_margin: max_margin = margin best_class = i + 1 return best_class @since("1.4.0") - def save(self, sc, path): + def save(self, sc: SparkContext, path: str) -> None: """ Save this model to the given path. """ + assert sc._jvm is not None + java_model = sc._jvm.org.apache.spark.mllib.classification.LogisticRegressionModel( _py2java(sc, self._coeff), self.intercept, self.numFeatures, self.numClasses ) @@ -261,10 +299,12 @@ def save(self, sc, path): @classmethod @since("1.4.0") - def load(cls, sc, path): + def load(cls, sc: SparkContext, path: str) -> "LogisticRegressionModel": """ Load a model from the given path. """ + assert sc._jvm is not None + java_model = sc._jvm.org.apache.spark.mllib.classification.LogisticRegressionModel.load( sc._jsc.sc(), path ) @@ -277,8 +317,11 @@ def load(cls, sc, path): model.setThreshold(threshold) return model - def __repr__(self): - return self._call_java("toString") + def __repr__(self) -> str: + return ( + "pyspark.mllib.LogisticRegressionModel: intercept = {}, " + "numFeatures = {}, numClasses = {}, threshold = {}" + ).format(self._intercept, self._numFeatures, self._numClasses, self._threshold) class LogisticRegressionWithSGD: @@ -293,17 +336,17 @@ class LogisticRegressionWithSGD: @classmethod def train( cls, - data, - iterations=100, - step=1.0, - miniBatchFraction=1.0, - initialWeights=None, - regParam=0.01, - regType="l2", - intercept=False, - validateData=True, - convergenceTol=0.001, - ): + data: RDD[LabeledPoint], + iterations: int = 100, + step: float = 1.0, + miniBatchFraction: float = 1.0, + initialWeights: Optional["VectorLike"] = None, + regParam: float = 0.01, + regType: str = "l2", + intercept: bool = False, + validateData: bool = True, + convergenceTol: float = 0.001, + ) -> LogisticRegressionModel: """ Train a logistic regression model on the given data. @@ -355,7 +398,7 @@ def train( FutureWarning, ) - def train(rdd, i): + def train(rdd: RDD[LabeledPoint], i: Vector) -> Iterable[Any]: return callMLlibFunc( "trainLogisticRegressionModelWithSGD", rdd, @@ -385,17 +428,17 @@ class LogisticRegressionWithLBFGS: @classmethod def train( cls, - data, - iterations=100, - initialWeights=None, - regParam=0.0, - regType="l2", - intercept=False, - corrections=10, - tolerance=1e-6, - validateData=True, - numClasses=2, - ): + data: RDD[LabeledPoint], + iterations: int = 100, + initialWeights: Optional["VectorLike"] = None, + regParam: float = 0.0, + regType: str = "l2", + intercept: bool = False, + corrections: int = 10, + tolerance: float = 1e-6, + validateData: bool = True, + numClasses: int = 2, + ) -> LogisticRegressionModel: """ Train a logistic regression model on the given data. @@ -457,7 +500,7 @@ def train( 0 """ - def train(rdd, i): + def train(rdd: RDD[LabeledPoint], i: Vector) -> Iterable[Any]: return callMLlibFunc( "trainLogisticRegressionModelWithLBFGS", rdd, @@ -541,31 +584,44 @@ class SVMModel(LinearClassificationModel): ... pass """ - def __init__(self, weights, intercept): + def __init__(self, weights: Vector, intercept: float) -> None: super(SVMModel, self).__init__(weights, intercept) self._threshold = 0.0 - @since("0.9.0") - def predict(self, x): + @overload + def predict(self, x: "VectorLike") -> Union[int, float]: + ... + + @overload + def predict(self, x: RDD["VectorLike"]) -> RDD[Union[int, float]]: + ... + + def predict( + self, x: Union["VectorLike", RDD["VectorLike"]] + ) -> Union[RDD[Union[int, float]], Union[int, float]]: """ Predict values for a single data point or an RDD of points using the model trained. + + .. versionadded:: 0.9.0 """ if isinstance(x, RDD): return x.map(lambda v: self.predict(v)) x = _convert_to_vector(x) - margin = self.weights.dot(x) + self.intercept + margin = self.weights.dot(x) + self.intercept # type: ignore[attr-defined] if self._threshold is None: return margin else: return 1 if margin > self._threshold else 0 @since("1.4.0") - def save(self, sc, path): + def save(self, sc: SparkContext, path: str) -> None: """ Save this model to the given path. """ + assert sc._jvm is not None + java_model = sc._jvm.org.apache.spark.mllib.classification.SVMModel( _py2java(sc, self._coeff), self.intercept ) @@ -573,10 +629,12 @@ def save(self, sc, path): @classmethod @since("1.4.0") - def load(cls, sc, path): + def load(cls, sc: SparkContext, path: str) -> "SVMModel": """ Load a model from the given path. """ + assert sc._jvm is not None + java_model = sc._jvm.org.apache.spark.mllib.classification.SVMModel.load(sc._jsc.sc(), path) weights = _java2py(sc, java_model.weights()) intercept = java_model.intercept() @@ -596,17 +654,17 @@ class SVMWithSGD: @classmethod def train( cls, - data, - iterations=100, - step=1.0, - regParam=0.01, - miniBatchFraction=1.0, - initialWeights=None, - regType="l2", - intercept=False, - validateData=True, - convergenceTol=0.001, - ): + data: RDD[LabeledPoint], + iterations: int = 100, + step: float = 1.0, + regParam: float = 0.01, + miniBatchFraction: float = 1.0, + initialWeights: Optional["VectorLike"] = None, + regType: str = "l2", + intercept: bool = False, + validateData: bool = True, + convergenceTol: float = 0.001, + ) -> SVMModel: """ Train a support vector machine on the given data. @@ -653,7 +711,7 @@ def train( (default: 0.001) """ - def train(rdd, i): + def train(rdd: RDD[LabeledPoint], i: Vector) -> Iterable[Any]: return callMLlibFunc( "trainSVMModelWithSGD", rdd, @@ -672,7 +730,7 @@ def train(rdd, i): @inherit_doc -class NaiveBayesModel(Saveable, Loader): +class NaiveBayesModel(Saveable, Loader["NaiveBayesModel"]): """ Model for Naive Bayes classifiers. @@ -727,13 +785,23 @@ class NaiveBayesModel(Saveable, Loader): ... pass """ - def __init__(self, labels, pi, theta): + def __init__(self, labels: numpy.ndarray, pi: numpy.ndarray, theta: numpy.ndarray) -> None: self.labels = labels self.pi = pi self.theta = theta + @overload + def predict(self, x: "VectorLike") -> numpy.float64: + ... + + @overload + def predict(self, x: RDD["VectorLike"]) -> RDD[numpy.float64]: + ... + @since("0.9.0") - def predict(self, x): + def predict( + self, x: Union["VectorLike", RDD["VectorLike"]] + ) -> Union[numpy.float64, RDD[numpy.float64]]: """ Return the most likely class for a data vector or an RDD of vectors @@ -741,12 +809,16 @@ def predict(self, x): if isinstance(x, RDD): return x.map(lambda v: self.predict(v)) x = _convert_to_vector(x) - return self.labels[numpy.argmax(self.pi + x.dot(self.theta.transpose()))] + return self.labels[ + numpy.argmax(self.pi + x.dot(self.theta.transpose())) # type: ignore[attr-defined] + ] - def save(self, sc, path): + def save(self, sc: SparkContext, path: str) -> None: """ Save this model to the given path. """ + assert sc._jvm is not None + java_labels = _py2java(sc, self.labels.tolist()) java_pi = _py2java(sc, self.pi.tolist()) java_theta = _py2java(sc, self.theta.tolist()) @@ -757,10 +829,12 @@ def save(self, sc, path): @classmethod @since("1.4.0") - def load(cls, sc, path): + def load(cls, sc: SparkContext, path: str) -> "NaiveBayesModel": """ Load a model from the given path. """ + assert sc._jvm is not None + java_model = sc._jvm.org.apache.spark.mllib.classification.NaiveBayesModel.load( sc._jsc.sc(), path ) @@ -779,7 +853,7 @@ class NaiveBayes: """ @classmethod - def train(cls, data, lambda_=1.0): + def train(cls, data: RDD[LabeledPoint], lambda_: float = 1.0) -> NaiveBayesModel: """ Train a Naive Bayes model given an RDD of (label, features) vectors. @@ -843,22 +917,24 @@ class StreamingLogisticRegressionWithSGD(StreamingLinearAlgorithm): def __init__( self, - stepSize=0.1, - numIterations=50, - miniBatchFraction=1.0, - regParam=0.0, - convergenceTol=0.001, - ): + stepSize: float = 0.1, + numIterations: int = 50, + miniBatchFraction: float = 1.0, + regParam: float = 0.0, + convergenceTol: float = 0.001, + ) -> None: self.stepSize = stepSize self.numIterations = numIterations self.regParam = regParam self.miniBatchFraction = miniBatchFraction self.convergenceTol = convergenceTol - self._model = None + self._model: Optional[LogisticRegressionModel] = None super(StreamingLogisticRegressionWithSGD, self).__init__(model=self._model) @since("1.5.0") - def setInitialWeights(self, initialWeights): + def setInitialWeights( + self, initialWeights: "VectorLike" + ) -> "StreamingLogisticRegressionWithSGD": """ Set the initial value of weights. @@ -867,15 +943,17 @@ def setInitialWeights(self, initialWeights): initialWeights = _convert_to_vector(initialWeights) # LogisticRegressionWithSGD does only binary classification. - self._model = LogisticRegressionModel(initialWeights, 0, initialWeights.size, 2) + self._model = LogisticRegressionModel( + initialWeights, 0, initialWeights.size, 2 # type: ignore[attr-defined] + ) return self @since("1.5.0") - def trainOn(self, dstream): + def trainOn(self, dstream: "DStream[LabeledPoint]") -> None: """Train the model on the incoming dstream.""" self._validate(dstream) - def update(rdd): + def update(rdd: RDD[LabeledPoint]) -> None: # LogisticRegressionWithSGD.train raises an error for an empty RDD. if not rdd.isEmpty(): self._model = LogisticRegressionWithSGD.train( @@ -883,7 +961,7 @@ def update(rdd): self.numIterations, self.stepSize, self.miniBatchFraction, - self._model.weights, + self._model.weights, # type: ignore[union-attr] regParam=self.regParam, convergenceTol=self.convergenceTol, ) @@ -891,7 +969,7 @@ def update(rdd): dstream.foreachRDD(update) -def _test(): +def _test() -> None: import doctest from pyspark.sql import SparkSession import pyspark.mllib.classification diff --git a/python/pyspark/mllib/classification.pyi b/python/pyspark/mllib/classification.pyi deleted file mode 100644 index ba88f6dcb2dda..0000000000000 --- a/python/pyspark/mllib/classification.pyi +++ /dev/null @@ -1,151 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import overload -from typing import Optional, Union - -from pyspark.context import SparkContext -from pyspark.rdd import RDD -from pyspark.mllib._typing import VectorLike -from pyspark.mllib.linalg import Vector -from pyspark.mllib.regression import LabeledPoint, LinearModel, StreamingLinearAlgorithm -from pyspark.mllib.util import Saveable, Loader -from pyspark.streaming.dstream import DStream - -from numpy import float64, ndarray - -class LinearClassificationModel(LinearModel): - def __init__(self, weights: Vector, intercept: float) -> None: ... - def setThreshold(self, value: float) -> None: ... - @property - def threshold(self) -> Optional[float]: ... - def clearThreshold(self) -> None: ... - @overload - def predict(self, test: VectorLike) -> Union[int, float, float64]: ... - @overload - def predict(self, test: RDD[VectorLike]) -> RDD[Union[int, float]]: ... - -class LogisticRegressionModel(LinearClassificationModel): - def __init__( - self, weights: Vector, intercept: float, numFeatures: int, numClasses: int - ) -> None: ... - @property - def numFeatures(self) -> int: ... - @property - def numClasses(self) -> int: ... - @overload - def predict(self, x: VectorLike) -> Union[int, float]: ... - @overload - def predict(self, x: RDD[VectorLike]) -> RDD[Union[int, float]]: ... - def save(self, sc: SparkContext, path: str) -> None: ... - @classmethod - def load(cls, sc: SparkContext, path: str) -> LogisticRegressionModel: ... - -class LogisticRegressionWithSGD: - @classmethod - def train( - cls, - data: RDD[LabeledPoint], - iterations: int = ..., - step: float = ..., - miniBatchFraction: float = ..., - initialWeights: Optional[VectorLike] = ..., - regParam: float = ..., - regType: str = ..., - intercept: bool = ..., - validateData: bool = ..., - convergenceTol: float = ..., - ) -> LogisticRegressionModel: ... - -class LogisticRegressionWithLBFGS: - @classmethod - def train( - cls, - data: RDD[LabeledPoint], - iterations: int = ..., - initialWeights: Optional[VectorLike] = ..., - regParam: float = ..., - regType: str = ..., - intercept: bool = ..., - corrections: int = ..., - tolerance: float = ..., - validateData: bool = ..., - numClasses: int = ..., - ) -> LogisticRegressionModel: ... - -class SVMModel(LinearClassificationModel): - def __init__(self, weights: Vector, intercept: float) -> None: ... - @overload # type: ignore - def predict(self, x: VectorLike) -> float64: ... - @overload - def predict(self, x: RDD[VectorLike]) -> RDD[float64]: ... - def save(self, sc: SparkContext, path: str) -> None: ... - @classmethod - def load(cls, sc: SparkContext, path: str) -> SVMModel: ... - -class SVMWithSGD: - @classmethod - def train( - cls, - data: RDD[LabeledPoint], - iterations: int = ..., - step: float = ..., - regParam: float = ..., - miniBatchFraction: float = ..., - initialWeights: Optional[VectorLike] = ..., - regType: str = ..., - intercept: bool = ..., - validateData: bool = ..., - convergenceTol: float = ..., - ) -> SVMModel: ... - -class NaiveBayesModel(Saveable, Loader[NaiveBayesModel]): - labels: ndarray - pi: ndarray - theta: ndarray - def __init__(self, labels: ndarray, pi: ndarray, theta: ndarray) -> None: ... - @overload - def predict(self, x: VectorLike) -> float64: ... - @overload - def predict(self, x: RDD[VectorLike]) -> RDD[float64]: ... - def save(self, sc: SparkContext, path: str) -> None: ... - @classmethod - def load(cls, sc: SparkContext, path: str) -> NaiveBayesModel: ... - -class NaiveBayes: - @classmethod - def train(cls, data: RDD[VectorLike], lambda_: float = ...) -> NaiveBayesModel: ... - -class StreamingLogisticRegressionWithSGD(StreamingLinearAlgorithm): - stepSize: float - numIterations: int - regParam: float - miniBatchFraction: float - convergenceTol: float - def __init__( - self, - stepSize: float = ..., - numIterations: int = ..., - miniBatchFraction: float = ..., - regParam: float = ..., - convergenceTol: float = ..., - ) -> None: ... - def setInitialWeights( - self, initialWeights: VectorLike - ) -> StreamingLogisticRegressionWithSGD: ... - def trainOn(self, dstream: DStream[LabeledPoint]) -> None: ... diff --git a/python/pyspark/mllib/clustering.pyi b/python/pyspark/mllib/clustering.pyi index f98348066b090..8a8401d35657f 100644 --- a/python/pyspark/mllib/clustering.pyi +++ b/python/pyspark/mllib/clustering.pyi @@ -22,7 +22,7 @@ from typing import List, NamedTuple, Optional, Tuple, TypeVar import array from numpy import float64, int64, ndarray -from py4j.java_gateway import JavaObject # type: ignore[import] +from py4j.java_gateway import JavaObject from pyspark.mllib._typing import VectorLike from pyspark.context import SparkContext diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index 24a3f411946d6..c5e1a7e8580c0 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -67,11 +67,9 @@ def _to_java_object_rdd(rdd: RDD) -> JavaObject: It will convert each Python object into Java object by Pickle, whenever the RDD is serialized in batch or not. """ - rdd = rdd._reserialize(AutoBatchedSerializer(CPickleSerializer())) # type: ignore[attr-defined] + rdd = rdd._reserialize(AutoBatchedSerializer(CPickleSerializer())) assert rdd.ctx._jvm is not None - return rdd.ctx._jvm.org.apache.spark.mllib.api.python.SerDe.pythonToJava( - rdd._jrdd, True # type: ignore[attr-defined] - ) + return rdd.ctx._jvm.org.apache.spark.mllib.api.python.SerDe.pythonToJava(rdd._jrdd, True) def _py2java(sc: SparkContext, obj: Any) -> JavaObject: @@ -81,7 +79,7 @@ def _py2java(sc: SparkContext, obj: Any) -> JavaObject: elif isinstance(obj, DataFrame): obj = obj._jdf elif isinstance(obj, SparkContext): - obj = obj._jsc # type: ignore[attr-defined] + obj = obj._jsc elif isinstance(obj, list): obj = [_py2java(sc, x) for x in obj] elif isinstance(obj, JavaObject): @@ -110,7 +108,7 @@ def _java2py(sc: SparkContext, r: "JavaObjectOrPickleDump", encoding: str = "byt return RDD(jrdd, sc) if clsName == "Dataset": - return DataFrame(r, SparkSession(sc)._wrapped) + return DataFrame(r, SparkSession._getActiveSessionOrCreate()) if clsName in _picklable_classes: r = sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(r) @@ -127,13 +125,13 @@ def _java2py(sc: SparkContext, r: "JavaObjectOrPickleDump", encoding: str = "byt def callJavaFunc( sc: pyspark.context.SparkContext, func: Callable[..., "JavaObjectOrPickleDump"], *args: Any -) -> "JavaObjectOrPickleDump": +) -> Any: """Call Java Function""" java_args = [_py2java(sc, a) for a in args] return _java2py(sc, func(*java_args)) -def callMLlibFunc(name: str, *args: Any) -> "JavaObjectOrPickleDump": +def callMLlibFunc(name: str, *args: Any) -> Any: """Call API in PythonMLLibAPI""" sc = SparkContext.getOrCreate() assert sc._jvm is not None @@ -154,7 +152,7 @@ def __del__(self) -> None: assert self._sc._gateway is not None self._sc._gateway.detach(self._java_model) - def call(self, name: str, *a: Any) -> "JavaObjectOrPickleDump": + def call(self, name: str, *a: Any) -> Any: """Call method of java_model""" return callJavaFunc(self._sc, getattr(self._java_model, name), *a) diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index b09783458510c..1003ba68c5fa0 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -15,12 +15,16 @@ # limitations under the License. # +from typing import Generic, List, Optional, Tuple, TypeVar + import sys from pyspark import since +from pyspark.rdd import RDD from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc +from pyspark.mllib.linalg import Matrix from pyspark.sql import SQLContext -from pyspark.sql.types import ArrayType, StructField, StructType, DoubleType +from pyspark.sql.types import ArrayType, DoubleType, StructField, StructType __all__ = [ "BinaryClassificationMetrics", @@ -29,6 +33,8 @@ "RankingMetrics", ] +T = TypeVar("T") + class BinaryClassificationMetrics(JavaModelWrapper): """ @@ -61,7 +67,7 @@ class BinaryClassificationMetrics(JavaModelWrapper): 0.88... """ - def __init__(self, scoreAndLabels): + def __init__(self, scoreAndLabels: RDD[Tuple[float, float]]): sc = scoreAndLabels.ctx sql_ctx = SQLContext.getOrCreate(sc) numCol = len(scoreAndLabels.first()) @@ -74,29 +80,30 @@ def __init__(self, scoreAndLabels): if numCol == 3: schema.add("weight", DoubleType(), False) df = sql_ctx.createDataFrame(scoreAndLabels, schema=schema) + assert sc._jvm is not None java_class = sc._jvm.org.apache.spark.mllib.evaluation.BinaryClassificationMetrics java_model = java_class(df._jdf) super(BinaryClassificationMetrics, self).__init__(java_model) - @property + @property # type: ignore[misc] @since("1.4.0") - def areaUnderROC(self): + def areaUnderROC(self) -> float: """ Computes the area under the receiver operating characteristic (ROC) curve. """ return self.call("areaUnderROC") - @property + @property # type: ignore[misc] @since("1.4.0") - def areaUnderPR(self): + def areaUnderPR(self) -> float: """ Computes the area under the precision-recall curve. """ return self.call("areaUnderPR") @since("1.4.0") - def unpersist(self): + def unpersist(self) -> None: """ Unpersists intermediate RDDs used in the computation. """ @@ -136,7 +143,7 @@ class RegressionMetrics(JavaModelWrapper): 0.68... """ - def __init__(self, predictionAndObservations): + def __init__(self, predictionAndObservations: RDD[Tuple[float, float]]): sc = predictionAndObservations.ctx sql_ctx = SQLContext.getOrCreate(sc) numCol = len(predictionAndObservations.first()) @@ -149,49 +156,50 @@ def __init__(self, predictionAndObservations): if numCol == 3: schema.add("weight", DoubleType(), False) df = sql_ctx.createDataFrame(predictionAndObservations, schema=schema) + assert sc._jvm is not None java_class = sc._jvm.org.apache.spark.mllib.evaluation.RegressionMetrics java_model = java_class(df._jdf) super(RegressionMetrics, self).__init__(java_model) - @property + @property # type: ignore[misc] @since("1.4.0") - def explainedVariance(self): + def explainedVariance(self) -> float: r""" Returns the explained variance regression score. explainedVariance = :math:`1 - \frac{variance(y - \hat{y})}{variance(y)}` """ return self.call("explainedVariance") - @property + @property # type: ignore[misc] @since("1.4.0") - def meanAbsoluteError(self): + def meanAbsoluteError(self) -> float: """ Returns the mean absolute error, which is a risk function corresponding to the expected value of the absolute error loss or l1-norm loss. """ return self.call("meanAbsoluteError") - @property + @property # type: ignore[misc] @since("1.4.0") - def meanSquaredError(self): + def meanSquaredError(self) -> float: """ Returns the mean squared error, which is a risk function corresponding to the expected value of the squared error loss or quadratic loss. """ return self.call("meanSquaredError") - @property + @property # type: ignore[misc] @since("1.4.0") - def rootMeanSquaredError(self): + def rootMeanSquaredError(self) -> float: """ Returns the root mean squared error, which is defined as the square root of the mean squared error. """ return self.call("rootMeanSquaredError") - @property + @property # type: ignore[misc] @since("1.4.0") - def r2(self): + def r2(self) -> float: """ Returns R^2^, the coefficient of determination. """ @@ -274,7 +282,7 @@ class MulticlassMetrics(JavaModelWrapper): 0.9682... """ - def __init__(self, predictionAndLabels): + def __init__(self, predictionAndLabels: RDD[Tuple[float, float]]): sc = predictionAndLabels.ctx sql_ctx = SQLContext.getOrCreate(sc) numCol = len(predictionAndLabels.first()) @@ -289,12 +297,13 @@ def __init__(self, predictionAndLabels): if numCol == 4: schema.add("probability", ArrayType(DoubleType(), False), False) df = sql_ctx.createDataFrame(predictionAndLabels, schema) + assert sc._jvm is not None java_class = sc._jvm.org.apache.spark.mllib.evaluation.MulticlassMetrics java_model = java_class(df._jdf) super(MulticlassMetrics, self).__init__(java_model) @since("1.4.0") - def confusionMatrix(self): + def confusionMatrix(self) -> Matrix: """ Returns confusion matrix: predicted classes are in columns, they are ordered by class label ascending, as in "labels". @@ -302,35 +311,35 @@ def confusionMatrix(self): return self.call("confusionMatrix") @since("1.4.0") - def truePositiveRate(self, label): + def truePositiveRate(self, label: float) -> float: """ Returns true positive rate for a given label (category). """ return self.call("truePositiveRate", label) @since("1.4.0") - def falsePositiveRate(self, label): + def falsePositiveRate(self, label: float) -> float: """ Returns false positive rate for a given label (category). """ return self.call("falsePositiveRate", label) @since("1.4.0") - def precision(self, label): + def precision(self, label: float) -> float: """ Returns precision. """ return self.call("precision", float(label)) @since("1.4.0") - def recall(self, label): + def recall(self, label: float) -> float: """ Returns recall. """ return self.call("recall", float(label)) @since("1.4.0") - def fMeasure(self, label, beta=None): + def fMeasure(self, label: float, beta: Optional[float] = None) -> float: """ Returns f-measure. """ @@ -339,51 +348,51 @@ def fMeasure(self, label, beta=None): else: return self.call("fMeasure", label, beta) - @property + @property # type: ignore[misc] @since("2.0.0") - def accuracy(self): + def accuracy(self) -> float: """ Returns accuracy (equals to the total number of correctly classified instances out of the total number of instances). """ return self.call("accuracy") - @property + @property # type: ignore[misc] @since("1.4.0") - def weightedTruePositiveRate(self): + def weightedTruePositiveRate(self) -> float: """ Returns weighted true positive rate. (equals to precision, recall and f-measure) """ return self.call("weightedTruePositiveRate") - @property + @property # type: ignore[misc] @since("1.4.0") - def weightedFalsePositiveRate(self): + def weightedFalsePositiveRate(self) -> float: """ Returns weighted false positive rate. """ return self.call("weightedFalsePositiveRate") - @property + @property # type: ignore[misc] @since("1.4.0") - def weightedRecall(self): + def weightedRecall(self) -> float: """ Returns weighted averaged recall. (equals to precision, recall and f-measure) """ return self.call("weightedRecall") - @property + @property # type: ignore[misc] @since("1.4.0") - def weightedPrecision(self): + def weightedPrecision(self) -> float: """ Returns weighted averaged precision. """ return self.call("weightedPrecision") @since("1.4.0") - def weightedFMeasure(self, beta=None): + def weightedFMeasure(self, beta: Optional[float] = None) -> float: """ Returns weighted averaged f-measure. """ @@ -393,14 +402,14 @@ def weightedFMeasure(self, beta=None): return self.call("weightedFMeasure", beta) @since("3.0.0") - def logLoss(self, eps=1e-15): + def logLoss(self, eps: float = 1e-15) -> float: """ Returns weighted logLoss. """ return self.call("logLoss", eps) -class RankingMetrics(JavaModelWrapper): +class RankingMetrics(JavaModelWrapper, Generic[T]): """ Evaluator for ranking algorithms. @@ -442,7 +451,7 @@ class RankingMetrics(JavaModelWrapper): 0.66... """ - def __init__(self, predictionAndLabels): + def __init__(self, predictionAndLabels: RDD[Tuple[List[T], List[T]]]): sc = predictionAndLabels.ctx sql_ctx = SQLContext.getOrCreate(sc) df = sql_ctx.createDataFrame( @@ -452,7 +461,7 @@ def __init__(self, predictionAndLabels): super(RankingMetrics, self).__init__(java_model) @since("1.4.0") - def precisionAt(self, k): + def precisionAt(self, k: int) -> float: """ Compute the average precision of all the queries, truncated at ranking position k. @@ -465,9 +474,9 @@ def precisionAt(self, k): """ return self.call("precisionAt", int(k)) - @property + @property # type: ignore[misc] @since("1.4.0") - def meanAveragePrecision(self): + def meanAveragePrecision(self) -> float: """ Returns the mean average precision (MAP) of all the queries. If a query has an empty ground truth set, the average precision will be zero and @@ -476,7 +485,7 @@ def meanAveragePrecision(self): return self.call("meanAveragePrecision") @since("3.0.0") - def meanAveragePrecisionAt(self, k): + def meanAveragePrecisionAt(self, k: int) -> float: """ Returns the mean average precision (MAP) at first k ranking of all the queries. If a query has an empty ground truth set, the average precision will be zero and @@ -485,7 +494,7 @@ def meanAveragePrecisionAt(self, k): return self.call("meanAveragePrecisionAt", int(k)) @since("1.4.0") - def ndcgAt(self, k): + def ndcgAt(self, k: int) -> float: """ Compute the average NDCG value of all the queries, truncated at ranking position k. The discounted cumulative gain at position k is computed as: @@ -498,7 +507,7 @@ def ndcgAt(self, k): return self.call("ndcgAt", int(k)) @since("3.0.0") - def recallAt(self, k): + def recallAt(self, k: int) -> float: """ Compute the average recall of all the queries, truncated at ranking position k. @@ -556,18 +565,19 @@ class MultilabelMetrics(JavaModelWrapper): 0.54... """ - def __init__(self, predictionAndLabels): + def __init__(self, predictionAndLabels: RDD[Tuple[List[float], List[float]]]): sc = predictionAndLabels.ctx sql_ctx = SQLContext.getOrCreate(sc) df = sql_ctx.createDataFrame( predictionAndLabels, schema=sql_ctx._inferSchema(predictionAndLabels) ) + assert sc._jvm is not None java_class = sc._jvm.org.apache.spark.mllib.evaluation.MultilabelMetrics java_model = java_class(df._jdf) super(MultilabelMetrics, self).__init__(java_model) @since("1.4.0") - def precision(self, label=None): + def precision(self, label: Optional[float] = None) -> float: """ Returns precision or precision for a given label (category) if specified. """ @@ -577,7 +587,7 @@ def precision(self, label=None): return self.call("precision", float(label)) @since("1.4.0") - def recall(self, label=None): + def recall(self, label: Optional[float] = None) -> float: """ Returns recall or recall for a given label (category) if specified. """ @@ -587,7 +597,7 @@ def recall(self, label=None): return self.call("recall", float(label)) @since("1.4.0") - def f1Measure(self, label=None): + def f1Measure(self, label: Optional[float] = None) -> float: """ Returns f1Measure or f1Measure for a given label (category) if specified. """ @@ -596,60 +606,60 @@ def f1Measure(self, label=None): else: return self.call("f1Measure", float(label)) - @property + @property # type: ignore[misc] @since("1.4.0") - def microPrecision(self): + def microPrecision(self) -> float: """ Returns micro-averaged label-based precision. (equals to micro-averaged document-based precision) """ return self.call("microPrecision") - @property + @property # type: ignore[misc] @since("1.4.0") - def microRecall(self): + def microRecall(self) -> float: """ Returns micro-averaged label-based recall. (equals to micro-averaged document-based recall) """ return self.call("microRecall") - @property + @property # type: ignore[misc] @since("1.4.0") - def microF1Measure(self): + def microF1Measure(self) -> float: """ Returns micro-averaged label-based f1-measure. (equals to micro-averaged document-based f1-measure) """ return self.call("microF1Measure") - @property + @property # type: ignore[misc] @since("1.4.0") - def hammingLoss(self): + def hammingLoss(self) -> float: """ Returns Hamming-loss. """ return self.call("hammingLoss") - @property + @property # type: ignore[misc] @since("1.4.0") - def subsetAccuracy(self): + def subsetAccuracy(self) -> float: """ Returns subset accuracy. (for equal sets of labels) """ return self.call("subsetAccuracy") - @property + @property # type: ignore[misc] @since("1.4.0") - def accuracy(self): + def accuracy(self) -> float: """ Returns accuracy. """ return self.call("accuracy") -def _test(): +def _test() -> None: import doctest import numpy from pyspark.sql import SparkSession diff --git a/python/pyspark/mllib/evaluation.pyi b/python/pyspark/mllib/evaluation.pyi deleted file mode 100644 index bbe0eebf33594..0000000000000 --- a/python/pyspark/mllib/evaluation.pyi +++ /dev/null @@ -1,92 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import List, Optional, Tuple, TypeVar -from pyspark.rdd import RDD -from pyspark.mllib.common import JavaModelWrapper -from pyspark.mllib.linalg import Matrix - -T = TypeVar("T") - -class BinaryClassificationMetrics(JavaModelWrapper): - def __init__(self, scoreAndLabels: RDD[Tuple[float, float]]) -> None: ... - @property - def areaUnderROC(self) -> float: ... - @property - def areaUnderPR(self) -> float: ... - def unpersist(self) -> None: ... - -class RegressionMetrics(JavaModelWrapper): - def __init__(self, predictionAndObservations: RDD[Tuple[float, float]]) -> None: ... - @property - def explainedVariance(self) -> float: ... - @property - def meanAbsoluteError(self) -> float: ... - @property - def meanSquaredError(self) -> float: ... - @property - def rootMeanSquaredError(self) -> float: ... - @property - def r2(self) -> float: ... - -class MulticlassMetrics(JavaModelWrapper): - def __init__(self, predictionAndLabels: RDD[Tuple[float, float]]) -> None: ... - def confusionMatrix(self) -> Matrix: ... - def truePositiveRate(self, label: float) -> float: ... - def falsePositiveRate(self, label: float) -> float: ... - def precision(self, label: float = ...) -> float: ... - def recall(self, label: float = ...) -> float: ... - def fMeasure(self, label: float = ..., beta: Optional[float] = ...) -> float: ... - @property - def accuracy(self) -> float: ... - @property - def weightedTruePositiveRate(self) -> float: ... - @property - def weightedFalsePositiveRate(self) -> float: ... - @property - def weightedRecall(self) -> float: ... - @property - def weightedPrecision(self) -> float: ... - def weightedFMeasure(self, beta: Optional[float] = ...) -> float: ... - -class RankingMetrics(JavaModelWrapper): - def __init__(self, predictionAndLabels: RDD[Tuple[List[T], List[T]]]) -> None: ... - def precisionAt(self, k: int) -> float: ... - @property - def meanAveragePrecision(self) -> float: ... - def meanAveragePrecisionAt(self, k: int) -> float: ... - def ndcgAt(self, k: int) -> float: ... - def recallAt(self, k: int) -> float: ... - -class MultilabelMetrics(JavaModelWrapper): - def __init__(self, predictionAndLabels: RDD[Tuple[List[float], List[float]]]) -> None: ... - def precision(self, label: Optional[float] = ...) -> float: ... - def recall(self, label: Optional[float] = ...) -> float: ... - def f1Measure(self, label: Optional[float] = ...) -> float: ... - @property - def microPrecision(self) -> float: ... - @property - def microRecall(self) -> float: ... - @property - def microF1Measure(self) -> float: ... - @property - def hammingLoss(self) -> float: ... - @property - def subsetAccuracy(self) -> float: ... - @property - def accuracy(self) -> float: ... diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 320ba0029a0c8..17dab6ac057e0 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -20,6 +20,8 @@ """ import sys import warnings +from typing import Dict, Hashable, Iterable, List, Optional, Tuple, Union, overload, TYPE_CHECKING + from py4j.protocol import Py4JJavaError from pyspark import since @@ -28,6 +30,15 @@ from pyspark.mllib.linalg import Vectors, _convert_to_vector from pyspark.mllib.util import JavaLoader, JavaSaveable +from pyspark.context import SparkContext +from pyspark.mllib.linalg import Vector +from pyspark.mllib.regression import LabeledPoint +from py4j.java_collections import JavaMap + +if TYPE_CHECKING: + from pyspark.mllib._typing import VectorLike + from py4j.java_collections import JavaMap + __all__ = [ "Normalizer", "StandardScalerModel", @@ -48,7 +59,17 @@ class VectorTransformer: Base class for transformation of a vector or RDD of vector """ - def transform(self, vector): + @overload + def transform(self, vector: "VectorLike") -> Vector: + ... + + @overload + def transform(self, vector: RDD["VectorLike"]) -> RDD[Vector]: + ... + + def transform( + self, vector: Union["VectorLike", RDD["VectorLike"]] + ) -> Union[Vector, RDD[Vector]]: """ Applies transformation on a vector. @@ -94,11 +115,21 @@ class Normalizer(VectorTransformer): DenseVector([0.0, 0.5, 1.0]) """ - def __init__(self, p=2.0): + def __init__(self, p: float = 2.0): assert p >= 1.0, "p should be greater than 1.0" self.p = float(p) - def transform(self, vector): + @overload + def transform(self, vector: "VectorLike") -> Vector: + ... + + @overload + def transform(self, vector: RDD["VectorLike"]) -> RDD[Vector]: + ... + + def transform( + self, vector: Union["VectorLike", RDD["VectorLike"]] + ) -> Union[Vector, RDD[Vector]]: """ Applies unit length normalization on a vector. @@ -127,7 +158,17 @@ class JavaVectorTransformer(JavaModelWrapper, VectorTransformer): Wrapper for the model in JVM """ - def transform(self, vector): + @overload + def transform(self, vector: "VectorLike") -> Vector: + ... + + @overload + def transform(self, vector: RDD["VectorLike"]) -> RDD[Vector]: + ... + + def transform( + self, vector: Union["VectorLike", RDD["VectorLike"]] + ) -> Union[Vector, RDD[Vector]]: """ Applies transformation on a vector or an RDD[Vector]. @@ -156,7 +197,17 @@ class StandardScalerModel(JavaVectorTransformer): .. versionadded:: 1.2.0 """ - def transform(self, vector): + @overload + def transform(self, vector: "VectorLike") -> Vector: + ... + + @overload + def transform(self, vector: RDD["VectorLike"]) -> RDD[Vector]: + ... + + def transform( + self, vector: Union["VectorLike", RDD["VectorLike"]] + ) -> Union[Vector, RDD[Vector]]: """ Applies standardization transformation on a vector. @@ -183,7 +234,7 @@ def transform(self, vector): return JavaVectorTransformer.transform(self, vector) @since("1.4.0") - def setWithMean(self, withMean): + def setWithMean(self, withMean: bool) -> "StandardScalerModel": """ Setter of the boolean which decides whether it uses mean or not @@ -192,7 +243,7 @@ def setWithMean(self, withMean): return self @since("1.4.0") - def setWithStd(self, withStd): + def setWithStd(self, withStd: bool) -> "StandardScalerModel": """ Setter of the boolean which decides whether it uses std or not @@ -200,33 +251,33 @@ def setWithStd(self, withStd): self.call("setWithStd", withStd) return self - @property + @property # type: ignore[misc] @since("2.0.0") - def withStd(self): + def withStd(self) -> bool: """ Returns if the model scales the data to unit standard deviation. """ return self.call("withStd") - @property + @property # type: ignore[misc] @since("2.0.0") - def withMean(self): + def withMean(self) -> bool: """ Returns if the model centers the data before scaling. """ return self.call("withMean") - @property + @property # type: ignore[misc] @since("2.0.0") - def std(self): + def std(self) -> Vector: """ Return the column standard deviation values. """ return self.call("std") - @property + @property # type: ignore[misc] @since("2.0.0") - def mean(self): + def mean(self) -> Vector: """ Return the column mean values. """ @@ -271,13 +322,13 @@ class StandardScaler: True """ - def __init__(self, withMean=False, withStd=True): + def __init__(self, withMean: bool = False, withStd: bool = True): if not (withMean or withStd): warnings.warn("Both withMean and withStd are false. The model does nothing.") self.withMean = withMean self.withStd = withStd - def fit(self, dataset): + def fit(self, dataset: RDD["VectorLike"]) -> "StandardScalerModel": """ Computes the mean and variance and stores as a model to be used for later scaling. @@ -306,7 +357,17 @@ class ChiSqSelectorModel(JavaVectorTransformer): .. versionadded:: 1.4.0 """ - def transform(self, vector): + @overload + def transform(self, vector: "VectorLike") -> Vector: + ... + + @overload + def transform(self, vector: RDD["VectorLike"]) -> RDD[Vector]: + ... + + def transform( + self, vector: Union["VectorLike", RDD["VectorLike"]] + ) -> Union[Vector, RDD[Vector]]: """ Applies transformation on a vector. @@ -379,12 +440,12 @@ class ChiSqSelector: def __init__( self, - numTopFeatures=50, - selectorType="numTopFeatures", - percentile=0.1, - fpr=0.05, - fdr=0.05, - fwe=0.05, + numTopFeatures: int = 50, + selectorType: str = "numTopFeatures", + percentile: float = 0.1, + fpr: float = 0.05, + fdr: float = 0.05, + fwe: float = 0.05, ): self.numTopFeatures = numTopFeatures self.selectorType = selectorType @@ -394,7 +455,7 @@ def __init__( self.fwe = fwe @since("2.1.0") - def setNumTopFeatures(self, numTopFeatures): + def setNumTopFeatures(self, numTopFeatures: int) -> "ChiSqSelector": """ set numTopFeature for feature selection by number of top features. Only applicable when selectorType = "numTopFeatures". @@ -403,7 +464,7 @@ def setNumTopFeatures(self, numTopFeatures): return self @since("2.1.0") - def setPercentile(self, percentile): + def setPercentile(self, percentile: float) -> "ChiSqSelector": """ set percentile [0.0, 1.0] for feature selection by percentile. Only applicable when selectorType = "percentile". @@ -412,7 +473,7 @@ def setPercentile(self, percentile): return self @since("2.1.0") - def setFpr(self, fpr): + def setFpr(self, fpr: float) -> "ChiSqSelector": """ set FPR [0.0, 1.0] for feature selection by FPR. Only applicable when selectorType = "fpr". @@ -421,7 +482,7 @@ def setFpr(self, fpr): return self @since("2.2.0") - def setFdr(self, fdr): + def setFdr(self, fdr: float) -> "ChiSqSelector": """ set FDR [0.0, 1.0] for feature selection by FDR. Only applicable when selectorType = "fdr". @@ -430,7 +491,7 @@ def setFdr(self, fdr): return self @since("2.2.0") - def setFwe(self, fwe): + def setFwe(self, fwe: float) -> "ChiSqSelector": """ set FWE [0.0, 1.0] for feature selection by FWE. Only applicable when selectorType = "fwe". @@ -439,7 +500,7 @@ def setFwe(self, fwe): return self @since("2.1.0") - def setSelectorType(self, selectorType): + def setSelectorType(self, selectorType: str) -> "ChiSqSelector": """ set the selector type of the ChisqSelector. Supported options: "numTopFeatures" (default), "percentile", "fpr", "fdr", "fwe". @@ -447,7 +508,7 @@ def setSelectorType(self, selectorType): self.selectorType = str(selectorType) return self - def fit(self, data): + def fit(self, data: RDD[LabeledPoint]) -> "ChiSqSelectorModel": """ Returns a ChiSquared feature selector. @@ -500,7 +561,7 @@ class PCA: -4.013... """ - def __init__(self, k): + def __init__(self, k: int): """ Parameters ---------- @@ -509,7 +570,7 @@ def __init__(self, k): """ self.k = int(k) - def fit(self, data): + def fit(self, data: RDD["VectorLike"]) -> PCAModel: """ Computes a [[PCAModel]] that contains the principal components of the input vectors. @@ -548,12 +609,12 @@ class HashingTF: SparseVector(100, {...}) """ - def __init__(self, numFeatures=1 << 20): + def __init__(self, numFeatures: int = 1 << 20): self.numFeatures = numFeatures self.binary = False @since("2.0.0") - def setBinary(self, value): + def setBinary(self, value: bool) -> "HashingTF": """ If True, term frequency vector will be binary such that non-zero term counts will be set to 1 @@ -563,12 +624,22 @@ def setBinary(self, value): return self @since("1.2.0") - def indexOf(self, term): + def indexOf(self, term: Hashable) -> int: """Returns the index of the input term.""" return hash(term) % self.numFeatures + @overload + def transform(self, document: Iterable[Hashable]) -> Vector: + ... + + @overload + def transform(self, document: RDD[Iterable[Hashable]]) -> RDD[Vector]: + ... + @since("1.2.0") - def transform(self, document): + def transform( + self, document: Union[Iterable[Hashable], RDD[Iterable[Hashable]]] + ) -> Union[Vector, RDD[Vector]]: """ Transforms the input document (list of terms) to term frequency vectors, or transform the RDD of document to RDD of term @@ -577,7 +648,7 @@ def transform(self, document): if isinstance(document, RDD): return document.map(self.transform) - freq = {} + freq: Dict[int, float] = {} for term in document: i = self.indexOf(term) freq[i] = 1.0 if self.binary else freq.get(i, 0) + 1.0 @@ -591,7 +662,15 @@ class IDFModel(JavaVectorTransformer): .. versionadded:: 1.2.0 """ - def transform(self, x): + @overload + def transform(self, x: "VectorLike") -> Vector: + ... + + @overload + def transform(self, x: RDD["VectorLike"]) -> RDD[Vector]: + ... + + def transform(self, x: Union["VectorLike", RDD["VectorLike"]]) -> Union[Vector, RDD[Vector]]: """ Transforms term frequency (TF) vectors to TF-IDF vectors. @@ -621,21 +700,21 @@ def transform(self, x): return JavaVectorTransformer.transform(self, x) @since("1.4.0") - def idf(self): + def idf(self) -> Vector: """ Returns the current IDF vector. """ return self.call("idf") @since("3.0.0") - def docFreq(self): + def docFreq(self) -> List[int]: """ Returns the document frequency. """ return self.call("docFreq") @since("3.0.0") - def numDocs(self): + def numDocs(self) -> int: """ Returns number of documents evaluated to compute idf """ @@ -684,10 +763,10 @@ class IDF: SparseVector(4, {1: 0.0, 3: 0.5754}) """ - def __init__(self, minDocFreq=0): + def __init__(self, minDocFreq: int = 0): self.minDocFreq = minDocFreq - def fit(self, dataset): + def fit(self, dataset: RDD["VectorLike"]) -> IDFModel: """ Computes the inverse document frequency. @@ -704,12 +783,12 @@ def fit(self, dataset): return IDFModel(jmodel) -class Word2VecModel(JavaVectorTransformer, JavaSaveable, JavaLoader): +class Word2VecModel(JavaVectorTransformer, JavaSaveable, JavaLoader["Word2VecModel"]): """ class for Word2Vec model """ - def transform(self, word): + def transform(self, word: str) -> Vector: # type: ignore[override] """ Transforms a word to its vector representation @@ -734,7 +813,7 @@ def transform(self, word): except Py4JJavaError: raise ValueError("%s not found" % word) - def findSynonyms(self, word, num): + def findSynonyms(self, word: Union[str, "VectorLike"], num: int) -> Iterable[Tuple[str, float]]: """ Find synonyms of a word @@ -763,7 +842,7 @@ def findSynonyms(self, word, num): return zip(words, similarity) @since("1.4.0") - def getVectors(self): + def getVectors(self) -> "JavaMap": """ Returns a map of words to their vector representations. """ @@ -771,10 +850,12 @@ def getVectors(self): @classmethod @since("1.5.0") - def load(cls, sc, path): + def load(cls, sc: SparkContext, path: str) -> "Word2VecModel": """ Load a model from the given path. """ + assert sc._jvm is not None + jmodel = sc._jvm.org.apache.spark.mllib.feature.Word2VecModel.load(sc._jsc.sc(), path) model = sc._jvm.org.apache.spark.mllib.api.python.Word2VecModelWrapper(jmodel) return Word2VecModel(model) @@ -837,7 +918,7 @@ class Word2Vec: ... pass """ - def __init__(self): + def __init__(self) -> None: """ Construct Word2Vec instance """ @@ -845,12 +926,12 @@ def __init__(self): self.learningRate = 0.025 self.numPartitions = 1 self.numIterations = 1 - self.seed = None + self.seed: Optional[int] = None self.minCount = 5 self.windowSize = 5 @since("1.2.0") - def setVectorSize(self, vectorSize): + def setVectorSize(self, vectorSize: int) -> "Word2Vec": """ Sets vector size (default: 100). """ @@ -858,7 +939,7 @@ def setVectorSize(self, vectorSize): return self @since("1.2.0") - def setLearningRate(self, learningRate): + def setLearningRate(self, learningRate: float) -> "Word2Vec": """ Sets initial learning rate (default: 0.025). """ @@ -866,7 +947,7 @@ def setLearningRate(self, learningRate): return self @since("1.2.0") - def setNumPartitions(self, numPartitions): + def setNumPartitions(self, numPartitions: int) -> "Word2Vec": """ Sets number of partitions (default: 1). Use a small number for accuracy. @@ -875,7 +956,7 @@ def setNumPartitions(self, numPartitions): return self @since("1.2.0") - def setNumIterations(self, numIterations): + def setNumIterations(self, numIterations: int) -> "Word2Vec": """ Sets number of iterations (default: 1), which should be smaller than or equal to number of partitions. @@ -884,7 +965,7 @@ def setNumIterations(self, numIterations): return self @since("1.2.0") - def setSeed(self, seed): + def setSeed(self, seed: int) -> "Word2Vec": """ Sets random seed. """ @@ -892,7 +973,7 @@ def setSeed(self, seed): return self @since("1.4.0") - def setMinCount(self, minCount): + def setMinCount(self, minCount: int) -> "Word2Vec": """ Sets minCount, the minimum number of times a token must appear to be included in the word2vec model's vocabulary (default: 5). @@ -901,14 +982,14 @@ def setMinCount(self, minCount): return self @since("2.0.0") - def setWindowSize(self, windowSize): + def setWindowSize(self, windowSize: int) -> "Word2Vec": """ Sets window size (default: 5). """ self.windowSize = windowSize return self - def fit(self, data): + def fit(self, data: RDD[List[str]]) -> "Word2VecModel": """ Computes the vector representation of each word in vocabulary. @@ -959,13 +1040,24 @@ class ElementwiseProduct(VectorTransformer): [DenseVector([2.0, 2.0, 9.0]), DenseVector([9.0, 6.0, 12.0])] """ - def __init__(self, scalingVector): + def __init__(self, scalingVector: Vector) -> None: self.scalingVector = _convert_to_vector(scalingVector) - @since("1.5.0") - def transform(self, vector): + @overload + def transform(self, vector: "VectorLike") -> Vector: + ... + + @overload + def transform(self, vector: RDD["VectorLike"]) -> RDD[Vector]: + ... + + def transform( + self, vector: Union["VectorLike", RDD["VectorLike"]] + ) -> Union[Vector, RDD[Vector]]: """ Computes the Hadamard product of the vector. + + .. versionadded:: 1.5.0 """ if isinstance(vector, RDD): vector = vector.map(_convert_to_vector) @@ -975,7 +1067,7 @@ def transform(self, vector): return callMLlibFunc("elementwiseProductVector", self.scalingVector, vector) -def _test(): +def _test() -> None: import doctest from pyspark.sql import SparkSession diff --git a/python/pyspark/mllib/feature.pyi b/python/pyspark/mllib/feature.pyi deleted file mode 100644 index e7ab7fc81a8ff..0000000000000 --- a/python/pyspark/mllib/feature.pyi +++ /dev/null @@ -1,169 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import overload -from typing import Iterable, Hashable, List, Tuple, Union - -from pyspark.mllib._typing import VectorLike -from pyspark.context import SparkContext -from pyspark.rdd import RDD -from pyspark.mllib.common import JavaModelWrapper -from pyspark.mllib.linalg import Vector -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.util import JavaLoader, JavaSaveable - -from py4j.java_collections import JavaMap # type: ignore[import] - -class VectorTransformer: - @overload - def transform(self, vector: VectorLike) -> Vector: ... - @overload - def transform(self, vector: RDD[VectorLike]) -> RDD[Vector]: ... - -class Normalizer(VectorTransformer): - p: float - def __init__(self, p: float = ...) -> None: ... - @overload - def transform(self, vector: VectorLike) -> Vector: ... - @overload - def transform(self, vector: RDD[VectorLike]) -> RDD[Vector]: ... - -class JavaVectorTransformer(JavaModelWrapper, VectorTransformer): - @overload - def transform(self, vector: VectorLike) -> Vector: ... - @overload - def transform(self, vector: RDD[VectorLike]) -> RDD[Vector]: ... - -class StandardScalerModel(JavaVectorTransformer): - @overload - def transform(self, vector: VectorLike) -> Vector: ... - @overload - def transform(self, vector: RDD[VectorLike]) -> RDD[Vector]: ... - def setWithMean(self, withMean: bool) -> StandardScalerModel: ... - def setWithStd(self, withStd: bool) -> StandardScalerModel: ... - @property - def withStd(self) -> bool: ... - @property - def withMean(self) -> bool: ... - @property - def std(self) -> Vector: ... - @property - def mean(self) -> Vector: ... - -class StandardScaler: - withMean: bool - withStd: bool - def __init__(self, withMean: bool = ..., withStd: bool = ...) -> None: ... - def fit(self, dataset: RDD[VectorLike]) -> StandardScalerModel: ... - -class ChiSqSelectorModel(JavaVectorTransformer): - @overload - def transform(self, vector: VectorLike) -> Vector: ... - @overload - def transform(self, vector: RDD[VectorLike]) -> RDD[Vector]: ... - -class ChiSqSelector: - numTopFeatures: int - selectorType: str - percentile: float - fpr: float - fdr: float - fwe: float - def __init__( - self, - numTopFeatures: int = ..., - selectorType: str = ..., - percentile: float = ..., - fpr: float = ..., - fdr: float = ..., - fwe: float = ..., - ) -> None: ... - def setNumTopFeatures(self, numTopFeatures: int) -> ChiSqSelector: ... - def setPercentile(self, percentile: float) -> ChiSqSelector: ... - def setFpr(self, fpr: float) -> ChiSqSelector: ... - def setFdr(self, fdr: float) -> ChiSqSelector: ... - def setFwe(self, fwe: float) -> ChiSqSelector: ... - def setSelectorType(self, selectorType: str) -> ChiSqSelector: ... - def fit(self, data: RDD[LabeledPoint]) -> ChiSqSelectorModel: ... - -class PCAModel(JavaVectorTransformer): ... - -class PCA: - k: int - def __init__(self, k: int) -> None: ... - def fit(self, data: RDD[VectorLike]) -> PCAModel: ... - -class HashingTF: - numFeatures: int - binary: bool - def __init__(self, numFeatures: int = ...) -> None: ... - def setBinary(self, value: bool) -> HashingTF: ... - def indexOf(self, term: Hashable) -> int: ... - @overload - def transform(self, document: Iterable[Hashable]) -> Vector: ... - @overload - def transform(self, document: RDD[Iterable[Hashable]]) -> RDD[Vector]: ... - -class IDFModel(JavaVectorTransformer): - @overload - def transform(self, x: VectorLike) -> Vector: ... - @overload - def transform(self, x: RDD[VectorLike]) -> RDD[Vector]: ... - def idf(self) -> Vector: ... - def docFreq(self) -> List[int]: ... - def numDocs(self) -> int: ... - -class IDF: - minDocFreq: int - def __init__(self, minDocFreq: int = ...) -> None: ... - def fit(self, dataset: RDD[VectorLike]) -> IDFModel: ... - -class Word2VecModel(JavaVectorTransformer, JavaSaveable, JavaLoader[Word2VecModel]): - def transform(self, word: str) -> Vector: ... # type: ignore - def findSynonyms( - self, word: Union[str, VectorLike], num: int - ) -> Iterable[Tuple[str, float]]: ... - def getVectors(self) -> JavaMap: ... - @classmethod - def load(cls, sc: SparkContext, path: str) -> Word2VecModel: ... - -class Word2Vec: - vectorSize: int - learningRate: float - numPartitions: int - numIterations: int - seed: int - minCount: int - windowSize: int - def __init__(self) -> None: ... - def setVectorSize(self, vectorSize: int) -> Word2Vec: ... - def setLearningRate(self, learningRate: float) -> Word2Vec: ... - def setNumPartitions(self, numPartitions: int) -> Word2Vec: ... - def setNumIterations(self, numIterations: int) -> Word2Vec: ... - def setSeed(self, seed: int) -> Word2Vec: ... - def setMinCount(self, minCount: int) -> Word2Vec: ... - def setWindowSize(self, windowSize: int) -> Word2Vec: ... - def fit(self, data: RDD[List[str]]) -> Word2VecModel: ... - -class ElementwiseProduct(VectorTransformer): - scalingVector: Vector - def __init__(self, scalingVector: Vector) -> None: ... - @overload - def transform(self, vector: VectorLike) -> Vector: ... - @overload - def transform(self, vector: RDD[VectorLike]) -> RDD[Vector]: ... diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 7d1818d3f2a36..20566f569bdd8 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -42,6 +42,33 @@ BooleanType, ) +from typing import ( + Any, + Callable, + cast, + Dict, + Generic, + Iterable, + List, + Optional, + overload, + Sequence, + Tuple, + Type, + TypeVar, + TYPE_CHECKING, + Union, +) + +if TYPE_CHECKING: + from pyspark.mllib._typing import VectorLike, NormType + from scipy.sparse import spmatrix + from numpy.typing import ArrayLike + + +QT = TypeVar("QT") +RT = TypeVar("RT") + __all__ = [ "Vector", @@ -68,23 +95,23 @@ _have_scipy = False -def _convert_to_vector(d): +def _convert_to_vector(d: Union["VectorLike", "spmatrix", range]) -> "Vector": if isinstance(d, Vector): return d elif type(d) in (array.array, np.array, np.ndarray, list, tuple, range): return DenseVector(d) elif _have_scipy and scipy.sparse.issparse(d): - assert d.shape[1] == 1, "Expected column vector" + assert cast("spmatrix", d).shape[1] == 1, "Expected column vector" # Make sure the converted csc_matrix has sorted indices. - csc = d.tocsc() + csc = cast("spmatrix", d).tocsc() if not csc.has_sorted_indices: csc.sort_indices() - return SparseVector(d.shape[0], csc.indices, csc.data) + return SparseVector(cast("spmatrix", d).shape[0], csc.indices, csc.data) else: raise TypeError("Cannot convert type %s into Vector" % type(d)) -def _vector_size(v): +def _vector_size(v: Union["VectorLike", "spmatrix", range]) -> int: """ Returns the size of the vector. @@ -115,24 +142,24 @@ def _vector_size(v): else: raise ValueError("Cannot treat an ndarray of shape %s as a vector" % str(v.shape)) elif _have_scipy and scipy.sparse.issparse(v): - assert v.shape[1] == 1, "Expected column vector" - return v.shape[0] + assert cast("spmatrix", v).shape[1] == 1, "Expected column vector" + return cast("spmatrix", v).shape[0] else: raise TypeError("Cannot treat type %s as a vector" % type(v)) -def _format_float(f, digits=4): +def _format_float(f: float, digits: int = 4) -> str: s = str(round(f, digits)) if "." in s: s = s[: s.index(".") + 1 + digits] return s -def _format_float_list(xs): +def _format_float_list(xs: Iterable[float]) -> List[str]: return [_format_float(x) for x in xs] -def _double_to_long_bits(value): +def _double_to_long_bits(value: float) -> int: if np.isnan(value): value = float("nan") # pack double into 64 bits, then unpack as long int @@ -145,7 +172,7 @@ class VectorUDT(UserDefinedType): """ @classmethod - def sqlType(cls): + def sqlType(cls) -> StructType: return StructType( [ StructField("type", ByteType(), False), @@ -156,37 +183,41 @@ def sqlType(cls): ) @classmethod - def module(cls): + def module(cls) -> str: return "pyspark.mllib.linalg" @classmethod - def scalaUDT(cls): + def scalaUDT(cls) -> str: return "org.apache.spark.mllib.linalg.VectorUDT" - def serialize(self, obj): + def serialize( + self, obj: "Vector" + ) -> Tuple[int, Optional[int], Optional[List[int]], List[float]]: if isinstance(obj, SparseVector): indices = [int(i) for i in obj.indices] values = [float(v) for v in obj.values] return (0, obj.size, indices, values) elif isinstance(obj, DenseVector): - values = [float(v) for v in obj] + values = [float(v) for v in obj] # type: ignore[attr-defined] return (1, None, None, values) else: raise TypeError("cannot serialize %r of type %r" % (obj, type(obj))) - def deserialize(self, datum): + def deserialize( + self, datum: Tuple[int, Optional[int], Optional[List[int]], List[float]] + ) -> "Vector": assert ( len(datum) == 4 ), "VectorUDT.deserialize given row with length %d but requires 4" % len(datum) tpe = datum[0] if tpe == 0: - return SparseVector(datum[1], datum[2], datum[3]) + return SparseVector(cast(int, datum[1]), cast(List[int], datum[2]), datum[3]) elif tpe == 1: return DenseVector(datum[3]) else: raise ValueError("do not recognize type %r" % tpe) - def simpleString(self): + def simpleString(self) -> str: return "vector" @@ -196,7 +227,7 @@ class MatrixUDT(UserDefinedType): """ @classmethod - def sqlType(cls): + def sqlType(cls) -> StructType: return StructType( [ StructField("type", ByteType(), False), @@ -210,14 +241,16 @@ def sqlType(cls): ) @classmethod - def module(cls): + def module(cls) -> str: return "pyspark.mllib.linalg" @classmethod - def scalaUDT(cls): + def scalaUDT(cls) -> str: return "org.apache.spark.mllib.linalg.MatrixUDT" - def serialize(self, obj): + def serialize( + self, obj: "Matrix" + ) -> Tuple[int, int, int, Optional[List[int]], Optional[List[int]], List[float], bool]: if isinstance(obj, SparseMatrix): colPtrs = [int(i) for i in obj.colPtrs] rowIndices = [int(i) for i in obj.rowIndices] @@ -237,19 +270,29 @@ def serialize(self, obj): else: raise TypeError("cannot serialize type %r" % (type(obj))) - def deserialize(self, datum): + def deserialize( + self, + datum: Tuple[int, int, int, Optional[List[int]], Optional[List[int]], List[float], bool], + ) -> "Matrix": assert ( len(datum) == 7 ), "MatrixUDT.deserialize given row with length %d but requires 7" % len(datum) tpe = datum[0] if tpe == 0: - return SparseMatrix(*datum[1:]) + return SparseMatrix( + datum[1], + datum[2], + cast(List[int], datum[3]), + cast(List[int], datum[4]), + datum[5], + datum[6], + ) elif tpe == 1: return DenseMatrix(datum[1], datum[2], datum[5], datum[6]) else: raise ValueError("do not recognize type %r" % tpe) - def simpleString(self): + def simpleString(self) -> str: return "matrix" @@ -261,7 +304,7 @@ class Vector: Abstract class for DenseVector and SparseVector """ - def toArray(self): + def toArray(self) -> np.ndarray: """ Convert the vector into an numpy.ndarray @@ -271,7 +314,7 @@ def toArray(self): """ raise NotImplementedError - def asML(self): + def asML(self) -> newlinalg.Vector: """ Convert this vector to the new mllib-local representation. This does NOT copy the data; it copies references. @@ -282,6 +325,9 @@ def asML(self): """ raise NotImplementedError + def __len__(self) -> int: + raise NotImplementedError + class DenseVector(Vector): """ @@ -309,17 +355,18 @@ class DenseVector(Vector): DenseVector([-1.0, -2.0]) """ - def __init__(self, ar): + def __init__(self, ar: Union[bytes, np.ndarray, Iterable[float]]): + ar_: np.ndarray if isinstance(ar, bytes): - ar = np.frombuffer(ar, dtype=np.float64) + ar_ = np.frombuffer(ar, dtype=np.float64) elif not isinstance(ar, np.ndarray): - ar = np.array(ar, dtype=np.float64) - if ar.dtype != np.float64: - ar = ar.astype(np.float64) - self.array = ar + ar_ = np.array(ar, dtype=np.float64) + else: + ar_ = ar.astype(np.float64) if ar.dtype != np.float64 else ar + self.array = ar_ @staticmethod - def parse(s): + def parse(s: str) -> "DenseVector": """ Parse string representation back into the DenseVector. @@ -342,16 +389,16 @@ def parse(s): raise ValueError("Unable to parse values from %s" % s) return DenseVector(values) - def __reduce__(self): - return DenseVector, (self.array.tostring(),) + def __reduce__(self) -> Tuple[Type["DenseVector"], Tuple[bytes]]: + return DenseVector, (self.array.tobytes(),) - def numNonzeros(self): + def numNonzeros(self) -> int: """ Number of nonzero elements. This scans all active values and count non zeros """ return np.count_nonzero(self.array) - def norm(self, p): + def norm(self, p: "NormType") -> np.float64: """ Calculates the norm of a DenseVector. @@ -365,7 +412,7 @@ def norm(self, p): """ return np.linalg.norm(self.array, p) - def dot(self, other): + def dot(self, other: Iterable[float]) -> np.float64: """ Compute the dot product of two Vectors. We support (Numpy array, list, SparseVector, or SciPy sparse) @@ -399,8 +446,8 @@ def dot(self, other): assert len(self) == other.shape[0], "dimension mismatch" return np.dot(self.array, other) elif _have_scipy and scipy.sparse.issparse(other): - assert len(self) == other.shape[0], "dimension mismatch" - return other.transpose().dot(self.toArray()) + assert len(self) == cast("spmatrix", other).shape[0], "dimension mismatch" + return cast("spmatrix", other).transpose().dot(self.toArray()) else: assert len(self) == _vector_size(other), "dimension mismatch" if isinstance(other, SparseVector): @@ -408,9 +455,9 @@ def dot(self, other): elif isinstance(other, Vector): return np.dot(self.toArray(), other.toArray()) else: - return np.dot(self.toArray(), other) + return np.dot(self.toArray(), cast("ArrayLike", other)) - def squared_distance(self, other): + def squared_distance(self, other: Iterable[float]) -> np.float64: """ Squared distance of two Vectors. @@ -441,22 +488,22 @@ def squared_distance(self, other): if isinstance(other, SparseVector): return other.squared_distance(self) elif _have_scipy and scipy.sparse.issparse(other): - return _convert_to_vector(other).squared_distance(self) + return _convert_to_vector(other).squared_distance(self) # type: ignore[attr-defined] if isinstance(other, Vector): other = other.toArray() elif not isinstance(other, np.ndarray): other = np.array(other) - diff = self.toArray() - other + diff: np.ndarray = self.toArray() - other return np.dot(diff, diff) - def toArray(self): + def toArray(self) -> np.ndarray: """ Returns an numpy.ndarray """ return self.array - def asML(self): + def asML(self) -> newlinalg.DenseVector: """ Convert this vector to the new mllib-local representation. This does NOT copy the data; it copies references. @@ -470,25 +517,33 @@ def asML(self): return newlinalg.DenseVector(self.array) @property - def values(self): + def values(self) -> np.ndarray: """ Returns a list of values """ return self.array - def __getitem__(self, item): + @overload + def __getitem__(self, item: int) -> np.float64: + ... + + @overload + def __getitem__(self, item: slice) -> np.ndarray: + ... + + def __getitem__(self, item: Union[int, slice]) -> Union[np.float64, np.ndarray]: return self.array[item] - def __len__(self): + def __len__(self) -> int: return len(self.array) - def __str__(self): + def __str__(self) -> str: return "[" + ",".join([str(v) for v in self.array]) + "]" - def __repr__(self): + def __repr__(self) -> str: return "DenseVector([%s])" % (", ".join(_format_float(i) for i in self.array)) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, DenseVector): return np.array_equal(self.array, other.array) elif isinstance(other, SparseVector): @@ -497,10 +552,10 @@ def __eq__(self, other): return Vectors._equals(list(range(len(self))), self.array, other.indices, other.values) return False - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other - def __hash__(self): + def __hash__(self) -> int: size = len(self) result = 31 + size nnz = 0 @@ -514,14 +569,14 @@ def __hash__(self): i += 1 return result - def __getattr__(self, item): + def __getattr__(self, item: str) -> Any: return getattr(self.array, item) - def __neg__(self): + def __neg__(self) -> "DenseVector": return DenseVector(-self.array) - def _delegate(op): - def func(self, other): + def _delegate(op: str) -> Callable[["DenseVector", Any], "DenseVector"]: # type: ignore[misc] + def func(self: "DenseVector", other: Any) -> "DenseVector": if isinstance(other, DenseVector): other = other.array return DenseVector(getattr(self.array, op)(other)) @@ -548,7 +603,33 @@ class SparseVector(Vector): alternatively pass SciPy's {scipy.sparse} data types. """ - def __init__(self, size, *args): + @overload + def __init__(self, size: int, __indices: bytes, __values: bytes): + ... + + @overload + def __init__(self, size: int, *args: Tuple[int, float]): + ... + + @overload + def __init__(self, size: int, __indices: Iterable[int], __values: Iterable[float]): + ... + + @overload + def __init__(self, size: int, __pairs: Iterable[Tuple[int, float]]): + ... + + @overload + def __init__(self, size: int, __map: Dict[int, float]): + ... + + def __init__( + self, + size: int, + *args: Union[ + bytes, Tuple[int, float], Iterable[float], Iterable[Tuple[int, float]], Dict[int, float] + ], + ): """ Create a sparse vector, using either a dictionary, a list of (index, value) pairs, or two separate arrays of indices and @@ -580,7 +661,7 @@ def __init__(self, size, *args): pairs = args[0] if type(pairs) == dict: pairs = pairs.items() - pairs = sorted(pairs) + pairs = cast(Iterable[Tuple[int, float]], sorted(pairs)) self.indices = np.array([p[0] for p in pairs], dtype=np.int32) """ A list of indices corresponding to active entries. """ self.values = np.array([p[1] for p in pairs], dtype=np.float64) @@ -606,13 +687,13 @@ def __init__(self, size, *args): % (self.indices[i], self.indices[i + 1]) ) - def numNonzeros(self): + def numNonzeros(self) -> int: """ Number of nonzero elements. This scans all active values and count non zeros. """ return np.count_nonzero(self.values) - def norm(self, p): + def norm(self, p: "NormType") -> np.float64: """ Calculates the norm of a SparseVector. @@ -626,11 +707,18 @@ def norm(self, p): """ return np.linalg.norm(self.values, p) - def __reduce__(self): - return (SparseVector, (self.size, self.indices.tostring(), self.values.tostring())) + def __reduce__(self) -> Tuple[Type["SparseVector"], Tuple[int, bytes, bytes]]: + return ( + SparseVector, + ( + self.size, + self.indices.tobytes(), + self.values.tobytes(), + ), + ) @staticmethod - def parse(s): + def parse(s: str) -> "SparseVector": """ Parse string representation back into the SparseVector. @@ -649,7 +737,7 @@ def parse(s): size = s[: s.find(",")] try: - size = int(size) + size = int(size) # type: ignore[assignment] except ValueError: raise ValueError("Cannot parse size %s." % size) @@ -678,9 +766,9 @@ def parse(s): values = [float(val) for val in val_list if val] except ValueError: raise ValueError("Unable to parse values from %s." % s) - return SparseVector(size, indices, values) + return SparseVector(cast(int, size), indices, values) - def dot(self, other): + def dot(self, other: Iterable[float]) -> np.float64: """ Dot product with a SparseVector or 1- or 2-dimensional Numpy array. @@ -730,15 +818,15 @@ def dot(self, other): self_cmind = np.in1d(self.indices, other.indices, assume_unique=True) self_values = self.values[self_cmind] if self_values.size == 0: - return 0.0 + return np.float64(0.0) else: other_cmind = np.in1d(other.indices, self.indices, assume_unique=True) return np.dot(self_values, other.values[other_cmind]) else: - return self.dot(_convert_to_vector(other)) + return self.dot(_convert_to_vector(other)) # type: ignore[arg-type] - def squared_distance(self, other): + def squared_distance(self, other: Iterable[float]) -> np.float64: """ Squared distance from a SparseVector or 1-dimensional NumPy array. @@ -806,9 +894,9 @@ def squared_distance(self, other): j += 1 return result else: - return self.squared_distance(_convert_to_vector(other)) + return self.squared_distance(_convert_to_vector(other)) # type: ignore[arg-type] - def toArray(self): + def toArray(self) -> np.ndarray: """ Returns a copy of this SparseVector as a 1-dimensional NumPy array. """ @@ -816,7 +904,7 @@ def toArray(self): arr[self.indices] = self.values return arr - def asML(self): + def asML(self) -> newlinalg.SparseVector: """ Convert this vector to the new mllib-local representation. This does NOT copy the data; it copies references. @@ -829,15 +917,15 @@ def asML(self): """ return newlinalg.SparseVector(self.size, self.indices, self.values) - def __len__(self): + def __len__(self) -> int: return self.size - def __str__(self): + def __str__(self) -> str: inds = "[" + ",".join([str(i) for i in self.indices]) + "]" vals = "[" + ",".join([str(v) for v in self.values]) + "]" return "(" + ",".join((str(self.size), inds, vals)) + ")" - def __repr__(self): + def __repr__(self) -> str: inds = self.indices vals = self.values entries = ", ".join( @@ -845,7 +933,7 @@ def __repr__(self): ) return "SparseVector({0}, {{{1}}})".format(self.size, entries) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, SparseVector): return ( other.size == self.size @@ -858,7 +946,7 @@ def __eq__(self, other): return Vectors._equals(self.indices, self.values, list(range(len(other))), other.array) return False - def __getitem__(self, index): + def __getitem__(self, index: int) -> np.float64: inds = self.indices vals = self.values if not isinstance(index, int): @@ -870,18 +958,18 @@ def __getitem__(self, index): index += self.size if (inds.size == 0) or (index > inds.item(-1)): - return 0.0 + return np.float64(0.0) insert_index = np.searchsorted(inds, index) row_ind = inds[insert_index] if row_ind == index: return vals[insert_index] - return 0.0 + return np.float64(0.0) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self.__eq__(other) - def __hash__(self): + def __hash__(self) -> int: result = 31 + self.size nnz = 0 i = 0 @@ -909,7 +997,37 @@ class Vectors: """ @staticmethod - def sparse(size, *args): + @overload + def sparse(size: int, __indices: bytes, __values: bytes) -> SparseVector: + ... + + @staticmethod + @overload + def sparse(size: int, *args: Tuple[int, float]) -> SparseVector: + ... + + @staticmethod + @overload + def sparse(size: int, __indices: Iterable[int], __values: Iterable[float]) -> SparseVector: + ... + + @staticmethod + @overload + def sparse(size: int, __pairs: Iterable[Tuple[int, float]]) -> SparseVector: + ... + + @staticmethod + @overload + def sparse(size: int, __map: Dict[int, float]) -> SparseVector: + ... + + @staticmethod + def sparse( + size: int, + *args: Union[ + bytes, Tuple[int, float], Iterable[float], Iterable[Tuple[int, float]], Dict[int, float] + ], + ) -> SparseVector: """ Create a sparse vector, using either a dictionary, a list of (index, value) pairs, or two separate arrays of indices and @@ -932,10 +1050,25 @@ def sparse(size, *args): >>> Vectors.sparse(4, [1, 3], [1.0, 5.5]) SparseVector(4, {1: 1.0, 3: 5.5}) """ - return SparseVector(size, *args) + return SparseVector(size, *args) # type: ignore[arg-type] + @overload @staticmethod - def dense(*elements): + def dense(*elements: float) -> DenseVector: + ... + + @overload + @staticmethod + def dense(__arr: bytes) -> DenseVector: + ... + + @overload + @staticmethod + def dense(__arr: Iterable[float]) -> DenseVector: + ... + + @staticmethod + def dense(*elements: Union[float, bytes, np.ndarray, Iterable[float]]) -> DenseVector: """ Create a dense vector of 64-bit floats from a Python list or numbers. @@ -948,11 +1081,11 @@ def dense(*elements): """ if len(elements) == 1 and not isinstance(elements[0], (float, int)): # it's list, numpy.array or other iterable object. - elements = elements[0] - return DenseVector(elements) + elements = elements[0] # type: ignore[assignment] + return DenseVector(cast(Iterable[float], elements)) @staticmethod - def fromML(vec): + def fromML(vec: newlinalg.DenseVector) -> DenseVector: """ Convert a vector from the new mllib-local representation. This does NOT copy the data; it copies references. @@ -975,7 +1108,7 @@ def fromML(vec): raise TypeError("Unsupported vector type %s" % type(vec)) @staticmethod - def stringify(vector): + def stringify(vector: Vector) -> str: """ Converts a vector into a string, which can be recognized by Vectors.parse(). @@ -990,7 +1123,7 @@ def stringify(vector): return str(vector) @staticmethod - def squared_distance(v1, v2): + def squared_distance(v1: Vector, v2: Vector) -> np.float64: """ Squared distance between two vectors. a and b can be of type SparseVector, DenseVector, np.ndarray @@ -1004,17 +1137,17 @@ def squared_distance(v1, v2): 51.0 """ v1, v2 = _convert_to_vector(v1), _convert_to_vector(v2) - return v1.squared_distance(v2) + return v1.squared_distance(v2) # type: ignore[attr-defined] @staticmethod - def norm(vector, p): + def norm(vector: Vector, p: "NormType") -> np.float64: """ Find norm of the given vector. """ - return _convert_to_vector(vector).norm(p) + return _convert_to_vector(vector).norm(p) # type: ignore[attr-defined] @staticmethod - def parse(s): + def parse(s: str) -> Vector: """Parse a string representation back into the Vector. Examples @@ -1032,11 +1165,16 @@ def parse(s): raise ValueError("Cannot find tokens '[' or '(' from the input string.") @staticmethod - def zeros(size): + def zeros(size: int) -> DenseVector: return DenseVector(np.zeros(size)) @staticmethod - def _equals(v1_indices, v1_values, v2_indices, v2_values): + def _equals( + v1_indices: Union[Sequence[int], np.ndarray], + v1_values: Union[Sequence[float], np.ndarray], + v2_indices: Union[Sequence[int], np.ndarray], + v2_values: Union[Sequence[float], np.ndarray], + ) -> bool: """ Check equality between sparse/dense vectors, v1_indices and v2_indices assume to be strictly increasing. @@ -1069,18 +1207,18 @@ class Matrix: Represents a local matrix. """ - def __init__(self, numRows, numCols, isTransposed=False): + def __init__(self, numRows: int, numCols: int, isTransposed: bool = False) -> None: self.numRows = numRows self.numCols = numCols self.isTransposed = isTransposed - def toArray(self): + def toArray(self) -> np.ndarray: """ Returns its elements in a NumPy ndarray. """ raise NotImplementedError - def asML(self): + def asML(self) -> newlinalg.Matrix: """ Convert this matrix to the new mllib-local representation. This does NOT copy the data; it copies references. @@ -1088,7 +1226,7 @@ def asML(self): raise NotImplementedError @staticmethod - def _convert_to_array(array_like, dtype): + def _convert_to_array(array_like: Union[bytes, Iterable[float]], dtype: Any) -> np.ndarray: """ Convert Matrix attributes which are array-like or buffer to array. """ @@ -1102,21 +1240,27 @@ class DenseMatrix(Matrix): Column-major dense matrix. """ - def __init__(self, numRows, numCols, values, isTransposed=False): + def __init__( + self, + numRows: int, + numCols: int, + values: Union[bytes, Iterable[float]], + isTransposed: bool = False, + ): Matrix.__init__(self, numRows, numCols, isTransposed) values = self._convert_to_array(values, np.float64) assert len(values) == numRows * numCols self.values = values - def __reduce__(self): + def __reduce__(self) -> Tuple[Type["DenseMatrix"], Tuple[int, int, bytes, int]]: return DenseMatrix, ( self.numRows, self.numCols, - self.values.tostring(), + self.values.tobytes(), int(self.isTransposed), ) - def __str__(self): + def __str__(self) -> str: """ Pretty printing of a DenseMatrix @@ -1139,7 +1283,7 @@ def __str__(self): x = "\n".join([(" " * 6 + line) for line in array_lines[1:]]) return array_lines[0].replace("array", "DenseMatrix") + "\n" + x - def __repr__(self): + def __repr__(self) -> str: """ Representation of a DenseMatrix @@ -1158,12 +1302,11 @@ def __repr__(self): _format_float_list(self.values[:8]) + ["..."] + _format_float_list(self.values[-8:]) ) - entries = ", ".join(entries) return "DenseMatrix({0}, {1}, [{2}], {3})".format( - self.numRows, self.numCols, entries, self.isTransposed + self.numRows, self.numCols, ", ".join(entries), self.isTransposed ) - def toArray(self): + def toArray(self) -> np.ndarray: """ Return an numpy.ndarray @@ -1179,7 +1322,7 @@ def toArray(self): else: return self.values.reshape((self.numRows, self.numCols), order="F") - def toSparse(self): + def toSparse(self) -> "SparseMatrix": """Convert to SparseMatrix""" if self.isTransposed: values = np.ravel(self.toArray(), order="F") @@ -1193,7 +1336,7 @@ def toSparse(self): return SparseMatrix(self.numRows, self.numCols, colPtrs, rowIndices, values) - def asML(self): + def asML(self) -> newlinalg.DenseMatrix: """ Convert this matrix to the new mllib-local representation. This does NOT copy the data; it copies references. @@ -1206,7 +1349,7 @@ def asML(self): """ return newlinalg.DenseMatrix(self.numRows, self.numCols, self.values, self.isTransposed) - def __getitem__(self, indices): + def __getitem__(self, indices: Tuple[int, int]) -> np.float64: i, j = indices if i < 0 or i >= self.numRows: raise IndexError("Row index %d is out of range [0, %d)" % (i, self.numRows)) @@ -1218,21 +1361,29 @@ def __getitem__(self, indices): else: return self.values[i + j * self.numRows] - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if self.numRows != other.numRows or self.numCols != other.numCols: return False if isinstance(other, SparseMatrix): - return np.all(self.toArray() == other.toArray()) + return np.all(self.toArray() == other.toArray()).tolist() self_values = np.ravel(self.toArray(), order="F") other_values = np.ravel(other.toArray(), order="F") - return np.all(self_values == other_values) + return np.all(self_values == other_values).tolist() class SparseMatrix(Matrix): """Sparse Matrix stored in CSC format.""" - def __init__(self, numRows, numCols, colPtrs, rowIndices, values, isTransposed=False): + def __init__( + self, + numRows: int, + numCols: int, + colPtrs: Union[bytes, Iterable[int]], + rowIndices: Union[bytes, Iterable[int]], + values: Union[bytes, Iterable[float]], + isTransposed: bool = False, + ) -> None: Matrix.__init__(self, numRows, numCols, isTransposed) self.colPtrs = self._convert_to_array(colPtrs, np.int32) self.rowIndices = self._convert_to_array(rowIndices, np.int32) @@ -1254,7 +1405,7 @@ def __init__(self, numRows, numCols, colPtrs, rowIndices, values, isTransposed=F % (self.rowIndices.size, self.values.size) ) - def __str__(self): + def __str__(self) -> str: """ Pretty printing of a SparseMatrix @@ -1300,7 +1451,7 @@ def __str__(self): spstr += "\n.." * 2 return spstr - def __repr__(self): + def __repr__(self) -> str: """ Representation of a SparseMatrix @@ -1325,24 +1476,26 @@ def __repr__(self): if len(self.colPtrs) > 16: colPtrs = colPtrs[:8] + ["..."] + colPtrs[-8:] - values = ", ".join(values) - rowIndices = ", ".join([str(ind) for ind in rowIndices]) - colPtrs = ", ".join([str(ptr) for ptr in colPtrs]) return "SparseMatrix({0}, {1}, [{2}], [{3}], [{4}], {5})".format( - self.numRows, self.numCols, colPtrs, rowIndices, values, self.isTransposed + self.numRows, + self.numCols, + ", ".join([str(ptr) for ptr in colPtrs]), + ", ".join([str(ind) for ind in rowIndices]), + ", ".join(values), + self.isTransposed, ) - def __reduce__(self): + def __reduce__(self) -> Tuple[Type["SparseMatrix"], Tuple[int, int, bytes, bytes, bytes, int]]: return SparseMatrix, ( self.numRows, self.numCols, - self.colPtrs.tostring(), - self.rowIndices.tostring(), - self.values.tostring(), + self.colPtrs.tobytes(), + self.rowIndices.tobytes(), + self.values.tobytes(), int(self.isTransposed), ) - def __getitem__(self, indices): + def __getitem__(self, indices: Tuple[int, int]) -> np.float64: i, j = indices if i < 0 or i >= self.numRows: raise IndexError("Row index %d is out of range [0, %d)" % (i, self.numRows)) @@ -1362,9 +1515,9 @@ def __getitem__(self, indices): if ind < colEnd and self.rowIndices[ind] == i: return self.values[ind] else: - return 0.0 + return np.float64(0.0) - def toArray(self): + def toArray(self) -> np.ndarray: """ Return an numpy.ndarray """ @@ -1378,11 +1531,11 @@ def toArray(self): A[self.rowIndices[startptr:endptr], k] = self.values[startptr:endptr] return A - def toDense(self): + def toDense(self) -> "DenseMatrix": densevals = np.ravel(self.toArray(), order="F") return DenseMatrix(self.numRows, self.numCols, densevals) - def asML(self): + def asML(self) -> newlinalg.SparseMatrix: """ Convert this matrix to the new mllib-local representation. This does NOT copy the data; it copies references. @@ -1403,27 +1556,34 @@ def asML(self): ) # TODO: More efficient implementation: - def __eq__(self, other): - return np.all(self.toArray() == other.toArray()) + def __eq__(self, other: Any) -> bool: + assert isinstance(other, Matrix) + return np.all(self.toArray() == other.toArray()).tolist() class Matrices: @staticmethod - def dense(numRows, numCols, values): + def dense(numRows: int, numCols: int, values: Union[bytes, Iterable[float]]) -> DenseMatrix: """ Create a DenseMatrix """ return DenseMatrix(numRows, numCols, values) @staticmethod - def sparse(numRows, numCols, colPtrs, rowIndices, values): + def sparse( + numRows: int, + numCols: int, + colPtrs: Union[bytes, Iterable[int]], + rowIndices: Union[bytes, Iterable[int]], + values: Union[bytes, Iterable[float]], + ) -> SparseMatrix: """ Create a SparseMatrix """ return SparseMatrix(numRows, numCols, colPtrs, rowIndices, values) @staticmethod - def fromML(mat): + def fromML(mat: newlinalg.Matrix) -> Matrix: """ Convert a matrix from the new mllib-local representation. This does NOT copy the data; it copies references. @@ -1448,34 +1608,34 @@ def fromML(mat): raise TypeError("Unsupported matrix type %s" % type(mat)) -class QRDecomposition: +class QRDecomposition(Generic[QT, RT]): """ Represents QR factors. """ - def __init__(self, Q, R): + def __init__(self, Q: QT, R: RT) -> None: self._Q = Q self._R = R - @property + @property # type: ignore[misc] @since("2.0.0") - def Q(self): + def Q(self) -> QT: """ An orthogonal matrix Q in a QR decomposition. May be null if not computed. """ return self._Q - @property + @property # type: ignore[misc] @since("2.0.0") - def R(self): + def R(self) -> RT: """ An upper triangular matrix R in a QR decomposition. """ return self._R -def _test(): +def _test() -> None: import doctest import numpy diff --git a/python/pyspark/mllib/linalg/__init__.pyi b/python/pyspark/mllib/linalg/__init__.pyi deleted file mode 100644 index 8988e92f5c29a..0000000000000 --- a/python/pyspark/mllib/linalg/__init__.pyi +++ /dev/null @@ -1,278 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import overload -from typing import ( - Any, - Dict, - Generic, - Iterable, - List, - Optional, - Tuple, - Type, - TypeVar, - Union, -) -from pyspark.ml import linalg as newlinalg -from pyspark.sql.types import StructType, UserDefinedType -from numpy import float64, ndarray - -QT = TypeVar("QT") -RT = TypeVar("RT") - -class VectorUDT(UserDefinedType): - @classmethod - def sqlType(cls) -> StructType: ... - @classmethod - def module(cls) -> str: ... - @classmethod - def scalaUDT(cls) -> str: ... - def serialize( - self, obj: Vector - ) -> Tuple[int, Optional[int], Optional[List[int]], List[float]]: ... - def deserialize(self, datum: Any) -> Vector: ... - def simpleString(self) -> str: ... - -class MatrixUDT(UserDefinedType): - @classmethod - def sqlType(cls) -> StructType: ... - @classmethod - def module(cls) -> str: ... - @classmethod - def scalaUDT(cls) -> str: ... - def serialize( - self, obj: Matrix - ) -> Tuple[int, int, int, Optional[List[int]], Optional[List[int]], List[float], bool]: ... - def deserialize(self, datum: Any) -> Matrix: ... - def simpleString(self) -> str: ... - -class Vector: - __UDT__: VectorUDT - def toArray(self) -> ndarray: ... - def asML(self) -> newlinalg.Vector: ... - -class DenseVector(Vector): - array: ndarray - @overload - def __init__(self, *elements: float) -> None: ... - @overload - def __init__(self, __arr: bytes) -> None: ... - @overload - def __init__(self, __arr: Iterable[float]) -> None: ... - @staticmethod - def parse(s: str) -> DenseVector: ... - def __reduce__(self) -> Tuple[Type[DenseVector], bytes]: ... - def numNonzeros(self) -> int: ... - def norm(self, p: Union[float, str]) -> float64: ... - def dot(self, other: Iterable[float]) -> float64: ... - def squared_distance(self, other: Iterable[float]) -> float64: ... - def toArray(self) -> ndarray: ... - def asML(self) -> newlinalg.DenseVector: ... - @property - def values(self) -> ndarray: ... - def __getitem__(self, item: int) -> float64: ... - def __len__(self) -> int: ... - def __eq__(self, other: Any) -> bool: ... - def __ne__(self, other: Any) -> bool: ... - def __hash__(self) -> int: ... - def __getattr__(self, item: str) -> Any: ... - def __neg__(self) -> DenseVector: ... - def __add__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... - def __sub__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... - def __mul__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... - def __div__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... - def __truediv__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... - def __mod__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... - def __radd__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... - def __rsub__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... - def __rmul__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... - def __rdiv__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... - def __rtruediv__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... - def __rmod__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... - -class SparseVector(Vector): - size: int - indices: ndarray - values: ndarray - @overload - def __init__(self, size: int, *args: Tuple[int, float]) -> None: ... - @overload - def __init__(self, size: int, __indices: bytes, __values: bytes) -> None: ... - @overload - def __init__(self, size: int, __indices: Iterable[int], __values: Iterable[float]) -> None: ... - @overload - def __init__(self, size: int, __pairs: Iterable[Tuple[int, float]]) -> None: ... - @overload - def __init__(self, size: int, __map: Dict[int, float]) -> None: ... - def numNonzeros(self) -> int: ... - def norm(self, p: Union[float, str]) -> float64: ... - def __reduce__(self) -> Tuple[Type[SparseVector], Tuple[int, bytes, bytes]]: ... - @staticmethod - def parse(s: str) -> SparseVector: ... - def dot(self, other: Iterable[float]) -> float64: ... - def squared_distance(self, other: Iterable[float]) -> float64: ... - def toArray(self) -> ndarray: ... - def asML(self) -> newlinalg.SparseVector: ... - def __len__(self) -> int: ... - def __eq__(self, other: Any) -> bool: ... - def __getitem__(self, index: int) -> float64: ... - def __ne__(self, other: Any) -> bool: ... - def __hash__(self) -> int: ... - -class Vectors: - @overload - @staticmethod - def sparse(size: int, *args: Tuple[int, float]) -> SparseVector: ... - @overload - @staticmethod - def sparse(size: int, __indices: bytes, __values: bytes) -> SparseVector: ... - @overload - @staticmethod - def sparse(size: int, __indices: Iterable[int], __values: Iterable[float]) -> SparseVector: ... - @overload - @staticmethod - def sparse(size: int, __pairs: Iterable[Tuple[int, float]]) -> SparseVector: ... - @overload - @staticmethod - def sparse(size: int, __map: Dict[int, float]) -> SparseVector: ... - @overload - @staticmethod - def dense(*elements: float) -> DenseVector: ... - @overload - @staticmethod - def dense(__arr: bytes) -> DenseVector: ... - @overload - @staticmethod - def dense(__arr: Iterable[float]) -> DenseVector: ... - @staticmethod - def fromML(vec: newlinalg.DenseVector) -> DenseVector: ... - @staticmethod - def stringify(vector: Vector) -> str: ... - @staticmethod - def squared_distance(v1: Vector, v2: Vector) -> float64: ... - @staticmethod - def norm(vector: Vector, p: Union[float, str]) -> float64: ... - @staticmethod - def parse(s: str) -> Vector: ... - @staticmethod - def zeros(size: int) -> DenseVector: ... - -class Matrix: - __UDT__: MatrixUDT - numRows: int - numCols: int - isTransposed: bool - def __init__(self, numRows: int, numCols: int, isTransposed: bool = ...) -> None: ... - def toArray(self) -> ndarray: ... - def asML(self) -> newlinalg.Matrix: ... - -class DenseMatrix(Matrix): - values: Any - @overload - def __init__( - self, numRows: int, numCols: int, values: bytes, isTransposed: bool = ... - ) -> None: ... - @overload - def __init__( - self, - numRows: int, - numCols: int, - values: Iterable[float], - isTransposed: bool = ..., - ) -> None: ... - def __reduce__(self) -> Tuple[Type[DenseMatrix], Tuple[int, int, bytes, int]]: ... - def toArray(self) -> ndarray: ... - def toSparse(self) -> SparseMatrix: ... - def asML(self) -> newlinalg.DenseMatrix: ... - def __getitem__(self, indices: Tuple[int, int]) -> float64: ... - def __eq__(self, other: Any) -> bool: ... - -class SparseMatrix(Matrix): - colPtrs: ndarray - rowIndices: ndarray - values: ndarray - @overload - def __init__( - self, - numRows: int, - numCols: int, - colPtrs: bytes, - rowIndices: bytes, - values: bytes, - isTransposed: bool = ..., - ) -> None: ... - @overload - def __init__( - self, - numRows: int, - numCols: int, - colPtrs: Iterable[int], - rowIndices: Iterable[int], - values: Iterable[float], - isTransposed: bool = ..., - ) -> None: ... - def __reduce__( - self, - ) -> Tuple[Type[SparseMatrix], Tuple[int, int, bytes, bytes, bytes, int]]: ... - def __getitem__(self, indices: Tuple[int, int]) -> float64: ... - def toArray(self) -> ndarray: ... - def toDense(self) -> DenseMatrix: ... - def asML(self) -> newlinalg.SparseMatrix: ... - def __eq__(self, other: Any) -> bool: ... - -class Matrices: - @overload - @staticmethod - def dense( - numRows: int, numCols: int, values: bytes, isTransposed: bool = ... - ) -> DenseMatrix: ... - @overload - @staticmethod - def dense( - numRows: int, numCols: int, values: Iterable[float], isTransposed: bool = ... - ) -> DenseMatrix: ... - @overload - @staticmethod - def sparse( - numRows: int, - numCols: int, - colPtrs: bytes, - rowIndices: bytes, - values: bytes, - isTransposed: bool = ..., - ) -> SparseMatrix: ... - @overload - @staticmethod - def sparse( - numRows: int, - numCols: int, - colPtrs: Iterable[int], - rowIndices: Iterable[int], - values: Iterable[float], - isTransposed: bool = ..., - ) -> SparseMatrix: ... - @staticmethod - def fromML(mat: newlinalg.Matrix) -> Matrix: ... - -class QRDecomposition(Generic[QT, RT]): - def __init__(self, Q: QT, R: RT) -> None: ... - @property - def Q(self) -> QT: ... - @property - def R(self) -> RT: ... diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index f892d41b12c13..d49af66479311 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -20,16 +20,22 @@ """ import sys +from typing import Any, Generic, Optional, Tuple, TypeVar, Union, TYPE_CHECKING from py4j.java_gateway import JavaObject from pyspark import RDD, since from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper -from pyspark.mllib.linalg import _convert_to_vector, DenseMatrix, Matrix, QRDecomposition +from pyspark.mllib.linalg import _convert_to_vector, DenseMatrix, Matrix, QRDecomposition, Vector from pyspark.mllib.stat import MultivariateStatisticalSummary from pyspark.sql import DataFrame from pyspark.storagelevel import StorageLevel +UT = TypeVar("UT", bound="DistributedMatrix") +VT = TypeVar("VT", bound="Matrix") + +if TYPE_CHECKING: + from pyspark.ml._typing import VectorLike __all__ = [ "BlockMatrix", @@ -50,11 +56,11 @@ class DistributedMatrix: """ - def numRows(self): + def numRows(self) -> int: """Get or compute the number of rows.""" raise NotImplementedError - def numCols(self): + def numCols(self) -> int: """Get or compute the number of cols.""" raise NotImplementedError @@ -82,7 +88,12 @@ class RowMatrix(DistributedMatrix): the first row. """ - def __init__(self, rows, numRows=0, numCols=0): + def __init__( + self, + rows: Union[RDD[Vector], DataFrame], + numRows: int = 0, + numCols: int = 0, + ): """ Note: This docstring is not shown publicly. @@ -121,7 +132,7 @@ def __init__(self, rows, numRows=0, numCols=0): self._java_matrix_wrapper = JavaModelWrapper(java_matrix) @property - def rows(self): + def rows(self) -> RDD[Vector]: """ Rows of the RowMatrix stored as an RDD of vectors. @@ -134,7 +145,7 @@ def rows(self): """ return self._java_matrix_wrapper.call("rows") - def numRows(self): + def numRows(self) -> int: """ Get or compute the number of rows. @@ -153,7 +164,7 @@ def numRows(self): """ return self._java_matrix_wrapper.call("numRows") - def numCols(self): + def numCols(self) -> int: """ Get or compute the number of cols. @@ -172,7 +183,7 @@ def numCols(self): """ return self._java_matrix_wrapper.call("numCols") - def computeColumnSummaryStatistics(self): + def computeColumnSummaryStatistics(self) -> MultivariateStatisticalSummary: """ Computes column-wise summary statistics. @@ -195,7 +206,7 @@ def computeColumnSummaryStatistics(self): java_col_stats = self._java_matrix_wrapper.call("computeColumnSummaryStatistics") return MultivariateStatisticalSummary(java_col_stats) - def computeCovariance(self): + def computeCovariance(self) -> Matrix: """ Computes the covariance matrix, treating each row as an observation. @@ -216,7 +227,7 @@ def computeCovariance(self): """ return self._java_matrix_wrapper.call("computeCovariance") - def computeGramianMatrix(self): + def computeGramianMatrix(self) -> Matrix: """ Computes the Gramian matrix `A^T A`. @@ -237,7 +248,7 @@ def computeGramianMatrix(self): return self._java_matrix_wrapper.call("computeGramianMatrix") @since("2.0.0") - def columnSimilarities(self, threshold=0.0): + def columnSimilarities(self, threshold: float = 0.0) -> "CoordinateMatrix": """ Compute similarities between columns of this matrix. @@ -310,7 +321,9 @@ def columnSimilarities(self, threshold=0.0): java_sims_mat = self._java_matrix_wrapper.call("columnSimilarities", float(threshold)) return CoordinateMatrix(java_sims_mat) - def tallSkinnyQR(self, computeQ=False): + def tallSkinnyQR( + self, computeQ: bool = False + ) -> QRDecomposition[Optional["RowMatrix"], Matrix]: """ Compute the QR decomposition of this RowMatrix. @@ -360,7 +373,9 @@ def tallSkinnyQR(self, computeQ=False): R = decomp.call("R") return QRDecomposition(Q, R) - def computeSVD(self, k, computeU=False, rCond=1e-9): + def computeSVD( + self, k: int, computeU: bool = False, rCond: float = 1e-9 + ) -> "SingularValueDecomposition[RowMatrix, Matrix]": """ Computes the singular value decomposition of the RowMatrix. @@ -414,7 +429,7 @@ def computeSVD(self, k, computeU=False, rCond=1e-9): j_model = self._java_matrix_wrapper.call("computeSVD", int(k), bool(computeU), float(rCond)) return SingularValueDecomposition(j_model) - def computePrincipalComponents(self, k): + def computePrincipalComponents(self, k: int) -> Matrix: """ Computes the k principal components of the given row matrix @@ -450,7 +465,7 @@ def computePrincipalComponents(self, k): """ return self._java_matrix_wrapper.call("computePrincipalComponents", k) - def multiply(self, matrix): + def multiply(self, matrix: Matrix) -> "RowMatrix": """ Multiply this matrix by a local dense matrix on the right. @@ -478,16 +493,16 @@ def multiply(self, matrix): return RowMatrix(j_model) -class SingularValueDecomposition(JavaModelWrapper): +class SingularValueDecomposition(JavaModelWrapper, Generic[UT, VT]): """ Represents singular value decomposition (SVD) factors. .. versionadded:: 2.2.0 """ - @property + @property # type: ignore[misc] @since("2.2.0") - def U(self): + def U(self) -> Optional[UT]: # type: ignore[return] """ Returns a distributed matrix whose columns are the left singular vectors of the SingularValueDecomposition if computeU was set to be True. @@ -496,23 +511,23 @@ def U(self): if u is not None: mat_name = u.getClass().getSimpleName() if mat_name == "RowMatrix": - return RowMatrix(u) + return RowMatrix(u) # type: ignore[return-value] elif mat_name == "IndexedRowMatrix": - return IndexedRowMatrix(u) + return IndexedRowMatrix(u) # type: ignore[return-value] else: raise TypeError("Expected RowMatrix/IndexedRowMatrix got %s" % mat_name) - @property + @property # type: ignore[misc] @since("2.2.0") - def s(self): + def s(self) -> Vector: """ Returns a DenseVector with singular values in descending order. """ return self.call("s") - @property + @property # type: ignore[misc] @since("2.2.0") - def V(self): + def V(self) -> VT: """ Returns a DenseMatrix whose columns are the right singular vectors of the SingularValueDecomposition. @@ -534,15 +549,15 @@ class IndexedRow: The row in the matrix at the given index. """ - def __init__(self, index, vector): + def __init__(self, index: int, vector: "VectorLike") -> None: self.index = int(index) self.vector = _convert_to_vector(vector) - def __repr__(self): + def __repr__(self) -> str: return "IndexedRow(%s, %s)" % (self.index, self.vector) -def _convert_to_indexed_row(row): +def _convert_to_indexed_row(row: Any) -> IndexedRow: if isinstance(row, IndexedRow): return row elif isinstance(row, tuple) and len(row) == 2: @@ -572,7 +587,12 @@ class IndexedRowMatrix(DistributedMatrix): the first row. """ - def __init__(self, rows, numRows=0, numCols=0): + def __init__( + self, + rows: RDD[Union[Tuple[int, "VectorLike"], IndexedRow]], + numRows: int = 0, + numCols: int = 0, + ): """ Note: This docstring is not shown publicly. @@ -623,7 +643,7 @@ def __init__(self, rows, numRows=0, numCols=0): self._java_matrix_wrapper = JavaModelWrapper(java_matrix) @property - def rows(self): + def rows(self) -> RDD[IndexedRow]: """ Rows of the IndexedRowMatrix stored as an RDD of IndexedRows. @@ -643,7 +663,7 @@ def rows(self): rows = rows_df.rdd.map(lambda row: IndexedRow(row[0], row[1])) return rows - def numRows(self): + def numRows(self) -> int: """ Get or compute the number of rows. @@ -664,7 +684,7 @@ def numRows(self): """ return self._java_matrix_wrapper.call("numRows") - def numCols(self): + def numCols(self) -> int: """ Get or compute the number of cols. @@ -685,7 +705,7 @@ def numCols(self): """ return self._java_matrix_wrapper.call("numCols") - def columnSimilarities(self): + def columnSimilarities(self) -> "CoordinateMatrix": """ Compute all cosine similarities between columns. @@ -701,7 +721,7 @@ def columnSimilarities(self): java_coordinate_matrix = self._java_matrix_wrapper.call("columnSimilarities") return CoordinateMatrix(java_coordinate_matrix) - def computeGramianMatrix(self): + def computeGramianMatrix(self) -> Matrix: """ Computes the Gramian matrix `A^T A`. @@ -722,7 +742,7 @@ def computeGramianMatrix(self): """ return self._java_matrix_wrapper.call("computeGramianMatrix") - def toRowMatrix(self): + def toRowMatrix(self) -> RowMatrix: """ Convert this matrix to a RowMatrix. @@ -737,7 +757,7 @@ def toRowMatrix(self): java_row_matrix = self._java_matrix_wrapper.call("toRowMatrix") return RowMatrix(java_row_matrix) - def toCoordinateMatrix(self): + def toCoordinateMatrix(self) -> "CoordinateMatrix": """ Convert this matrix to a CoordinateMatrix. @@ -752,7 +772,7 @@ def toCoordinateMatrix(self): java_coordinate_matrix = self._java_matrix_wrapper.call("toCoordinateMatrix") return CoordinateMatrix(java_coordinate_matrix) - def toBlockMatrix(self, rowsPerBlock=1024, colsPerBlock=1024): + def toBlockMatrix(self, rowsPerBlock: int = 1024, colsPerBlock: int = 1024) -> "BlockMatrix": """ Convert this matrix to a BlockMatrix. @@ -787,7 +807,9 @@ def toBlockMatrix(self, rowsPerBlock=1024, colsPerBlock=1024): ) return BlockMatrix(java_block_matrix, rowsPerBlock, colsPerBlock) - def computeSVD(self, k, computeU=False, rCond=1e-9): + def computeSVD( + self, k: int, computeU: bool = False, rCond: float = 1e-9 + ) -> SingularValueDecomposition["IndexedRowMatrix", Matrix]: """ Computes the singular value decomposition of the IndexedRowMatrix. @@ -841,7 +863,7 @@ def computeSVD(self, k, computeU=False, rCond=1e-9): j_model = self._java_matrix_wrapper.call("computeSVD", int(k), bool(computeU), float(rCond)) return SingularValueDecomposition(j_model) - def multiply(self, matrix): + def multiply(self, matrix: Matrix) -> "IndexedRowMatrix": """ Multiply this matrix by a local dense matrix on the right. @@ -884,16 +906,16 @@ class MatrixEntry: The (i, j)th entry of the matrix, as a float. """ - def __init__(self, i, j, value): + def __init__(self, i: int, j: int, value: float) -> None: self.i = int(i) self.j = int(j) self.value = float(value) - def __repr__(self): + def __repr__(self) -> str: return "MatrixEntry(%s, %s, %s)" % (self.i, self.j, self.value) -def _convert_to_matrix_entry(entry): +def _convert_to_matrix_entry(entry: Any) -> MatrixEntry: if isinstance(entry, MatrixEntry): return entry elif isinstance(entry, tuple) and len(entry) == 3: @@ -923,7 +945,12 @@ class CoordinateMatrix(DistributedMatrix): index plus one. """ - def __init__(self, entries, numRows=0, numCols=0): + def __init__( + self, + entries: RDD[Union[Tuple[int, int, float], MatrixEntry]], + numRows: int = 0, + numCols: int = 0, + ): """ Note: This docstring is not shown publicly. @@ -975,7 +1002,7 @@ def __init__(self, entries, numRows=0, numCols=0): self._java_matrix_wrapper = JavaModelWrapper(java_matrix) @property - def entries(self): + def entries(self) -> RDD[MatrixEntry]: """ Entries of the CoordinateMatrix stored as an RDD of MatrixEntries. @@ -996,7 +1023,7 @@ def entries(self): entries = entries_df.rdd.map(lambda row: MatrixEntry(row[0], row[1], row[2])) return entries - def numRows(self): + def numRows(self) -> int: """ Get or compute the number of rows. @@ -1016,7 +1043,7 @@ def numRows(self): """ return self._java_matrix_wrapper.call("numRows") - def numCols(self): + def numCols(self) -> int: """ Get or compute the number of cols. @@ -1036,7 +1063,7 @@ def numCols(self): """ return self._java_matrix_wrapper.call("numCols") - def transpose(self): + def transpose(self) -> "CoordinateMatrix": """ Transpose this CoordinateMatrix. @@ -1059,7 +1086,7 @@ def transpose(self): java_transposed_matrix = self._java_matrix_wrapper.call("transpose") return CoordinateMatrix(java_transposed_matrix) - def toRowMatrix(self): + def toRowMatrix(self) -> RowMatrix: """ Convert this matrix to a RowMatrix. @@ -1085,7 +1112,7 @@ def toRowMatrix(self): java_row_matrix = self._java_matrix_wrapper.call("toRowMatrix") return RowMatrix(java_row_matrix) - def toIndexedRowMatrix(self): + def toIndexedRowMatrix(self) -> IndexedRowMatrix: """ Convert this matrix to an IndexedRowMatrix. @@ -1110,7 +1137,7 @@ def toIndexedRowMatrix(self): java_indexed_row_matrix = self._java_matrix_wrapper.call("toIndexedRowMatrix") return IndexedRowMatrix(java_indexed_row_matrix) - def toBlockMatrix(self, rowsPerBlock=1024, colsPerBlock=1024): + def toBlockMatrix(self, rowsPerBlock: int = 1024, colsPerBlock: int = 1024) -> "BlockMatrix": """ Convert this matrix to a BlockMatrix. @@ -1149,7 +1176,7 @@ def toBlockMatrix(self, rowsPerBlock=1024, colsPerBlock=1024): return BlockMatrix(java_block_matrix, rowsPerBlock, colsPerBlock) -def _convert_to_matrix_block_tuple(block): +def _convert_to_matrix_block_tuple(block: Any) -> Tuple[Tuple[int, int], Matrix]: if ( isinstance(block, tuple) and len(block) == 2 @@ -1198,7 +1225,14 @@ class BlockMatrix(DistributedMatrix): invoked. """ - def __init__(self, blocks, rowsPerBlock, colsPerBlock, numRows=0, numCols=0): + def __init__( + self, + blocks: RDD[Tuple[Tuple[int, int], Matrix]], + rowsPerBlock: int, + colsPerBlock: int, + numRows: int = 0, + numCols: int = 0, + ): """ Note: This docstring is not shown publicly. @@ -1254,7 +1288,7 @@ def __init__(self, blocks, rowsPerBlock, colsPerBlock, numRows=0, numCols=0): self._java_matrix_wrapper = JavaModelWrapper(java_matrix) @property - def blocks(self): + def blocks(self) -> RDD[Tuple[Tuple[int, int], Matrix]]: """ The RDD of sub-matrix blocks ((blockRowIndex, blockColIndex), sub-matrix) that form this @@ -1279,7 +1313,7 @@ def blocks(self): return blocks @property - def rowsPerBlock(self): + def rowsPerBlock(self) -> int: """ Number of rows that make up each block. @@ -1294,7 +1328,7 @@ def rowsPerBlock(self): return self._java_matrix_wrapper.call("rowsPerBlock") @property - def colsPerBlock(self): + def colsPerBlock(self) -> int: """ Number of columns that make up each block. @@ -1309,7 +1343,7 @@ def colsPerBlock(self): return self._java_matrix_wrapper.call("colsPerBlock") @property - def numRowBlocks(self): + def numRowBlocks(self) -> int: """ Number of rows of blocks in the BlockMatrix. @@ -1324,7 +1358,7 @@ def numRowBlocks(self): return self._java_matrix_wrapper.call("numRowBlocks") @property - def numColBlocks(self): + def numColBlocks(self) -> int: """ Number of columns of blocks in the BlockMatrix. @@ -1338,7 +1372,7 @@ def numColBlocks(self): """ return self._java_matrix_wrapper.call("numColBlocks") - def numRows(self): + def numRows(self) -> int: """ Get or compute the number of rows. @@ -1357,7 +1391,7 @@ def numRows(self): """ return self._java_matrix_wrapper.call("numRows") - def numCols(self): + def numCols(self) -> int: """ Get or compute the number of cols. @@ -1377,7 +1411,7 @@ def numCols(self): return self._java_matrix_wrapper.call("numCols") @since("2.0.0") - def cache(self): + def cache(self) -> "BlockMatrix": """ Caches the underlying RDD. """ @@ -1385,7 +1419,7 @@ def cache(self): return self @since("2.0.0") - def persist(self, storageLevel): + def persist(self, storageLevel: StorageLevel) -> "BlockMatrix": """ Persists the underlying RDD with the specified storage level. """ @@ -1396,14 +1430,14 @@ def persist(self, storageLevel): return self @since("2.0.0") - def validate(self): + def validate(self) -> None: """ Validates the block matrix info against the matrix data (`blocks`) and throws an exception if any error is found. """ self._java_matrix_wrapper.call("validate") - def add(self, other): + def add(self, other: "BlockMatrix") -> "BlockMatrix": """ Adds two block matrices together. The matrices must have the same size and matching `rowsPerBlock` and `colsPerBlock` values. @@ -1438,7 +1472,7 @@ def add(self, other): java_block_matrix = self._java_matrix_wrapper.call("add", other_java_block_matrix) return BlockMatrix(java_block_matrix, self.rowsPerBlock, self.colsPerBlock) - def subtract(self, other): + def subtract(self, other: "BlockMatrix") -> "BlockMatrix": """ Subtracts the given block matrix `other` from this block matrix: `this - other`. The matrices must have the same size and @@ -1476,7 +1510,7 @@ def subtract(self, other): java_block_matrix = self._java_matrix_wrapper.call("subtract", other_java_block_matrix) return BlockMatrix(java_block_matrix, self.rowsPerBlock, self.colsPerBlock) - def multiply(self, other): + def multiply(self, other: "BlockMatrix") -> "BlockMatrix": """ Left multiplies this BlockMatrix by `other`, another BlockMatrix. The `colsPerBlock` of this matrix must equal the @@ -1513,7 +1547,7 @@ def multiply(self, other): java_block_matrix = self._java_matrix_wrapper.call("multiply", other_java_block_matrix) return BlockMatrix(java_block_matrix, self.rowsPerBlock, self.colsPerBlock) - def transpose(self): + def transpose(self) -> "BlockMatrix": """ Transpose this BlockMatrix. Returns a new BlockMatrix instance sharing the same underlying data. Is a lazy operation. @@ -1533,7 +1567,7 @@ def transpose(self): java_transposed_matrix = self._java_matrix_wrapper.call("transpose") return BlockMatrix(java_transposed_matrix, self.colsPerBlock, self.rowsPerBlock) - def toLocalMatrix(self): + def toLocalMatrix(self) -> Matrix: """ Collect the distributed matrix on the driver as a DenseMatrix. @@ -1557,7 +1591,7 @@ def toLocalMatrix(self): """ return self._java_matrix_wrapper.call("toLocalMatrix") - def toIndexedRowMatrix(self): + def toIndexedRowMatrix(self) -> IndexedRowMatrix: """ Convert this matrix to an IndexedRowMatrix. @@ -1582,7 +1616,7 @@ def toIndexedRowMatrix(self): java_indexed_row_matrix = self._java_matrix_wrapper.call("toIndexedRowMatrix") return IndexedRowMatrix(java_indexed_row_matrix) - def toCoordinateMatrix(self): + def toCoordinateMatrix(self) -> CoordinateMatrix: """ Convert this matrix to a CoordinateMatrix. @@ -1598,7 +1632,7 @@ def toCoordinateMatrix(self): return CoordinateMatrix(java_coordinate_matrix) -def _test(): +def _test() -> None: import doctest import numpy from pyspark.sql import SparkSession diff --git a/python/pyspark/mllib/linalg/distributed.pyi b/python/pyspark/mllib/linalg/distributed.pyi deleted file mode 100644 index 3d8a0c57b1d8c..0000000000000 --- a/python/pyspark/mllib/linalg/distributed.pyi +++ /dev/null @@ -1,145 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import Generic, Sequence, Optional, Tuple, TypeVar, Union -from pyspark.rdd import RDD -from pyspark.storagelevel import StorageLevel -from pyspark.mllib.common import JavaModelWrapper -from pyspark.mllib.linalg import Vector, Matrix, QRDecomposition -from pyspark.mllib.stat import MultivariateStatisticalSummary -import pyspark.sql.dataframe -from numpy import ndarray # noqa: F401 - -VectorLike = Union[Vector, Sequence[Union[float, int]]] - -UT = TypeVar("UT") -VT = TypeVar("VT") - -class DistributedMatrix: - def numRows(self) -> int: ... - def numCols(self) -> int: ... - -class RowMatrix(DistributedMatrix): - def __init__( - self, - rows: Union[RDD[Vector], pyspark.sql.dataframe.DataFrame], - numRows: int = ..., - numCols: int = ..., - ) -> None: ... - @property - def rows(self) -> RDD[Vector]: ... - def numRows(self) -> int: ... - def numCols(self) -> int: ... - def computeColumnSummaryStatistics(self) -> MultivariateStatisticalSummary: ... - def computeCovariance(self) -> Matrix: ... - def computeGramianMatrix(self) -> Matrix: ... - def columnSimilarities(self, threshold: float = ...) -> CoordinateMatrix: ... - def tallSkinnyQR(self, computeQ: bool = ...) -> QRDecomposition[RowMatrix, Matrix]: ... - def computeSVD( - self, k: int, computeU: bool = ..., rCond: float = ... - ) -> SingularValueDecomposition[RowMatrix, Matrix]: ... - def computePrincipalComponents(self, k: int) -> Matrix: ... - def multiply(self, matrix: Matrix) -> RowMatrix: ... - -class SingularValueDecomposition(JavaModelWrapper, Generic[UT, VT]): - @property - def U(self) -> Optional[UT]: ... - @property - def s(self) -> Vector: ... - @property - def V(self) -> VT: ... - -class IndexedRow: - index: int - vector: VectorLike - def __init__(self, index: int, vector: VectorLike) -> None: ... - -class IndexedRowMatrix(DistributedMatrix): - def __init__( - self, - rows: RDD[Union[Tuple[int, VectorLike], IndexedRow]], - numRows: int = ..., - numCols: int = ..., - ) -> None: ... - @property - def rows(self) -> RDD[IndexedRow]: ... - def numRows(self) -> int: ... - def numCols(self) -> int: ... - def columnSimilarities(self) -> CoordinateMatrix: ... - def computeGramianMatrix(self) -> Matrix: ... - def toRowMatrix(self) -> RowMatrix: ... - def toCoordinateMatrix(self) -> CoordinateMatrix: ... - def toBlockMatrix(self, rowsPerBlock: int = ..., colsPerBlock: int = ...) -> BlockMatrix: ... - def computeSVD( - self, k: int, computeU: bool = ..., rCond: float = ... - ) -> SingularValueDecomposition[IndexedRowMatrix, Matrix]: ... - def multiply(self, matrix: Matrix) -> IndexedRowMatrix: ... - -class MatrixEntry: - i: int - j: int - value: float - def __init__(self, i: int, j: int, value: float) -> None: ... - -class CoordinateMatrix(DistributedMatrix): - def __init__( - self, - entries: RDD[Union[Tuple[int, int, float], MatrixEntry]], - numRows: int = ..., - numCols: int = ..., - ) -> None: ... - @property - def entries(self) -> RDD[MatrixEntry]: ... - def numRows(self) -> int: ... - def numCols(self) -> int: ... - def transpose(self) -> CoordinateMatrix: ... - def toRowMatrix(self) -> RowMatrix: ... - def toIndexedRowMatrix(self) -> IndexedRowMatrix: ... - def toBlockMatrix(self, rowsPerBlock: int = ..., colsPerBlock: int = ...) -> BlockMatrix: ... - -class BlockMatrix(DistributedMatrix): - def __init__( - self, - blocks: RDD[Tuple[Tuple[int, int], Matrix]], - rowsPerBlock: int, - colsPerBlock: int, - numRows: int = ..., - numCols: int = ..., - ) -> None: ... - @property - def blocks(self) -> RDD[Tuple[Tuple[int, int], Matrix]]: ... - @property - def rowsPerBlock(self) -> int: ... - @property - def colsPerBlock(self) -> int: ... - @property - def numRowBlocks(self) -> int: ... - @property - def numColBlocks(self) -> int: ... - def numRows(self) -> int: ... - def numCols(self) -> int: ... - def cache(self) -> BlockMatrix: ... - def persist(self, storageLevel: StorageLevel) -> BlockMatrix: ... - def validate(self) -> None: ... - def add(self, other: BlockMatrix) -> BlockMatrix: ... - def subtract(self, other: BlockMatrix) -> BlockMatrix: ... - def multiply(self, other: BlockMatrix) -> BlockMatrix: ... - def transpose(self) -> BlockMatrix: ... - def toLocalMatrix(self) -> Matrix: ... - def toIndexedRowMatrix(self) -> IndexedRowMatrix: ... - def toCoordinateMatrix(self) -> CoordinateMatrix: ... diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index c099b4880281e..4f7da0131f6e9 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -17,6 +17,18 @@ import sys import warnings +from typing import ( + Any, + Callable, + Iterable, + Optional, + Tuple, + Type, + TypeVar, + Union, + overload, + TYPE_CHECKING, +) import numpy as np @@ -25,6 +37,16 @@ from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py, inherit_doc from pyspark.mllib.linalg import _convert_to_vector from pyspark.mllib.util import Saveable, Loader +from pyspark.rdd import RDD +from pyspark.context import SparkContext +from pyspark.mllib.linalg import Vector + +if TYPE_CHECKING: + from pyspark.mllib._typing import VectorLike + + +LM = TypeVar("LM") +K = TypeVar("K") __all__ = [ "LabeledPoint", @@ -62,17 +84,17 @@ class LabeledPoint: 'label' and 'features' are accessible as class attributes. """ - def __init__(self, label, features): + def __init__(self, label: float, features: Iterable[float]): self.label = float(label) self.features = _convert_to_vector(features) - def __reduce__(self): + def __reduce__(self) -> Tuple[Type["LabeledPoint"], Tuple[float, Vector]]: return (LabeledPoint, (self.label, self.features)) - def __str__(self): + def __str__(self) -> str: return "(" + ",".join((str(self.label), str(self.features))) + ")" - def __repr__(self): + def __repr__(self) -> str: return "LabeledPoint(%s, %s)" % (self.label, self.features) @@ -91,23 +113,23 @@ class LinearModel: Intercept computed for this model. """ - def __init__(self, weights, intercept): + def __init__(self, weights: Vector, intercept: float): self._coeff = _convert_to_vector(weights) self._intercept = float(intercept) - @property + @property # type: ignore[misc] @since("1.0.0") - def weights(self): + def weights(self) -> Vector: """Weights computed for every feature.""" return self._coeff - @property + @property # type: ignore[misc] @since("1.0.0") - def intercept(self): + def intercept(self) -> float: """Intercept computed for this model.""" return self._intercept - def __repr__(self): + def __repr__(self) -> str: return "(weights=%s, intercept=%r)" % (self._coeff, self._intercept) @@ -128,16 +150,25 @@ class LinearRegressionModelBase(LinearModel): True """ - @since("0.9.0") - def predict(self, x): + @overload + def predict(self, x: "VectorLike") -> float: + ... + + @overload + def predict(self, x: RDD["VectorLike"]) -> RDD[float]: + ... + + def predict(self, x: Union["VectorLike", RDD["VectorLike"]]) -> Union[float, RDD[float]]: """ Predict the value of the dependent variable given a vector or an RDD of vectors containing values for the independent variables. + + .. versionadded:: 0.9.0 """ if isinstance(x, RDD): return x.map(self.predict) x = _convert_to_vector(x) - return self.weights.dot(x) + self.intercept + return self.weights.dot(x) + self.intercept # type: ignore[attr-defined] @inherit_doc @@ -204,8 +235,10 @@ class LinearRegressionModel(LinearRegressionModelBase): """ @since("1.4.0") - def save(self, sc, path): + def save(self, sc: SparkContext, path: str) -> None: """Save a LinearRegressionModel.""" + assert sc._jvm is not None + java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel( _py2java(sc, self._coeff), self.intercept ) @@ -213,8 +246,10 @@ def save(self, sc, path): @classmethod @since("1.4.0") - def load(cls, sc, path): + def load(cls, sc: SparkContext, path: str) -> "LinearRegressionModel": """Load a LinearRegressionModel.""" + assert sc._jvm is not None + java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel.load( sc._jsc.sc(), path ) @@ -227,7 +262,12 @@ def load(cls, sc, path): # train_func should take two parameters, namely data and initial_weights, and # return the result of a call to the appropriate JVM stub. # _regression_train_wrapper is responsible for setup and error checking. -def _regression_train_wrapper(train_func, modelClass, data, initial_weights): +def _regression_train_wrapper( + train_func: Callable[[RDD[LabeledPoint], Vector], Iterable[Any]], + modelClass: Type[LM], + data: RDD[LabeledPoint], + initial_weights: Optional["VectorLike"], +) -> LM: from pyspark.mllib.classification import LogisticRegressionModel first = data.first() @@ -239,10 +279,12 @@ def _regression_train_wrapper(train_func, modelClass, data, initial_weights): weights, intercept, numFeatures, numClasses = train_func( data, _convert_to_vector(initial_weights) ) - return modelClass(weights, intercept, numFeatures, numClasses) + return modelClass( # type: ignore[call-arg, return-value] + weights, intercept, numFeatures, numClasses + ) else: weights, intercept = train_func(data, _convert_to_vector(initial_weights)) - return modelClass(weights, intercept) + return modelClass(weights, intercept) # type: ignore[call-arg, return-value] class LinearRegressionWithSGD: @@ -257,17 +299,17 @@ class LinearRegressionWithSGD: @classmethod def train( cls, - data, - iterations=100, - step=1.0, - miniBatchFraction=1.0, - initialWeights=None, - regParam=0.0, - regType=None, - intercept=False, - validateData=True, - convergenceTol=0.001, - ): + data: RDD[LabeledPoint], + iterations: int = 100, + step: float = 1.0, + miniBatchFraction: float = 1.0, + initialWeights: Optional["VectorLike"] = None, + regParam: float = 0.0, + regType: Optional[str] = None, + intercept: bool = False, + validateData: bool = True, + convergenceTol: float = 0.001, + ) -> LinearRegressionModel: """ Train a linear regression model using Stochastic Gradient Descent (SGD). This solves the least squares regression @@ -324,7 +366,7 @@ def train( """ warnings.warn("Deprecated in 2.0.0. Use ml.regression.LinearRegression.", FutureWarning) - def train(rdd, i): + def train(rdd: RDD[LabeledPoint], i: Vector) -> Iterable[Any]: return callMLlibFunc( "trainLinearRegressionModelWithSGD", rdd, @@ -407,8 +449,10 @@ class LassoModel(LinearRegressionModelBase): """ @since("1.4.0") - def save(self, sc, path): + def save(self, sc: SparkContext, path: str) -> None: """Save a LassoModel.""" + assert sc._jvm is not None + java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel( _py2java(sc, self._coeff), self.intercept ) @@ -416,8 +460,10 @@ def save(self, sc, path): @classmethod @since("1.4.0") - def load(cls, sc, path): + def load(cls, sc: SparkContext, path: str) -> "LassoModel": """Load a LassoModel.""" + assert sc._jvm is not None + java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel.load(sc._jsc.sc(), path) weights = _java2py(sc, java_model.weights()) intercept = java_model.intercept() @@ -438,16 +484,16 @@ class LassoWithSGD: @classmethod def train( cls, - data, - iterations=100, - step=1.0, - regParam=0.01, - miniBatchFraction=1.0, - initialWeights=None, - intercept=False, - validateData=True, - convergenceTol=0.001, - ): + data: RDD[LabeledPoint], + iterations: int = 100, + step: float = 1.0, + regParam: float = 0.01, + miniBatchFraction: float = 1.0, + initialWeights: Optional["VectorLike"] = None, + intercept: bool = False, + validateData: bool = True, + convergenceTol: float = 0.001, + ) -> LassoModel: """ Train a regression model with L1-regularization using Stochastic Gradient Descent. This solves the l1-regularized least squares @@ -499,7 +545,7 @@ def train( FutureWarning, ) - def train(rdd, i): + def train(rdd: RDD[LabeledPoint], i: Vector) -> Iterable[Any]: return callMLlibFunc( "trainLassoModelWithSGD", rdd, @@ -581,8 +627,10 @@ class RidgeRegressionModel(LinearRegressionModelBase): """ @since("1.4.0") - def save(self, sc, path): + def save(self, sc: SparkContext, path: str) -> None: """Save a RidgeRegressionMode.""" + assert sc._jvm is not None + java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel( _py2java(sc, self._coeff), self.intercept ) @@ -590,8 +638,10 @@ def save(self, sc, path): @classmethod @since("1.4.0") - def load(cls, sc, path): + def load(cls, sc: SparkContext, path: str) -> "RidgeRegressionModel": """Load a RidgeRegressionMode.""" + assert sc._jvm is not None + java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel.load( sc._jsc.sc(), path ) @@ -615,16 +665,16 @@ class RidgeRegressionWithSGD: @classmethod def train( cls, - data, - iterations=100, - step=1.0, - regParam=0.01, - miniBatchFraction=1.0, - initialWeights=None, - intercept=False, - validateData=True, - convergenceTol=0.001, - ): + data: RDD[LabeledPoint], + iterations: int = 100, + step: float = 1.0, + regParam: float = 0.01, + miniBatchFraction: float = 1.0, + initialWeights: Optional["VectorLike"] = None, + intercept: bool = False, + validateData: bool = True, + convergenceTol: float = 0.001, + ) -> RidgeRegressionModel: """ Train a regression model with L2-regularization using Stochastic Gradient Descent. This solves the l2-regularized least squares @@ -677,7 +727,7 @@ def train( FutureWarning, ) - def train(rdd, i): + def train(rdd: RDD[LabeledPoint], i: Vector) -> Iterable[Any]: return callMLlibFunc( "trainRidgeModelWithSGD", rdd, @@ -694,7 +744,7 @@ def train(rdd, i): return _regression_train_wrapper(train, RidgeRegressionModel, data, initialWeights) -class IsotonicRegressionModel(Saveable, Loader): +class IsotonicRegressionModel(Saveable, Loader["IsotonicRegressionModel"]): """ Regression model for isotonic regression. @@ -737,12 +787,30 @@ class IsotonicRegressionModel(Saveable, Loader): ... pass """ - def __init__(self, boundaries, predictions, isotonic): + def __init__(self, boundaries: np.ndarray, predictions: np.ndarray, isotonic: bool): self.boundaries = boundaries self.predictions = predictions self.isotonic = isotonic - def predict(self, x): + @overload + def predict(self, x: float) -> np.float64: + ... + + @overload + def predict(self, x: "VectorLike") -> np.ndarray: + ... + + @overload + def predict(self, x: RDD[float]) -> RDD[np.float64]: + ... + + @overload + def predict(self, x: RDD["VectorLike"]) -> RDD[np.ndarray]: + ... + + def predict( + self, x: Union[float, "VectorLike", RDD[float], RDD["VectorLike"]] + ) -> Union[np.float64, np.ndarray, RDD[np.float64], RDD[np.ndarray]]: """ Predict labels for provided features. Using a piecewise linear function. @@ -770,13 +838,17 @@ def predict(self, x): """ if isinstance(x, RDD): return x.map(lambda v: self.predict(v)) - return np.interp(x, self.boundaries, self.predictions) + return np.interp( + x, self.boundaries, self.predictions # type: ignore[call-overload, arg-type] + ) @since("1.4.0") - def save(self, sc, path): + def save(self, sc: SparkContext, path: str) -> None: """Save an IsotonicRegressionModel.""" java_boundaries = _py2java(sc, self.boundaries.tolist()) java_predictions = _py2java(sc, self.predictions.tolist()) + assert sc._jvm is not None + java_model = sc._jvm.org.apache.spark.mllib.regression.IsotonicRegressionModel( java_boundaries, java_predictions, self.isotonic ) @@ -784,8 +856,10 @@ def save(self, sc, path): @classmethod @since("1.4.0") - def load(cls, sc, path): + def load(cls, sc: SparkContext, path: str) -> "IsotonicRegressionModel": """Load an IsotonicRegressionModel.""" + assert sc._jvm is not None + java_model = sc._jvm.org.apache.spark.mllib.regression.IsotonicRegressionModel.load( sc._jsc.sc(), path ) @@ -823,7 +897,7 @@ class IsotonicRegression: """ @classmethod - def train(cls, data, isotonic=True): + def train(cls, data: RDD["VectorLike"], isotonic: bool = True) -> IsotonicRegressionModel: """ Train an isotonic regression model on the given data. @@ -852,23 +926,23 @@ class StreamingLinearAlgorithm: .. versionadded:: 1.5.0 """ - def __init__(self, model): + def __init__(self, model: Optional[LinearModel]): self._model = model @since("1.5.0") - def latestModel(self): + def latestModel(self) -> Optional[LinearModel]: """ Returns the latest model. """ return self._model - def _validate(self, dstream): + def _validate(self, dstream: Any) -> None: if not isinstance(dstream, DStream): raise TypeError("dstream should be a DStream object, got %s" % type(dstream)) if not self._model: raise ValueError("Model must be initialized using setInitialWeights") - def predictOn(self, dstream): + def predictOn(self, dstream: "DStream[VectorLike]") -> "DStream[float]": """ Use the model to make predictions on batches of data from a DStream. @@ -881,9 +955,11 @@ def predictOn(self, dstream): DStream containing predictions. """ self._validate(dstream) - return dstream.map(lambda x: self._model.predict(x)) + return dstream.map(lambda x: self._model.predict(x)) # type: ignore[union-attr] - def predictOnValues(self, dstream): + def predictOnValues( + self, dstream: "DStream[Tuple[K, VectorLike]]" + ) -> "DStream[Tuple[K, float]]": """ Use the model to make predictions on the values of a DStream and carry over its keys. @@ -896,7 +972,7 @@ def predictOnValues(self, dstream): DStream containing predictions. """ self._validate(dstream) - return dstream.mapValues(lambda x: self._model.predict(x)) + return dstream.mapValues(lambda x: self._model.predict(x)) # type: ignore[union-attr] @inherit_doc @@ -930,16 +1006,22 @@ class StreamingLinearRegressionWithSGD(StreamingLinearAlgorithm): (default: 0.001) """ - def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, convergenceTol=0.001): + def __init__( + self, + stepSize: float = 0.1, + numIterations: int = 50, + miniBatchFraction: float = 1.0, + convergenceTol: float = 0.001, + ): self.stepSize = stepSize self.numIterations = numIterations self.miniBatchFraction = miniBatchFraction self.convergenceTol = convergenceTol - self._model = None + self._model: Optional[LinearModel] = None super(StreamingLinearRegressionWithSGD, self).__init__(model=self._model) @since("1.5.0") - def setInitialWeights(self, initialWeights): + def setInitialWeights(self, initialWeights: "VectorLike") -> "StreamingLinearRegressionWithSGD": """ Set the initial value of weights. @@ -950,27 +1032,28 @@ def setInitialWeights(self, initialWeights): return self @since("1.5.0") - def trainOn(self, dstream): + def trainOn(self, dstream: "DStream[LabeledPoint]") -> None: """Train the model on the incoming dstream.""" self._validate(dstream) - def update(rdd): + def update(rdd: RDD[LabeledPoint]) -> None: # LinearRegressionWithSGD.train raises an error for an empty RDD. if not rdd.isEmpty(): + assert self._model is not None self._model = LinearRegressionWithSGD.train( rdd, self.numIterations, self.stepSize, self.miniBatchFraction, self._model.weights, - intercept=self._model.intercept, + intercept=self._model.intercept, # type: ignore[arg-type] convergenceTol=self.convergenceTol, ) dstream.foreachRDD(update) -def _test(): +def _test() -> None: import doctest from pyspark.sql import SparkSession import pyspark.mllib.regression diff --git a/python/pyspark/mllib/regression.pyi b/python/pyspark/mllib/regression.pyi deleted file mode 100644 index 0e5e13a53f811..0000000000000 --- a/python/pyspark/mllib/regression.pyi +++ /dev/null @@ -1,149 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import overload -from typing import Iterable, Optional, Tuple, TypeVar -from pyspark.rdd import RDD -from pyspark.mllib._typing import VectorLike -from pyspark.context import SparkContext -from pyspark.mllib.linalg import Vector -from pyspark.mllib.util import Saveable, Loader -from pyspark.streaming.dstream import DStream -from numpy import ndarray - -K = TypeVar("K") - -class LabeledPoint: - label: int - features: Vector - def __init__(self, label: float, features: Iterable[float]) -> None: ... - def __reduce__(self) -> Tuple[type, Tuple[bytes]]: ... - -class LinearModel: - def __init__(self, weights: Vector, intercept: float) -> None: ... - @property - def weights(self) -> Vector: ... - @property - def intercept(self) -> float: ... - -class LinearRegressionModelBase(LinearModel): - @overload - def predict(self, x: Vector) -> float: ... - @overload - def predict(self, x: RDD[Vector]) -> RDD[float]: ... - -class LinearRegressionModel(LinearRegressionModelBase): - def save(self, sc: SparkContext, path: str) -> None: ... - @classmethod - def load(cls, sc: SparkContext, path: str) -> LinearRegressionModel: ... - -class LinearRegressionWithSGD: - @classmethod - def train( - cls, - data: RDD[LabeledPoint], - iterations: int = ..., - step: float = ..., - miniBatchFraction: float = ..., - initialWeights: Optional[VectorLike] = ..., - regParam: float = ..., - regType: Optional[str] = ..., - intercept: bool = ..., - validateData: bool = ..., - convergenceTol: float = ..., - ) -> LinearRegressionModel: ... - -class LassoModel(LinearRegressionModelBase): - def save(self, sc: SparkContext, path: str) -> None: ... - @classmethod - def load(cls, sc: SparkContext, path: str) -> LassoModel: ... - -class LassoWithSGD: - @classmethod - def train( - cls, - data: RDD[LabeledPoint], - iterations: int = ..., - step: float = ..., - regParam: float = ..., - miniBatchFraction: float = ..., - initialWeights: Optional[VectorLike] = ..., - intercept: bool = ..., - validateData: bool = ..., - convergenceTol: float = ..., - ) -> LassoModel: ... - -class RidgeRegressionModel(LinearRegressionModelBase): - def save(self, sc: SparkContext, path: str) -> None: ... - @classmethod - def load(cls, sc: SparkContext, path: str) -> RidgeRegressionModel: ... - -class RidgeRegressionWithSGD: - @classmethod - def train( - cls, - data: RDD[LabeledPoint], - iterations: int = ..., - step: float = ..., - regParam: float = ..., - miniBatchFraction: float = ..., - initialWeights: Optional[VectorLike] = ..., - intercept: bool = ..., - validateData: bool = ..., - convergenceTol: float = ..., - ) -> RidgeRegressionModel: ... - -class IsotonicRegressionModel(Saveable, Loader[IsotonicRegressionModel]): - boundaries: ndarray - predictions: ndarray - isotonic: bool - def __init__(self, boundaries: ndarray, predictions: ndarray, isotonic: bool) -> None: ... - @overload - def predict(self, x: Vector) -> ndarray: ... - @overload - def predict(self, x: RDD[Vector]) -> RDD[ndarray]: ... - def save(self, sc: SparkContext, path: str) -> None: ... - @classmethod - def load(cls, sc: SparkContext, path: str) -> IsotonicRegressionModel: ... - -class IsotonicRegression: - @classmethod - def train(cls, data: RDD[VectorLike], isotonic: bool = ...) -> IsotonicRegressionModel: ... - -class StreamingLinearAlgorithm: - def __init__(self, model: LinearModel) -> None: ... - def latestModel(self) -> LinearModel: ... - def predictOn(self, dstream: DStream[VectorLike]) -> DStream[float]: ... - def predictOnValues( - self, dstream: DStream[Tuple[K, VectorLike]] - ) -> DStream[Tuple[K, float]]: ... - -class StreamingLinearRegressionWithSGD(StreamingLinearAlgorithm): - stepSize: float - numIterations: int - miniBatchFraction: float - convergenceTol: float - def __init__( - self, - stepSize: float = ..., - numIterations: int = ..., - miniBatchFraction: float = ..., - convergenceTol: float = ..., - ) -> None: ... - def setInitialWeights(self, initialWeights: VectorLike) -> StreamingLinearRegressionWithSGD: ... - def trainOn(self, dstream: DStream[LabeledPoint]) -> None: ... diff --git a/python/pyspark/mllib/stat/KernelDensity.py b/python/pyspark/mllib/stat/KernelDensity.py index 103c955df9bae..febf4fd53fd2f 100644 --- a/python/pyspark/mllib/stat/KernelDensity.py +++ b/python/pyspark/mllib/stat/KernelDensity.py @@ -46,7 +46,7 @@ def setBandwidth(self, bandwidth: float) -> None: """Set bandwidth of each sample. Defaults to 1.0""" self._bandwidth = bandwidth - def setSample(self, sample: "RDD[float]") -> None: + def setSample(self, sample: RDD[float]) -> None: """Set sample points from the population. Should be a RDD""" if not isinstance(sample, RDD): raise TypeError("samples should be a RDD, received %s" % type(sample)) diff --git a/python/pyspark/mllib/tests/test_linalg.py b/python/pyspark/mllib/tests/test_linalg.py index d25d2f21202ce..007f42d3c2d09 100644 --- a/python/pyspark/mllib/tests/test_linalg.py +++ b/python/pyspark/mllib/tests/test_linalg.py @@ -22,7 +22,7 @@ import pyspark.ml.linalg as newlinalg from pyspark.serializers import CPickleSerializer -from pyspark.mllib.linalg import ( # type: ignore[attr-defined] +from pyspark.mllib.linalg import ( Vector, SparseVector, DenseVector, diff --git a/python/pyspark/mllib/tests/test_util.py b/python/pyspark/mllib/tests/test_util.py index b45dce9f21642..aad1349c71bbc 100644 --- a/python/pyspark/mllib/tests/test_util.py +++ b/python/pyspark/mllib/tests/test_util.py @@ -19,7 +19,7 @@ import tempfile import unittest -from pyspark.mllib.common import _to_java_object_rdd # type: ignore[attr-defined] +from pyspark.mllib.common import _to_java_object_rdd from pyspark.mllib.util import LinearDataGenerator from pyspark.mllib.util import MLUtils from pyspark.mllib.linalg import SparseVector, DenseVector, Vectors diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 9b477ffecfd23..e1d87e99c8a5e 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -23,6 +23,12 @@ from pyspark.mllib.linalg import _convert_to_vector from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.util import JavaLoader, JavaSaveable +from typing import Dict, Optional, Tuple, Union, overload, TYPE_CHECKING +from pyspark.rdd import RDD + +if TYPE_CHECKING: + from pyspark.mllib._typing import VectorLike + __all__ = [ "DecisionTreeModel", @@ -40,7 +46,15 @@ class TreeEnsembleModel(JavaModelWrapper, JavaSaveable): .. versionadded:: 1.3.0 """ - def predict(self, x): + @overload + def predict(self, x: "VectorLike") -> float: + ... + + @overload + def predict(self, x: RDD["VectorLike"]) -> RDD[float]: + ... + + def predict(self, x: Union["VectorLike", RDD["VectorLike"]]) -> Union[float, RDD[float]]: """ Predict values for a single data point or an RDD of points using the model trained. @@ -60,37 +74,45 @@ def predict(self, x): return self.call("predict", _convert_to_vector(x)) @since("1.3.0") - def numTrees(self): + def numTrees(self) -> int: """ Get number of trees in ensemble. """ return self.call("numTrees") @since("1.3.0") - def totalNumNodes(self): + def totalNumNodes(self) -> int: """ Get total number of nodes, summed over all trees in the ensemble. """ return self.call("totalNumNodes") - def __repr__(self): + def __repr__(self) -> str: """Summary of model""" return self._java_model.toString() @since("1.3.0") - def toDebugString(self): + def toDebugString(self) -> str: """Full model""" return self._java_model.toDebugString() -class DecisionTreeModel(JavaModelWrapper, JavaSaveable, JavaLoader): +class DecisionTreeModel(JavaModelWrapper, JavaSaveable, JavaLoader["DecisionTreeModel"]): """ A decision tree model for classification or regression. .. versionadded:: 1.1.0 """ - def predict(self, x): + @overload + def predict(self, x: "VectorLike") -> float: + ... + + @overload + def predict(self, x: RDD["VectorLike"]) -> RDD[float]: + ... + + def predict(self, x: Union["VectorLike", RDD["VectorLike"]]) -> Union[float, RDD[float]]: """ Predict the label of one or more examples. @@ -115,29 +137,29 @@ def predict(self, x): return self.call("predict", _convert_to_vector(x)) @since("1.1.0") - def numNodes(self): + def numNodes(self) -> int: """Get number of nodes in tree, including leaf nodes.""" return self._java_model.numNodes() @since("1.1.0") - def depth(self): + def depth(self) -> int: """ Get depth of tree (e.g. depth 0 means 1 leaf node, depth 1 means 1 internal node + 2 leaf nodes). """ return self._java_model.depth() - def __repr__(self): + def __repr__(self) -> str: """summary of model.""" return self._java_model.toString() @since("1.2.0") - def toDebugString(self): + def toDebugString(self) -> str: """full model.""" return self._java_model.toDebugString() @classmethod - def _java_loader_class(cls): + def _java_loader_class(cls) -> str: return "org.apache.spark.mllib.tree.model.DecisionTreeModel" @@ -152,16 +174,16 @@ class DecisionTree: @classmethod def _train( cls, - data, - type, - numClasses, - features, - impurity="gini", - maxDepth=5, - maxBins=32, - minInstancesPerNode=1, - minInfoGain=0.0, - ): + data: RDD[LabeledPoint], + type: str, + numClasses: int, + features: Dict[int, int], + impurity: str = "gini", + maxDepth: int = 5, + maxBins: int = 32, + minInstancesPerNode: int = 1, + minInfoGain: float = 0.0, + ) -> DecisionTreeModel: first = data.first() assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint" model = callMLlibFunc( @@ -181,15 +203,15 @@ def _train( @classmethod def trainClassifier( cls, - data, - numClasses, - categoricalFeaturesInfo, - impurity="gini", - maxDepth=5, - maxBins=32, - minInstancesPerNode=1, - minInfoGain=0.0, - ): + data: RDD[LabeledPoint], + numClasses: int, + categoricalFeaturesInfo: Dict[int, int], + impurity: str = "gini", + maxDepth: int = 5, + maxBins: int = 32, + minInstancesPerNode: int = 1, + minInfoGain: float = 0.0, + ) -> DecisionTreeModel: """ Train a decision tree model for classification. @@ -276,14 +298,14 @@ def trainClassifier( @since("1.1.0") def trainRegressor( cls, - data, - categoricalFeaturesInfo, - impurity="variance", - maxDepth=5, - maxBins=32, - minInstancesPerNode=1, - minInfoGain=0.0, - ): + data: RDD[LabeledPoint], + categoricalFeaturesInfo: Dict[int, int], + impurity: str = "variance", + maxDepth: int = 5, + maxBins: int = 32, + minInstancesPerNode: int = 1, + minInfoGain: float = 0.0, + ) -> DecisionTreeModel: """ Train a decision tree model for regression. @@ -354,7 +376,7 @@ def trainRegressor( @inherit_doc -class RandomForestModel(TreeEnsembleModel, JavaLoader): +class RandomForestModel(TreeEnsembleModel, JavaLoader["RandomForestModel"]): """ Represents a random forest model. @@ -362,7 +384,7 @@ class RandomForestModel(TreeEnsembleModel, JavaLoader): """ @classmethod - def _java_loader_class(cls): + def _java_loader_class(cls) -> str: return "org.apache.spark.mllib.tree.model.RandomForestModel" @@ -374,22 +396,22 @@ class RandomForest: .. versionadded:: 1.2.0 """ - supportedFeatureSubsetStrategies = ("auto", "all", "sqrt", "log2", "onethird") + supportedFeatureSubsetStrategies: Tuple[str, ...] = ("auto", "all", "sqrt", "log2", "onethird") @classmethod def _train( cls, - data, - algo, - numClasses, - categoricalFeaturesInfo, - numTrees, - featureSubsetStrategy, - impurity, - maxDepth, - maxBins, - seed, - ): + data: RDD[LabeledPoint], + algo: str, + numClasses: int, + categoricalFeaturesInfo: Dict[int, int], + numTrees: int, + featureSubsetStrategy: str, + impurity: str, + maxDepth: int, + maxBins: int, + seed: Optional[int], + ) -> RandomForestModel: first = data.first() assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint" if featureSubsetStrategy not in cls.supportedFeatureSubsetStrategies: @@ -414,16 +436,16 @@ def _train( @classmethod def trainClassifier( cls, - data, - numClasses, - categoricalFeaturesInfo, - numTrees, - featureSubsetStrategy="auto", - impurity="gini", - maxDepth=4, - maxBins=32, - seed=None, - ): + data: RDD[LabeledPoint], + numClasses: int, + categoricalFeaturesInfo: Dict[int, int], + numTrees: int, + featureSubsetStrategy: str = "auto", + impurity: str = "gini", + maxDepth: int = 4, + maxBins: int = 32, + seed: Optional[int] = None, + ) -> RandomForestModel: """ Train a random forest model for binary or multiclass classification. @@ -530,15 +552,15 @@ def trainClassifier( @classmethod def trainRegressor( cls, - data, - categoricalFeaturesInfo, - numTrees, - featureSubsetStrategy="auto", - impurity="variance", - maxDepth=4, - maxBins=32, - seed=None, - ): + data: RDD[LabeledPoint], + categoricalFeaturesInfo: Dict[int, int], + numTrees: int, + featureSubsetStrategy: str = "auto", + impurity: str = "variance", + maxDepth: int = 4, + maxBins: int = 32, + seed: Optional[int] = None, + ) -> RandomForestModel: """ Train a random forest model for regression. @@ -625,7 +647,7 @@ def trainRegressor( @inherit_doc -class GradientBoostedTreesModel(TreeEnsembleModel, JavaLoader): +class GradientBoostedTreesModel(TreeEnsembleModel, JavaLoader["GradientBoostedTreesModel"]): """ Represents a gradient-boosted tree model. @@ -633,7 +655,7 @@ class GradientBoostedTreesModel(TreeEnsembleModel, JavaLoader): """ @classmethod - def _java_loader_class(cls): + def _java_loader_class(cls) -> str: return "org.apache.spark.mllib.tree.model.GradientBoostedTreesModel" @@ -648,15 +670,15 @@ class GradientBoostedTrees: @classmethod def _train( cls, - data, - algo, - categoricalFeaturesInfo, - loss, - numIterations, - learningRate, - maxDepth, - maxBins, - ): + data: RDD[LabeledPoint], + algo: str, + categoricalFeaturesInfo: Dict[int, int], + loss: str, + numIterations: int, + learningRate: float, + maxDepth: int, + maxBins: int, + ) -> GradientBoostedTreesModel: first = data.first() assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint" model = callMLlibFunc( @@ -675,14 +697,14 @@ def _train( @classmethod def trainClassifier( cls, - data, - categoricalFeaturesInfo, - loss="logLoss", - numIterations=100, - learningRate=0.1, - maxDepth=3, - maxBins=32, - ): + data: RDD[LabeledPoint], + categoricalFeaturesInfo: Dict[int, int], + loss: str = "logLoss", + numIterations: int = 100, + learningRate: float = 0.1, + maxDepth: int = 3, + maxBins: int = 32, + ) -> GradientBoostedTreesModel: """ Train a gradient-boosted trees model for classification. @@ -765,14 +787,14 @@ def trainClassifier( @classmethod def trainRegressor( cls, - data, - categoricalFeaturesInfo, - loss="leastSquaresError", - numIterations=100, - learningRate=0.1, - maxDepth=3, - maxBins=32, - ): + data: RDD[LabeledPoint], + categoricalFeaturesInfo: Dict[int, int], + loss: str = "leastSquaresError", + numIterations: int = 100, + learningRate: float = 0.1, + maxDepth: int = 3, + maxBins: int = 32, + ) -> GradientBoostedTreesModel: """ Train a gradient-boosted trees model for regression. @@ -851,7 +873,7 @@ def trainRegressor( ) -def _test(): +def _test() -> None: import doctest globs = globals().copy() diff --git a/python/pyspark/mllib/tree.pyi b/python/pyspark/mllib/tree.pyi deleted file mode 100644 index fedb494f19062..0000000000000 --- a/python/pyspark/mllib/tree.pyi +++ /dev/null @@ -1,124 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import overload -from typing import Dict, Optional, Tuple -from pyspark.mllib._typing import VectorLike -from pyspark.rdd import RDD -from pyspark.mllib.common import JavaModelWrapper -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.util import JavaLoader, JavaSaveable - -class TreeEnsembleModel(JavaModelWrapper, JavaSaveable): - @overload - def predict(self, x: VectorLike) -> float: ... - @overload - def predict(self, x: RDD[VectorLike]) -> RDD[VectorLike]: ... - def numTrees(self) -> int: ... - def totalNumNodes(self) -> int: ... - def toDebugString(self) -> str: ... - -class DecisionTreeModel(JavaModelWrapper, JavaSaveable, JavaLoader[DecisionTreeModel]): - @overload - def predict(self, x: VectorLike) -> float: ... - @overload - def predict(self, x: RDD[VectorLike]) -> RDD[VectorLike]: ... - def numNodes(self) -> int: ... - def depth(self) -> int: ... - def toDebugString(self) -> str: ... - -class DecisionTree: - @classmethod - def trainClassifier( - cls, - data: RDD[LabeledPoint], - numClasses: int, - categoricalFeaturesInfo: Dict[int, int], - impurity: str = ..., - maxDepth: int = ..., - maxBins: int = ..., - minInstancesPerNode: int = ..., - minInfoGain: float = ..., - ) -> DecisionTreeModel: ... - @classmethod - def trainRegressor( - cls, - data: RDD[LabeledPoint], - categoricalFeaturesInfo: Dict[int, int], - impurity: str = ..., - maxDepth: int = ..., - maxBins: int = ..., - minInstancesPerNode: int = ..., - minInfoGain: float = ..., - ) -> DecisionTreeModel: ... - -class RandomForestModel(TreeEnsembleModel, JavaLoader[RandomForestModel]): ... - -class RandomForest: - supportedFeatureSubsetStrategies: Tuple[str, ...] - @classmethod - def trainClassifier( - cls, - data: RDD[LabeledPoint], - numClasses: int, - categoricalFeaturesInfo: Dict[int, int], - numTrees: int, - featureSubsetStrategy: str = ..., - impurity: str = ..., - maxDepth: int = ..., - maxBins: int = ..., - seed: Optional[int] = ..., - ) -> RandomForestModel: ... - @classmethod - def trainRegressor( - cls, - data: RDD[LabeledPoint], - categoricalFeaturesInfo: Dict[int, int], - numTrees: int, - featureSubsetStrategy: str = ..., - impurity: str = ..., - maxDepth: int = ..., - maxBins: int = ..., - seed: Optional[int] = ..., - ) -> RandomForestModel: ... - -class GradientBoostedTreesModel(TreeEnsembleModel, JavaLoader[GradientBoostedTreesModel]): ... - -class GradientBoostedTrees: - @classmethod - def trainClassifier( - cls, - data: RDD[LabeledPoint], - categoricalFeaturesInfo: Dict[int, int], - loss: str = ..., - numIterations: int = ..., - learningRate: float = ..., - maxDepth: int = ..., - maxBins: int = ..., - ) -> GradientBoostedTreesModel: ... - @classmethod - def trainRegressor( - cls, - data: RDD[LabeledPoint], - categoricalFeaturesInfo: Dict[int, int], - loss: str = ..., - numIterations: int = ..., - learningRate: float = ..., - maxDepth: int = ..., - maxBins: int = ..., - ) -> GradientBoostedTreesModel: ... diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index d3824e86c2618..8f28e2cfee0eb 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -16,12 +16,27 @@ # import sys +from functools import reduce import numpy as np from pyspark import SparkContext, since from pyspark.mllib.common import callMLlibFunc, inherit_doc from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector from pyspark.sql import DataFrame +from typing import Generic, Iterable, List, Optional, Tuple, Type, TypeVar, cast, TYPE_CHECKING +from pyspark.context import SparkContext +from pyspark.mllib.linalg import Vector +from pyspark.rdd import RDD +from pyspark.sql.dataframe import DataFrame + +T = TypeVar("T") +L = TypeVar("L", bound="Loader") +JL = TypeVar("JL", bound="JavaLoader") + +if TYPE_CHECKING: + from pyspark.mllib._typing import VectorLike + from py4j.java_gateway import JavaObject + from pyspark.mllib.regression import LabeledPoint class MLUtils: @@ -33,7 +48,7 @@ class MLUtils: """ @staticmethod - def _parse_libsvm_line(line): + def _parse_libsvm_line(line: str) -> Tuple[float, np.ndarray, np.ndarray]: """ Parses a line in LIBSVM format into (label, indices, values). """ @@ -49,7 +64,7 @@ def _parse_libsvm_line(line): return label, indices, values @staticmethod - def _convert_labeled_point_to_libsvm(p): + def _convert_labeled_point_to_libsvm(p: "LabeledPoint") -> str: """Converts a LabeledPoint to a string in LIBSVM format.""" from pyspark.mllib.regression import LabeledPoint @@ -62,11 +77,13 @@ def _convert_labeled_point_to_libsvm(p): items.append(str(v.indices[i] + 1) + ":" + str(v.values[i])) else: for i in range(len(v)): - items.append(str(i + 1) + ":" + str(v[i])) + items.append(str(i + 1) + ":" + str(v[i])) # type: ignore[index] return " ".join(items) @staticmethod - def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None): + def loadLibSVMFile( + sc: SparkContext, path: str, numFeatures: int = -1, minPartitions: Optional[int] = None + ) -> RDD["LabeledPoint"]: """ Loads labeled data in the LIBSVM format into an RDD of LabeledPoint. The LIBSVM format is a text-based format used by @@ -128,10 +145,14 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None): if numFeatures <= 0: parsed.cache() numFeatures = parsed.map(lambda x: -1 if x[1].size == 0 else x[1][-1]).reduce(max) + 1 - return parsed.map(lambda x: LabeledPoint(x[0], Vectors.sparse(numFeatures, x[1], x[2]))) + return parsed.map( + lambda x: LabeledPoint( + x[0], Vectors.sparse(numFeatures, x[1], x[2]) # type: ignore[arg-type] + ) + ) @staticmethod - def saveAsLibSVMFile(data, dir): + def saveAsLibSVMFile(data: RDD["LabeledPoint"], dir: str) -> None: """ Save labeled data in LIBSVM format. @@ -163,7 +184,9 @@ def saveAsLibSVMFile(data, dir): lines.saveAsTextFile(dir) @staticmethod - def loadLabeledPoints(sc, path, minPartitions=None): + def loadLabeledPoints( + sc: SparkContext, path: str, minPartitions: Optional[int] = None + ) -> RDD["LabeledPoint"]: """ Load labeled points saved using RDD.saveAsTextFile. @@ -201,7 +224,7 @@ def loadLabeledPoints(sc, path, minPartitions=None): @staticmethod @since("1.5.0") - def appendBias(data): + def appendBias(data: Vector) -> Vector: """ Returns a new vector with `1.0` (bias) appended to the end of the input vector. @@ -216,7 +239,7 @@ def appendBias(data): @staticmethod @since("1.5.0") - def loadVectors(sc, path): + def loadVectors(sc: SparkContext, path: str) -> RDD[Vector]: """ Loads vectors saved using `RDD[Vector].saveAsTextFile` with the default number of partitions. @@ -224,7 +247,7 @@ def loadVectors(sc, path): return callMLlibFunc("loadVectors", sc, path) @staticmethod - def convertVectorColumnsToML(dataset, *cols): + def convertVectorColumnsToML(dataset: DataFrame, *cols: str) -> DataFrame: """ Converts vector columns in an input DataFrame from the :py:class:`pyspark.mllib.linalg.Vector` type to the new @@ -273,7 +296,7 @@ def convertVectorColumnsToML(dataset, *cols): return callMLlibFunc("convertVectorColumnsToML", dataset, list(cols)) @staticmethod - def convertVectorColumnsFromML(dataset, *cols): + def convertVectorColumnsFromML(dataset: DataFrame, *cols: str) -> DataFrame: """ Converts vector columns in an input DataFrame to the :py:class:`pyspark.mllib.linalg.Vector` type from the new @@ -322,7 +345,7 @@ def convertVectorColumnsFromML(dataset, *cols): return callMLlibFunc("convertVectorColumnsFromML", dataset, list(cols)) @staticmethod - def convertMatrixColumnsToML(dataset, *cols): + def convertMatrixColumnsToML(dataset: DataFrame, *cols: str) -> DataFrame: """ Converts matrix columns in an input DataFrame from the :py:class:`pyspark.mllib.linalg.Matrix` type to the new @@ -371,7 +394,7 @@ def convertMatrixColumnsToML(dataset, *cols): return callMLlibFunc("convertMatrixColumnsToML", dataset, list(cols)) @staticmethod - def convertMatrixColumnsFromML(dataset, *cols): + def convertMatrixColumnsFromML(dataset: DataFrame, *cols: str) -> DataFrame: """ Converts matrix columns in an input DataFrame to the :py:class:`pyspark.mllib.linalg.Matrix` type from the new @@ -427,7 +450,7 @@ class Saveable: .. versionadded:: 1.3.0 """ - def save(self, sc, path): + def save(self, sc: SparkContext, path: str) -> None: """ Save this model to the given path. @@ -458,8 +481,10 @@ class JavaSaveable(Saveable): .. versionadded:: 1.3.0 """ + _java_model: "JavaObject" + @since("1.3.0") - def save(self, sc, path): + def save(self, sc: SparkContext, path: str) -> None: """Save this model to the given path.""" if not isinstance(sc, SparkContext): raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) @@ -468,7 +493,7 @@ def save(self, sc, path): self._java_model.save(sc._jsc.sc(), path) -class Loader: +class Loader(Generic[T]): """ Mixin for classes which can load saved models from files. @@ -476,7 +501,7 @@ class Loader: """ @classmethod - def load(cls, sc, path): + def load(cls: Type[L], sc: SparkContext, path: str) -> L: """ Load a model from the given path. The model should have been saved using :py:meth:`Saveable.save`. @@ -497,7 +522,7 @@ def load(cls, sc, path): @inherit_doc -class JavaLoader(Loader): +class JavaLoader(Loader[T]): """ Mixin for classes which can load saved models using its Scala implementation. @@ -506,7 +531,7 @@ class JavaLoader(Loader): """ @classmethod - def _java_loader_class(cls): + def _java_loader_class(cls) -> str: """ Returns the full class name of the Java loader. The default implementation replaces "pyspark" by "org.apache.spark" in @@ -516,22 +541,20 @@ def _java_loader_class(cls): return ".".join([java_package, cls.__name__]) @classmethod - def _load_java(cls, sc, path): + def _load_java(cls, sc: SparkContext, path: str) -> "JavaObject": """ Load a Java model from the given path. """ java_class = cls._java_loader_class() - java_obj = sc._jvm - for name in java_class.split("."): - java_obj = getattr(java_obj, name) + java_obj: "JavaObject" = reduce(getattr, java_class.split("."), sc._jvm) return java_obj.load(sc._jsc.sc(), path) @classmethod @since("1.3.0") - def load(cls, sc, path): + def load(cls: Type[JL], sc: SparkContext, path: str) -> JL: """Load a model from the given path.""" java_model = cls._load_java(sc, path) - return cls(java_model) + return cls(java_model) # type: ignore[call-arg] class LinearDataGenerator: @@ -541,7 +564,15 @@ class LinearDataGenerator: """ @staticmethod - def generateLinearInput(intercept, weights, xMean, xVariance, nPoints, seed, eps): + def generateLinearInput( + intercept: float, + weights: "VectorLike", + xMean: "VectorLike", + xVariance: "VectorLike", + nPoints: int, + seed: int, + eps: float, + ) -> List["LabeledPoint"]: """ .. versionadded:: 1.5.0 @@ -568,9 +599,9 @@ def generateLinearInput(intercept, weights, xMean, xVariance, nPoints, seed, eps list of :py:class:`pyspark.mllib.regression.LabeledPoints` of length nPoints """ - weights = [float(weight) for weight in weights] - xMean = [float(mean) for mean in xMean] - xVariance = [float(var) for var in xVariance] + weights = [float(weight) for weight in cast(Iterable[float], weights)] + xMean = [float(mean) for mean in cast(Iterable[float], xMean)] + xVariance = [float(var) for var in cast(Iterable[float], xVariance)] return list( callMLlibFunc( "generateLinearInputWrapper", @@ -586,7 +617,14 @@ def generateLinearInput(intercept, weights, xMean, xVariance, nPoints, seed, eps @staticmethod @since("1.5.0") - def generateLinearRDD(sc, nexamples, nfeatures, eps, nParts=2, intercept=0.0): + def generateLinearRDD( + sc: SparkContext, + nexamples: int, + nfeatures: int, + eps: float, + nParts: int = 2, + intercept: float = 0.0, + ) -> RDD["LabeledPoint"]: """ Generate an RDD of LabeledPoints. """ @@ -601,7 +639,7 @@ def generateLinearRDD(sc, nexamples, nfeatures, eps, nParts=2, intercept=0.0): ) -def _test(): +def _test() -> None: import doctest from pyspark.sql import SparkSession diff --git a/python/pyspark/mllib/util.pyi b/python/pyspark/mllib/util.pyi deleted file mode 100644 index 265f765ee263a..0000000000000 --- a/python/pyspark/mllib/util.pyi +++ /dev/null @@ -1,90 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import Generic, List, Optional, TypeVar - -from pyspark.mllib._typing import VectorLike -from pyspark.context import SparkContext -from pyspark.mllib.linalg import Vector -from pyspark.mllib.regression import LabeledPoint -from pyspark.rdd import RDD -from pyspark.sql.dataframe import DataFrame - -T = TypeVar("T") - -class MLUtils: - @staticmethod - def loadLibSVMFile( - sc: SparkContext, - path: str, - numFeatures: int = ..., - minPartitions: Optional[int] = ..., - ) -> RDD[LabeledPoint]: ... - @staticmethod - def saveAsLibSVMFile(data: RDD[LabeledPoint], dir: str) -> None: ... - @staticmethod - def loadLabeledPoints( - sc: SparkContext, path: str, minPartitions: Optional[int] = ... - ) -> RDD[LabeledPoint]: ... - @staticmethod - def appendBias(data: Vector) -> Vector: ... - @staticmethod - def loadVectors(sc: SparkContext, path: str) -> RDD[Vector]: ... - @staticmethod - def convertVectorColumnsToML(dataset: DataFrame, *cols: str) -> DataFrame: ... - @staticmethod - def convertVectorColumnsFromML(dataset: DataFrame, *cols: str) -> DataFrame: ... - @staticmethod - def convertMatrixColumnsToML(dataset: DataFrame, *cols: str) -> DataFrame: ... - @staticmethod - def convertMatrixColumnsFromML(dataset: DataFrame, *cols: str) -> DataFrame: ... - -class Saveable: - def save(self, sc: SparkContext, path: str) -> None: ... - -class JavaSaveable(Saveable): - def save(self, sc: SparkContext, path: str) -> None: ... - -class Loader(Generic[T]): - @classmethod - def load(cls, sc: SparkContext, path: str) -> T: ... - -class JavaLoader(Loader[T]): - @classmethod - def load(cls, sc: SparkContext, path: str) -> T: ... - -class LinearDataGenerator: - @staticmethod - def generateLinearInput( - intercept: float, - weights: VectorLike, - xMean: VectorLike, - xVariance: VectorLike, - nPoints: int, - seed: int, - eps: float, - ) -> List[LabeledPoint]: ... - @staticmethod - def generateLinearRDD( - sc: SparkContext, - nexamples: int, - nfeatures: int, - eps: float, - nParts: int = ..., - intercept: float = ..., - ) -> RDD[LabeledPoint]: ... diff --git a/python/pyspark/pandas/__init__.py b/python/pyspark/pandas/__init__.py index cb503a23d3444..df84c118db5a6 100644 --- a/python/pyspark/pandas/__init__.py +++ b/python/pyspark/pandas/__init__.py @@ -149,5 +149,5 @@ def _auto_patch_pandas() -> None: # Import after the usage logger is attached. from pyspark.pandas.config import get_option, options, option_context, reset_option, set_option -from pyspark.pandas.namespace import * # F405 +from pyspark.pandas.namespace import * # noqa: F403 from pyspark.pandas.sql_formatter import sql diff --git a/python/pyspark/pandas/accessors.py b/python/pyspark/pandas/accessors.py index 22042491eb072..411ed0ee49bbf 100644 --- a/python/pyspark/pandas/accessors.py +++ b/python/pyspark/pandas/accessors.py @@ -335,14 +335,19 @@ def apply_batch( if not isinstance(func, FunctionType): assert callable(func), "the first argument should be a callable function." f = func - func = lambda *args, **kwargs: f(*args, **kwargs) + # Note that the return type hint specified here affects actual return + # type in Spark (e.g., infer_return_type). And, MyPy does not allow + # redefinition of a function. + func = lambda *args, **kwargs: f(*args, **kwargs) # noqa: E731 spec = inspect.getfullargspec(func) return_sig = spec.annotations.get("return", None) should_infer_schema = return_sig is None original_func = func - func = lambda o: original_func(o, *args, **kwds) + + def new_func(o: Any) -> pd.DataFrame: + return original_func(o, *args, **kwds) self_applied: DataFrame = DataFrame(self._psdf._internal.resolved_copy) @@ -355,7 +360,7 @@ def apply_batch( ) limit = ps.get_option("compute.shortcut_limit") pdf = self_applied.head(limit + 1)._to_internal_pandas() - applied = func(pdf) + applied = new_func(pdf) if not isinstance(applied, pd.DataFrame): raise ValueError( "The given function should return a frame; however, " @@ -371,7 +376,7 @@ def apply_batch( return_schema = StructType([field.struct_field for field in index_fields + data_fields]) output_func = GroupBy._make_pandas_df_builder_func( - self_applied, func, return_schema, retain_index=True + self_applied, new_func, return_schema, retain_index=True ) sdf = self_applied._internal.spark_frame.mapInPandas( lambda iterator: map(output_func, iterator), schema=return_schema @@ -394,7 +399,7 @@ def apply_batch( return_schema = cast(DataFrameType, return_type).spark_type output_func = GroupBy._make_pandas_df_builder_func( - self_applied, func, return_schema, retain_index=should_retain_index + self_applied, new_func, return_schema, retain_index=should_retain_index ) sdf = self_applied._internal.to_internal_spark_frame.mapInPandas( lambda iterator: map(output_func, iterator), schema=return_schema @@ -570,10 +575,12 @@ def transform_batch( should_infer_schema = return_sig is None should_retain_index = should_infer_schema original_func = func - func = lambda o: original_func(o, *args, **kwargs) + + def new_func(o: Any) -> Union[pd.DataFrame, pd.Series]: + return original_func(o, *args, **kwargs) def apply_func(pdf: pd.DataFrame) -> pd.DataFrame: - return func(pdf).to_frame() + return new_func(pdf).to_frame() def pandas_series_func( f: Callable[[pd.DataFrame], pd.DataFrame], return_type: DataType @@ -595,7 +602,7 @@ def udf(pdf: pd.DataFrame) -> pd.Series: ) limit = ps.get_option("compute.shortcut_limit") pdf = self._psdf.head(limit + 1)._to_internal_pandas() - transformed = func(pdf) + transformed = new_func(pdf) if not isinstance(transformed, (pd.DataFrame, pd.Series)): raise ValueError( "The given function should return a frame; however, " @@ -606,7 +613,7 @@ def udf(pdf: pd.DataFrame) -> pd.Series: psdf_or_psser = ps.from_pandas(transformed) if isinstance(psdf_or_psser, ps.Series): - psser = cast(ps.Series, psdf_or_psser) + psser = psdf_or_psser field = psser._internal.data_fields[0].normalize_spark_type() @@ -644,7 +651,10 @@ def udf(pdf: pd.DataFrame) -> pd.Series: self_applied: DataFrame = DataFrame(self._psdf._internal.resolved_copy) output_func = GroupBy._make_pandas_df_builder_func( - self_applied, func, return_schema, retain_index=True # type: ignore[arg-type] + self_applied, + new_func, # type: ignore[arg-type] + return_schema, + retain_index=True, ) columns = self_applied._internal.spark_columns @@ -709,7 +719,10 @@ def udf(pdf: pd.DataFrame) -> pd.Series: self_applied = DataFrame(self._psdf._internal.resolved_copy) output_func = GroupBy._make_pandas_df_builder_func( - self_applied, func, return_schema, should_retain_index # type: ignore[arg-type] + self_applied, + new_func, # type: ignore[arg-type] + return_schema, + retain_index=should_retain_index, ) columns = self_applied._internal.spark_columns @@ -879,7 +892,7 @@ def transform_batch( "Expected the return type of this function to be of type column," " but found type {}".format(sig_return) ) - return_type = cast(SeriesType, sig_return) + return_type = sig_return return self._transform_batch(lambda c: func(c, *args, **kwargs), return_type) @@ -892,7 +905,10 @@ def _transform_batch( if not isinstance(func, FunctionType): f = func - func = lambda *args, **kwargs: f(*args, **kwargs) + # Note that the return type hint specified here affects actual return + # type in Spark (e.g., infer_return_type). And, MyPy does not allow + # redefinition of a function. + func = lambda *args, **kwargs: f(*args, **kwargs) # noqa: E731 if return_type is None: # TODO: In this case, it avoids the shortcut for now (but only infers schema) diff --git a/python/pyspark/pandas/base.py b/python/pyspark/pandas/base.py index 5cb60ca9ffac3..2d2c79e7f472d 100644 --- a/python/pyspark/pandas/base.py +++ b/python/pyspark/pandas/base.py @@ -27,7 +27,7 @@ import pandas as pd from pandas.api.types import is_list_like, CategoricalDtype # type: ignore[attr-defined] from pyspark.sql import functions as F, Column, Window -from pyspark.sql.types import LongType, BooleanType +from pyspark.sql.types import LongType, BooleanType, NumericType from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm. from pyspark.pandas._typing import Axis, Dtype, IndexOpsLike, Label, SeriesOrIndex @@ -148,7 +148,7 @@ def align_diff_index_ops( ], ).rename(this_index_ops.name) else: - this = cast(Index, this_index_ops).to_frame().reset_index(drop=True) + this = this_index_ops.to_frame().reset_index(drop=True) that_series = next(col for col in cols if isinstance(col, Series)) that_frame = that_series._psdf[ @@ -876,7 +876,9 @@ def isin(self: IndexOpsLike, values: Sequence[Any]) -> IndexOpsLike: " to isin(), you passed a [{values_type}]".format(values_type=type(values).__name__) ) - values = values.tolist() if isinstance(values, np.ndarray) else list(values) + values = ( + cast(np.ndarray, values).tolist() if isinstance(values, np.ndarray) else list(values) + ) other = [SF.lit(v) for v in values] scol = self.spark.column.isin(other) @@ -963,8 +965,8 @@ def notnull(self: IndexOpsLike) -> IndexOpsLike: notna = notnull - # TODO: axis, skipna, and many arguments should be implemented. - def all(self, axis: Axis = 0) -> bool: + # TODO: axis and many arguments should be implemented. + def all(self, axis: Axis = 0, skipna: bool = True) -> bool: """ Return whether all elements are True. @@ -979,6 +981,11 @@ def all(self, axis: Axis = 0) -> bool: * 0 / 'index' : reduce the index, return a Series whose index is the original column labels. + skipna : boolean, default True + Exclude NA/null values. If an entire row/column is NA and skipna is True, + then the result will be True, as for an empty row/column. + If skipna is False, then NA are treated as True, because these are not equal to zero. + Examples -------- >>> ps.Series([True, True]).all() @@ -996,6 +1003,9 @@ def all(self, axis: Axis = 0) -> bool: >>> ps.Series([True, True, None]).all() True + >>> ps.Series([True, True, None]).all(skipna=False) + False + >>> ps.Series([True, False, None]).all() False @@ -1005,6 +1015,15 @@ def all(self, axis: Axis = 0) -> bool: >>> ps.Series([np.nan]).all() True + >>> ps.Series([np.nan]).all(skipna=False) + True + + >>> ps.Series([None]).all() + True + + >>> ps.Series([None]).all(skipna=False) + False + >>> df = ps.Series([True, False, None]).rename("a").to_frame() >>> df.set_index("a").index.all() False @@ -1016,11 +1035,18 @@ def all(self, axis: Axis = 0) -> bool: sdf = self._internal.spark_frame.select(self.spark.column) col = scol_for(sdf, sdf.columns[0]) - # Note that we're ignoring `None`s here for now. - # any and every was added as of Spark 3.0 + # `any` and `every` was added as of Spark 3.0. # ret = sdf.select(F.expr("every(CAST(`%s` AS BOOLEAN))" % sdf.columns[0])).collect()[0][0] - # Here we use min as its alternative: - ret = sdf.select(F.min(F.coalesce(col.cast("boolean"), SF.lit(True)))).collect()[0][0] + # We use min as its alternative as below. + if isinstance(self.spark.data_type, NumericType) or skipna: + # np.nan takes no effect to the result; None takes no effect if `skipna` + ret = sdf.select(F.min(F.coalesce(col.cast("boolean"), SF.lit(True)))).collect()[0][0] + else: + # Take None as False when not `skipna` + ret = sdf.select( + F.min(F.when(col.isNull(), SF.lit(False)).otherwise(col.cast("boolean"))) + ).collect()[0][0] + if ret is None: return True else: diff --git a/python/pyspark/pandas/data_type_ops/complex_ops.py b/python/pyspark/pandas/data_type_ops/complex_ops.py index bee09f383e127..415301e400e99 100644 --- a/python/pyspark/pandas/data_type_ops/complex_ops.py +++ b/python/pyspark/pandas/data_type_ops/complex_ops.py @@ -53,7 +53,7 @@ def add(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: ) left_type = cast(ArrayType, left.spark.data_type).elementType - right_type = cast(ArrayType, right.spark.data_type).elementType + right_type = right.spark.data_type.elementType if left_type != right_type and not ( isinstance(left_type, NumericType) and isinstance(right_type, NumericType) diff --git a/python/pyspark/pandas/data_type_ops/datetime_ops.py b/python/pyspark/pandas/data_type_ops/datetime_ops.py index 3ec4109499c6e..16613f1bb288d 100644 --- a/python/pyspark/pandas/data_type_ops/datetime_ops.py +++ b/python/pyspark/pandas/data_type_ops/datetime_ops.py @@ -160,7 +160,7 @@ class DatetimeNTZOps(DatetimeOps): """ def _cast_spark_column_timestamp_to_long(self, scol: Column) -> Column: - jvm = SparkContext._active_spark_context._jvm # type: ignore[attr-defined] + jvm = SparkContext._active_spark_context._jvm return Column(jvm.PythonSQLUtils.castTimestampNTZToLong(scol._jc)) def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike: diff --git a/python/pyspark/pandas/datetimes.py b/python/pyspark/pandas/datetimes.py index f52809d5abef7..d0b3f2ff3d749 100644 --- a/python/pyspark/pandas/datetimes.py +++ b/python/pyspark/pandas/datetimes.py @@ -18,17 +18,16 @@ """ Date/Time related functions on pandas-on-Spark Series """ -from typing import Any, Optional, Union, TYPE_CHECKING, no_type_check +from typing import Any, Optional, Union, no_type_check import numpy as np import pandas as pd # noqa: F401 from pandas.tseries.offsets import DateOffset + +import pyspark.pandas as ps import pyspark.sql.functions as F from pyspark.sql.types import DateType, TimestampType, TimestampNTZType, LongType -if TYPE_CHECKING: - import pyspark.pandas as ps - class DatetimeMethods: """Date/Time methods for pandas-on-Spark Series""" @@ -107,8 +106,7 @@ def microsecond(self) -> "ps.Series": The microseconds of the datetime. """ - @no_type_check - def pandas_microsecond(s) -> "ps.Series[np.int64]": + def pandas_microsecond(s) -> ps.Series[np.int64]: # type: ignore[no-untyped-def] return s.dt.microsecond return self._data.pandas_on_spark.transform_batch(pandas_microsecond) @@ -167,8 +165,7 @@ def dayofweek(self) -> "ps.Series": dtype: int64 """ - @no_type_check - def pandas_dayofweek(s) -> "ps.Series[np.int64]": + def pandas_dayofweek(s) -> ps.Series[np.int64]: # type: ignore[no-untyped-def] return s.dt.dayofweek return self._data.pandas_on_spark.transform_batch(pandas_dayofweek) @@ -185,8 +182,7 @@ def dayofyear(self) -> "ps.Series": The ordinal day of the year. """ - @no_type_check - def pandas_dayofyear(s) -> "ps.Series[np.int64]": + def pandas_dayofyear(s) -> ps.Series[np.int64]: # type: ignore[no-untyped-def] return s.dt.dayofyear return self._data.pandas_on_spark.transform_batch(pandas_dayofyear) @@ -197,8 +193,7 @@ def quarter(self) -> "ps.Series": The quarter of the date. """ - @no_type_check - def pandas_quarter(s) -> "ps.Series[np.int64]": + def pandas_quarter(s) -> ps.Series[np.int64]: # type: ignore[no-untyped-def] return s.dt.quarter return self._data.pandas_on_spark.transform_batch(pandas_quarter) @@ -237,8 +232,7 @@ def is_month_start(self) -> "ps.Series": dtype: bool """ - @no_type_check - def pandas_is_month_start(s) -> "ps.Series[bool]": + def pandas_is_month_start(s) -> ps.Series[bool]: # type: ignore[no-untyped-def] return s.dt.is_month_start return self._data.pandas_on_spark.transform_batch(pandas_is_month_start) @@ -277,8 +271,7 @@ def is_month_end(self) -> "ps.Series": dtype: bool """ - @no_type_check - def pandas_is_month_end(s) -> "ps.Series[bool]": + def pandas_is_month_end(s) -> ps.Series[bool]: # type: ignore[no-untyped-def] return s.dt.is_month_end return self._data.pandas_on_spark.transform_batch(pandas_is_month_end) @@ -328,8 +321,7 @@ def is_quarter_start(self) -> "ps.Series": Name: dates, dtype: bool """ - @no_type_check - def pandas_is_quarter_start(s) -> "ps.Series[bool]": + def pandas_is_quarter_start(s) -> ps.Series[bool]: # type: ignore[no-untyped-def] return s.dt.is_quarter_start return self._data.pandas_on_spark.transform_batch(pandas_is_quarter_start) @@ -379,8 +371,7 @@ def is_quarter_end(self) -> "ps.Series": Name: dates, dtype: bool """ - @no_type_check - def pandas_is_quarter_end(s) -> "ps.Series[bool]": + def pandas_is_quarter_end(s) -> ps.Series[bool]: # type: ignore[no-untyped-def] return s.dt.is_quarter_end return self._data.pandas_on_spark.transform_batch(pandas_is_quarter_end) @@ -419,8 +410,7 @@ def is_year_start(self) -> "ps.Series": dtype: bool """ - @no_type_check - def pandas_is_year_start(s) -> "ps.Series[bool]": + def pandas_is_year_start(s) -> ps.Series[bool]: # type: ignore[no-untyped-def] return s.dt.is_year_start return self._data.pandas_on_spark.transform_batch(pandas_is_year_start) @@ -459,8 +449,7 @@ def is_year_end(self) -> "ps.Series": dtype: bool """ - @no_type_check - def pandas_is_year_end(s) -> "ps.Series[bool]": + def pandas_is_year_end(s) -> ps.Series[bool]: # type: ignore[no-untyped-def] return s.dt.is_year_end return self._data.pandas_on_spark.transform_batch(pandas_is_year_end) @@ -499,8 +488,7 @@ def is_leap_year(self) -> "ps.Series": dtype: bool """ - @no_type_check - def pandas_is_leap_year(s) -> "ps.Series[bool]": + def pandas_is_leap_year(s) -> ps.Series[bool]: # type: ignore[no-untyped-def] return s.dt.is_leap_year return self._data.pandas_on_spark.transform_batch(pandas_is_leap_year) @@ -511,8 +499,7 @@ def daysinmonth(self) -> "ps.Series": The number of days in the month. """ - @no_type_check - def pandas_daysinmonth(s) -> "ps.Series[np.int64]": + def pandas_daysinmonth(s) -> ps.Series[np.int64]: # type: ignore[no-untyped-def] return s.dt.daysinmonth return self._data.pandas_on_spark.transform_batch(pandas_daysinmonth) @@ -574,8 +561,7 @@ def normalize(self) -> "ps.Series": dtype: datetime64[ns] """ - @no_type_check - def pandas_normalize(s) -> "ps.Series[np.datetime64]": + def pandas_normalize(s) -> ps.Series[np.datetime64]: # type: ignore[no-untyped-def] return s.dt.normalize() return self._data.pandas_on_spark.transform_batch(pandas_normalize) @@ -623,8 +609,7 @@ def strftime(self, date_format: str) -> "ps.Series": dtype: object """ - @no_type_check - def pandas_strftime(s) -> "ps.Series[str]": + def pandas_strftime(s) -> ps.Series[str]: # type: ignore[no-untyped-def] return s.dt.strftime(date_format) return self._data.pandas_on_spark.transform_batch(pandas_strftime) @@ -679,8 +664,7 @@ def round(self, freq: Union[str, DateOffset], *args: Any, **kwargs: Any) -> "ps. dtype: datetime64[ns] """ - @no_type_check - def pandas_round(s) -> "ps.Series[np.datetime64]": + def pandas_round(s) -> ps.Series[np.datetime64]: # type: ignore[no-untyped-def] return s.dt.round(freq, *args, **kwargs) return self._data.pandas_on_spark.transform_batch(pandas_round) @@ -735,8 +719,7 @@ def floor(self, freq: Union[str, DateOffset], *args: Any, **kwargs: Any) -> "ps. dtype: datetime64[ns] """ - @no_type_check - def pandas_floor(s) -> "ps.Series[np.datetime64]": + def pandas_floor(s) -> ps.Series[np.datetime64]: # type: ignore[no-untyped-def] return s.dt.floor(freq, *args, **kwargs) return self._data.pandas_on_spark.transform_batch(pandas_floor) @@ -791,8 +774,7 @@ def ceil(self, freq: Union[str, DateOffset], *args: Any, **kwargs: Any) -> "ps.S dtype: datetime64[ns] """ - @no_type_check - def pandas_ceil(s) -> "ps.Series[np.datetime64]": + def pandas_ceil(s) -> ps.Series[np.datetime64]: # type: ignore[no-untyped-def] return s.dt.ceil(freq, *args, **kwargs) return self._data.pandas_on_spark.transform_batch(pandas_ceil) @@ -828,8 +810,7 @@ def month_name(self, locale: Optional[str] = None) -> "ps.Series": dtype: object """ - @no_type_check - def pandas_month_name(s) -> "ps.Series[str]": + def pandas_month_name(s) -> ps.Series[str]: # type: ignore[no-untyped-def] return s.dt.month_name(locale=locale) return self._data.pandas_on_spark.transform_batch(pandas_month_name) @@ -865,8 +846,7 @@ def day_name(self, locale: Optional[str] = None) -> "ps.Series": dtype: object """ - @no_type_check - def pandas_day_name(s) -> "ps.Series[str]": + def pandas_day_name(s) -> ps.Series[str]: # type: ignore[no-untyped-def] return s.dt.day_name(locale=locale) return self._data.pandas_on_spark.transform_batch(pandas_day_name) diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index 0a11a0f15f80f..41a0dde47a51b 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -2438,7 +2438,10 @@ def apply( if not isinstance(func, types.FunctionType): assert callable(func), "the first argument should be a callable function." f = func - func = lambda *args, **kwargs: f(*args, **kwargs) + # Note that the return type hint specified here affects actual return + # type in Spark (e.g., infer_return_type). And, MyPy does not allow + # redefinition of a function. + func = lambda *args, **kwargs: f(*args, **kwargs) # noqa: E731 axis = validate_axis(axis) should_return_series = False @@ -2691,7 +2694,10 @@ def transform( if not isinstance(func, types.FunctionType): assert callable(func), "the first argument should be a callable function." f = func - func = lambda *args, **kwargs: f(*args, **kwargs) + # Note that the return type hint specified here affects actual return + # type in Spark (e.g., infer_return_type). And, MyPy does not allow + # redefinition of a function. + func = lambda *args, **kwargs: f(*args, **kwargs) # noqa: E731 axis = validate_axis(axis) if axis != 0: @@ -3028,8 +3034,9 @@ def between_time( psdf.index.name = verify_temp_column_name(psdf, "__index_name__") return_types = [psdf.index.dtype] + list(psdf.dtypes) - @no_type_check - def pandas_between_time(pdf) -> ps.DataFrame[return_types]: + def pandas_between_time( # type: ignore[no-untyped-def] + pdf, + ) -> ps.DataFrame[return_types]: # type: ignore[valid-type] return pdf.between_time(start_time, end_time, include_start, include_end).reset_index() # apply_batch will remove the index of the pandas-on-Spark DataFrame and attach a @@ -3106,8 +3113,9 @@ def at_time( psdf.index.name = verify_temp_column_name(psdf, "__index_name__") return_types = [psdf.index.dtype] + list(psdf.dtypes) - @no_type_check - def pandas_at_time(pdf) -> ps.DataFrame[return_types]: + def pandas_at_time( # type: ignore[no-untyped-def] + pdf, + ) -> ps.DataFrame[return_types]: # type: ignore[valid-type] return pdf.at_time(time, asof, axis).reset_index() # apply_batch will remove the index of the pandas-on-Spark DataFrame and attach @@ -5466,9 +5474,15 @@ def op(psser: ps.Series) -> ps.Series: return psser else: - op = lambda psser: psser._fillna(value=value, method=method, axis=axis, limit=limit) + + def op(psser: ps.Series) -> ps.Series: + return psser._fillna(value=value, method=method, axis=axis, limit=limit) + elif method is not None: - op = lambda psser: psser._fillna(value=value, method=method, axis=axis, limit=limit) + + def op(psser: ps.Series) -> ps.Series: + return psser._fillna(value=value, method=method, axis=axis, limit=limit) + else: raise ValueError("Must specify a fillna 'value' or 'method' parameter.") @@ -5592,7 +5606,7 @@ def replace( if isinstance(to_replace, dict) and ( value is not None or all(isinstance(i, dict) for i in to_replace.values()) ): - to_replace_dict = cast(dict, to_replace) + to_replace_dict = to_replace def op(psser: ps.Series) -> ps.Series: if psser.name in to_replace_dict: @@ -5603,7 +5617,9 @@ def op(psser: ps.Series) -> ps.Series: return psser else: - op = lambda psser: psser.replace(to_replace=to_replace, value=value, regex=regex) + + def op(psser: ps.Series) -> ps.Series: + return psser.replace(to_replace=to_replace, value=value, regex=regex) psdf = self._apply_series_op(op) if inplace: @@ -6839,6 +6855,7 @@ def sort_values( ascending: Union[bool, List[bool]] = True, inplace: bool = False, na_position: str = "last", + ignore_index: bool = False, ) -> Optional["DataFrame"]: """ Sort by the values along either axis. @@ -6854,6 +6871,8 @@ def sort_values( if True, perform operation in-place na_position : {'first', 'last'}, default 'last' `first` puts NaNs at the beginning, `last` puts NaNs at the end + ignore_index : bool, default False + If True, the resulting axis will be labeled 0, 1, …, n - 1. Returns ------- @@ -6866,34 +6885,45 @@ def sort_values( ... 'col2': [2, 9, 8, 7, 4], ... 'col3': [0, 9, 4, 2, 3], ... }, - ... columns=['col1', 'col2', 'col3']) + ... columns=['col1', 'col2', 'col3'], + ... index=['a', 'b', 'c', 'd', 'e']) >>> df col1 col2 col3 - 0 A 2 0 - 1 B 9 9 - 2 None 8 4 - 3 D 7 2 - 4 C 4 3 + a A 2 0 + b B 9 9 + c None 8 4 + d D 7 2 + e C 4 3 Sort by col1 >>> df.sort_values(by=['col1']) col1 col2 col3 + a A 2 0 + b B 9 9 + e C 4 3 + d D 7 2 + c None 8 4 + + Ignore index for the resulting axis + + >>> df.sort_values(by=['col1'], ignore_index=True) + col1 col2 col3 0 A 2 0 1 B 9 9 - 4 C 4 3 + 2 C 4 3 3 D 7 2 - 2 None 8 4 + 4 None 8 4 Sort Descending >>> df.sort_values(by='col1', ascending=False) col1 col2 col3 - 3 D 7 2 - 4 C 4 3 - 1 B 9 9 - 0 A 2 0 - 2 None 8 4 + d D 7 2 + e C 4 3 + b B 9 9 + a A 2 0 + c None 8 4 Sort by multiple columns @@ -6929,11 +6959,14 @@ def sort_values( new_by.append(ser.spark.column) psdf = self._sort(by=new_by, ascending=ascending, na_position=na_position) + if inplace: + if ignore_index: + psdf.reset_index(drop=True, inplace=inplace) self._update_internal_frame(psdf._internal) return None else: - return psdf + return psdf.reset_index(drop=True) if ignore_index else psdf def sort_index( self, @@ -7267,7 +7300,7 @@ def _swaplevel_index(self, i: Union[int, Name], j: Union[int, Name]) -> Internal ) return internal - # TODO: add keep = First + # TODO: add keep = First def nlargest(self, n: int, columns: Union[Name, List[Name]]) -> "DataFrame": """ Return the first `n` rows ordered by `columns` in descending order. @@ -7324,7 +7357,7 @@ def nlargest(self, n: int, columns: Union[Name, List[Name]]) -> "DataFrame": 6 NaN 12 In the following example, we will use ``nlargest`` to select the three - rows having the largest values in column "population". + rows having the largest values in column "X". >>> df.nlargest(n=3, columns='X') X Y @@ -7332,12 +7365,14 @@ def nlargest(self, n: int, columns: Union[Name, List[Name]]) -> "DataFrame": 4 6.0 10 3 5.0 9 + To order by the largest values in column "Y" and then "X", we can + specify multiple columns like in the next example. + >>> df.nlargest(n=3, columns=['Y', 'X']) X Y 6 NaN 12 5 7.0 11 4 6.0 10 - """ return self.sort_values(by=columns, ascending=False).head(n=n) @@ -7387,7 +7422,7 @@ def nsmallest(self, n: int, columns: Union[Name, List[Name]]) -> "DataFrame": 6 NaN 12 In the following example, we will use ``nsmallest`` to select the - three rows having the smallest values in column "a". + three rows having the smallest values in column "X". >>> df.nsmallest(n=3, columns='X') # doctest: +NORMALIZE_WHITESPACE X Y @@ -7395,7 +7430,7 @@ def nsmallest(self, n: int, columns: Union[Name, List[Name]]) -> "DataFrame": 1 2.0 7 2 3.0 8 - To order by the largest values in column "a" and then "c", we can + To order by the smallest values in column "Y" and then "X", we can specify multiple columns like in the next example. >>> df.nsmallest(n=3, columns=['Y', 'X']) # doctest: +NORMALIZE_WHITESPACE @@ -7698,7 +7733,9 @@ def to_list(os: Optional[Union[Name, List[Name]]]) -> List[Label]: how = validate_how(how) def resolve(internal: InternalFrame, side: str) -> InternalFrame: - rename = lambda col: "__{}_{}".format(side, col) + def rename(col: str) -> str: + return "__{}_{}".format(side, col) + internal = internal.resolved_copy sdf = internal.spark_frame sdf = sdf.select( @@ -7750,12 +7787,11 @@ def resolve(internal: InternalFrame, side: str) -> InternalFrame: data_columns = [] column_labels = [] - left_scol_for = lambda label: scol_for( - left_table, left_internal.spark_column_name_for(label) - ) - right_scol_for = lambda label: scol_for( - right_table, right_internal.spark_column_name_for(label) - ) + def left_scol_for(label: Label) -> Column: + return scol_for(left_table, left_internal.spark_column_name_for(label)) + + def right_scol_for(label: Label) -> Column: + return scol_for(right_table, right_internal.spark_column_name_for(label)) for label in left_internal.column_labels: col = left_internal.spark_column_name_for(label) @@ -10547,7 +10583,7 @@ def gen_mapper_fn( mapper: Union[Dict, Callable[[Any], Any]] ) -> Tuple[Callable[[Any], Any], Dtype, DataType]: if isinstance(mapper, dict): - mapper_dict = cast(dict, mapper) + mapper_dict = mapper type_set = set(map(lambda x: type(x), mapper_dict.values())) if len(type_set) > 1: @@ -11645,8 +11681,7 @@ def eval(self, expr: str, inplace: bool = False) -> Optional[DataFrameOrSeries]: # Since `eval_func` doesn't have a type hint, inferring the schema is always preformed # in the `apply_batch`. Hence, the variables `should_return_series`, `series_name`, # and `should_return_scalar` can be updated. - @no_type_check - def eval_func(pdf): + def eval_func(pdf): # type: ignore[no-untyped-def] nonlocal should_return_series nonlocal series_name nonlocal should_return_scalar @@ -12465,7 +12500,7 @@ def _reduce_spark_multi(sdf: SparkDataFrame, aggs: List[Column]) -> Any: """ assert isinstance(sdf, SparkDataFrame) sdf0 = sdf.agg(*aggs) - lst = cast(pd.DataFrame, sdf0.limit(2).toPandas()) + lst = sdf0.limit(2).toPandas() assert len(lst) == 1, (sdf, lst) row = lst.iloc[0] lst2 = list(row) diff --git a/python/pyspark/pandas/generic.py b/python/pyspark/pandas/generic.py index 2dac5b056aba0..49375da516629 100644 --- a/python/pyspark/pandas/generic.py +++ b/python/pyspark/pandas/generic.py @@ -24,6 +24,7 @@ from typing import ( Any, Callable, + Dict, Iterable, IO, List, @@ -581,7 +582,7 @@ def to_numpy(self) -> np.ndarray: "`to_numpy` loads all data into the driver's memory. " "It should only be used if the resulting NumPy ndarray is expected to be small." ) - return self._to_pandas().values + return cast(np.ndarray, self._to_pandas().values) @property def values(self) -> np.ndarray: @@ -905,6 +906,9 @@ def to_json( .. note:: output JSON format is different from pandas'. It always use `orient='records'` for its output. This behaviour might have to change in the near future. + .. note:: Set `ignoreNullFields` keyword argument to `True` to omit `None` or `NaN` values + when writing JSON objects. It works only when `path` is provided. + Note NaN's and None will be converted to null and datetime objects will be converted to UNIX timestamps. @@ -981,6 +985,9 @@ def to_json( if "options" in options and isinstance(options.get("options"), dict) and len(options) == 1: options = options.get("options") + default_options: Dict[str, Any] = {"ignoreNullFields": False} + options = {**default_options, **options} + if not lines: raise NotImplementedError("lines=False is not implemented yet.") @@ -2434,12 +2441,11 @@ def first_valid_index(self) -> Optional[Union[Scalar, Tuple[Scalar, ...]]]: with sql_conf({SPARK_CONF_ARROW_ENABLED: False}): # Disable Arrow to keep row ordering. - first_valid_row = cast( - pd.DataFrame, + first_valid_row = ( self._internal.spark_frame.filter(cond) .select(self._internal.index_spark_columns) .limit(1) - .toPandas(), + .toPandas() ) # For Empty Series or DataFrame, returns None. diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py index b9b65208910f7..addb53d8cd5c1 100644 --- a/python/pyspark/pandas/groupby.py +++ b/python/pyspark/pandas/groupby.py @@ -1251,9 +1251,9 @@ def pandas_apply(pdf: pd.DataFrame, *a: Any, **k: Any) -> Any: ) if isinstance(return_type, DataFrameType): - data_fields = cast(DataFrameType, return_type).data_fields - return_schema = cast(DataFrameType, return_type).spark_type - index_fields = cast(DataFrameType, return_type).index_fields + data_fields = return_type.data_fields + return_schema = return_type.spark_type + index_fields = return_type.index_fields should_retain_index = len(index_fields) > 0 psdf_from_pandas = None else: @@ -2287,8 +2287,8 @@ def pandas_transform(pdf: pd.DataFrame) -> pd.DataFrame: "but found type {}".format(return_type) ) - dtype = cast(SeriesType, return_type).dtype - spark_type = cast(SeriesType, return_type).spark_type + dtype = return_type.dtype + spark_type = return_type.spark_type data_fields = [ InternalField(dtype=dtype, struct_field=StructField(name=c, dataType=spark_type)) @@ -2356,12 +2356,16 @@ def nunique(self, dropna: bool = True) -> FrameLike: Name: value1, dtype: int64 """ if dropna: - stat_function = lambda col: F.countDistinct(col) + + def stat_function(col: Column) -> Column: + return F.countDistinct(col) + else: - stat_function = lambda col: ( - F.countDistinct(col) - + F.when(F.count(F.when(col.isNull(), 1).otherwise(None)) >= 1, 1).otherwise(0) - ) + + def stat_function(col: Column) -> Column: + return F.countDistinct(col) + F.when( + F.count(F.when(col.isNull(), 1).otherwise(None)) >= 1, 1 + ).otherwise(0) return self._reduce_for_stat_function(stat_function, only_numeric=False) @@ -2563,7 +2567,9 @@ def median(self, numeric_only: bool = True, accuracy: int = 10000) -> FrameLike: "accuracy must be an integer; however, got [%s]" % type(accuracy).__name__ ) - stat_function = lambda col: F.percentile_approx(col, 0.5, accuracy) + def stat_function(col: Column) -> Column: + return F.percentile_approx(col, 0.5, accuracy) + return self._reduce_for_stat_function(stat_function, only_numeric=numeric_only) def _reduce_for_stat_function( @@ -2832,7 +2838,7 @@ def _apply_series_op( ) -> DataFrame: applied = [] for column in self._agg_columns: - applied.append(op(cast(SeriesGroupBy, column.groupby(self._groupkeys)))) + applied.append(op(column.groupby(self._groupkeys))) if numeric_only: applied = [col for col in applied if isinstance(col.spark.data_type, NumericType)] if not applied: diff --git a/python/pyspark/pandas/indexes/base.py b/python/pyspark/pandas/indexes/base.py index 3a7fa5e636848..1705ef83261bf 100644 --- a/python/pyspark/pandas/indexes/base.py +++ b/python/pyspark/pandas/indexes/base.py @@ -286,12 +286,11 @@ def _summary(self, name: Optional[str] = None) -> str: String with a summarized representation of the index """ head, tail, total_count = tuple( - cast( - pd.DataFrame, - self._internal.spark_frame.select( - F.first(self.spark.column), F.last(self.spark.column), F.count(F.expr("*")) - ).toPandas(), - ).iloc[0] + self._internal.spark_frame.select( + F.first(self.spark.column), F.last(self.spark.column), F.count(F.expr("*")) + ) + .toPandas() + .iloc[0] ) if total_count > 0: @@ -1652,11 +1651,10 @@ def min(self) -> Union[Scalar, Tuple[Scalar, ...]]: ('a', 'x', 1) """ sdf = self._internal.spark_frame - min_row = cast( - pd.DataFrame, + min_row = ( sdf.select(F.min(F.struct(*self._internal.index_spark_columns)).alias("min_row")) .select("min_row.*") - .toPandas(), + .toPandas() ) result = tuple(min_row.iloc[0]) @@ -1694,11 +1692,10 @@ def max(self) -> Union[Scalar, Tuple[Scalar, ...]]: ('b', 'y', 2) """ sdf = self._internal.spark_frame - max_row = cast( - pd.DataFrame, + max_row = ( sdf.select(F.max(F.struct(*self._internal.index_spark_columns)).alias("max_row")) .select("max_row.*") - .toPandas(), + .toPandas() ) result = tuple(max_row.iloc[0]) @@ -2285,7 +2282,7 @@ def asof(self, label: Any) -> Scalar: else: raise ValueError("index must be monotonic increasing or decreasing") - result = cast(pd.DataFrame, sdf.toPandas()).iloc[0, 0] + result = sdf.toPandas().iloc[0, 0] return result if result is not None else np.nan def _index_fields_for_union_like( diff --git a/python/pyspark/pandas/indexes/category.py b/python/pyspark/pandas/indexes/category.py index 40c6410cedc65..2dfc7f25eb0f3 100644 --- a/python/pyspark/pandas/indexes/category.py +++ b/python/pyspark/pandas/indexes/category.py @@ -705,6 +705,10 @@ def map( # type: ignore[override] """ return super().map(mapper) + @no_type_check + def all(self, *args, **kwargs) -> None: + raise TypeError("Cannot perform 'all' with this index type: %s" % type(self).__name__) + def _test() -> None: import os diff --git a/python/pyspark/pandas/indexes/datetimes.py b/python/pyspark/pandas/indexes/datetimes.py index abc1d8c35f5a4..b4a7c1e8356a8 100644 --- a/python/pyspark/pandas/indexes/datetimes.py +++ b/python/pyspark/pandas/indexes/datetimes.py @@ -682,8 +682,7 @@ def indexer_between_time( Int64Index([2], dtype='int64') """ - @no_type_check - def pandas_between_time(pdf) -> ps.DataFrame[int]: + def pandas_between_time(pdf) -> ps.DataFrame[int]: # type: ignore[no-untyped-def] return pdf.between_time(start_time, end_time, include_start, include_end) psdf = self.to_frame()[[]] @@ -728,8 +727,7 @@ def indexer_at_time(self, time: Union[datetime.time, str], asof: bool = False) - if asof: raise NotImplementedError("'asof' argument is not supported") - @no_type_check - def pandas_at_time(pdf) -> ps.DataFrame[int]: + def pandas_at_time(pdf) -> ps.DataFrame[int]: # type: ignore[no-untyped-def] return pdf.at_time(time, asof) psdf = self.to_frame()[[]] @@ -741,6 +739,10 @@ def pandas_at_time(pdf) -> ps.DataFrame[int]: psdf = psdf.pandas_on_spark.apply_batch(pandas_at_time) return ps.Index(first_series(psdf).rename(self.name)) + @no_type_check + def all(self, *args, **kwargs) -> None: + raise TypeError("Cannot perform 'all' with this index type: %s" % type(self).__name__) + def disallow_nanoseconds(freq: Union[str, DateOffset]) -> None: if freq in ["N", "ns"]: diff --git a/python/pyspark/pandas/indexes/timedelta.py b/python/pyspark/pandas/indexes/timedelta.py index 2888642375655..564c484d9684b 100644 --- a/python/pyspark/pandas/indexes/timedelta.py +++ b/python/pyspark/pandas/indexes/timedelta.py @@ -137,8 +137,7 @@ def days(self) -> Index: Number of days for each element. """ - @no_type_check - def pandas_days(x) -> int: + def pandas_days(x) -> int: # type: ignore[no-untyped-def] return x.days return Index(self.to_series().transform(pandas_days)) @@ -193,3 +192,7 @@ def get_microseconds(scol): ).cast("int") return Index(self.to_series().spark.transform(get_microseconds)) + + @no_type_check + def all(self, *args, **kwargs) -> None: + raise TypeError("Cannot perform 'all' with this index type: %s" % type(self).__name__) diff --git a/python/pyspark/pandas/indexing.py b/python/pyspark/pandas/indexing.py index 2058c264f270c..76627bd0e128b 100644 --- a/python/pyspark/pandas/indexing.py +++ b/python/pyspark/pandas/indexing.py @@ -172,7 +172,7 @@ def __getitem__(self, key: Any) -> Union["Series", "DataFrame", Scalar]: if len(pdf) < 1: raise KeyError(name_like_string(row_sel)) - values = cast(pd.DataFrame, pdf).iloc[:, 0].values + values = pdf.iloc[:, 0].values return ( values if (len(row_sel) < self._internal.index_level or len(values) > 1) else values[0] ) @@ -535,7 +535,7 @@ def __getitem__(self, key: Any) -> Union["Series", "DataFrame"]: except AnalysisException: raise KeyError( "[{}] don't exist in columns".format( - [col._jc.toString() for col in data_spark_columns] # type: ignore[operator] + [col._jc.toString() for col in data_spark_columns] ) ) @@ -553,7 +553,7 @@ def __getitem__(self, key: Any) -> Union["Series", "DataFrame"]: psdf_or_psser: Union[DataFrame, Series] if returns_series: - psdf_or_psser = cast(Series, first_series(psdf)) + psdf_or_psser = first_series(psdf) if series_name is not None and series_name != psdf_or_psser.name: psdf_or_psser = psdf_or_psser.rename(series_name) else: diff --git a/python/pyspark/pandas/internal.py b/python/pyspark/pandas/internal.py index 1c32c430911cd..ffc86ba4c6134 100644 --- a/python/pyspark/pandas/internal.py +++ b/python/pyspark/pandas/internal.py @@ -905,8 +905,8 @@ def attach_distributed_sequence_column(sdf: SparkDataFrame, column_name: str) -> """ if len(sdf.columns) > 0: return SparkDataFrame( - sdf._jdf.toDF().withSequenceColumn(column_name), # type: ignore[operator] - sdf.sql_ctx, + sdf._jdf.toDF().withSequenceColumn(column_name), + sdf.sparkSession, ) else: cnt = sdf.count() diff --git a/python/pyspark/pandas/missing/common.py b/python/pyspark/pandas/missing/common.py index 1ebf28bb0bbf5..e6530a00bad14 100644 --- a/python/pyspark/pandas/missing/common.py +++ b/python/pyspark/pandas/missing/common.py @@ -16,44 +16,61 @@ # -memory_usage = lambda f: f( - "memory_usage", - reason="Unlike pandas, most DataFrames are not materialized in memory in Spark " - "(and pandas-on-Spark), and as a result memory_usage() does not do what you intend it " - "to do. Use Spark's web UI to monitor disk and memory usage of your application.", -) - -array = lambda f: f( - "array", reason="If you want to collect your data as an NumPy array, use 'to_numpy()' instead." -) - -to_pickle = lambda f: f( - "to_pickle", - reason="For storage, we encourage you to use Delta or Parquet, instead of Python pickle " - "format.", -) - -to_xarray = lambda f: f( - "to_xarray", - reason="If you want to collect your data as an NumPy array, use 'to_numpy()' instead.", -) - -to_list = lambda f: f( - "to_list", - reason="If you want to collect your data as an NumPy array, use 'to_numpy()' instead.", -) - -tolist = lambda f: f( - "tolist", reason="If you want to collect your data as an NumPy array, use 'to_numpy()' instead." -) - -__iter__ = lambda f: f( - "__iter__", - reason="If you want to collect your data as an NumPy array, use 'to_numpy()' instead.", -) - -duplicated = lambda f: f( - "duplicated", - reason="'duplicated' API returns np.ndarray and the data size is too large." - "You can just use DataFrame.deduplicated instead", -) +def memory_usage(f): + return f( + "memory_usage", + reason="Unlike pandas, most DataFrames are not materialized in memory in Spark " + "(and pandas-on-Spark), and as a result memory_usage() does not do what you intend it " + "to do. Use Spark's web UI to monitor disk and memory usage of your application.", + ) + + +def array(f): + return f( + "array", + reason="If you want to collect your data as an NumPy array, use 'to_numpy()' instead.", + ) + + +def to_pickle(f): + return f( + "to_pickle", + reason="For storage, we encourage you to use Delta or Parquet, instead of Python pickle " + "format.", + ) + + +def to_xarray(f): + return f( + "to_xarray", + reason="If you want to collect your data as an NumPy array, use 'to_numpy()' instead.", + ) + + +def to_list(f): + return f( + "to_list", + reason="If you want to collect your data as an NumPy array, use 'to_numpy()' instead.", + ) + + +def tolist(f): + return f( + "tolist", + reason="If you want to collect your data as an NumPy array, use 'to_numpy()' instead.", + ) + + +def __iter__(f): + return f( + "__iter__", + reason="If you want to collect your data as an NumPy array, use 'to_numpy()' instead.", + ) + + +def duplicated(f): + return f( + "duplicated", + reason="'duplicated' API returns np.ndarray and the data size is too large." + "You can just use DataFrame.deduplicated instead", + ) diff --git a/python/pyspark/pandas/ml.py b/python/pyspark/pandas/ml.py index f0dbc9ab55f50..a8203f11d8d57 100644 --- a/python/pyspark/pandas/ml.py +++ b/python/pyspark/pandas/ml.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import List, Tuple, TYPE_CHECKING, cast +from typing import List, Tuple, TYPE_CHECKING import numpy as np import pandas as pd @@ -54,7 +54,7 @@ def corr(psdf: "ps.DataFrame", method: str = "pearson") -> pd.DataFrame: assert method in ("pearson", "spearman") ndf, column_labels = to_numeric_df(psdf) corr = Correlation.corr(ndf, CORRELATION_OUTPUT_COLUMN, method) - pcorr = cast(pd.DataFrame, corr.toPandas()) + pcorr = corr.toPandas() arr = pcorr.iloc[0, 0].toArray() if column_labels_level(column_labels) > 1: idx = pd.MultiIndex.from_tuples(column_labels) @@ -78,7 +78,7 @@ def to_numeric_df(psdf: "ps.DataFrame") -> Tuple[pyspark.sql.DataFrame, List[Lab """ # TODO, it should be more robust. accepted_types = { - np.dtype(dt) # type: ignore[misc] + np.dtype(dt) for dt in [np.int8, np.int16, np.int32, np.int64, np.float32, np.float64, np.bool_] } numeric_column_labels = [ diff --git a/python/pyspark/pandas/namespace.py b/python/pyspark/pandas/namespace.py index ae0018ca3d385..340e270ace551 100644 --- a/python/pyspark/pandas/namespace.py +++ b/python/pyspark/pandas/namespace.py @@ -816,8 +816,9 @@ def read_parquet( if index_col is None and pandas_metadata: # Try to read pandas metadata - @no_type_check - @pandas_udf("index_col array, index_names array") + @pandas_udf( # type: ignore[call-overload] + "index_col array, index_names array" + ) def read_index_metadata(pser: pd.Series) -> pd.DataFrame: binary = pser.iloc[0] metadata = pq.ParquetFile(pa.BufferReader(binary)).metadata.metadata @@ -3363,7 +3364,9 @@ def to_list(os: Optional[Union[Name, List[Name]]]) -> List[Label]: right_as_of_name = right_as_of_names[0] def resolve(internal: InternalFrame, side: str) -> InternalFrame: - rename = lambda col: "__{}_{}".format(side, col) + def rename(col: str) -> str: + return "__{}_{}".format(side, col) + internal = internal.resolved_copy sdf = internal.spark_frame sdf = sdf.select( @@ -3430,12 +3433,11 @@ def resolve(internal: InternalFrame, side: str) -> InternalFrame: data_columns = [] column_labels = [] - left_scol_for = lambda label: scol_for( - as_of_joined_table, left_internal.spark_column_name_for(label) - ) - right_scol_for = lambda label: scol_for( - as_of_joined_table, right_internal.spark_column_name_for(label) - ) + def left_scol_for(label: Label) -> Column: + return scol_for(as_of_joined_table, left_internal.spark_column_name_for(label)) + + def right_scol_for(label: Label) -> Column: + return scol_for(as_of_joined_table, right_internal.spark_column_name_for(label)) for label in left_internal.column_labels: col = left_internal.spark_column_name_for(label) diff --git a/python/pyspark/pandas/plot/__init__.py b/python/pyspark/pandas/plot/__init__.py index 8b3376e7b214f..d00e002266ebd 100644 --- a/python/pyspark/pandas/plot/__init__.py +++ b/python/pyspark/pandas/plot/__init__.py @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from pyspark.pandas.plot.core import * # noqa: F401 +from pyspark.pandas.plot.core import * # noqa: F401,F403 diff --git a/python/pyspark/pandas/plot/core.py b/python/pyspark/pandas/plot/core.py index 5daf92178149a..8ee959db481a6 100644 --- a/python/pyspark/pandas/plot/core.py +++ b/python/pyspark/pandas/plot/core.py @@ -20,7 +20,7 @@ import pandas as pd import numpy as np from pyspark.ml.feature import Bucketizer -from pyspark.mllib.stat import KernelDensity # type: ignore[no-redef] +from pyspark.mllib.stat import KernelDensity from pyspark.sql import functions as F from pandas.core.base import PandasObject from pandas.core.dtypes.inference import is_integer diff --git a/python/pyspark/pandas/plot/plotly.py b/python/pyspark/pandas/plot/plotly.py index dfcc13931d4bb..ebf23416344d4 100644 --- a/python/pyspark/pandas/plot/plotly.py +++ b/python/pyspark/pandas/plot/plotly.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import inspect from typing import TYPE_CHECKING, Union import pandas as pd @@ -109,7 +110,11 @@ def plot_histogram(data: Union["ps.DataFrame", "ps.Series"], **kwargs): ) ) - fig = go.Figure(data=bars, layout=go.Layout(barmode="stack")) + layout_keys = inspect.signature(go.Layout).parameters.keys() + layout_kwargs = {k: v for k, v in kwargs.items() if k in layout_keys} + + fig = go.Figure(data=bars, layout=go.Layout(**layout_kwargs)) + fig["layout"]["barmode"] = "stack" fig["layout"]["xaxis"]["title"] = "value" fig["layout"]["yaxis"]["title"] = "count" return fig diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py index d403d871d3f47..35dc5acf21bee 100644 --- a/python/pyspark/pandas/series.py +++ b/python/pyspark/pandas/series.py @@ -66,6 +66,7 @@ NumericType, Row, StructType, + TimestampType, ) from pyspark.sql.window import Window @@ -976,9 +977,11 @@ def cov(self, other: "Series", min_periods: Optional[int] = None) -> float: else: return sdf.select(F.covar_samp(*sdf.columns)).head(1)[0][0] - # TODO: arg should support Series - # TODO: NaN and None - def map(self, arg: Union[Dict, Callable]) -> "Series": + # TODO: NaN and None when ``arg`` is an empty dict + # TODO: Support ps.Series ``arg`` + def map( + self, arg: Union[Dict, Callable[[Any], Any], pd.Series], na_action: Optional[str] = None + ) -> "Series": """ Map values of Series according to input correspondence. @@ -992,8 +995,10 @@ def map(self, arg: Union[Dict, Callable]) -> "Series": Parameters ---------- - arg : function or dict + arg : function, dict or pd.Series Mapping correspondence. + na_action : + If `ignore`, propagate NA values, without passing them to the mapping correspondence. Returns ------- @@ -1034,6 +1039,16 @@ def map(self, arg: Union[Dict, Callable]) -> "Series": 3 None dtype: object + It also accepts a pandas Series: + + >>> pser = pd.Series(['kitten', 'puppy'], index=['cat', 'dog']) + >>> s.map(pser) + 0 kitten + 1 puppy + 2 None + 3 None + dtype: object + It also accepts a function: >>> def format(x) -> str: @@ -1045,8 +1060,18 @@ def map(self, arg: Union[Dict, Callable]) -> "Series": 2 I am a None 3 I am a rabbit dtype: object + + To avoid applying the function to missing values (and keep them as NaN) + na_action='ignore' can be used: + + >>> s.map('I am a {}'.format, na_action='ignore') + 0 I am a cat + 1 I am a dog + 2 None + 3 I am a rabbit + dtype: object """ - if isinstance(arg, dict): + if isinstance(arg, (dict, pd.Series)): is_start = True # In case dictionary is empty. current = F.when(SF.lit(False), SF.lit(None).cast(self.spark.data_type)) @@ -1067,7 +1092,7 @@ def map(self, arg: Union[Dict, Callable]) -> "Series": current = current.otherwise(SF.lit(None).cast(self.spark.data_type)) return self._with_new_scol(current) else: - return self.apply(arg) + return self.pandas_on_spark.transform_batch(lambda pser: pser.map(arg, na_action)) @property def shape(self) -> Tuple[int]: @@ -1087,16 +1112,18 @@ def name(self) -> Name: def name(self, name: Name) -> None: self.rename(name, inplace=True) - # TODO: Functionality and documentation should be matched. Currently, changing index labels - # taking dictionary and function to change index are not supported. - def rename(self, index: Optional[Name] = None, **kwargs: Any) -> "Series": + # TODO: Currently, changing index labels taking dictionary/Series is not supported. + def rename( + self, index: Optional[Union[Name, Callable[[Any], Any]]] = None, **kwargs: Any + ) -> "Series": """ - Alter Series name. + Alter Series index labels or name. Parameters ---------- - index : scalar - Scalar will alter the ``Series.name`` attribute. + index : scalar or function, optional + Functions are transformations to apply to the index. + Scalar will alter the Series.name attribute. inplace : bool, default False Whether to return a new Series. If True then value of copy is @@ -1105,7 +1132,7 @@ def rename(self, index: Optional[Name] = None, **kwargs: Any) -> "Series": Returns ------- Series - Series with name altered. + Series with index labels or name altered. Examples -------- @@ -1122,9 +1149,26 @@ def rename(self, index: Optional[Name] = None, **kwargs: Any) -> "Series": 1 2 2 3 Name: my_name, dtype: int64 + + >>> s.rename(lambda x: x ** 2) # function, changes labels + 0 1 + 1 2 + 4 3 + dtype: int64 """ if index is None: pass + if callable(index): + if kwargs.get("inplace", False): + raise ValueError("inplace True is not supported yet for a function 'index'") + frame = self.to_frame() + new_index_name = verify_temp_column_name(frame, "__index_name__") + frame[new_index_name] = self.index.map(index) + frame.set_index(new_index_name, inplace=True) + frame.index.name = self.index.name + return first_series(frame).rename(self.name) + elif isinstance(index, (pd.Series, dict)): + raise ValueError("'index' of %s type is not supported yet" % type(index).__name__) elif not is_hashable(index): raise TypeError("Series.name must be a hashable type") elif not isinstance(index, tuple): @@ -3210,7 +3254,8 @@ def apply(self, func: Callable, args: Sequence[Any] = (), **kwds: Any) -> "Serie # Falls back to schema inference if it fails to get signature. should_infer_schema = True - apply_each = lambda s: s.apply(func, args=args, **kwds) + def apply_each(s: Any) -> pd.Series: + return s.apply(func, args=args, **kwds) if should_infer_schema: return self.pandas_on_spark._transform_batch(apply_each, None) @@ -3221,7 +3266,7 @@ def apply(self, func: Callable, args: Sequence[Any] = (), **kwds: Any) -> "Serie "Expected the return type of this function to be of scalar type, " "but found type {}".format(sig_return) ) - return_type = cast(ScalarType, sig_return) + return_type = sig_return return self.pandas_on_spark._transform_batch(apply_each, return_type) # TODO: not all arguments are implemented comparing to pandas' for now. @@ -3497,7 +3542,7 @@ def quantile( raise TypeError( "q must be a float or an array of floats; however, [%s] found." % type(q) ) - q_float = cast(float, q) + q_float = q if q_float < 0.0 or q_float > 1.0: raise ValueError("percentiles should all be in the interval [0, 1].") @@ -3611,9 +3656,9 @@ def _rank( raise NotImplementedError("rank do not support MultiIndex now") if ascending: - asc_func = lambda scol: scol.asc() + asc_func = Column.asc else: - asc_func = lambda scol: scol.desc() + asc_func = Column.desc if method == "first": window = ( @@ -3662,10 +3707,7 @@ def filter( if axis == 1: raise ValueError("Series does not support columns axis.") return first_series( - cast( - "ps.DataFrame", - self.to_frame().filter(items=items, like=like, regex=regex, axis=axis), - ) + self.to_frame().filter(items=items, like=like, regex=regex, axis=axis), ).rename(self.name) filter.__doc__ = DataFrame.filter.__doc__ @@ -4746,7 +4788,7 @@ def mask(self, cond: "Series", other: Any = np.nan) -> "Series": >>> reset_option("compute.ops_on_diff_frames") """ - return self.where(cast(Series, ~cond), other) + return self.where(~cond, other) def xs(self, key: Name, level: Optional[int] = None) -> "Series": """ @@ -5226,24 +5268,51 @@ def asof(self, where: Union[Any, List]) -> Union[Scalar, "Series"]: if not is_list_like(where): should_return_series = False where = [where] - index_scol = self._internal.index_spark_columns[0] - index_type = self._internal.spark_type_for(index_scol) + internal = self._internal.resolved_copy + index_scol = internal.index_spark_columns[0] + index_type = internal.spark_type_for(index_scol) + spark_column = internal.data_spark_columns[0] + monotonically_increasing_id_column = verify_temp_column_name( + internal.spark_frame, "__monotonically_increasing_id__" + ) cond = [ - F.max(F.when(index_scol <= SF.lit(index).cast(index_type), self.spark.column)) + F.max_by( + spark_column, + F.when( + (index_scol <= SF.lit(index).cast(index_type)) & spark_column.isNotNull() + if pd.notna(index) + # If index is nan and the value of the col is not null + # then return monotonically_increasing_id .This will let max by + # to return last index value , which is the behaviour of pandas + else spark_column.isNotNull(), + monotonically_increasing_id_column, + ), + ) for index in where ] - sdf = self._internal.spark_frame.select(cond) + + sdf = internal.spark_frame.withColumn( + monotonically_increasing_id_column, F.monotonically_increasing_id() + ).select(cond) + if not should_return_series: with sql_conf({SPARK_CONF_ARROW_ENABLED: False}): # Disable Arrow to keep row ordering. - result = cast(pd.DataFrame, sdf.limit(1).toPandas()).iloc[0, 0] + result = sdf.limit(1).toPandas().iloc[0, 0] return result if result is not None else np.nan # The data is expected to be small so it's fine to transpose/use default index. with ps.option_context("compute.default_index_type", "distributed", "compute.max_rows", 1): - psdf: DataFrame = DataFrame(sdf) - psdf.columns = pd.Index(where) - return first_series(psdf.transpose()).rename(self.name) + if len(where) == len(set(where)) and not isinstance(index_type, TimestampType): + psdf: DataFrame = DataFrame(sdf) + psdf.columns = pd.Index(where) + return first_series(psdf.transpose()).rename(self.name) + else: + # If `where` has duplicate items, leverage the pandas directly + # since pandas API on Spark doesn't support the duplicate column name. + pdf: pd.DataFrame = sdf.limit(1).toPandas() + pdf.columns = pd.Index(where) + return first_series(DataFrame(pdf.transpose())).rename(self.name) def mad(self) -> float: """ @@ -6277,7 +6346,7 @@ def _apply_series_op( if isinstance(psser_or_scol, Series): psser = psser_or_scol else: - psser = self._with_new_scol(cast(Column, psser_or_scol)) + psser = self._with_new_scol(psser_or_scol) if should_resolve: internal = psser._internal.resolved_copy return first_series(DataFrame(internal)) @@ -6429,7 +6498,7 @@ def unpack_scalar(sdf: SparkDataFrame) -> Any: Takes a dataframe that is supposed to contain a single row with a single scalar value, and returns this value. """ - lst = cast(pd.DataFrame, sdf.limit(2).toPandas()) + lst = sdf.limit(2).toPandas() assert len(lst) == 1, (sdf, lst) row = lst.iloc[0] lst2 = list(row) diff --git a/python/pyspark/pandas/spark/functions.py b/python/pyspark/pandas/spark/functions.py index dcc8fc8d91285..b7d57b4c3f828 100644 --- a/python/pyspark/pandas/spark/functions.py +++ b/python/pyspark/pandas/spark/functions.py @@ -40,7 +40,7 @@ def repeat(col: Column, n: Union[int, Column]) -> Column: """ Repeats a string column n times, and returns it as a new string column. """ - sc = SparkContext._active_spark_context # type: ignore[attr-defined] + sc = SparkContext._active_spark_context n = _to_java_column(n) if isinstance(n, Column) else _create_column_from_literal(n) return _call_udf(sc, "repeat", _to_java_column(col), n) @@ -49,7 +49,7 @@ def date_part(field: Union[str, Column], source: Column) -> Column: """ Extracts a part of the date/timestamp or interval source. """ - sc = SparkContext._active_spark_context # type: ignore[attr-defined] + sc = SparkContext._active_spark_context field = ( _to_java_column(field) if isinstance(field, Column) else _create_column_from_literal(field) ) diff --git a/python/pyspark/pandas/strings.py b/python/pyspark/pandas/strings.py index 986e3d1a0ace5..774fd6c7ca0bf 100644 --- a/python/pyspark/pandas/strings.py +++ b/python/pyspark/pandas/strings.py @@ -25,7 +25,6 @@ List, Optional, Union, - TYPE_CHECKING, cast, no_type_check, ) @@ -37,11 +36,9 @@ from pyspark.sql import functions as F from pyspark.sql.functions import pandas_udf +import pyspark.pandas as ps from pyspark.pandas.spark import functions as SF -if TYPE_CHECKING: - import pyspark.pandas as ps - class StringMethods: """String methods for pandas-on-Spark Series""" @@ -74,8 +71,7 @@ def capitalize(self) -> "ps.Series": dtype: object """ - @no_type_check - def pandas_capitalize(s) -> "ps.Series[str]": + def pandas_capitalize(s) -> ps.Series[str]: # type: ignore[no-untyped-def] return s.str.capitalize() return self._data.pandas_on_spark.transform_batch(pandas_capitalize) @@ -102,8 +98,7 @@ def title(self) -> "ps.Series": dtype: object """ - @no_type_check - def pandas_title(s) -> "ps.Series[str]": + def pandas_title(s) -> ps.Series[str]: # type: ignore[no-untyped-def] return s.str.title() return self._data.pandas_on_spark.transform_batch(pandas_title) @@ -176,8 +171,7 @@ def swapcase(self) -> "ps.Series": dtype: object """ - @no_type_check - def pandas_swapcase(s) -> "ps.Series[str]": + def pandas_swapcase(s) -> ps.Series[str]: # type: ignore[no-untyped-def] return s.str.swapcase() return self._data.pandas_on_spark.transform_batch(pandas_swapcase) @@ -228,8 +222,7 @@ def startswith(self, pattern: str, na: Optional[Any] = None) -> "ps.Series": dtype: bool """ - @no_type_check - def pandas_startswith(s) -> "ps.Series[bool]": + def pandas_startswith(s) -> ps.Series[bool]: # type: ignore[no-untyped-def] return s.str.startswith(pattern, na) return self._data.pandas_on_spark.transform_batch(pandas_startswith) @@ -280,8 +273,7 @@ def endswith(self, pattern: str, na: Optional[Any] = None) -> "ps.Series": dtype: bool """ - @no_type_check - def pandas_endswith(s) -> "ps.Series[bool]": + def pandas_endswith(s) -> ps.Series[bool]: # type: ignore[no-untyped-def] return s.str.endswith(pattern, na) return self._data.pandas_on_spark.transform_batch(pandas_endswith) @@ -333,8 +325,7 @@ def strip(self, to_strip: Optional[str] = None) -> "ps.Series": dtype: object """ - @no_type_check - def pandas_strip(s) -> "ps.Series[str]": + def pandas_strip(s) -> ps.Series[str]: # type: ignore[no-untyped-def] return s.str.strip(to_strip) return self._data.pandas_on_spark.transform_batch(pandas_strip) @@ -374,8 +365,7 @@ def lstrip(self, to_strip: Optional[str] = None) -> "ps.Series": dtype: object """ - @no_type_check - def pandas_lstrip(s) -> "ps.Series[str]": + def pandas_lstrip(s) -> ps.Series[str]: # type: ignore[no-untyped-def] return s.str.lstrip(to_strip) return self._data.pandas_on_spark.transform_batch(pandas_lstrip) @@ -415,8 +405,7 @@ def rstrip(self, to_strip: Optional[str] = None) -> "ps.Series": dtype: object """ - @no_type_check - def pandas_rstrip(s) -> "ps.Series[str]": + def pandas_rstrip(s) -> ps.Series[str]: # type: ignore[no-untyped-def] return s.str.rstrip(to_strip) return self._data.pandas_on_spark.transform_batch(pandas_rstrip) @@ -470,8 +459,7 @@ def get(self, i: int) -> "ps.Series": dtype: object """ - @no_type_check - def pandas_get(s) -> "ps.Series[str]": + def pandas_get(s) -> ps.Series[str]: # type: ignore[no-untyped-def] return s.str.get(i) return self._data.pandas_on_spark.transform_batch(pandas_get) @@ -507,8 +495,7 @@ def isalnum(self) -> "ps.Series": dtype: bool """ - @no_type_check - def pandas_isalnum(s) -> "ps.Series[bool]": + def pandas_isalnum(s) -> ps.Series[bool]: # type: ignore[no-untyped-def] return s.str.isalnum() return self._data.pandas_on_spark.transform_batch(pandas_isalnum) @@ -533,8 +520,7 @@ def isalpha(self) -> "ps.Series": dtype: bool """ - @no_type_check - def pandas_isalpha(s) -> "ps.Series[bool]": + def pandas_isalpha(s) -> ps.Series[bool]: # type: ignore[no-untyped-def] return s.str.isalpha() return self._data.pandas_on_spark.transform_batch(pandas_isalpha) @@ -584,8 +570,7 @@ def isdigit(self) -> "ps.Series": dtype: bool """ - @no_type_check - def pandas_isdigit(s) -> "ps.Series[bool]": + def pandas_isdigit(s) -> ps.Series[bool]: # type: ignore[no-untyped-def] return s.str.isdigit() return self._data.pandas_on_spark.transform_batch(pandas_isdigit) @@ -608,8 +593,7 @@ def isspace(self) -> "ps.Series": dtype: bool """ - @no_type_check - def pandas_isspace(s) -> "ps.Series[bool]": + def pandas_isspace(s) -> ps.Series[bool]: # type: ignore[no-untyped-def] return s.str.isspace() return self._data.pandas_on_spark.transform_batch(pandas_isspace) @@ -633,8 +617,7 @@ def islower(self) -> "ps.Series": dtype: bool """ - @no_type_check - def pandas_isspace(s) -> "ps.Series[bool]": + def pandas_isspace(s) -> ps.Series[bool]: # type: ignore[no-untyped-def] return s.str.islower() return self._data.pandas_on_spark.transform_batch(pandas_isspace) @@ -658,8 +641,7 @@ def isupper(self) -> "ps.Series": dtype: bool """ - @no_type_check - def pandas_isspace(s) -> "ps.Series[bool]": + def pandas_isspace(s) -> ps.Series[bool]: # type: ignore[no-untyped-def] return s.str.isupper() return self._data.pandas_on_spark.transform_batch(pandas_isspace) @@ -689,8 +671,7 @@ def istitle(self) -> "ps.Series": dtype: bool """ - @no_type_check - def pandas_istitle(s) -> "ps.Series[bool]": + def pandas_istitle(s) -> ps.Series[bool]: # type: ignore[no-untyped-def] return s.str.istitle() return self._data.pandas_on_spark.transform_batch(pandas_istitle) @@ -748,8 +729,7 @@ def isnumeric(self) -> "ps.Series": dtype: bool """ - @no_type_check - def pandas_isnumeric(s) -> "ps.Series[bool]": + def pandas_isnumeric(s) -> ps.Series[bool]: # type: ignore[no-untyped-def] return s.str.isnumeric() return self._data.pandas_on_spark.transform_batch(pandas_isnumeric) @@ -799,8 +779,7 @@ def isdecimal(self) -> "ps.Series": dtype: bool """ - @no_type_check - def pandas_isdecimal(s) -> "ps.Series[bool]": + def pandas_isdecimal(s) -> ps.Series[bool]: # type: ignore[no-untyped-def] return s.str.isdecimal() return self._data.pandas_on_spark.transform_batch(pandas_isdecimal) @@ -843,8 +822,7 @@ def center(self, width: int, fillchar: str = " ") -> "ps.Series": dtype: object """ - @no_type_check - def pandas_center(s) -> "ps.Series[str]": + def pandas_center(s) -> ps.Series[str]: # type: ignore[no-untyped-def] return s.str.center(width, fillchar) return self._data.pandas_on_spark.transform_batch(pandas_center) @@ -963,8 +941,7 @@ def contains( dtype: bool """ - @no_type_check - def pandas_contains(s) -> "ps.Series[bool]": + def pandas_contains(s) -> ps.Series[bool]: # type: ignore[no-untyped-def] return s.str.contains(pat, case, flags, na, regex) return self._data.pandas_on_spark.transform_batch(pandas_contains) @@ -1014,8 +991,7 @@ def count(self, pat: str, flags: int = 0) -> "ps.Series": dtype: int64 """ - @no_type_check - def pandas_count(s) -> "ps.Series[int]": + def pandas_count(s) -> ps.Series[int]: # type: ignore[no-untyped-def] return s.str.count(pat, flags) return self._data.pandas_on_spark.transform_batch(pandas_count) @@ -1098,8 +1074,7 @@ def find(self, sub: str, start: int = 0, end: Optional[int] = None) -> "ps.Serie dtype: int64 """ - @no_type_check - def pandas_find(s) -> "ps.Series[int]": + def pandas_find(s) -> ps.Series[int]: # type: ignore[no-untyped-def] return s.str.find(sub, start, end) return self._data.pandas_on_spark.transform_batch(pandas_find) @@ -1229,8 +1204,7 @@ def index(self, sub: str, start: int = 0, end: Optional[int] = None) -> "ps.Seri >>> s.str.index('a', start=2) # doctest: +SKIP """ - @no_type_check - def pandas_index(s) -> "ps.Series[np.int64]": + def pandas_index(s) -> ps.Series[np.int64]: # type: ignore[no-untyped-def] return s.str.index(sub, start, end) return self._data.pandas_on_spark.transform_batch(pandas_index) @@ -1279,8 +1253,7 @@ def join(self, sep: str) -> "ps.Series": dtype: object """ - @no_type_check - def pandas_join(s) -> "ps.Series[str]": + def pandas_join(s) -> ps.Series[str]: # type: ignore[no-untyped-def] return s.str.join(sep) return self._data.pandas_on_spark.transform_batch(pandas_join) @@ -1350,8 +1323,7 @@ def ljust(self, width: int, fillchar: str = " ") -> "ps.Series": dtype: object """ - @no_type_check - def pandas_ljust(s) -> "ps.Series[str]": + def pandas_ljust(s) -> ps.Series[str]: # type: ignore[no-untyped-def] return s.str.ljust(width, fillchar) return self._data.pandas_on_spark.transform_batch(pandas_ljust) @@ -1417,8 +1389,7 @@ def match(self, pat: str, case: bool = True, flags: int = 0, na: Any = np.NaN) - dtype: object """ - @no_type_check - def pandas_match(s) -> "ps.Series[bool]": + def pandas_match(s) -> ps.Series[bool]: # type: ignore[no-untyped-def] return s.str.match(pat, case, flags, na) return self._data.pandas_on_spark.transform_batch(pandas_match) @@ -1441,8 +1412,7 @@ def normalize(self, form: str) -> "ps.Series": A Series of normalized strings. """ - @no_type_check - def pandas_normalize(s) -> "ps.Series[str]": + def pandas_normalize(s) -> ps.Series[str]: # type: ignore[no-untyped-def] return s.str.normalize(form) return self._data.pandas_on_spark.transform_batch(pandas_normalize) @@ -1490,8 +1460,7 @@ def pad(self, width: int, side: str = "left", fillchar: str = " ") -> "ps.Series dtype: object """ - @no_type_check - def pandas_pad(s) -> "ps.Series[str]": + def pandas_pad(s) -> ps.Series[str]: # type: ignore[no-untyped-def] return s.str.pad(width, side, fillchar) return self._data.pandas_on_spark.transform_batch(pandas_pad) @@ -1636,8 +1605,7 @@ def replace( dtype: object """ - @no_type_check - def pandas_replace(s) -> "ps.Series[str]": + def pandas_replace(s) -> ps.Series[str]: # type: ignore[no-untyped-def] return s.str.replace(pat, repl, n=n, case=case, flags=flags, regex=regex) return self._data.pandas_on_spark.transform_batch(pandas_replace) @@ -1692,8 +1660,7 @@ def rfind(self, sub: str, start: int = 0, end: Optional[int] = None) -> "ps.Seri dtype: int64 """ - @no_type_check - def pandas_rfind(s) -> "ps.Series[int]": + def pandas_rfind(s) -> ps.Series[int]: # type: ignore[no-untyped-def] return s.str.rfind(sub, start, end) return self._data.pandas_on_spark.transform_batch(pandas_rfind) @@ -1736,8 +1703,7 @@ def rindex(self, sub: str, start: int = 0, end: Optional[int] = None) -> "ps.Ser >>> s.str.rindex('a', start=2) # doctest: +SKIP """ - @no_type_check - def pandas_rindex(s) -> "ps.Series[np.int64]": + def pandas_rindex(s) -> ps.Series[np.int64]: # type: ignore[no-untyped-def] return s.str.rindex(sub, start, end) return self._data.pandas_on_spark.transform_batch(pandas_rindex) @@ -1778,8 +1744,7 @@ def rjust(self, width: int, fillchar: str = " ") -> "ps.Series": dtype: object """ - @no_type_check - def pandas_rjust(s) -> "ps.Series[str]": + def pandas_rjust(s) -> ps.Series[str]: # type: ignore[no-untyped-def] return s.str.rjust(width, fillchar) return self._data.pandas_on_spark.transform_batch(pandas_rjust) @@ -1844,8 +1809,7 @@ def slice( dtype: object """ - @no_type_check - def pandas_slice(s) -> "ps.Series[str]": + def pandas_slice(s) -> ps.Series[str]: # type: ignore[no-untyped-def] return s.str.slice(start, stop, step) return self._data.pandas_on_spark.transform_batch(pandas_slice) @@ -1921,8 +1885,7 @@ def slice_replace( dtype: object """ - @no_type_check - def pandas_slice_replace(s) -> "ps.Series[str]": + def pandas_slice_replace(s) -> ps.Series[str]: # type: ignore[no-untyped-def] return s.str.slice_replace(start, stop, repl) return self._data.pandas_on_spark.transform_batch(pandas_slice_replace) @@ -2259,8 +2222,7 @@ def translate(self, table: Dict) -> "ps.Series": dtype: object """ - @no_type_check - def pandas_translate(s) -> "ps.Series[str]": + def pandas_translate(s) -> ps.Series[str]: # type: ignore[no-untyped-def] return s.str.translate(table) return self._data.pandas_on_spark.transform_batch(pandas_translate) @@ -2311,8 +2273,7 @@ def wrap(self, width: int, **kwargs: bool) -> "ps.Series": dtype: object """ - @no_type_check - def pandas_wrap(s) -> "ps.Series[str]": + def pandas_wrap(s) -> ps.Series[str]: # type: ignore[no-untyped-def] return s.str.wrap(width, **kwargs) return self._data.pandas_on_spark.transform_batch(pandas_wrap) @@ -2362,8 +2323,7 @@ def zfill(self, width: int) -> "ps.Series": dtype: object """ - @no_type_check - def pandas_zfill(s) -> "ps.Series[str]": + def pandas_zfill(s) -> ps.Series[str]: # type: ignore[no-untyped-def] return s.str.zfill(width) return self._data.pandas_on_spark.transform_batch(pandas_zfill) diff --git a/python/pyspark/pandas/tests/data_type_ops/test_binary_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_binary_ops.py index 5dc7f8096855b..35fcb3705a310 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_binary_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_binary_ops.py @@ -19,11 +19,10 @@ from pandas.api.types import CategoricalDtype from pyspark import pandas as ps -from pyspark.pandas.tests.data_type_ops.testing_utils import TestCasesUtils -from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.pandas.tests.data_type_ops.testing_utils import OpsTestBase -class BinaryOpsTest(PandasOnSparkTestCase, TestCasesUtils): +class BinaryOpsTest(OpsTestBase): @property def pser(self): return pd.Series([b"1", b"2", b"3"]) diff --git a/python/pyspark/pandas/tests/data_type_ops/test_boolean_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_boolean_ops.py index b83b610d0cc21..02bb048ee5bc8 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_boolean_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_boolean_ops.py @@ -25,15 +25,14 @@ from pyspark import pandas as ps from pyspark.pandas import option_context -from pyspark.pandas.tests.data_type_ops.testing_utils import TestCasesUtils +from pyspark.pandas.tests.data_type_ops.testing_utils import OpsTestBase from pyspark.pandas.typedef.typehints import ( extension_float_dtypes_available, extension_object_dtypes_available, ) -from pyspark.testing.pandasutils import PandasOnSparkTestCase -class BooleanOpsTest(PandasOnSparkTestCase, TestCasesUtils): +class BooleanOpsTest(OpsTestBase): @property def bool_pdf(self): return pd.DataFrame({"this": [True, False, True], "that": [False, True, True]}) @@ -381,7 +380,7 @@ def test_ge(self): @unittest.skipIf( not extension_object_dtypes_available, "pandas extension object dtypes are not available" ) -class BooleanExtensionOpsTest(PandasOnSparkTestCase, TestCasesUtils): +class BooleanExtensionOpsTest(OpsTestBase): @property def boolean_pdf(self): return pd.DataFrame( diff --git a/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py index 0aa2e108d799a..b84c35bb104f9 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py @@ -23,11 +23,10 @@ from pyspark import pandas as ps from pyspark.pandas.config import option_context -from pyspark.pandas.tests.data_type_ops.testing_utils import TestCasesUtils -from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.pandas.tests.data_type_ops.testing_utils import OpsTestBase -class CategoricalOpsTest(PandasOnSparkTestCase, TestCasesUtils): +class CategoricalOpsTest(OpsTestBase): @property def pdf(self): return pd.DataFrame( @@ -54,10 +53,6 @@ def pdf(self): } ) - @property - def psdf(self): - return ps.from_pandas(self.pdf) - @property def pser(self): return pd.Series([1, 2, 3], dtype="category") diff --git a/python/pyspark/pandas/tests/data_type_ops/test_complex_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_complex_ops.py index 91a92badf8cd6..cc9a0bf4a7430 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_complex_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_complex_ops.py @@ -21,11 +21,10 @@ import pandas as pd from pyspark import pandas as ps -from pyspark.pandas.tests.data_type_ops.testing_utils import TestCasesUtils -from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.pandas.tests.data_type_ops.testing_utils import OpsTestBase -class ComplexOpsTest(PandasOnSparkTestCase, TestCasesUtils): +class ComplexOpsTest(OpsTestBase): @property def pser(self): return pd.Series([[1, 2, 3]]) diff --git a/python/pyspark/pandas/tests/data_type_ops/test_date_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_date_ops.py index 8c196d2a715bb..f0585c3f5a14f 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_date_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_date_ops.py @@ -21,11 +21,10 @@ from pandas.api.types import CategoricalDtype from pyspark import pandas as ps -from pyspark.pandas.tests.data_type_ops.testing_utils import TestCasesUtils -from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.pandas.tests.data_type_ops.testing_utils import OpsTestBase -class DateOpsTest(PandasOnSparkTestCase, TestCasesUtils): +class DateOpsTest(OpsTestBase): @property def pser(self): return pd.Series( diff --git a/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py index 5eba4855f93ae..f29f9d375e47f 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py @@ -21,11 +21,10 @@ from pandas.api.types import CategoricalDtype from pyspark import pandas as ps -from pyspark.pandas.tests.data_type_ops.testing_utils import TestCasesUtils -from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.pandas.tests.data_type_ops.testing_utils import OpsTestBase -class DatetimeOpsTest(PandasOnSparkTestCase, TestCasesUtils): +class DatetimeOpsTest(OpsTestBase): @property def pser(self): return pd.Series(pd.date_range("1994-1-31 10:30:15", periods=3, freq="D")) diff --git a/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py index c2b6be29038bd..009d4d0aba019 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py @@ -19,11 +19,10 @@ from pandas.api.types import CategoricalDtype import pyspark.pandas as ps -from pyspark.pandas.tests.data_type_ops.testing_utils import TestCasesUtils -from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.pandas.tests.data_type_ops.testing_utils import OpsTestBase -class NullOpsTest(PandasOnSparkTestCase, TestCasesUtils): +class NullOpsTest(OpsTestBase): @property def pser(self): return pd.Series([None, None, None]) diff --git a/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py index 785eb250a72b3..0c2c94eab8ef1 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py @@ -25,17 +25,16 @@ from pyspark import pandas as ps from pyspark.pandas.config import option_context -from pyspark.pandas.tests.data_type_ops.testing_utils import TestCasesUtils +from pyspark.pandas.tests.data_type_ops.testing_utils import OpsTestBase from pyspark.pandas.typedef.typehints import ( extension_dtypes_available, extension_float_dtypes_available, extension_object_dtypes_available, ) from pyspark.sql.types import DecimalType, IntegralType -from pyspark.testing.pandasutils import PandasOnSparkTestCase -class NumOpsTest(PandasOnSparkTestCase, TestCasesUtils): +class NumOpsTest(OpsTestBase): """Unit tests for arithmetic operations of numeric data types. A few test cases are disabled because pandas-on-Spark returns float64 whereas pandas @@ -450,7 +449,7 @@ def test_ge(self): @unittest.skipIf(not extension_dtypes_available, "pandas extension dtypes are not available") -class IntegralExtensionOpsTest(PandasOnSparkTestCase, TestCasesUtils): +class IntegralExtensionOpsTest(OpsTestBase): @property def intergral_extension_psers(self): return [pd.Series([1, 2, 3, None], dtype=dtype) for dtype in self.integral_extension_dtypes] @@ -590,7 +589,7 @@ def test_rxor(self): @unittest.skipIf( not extension_float_dtypes_available, "pandas extension float dtypes are not available" ) -class FractionalExtensionOpsTest(PandasOnSparkTestCase, TestCasesUtils): +class FractionalExtensionOpsTest(OpsTestBase): @property def fractional_extension_psers(self): return [ diff --git a/python/pyspark/pandas/tests/data_type_ops/test_string_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_string_ops.py index f7c45cc429837..572ea7688cb7f 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_string_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_string_ops.py @@ -23,15 +23,14 @@ from pyspark import pandas as ps from pyspark.pandas.config import option_context -from pyspark.pandas.tests.data_type_ops.testing_utils import TestCasesUtils +from pyspark.pandas.tests.data_type_ops.testing_utils import OpsTestBase from pyspark.pandas.typedef.typehints import extension_object_dtypes_available -from pyspark.testing.pandasutils import PandasOnSparkTestCase if extension_object_dtypes_available: from pandas import StringDtype -class StringOpsTest(PandasOnSparkTestCase, TestCasesUtils): +class StringOpsTest(OpsTestBase): @property def bool_pdf(self): return pd.DataFrame({"this": ["x", "y", "z"], "that": ["z", "y", "x"]}) @@ -237,7 +236,7 @@ def test_ge(self): @unittest.skipIf( not extension_object_dtypes_available, "pandas extension object dtypes are not available" ) -class StringExtensionOpsTest(StringOpsTest, PandasOnSparkTestCase, TestCasesUtils): +class StringExtensionOpsTest(StringOpsTest): @property def pser(self): return pd.Series(["x", "y", "z", None], dtype="string") diff --git a/python/pyspark/pandas/tests/data_type_ops/test_timedelta_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_timedelta_ops.py index 40882b8f24a90..16788c06c7c92 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_timedelta_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_timedelta_ops.py @@ -21,11 +21,10 @@ from pandas.api.types import CategoricalDtype import pyspark.pandas as ps -from pyspark.pandas.tests.data_type_ops.testing_utils import TestCasesUtils -from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.pandas.tests.data_type_ops.testing_utils import OpsTestBase -class TimedeltaOpsTest(PandasOnSparkTestCase, TestCasesUtils): +class TimedeltaOpsTest(OpsTestBase): @property def pser(self): return pd.Series([timedelta(1), timedelta(microseconds=2), timedelta(weeks=3)]) diff --git a/python/pyspark/pandas/tests/data_type_ops/test_udt_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_udt_ops.py index 70175c4a97d2b..a71691c036cfe 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_udt_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_udt_ops.py @@ -19,11 +19,10 @@ import pyspark.pandas as ps from pyspark.ml.linalg import SparseVector -from pyspark.pandas.tests.data_type_ops.testing_utils import TestCasesUtils -from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.pandas.tests.data_type_ops.testing_utils import OpsTestBase -class UDTOpsTest(PandasOnSparkTestCase, TestCasesUtils): +class UDTOpsTest(OpsTestBase): @property def pser(self): sparse_values = {0: 0.1, 1: 1.1} diff --git a/python/pyspark/pandas/tests/data_type_ops/testing_utils.py b/python/pyspark/pandas/tests/data_type_ops/testing_utils.py index 9f57ad4832da2..222b945265264 100644 --- a/python/pyspark/pandas/tests/data_type_ops/testing_utils.py +++ b/python/pyspark/pandas/tests/data_type_ops/testing_utils.py @@ -31,6 +31,8 @@ extension_object_dtypes_available, ) +from pyspark.testing.pandasutils import ComparisonTestBase + if extension_dtypes_available: from pandas import Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype @@ -41,8 +43,8 @@ from pandas import BooleanDtype, StringDtype -class TestCasesUtils: - """A utility holding common test cases for arithmetic operations of different data types.""" +class OpsTestBase(ComparisonTestBase): + """The test base for arithmetic operations of different data types.""" @property def numeric_pdf(self): @@ -110,10 +112,6 @@ def non_numeric_df_cols(self): def pdf(self): return pd.concat([self.numeric_pdf, self.non_numeric_pdf], axis=1) - @property - def psdf(self): - return ps.from_pandas(self.pdf) - @property def df_cols(self): return self.pdf.columns diff --git a/python/pyspark/pandas/tests/indexes/test_base.py b/python/pyspark/pandas/tests/indexes/test_base.py index 88c826eea786b..dc1f26dfc4588 100644 --- a/python/pyspark/pandas/tests/indexes/test_base.py +++ b/python/pyspark/pandas/tests/indexes/test_base.py @@ -31,10 +31,10 @@ MissingPandasLikeMultiIndex, MissingPandasLikeTimedeltaIndex, ) -from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils, SPARK_CONF_ARROW_ENABLED +from pyspark.testing.pandasutils import ComparisonTestBase, TestUtils, SPARK_CONF_ARROW_ENABLED -class IndexesTest(PandasOnSparkTestCase, TestUtils): +class IndexesTest(ComparisonTestBase, TestUtils): @property def pdf(self): return pd.DataFrame( @@ -42,10 +42,6 @@ def pdf(self): index=[0, 1, 3, 5, 6, 8, 9, 9, 9], ) - @property - def psdf(self): - return ps.from_pandas(self.pdf) - def test_index_basic(self): for pdf in [ pd.DataFrame(np.random.randn(10, 5), index=np.random.randint(100, size=10)), diff --git a/python/pyspark/pandas/tests/indexes/test_datetime.py b/python/pyspark/pandas/tests/indexes/test_datetime.py index e3bf14e654616..85a2b21901774 100644 --- a/python/pyspark/pandas/tests/indexes/test_datetime.py +++ b/python/pyspark/pandas/tests/indexes/test_datetime.py @@ -120,7 +120,7 @@ def test_day_name(self): def test_month_name(self): for psidx, pidx in self.idx_pairs: - self.assert_eq(psidx.day_name(), pidx.day_name()) + self.assert_eq(psidx.month_name(), pidx.month_name()) def test_normalize(self): for psidx, pidx in self.idx_pairs: diff --git a/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py b/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py index 7be00d593ee36..2937ef1813f74 100644 --- a/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py +++ b/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py @@ -186,6 +186,13 @@ def check_pie_plot(psdf): # ) # check_pie_plot(psdf1) + def test_hist_layout_kwargs(self): + s = ps.Series([1, 3, 2]) + plt = s.plot.hist(title="Title", foo="xxx") + self.assertEqual(plt.layout.barmode, "stack") + self.assertEqual(plt.layout.title.text, "Title") + self.assertFalse(hasattr(plt.layout, "foo")) + def test_hist_plot(self): def check_hist_plot(psdf): bins = np.array([1.0, 5.9, 10.8, 15.7, 20.6, 25.5, 30.4, 35.3, 40.2, 45.1, 50.0]) diff --git a/python/pyspark/pandas/tests/test_categorical.py b/python/pyspark/pandas/tests/test_categorical.py index 2430935ecbe57..a4746cdda148e 100644 --- a/python/pyspark/pandas/tests/test_categorical.py +++ b/python/pyspark/pandas/tests/test_categorical.py @@ -16,17 +16,16 @@ # from distutils.version import LooseVersion -from typing import no_type_check import numpy as np import pandas as pd from pandas.api.types import CategoricalDtype import pyspark.pandas as ps -from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils +from pyspark.testing.pandasutils import ComparisonTestBase, TestUtils -class CategoricalTest(PandasOnSparkTestCase, TestUtils): +class CategoricalTest(ComparisonTestBase, TestUtils): @property def pdf(self): return pd.DataFrame( @@ -38,10 +37,6 @@ def pdf(self): }, ) - @property - def psdf(self): - return ps.from_pandas(self.pdf) - @property def df_pair(self): return self.pdf, self.psdf @@ -438,8 +433,7 @@ def test_groupby_transform_without_shortcut(self): pdf, psdf = self.df_pair - @no_type_check - def identity(x) -> ps.Series[psdf.b.dtype]: + def identity(x) -> ps.Series[psdf.b.dtype]: # type: ignore[name-defined, no-untyped-def] return x self.assert_eq( diff --git a/python/pyspark/pandas/tests/test_dataframe.py b/python/pyspark/pandas/tests/test_dataframe.py index 0bf9291f1d2ed..6f3c1c41653ad 100644 --- a/python/pyspark/pandas/tests/test_dataframe.py +++ b/python/pyspark/pandas/tests/test_dataframe.py @@ -43,7 +43,7 @@ ) from pyspark.testing.pandasutils import ( have_tabulate, - PandasOnSparkTestCase, + ComparisonTestBase, SPARK_CONF_ARROW_ENABLED, tabulate_requirement_message, ) @@ -51,7 +51,7 @@ from pyspark.pandas.utils import name_like_string -class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils): +class DataFrameTest(ComparisonTestBase, SQLTestUtils): @property def pdf(self): return pd.DataFrame( @@ -59,10 +59,6 @@ def pdf(self): index=np.random.rand(9), ) - @property - def psdf(self): - return ps.from_pandas(self.pdf) - @property def df_pair(self): pdf = self.pdf @@ -1562,6 +1558,9 @@ def test_sort_values(self): psdf = ps.from_pandas(pdf) self.assert_eq(psdf.sort_values("b"), pdf.sort_values("b")) + self.assert_eq( + psdf.sort_values("b", ignore_index=True), pdf.sort_values("b", ignore_index=True) + ) for ascending in [True, False]: for na_position in ["first", "last"]: @@ -1571,6 +1570,10 @@ def test_sort_values(self): ) self.assert_eq(psdf.sort_values(["a", "b"]), pdf.sort_values(["a", "b"])) + self.assert_eq( + psdf.sort_values(["a", "b"], ignore_index=True), + pdf.sort_values(["a", "b"], ignore_index=True), + ) self.assert_eq( psdf.sort_values(["a", "b"], ascending=[False, True]), pdf.sort_values(["a", "b"], ascending=[False, True]), @@ -1591,6 +1594,41 @@ def test_sort_values(self): self.assert_eq(psdf, pdf) self.assert_eq(psserA, pserA) + pdf = pd.DataFrame( + {"a": [1, 2, 3, 4, 5, None, 7], "b": [7, 6, 5, 4, 3, 2, 1]}, index=np.random.rand(7) + ) + psdf = ps.from_pandas(pdf) + pserA = pdf.a + psserA = psdf.a + self.assert_eq( + psdf.sort_values("b", inplace=True, ignore_index=True), + pdf.sort_values("b", inplace=True, ignore_index=True), + ) + self.assert_eq(psdf, pdf) + self.assert_eq(psserA, pserA) + + # multi-index indexes + + pdf = pd.DataFrame( + {"a": [1, 2, 3, 4, 5, None, 7], "b": [7, 6, 5, 4, 3, 2, 1]}, + index=pd.MultiIndex.from_tuples( + [ + ("bar", "one"), + ("bar", "two"), + ("baz", "one"), + ("baz", "two"), + ("foo", "one"), + ("foo", "two"), + ("qux", "one"), + ] + ), + ) + psdf = ps.from_pandas(pdf) + self.assert_eq(psdf.sort_values("b"), pdf.sort_values("b")) + self.assert_eq( + psdf.sort_values("b", ignore_index=True), pdf.sort_values("b", ignore_index=True) + ) + # multi-index columns pdf = pd.DataFrame( {("X", 10): [1, 2, 3, 4, 5, None, 7], ("X", 20): [7, 6, 5, 4, 3, 2, 1]}, @@ -4543,7 +4581,7 @@ def identify3(x) -> ps.DataFrame[float, [int, List[int]]]: def identify4( x, - ) -> ps.DataFrame[float, [int, ntp.NDArray[int]]]: # type: ignore[name-defined] + ) -> ps.DataFrame[float, [int, ntp.NDArray[int]]]: return x actual = psdf.pandas_on_spark.apply_batch(identify4) diff --git a/python/pyspark/pandas/tests/test_dataframe_conversion.py b/python/pyspark/pandas/tests/test_dataframe_conversion.py index 2cc2f15e1ae08..123dd14324c13 100644 --- a/python/pyspark/pandas/tests/test_dataframe_conversion.py +++ b/python/pyspark/pandas/tests/test_dataframe_conversion.py @@ -25,11 +25,11 @@ import pandas as pd from pyspark import pandas as ps -from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils +from pyspark.testing.pandasutils import ComparisonTestBase, TestUtils from pyspark.testing.sqlutils import SQLTestUtils -class DataFrameConversionTest(PandasOnSparkTestCase, SQLTestUtils, TestUtils): +class DataFrameConversionTest(ComparisonTestBase, SQLTestUtils, TestUtils): """Test cases for "small data" conversion and I/O.""" def setUp(self): @@ -42,10 +42,6 @@ def tearDown(self): def pdf(self): return pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, index=[0, 1, 3]) - @property - def psdf(self): - return ps.from_pandas(self.pdf) - @staticmethod def strip_all_whitespace(str): """A helper function to remove all whitespace from a string.""" diff --git a/python/pyspark/pandas/tests/test_extension.py b/python/pyspark/pandas/tests/test_extension.py index fb5f9bbc8eed3..dd2d08dded058 100644 --- a/python/pyspark/pandas/tests/test_extension.py +++ b/python/pyspark/pandas/tests/test_extension.py @@ -21,7 +21,7 @@ import pandas as pd from pyspark import pandas as ps -from pyspark.testing.pandasutils import assert_produces_warning, PandasOnSparkTestCase +from pyspark.testing.pandasutils import assert_produces_warning, ComparisonTestBase from pyspark.pandas.extensions import ( register_dataframe_accessor, register_series_accessor, @@ -66,7 +66,7 @@ def check_length(self, col=None): raise ValueError(str(e)) -class ExtensionTest(PandasOnSparkTestCase): +class ExtensionTest(ComparisonTestBase): @property def pdf(self): return pd.DataFrame( @@ -74,10 +74,6 @@ def pdf(self): index=np.random.rand(9), ) - @property - def psdf(self): - return ps.from_pandas(self.pdf) - @property def accessor(self): return CustomAccessor(self.psdf) diff --git a/python/pyspark/pandas/tests/test_groupby.py b/python/pyspark/pandas/tests/test_groupby.py index ec6d761dddd42..661526b160050 100644 --- a/python/pyspark/pandas/tests/test_groupby.py +++ b/python/pyspark/pandas/tests/test_groupby.py @@ -49,9 +49,15 @@ def test_groupby_simple(self): for as_index in [True, False]: if as_index: - sort = lambda df: df.sort_index() + + def sort(df): + return df.sort_index() + else: - sort = lambda df: df.sort_values("a").reset_index(drop=True) + + def sort(df): + return df.sort_values("a").reset_index(drop=True) + self.assert_eq( sort(psdf.groupby("a", as_index=as_index).sum()), sort(pdf.groupby("a", as_index=as_index).sum()), @@ -156,9 +162,15 @@ def test_groupby_simple(self): for as_index in [True, False]: if as_index: - sort = lambda df: df.sort_index() + + def sort(df): + return df.sort_index() + else: - sort = lambda df: df.sort_values(10).reset_index(drop=True) + + def sort(df): + return df.sort_values(10).reset_index(drop=True) + self.assert_eq( sort(psdf.groupby(10, as_index=as_index).sum()), sort(pdf.groupby(10, as_index=as_index).sum()), @@ -244,9 +256,14 @@ def test_split_apply_combine_on_series(self): for as_index in [True, False]: if as_index: - sort = lambda df: df.sort_index() + + def sort(df): + return df.sort_index() + else: - sort = lambda df: df.sort_values(list(df.columns)).reset_index(drop=True) + + def sort(df): + return df.sort_values(list(df.columns)).reset_index(drop=True) for check_exact, almost, func in funcs: for kkey, pkey in [("b", "b"), (psdf.b, pdf.b)]: @@ -351,9 +368,14 @@ def test_aggregate(self): for as_index in [True, False]: if as_index: - sort = lambda df: df.sort_index() + + def sort(df): + return df.sort_index() + else: - sort = lambda df: df.sort_values(list(df.columns)).reset_index(drop=True) + + def sort(df): + return df.sort_values(list(df.columns)).reset_index(drop=True) for kkey, pkey in [("A", "A"), (psdf.A, pdf.A)]: with self.subTest(as_index=as_index, key=pkey): @@ -564,9 +586,14 @@ def test_dropna(self): for dropna in [True, False]: for as_index in [True, False]: if as_index: - sort = lambda df: df.sort_index() + + def sort(df): + return df.sort_index() + else: - sort = lambda df: df.sort_values("A").reset_index(drop=True) + + def sort(df): + return df.sort_values("A").reset_index(drop=True) self.assert_eq( sort(psdf.groupby("A", as_index=as_index, dropna=dropna).std()), @@ -598,9 +625,14 @@ def test_dropna(self): for dropna in [True, False]: for as_index in [True, False]: if as_index: - sort = lambda df: df.sort_index() + + def sort(df): + return df.sort_index() + else: - sort = lambda df: df.sort_values(["A", "B"]).reset_index(drop=True) + + def sort(df): + return df.sort_values(["A", "B"]).reset_index(drop=True) self.assert_eq( sort( @@ -624,9 +656,15 @@ def test_dropna(self): for dropna in [True, False]: for as_index in [True, False]: if as_index: - sort = lambda df: df.sort_index() + + def sort(df): + return df.sort_index() + else: - sort = lambda df: df.sort_values(("X", "A")).reset_index(drop=True) + + def sort(df): + return df.sort_values(("X", "A")).reset_index(drop=True) + sorted_stats_psdf = sort( psdf.groupby(("X", "A"), as_index=as_index, dropna=dropna).agg( {("X", "B"): "min", ("Y", "C"): "std"} @@ -642,9 +680,14 @@ def test_dropna(self): # Testing dropna=True (pandas default behavior) for as_index in [True, False]: if as_index: - sort = lambda df: df.sort_index() + + def sort(df): + return df.sort_index() + else: - sort = lambda df: df.sort_values("A").reset_index(drop=True) + + def sort(df): + return df.sort_values("A").reset_index(drop=True) self.assert_eq( sort(psdf.groupby("A", as_index=as_index, dropna=True)["B"].min()), @@ -652,9 +695,14 @@ def test_dropna(self): ) if as_index: - sort = lambda df: df.sort_index() + + def sort(df): + return df.sort_index() + else: - sort = lambda df: df.sort_values(["A", "B"]).reset_index(drop=True) + + def sort(df): + return df.sort_values(["A", "B"]).reset_index(drop=True) self.assert_eq( sort( @@ -847,9 +895,15 @@ def test_all_any(self): for as_index in [True, False]: if as_index: - sort = lambda df: df.sort_index() + + def sort(df): + return df.sort_index() + else: - sort = lambda df: df.sort_values("A").reset_index(drop=True) + + def sort(df): + return df.sort_values("A").reset_index(drop=True) + self.assert_eq( sort(psdf.groupby("A", as_index=as_index).all()), sort(pdf.groupby("A", as_index=as_index).all()), @@ -882,9 +936,15 @@ def test_all_any(self): for as_index in [True, False]: if as_index: - sort = lambda df: df.sort_index() + + def sort(df): + return df.sort_index() + else: - sort = lambda df: df.sort_values(("X", "A")).reset_index(drop=True) + + def sort(df): + return df.sort_values(("X", "A")).reset_index(drop=True) + self.assert_eq( sort(psdf.groupby(("X", "A"), as_index=as_index).all()), sort(pdf.groupby(("X", "A"), as_index=as_index).all()), diff --git a/python/pyspark/pandas/tests/test_indexing.py b/python/pyspark/pandas/tests/test_indexing.py index 0b76e9ea12912..fcce93aaafba3 100644 --- a/python/pyspark/pandas/tests/test_indexing.py +++ b/python/pyspark/pandas/tests/test_indexing.py @@ -24,7 +24,7 @@ from pyspark import pandas as ps from pyspark.pandas.exceptions import SparkPandasIndexingError -from pyspark.testing.pandasutils import ComparisonTestBase, PandasOnSparkTestCase, compare_both +from pyspark.testing.pandasutils import ComparisonTestBase, compare_both class BasicIndexingTest(ComparisonTestBase): @@ -153,7 +153,7 @@ def test_limitations(self): ) -class IndexingTest(PandasOnSparkTestCase): +class IndexingTest(ComparisonTestBase): @property def pdf(self): return pd.DataFrame( @@ -161,10 +161,6 @@ def pdf(self): index=[0, 1, 3, 5, 6, 8, 9, 9, 9], ) - @property - def psdf(self): - return ps.from_pandas(self.pdf) - @property def pdf2(self): return pd.DataFrame( diff --git a/python/pyspark/pandas/tests/test_numpy_compat.py b/python/pyspark/pandas/tests/test_numpy_compat.py index 0d6a8fb682579..c6b6e5dba9201 100644 --- a/python/pyspark/pandas/tests/test_numpy_compat.py +++ b/python/pyspark/pandas/tests/test_numpy_compat.py @@ -21,11 +21,11 @@ from pyspark import pandas as ps from pyspark.pandas import set_option, reset_option from pyspark.pandas.numpy_compat import unary_np_spark_mappings, binary_np_spark_mappings -from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.testing.pandasutils import ComparisonTestBase from pyspark.testing.sqlutils import SQLTestUtils -class NumPyCompatTest(PandasOnSparkTestCase, SQLTestUtils): +class NumPyCompatTest(ComparisonTestBase, SQLTestUtils): blacklist = [ # Koalas does not currently support "conj", @@ -55,10 +55,6 @@ def pdf(self): index=[0, 1, 3, 5, 6, 8, 9, 9, 9], ) - @property - def psdf(self): - return ps.from_pandas(self.pdf) - def test_np_add_series(self): psdf = self.psdf pdf = self.pdf diff --git a/python/pyspark/pandas/tests/test_ops_on_diff_frames.py b/python/pyspark/pandas/tests/test_ops_on_diff_frames.py index dad4476975f6b..96473769475d2 100644 --- a/python/pyspark/pandas/tests/test_ops_on_diff_frames.py +++ b/python/pyspark/pandas/tests/test_ops_on_diff_frames.py @@ -1361,7 +1361,7 @@ def test_update(self): pser1.update(pser2) psser1.update(psser2) - self.assert_eq(psser1, pser1) + self.assert_eq(psser1.sort_index(), pser1) def test_where(self): pdf1 = pd.DataFrame({"A": [0, 1, 2, 3, 4], "B": [100, 200, 300, 400, 500]}) diff --git a/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py b/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py index 3e8bcff8579f9..69621e49301f6 100644 --- a/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py +++ b/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py @@ -54,9 +54,15 @@ def test_groupby_different_lengths(self): for as_index in [True, False]: if as_index: - sort = lambda df: df.sort_index() + + def sort(df): + return df.sort_index() + else: - sort = lambda df: df.sort_values("c").reset_index(drop=True) + + def sort(df): + return df.sort_values("c").reset_index(drop=True) + self.assert_eq( sort(psdf1.groupby(psdf2.a, as_index=as_index).sum()), sort(pdf1.groupby(pdf2.a, as_index=as_index).sum()), @@ -112,9 +118,14 @@ def test_split_apply_combine_on_series(self): for as_index in [True, False]: if as_index: - sort = lambda df: df.sort_index() + + def sort(df): + return df.sort_index() + else: - sort = lambda df: df.sort_values(list(df.columns)).reset_index(drop=True) + + def sort(df): + return df.sort_values(list(df.columns)).reset_index(drop=True) with self.subTest(as_index=as_index): self.assert_eq( @@ -164,9 +175,14 @@ def test_aggregate(self): for as_index in [True, False]: if as_index: - sort = lambda df: df.sort_index() + + def sort(df): + return df.sort_index() + else: - sort = lambda df: df.sort_values(list(df.columns)).reset_index(drop=True) + + def sort(df): + return df.sort_values(list(df.columns)).reset_index(drop=True) with self.subTest(as_index=as_index): self.assert_eq( diff --git a/python/pyspark/pandas/tests/test_series.py b/python/pyspark/pandas/tests/test_series.py index cec6a475eb791..4cfd7c63e312d 100644 --- a/python/pyspark/pandas/tests/test_series.py +++ b/python/pyspark/pandas/tests/test_series.py @@ -193,14 +193,26 @@ def test_rename_method(self): with self.assertRaisesRegex(TypeError, expected_error_message): psser.rename(["0", "1"]) + # Function index + self.assert_eq(psser.rename(lambda x: x ** 2), pser.rename(lambda x: x ** 2)) + self.assert_eq((psser + 1).rename(lambda x: x ** 2), (pser + 1).rename(lambda x: x ** 2)) + + expected_error_message = "inplace True is not supported yet for a function 'index'" + with self.assertRaisesRegex(ValueError, expected_error_message): + psser.rename(lambda x: x ** 2, inplace=True) + + unsupported_index_inputs = (pd.Series([2, 3, 4, 5, 6, 7, 8]), {0: "zero", 1: "one"}) + for index in unsupported_index_inputs: + expected_error_message = ( + "'index' of %s type is not supported yet" % type(index).__name__ + ) + with self.assertRaisesRegex(ValueError, expected_error_message): + psser.rename(index) + # Series index # pser = pd.Series(['a', 'b', 'c', 'd', 'e', 'f', 'g'], name='x') # psser = ps.from_pandas(s) - # TODO: index - # res = psser.rename(lambda x: x ** 2) - # self.assert_eq(res, pser.rename(lambda x: x ** 2)) - # res = psser.rename(pser) # self.assert_eq(res, pser.rename(pser)) @@ -838,13 +850,20 @@ def test_all(self): pd.Series([True, False], name="x"), pd.Series([0, 1], name="x"), pd.Series([1, 2, 3], name="x"), + pd.Series([np.nan, 0, 1], name="x"), + pd.Series([np.nan, 1, 2, 3], name="x"), pd.Series([True, True, None], name="x"), pd.Series([True, False, None], name="x"), pd.Series([], name="x"), pd.Series([np.nan], name="x"), + pd.Series([np.nan, np.nan], name="x"), + pd.Series([None], name="x"), + pd.Series([None, None], name="x"), ]: psser = ps.from_pandas(pser) self.assert_eq(psser.all(), pser.all()) + self.assert_eq(psser.all(skipna=False), pser.all(skipna=False)) + self.assert_eq(psser.all(skipna=True), pser.all(skipna=True)) pser = pd.Series([1, 2, 3, 4], name="x") psser = ps.from_pandas(pser) @@ -1161,13 +1180,34 @@ def test_append(self): def test_map(self): pser = pd.Series(["cat", "dog", None, "rabbit"]) psser = ps.from_pandas(pser) - # Currently Koalas doesn't return NaN as pandas does. - self.assert_eq(psser.map({}), pser.map({}).replace({pd.np.nan: None})) + + # dict correspondence + # Currently pandas API on Spark doesn't return NaN as pandas does. + self.assert_eq(psser.map({}), pser.map({}).replace({np.nan: None})) d = defaultdict(lambda: "abc") self.assertTrue("abc" in repr(psser.map(d))) self.assert_eq(psser.map(d), pser.map(d)) + # series correspondence + pser_to_apply = pd.Series(["one", "two", "four"], index=["cat", "dog", "rabbit"]) + self.assert_eq(psser.map(pser_to_apply), pser.map(pser_to_apply)) + self.assert_eq( + psser.map(pser_to_apply, na_action="ignore"), + pser.map(pser_to_apply, na_action="ignore"), + ) + + # function correspondence + self.assert_eq( + psser.map(lambda x: x.upper(), na_action="ignore"), + pser.map(lambda x: x.upper(), na_action="ignore"), + ) + + def to_upper(string) -> str: + return string.upper() if string else "" + + self.assert_eq(psser.map(to_upper), pser.map(to_upper)) + def tomorrow(date) -> datetime: return date + timedelta(days=1) @@ -2071,6 +2111,48 @@ def test_asof(self): with ps.option_context("compute.eager_check", False): self.assert_eq(psser.asof(20), 4.0) + pser = pd.Series([2, 1, np.nan, 4], index=[10, 20, 30, 40], name="Koalas") + psser = ps.from_pandas(pser) + self.assert_eq(psser.asof([5, 20]), pser.asof([5, 20])) + + pser = pd.Series([4, np.nan, np.nan, 2], index=[10, 20, 30, 40], name="Koalas") + psser = ps.from_pandas(pser) + self.assert_eq(psser.asof([5, 100]), pser.asof([5, 100])) + + pser = pd.Series([np.nan, 4, 1, 2], index=[10, 20, 30, 40], name="Koalas") + psser = ps.from_pandas(pser) + self.assert_eq(psser.asof([5, 35]), pser.asof([5, 35])) + + pser = pd.Series([2, 1, np.nan, 4], index=[10, 20, 30, 40], name="Koalas") + psser = ps.from_pandas(pser) + self.assert_eq(psser.asof([25, 25]), pser.asof([25, 25])) + + pser = pd.Series([2, 1, np.nan, 4], index=["a", "b", "c", "d"], name="Koalas") + psser = ps.from_pandas(pser) + self.assert_eq(psser.asof(["a", "d"]), pser.asof(["a", "d"])) + + pser = pd.Series( + [2, 1, np.nan, 4], + index=[ + pd.Timestamp(2020, 1, 1), + pd.Timestamp(2020, 2, 2), + pd.Timestamp(2020, 3, 3), + pd.Timestamp(2020, 4, 4), + ], + name="Koalas", + ) + psser = ps.from_pandas(pser) + self.assert_eq( + psser.asof([pd.Timestamp(2020, 1, 1)]), + pser.asof([pd.Timestamp(2020, 1, 1)]), + ) + + pser = pd.Series([2, np.nan, 1, 4], index=[10, 20, 30, 40], name="Koalas") + psser = ps.from_pandas(pser) + self.assert_eq(psser.asof(np.nan), pser.asof(np.nan)) + self.assert_eq(psser.asof([np.nan, np.nan]), pser.asof([np.nan, np.nan])) + self.assert_eq(psser.asof([10, np.nan]), pser.asof([10, np.nan])) + def test_squeeze(self): # Single value pser = pd.Series([90]) diff --git a/python/pyspark/pandas/tests/test_series_datetime.py b/python/pyspark/pandas/tests/test_series_datetime.py index 637c1897bf544..d837c34fc7439 100644 --- a/python/pyspark/pandas/tests/test_series_datetime.py +++ b/python/pyspark/pandas/tests/test_series_datetime.py @@ -264,8 +264,8 @@ def test_floor(self): self.check_func(lambda x: x.dt.floor(freq="H")) def test_ceil(self): - self.check_func(lambda x: x.dt.floor(freq="min")) - self.check_func(lambda x: x.dt.floor(freq="H")) + self.check_func(lambda x: x.dt.ceil(freq="min")) + self.check_func(lambda x: x.dt.ceil(freq="H")) @unittest.skip("Unsupported locale setting") def test_month_name(self): diff --git a/python/pyspark/pandas/tests/test_series_string.py b/python/pyspark/pandas/tests/test_series_string.py index 832cc0bbfeb46..0b778583e735a 100644 --- a/python/pyspark/pandas/tests/test_series_string.py +++ b/python/pyspark/pandas/tests/test_series_string.py @@ -248,8 +248,11 @@ def test_string_replace(self): self.check_func(lambda x: x.str.replace("a.", "xx", regex=True)) self.check_func(lambda x: x.str.replace("a.", "xx", regex=False)) self.check_func(lambda x: x.str.replace("ing", "0", flags=re.IGNORECASE)) + # reverse every lowercase word - repl = lambda m: m.group(0)[::-1] + def repl(m): + return m.group(0)[::-1] + self.check_func(lambda x: x.str.replace(r"[a-z]+", repl)) # compiled regex with flags regex_pat = re.compile(r"WHITESPACE", flags=re.IGNORECASE) diff --git a/python/pyspark/pandas/tests/test_typedef.py b/python/pyspark/pandas/tests/test_typedef.py index ef331da8bec1e..1bc5c8cfdd051 100644 --- a/python/pyspark/pandas/tests/test_typedef.py +++ b/python/pyspark/pandas/tests/test_typedef.py @@ -56,10 +56,27 @@ class TypeHintTests(unittest.TestCase): - @unittest.skipIf( - sys.version_info < (3, 7), - "Type inference from pandas instances is supported with Python 3.7+", - ) + def test_infer_schema_with_no_return(self): + def try_infer_return_type(): + def f(): + pass + + infer_return_type(f) + + self.assertRaisesRegex( + ValueError, "A return value is required for the input function", try_infer_return_type + ) + + def try_infer_return_type(): + def f() -> None: + pass + + infer_return_type(f) + + self.assertRaisesRegex( + TypeError, "Type was not understood", try_infer_return_type + ) + def test_infer_schema_from_pandas_instances(self): def func() -> pd.Series[int]: pass @@ -148,10 +165,6 @@ def test_if_pandas_implements_class_getitem(self): assert not ps._frame_has_class_getitem assert not ps._series_has_class_getitem - @unittest.skipIf( - sys.version_info < (3, 7), - "Type inference from pandas instances is supported with Python 3.7+", - ) def test_infer_schema_with_names_pandas_instances(self): def func() -> 'pd.DataFrame["a" : np.float_, "b":str]': # noqa: F405 pass @@ -201,10 +214,6 @@ def func() -> pd.DataFrame[zip(pdf.columns, pdf.dtypes)]: self.assertEqual(inferred.dtypes, [np.int64, CategoricalDtype(categories=["a", "b", "c"])]) self.assertEqual(inferred.spark_type, expected) - @unittest.skipIf( - sys.version_info < (3, 7), - "Type inference from pandas instances is supported with Python 3.7+", - ) def test_infer_schema_with_names_pandas_instances_negative(self): def try_infer_return_type(): def f() -> 'pd.DataFrame["a" : np.float_ : 1, "b":str:2]': # noqa: F405 diff --git a/python/pyspark/pandas/typedef/__init__.py b/python/pyspark/pandas/typedef/__init__.py index 5f7ea2834a52a..49490674d7291 100644 --- a/python/pyspark/pandas/typedef/__init__.py +++ b/python/pyspark/pandas/typedef/__init__.py @@ -15,4 +15,4 @@ # limitations under the License. # -from pyspark.pandas.typedef.typehints import * # noqa: F401,F405 +from pyspark.pandas.typedef.typehints import * # noqa: F401,F403,F405 diff --git a/python/pyspark/pandas/typedef/string_typehints.py b/python/pyspark/pandas/typedef/string_typehints.py deleted file mode 100644 index c7a72351ad934..0000000000000 --- a/python/pyspark/pandas/typedef/string_typehints.py +++ /dev/null @@ -1,40 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from inspect import FullArgSpec -from typing import List, Optional, Type, cast as _cast # noqa: F401 - -import numpy as np # noqa: F401 -import pandas # noqa: F401 -import pandas as pd # noqa: F401 -from numpy import * # noqa: F401 -from pandas import * # type: ignore[no-redef] # noqa: F401 -from inspect import getfullargspec # noqa: F401 - - -def resolve_string_type_hint(tpe: str) -> Optional[Type]: - import pyspark.pandas as ps - from pyspark.pandas import DataFrame, Series # type: ignore[misc] - - locs = { - "ps": ps, - "pyspark.pandas": ps, - "DataFrame": DataFrame, - "Series": Series, - } - # This is a hack to resolve the forward reference string. - exec("def func() -> %s: pass\narg_spec = getfullargspec(func)" % tpe, globals(), locs) - return _cast(FullArgSpec, locs["arg_spec"]).annotations.get("return", None) diff --git a/python/pyspark/pandas/typedef/typehints.py b/python/pyspark/pandas/typedef/typehints.py index eddd06d061def..695ed31af6f42 100644 --- a/python/pyspark/pandas/typedef/typehints.py +++ b/python/pyspark/pandas/typedef/typehints.py @@ -22,18 +22,10 @@ import decimal import sys import typing -from collections import Iterable +from collections.abc import Iterable from distutils.version import LooseVersion -from inspect import getfullargspec, isclass -from typing import ( - Any, - Callable, - Generic, - List, - Tuple, - Union, - Type, -) +from inspect import isclass +from typing import Any, Callable, Generic, List, Tuple, Union, Type, get_type_hints import numpy as np import pandas as pd @@ -76,7 +68,6 @@ # For running doctests and reference resolution in PyCharm. from pyspark import pandas as ps # noqa: F401 from pyspark.pandas._typing import Dtype, T -from pyspark.pandas.typedef.string_typehints import resolve_string_type_hint if typing.TYPE_CHECKING: from pyspark.pandas.internal import InternalField @@ -566,11 +557,10 @@ def infer_return_type(f: Callable) -> Union[SeriesType, DataFrameType, ScalarTyp from pyspark.pandas.typedef import SeriesType, NameTypeHolder, IndexNameTypeHolder from pyspark.pandas.utils import name_like_string - spec = getfullargspec(f) - tpe = spec.annotations.get("return", None) - if isinstance(tpe, str): - # This type hint can happen when given hints are string to avoid forward reference. - tpe = resolve_string_type_hint(tpe) + tpe = get_type_hints(f).get("return", None) + + if tpe is None: + raise ValueError("A return value is required for the input function") if hasattr(tpe, "__origin__") and issubclass(tpe.__origin__, SeriesType): tpe = tpe.__args__[0] diff --git a/python/pyspark/pandas/usage_logging/__init__.py b/python/pyspark/pandas/usage_logging/__init__.py index b350faf6b9ca5..a6f1470b9f4e4 100644 --- a/python/pyspark/pandas/usage_logging/__init__.py +++ b/python/pyspark/pandas/usage_logging/__init__.py @@ -15,11 +15,6 @@ # limitations under the License. # -import functools -import importlib -import inspect -import threading -import time from types import ModuleType from typing import Union @@ -60,6 +55,7 @@ ) from pyspark.pandas.strings import StringMethods from pyspark.pandas.window import Expanding, ExpandingGroupby, Rolling, RollingGroupby +from pyspark.instrumentation_utils import _attach def attach(logger_module: Union[str, ModuleType]) -> None: @@ -76,10 +72,6 @@ def attach(logger_module: Union[str, ModuleType]) -> None: -------- usage_logger : the reference implementation of the usage logger. """ - if isinstance(logger_module, str): - logger_module = importlib.import_module(logger_module) - - logger = getattr(logger_module, "get_logger")() modules = [config, namespace] classes = [ @@ -116,42 +108,7 @@ def attach(logger_module: Union[str, ModuleType]) -> None: sql_formatter._CAPTURE_SCOPES = 4 modules.append(sql_formatter) - # Modules - for target_module in modules: - target_name = target_module.__name__.split(".")[-1] - for name in getattr(target_module, "__all__"): - func = getattr(target_module, name) - if not inspect.isfunction(func): - continue - setattr(target_module, name, _wrap_function(target_name, name, func, logger)) - - special_functions = set( - [ - "__init__", - "__repr__", - "__str__", - "_repr_html_", - "__len__", - "__getitem__", - "__setitem__", - "__getattr__", - ] - ) - - # Classes - for target_class in classes: - for name, func in inspect.getmembers(target_class, inspect.isfunction): - if name.startswith("_") and name not in special_functions: - continue - setattr(target_class, name, _wrap_function(target_class.__name__, name, func, logger)) - - for name, prop in inspect.getmembers(target_class, lambda o: isinstance(o, property)): - if name.startswith("_"): - continue - setattr(target_class, name, _wrap_property(target_class.__name__, name, prop, logger)) - - # Missings - for original, missing in [ + missings = [ (pd.DataFrame, _MissingPandasLikeDataFrame), (pd.Series, MissingPandasLikeSeries), (pd.Index, MissingPandasLikeIndex), @@ -163,105 +120,6 @@ def attach(logger_module: Union[str, ModuleType]) -> None: (pd.core.window.Rolling, MissingPandasLikeRolling), (pd.core.window.ExpandingGroupby, MissingPandasLikeExpandingGroupby), (pd.core.window.RollingGroupby, MissingPandasLikeRollingGroupby), - ]: - for name, func in inspect.getmembers(missing, inspect.isfunction): - setattr( - missing, - name, - _wrap_missing_function(original.__name__, name, func, original, logger), - ) - - for name, prop in inspect.getmembers(missing, lambda o: isinstance(o, property)): - setattr(missing, name, _wrap_missing_property(original.__name__, name, prop, logger)) - - -_local = threading.local() - - -def _wrap_function(class_name, function_name, func, logger): - - signature = inspect.signature(func) - - @functools.wraps(func) - def wrapper(*args, **kwargs): - if hasattr(_local, "logging") and _local.logging: - # no need to log since this should be internal call. - return func(*args, **kwargs) - _local.logging = True - try: - start = time.perf_counter() - try: - res = func(*args, **kwargs) - logger.log_success( - class_name, function_name, time.perf_counter() - start, signature - ) - return res - except Exception as ex: - logger.log_failure( - class_name, function_name, ex, time.perf_counter() - start, signature - ) - raise - finally: - _local.logging = False - - return wrapper - - -def _wrap_property(class_name, property_name, prop, logger): - @property - def wrapper(self): - if hasattr(_local, "logging") and _local.logging: - # no need to log since this should be internal call. - return prop.fget(self) - _local.logging = True - try: - start = time.perf_counter() - try: - res = prop.fget(self) - logger.log_success(class_name, property_name, time.perf_counter() - start) - return res - except Exception as ex: - logger.log_failure(class_name, property_name, ex, time.perf_counter() - start) - raise - finally: - _local.logging = False - - wrapper.__doc__ = prop.__doc__ - - if prop.fset is not None: - wrapper = wrapper.setter(_wrap_function(class_name, prop.fset.__name__, prop.fset, logger)) - - return wrapper - - -def _wrap_missing_function(class_name, function_name, func, original, logger): - - if not hasattr(original, function_name): - return func - - signature = inspect.signature(getattr(original, function_name)) - - is_deprecated = func.__name__ == "deprecated_function" - - @functools.wraps(func) - def wrapper(*args, **kwargs): - try: - return func(*args, **kwargs) - finally: - logger.log_missing(class_name, function_name, is_deprecated, signature) - - return wrapper - - -def _wrap_missing_property(class_name, property_name, prop, logger): - - is_deprecated = prop.fget.__name__ == "deprecated_property" - - @property - def wrapper(self): - try: - return prop.fget(self) - finally: - logger.log_missing(class_name, property_name, is_deprecated) + ] - return wrapper + _attach(logger_module, modules, classes, missings) diff --git a/python/pyspark/pandas/utils.py b/python/pyspark/pandas/utils.py index 43c203f6de5d6..a61ea7d19b3ec 100644 --- a/python/pyspark/pandas/utils.py +++ b/python/pyspark/pandas/utils.py @@ -149,7 +149,9 @@ def combine_frames( if get_option("compute.ops_on_diff_frames"): def resolve(internal: InternalFrame, side: str) -> InternalFrame: - rename = lambda col: "__{}_{}".format(side, col) + def rename(col: str) -> str: + return "__{}_{}".format(side, col) + internal = internal.resolved_copy sdf = internal.spark_frame sdf = internal.spark_frame.select( @@ -465,11 +467,22 @@ def is_testing() -> bool: def default_session() -> SparkSession: spark = SparkSession.getActiveSession() - if spark is not None: - return spark + if spark is None: + spark = SparkSession.builder.appName("pandas-on-Spark").getOrCreate() + + # Turn ANSI off when testing the pandas API on Spark since + # the behavior of pandas API on Spark follows pandas, not SQL. + if is_testing(): + spark.conf.set("spark.sql.ansi.enabled", False) # type: ignore[arg-type] + if spark.conf.get("spark.sql.ansi.enabled") == "true": + log_advice( + "The config 'spark.sql.ansi.enabled' is set to True. " + "This can cause unexpected behavior " + "from pandas API on Spark since pandas API on Spark follows " + "the behavior of pandas, not SQL." + ) - builder = SparkSession.builder.appName("pandas-on-Spark") - return builder.getOrCreate() + return spark @contextmanager @@ -914,7 +927,7 @@ def spark_column_equals(left: Column, right: Column) -> bool: >>> spark_column_equals(sdf1["x"] + 1, sdf2["x"] + 1) False """ - return left._jc.equals(right._jc) # type: ignore[operator] + return left._jc.equals(right._jc) def compare_null_first( diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py index 6271bbc4814f0..45365cc1e79b0 100644 --- a/python/pyspark/profiler.py +++ b/python/pyspark/profiler.py @@ -15,6 +15,8 @@ # limitations under the License. # +from typing import Any, Callable, List, Optional, Type, TYPE_CHECKING, cast + import cProfile import pstats import os @@ -23,6 +25,9 @@ from pyspark.accumulators import AccumulatorParam +if TYPE_CHECKING: + from pyspark.context import SparkContext + class ProfilerCollector: """ @@ -31,21 +36,26 @@ class ProfilerCollector: the different stages/UDFs. """ - def __init__(self, profiler_cls, udf_profiler_cls, dump_path=None): - self.profiler_cls = profiler_cls - self.udf_profiler_cls = udf_profiler_cls - self.profile_dump_path = dump_path - self.profilers = [] - - def new_profiler(self, ctx): + def __init__( + self, + profiler_cls: Type["Profiler"], + udf_profiler_cls: Type["Profiler"], + dump_path: Optional[str] = None, + ): + self.profiler_cls: Type[Profiler] = profiler_cls + self.udf_profiler_cls: Type[Profiler] = udf_profiler_cls + self.profile_dump_path: Optional[str] = dump_path + self.profilers: List[List[Any]] = [] + + def new_profiler(self, ctx: "SparkContext") -> "Profiler": """Create a new profiler using class `profiler_cls`""" return self.profiler_cls(ctx) - def new_udf_profiler(self, ctx): + def new_udf_profiler(self, ctx: "SparkContext") -> "Profiler": """Create a new profiler using class `udf_profiler_cls`""" return self.udf_profiler_cls(ctx) - def add_profiler(self, id, profiler): + def add_profiler(self, id: int, profiler: "Profiler") -> None: """Add a profiler for RDD/UDF `id`""" if not self.profilers: if self.profile_dump_path: @@ -55,13 +65,13 @@ def add_profiler(self, id, profiler): self.profilers.append([id, profiler, False]) - def dump_profiles(self, path): + def dump_profiles(self, path: str) -> None: """Dump the profile stats into directory `path`""" for id, profiler, _ in self.profilers: profiler.dump(id, path) self.profilers = [] - def show_profiles(self): + def show_profiles(self) -> None: """Print the profile stats to stdout""" for i, (id, profiler, showed) in enumerate(self.profilers): if not showed and profiler: @@ -108,18 +118,18 @@ class Profiler: This API is a developer API. """ - def __init__(self, ctx): + def __init__(self, ctx: "SparkContext") -> None: pass - def profile(self, func, *args, **kwargs): + def profile(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: """Do profiling on the function `func`""" raise NotImplementedError - def stats(self): + def stats(self) -> pstats.Stats: """Return the collected profiling stats (pstats.Stats)""" raise NotImplementedError - def show(self, id): + def show(self, id: int) -> None: """Print the profile stats to stdout, id is the RDD id""" stats = self.stats() if stats: @@ -128,7 +138,7 @@ def show(self, id): print("=" * 60) stats.sort_stats("time", "cumulative").print_stats() - def dump(self, id, path): + def dump(self, id: int, path: str) -> None: """Dump the profile into path, id is the RDD id""" if not os.path.exists(path): os.makedirs(path) @@ -138,15 +148,17 @@ def dump(self, id, path): stats.dump_stats(p) -class PStatsParam(AccumulatorParam): +class PStatsParam(AccumulatorParam[Optional[pstats.Stats]]): """PStatsParam is used to merge pstats.Stats""" @staticmethod - def zero(value): + def zero(value: Optional[pstats.Stats]) -> None: return None @staticmethod - def addInPlace(value1, value2): + def addInPlace( + value1: Optional[pstats.Stats], value2: Optional[pstats.Stats] + ) -> Optional[pstats.Stats]: if value1 is None: return value2 value1.add(value2) @@ -159,27 +171,27 @@ class BasicProfiler(Profiler): cProfile and Accumulator """ - def __init__(self, ctx): + def __init__(self, ctx: "SparkContext") -> None: Profiler.__init__(self, ctx) # Creates a new accumulator for combining the profiles of different # partitions of a stage - self._accumulator = ctx.accumulator(None, PStatsParam) + self._accumulator = ctx.accumulator(None, PStatsParam) # type: ignore[arg-type] - def profile(self, func, *args, **kwargs): + def profile(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: """Runs and profiles the method to_profile passed in. A profile object is returned.""" pr = cProfile.Profile() ret = pr.runcall(func, *args, **kwargs) st = pstats.Stats(pr) - st.stream = None # make it picklable + st.stream = None # type: ignore[attr-defined] # make it picklable st.strip_dirs() # Adds a new profile to the existing accumulated value - self._accumulator.add(st) + self._accumulator.add(st) # type: ignore[arg-type] return ret - def stats(self): - return self._accumulator.value + def stats(self) -> pstats.Stats: + return cast(pstats.Stats, self._accumulator.value) class UDFBasicProfiler(BasicProfiler): @@ -187,7 +199,7 @@ class UDFBasicProfiler(BasicProfiler): UDFBasicProfiler is the profiler for Python/Pandas UDFs. """ - def show(self, id): + def show(self, id: int) -> None: """Print the profile stats to stdout, id is the PythonUDF id""" stats = self.stats() if stats: @@ -196,7 +208,7 @@ def show(self, id): print("=" * 60) stats.sort_stats("time", "cumulative").print_stats() - def dump(self, id, path): + def dump(self, id: int, path: str) -> None: """Dump the profile into path, id is the PythonUDF id""" if not os.path.exists(path): os.makedirs(path) diff --git a/python/pyspark/profiler.pyi b/python/pyspark/profiler.pyi deleted file mode 100644 index 85aa6a248036c..0000000000000 --- a/python/pyspark/profiler.pyi +++ /dev/null @@ -1,65 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import Any, Callable, List, Optional, Tuple, Type - -import pstats - -from pyspark.accumulators import AccumulatorParam -from pyspark.context import SparkContext - -class ProfilerCollector: - profiler_cls: Type[Profiler] - udf_profiler_cls: Type[Profiler] - profile_dump_path: Optional[str] - profilers: List[Tuple[int, Profiler, bool]] - def __init__( - self, - profiler_cls: Type[Profiler], - udf_profiler_cls: Type[Profiler], - dump_path: Optional[str] = ..., - ) -> None: ... - def new_profiler(self, ctx: SparkContext) -> Profiler: ... - def new_udf_profiler(self, ctx: SparkContext) -> Profiler: ... - def add_profiler(self, id: int, profiler: Profiler) -> None: ... - def dump_profiles(self, path: str) -> None: ... - def show_profiles(self) -> None: ... - -class Profiler: - def __init__(self, ctx: SparkContext) -> None: ... - def profile(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: ... - def stats(self) -> pstats.Stats: ... - def show(self, id: int) -> None: ... - def dump(self, id: int, path: str) -> None: ... - -class PStatsParam(AccumulatorParam): - @staticmethod - def zero(value: pstats.Stats) -> None: ... - @staticmethod - def addInPlace( - value1: Optional[pstats.Stats], value2: Optional[pstats.Stats] - ) -> Optional[pstats.Stats]: ... - -class BasicProfiler(Profiler): - def __init__(self, ctx: SparkContext) -> None: ... - def profile(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: ... - def stats(self) -> pstats.Stats: ... - -class UDFBasicProfiler(BasicProfiler): - def show(self, id: int) -> None: ... - def dump(self, id: int, path: str) -> None: ... diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 27b6665ecf1ce..611183160a5f4 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -30,6 +30,26 @@ from itertools import chain from functools import reduce from math import sqrt, log, isinf, isnan, pow, ceil +from typing import ( + Any, + Callable, + Dict, + Generic, + Hashable, + Iterable, + Iterator, + IO, + List, + NoReturn, + Optional, + Sequence, + Tuple, + Union, + TypeVar, + cast, + overload, + TYPE_CHECKING, +) from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import ( @@ -40,6 +60,7 @@ CloudPickleSerializer, PairDeserializer, CPickleSerializer, + Serializer, pack_long, read_int, write_int, @@ -67,6 +88,41 @@ from pyspark.traceback_utils import SCCallSiteSync from pyspark.util import fail_on_stopiteration, _parse_memory + +if TYPE_CHECKING: + import socket + import io + + from pyspark._typing import NonUDFType + from pyspark._typing import S, NumberOrArray + from pyspark.context import SparkContext + from pyspark.sql.pandas._typing import ( + PandasScalarUDFType, + PandasGroupedMapUDFType, + PandasGroupedAggUDFType, + PandasWindowAggUDFType, + PandasScalarIterUDFType, + PandasMapIterUDFType, + PandasCogroupedMapUDFType, + ArrowMapIterUDFType, + ) + from pyspark.sql.dataframe import DataFrame + from pyspark.sql.types import AtomicType, StructType + from pyspark.sql._typing import AtomicValue, RowLike, SQLBatchedUDFType + + from py4j.java_gateway import JavaObject + from py4j.java_collections import JavaArray + +T = TypeVar("T") +T_co = TypeVar("T_co", covariant=True) +U = TypeVar("U") +K = TypeVar("K", bound=Hashable) +V = TypeVar("V") +V1 = TypeVar("V1") +V2 = TypeVar("V2") +V3 = TypeVar("V3") + + __all__ = ["RDD"] @@ -79,21 +135,21 @@ class PythonEvalType: These values should match values in org.apache.spark.api.python.PythonEvalType. """ - NON_UDF = 0 + NON_UDF: "NonUDFType" = 0 - SQL_BATCHED_UDF = 100 + SQL_BATCHED_UDF: "SQLBatchedUDFType" = 100 - SQL_SCALAR_PANDAS_UDF = 200 - SQL_GROUPED_MAP_PANDAS_UDF = 201 - SQL_GROUPED_AGG_PANDAS_UDF = 202 - SQL_WINDOW_AGG_PANDAS_UDF = 203 - SQL_SCALAR_PANDAS_ITER_UDF = 204 - SQL_MAP_PANDAS_ITER_UDF = 205 - SQL_COGROUPED_MAP_PANDAS_UDF = 206 - SQL_MAP_ARROW_ITER_UDF = 207 + SQL_SCALAR_PANDAS_UDF: "PandasScalarUDFType" = 200 + SQL_GROUPED_MAP_PANDAS_UDF: "PandasGroupedMapUDFType" = 201 + SQL_GROUPED_AGG_PANDAS_UDF: "PandasGroupedAggUDFType" = 202 + SQL_WINDOW_AGG_PANDAS_UDF: "PandasWindowAggUDFType" = 203 + SQL_SCALAR_PANDAS_ITER_UDF: "PandasScalarIterUDFType" = 204 + SQL_MAP_PANDAS_ITER_UDF: "PandasMapIterUDFType" = 205 + SQL_COGROUPED_MAP_PANDAS_UDF: "PandasCogroupedMapUDFType" = 206 + SQL_MAP_ARROW_ITER_UDF: "ArrowMapIterUDFType" = 207 -def portable_hash(x): +def portable_hash(x: Hashable) -> int: """ This function returns consistent hash code for builtin types, especially for None and tuple with None. @@ -137,7 +193,11 @@ class BoundedFloat(float): 100.0 """ - def __new__(cls, mean, confidence, low, high): + confidence: float + low: float + high: float + + def __new__(cls, mean: float, confidence: float, low: float, high: float) -> "BoundedFloat": obj = float.__new__(cls, mean) obj.confidence = confidence obj.low = low @@ -145,7 +205,7 @@ def __new__(cls, mean, confidence, low, high): return obj -def _create_local_socket(sock_info): +def _create_local_socket(sock_info: "JavaArray") -> "io.BufferedRWPair": """ Create a local socket that can be used to load deserialized data from the JVM @@ -158,8 +218,10 @@ def _create_local_socket(sock_info): ------- sockfile file descriptor of the local socket """ - port = sock_info[0] - auth_secret = sock_info[1] + sockfile: "io.BufferedRWPair" + sock: "socket.socket" + port: int = sock_info[0] + auth_secret: str = sock_info[1] sockfile, sock = local_connect_and_auth(port, auth_secret) # The RDD materialization time is unpredictable, if we set a timeout for socket reading # operation, it will very possibly fail. See SPARK-18281. @@ -167,7 +229,7 @@ def _create_local_socket(sock_info): return sockfile -def _load_from_socket(sock_info, serializer): +def _load_from_socket(sock_info: "JavaArray", serializer: Serializer) -> Iterator[Any]: """ Connect to a local socket described by sock_info and use the given serializer to yield data @@ -188,18 +250,21 @@ def _load_from_socket(sock_info, serializer): return serializer.load_stream(sockfile) -def _local_iterator_from_socket(sock_info, serializer): +def _local_iterator_from_socket(sock_info: "JavaArray", serializer: Serializer) -> Iterator[Any]: class PyLocalIterable: """Create a synchronous local iterable over a socket""" - def __init__(self, _sock_info, _serializer): + def __init__(self, _sock_info: "JavaArray", _serializer: Serializer): + port: int + auth_secret: str + jsocket_auth_server: "JavaObject" port, auth_secret, self.jsocket_auth_server = _sock_info self._sockfile = _create_local_socket((port, auth_secret)) self._serializer = _serializer - self._read_iter = iter([]) # Initialize as empty iterator + self._read_iter: Iterator[Any] = iter([]) # Initialize as empty iterator self._read_status = 1 - def __iter__(self): + def __iter__(self) -> Iterator[Any]: while self._read_status == 1: # Request next partition data from Java write_int(1, self._sockfile) @@ -218,7 +283,7 @@ def __iter__(self): elif self._read_status == -1: self.jsocket_auth_server.getResult() - def __del__(self): + def __del__(self) -> None: # If local iterator is not fully consumed, if self._read_status == 1: try: @@ -236,22 +301,22 @@ def __del__(self): class Partitioner: - def __init__(self, numPartitions, partitionFunc): + def __init__(self, numPartitions: int, partitionFunc: Callable[[Any], int]): self.numPartitions = numPartitions self.partitionFunc = partitionFunc - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return ( isinstance(other, Partitioner) and self.numPartitions == other.numPartitions and self.partitionFunc == other.partitionFunc ) - def __call__(self, k): + def __call__(self, k: Any) -> int: return self.partitionFunc(k) % self.numPartitions -class RDD: +class RDD(Generic[T_co]): """ A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. @@ -259,7 +324,12 @@ class RDD: operated on in parallel. """ - def __init__(self, jrdd, ctx, jrdd_deserializer=AutoBatchedSerializer(CPickleSerializer())): + def __init__( + self, + jrdd: "JavaObject", + ctx: "SparkContext", + jrdd_deserializer: Serializer = AutoBatchedSerializer(CPickleSerializer()), + ): self._jrdd = jrdd self.is_cached = False self.is_checkpointed = False @@ -267,21 +337,21 @@ def __init__(self, jrdd, ctx, jrdd_deserializer=AutoBatchedSerializer(CPickleSer self.ctx = ctx self._jrdd_deserializer = jrdd_deserializer self._id = jrdd.id() - self.partitioner = None + self.partitioner: Optional[Partitioner] = None - def _pickled(self): + def _pickled(self: "RDD[T]") -> "RDD[T]": return self._reserialize(AutoBatchedSerializer(CPickleSerializer())) - def id(self): + def id(self) -> int: """ A unique ID for this RDD (within its SparkContext). """ return self._id - def __repr__(self): + def __repr__(self) -> str: return self._jrdd.toString() - def __getnewargs__(self): + def __getnewargs__(self) -> NoReturn: # This method is called when attempting to pickle an RDD, which is always an error: raise RuntimeError( "It appears that you are attempting to broadcast an RDD or reference an RDD from an " @@ -293,13 +363,13 @@ def __getnewargs__(self): ) @property - def context(self): + def context(self) -> "SparkContext": """ The :class:`SparkContext` that this RDD was created on. """ return self.ctx - def cache(self): + def cache(self: "RDD[T]") -> "RDD[T]": """ Persist this RDD with the default storage level (`MEMORY_ONLY`). """ @@ -307,7 +377,7 @@ def cache(self): self.persist(StorageLevel.MEMORY_ONLY) return self - def persist(self, storageLevel=StorageLevel.MEMORY_ONLY): + def persist(self: "RDD[T]", storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) -> "RDD[T]": """ Set this RDD's storage level to persist its values across operations after the first time it is computed. This can only be used to assign @@ -325,7 +395,7 @@ def persist(self, storageLevel=StorageLevel.MEMORY_ONLY): self._jrdd.persist(javaStorageLevel) return self - def unpersist(self, blocking=False): + def unpersist(self: "RDD[T]", blocking: bool = False) -> "RDD[T]": """ Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. @@ -338,7 +408,7 @@ def unpersist(self, blocking=False): self._jrdd.unpersist(blocking) return self - def checkpoint(self): + def checkpoint(self) -> None: """ Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint directory set with :meth:`SparkContext.setCheckpointDir` and @@ -350,13 +420,13 @@ def checkpoint(self): self.is_checkpointed = True self._jrdd.rdd().checkpoint() - def isCheckpointed(self): + def isCheckpointed(self) -> bool: """ Return whether this RDD is checkpointed and materialized, either reliably or locally. """ return self._jrdd.rdd().isCheckpointed() - def localCheckpoint(self): + def localCheckpoint(self) -> None: """ Mark this RDD for local checkpointing using Spark's existing caching layer. @@ -377,7 +447,7 @@ def localCheckpoint(self): """ self._jrdd.rdd().localCheckpoint() - def isLocallyCheckpointed(self): + def isLocallyCheckpointed(self) -> bool: """ Return whether this RDD is marked for local checkpointing. @@ -385,17 +455,37 @@ def isLocallyCheckpointed(self): """ return self._jrdd.rdd().isLocallyCheckpointed() - def getCheckpointFile(self): + def getCheckpointFile(self) -> Optional[str]: """ Gets the name of the file to which this RDD was checkpointed Not defined if RDD is checkpointed locally. """ checkpointFile = self._jrdd.rdd().getCheckpointFile() - if checkpointFile.isDefined(): - return checkpointFile.get() - def map(self, f, preservesPartitioning=False): + return checkpointFile.get() if checkpointFile.isDefined() else None + + def cleanShuffleDependencies(self, blocking: bool = False) -> None: + """ + Removes an RDD's shuffles and it's non-persisted ancestors. + + When running without a shuffle service, cleaning up shuffle files enables downscaling. + If you use the RDD after this call, you should checkpoint and materialize it first. + + .. versionadded:: 3.3.0 + + Parameters + ---------- + blocking : bool, optional + block on shuffle cleanup tasks. Disabled by default. + + Notes + ----- + This API is a developer API. + """ + self._jrdd.rdd().cleanShuffleDependencies(blocking) + + def map(self: "RDD[T]", f: Callable[[T], U], preservesPartitioning: bool = False) -> "RDD[U]": """ Return a new RDD by applying a function to each element of this RDD. @@ -406,12 +496,14 @@ def map(self, f, preservesPartitioning=False): [('a', 1), ('b', 1), ('c', 1)] """ - def func(_, iterator): + def func(_: int, iterator: Iterable[T]) -> Iterable[U]: return map(fail_on_stopiteration(f), iterator) return self.mapPartitionsWithIndex(func, preservesPartitioning) - def flatMap(self, f, preservesPartitioning=False): + def flatMap( + self: "RDD[T]", f: Callable[[T], Iterable[U]], preservesPartitioning: bool = False + ) -> "RDD[U]": """ Return a new RDD by first applying a function to all elements of this RDD, and then flattening the results. @@ -425,12 +517,14 @@ def flatMap(self, f, preservesPartitioning=False): [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ - def func(s, iterator): + def func(_: int, iterator: Iterable[T]) -> Iterable[U]: return chain.from_iterable(map(fail_on_stopiteration(f), iterator)) return self.mapPartitionsWithIndex(func, preservesPartitioning) - def mapPartitions(self, f, preservesPartitioning=False): + def mapPartitions( + self: "RDD[T]", f: Callable[[Iterable[T]], Iterable[U]], preservesPartitioning: bool = False + ) -> "RDD[U]": """ Return a new RDD by applying a function to each partition of this RDD. @@ -442,12 +536,16 @@ def mapPartitions(self, f, preservesPartitioning=False): [3, 7] """ - def func(s, iterator): + def func(_: int, iterator: Iterable[T]) -> Iterable[U]: return f(iterator) return self.mapPartitionsWithIndex(func, preservesPartitioning) - def mapPartitionsWithIndex(self, f, preservesPartitioning=False): + def mapPartitionsWithIndex( + self: "RDD[T]", + f: Callable[[int, Iterable[T]], Iterable[U]], + preservesPartitioning: bool = False, + ) -> "RDD[U]": """ Return a new RDD by applying a function to each partition of this RDD, while tracking the index of the original partition. @@ -461,7 +559,11 @@ def mapPartitionsWithIndex(self, f, preservesPartitioning=False): """ return PipelinedRDD(self, f, preservesPartitioning) - def mapPartitionsWithSplit(self, f, preservesPartitioning=False): + def mapPartitionsWithSplit( + self: "RDD[T]", + f: Callable[[int, Iterable[T]], Iterable[U]], + preservesPartitioning: bool = False, + ) -> "RDD[U]": """ Return a new RDD by applying a function to each partition of this RDD, @@ -484,7 +586,7 @@ def mapPartitionsWithSplit(self, f, preservesPartitioning=False): ) return self.mapPartitionsWithIndex(f, preservesPartitioning) - def getNumPartitions(self): + def getNumPartitions(self) -> int: """ Returns the number of partitions in RDD @@ -496,7 +598,7 @@ def getNumPartitions(self): """ return self._jrdd.partitions().size() - def filter(self, f): + def filter(self: "RDD[T]", f: Callable[[T], bool]) -> "RDD[T]": """ Return a new RDD containing only the elements that satisfy a predicate. @@ -507,12 +609,12 @@ def filter(self, f): [2, 4] """ - def func(iterator): + def func(iterator: Iterable[T]) -> Iterable[T]: return filter(fail_on_stopiteration(f), iterator) return self.mapPartitions(func, True) - def distinct(self, numPartitions=None): + def distinct(self: "RDD[T]", numPartitions: Optional[int] = None) -> "RDD[T]": """ Return a new RDD containing the distinct elements in this RDD. @@ -527,7 +629,9 @@ def distinct(self, numPartitions=None): .map(lambda x: x[0]) ) - def sample(self, withReplacement, fraction, seed=None): + def sample( + self: "RDD[T]", withReplacement: bool, fraction: float, seed: Optional[int] = None + ) -> "RDD[T]": """ Return a sampled subset of this RDD. @@ -556,7 +660,9 @@ def sample(self, withReplacement, fraction, seed=None): assert fraction >= 0.0, "Negative fraction value: %s" % fraction return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True) - def randomSplit(self, weights, seed=None): + def randomSplit( + self: "RDD[T]", weights: Sequence[Union[int, float]], seed: Optional[int] = None + ) -> "List[RDD[T]]": """ Randomly splits this RDD with the provided weights. @@ -593,7 +699,9 @@ def randomSplit(self, weights, seed=None): ] # this is ported from scala/spark/RDD.scala - def takeSample(self, withReplacement, num, seed=None): + def takeSample( + self: "RDD[T]", withReplacement: bool, num: int, seed: Optional[int] = None + ) -> List[T]: """ Return a fixed-size sampled subset of this RDD. @@ -651,7 +759,9 @@ def takeSample(self, withReplacement, num, seed=None): return samples[0:num] @staticmethod - def _computeFractionForSampleSize(sampleSizeLowerBound, total, withReplacement): + def _computeFractionForSampleSize( + sampleSizeLowerBound: int, total: int, withReplacement: bool + ) -> float: """ Returns a sampling rate that guarantees a sample of size >= sampleSizeLowerBound 99.99% of the time. @@ -683,7 +793,7 @@ def _computeFractionForSampleSize(sampleSizeLowerBound, total, withReplacement): gamma = -log(delta) / total return min(1, fraction + gamma + sqrt(gamma * gamma + 2 * gamma * fraction)) - def union(self, other): + def union(self: "RDD[T]", other: "RDD[U]") -> "RDD[Union[T, U]]": """ Return the union of this RDD and another one. @@ -694,7 +804,9 @@ def union(self, other): [1, 1, 2, 3, 1, 1, 2, 3] """ if self._jrdd_deserializer == other._jrdd_deserializer: - rdd = RDD(self._jrdd.union(other._jrdd), self.ctx, self._jrdd_deserializer) + rdd: "RDD[Union[T, U]]" = RDD( + self._jrdd.union(other._jrdd), self.ctx, self._jrdd_deserializer + ) else: # These RDDs contain data in different serialized formats, so we # must normalize them to the default serializer. @@ -708,7 +820,7 @@ def union(self, other): rdd.partitioner = self.partitioner return rdd - def intersection(self, other): + def intersection(self: "RDD[T]", other: "RDD[T]") -> "RDD[T]": """ Return the intersection of this RDD and another one. The output will not contain any duplicate elements, even if the input RDDs did. @@ -731,14 +843,14 @@ def intersection(self, other): .keys() ) - def _reserialize(self, serializer=None): + def _reserialize(self: "RDD[T]", serializer: Optional[Serializer] = None) -> "RDD[T]": serializer = serializer or self.ctx.serializer if self._jrdd_deserializer != serializer: self = self.map(lambda x: x, preservesPartitioning=True) self._jrdd_deserializer = serializer return self - def __add__(self, other): + def __add__(self: "RDD[T]", other: "RDD[U]") -> "RDD[Union[T, U]]": """ Return the union of this RDD and another one. @@ -752,9 +864,43 @@ def __add__(self, other): raise TypeError return self.union(other) + @overload def repartitionAndSortWithinPartitions( - self, numPartitions=None, partitionFunc=portable_hash, ascending=True, keyfunc=lambda x: x - ): + self: "RDD[Tuple[S, V]]", + numPartitions: Optional[int] = ..., + partitionFunc: Callable[["S"], int] = ..., + ascending: bool = ..., + ) -> "RDD[Tuple[S, V]]": + ... + + @overload + def repartitionAndSortWithinPartitions( + self: "RDD[Tuple[K, V]]", + numPartitions: Optional[int], + partitionFunc: Callable[[K], int], + ascending: bool, + keyfunc: Callable[[K], "S"], + ) -> "RDD[Tuple[K, V]]": + ... + + @overload + def repartitionAndSortWithinPartitions( + self: "RDD[Tuple[K, V]]", + numPartitions: Optional[int] = ..., + partitionFunc: Callable[[K], int] = ..., + ascending: bool = ..., + *, + keyfunc: Callable[[K], "S"], + ) -> "RDD[Tuple[K, V]]": + ... + + def repartitionAndSortWithinPartitions( + self: "RDD[Tuple[Any, Any]]", + numPartitions: Optional[int] = None, + partitionFunc: Callable[[Any], int] = portable_hash, + ascending: bool = True, + keyfunc: Callable[[Any], Any] = lambda x: x, + ) -> "RDD[Tuple[Any, Any]]": """ Repartition the RDD according to the given partitioner and, within each resulting partition, sort records by their keys. @@ -772,13 +918,45 @@ def repartitionAndSortWithinPartitions( memory = self._memory_limit() serializer = self._jrdd_deserializer - def sortPartition(iterator): + def sortPartition(iterator: Iterable[Tuple[K, V]]) -> Iterable[Tuple[K, V]]: sort = ExternalSorter(memory * 0.9, serializer).sorted return iter(sort(iterator, key=lambda k_v: keyfunc(k_v[0]), reverse=(not ascending))) return self.partitionBy(numPartitions, partitionFunc).mapPartitions(sortPartition, True) - def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): + @overload + def sortByKey( + self: "RDD[Tuple[S, V]]", + ascending: bool = ..., + numPartitions: Optional[int] = ..., + ) -> "RDD[Tuple[K, V]]": + ... + + @overload + def sortByKey( + self: "RDD[Tuple[K, V]]", + ascending: bool, + numPartitions: int, + keyfunc: Callable[[K], "S"], + ) -> "RDD[Tuple[K, V]]": + ... + + @overload + def sortByKey( + self: "RDD[Tuple[K, V]]", + ascending: bool = ..., + numPartitions: Optional[int] = ..., + *, + keyfunc: Callable[[K], "S"], + ) -> "RDD[Tuple[K, V]]": + ... + + def sortByKey( + self: "RDD[Tuple[K, V]]", + ascending: Optional[bool] = True, + numPartitions: Optional[int] = None, + keyfunc: Callable[[Any], Any] = lambda x: x, + ) -> "RDD[Tuple[K, V]]": """ Sorts this RDD, which is assumed to consist of (key, value) pairs. @@ -802,7 +980,7 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): memory = self._memory_limit() serializer = self._jrdd_deserializer - def sortPartition(iterator): + def sortPartition(iterator: Iterable[Tuple[K, V]]) -> Iterable[Tuple[K, V]]: sort = ExternalSorter(memory * 0.9, serializer).sorted return iter(sort(iterator, key=lambda kv: keyfunc(kv[0]), reverse=(not ascending))) @@ -829,16 +1007,21 @@ def sortPartition(iterator): for i in range(0, numPartitions - 1) ] - def rangePartitioner(k): + def rangePartitioner(k: K) -> int: p = bisect.bisect_left(bounds, keyfunc(k)) if ascending: return p else: - return numPartitions - 1 - p + return numPartitions - 1 - p # type: ignore[operator] return self.partitionBy(numPartitions, rangePartitioner).mapPartitions(sortPartition, True) - def sortBy(self, keyfunc, ascending=True, numPartitions=None): + def sortBy( + self: "RDD[T]", + keyfunc: Callable[[T], "S"], + ascending: bool = True, + numPartitions: Optional[int] = None, + ) -> "RDD[T]": """ Sorts this RDD by the given keyfunc @@ -850,9 +1033,13 @@ def sortBy(self, keyfunc, ascending=True, numPartitions=None): >>> sc.parallelize(tmp).sortBy(lambda x: x[1]).collect() [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)] """ - return self.keyBy(keyfunc).sortByKey(ascending, numPartitions).values() + return ( + self.keyBy(keyfunc) # type: ignore[type-var] + .sortByKey(ascending, numPartitions) + .values() + ) - def glom(self): + def glom(self: "RDD[T]") -> "RDD[List[T]]": """ Return an RDD created by coalescing all elements within each partition into a list. @@ -864,12 +1051,12 @@ def glom(self): [[1, 2], [3, 4]] """ - def func(iterator): + def func(iterator: Iterable[T]) -> Iterable[List[T]]: yield list(iterator) return self.mapPartitions(func) - def cartesian(self, other): + def cartesian(self: "RDD[T]", other: "RDD[U]") -> "RDD[Tuple[T, U]]": """ Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of elements ``(a, b)`` where ``a`` is in `self` and @@ -885,7 +1072,12 @@ def cartesian(self, other): deserializer = CartesianDeserializer(self._jrdd_deserializer, other._jrdd_deserializer) return RDD(self._jrdd.cartesian(other._jrdd), self.ctx, deserializer) - def groupBy(self, f, numPartitions=None, partitionFunc=portable_hash): + def groupBy( + self: "RDD[T]", + f: Callable[[T], K], + numPartitions: Optional[int] = None, + partitionFunc: Callable[[K], int] = portable_hash, + ) -> "RDD[Tuple[K, Iterable[T]]]": """ Return an RDD of grouped items. @@ -898,7 +1090,9 @@ def groupBy(self, f, numPartitions=None, partitionFunc=portable_hash): """ return self.map(lambda x: (f(x), x)).groupByKey(numPartitions, partitionFunc) - def pipe(self, command, env=None, checkCode=False): + def pipe( + self, command: str, env: Optional[Dict[str, str]] = None, checkCode: bool = False + ) -> "RDD[str]": """ Return an RDD created by piping elements to a forked external process. @@ -919,10 +1113,10 @@ def pipe(self, command, env=None, checkCode=False): if env is None: env = dict() - def func(iterator): + def func(iterator: Iterable[T]) -> Iterable[str]: pipe = Popen(shlex.split(command), env=env, stdin=PIPE, stdout=PIPE) - def pipe_objs(out): + def pipe_objs(out: IO[bytes]) -> None: for obj in iterator: s = str(obj).rstrip("\n") + "\n" out.write(s.encode("utf-8")) @@ -930,7 +1124,7 @@ def pipe_objs(out): Thread(target=pipe_objs, args=[pipe.stdin]).start() - def check_return_code(): + def check_return_code() -> Iterable[int]: pipe.wait() if checkCode and pipe.returncode: raise RuntimeError( @@ -942,13 +1136,15 @@ def check_return_code(): yield i return ( - x.rstrip(b"\n").decode("utf-8") - for x in chain(iter(pipe.stdout.readline, b""), check_return_code()) + cast(bytes, x).rstrip(b"\n").decode("utf-8") + for x in chain( + iter(cast(IO[bytes], pipe.stdout).readline, b""), check_return_code() + ) ) return self.mapPartitions(func) - def foreach(self, f): + def foreach(self: "RDD[T]", f: Callable[[T], None]) -> None: """ Applies a function to all elements of this RDD. @@ -959,14 +1155,14 @@ def foreach(self, f): """ f = fail_on_stopiteration(f) - def processPartition(iterator): + def processPartition(iterator: Iterable[T]) -> Iterable[Any]: for x in iterator: f(x) return iter([]) self.mapPartitions(processPartition).count() # Force evaluation - def foreachPartition(self, f): + def foreachPartition(self: "RDD[T]", f: Callable[[Iterable[T]], None]) -> None: """ Applies a function to each partition of this RDD. @@ -978,16 +1174,16 @@ def foreachPartition(self, f): >>> sc.parallelize([1, 2, 3, 4, 5]).foreachPartition(f) """ - def func(it): + def func(it: Iterable[T]) -> Iterable[Any]: r = f(it) try: - return iter(r) + return iter(r) # type: ignore[call-overload] except TypeError: return iter([]) self.mapPartitions(func).count() # Force evaluation - def collect(self): + def collect(self: "RDD[T]") -> List[T]: """ Return a list that contains all of the elements in this RDD. @@ -997,10 +1193,13 @@ def collect(self): to be small, as all the data is loaded into the driver's memory. """ with SCCallSiteSync(self.context): + assert self.ctx._jvm is not None sock_info = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd()) return list(_load_from_socket(sock_info, self._jrdd_deserializer)) - def collectWithJobGroup(self, groupId, description, interruptOnCancel=False): + def collectWithJobGroup( + self: "RDD[T]", groupId: str, description: str, interruptOnCancel: bool = False + ) -> "List[T]": """ When collect rdd, use this method to specify job group. @@ -1015,12 +1214,13 @@ def collectWithJobGroup(self, groupId, description, interruptOnCancel=False): ) with SCCallSiteSync(self.context): + assert self.ctx._jvm is not None sock_info = self.ctx._jvm.PythonRDD.collectAndServeWithJobGroup( self._jrdd.rdd(), groupId, description, interruptOnCancel ) return list(_load_from_socket(sock_info, self._jrdd_deserializer)) - def reduce(self, f): + def reduce(self: "RDD[T]", f: Callable[[T, T], T]) -> T: """ Reduces the elements of this RDD using the specified commutative and associative binary operator. Currently reduces partitions locally. @@ -1039,7 +1239,7 @@ def reduce(self, f): """ f = fail_on_stopiteration(f) - def func(iterator): + def func(iterator: Iterable[T]) -> Iterable[T]: iterator = iter(iterator) try: initial = next(iterator) @@ -1052,7 +1252,7 @@ def func(iterator): return reduce(f, vals) raise ValueError("Can not reduce() empty RDD") - def treeReduce(self, f, depth=2): + def treeReduce(self: "RDD[T]", f: Callable[[T, T], T], depth: int = 2) -> T: """ Reduces the elements of this RDD in a multi-level tree pattern. @@ -1080,9 +1280,13 @@ def treeReduce(self, f, depth=2): if depth < 1: raise ValueError("Depth cannot be smaller than 1 but got %d." % depth) - zeroValue = None, True # Use the second entry to indicate whether this is a dummy value. + # Use the second entry to indicate whether this is a dummy value. + zeroValue: Tuple[T, bool] = ( # type: ignore[assignment] + None, + True, + ) - def op(x, y): + def op(x: Tuple[T, bool], y: Tuple[T, bool]) -> Tuple[T, bool]: if x[1]: return y elif y[1]: @@ -1095,7 +1299,7 @@ def op(x, y): raise ValueError("Cannot reduce empty RDD.") return reduced[0] - def fold(self, zeroValue, op): + def fold(self: "RDD[T]", zeroValue: T, op: Callable[[T, T], T]) -> T: """ Aggregate the elements of each partition, and then the results for all the partitions, using a given associative function and a neutral "zero value." @@ -1120,7 +1324,7 @@ def fold(self, zeroValue, op): """ op = fail_on_stopiteration(op) - def func(iterator): + def func(iterator: Iterable[T]) -> Iterable[T]: acc = zeroValue for obj in iterator: acc = op(acc, obj) @@ -1132,7 +1336,9 @@ def func(iterator): vals = self.mapPartitions(func).collect() return reduce(op, vals, zeroValue) - def aggregate(self, zeroValue, seqOp, combOp): + def aggregate( + self: "RDD[T]", zeroValue: U, seqOp: Callable[[U, T], U], combOp: Callable[[U, U], U] + ) -> U: """ Aggregate the elements of each partition, and then the results for all the partitions, using a given combine functions and a neutral "zero @@ -1158,7 +1364,7 @@ def aggregate(self, zeroValue, seqOp, combOp): seqOp = fail_on_stopiteration(seqOp) combOp = fail_on_stopiteration(combOp) - def func(iterator): + def func(iterator: Iterable[T]) -> Iterable[U]: acc = zeroValue for obj in iterator: acc = seqOp(acc, obj) @@ -1170,7 +1376,13 @@ def func(iterator): vals = self.mapPartitions(func).collect() return reduce(combOp, vals, zeroValue) - def treeAggregate(self, zeroValue, seqOp, combOp, depth=2): + def treeAggregate( + self: "RDD[T]", + zeroValue: U, + seqOp: Callable[[U, T], U], + combOp: Callable[[U, U], U], + depth: int = 2, + ) -> U: """ Aggregates the elements of this RDD in a multi-level tree pattern. @@ -1199,7 +1411,7 @@ def treeAggregate(self, zeroValue, seqOp, combOp, depth=2): if self.getNumPartitions() == 0: return zeroValue - def aggregatePartition(iterator): + def aggregatePartition(iterator: Iterable[T]) -> Iterable[U]: acc = zeroValue for obj in iterator: acc = seqOp(acc, obj) @@ -1211,10 +1423,10 @@ def aggregatePartition(iterator): # If creating an extra level doesn't help reduce the wall-clock time, we stop the tree # aggregation. while numPartitions > scale + numPartitions / scale: - numPartitions /= scale + numPartitions /= scale # type: ignore[assignment] curNumPartitions = int(numPartitions) - def mapPartition(i, iterator): + def mapPartition(i: int, iterator: Iterable[U]) -> Iterable[Tuple[int, U]]: for obj in iterator: yield (i % curNumPartitions, obj) @@ -1226,7 +1438,15 @@ def mapPartition(i, iterator): return partiallyAggregated.reduce(combOp) - def max(self, key=None): + @overload + def max(self: "RDD[S]") -> "S": + ... + + @overload + def max(self: "RDD[T]", key: Callable[[T], "S"]) -> T: + ... + + def max(self: "RDD[T]", key: Optional[Callable[[T], "S"]] = None) -> T: """ Find the maximum item in this RDD. @@ -1244,10 +1464,18 @@ def max(self, key=None): 5.0 """ if key is None: - return self.reduce(max) - return self.reduce(lambda a, b: max(a, b, key=key)) + return self.reduce(max) # type: ignore[arg-type] + return self.reduce(lambda a, b: max(a, b, key=key)) # type: ignore[arg-type] + + @overload + def min(self: "RDD[S]") -> "S": + ... - def min(self, key=None): + @overload + def min(self: "RDD[T]", key: Callable[[T], "S"]) -> T: + ... + + def min(self: "RDD[T]", key: Optional[Callable[[T], "S"]] = None) -> T: """ Find the minimum item in this RDD. @@ -1265,10 +1493,10 @@ def min(self, key=None): 10.0 """ if key is None: - return self.reduce(min) - return self.reduce(lambda a, b: min(a, b, key=key)) + return self.reduce(min) # type: ignore[arg-type] + return self.reduce(lambda a, b: min(a, b, key=key)) # type: ignore[arg-type] - def sum(self): + def sum(self: "RDD[NumberOrArray]") -> "NumberOrArray": """ Add up the elements in this RDD. @@ -1277,9 +1505,11 @@ def sum(self): >>> sc.parallelize([1.0, 2.0, 3.0]).sum() 6.0 """ - return self.mapPartitions(lambda x: [sum(x)]).fold(0, operator.add) + return self.mapPartitions(lambda x: [sum(x)]).fold( # type: ignore[return-value] + 0, operator.add + ) - def count(self): + def count(self) -> int: """ Return the number of elements in this RDD. @@ -1290,18 +1520,22 @@ def count(self): """ return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum() - def stats(self): + def stats(self: "RDD[NumberOrArray]") -> StatCounter: """ Return a :class:`StatCounter` object that captures the mean, variance and count of the RDD's elements in one operation. """ - def redFunc(left_counter, right_counter): + def redFunc(left_counter: StatCounter, right_counter: StatCounter) -> StatCounter: return left_counter.mergeStats(right_counter) - return self.mapPartitions(lambda i: [StatCounter(i)]).reduce(redFunc) + return self.mapPartitions(lambda i: [StatCounter(i)]).reduce( # type: ignore[arg-type] + redFunc + ) - def histogram(self, buckets): + def histogram( + self: "RDD[S]", buckets: Union[int, List["S"], Tuple["S", ...]] + ) -> Tuple[Sequence["S"], List[int]]: """ Compute a histogram using the provided buckets. The buckets are all open to the right except for the last which is closed. @@ -1345,7 +1579,7 @@ def histogram(self, buckets): raise ValueError("number of buckets must be >= 1") # filter out non-comparable elements - def comparable(x): + def comparable(x: Any) -> bool: if x is None: return False if type(x) is float and isnan(x): @@ -1355,7 +1589,7 @@ def comparable(x): filtered = self.filter(comparable) # faster than stats() - def minmax(a, b): + def minmax(a: Tuple["S", "S"], b: Tuple["S", "S"]) -> Tuple["S", "S"]: return min(a[0], b[0]), max(a[1], b[1]) try: @@ -1369,7 +1603,7 @@ def minmax(a, b): return [minv, maxv], [filtered.count()] try: - inc = (maxv - minv) / buckets + inc = (maxv - minv) / buckets # type: ignore[operator] except TypeError: raise TypeError("Can not generate buckets with non-number in RDD") @@ -1378,8 +1612,8 @@ def minmax(a, b): # keep them as integer if possible inc = int(inc) - if inc * buckets != maxv - minv: - inc = (maxv - minv) * 1.0 / buckets + if inc * buckets != maxv - minv: # type: ignore[operator] + inc = (maxv - minv) * 1.0 / buckets # type: ignore[operator] buckets = [i * inc + minv for i in range(buckets)] buckets.append(maxv) # fix accumulated error @@ -1403,35 +1637,42 @@ def minmax(a, b): even = False inc = None try: - steps = [buckets[i + 1] - buckets[i] for i in range(len(buckets) - 1)] + steps = [ + buckets[i + 1] - buckets[i] # type: ignore[operator] + for i in range(len(buckets) - 1) + ] except TypeError: pass # objects in buckets do not support '-' else: if max(steps) - min(steps) < 1e-10: # handle precision errors even = True - inc = (maxv - minv) / (len(buckets) - 1) + inc = (maxv - minv) / (len(buckets) - 1) # type: ignore[operator] else: raise TypeError("buckets should be a list or tuple or number(int or long)") - def histogram(iterator): - counters = [0] * len(buckets) + def histogram(iterator: Iterable["S"]) -> Iterable[List[int]]: + counters = [0] * len(buckets) # type: ignore[arg-type] for i in iterator: - if i is None or (type(i) is float and isnan(i)) or i > maxv or i < minv: + if i is None or (isinstance(i, float) and isnan(i)) or i > maxv or i < minv: continue - t = int((i - minv) / inc) if even else bisect.bisect_right(buckets, i) - 1 + t = ( + int((i - minv) / inc) # type: ignore[operator] + if even + else bisect.bisect_right(buckets, i) - 1 # type: ignore[arg-type] + ) counters[t] += 1 # add last two together last = counters.pop() counters[-1] += last return [counters] - def mergeCounters(a, b): + def mergeCounters(a: List[int], b: List[int]) -> List[int]: return [i + j for i, j in zip(a, b)] return buckets, self.mapPartitions(histogram).reduce(mergeCounters) - def mean(self): + def mean(self: "RDD[NumberOrArray]") -> "NumberOrArray": """ Compute the mean of this RDD's elements. @@ -1440,9 +1681,9 @@ def mean(self): >>> sc.parallelize([1, 2, 3]).mean() 2.0 """ - return self.stats().mean() + return self.stats().mean() # type: ignore[return-value] - def variance(self): + def variance(self: "RDD[NumberOrArray]") -> "NumberOrArray": """ Compute the variance of this RDD's elements. @@ -1451,9 +1692,9 @@ def variance(self): >>> sc.parallelize([1, 2, 3]).variance() 0.666... """ - return self.stats().variance() + return self.stats().variance() # type: ignore[return-value] - def stdev(self): + def stdev(self: "RDD[NumberOrArray]") -> "NumberOrArray": """ Compute the standard deviation of this RDD's elements. @@ -1462,9 +1703,9 @@ def stdev(self): >>> sc.parallelize([1, 2, 3]).stdev() 0.816... """ - return self.stats().stdev() + return self.stats().stdev() # type: ignore[return-value] - def sampleStdev(self): + def sampleStdev(self: "RDD[NumberOrArray]") -> "NumberOrArray": """ Compute the sample standard deviation of this RDD's elements (which corrects for bias in estimating the standard deviation by dividing by @@ -1475,9 +1716,9 @@ def sampleStdev(self): >>> sc.parallelize([1, 2, 3]).sampleStdev() 1.0 """ - return self.stats().sampleStdev() + return self.stats().sampleStdev() # type: ignore[return-value] - def sampleVariance(self): + def sampleVariance(self: "RDD[NumberOrArray]") -> "NumberOrArray": """ Compute the sample variance of this RDD's elements (which corrects for bias in estimating the variance by dividing by N-1 instead of N). @@ -1487,9 +1728,9 @@ def sampleVariance(self): >>> sc.parallelize([1, 2, 3]).sampleVariance() 1.0 """ - return self.stats().sampleVariance() + return self.stats().sampleVariance() # type: ignore[return-value] - def countByValue(self): + def countByValue(self: "RDD[K]") -> Dict[K, int]: """ Return the count of each unique value in this RDD as a dictionary of (value, count) pairs. @@ -1500,20 +1741,28 @@ def countByValue(self): [(1, 2), (2, 3)] """ - def countPartition(iterator): - counts = defaultdict(int) + def countPartition(iterator: Iterable[K]) -> Iterable[Dict[K, int]]: + counts: Dict[K, int] = defaultdict(int) for obj in iterator: counts[obj] += 1 yield counts - def mergeMaps(m1, m2): + def mergeMaps(m1: Dict[K, int], m2: Dict[K, int]) -> Dict[K, int]: for k, v in m2.items(): m1[k] += v return m1 return self.mapPartitions(countPartition).reduce(mergeMaps) - def top(self, num, key=None): + @overload + def top(self: "RDD[S]", num: int) -> List["S"]: + ... + + @overload + def top(self: "RDD[T]", num: int, key: Callable[[T], "S"]) -> List[T]: + ... + + def top(self: "RDD[T]", num: int, key: Optional[Callable[[T], "S"]] = None) -> List[T]: """ Get the top N elements from an RDD. @@ -1534,15 +1783,23 @@ def top(self, num, key=None): [4, 3, 2] """ - def topIterator(iterator): + def topIterator(iterator: Iterable[T]) -> Iterable[List[T]]: yield heapq.nlargest(num, iterator, key=key) - def merge(a, b): + def merge(a: List[T], b: List[T]) -> List[T]: return heapq.nlargest(num, a + b, key=key) return self.mapPartitions(topIterator).reduce(merge) - def takeOrdered(self, num, key=None): + @overload + def takeOrdered(self: "RDD[S]", num: int) -> List["S"]: + ... + + @overload + def takeOrdered(self: "RDD[T]", num: int, key: Callable[[T], "S"]) -> List[T]: + ... + + def takeOrdered(self: "RDD[T]", num: int, key: Optional[Callable[[T], "S"]] = None) -> List[T]: """ Get the N elements from an RDD ordered in ascending order or as specified by the optional key function. @@ -1560,12 +1817,12 @@ def takeOrdered(self, num, key=None): [10, 9, 7, 6, 5, 4] """ - def merge(a, b): + def merge(a: List[T], b: List[T]) -> List[T]: return heapq.nsmallest(num, a + b, key) return self.mapPartitions(lambda it: [heapq.nsmallest(num, it, key)]).reduce(merge) - def take(self, num): + def take(self: "RDD[T]", num: int) -> List[T]: """ Take the first num elements of the RDD. @@ -1589,7 +1846,7 @@ def take(self, num): >>> sc.parallelize(range(100), 100).filter(lambda x: x > 90).take(3) [91, 92, 93] """ - items = [] + items: List[T] = [] totalParts = self.getNumPartitions() partsScanned = 0 @@ -1612,7 +1869,7 @@ def take(self, num): left = num - len(items) - def takeUpToNumLeft(iterator): + def takeUpToNumLeft(iterator: Iterable[T]) -> Iterable[T]: iterator = iter(iterator) taken = 0 while taken < left: @@ -1630,7 +1887,7 @@ def takeUpToNumLeft(iterator): return items[:num] - def first(self): + def first(self: "RDD[T]") -> T: """ Return the first element in this RDD. @@ -1648,7 +1905,7 @@ def first(self): return rs[0] raise ValueError("RDD is empty") - def isEmpty(self): + def isEmpty(self) -> bool: """ Returns true if and only if the RDD contains no elements at all. @@ -1665,7 +1922,12 @@ def isEmpty(self): """ return self.getNumPartitions() == 0 or len(self.take(1)) == 0 - def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None): + def saveAsNewAPIHadoopDataset( + self: "RDD[Tuple[K, V]]", + conf: Dict[str, str], + keyConverter: Optional[str] = None, + valueConverter: Optional[str] = None, + ) -> None: """ Output a Python RDD of key-value pairs (of form ``RDD[(K, V)]``) to any Hadoop file system, using the new Hadoop OutputFormat API (mapreduce package). Keys/values are @@ -1683,20 +1945,22 @@ def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None """ jconf = self.ctx._dictToJavaMap(conf) pickledRDD = self._pickled() + assert self.ctx._jvm is not None + self.ctx._jvm.PythonRDD.saveAsHadoopDataset( pickledRDD._jrdd, True, jconf, keyConverter, valueConverter, True ) def saveAsNewAPIHadoopFile( - self, - path, - outputFormatClass, - keyClass=None, - valueClass=None, - keyConverter=None, - valueConverter=None, - conf=None, - ): + self: "RDD[Tuple[K, V]]", + path: str, + outputFormatClass: str, + keyClass: Optional[str] = None, + valueClass: Optional[str] = None, + keyConverter: Optional[str] = None, + valueConverter: Optional[str] = None, + conf: Optional[Dict[str, str]] = None, + ) -> None: """ Output a Python RDD of key-value pairs (of form ``RDD[(K, V)]``) to any Hadoop file system, using the new Hadoop OutputFormat API (mapreduce package). Key and value types @@ -1725,6 +1989,8 @@ def saveAsNewAPIHadoopFile( """ jconf = self.ctx._dictToJavaMap(conf) pickledRDD = self._pickled() + assert self.ctx._jvm is not None + self.ctx._jvm.PythonRDD.saveAsNewAPIHadoopFile( pickledRDD._jrdd, True, @@ -1737,7 +2003,12 @@ def saveAsNewAPIHadoopFile( jconf, ) - def saveAsHadoopDataset(self, conf, keyConverter=None, valueConverter=None): + def saveAsHadoopDataset( + self: "RDD[Tuple[K, V]]", + conf: Dict[str, str], + keyConverter: Optional[str] = None, + valueConverter: Optional[str] = None, + ) -> None: """ Output a Python RDD of key-value pairs (of form ``RDD[(K, V)]``) to any Hadoop file system, using the old Hadoop OutputFormat API (mapred package). Keys/values are @@ -1755,21 +2026,23 @@ def saveAsHadoopDataset(self, conf, keyConverter=None, valueConverter=None): """ jconf = self.ctx._dictToJavaMap(conf) pickledRDD = self._pickled() + assert self.ctx._jvm is not None + self.ctx._jvm.PythonRDD.saveAsHadoopDataset( pickledRDD._jrdd, True, jconf, keyConverter, valueConverter, False ) def saveAsHadoopFile( - self, - path, - outputFormatClass, - keyClass=None, - valueClass=None, - keyConverter=None, - valueConverter=None, - conf=None, - compressionCodecClass=None, - ): + self: "RDD[Tuple[K, V]]", + path: str, + outputFormatClass: str, + keyClass: Optional[str] = None, + valueClass: Optional[str] = None, + keyConverter: Optional[str] = None, + valueConverter: Optional[str] = None, + conf: Optional[Dict[str, str]] = None, + compressionCodecClass: Optional[str] = None, + ) -> None: """ Output a Python RDD of key-value pairs (of form ``RDD[(K, V)]``) to any Hadoop file system, using the old Hadoop OutputFormat API (mapred package). Key and value types @@ -1803,6 +2076,8 @@ def saveAsHadoopFile( """ jconf = self.ctx._dictToJavaMap(conf) pickledRDD = self._pickled() + assert self.ctx._jvm is not None + self.ctx._jvm.PythonRDD.saveAsHadoopFile( pickledRDD._jrdd, True, @@ -1816,7 +2091,9 @@ def saveAsHadoopFile( compressionCodecClass, ) - def saveAsSequenceFile(self, path, compressionCodecClass=None): + def saveAsSequenceFile( + self: "RDD[Tuple[K, V]]", path: str, compressionCodecClass: Optional[str] = None + ) -> None: """ Output a Python RDD of key-value pairs (of form ``RDD[(K, V)]``) to any Hadoop file system, using the "org.apache.hadoop.io.Writable" types that we convert from the @@ -1834,11 +2111,13 @@ def saveAsSequenceFile(self, path, compressionCodecClass=None): i.e. "org.apache.hadoop.io.compress.GzipCodec" (None by default) """ pickledRDD = self._pickled() + assert self.ctx._jvm is not None + self.ctx._jvm.PythonRDD.saveAsSequenceFile( pickledRDD._jrdd, True, path, compressionCodecClass ) - def saveAsPickleFile(self, path, batchSize=10): + def saveAsPickleFile(self, path: str, batchSize: int = 10) -> None: """ Save this RDD as a SequenceFile of serialized objects. The serializer used is :class:`pyspark.serializers.CPickleSerializer`, default batch size @@ -1853,13 +2132,14 @@ def saveAsPickleFile(self, path, batchSize=10): >>> sorted(sc.pickleFile(tmpFile.name, 5).map(str).collect()) ['1', '2', 'rdd', 'spark'] """ + ser: Serializer if batchSize == 0: ser = AutoBatchedSerializer(CPickleSerializer()) else: ser = BatchedSerializer(CPickleSerializer(), batchSize) self._reserialize(ser)._jrdd.saveAsObjectFile(path) - def saveAsTextFile(self, path, compressionCodecClass=None): + def saveAsTextFile(self, path: str, compressionCodecClass: Optional[str] = None) -> None: """ Save this RDD as a text file, using string representations of elements. @@ -1904,16 +2184,20 @@ def saveAsTextFile(self, path, compressionCodecClass=None): 'bar\\nfoo\\n' """ - def func(split, iterator): + def func(split: int, iterator: Iterable[Any]) -> Iterable[bytes]: for x in iterator: - if not isinstance(x, (str, bytes)): - x = str(x) - if isinstance(x, str): - x = x.encode("utf-8") - yield x + if isinstance(x, bytes): + yield x + elif isinstance(x, str): + yield x.encode("utf-8") + else: + yield str(x).encode("utf-8") keyed = self.mapPartitionsWithIndex(func) - keyed._bypass_serializer = True + keyed._bypass_serializer = True # type: ignore[attr-defined] + + assert self.ctx._jvm is not None + if compressionCodecClass: compressionCodec = self.ctx._jvm.java.lang.Class.forName(compressionCodecClass) keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path, compressionCodec) @@ -1922,7 +2206,7 @@ def func(split, iterator): # Pair functions - def collectAsMap(self): + def collectAsMap(self: "RDD[Tuple[K, V]]") -> Dict[K, V]: """ Return the key-value pairs in this RDD to the master as a dictionary. @@ -1941,7 +2225,7 @@ def collectAsMap(self): """ return dict(self.collect()) - def keys(self): + def keys(self: "RDD[Tuple[K, V]]") -> "RDD[K]": """ Return an RDD with the keys of each tuple. @@ -1953,7 +2237,7 @@ def keys(self): """ return self.map(lambda x: x[0]) - def values(self): + def values(self: "RDD[Tuple[K, V]]") -> "RDD[V]": """ Return an RDD with the values of each tuple. @@ -1965,7 +2249,12 @@ def values(self): """ return self.map(lambda x: x[1]) - def reduceByKey(self, func, numPartitions=None, partitionFunc=portable_hash): + def reduceByKey( + self: "RDD[Tuple[K, V]]", + func: Callable[[V, V], V], + numPartitions: Optional[int] = None, + partitionFunc: Callable[[K], int] = portable_hash, + ) -> "RDD[Tuple[K, V]]": """ Merge the values for each key using an associative and commutative reduce function. @@ -1985,7 +2274,7 @@ def reduceByKey(self, func, numPartitions=None, partitionFunc=portable_hash): """ return self.combineByKey(lambda x: x, func, func, numPartitions, partitionFunc) - def reduceByKeyLocally(self, func): + def reduceByKeyLocally(self: "RDD[Tuple[K, V]]", func: Callable[[V, V], V]) -> Dict[K, V]: """ Merge the values for each key using an associative and commutative reduce function, but return the results immediately to the master as a dictionary. @@ -2002,20 +2291,20 @@ def reduceByKeyLocally(self, func): """ func = fail_on_stopiteration(func) - def reducePartition(iterator): - m = {} + def reducePartition(iterator: Iterable[Tuple[K, V]]) -> Iterable[Dict[K, V]]: + m: Dict[K, V] = {} for k, v in iterator: m[k] = func(m[k], v) if k in m else v yield m - def mergeMaps(m1, m2): + def mergeMaps(m1: Dict[K, V], m2: Dict[K, V]) -> Dict[K, V]: for k, v in m2.items(): m1[k] = func(m1[k], v) if k in m1 else v return m1 return self.mapPartitions(reducePartition).reduce(mergeMaps) - def countByKey(self): + def countByKey(self: "RDD[Tuple[K, V]]") -> Dict[K, int]: """ Count the number of elements for each key, and return the result to the master as a dictionary. @@ -2028,7 +2317,11 @@ def countByKey(self): """ return self.map(lambda x: x[0]).countByValue() - def join(self, other, numPartitions=None): + def join( + self: "RDD[Tuple[K, V]]", + other: "RDD[Tuple[K, U]]", + numPartitions: Optional[int] = None, + ) -> "RDD[Tuple[K, Tuple[V, U]]]": """ Return an RDD containing all pairs of elements with matching keys in `self` and `other`. @@ -2047,7 +2340,11 @@ def join(self, other, numPartitions=None): """ return python_join(self, other, numPartitions) - def leftOuterJoin(self, other, numPartitions=None): + def leftOuterJoin( + self: "RDD[Tuple[K, V]]", + other: "RDD[Tuple[K, U]]", + numPartitions: Optional[int] = None, + ) -> "RDD[Tuple[K, Tuple[V, Optional[U]]]]": """ Perform a left outer join of `self` and `other`. @@ -2066,7 +2363,11 @@ def leftOuterJoin(self, other, numPartitions=None): """ return python_left_outer_join(self, other, numPartitions) - def rightOuterJoin(self, other, numPartitions=None): + def rightOuterJoin( + self: "RDD[Tuple[K, V]]", + other: "RDD[Tuple[K, U]]", + numPartitions: Optional[int] = None, + ) -> "RDD[Tuple[K, Tuple[Optional[V], U]]]": """ Perform a right outer join of `self` and `other`. @@ -2085,7 +2386,11 @@ def rightOuterJoin(self, other, numPartitions=None): """ return python_right_outer_join(self, other, numPartitions) - def fullOuterJoin(self, other, numPartitions=None): + def fullOuterJoin( + self: "RDD[Tuple[K, V]]", + other: "RDD[Tuple[K, U]]", + numPartitions: Optional[int] = None, + ) -> "RDD[Tuple[K, Tuple[Optional[V], Optional[U]]]]": """ Perform a right outer join of `self` and `other`. @@ -2111,7 +2416,11 @@ def fullOuterJoin(self, other, numPartitions=None): # TODO: add option to control map-side combining # portable_hash is used as default, because builtin hash of None is different # cross machines. - def partitionBy(self, numPartitions, partitionFunc=portable_hash): + def partitionBy( + self: "RDD[Tuple[K, V]]", + numPartitions: Optional[int], + partitionFunc: Callable[[K], int] = portable_hash, + ) -> "RDD[Tuple[K, V]]": """ Return a copy of the RDD partitioned using the specified partitioner. @@ -2138,13 +2447,13 @@ def partitionBy(self, numPartitions, partitionFunc=portable_hash): limit = self._memory_limit() / 2 - def add_shuffle_key(split, iterator): + def add_shuffle_key(split: int, iterator: Iterable[Tuple[K, V]]) -> Iterable[bytes]: buckets = defaultdict(list) - c, batch = 0, min(10 * numPartitions, 1000) + c, batch = 0, min(10 * numPartitions, 1000) # type: ignore[operator] for k, v in iterator: - buckets[partitionFunc(k) % numPartitions].append((k, v)) + buckets[partitionFunc(k) % numPartitions].append((k, v)) # type: ignore[operator] c += 1 # check used memory and avg size of chunk of objects @@ -2160,7 +2469,7 @@ def add_shuffle_key(split, iterator): avg = int(size / n) >> 20 # let 1M < avg < 10M if avg < 1: - batch = min(sys.maxsize, batch * 1.5) + batch = min(sys.maxsize, batch * 1.5) # type: ignore[assignment] elif avg > 10: batch = max(int(batch / 1.5), 1) c = 0 @@ -2170,24 +2479,26 @@ def add_shuffle_key(split, iterator): yield outputSerializer.dumps(items) keyed = self.mapPartitionsWithIndex(add_shuffle_key, preservesPartitioning=True) - keyed._bypass_serializer = True + keyed._bypass_serializer = True # type: ignore[attr-defined] + assert self.ctx._jvm is not None + with SCCallSiteSync(self.context): pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() jpartitioner = self.ctx._jvm.PythonPartitioner(numPartitions, id(partitionFunc)) jrdd = self.ctx._jvm.PythonRDD.valueOfPair(pairRDD.partitionBy(jpartitioner)) - rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer)) + rdd: "RDD[Tuple[K, V]]" = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer)) rdd.partitioner = partitioner return rdd # TODO: add control over map-side aggregation def combineByKey( - self, - createCombiner, - mergeValue, - mergeCombiners, - numPartitions=None, - partitionFunc=portable_hash, - ): + self: "RDD[Tuple[K, V]]", + createCombiner: Callable[[V], U], + mergeValue: Callable[[U, V], U], + mergeCombiners: Callable[[U, U], U], + numPartitions: Optional[int] = None, + partitionFunc: Callable[[K], int] = portable_hash, + ) -> "RDD[Tuple[K, U]]": """ Generic function to combine the elements for each key using a custom set of aggregation functions. @@ -2238,7 +2549,7 @@ def combineByKey( memory = self._memory_limit() agg = Aggregator(createCombiner, mergeValue, mergeCombiners) - def combineLocally(iterator): + def combineLocally(iterator: Iterable[Tuple[K, V]]) -> Iterable[Tuple[K, U]]: merger = ExternalMerger(agg, memory * 0.9, serializer) merger.mergeValues(iterator) return merger.items() @@ -2246,7 +2557,7 @@ def combineLocally(iterator): locally_combined = self.mapPartitions(combineLocally, preservesPartitioning=True) shuffled = locally_combined.partitionBy(numPartitions, partitionFunc) - def _mergeCombiners(iterator): + def _mergeCombiners(iterator: Iterable[Tuple[K, U]]) -> Iterable[Tuple[K, U]]: merger = ExternalMerger(agg, memory, serializer) merger.mergeCombiners(iterator) return merger.items() @@ -2254,8 +2565,13 @@ def _mergeCombiners(iterator): return shuffled.mapPartitions(_mergeCombiners, preservesPartitioning=True) def aggregateByKey( - self, zeroValue, seqFunc, combFunc, numPartitions=None, partitionFunc=portable_hash - ): + self: "RDD[Tuple[K, V]]", + zeroValue: U, + seqFunc: Callable[[U, V], U], + combFunc: Callable[[U, U], U], + numPartitions: Optional[int] = None, + partitionFunc: Callable[[K], int] = portable_hash, + ) -> "RDD[Tuple[K, U]]": """ Aggregate the values of each key, using given combine functions and a neutral "zero value". This function can return a different result type, U, than the type @@ -2266,14 +2582,20 @@ def aggregateByKey( allowed to modify and return their first argument instead of creating a new U. """ - def createZero(): + def createZero() -> U: return copy.deepcopy(zeroValue) return self.combineByKey( lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions, partitionFunc ) - def foldByKey(self, zeroValue, func, numPartitions=None, partitionFunc=portable_hash): + def foldByKey( + self: "RDD[Tuple[K, V]]", + zeroValue: V, + func: Callable[[V, V], V], + numPartitions: Optional[int] = None, + partitionFunc: Callable[[K], int] = portable_hash, + ) -> "RDD[Tuple[K, V]]": """ Merge the values for each key using an associative function "func" and a neutral "zeroValue" which may be added to the result an @@ -2288,18 +2610,22 @@ def foldByKey(self, zeroValue, func, numPartitions=None, partitionFunc=portable_ [('a', 2), ('b', 1)] """ - def createZero(): + def createZero() -> V: return copy.deepcopy(zeroValue) return self.combineByKey( lambda v: func(createZero(), v), func, func, numPartitions, partitionFunc ) - def _memory_limit(self): + def _memory_limit(self) -> int: return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) # TODO: support variant with custom partitioner - def groupByKey(self, numPartitions=None, partitionFunc=portable_hash): + def groupByKey( + self: "RDD[Tuple[K, V]]", + numPartitions: Optional[int] = None, + partitionFunc: Callable[[K], int] = portable_hash, + ) -> "RDD[Tuple[K, Iterable[V]]]": """ Group the values for each key in the RDD into a single sequence. Hash-partitions the resulting RDD with numPartitions partitions. @@ -2319,14 +2645,14 @@ def groupByKey(self, numPartitions=None, partitionFunc=portable_hash): [('a', [1, 1]), ('b', [1])] """ - def createCombiner(x): + def createCombiner(x: V) -> List[V]: return [x] - def mergeValue(xs, x): + def mergeValue(xs: List[V], x: V) -> List[V]: xs.append(x) return xs - def mergeCombiners(a, b): + def mergeCombiners(a: List[V], b: List[V]) -> List[V]: a.extend(b) return a @@ -2334,7 +2660,7 @@ def mergeCombiners(a, b): serializer = self._jrdd_deserializer agg = Aggregator(createCombiner, mergeValue, mergeCombiners) - def combine(iterator): + def combine(iterator: Iterable[Tuple[K, V]]) -> Iterable[Tuple[K, List[V]]]: merger = ExternalMerger(agg, memory * 0.9, serializer) merger.mergeValues(iterator) return merger.items() @@ -2342,14 +2668,16 @@ def combine(iterator): locally_combined = self.mapPartitions(combine, preservesPartitioning=True) shuffled = locally_combined.partitionBy(numPartitions, partitionFunc) - def groupByKey(it): + def groupByKey(it: Iterable[Tuple[K, List[V]]]) -> Iterable[Tuple[K, List[V]]]: merger = ExternalGroupBy(agg, memory, serializer) merger.mergeCombiners(it) return merger.items() return shuffled.mapPartitions(groupByKey, True).mapValues(ResultIterable) - def flatMapValues(self, f): + def flatMapValues( + self: "RDD[Tuple[K, V]]", f: Callable[[V], Iterable[U]] + ) -> "RDD[Tuple[K, U]]": """ Pass each value in the key-value pair RDD through a flatMap function without changing the keys; this also retains the original RDD's @@ -2362,10 +2690,13 @@ def flatMapValues(self, f): >>> x.flatMapValues(f).collect() [('a', 'x'), ('a', 'y'), ('a', 'z'), ('b', 'p'), ('b', 'r')] """ - flat_map_fn = lambda kv: ((kv[0], x) for x in f(kv[1])) + + def flat_map_fn(kv: Tuple[K, V]) -> Iterable[Tuple[K, U]]: + return ((kv[0], x) for x in f(kv[1])) + return self.flatMap(flat_map_fn, preservesPartitioning=True) - def mapValues(self, f): + def mapValues(self: "RDD[Tuple[K, V]]", f: Callable[[V], U]) -> "RDD[Tuple[K, U]]": """ Pass each value in the key-value pair RDD through a map function without changing the keys; this also retains the original RDD's @@ -2378,10 +2709,46 @@ def mapValues(self, f): >>> x.mapValues(f).collect() [('a', 3), ('b', 1)] """ - map_values_fn = lambda kv: (kv[0], f(kv[1])) + + def map_values_fn(kv: Tuple[K, V]) -> Tuple[K, U]: + return kv[0], f(kv[1]) + return self.map(map_values_fn, preservesPartitioning=True) - def groupWith(self, other, *others): + @overload + def groupWith( + self: "RDD[Tuple[K, V]]", other: "RDD[Tuple[K, V1]]" + ) -> "RDD[Tuple[K, Tuple[ResultIterable[V], ResultIterable[V1]]]]": + ... + + @overload + def groupWith( + self: "RDD[Tuple[K, V]]", other: "RDD[Tuple[K, V1]]", __o1: "RDD[Tuple[K, V2]]" + ) -> "RDD[Tuple[K, Tuple[ResultIterable[V], ResultIterable[V1], ResultIterable[V2]]]]": + ... + + @overload + def groupWith( + self: "RDD[Tuple[K, V]]", + other: "RDD[Tuple[K, V1]]", + _o1: "RDD[Tuple[K, V2]]", + _o2: "RDD[Tuple[K, V3]]", + ) -> """RDD[ + Tuple[ + K, + Tuple[ + ResultIterable[V], + ResultIterable[V1], + ResultIterable[V2], + ResultIterable[V3], + ], + ] + ]""": + ... + + def groupWith( # type: ignore[misc] + self: "RDD[Tuple[Any, Any]]", other: "RDD[Tuple[Any, Any]]", *others: "RDD[Tuple[Any, Any]]" + ) -> "RDD[Tuple[Any, Tuple[ResultIterable[Any], ...]]]": """ Alias for cogroup but with support for multiple RDDs. @@ -2398,7 +2765,11 @@ def groupWith(self, other, *others): return python_cogroup((self, other) + others, numPartitions=None) # TODO: add variant with custom partitioner - def cogroup(self, other, numPartitions=None): + def cogroup( + self: "RDD[Tuple[K, V]]", + other: "RDD[Tuple[K, U]]", + numPartitions: Optional[int] = None, + ) -> "RDD[Tuple[K, Tuple[ResultIterable[V], ResultIterable[U]]]]": """ For each key k in `self` or `other`, return a resulting RDD that contains a tuple with the list of values for that key in `self` as @@ -2413,7 +2784,12 @@ def cogroup(self, other, numPartitions=None): """ return python_cogroup((self, other), numPartitions) - def sampleByKey(self, withReplacement, fractions, seed=None): + def sampleByKey( + self: "RDD[Tuple[K, V]]", + withReplacement: bool, + fractions: Dict[K, Union[float, int]], + seed: Optional[int] = None, + ) -> "RDD[Tuple[K, V]]": """ Return a subset of this RDD sampled by key (via stratified sampling). Create a sample of this RDD using variable sampling rates for @@ -2437,7 +2813,11 @@ def sampleByKey(self, withReplacement, fractions, seed=None): RDDStratifiedSampler(withReplacement, fractions, seed).func, True ) - def subtractByKey(self, other, numPartitions=None): + def subtractByKey( + self: "RDD[Tuple[K, V]]", + other: "RDD[Tuple[K, Any]]", + numPartitions: Optional[int] = None, + ) -> "RDD[Tuple[K, V]]": """ Return each (key, value) pair in `self` that has no pair with matching key in `other`. @@ -2450,13 +2830,17 @@ def subtractByKey(self, other, numPartitions=None): [('b', 4), ('b', 5)] """ - def filter_func(pair): + def filter_func(pair: Tuple[K, Tuple[V, Any]]) -> bool: key, (val1, val2) = pair - return val1 and not val2 + return val1 and not val2 # type: ignore[return-value] - return self.cogroup(other, numPartitions).filter(filter_func).flatMapValues(lambda x: x[0]) + return ( + self.cogroup(other, numPartitions) + .filter(filter_func) # type: ignore[arg-type] + .flatMapValues(lambda x: x[0]) + ) - def subtract(self, other, numPartitions=None): + def subtract(self: "RDD[T]", other: "RDD[T]", numPartitions: Optional[int] = None) -> "RDD[T]": """ Return each value in `self` that is not contained in `other`. @@ -2471,7 +2855,7 @@ def subtract(self, other, numPartitions=None): rdd = other.map(lambda x: (x, True)) return self.map(lambda x: (x, True)).subtractByKey(rdd, numPartitions).keys() - def keyBy(self, f): + def keyBy(self: "RDD[T]", f: Callable[[T], K]) -> "RDD[Tuple[K, T]]": """ Creates tuples of the elements in this RDD by applying `f`. @@ -2484,7 +2868,7 @@ def keyBy(self, f): """ return self.map(lambda x: (f(x), x)) - def repartition(self, numPartitions): + def repartition(self: "RDD[T]", numPartitions: int) -> "RDD[T]": """ Return a new RDD that has exactly numPartitions partitions. @@ -2505,7 +2889,7 @@ def repartition(self, numPartitions): """ return self.coalesce(numPartitions, shuffle=True) - def coalesce(self, numPartitions, shuffle=False): + def coalesce(self: "RDD[T]", numPartitions: int, shuffle: bool = False) -> "RDD[T]": """ Return a new RDD that is reduced into `numPartitions` partitions. @@ -2529,7 +2913,7 @@ def coalesce(self, numPartitions, shuffle=False): jrdd = self._jrdd.coalesce(numPartitions, shuffle) return RDD(jrdd, self.ctx, jrdd_deserializer) - def zip(self, other): + def zip(self: "RDD[T]", other: "RDD[U]") -> "RDD[Tuple[T, U]]": """ Zips this RDD with another one, returning key-value pairs with the first element in each RDD second element in each RDD, etc. Assumes @@ -2545,12 +2929,12 @@ def zip(self, other): [(0, 1000), (1, 1001), (2, 1002), (3, 1003), (4, 1004)] """ - def get_batch_size(ser): + def get_batch_size(ser: Serializer) -> int: if isinstance(ser, BatchedSerializer): return ser.batchSize return 1 # not batched - def batch_as(rdd, batchSize): + def batch_as(rdd: "RDD[V]", batchSize: int) -> "RDD[V]": return rdd._reserialize(BatchedSerializer(CPickleSerializer(), batchSize)) my_batch = get_batch_size(self._jrdd_deserializer) @@ -2573,7 +2957,7 @@ def batch_as(rdd, batchSize): deserializer = PairDeserializer(self._jrdd_deserializer, other._jrdd_deserializer) return RDD(pairRDD, self.ctx, deserializer) - def zipWithIndex(self): + def zipWithIndex(self: "RDD[T]") -> "RDD[Tuple[T, int]]": """ Zips this RDD with its element indices. @@ -2596,13 +2980,13 @@ def zipWithIndex(self): for i in range(len(nums) - 1): starts.append(starts[-1] + nums[i]) - def func(k, it): + def func(k: int, it: Iterable[T]) -> Iterable[Tuple[T, int]]: for i, v in enumerate(it, starts[k]): yield v, i return self.mapPartitionsWithIndex(func) - def zipWithUniqueId(self): + def zipWithUniqueId(self: "RDD[T]") -> "RDD[Tuple[T, int]]": """ Zips this RDD with generated unique Long ids. @@ -2618,21 +3002,20 @@ def zipWithUniqueId(self): """ n = self.getNumPartitions() - def func(k, it): + def func(k: int, it: Iterable[T]) -> Iterable[Tuple[T, int]]: for i, v in enumerate(it): yield v, i * n + k return self.mapPartitionsWithIndex(func) - def name(self): + def name(self) -> Optional[str]: """ Return the name of this RDD. """ n = self._jrdd.name() - if n: - return n + return n if n else None - def setName(self, name): + def setName(self: "RDD[T]", name: str) -> "RDD[T]": """ Assign a name to this RDD. @@ -2645,15 +3028,15 @@ def setName(self, name): self._jrdd.setName(name) return self - def toDebugString(self): + def toDebugString(self) -> Optional[bytes]: """ A description of this RDD and its recursive dependencies for debugging. """ debug_string = self._jrdd.toDebugString() - if debug_string: - return debug_string.encode("utf-8") - def getStorageLevel(self): + return debug_string.encode("utf-8") if debug_string else None + + def getStorageLevel(self) -> StorageLevel: """ Get the RDD's current storage level. @@ -2675,7 +3058,7 @@ def getStorageLevel(self): ) return storage_level - def _defaultReducePartitions(self): + def _defaultReducePartitions(self) -> int: """ Returns the default number of partitions to use during reduce tasks (e.g., groupBy). If spark.default.parallelism is set, then we'll use the value from SparkContext @@ -2690,7 +3073,7 @@ def _defaultReducePartitions(self): else: return self.getNumPartitions() - def lookup(self, key): + def lookup(self: "RDD[Tuple[K, V]]", key: K) -> List[V]: """ Return the list of values in the RDD for key `key`. This operation is done efficiently if the RDD has a known partitioner by only @@ -2718,16 +3101,18 @@ def lookup(self, key): return values.collect() - def _to_java_object_rdd(self): + def _to_java_object_rdd(self) -> "JavaObject": """Return a JavaRDD of Object by unpickling It will convert each Python object into Java object by Pickle, whenever the RDD is serialized in batch or not. """ rdd = self._pickled() + assert self.ctx._jvm is not None + return self.ctx._jvm.SerDeUtil.pythonToJava(rdd._jrdd, True) - def countApprox(self, timeout, confidence=0.95): + def countApprox(self, timeout: int, confidence: float = 0.95) -> int: """ Approximate version of count() that returns a potentially incomplete result within a timeout, even if not all tasks have finished. @@ -2741,7 +3126,9 @@ def countApprox(self, timeout, confidence=0.95): drdd = self.mapPartitions(lambda it: [float(sum(1 for i in it))]) return int(drdd.sumApprox(timeout, confidence)) - def sumApprox(self, timeout, confidence=0.95): + def sumApprox( + self: "RDD[Union[float, int]]", timeout: int, confidence: float = 0.95 + ) -> BoundedFloat: """ Approximate operation to return the sum within a timeout or meet the confidence. @@ -2754,11 +3141,14 @@ def sumApprox(self, timeout, confidence=0.95): True """ jrdd = self.mapPartitions(lambda it: [float(sum(it))])._to_java_object_rdd() + assert self.ctx._jvm is not None jdrdd = self.ctx._jvm.JavaDoubleRDD.fromRDD(jrdd.rdd()) r = jdrdd.sumApprox(timeout, confidence).getFinalValue() return BoundedFloat(r.mean(), r.confidence(), r.low(), r.high()) - def meanApprox(self, timeout, confidence=0.95): + def meanApprox( + self: "RDD[Union[float, int]]", timeout: int, confidence: float = 0.95 + ) -> BoundedFloat: """ Approximate operation to return the mean within a timeout or meet the confidence. @@ -2771,11 +3161,12 @@ def meanApprox(self, timeout, confidence=0.95): True """ jrdd = self.map(float)._to_java_object_rdd() + assert self.ctx._jvm is not None jdrdd = self.ctx._jvm.JavaDoubleRDD.fromRDD(jrdd.rdd()) r = jdrdd.meanApprox(timeout, confidence).getFinalValue() return BoundedFloat(r.mean(), r.confidence(), r.low(), r.high()) - def countApproxDistinct(self, relativeSD=0.05): + def countApproxDistinct(self: "RDD[T]", relativeSD: float = 0.05) -> int: """ Return approximate number of distinct elements in the RDD. @@ -2808,7 +3199,7 @@ def countApproxDistinct(self, relativeSD=0.05): hashRDD = self.map(lambda x: portable_hash(x) & 0xFFFFFFFF) return hashRDD._to_java_object_rdd().countApproxDistinct(relativeSD) - def toLocalIterator(self, prefetchPartitions=False): + def toLocalIterator(self: "RDD[T]", prefetchPartitions: bool = False) -> Iterator[T]: """ Return an iterator that contains all of the elements in this RDD. The iterator will consume as much memory as the largest partition in this RDD. @@ -2826,13 +3217,15 @@ def toLocalIterator(self, prefetchPartitions=False): >>> [x for x in rdd.toLocalIterator()] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] """ + assert self.ctx._jvm is not None + with SCCallSiteSync(self.context): sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe( self._jrdd.rdd(), prefetchPartitions ) return _local_iterator_from_socket(sock_info, self._jrdd_deserializer) - def barrier(self): + def barrier(self: "RDD[T]") -> "RDDBarrier[T]": """ Marks the current stage as a barrier stage, where Spark must launch all tasks together. In case of a task failure, instead of only restarting the failed task, Spark will abort the @@ -2862,13 +3255,13 @@ def barrier(self): """ return RDDBarrier(self) - def _is_barrier(self): + def _is_barrier(self) -> bool: """ Whether this RDD is in a barrier stage. """ return self._jrdd.rdd().isBarrier() - def withResources(self, profile): + def withResources(self: "RDD[T]", profile: ResourceProfile) -> "RDD[T]": """ Specify a :class:`pyspark.resource.ResourceProfile` to use when calculating this RDD. This is only supported on certain cluster managers and currently requires dynamic @@ -2885,6 +3278,8 @@ def withResources(self, profile): if profile._java_resource_profile is not None: jrp = profile._java_resource_profile else: + assert self.ctx._jvm is not None + builder = self.ctx._jvm.org.apache.spark.resource.ResourceProfileBuilder() ereqs = ExecutorResourceRequests(self.ctx._jvm, profile._executor_resource_requests) treqs = TaskResourceRequests(self.ctx._jvm, profile._task_resource_requests) @@ -2895,7 +3290,7 @@ def withResources(self, profile): self._jrdd.withResources(jrp) return self - def getResourceProfile(self): + def getResourceProfile(self) -> Optional[ResourceProfile]: """ Get the :class:`pyspark.resource.ResourceProfile` specified with this RDD or None if it wasn't specified. @@ -2917,11 +3312,38 @@ def getResourceProfile(self): else: return None + @overload + def toDF( + self: "RDD[RowLike]", + schema: Optional[Union[List[str], Tuple[str, ...]]] = None, + sampleRatio: Optional[float] = None, + ) -> "DataFrame": + ... + + @overload + def toDF( + self: "RDD[RowLike]", schema: Optional[Union["StructType", str]] = None + ) -> "DataFrame": + ... + + @overload + def toDF( + self: "RDD[AtomicValue]", + schema: Union["AtomicType", str], + ) -> "DataFrame": + ... -def _prepare_for_python_RDD(sc, command): + def toDF( + self: "RDD[Any]", schema: Optional[Any] = None, sampleRatio: Optional[float] = None + ) -> "DataFrame": + raise RuntimeError("""RDD.toDF was called before SparkSession was initialized.""") + + +def _prepare_for_python_RDD(sc: "SparkContext", command: Any) -> Tuple[bytes, Any, Any, Any]: # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() pickled_command = ser.dumps(command) + assert sc._jvm is not None if len(pickled_command) > sc._jvm.PythonUtils.getBroadcastThreshold(sc._jsc): # Default 1M # The broadcast will have same life cycle as created PythonRDD broadcast = sc.broadcast(pickled_command) @@ -2931,11 +3353,14 @@ def _prepare_for_python_RDD(sc, command): return pickled_command, broadcast_vars, sc.environment, sc._python_includes -def _wrap_function(sc, func, deserializer, serializer, profiler=None): +def _wrap_function( + sc: "SparkContext", func: Callable, deserializer: Any, serializer: Any, profiler: Any = None +) -> "JavaObject": assert deserializer, "deserializer should not be empty" assert serializer, "serializer should not be empty" command = (func, profiler, deserializer, serializer) pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) + assert sc._jvm is not None return sc._jvm.PythonFunction( bytearray(pickled_command), env, @@ -2947,7 +3372,7 @@ def _wrap_function(sc, func, deserializer, serializer, profiler=None): ) -class RDDBarrier: +class RDDBarrier(Generic[T]): """ Wraps an RDD in a barrier stage, which forces Spark to launch tasks of this stage together. @@ -2960,10 +3385,12 @@ class RDDBarrier: This API is experimental """ - def __init__(self, rdd): + def __init__(self, rdd: RDD[T]): self.rdd = rdd - def mapPartitions(self, f, preservesPartitioning=False): + def mapPartitions( + self, f: Callable[[Iterable[T]], Iterable[U]], preservesPartitioning: bool = False + ) -> RDD[U]: """ Returns a new RDD by applying a function to each partition of the wrapped RDD, where tasks are launched together in a barrier stage. @@ -2977,12 +3404,16 @@ def mapPartitions(self, f, preservesPartitioning=False): This API is experimental """ - def func(s, iterator): + def func(s: int, iterator: Iterable[T]) -> Iterable[U]: return f(iterator) return PipelinedRDD(self.rdd, func, preservesPartitioning, isFromBarrier=True) - def mapPartitionsWithIndex(self, f, preservesPartitioning=False): + def mapPartitionsWithIndex( + self, + f: Callable[[int, Iterable[T]], Iterable[U]], + preservesPartitioning: bool = False, + ) -> RDD[U]: """ Returns a new RDD by applying a function to each partition of the wrapped RDD, while tracking the index of the original partition. And all tasks are launched together @@ -2999,7 +3430,7 @@ def mapPartitionsWithIndex(self, f, preservesPartitioning=False): return PipelinedRDD(self.rdd, f, preservesPartitioning, isFromBarrier=True) -class PipelinedRDD(RDD): +class PipelinedRDD(RDD[U], Generic[T, U]): """ Examples @@ -3021,7 +3452,13 @@ class PipelinedRDD(RDD): 20 """ - def __init__(self, prev, func, preservesPartitioning=False, isFromBarrier=False): + def __init__( + self, + prev: RDD[T], + func: Callable[[int, Iterable[T]], Iterable[U]], + preservesPartitioning: bool = False, + isFromBarrier: bool = False, + ): if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable(): # This transformation is the first in its stage: self.func = func @@ -3029,9 +3466,9 @@ def __init__(self, prev, func, preservesPartitioning=False, isFromBarrier=False) self._prev_jrdd = prev._jrdd self._prev_jrdd_deserializer = prev._jrdd_deserializer else: - prev_func = prev.func + prev_func: Callable[[int, Iterable[V]], Iterable[T]] = prev.func - def pipeline_func(split, iterator): + def pipeline_func(split: int, iterator: Iterable[V]) -> Iterable[U]: return func(split, prev_func(split, iterator)) self.func = pipeline_func @@ -3043,18 +3480,18 @@ def pipeline_func(split, iterator): self.is_checkpointed = False self.ctx = prev.ctx self.prev = prev - self._jrdd_val = None + self._jrdd_val: Optional["JavaObject"] = None self._id = None self._jrdd_deserializer = self.ctx.serializer self._bypass_serializer = False self.partitioner = prev.partitioner if self.preservesPartitioning else None self.is_barrier = isFromBarrier or prev._is_barrier() - def getNumPartitions(self): + def getNumPartitions(self) -> int: return self._prev_jrdd.partitions().size() @property - def _jrdd(self): + def _jrdd(self) -> "JavaObject": if self._jrdd_val: return self._jrdd_val if self._bypass_serializer: @@ -3068,29 +3505,32 @@ def _jrdd(self): wrapped_func = _wrap_function( self.ctx, self.func, self._prev_jrdd_deserializer, self._jrdd_deserializer, profiler ) + + assert self.ctx._jvm is not None python_rdd = self.ctx._jvm.PythonRDD( self._prev_jrdd.rdd(), wrapped_func, self.preservesPartitioning, self.is_barrier ) self._jrdd_val = python_rdd.asJavaRDD() if profiler: + assert self._jrdd_val is not None self._id = self._jrdd_val.id() self.ctx.profiler_collector.add_profiler(self._id, profiler) return self._jrdd_val - def id(self): + def id(self) -> int: if self._id is None: self._id = self._jrdd.id() return self._id - def _is_pipelinable(self): + def _is_pipelinable(self) -> bool: return not (self.is_cached or self.is_checkpointed or self.has_resource_profile) - def _is_barrier(self): + def _is_barrier(self) -> bool: return self.is_barrier -def _test(): +def _test() -> None: import doctest from pyspark.context import SparkContext diff --git a/python/pyspark/rdd.pyi b/python/pyspark/rdd.pyi deleted file mode 100644 index c4eddbf150423..0000000000000 --- a/python/pyspark/rdd.pyi +++ /dev/null @@ -1,481 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import overload -from typing import ( - Any, - Callable, - Dict, - Generic, - Hashable, - Iterable, - Iterator, - List, - Optional, - Tuple, - Union, - TypeVar, -) -from typing_extensions import Literal - -from numpy import int32, int64, float32, float64, ndarray - -from pyspark._typing import SupportsOrdering -from pyspark.sql.pandas._typing import ( - PandasScalarUDFType, - PandasScalarIterUDFType, - PandasGroupedMapUDFType, - PandasCogroupedMapUDFType, - PandasGroupedAggUDFType, - PandasMapIterUDFType, - ArrowMapIterUDFType, -) -import pyspark.context -from pyspark.resultiterable import ResultIterable -from pyspark.serializers import Serializer -from pyspark.storagelevel import StorageLevel -from pyspark.resource.requests import ( # noqa: F401 - ExecutorResourceRequests, - TaskResourceRequests, -) -from pyspark.resource.profile import ResourceProfile -from pyspark.statcounter import StatCounter -from pyspark.sql.dataframe import DataFrame -from pyspark.sql.types import AtomicType, StructType -from pyspark.sql._typing import AtomicValue, RowLike -from py4j.java_gateway import JavaObject # type: ignore[import] - -T = TypeVar("T") -T_co = TypeVar("T_co", covariant=True) -U = TypeVar("U") -K = TypeVar("K", bound=Hashable) -V = TypeVar("V") -V1 = TypeVar("V1") -V2 = TypeVar("V2") -V3 = TypeVar("V3") -O = TypeVar("O", bound=SupportsOrdering) -NumberOrArray = TypeVar( - "NumberOrArray", float, int, complex, int32, int64, float32, float64, ndarray -) - -def portable_hash(x: Hashable) -> int: ... - -class PythonEvalType: - NON_UDF: Literal[0] - SQL_BATCHED_UDF: Literal[100] - SQL_SCALAR_PANDAS_UDF: PandasScalarUDFType - SQL_GROUPED_MAP_PANDAS_UDF: PandasGroupedMapUDFType - SQL_GROUPED_AGG_PANDAS_UDF: PandasGroupedAggUDFType - SQL_WINDOW_AGG_PANDAS_UDF: Literal[203] - SQL_SCALAR_PANDAS_ITER_UDF: PandasScalarIterUDFType - SQL_MAP_PANDAS_ITER_UDF: PandasMapIterUDFType - SQL_COGROUPED_MAP_PANDAS_UDF: PandasCogroupedMapUDFType - SQL_MAP_ARROW_ITER_UDF: ArrowMapIterUDFType - -class BoundedFloat(float): - def __new__(cls, mean: float, confidence: float, low: float, high: float) -> BoundedFloat: ... - -class Partitioner: - numPartitions: int - partitionFunc: Callable[[Any], int] - def __init__(self, numPartitions: int, partitionFunc: Callable[[Any], int]) -> None: ... - def __eq__(self, other: Any) -> bool: ... - def __call__(self, k: Any) -> int: ... - -class RDD(Generic[T_co]): - is_cached: bool - is_checkpointed: bool - ctx: pyspark.context.SparkContext - partitioner: Optional[Partitioner] - def __init__( - self, - jrdd: JavaObject, - ctx: pyspark.context.SparkContext, - jrdd_deserializer: Serializer = ..., - ) -> None: ... - def id(self) -> int: ... - def __getnewargs__(self) -> Any: ... - @property - def context(self) -> pyspark.context.SparkContext: ... - def cache(self) -> RDD[T_co]: ... - def persist(self, storageLevel: StorageLevel = ...) -> RDD[T_co]: ... - def unpersist(self, blocking: bool = ...) -> RDD[T_co]: ... - def checkpoint(self) -> None: ... - def isCheckpointed(self) -> bool: ... - def localCheckpoint(self) -> None: ... - def isLocallyCheckpointed(self) -> bool: ... - def getCheckpointFile(self) -> Optional[str]: ... - def map(self, f: Callable[[T_co], U], preservesPartitioning: bool = ...) -> RDD[U]: ... - def flatMap( - self, f: Callable[[T_co], Iterable[U]], preservesPartitioning: bool = ... - ) -> RDD[U]: ... - def mapPartitions( - self, f: Callable[[Iterable[T_co]], Iterable[U]], preservesPartitioning: bool = ... - ) -> RDD[U]: ... - def mapPartitionsWithIndex( - self, - f: Callable[[int, Iterable[T_co]], Iterable[U]], - preservesPartitioning: bool = ..., - ) -> RDD[U]: ... - def mapPartitionsWithSplit( - self, - f: Callable[[int, Iterable[T_co]], Iterable[U]], - preservesPartitioning: bool = ..., - ) -> RDD[U]: ... - def getNumPartitions(self) -> int: ... - def filter(self, f: Callable[[T_co], bool]) -> RDD[T_co]: ... - def distinct(self, numPartitions: Optional[int] = ...) -> RDD[T_co]: ... - def sample( - self, withReplacement: bool, fraction: float, seed: Optional[int] = ... - ) -> RDD[T_co]: ... - def randomSplit( - self, weights: List[Union[int, float]], seed: Optional[int] = ... - ) -> List[RDD[T_co]]: ... - def takeSample( - self, withReplacement: bool, num: int, seed: Optional[int] = ... - ) -> List[T_co]: ... - def union(self, other: RDD[U]) -> RDD[Union[T_co, U]]: ... - def intersection(self, other: RDD[T_co]) -> RDD[T_co]: ... - def __add__(self, other: RDD[T_co]) -> RDD[T_co]: ... - @overload - def repartitionAndSortWithinPartitions( - self: RDD[Tuple[O, V]], - numPartitions: Optional[int] = ..., - partitionFunc: Callable[[O], int] = ..., - ascending: bool = ..., - ) -> RDD[Tuple[O, V]]: ... - @overload - def repartitionAndSortWithinPartitions( - self: RDD[Tuple[K, V]], - numPartitions: Optional[int], - partitionFunc: Callable[[K], int], - ascending: bool, - keyfunc: Callable[[K], O], - ) -> RDD[Tuple[K, V]]: ... - @overload - def repartitionAndSortWithinPartitions( - self: RDD[Tuple[K, V]], - numPartitions: Optional[int] = ..., - partitionFunc: Callable[[K], int] = ..., - ascending: bool = ..., - *, - keyfunc: Callable[[K], O], - ) -> RDD[Tuple[K, V]]: ... - @overload - def sortByKey( - self: RDD[Tuple[O, V]], - ascending: bool = ..., - numPartitions: Optional[int] = ..., - ) -> RDD[Tuple[K, V]]: ... - @overload - def sortByKey( - self: RDD[Tuple[K, V]], - ascending: bool, - numPartitions: int, - keyfunc: Callable[[K], O], - ) -> RDD[Tuple[K, V]]: ... - @overload - def sortByKey( - self: RDD[Tuple[K, V]], - ascending: bool = ..., - numPartitions: Optional[int] = ..., - *, - keyfunc: Callable[[K], O], - ) -> RDD[Tuple[K, V]]: ... - def sortBy( - self, - keyfunc: Callable[[T_co], O], - ascending: bool = ..., - numPartitions: Optional[int] = ..., - ) -> RDD[T_co]: ... - def glom(self) -> RDD[List[T_co]]: ... - def cartesian(self, other: RDD[U]) -> RDD[Tuple[T_co, U]]: ... - def groupBy( - self, - f: Callable[[T_co], K], - numPartitions: Optional[int] = ..., - partitionFunc: Callable[[K], int] = ..., - ) -> RDD[Tuple[K, Iterable[T_co]]]: ... - def pipe( - self, command: str, env: Optional[Dict[str, str]] = ..., checkCode: bool = ... - ) -> RDD[str]: ... - def foreach(self, f: Callable[[T_co], None]) -> None: ... - def foreachPartition(self, f: Callable[[Iterable[T_co]], None]) -> None: ... - def collect(self) -> List[T_co]: ... - def collectWithJobGroup( - self, groupId: str, description: str, interruptOnCancel: bool = ... - ) -> List[T_co]: ... - def reduce(self, f: Callable[[T_co, T_co], T_co]) -> T_co: ... - def treeReduce(self, f: Callable[[T_co, T_co], T_co], depth: int = ...) -> T_co: ... - def fold(self, zeroValue: T, op: Callable[[T_co, T_co], T_co]) -> T_co: ... - def aggregate( - self, zeroValue: U, seqOp: Callable[[U, T_co], U], combOp: Callable[[U, U], U] - ) -> U: ... - def treeAggregate( - self, - zeroValue: U, - seqOp: Callable[[U, T_co], U], - combOp: Callable[[U, U], U], - depth: int = ..., - ) -> U: ... - @overload - def max(self: RDD[O]) -> O: ... - @overload - def max(self, key: Callable[[T_co], O]) -> T_co: ... - @overload - def min(self: RDD[O]) -> O: ... - @overload - def min(self, key: Callable[[T_co], O]) -> T_co: ... - def sum(self: RDD[NumberOrArray]) -> NumberOrArray: ... - def count(self) -> int: ... - def stats(self: RDD[NumberOrArray]) -> StatCounter: ... - def histogram( - self, buckets: Union[int, List[T_co], Tuple[T_co, ...]] - ) -> Tuple[List[T_co], List[int]]: ... - def mean(self: RDD[NumberOrArray]) -> NumberOrArray: ... - def variance(self: RDD[NumberOrArray]) -> NumberOrArray: ... - def stdev(self: RDD[NumberOrArray]) -> NumberOrArray: ... - def sampleStdev(self: RDD[NumberOrArray]) -> NumberOrArray: ... - def sampleVariance(self: RDD[NumberOrArray]) -> NumberOrArray: ... - def countByValue(self: RDD[K]) -> Dict[K, int]: ... - @overload - def top(self: RDD[O], num: int) -> List[O]: ... - @overload - def top(self, num: int, key: Callable[[T_co], O]) -> List[T_co]: ... - @overload - def takeOrdered(self: RDD[O], num: int) -> List[O]: ... - @overload - def takeOrdered(self, num: int, key: Callable[[T_co], O]) -> List[T_co]: ... - def take(self, num: int) -> List[T_co]: ... - def first(self) -> T_co: ... - def isEmpty(self) -> bool: ... - def saveAsNewAPIHadoopDataset( - self: RDD[Tuple[K, V]], - conf: Dict[str, str], - keyConverter: Optional[str] = ..., - valueConverter: Optional[str] = ..., - ) -> None: ... - def saveAsNewAPIHadoopFile( - self: RDD[Tuple[K, V]], - path: str, - outputFormatClass: str, - keyClass: Optional[str] = ..., - valueClass: Optional[str] = ..., - keyConverter: Optional[str] = ..., - valueConverter: Optional[str] = ..., - conf: Optional[Dict[str, str]] = ..., - ) -> None: ... - def saveAsHadoopDataset( - self: RDD[Tuple[K, V]], - conf: Dict[str, str], - keyConverter: Optional[str] = ..., - valueConverter: Optional[str] = ..., - ) -> None: ... - def saveAsHadoopFile( - self: RDD[Tuple[K, V]], - path: str, - outputFormatClass: str, - keyClass: Optional[str] = ..., - valueClass: Optional[str] = ..., - keyConverter: Optional[str] = ..., - valueConverter: Optional[str] = ..., - conf: Optional[str] = ..., - compressionCodecClass: Optional[str] = ..., - ) -> None: ... - def saveAsSequenceFile( - self: RDD[Tuple[K, V]], path: str, compressionCodecClass: Optional[str] = ... - ) -> None: ... - def saveAsPickleFile(self, path: str, batchSize: int = ...) -> None: ... - def saveAsTextFile(self, path: str, compressionCodecClass: Optional[str] = ...) -> None: ... - def collectAsMap(self: RDD[Tuple[K, V]]) -> Dict[K, V]: ... - def keys(self: RDD[Tuple[K, V]]) -> RDD[K]: ... - def values(self: RDD[Tuple[K, V]]) -> RDD[V]: ... - def reduceByKey( - self: RDD[Tuple[K, V]], - func: Callable[[V, V], V], - numPartitions: Optional[int] = ..., - partitionFunc: Callable[[K], int] = ..., - ) -> RDD[Tuple[K, V]]: ... - def reduceByKeyLocally(self: RDD[Tuple[K, V]], func: Callable[[V, V], V]) -> Dict[K, V]: ... - def countByKey(self: RDD[Tuple[K, V]]) -> Dict[K, int]: ... - def join( - self: RDD[Tuple[K, V]], - other: RDD[Tuple[K, U]], - numPartitions: Optional[int] = ..., - ) -> RDD[Tuple[K, Tuple[V, U]]]: ... - def leftOuterJoin( - self: RDD[Tuple[K, V]], - other: RDD[Tuple[K, U]], - numPartitions: Optional[int] = ..., - ) -> RDD[Tuple[K, Tuple[V, Optional[U]]]]: ... - def rightOuterJoin( - self: RDD[Tuple[K, V]], - other: RDD[Tuple[K, U]], - numPartitions: Optional[int] = ..., - ) -> RDD[Tuple[K, Tuple[Optional[V], U]]]: ... - def fullOuterJoin( - self: RDD[Tuple[K, V]], - other: RDD[Tuple[K, U]], - numPartitions: Optional[int] = ..., - ) -> RDD[Tuple[K, Tuple[Optional[V], Optional[U]]]]: ... - def partitionBy( - self: RDD[Tuple[K, V]], - numPartitions: int, - partitionFunc: Callable[[K], int] = ..., - ) -> RDD[Tuple[K, V]]: ... - def combineByKey( - self: RDD[Tuple[K, V]], - createCombiner: Callable[[V], U], - mergeValue: Callable[[U, V], U], - mergeCombiners: Callable[[U, U], U], - numPartitions: Optional[int] = ..., - partitionFunc: Callable[[K], int] = ..., - ) -> RDD[Tuple[K, U]]: ... - def aggregateByKey( - self: RDD[Tuple[K, V]], - zeroValue: U, - seqFunc: Callable[[U, V], U], - combFunc: Callable[[U, U], U], - numPartitions: Optional[int] = ..., - partitionFunc: Callable[[K], int] = ..., - ) -> RDD[Tuple[K, U]]: ... - def foldByKey( - self: RDD[Tuple[K, V]], - zeroValue: V, - func: Callable[[V, V], V], - numPartitions: Optional[int] = ..., - partitionFunc: Callable[[K], int] = ..., - ) -> RDD[Tuple[K, V]]: ... - def groupByKey( - self: RDD[Tuple[K, V]], - numPartitions: Optional[int] = ..., - partitionFunc: Callable[[K], int] = ..., - ) -> RDD[Tuple[K, Iterable[V]]]: ... - def flatMapValues( - self: RDD[Tuple[K, V]], f: Callable[[V], Iterable[U]] - ) -> RDD[Tuple[K, U]]: ... - def mapValues(self: RDD[Tuple[K, V]], f: Callable[[V], U]) -> RDD[Tuple[K, U]]: ... - @overload - def groupWith( - self: RDD[Tuple[K, V]], __o: RDD[Tuple[K, V1]] - ) -> RDD[Tuple[K, Tuple[ResultIterable[V], ResultIterable[V1]]]]: ... - @overload - def groupWith( - self: RDD[Tuple[K, V]], __o1: RDD[Tuple[K, V1]], __o2: RDD[Tuple[K, V2]] - ) -> RDD[Tuple[K, Tuple[ResultIterable[V], ResultIterable[V1], ResultIterable[V2]]]]: ... - @overload - def groupWith( - self: RDD[Tuple[K, V]], - other1: RDD[Tuple[K, V1]], - other2: RDD[Tuple[K, V2]], - other3: RDD[Tuple[K, V3]], - ) -> RDD[ - Tuple[ - K, - Tuple[ - ResultIterable[V], - ResultIterable[V1], - ResultIterable[V2], - ResultIterable[V3], - ], - ] - ]: ... - def cogroup( - self: RDD[Tuple[K, V]], - other: RDD[Tuple[K, U]], - numPartitions: Optional[int] = ..., - ) -> RDD[Tuple[K, Tuple[ResultIterable[V], ResultIterable[U]]]]: ... - def sampleByKey( - self: RDD[Tuple[K, V]], - withReplacement: bool, - fractions: Dict[K, Union[float, int]], - seed: Optional[int] = ..., - ) -> RDD[Tuple[K, V]]: ... - def subtractByKey( - self: RDD[Tuple[K, V]], - other: RDD[Tuple[K, U]], - numPartitions: Optional[int] = ..., - ) -> RDD[Tuple[K, V]]: ... - def subtract(self, other: RDD[T_co], numPartitions: Optional[int] = ...) -> RDD[T_co]: ... - def keyBy(self, f: Callable[[T_co], K]) -> RDD[Tuple[K, T_co]]: ... - def repartition(self, numPartitions: int) -> RDD[T_co]: ... - def coalesce(self, numPartitions: int, shuffle: bool = ...) -> RDD[T_co]: ... - def zip(self, other: RDD[U]) -> RDD[Tuple[T_co, U]]: ... - def zipWithIndex(self) -> RDD[Tuple[T_co, int]]: ... - def zipWithUniqueId(self) -> RDD[Tuple[T_co, int]]: ... - def name(self) -> str: ... - def setName(self, name: str) -> RDD[T_co]: ... - def toDebugString(self) -> bytes: ... - def getStorageLevel(self) -> StorageLevel: ... - def lookup(self: RDD[Tuple[K, V]], key: K) -> List[V]: ... - def countApprox(self, timeout: int, confidence: float = ...) -> int: ... - def sumApprox( - self: RDD[Union[float, int]], timeout: int, confidence: float = ... - ) -> BoundedFloat: ... - def meanApprox( - self: RDD[Union[float, int]], timeout: int, confidence: float = ... - ) -> BoundedFloat: ... - def countApproxDistinct(self, relativeSD: float = ...) -> int: ... - def toLocalIterator(self, prefetchPartitions: bool = ...) -> Iterator[T_co]: ... - def barrier(self) -> RDDBarrier[T_co]: ... - def withResources(self, profile: ResourceProfile) -> RDD[T_co]: ... - def getResourceProfile(self) -> Optional[ResourceProfile]: ... - @overload - def toDF( - self: RDD[RowLike], - schema: Optional[Union[List[str], Tuple[str, ...]]] = ..., - sampleRatio: Optional[float] = ..., - ) -> DataFrame: ... - @overload - def toDF(self: RDD[RowLike], schema: Optional[Union[StructType, str]] = ...) -> DataFrame: ... - @overload - def toDF( - self: RDD[AtomicValue], - schema: Union[AtomicType, str], - ) -> DataFrame: ... - -class RDDBarrier(Generic[T]): - rdd: RDD[T] - def __init__(self, rdd: RDD[T]) -> None: ... - def mapPartitions( - self, f: Callable[[Iterable[T]], Iterable[U]], preservesPartitioning: bool = ... - ) -> RDD[U]: ... - def mapPartitionsWithIndex( - self, - f: Callable[[int, Iterable[T]], Iterable[U]], - preservesPartitioning: bool = ..., - ) -> RDD[U]: ... - -class PipelinedRDD(RDD[U], Generic[T, U]): - func: Callable[[T], U] - preservesPartitioning: bool - is_cached: bool - is_checkpointed: bool - ctx: pyspark.context.SparkContext - prev: RDD[T] - partitioner: Optional[Partitioner] - is_barrier: bool - def __init__( - self, - prev: RDD[T], - func: Callable[[Iterable[T]], Iterable[U]], - preservesPartitioning: bool = ..., - isFromBarrier: bool = ..., - ) -> None: ... - def getNumPartitions(self) -> int: ... - def id(self) -> int: ... diff --git a/python/pyspark/resource/profile.py b/python/pyspark/resource/profile.py index 24556f4f3b339..37e8ee85ea21c 100644 --- a/python/pyspark/resource/profile.py +++ b/python/pyspark/resource/profile.py @@ -121,7 +121,7 @@ def __init__(self) -> None: from pyspark.context import SparkContext # TODO: ignore[attr-defined] will be removed, once SparkContext is inlined - _jvm = SparkContext._jvm # type: ignore[attr-defined] + _jvm = SparkContext._jvm if _jvm is not None: self._jvm = _jvm self._java_resource_profile_builder = ( @@ -138,17 +138,15 @@ def require( ) -> "ResourceProfileBuilder": if isinstance(resourceRequest, TaskResourceRequests): if self._java_resource_profile_builder is not None: - if ( - resourceRequest._java_task_resource_requests is not None - ): # type: ignore[attr-defined] + if resourceRequest._java_task_resource_requests is not None: self._java_resource_profile_builder.require( resourceRequest._java_task_resource_requests - ) # type: ignore[attr-defined] + ) else: taskReqs = TaskResourceRequests(self._jvm, resourceRequest.requests) self._java_resource_profile_builder.require( taskReqs._java_task_resource_requests - ) # type: ignore[attr-defined] + ) else: self._task_resource_requests.update( # type: ignore[union-attr] resourceRequest.requests @@ -163,7 +161,7 @@ def require( self._jvm, resourceRequest.requests # type: ignore[attr-defined] ) self._java_resource_profile_builder.require( - execReqs._java_executor_resource_requests # type: ignore[attr-defined] + execReqs._java_executor_resource_requests ) else: self._executor_resource_requests.update( # type: ignore[union-attr] diff --git a/python/pyspark/resource/requests.py b/python/pyspark/resource/requests.py index 58226116979fe..0999e4e4aeb68 100644 --- a/python/pyspark/resource/requests.py +++ b/python/pyspark/resource/requests.py @@ -18,7 +18,7 @@ from py4j.java_gateway import JavaObject, JVMView -from pyspark.util import _parse_memory # type: ignore[attr-defined] +from pyspark.util import _parse_memory class ExecutorResourceRequest: @@ -133,7 +133,7 @@ def __init__( ): from pyspark import SparkContext - _jvm = _jvm or SparkContext._jvm # type: ignore[attr-defined] + _jvm = _jvm or SparkContext._jvm if _jvm is not None: self._java_executor_resource_requests = ( _jvm.org.apache.spark.resource.ExecutorResourceRequests() @@ -302,7 +302,7 @@ def __init__( ): from pyspark import SparkContext - _jvm = _jvm or SparkContext._jvm # type: ignore[attr-defined] + _jvm = _jvm or SparkContext._jvm if _jvm is not None: self._java_task_resource_requests: Optional[ JavaObject diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index a0941afd36e4f..8c5a941f376d2 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -66,7 +66,7 @@ pickle_protocol = pickle.HIGHEST_PROTOCOL from pyspark import cloudpickle -from pyspark.util import print_exec # type: ignore +from pyspark.util import print_exec __all__ = [ @@ -100,6 +100,13 @@ def load_stream(self, stream): """ raise NotImplementedError + def dumps(self, obj): + """ + Serialize an object into a byte array. + When batching is used, this will be called with an array of objects. + """ + raise NotImplementedError + def _load_stream_without_unbatching(self, stream): """ Return an iterator of deserialized batches (iterable) of objects from the input stream. @@ -357,7 +364,7 @@ def dumps(self, obj): # requires namedtuple hack. # The whole hack here should be removed once we drop Python 3.7. - __cls = {} # type: ignore + __cls = {} def _restore(name, fields, value): """Restore an object of namedtuple""" diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index f0c487877a086..9004a94e34063 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -28,14 +28,15 @@ from pyspark.context import SparkContext from pyspark.sql import SparkSession +from pyspark.sql.context import SQLContext if os.environ.get("SPARK_EXECUTOR_URI"): SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"]) -SparkContext._ensure_initialized() # type: ignore +SparkContext._ensure_initialized() try: - spark = SparkSession._create_shell_session() # type: ignore + spark = SparkSession._create_shell_session() except Exception: import sys import traceback @@ -46,10 +47,10 @@ sc = spark.sparkContext sql = spark.sql -atexit.register(lambda: sc.stop()) +atexit.register((lambda sc: lambda: sc.stop())(sc)) # for compatibility -sqlContext = spark._wrapped +sqlContext = SQLContext._get_or_create(sc) sqlCtx = sqlContext print( diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 0709d2de25a67..35c3397de503b 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -33,7 +33,7 @@ CompressedSerializer, AutoBatchedSerializer, ) -from pyspark.util import fail_on_stopiteration # type: ignore +from pyspark.util import fail_on_stopiteration try: diff --git a/python/pyspark/sql/_typing.pyi b/python/pyspark/sql/_typing.pyi index 2adae6c237389..209bb70faddef 100644 --- a/python/pyspark/sql/_typing.pyi +++ b/python/pyspark/sql/_typing.pyi @@ -25,7 +25,7 @@ from typing import ( TypeVar, Union, ) -from typing_extensions import Protocol +from typing_extensions import Literal, Protocol import datetime import decimal @@ -56,6 +56,8 @@ AtomicValue = TypeVar( RowLike = TypeVar("RowLike", List[Any], Tuple[Any, ...], pyspark.sql.types.Row) +SQLBatchedUDFType = Literal[100] + class SupportsOpen(Protocol): def open(self, partition_id: int, epoch_id: int) -> bool: ... diff --git a/python/pyspark/sql/avro/functions.py b/python/pyspark/sql/avro/functions.py index 57fa7fc773df2..909fe3f3bd00b 100644 --- a/python/pyspark/sql/avro/functions.py +++ b/python/pyspark/sql/avro/functions.py @@ -23,7 +23,7 @@ from typing import Dict, Optional, TYPE_CHECKING from pyspark import SparkContext from pyspark.sql.column import Column, _to_java_column -from pyspark.util import _print_missing_jar # type: ignore[attr-defined] +from pyspark.util import _print_missing_jar if TYPE_CHECKING: from pyspark.sql._typing import ColumnOrName diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index ea8bb97c3b712..b954995f857bb 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -345,8 +345,8 @@ def createTable( if path is not None: options["path"] = path if source is None: - c = self._sparkSession._wrapped._conf - source = c.defaultDataSourceName() # type: ignore[attr-defined] + c = self._sparkSession._jconf + source = c.defaultDataSourceName() if description is None: description = "" if schema is None: @@ -356,7 +356,7 @@ def createTable( raise TypeError("schema should be StructType") scala_datatype = self._jsparkSession.parseDataType(schema.json()) df = self._jcatalog.createTable(tableName, source, scala_datatype, description, options) - return DataFrame(df, self._sparkSession._wrapped) + return DataFrame(df, self._sparkSession) def dropTempView(self, viewName: str) -> None: """Drops the local temporary view with the given view name in the catalog. diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index dce0cc6d1b327..04458d560ee8d 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -31,7 +31,7 @@ Union, ) -from py4j.java_gateway import JavaObject # type: ignore[import] +from py4j.java_gateway import JavaObject from pyspark import copy_func from pyspark.context import SparkContext @@ -233,23 +233,13 @@ def __init__(self, jc: JavaObject) -> None: __radd__ = cast( Callable[["Column", Union["LiteralType", "DecimalLiteral"]], "Column"], _bin_op("plus") ) - __rsub__ = cast( - Callable[["Column", Union["LiteralType", "DecimalLiteral"]], "Column"], _reverse_op("minus") - ) + __rsub__ = _reverse_op("minus") __rmul__ = cast( Callable[["Column", Union["LiteralType", "DecimalLiteral"]], "Column"], _bin_op("multiply") ) - __rdiv__ = cast( - Callable[["Column", Union["LiteralType", "DecimalLiteral"]], "Column"], - _reverse_op("divide"), - ) - __rtruediv__ = cast( - Callable[["Column", Union["LiteralType", "DecimalLiteral"]], "Column"], - _reverse_op("divide"), - ) - __rmod__ = cast( - Callable[["Column", Union["LiteralType", "DecimalLiteral"]], "Column"], _reverse_op("mod") - ) + __rdiv__ = _reverse_op("divide") + __rtruediv__ = _reverse_op("divide") + __rmod__ = _reverse_op("mod") __pow__ = _bin_func_op("pow") __rpow__ = cast( @@ -709,7 +699,7 @@ def substr(self, startPos: Union[int, "Column"], length: Union[int, "Column"]) - if isinstance(startPos, int): jc = self._jc.substr(startPos, length) elif isinstance(startPos, Column): - jc = self._jc.substr(cast("Column", startPos)._jc, cast("Column", length)._jc) + jc = self._jc.substr(startPos._jc, cast("Column", length)._jc) else: raise TypeError("Unexpected type: %s" % type(startPos)) return Column(jc) diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index 7e8a56574822f..40a36a26701a6 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -18,9 +18,9 @@ import sys from typing import Any, Optional, Union -from py4j.java_gateway import JavaObject # type: ignore[import] +from py4j.java_gateway import JavaObject -from pyspark import since, _NoValue # type: ignore[attr-defined] +from pyspark import since, _NoValue from pyspark._globals import _NoValueType diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 6ab70ee1c39e0..6c7ab6f937e75 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -32,9 +32,9 @@ cast, ) -from py4j.java_gateway import JavaObject # type: ignore[import] +from py4j.java_gateway import JavaObject -from pyspark import since, _NoValue # type: ignore[attr-defined] +from pyspark import since, _NoValue from pyspark._globals import _NoValueType from pyspark.sql.session import _monkey_patch_RDD, SparkSession from pyspark.sql.dataframe import DataFrame @@ -46,7 +46,6 @@ from pyspark.rdd import RDD from pyspark.sql.types import AtomicType, DataType, StructType from pyspark.sql.streaming import StreamingQueryManager -from pyspark.conf import SparkConf if TYPE_CHECKING: from pyspark.sql._typing import ( @@ -116,19 +115,19 @@ def __init__( ) self._sc = sparkContext - self._jsc = self._sc._jsc # type: ignore[attr-defined] - self._jvm = self._sc._jvm # type: ignore[attr-defined] + self._jsc = self._sc._jsc + self._jvm = self._sc._jvm if sparkSession is None: sparkSession = SparkSession._getActiveSessionOrCreate() if jsqlContext is None: - jsqlContext = sparkSession._jwrapped + jsqlContext = sparkSession._jsparkSession.sqlContext() self.sparkSession = sparkSession self._jsqlContext = jsqlContext _monkey_patch_RDD(self.sparkSession) install_exception_handler() if ( SQLContext._instantiatedContext is None - or SQLContext._instantiatedContext._sc._jsc is None # type: ignore[attr-defined] + or SQLContext._instantiatedContext._sc._jsc is None ): SQLContext._instantiatedContext = self @@ -141,11 +140,6 @@ def _ssql_ctx(self) -> JavaObject: """ return self._jsqlContext - @property - def _conf(self) -> SparkConf: - """Accessor for the JVM SQL-specific configurations""" - return self.sparkSession._jsparkSession.sessionState().conf() - @classmethod def getOrCreate(cls: Type["SQLContext"], sc: SparkContext) -> "SQLContext": """ @@ -164,17 +158,22 @@ def getOrCreate(cls: Type["SQLContext"], sc: SparkContext) -> "SQLContext": "Deprecated in 3.0.0. Use SparkSession.builder.getOrCreate() instead.", FutureWarning, ) + return cls._get_or_create(sc) + + @classmethod + def _get_or_create( + cls: Type["SQLContext"], sc: SparkContext, **static_conf: Any + ) -> "SQLContext": if ( cls._instantiatedContext is None or SQLContext._instantiatedContext._sc._jsc is None # type: ignore[union-attr] ): assert sc._jvm is not None - jsqlContext = ( - sc._jvm.SparkSession.builder().sparkContext(sc._jsc.sc()).getOrCreate().sqlContext() - ) - sparkSession = SparkSession(sc, jsqlContext.sparkSession()) - cls(sc, sparkSession, jsqlContext) + # There can be only one running Spark context. That will automatically + # be used in the Spark session internally. + session = SparkSession._getActiveSessionOrCreate(**static_conf) + cls(sc, session, session._jsparkSession.sqlContext()) return cast(SQLContext, cls._instantiatedContext) def newSession(self) -> "SQLContext": @@ -365,7 +364,7 @@ def createDataFrame( def createDataFrame( # type: ignore[misc] self, - data: Union["RDD[Any]", Iterable[Any], "PandasDataFrameLike"], + data: Union[RDD[Any], Iterable[Any], "PandasDataFrameLike"], schema: Optional[Union[AtomicType, StructType, str]] = None, samplingRatio: Optional[float] = None, verifySchema: bool = True, @@ -590,9 +589,9 @@ def tables(self, dbName: Optional[str] = None) -> DataFrame: Row(namespace='', tableName='table1', isTemporary=True) """ if dbName is None: - return DataFrame(self._ssql_ctx.tables(), self) + return DataFrame(self._ssql_ctx.tables(), self.sparkSession) else: - return DataFrame(self._ssql_ctx.tables(dbName), self) + return DataFrame(self._ssql_ctx.tables(dbName), self.sparkSession) def tableNames(self, dbName: Optional[str] = None) -> List[str]: """Returns a list of names of tables in the database ``dbName``. @@ -647,7 +646,7 @@ def read(self) -> DataFrameReader: ------- :class:`DataFrameReader` """ - return DataFrameReader(self) + return DataFrameReader(self.sparkSession) @property def readStream(self) -> DataStreamReader: @@ -669,7 +668,7 @@ def readStream(self) -> DataStreamReader: >>> text_sdf.isStreaming True """ - return DataStreamReader(self) + return DataStreamReader(self.sparkSession) @property def streams(self) -> StreamingQueryManager: @@ -708,21 +707,34 @@ class HiveContext(SQLContext): """ - def __init__(self, sparkContext: SparkContext, jhiveContext: Optional[JavaObject] = None): + _static_conf = {"spark.sql.catalogImplementation": "hive"} + + def __init__( + self, + sparkContext: SparkContext, + sparkSession: Optional[SparkSession] = None, + jhiveContext: Optional[JavaObject] = None, + ): warnings.warn( "HiveContext is deprecated in Spark 2.0.0. Please use " + "SparkSession.builder.enableHiveSupport().getOrCreate() instead.", FutureWarning, ) + static_conf = {} if jhiveContext is None: - sparkContext._conf.set( # type: ignore[attr-defined] - "spark.sql.catalogImplementation", "hive" - ) - sparkSession = SparkSession.builder._sparkContext(sparkContext).getOrCreate() - else: - sparkSession = SparkSession(sparkContext, jhiveContext.sparkSession()) + static_conf = HiveContext._static_conf + # There can be only one running Spark context. That will automatically + # be used in the Spark session internally. + if sparkSession is not None: + sparkSession = SparkSession._getActiveSessionOrCreate(**static_conf) SQLContext.__init__(self, sparkContext, sparkSession, jhiveContext) + @classmethod + def _get_or_create( + cls: Type["SQLContext"], sc: SparkContext, **static_conf: Any + ) -> "SQLContext": + return SQLContext._get_or_create(sc, **HiveContext._static_conf) + @classmethod def _createForTesting(cls, sparkContext: SparkContext) -> "HiveContext": """(Internal use only) Create a new HiveContext for testing. diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 9e75006723eba..c5de9fb79571f 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -16,6 +16,7 @@ # import json +import os import sys import random import warnings @@ -38,12 +39,12 @@ TYPE_CHECKING, ) -from py4j.java_gateway import JavaObject # type: ignore[import] +from py4j.java_gateway import JavaObject -from pyspark import copy_func, since, _NoValue # type: ignore[attr-defined] +from pyspark import copy_func, since, _NoValue from pyspark._globals import _NoValueType from pyspark.context import SparkContext -from pyspark.rdd import ( # type: ignore[attr-defined] +from pyspark.rdd import ( RDD, _load_from_socket, _local_iterator_from_socket, @@ -70,6 +71,7 @@ from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame from pyspark.sql._typing import ColumnOrName, LiteralType, OptionalPrimitiveType from pyspark.sql.context import SQLContext + from pyspark.sql.session import SparkSession from pyspark.sql.group import GroupedData from pyspark.sql.observation import Observation @@ -102,12 +104,35 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) .. versionadded:: 1.3.0 + + .. note: A DataFrame should only be created as described above. It should not be directly + created via using the constructor. """ - def __init__(self, jdf: JavaObject, sql_ctx: "SQLContext"): - self._jdf = jdf - self.sql_ctx = sql_ctx - self._sc: SparkContext = cast(SparkContext, sql_ctx and sql_ctx._sc) + def __init__( + self, + jdf: JavaObject, + sql_ctx: Union["SQLContext", "SparkSession"], + ): + from pyspark.sql.context import SQLContext + + self._sql_ctx: Optional["SQLContext"] = None + + if isinstance(sql_ctx, SQLContext): + assert not os.environ.get("SPARK_TESTING") # Sanity check for our internal usage. + assert isinstance(sql_ctx, SQLContext) + # We should remove this if-else branch in the future release, and rename + # sql_ctx to session in the constructor. This is an internal code path but + # was kept with an warning because it's used intensively by third-party libraries. + warnings.warn("DataFrame constructor is internal. Do not directly use it.") + self._sql_ctx = sql_ctx + session = sql_ctx.sparkSession + else: + session = sql_ctx + self._session: "SparkSession" = session + + self._sc: SparkContext = sql_ctx._sc + self._jdf: JavaObject = jdf self.is_cached = False # initialized lazily self._schema: Optional[StructType] = None @@ -116,13 +141,41 @@ def __init__(self, jdf: JavaObject, sql_ctx: "SQLContext"): # by __repr__ and _repr_html_ while eager evaluation opened. self._support_repr_html = False + @property + def sql_ctx(self) -> "SQLContext": + from pyspark.sql.context import SQLContext + + warnings.warn( + "DataFrame.sql_ctx is an internal property, and will be removed " + "in future releases. Use DataFrame.sparkSession instead." + ) + if self._sql_ctx is None: + self._sql_ctx = SQLContext._get_or_create(self._sc) + return self._sql_ctx + + @property + def sparkSession(self) -> "SparkSession": + """Returns Spark session that created this :class:`DataFrame`. + + .. versionadded:: 3.3.0 + + Examples + -------- + >>> df = spark.range(1) + >>> type(df.sparkSession) + + """ + return self._session + @property # type: ignore[misc] @since(1.3) def rdd(self) -> "RDD[Row]": """Returns the content as an :class:`pyspark.RDD` of :class:`Row`.""" if self._lazy_rdd is None: jrdd = self._jdf.javaToPython() - self._lazy_rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(CPickleSerializer())) + self._lazy_rdd = RDD( + jrdd, self.sparkSession._sc, BatchedSerializer(CPickleSerializer()) + ) return self._lazy_rdd @property # type: ignore[misc] @@ -137,7 +190,7 @@ def stat(self) -> "DataFrameStatFunctions": """Returns a :class:`DataFrameStatFunctions` for statistic functions.""" return DataFrameStatFunctions(self) - def toJSON(self, use_unicode: bool = True) -> "RDD[str]": + def toJSON(self, use_unicode: bool = True) -> RDD[str]: """Converts a :class:`DataFrame` into a :class:`RDD` of string. Each row is turned into a JSON document as one element in the returned RDD. @@ -456,7 +509,7 @@ def exceptAll(self, other: "DataFrame") -> "DataFrame": +---+---+ """ - return DataFrame(self._jdf.exceptAll(other._jdf), self.sql_ctx) + return DataFrame(self._jdf.exceptAll(other._jdf), self.sparkSession) @since(1.3) def isLocal(self) -> bool: @@ -561,16 +614,13 @@ def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = print(self._jdf.showString(n, int_truncate, vertical)) def __repr__(self) -> str: - if ( - not self._support_repr_html - and self.sql_ctx._conf.isReplEagerEvalEnabled() # type: ignore[attr-defined] - ): + if not self._support_repr_html and self.sparkSession._jconf.isReplEagerEvalEnabled(): vertical = False return self._jdf.showString( - self.sql_ctx._conf.replEagerEvalMaxNumRows(), # type: ignore[attr-defined] - self.sql_ctx._conf.replEagerEvalTruncate(), # type: ignore[attr-defined] + self.sparkSession._jconf.replEagerEvalMaxNumRows(), + self.sparkSession._jconf.replEagerEvalTruncate(), vertical, - ) # type: ignore[attr-defined] + ) else: return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) @@ -581,13 +631,11 @@ def _repr_html_(self) -> Optional[str]: """ if not self._support_repr_html: self._support_repr_html = True - if self.sql_ctx._conf.isReplEagerEvalEnabled(): # type: ignore[attr-defined] - max_num_rows = max( - self.sql_ctx._conf.replEagerEvalMaxNumRows(), 0 # type: ignore[attr-defined] - ) + if self.sparkSession._jconf.isReplEagerEvalEnabled(): + max_num_rows = max(self.sparkSession._jconf.replEagerEvalMaxNumRows(), 0) sock_info = self._jdf.getRowsToPython( max_num_rows, - self.sql_ctx._conf.replEagerEvalTruncate(), # type: ignore[attr-defined] + self.sparkSession._jconf.replEagerEvalTruncate(), ) rows = list(_load_from_socket(sock_info, BatchedSerializer(CPickleSerializer()))) head = rows[0] @@ -631,7 +679,7 @@ def checkpoint(self, eager: bool = True) -> "DataFrame": This API is experimental. """ jdf = self._jdf.checkpoint(eager) - return DataFrame(jdf, self.sql_ctx) + return DataFrame(jdf, self.sparkSession) def localCheckpoint(self, eager: bool = True) -> "DataFrame": """Returns a locally checkpointed version of this :class:`DataFrame`. Checkpointing can be @@ -651,7 +699,7 @@ def localCheckpoint(self, eager: bool = True) -> "DataFrame": This API is experimental. """ jdf = self._jdf.localCheckpoint(eager) - return DataFrame(jdf, self.sql_ctx) + return DataFrame(jdf, self.sparkSession) def withWatermark(self, eventTime: str, delayThreshold: str) -> "DataFrame": """Defines an event time watermark for this :class:`DataFrame`. A watermark tracks a point @@ -695,7 +743,7 @@ def withWatermark(self, eventTime: str, delayThreshold: str) -> "DataFrame": if not delayThreshold or type(delayThreshold) is not str: raise TypeError("delayThreshold should be provided as a string interval") jdf = self._jdf.withWatermark(eventTime, delayThreshold) - return DataFrame(jdf, self.sql_ctx) + return DataFrame(jdf, self.sparkSession) def hint( self, name: str, *parameters: Union["PrimitiveType", List["PrimitiveType"]] @@ -740,7 +788,7 @@ def hint( ) jdf = self._jdf.hint(name, self._jseq(parameters)) - return DataFrame(jdf, self.sql_ctx) + return DataFrame(jdf, self.sparkSession) def count(self) -> int: """Returns the number of rows in this :class:`DataFrame`. @@ -804,7 +852,7 @@ def limit(self, num: int) -> "DataFrame": [] """ jdf = self._jdf.limit(num) - return DataFrame(jdf, self.sql_ctx) + return DataFrame(jdf, self.sparkSession) def take(self, num: int) -> List[Row]: """Returns the first ``num`` rows as a :class:`list` of :class:`Row`. @@ -882,9 +930,7 @@ def cache(self) -> "DataFrame": def persist( self, - storageLevel: StorageLevel = ( - StorageLevel.MEMORY_AND_DISK_DESER # type: ignore[attr-defined] - ), + storageLevel: StorageLevel = (StorageLevel.MEMORY_AND_DISK_DESER), ) -> "DataFrame": """Sets the storage level to persist the contents of the :class:`DataFrame` across operations after the first time it is computed. This can only be used to assign @@ -898,7 +944,7 @@ def persist( The default storage level has changed to `MEMORY_AND_DISK_DESER` to match Scala in 3.0. """ self.is_cached = True - javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) # type: ignore[attr-defined] + javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) self._jdf.persist(javaStorageLevel) return self @@ -970,7 +1016,7 @@ def coalesce(self, numPartitions: int) -> "DataFrame": >>> df.coalesce(1).rdd.getNumPartitions() 1 """ - return DataFrame(self._jdf.coalesce(numPartitions), self.sql_ctx) + return DataFrame(self._jdf.coalesce(numPartitions), self.sparkSession) @overload def repartition(self, numPartitions: int, *cols: "ColumnOrName") -> "DataFrame": @@ -1041,14 +1087,15 @@ def repartition( # type: ignore[misc] """ if isinstance(numPartitions, int): if len(cols) == 0: - return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx) + return DataFrame(self._jdf.repartition(numPartitions), self.sparkSession) else: return DataFrame( - self._jdf.repartition(numPartitions, self._jcols(*cols)), self.sql_ctx + self._jdf.repartition(numPartitions, self._jcols(*cols)), + self.sparkSession, ) elif isinstance(numPartitions, (str, Column)): cols = (numPartitions,) + cols - return DataFrame(self._jdf.repartition(self._jcols(*cols)), self.sql_ctx) + return DataFrame(self._jdf.repartition(self._jcols(*cols)), self.sparkSession) else: raise TypeError("numPartitions should be an int or Column") @@ -1115,11 +1162,12 @@ def repartitionByRange( # type: ignore[misc] raise ValueError("At least one partition-by expression must be specified.") else: return DataFrame( - self._jdf.repartitionByRange(numPartitions, self._jcols(*cols)), self.sql_ctx + self._jdf.repartitionByRange(numPartitions, self._jcols(*cols)), + self.sparkSession, ) elif isinstance(numPartitions, (str, Column)): cols = (numPartitions,) + cols - return DataFrame(self._jdf.repartitionByRange(self._jcols(*cols)), self.sql_ctx) + return DataFrame(self._jdf.repartitionByRange(self._jcols(*cols)), self.sparkSession) else: raise TypeError("numPartitions should be an int, string or Column") @@ -1133,7 +1181,7 @@ def distinct(self) -> "DataFrame": >>> df.distinct().count() 2 """ - return DataFrame(self._jdf.distinct(), self.sql_ctx) + return DataFrame(self._jdf.distinct(), self.sparkSession) @overload def sample(self, fraction: float, seed: Optional[int] = ...) -> "DataFrame": @@ -1228,7 +1276,7 @@ def sample( # type: ignore[misc] seed = int(seed) if seed is not None else None args = [arg for arg in [withReplacement, fraction, seed] if arg is not None] jdf = self._jdf.sample(*args) - return DataFrame(jdf, self.sql_ctx) + return DataFrame(jdf, self.sparkSession) def sampleBy( self, col: "ColumnOrName", fractions: Dict[Any, float], seed: Optional[int] = None @@ -1283,7 +1331,9 @@ def sampleBy( fractions[k] = float(v) col = col._jc seed = seed if seed is not None else random.randint(0, sys.maxsize) - return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx) + return DataFrame( + self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sparkSession + ) def randomSplit(self, weights: List[float], seed: Optional[int] = None) -> List["DataFrame"]: """Randomly splits this :class:`DataFrame` with the provided weights. @@ -1311,10 +1361,10 @@ def randomSplit(self, weights: List[float], seed: Optional[int] = None) -> List[ if w < 0.0: raise ValueError("Weights must be positive. Found weight value: %s" % w) seed = seed if seed is not None else random.randint(0, sys.maxsize) - rdd_array = self._jdf.randomSplit( - _to_list(self.sql_ctx._sc, cast(List["ColumnOrName"], weights)), int(seed) + df_array = self._jdf.randomSplit( + _to_list(self.sparkSession._sc, cast(List["ColumnOrName"], weights)), int(seed) ) - return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array] + return [DataFrame(df, self.sparkSession) for df in df_array] @property def dtypes(self) -> List[Tuple[str, str]]: @@ -1392,7 +1442,7 @@ def alias(self, alias: str) -> "DataFrame": [Row(name='Bob', name='Bob', age=5), Row(name='Alice', name='Alice', age=2)] """ assert isinstance(alias, str), "alias should be a string" - return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx) + return DataFrame(getattr(self._jdf, "as")(alias), self.sparkSession) def crossJoin(self, other: "DataFrame") -> "DataFrame": """Returns the cartesian product with another :class:`DataFrame`. @@ -1416,7 +1466,7 @@ def crossJoin(self, other: "DataFrame") -> "DataFrame": """ jdf = self._jdf.crossJoin(other._jdf) - return DataFrame(jdf, self.sql_ctx) + return DataFrame(jdf, self.sparkSession) def join( self, @@ -1486,7 +1536,7 @@ def join( on = self._jseq([]) assert isinstance(how, str), "how should be a string" jdf = self._jdf.join(other._jdf, on, how) - return DataFrame(jdf, self.sql_ctx) + return DataFrame(jdf, self.sparkSession) # TODO(SPARK-22947): Fix the DataFrame API. def _joinAsOf( @@ -1573,10 +1623,10 @@ def _joinAsOf( """ if isinstance(leftAsOfColumn, str): leftAsOfColumn = self[leftAsOfColumn] - left_as_of_jcol = cast(Column, leftAsOfColumn)._jc + left_as_of_jcol = leftAsOfColumn._jc if isinstance(rightAsOfColumn, str): rightAsOfColumn = other[rightAsOfColumn] - right_as_of_jcol = cast(Column, rightAsOfColumn)._jc + right_as_of_jcol = rightAsOfColumn._jc if on is not None and not isinstance(on, list): on = [on] # type: ignore[assignment] @@ -1607,7 +1657,7 @@ def _joinAsOf( allowExactMatches, direction, ) - return DataFrame(jdf, self.sql_ctx) + return DataFrame(jdf, self.sparkSession) def sortWithinPartitions( self, *cols: Union[str, Column, List[Union[str, Column]]], **kwargs: Any @@ -1639,7 +1689,7 @@ def sortWithinPartitions( +---+-----+ """ jdf = self._jdf.sortWithinPartitions(self._sort_cols(cols, kwargs)) - return DataFrame(jdf, self.sql_ctx) + return DataFrame(jdf, self.sparkSession) def sort( self, *cols: Union[str, Column, List[Union[str, Column]]], **kwargs: Any @@ -1677,7 +1727,7 @@ def sort( [Row(age=5, name='Bob'), Row(age=2, name='Alice')] """ jdf = self._jdf.sort(self._sort_cols(cols, kwargs)) - return DataFrame(jdf, self.sql_ctx) + return DataFrame(jdf, self.sparkSession) orderBy = sort @@ -1687,11 +1737,11 @@ def _jseq( converter: Optional[Callable[..., Union["PrimitiveType", JavaObject]]] = None, ) -> JavaObject: """Return a JVM Seq of Columns from a list of Column or names""" - return _to_seq(self.sql_ctx._sc, cols, converter) + return _to_seq(self.sparkSession._sc, cols, converter) def _jmap(self, jm: Dict) -> JavaObject: """Return a JVM Scala Map from a dict""" - return _to_scala_map(self.sql_ctx._sc, jm) + return _to_scala_map(self.sparkSession._sc, jm) def _jcols(self, *cols: "ColumnOrName") -> JavaObject: """Return a JVM Seq of Columns from a list of Column or column names @@ -1739,26 +1789,31 @@ def describe(self, *cols: Union[str, List[str]]) -> "DataFrame": Examples -------- + >>> df = spark.createDataFrame( + ... [("Bob", 13, 40.3, 150.5), ("Alice", 12, 37.8, 142.3), ("Tom", 11, 44.1, 142.2)], + ... ["name", "age", "weight", "height"], + ... ) >>> df.describe(['age']).show() - +-------+------------------+ - |summary| age| - +-------+------------------+ - | count| 2| - | mean| 3.5| - | stddev|2.1213203435596424| - | min| 2| - | max| 5| - +-------+------------------+ - >>> df.describe().show() - +-------+------------------+-----+ - |summary| age| name| - +-------+------------------+-----+ - | count| 2| 2| - | mean| 3.5| null| - | stddev|2.1213203435596424| null| - | min| 2|Alice| - | max| 5| Bob| - +-------+------------------+-----+ + +-------+----+ + |summary| age| + +-------+----+ + | count| 3| + | mean|12.0| + | stddev| 1.0| + | min| 11| + | max| 13| + +-------+----+ + + >>> df.describe(['age', 'weight', 'height']).show() + +-------+----+------------------+-----------------+ + |summary| age| weight| height| + +-------+----+------------------+-----------------+ + | count| 3| 3| 3| + | mean|12.0| 40.73333333333333| 145.0| + | stddev| 1.0|3.1722757341273704|4.763402145525822| + | min| 11| 37.8| 142.2| + | max| 13| 44.1| 150.5| + +-------+----+------------------+-----------------+ See Also -------- @@ -1767,7 +1822,7 @@ def describe(self, *cols: Union[str, List[str]]) -> "DataFrame": if len(cols) == 1 and isinstance(cols[0], list): cols = cols[0] # type: ignore[assignment] jdf = self._jdf.describe(self._jseq(cols)) - return DataFrame(jdf, self.sql_ctx) + return DataFrame(jdf, self.sparkSession) def summary(self, *statistics: str) -> "DataFrame": """Computes specified statistics for numeric and string columns. Available statistics are: @@ -1791,39 +1846,34 @@ def summary(self, *statistics: str) -> "DataFrame": Examples -------- - >>> df.summary().show() - +-------+------------------+-----+ - |summary| age| name| - +-------+------------------+-----+ - | count| 2| 2| - | mean| 3.5| null| - | stddev|2.1213203435596424| null| - | min| 2|Alice| - | 25%| 2| null| - | 50%| 2| null| - | 75%| 5| null| - | max| 5| Bob| - +-------+------------------+-----+ - - >>> df.summary("count", "min", "25%", "75%", "max").show() - +-------+---+-----+ - |summary|age| name| - +-------+---+-----+ - | count| 2| 2| - | min| 2|Alice| - | 25%| 2| null| - | 75%| 5| null| - | max| 5| Bob| - +-------+---+-----+ - - To do a summary for specific columns first select them: - - >>> df.select("age", "name").summary("count").show() - +-------+---+----+ - |summary|age|name| - +-------+---+----+ - | count| 2| 2| - +-------+---+----+ + >>> df = spark.createDataFrame( + ... [("Bob", 13, 40.3, 150.5), ("Alice", 12, 37.8, 142.3), ("Tom", 11, 44.1, 142.2)], + ... ["name", "age", "weight", "height"], + ... ) + >>> df.select("age", "weight", "height").summary().show() + +-------+----+------------------+-----------------+ + |summary| age| weight| height| + +-------+----+------------------+-----------------+ + | count| 3| 3| 3| + | mean|12.0| 40.73333333333333| 145.0| + | stddev| 1.0|3.1722757341273704|4.763402145525822| + | min| 11| 37.8| 142.2| + | 25%| 11| 37.8| 142.2| + | 50%| 12| 40.3| 142.3| + | 75%| 13| 44.1| 150.5| + | max| 13| 44.1| 150.5| + +-------+----+------------------+-----------------+ + + >>> df.select("age", "weight", "height").summary("count", "min", "25%", "75%", "max").show() + +-------+---+------+------+ + |summary|age|weight|height| + +-------+---+------+------+ + | count| 3| 3| 3| + | min| 11| 37.8| 142.2| + | 25%| 11| 37.8| 142.2| + | 75%| 13| 44.1| 150.5| + | max| 13| 44.1| 150.5| + +-------+---+------+------+ See Also -------- @@ -1832,7 +1882,7 @@ def summary(self, *statistics: str) -> "DataFrame": if len(statistics) == 1 and isinstance(statistics[0], list): statistics = statistics[0] jdf = self._jdf.summary(self._jseq(statistics)) - return DataFrame(jdf, self.sql_ctx) + return DataFrame(jdf, self.sparkSession) @overload def head(self) -> Optional[Row]: @@ -1970,7 +2020,7 @@ def select(self, *cols: "ColumnOrName") -> "DataFrame": # type: ignore[misc] [Row(name='Alice', age=12), Row(name='Bob', age=15)] """ jdf = self._jdf.select(self._jcols(*cols)) - return DataFrame(jdf, self.sql_ctx) + return DataFrame(jdf, self.sparkSession) @overload def selectExpr(self, *expr: str) -> "DataFrame": @@ -1995,7 +2045,7 @@ def selectExpr(self, *expr: Union[str, List[str]]) -> "DataFrame": if len(expr) == 1 and isinstance(expr[0], list): expr = expr[0] # type: ignore[assignment] jdf = self._jdf.selectExpr(self._jseq(expr)) - return DataFrame(jdf, self.sql_ctx) + return DataFrame(jdf, self.sparkSession) def filter(self, condition: "ColumnOrName") -> "DataFrame": """Filters rows using the given condition. @@ -2028,7 +2078,7 @@ def filter(self, condition: "ColumnOrName") -> "DataFrame": jdf = self._jdf.filter(condition._jc) else: raise TypeError("condition should be string or Column") - return DataFrame(jdf, self.sql_ctx) + return DataFrame(jdf, self.sparkSession) @overload def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": @@ -2203,7 +2253,7 @@ def union(self, other: "DataFrame") -> "DataFrame": Also as standard in SQL, this function resolves columns by position (not by name). """ - return DataFrame(self._jdf.union(other._jdf), self.sql_ctx) + return DataFrame(self._jdf.union(other._jdf), self.sparkSession) @since(1.3) def unionAll(self, other: "DataFrame") -> "DataFrame": @@ -2260,7 +2310,7 @@ def unionByName(self, other: "DataFrame", allowMissingColumns: bool = False) -> Added optional argument `allowMissingColumns` to specify whether to allow missing columns. """ - return DataFrame(self._jdf.unionByName(other._jdf, allowMissingColumns), self.sql_ctx) + return DataFrame(self._jdf.unionByName(other._jdf, allowMissingColumns), self.sparkSession) @since(1.3) def intersect(self, other: "DataFrame") -> "DataFrame": @@ -2269,7 +2319,7 @@ def intersect(self, other: "DataFrame") -> "DataFrame": This is equivalent to `INTERSECT` in SQL. """ - return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx) + return DataFrame(self._jdf.intersect(other._jdf), self.sparkSession) def intersectAll(self, other: "DataFrame") -> "DataFrame": """Return a new :class:`DataFrame` containing rows in both this :class:`DataFrame` @@ -2295,7 +2345,7 @@ def intersectAll(self, other: "DataFrame") -> "DataFrame": +---+---+ """ - return DataFrame(self._jdf.intersectAll(other._jdf), self.sql_ctx) + return DataFrame(self._jdf.intersectAll(other._jdf), self.sparkSession) @since(1.3) def subtract(self, other: "DataFrame") -> "DataFrame": @@ -2305,7 +2355,7 @@ def subtract(self, other: "DataFrame") -> "DataFrame": This is equivalent to `EXCEPT DISTINCT` in SQL. """ - return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) + return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sparkSession) def dropDuplicates(self, subset: Optional[List[str]] = None) -> "DataFrame": """Return a new :class:`DataFrame` with duplicate rows removed, @@ -2350,7 +2400,7 @@ def dropDuplicates(self, subset: Optional[List[str]] = None) -> "DataFrame": jdf = self._jdf.dropDuplicates() else: jdf = self._jdf.dropDuplicates(self._jseq(subset)) - return DataFrame(jdf, self.sql_ctx) + return DataFrame(jdf, self.sparkSession) def dropna( self, @@ -2398,7 +2448,7 @@ def dropna( if thresh is None: thresh = len(subset) if how == "any" else 1 - return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)), self.sql_ctx) + return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)), self.sparkSession) @overload def fillna( @@ -2476,16 +2526,16 @@ def fillna( value = float(value) if isinstance(value, dict): - return DataFrame(self._jdf.na().fill(value), self.sql_ctx) + return DataFrame(self._jdf.na().fill(value), self.sparkSession) elif subset is None: - return DataFrame(self._jdf.na().fill(value), self.sql_ctx) + return DataFrame(self._jdf.na().fill(value), self.sparkSession) else: if isinstance(subset, str): subset = [subset] elif not isinstance(subset, (list, tuple)): raise TypeError("subset should be a list or tuple of column names") - return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) + return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sparkSession) @overload def replace( @@ -2686,10 +2736,11 @@ def all_of_(xs: Iterable) -> bool: raise ValueError("Mixed type replacements are not supported") if subset is None: - return DataFrame(self._jdf.na().replace("*", rep_dict), self.sql_ctx) + return DataFrame(self._jdf.na().replace("*", rep_dict), self.sparkSession) else: return DataFrame( - self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx + self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), + self.sparkSession, ) @overload @@ -2875,7 +2926,7 @@ def crosstab(self, col1: str, col2: str) -> "DataFrame": raise TypeError("col1 should be a string.") if not isinstance(col2, str): raise TypeError("col2 should be a string.") - return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx) + return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sparkSession) def freqItems( self, cols: Union[List[str], Tuple[str]], support: Optional[float] = None @@ -2909,7 +2960,45 @@ def freqItems( raise TypeError("cols must be a list or tuple of column names as strings.") if not support: support = 0.01 - return DataFrame(self._jdf.stat().freqItems(_to_seq(self._sc, cols), support), self.sql_ctx) + return DataFrame( + self._jdf.stat().freqItems(_to_seq(self._sc, cols), support), self.sparkSession + ) + + def withColumns(self, *colsMap: Dict[str, Column]) -> "DataFrame": + """ + Returns a new :class:`DataFrame` by adding multiple columns or replacing the + existing columns that has the same names. + + The colsMap is a map of column name and column, the column must only refer to attributes + supplied by this Dataset. It is an error to add columns that refer to some other Dataset. + + .. versionadded:: 3.3.0 + Added support for multiple columns adding + + Parameters + ---------- + colsMap : dict + a dict of column name and :class:`Column`. Currently, only single map is supported. + + Examples + -------- + >>> df.withColumns({'age2': df.age + 2, 'age3': df.age + 3}).collect() + [Row(age=2, name='Alice', age2=4, age3=5), Row(age=5, name='Bob', age2=7, age3=8)] + """ + # Below code is to help enable kwargs in future. + assert len(colsMap) == 1 + colsMap = colsMap[0] # type: ignore[assignment] + + if not isinstance(colsMap, dict): + raise TypeError("colsMap must be dict of column name and column.") + + col_names = list(colsMap.keys()) + cols = list(colsMap.values()) + + return DataFrame( + self._jdf.withColumns(_to_seq(self._sc, col_names), self._jcols(*cols)), + self.sparkSession, + ) def withColumn(self, colName: str, col: Column) -> "DataFrame": """ @@ -2943,7 +3032,7 @@ def withColumn(self, colName: str, col: Column) -> "DataFrame": """ if not isinstance(col, Column): raise TypeError("col should be Column") - return DataFrame(self._jdf.withColumn(colName, col._jc), self.sql_ctx) + return DataFrame(self._jdf.withColumn(colName, col._jc), self.sparkSession) def withColumnRenamed(self, existing: str, new: str) -> "DataFrame": """Returns a new :class:`DataFrame` by renaming an existing column. @@ -2963,7 +3052,7 @@ def withColumnRenamed(self, existing: str, new: str) -> "DataFrame": >>> df.withColumnRenamed('age', 'age2').collect() [Row(age2=2, name='Alice'), Row(age2=5, name='Bob')] """ - return DataFrame(self._jdf.withColumnRenamed(existing, new), self.sql_ctx) + return DataFrame(self._jdf.withColumnRenamed(existing, new), self.sparkSession) def withMetadata(self, columnName: str, metadata: Dict[str, Any]) -> "DataFrame": """Returns a new :class:`DataFrame` by updating an existing column with metadata. @@ -2988,7 +3077,7 @@ def withMetadata(self, columnName: str, metadata: Dict[str, Any]) -> "DataFrame" sc = SparkContext._active_spark_context assert sc is not None and sc._jvm is not None jmeta = sc._jvm.org.apache.spark.sql.types.Metadata.fromJson(json.dumps(metadata)) - return DataFrame(self._jdf.withMetadata(columnName, jmeta), self.sql_ctx) + return DataFrame(self._jdf.withMetadata(columnName, jmeta), self.sparkSession) @overload def drop(self, cols: "ColumnOrName") -> "DataFrame": @@ -3040,7 +3129,7 @@ def drop(self, *cols: "ColumnOrName") -> "DataFrame": # type: ignore[misc] raise TypeError("each col in the param list should be a string") jdf = self._jdf.drop(self._jseq(cols)) - return DataFrame(jdf, self.sql_ctx) + return DataFrame(jdf, self.sparkSession) def toDF(self, *cols: "ColumnOrName") -> "DataFrame": """Returns a new :class:`DataFrame` that with new specified column names @@ -3056,7 +3145,7 @@ def toDF(self, *cols: "ColumnOrName") -> "DataFrame": [Row(f1=2, f2='Alice'), Row(f1=5, f2='Bob')] """ jdf = self._jdf.toDF(self._jseq(cols)) - return DataFrame(jdf, self.sql_ctx) + return DataFrame(jdf, self.sparkSession) def transform(self, func: Callable[..., "DataFrame"], *args: Any, **kwargs: Any) -> "DataFrame": """Returns a new :class:`DataFrame`. Concise syntax for chaining custom transformations. @@ -3310,7 +3399,10 @@ def __init__(self, df: DataFrame): self.df = df def drop( - self, how: str = "any", thresh: Optional[int] = None, subset: Optional[List[str]] = None + self, + how: str = "any", + thresh: Optional[int] = None, + subset: Optional[Union[str, Tuple[str, ...], List[str]]] = None, ) -> DataFrame: return self.df.dropna(how=how, thresh=thresh, subset=subset) @@ -3399,7 +3491,7 @@ def approxQuantile( ) -> List[List[float]]: ... - def approxQuantile( # type: ignore[misc] + def approxQuantile( self, col: Union[str, List[str], Tuple[str]], probabilities: Union[List[float], Tuple[float]], diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d9ba4220e93aa..06fdbf1ed3904 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -28,6 +28,7 @@ Callable, Dict, List, + Iterable, overload, Optional, Tuple, @@ -65,7 +66,7 @@ # since it requires to make every single overridden definition. -def _get_get_jvm_function(name: str, sc: SparkContext) -> Callable: +def _get_jvm_function(name: str, sc: SparkContext) -> Callable: """ Retrieves JVM function identified by name from Java gateway associated with sc. @@ -80,16 +81,26 @@ def _invoke_function(name: str, *args: Any) -> Column: and wraps the result with :class:`~pyspark.sql.Column`. """ assert SparkContext._active_spark_context is not None - jf = _get_get_jvm_function(name, SparkContext._active_spark_context) + jf = _get_jvm_function(name, SparkContext._active_spark_context) return Column(jf(*args)) -def _invoke_function_over_column(name: str, col: "ColumnOrName") -> Column: +def _invoke_function_over_columns(name: str, *cols: "ColumnOrName") -> Column: """ - Invokes unary JVM function identified by name + Invokes n-ary JVM function identified by name and wraps the result with :class:`~pyspark.sql.Column`. """ - return _invoke_function(name, _to_java_column(col)) + return _invoke_function(name, *(_to_java_column(col) for col in cols)) + + +def _invoke_function_over_seq_of_columns(name: str, cols: "Iterable[ColumnOrName]") -> Column: + """ + Invokes unary JVM function identified by name with + and wraps the result with :class:`~pyspark.sql.Column`. + """ + sc = SparkContext._active_spark_context + assert sc is not None and sc._jvm is not None + return _invoke_function(name, _to_seq(sc, cols, _to_java_column)) def _invoke_binary_math_function(name: str, col1: Any, col2: Any) -> Column: @@ -164,7 +175,7 @@ def sqrt(col: "ColumnOrName") -> Column: """ Computes the square root of the specified float value. """ - return _invoke_function_over_column("sqrt", col) + return _invoke_function_over_columns("sqrt", col) @since(1.3) @@ -172,7 +183,7 @@ def abs(col: "ColumnOrName") -> Column: """ Computes the absolute value. """ - return _invoke_function_over_column("abs", col) + return _invoke_function_over_columns("abs", col) @since(1.3) @@ -180,7 +191,7 @@ def max(col: "ColumnOrName") -> Column: """ Aggregate function: returns the maximum value of the expression in a group. """ - return _invoke_function_over_column("max", col) + return _invoke_function_over_columns("max", col) @since(1.3) @@ -188,7 +199,7 @@ def min(col: "ColumnOrName") -> Column: """ Aggregate function: returns the minimum value of the expression in a group. """ - return _invoke_function_over_column("min", col) + return _invoke_function_over_columns("min", col) def max_by(col: "ColumnOrName", ord: "ColumnOrName") -> Column: @@ -223,7 +234,7 @@ def max_by(col: "ColumnOrName", ord: "ColumnOrName") -> Column: |dotNET| 2013| +------+----------------------+ """ - return _invoke_function("max_by", _to_java_column(col), _to_java_column(ord)) + return _invoke_function_over_columns("max_by", col, ord) def min_by(col: "ColumnOrName", ord: "ColumnOrName") -> Column: @@ -258,7 +269,7 @@ def min_by(col: "ColumnOrName", ord: "ColumnOrName") -> Column: |dotNET| 2012| +------+----------------------+ """ - return _invoke_function("min_by", _to_java_column(col), _to_java_column(ord)) + return _invoke_function_over_columns("min_by", col, ord) @since(1.3) @@ -266,7 +277,7 @@ def count(col: "ColumnOrName") -> Column: """ Aggregate function: returns the number of items in a group. """ - return _invoke_function_over_column("count", col) + return _invoke_function_over_columns("count", col) @since(1.3) @@ -274,7 +285,7 @@ def sum(col: "ColumnOrName") -> Column: """ Aggregate function: returns the sum of all values in the expression. """ - return _invoke_function_over_column("sum", col) + return _invoke_function_over_columns("sum", col) @since(1.3) @@ -282,7 +293,7 @@ def avg(col: "ColumnOrName") -> Column: """ Aggregate function: returns the average of the values in a group. """ - return _invoke_function_over_column("avg", col) + return _invoke_function_over_columns("avg", col) @since(1.3) @@ -290,7 +301,7 @@ def mean(col: "ColumnOrName") -> Column: """ Aggregate function: returns the average of the values in a group. """ - return _invoke_function_over_column("mean", col) + return _invoke_function_over_columns("mean", col) @since(1.3) @@ -310,7 +321,7 @@ def sum_distinct(col: "ColumnOrName") -> Column: """ Aggregate function: returns the sum of distinct values in the expression. """ - return _invoke_function_over_column("sum_distinct", col) + return _invoke_function_over_columns("sum_distinct", col) def product(col: "ColumnOrName") -> Column: @@ -338,7 +349,7 @@ def product(col: "ColumnOrName") -> Column: +----+-------+ """ - return _invoke_function_over_column("product", col) + return _invoke_function_over_columns("product", col) def acos(col: "ColumnOrName") -> Column: @@ -352,7 +363,7 @@ def acos(col: "ColumnOrName") -> Column: :class:`~pyspark.sql.Column` inverse cosine of `col`, as if computed by `java.lang.Math.acos()` """ - return _invoke_function_over_column("acos", col) + return _invoke_function_over_columns("acos", col) def acosh(col: "ColumnOrName") -> Column: @@ -365,7 +376,7 @@ def acosh(col: "ColumnOrName") -> Column: ------- :class:`~pyspark.sql.Column` """ - return _invoke_function_over_column("acosh", col) + return _invoke_function_over_columns("acosh", col) def asin(col: "ColumnOrName") -> Column: @@ -380,7 +391,7 @@ def asin(col: "ColumnOrName") -> Column: :class:`~pyspark.sql.Column` inverse sine of `col`, as if computed by `java.lang.Math.asin()` """ - return _invoke_function_over_column("asin", col) + return _invoke_function_over_columns("asin", col) def asinh(col: "ColumnOrName") -> Column: @@ -393,7 +404,7 @@ def asinh(col: "ColumnOrName") -> Column: ------- :class:`~pyspark.sql.Column` """ - return _invoke_function_over_column("asinh", col) + return _invoke_function_over_columns("asinh", col) def atan(col: "ColumnOrName") -> Column: @@ -407,7 +418,7 @@ def atan(col: "ColumnOrName") -> Column: :class:`~pyspark.sql.Column` inverse tangent of `col`, as if computed by `java.lang.Math.atan()` """ - return _invoke_function_over_column("atan", col) + return _invoke_function_over_columns("atan", col) def atanh(col: "ColumnOrName") -> Column: @@ -420,7 +431,7 @@ def atanh(col: "ColumnOrName") -> Column: ------- :class:`~pyspark.sql.Column` """ - return _invoke_function_over_column("atanh", col) + return _invoke_function_over_columns("atanh", col) @since(1.4) @@ -428,7 +439,7 @@ def cbrt(col: "ColumnOrName") -> Column: """ Computes the cube-root of the given value. """ - return _invoke_function_over_column("cbrt", col) + return _invoke_function_over_columns("cbrt", col) @since(1.4) @@ -436,7 +447,7 @@ def ceil(col: "ColumnOrName") -> Column: """ Computes the ceiling of the given value. """ - return _invoke_function_over_column("ceil", col) + return _invoke_function_over_columns("ceil", col) def cos(col: "ColumnOrName") -> Column: @@ -455,7 +466,7 @@ def cos(col: "ColumnOrName") -> Column: :class:`~pyspark.sql.Column` cosine of the angle, as if computed by `java.lang.Math.cos()`. """ - return _invoke_function_over_column("cos", col) + return _invoke_function_over_columns("cos", col) def cosh(col: "ColumnOrName") -> Column: @@ -474,7 +485,7 @@ def cosh(col: "ColumnOrName") -> Column: :class:`~pyspark.sql.Column` hyperbolic cosine of the angle, as if computed by `java.lang.Math.cosh()` """ - return _invoke_function_over_column("cosh", col) + return _invoke_function_over_columns("cosh", col) def cot(col: "ColumnOrName") -> Column: @@ -493,7 +504,7 @@ def cot(col: "ColumnOrName") -> Column: :class:`~pyspark.sql.Column` Cotangent of the angle. """ - return _invoke_function_over_column("cot", col) + return _invoke_function_over_columns("cot", col) def csc(col: "ColumnOrName") -> Column: @@ -512,7 +523,7 @@ def csc(col: "ColumnOrName") -> Column: :class:`~pyspark.sql.Column` Cosecant of the angle. """ - return _invoke_function_over_column("csc", col) + return _invoke_function_over_columns("csc", col) @since(1.4) @@ -520,7 +531,7 @@ def exp(col: "ColumnOrName") -> Column: """ Computes the exponential of the given value. """ - return _invoke_function_over_column("exp", col) + return _invoke_function_over_columns("exp", col) @since(1.4) @@ -528,7 +539,7 @@ def expm1(col: "ColumnOrName") -> Column: """ Computes the exponential of the given value minus one. """ - return _invoke_function_over_column("expm1", col) + return _invoke_function_over_columns("expm1", col) @since(1.4) @@ -536,7 +547,7 @@ def floor(col: "ColumnOrName") -> Column: """ Computes the floor of the given value. """ - return _invoke_function_over_column("floor", col) + return _invoke_function_over_columns("floor", col) @since(1.4) @@ -544,7 +555,7 @@ def log(col: "ColumnOrName") -> Column: """ Computes the natural logarithm of the given value. """ - return _invoke_function_over_column("log", col) + return _invoke_function_over_columns("log", col) @since(1.4) @@ -552,7 +563,7 @@ def log10(col: "ColumnOrName") -> Column: """ Computes the logarithm of the given value in Base 10. """ - return _invoke_function_over_column("log10", col) + return _invoke_function_over_columns("log10", col) @since(1.4) @@ -560,7 +571,7 @@ def log1p(col: "ColumnOrName") -> Column: """ Computes the natural logarithm of the given value plus one. """ - return _invoke_function_over_column("log1p", col) + return _invoke_function_over_columns("log1p", col) @since(1.4) @@ -569,7 +580,7 @@ def rint(col: "ColumnOrName") -> Column: Returns the double value that is closest in value to the argument and is equal to a mathematical integer. """ - return _invoke_function_over_column("rint", col) + return _invoke_function_over_columns("rint", col) def sec(col: "ColumnOrName") -> Column: @@ -588,7 +599,7 @@ def sec(col: "ColumnOrName") -> Column: :class:`~pyspark.sql.Column` Secant of the angle. """ - return _invoke_function_over_column("sec", col) + return _invoke_function_over_columns("sec", col) @since(1.4) @@ -596,7 +607,7 @@ def signum(col: "ColumnOrName") -> Column: """ Computes the signum of the given value. """ - return _invoke_function_over_column("signum", col) + return _invoke_function_over_columns("signum", col) def sin(col: "ColumnOrName") -> Column: @@ -614,7 +625,7 @@ def sin(col: "ColumnOrName") -> Column: :class:`~pyspark.sql.Column` sine of the angle, as if computed by `java.lang.Math.sin()` """ - return _invoke_function_over_column("sin", col) + return _invoke_function_over_columns("sin", col) def sinh(col: "ColumnOrName") -> Column: @@ -634,7 +645,7 @@ def sinh(col: "ColumnOrName") -> Column: hyperbolic sine of the given value, as if computed by `java.lang.Math.sinh()` """ - return _invoke_function_over_column("sinh", col) + return _invoke_function_over_columns("sinh", col) def tan(col: "ColumnOrName") -> Column: @@ -653,7 +664,7 @@ def tan(col: "ColumnOrName") -> Column: :class:`~pyspark.sql.Column` tangent of the given value, as if computed by `java.lang.Math.tan()` """ - return _invoke_function_over_column("tan", col) + return _invoke_function_over_columns("tan", col) def tanh(col: "ColumnOrName") -> Column: @@ -673,7 +684,7 @@ def tanh(col: "ColumnOrName") -> Column: hyperbolic tangent of the given value as if computed by `java.lang.Math.tanh()` """ - return _invoke_function_over_column("tanh", col) + return _invoke_function_over_columns("tanh", col) @since(1.4) @@ -713,7 +724,7 @@ def bitwise_not(col: "ColumnOrName") -> Column: """ Computes bitwise not. """ - return _invoke_function_over_column("bitwise_not", col) + return _invoke_function_over_columns("bitwise_not", col) @since(2.4) @@ -771,7 +782,7 @@ def stddev(col: "ColumnOrName") -> Column: """ Aggregate function: alias for stddev_samp. """ - return _invoke_function_over_column("stddev", col) + return _invoke_function_over_columns("stddev", col) @since(1.6) @@ -780,7 +791,7 @@ def stddev_samp(col: "ColumnOrName") -> Column: Aggregate function: returns the unbiased sample standard deviation of the expression in a group. """ - return _invoke_function_over_column("stddev_samp", col) + return _invoke_function_over_columns("stddev_samp", col) @since(1.6) @@ -789,7 +800,7 @@ def stddev_pop(col: "ColumnOrName") -> Column: Aggregate function: returns population standard deviation of the expression in a group. """ - return _invoke_function_over_column("stddev_pop", col) + return _invoke_function_over_columns("stddev_pop", col) @since(1.6) @@ -797,7 +808,7 @@ def variance(col: "ColumnOrName") -> Column: """ Aggregate function: alias for var_samp """ - return _invoke_function_over_column("variance", col) + return _invoke_function_over_columns("variance", col) @since(1.6) @@ -806,7 +817,7 @@ def var_samp(col: "ColumnOrName") -> Column: Aggregate function: returns the unbiased sample variance of the values in a group. """ - return _invoke_function_over_column("var_samp", col) + return _invoke_function_over_columns("var_samp", col) @since(1.6) @@ -814,7 +825,7 @@ def var_pop(col: "ColumnOrName") -> Column: """ Aggregate function: returns the population variance of the values in a group. """ - return _invoke_function_over_column("var_pop", col) + return _invoke_function_over_columns("var_pop", col) @since(1.6) @@ -822,7 +833,7 @@ def skewness(col: "ColumnOrName") -> Column: """ Aggregate function: returns the skewness of the values in a group. """ - return _invoke_function_over_column("skewness", col) + return _invoke_function_over_columns("skewness", col) @since(1.6) @@ -830,7 +841,7 @@ def kurtosis(col: "ColumnOrName") -> Column: """ Aggregate function: returns the kurtosis of the values in a group. """ - return _invoke_function_over_column("kurtosis", col) + return _invoke_function_over_columns("kurtosis", col) def collect_list(col: "ColumnOrName") -> Column: @@ -850,7 +861,7 @@ def collect_list(col: "ColumnOrName") -> Column: >>> df2.agg(collect_list('age')).collect() [Row(collect_list(age)=[2, 5, 5])] """ - return _invoke_function_over_column("collect_list", col) + return _invoke_function_over_columns("collect_list", col) def collect_set(col: "ColumnOrName") -> Column: @@ -870,7 +881,7 @@ def collect_set(col: "ColumnOrName") -> Column: >>> df2.agg(array_sort(collect_set('age')).alias('c')).collect() [Row(c=[2, 5])] """ - return _invoke_function_over_column("collect_set", col) + return _invoke_function_over_columns("collect_set", col) def degrees(col: "ColumnOrName") -> Column: @@ -890,7 +901,7 @@ def degrees(col: "ColumnOrName") -> Column: :class:`~pyspark.sql.Column` angle in degrees, as if computed by `java.lang.Math.toDegrees()` """ - return _invoke_function_over_column("degrees", col) + return _invoke_function_over_columns("degrees", col) def radians(col: "ColumnOrName") -> Column: @@ -910,7 +921,7 @@ def radians(col: "ColumnOrName") -> Column: :class:`~pyspark.sql.Column` angle in radians, as if computed by `java.lang.Math.toRadians()` """ - return _invoke_function_over_column("radians", col) + return _invoke_function_over_columns("radians", col) @overload @@ -1082,13 +1093,10 @@ def approx_count_distinct(col: "ColumnOrName", rsd: Optional[float] = None) -> C >>> df.agg(approx_count_distinct(df.age).alias('distinct_ages')).collect() [Row(distinct_ages=2)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None if rsd is None: - jc = sc._jvm.functions.approx_count_distinct(_to_java_column(col)) + return _invoke_function_over_columns("approx_count_distinct", col) else: - jc = sc._jvm.functions.approx_count_distinct(_to_java_column(col), rsd) - return Column(jc) + return _invoke_function("approx_count_distinct", _to_java_column(col), rsd) @since(1.6) @@ -1097,7 +1105,7 @@ def broadcast(df: DataFrame) -> DataFrame: sc = SparkContext._active_spark_context assert sc is not None and sc._jvm is not None - return DataFrame(sc._jvm.functions.broadcast(df._jdf), df.sql_ctx) + return DataFrame(sc._jvm.functions.broadcast(df._jdf), df.sparkSession) def coalesce(*cols: "ColumnOrName") -> Column: @@ -1135,10 +1143,7 @@ def coalesce(*cols: "ColumnOrName") -> Column: |null| 2| 0.0| +----+----+----------------+ """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jc = sc._jvm.functions.coalesce(_to_seq(sc, cols, _to_java_column)) - return Column(jc) + return _invoke_function_over_seq_of_columns("coalesce", cols) def corr(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: @@ -1155,9 +1160,7 @@ def corr(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: >>> df.agg(corr("a", "b").alias('c')).collect() [Row(c=1.0)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.corr(_to_java_column(col1), _to_java_column(col2))) + return _invoke_function_over_columns("corr", col1, col2) def covar_pop(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: @@ -1174,9 +1177,7 @@ def covar_pop(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: >>> df.agg(covar_pop("a", "b").alias('c')).collect() [Row(c=0.0)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.covar_pop(_to_java_column(col1), _to_java_column(col2))) + return _invoke_function_over_columns("covar_pop", col1, col2) def covar_samp(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: @@ -1193,9 +1194,7 @@ def covar_samp(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: >>> df.agg(covar_samp("a", "b").alias('c')).collect() [Row(c=0.0)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.covar_samp(_to_java_column(col1), _to_java_column(col2))) + return _invoke_function_over_columns("covar_samp", col1, col2) def countDistinct(col: "ColumnOrName", *cols: "ColumnOrName") -> Column: @@ -1224,8 +1223,9 @@ def count_distinct(col: "ColumnOrName", *cols: "ColumnOrName") -> Column: """ sc = SparkContext._active_spark_context assert sc is not None and sc._jvm is not None - jc = sc._jvm.functions.count_distinct(_to_java_column(col), _to_seq(sc, cols, _to_java_column)) - return Column(jc) + return _invoke_function( + "count_distinct", _to_java_column(col), _to_seq(sc, cols, _to_java_column) + ) def first(col: "ColumnOrName", ignorenulls: bool = False) -> Column: @@ -1241,10 +1241,7 @@ def first(col: "ColumnOrName", ignorenulls: bool = False) -> Column: The function is non-deterministic because its results depends on the order of the rows which may be non-deterministic after a shuffle. """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jc = sc._jvm.functions.first(_to_java_column(col), ignorenulls) - return Column(jc) + return _invoke_function("first", _to_java_column(col), ignorenulls) def grouping(col: "ColumnOrName") -> Column: @@ -1265,10 +1262,7 @@ def grouping(col: "ColumnOrName") -> Column: | Bob| 0| 5| +-----+--------------+--------+ """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jc = sc._jvm.functions.grouping(_to_java_column(col)) - return Column(jc) + return _invoke_function_over_columns("grouping", col) def grouping_id(*cols: "ColumnOrName") -> Column: @@ -1295,18 +1289,13 @@ def grouping_id(*cols: "ColumnOrName") -> Column: | Bob| 0| 5| +-----+-------------+--------+ """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jc = sc._jvm.functions.grouping_id(_to_seq(sc, cols, _to_java_column)) - return Column(jc) + return _invoke_function_over_seq_of_columns("grouping_id", cols) @since(1.6) def input_file_name() -> Column: """Creates a string column for the file name of the current Spark task.""" - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.input_file_name()) + return _invoke_function("input_file_name") def isnan(col: "ColumnOrName") -> Column: @@ -1320,9 +1309,7 @@ def isnan(col: "ColumnOrName") -> Column: >>> df.select(isnan("a").alias("r1"), isnan(df.a).alias("r2")).collect() [Row(r1=False, r2=False), Row(r1=True, r2=True)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.isnan(_to_java_column(col))) + return _invoke_function_over_columns("isnan", col) def isnull(col: "ColumnOrName") -> Column: @@ -1336,9 +1323,7 @@ def isnull(col: "ColumnOrName") -> Column: >>> df.select(isnull("a").alias("r1"), isnull(df.a).alias("r2")).collect() [Row(r1=False, r2=False), Row(r1=True, r2=True)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.isnull(_to_java_column(col))) + return _invoke_function_over_columns("isnull", col) def last(col: "ColumnOrName", ignorenulls: bool = False) -> Column: @@ -1354,10 +1339,7 @@ def last(col: "ColumnOrName", ignorenulls: bool = False) -> Column: The function is non-deterministic because its results depends on the order of the rows which may be non-deterministic after a shuffle. """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jc = sc._jvm.functions.last(_to_java_column(col), ignorenulls) - return Column(jc) + return _invoke_function("last", _to_java_column(col), ignorenulls) def monotonically_increasing_id() -> Column: @@ -1382,9 +1364,7 @@ def monotonically_increasing_id() -> Column: >>> df0.select(monotonically_increasing_id().alias('id')).collect() [Row(id=0), Row(id=1), Row(id=2), Row(id=8589934592), Row(id=8589934593), Row(id=8589934594)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.monotonically_increasing_id()) + return _invoke_function("monotonically_increasing_id") def nanvl(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: @@ -1400,9 +1380,7 @@ def nanvl(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: >>> df.select(nanvl("a", "b").alias("r1"), nanvl(df.a, df.b).alias("r2")).collect() [Row(r1=1.0, r2=1.0), Row(r1=2.0, r2=2.0)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.nanvl(_to_java_column(col1), _to_java_column(col2))) + return _invoke_function_over_columns("nanvl", col1, col2) def percentile_approx( @@ -1450,9 +1428,9 @@ def percentile_approx( if isinstance(percentage, (list, tuple)): # A local list - percentage = sc._jvm.functions.array( - _to_seq(sc, [_create_column_from_literal(x) for x in percentage]) - ) + percentage = _invoke_function( + "array", _to_seq(sc, [_create_column_from_literal(x) for x in percentage]) + )._jc elif isinstance(percentage, Column): # Already a Column percentage = _to_java_column(percentage) @@ -1466,7 +1444,7 @@ def percentile_approx( else _create_column_from_literal(accuracy) ) - return Column(sc._jvm.functions.percentile_approx(_to_java_column(col), percentage, accuracy)) + return _invoke_function("percentile_approx", _to_java_column(col), percentage, accuracy) def rand(seed: Optional[int] = None) -> Column: @@ -1485,13 +1463,10 @@ def rand(seed: Optional[int] = None) -> Column: [Row(age=2, name='Alice', rand=2.4052597283576684), Row(age=5, name='Bob', rand=2.3913904055683974)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None if seed is not None: - jc = sc._jvm.functions.rand(seed) + return _invoke_function("rand", seed) else: - jc = sc._jvm.functions.rand() - return Column(jc) + return _invoke_function("rand") def randn(seed: Optional[int] = None) -> Column: @@ -1510,13 +1485,10 @@ def randn(seed: Optional[int] = None) -> Column: [Row(age=2, name='Alice', randn=1.1027054481455365), Row(age=5, name='Bob', randn=0.7400395449950132)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None if seed is not None: - jc = sc._jvm.functions.randn(seed) + return _invoke_function("randn", seed) else: - jc = sc._jvm.functions.randn() - return Column(jc) + return _invoke_function("randn") def round(col: "ColumnOrName", scale: int = 0) -> Column: @@ -1531,9 +1503,7 @@ def round(col: "ColumnOrName", scale: int = 0) -> Column: >>> spark.createDataFrame([(2.5,)], ['a']).select(round('a', 0).alias('r')).collect() [Row(r=3.0)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.round(_to_java_column(col), scale)) + return _invoke_function("round", _to_java_column(col), scale) def bround(col: "ColumnOrName", scale: int = 0) -> Column: @@ -1548,9 +1518,7 @@ def bround(col: "ColumnOrName", scale: int = 0) -> Column: >>> spark.createDataFrame([(2.5,)], ['a']).select(bround('a', 0).alias('r')).collect() [Row(r=2.0)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.bround(_to_java_column(col), scale)) + return _invoke_function("bround", _to_java_column(col), scale) def shiftLeft(col: "ColumnOrName", numBits: int) -> Column: @@ -1575,9 +1543,7 @@ def shiftleft(col: "ColumnOrName", numBits: int) -> Column: >>> spark.createDataFrame([(21,)], ['a']).select(shiftleft('a', 1).alias('r')).collect() [Row(r=42)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.shiftleft(_to_java_column(col), numBits)) + return _invoke_function("shiftleft", _to_java_column(col), numBits) def shiftRight(col: "ColumnOrName", numBits: int) -> Column: @@ -1602,10 +1568,7 @@ def shiftright(col: "ColumnOrName", numBits: int) -> Column: >>> spark.createDataFrame([(42,)], ['a']).select(shiftright('a', 1).alias('r')).collect() [Row(r=21)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jc = sc._jvm.functions.shiftRight(_to_java_column(col), numBits) - return Column(jc) + return _invoke_function("shiftright", _to_java_column(col), numBits) def shiftRightUnsigned(col: "ColumnOrName", numBits: int) -> Column: @@ -1631,10 +1594,7 @@ def shiftrightunsigned(col: "ColumnOrName", numBits: int) -> Column: >>> df.select(shiftrightunsigned('a', 1).alias('r')).collect() [Row(r=9223372036854775787)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jc = sc._jvm.functions.shiftRightUnsigned(_to_java_column(col), numBits) - return Column(jc) + return _invoke_function("shiftrightunsigned", _to_java_column(col), numBits) def spark_partition_id() -> Column: @@ -1651,9 +1611,7 @@ def spark_partition_id() -> Column: >>> df.repartition(1).select(spark_partition_id().alias("pid")).collect() [Row(pid=0), Row(pid=0)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.spark_partition_id()) + return _invoke_function("spark_partition_id") def expr(str: str) -> Column: @@ -1666,9 +1624,7 @@ def expr(str: str) -> Column: >>> df.select(expr("length(name)")).collect() [Row(length(name)=5), Row(length(name)=3)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.expr(str)) + return _invoke_function("expr", str) @overload @@ -1700,12 +1656,9 @@ def struct( >>> df.select(struct([df.age, df.name]).alias("struct")).collect() [Row(struct=Row(age=2, name='Alice')), Row(struct=Row(age=5, name='Bob'))] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None if len(cols) == 1 and isinstance(cols[0], (list, set)): cols = cols[0] # type: ignore[assignment] - jc = sc._jvm.functions.struct(_to_seq(sc, cols, _to_java_column)) # type: ignore[arg-type] - return Column(jc) + return _invoke_function_over_seq_of_columns("struct", cols) # type: ignore[arg-type] def greatest(*cols: "ColumnOrName") -> Column: @@ -1723,9 +1676,7 @@ def greatest(*cols: "ColumnOrName") -> Column: """ if len(cols) < 2: raise ValueError("greatest should take at least two columns") - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.greatest(_to_seq(sc, cols, _to_java_column))) + return _invoke_function_over_seq_of_columns("greatest", cols) def least(*cols: "ColumnOrName") -> Column: @@ -1748,9 +1699,7 @@ def least(*cols: "ColumnOrName") -> Column: """ if len(cols) < 2: raise ValueError("least should take at least two columns") - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.least(_to_seq(sc, cols, _to_java_column))) + return _invoke_function_over_seq_of_columns("least", cols) def when(condition: Column, value: Any) -> Column: @@ -1773,15 +1722,12 @@ def when(condition: Column, value: Any) -> Column: >>> df.select(when(df.age == 2, df.age + 1).alias("age")).collect() [Row(age=3), Row(age=None)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - # Explicitly not using ColumnOrName type here to make reading condition less opaque if not isinstance(condition, Column): raise TypeError("condition should be a Column") v = value._jc if isinstance(value, Column) else value - jc = sc._jvm.functions.when(condition._jc, v) - return Column(jc) + + return _invoke_function("when", condition._jc, v) @overload # type: ignore[no-redef] @@ -1809,13 +1755,10 @@ def log(arg1: Union["ColumnOrName", float], arg2: Optional["ColumnOrName"] = Non >>> df.select(log(df.age).alias('e')).rdd.map(lambda l: str(l.e)[:7]).collect() ['0.69314', '1.60943'] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None if arg2 is None: - jc = sc._jvm.functions.log(_to_java_column(cast("ColumnOrName", arg1))) + return _invoke_function_over_columns("log", cast("ColumnOrName", arg1)) else: - jc = sc._jvm.functions.log(arg1, _to_java_column(arg2)) - return Column(jc) + return _invoke_function("log", arg1, _to_java_column(arg2)) def log2(col: "ColumnOrName") -> Column: @@ -1828,9 +1771,7 @@ def log2(col: "ColumnOrName") -> Column: >>> spark.createDataFrame([(4,)], ['a']).select(log2('a').alias('log2')).collect() [Row(log2=2.0)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.log2(_to_java_column(col))) + return _invoke_function_over_columns("log2", col) def conv(col: "ColumnOrName", fromBase: int, toBase: int) -> Column: @@ -1845,9 +1786,7 @@ def conv(col: "ColumnOrName", fromBase: int, toBase: int) -> Column: >>> df.select(conv(df.n, 2, 16).alias('hex')).collect() [Row(hex='15')] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.conv(_to_java_column(col), fromBase, toBase)) + return _invoke_function("conv", _to_java_column(col), fromBase, toBase) def factorial(col: "ColumnOrName") -> Column: @@ -1862,9 +1801,7 @@ def factorial(col: "ColumnOrName") -> Column: >>> df.select(factorial(df.n).alias('f')).collect() [Row(f=120)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.factorial(_to_java_column(col))) + return _invoke_function_over_columns("factorial", col) # --------------- Window functions ------------------------ @@ -1889,9 +1826,7 @@ def lag(col: "ColumnOrName", offset: int = 1, default: Optional[Any] = None) -> default : optional default value """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.lag(_to_java_column(col), offset, default)) + return _invoke_function("lag", _to_java_column(col), offset, default) def lead(col: "ColumnOrName", offset: int = 1, default: Optional[Any] = None) -> Column: @@ -1913,9 +1848,7 @@ def lead(col: "ColumnOrName", offset: int = 1, default: Optional[Any] = None) -> default : optional default value """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.lead(_to_java_column(col), offset, default)) + return _invoke_function("lead", _to_java_column(col), offset, default) def nth_value(col: "ColumnOrName", offset: int, ignoreNulls: Optional[bool] = False) -> Column: @@ -1940,9 +1873,7 @@ def nth_value(col: "ColumnOrName", offset: int, ignoreNulls: Optional[bool] = Fa indicates the Nth value should skip null in the determination of which row to use """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.nth_value(_to_java_column(col), offset, ignoreNulls)) + return _invoke_function("nth_value", _to_java_column(col), offset, ignoreNulls) def ntile(n: int) -> Column: @@ -1961,9 +1892,7 @@ def ntile(n: int) -> Column: n : int an integer """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.ntile(int(n))) + return _invoke_function("ntile", int(n)) # ---------------------- Date/Timestamp functions ------------------------------ @@ -1975,9 +1904,7 @@ def current_date() -> Column: Returns the current date at the start of query evaluation as a :class:`DateType` column. All calls of current_date within the same query return the same value. """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.current_date()) + return _invoke_function("current_date") def current_timestamp() -> Column: @@ -1985,9 +1912,7 @@ def current_timestamp() -> Column: Returns the current timestamp at the start of query evaluation as a :class:`TimestampType` column. All calls of current_timestamp within the same query return the same value. """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.current_timestamp()) + return _invoke_function("current_timestamp") def date_format(date: "ColumnOrName", format: str) -> Column: @@ -2012,9 +1937,7 @@ def date_format(date: "ColumnOrName", format: str) -> Column: >>> df.select(date_format('dt', 'MM/dd/yyy').alias('date')).collect() [Row(date='04/08/2015')] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.date_format(_to_java_column(date), format)) + return _invoke_function("date_format", _to_java_column(date), format) def year(col: "ColumnOrName") -> Column: @@ -2029,9 +1952,7 @@ def year(col: "ColumnOrName") -> Column: >>> df.select(year('dt').alias('year')).collect() [Row(year=2015)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.year(_to_java_column(col))) + return _invoke_function_over_columns("year", col) def quarter(col: "ColumnOrName") -> Column: @@ -2046,9 +1967,7 @@ def quarter(col: "ColumnOrName") -> Column: >>> df.select(quarter('dt').alias('quarter')).collect() [Row(quarter=2)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.quarter(_to_java_column(col))) + return _invoke_function_over_columns("quarter", col) def month(col: "ColumnOrName") -> Column: @@ -2063,9 +1982,7 @@ def month(col: "ColumnOrName") -> Column: >>> df.select(month('dt').alias('month')).collect() [Row(month=4)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.month(_to_java_column(col))) + return _invoke_function_over_columns("month", col) def dayofweek(col: "ColumnOrName") -> Column: @@ -2081,9 +1998,7 @@ def dayofweek(col: "ColumnOrName") -> Column: >>> df.select(dayofweek('dt').alias('day')).collect() [Row(day=4)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.dayofweek(_to_java_column(col))) + return _invoke_function_over_columns("dayofweek", col) def dayofmonth(col: "ColumnOrName") -> Column: @@ -2098,9 +2013,7 @@ def dayofmonth(col: "ColumnOrName") -> Column: >>> df.select(dayofmonth('dt').alias('day')).collect() [Row(day=8)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.dayofmonth(_to_java_column(col))) + return _invoke_function_over_columns("dayofmonth", col) def dayofyear(col: "ColumnOrName") -> Column: @@ -2115,9 +2028,7 @@ def dayofyear(col: "ColumnOrName") -> Column: >>> df.select(dayofyear('dt').alias('day')).collect() [Row(day=98)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.dayofyear(_to_java_column(col))) + return _invoke_function_over_columns("dayofyear", col) def hour(col: "ColumnOrName") -> Column: @@ -2128,13 +2039,12 @@ def hour(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['ts']) + >>> import datetime + >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ['ts']) >>> df.select(hour('ts').alias('hour')).collect() [Row(hour=13)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.hour(_to_java_column(col))) + return _invoke_function_over_columns("hour", col) def minute(col: "ColumnOrName") -> Column: @@ -2145,13 +2055,12 @@ def minute(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['ts']) + >>> import datetime + >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ['ts']) >>> df.select(minute('ts').alias('minute')).collect() [Row(minute=8)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.minute(_to_java_column(col))) + return _invoke_function_over_columns("minute", col) def second(col: "ColumnOrName") -> Column: @@ -2162,13 +2071,12 @@ def second(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['ts']) + >>> import datetime + >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ['ts']) >>> df.select(second('ts').alias('second')).collect() [Row(second=15)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.second(_to_java_column(col))) + return _invoke_function_over_columns("second", col) def weekofyear(col: "ColumnOrName") -> Column: @@ -2185,9 +2093,7 @@ def weekofyear(col: "ColumnOrName") -> Column: >>> df.select(weekofyear(df.dt).alias('week')).collect() [Row(week=15)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.weekofyear(_to_java_column(col))) + return _invoke_function_over_columns("weekofyear", col) def make_date(year: "ColumnOrName", month: "ColumnOrName", day: "ColumnOrName") -> Column: @@ -2211,13 +2117,7 @@ def make_date(year: "ColumnOrName", month: "ColumnOrName", day: "ColumnOrName") >>> df.select(make_date(df.Y, df.M, df.D).alias("datefield")).collect() [Row(datefield=datetime.date(2020, 6, 26))] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - year_col = _to_java_column(year) - month_col = _to_java_column(month) - day_col = _to_java_column(day) - jc = sc._jvm.functions.make_date(year_col, month_col, day_col) - return Column(jc) + return _invoke_function_over_columns("make_date", year, month, day) def date_add(start: "ColumnOrName", days: Union["ColumnOrName", int]) -> Column: @@ -2234,12 +2134,8 @@ def date_add(start: "ColumnOrName", days: Union["ColumnOrName", int]) -> Column: >>> df.select(date_add(df.dt, df.add.cast('integer')).alias('next_date')).collect() [Row(next_date=datetime.date(2015, 4, 10))] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - days = lit(days) if isinstance(days, int) else days - - return Column(sc._jvm.functions.date_add(_to_java_column(start), _to_java_column(days))) + return _invoke_function_over_columns("date_add", start, days) def date_sub(start: "ColumnOrName", days: Union["ColumnOrName", int]) -> Column: @@ -2256,12 +2152,8 @@ def date_sub(start: "ColumnOrName", days: Union["ColumnOrName", int]) -> Column: >>> df.select(date_sub(df.dt, df.sub.cast('integer')).alias('prev_date')).collect() [Row(prev_date=datetime.date(2015, 4, 6))] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - days = lit(days) if isinstance(days, int) else days - - return Column(sc._jvm.functions.date_sub(_to_java_column(start), _to_java_column(days))) + return _invoke_function_over_columns("date_sub", start, days) def datediff(end: "ColumnOrName", start: "ColumnOrName") -> Column: @@ -2276,9 +2168,7 @@ def datediff(end: "ColumnOrName", start: "ColumnOrName") -> Column: >>> df.select(datediff(df.d2, df.d1).alias('diff')).collect() [Row(diff=32)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.datediff(_to_java_column(end), _to_java_column(start))) + return _invoke_function_over_columns("datediff", end, start) def add_months(start: "ColumnOrName", months: Union["ColumnOrName", int]) -> Column: @@ -2295,12 +2185,8 @@ def add_months(start: "ColumnOrName", months: Union["ColumnOrName", int]) -> Col >>> df.select(add_months(df.dt, df.add.cast('integer')).alias('next_month')).collect() [Row(next_month=datetime.date(2015, 6, 8))] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - months = lit(months) if isinstance(months, int) else months - - return Column(sc._jvm.functions.add_months(_to_java_column(start), _to_java_column(months))) + return _invoke_function_over_columns("add_months", start, months) def months_between(date1: "ColumnOrName", date2: "ColumnOrName", roundOff: bool = True) -> Column: @@ -2321,10 +2207,8 @@ def months_between(date1: "ColumnOrName", date2: "ColumnOrName", roundOff: bool >>> df.select(months_between(df.date1, df.date2, False).alias('months')).collect() [Row(months=3.9495967741935485)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column( - sc._jvm.functions.months_between(_to_java_column(date1), _to_java_column(date2), roundOff) + return _invoke_function( + "months_between", _to_java_column(date1), _to_java_column(date2), roundOff ) @@ -2348,13 +2232,10 @@ def to_date(col: "ColumnOrName", format: Optional[str] = None) -> Column: >>> df.select(to_date(df.t, 'yyyy-MM-dd HH:mm:ss').alias('date')).collect() [Row(date=datetime.date(1997, 2, 28))] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None if format is None: - jc = sc._jvm.functions.to_date(_to_java_column(col)) + return _invoke_function_over_columns("to_date", col) else: - jc = sc._jvm.functions.to_date(_to_java_column(col), format) - return Column(jc) + return _invoke_function("to_date", _to_java_column(col), format) @overload @@ -2387,13 +2268,10 @@ def to_timestamp(col: "ColumnOrName", format: Optional[str] = None) -> Column: >>> df.select(to_timestamp(df.t, 'yyyy-MM-dd HH:mm:ss').alias('dt')).collect() [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None if format is None: - jc = sc._jvm.functions.to_timestamp(_to_java_column(col)) + return _invoke_function_over_columns("to_timestamp", col) else: - jc = sc._jvm.functions.to_timestamp(_to_java_column(col), format) - return Column(jc) + return _invoke_function("to_timestamp", _to_java_column(col), format) def trunc(date: "ColumnOrName", format: str) -> Column: @@ -2418,9 +2296,7 @@ def trunc(date: "ColumnOrName", format: str) -> Column: >>> df.select(trunc(df.d, 'mon').alias('month')).collect() [Row(month=datetime.date(1997, 2, 1))] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.trunc(_to_java_column(date), format)) + return _invoke_function("trunc", _to_java_column(date), format) def date_trunc(format: str, timestamp: "ColumnOrName") -> Column: @@ -2447,9 +2323,7 @@ def date_trunc(format: str, timestamp: "ColumnOrName") -> Column: >>> df.select(date_trunc('mon', df.t).alias('month')).collect() [Row(month=datetime.datetime(1997, 2, 1, 0, 0))] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.date_trunc(format, _to_java_column(timestamp))) + return _invoke_function("date_trunc", format, _to_java_column(timestamp)) def next_day(date: "ColumnOrName", dayOfWeek: str) -> Column: @@ -2467,9 +2341,7 @@ def next_day(date: "ColumnOrName", dayOfWeek: str) -> Column: >>> df.select(next_day(df.d, 'Sun').alias('date')).collect() [Row(date=datetime.date(2015, 8, 2))] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.next_day(_to_java_column(date), dayOfWeek)) + return _invoke_function("next_day", _to_java_column(date), dayOfWeek) def last_day(date: "ColumnOrName") -> Column: @@ -2484,9 +2356,7 @@ def last_day(date: "ColumnOrName") -> Column: >>> df.select(last_day(df.d).alias('date')).collect() [Row(date=datetime.date(1997, 2, 28))] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.last_day(_to_java_column(date))) + return _invoke_function("last_day", _to_java_column(date)) def from_unixtime(timestamp: "ColumnOrName", format: str = "yyyy-MM-dd HH:mm:ss") -> Column: @@ -2505,9 +2375,17 @@ def from_unixtime(timestamp: "ColumnOrName", format: str = "yyyy-MM-dd HH:mm:ss" [Row(ts='2015-04-08 00:00:00')] >>> spark.conf.unset("spark.sql.session.timeZone") """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.from_unixtime(_to_java_column(timestamp), format)) + return _invoke_function("from_unixtime", _to_java_column(timestamp), format) + + +@overload +def unix_timestamp(timestamp: "ColumnOrName", format: str = ...) -> Column: + ... + + +@overload +def unix_timestamp() -> Column: + ... def unix_timestamp( @@ -2530,11 +2408,9 @@ def unix_timestamp( [Row(unix_time=1428476400)] >>> spark.conf.unset("spark.sql.session.timeZone") """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None if timestamp is None: - return Column(sc._jvm.functions.unix_timestamp()) - return Column(sc._jvm.functions.unix_timestamp(_to_java_column(timestamp), format)) + return _invoke_function("unix_timestamp") + return _invoke_function("unix_timestamp", _to_java_column(timestamp), format) def from_utc_timestamp(timestamp: "ColumnOrName", tz: "ColumnOrName") -> Column: @@ -2577,11 +2453,9 @@ def from_utc_timestamp(timestamp: "ColumnOrName", tz: "ColumnOrName") -> Column: >>> df.select(from_utc_timestamp(df.ts, df.tz).alias('local_time')).collect() [Row(local_time=datetime.datetime(1997, 2, 28, 19, 30))] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None if isinstance(tz, Column): tz = _to_java_column(tz) - return Column(sc._jvm.functions.from_utc_timestamp(_to_java_column(timestamp), tz)) + return _invoke_function("from_utc_timestamp", _to_java_column(timestamp), tz) def to_utc_timestamp(timestamp: "ColumnOrName", tz: "ColumnOrName") -> Column: @@ -2624,11 +2498,9 @@ def to_utc_timestamp(timestamp: "ColumnOrName", tz: "ColumnOrName") -> Column: >>> df.select(to_utc_timestamp(df.ts, df.tz).alias('utc_time')).collect() [Row(utc_time=datetime.datetime(1997, 2, 28, 1, 30))] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None if isinstance(tz, Column): tz = _to_java_column(tz) - return Column(sc._jvm.functions.to_utc_timestamp(_to_java_column(timestamp), tz)) + return _invoke_function("to_utc_timestamp", _to_java_column(timestamp), tz) def timestamp_seconds(col: "ColumnOrName") -> Column: @@ -2649,9 +2521,7 @@ def timestamp_seconds(col: "ColumnOrName") -> Column: >>> spark.conf.unset("spark.sql.session.timeZone") """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.timestamp_seconds(_to_java_column(col))) + return _invoke_function_over_columns("timestamp_seconds", col) def window( @@ -2684,7 +2554,7 @@ def window( ---------- timeColumn : :class:`~pyspark.sql.Column` The column or the expression to use as the timestamp for windowing by time. - The time column must be of TimestampType. + The time column must be of TimestampType or TimestampNTZType. windowDuration : str A string specifying the width of the window, e.g. `10 minutes`, `1 second`. Check `org.apache.spark.unsafe.types.CalendarInterval` for @@ -2705,7 +2575,10 @@ def window( Examples -------- - >>> df = spark.createDataFrame([("2016-03-11 09:00:07", 1)]).toDF("date", "val") + >>> import datetime + >>> df = spark.createDataFrame( + ... [(datetime.datetime(2016, 3, 11, 9, 0, 7), 1)], + ... ).toDF("date", "val") >>> w = df.groupBy(window("date", "5 seconds")).agg(sum("val").alias("sum")) >>> w.select(w.window.start.cast("string").alias("start"), ... w.window.end.cast("string").alias("end"), "sum").collect() @@ -2716,23 +2589,20 @@ def check_string_field(field, fieldName): # type: ignore[no-untyped-def] if not field or type(field) is not str: raise TypeError("%s should be provided as a string" % fieldName) - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None time_col = _to_java_column(timeColumn) check_string_field(windowDuration, "windowDuration") if slideDuration and startTime: check_string_field(slideDuration, "slideDuration") check_string_field(startTime, "startTime") - res = sc._jvm.functions.window(time_col, windowDuration, slideDuration, startTime) + return _invoke_function("window", time_col, windowDuration, slideDuration, startTime) elif slideDuration: check_string_field(slideDuration, "slideDuration") - res = sc._jvm.functions.window(time_col, windowDuration, slideDuration) + return _invoke_function("window", time_col, windowDuration, slideDuration) elif startTime: check_string_field(startTime, "startTime") - res = sc._jvm.functions.window(time_col, windowDuration, windowDuration, startTime) + return _invoke_function("window", time_col, windowDuration, windowDuration, startTime) else: - res = sc._jvm.functions.window(time_col, windowDuration) - return Column(res) + return _invoke_function("window", time_col, windowDuration) def session_window(timeColumn: "ColumnOrName", gapDuration: Union[Column, str]) -> Column: @@ -2759,7 +2629,7 @@ def session_window(timeColumn: "ColumnOrName", gapDuration: Union[Column, str]) ---------- timeColumn : :class:`~pyspark.sql.Column` or str The column name or column to use as the timestamp for windowing by time. - The time column must be of TimestampType. + The time column must be of TimestampType or TimestampNTZType. gapDuration : :class:`~pyspark.sql.Column` or str A Python string literal or column specifying the timeout of the session. It could be static value, e.g. `10 minutes`, `1 second`, or an expression/UDF that specifies gap @@ -2782,13 +2652,10 @@ def check_field(field: Union[Column, str], fieldName: str) -> None: if field is None or not isinstance(field, (str, Column)): raise TypeError("%s should be provided as a string or Column" % fieldName) - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None time_col = _to_java_column(timeColumn) check_field(gapDuration, "gapDuration") gap_duration = gapDuration if isinstance(gapDuration, str) else _to_java_column(gapDuration) - res = sc._jvm.functions.session_window(time_col, gap_duration) - return Column(res) + return _invoke_function("session_window", time_col, gap_duration) # ---------------------------- misc functions ---------------------------------- @@ -2806,9 +2673,7 @@ def crc32(col: "ColumnOrName") -> Column: >>> spark.createDataFrame([('ABC',)], ['a']).select(crc32('a').alias('crc32')).collect() [Row(crc32=2743272264)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.crc32(_to_java_column(col))) + return _invoke_function_over_columns("crc32", col) def md5(col: "ColumnOrName") -> Column: @@ -2821,10 +2686,7 @@ def md5(col: "ColumnOrName") -> Column: >>> spark.createDataFrame([('ABC',)], ['a']).select(md5('a').alias('hash')).collect() [Row(hash='902fbdd2b1df0c4f70b4a5d23525e932')] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jc = sc._jvm.functions.md5(_to_java_column(col)) - return Column(jc) + return _invoke_function_over_columns("md5", col) def sha1(col: "ColumnOrName") -> Column: @@ -2837,10 +2699,7 @@ def sha1(col: "ColumnOrName") -> Column: >>> spark.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect() [Row(hash='3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jc = sc._jvm.functions.sha1(_to_java_column(col)) - return Column(jc) + return _invoke_function_over_columns("sha1", col) def sha2(col: "ColumnOrName", numBits: int) -> Column: @@ -2858,10 +2717,7 @@ def sha2(col: "ColumnOrName", numBits: int) -> Column: >>> digests[1] Row(s='cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961') """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jc = sc._jvm.functions.sha2(_to_java_column(col), numBits) - return Column(jc) + return _invoke_function("sha2", _to_java_column(col), numBits) def hash(*cols: "ColumnOrName") -> Column: @@ -2874,10 +2730,7 @@ def hash(*cols: "ColumnOrName") -> Column: >>> spark.createDataFrame([('ABC',)], ['a']).select(hash('a').alias('hash')).collect() [Row(hash=-757602832)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jc = sc._jvm.functions.hash(_to_seq(sc, cols, _to_java_column)) - return Column(jc) + return _invoke_function_over_seq_of_columns("hash", cols) def xxhash64(*cols: "ColumnOrName") -> Column: @@ -2891,10 +2744,7 @@ def xxhash64(*cols: "ColumnOrName") -> Column: >>> spark.createDataFrame([('ABC',)], ['a']).select(xxhash64('a').alias('hash')).collect() [Row(hash=4105715581806190027)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jc = sc._jvm.functions.xxhash64(_to_seq(sc, cols, _to_java_column)) - return Column(jc) + return _invoke_function_over_seq_of_columns("xxhash64", cols) def assert_true(col: "ColumnOrName", errMsg: Optional[Union[Column, str]] = None) -> Column: @@ -2923,17 +2773,15 @@ def assert_true(col: "ColumnOrName", errMsg: Optional[Union[Column, str]] = None >>> df.select(assert_true(df.a < df.b, 'error').alias('r')).collect() [Row(r=None)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None if errMsg is None: - return Column(sc._jvm.functions.assert_true(_to_java_column(col))) + return _invoke_function_over_columns("assert_true", col) if not isinstance(errMsg, (str, Column)): raise TypeError("errMsg should be a Column or a str, got {}".format(type(errMsg))) errMsg = ( _create_column_from_literal(errMsg) if isinstance(errMsg, str) else _to_java_column(errMsg) ) - return Column(sc._jvm.functions.assert_true(_to_java_column(col), errMsg)) + return _invoke_function("assert_true", _to_java_column(col), errMsg) @since(3.1) @@ -2949,12 +2797,10 @@ def raise_error(errMsg: Union[Column, str]) -> Column: if not isinstance(errMsg, (str, Column)): raise TypeError("errMsg should be a Column or a str, got {}".format(type(errMsg))) - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None errMsg = ( _create_column_from_literal(errMsg) if isinstance(errMsg, str) else _to_java_column(errMsg) ) - return Column(sc._jvm.functions.raise_error(errMsg)) + return _invoke_function("raise_error", errMsg) # ---------------------- String/Binary functions ------------------------------ @@ -2965,7 +2811,7 @@ def upper(col: "ColumnOrName") -> Column: """ Converts a string expression to upper case. """ - return _invoke_function_over_column("upper", col) + return _invoke_function_over_columns("upper", col) @since(1.5) @@ -2973,7 +2819,7 @@ def lower(col: "ColumnOrName") -> Column: """ Converts a string expression to lower case. """ - return _invoke_function_over_column("lower", col) + return _invoke_function_over_columns("lower", col) @since(1.5) @@ -2981,7 +2827,7 @@ def ascii(col: "ColumnOrName") -> Column: """ Computes the numeric value of the first character of the string column. """ - return _invoke_function_over_column("ascii", col) + return _invoke_function_over_columns("ascii", col) @since(1.5) @@ -2989,7 +2835,7 @@ def base64(col: "ColumnOrName") -> Column: """ Computes the BASE64 encoding of a binary column and returns it as a string column. """ - return _invoke_function_over_column("base64", col) + return _invoke_function_over_columns("base64", col) @since(1.5) @@ -2997,7 +2843,7 @@ def unbase64(col: "ColumnOrName") -> Column: """ Decodes a BASE64 encoded string column and returns it as a binary column. """ - return _invoke_function_over_column("unbase64", col) + return _invoke_function_over_columns("unbase64", col) @since(1.5) @@ -3005,7 +2851,7 @@ def ltrim(col: "ColumnOrName") -> Column: """ Trim the spaces from left end for the specified string value. """ - return _invoke_function_over_column("ltrim", col) + return _invoke_function_over_columns("ltrim", col) @since(1.5) @@ -3013,7 +2859,7 @@ def rtrim(col: "ColumnOrName") -> Column: """ Trim the spaces from right end for the specified string value. """ - return _invoke_function_over_column("rtrim", col) + return _invoke_function_over_columns("rtrim", col) @since(1.5) @@ -3021,7 +2867,7 @@ def trim(col: "ColumnOrName") -> Column: """ Trim the spaces from both ends for the specified string column. """ - return _invoke_function_over_column("trim", col) + return _invoke_function_over_columns("trim", col) def concat_ws(sep: str, *cols: "ColumnOrName") -> Column: @@ -3039,7 +2885,7 @@ def concat_ws(sep: str, *cols: "ColumnOrName") -> Column: """ sc = SparkContext._active_spark_context assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.concat_ws(sep, _to_seq(sc, cols, _to_java_column))) + return _invoke_function("concat_ws", sep, _to_seq(sc, cols, _to_java_column)) @since(1.5) @@ -3048,9 +2894,7 @@ def decode(col: "ColumnOrName", charset: str) -> Column: Computes the first argument into a string from a binary using the provided character set (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.decode(_to_java_column(col), charset)) + return _invoke_function("decode", _to_java_column(col), charset) @since(1.5) @@ -3059,9 +2903,7 @@ def encode(col: "ColumnOrName", charset: str) -> Column: Computes the first argument into a binary from a string using the provided character set (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.encode(_to_java_column(col), charset)) + return _invoke_function("encode", _to_java_column(col), charset) def format_number(col: "ColumnOrName", d: int) -> Column: @@ -3081,9 +2923,7 @@ def format_number(col: "ColumnOrName", d: int) -> Column: >>> spark.createDataFrame([(5,)], ['a']).select(format_number('a', 4).alias('v')).collect() [Row(v='5.0000')] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.format_number(_to_java_column(col), d)) + return _invoke_function("format_number", _to_java_column(col), d) def format_string(format: str, *cols: "ColumnOrName") -> Column: @@ -3107,7 +2947,7 @@ def format_string(format: str, *cols: "ColumnOrName") -> Column: """ sc = SparkContext._active_spark_context assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.format_string(format, _to_seq(sc, cols, _to_java_column))) + return _invoke_function("format_string", format, _to_seq(sc, cols, _to_java_column)) def instr(str: "ColumnOrName", substr: str) -> Column: @@ -3126,9 +2966,7 @@ def instr(str: "ColumnOrName", substr: str) -> Column: >>> df.select(instr(df.s, 'b').alias('s')).collect() [Row(s=2)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.instr(_to_java_column(str), substr)) + return _invoke_function("instr", _to_java_column(str), substr) def overlay( @@ -3177,12 +3015,7 @@ def overlay( pos = _create_column_from_literal(pos) if isinstance(pos, int) else _to_java_column(pos) len = _create_column_from_literal(len) if isinstance(len, int) else _to_java_column(len) - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - - return Column( - sc._jvm.functions.overlay(_to_java_column(src), _to_java_column(replace), pos, len) - ) + return _invoke_function("overlay", _to_java_column(src), _to_java_column(replace), pos, len) def sentences( @@ -3220,13 +3053,7 @@ def sentences( if country is None: country = lit("") - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column( - sc._jvm.functions.sentences( - _to_java_column(string), _to_java_column(language), _to_java_column(country) - ) - ) + return _invoke_function_over_columns("sentences", string, language, country) def substring(str: "ColumnOrName", pos: int, len: int) -> Column: @@ -3247,9 +3074,7 @@ def substring(str: "ColumnOrName", pos: int, len: int) -> Column: >>> df.select(substring(df.s, 1, 2).alias('s')).collect() [Row(s='ab')] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.substring(_to_java_column(str), pos, len)) + return _invoke_function("substring", _to_java_column(str), pos, len) def substring_index(str: "ColumnOrName", delim: str, count: int) -> Column: @@ -3269,9 +3094,7 @@ def substring_index(str: "ColumnOrName", delim: str, count: int) -> Column: >>> df.select(substring_index(df.s, '.', -3).alias('s')).collect() [Row(s='b.c.d')] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.substring_index(_to_java_column(str), delim, count)) + return _invoke_function("substring_index", _to_java_column(str), delim, count) def levenshtein(left: "ColumnOrName", right: "ColumnOrName") -> Column: @@ -3285,10 +3108,7 @@ def levenshtein(left: "ColumnOrName", right: "ColumnOrName") -> Column: >>> df0.select(levenshtein('l', 'r').alias('d')).collect() [Row(d=3)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jc = sc._jvm.functions.levenshtein(_to_java_column(left), _to_java_column(right)) - return Column(jc) + return _invoke_function_over_columns("levenshtein", left, right) def locate(substr: str, str: "ColumnOrName", pos: int = 1) -> Column: @@ -3317,9 +3137,7 @@ def locate(substr: str, str: "ColumnOrName", pos: int = 1) -> Column: >>> df.select(locate('b', df.s, 1).alias('s')).collect() [Row(s=2)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.locate(substr, _to_java_column(str), pos)) + return _invoke_function("locate", substr, _to_java_column(str), pos) def lpad(col: "ColumnOrName", len: int, pad: str) -> Column: @@ -3334,9 +3152,7 @@ def lpad(col: "ColumnOrName", len: int, pad: str) -> Column: >>> df.select(lpad(df.s, 6, '#').alias('s')).collect() [Row(s='##abcd')] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.lpad(_to_java_column(col), len, pad)) + return _invoke_function("lpad", _to_java_column(col), len, pad) def rpad(col: "ColumnOrName", len: int, pad: str) -> Column: @@ -3351,9 +3167,7 @@ def rpad(col: "ColumnOrName", len: int, pad: str) -> Column: >>> df.select(rpad(df.s, 6, '#').alias('s')).collect() [Row(s='abcd##')] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.rpad(_to_java_column(col), len, pad)) + return _invoke_function("rpad", _to_java_column(col), len, pad) def repeat(col: "ColumnOrName", n: int) -> Column: @@ -3368,9 +3182,7 @@ def repeat(col: "ColumnOrName", n: int) -> Column: >>> df.select(repeat(df.s, 3).alias('s')).collect() [Row(s='ababab')] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.repeat(_to_java_column(col), n)) + return _invoke_function("repeat", _to_java_column(col), n) def split(str: "ColumnOrName", pattern: str, limit: int = -1) -> Column: @@ -3406,9 +3218,7 @@ def split(str: "ColumnOrName", pattern: str, limit: int = -1) -> Column: >>> df.select(split(df.s, '[ABC]', -1).alias('s')).collect() [Row(s=['one', 'two', 'three', ''])] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.split(_to_java_column(str), pattern, limit)) + return _invoke_function("split", _to_java_column(str), pattern, limit) def regexp_extract(str: "ColumnOrName", pattern: str, idx: int) -> Column: @@ -3429,10 +3239,7 @@ def regexp_extract(str: "ColumnOrName", pattern: str, idx: int) -> Column: >>> df.select(regexp_extract('str', '(a+)(b)?(c)', 2).alias('d')).collect() [Row(d='')] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jc = sc._jvm.functions.regexp_extract(_to_java_column(str), pattern, idx) - return Column(jc) + return _invoke_function("regexp_extract", _to_java_column(str), pattern, idx) def regexp_replace(str: "ColumnOrName", pattern: str, replacement: str) -> Column: @@ -3446,10 +3253,7 @@ def regexp_replace(str: "ColumnOrName", pattern: str, replacement: str) -> Colum >>> df.select(regexp_replace('str', r'(\d+)', '--').alias('d')).collect() [Row(d='-----')] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jc = sc._jvm.functions.regexp_replace(_to_java_column(str), pattern, replacement) - return Column(jc) + return _invoke_function("regexp_replace", _to_java_column(str), pattern, replacement) def initcap(col: "ColumnOrName") -> Column: @@ -3462,9 +3266,7 @@ def initcap(col: "ColumnOrName") -> Column: >>> spark.createDataFrame([('ab cd',)], ['a']).select(initcap("a").alias('v')).collect() [Row(v='Ab Cd')] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.initcap(_to_java_column(col))) + return _invoke_function_over_columns("initcap", col) def soundex(col: "ColumnOrName") -> Column: @@ -3479,9 +3281,7 @@ def soundex(col: "ColumnOrName") -> Column: >>> df.select(soundex(df.name).alias("soundex")).collect() [Row(soundex='P362'), Row(soundex='U612')] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.soundex(_to_java_column(col))) + return _invoke_function_over_columns("soundex", col) def bin(col: "ColumnOrName") -> Column: @@ -3494,10 +3294,7 @@ def bin(col: "ColumnOrName") -> Column: >>> df.select(bin(df.age).alias('c')).collect() [Row(c='10'), Row(c='101')] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jc = sc._jvm.functions.bin(_to_java_column(col)) - return Column(jc) + return _invoke_function_over_columns("bin", col) def hex(col: "ColumnOrName") -> Column: @@ -3512,10 +3309,7 @@ def hex(col: "ColumnOrName") -> Column: >>> spark.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect() [Row(hex(a)='414243', hex(b)='3')] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jc = sc._jvm.functions.hex(_to_java_column(col)) - return Column(jc) + return _invoke_function_over_columns("hex", col) def unhex(col: "ColumnOrName") -> Column: @@ -3529,9 +3323,7 @@ def unhex(col: "ColumnOrName") -> Column: >>> spark.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect() [Row(unhex(a)=bytearray(b'ABC'))] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.unhex(_to_java_column(col))) + return _invoke_function_over_columns("unhex", col) def length(col: "ColumnOrName") -> Column: @@ -3546,9 +3338,7 @@ def length(col: "ColumnOrName") -> Column: >>> spark.createDataFrame([('ABC ',)], ['a']).select(length('a').alias('length')).collect() [Row(length=4)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.length(_to_java_column(col))) + return _invoke_function_over_columns("length", col) def octet_length(col: "ColumnOrName") -> Column: @@ -3574,7 +3364,7 @@ def octet_length(col: "ColumnOrName") -> Column: ... .select(octet_length('cat')).collect() [Row(octet_length(cat)=3), Row(octet_length(cat)=4)] """ - return _invoke_function_over_column("octet_length", col) + return _invoke_function_over_columns("octet_length", col) def bit_length(col: "ColumnOrName") -> Column: @@ -3600,7 +3390,7 @@ def bit_length(col: "ColumnOrName") -> Column: ... .select(bit_length('cat')).collect() [Row(bit_length(cat)=24), Row(bit_length(cat)=32)] """ - return _invoke_function_over_column("bit_length", col) + return _invoke_function_over_columns("bit_length", col) def translate(srcCol: "ColumnOrName", matching: str, replace: str) -> Column: @@ -3617,9 +3407,7 @@ def translate(srcCol: "ColumnOrName", matching: str, replace: str) -> Column: ... .alias('r')).collect() [Row(r='1a2s3ae')] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.translate(_to_java_column(srcCol), matching, replace)) + return _invoke_function("translate", _to_java_column(srcCol), matching, replace) # ---------------------- Collection functions ------------------------------ @@ -3655,12 +3443,9 @@ def create_map( >>> df.select(create_map([df.name, df.age]).alias("map")).collect() [Row(map={'Alice': 2}), Row(map={'Bob': 5})] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None if len(cols) == 1 and isinstance(cols[0], (list, set)): cols = cols[0] # type: ignore[assignment] - jc = sc._jvm.functions.map(_to_seq(sc, cols, _to_java_column)) # type: ignore[arg-type] - return Column(jc) + return _invoke_function_over_seq_of_columns("map", cols) # type: ignore[arg-type] def map_from_arrays(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: @@ -3685,9 +3470,7 @@ def map_from_arrays(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: |{2 -> a, 5 -> b}| +----------------+ """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.map_from_arrays(_to_java_column(col1), _to_java_column(col2))) + return _invoke_function_over_columns("map_from_arrays", col1, col2) @overload @@ -3720,12 +3503,9 @@ def array( >>> df.select(array([df.age, df.age]).alias("arr")).collect() [Row(arr=[2, 2]), Row(arr=[5, 5])] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None if len(cols) == 1 and isinstance(cols[0], (list, set)): cols = cols[0] # type: ignore[assignment] - jc = sc._jvm.functions.array(_to_seq(sc, cols, _to_java_column)) # type: ignore[arg-type] - return Column(jc) + return _invoke_function_over_seq_of_columns("array", cols) # type: ignore[arg-type] def array_contains(col: "ColumnOrName", value: Any) -> Column: @@ -3750,10 +3530,8 @@ def array_contains(col: "ColumnOrName", value: Any) -> Column: >>> df.select(array_contains(df.data, lit("a"))).collect() [Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None value = value._jc if isinstance(value, Column) else value - return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) + return _invoke_function("array_contains", _to_java_column(col), value) def arrays_overlap(a1: "ColumnOrName", a2: "ColumnOrName") -> Column: @@ -3770,9 +3548,7 @@ def arrays_overlap(a1: "ColumnOrName", a2: "ColumnOrName") -> Column: >>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect() [Row(overlap=True), Row(overlap=False)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.arrays_overlap(_to_java_column(a1), _to_java_column(a2))) + return _invoke_function_over_columns("arrays_overlap", a1, a2) def slice( @@ -3799,19 +3575,10 @@ def slice( >>> df.select(slice(df.x, 2, 2).alias("sliced")).collect() [Row(sliced=[2, 3]), Row(sliced=[5])] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - start = lit(start) if isinstance(start, int) else start length = lit(length) if isinstance(length, int) else length - return Column( - sc._jvm.functions.slice( - _to_java_column(x), - _to_java_column(start), - _to_java_column(length), - ) - ) + return _invoke_function_over_columns("slice", x, start, length) def array_join( @@ -3834,11 +3601,9 @@ def array_join( sc = SparkContext._active_spark_context assert sc is not None and sc._jvm is not None if null_replacement is None: - return Column(sc._jvm.functions.array_join(_to_java_column(col), delimiter)) + return _invoke_function("array_join", _to_java_column(col), delimiter) else: - return Column( - sc._jvm.functions.array_join(_to_java_column(col), delimiter, null_replacement) - ) + return _invoke_function("array_join", _to_java_column(col), delimiter, null_replacement) def concat(*cols: "ColumnOrName") -> Column: @@ -3858,9 +3623,7 @@ def concat(*cols: "ColumnOrName") -> Column: >>> df.select(concat(df.a, df.b, df.c).alias("arr")).collect() [Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column))) + return _invoke_function_over_seq_of_columns("concat", cols) def array_position(col: "ColumnOrName", value: Any) -> Column: @@ -3881,9 +3644,7 @@ def array_position(col: "ColumnOrName", value: Any) -> Column: >>> df.select(array_position(df.data, "a")).collect() [Row(array_position(data, a)=3), Row(array_position(data, a)=0)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.array_position(_to_java_column(col), value)) + return _invoke_function("array_position", _to_java_column(col), value) def element_at(col: "ColumnOrName", extraction: Any) -> Column: @@ -3906,17 +3667,15 @@ def element_at(col: "ColumnOrName", extraction: Any) -> Column: Examples -------- - >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) + >>> df = spark.createDataFrame([(["a", "b", "c"],)], ['data']) >>> df.select(element_at(df.data, 1)).collect() - [Row(element_at(data, 1)='a'), Row(element_at(data, 1)=None)] + [Row(element_at(data, 1)='a')] - >>> df = spark.createDataFrame([({"a": 1.0, "b": 2.0},), ({},)], ['data']) + >>> df = spark.createDataFrame([({"a": 1.0, "b": 2.0},)], ['data']) >>> df.select(element_at(df.data, lit("a"))).collect() - [Row(element_at(data, a)=1.0), Row(element_at(data, a)=None)] + [Row(element_at(data, a)=1.0)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.element_at(_to_java_column(col), lit(extraction)._jc)) + return _invoke_function_over_columns("element_at", col, lit(extraction)) def array_remove(col: "ColumnOrName", element: Any) -> Column: @@ -3938,9 +3697,7 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: >>> df.select(array_remove(df.data, 1)).collect() [Row(array_remove(data, 1)=[2, 3]), Row(array_remove(data, 1)=[])] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.array_remove(_to_java_column(col), element)) + return _invoke_function("array_remove", _to_java_column(col), element) def array_distinct(col: "ColumnOrName") -> Column: @@ -3960,9 +3717,7 @@ def array_distinct(col: "ColumnOrName") -> Column: >>> df.select(array_distinct(df.data)).collect() [Row(array_distinct(data)=[1, 2, 3]), Row(array_distinct(data)=[4, 5])] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.array_distinct(_to_java_column(col))) + return _invoke_function_over_columns("array_distinct", col) def array_intersect(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: @@ -3986,9 +3741,7 @@ def array_intersect(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: >>> df.select(array_intersect(df.c1, df.c2)).collect() [Row(array_intersect(c1, c2)=['a', 'c'])] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.array_intersect(_to_java_column(col1), _to_java_column(col2))) + return _invoke_function_over_columns("array_intersect", col1, col2) def array_union(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: @@ -4012,9 +3765,7 @@ def array_union(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: >>> df.select(array_union(df.c1, df.c2)).collect() [Row(array_union(c1, c2)=['b', 'a', 'c', 'd', 'f'])] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.array_union(_to_java_column(col1), _to_java_column(col2))) + return _invoke_function_over_columns("array_union", col1, col2) def array_except(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: @@ -4038,9 +3789,7 @@ def array_except(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: >>> df.select(array_except(df.c1, df.c2)).collect() [Row(array_except(c1, c2)=['b'])] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.array_except(_to_java_column(col1), _to_java_column(col2))) + return _invoke_function_over_columns("array_except", col1, col2) def explode(col: "ColumnOrName") -> Column: @@ -4065,10 +3814,7 @@ def explode(col: "ColumnOrName") -> Column: | a| b| +---+-----+ """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jc = sc._jvm.functions.explode(_to_java_column(col)) - return Column(jc) + return _invoke_function_over_columns("explode", col) def posexplode(col: "ColumnOrName") -> Column: @@ -4093,10 +3839,7 @@ def posexplode(col: "ColumnOrName") -> Column: | 0| a| b| +---+---+-----+ """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jc = sc._jvm.functions.posexplode(_to_java_column(col)) - return Column(jc) + return _invoke_function_over_columns("posexplode", col) def explode_outer(col: "ColumnOrName") -> Column: @@ -4133,10 +3876,7 @@ def explode_outer(col: "ColumnOrName") -> Column: | 3| null|null| +---+----------+----+ """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jc = sc._jvm.functions.explode_outer(_to_java_column(col)) - return Column(jc) + return _invoke_function_over_columns("explode_outer", col) def posexplode_outer(col: "ColumnOrName") -> Column: @@ -4172,10 +3912,7 @@ def posexplode_outer(col: "ColumnOrName") -> Column: | 3| null|null|null| +---+----------+----+----+ """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jc = sc._jvm.functions.posexplode_outer(_to_java_column(col)) - return Column(jc) + return _invoke_function_over_columns("posexplode_outer", col) def get_json_object(col: "ColumnOrName", path: str) -> Column: @@ -4200,10 +3937,7 @@ def get_json_object(col: "ColumnOrName", path: str) -> Column: ... get_json_object(df.jstring, '$.f2').alias("c1") ).collect() [Row(key='1', c0='value1', c1='value2'), Row(key='2', c0='value12', c1=None)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jc = sc._jvm.functions.get_json_object(_to_java_column(col), path) - return Column(jc) + return _invoke_function("get_json_object", _to_java_column(col), path) def json_tuple(col: "ColumnOrName", *fields: str) -> Column: @@ -4227,8 +3961,7 @@ def json_tuple(col: "ColumnOrName", *fields: str) -> Column: """ sc = SparkContext._active_spark_context assert sc is not None and sc._jvm is not None - jc = sc._jvm.functions.json_tuple(_to_java_column(col), _to_seq(sc, fields)) - return Column(jc) + return _invoke_function("json_tuple", _to_java_column(col), _to_seq(sc, fields)) def from_json( @@ -4284,14 +4017,11 @@ def from_json( [Row(json=[1, 2, 3])] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None if isinstance(schema, DataType): schema = schema.json() elif isinstance(schema, Column): schema = _to_java_column(schema) - jc = sc._jvm.functions.from_json(_to_java_column(col), schema, _options_to_str(options)) - return Column(jc) + return _invoke_function("from_json", _to_java_column(col), schema, _options_to_str(options)) def to_json(col: "ColumnOrName", options: Optional[Dict[str, str]] = None) -> Column: @@ -4340,10 +4070,7 @@ def to_json(col: "ColumnOrName", options: Optional[Dict[str, str]] = None) -> Co [Row(json='["Alice","Bob"]')] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jc = sc._jvm.functions.to_json(_to_java_column(col), _options_to_str(options)) - return Column(jc) + return _invoke_function("to_json", _to_java_column(col), _options_to_str(options)) def schema_of_json(json: "ColumnOrName", options: Optional[Dict[str, str]] = None) -> Column: @@ -4370,10 +4097,10 @@ def schema_of_json(json: "ColumnOrName", options: Optional[Dict[str, str]] = Non -------- >>> df = spark.range(1) >>> df.select(schema_of_json(lit('{"a": 0}')).alias("json")).collect() - [Row(json='STRUCT<`a`: BIGINT>')] + [Row(json='STRUCT')] >>> schema = schema_of_json('{a: 1}', {'allowUnquotedFieldNames':'true'}) >>> df.select(schema.alias("json")).collect() - [Row(json='STRUCT<`a`: BIGINT>')] + [Row(json='STRUCT')] """ if isinstance(json, str): col = _create_column_from_literal(json) @@ -4382,10 +4109,7 @@ def schema_of_json(json: "ColumnOrName", options: Optional[Dict[str, str]] = Non else: raise TypeError("schema argument should be a column or string") - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jc = sc._jvm.functions.schema_of_json(col, _options_to_str(options)) - return Column(jc) + return _invoke_function("schema_of_json", col, _options_to_str(options)) def schema_of_csv(csv: "ColumnOrName", options: Optional[Dict[str, str]] = None) -> Column: @@ -4409,9 +4133,9 @@ def schema_of_csv(csv: "ColumnOrName", options: Optional[Dict[str, str]] = None) -------- >>> df = spark.range(1) >>> df.select(schema_of_csv(lit('1|a'), {'sep':'|'}).alias("csv")).collect() - [Row(csv='STRUCT<`_c0`: INT, `_c1`: STRING>')] + [Row(csv='STRUCT<_c0: INT, _c1: STRING>')] >>> df.select(schema_of_csv('1|a', {'sep':'|'}).alias("csv")).collect() - [Row(csv='STRUCT<`_c0`: INT, `_c1`: STRING>')] + [Row(csv='STRUCT<_c0: INT, _c1: STRING>')] """ if isinstance(csv, str): col = _create_column_from_literal(csv) @@ -4420,10 +4144,7 @@ def schema_of_csv(csv: "ColumnOrName", options: Optional[Dict[str, str]] = None) else: raise TypeError("schema argument should be a column or string") - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jc = sc._jvm.functions.schema_of_csv(col, _options_to_str(options)) - return Column(jc) + return _invoke_function("schema_of_csv", col, _options_to_str(options)) def to_csv(col: "ColumnOrName", options: Optional[Dict[str, str]] = None) -> Column: @@ -4453,10 +4174,7 @@ def to_csv(col: "ColumnOrName", options: Optional[Dict[str, str]] = None) -> Col [Row(csv='2,Alice')] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jc = sc._jvm.functions.to_csv(_to_java_column(col), _options_to_str(options)) - return Column(jc) + return _invoke_function("to_csv", _to_java_column(col), _options_to_str(options)) def size(col: "ColumnOrName") -> Column: @@ -4476,9 +4194,7 @@ def size(col: "ColumnOrName") -> Column: >>> df.select(size(df.data)).collect() [Row(size(data)=3), Row(size(data)=1), Row(size(data)=0)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.size(_to_java_column(col))) + return _invoke_function_over_columns("size", col) def array_min(col: "ColumnOrName") -> Column: @@ -4498,9 +4214,7 @@ def array_min(col: "ColumnOrName") -> Column: >>> df.select(array_min(df.data).alias('min')).collect() [Row(min=1), Row(min=-1)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.array_min(_to_java_column(col))) + return _invoke_function_over_columns("array_min", col) def array_max(col: "ColumnOrName") -> Column: @@ -4520,9 +4234,7 @@ def array_max(col: "ColumnOrName") -> Column: >>> df.select(array_max(df.data).alias('max')).collect() [Row(max=3), Row(max=10)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.array_max(_to_java_column(col))) + return _invoke_function_over_columns("array_max", col) def sort_array(col: "ColumnOrName", asc: bool = True) -> Column: @@ -4548,9 +4260,7 @@ def sort_array(col: "ColumnOrName", asc: bool = True) -> Column: >>> df.select(sort_array(df.data, asc=False).alias('r')).collect() [Row(r=[3, 2, 1, None]), Row(r=[1]), Row(r=[])] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc)) + return _invoke_function("sort_array", _to_java_column(col), asc) def array_sort(col: "ColumnOrName") -> Column: @@ -4571,9 +4281,7 @@ def array_sort(col: "ColumnOrName") -> Column: >>> df.select(array_sort(df.data).alias('r')).collect() [Row(r=[1, 2, 3, None]), Row(r=[1]), Row(r=[])] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.array_sort(_to_java_column(col))) + return _invoke_function_over_columns("array_sort", col) def shuffle(col: "ColumnOrName") -> Column: @@ -4597,9 +4305,7 @@ def shuffle(col: "ColumnOrName") -> Column: >>> df.select(shuffle(df.data).alias('s')).collect() # doctest: +SKIP [Row(s=[3, 1, 5, 20]), Row(s=[20, None, 3, 1])] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.shuffle(_to_java_column(col))) + return _invoke_function_over_columns("shuffle", col) def reverse(col: "ColumnOrName") -> Column: @@ -4622,9 +4328,7 @@ def reverse(col: "ColumnOrName") -> Column: >>> df.select(reverse(df.data).alias('r')).collect() [Row(r=[3, 1, 2]), Row(r=[1]), Row(r=[])] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.reverse(_to_java_column(col))) + return _invoke_function_over_columns("reverse", col) def flatten(col: "ColumnOrName") -> Column: @@ -4646,9 +4350,7 @@ def flatten(col: "ColumnOrName") -> Column: >>> df.select(flatten(df.data).alias('r')).collect() [Row(r=[1, 2, 3, 4, 5, 6]), Row(r=None)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.flatten(_to_java_column(col))) + return _invoke_function_over_columns("flatten", col) def map_keys(col: "ColumnOrName") -> Column: @@ -4673,9 +4375,7 @@ def map_keys(col: "ColumnOrName") -> Column: |[1, 2]| +------+ """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.map_keys(_to_java_column(col))) + return _invoke_function_over_columns("map_keys", col) def map_values(col: "ColumnOrName") -> Column: @@ -4700,9 +4400,7 @@ def map_values(col: "ColumnOrName") -> Column: |[a, b]| +------+ """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.map_values(_to_java_column(col))) + return _invoke_function_over_columns("map_values", col) def map_entries(col: "ColumnOrName") -> Column: @@ -4727,9 +4425,7 @@ def map_entries(col: "ColumnOrName") -> Column: |[{1, a}, {2, b}]| +----------------+ """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.map_entries(_to_java_column(col))) + return _invoke_function_over_columns("map_entries", col) def map_from_entries(col: "ColumnOrName") -> Column: @@ -4754,9 +4450,7 @@ def map_from_entries(col: "ColumnOrName") -> Column: |{1 -> a, 2 -> b}| +----------------+ """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.map_from_entries(_to_java_column(col))) + return _invoke_function_over_columns("map_from_entries", col) def array_repeat(col: "ColumnOrName", count: Union["ColumnOrName", int]) -> Column: @@ -4778,12 +4472,9 @@ def array_repeat(col: "ColumnOrName", count: Union["ColumnOrName", int]) -> Colu >>> df.select(array_repeat(df.data, 3).alias('r')).collect() [Row(r=['ab', 'ab', 'ab'])] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - count = lit(count) if isinstance(count, int) else count - return Column(sc._jvm.functions.array_repeat(_to_java_column(col), _to_java_column(count))) + return _invoke_function_over_columns("array_repeat", col, count) def arrays_zip(*cols: "ColumnOrName") -> Column: @@ -4805,9 +4496,7 @@ def arrays_zip(*cols: "ColumnOrName") -> Column: >>> df.select(arrays_zip(df.vals1, df.vals2).alias('zipped')).collect() [Row(zipped=[Row(vals1=1, vals2=2), Row(vals1=2, vals2=3), Row(vals1=3, vals2=4)])] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.arrays_zip(_to_seq(sc, cols, _to_java_column))) + return _invoke_function_over_seq_of_columns("arrays_zip", cols) @overload @@ -4843,12 +4532,9 @@ def map_concat( |{1 -> a, 2 -> b, 3 -> c}| +------------------------+ """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None if len(cols) == 1 and isinstance(cols[0], (list, set)): cols = cols[0] # type: ignore[assignment] - jc = sc._jvm.functions.map_concat(_to_seq(sc, cols, _to_java_column)) # type: ignore[arg-type] - return Column(jc) + return _invoke_function_over_seq_of_columns("map_concat", cols) # type: ignore[arg-type] def sequence( @@ -4870,16 +4556,10 @@ def sequence( >>> df2.select(sequence('C1', 'C2', 'C3').alias('r')).collect() [Row(r=[4, 2, 0, -2, -4])] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None if step is None: - return Column(sc._jvm.functions.sequence(_to_java_column(start), _to_java_column(stop))) + return _invoke_function_over_columns("sequence", start, stop) else: - return Column( - sc._jvm.functions.sequence( - _to_java_column(start), _to_java_column(stop), _to_java_column(step) - ) - ) + return _invoke_function_over_columns("sequence", start, stop, step) def from_csv( @@ -4931,8 +4611,7 @@ def from_csv( else: raise TypeError("schema argument should be a column or string") - jc = sc._jvm.functions.from_csv(_to_java_column(col), schema, _options_to_str(options)) - return Column(jc) + return _invoke_function("from_csv", _to_java_column(col), schema, _options_to_str(options)) def _unresolved_named_lambda_variable(*name_parts: Any) -> Column: @@ -5008,8 +4687,8 @@ def _create_lambda(f: Callable) -> Callable: if not isinstance(result, Column): raise ValueError("f should return Column, got {}".format(type(result))) - jexpr = result._jc.expr() # type: ignore[operator] - jargs = _to_seq(sc, [arg._jc.expr() for arg in args]) # type: ignore[operator] + jexpr = result._jc.expr() + jargs = _to_seq(sc, [arg._jc.expr() for arg in args]) return expressions.LambdaFunction(jexpr, jargs, False) @@ -5532,9 +5211,7 @@ def years(col: "ColumnOrName") -> Column: method of the `DataFrameWriterV2`. """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.years(_to_java_column(col))) + return _invoke_function_over_columns("years", col) def months(col: "ColumnOrName") -> Column: @@ -5557,9 +5234,7 @@ def months(col: "ColumnOrName") -> Column: method of the `DataFrameWriterV2`. """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.months(_to_java_column(col))) + return _invoke_function_over_columns("months", col) def days(col: "ColumnOrName") -> Column: @@ -5582,9 +5257,7 @@ def days(col: "ColumnOrName") -> Column: method of the `DataFrameWriterV2`. """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.days(_to_java_column(col))) + return _invoke_function_over_columns("days", col) def hours(col: "ColumnOrName") -> Column: @@ -5607,9 +5280,7 @@ def hours(col: "ColumnOrName") -> Column: method of the `DataFrameWriterV2`. """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return Column(sc._jvm.functions.hours(_to_java_column(col))) + return _invoke_function_over_columns("hours", col) def bucket(numBuckets: Union[Column, int], col: "ColumnOrName") -> Column: @@ -5642,7 +5313,7 @@ def bucket(numBuckets: Union[Column, int], col: "ColumnOrName") -> Column: if isinstance(numBuckets, int) else _to_java_column(numBuckets) ) - return Column(sc._jvm.functions.bucket(numBuckets, _to_java_column(col))) + return _invoke_function("bucket", numBuckets, _to_java_column(col)) # ---------------------------- User Defined Function ---------------------------------- diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 485e01776f872..bece13684e087 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -19,10 +19,10 @@ from typing import Callable, List, Optional, TYPE_CHECKING, overload, Dict, Union, cast, Tuple -from py4j.java_gateway import JavaObject # type: ignore[import] +from py4j.java_gateway import JavaObject from pyspark.sql.column import Column, _to_seq -from pyspark.sql.context import SQLContext +from pyspark.sql.session import SparkSession from pyspark.sql.dataframe import DataFrame from pyspark.sql.pandas.group_ops import PandasGroupedOpsMixin from pyspark.sql.types import StructType, StructField, IntegerType, StringType @@ -37,7 +37,7 @@ def dfapi(f: Callable) -> Callable: def _api(self: "GroupedData") -> DataFrame: name = f.__name__ jdf = getattr(self._jgd, name)() - return DataFrame(jdf, self.sql_ctx) + return DataFrame(jdf, self.session) _api.__name__ = f.__name__ _api.__doc__ = f.__doc__ @@ -47,8 +47,8 @@ def _api(self: "GroupedData") -> DataFrame: def df_varargs_api(f: Callable) -> Callable: def _api(self: "GroupedData", *cols: str) -> DataFrame: name = f.__name__ - jdf = getattr(self._jgd, name)(_to_seq(self.sql_ctx._sc, cols)) - return DataFrame(jdf, self.sql_ctx) + jdf = getattr(self._jgd, name)(_to_seq(self.session._sc, cols)) + return DataFrame(jdf, self.session) _api.__name__ = f.__name__ _api.__doc__ = f.__doc__ @@ -66,7 +66,7 @@ class GroupedData(PandasGroupedOpsMixin): def __init__(self, jgd: JavaObject, df: DataFrame): self._jgd = jgd self._df = df - self.sql_ctx: SQLContext = df.sql_ctx + self.session: SparkSession = df.sparkSession @overload def agg(self, *exprs: Column) -> DataFrame: @@ -134,8 +134,8 @@ def agg(self, *exprs: Union[Column, Dict[str, str]]) -> DataFrame: # Columns assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" exprs = cast(Tuple[Column, ...], exprs) - jdf = self._jgd.agg(exprs[0]._jc, _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) - return DataFrame(jdf, self.sql_ctx) + jdf = self._jgd.agg(exprs[0]._jc, _to_seq(self.session._sc, [c._jc for c in exprs[1:]])) + return DataFrame(jdf, self.session) @dfapi def count(self) -> DataFrame: diff --git a/python/pyspark/sql/observation.py b/python/pyspark/sql/observation.py index e5d426ab4c61e..48b3d96a45ae6 100644 --- a/python/pyspark/sql/observation.py +++ b/python/pyspark/sql/observation.py @@ -16,7 +16,7 @@ # from typing import Any, Dict, Optional -from py4j.java_gateway import JavaObject, JVMView # type: ignore[import] +from py4j.java_gateway import JavaObject, JVMView from pyspark.sql import column from pyspark.sql.column import Column @@ -109,7 +109,7 @@ def _on(self, df: DataFrame, *exprs: Column) -> DataFrame: observed_df = self._jo.on( df._jdf, exprs[0]._jc, column._to_seq(df._sc, [c._jc for c in exprs[1:]]) ) - return DataFrame(observed_df, df.sql_ctx) + return DataFrame(observed_df, df.sparkSession) @property def get(self) -> Dict[str, Any]: diff --git a/python/pyspark/sql/pandas/_typing/__init__.pyi b/python/pyspark/sql/pandas/_typing/__init__.pyi index d3796f48066de..6ecd04f057e02 100644 --- a/python/pyspark/sql/pandas/_typing/__init__.pyi +++ b/python/pyspark/sql/pandas/_typing/__init__.pyi @@ -33,7 +33,7 @@ from pyspark.sql._typing import LiteralType from pandas.core.frame import DataFrame as PandasDataFrame from pandas.core.series import Series as PandasSeries -import pyarrow # type: ignore[import] +import pyarrow DataFrameLike = PandasDataFrame SeriesLike = PandasSeries @@ -42,11 +42,12 @@ DataFrameOrSeriesLike_ = TypeVar("DataFrameOrSeriesLike_", bound=DataFrameOrSeri # UDF annotations PandasScalarUDFType = Literal[200] -PandasScalarIterUDFType = Literal[204] PandasGroupedMapUDFType = Literal[201] -PandasCogroupedMapUDFType = Literal[206] PandasGroupedAggUDFType = Literal[202] +PandasWindowAggUDFType = Literal[203] +PandasScalarIterUDFType = Literal[204] PandasMapIterUDFType = Literal[205] +PandasCogroupedMapUDFType = Literal[206] ArrowMapIterUDFType = Literal[207] class PandasVariadicScalarToScalarFunction(Protocol): diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py index 33a405838cc91..7153450d2bc4f 100644 --- a/python/pyspark/sql/pandas/conversion.py +++ b/python/pyspark/sql/pandas/conversion.py @@ -19,7 +19,7 @@ from collections import Counter from typing import List, Optional, Type, Union, no_type_check, overload, TYPE_CHECKING -from pyspark.rdd import _load_from_socket # type: ignore[attr-defined] +from pyspark.rdd import _load_from_socket from pyspark.sql.pandas.serializers import ArrowCollectSerializer from pyspark.sql.types import ( IntegralType, @@ -43,6 +43,7 @@ if TYPE_CHECKING: import numpy as np import pyarrow as pa + from py4j.java_gateway import JavaObject from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike from pyspark.sql import DataFrame @@ -88,9 +89,10 @@ def toPandas(self) -> "PandasDataFrameLike": import pandas as pd from pandas.core.dtypes.common import is_timedelta64_dtype - timezone = self.sql_ctx._conf.sessionLocalTimeZone() # type: ignore[attr-defined] + jconf = self.sparkSession._jconf + timezone = jconf.sessionLocalTimeZone() - if self.sql_ctx._conf.arrowPySparkEnabled(): # type: ignore[attr-defined] + if jconf.arrowPySparkEnabled(): use_arrow = True try: from pyspark.sql.pandas.types import to_arrow_schema @@ -100,7 +102,7 @@ def toPandas(self) -> "PandasDataFrameLike": to_arrow_schema(self.schema) except Exception as e: - if self.sql_ctx._conf.arrowPySparkFallbackEnabled(): # type: ignore[attr-defined] + if jconf.arrowPySparkFallbackEnabled(): msg = ( "toPandas attempted Arrow optimization because " "'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, " @@ -134,10 +136,8 @@ def toPandas(self) -> "PandasDataFrameLike": # Rename columns to avoid duplicated column names. tmp_column_names = ["col_{}".format(i) for i in range(len(self.columns))] - c = self.sql_ctx._conf - self_destruct = ( - c.arrowPySparkSelfDestructEnabled() # type: ignore[attr-defined] - ) + c = self.sparkSession._jconf + self_destruct = c.arrowPySparkSelfDestructEnabled() batches = self.toDF(*tmp_column_names)._collect_as_arrow( split_batches=self_destruct ) @@ -322,7 +322,7 @@ def _collect_as_arrow(self, split_batches: bool = False) -> List["pa.RecordBatch port, auth_secret, jsocket_auth_server, - ) = self._jdf.collectAsArrowToPython() # type: ignore[operator] + ) = self._jdf.collectAsArrowToPython() # Collect list of un-ordered batches where last element is a list of correct order indices try: @@ -368,6 +368,8 @@ class SparkConversionMixin: can use this class. """ + _jsparkSession: "JavaObject" + @overload def createDataFrame( self, data: "PandasDataFrameLike", samplingRatio: Optional[float] = ... @@ -398,20 +400,17 @@ def createDataFrame( # type: ignore[misc] require_minimum_pandas_version() - timezone = self._wrapped._conf.sessionLocalTimeZone() # type: ignore[attr-defined] + timezone = self._jconf.sessionLocalTimeZone() # If no schema supplied by user then get the names of columns only if schema is None: schema = [str(x) if not isinstance(x, str) else x for x in data.columns] - if ( - self._wrapped._conf.arrowPySparkEnabled() # type: ignore[attr-defined] - and len(data) > 0 - ): + if self._jconf.arrowPySparkEnabled() and len(data) > 0: try: return self._create_from_pandas_with_arrow(data, schema, timezone) except Exception as e: - if self._wrapped._conf.arrowPySparkFallbackEnabled(): # type: ignore[attr-defined] + if self._jconf.arrowPySparkFallbackEnabled(): msg = ( "createDataFrame attempted Arrow optimization because " "'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, " @@ -603,25 +602,25 @@ def _create_from_pandas_with_arrow( for pdf_slice in pdf_slices ] - jsqlContext = self._wrapped._jsqlContext # type: ignore[attr-defined] + jsparkSession = self._jsparkSession - safecheck = self._wrapped._conf.arrowSafeTypeConversion() # type: ignore[attr-defined] + safecheck = self._jconf.arrowSafeTypeConversion() col_by_name = True # col by name only applies to StructType columns, can't happen here ser = ArrowStreamPandasSerializer(timezone, safecheck, col_by_name) @no_type_check def reader_func(temp_filename): - return self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsqlContext, temp_filename) + return self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsparkSession, temp_filename) @no_type_check def create_RDD_server(): - return self._jvm.ArrowRDDServer(jsqlContext) + return self._jvm.ArrowRDDServer(jsparkSession) # Create Spark DataFrame from Arrow stream file, using one batch per partition jrdd = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, create_RDD_server) assert self._jvm is not None - jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsqlContext) - df = DataFrame(jdf, self._wrapped) + jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsparkSession) + df = DataFrame(jdf, self) df._schema = schema return df diff --git a/python/pyspark/sql/pandas/functions.py b/python/pyspark/sql/pandas/functions.py index 5f502209ae608..94fabdbb29590 100644 --- a/python/pyspark/sql/pandas/functions.py +++ b/python/pyspark/sql/pandas/functions.py @@ -17,7 +17,8 @@ import functools import warnings -from inspect import getfullargspec +from inspect import getfullargspec, signature +from typing import get_type_hints from pyspark.rdd import PythonEvalType from pyspark.sql.pandas.typehints import infer_eval_type @@ -317,15 +318,15 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: # | string| None| X| X| X| X| X| X| X| X| X| X| X| X| X| X| 'a'| X| X| X| X| X| 'A'| X| # noqa # | decimal(10,0)| None| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| Decimal('1')| X| X| X| X| X| X| # noqa # | array| None| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| [1, 2, 3]| X| X| X| X| X| # noqa - # | map| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| # noqa + # | map| None| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| # noqa # | struct<_1:int>| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| # noqa # | binary| None|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')| bytearray(b'\x01')| bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'')|bytearray(b'')|bytearray(b'')| bytearray(b'')| bytearray(b'')| bytearray(b'a')| X| X|bytearray(b'')| bytearray(b'')| bytearray(b'')|bytearray(b'A')| bytearray(b'')| # noqa - # +-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+--------------+-----------------+------------------+---------------+--------------------------------+ # noqa # + # +-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+--------------+-----------------+------------------+---------------+--------------------------------+ # noqa # # Note: DDL formatted string is used for 'SQL Type' for simplicity. This string can be # used in `returnType`. # Note: The values inside of the table are generated by `repr`. - # Note: Python 3.7.3, Pandas 1.1.1 and PyArrow 1.0.1 are used. + # Note: Python 3.9.5, Pandas 1.4.0 and PyArrow 6.0.1 are used. # Note: Timezone is KST. # Note: 'X' means it throws an exception during the conversion. require_minimum_pandas_version() @@ -385,8 +386,6 @@ def _create_pandas_udf(f, returnType, evalType): argspec = getfullargspec(f) # pandas UDF by type hints. - from inspect import signature - if evalType in [ PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, @@ -410,7 +409,11 @@ def _create_pandas_udf(f, returnType, evalType): # 'SQL_COGROUPED_MAP_PANDAS_UDF', the evaluation type will always be set. pass elif len(argspec.annotations) > 0: - evalType = infer_eval_type(signature(f)) + try: + type_hints = get_type_hints(f) + except NameError: + type_hints = {} + evalType = infer_eval_type(signature(f), type_hints) assert evalType is not None if evalType is None: diff --git a/python/pyspark/sql/pandas/functions.pyi b/python/pyspark/sql/pandas/functions.pyi index 7ff06be915137..1af6f8625935e 100644 --- a/python/pyspark/sql/pandas/functions.pyi +++ b/python/pyspark/sql/pandas/functions.pyi @@ -65,11 +65,17 @@ def pandas_udf( functionType: PandasScalarUDFType, ) -> UserDefinedFunctionLike: ... @overload -def pandas_udf(f: Union[StructType, str], returnType: PandasScalarUDFType) -> Callable[[PandasScalarToStructFunction], UserDefinedFunctionLike]: ... # type: ignore[misc] +def pandas_udf( + f: Union[StructType, str], returnType: PandasScalarUDFType +) -> Callable[[PandasScalarToStructFunction], UserDefinedFunctionLike]: ... @overload -def pandas_udf(f: Union[StructType, str], *, functionType: PandasScalarUDFType) -> Callable[[PandasScalarToStructFunction], UserDefinedFunctionLike]: ... # type: ignore[misc] +def pandas_udf( + f: Union[StructType, str], *, functionType: PandasScalarUDFType +) -> Callable[[PandasScalarToStructFunction], UserDefinedFunctionLike]: ... @overload -def pandas_udf(*, returnType: Union[StructType, str], functionType: PandasScalarUDFType) -> Callable[[PandasScalarToStructFunction], UserDefinedFunctionLike]: ... # type: ignore[misc] +def pandas_udf( + *, returnType: Union[StructType, str], functionType: PandasScalarUDFType +) -> Callable[[PandasScalarToStructFunction], UserDefinedFunctionLike]: ... @overload def pandas_udf( f: PandasScalarIterFunction, diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 35f531f5c4d0d..6178433573e9e 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -213,8 +213,8 @@ def applyInPandas( udf = pandas_udf(func, returnType=schema, functionType=PandasUDFType.GROUPED_MAP) df = self._df udf_column = udf(*[df[col] for col in df.columns]) - jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) # type: ignore[attr-defined] - return DataFrame(jdf, self.sql_ctx) + jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) + return DataFrame(jdf, self.session) def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps": """ @@ -246,7 +246,6 @@ class PandasCogroupedOps: def __init__(self, gd1: "GroupedData", gd2: "GroupedData"): self._gd1 = gd1 self._gd2 = gd2 - self.sql_ctx = gd1.sql_ctx def applyInPandas( self, func: "PandasCogroupedMapFunction", schema: Union[StructType, str] @@ -342,10 +341,8 @@ def applyInPandas( all_cols = self._extract_cols(self._gd1) + self._extract_cols(self._gd2) udf_column = udf(*all_cols) - jdf = self._gd1._jgd.flatMapCoGroupsInPandas( # type: ignore[attr-defined] - self._gd2._jgd, udf_column._jc.expr() # type: ignore[attr-defined] - ) - return DataFrame(jdf, self.sql_ctx) + jdf = self._gd1._jgd.flatMapCoGroupsInPandas(self._gd2._jgd, udf_column._jc.expr()) + return DataFrame(jdf, self._gd1.session) @staticmethod def _extract_cols(gd: "GroupedData") -> List[Column]: diff --git a/python/pyspark/sql/pandas/map_ops.py b/python/pyspark/sql/pandas/map_ops.py index c1c29ecbc7576..5f89577a1b6c3 100644 --- a/python/pyspark/sql/pandas/map_ops.py +++ b/python/pyspark/sql/pandas/map_ops.py @@ -89,8 +89,8 @@ def mapInPandas( func, returnType=schema, functionType=PythonEvalType.SQL_MAP_PANDAS_ITER_UDF ) # type: ignore[call-overload] udf_column = udf(*[self[col] for col in self.columns]) - jdf = self._jdf.mapInPandas(udf_column._jc.expr()) # type: ignore[operator] - return DataFrame(jdf, self.sql_ctx) + jdf = self._jdf.mapInPandas(udf_column._jc.expr()) + return DataFrame(jdf, self.sparkSession) def mapInArrow( self, func: "ArrowMapIterFunction", schema: Union[StructType, str] @@ -153,7 +153,7 @@ def mapInArrow( ) # type: ignore[call-overload] udf_column = udf(*[self[col] for col in self.columns]) jdf = self._jdf.pythonMapInArrow(udf_column._jc.expr()) - return DataFrame(jdf, self.sql_ctx) + return DataFrame(jdf, self.sparkSession) def _test() -> None: diff --git a/python/pyspark/sql/pandas/typehints.py b/python/pyspark/sql/pandas/typehints.py index 167104c1ad7dc..fc3dd89a0712a 100644 --- a/python/pyspark/sql/pandas/typehints.py +++ b/python/pyspark/sql/pandas/typehints.py @@ -15,7 +15,7 @@ # limitations under the License. # from inspect import Signature -from typing import Any, Callable, Optional, Union, TYPE_CHECKING +from typing import Any, Callable, Dict, Optional, Union, TYPE_CHECKING from pyspark.sql.pandas.utils import require_minimum_pandas_version @@ -28,11 +28,11 @@ def infer_eval_type( - sig: Signature, + sig: Signature, type_hints: Dict[str, Any] ) -> Union["PandasScalarUDFType", "PandasScalarIterUDFType", "PandasGroupedAggUDFType"]: """ Infers the evaluation type in :class:`pyspark.rdd.PythonEvalType` from - :class:`inspect.Signature` instance. + :class:`inspect.Signature` instance and type hints. """ from pyspark.sql.pandas.functions import PandasUDFType @@ -43,7 +43,7 @@ def infer_eval_type( annotations = {} for param in sig.parameters.values(): if param.annotation is not param.empty: - annotations[param.name] = param.annotation + annotations[param.name] = type_hints.get(param.name, param.annotation) # Check if all arguments have type hints parameters_sig = [ @@ -53,7 +53,7 @@ def infer_eval_type( raise ValueError("Type hints for all parameters should be specified; however, got %s" % sig) # Check if the return has a type hint - return_annotation = sig.return_annotation + return_annotation = type_hints.get("return", sig.return_annotation) if sig.empty is return_annotation: raise ValueError("Type hint for the return type should be specified; however, got %s" % sig) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index df4a0891dcc71..760e54831c2f0 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -17,7 +17,7 @@ import sys from typing import cast, overload, Dict, Iterable, List, Optional, Tuple, TYPE_CHECKING, Union -from py4j.java_gateway import JavaClass, JavaObject # type: ignore[import] +from py4j.java_gateway import JavaClass, JavaObject from pyspark import RDD, since from pyspark.sql.column import _to_seq, _to_java_column, Column @@ -27,7 +27,7 @@ if TYPE_CHECKING: from pyspark.sql._typing import OptionalPrimitiveType, ColumnOrName - from pyspark.sql.context import SQLContext + from pyspark.sql.session import SparkSession from pyspark.sql.dataframe import DataFrame from pyspark.sql.streaming import StreamingQuery @@ -62,8 +62,8 @@ class DataFrameReader(OptionUtils): .. versionadded:: 1.4 """ - def __init__(self, spark: "SQLContext"): - self._jreader = spark._ssql_ctx.read() # type: ignore[attr-defined] + def __init__(self, spark: "SparkSession"): + self._jreader = spark._jsparkSession.read() self._spark = spark def _df(self, jdf: JavaObject) -> "DataFrame": @@ -112,9 +112,7 @@ def schema(self, schema: Union[StructType, str]) -> "DataFrameReader": spark = SparkSession._getActiveSessionOrCreate() if isinstance(schema, StructType): - jschema = spark._jsparkSession.parseDataType( - schema.json() - ) # type: ignore[attr-defined] + jschema = spark._jsparkSession.parseDataType(schema.json()) self._jreader = self._jreader.schema(jschema) elif isinstance(schema, str): self._jreader = self._jreader.schema(schema) @@ -187,7 +185,7 @@ def load( def json( self, - path: Union[str, List[str], "RDD[str]"], + path: Union[str, List[str], RDD[str]], schema: Optional[Union[StructType, str]] = None, primitivesAsString: Optional[Union[bool, str]] = None, prefersDecimal: Optional[Union[bool, str]] = None, @@ -283,11 +281,7 @@ def json( path = [path] if type(path) == list: assert self._spark._sc._jvm is not None - return self._df( - self._jreader.json( - self._spark._sc._jvm.PythonUtils.toSeq(path) # type: ignore[attr-defined] - ) - ) + return self._df(self._jreader.json(self._spark._sc._jvm.PythonUtils.toSeq(path))) elif isinstance(path, RDD): def func(iterator: Iterable) -> Iterable: @@ -301,7 +295,7 @@ def func(iterator: Iterable) -> Iterable: keyed = path.mapPartitions(func) keyed._bypass_serializer = True # type: ignore[attr-defined] assert self._spark._jvm is not None - jrdd = keyed._jrdd.map(self._spark._jvm.BytesToString()) # type: ignore[attr-defined] + jrdd = keyed._jrdd.map(self._spark._jvm.BytesToString()) return self._df(self._jreader.json(jrdd)) else: raise TypeError("path can be only string, list or RDD") @@ -424,11 +418,7 @@ def text( if isinstance(paths, str): paths = [paths] assert self._spark._sc._jvm is not None - return self._df( - self._jreader.text( - self._spark._sc._jvm.PythonUtils.toSeq(paths) # type: ignore[attr-defined] - ) - ) + return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(paths))) def csv( self, @@ -560,7 +550,7 @@ def func(iterator): # There aren't any jvm api for creating a dataframe from rdd storing csv. # We can do it through creating a jvm dataset firstly and using the jvm api # for creating a dataframe from dataset storing csv. - jdataset = self._spark._ssql_ctx.createDataset( + jdataset = self._spark._jsparkSession.createDataset( jrdd.rdd(), self._spark._jvm.Encoders.STRING() ) return self._df(self._jreader.csv(jdataset)) @@ -737,8 +727,8 @@ class DataFrameWriter(OptionUtils): def __init__(self, df: "DataFrame"): self._df = df - self._spark = df.sql_ctx - self._jwrite = df._jdf.write() # type: ignore[operator] + self._spark = df.sparkSession + self._jwrite = df._jdf.write() def _sq(self, jsq: JavaObject) -> "StreamingQuery": from pyspark.sql.streaming import StreamingQuery @@ -1360,8 +1350,8 @@ class DataFrameWriterV2: def __init__(self, df: "DataFrame", table: str): self._df = df - self._spark = df.sql_ctx - self._jwriter = df._jdf.writeTo(table) # type: ignore[operator] + self._spark = df.sparkSession + self._jwriter = df._jdf.writeTo(table) @since(3.1) def using(self, provider: str) -> "DataFrameWriterV2": diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 58621491dfb9a..8f4809907b599 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -36,7 +36,7 @@ TYPE_CHECKING, ) -from py4j.java_gateway import JavaObject # type: ignore[import] +from py4j.java_gateway import JavaObject from pyspark import SparkConf, SparkContext, since from pyspark.rdd import RDD @@ -230,11 +230,6 @@ def enableHiveSupport(self) -> "SparkSession.Builder": """ return self.config("spark.sql.catalogImplementation", "hive") - def _sparkContext(self, sc: SparkContext) -> "SparkSession.Builder": - with self._lock: - self._sc = sc - return self - def getOrCreate(self) -> "SparkSession": """Gets an existing :class:`SparkSession` or, if there is no existing one, creates a new one based on the options set in this builder. @@ -266,15 +261,12 @@ def getOrCreate(self) -> "SparkSession": from pyspark.conf import SparkConf session = SparkSession._instantiatedSession - if session is None or session._sc._jsc is None: # type: ignore[attr-defined] - if self._sc is not None: - sc = self._sc - else: - sparkConf = SparkConf() - for key, value in self._options.items(): - sparkConf.set(key, value) - # This SparkContext may be an existing one. - sc = SparkContext.getOrCreate(sparkConf) + if session is None or session._sc._jsc is None: + sparkConf = SparkConf() + for key, value in self._options.items(): + sparkConf.set(key, value) + # This SparkContext may be an existing one. + sc = SparkContext.getOrCreate(sparkConf) # Do not update `SparkConf` for existing `SparkContext`, as it's shared # by all sessions. session = SparkSession(sc, options=self._options) @@ -296,8 +288,6 @@ def __init__( jsparkSession: Optional[JavaObject] = None, options: Dict[str, Any] = {}, ): - from pyspark.sql.context import SQLContext - self._sc = sparkContext self._jsc = self._sc._jsc self._jvm = self._sc._jvm @@ -320,8 +310,6 @@ def __init__( jsparkSession, options ) self._jsparkSession = jsparkSession - self._jwrapped = self._jsparkSession.sqlContext() - self._wrapped = SQLContext(self._sc, self, self._jwrapped) _monkey_patch_RDD(self) install_exception_handler() # If we had an instantiated SparkSession attached with a SparkContext @@ -329,7 +317,7 @@ def __init__( # Otherwise, we will use invalid SparkSession when we call Builder.getOrCreate. if ( SparkSession._instantiatedSession is None - or SparkSession._instantiatedSession._sc._jsc is None # type: ignore[attr-defined] + or SparkSession._instantiatedSession._sc._jsc is None ): SparkSession._instantiatedSession = self SparkSession._activeSession = self @@ -345,9 +333,14 @@ def _repr_html_(self) -> str:
""".format( catalogImplementation=self.conf.get("spark.sql.catalogImplementation"), - sc_HTML=self.sparkContext._repr_html_(), # type: ignore[attr-defined] + sc_HTML=self.sparkContext._repr_html_(), ) + @property + def _jconf(self) -> "JavaObject": + """Accessor for the JVM SQL-specific configurations""" + return self._jsparkSession.sessionState().conf() + @since(2.0) def newSession(self) -> "SparkSession": """ @@ -498,7 +491,7 @@ def range( else: jdf = self._jsparkSession.range(int(start), int(end), int(step), int(numPartitions)) - return DataFrame(jdf, self._wrapped) + return DataFrame(jdf, self) def _inferSchemaFromList( self, data: Iterable[Any], names: Optional[List[str]] = None @@ -519,7 +512,7 @@ def _inferSchemaFromList( """ if not data: raise ValueError("can not infer schema from empty dataset") - infer_dict_as_struct = self._wrapped._conf.inferDictAsStruct() # type: ignore[attr-defined] + infer_dict_as_struct = self._jconf.inferDictAsStruct() prefer_timestamp_ntz = is_timestamp_ntz_preferred() schema = reduce( _merge_type, @@ -531,7 +524,7 @@ def _inferSchemaFromList( def _inferSchema( self, - rdd: "RDD[Any]", + rdd: RDD[Any], samplingRatio: Optional[float] = None, names: Optional[List[str]] = None, ) -> StructType: @@ -554,7 +547,7 @@ def _inferSchema( if not first: raise ValueError("The first row in RDD is empty, " "can not infer schema") - infer_dict_as_struct = self._wrapped._conf.inferDictAsStruct() # type: ignore[attr-defined] + infer_dict_as_struct = self._jconf.inferDictAsStruct() prefer_timestamp_ntz = is_timestamp_ntz_preferred() if samplingRatio is None: schema = _infer_schema( @@ -596,10 +589,10 @@ def _inferSchema( def _createFromRDD( self, - rdd: "RDD[Any]", + rdd: RDD[Any], schema: Optional[Union[DataType, List[str]]], samplingRatio: Optional[float], - ) -> Tuple["RDD[Tuple]", StructType]: + ) -> Tuple[RDD[Tuple], StructType]: """ Create an RDD for DataFrame from an existing RDD, returns the RDD and schema. """ @@ -625,7 +618,7 @@ def _createFromRDD( def _createFromLocal( self, data: Iterable[Any], schema: Optional[Union[DataType, List[str]]] - ) -> Tuple["RDD[Tuple]", StructType]: + ) -> Tuple[RDD[Tuple], StructType]: """ Create an RDD for DataFrame from a list or pandas.DataFrame, returns the RDD and schema. @@ -669,13 +662,13 @@ def _create_shell_session() -> "SparkSession": # Try to access HiveConf, it will raise exception if Hive is not added conf = SparkConf() assert SparkContext._jvm is not None - if cast(str, conf.get("spark.sql.catalogImplementation", "hive")).lower() == "hive": + if conf.get("spark.sql.catalogImplementation", "hive").lower() == "hive": SparkContext._jvm.org.apache.hadoop.hive.conf.HiveConf() return SparkSession.builder.enableHiveSupport().getOrCreate() else: return SparkSession._getActiveSessionOrCreate() except (py4j.protocol.Py4JError, TypeError): - if cast(str, conf.get("spark.sql.catalogImplementation", "")).lower() == "hive": + if conf.get("spark.sql.catalogImplementation", "").lower() == "hive": warnings.warn( "Fall back to non-hive support because failing to access HiveConf, " "please make sure you build spark with hive" @@ -684,14 +677,20 @@ def _create_shell_session() -> "SparkSession": return SparkSession._getActiveSessionOrCreate() @staticmethod - def _getActiveSessionOrCreate() -> "SparkSession": + def _getActiveSessionOrCreate(**static_conf: Any) -> "SparkSession": """ Returns the active :class:`SparkSession` for the current thread, returned by the builder, or if there is no existing one, creates a new one based on the options set in the builder. + + NOTE that 'static_conf' might not be set if there's an active or default Spark session + running. """ spark = SparkSession.getActiveSession() if spark is None: - spark = SparkSession.builder.getOrCreate() + builder = SparkSession.builder + for k, v in static_conf.items(): + builder = builder.config(k, v) + spark = builder.getOrCreate() return spark @overload @@ -767,7 +766,7 @@ def createDataFrame( def createDataFrame( # type: ignore[misc] self, - data: Union["RDD[Any]", Iterable[Any], "PandasDataFrameLike"], + data: Union[RDD[Any], Iterable[Any], "PandasDataFrameLike"], schema: Optional[Union[AtomicType, StructType, str]] = None, samplingRatio: Optional[float] = None, verifySchema: bool = True, @@ -898,7 +897,7 @@ def createDataFrame( # type: ignore[misc] def _create_dataframe( self, - data: Union["RDD[Any]", Iterable[Any]], + data: Union[RDD[Any], Iterable[Any]], schema: Optional[Union[DataType, List[str]]], samplingRatio: Optional[float], verifySchema: bool, @@ -927,18 +926,18 @@ def prepare(obj): return (obj,) else: - prepare = lambda obj: obj + + def prepare(obj: Any) -> Any: + return obj if isinstance(data, RDD): rdd, struct = self._createFromRDD(data.map(prepare), schema, samplingRatio) else: rdd, struct = self._createFromLocal(map(prepare, data), schema) assert self._jvm is not None - jrdd = self._jvm.SerDeUtil.toJavaArray( - rdd._to_java_object_rdd() # type: ignore[attr-defined] - ) + jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) jdf = self._jsparkSession.applySchemaToPythonRDD(jrdd.rdd(), struct.json()) - df = DataFrame(jdf, self._wrapped) + df = DataFrame(jdf, self) df._schema = struct return df @@ -1032,7 +1031,7 @@ def sql(self, sqlQuery: str, **kwargs: Any) -> DataFrame: if len(kwargs) > 0: sqlQuery = formatter.format(sqlQuery, **kwargs) try: - return DataFrame(self._jsparkSession.sql(sqlQuery), self._wrapped) + return DataFrame(self._jsparkSession.sql(sqlQuery), self) finally: if len(kwargs) > 0: formatter.clear() @@ -1053,7 +1052,7 @@ def table(self, tableName: str) -> DataFrame: >>> sorted(df.collect()) == sorted(df2.collect()) True """ - return DataFrame(self._jsparkSession.table(tableName), self._wrapped) + return DataFrame(self._jsparkSession.table(tableName), self) @property def read(self) -> DataFrameReader: @@ -1067,7 +1066,7 @@ def read(self) -> DataFrameReader: ------- :class:`DataFrameReader` """ - return DataFrameReader(self._wrapped) + return DataFrameReader(self) @property def readStream(self) -> DataStreamReader: @@ -1085,7 +1084,7 @@ def readStream(self) -> DataStreamReader: ------- :class:`DataStreamReader` """ - return DataStreamReader(self._wrapped) + return DataStreamReader(self) @property def streams(self) -> "StreamingQueryManager": diff --git a/python/pyspark/sql/sql_formatter.py b/python/pyspark/sql/sql_formatter.py index 8528dd3e88352..5e79b9ff5ea98 100644 --- a/python/pyspark/sql/sql_formatter.py +++ b/python/pyspark/sql/sql_formatter.py @@ -50,9 +50,9 @@ def _convert_value(self, val: Any, field_name: str) -> Optional[str]: from pyspark.sql import Column, DataFrame if isinstance(val, Column): - assert SparkContext._gateway is not None # type: ignore[attr-defined] + assert SparkContext._gateway is not None - gw = SparkContext._gateway # type: ignore[attr-defined] + gw = SparkContext._gateway jexpr = val._jc.expr() if is_instance_of( gw, jexpr, "org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute" diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index de68ccc3d9a00..7517a41337f90 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -20,7 +20,7 @@ from collections.abc import Iterator from typing import cast, overload, Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union -from py4j.java_gateway import java_import, JavaObject # type: ignore[import] +from py4j.java_gateway import java_import, JavaObject from pyspark import since from pyspark.sql.column import _to_seq @@ -29,7 +29,7 @@ from pyspark.sql.utils import ForeachBatchFunction, StreamingQueryException if TYPE_CHECKING: - from pyspark.sql import SQLContext + from pyspark.sql.session import SparkSession from pyspark.sql._typing import SupportsProcess, OptionalPrimitiveType from pyspark.sql.dataframe import DataFrame @@ -316,8 +316,8 @@ class DataStreamReader(OptionUtils): This API is evolving. """ - def __init__(self, spark: "SQLContext") -> None: - self._jreader = spark._ssql_ctx.readStream() + def __init__(self, spark: "SparkSession") -> None: + self._jreader = spark._jsparkSession.readStream() self._spark = spark def _df(self, jdf: JavaObject) -> "DataFrame": @@ -856,7 +856,7 @@ class DataStreamWriter: def __init__(self, df: "DataFrame") -> None: self._df = df - self._spark = df.sql_ctx + self._spark = df.sparkSession self._jwrite = df._jdf.writeStream() def _sq(self, jsq: JavaObject) -> StreamingQuery: @@ -1196,7 +1196,7 @@ def foreach(self, f: Union[Callable[[Row], None], "SupportsProcess"]) -> "DataSt >>> writer = sdf.writeStream.foreach(RowPrinter()) """ - from pyspark.rdd import _wrap_function # type: ignore[attr-defined] + from pyspark.rdd import _wrap_function from pyspark.serializers import CPickleSerializer, AutoBatchedSerializer from pyspark.taskcontext import TaskContext @@ -1474,7 +1474,7 @@ def _test() -> None: import tempfile from pyspark.sql import SparkSession, SQLContext import pyspark.sql.streaming - from py4j.protocol import Py4JError # type: ignore[import] + from py4j.protocol import Py4JError os.chdir(os.environ["SPARK_HOME"]) diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 1367fe79f0260..5f5e88fd46deb 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -479,6 +479,33 @@ def foo(): self.assertRaises(TypeError, foo) + def test_with_columns(self): + # With single column + keys = self.df.withColumns({"key": self.df.key}).select("key").collect() + self.assertEqual([r.key for r in keys], list(range(100))) + + # With key and value columns + kvs = ( + self.df.withColumns({"key": self.df.key, "value": self.df.value}) + .select("key", "value") + .collect() + ) + self.assertEqual([(r.key, r.value) for r in kvs], [(i, str(i)) for i in range(100)]) + + # Columns rename + kvs = ( + self.df.withColumns({"key_alias": self.df.key, "value_alias": self.df.value}) + .select("key_alias", "value_alias") + .collect() + ) + self.assertEqual( + [(r.key_alias, r.value_alias) for r in kvs], [(i, str(i)) for i in range(100)] + ) + + # Type check + self.assertRaises(TypeError, self.df.withColumns, ["key"]) + self.assertRaises(AssertionError, self.df.withColumns) + def test_generic_hints(self): from pyspark.sql import DataFrame diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 5021da569fe40..5c6acaffa324b 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -20,7 +20,7 @@ import re import math -from py4j.protocol import Py4JJavaError # type: ignore[import] +from py4j.protocol import Py4JJavaError from pyspark.sql import Row, Window, types from pyspark.sql.functions import ( udf, diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index bee9cff525717..08fba7cea01fc 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -145,7 +145,10 @@ def test_vectorized_udf_basic(self): col("id").cast("boolean").alias("bool"), array(col("id")).alias("array_long"), ) - f = lambda x: x + + def f(x): + return x + for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]: str_f = pandas_udf(f, StringType(), udf_type) int_f = pandas_udf(f, IntegerType(), udf_type) @@ -283,7 +286,9 @@ def test_vectorized_udf_null_string(self): def test_vectorized_udf_string_in_udf(self): df = self.spark.range(10) - scalar_f = lambda x: pd.Series(map(str, x)) + + def scalar_f(x): + return pd.Series(map(str, x)) def iter_f(it): for i in it: @@ -305,7 +310,10 @@ def test_vectorized_udf_datatype_string(self): col("id").cast("decimal").alias("decimal"), col("id").cast("boolean").alias("bool"), ) - f = lambda x: x + + def f(x): + return x + for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]: str_f = pandas_udf(f, "string", udf_type) int_f = pandas_udf(f, "integer", udf_type) diff --git a/python/pyspark/sql/tests/test_pandas_udf_typehints.py b/python/pyspark/sql/tests/test_pandas_udf_typehints.py index 119b2cf310f5d..44315c95614b8 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_typehints.py +++ b/python/pyspark/sql/tests/test_pandas_udf_typehints.py @@ -15,8 +15,8 @@ # limitations under the License. # import unittest -import inspect -from typing import Union, Iterator, Tuple, cast +from inspect import signature +from typing import Union, Iterator, Tuple, cast, get_type_hints from pyspark.sql.functions import mean, lit from pyspark.testing.sqlutils import ( @@ -45,84 +45,116 @@ def test_type_annotation_scalar(self): def func(col: pd.Series) -> pd.Series: pass - self.assertEqual(infer_eval_type(inspect.signature(func)), PandasUDFType.SCALAR) + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR + ) def func(col: pd.DataFrame, col1: pd.Series) -> pd.DataFrame: pass - self.assertEqual(infer_eval_type(inspect.signature(func)), PandasUDFType.SCALAR) + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR + ) def func(col: pd.DataFrame, *args: pd.Series) -> pd.Series: pass - self.assertEqual(infer_eval_type(inspect.signature(func)), PandasUDFType.SCALAR) + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR + ) def func(col: pd.Series, *args: pd.Series, **kwargs: pd.DataFrame) -> pd.Series: pass - self.assertEqual(infer_eval_type(inspect.signature(func)), PandasUDFType.SCALAR) + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR + ) def func(col: pd.Series, *, col2: pd.DataFrame) -> pd.DataFrame: pass - self.assertEqual(infer_eval_type(inspect.signature(func)), PandasUDFType.SCALAR) + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR + ) def func(col: Union[pd.Series, pd.DataFrame], *, col2: pd.DataFrame) -> pd.Series: pass - self.assertEqual(infer_eval_type(inspect.signature(func)), PandasUDFType.SCALAR) + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR + ) def test_type_annotation_scalar_iter(self): def func(iter: Iterator[pd.Series]) -> Iterator[pd.Series]: pass - self.assertEqual(infer_eval_type(inspect.signature(func)), PandasUDFType.SCALAR_ITER) + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR_ITER + ) def func(iter: Iterator[Tuple[pd.DataFrame, pd.Series]]) -> Iterator[pd.DataFrame]: pass - self.assertEqual(infer_eval_type(inspect.signature(func)), PandasUDFType.SCALAR_ITER) + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR_ITER + ) def func(iter: Iterator[Tuple[pd.DataFrame, ...]]) -> Iterator[pd.Series]: pass - self.assertEqual(infer_eval_type(inspect.signature(func)), PandasUDFType.SCALAR_ITER) + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR_ITER + ) def func(iter: Iterator[Tuple[Union[pd.DataFrame, pd.Series], ...]]) -> Iterator[pd.Series]: pass - self.assertEqual(infer_eval_type(inspect.signature(func)), PandasUDFType.SCALAR_ITER) + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR_ITER + ) def test_type_annotation_group_agg(self): def func(col: pd.Series) -> str: pass - self.assertEqual(infer_eval_type(inspect.signature(func)), PandasUDFType.GROUPED_AGG) + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.GROUPED_AGG + ) def func(col: pd.DataFrame, col1: pd.Series) -> int: pass - self.assertEqual(infer_eval_type(inspect.signature(func)), PandasUDFType.GROUPED_AGG) + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.GROUPED_AGG + ) def func(col: pd.DataFrame, *args: pd.Series) -> Row: pass - self.assertEqual(infer_eval_type(inspect.signature(func)), PandasUDFType.GROUPED_AGG) + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.GROUPED_AGG + ) def func(col: pd.Series, *args: pd.Series, **kwargs: pd.DataFrame) -> str: pass - self.assertEqual(infer_eval_type(inspect.signature(func)), PandasUDFType.GROUPED_AGG) + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.GROUPED_AGG + ) def func(col: pd.Series, *, col2: pd.DataFrame) -> float: pass - self.assertEqual(infer_eval_type(inspect.signature(func)), PandasUDFType.GROUPED_AGG) + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.GROUPED_AGG + ) def func(col: Union[pd.Series, pd.DataFrame], *, col2: pd.DataFrame) -> float: pass - self.assertEqual(infer_eval_type(inspect.signature(func)), PandasUDFType.GROUPED_AGG) + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.GROUPED_AGG + ) def test_type_annotation_negative(self): def func(col: str) -> pd.Series: @@ -132,7 +164,8 @@ def func(col: str) -> pd.Series: NotImplementedError, "Unsupported signature.*str", infer_eval_type, - inspect.signature(func), + signature(func), + get_type_hints(func), ) def func(col: pd.DataFrame, col1: int) -> pd.DataFrame: @@ -142,7 +175,8 @@ def func(col: pd.DataFrame, col1: int) -> pd.DataFrame: NotImplementedError, "Unsupported signature.*int", infer_eval_type, - inspect.signature(func), + signature(func), + get_type_hints(func), ) def func(col: Union[pd.DataFrame, str], col1: int) -> pd.DataFrame: @@ -152,7 +186,8 @@ def func(col: Union[pd.DataFrame, str], col1: int) -> pd.DataFrame: NotImplementedError, "Unsupported signature.*str", infer_eval_type, - inspect.signature(func), + signature(func), + get_type_hints(func), ) def func(col: pd.Series) -> Tuple[pd.DataFrame]: @@ -162,28 +197,41 @@ def func(col: pd.Series) -> Tuple[pd.DataFrame]: NotImplementedError, "Unsupported signature.*Tuple", infer_eval_type, - inspect.signature(func), + signature(func), + get_type_hints(func), ) def func(col, *args: pd.Series) -> pd.Series: pass self.assertRaisesRegex( - ValueError, "should be specified.*Series", infer_eval_type, inspect.signature(func) + ValueError, + "should be specified.*Series", + infer_eval_type, + signature(func), + get_type_hints(func), ) def func(col: pd.Series, *args: pd.Series, **kwargs: pd.DataFrame): pass self.assertRaisesRegex( - ValueError, "should be specified.*Series", infer_eval_type, inspect.signature(func) + ValueError, + "should be specified.*Series", + infer_eval_type, + signature(func), + get_type_hints(func), ) def func(col: pd.Series, *, col2) -> pd.DataFrame: pass self.assertRaisesRegex( - ValueError, "should be specified.*Series", infer_eval_type, inspect.signature(func) + ValueError, + "should be specified.*Series", + infer_eval_type, + signature(func), + get_type_hints(func), ) def test_scalar_udf_type_hint(self): @@ -213,7 +261,7 @@ def plus_one(itr: Iterator[pd.Series]) -> Iterator[pd.Series]: def test_group_agg_udf_type_hint(self): df = self.spark.range(10).selectExpr("id", "id as v") - def weighted_mean(v: pd.Series, w: pd.Series) -> float: + def weighted_mean(v: pd.Series, w: pd.Series) -> np.float64: return np.average(v, weights=w) weighted_mean = pandas_udf("double")(weighted_mean) @@ -257,6 +305,56 @@ def pandas_plus_one(iter: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]: expected = df.selectExpr("id + 1 as id") assert_frame_equal(expected.toPandas(), actual.toPandas()) + def test_string_type_annotation(self): + def func(col: "pd.Series") -> "pd.Series": + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR + ) + + def func(col: "pd.DataFrame", col1: "pd.Series") -> "pd.DataFrame": + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR + ) + + def func(col: "pd.DataFrame", *args: "pd.Series") -> "pd.Series": + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR + ) + + def func(col: "pd.Series", *args: "pd.Series", **kwargs: "pd.DataFrame") -> "pd.Series": + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR + ) + + def func(col: "pd.Series", *, col2: "pd.DataFrame") -> "pd.DataFrame": + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR + ) + + def func(col: Union["pd.Series", "pd.DataFrame"], *, col2: "pd.DataFrame") -> "pd.Series": + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR + ) + + def func(col: "Union[pd.Series, pd.DataFrame]", *, col2: "pd.DataFrame") -> "pd.Series": + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR + ) + if __name__ == "__main__": from pyspark.sql.tests.test_pandas_udf_typehints import * # noqa: #401 diff --git a/python/pyspark/sql/tests/test_pandas_udf_typehints_with_future_annotations.py b/python/pyspark/sql/tests/test_pandas_udf_typehints_with_future_annotations.py new file mode 100644 index 0000000000000..832086cb9ec8f --- /dev/null +++ b/python/pyspark/sql/tests/test_pandas_udf_typehints_with_future_annotations.py @@ -0,0 +1,375 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +import sys +import unittest +from inspect import signature +from typing import Union, Iterator, Tuple, cast, get_type_hints + +from pyspark.sql.functions import mean, lit +from pyspark.testing.sqlutils import ( + ReusedSQLTestCase, + have_pandas, + have_pyarrow, + pandas_requirement_message, + pyarrow_requirement_message, +) +from pyspark.sql.pandas.typehints import infer_eval_type +from pyspark.sql.pandas.functions import pandas_udf, PandasUDFType +from pyspark.sql import Row + +if have_pandas: + import pandas as pd + import numpy as np + from pandas.testing import assert_frame_equal + + +@unittest.skipIf( + not have_pandas or not have_pyarrow, + cast(str, pandas_requirement_message or pyarrow_requirement_message), +) +class PandasUDFTypeHintsWithFutureAnnotationsTests(ReusedSQLTestCase): + def test_type_annotation_scalar(self): + def func(col: pd.Series) -> pd.Series: + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR + ) + + def func(col: pd.DataFrame, col1: pd.Series) -> pd.DataFrame: + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR + ) + + def func(col: pd.DataFrame, *args: pd.Series) -> pd.Series: + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR + ) + + def func(col: pd.Series, *args: pd.Series, **kwargs: pd.DataFrame) -> pd.Series: + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR + ) + + def func(col: pd.Series, *, col2: pd.DataFrame) -> pd.DataFrame: + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR + ) + + def func(col: Union[pd.Series, pd.DataFrame], *, col2: pd.DataFrame) -> pd.Series: + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR + ) + + def test_type_annotation_scalar_iter(self): + def func(iter: Iterator[pd.Series]) -> Iterator[pd.Series]: + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR_ITER + ) + + def func(iter: Iterator[Tuple[pd.DataFrame, pd.Series]]) -> Iterator[pd.DataFrame]: + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR_ITER + ) + + def func(iter: Iterator[Tuple[pd.DataFrame, ...]]) -> Iterator[pd.Series]: + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR_ITER + ) + + def func(iter: Iterator[Tuple[Union[pd.DataFrame, pd.Series], ...]]) -> Iterator[pd.Series]: + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR_ITER + ) + + def test_type_annotation_group_agg(self): + def func(col: pd.Series) -> str: + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.GROUPED_AGG + ) + + def func(col: pd.DataFrame, col1: pd.Series) -> int: + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.GROUPED_AGG + ) + + def func(col: pd.DataFrame, *args: pd.Series) -> Row: + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.GROUPED_AGG + ) + + def func(col: pd.Series, *args: pd.Series, **kwargs: pd.DataFrame) -> str: + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.GROUPED_AGG + ) + + def func(col: pd.Series, *, col2: pd.DataFrame) -> float: + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.GROUPED_AGG + ) + + def func(col: Union[pd.Series, pd.DataFrame], *, col2: pd.DataFrame) -> float: + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.GROUPED_AGG + ) + + def test_type_annotation_negative(self): + def func(col: str) -> pd.Series: + pass + + self.assertRaisesRegex( + NotImplementedError, + "Unsupported signature.*str", + infer_eval_type, + signature(func), + get_type_hints(func), + ) + + def func(col: pd.DataFrame, col1: int) -> pd.DataFrame: + pass + + self.assertRaisesRegex( + NotImplementedError, + "Unsupported signature.*int", + infer_eval_type, + signature(func), + get_type_hints(func), + ) + + def func(col: Union[pd.DataFrame, str], col1: int) -> pd.DataFrame: + pass + + self.assertRaisesRegex( + NotImplementedError, + "Unsupported signature.*str", + infer_eval_type, + signature(func), + get_type_hints(func), + ) + + def func(col: pd.Series) -> Tuple[pd.DataFrame]: + pass + + self.assertRaisesRegex( + NotImplementedError, + "Unsupported signature.*Tuple", + infer_eval_type, + signature(func), + get_type_hints(func), + ) + + def func(col, *args: pd.Series) -> pd.Series: + pass + + self.assertRaisesRegex( + ValueError, + "should be specified.*Series", + infer_eval_type, + signature(func), + get_type_hints(func), + ) + + def func(col: pd.Series, *args: pd.Series, **kwargs: pd.DataFrame): + pass + + self.assertRaisesRegex( + ValueError, + "should be specified.*Series", + infer_eval_type, + signature(func), + get_type_hints(func), + ) + + def func(col: pd.Series, *, col2) -> pd.DataFrame: + pass + + self.assertRaisesRegex( + ValueError, + "should be specified.*Series", + infer_eval_type, + signature(func), + get_type_hints(func), + ) + + def test_scalar_udf_type_hint(self): + df = self.spark.range(10).selectExpr("id", "id as v") + + def plus_one(v: Union[pd.Series, pd.DataFrame]) -> pd.Series: + return v + 1 # type: ignore[return-value] + + plus_one = pandas_udf("long")(plus_one) + actual = df.select(plus_one(df.v).alias("plus_one")) + expected = df.selectExpr("(v + 1) as plus_one") + assert_frame_equal(expected.toPandas(), actual.toPandas()) + + def test_scalar_iter_udf_type_hint(self): + df = self.spark.range(10).selectExpr("id", "id as v") + + def plus_one(itr: Iterator[pd.Series]) -> Iterator[pd.Series]: + for s in itr: + yield s + 1 + + plus_one = pandas_udf("long")(plus_one) + + actual = df.select(plus_one(df.v).alias("plus_one")) + expected = df.selectExpr("(v + 1) as plus_one") + assert_frame_equal(expected.toPandas(), actual.toPandas()) + + def test_group_agg_udf_type_hint(self): + df = self.spark.range(10).selectExpr("id", "id as v") + + def weighted_mean(v: pd.Series, w: pd.Series) -> np.float64: + return np.average(v, weights=w) + + weighted_mean = pandas_udf("double")(weighted_mean) + + actual = df.groupby("id").agg(weighted_mean(df.v, lit(1.0))).sort("id") + expected = df.groupby("id").agg(mean(df.v).alias("weighted_mean(v, 1.0)")).sort("id") + assert_frame_equal(expected.toPandas(), actual.toPandas()) + + def test_ignore_type_hint_in_group_apply_in_pandas(self): + df = self.spark.range(10) + + def pandas_plus_one(v: pd.DataFrame) -> pd.DataFrame: + return v + 1 + + actual = df.groupby("id").applyInPandas(pandas_plus_one, schema=df.schema).sort("id") + expected = df.selectExpr("id + 1 as id") + assert_frame_equal(expected.toPandas(), actual.toPandas()) + + def test_ignore_type_hint_in_cogroup_apply_in_pandas(self): + df = self.spark.range(10) + + def pandas_plus_one(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame: + return left + 1 + + actual = ( + df.groupby("id") + .cogroup(self.spark.range(10).groupby("id")) + .applyInPandas(pandas_plus_one, schema=df.schema) + .sort("id") + ) + expected = df.selectExpr("id + 1 as id") + assert_frame_equal(expected.toPandas(), actual.toPandas()) + + def test_ignore_type_hint_in_map_in_pandas(self): + df = self.spark.range(10) + + def pandas_plus_one(iter: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]: + return map(lambda v: v + 1, iter) + + actual = df.mapInPandas(pandas_plus_one, schema=df.schema) + expected = df.selectExpr("id + 1 as id") + assert_frame_equal(expected.toPandas(), actual.toPandas()) + + @unittest.skipIf( + sys.version_info < (3, 9), + "string annotations with future annotations do not work under Python<3.9", + ) + def test_string_type_annotation(self): + def func(col: "pd.Series") -> "pd.Series": + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR + ) + + def func(col: "pd.DataFrame", col1: "pd.Series") -> "pd.DataFrame": + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR + ) + + def func(col: "pd.DataFrame", *args: "pd.Series") -> "pd.Series": + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR + ) + + def func(col: "pd.Series", *args: "pd.Series", **kwargs: "pd.DataFrame") -> "pd.Series": + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR + ) + + def func(col: "pd.Series", *, col2: "pd.DataFrame") -> "pd.DataFrame": + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR + ) + + def func(col: Union["pd.Series", "pd.DataFrame"], *, col2: "pd.DataFrame") -> "pd.Series": + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR + ) + + def func(col: "Union[pd.Series, pd.DataFrame]", *, col2: "pd.DataFrame") -> "pd.Series": + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR + ) + + +if __name__ == "__main__": + from pyspark.sql.tests.test_pandas_udf_typehints_with_future_annotations import * # noqa: #401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_session.py b/python/pyspark/sql/tests/test_session.py index 1262e529b9ccc..91aa923768fc2 100644 --- a/python/pyspark/sql/tests/test_session.py +++ b/python/pyspark/sql/tests/test_session.py @@ -224,28 +224,26 @@ def tearDown(self): def test_sqlcontext_with_stopped_sparksession(self): # SPARK-30856: test that SQLContext.getOrCreate() returns a usable instance after # the SparkSession is restarted. - sql_context = self.spark._wrapped + sql_context = SQLContext.getOrCreate(self.spark.sparkContext) self.spark.stop() - sc = SparkContext("local[4]", self.sc.appName) - spark = SparkSession(sc) # Instantiate the underlying SQLContext - new_sql_context = spark._wrapped + spark = SparkSession.builder.master("local[4]").appName(self.sc.appName).getOrCreate() + new_sql_context = SQLContext.getOrCreate(spark.sparkContext) self.assertIsNot(new_sql_context, sql_context) - self.assertIs(SQLContext.getOrCreate(sc).sparkSession, spark) + self.assertIs(SQLContext.getOrCreate(spark.sparkContext).sparkSession, spark) try: df = spark.createDataFrame([(1, 2)], ["c", "c"]) df.collect() finally: spark.stop() self.assertIsNone(SQLContext._instantiatedContext) - sc.stop() def test_sqlcontext_with_stopped_sparkcontext(self): # SPARK-30856: test initialization via SparkSession when only the SparkContext is stopped self.sc.stop() - self.sc = SparkContext("local[4]", self.sc.appName) - self.spark = SparkSession(self.sc) - self.assertIs(SQLContext.getOrCreate(self.sc).sparkSession, self.spark) + spark = SparkSession.builder.master("local[4]").appName(self.sc.appName).getOrCreate() + self.sc = spark.sparkContext + self.assertIs(SQLContext.getOrCreate(self.sc).sparkSession, spark) def test_get_sqlcontext_with_stopped_sparkcontext(self): # SPARK-30856: test initialization via SQLContext.getOrCreate() when only the SparkContext diff --git a/python/pyspark/sql/tests/test_streaming.py b/python/pyspark/sql/tests/test_streaming.py index 87e35641f648a..4920423be228b 100644 --- a/python/pyspark/sql/tests/test_streaming.py +++ b/python/pyspark/sql/tests/test_streaming.py @@ -86,7 +86,7 @@ def test_stream_save_options(self): .load("python/test_support/sql/streaming") .withColumn("id", lit(1)) ) - for q in self.spark._wrapped.streams.active: + for q in self.spark.streams.active: q.stop() tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) @@ -117,7 +117,7 @@ def test_stream_save_options(self): def test_stream_save_options_overwrite(self): df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") - for q in self.spark._wrapped.streams.active: + for q in self.spark.streams.active: q.stop() tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) @@ -154,7 +154,7 @@ def test_stream_save_options_overwrite(self): def test_stream_status_and_progress(self): df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") - for q in self.spark._wrapped.streams.active: + for q in self.spark.streams.active: q.stop() tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) @@ -198,7 +198,7 @@ def func(x): def test_stream_await_termination(self): df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") - for q in self.spark._wrapped.streams.active: + for q in self.spark.streams.active: q.stop() tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) @@ -267,7 +267,7 @@ def _assert_exception_tree_contains_msg(self, exception, msg): def test_query_manager_await_termination(self): df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") - for q in self.spark._wrapped.streams.active: + for q in self.spark.streams.active: q.stop() tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) @@ -280,13 +280,13 @@ def test_query_manager_await_termination(self): try: self.assertTrue(q.isActive) try: - self.spark._wrapped.streams.awaitAnyTermination("hello") + self.spark.streams.awaitAnyTermination("hello") self.fail("Expected a value exception") except ValueError: pass now = time.time() # test should take at least 2 seconds - res = self.spark._wrapped.streams.awaitAnyTermination(2.6) + res = self.spark.streams.awaitAnyTermination(2.6) duration = time.time() - now self.assertTrue(duration >= 2) self.assertFalse(res) @@ -347,7 +347,7 @@ def assert_invalid_writer(self, writer, msg=None): self.stop_all() def stop_all(self): - for q in self.spark._wrapped.streams.active: + for q in self.spark.streams.active: q.stop() def _reset(self): diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 2502387b8cc50..9ae6c3a63457e 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -48,7 +48,7 @@ BooleanType, NullType, ) -from pyspark.sql.types import ( # type: ignore +from pyspark.sql.types import ( _array_signed_int_typecode_ctype_mappings, _array_type_mappings, _array_unsigned_int_typecode_ctype_mappings, @@ -118,8 +118,14 @@ def test_infer_schema(self): with self.tempView("test"): df.createOrReplaceTempView("test") - result = self.spark.sql("SELECT l[0].a from test where d['key'].d = '2'") - self.assertEqual(1, result.head()[0]) + result = self.spark.sql("SELECT l from test") + self.assertEqual([], result.head()[0]) + # We set `spark.sql.ansi.enabled` to False for this case + # since it occurs an error in ANSI mode if there is a list index + # or key that does not exist. + with self.sql_conf({"spark.sql.ansi.enabled": False}): + result = self.spark.sql("SELECT l[0].a from test where d['key'].d = '2'") + self.assertEqual(1, result.head()[0]) df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0) self.assertEqual(df.schema, df2.schema) @@ -128,8 +134,14 @@ def test_infer_schema(self): with self.tempView("test2"): df2.createOrReplaceTempView("test2") - result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = '2'") - self.assertEqual(1, result.head()[0]) + result = self.spark.sql("SELECT l from test2") + self.assertEqual([], result.head()[0]) + # We set `spark.sql.ansi.enabled` to False for this case + # since it occurs an error in ANSI mode if there is a list index + # or key that does not exist. + with self.sql_conf({"spark.sql.ansi.enabled": False}): + result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = '2'") + self.assertEqual(1, result.head()[0]) def test_infer_schema_specification(self): from decimal import Decimal diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index 7f421aaea892c..805d5a8dfec9a 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -22,7 +22,7 @@ import unittest import datetime -from pyspark import SparkContext +from pyspark import SparkContext, SQLContext from pyspark.sql import SparkSession, Column, Row from pyspark.sql.functions import udf, assert_true, lit from pyspark.sql.udf import UserDefinedFunction @@ -79,7 +79,7 @@ def test_udf(self): self.assertEqual(row[0], 5) # This is to check if a deprecated 'SQLContext.registerFunction' can call its alias. - sqlContext = self.spark._wrapped + sqlContext = SQLContext.getOrCreate(self.spark.sparkContext) sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType()) [row] = sqlContext.sql("SELECT oneArg('test')").collect() self.assertEqual(row[0], 4) @@ -257,7 +257,8 @@ def test_udf_not_supported_in_join_condition(self): def runWithJoinType(join_type, type_string): with self.assertRaisesRegex( - AnalysisException, "Using PythonUDF.*%s is not supported." % type_string + AnalysisException, + "Using PythonUDF in join condition of join type %s is not supported" % type_string, ): left.join(right, [f("a", "b"), left.a1 == right.b1], join_type).collect() @@ -372,7 +373,7 @@ def test_udf_registration_returns_udf(self): ) # This is to check if a 'SQLContext.udf' can call its alias. - sqlContext = self.spark._wrapped + sqlContext = SQLContext.getOrCreate(self.spark.sparkContext) add_four = sqlContext.udf.register("add_four", lambda x: x + 4, IntegerType()) self.assertListEqual( @@ -419,7 +420,7 @@ def test_non_existed_udf(self): ) # This is to check if a deprecated 'SQLContext.registerJavaFunction' can call its alias. - sqlContext = spark._wrapped + sqlContext = SQLContext.getOrCreate(self.spark.sparkContext) self.assertRaisesRegex( AnalysisException, "Can not load class non_existed_udf", @@ -747,7 +748,8 @@ def f(*a): self.assertEqual(r.first()[0], "success") def test_udf_cache(self): - func = lambda x: x + def func(x): + return x df = self.spark.range(1) df.select(udf(func)("id")).cache() diff --git a/python/pyspark/sql/tests/test_udf_profiler.py b/python/pyspark/sql/tests/test_udf_profiler.py index 27d9458509402..136f423d0a35c 100644 --- a/python/pyspark/sql/tests/test_udf_profiler.py +++ b/python/pyspark/sql/tests/test_udf_profiler.py @@ -21,7 +21,7 @@ import sys from io import StringIO -from pyspark import SparkConf, SparkContext +from pyspark import SparkConf from pyspark.sql import SparkSession from pyspark.sql.functions import udf from pyspark.profiler import UDFBasicProfiler @@ -32,8 +32,13 @@ def setUp(self): self._old_sys_path = list(sys.path) class_name = self.__class__.__name__ conf = SparkConf().set("spark.python.profile", "true") - self.sc = SparkContext("local[4]", class_name, conf=conf) - self.spark = SparkSession.builder._sparkContext(self.sc).getOrCreate() + self.spark = ( + SparkSession.builder.master("local[4]") + .config(conf=conf) + .appName(class_name) + .getOrCreate() + ) + self.sc = self.spark.sparkContext def tearDown(self): self.spark.stop() diff --git a/python/pyspark/sql/tests/typing/test_session.yml b/python/pyspark/sql/tests/typing/test_session.yml index 01a6b288aae1a..70d0001c47ca8 100644 --- a/python/pyspark/sql/tests/typing/test_session.yml +++ b/python/pyspark/sql/tests/typing/test_session.yml @@ -35,7 +35,7 @@ spark.createDataFrame(data, schema) spark.createDataFrame(data, "name string, age integer") spark.createDataFrame([(1, ("foo", "bar"))], ("_1", "_2")) - spark.createDataFrame(data, ("name", "age"), samplingRatio=0.1) # type: ignore + spark.createDataFrame(data, ("name", "age"), samplingRatio=0.1) - case: createDataFrameScalarsValid diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index aad225aefcfae..41db22b054049 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -43,8 +43,8 @@ TypeVar, ) -from py4j.protocol import register_input_converter # type: ignore[import] -from py4j.java_gateway import JavaClass, JavaGateway, JavaObject # type: ignore[import] +from py4j.protocol import register_input_converter +from py4j.java_gateway import JavaClass, JavaGateway, JavaObject from pyspark.serializers import CloudPickleSerializer @@ -1011,7 +1011,7 @@ def _parse_datatype_string(s: str) -> DataType: """ from pyspark import SparkContext - sc = SparkContext._active_spark_context # type: ignore[attr-defined] + sc = SparkContext._active_spark_context assert sc is not None def from_ddl_schema(type_str: str) -> DataType: @@ -1355,11 +1355,20 @@ def _merge_type( name: Optional[str] = None, ) -> Union[StructType, ArrayType, MapType, DataType]: if name is None: - new_msg = lambda msg: msg - new_name = lambda n: "field %s" % n + + def new_msg(msg: str) -> str: + return msg + + def new_name(n: str) -> str: + return "field %s" % n + else: - new_msg = lambda msg: "%s: %s" % (name, msg) - new_name = lambda n: "field %s in %s" % (n, name) + + def new_msg(msg: str) -> str: + return "%s: %s" % (name, msg) + + def new_name(n: str) -> str: + return "field %s in %s" % (n, name) if isinstance(a, NullType): return b @@ -1551,11 +1560,20 @@ def _make_type_verifier( """ if name is None: - new_msg = lambda msg: msg - new_name = lambda n: "field %s" % n + + def new_msg(msg: str) -> str: + return msg + + def new_name(n: str) -> str: + return "field %s" % n + else: - new_msg = lambda msg: "%s: %s" % (name, msg) - new_name = lambda n: "field %s in %s" % (n, name) + + def new_msg(msg: str) -> str: + return "%s: %s" % (name, msg) + + def new_name(n: str) -> str: + return "field %s in %s" % (n, name) def verify_nullability(obj: Any) -> bool: if obj is None: @@ -1575,14 +1593,15 @@ def assert_acceptable_types(obj: Any) -> None: def verify_acceptable_types(obj: Any) -> None: # subclass of them can not be fromInternal in JVM - if type(obj) not in _acceptable_types[_type]: # type: ignore[operator] + if type(obj) not in _acceptable_types[_type]: raise TypeError( new_msg("%s can not accept object %r in type %s" % (dataType, obj, type(obj))) ) if isinstance(dataType, StringType): # StringType can work with any types - verify_value = lambda _: _ + def verify_value(obj: Any) -> None: + pass elif isinstance(dataType, UserDefinedType): verifier = _make_type_verifier(dataType.sqlType(), name=name) @@ -1665,9 +1684,7 @@ def verify_map(obj: Any) -> None: elif isinstance(dataType, StructType): verifiers = [] for f in dataType.fields: - verifier = _make_type_verifier( - f.dataType, f.nullable, name=new_name(f.name) - ) # type: ignore[arg-type] + verifier = _make_type_verifier(f.dataType, f.nullable, name=new_name(f.name)) verifiers.append((f.name, verifier)) def verify_struct(obj: Any) -> None: diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index d98078af743c0..d8856e053faa7 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -22,11 +22,11 @@ import sys from typing import Callable, Any, TYPE_CHECKING, Optional, cast, Union -from py4j.java_gateway import JavaObject # type: ignore[import] +from py4j.java_gateway import JavaObject from pyspark import SparkContext from pyspark.profiler import Profiler -from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType # type: ignore[attr-defined] +from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import ( StringType, @@ -53,10 +53,10 @@ def _wrap_function( bytearray(pickled_command), env, includes, - sc.pythonExec, # type: ignore[attr-defined] - sc.pythonVer, # type: ignore[attr-defined] + sc.pythonExec, + sc.pythonVer, broadcast_vars, - sc._javaAccumulator, # type: ignore[attr-defined] + sc._javaAccumulator, ) @@ -505,7 +505,6 @@ def registerJavaFunction( if returnType is not None: if not isinstance(returnType, DataType): returnType = _parse_datatype_string(returnType) - returnType = cast(DataType, returnType) jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 15645d0085d40..b3219b8b9be4e 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -17,20 +17,20 @@ from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING, cast import py4j -from py4j.java_collections import JavaArray # type: ignore[import] -from py4j.java_gateway import ( # type: ignore[import] +from py4j.java_collections import JavaArray +from py4j.java_gateway import ( JavaClass, JavaGateway, JavaObject, is_instance_of, ) -from py4j.protocol import Py4JJavaError # type: ignore[import] +from py4j.protocol import Py4JJavaError from pyspark import SparkContext from pyspark.find_spark_home import _find_spark_home if TYPE_CHECKING: - from pyspark.sql.context import SQLContext + from pyspark.sql.session import SparkSession from pyspark.sql.dataframe import DataFrame @@ -61,9 +61,9 @@ def __init__( self._origin = origin def __str__(self) -> str: - assert SparkContext._jvm is not None # type: ignore[attr-defined] + assert SparkContext._jvm is not None - jvm = SparkContext._jvm # type: ignore[attr-defined] + jvm = SparkContext._jvm sql_conf = jvm.org.apache.spark.sql.internal.SQLConf.get() debug_enabled = sql_conf.pysparkJVMStacktraceEnabled() desc = self.desc @@ -72,9 +72,9 @@ def __str__(self) -> str: return str(desc) def getErrorClass(self) -> Optional[str]: - assert SparkContext._gateway is not None # type: ignore[attr-defined] + assert SparkContext._gateway is not None - gw = SparkContext._gateway # type: ignore[attr-defined] + gw = SparkContext._gateway if self._origin is not None and is_instance_of( gw, self._origin, "org.apache.spark.SparkThrowable" ): @@ -83,9 +83,9 @@ def getErrorClass(self) -> Optional[str]: return None def getSqlState(self) -> Optional[str]: - assert SparkContext._gateway is not None # type: ignore[attr-defined] + assert SparkContext._gateway is not None - gw = SparkContext._gateway # type: ignore[attr-defined] + gw = SparkContext._gateway if self._origin is not None and is_instance_of( gw, self._origin, "org.apache.spark.SparkThrowable" ): @@ -144,11 +144,11 @@ class SparkUpgradeException(CapturedException): def convert_exception(e: Py4JJavaError) -> CapturedException: assert e is not None - assert SparkContext._jvm is not None # type: ignore[attr-defined] - assert SparkContext._gateway is not None # type: ignore[attr-defined] + assert SparkContext._jvm is not None + assert SparkContext._gateway is not None - jvm = SparkContext._jvm # type: ignore[attr-defined] - gw = SparkContext._gateway # type: ignore[attr-defined] + jvm = SparkContext._jvm + gw = SparkContext._gateway if is_instance_of(gw, e, "org.apache.spark.sql.catalyst.parser.ParseException"): return ParseException(origin=e) @@ -258,15 +258,15 @@ class ForeachBatchFunction: the query is active. """ - def __init__(self, sql_ctx: "SQLContext", func: Callable[["DataFrame", int], None]): - self.sql_ctx = sql_ctx + def __init__(self, session: "SparkSession", func: Callable[["DataFrame", int], None]): self.func = func + self.session = session def call(self, jdf: JavaObject, batch_id: int) -> None: from pyspark.sql.dataframe import DataFrame try: - self.func(DataFrame(jdf, self.sql_ctx), batch_id) + self.func(DataFrame(jdf, self.session), batch_id) except Exception as e: self.error = e raise e @@ -292,7 +292,7 @@ def is_timestamp_ntz_preferred() -> bool: """ Return a bool if TimestampNTZType is preferred according to the SQL configuration set. """ - jvm = SparkContext._jvm # type: ignore[attr-defined] + jvm = SparkContext._jvm return ( jvm is not None and getattr(getattr(jvm.org.apache.spark.sql.internal, "SQLConf$"), "MODULE$") diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index 1690c49a777fb..b8bc90f458cdf 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -21,7 +21,7 @@ from pyspark import since, SparkContext from pyspark.sql.column import _to_seq, _to_java_column -from py4j.java_gateway import JavaObject # type: ignore[import] +from py4j.java_gateway import JavaObject if TYPE_CHECKING: from pyspark.sql._typing import ColumnOrName, ColumnOrName_ diff --git a/python/pyspark/status.py b/python/pyspark/status.py index 193b9ff60f229..7e64c414403f2 100644 --- a/python/pyspark/status.py +++ b/python/pyspark/status.py @@ -19,8 +19,8 @@ from typing import List, NamedTuple, Optional -from py4j.java_collections import JavaArray # type: ignore[import] -from py4j.java_gateway import JavaObject # type: ignore[import] +from py4j.java_collections import JavaArray +from py4j.java_gateway import JavaObject class SparkJobInfo(NamedTuple): diff --git a/python/pyspark/streaming/context.pyi b/python/pyspark/streaming/context.pyi index 3eb252630934d..0d1b2aca7395f 100644 --- a/python/pyspark/streaming/context.pyi +++ b/python/pyspark/streaming/context.pyi @@ -18,7 +18,7 @@ from typing import Any, Callable, List, Optional, TypeVar -from py4j.java_gateway import JavaObject # type: ignore[import] +from py4j.java_gateway import JavaObject from pyspark.context import SparkContext from pyspark.rdd import RDD diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 0c1aa19fdc29b..f445a78bd9530 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -161,7 +161,10 @@ def foreachRDD(self, func): """ if func.__code__.co_argcount == 1: old_func = func - func = lambda t, rdd: old_func(rdd) + + def func(_, rdd): + return old_func(rdd) + jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer) api = self._ssc._jvm.PythonDStream api.callForeachRDD(self._jdstream, jfunc) @@ -194,7 +197,10 @@ def mapValues(self, f): Return a new DStream by applying a map function to the value of each key-value pairs in this DStream without changing the key. """ - map_values_fn = lambda kv: (kv[0], f(kv[1])) + + def map_values_fn(kv): + return kv[0], f(kv[1]) + return self.map(map_values_fn, preservesPartitioning=True) def flatMapValues(self, f): @@ -202,7 +208,10 @@ def flatMapValues(self, f): Return a new DStream by applying a flatmap function to the value of each key-value pairs in this DStream without changing the key. """ - flat_map_fn = lambda kv: ((kv[0], x) for x in f(kv[1])) + + def flat_map_fn(kv): + return ((kv[0], x) for x in f(kv[1])) + return self.flatMap(flat_map_fn, preservesPartitioning=True) def glom(self): @@ -308,7 +317,10 @@ def transform(self, func): """ if func.__code__.co_argcount == 1: oldfunc = func - func = lambda t, rdd: oldfunc(rdd) + + def func(_, rdd): + return oldfunc(rdd) + assert func.__code__.co_argcount == 2, "func should take one or two arguments" return TransformedDStream(self, func) @@ -322,7 +334,10 @@ def transformWith(self, func, other, keepSerializer=False): """ if func.__code__.co_argcount == 2: oldfunc = func - func = lambda t, a, b: oldfunc(a, b) + + def func(_, a, b): + return oldfunc(a, b) + assert func.__code__.co_argcount == 3, "func should take two or three arguments" jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer, other._jrdd_deserializer) dstream = self._sc._jvm.PythonTransformed2DStream( diff --git a/python/pyspark/streaming/kinesis.py b/python/pyspark/streaming/kinesis.py index e48a91e7ceb86..26d66c394ab83 100644 --- a/python/pyspark/streaming/kinesis.py +++ b/python/pyspark/streaming/kinesis.py @@ -20,7 +20,7 @@ from pyspark.storagelevel import StorageLevel from pyspark.streaming import DStream from pyspark.streaming.context import StreamingContext -from pyspark.util import _print_missing_jar # type: ignore[attr-defined] +from pyspark.util import _print_missing_jar __all__ = ["KinesisUtils", "InitialPositionInStream", "utf8_decoder"] diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index 627456b3744e3..c4d10aaeacc14 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -183,7 +183,7 @@ def _getOrCreate(cls: Type["BarrierTaskContext"]) -> "BarrierTaskContext": """ if not isinstance(cls._taskContext, BarrierTaskContext): cls._taskContext = object.__new__(cls) - return cast(BarrierTaskContext, cls._taskContext) + return cls._taskContext @classmethod def get(cls: Type["BarrierTaskContext"]) -> "BarrierTaskContext": diff --git a/python/pyspark/testing/mlutils.py b/python/pyspark/testing/mlutils.py index 48e3498eb0db3..503ba7c76960b 100644 --- a/python/pyspark/testing/mlutils.py +++ b/python/pyspark/testing/mlutils.py @@ -24,7 +24,7 @@ from pyspark.ml.param.shared import HasMaxIter, HasRegParam from pyspark.ml.classification import Classifier, ClassificationModel from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable -from pyspark.ml.wrapper import _java2py # type: ignore +from pyspark.ml.wrapper import _java2py from pyspark.sql import DataFrame, SparkSession from pyspark.sql.types import DoubleType from pyspark.testing.utils import ReusedPySparkTestCase as PySparkTestCase @@ -126,7 +126,7 @@ def _transform(self, dataset): class MockUnaryTransformer(UnaryTransformer, DefaultParamsReadable, DefaultParamsWritable): shift = Param( - Params._dummy(), # type: ignore + Params._dummy(), "shift", "The amount by which to shift " + "data in a DataFrame", typeConverter=TypeConverters.toFloat, diff --git a/python/pyspark/testing/pandasutils.py b/python/pyspark/testing/pandasutils.py index 6d402985f4aed..9b07a23ae1b56 100644 --- a/python/pyspark/testing/pandasutils.py +++ b/python/pyspark/testing/pandasutils.py @@ -46,7 +46,7 @@ matplotlib_requirement_message = None try: - import matplotlib # type: ignore # noqa: F401 + import matplotlib # noqa: F401 except ImportError as e: # If matplotlib requirement is not satisfied, skip related tests. matplotlib_requirement_message = str(e) @@ -54,7 +54,7 @@ plotly_requirement_message = None try: - import plotly # type: ignore # noqa: F401 + import plotly # noqa: F401 except ImportError as e: # If plotly requirement is not satisfied, skip related tests. plotly_requirement_message = str(e) @@ -259,7 +259,7 @@ def temp_dir(self): @contextmanager def temp_file(self): with self.temp_dir() as tmp: - yield tempfile.mktemp(dir=tmp) + yield tempfile.mkstemp(dir=tmp)[1] class ComparisonTestBase(PandasOnSparkTestCase): diff --git a/python/pyspark/testing/streamingutils.py b/python/pyspark/testing/streamingutils.py index b44fb4c73aeb2..1860c54d31856 100644 --- a/python/pyspark/testing/streamingutils.py +++ b/python/pyspark/testing/streamingutils.py @@ -40,7 +40,7 @@ "spark-streaming-kinesis-asl-assembly_", ) if kinesis_asl_assembly_jar is None: - kinesis_requirement_message = ( # type: ignore + kinesis_requirement_message = ( "Skipping all Kinesis Python tests as the optional Kinesis project was " "not compiled into a JAR. To run these tests, " "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/package " diff --git a/python/pyspark/tests/test_context.py b/python/pyspark/tests/test_context.py index 0f092a860b2a1..1b63869562f40 100644 --- a/python/pyspark/tests/test_context.py +++ b/python/pyspark/tests/test_context.py @@ -164,7 +164,7 @@ def test_overwrite_system_module(self): self.assertEqual("My Server", SimpleHTTPServer.__name__) def func(x): - import SimpleHTTPServer # type: ignore[import] + import SimpleHTTPServer return SimpleHTTPServer.__name__ diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py index 5790cae616aca..bf066e80b6b3b 100644 --- a/python/pyspark/tests/test_rdd.py +++ b/python/pyspark/tests/test_rdd.py @@ -37,7 +37,7 @@ from pyspark.testing.utils import ReusedPySparkTestCase, SPARK_HOME, QuietTest -global_func = lambda: "Hi" +global_func = lambda: "Hi" # noqa: E731 class RDDTests(ReusedPySparkTestCase): @@ -764,7 +764,7 @@ def test_overwritten_global_func(self): # Regression test for SPARK-27000 global global_func self.assertEqual(self.sc.parallelize([1]).map(lambda _: global_func()).first(), "Hi") - global_func = lambda: "Yeah" + global_func = lambda: "Yeah" # noqa: E731 self.assertEqual(self.sc.parallelize([1]).map(lambda _: global_func()).first(), "Yeah") def test_to_local_iterator_failure(self): diff --git a/python/pyspark/tests/test_serializers.py b/python/pyspark/tests/test_serializers.py index 1c04295213c77..0a89861a26f8c 100644 --- a/python/pyspark/tests/test_serializers.py +++ b/python/pyspark/tests/test_serializers.py @@ -72,7 +72,10 @@ def test_itemgetter(self): def test_function_module_name(self): ser = CloudPickleSerializer() - func = lambda x: x + + def func(x): + return x + func2 = ser.loads(ser.dumps(func)) self.assertEqual(func.__module__, func2.__module__) @@ -246,7 +249,7 @@ def test_chunked_stream(self): from pyspark.tests.test_serializers import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/tests/test_worker.py b/python/pyspark/tests/test_worker.py index 64e7b7d6a1bcf..0fdf6adb031bf 100644 --- a/python/pyspark/tests/test_worker.py +++ b/python/pyspark/tests/test_worker.py @@ -31,7 +31,7 @@ from py4j.protocol import Py4JJavaError from pyspark import SparkConf, SparkContext -from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest +from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest, eventually class WorkerTests(ReusedPySparkTestCase): @@ -188,11 +188,15 @@ def f(): class WorkerReuseTest(PySparkTestCase): def test_reuse_worker_of_parallelize_range(self): - rdd = self.sc.parallelize(range(20), 8) - previous_pids = rdd.map(lambda x: os.getpid()).collect() - current_pids = rdd.map(lambda x: os.getpid()).collect() - for pid in current_pids: - self.assertTrue(pid in previous_pids) + def check_reuse_worker_of_parallelize_range(): + rdd = self.sc.parallelize(range(20), 8) + previous_pids = rdd.map(lambda x: os.getpid()).collect() + current_pids = rdd.map(lambda x: os.getpid()).collect() + for pid in current_pids: + self.assertTrue(pid in previous_pids) + return True + + eventually(check_reuse_worker_of_parallelize_range, catch_assertions=True) @unittest.skipIf( diff --git a/python/pyspark/tests/typing/test_rdd.yml b/python/pyspark/tests/typing/test_rdd.yml index 749ad534d5ade..48965829cfdca 100644 --- a/python/pyspark/tests/typing/test_rdd.yml +++ b/python/pyspark/tests/typing/test_rdd.yml @@ -18,11 +18,11 @@ - case: toDF main: | from pyspark.sql.types import ( - IntegerType, - Row, - StructType, - StringType, - StructField, + IntegerType, + Row, + StructType, + StringType, + StructField, ) from collections import namedtuple from pyspark.sql import SparkSession @@ -60,3 +60,70 @@ rdd_named_tuple.toDF(sampleRatio=0.4) rdd_named_tuple.toDF(["a", "b"], sampleRatio=0.4) rdd_named_tuple.toDF(struct) + + +- case: rddMethods + main: | + from operator import add + from typing import Iterable, Set, Tuple + from pyspark.sql import SparkSession + + spark = SparkSession.builder.getOrCreate() + sc = spark.sparkContext + + def f1(x: int) -> str: + return str(x) + + reveal_type(sc.range(10).map(f1)) + + def f2(x: int) -> Iterable[int]: + return range(x) + + reveal_type(sc.range(10).flatMap(f2)) + + reveal_type(sc.parallelize([("a", 1), ("b", 0)]).filter(lambda x: x[1] != 0)) + + reveal_type(sc.parallelize([("a", 1), ("b", 0)]).max()) + + reveal_type(sc.range(10).reduce(add)) + + def seq_func(xs: Set[str], x: int) -> Set[str]: + xs.add(str(x % 11)) + return xs + + def comb_func(xs: Set[str], ys: Set[str]) -> Set[str]: + xs.update(ys) + return xs + + zero: Set[str] = set() + + reveal_type(sc.parallelize([("a", 1)]).aggregateByKey(zero, seq_func, comb_func)) + + out: | + main:11: note: Revealed type is "pyspark.rdd.RDD[builtins.str*]" + main:16: note: Revealed type is "pyspark.rdd.RDD[builtins.int*]" + main:18: note: Revealed type is "pyspark.rdd.RDD[Tuple[builtins.str, builtins.int]]" + main:20: note: Revealed type is "Tuple[builtins.str, builtins.int]" + main:22: note: Revealed type is "builtins.int" + main:34: note: Revealed type is "pyspark.rdd.RDD[Tuple[builtins.str, builtins.set[builtins.str]]]" + +- case: rddMethodsErrors + main: | + from pyspark.sql import SparkSession + + spark = SparkSession.builder.getOrCreate() + sc = spark.sparkContext + + def f1(x: str) -> str: + return x + + sc.range(10).map(f1) + + def f2(x: int) -> str: + return str(x) + + sc.range(10).reduce(f2) + + out: | + main:9: error: Argument 1 to "map" of "RDD" has incompatible type "Callable[[str], str]"; expected "Callable[[int], str]" [arg-type] + main:14: error: Argument 1 to "reduce" of "RDD" has incompatible type "Callable[[int], str]"; expected "Callable[[int, int], int]" [arg-type] diff --git a/python/pyspark/util.py b/python/pyspark/util.py index de44ab681b74d..b7b972a5d35b8 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -27,7 +27,7 @@ from types import TracebackType from typing import Any, Callable, Iterator, List, Optional, TextIO, Tuple -from py4j.clientserver import ClientServer # type: ignore[import] +from py4j.clientserver import ClientServer __all__: List[str] = [] @@ -331,13 +331,10 @@ def inheritable_thread_target(f: Callable) -> Callable: @functools.wraps(f) def wrapped(*args: Any, **kwargs: Any) -> Any: - try: - # Set local properties in child thread. - assert SparkContext._active_spark_context is not None - SparkContext._active_spark_context._jsc.sc().setLocalProperties(properties) - return f(*args, **kwargs) - finally: - InheritableThread._clean_py4j_conn_for_current_thread() + # Set local properties in child thread. + assert SparkContext._active_spark_context is not None + SparkContext._active_spark_context._jsc.sc().setLocalProperties(properties) + return f(*args, **kwargs) return wrapped else: @@ -377,10 +374,7 @@ def copy_local_properties(*a: Any, **k: Any) -> Any: assert hasattr(self, "_props") assert SparkContext._active_spark_context is not None SparkContext._active_spark_context._jsc.sc().setLocalProperties(self._props) - try: - return target(*a, **k) - finally: - InheritableThread._clean_py4j_conn_for_current_thread() + return target(*a, **k) super(InheritableThread, self).__init__( target=copy_local_properties, *args, **kwargs # type: ignore[misc] @@ -401,25 +395,6 @@ def start(self) -> None: self._props = SparkContext._active_spark_context._jsc.sc().getLocalProperties().clone() return super(InheritableThread, self).start() - @staticmethod - def _clean_py4j_conn_for_current_thread() -> None: - from pyspark import SparkContext - - jvm = SparkContext._jvm - assert jvm is not None - thread_connection = jvm._gateway_client.get_thread_connection() - if thread_connection is not None: - try: - # Dequeue is shared across other threads but it's thread-safe. - # If this function has to be invoked one more time in the same thead - # Py4J will create a new connection automatically. - jvm._gateway_client.deque.remove(thread_connection) - except ValueError: - # Should never reach this point - return - finally: - thread_connection.close() - if __name__ == "__main__": if "pypy" not in platform.python_implementation().lower() and sys.version_info[:2] >= (3, 7): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 1935e27d66363..8784abfb33379 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -60,7 +60,7 @@ ) from pyspark.sql.pandas.types import to_arrow_type from pyspark.sql.types import StructType -from pyspark.util import fail_on_stopiteration, try_simplify_traceback # type: ignore +from pyspark.util import fail_on_stopiteration, try_simplify_traceback from pyspark import shuffle pickleSer = CPickleSerializer() @@ -507,7 +507,8 @@ def mapper(a): else: return result - func = lambda _, it: map(mapper, it) + def func(_, it): + return map(mapper, it) # profiling is not supported for UDF return func, None, ser, ser diff --git a/python/setup.py b/python/setup.py index 4ff495c19d4dc..ab9b64f79bc37 100755 --- a/python/setup.py +++ b/python/setup.py @@ -258,10 +258,10 @@ def run(self): license='http://www.apache.org/licenses/LICENSE-2.0', # Don't forget to update python/docs/source/getting_started/install.rst # if you're updating the versions or dependencies. - install_requires=['py4j==0.10.9.3'], + install_requires=['py4j==0.10.9.4'], extras_require={ - 'ml': ['numpy>=1.7'], - 'mllib': ['numpy>=1.7'], + 'ml': ['numpy>=1.15'], + 'mllib': ['numpy>=1.15'], 'sql': [ 'pandas>=%s' % _minimum_pandas_version, 'pyarrow>=%s' % _minimum_pyarrow_version, @@ -269,7 +269,7 @@ def run(self): 'pandas_on_spark': [ 'pandas>=%s' % _minimum_pandas_version, 'pyarrow>=%s' % _minimum_pyarrow_version, - 'numpy>=1.14', + 'numpy>=1.15', ], }, python_requires='>=3.7', diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index 0cb5e115906a5..611fee66342e3 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -29,8 +29,30 @@ Spark Project Kubernetes kubernetes + **/*Volcano*.scala + + + volcano + + + + + + io.fabric8 + volcano-model-v1beta1 + ${kubernetes-client.version} + + + io.fabric8 + volcano-client + ${kubernetes-client.version} + + + + + org.apache.spark @@ -103,6 +125,19 @@ + + + + net.alchim31.maven + scala-maven-plugin + + + ${volcano.exclude} + + + + + target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes diff --git a/resource-managers/kubernetes/core/src/main/resources/META-INF/services/org.apache.spark.deploy.SparkSubmitOperation b/resource-managers/kubernetes/core/src/main/resources/META-INF/services/org.apache.spark.deploy.SparkSubmitOperation index d589e6b60f847..057c234287469 100644 --- a/resource-managers/kubernetes/core/src/main/resources/META-INF/services/org.apache.spark.deploy.SparkSubmitOperation +++ b/resource-managers/kubernetes/core/src/main/resources/META-INF/services/org.apache.spark.deploy.SparkSubmitOperation @@ -1 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + org.apache.spark.deploy.k8s.submit.K8SSparkSubmitOperation \ No newline at end of file diff --git a/resource-managers/kubernetes/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager b/resource-managers/kubernetes/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager index 81d14766ffb8d..72cb48ec46478 100644 --- a/resource-managers/kubernetes/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager +++ b/resource-managers/kubernetes/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager @@ -1 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + org.apache.spark.scheduler.cluster.k8s.KubernetesClusterManager diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index a2ad0d0a52a7f..7930cd0ce1563 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -58,7 +58,7 @@ private[spark] object Config extends Logging { val KUBERNETES_DRIVER_SERVICE_DELETE_ON_TERMINATION = ConfigBuilder("spark.kubernetes.driver.service.deleteOnTermination") .doc("If true, driver service will be deleted on Spark application termination. " + - "If false, it will be cleaned up when the driver pod is deletion.") + "If false, it will be cleaned up when the driver pod is deleted.") .version("3.2.0") .booleanConf .createWithDefault(true) @@ -147,7 +147,8 @@ private[spark] object Config extends Logging { .createWithDefault(0) object ExecutorRollPolicy extends Enumeration { - val ID, ADD_TIME, TOTAL_GC_TIME, TOTAL_DURATION, AVERAGE_DURATION, FAILED_TASKS, OUTLIER = Value + val ID, ADD_TIME, TOTAL_GC_TIME, TOTAL_DURATION, AVERAGE_DURATION, FAILED_TASKS, + OUTLIER, OUTLIER_NO_FALLBACK = Value } val EXECUTOR_ROLL_POLICY = @@ -165,7 +166,9 @@ private[spark] object Config extends Logging { "OUTLIER policy chooses an executor with outstanding statistics which is bigger than" + "at least two standard deviation from the mean in average task time, " + "total task time, total task GC time, and the number of failed tasks if exists. " + - "If there is no outlier, it works like TOTAL_DURATION policy.") + "If there is no outlier it works like TOTAL_DURATION policy. " + + "OUTLIER_NO_FALLBACK policy picks an outlier using the OUTLIER policy above. " + + "If there is no outlier then no executor will be rolled.") .version("3.3.0") .stringConf .transform(_.toUpperCase(Locale.ROOT)) @@ -280,6 +283,15 @@ private[spark] object Config extends Logging { .stringConf .createOptional + val KUBERNETES_SCHEDULER_NAME = + ConfigBuilder("spark.kubernetes.scheduler.name") + .doc("Specify the scheduler name for driver and executor pods. If " + + s"`${KUBERNETES_DRIVER_SCHEDULER_NAME.key}` or " + + s"`${KUBERNETES_EXECUTOR_SCHEDULER_NAME.key}` is set, will override this.") + .version("3.3.0") + .stringConf + .createOptional + val KUBERNETES_EXECUTOR_REQUEST_CORES = ConfigBuilder("spark.kubernetes.executor.request.cores") .doc("Specify the cpu request for each executor pod") @@ -341,7 +353,9 @@ private[spark] object Config extends Logging { ConfigBuilder("spark.kubernetes.driver.pod.featureSteps") .doc("Class names of an extra driver pod feature step implementing " + "KubernetesFeatureConfigStep. This is a developer API. Comma separated. " + - "Runs after all of Spark internal feature steps.") + "Runs after all of Spark internal feature steps. Since 3.3.0, your driver feature " + + "step can implement `KubernetesDriverCustomFeatureConfigStep` where the driver " + + "config is also available.") .version("3.2.0") .stringConf .toSequence @@ -351,14 +365,16 @@ private[spark] object Config extends Logging { ConfigBuilder("spark.kubernetes.executor.pod.featureSteps") .doc("Class name of an extra executor pod feature step implementing " + "KubernetesFeatureConfigStep. This is a developer API. Comma separated. " + - "Runs after all of Spark internal feature steps.") + "Runs after all of Spark internal feature steps. Since 3.3.0, your executor feature " + + "step can implement `KubernetesExecutorCustomFeatureConfigStep` where the executor " + + "config is also available.") .version("3.2.0") .stringConf .toSequence .createWithDefault(Nil) val KUBERNETES_EXECUTOR_DECOMMISSION_LABEL = - ConfigBuilder("spark.kubernetes.executor.decommmissionLabel") + ConfigBuilder("spark.kubernetes.executor.decommissionLabel") .doc("Label to apply to a pod which is being decommissioned." + " Designed for use with pod disruption budgets and similar mechanism" + " such as pod-deletion-cost.") @@ -367,7 +383,7 @@ private[spark] object Config extends Logging { .createOptional val KUBERNETES_EXECUTOR_DECOMMISSION_LABEL_VALUE = - ConfigBuilder("spark.kubernetes.executor.decommmissionLabelValue") + ConfigBuilder("spark.kubernetes.executor.decommissionLabelValue") .doc("Label value to apply to a pod which is being decommissioned." + " Designed for use with pod disruption budgets and similar mechanism" + " such as pod-deletion-cost.") diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala index 46086fac02021..118f4e5a61d3f 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -42,7 +42,7 @@ private[spark] abstract class KubernetesConf(val sparkConf: SparkConf) { def secretEnvNamesToKeyRefs: Map[String, String] def secretNamesToMountPaths: Map[String, String] def volumes: Seq[KubernetesVolumeSpec] - def schedulerName: String + def schedulerName: Option[String] def appId: String def appName: String = get("spark.app.name", "spark") @@ -136,7 +136,9 @@ private[spark] class KubernetesDriverConf( KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, KUBERNETES_DRIVER_VOLUMES_PREFIX) } - override def schedulerName: String = get(KUBERNETES_DRIVER_SCHEDULER_NAME).getOrElse("") + override def schedulerName: Option[String] = { + Option(get(KUBERNETES_DRIVER_SCHEDULER_NAME).getOrElse(get(KUBERNETES_SCHEDULER_NAME).orNull)) + } } private[spark] class KubernetesExecutorConf( @@ -195,7 +197,9 @@ private[spark] class KubernetesExecutorConf( KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, KUBERNETES_EXECUTOR_VOLUMES_PREFIX) } - override def schedulerName: String = get(KUBERNETES_EXECUTOR_SCHEDULER_NAME).getOrElse("") + override def schedulerName: Option[String] = { + Option(get(KUBERNETES_EXECUTOR_SCHEDULER_NAME).getOrElse(get(KUBERNETES_SCHEDULER_NAME).orNull)) + } private def checkExecutorEnvKey(key: String): Boolean = { // Pattern for matching an executorEnv key, which meets certain naming rules. diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index 0c8d9646a2b4e..a05d07adcc825 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -344,7 +344,7 @@ object KubernetesUtils extends Logging { delSrc : Boolean = false, overwrite: Boolean = true): Unit = { try { - fs.copyFromLocalFile(false, true, src, dest) + fs.copyFromLocalFile(delSrc, overwrite, src, dest) } catch { case e: IOException => throw new SparkException(s"Error uploading file ${src.getName}", e) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala index 4131605e62b1f..54f557c750a4b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala @@ -24,9 +24,10 @@ import com.google.common.io.Files import io.fabric8.kubernetes.client.{ConfigBuilder, DefaultKubernetesClient, KubernetesClient} import io.fabric8.kubernetes.client.Config.KUBERNETES_REQUEST_RETRY_BACKOFFLIMIT_SYSTEM_PROPERTY import io.fabric8.kubernetes.client.Config.autoConfigure -import io.fabric8.kubernetes.client.utils.HttpClientUtils +import io.fabric8.kubernetes.client.okhttp.OkHttpClientFactory import io.fabric8.kubernetes.client.utils.Utils.getSystemPropertyOrEnvVar import okhttp3.Dispatcher +import okhttp3.OkHttpClient import org.apache.spark.SparkConf import org.apache.spark.deploy.k8s.Config._ @@ -68,6 +69,8 @@ private[spark] object SparkKubernetesClientFactory extends Logging { .getOption(s"$kubernetesAuthConfPrefix.$CLIENT_KEY_FILE_CONF_SUFFIX") val clientCertFile = sparkConf .getOption(s"$kubernetesAuthConfPrefix.$CLIENT_CERT_FILE_CONF_SUFFIX") + // TODO(SPARK-37687): clean up direct usage of OkHttpClient, see also: + // https://github.com/fabric8io/kubernetes-client/issues/3547 val dispatcher = new Dispatcher( ThreadUtils.newDaemonCachedThreadPool("kubernetes-dispatcher")) @@ -105,13 +108,14 @@ private[spark] object SparkKubernetesClientFactory extends Logging { }.withOption(namespace) { (ns, configBuilder) => configBuilder.withNamespace(ns) }.build() - val baseHttpClient = HttpClientUtils.createHttpClient(config) - val httpClientWithCustomDispatcher = baseHttpClient.newBuilder() - .dispatcher(dispatcher) - .build() + val factoryWithCustomDispatcher = new OkHttpClientFactory() { + override protected def additionalConfig(builder: OkHttpClient.Builder): Unit = { + builder.dispatcher(dispatcher) + } + } logDebug("Kubernetes client config: " + new ObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(config)) - new DefaultKubernetesClient(httpClientWithCustomDispatcher, config) + new DefaultKubernetesClient(factoryWithCustomDispatcher.createHttpClient(config), config) } private implicit class OptionConfigurableConfigBuilder(val configBuilder: ConfigBuilder) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala index 49681dc8191c2..97151494fc60c 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala @@ -53,18 +53,23 @@ private[spark] class BasicDriverFeatureStep(conf: KubernetesDriverConf) // Memory settings private val driverMemoryMiB = conf.get(DRIVER_MEMORY) + private val memoryOverheadFactor = if (conf.contains(DRIVER_MEMORY_OVERHEAD_FACTOR)) { + conf.get(DRIVER_MEMORY_OVERHEAD_FACTOR) + } else { + conf.get(MEMORY_OVERHEAD_FACTOR) + } // The memory overhead factor to use. If the user has not set it, then use a different // value for non-JVM apps. This value is propagated to executors. private val overheadFactor = if (conf.mainAppResource.isInstanceOf[NonJVMResource]) { - if (conf.contains(MEMORY_OVERHEAD_FACTOR)) { - conf.get(MEMORY_OVERHEAD_FACTOR) + if (conf.contains(MEMORY_OVERHEAD_FACTOR) || conf.contains(DRIVER_MEMORY_OVERHEAD_FACTOR)) { + memoryOverheadFactor } else { NON_JVM_MEMORY_OVERHEAD_FACTOR } } else { - conf.get(MEMORY_OVERHEAD_FACTOR) + memoryOverheadFactor } private val memoryOverheadMiB = conf @@ -142,8 +147,8 @@ private[spark] class BasicDriverFeatureStep(conf: KubernetesDriverConf) .editOrNewMetadata() .withName(driverPodName) .addToLabels(conf.labels.asJava) - .addToLabels(SPARK_APP_NAME_LABEL, KubernetesConf.getAppNameLabel(conf.appName)) - .addToAnnotations(conf.annotations.asJava) + .addToAnnotations(conf.annotations.map { case (k, v) => + (k, Utils.substituteAppNExecIds(v, conf.appId, "")) }.asJava) .endMetadata() .editOrNewSpec() .withRestartPolicy("Never") @@ -153,7 +158,7 @@ private[spark] class BasicDriverFeatureStep(conf: KubernetesDriverConf) .endSpec() .build() - conf.get(KUBERNETES_DRIVER_SCHEDULER_NAME) + conf.schedulerName .foreach(driverPod.getSpec.setSchedulerName) SparkPod(driverPod, driverContainer) @@ -164,7 +169,7 @@ private[spark] class BasicDriverFeatureStep(conf: KubernetesDriverConf) KUBERNETES_DRIVER_POD_NAME.key -> driverPodName, "spark.app.id" -> conf.appId, KUBERNETES_DRIVER_SUBMIT_CHECK.key -> "true", - MEMORY_OVERHEAD_FACTOR.key -> overheadFactor.toString) + DRIVER_MEMORY_OVERHEAD_FACTOR.key -> overheadFactor.toString) // try upload local, resolvable files to a hadoop compatible file system Seq(JARS, FILES, ARCHIVES, SUBMIT_PYTHON_FILES).foreach { key => val uris = conf.get(key).filter(uri => KubernetesUtils.isLocalAndResolvable(uri)) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index 3f0a21e72ffbf..15c69ad487f5f 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -59,11 +59,16 @@ private[spark] class BasicExecutorFeatureStep( private val isDefaultProfile = resourceProfile.id == ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID private val isPythonApp = kubernetesConf.get(APP_RESOURCE_TYPE) == Some(APP_RESOURCE_TYPE_PYTHON) private val disableConfigMap = kubernetesConf.get(KUBERNETES_EXECUTOR_DISABLE_CONFIGMAP) + private val memoryOverheadFactor = if (kubernetesConf.contains(EXECUTOR_MEMORY_OVERHEAD_FACTOR)) { + kubernetesConf.get(EXECUTOR_MEMORY_OVERHEAD_FACTOR) + } else { + kubernetesConf.get(MEMORY_OVERHEAD_FACTOR) + } val execResources = ResourceProfile.getResourcesForClusterManager( resourceProfile.id, resourceProfile.executorResources, - kubernetesConf.get(MEMORY_OVERHEAD_FACTOR), + memoryOverheadFactor, kubernetesConf.sparkConf, isPythonApp, Map.empty) @@ -272,16 +277,14 @@ private[spark] class BasicExecutorFeatureStep( case "statefulset" => "Always" case _ => "Never" } + val annotations = kubernetesConf.annotations.map { case (k, v) => + (k, Utils.substituteAppNExecIds(v, kubernetesConf.appId, kubernetesConf.executorId)) + } val executorPodBuilder = new PodBuilder(pod.pod) .editOrNewMetadata() .withName(name) .addToLabels(kubernetesConf.labels.asJava) - .addToLabels(SPARK_RESOURCE_PROFILE_ID_LABEL, resourceProfile.id.toString) - .addToLabels( - SPARK_APP_NAME_LABEL, - KubernetesConf.getAppNameLabel(kubernetesConf.appName) - ) - .addToAnnotations(kubernetesConf.annotations.asJava) + .addToAnnotations(annotations.asJava) .addToOwnerReferences(ownerReference.toSeq: _*) .endMetadata() .editOrNewSpec() @@ -304,7 +307,7 @@ private[spark] class BasicExecutorFeatureStep( .endSpec() .build() } - kubernetesConf.get(KUBERNETES_EXECUTOR_SCHEDULER_NAME) + kubernetesConf.schedulerName .foreach(executorPod.getSpec.setSchedulerName) SparkPod(executorPod, containerWithLifecycle) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesDriverCustomFeatureConfigStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesDriverCustomFeatureConfigStep.scala new file mode 100644 index 0000000000000..0edd94d3370ab --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesDriverCustomFeatureConfigStep.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import org.apache.spark.annotation.{DeveloperApi, Unstable} +import org.apache.spark.deploy.k8s.KubernetesDriverConf + +/** + * :: DeveloperApi :: + * + * A base interface to help user extend custom feature step in driver side. + * Note: If your custom feature step would be used only in driver or both in driver and executor, + * please use this. + * + * Example of driver feature step: + * + * {{{ + * class DriverExampleFeatureStep extends KubernetesDriverCustomFeatureConfigStep { + * private var driverConf: KubernetesDriverConf = _ + * + * override def init(conf: KubernetesDriverConf): Unit = { + * driverConf = conf + * } + * + * // Implements methods of `KubernetesFeatureConfigStep`, such as `configurePod` + * override def configurePod(pod: SparkPod): SparkPod = { + * // Apply modifications on the given pod in accordance to this feature. + * } + * } + * }}} + * + * Example of feature step for both driver and executor: + * + * {{{ + * class DriverAndExecutorExampleFeatureStep extends KubernetesDriverCustomFeatureConfigStep + * with KubernetesExecutorCustomFeatureConfigStep { + * private var kubernetesConf: KubernetesConf = _ + * + * override def init(conf: KubernetesDriverConf): Unit = { + * kubernetesConf = conf + * } + * + * override def init(conf: KubernetesExecutorConf): Unit = { + * kubernetesConf = conf + * } + * + * // Implements methods of `KubernetesFeatureConfigStep`, such as `configurePod` + * override def configurePod(pod: SparkPod): SparkPod = { + * // Apply modifications on the given pod in accordance to this feature. + * } + * } + * }}} + */ +@Unstable +@DeveloperApi +trait KubernetesDriverCustomFeatureConfigStep extends KubernetesFeatureConfigStep { + /** + * Initialize the configuration for driver user feature step, this only applicable when user + * specified `spark.kubernetes.driver.pod.featureSteps`, the init will be called after feature + * step loading. + */ + def init(config: KubernetesDriverConf): Unit +} + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesExecutorCustomFeatureConfigStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesExecutorCustomFeatureConfigStep.scala new file mode 100644 index 0000000000000..dfb1c768c990e --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesExecutorCustomFeatureConfigStep.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import org.apache.spark.annotation.{DeveloperApi, Unstable} +import org.apache.spark.deploy.k8s.KubernetesExecutorConf + +/** + * :: DeveloperApi :: + * + * A base interface to help user extend custom feature step in executor side. + * Note: If your custom feature step would be used only in driver or both in driver and executor, + * please use this. + * + * Example of executor feature step: + * + * {{{ + * class ExecutorExampleFeatureStep extends KubernetesExecutorCustomFeatureConfigStep { + * private var executorConf: KubernetesExecutorConf = _ + * + * override def init(conf: KubernetesExecutorConf): Unit = { + * executorConf = conf + * } + * + * // Implements methods of `KubernetesFeatureConfigStep`, such as `configurePod` + * override def configurePod(pod: SparkPod): SparkPod = { + * // Apply modifications on the given pod in accordance to this feature. + * } + * } + * }}} + * + * Example of feature step for both driver and executor: + * + * {{{ + * class DriverAndExecutorExampleFeatureStep extends KubernetesDriverCustomFeatureConfigStep + * with KubernetesExecutorCustomFeatureConfigStep { + * private var kubernetesConf: KubernetesConf = _ + * + * override def init(conf: KubernetesDriverConf): Unit = { + * kubernetesConf = conf + * } + * + * override def init(conf: KubernetesExecutorConf): Unit = { + * kubernetesConf = conf + * } + * + * // Implements methods of `KubernetesFeatureConfigStep`, such as `configurePod` + * override def configurePod(pod: SparkPod): SparkPod = { + * // Apply modifications on the given pod in accordance to this feature. + * } + * } + * }}} + */ +@Unstable +@DeveloperApi +trait KubernetesExecutorCustomFeatureConfigStep extends KubernetesFeatureConfigStep { + /** + * Initialize the configuration for executor user feature step, this only applicable when user + * specified `spark.kubernetes.executor.pod.featureSteps` the init will be called after feature + * step loading. + */ + def init(config: KubernetesExecutorConf): Unit +} + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala index 4e1647372ecdc..78dd6ec21ed34 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala @@ -85,7 +85,7 @@ private[spark] class MountVolumesFeatureStep(conf: KubernetesConf) .withApiVersion("v1") .withNewMetadata() .withName(claimName) - .addToLabels(SPARK_APP_ID_LABEL, conf.sparkConf.getAppId) + .addToLabels(SPARK_APP_ID_LABEL, conf.appId) .endMetadata() .withNewSpec() .withStorageClassName(storageClass.get) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/VolcanoFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/VolcanoFeatureStep.scala new file mode 100644 index 0000000000000..091923a78efe5 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/VolcanoFeatureStep.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model._ +import io.fabric8.volcano.client.DefaultVolcanoClient +import io.fabric8.volcano.scheduling.v1beta1.{PodGroup, PodGroupSpec} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverConf, KubernetesExecutorConf, SparkPod} + +private[spark] class VolcanoFeatureStep extends KubernetesDriverCustomFeatureConfigStep + with KubernetesExecutorCustomFeatureConfigStep { + import VolcanoFeatureStep._ + + private var kubernetesConf: KubernetesConf = _ + + private lazy val podGroupName = s"${kubernetesConf.appId}-podgroup" + private lazy val namespace = kubernetesConf.namespace + + override def init(config: KubernetesDriverConf): Unit = { + kubernetesConf = config + } + + override def init(config: KubernetesExecutorConf): Unit = { + kubernetesConf = config + } + + override def getAdditionalPreKubernetesResources(): Seq[HasMetadata] = { + val client = new DefaultVolcanoClient + val template = kubernetesConf.getOption(POD_GROUP_TEMPLATE_FILE_KEY) + val pg = template.map(client.podGroups.load(_).get).getOrElse(new PodGroup()) + var metadata = pg.getMetadata + if (metadata == null) metadata = new ObjectMeta + metadata.setName(podGroupName) + metadata.setNamespace(namespace) + pg.setMetadata(metadata) + + var spec = pg.getSpec + if (spec == null) spec = new PodGroupSpec + pg.setSpec(spec) + + Seq(pg) + } + + override def configurePod(pod: SparkPod): SparkPod = { + val k8sPodBuilder = new PodBuilder(pod.pod) + .editMetadata() + .addToAnnotations(POD_GROUP_ANNOTATION, podGroupName) + .endMetadata() + val k8sPod = k8sPodBuilder.build() + SparkPod(k8sPod, pod.container) + } +} + +private[spark] object VolcanoFeatureStep { + val POD_GROUP_ANNOTATION = "scheduling.k8s.io/group-name" + val POD_GROUP_TEMPLATE_FILE_KEY = "spark.kubernetes.scheduler.volcano.podGroupTemplateFile" +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala index 96c19bbb3da69..3a3ab081fe843 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala @@ -105,7 +105,8 @@ private[spark] class Client( val configMapName = KubernetesClientUtils.configMapNameDriver val confFilesMap = KubernetesClientUtils.buildSparkConfDirFilesMap(configMapName, conf.sparkConf, resolvedDriverSpec.systemProperties) - val configMap = KubernetesClientUtils.buildConfigMap(configMapName, confFilesMap) + val configMap = KubernetesClientUtils.buildConfigMap(configMapName, confFilesMap + + (KUBERNETES_NAMESPACE.key -> conf.namespace)) // The include of the ENV_VAR for "SPARK_CONF_DIR" is to allow for the // Spark command builder to pickup on the Java Options present in the ConfigMap diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala index f0c78f371d6d2..e89e52f1af201 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.k8s.submit import io.fabric8.kubernetes.client.KubernetesClient +import org.apache.spark.SparkException import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.features._ import org.apache.spark.util.Utils @@ -39,7 +40,26 @@ private[spark] class KubernetesDriverBuilder { val userFeatures = conf.get(Config.KUBERNETES_DRIVER_POD_FEATURE_STEPS) .map { className => - Utils.classForName(className).newInstance().asInstanceOf[KubernetesFeatureConfigStep] + val feature = Utils.classForName[Any](className).newInstance() + val initializedFeature = feature match { + // Since 3.3, allow user to implement feature with KubernetesDriverConf + case d: KubernetesDriverCustomFeatureConfigStep => + d.init(conf) + Some(d) + // raise SparkException with wrong type feature step + case _: KubernetesExecutorCustomFeatureConfigStep => + None + // Since 3.2, allow user to implement feature without config + case f: KubernetesFeatureConfigStep => + Some(f) + case _ => None + } + initializedFeature.getOrElse { + throw new SparkException(s"Failed to initialize feature step: $className, " + + s"please make sure your driver side feature steps are implemented by " + + s"`${classOf[KubernetesDriverCustomFeatureConfigStep].getName}` or " + + s"`${classOf[KubernetesFeatureConfigStep].getName}`.") + } } val features = Seq( diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/AbstractPodsAllocator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/AbstractPodsAllocator.scala index 2e0d4fa7ca00b..cc081202cf89a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/AbstractPodsAllocator.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/AbstractPodsAllocator.scala @@ -26,9 +26,8 @@ import org.apache.spark.resource.ResourceProfile * :: DeveloperApi :: * A abstract interface for allowing different types of pods allocation. * - * The internal Spark implementations are [[StatefulsetPodsAllocator]] - * and [[ExecutorPodsAllocator]]. This may be useful for folks integrating with custom schedulers - * such as Volcano, Yunikorn, etc. + * The internal Spark implementations are [[StatefulSetPodsAllocator]] + * and [[ExecutorPodsAllocator]]. This may be useful for folks integrating with custom schedulers. * * This API may change or be removed at anytime. * diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorRollPlugin.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorRollPlugin.scala index f6054a8dbc5ee..5da4510d2cc86 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorRollPlugin.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorRollPlugin.scala @@ -61,7 +61,7 @@ class ExecutorRollDriverPlugin extends DriverPlugin with Logging { } else if (!sc.conf.get(DECOMMISSION_ENABLED)) { logWarning(s"Disabled because ${DECOMMISSION_ENABLED.key} is false.") } else { - minTasks = sparkContext.conf.get(MINIMUM_TASKS_PER_EXECUTOR_BEFORE_ROLLING) + minTasks = sc.conf.get(MINIMUM_TASKS_PER_EXECUTOR_BEFORE_ROLLING) // Scheduler is not created yet sparkContext = sc @@ -118,20 +118,27 @@ class ExecutorRollDriverPlugin extends DriverPlugin with Logging { case ExecutorRollPolicy.FAILED_TASKS => listWithoutDriver.sortBy(_.failedTasks).reverse case ExecutorRollPolicy.OUTLIER => - // We build multiple outlier lists and concat in the following importance order to find - // outliers in various perspective: - // AVERAGE_DURATION > TOTAL_DURATION > TOTAL_GC_TIME > FAILED_TASKS - // Since we will choose only first item, the duplication is okay. If there is no outlier, - // We fallback to TOTAL_DURATION policy. - outliers(listWithoutDriver.filter(_.totalTasks > 0), e => e.totalDuration / e.totalTasks) ++ - outliers(listWithoutDriver, e => e.totalDuration) ++ - outliers(listWithoutDriver, e => e.totalGCTime) ++ - outliers(listWithoutDriver, e => e.failedTasks) ++ + // If there is no outlier we fallback to TOTAL_DURATION policy. + outliersFromMultipleDimensions(listWithoutDriver) ++ listWithoutDriver.sortBy(_.totalDuration).reverse + case ExecutorRollPolicy.OUTLIER_NO_FALLBACK => + outliersFromMultipleDimensions(listWithoutDriver) } sortedList.headOption.map(_.id) } + /** + * We build multiple outlier lists and concat in the following importance order to find + * outliers in various perspective: + * AVERAGE_DURATION > TOTAL_DURATION > TOTAL_GC_TIME > FAILED_TASKS + * Since we will choose only first item, the duplication is okay. + */ + private def outliersFromMultipleDimensions(listWithoutDriver: Seq[v1.ExecutorSummary]) = + outliers(listWithoutDriver.filter(_.totalTasks > 0), e => e.totalDuration / e.totalTasks) ++ + outliers(listWithoutDriver, e => e.totalDuration) ++ + outliers(listWithoutDriver, e => e.totalGCTime) ++ + outliers(listWithoutDriver, e => e.failedTasks) + /** * Return executors whose metrics is outstanding, '(value - mean) > 2-sigma'. This is * a best-effort approach because the snapshot of ExecutorSummary is not a normal distribution. diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index 9497349569efc..10ea3a8cb0e46 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -137,7 +137,7 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit snapshotsStore: ExecutorPodsSnapshotsStore) = { val executorPodsAllocatorName = sc.conf.get(KUBERNETES_ALLOCATION_PODS_ALLOCATOR) match { case "statefulset" => - classOf[StatefulsetPodsAllocator].getName + classOf[StatefulSetPodsAllocator].getName case "direct" => classOf[ExecutorPodsAllocator].getName case fullClass => diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala index 110225e17473b..43c6597362e41 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -79,7 +79,8 @@ private[spark] class KubernetesClusterSchedulerBackend( val resolvedExecutorProperties = Map(KUBERNETES_NAMESPACE.key -> conf.get(KUBERNETES_NAMESPACE)) val confFilesMap = KubernetesClientUtils - .buildSparkConfDirFilesMap(configMapName, conf, resolvedExecutorProperties) + .buildSparkConfDirFilesMap(configMapName, conf, resolvedExecutorProperties) ++ + resolvedExecutorProperties val labels = Map(SPARK_APP_ID_LABEL -> applicationId(), SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) val configMap = KubernetesClientUtils.buildConfigMap(configMapName, confFilesMap, labels) @@ -95,7 +96,7 @@ private[spark] class KubernetesClusterSchedulerBackend( * @return The application ID */ override def applicationId(): String = { - conf.getOption("spark.app.id").map(_.toString).getOrElse(appId) + conf.getOption("spark.app.id").getOrElse(appId) } override def start(): Unit = { @@ -301,7 +302,7 @@ private[spark] class KubernetesClusterSchedulerBackend( kubernetesClient.pods() .withName(x.podName) .edit({p: Pod => new PodBuilder(p).editMetadata() - .addToLabels(SPARK_EXECUTOR_ID_LABEL, newId.toString) + .addToLabels(SPARK_EXECUTOR_ID_LABEL, newId) .endMetadata() .build()}) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala index 1a62d08a7b413..1f6d72cb7eee0 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler.cluster.k8s import io.fabric8.kubernetes.client.KubernetesClient -import org.apache.spark.SecurityManager +import org.apache.spark.{SecurityManager, SparkException} import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.features._ import org.apache.spark.resource.ResourceProfile @@ -43,7 +43,26 @@ private[spark] class KubernetesExecutorBuilder { val userFeatures = conf.get(Config.KUBERNETES_EXECUTOR_POD_FEATURE_STEPS) .map { className => - Utils.classForName(className).newInstance().asInstanceOf[KubernetesFeatureConfigStep] + val feature = Utils.classForName[Any](className).newInstance() + val initializedFeature = feature match { + // Since 3.3, allow user to implement feature with KubernetesExecutorConf + case e: KubernetesExecutorCustomFeatureConfigStep => + e.init(conf) + Some(e) + // raise SparkException with wrong type feature step + case _: KubernetesDriverCustomFeatureConfigStep => + None + // Since 3.2, allow user to implement feature without config + case f: KubernetesFeatureConfigStep => + Some(f) + case _ => None + } + initializedFeature.getOrElse { + throw new SparkException(s"Failed to initialize feature step: $className, " + + s"please make sure your executor side feature steps are implemented by " + + s"`${classOf[KubernetesExecutorCustomFeatureConfigStep].getSimpleName}` or " + + s"`${classOf[KubernetesFeatureConfigStep].getSimpleName}`.") + } } val features = Seq( diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/StatefulsetPodsAllocator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/StatefulSetPodsAllocator.scala similarity index 99% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/StatefulsetPodsAllocator.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/StatefulSetPodsAllocator.scala index 0d00d9678048e..294ee70168b23 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/StatefulsetPodsAllocator.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/StatefulSetPodsAllocator.scala @@ -34,7 +34,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.resource.ResourceProfile import org.apache.spark.util.{Clock, Utils} -class StatefulsetPodsAllocator( +class StatefulSetPodsAllocator( conf: SparkConf, secMgr: SecurityManager, executorBuilder: KubernetesExecutorBuilder, diff --git a/resource-managers/kubernetes/core/src/test/resources/driver-podgroup-template.yml b/resource-managers/kubernetes/core/src/test/resources/driver-podgroup-template.yml new file mode 100644 index 0000000000000..085d6b84c57aa --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/resources/driver-podgroup-template.yml @@ -0,0 +1,25 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +apiVersion: scheduling.volcano.sh/v1beta1 +kind: PodGroup +spec: + minMember: 1 + minResources: + cpu: "2" + memory: "2048Mi" + priorityClassName: driver-priority + queue: driver-queue diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala index 1b3aaa579c621..eecaff262bf66 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala @@ -206,24 +206,36 @@ class KubernetesConfSuite extends SparkFunSuite { test("SPARK-36059: Set driver.scheduler and executor.scheduler") { val sparkConf = new SparkConf(false) val execUnsetConf = KubernetesTestConf.createExecutorConf(sparkConf) - val driverUnsetConf = KubernetesTestConf.createExecutorConf(sparkConf) - assert(execUnsetConf.schedulerName === "") - assert(driverUnsetConf.schedulerName === "") - + val driverUnsetConf = KubernetesTestConf.createDriverConf(sparkConf) + assert(execUnsetConf.schedulerName === None) + assert(driverUnsetConf.schedulerName === None) + + sparkConf.set(KUBERNETES_SCHEDULER_NAME, "sameScheduler") + // Use KUBERNETES_SCHEDULER_NAME when is NOT set + assert(KubernetesTestConf.createDriverConf(sparkConf).schedulerName === Some("sameScheduler")) + assert(KubernetesTestConf.createExecutorConf(sparkConf).schedulerName === Some("sameScheduler")) + + // Override by driver/executor side scheduler when "" + sparkConf.set(KUBERNETES_DRIVER_SCHEDULER_NAME, "") + sparkConf.set(KUBERNETES_EXECUTOR_SCHEDULER_NAME, "") + assert(KubernetesTestConf.createDriverConf(sparkConf).schedulerName === Some("")) + assert(KubernetesTestConf.createExecutorConf(sparkConf).schedulerName === Some("")) + + // Override by driver/executor side scheduler when set sparkConf.set(KUBERNETES_DRIVER_SCHEDULER_NAME, "driverScheduler") sparkConf.set(KUBERNETES_EXECUTOR_SCHEDULER_NAME, "executorScheduler") val execConf = KubernetesTestConf.createExecutorConf(sparkConf) - assert(execConf.schedulerName === "executorScheduler") + assert(execConf.schedulerName === Some("executorScheduler")) val driverConf = KubernetesTestConf.createDriverConf(sparkConf) - assert(driverConf.schedulerName === "driverScheduler") + assert(driverConf.schedulerName === Some("driverScheduler")) } test("SPARK-37735: access appId in KubernetesConf") { val sparkConf = new SparkConf(false) val driverConf = KubernetesTestConf.createDriverConf(sparkConf) val execConf = KubernetesTestConf.createExecutorConf(sparkConf) - assert(driverConf.asInstanceOf[KubernetesConf].appId === KubernetesTestConf.APP_ID) - assert(execConf.asInstanceOf[KubernetesConf].appId === KubernetesTestConf.APP_ID) + assert(driverConf.appId === KubernetesTestConf.APP_ID) + assert(execConf.appId === KubernetesTestConf.APP_ID) } test("SPARK-36566: get app name label") { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsSuite.scala index ef57a4b861508..5498238307d1c 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsSuite.scala @@ -17,13 +17,20 @@ package org.apache.spark.deploy.k8s +import java.io.File +import java.nio.charset.StandardCharsets + import scala.collection.JavaConverters._ import io.fabric8.kubernetes.api.model.{ContainerBuilder, PodBuilder} +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.scalatest.PrivateMethodTester -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} -class KubernetesUtilsSuite extends SparkFunSuite { +class KubernetesUtilsSuite extends SparkFunSuite with PrivateMethodTester { private val HOST = "test-host" private val POD = new PodBuilder() .withNewSpec() @@ -65,4 +72,59 @@ class KubernetesUtilsSuite extends SparkFunSuite { assert(sparkPodWithNoContainerName.pod.getSpec.getHostname == HOST) assert(sparkPodWithNoContainerName.container.getName == null) } + + test("SPARK-38201: check uploadFileToHadoopCompatibleFS with different delSrc and overwrite") { + withTempDir { srcDir => + withTempDir { destDir => + val upload = PrivateMethod[Unit](Symbol("uploadFileToHadoopCompatibleFS")) + val fileName = "test.txt" + val srcFile = new File(srcDir, fileName) + val src = new Path(srcFile.getAbsolutePath) + val dest = new Path(destDir.getAbsolutePath, fileName) + val fs = src.getFileSystem(new Configuration()) + + def checkUploadException(delSrc: Boolean, overwrite: Boolean): Unit = { + val message = intercept[SparkException] { + KubernetesUtils.invokePrivate(upload(src, dest, fs, delSrc, overwrite)) + }.getMessage + assert(message.contains("Error uploading file")) + } + + def appendFileAndUpload(content: String, delSrc: Boolean, overwrite: Boolean): Unit = { + FileUtils.write(srcFile, content, StandardCharsets.UTF_8, true) + KubernetesUtils.invokePrivate(upload(src, dest, fs, delSrc, overwrite)) + } + + // Write a new file, upload file with delSrc = false and overwrite = true. + // Upload successful and record the `fileLength`. + appendFileAndUpload("init-content", delSrc = false, overwrite = true) + val firstLength = fs.getFileStatus(dest).getLen + + // Append the file, upload file with delSrc = false and overwrite = true. + // Upload succeeded but `fileLength` changed. + appendFileAndUpload("append-content", delSrc = false, overwrite = true) + val secondLength = fs.getFileStatus(dest).getLen + assert(firstLength < secondLength) + + // Upload file with delSrc = false and overwrite = false. + // Upload failed because dest exists and not changed. + checkUploadException(delSrc = false, overwrite = false) + assert(fs.exists(dest)) + assert(fs.getFileStatus(dest).getLen == secondLength) + + // Append the file again, upload file delSrc = true and overwrite = true. + // Upload succeeded, `fileLength` changed and src not exists. + appendFileAndUpload("append-content", delSrc = true, overwrite = true) + val thirdLength = fs.getFileStatus(dest).getLen + assert(secondLength < thirdLength) + assert(!fs.exists(src)) + + // Rewrite a new file, upload file with delSrc = true and overwrite = false. + // Upload failed because dest exists, src still exists. + FileUtils.write(srcFile, "re-init-content", StandardCharsets.UTF_8, true) + checkUploadException(delSrc = true, overwrite = false) + assert(fs.exists(src)) + } + } + } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/PodBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/PodBuilderSuite.scala index a8a3ca4eea965..642c18db541e1 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/PodBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/PodBuilderSuite.scala @@ -26,15 +26,24 @@ import org.mockito.Mockito.{mock, never, verify, when} import scala.collection.JavaConverters._ import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} -import org.apache.spark.deploy.k8s.features.KubernetesFeatureConfigStep +import org.apache.spark.deploy.k8s.features.{KubernetesDriverCustomFeatureConfigStep, KubernetesExecutorCustomFeatureConfigStep, KubernetesFeatureConfigStep} import org.apache.spark.internal.config.ConfigEntry abstract class PodBuilderSuite extends SparkFunSuite { + val POD_ROLE: String + val TEST_ANNOTATION_KEY: String + val TEST_ANNOTATION_VALUE: String protected def templateFileConf: ConfigEntry[_] + protected def roleSpecificSchedulerNameConf: ConfigEntry[_] + protected def userFeatureStepsConf: ConfigEntry[_] + protected def userFeatureStepWithExpectedAnnotation: (String, String) + + protected def wrongTypeFeatureStep: String + protected def buildPod(sparkConf: SparkConf, client: KubernetesClient): SparkPod protected val baseConf = new SparkConf(false) @@ -46,6 +55,20 @@ abstract class PodBuilderSuite extends SparkFunSuite { verify(client, never()).pods() } + test("SPARK-36059: set custom scheduler") { + val client = mockKubernetesClient() + val conf1 = baseConf.clone().set(templateFileConf.key, "template-file.yaml") + .set(Config.KUBERNETES_SCHEDULER_NAME.key, "custom") + val pod1 = buildPod(conf1, client) + assert(pod1.pod.getSpec.getSchedulerName === "custom") + + val conf2 = baseConf.clone().set(templateFileConf.key, "template-file.yaml") + .set(Config.KUBERNETES_SCHEDULER_NAME.key, "custom") + .set(roleSpecificSchedulerNameConf.key, "rolescheduler") + val pod2 = buildPod(conf2, client) + assert(pod2.pod.getSpec.getSchedulerName === "rolescheduler") + } + test("load pod template if specified") { val client = mockKubernetesClient() val sparkConf = baseConf.clone().set(templateFileConf.key, "template-file.yaml") @@ -66,6 +89,57 @@ abstract class PodBuilderSuite extends SparkFunSuite { assert(pod.container.getVolumeMounts.asScala.exists(_.getName == "so_long_two")) } + test("SPARK-37145: configure a custom test step with base config") { + val client = mockKubernetesClient() + val sparkConf = baseConf.clone() + .set(userFeatureStepsConf.key, + "org.apache.spark.deploy.k8s.TestStepWithConf") + .set(templateFileConf.key, "template-file.yaml") + .set("test-features-key", "test-features-value") + val pod = buildPod(sparkConf, client) + verifyPod(pod) + val metadata = pod.pod.getMetadata + assert(metadata.getAnnotations.containsKey("test-features-key")) + assert(metadata.getAnnotations.get("test-features-key") === "test-features-value") + } + + test("SPARK-37145: configure a custom test step with driver or executor config") { + val client = mockKubernetesClient() + val (featureSteps, annotation) = userFeatureStepWithExpectedAnnotation + val sparkConf = baseConf.clone() + .set(templateFileConf.key, "template-file.yaml") + .set(userFeatureStepsConf.key, featureSteps) + .set(TEST_ANNOTATION_KEY, annotation) + val pod = buildPod(sparkConf, client) + verifyPod(pod) + val metadata = pod.pod.getMetadata + assert(metadata.getAnnotations.containsKey(TEST_ANNOTATION_KEY)) + assert(metadata.getAnnotations.get(TEST_ANNOTATION_KEY) === annotation) + } + + test("SPARK-37145: configure a custom test step with wrong type config") { + val client = mockKubernetesClient() + val sparkConf = baseConf.clone() + .set(templateFileConf.key, "template-file.yaml") + .set(userFeatureStepsConf.key, wrongTypeFeatureStep) + val e = intercept[SparkException] { + buildPod(sparkConf, client) + } + assert(e.getMessage.contains(s"please make sure your $POD_ROLE side feature steps")) + } + + test("SPARK-37145: configure a custom test step with wrong name") { + val client = mockKubernetesClient() + val featureSteps = "unknow.class" + val sparkConf = baseConf.clone() + .set(templateFileConf.key, "template-file.yaml") + .set(userFeatureStepsConf.key, featureSteps) + val e = intercept[ClassNotFoundException] { + buildPod(sparkConf, client) + } + assert(e.getMessage.contains("unknow.class")) + } + test("complain about misconfigured pod template") { val client = mockKubernetesClient( new PodBuilder() @@ -249,3 +323,30 @@ class TestStepTwo extends KubernetesFeatureConfigStep { SparkPod(podWithLocalDirVolumes, containerWithLocalDirVolumeMounts) } } + +/** + * A test user feature step would be used in driver and executor. + */ +class TestStepWithConf extends KubernetesDriverCustomFeatureConfigStep + with KubernetesExecutorCustomFeatureConfigStep { + import io.fabric8.kubernetes.api.model._ + + private var kubernetesConf: KubernetesConf = _ + + override def init(conf: KubernetesDriverConf): Unit = { + kubernetesConf = conf + } + + override def init(conf: KubernetesExecutorConf): Unit = { + kubernetesConf = conf + } + + override def configurePod(pod: SparkPod): SparkPod = { + val k8sPodBuilder = new PodBuilder(pod.pod) + .editOrNewMetadata() + .addToAnnotations("test-features-key", kubernetesConf.get("test-features-key")) + .endMetadata() + val k8sPod = k8sPodBuilder.build() + SparkPod(k8sPod, pod.container) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala index 9e52c6ef6ccf1..d45f5f97da213 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import io.fabric8.kubernetes.api.model.{ContainerPort, ContainerPortBuilder, LocalObjectReferenceBuilder, Quantity} import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesTestConf, SparkPod} +import org.apache.spark.deploy.k8s.{KubernetesDriverConf, KubernetesTestConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.features.KubernetesFeaturesTestUtils.TestResourceInformation @@ -34,12 +34,14 @@ import org.apache.spark.util.Utils class BasicDriverFeatureStepSuite extends SparkFunSuite { - private val DRIVER_LABELS = Map("labelkey" -> "labelvalue") + private val CUSTOM_DRIVER_LABELS = Map("labelkey" -> "labelvalue") private val CONTAINER_IMAGE_PULL_POLICY = "IfNotPresent" - private val DRIVER_ANNOTATIONS = Map("customAnnotation" -> "customAnnotationValue") + private val DRIVER_ANNOTATIONS = Map( + "customAnnotation" -> "customAnnotationValue", + "yunikorn.apache.org/app-id" -> "{{APPID}}") private val DRIVER_ENVS = Map( - "customDriverEnv1" -> "customDriverEnv2", - "customDriverEnv2" -> "customDriverEnv2") + "customDriverEnv1" -> "customDriverEnv1Value", + "customDriverEnv2" -> "customDriverEnv2Value") private val TEST_IMAGE_PULL_SECRETS = Seq("my-secret-1", "my-secret-2") private val TEST_IMAGE_PULL_SECRET_OBJECTS = TEST_IMAGE_PULL_SECRETS.map { secret => @@ -62,9 +64,9 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { sparkConf.set(testRInfo.rId.amountConf, testRInfo.count) sparkConf.set(testRInfo.rId.vendorConf, testRInfo.vendor) } - val kubernetesConf = KubernetesTestConf.createDriverConf( + val kubernetesConf: KubernetesDriverConf = KubernetesTestConf.createDriverConf( sparkConf = sparkConf, - labels = DRIVER_LABELS, + labels = CUSTOM_DRIVER_LABELS, environment = DRIVER_ENVS, annotations = DRIVER_ANNOTATIONS) @@ -90,7 +92,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { .map { env => (env.getName, env.getValue) } .toMap DRIVER_ENVS.foreach { case (k, v) => - assert(envs(v) === v) + assert(envs(k) === v) } assert(envs(ENV_SPARK_USER) === Utils.getCurrentUserName()) assert(envs(ENV_APPLICATION_ID) === kubernetesConf.appId) @@ -116,19 +118,23 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { val driverPodMetadata = configuredPod.pod.getMetadata assert(driverPodMetadata.getName === "spark-driver-pod") - val DEFAULT_LABELS = Map( - SPARK_APP_NAME_LABEL-> KubernetesConf.getAppNameLabel(kubernetesConf.appName) - ) - (DRIVER_LABELS ++ DEFAULT_LABELS).foreach { case (k, v) => + + // Check custom and preset labels are as expected + CUSTOM_DRIVER_LABELS.foreach { case (k, v) => assert(driverPodMetadata.getLabels.get(k) === v) } - assert(driverPodMetadata.getAnnotations.asScala === DRIVER_ANNOTATIONS) + assert(driverPodMetadata.getLabels === kubernetesConf.labels.asJava) + + val annotations = driverPodMetadata.getAnnotations.asScala + DRIVER_ANNOTATIONS.foreach { case (k, v) => + assert(annotations(k) === Utils.substituteAppNExecIds(v, KubernetesTestConf.APP_ID, "")) + } assert(configuredPod.pod.getSpec.getRestartPolicy === "Never") val expectedSparkConf = Map( KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", "spark.app.id" -> KubernetesTestConf.APP_ID, "spark.kubernetes.submitInDriver" -> "true", - MEMORY_OVERHEAD_FACTOR.key -> MEMORY_OVERHEAD_FACTOR.defaultValue.get.toString) + DRIVER_MEMORY_OVERHEAD_FACTOR.key -> DRIVER_MEMORY_OVERHEAD_FACTOR.defaultValue.get.toString) assert(featureStep.getAdditionalPodSystemProperties() === expectedSparkConf) } @@ -187,7 +193,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { // Memory overhead tests. Tuples are: // test name, main resource, overhead factor, expected factor Seq( - ("java", JavaMainAppResource(None), None, MEMORY_OVERHEAD_FACTOR.defaultValue.get), + ("java", JavaMainAppResource(None), None, DRIVER_MEMORY_OVERHEAD_FACTOR.defaultValue.get), ("python default", PythonMainAppResource(null), None, NON_JVM_MEMORY_OVERHEAD_FACTOR), ("python w/ override", PythonMainAppResource(null), Some(0.9d), 0.9d), ("r default", RMainAppResource(null), None, NON_JVM_MEMORY_OVERHEAD_FACTOR) @@ -195,13 +201,13 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { test(s"memory overhead factor: $name") { // Choose a driver memory where the default memory overhead is > MEMORY_OVERHEAD_MIN_MIB val driverMem = - ResourceProfile.MEMORY_OVERHEAD_MIN_MIB / MEMORY_OVERHEAD_FACTOR.defaultValue.get * 2 + ResourceProfile.MEMORY_OVERHEAD_MIN_MIB / DRIVER_MEMORY_OVERHEAD_FACTOR.defaultValue.get * 2 // main app resource, overhead factor val sparkConf = new SparkConf(false) .set(CONTAINER_IMAGE, "spark-driver:latest") .set(DRIVER_MEMORY.key, s"${driverMem.toInt}m") - factor.foreach { value => sparkConf.set(MEMORY_OVERHEAD_FACTOR, value) } + factor.foreach { value => sparkConf.set(DRIVER_MEMORY_OVERHEAD_FACTOR, value) } val conf = KubernetesTestConf.createDriverConf( sparkConf = sparkConf, mainAppResource = resource) @@ -212,10 +218,63 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { assert(mem === s"${expected}Mi") val systemProperties = step.getAdditionalPodSystemProperties() - assert(systemProperties(MEMORY_OVERHEAD_FACTOR.key) === expectedFactor.toString) + assert(systemProperties(DRIVER_MEMORY_OVERHEAD_FACTOR.key) === expectedFactor.toString) } } + test(s"SPARK-38194: memory overhead factor precendence") { + // Choose a driver memory where the default memory overhead is > MEMORY_OVERHEAD_MIN_MIB + val driverMem = + ResourceProfile.MEMORY_OVERHEAD_MIN_MIB / DRIVER_MEMORY_OVERHEAD_FACTOR.defaultValue.get * 2 + + // main app resource, overhead factor + val sparkConf = new SparkConf(false) + .set(CONTAINER_IMAGE, "spark-driver:latest") + .set(DRIVER_MEMORY.key, s"${driverMem.toInt}m") + + // New config should take precedence + val expectedFactor = 0.2 + sparkConf.set(DRIVER_MEMORY_OVERHEAD_FACTOR, expectedFactor) + sparkConf.set(MEMORY_OVERHEAD_FACTOR, 0.3) + + val conf = KubernetesTestConf.createDriverConf( + sparkConf = sparkConf) + val step = new BasicDriverFeatureStep(conf) + val pod = step.configurePod(SparkPod.initialPod()) + val mem = amountAndFormat(pod.container.getResources.getRequests.get("memory")) + val expected = (driverMem + driverMem * expectedFactor).toInt + assert(mem === s"${expected}Mi") + + val systemProperties = step.getAdditionalPodSystemProperties() + assert(systemProperties(DRIVER_MEMORY_OVERHEAD_FACTOR.key) === expectedFactor.toString) + } + + test(s"SPARK-38194: old memory factor settings is applied if new one isn't given") { + // Choose a driver memory where the default memory overhead is > MEMORY_OVERHEAD_MIN_MIB + val driverMem = + ResourceProfile.MEMORY_OVERHEAD_MIN_MIB / DRIVER_MEMORY_OVERHEAD_FACTOR.defaultValue.get * 2 + + // main app resource, overhead factor + val sparkConf = new SparkConf(false) + .set(CONTAINER_IMAGE, "spark-driver:latest") + .set(DRIVER_MEMORY.key, s"${driverMem.toInt}m") + + // Old config still works if new config isn't given + val expectedFactor = 0.3 + sparkConf.set(MEMORY_OVERHEAD_FACTOR, expectedFactor) + + val conf = KubernetesTestConf.createDriverConf( + sparkConf = sparkConf) + val step = new BasicDriverFeatureStep(conf) + val pod = step.configurePod(SparkPod.initialPod()) + val mem = amountAndFormat(pod.container.getResources.getRequests.get("memory")) + val expected = (driverMem + driverMem * expectedFactor).toInt + assert(mem === s"${expected}Mi") + + val systemProperties = step.getAdditionalPodSystemProperties() + assert(systemProperties(DRIVER_MEMORY_OVERHEAD_FACTOR.key) === expectedFactor.toString) + } + test("SPARK-35493: make spark.blockManager.port be able to be fallen back to in driver pod") { val initPod = SparkPod.initialPod() val sparkConf = new SparkConf() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala index b0e7a34a4732f..731a9b77d2059 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala @@ -27,7 +27,7 @@ import io.fabric8.kubernetes.api.model._ import org.scalatest.BeforeAndAfter import org.apache.spark.{SecurityManager, SparkConf, SparkException, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorConf, KubernetesTestConf, SecretVolumeUtils, SparkPod} +import org.apache.spark.deploy.k8s.{KubernetesExecutorConf, KubernetesTestConf, SecretVolumeUtils, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.features.KubernetesFeaturesTestUtils.TestResourceInformation @@ -54,7 +54,7 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { private val DRIVER_POD_UID = "driver-uid" private val RESOURCE_NAME_PREFIX = "base" private val EXECUTOR_IMAGE = "executor-image" - private val LABELS = Map("label1key" -> "label1value") + private val CUSTOM_EXECUTOR_LABELS = Map("label1key" -> "label1value") private var defaultProfile: ResourceProfile = _ private val TEST_IMAGE_PULL_SECRETS = Seq("my-1secret-1", "my-secret-2") private val TEST_IMAGE_PULL_SECRET_OBJECTS = @@ -93,7 +93,7 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { KubernetesTestConf.createExecutorConf( sparkConf = baseConf, driverPod = Some(DRIVER_POD), - labels = LABELS, + labels = CUSTOM_EXECUTOR_LABELS, environment = environment) } @@ -156,12 +156,13 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { // The executor pod name and default labels. assert(executor.pod.getMetadata.getName === s"$RESOURCE_NAME_PREFIX-exec-1") - val DEFAULT_LABELS = Map( - SPARK_APP_NAME_LABEL-> KubernetesConf.getAppNameLabel(conf.appName) - ) - (LABELS ++ DEFAULT_LABELS).foreach { case (k, v) => + + // Check custom and preset labels are as expected + CUSTOM_EXECUTOR_LABELS.foreach { case (k, v) => assert(executor.pod.getMetadata.getLabels.get(k) === v) } + assert(executor.pod.getMetadata.getLabels === conf.labels.asJava) + assert(executor.pod.getSpec.getImagePullSecrets.asScala === TEST_IMAGE_PULL_SECRET_OBJECTS) // There is exactly 1 container with 1 volume mount and default memory limits. @@ -440,6 +441,60 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { )) } + test(s"SPARK-38194: memory overhead factor precendence") { + // Choose an executor memory where the default memory overhead is > MEMORY_OVERHEAD_MIN_MIB + val defaultFactor = EXECUTOR_MEMORY_OVERHEAD_FACTOR.defaultValue.get + val executorMem = ResourceProfile.MEMORY_OVERHEAD_MIN_MIB / defaultFactor * 2 + + // main app resource, overhead factor + val sparkConf = new SparkConf(false) + .set(CONTAINER_IMAGE, "spark-driver:latest") + .set(EXECUTOR_MEMORY.key, s"${executorMem.toInt}m") + + // New config should take precedence + val expectedFactor = 0.2 + sparkConf.set(EXECUTOR_MEMORY_OVERHEAD_FACTOR, expectedFactor) + sparkConf.set(MEMORY_OVERHEAD_FACTOR, 0.3) + + val conf = KubernetesTestConf.createExecutorConf( + sparkConf = sparkConf) + ResourceProfile.clearDefaultProfile() + val resourceProfile = ResourceProfile.getOrCreateDefaultProfile(sparkConf) + val step = new BasicExecutorFeatureStep(conf, new SecurityManager(baseConf), + resourceProfile) + val pod = step.configurePod(SparkPod.initialPod()) + val mem = amountAndFormat(pod.container.getResources.getRequests.get("memory")) + val expected = (executorMem + executorMem * expectedFactor).toInt + assert(mem === s"${expected}Mi") + } + + test(s"SPARK-38194: old memory factor settings is applied if new one isn't given") { + // Choose an executor memory where the default memory overhead is > MEMORY_OVERHEAD_MIN_MIB + val defaultFactor = EXECUTOR_MEMORY_OVERHEAD_FACTOR.defaultValue.get + val executorMem = ResourceProfile.MEMORY_OVERHEAD_MIN_MIB / defaultFactor * 2 + + // main app resource, overhead factor + val sparkConf = new SparkConf(false) + .set(CONTAINER_IMAGE, "spark-driver:latest") + .set(EXECUTOR_MEMORY.key, s"${executorMem.toInt}m") + + // New config should take precedence + val expectedFactor = 0.3 + sparkConf.set(MEMORY_OVERHEAD_FACTOR, expectedFactor) + + val conf = KubernetesTestConf.createExecutorConf( + sparkConf = sparkConf) + ResourceProfile.clearDefaultProfile() + val resourceProfile = ResourceProfile.getOrCreateDefaultProfile(sparkConf) + val step = new BasicExecutorFeatureStep(conf, new SecurityManager(baseConf), + resourceProfile) + val pod = step.configurePod(SparkPod.initialPod()) + val mem = amountAndFormat(pod.container.getResources.getRequests.get("memory")) + val expected = (executorMem + executorMem * expectedFactor).toInt + assert(mem === s"${expected}Mi") + } + + // There is always exactly one controller reference, and it points to the driver pod. private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = { assert(executor.getMetadata.getOwnerReferences.size() === 1) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala index 38f8fac1858f1..468d1dde9fb6d 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala @@ -89,6 +89,31 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { assert(executorPVC.getClaimName === s"pvc-spark-${KubernetesTestConf.EXECUTOR_ID}") } + test("SPARK-32713 Mounts parameterized persistentVolumeClaims in executors with storage class") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + "", + true, + KubernetesPVCVolumeConf("pvc-spark-SPARK_EXECUTOR_ID", Some("fast"), Some("512mb")) + ) + val driverConf = KubernetesTestConf.createDriverConf(volumes = Seq(volumeConf)) + val driverStep = new MountVolumesFeatureStep(driverConf) + val driverPod = driverStep.configurePod(SparkPod.initialPod()) + + assert(driverPod.pod.getSpec.getVolumes.size() === 1) + val driverPVC = driverPod.pod.getSpec.getVolumes.get(0).getPersistentVolumeClaim + assert(driverPVC.getClaimName === "pvc-spark-SPARK_EXECUTOR_ID") + + val executorConf = KubernetesTestConf.createExecutorConf(volumes = Seq(volumeConf)) + val executorStep = new MountVolumesFeatureStep(executorConf) + val executorPod = executorStep.configurePod(SparkPod.initialPod()) + + assert(executorPod.pod.getSpec.getVolumes.size() === 1) + val executorPVC = executorPod.pod.getSpec.getVolumes.get(0).getPersistentVolumeClaim + assert(executorPVC.getClaimName === s"pvc-spark-${KubernetesTestConf.EXECUTOR_ID}") + } + test("Create and mounts persistentVolumeClaims in driver") { val volumeConf = KubernetesVolumeSpec( "testVolume", diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/VolcanoFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/VolcanoFeatureStepSuite.scala new file mode 100644 index 0000000000000..d0d1f5ee5e11b --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/VolcanoFeatureStepSuite.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import java.io.File + +import io.fabric8.volcano.scheduling.v1beta1.PodGroup + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s._ + +class VolcanoFeatureStepSuite extends SparkFunSuite { + + test("SPARK-36061: Driver Pod with Volcano PodGroup") { + val sparkConf = new SparkConf() + val kubernetesConf = KubernetesTestConf.createDriverConf(sparkConf) + val step = new VolcanoFeatureStep() + step.init(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + val annotations = configuredPod.pod.getMetadata.getAnnotations + + assert(annotations.get("scheduling.k8s.io/group-name") === s"${kubernetesConf.appId}-podgroup") + val podGroup = step.getAdditionalPreKubernetesResources().head.asInstanceOf[PodGroup] + assert(podGroup.getMetadata.getName === s"${kubernetesConf.appId}-podgroup") + } + + test("SPARK-36061: Executor Pod with Volcano PodGroup") { + val sparkConf = new SparkConf() + val kubernetesConf = KubernetesTestConf.createExecutorConf(sparkConf) + val step = new VolcanoFeatureStep() + step.init(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + val annotations = configuredPod.pod.getMetadata.getAnnotations + assert(annotations.get("scheduling.k8s.io/group-name") === s"${kubernetesConf.appId}-podgroup") + } + + test("SPARK-38455: Support driver podgroup template") { + val templatePath = new File( + getClass.getResource("/driver-podgroup-template.yml").getFile).getAbsolutePath + val sparkConf = new SparkConf() + .set(VolcanoFeatureStep.POD_GROUP_TEMPLATE_FILE_KEY, templatePath) + val kubernetesConf = KubernetesTestConf.createDriverConf(sparkConf) + val step = new VolcanoFeatureStep() + step.init(kubernetesConf) + step.configurePod(SparkPod.initialPod()) + val podGroup = step.getAdditionalPreKubernetesResources().head.asInstanceOf[PodGroup] + assert(podGroup.getSpec.getMinMember == 1) + assert(podGroup.getSpec.getMinResources.get("cpu").getAmount == "2") + assert(podGroup.getSpec.getMinResources.get("memory").getAmount == "2048") + assert(podGroup.getSpec.getMinResources.get("memory").getFormat == "Mi") + assert(podGroup.getSpec.getPriorityClassName == "driver-priority") + assert(podGroup.getSpec.getQueue == "driver-queue") + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index bd4a78b3bdf97..12a5202b9d067 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -321,7 +321,8 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { val configMapName = KubernetesClientUtils.configMapNameDriver val configMap: ConfigMap = configMaps.head assert(configMap.getMetadata.getName == configMapName) - val configMapLoadedFiles = configMap.getData.keySet().asScala.toSet + val configMapLoadedFiles = configMap.getData.keySet().asScala.toSet - + Config.KUBERNETES_NAMESPACE.key assert(configMapLoadedFiles === expectedConfFiles.toSet ++ Set(SPARK_CONF_FILE_NAME)) for (f <- configMapLoadedFiles) { assert(configMap.getData.get(f).contains("conf1key=conf1value")) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala index 8bf43d909dee3..861b8e0fff943 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -22,19 +22,34 @@ import io.fabric8.kubernetes.client.KubernetesClient import org.apache.spark.SparkConf import org.apache.spark.deploy.k8s._ -import org.apache.spark.deploy.k8s.features.KubernetesFeatureConfigStep +import org.apache.spark.deploy.k8s.features.{KubernetesDriverCustomFeatureConfigStep, KubernetesFeatureConfigStep} import org.apache.spark.internal.config.ConfigEntry class KubernetesDriverBuilderSuite extends PodBuilderSuite { + val POD_ROLE: String = "driver" + val TEST_ANNOTATION_KEY: String = "driver-annotation-key" + val TEST_ANNOTATION_VALUE: String = "driver-annotation-value" override protected def templateFileConf: ConfigEntry[_] = { Config.KUBERNETES_DRIVER_PODTEMPLATE_FILE } + override protected def roleSpecificSchedulerNameConf: ConfigEntry[_] = { + Config.KUBERNETES_DRIVER_SCHEDULER_NAME + } + override protected def userFeatureStepsConf: ConfigEntry[_] = { Config.KUBERNETES_DRIVER_POD_FEATURE_STEPS } + override protected def userFeatureStepWithExpectedAnnotation: (String, String) = { + ("org.apache.spark.deploy.k8s.submit.TestStepWithDrvConf", TEST_ANNOTATION_VALUE) + } + + override protected def wrongTypeFeatureStep: String = { + "org.apache.spark.scheduler.cluster.k8s.TestStepWithExecConf" + } + override protected def buildPod(sparkConf: SparkConf, client: KubernetesClient): SparkPod = { val conf = KubernetesTestConf.createDriverConf(sparkConf = sparkConf) new KubernetesDriverBuilder().buildFromFeatures(conf, client).pod @@ -82,3 +97,27 @@ class TestStep extends KubernetesFeatureConfigStep { .build() ) } + + +/** + * A test driver user feature step would be used in only driver. + */ +class TestStepWithDrvConf extends KubernetesDriverCustomFeatureConfigStep { + import io.fabric8.kubernetes.api.model._ + + private var driverConf: KubernetesDriverConf = _ + + override def init(config: KubernetesDriverConf): Unit = { + driverConf = config + } + + override def configurePod(pod: SparkPod): SparkPod = { + val k8sPodBuilder = new PodBuilder(pod.pod) + .editOrNewMetadata() + // The annotation key = TEST_ANNOTATION_KEY, value = TEST_ANNOTATION_VALUE + .addToAnnotations("driver-annotation-key", driverConf.get("driver-annotation-key")) + .endMetadata() + val k8sPod = k8sPodBuilder.build() + SparkPod(k8sPod, pod.container) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorRollPluginSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorRollPluginSuite.scala index 9a6836bee93f7..886abc033893d 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorRollPluginSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorRollPluginSuite.scala @@ -132,7 +132,7 @@ class ExecutorRollPluginSuite extends SparkFunSuite with PrivateMethodTester { } test("A one-item executor list") { - ExecutorRollPolicy.values.foreach { value => + ExecutorRollPolicy.values.filter(_ != ExecutorRollPolicy.OUTLIER_NO_FALLBACK).foreach { value => assertEquals( Some(execWithSmallestID.id), plugin.invokePrivate(_choose(Seq(execWithSmallestID), value))) @@ -216,4 +216,47 @@ class ExecutorRollPluginSuite extends SparkFunSuite with PrivateMethodTester { plugin.invokePrivate(_choose(list :+ outlier, ExecutorRollPolicy.TOTAL_GC_TIME)), plugin.invokePrivate(_choose(list :+ outlier, ExecutorRollPolicy.OUTLIER))) } + + test("Policy: OUTLIER_NO_FALLBACK - Return None if there are no outliers") { + assertEquals(None, plugin.invokePrivate(_choose(list, ExecutorRollPolicy.OUTLIER_NO_FALLBACK))) + } + + test("Policy: OUTLIER_NO_FALLBACK - Detect an average task duration outlier") { + val outlier = new ExecutorSummary("9999", "host:port", true, 1, + 0, 0, 1, 0, 0, + 3, 0, 1, 300, + 20, 0, 0, + 0, false, 0, new Date(1639300001000L), + Option.empty, Option.empty, Map(), Option.empty, Set(), Option.empty, Map(), Map(), 1, + false, Set()) + assertEquals( + plugin.invokePrivate(_choose(list :+ outlier, ExecutorRollPolicy.AVERAGE_DURATION)), + plugin.invokePrivate(_choose(list :+ outlier, ExecutorRollPolicy.OUTLIER_NO_FALLBACK))) + } + + test("Policy: OUTLIER_NO_FALLBACK - Detect a total task duration outlier") { + val outlier = new ExecutorSummary("9999", "host:port", true, 1, + 0, 0, 1, 0, 0, + 3, 0, 1000, 1000, + 0, 0, 0, + 0, false, 0, new Date(1639300001000L), + Option.empty, Option.empty, Map(), Option.empty, Set(), Option.empty, Map(), Map(), 1, + false, Set()) + assertEquals( + plugin.invokePrivate(_choose(list :+ outlier, ExecutorRollPolicy.TOTAL_DURATION)), + plugin.invokePrivate(_choose(list :+ outlier, ExecutorRollPolicy.OUTLIER_NO_FALLBACK))) + } + + test("Policy: OUTLIER_NO_FALLBACK - Detect a total GC time outlier") { + val outlier = new ExecutorSummary("9999", "host:port", true, 1, + 0, 0, 1, 0, 0, + 3, 0, 1, 100, + 1000, 0, 0, + 0, false, 0, new Date(1639300001000L), + Option.empty, Option.empty, Map(), Option.empty, Set(), Option.empty, Map(), Map(), 1, + false, Set()) + assertEquals( + plugin.invokePrivate(_choose(list :+ outlier, ExecutorRollPolicy.TOTAL_GC_TIME)), + plugin.invokePrivate(_choose(list :+ outlier, ExecutorRollPolicy.OUTLIER_NO_FALLBACK))) + } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManagerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManagerSuite.scala index ae1477e51bdf6..2b6bfe851dbd3 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManagerSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManagerSuite.scala @@ -47,8 +47,8 @@ class KubernetesClusterManagerSuite extends SparkFunSuite with BeforeAndAfter { test("constructing a AbstractPodsAllocator works") { val validConfigs = List("statefulset", "direct", - "org.apache.spark.scheduler.cluster.k8s.StatefulsetPodsAllocator", - "org.apache.spark.scheduler.cluster.k8s.ExecutorPodsAllocator") + classOf[StatefulSetPodsAllocator].getName, + classOf[ExecutorPodsAllocator].getName) validConfigs.foreach { c => val manager = new KubernetesClusterManager() when(sc.conf.get(KUBERNETES_ALLOCATION_PODS_ALLOCATOR)).thenReturn(c) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala index 53aaba206fe48..9c31f9f912f01 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala @@ -45,8 +45,8 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn private val sparkConf = new SparkConf(false) .set("spark.executor.instances", "3") .set("spark.app.id", TEST_SPARK_APP_ID) - .set("spark.kubernetes.executor.decommmissionLabel", "soLong") - .set("spark.kubernetes.executor.decommmissionLabelValue", "cruelWorld") + .set(KUBERNETES_EXECUTOR_DECOMMISSION_LABEL.key, "soLong") + .set(KUBERNETES_EXECUTOR_DECOMMISSION_LABEL_VALUE.key, "cruelWorld") @Mock private var sc: SparkContext = _ diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala index ec60c6fc0bf82..97f7f4876ec12 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -20,19 +20,35 @@ import io.fabric8.kubernetes.client.KubernetesClient import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.k8s._ +import org.apache.spark.deploy.k8s.features.KubernetesExecutorCustomFeatureConfigStep import org.apache.spark.internal.config.ConfigEntry import org.apache.spark.resource.ResourceProfile class KubernetesExecutorBuilderSuite extends PodBuilderSuite { + val POD_ROLE: String = "executor" + val TEST_ANNOTATION_KEY: String = "executor-annotation-key" + val TEST_ANNOTATION_VALUE: String = "executor-annotation-value" override protected def templateFileConf: ConfigEntry[_] = { Config.KUBERNETES_EXECUTOR_PODTEMPLATE_FILE } + override protected def roleSpecificSchedulerNameConf: ConfigEntry[_] = { + Config.KUBERNETES_EXECUTOR_SCHEDULER_NAME + } + override protected def userFeatureStepsConf: ConfigEntry[_] = { Config.KUBERNETES_EXECUTOR_POD_FEATURE_STEPS } + override protected def userFeatureStepWithExpectedAnnotation: (String, String) = { + ("org.apache.spark.scheduler.cluster.k8s.TestStepWithExecConf", TEST_ANNOTATION_VALUE) + } + + override protected def wrongTypeFeatureStep: String = { + "org.apache.spark.deploy.k8s.submit.TestStepWithDrvConf" + } + override protected def buildPod(sparkConf: SparkConf, client: KubernetesClient): SparkPod = { sparkConf.set("spark.driver.host", "https://driver.host.com") val conf = KubernetesTestConf.createExecutorConf(sparkConf = sparkConf) @@ -40,5 +56,27 @@ class KubernetesExecutorBuilderSuite extends PodBuilderSuite { val defaultProfile = ResourceProfile.getOrCreateDefaultProfile(sparkConf) new KubernetesExecutorBuilder().buildFromFeatures(conf, secMgr, client, defaultProfile).pod } +} +/** + * A test executor user feature step would be used in only executor. + */ +class TestStepWithExecConf extends KubernetesExecutorCustomFeatureConfigStep { + import io.fabric8.kubernetes.api.model._ + + private var executorConf: KubernetesExecutorConf = _ + + def init(config: KubernetesExecutorConf): Unit = { + executorConf = config + } + + override def configurePod(pod: SparkPod): SparkPod = { + val k8sPodBuilder = new PodBuilder(pod.pod) + .editOrNewMetadata() + // The annotation key = TEST_ANNOTATION_KEY, value = TEST_ANNOTATION_VALUE + .addToAnnotations("executor-annotation-key", executorConf.get("executor-annotation-key")) + .endMetadata() + val k8sPod = k8sPodBuilder.build() + SparkPod(k8sPod, pod.container) + } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/StatefulsetAllocatorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/StatefulSetAllocatorSuite.scala similarity index 98% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/StatefulsetAllocatorSuite.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/StatefulSetAllocatorSuite.scala index 5f8ceb2d3ffc5..748f509e01303 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/StatefulsetAllocatorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/StatefulSetAllocatorSuite.scala @@ -80,7 +80,7 @@ class StatefulSetAllocatorSuite extends SparkFunSuite with BeforeAndAfter { @Mock private var driverPodOperations: PodResource[Pod] = _ - private var podsAllocatorUnderTest: StatefulsetPodsAllocator = _ + private var podsAllocatorUnderTest: StatefulSetPodsAllocator = _ private var snapshotsStore: DeterministicExecutorPodsSnapshotsStore = _ @@ -111,7 +111,7 @@ class StatefulSetAllocatorSuite extends SparkFunSuite with BeforeAndAfter { when(executorBuilder.buildFromFeatures(any(classOf[KubernetesExecutorConf]), meq(secMgr), meq(kubernetesClient), any(classOf[ResourceProfile]))).thenAnswer(executorPodAnswer()) snapshotsStore = new DeterministicExecutorPodsSnapshotsStore() - podsAllocatorUnderTest = new StatefulsetPodsAllocator( + podsAllocatorUnderTest = new StatefulSetPodsAllocator( conf, secMgr, executorBuilder, kubernetesClient, snapshotsStore, null) when(schedulerBackend.getExecutorIds).thenReturn(Seq.empty) podsAllocatorUnderTest.start(TEST_SPARK_APP_ID, schedulerBackend) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index b3e0d69909ab0..5691011795dcf 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -30,9 +30,9 @@ set -e # If there is no passwd entry for the container UID, attempt to create one if [ -z "$uidentry" ] ; then if [ -w /etc/passwd ] ; then - echo "$myuid:x:$myuid:$mygid:${SPARK_USER_NAME:-anonymous uid}:$SPARK_HOME:/bin/false" >> /etc/passwd + echo "$myuid:x:$myuid:$mygid:${SPARK_USER_NAME:-anonymous uid}:$SPARK_HOME:/bin/false" >> /etc/passwd else - echo "Container ENTRYPOINT failed to add passwd entry for anonymous UID" + echo "Container ENTRYPOINT failed to add passwd entry for anonymous UID" fi fi diff --git a/resource-managers/kubernetes/integration-tests/README.md b/resource-managers/kubernetes/integration-tests/README.md index 2c759d9095ce4..748664cf41b74 100644 --- a/resource-managers/kubernetes/integration-tests/README.md +++ b/resource-managers/kubernetes/integration-tests/README.md @@ -28,7 +28,7 @@ To run tests with Hadoop 2.x instead of Hadoop 3.x, use `--hadoop-profile`. ./dev/dev-run-integration-tests.sh --hadoop-profile hadoop-2 -The minimum tested version of Minikube is 1.7.3. The kube-dns addon must be enabled. Minikube should +The minimum tested version of Minikube is 1.18.0. The kube-dns addon must be enabled. Minikube should run with a minimum of 4 CPUs and 6G of memory: minikube start --cpus 4 --memory 6144 @@ -47,15 +47,15 @@ default this is set to `minikube`, the available backends are their prerequisite ### `minikube` -Uses the local `minikube` cluster, this requires that `minikube` 1.7.3 or greater be installed and that it be allocated +Uses the local `minikube` cluster, this requires that `minikube` 1.18.0 or greater be installed and that it be allocated at least 4 CPUs and 6GB memory (some users have reported success with as few as 3 CPUs and 4GB memory). The tests will check if `minikube` is started and abort early if it isn't currently running. -### `docker-for-desktop` +### `docker-desktop` Since July 2018 Docker for Desktop provide an optional Kubernetes cluster that can be enabled as described in this [blog post](https://blog.docker.com/2018/07/kubernetes-is-now-available-in-docker-desktop-stable-channel/). Assuming -this is enabled using this backend will auto-configure itself from the `docker-for-desktop` context that Docker creates +this is enabled using this backend will auto-configure itself from the `docker-desktop` context that Docker creates in your `~/.kube/config` file. If your config file is in a different location you should set the `KUBECONFIG` environment variable appropriately. @@ -139,7 +139,7 @@ properties to Maven. For example: -Dspark.kubernetes.test.imageTag=sometag \ -Dspark.kubernetes.test.imageRepo=docker.io/somerepo \ -Dspark.kubernetes.test.namespace=spark-int-tests \ - -Dspark.kubernetes.test.deployMode=docker-for-desktop \ + -Dspark.kubernetes.test.deployMode=docker-desktop \ -Dtest.include.tags=k8s @@ -172,7 +172,7 @@ to the wrapper scripts and using the wrapper scripts will simply set these appro spark.kubernetes.test.deployMode The integration test backend to use. Acceptable values are minikube, - docker-for-desktop and cloud. + docker-desktop and cloud. minikube @@ -187,7 +187,7 @@ to the wrapper scripts and using the wrapper scripts will simply set these appro spark.kubernetes.test.master - When using the cloud-url backend must be specified to indicate the K8S master URL to communicate + When using the cloud backend must be specified to indicate the K8S master URL to communicate with. @@ -269,3 +269,85 @@ to the wrapper scripts and using the wrapper scripts will simply set these appro + +# Running the Kubernetes Integration Tests with SBT + +You can use SBT in the same way to build image and run all K8s integration tests except Minikube-only ones. + + build/sbt -Psparkr -Pkubernetes -Pkubernetes-integration-tests \ + -Dtest.exclude.tags=minikube \ + -Dspark.kubernetes.test.deployMode=docker-desktop \ + -Dspark.kubernetes.test.imageTag=2022-03-06 \ + 'kubernetes-integration-tests/test' + +The following is an example to rerun tests with the pre-built image. + + build/sbt -Psparkr -Pkubernetes -Pkubernetes-integration-tests \ + -Dtest.exclude.tags=minikube \ + -Dspark.kubernetes.test.deployMode=docker-desktop \ + -Dspark.kubernetes.test.imageTag=2022-03-06 \ + 'kubernetes-integration-tests/runIts' + +In addition, you can run a single test selectively. + + build/sbt -Psparkr -Pkubernetes -Pkubernetes-integration-tests \ + -Dspark.kubernetes.test.deployMode=docker-desktop \ + -Dspark.kubernetes.test.imageTag=2022-03-06 \ + 'kubernetes-integration-tests/testOnly -- -z "Run SparkPi with a very long application name"' + +You can also specify your specific dockerfile to build JVM/Python/R based image to test. + + build/sbt -Psparkr -Pkubernetes -Pkubernetes-integration-tests \ + -Dtest.exclude.tags=minikube \ + -Dspark.kubernetes.test.deployMode=docker-desktop \ + -Dspark.kubernetes.test.imageTag=2022-03-06 \ + -Dspark.kubernetes.test.dockerFile=/path/to/Dockerfile \ + -Dspark.kubernetes.test.pyDockerFile=/path/to/py/Dockerfile \ + -Dspark.kubernetes.test.rDockerFile=/path/to/r/Dockerfile \ + 'kubernetes-integration-tests/test' + +# Running the Volcano Integration Tests + +Volcano integration is experimental in Aapche Spark 3.3.0 and the test coverage is limited. + +## Requirements +- A minimum of 6 CPUs and 9G of memory is required to complete all Volcano test cases. +- Volcano v1.5.1. + +## Installation + + # x86_64 + kubectl apply -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.5.1/installer/volcano-development.yaml + + # arm64: + kubectl apply -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.5.1/installer/volcano-development-arm64.yaml + +## Run tests + +You can specify `-Pvolcano` to enable volcano module to run all Kubernetes and Volcano tests + + build/sbt -Pvolcano -Pkubernetes -Pkubernetes-integration-tests \ + -Dtest.exclude.tags=minikube \ + -Dspark.kubernetes.test.deployMode=docker-desktop \ + 'kubernetes-integration-tests/test' + +You can also specify `volcano` tag to only run Volcano test: + + build/sbt -Pvolcano -Pkubernetes -Pkubernetes-integration-tests \ + -Dtest.include.tags=volcano \ + -Dtest.exclude.tags=minikube \ + -Dspark.kubernetes.test.deployMode=docker-desktop \ + 'kubernetes-integration-tests/test' + +## Cleanup Volcano + + # x86_64 + kubectl delete -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.5.1/installer/volcano-development.yaml + + # arm64: + kubectl delete -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.5.1/installer/volcano-development-arm64.yaml + + # Cleanup Volcano webhook + kubectl delete validatingwebhookconfigurations volcano-admission-service-jobs-validate volcano-admission-service-pods-validate volcano-admission-service-queues-validate + kubectl delete mutatingwebhookconfigurations volcano-admission-service-jobs-mutate volcano-admission-service-podgroups-mutate volcano-admission-service-pods-mutate volcano-admission-service-queues-mutate + diff --git a/resource-managers/kubernetes/integration-tests/dev/spark-rbac.yaml b/resource-managers/kubernetes/integration-tests/dev/spark-rbac.yaml index a4c242f2f2645..f6b8b10c87b15 100644 --- a/resource-managers/kubernetes/integration-tests/dev/spark-rbac.yaml +++ b/resource-managers/kubernetes/integration-tests/dev/spark-rbac.yaml @@ -26,7 +26,7 @@ metadata: name: spark-sa namespace: spark --- -apiVersion: rbac.authorization.k8s.io/v1beta1 +apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRole metadata: name: spark-role @@ -38,7 +38,7 @@ rules: verbs: - "*" --- -apiVersion: rbac.authorization.k8s.io/v1beta1 +apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRoleBinding metadata: name: spark-role-binding @@ -49,4 +49,4 @@ subjects: roleRef: kind: ClusterRole name: spark-role - apiGroup: rbac.authorization.k8s.io \ No newline at end of file + apiGroup: rbac.authorization.k8s.io diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index 4c5f14b79f690..318a903c14215 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -35,7 +35,7 @@ N/A ${project.build.directory}/spark-dist-unpacked N/A - 8-jre-slim + N/A ${project.build.directory}/imageTag.txt minikube docker.io/kubespark @@ -43,10 +43,11 @@ - N/A + Dockerfile.java17 + **/*Volcano*.scala jar Spark Project Kubernetes Integration Tests @@ -74,9 +75,28 @@ spark-tags_${scala.binary.version} test-jar + + org.apache.spark + spark-kubernetes_${scala.binary.version} + ${project.version} + test + + + + + net.alchim31.maven + scala-maven-plugin + + + ${volcano.exclude} + + + + + org.codehaus.mojo @@ -209,5 +229,18 @@
+ + volcano + + + + + + io.fabric8 + volcano-client + ${kubernetes-client.version} + + + diff --git a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh index 562d1d820cdd1..d8960349f0080 100755 --- a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh +++ b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh @@ -23,7 +23,7 @@ IMAGE_TAG_OUTPUT_FILE="$TEST_ROOT_DIR/target/image-tag.txt" DEPLOY_MODE="minikube" IMAGE_REPO="docker.io/kubespark" IMAGE_TAG="N/A" -JAVA_IMAGE_TAG="8-jre-slim" +JAVA_IMAGE_TAG="N/A" SPARK_TGZ="N/A" MVN="$TEST_ROOT_DIR/build/mvn" DOCKER_FILE="N/A" @@ -106,7 +106,11 @@ then # OpenJDK base-image tag (e.g. 8-jre-slim, 11-jre-slim) JAVA_IMAGE_TAG_BUILD_ARG="-b java_image_tag=$JAVA_IMAGE_TAG" else - JAVA_IMAGE_TAG_BUILD_ARG="-f $DOCKER_FILE" + if [[ $DOCKER_FILE = /* ]]; then + JAVA_IMAGE_TAG_BUILD_ARG="-f $DOCKER_FILE" + else + JAVA_IMAGE_TAG_BUILD_ARG="-f $DOCKER_FILE_BASE_PATH/$DOCKER_FILE" + fi fi # Build PySpark image @@ -136,7 +140,7 @@ then fi ;; - docker-for-desktop) + docker-desktop | docker-for-desktop) # Only need to build as this will place it in our local Docker repo which is all # we need for Docker for Desktop to work so no need to also push $SPARK_INPUT_DIR/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG $JAVA_IMAGE_TAG_BUILD_ARG $LANGUAGE_BINDING_BUILD_ARGS build diff --git a/resource-managers/kubernetes/integration-tests/src/test/resources/driver-schedule-template.yml b/resource-managers/kubernetes/integration-tests/src/test/resources/driver-schedule-template.yml new file mode 100644 index 0000000000000..22eaa6c13a85d --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/resources/driver-schedule-template.yml @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +apiVersion: v1 +Kind: Pod +metadata: + labels: + template-label-key: driver-template-label-value +spec: + priorityClassName: system-node-critical + containers: + - name: test-driver-container + image: will-be-overwritten + diff --git a/resource-managers/kubernetes/integration-tests/src/test/resources/pagerank_data.txt b/resource-managers/kubernetes/integration-tests/src/test/resources/pagerank_data.txt new file mode 100644 index 0000000000000..95755ab8f5af8 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/resources/pagerank_data.txt @@ -0,0 +1,6 @@ +1 2 +1 3 +1 4 +2 1 +3 1 +4 1 diff --git a/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/disable-queue.yml b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/disable-queue.yml new file mode 100644 index 0000000000000..d9f8c36471ec8 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/disable-queue.yml @@ -0,0 +1,24 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +apiVersion: scheduling.volcano.sh/v1beta1 +kind: Queue +metadata: + name: queue +spec: + weight: 1 + capability: + cpu: "0.001" diff --git a/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/disable-queue0-enable-queue1.yml b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/disable-queue0-enable-queue1.yml new file mode 100644 index 0000000000000..82e479478ccd9 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/disable-queue0-enable-queue1.yml @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +apiVersion: scheduling.volcano.sh/v1beta1 +kind: Queue +metadata: + name: queue0 +spec: + weight: 1 + capability: + cpu: "0.001" +--- +apiVersion: scheduling.volcano.sh/v1beta1 +kind: Queue +metadata: + name: queue1 +spec: + weight: 1 diff --git a/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/driver-podgroup-template-cpu-2u.yml b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/driver-podgroup-template-cpu-2u.yml new file mode 100644 index 0000000000000..e6d53ddc8b5cd --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/driver-podgroup-template-cpu-2u.yml @@ -0,0 +1,23 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +apiVersion: scheduling.volcano.sh/v1beta1 +kind: PodGroup +spec: + queue: queue-2u-3g + minMember: 1 + minResources: + cpu: "2" diff --git a/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/driver-podgroup-template-memory-3g.yml b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/driver-podgroup-template-memory-3g.yml new file mode 100644 index 0000000000000..9aaa5cf20658b --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/driver-podgroup-template-memory-3g.yml @@ -0,0 +1,23 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +apiVersion: scheduling.volcano.sh/v1beta1 +kind: PodGroup +spec: + queue: queue-2u-3g + minMember: 1 + minResources: + memory: "3Gi" diff --git a/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/enable-queue.yml b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/enable-queue.yml new file mode 100644 index 0000000000000..e753b8c07f01e --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/enable-queue.yml @@ -0,0 +1,24 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +apiVersion: scheduling.volcano.sh/v1beta1 +kind: Queue +metadata: + name: queue +spec: + weight: 1 + capability: + cpu: "1" diff --git a/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/enable-queue0-enable-queue1.yml b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/enable-queue0-enable-queue1.yml new file mode 100644 index 0000000000000..aadeb2851882e --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/enable-queue0-enable-queue1.yml @@ -0,0 +1,29 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +apiVersion: scheduling.volcano.sh/v1beta1 +kind: Queue +metadata: + name: queue0 +spec: + weight: 1 +--- +apiVersion: scheduling.volcano.sh/v1beta1 +kind: Queue +metadata: + name: queue1 +spec: + weight: 1 diff --git a/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/high-priority-driver-podgroup-template.yml b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/high-priority-driver-podgroup-template.yml new file mode 100644 index 0000000000000..a64431d69daa5 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/high-priority-driver-podgroup-template.yml @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +apiVersion: scheduling.volcano.sh/v1beta1 +kind: PodGroup +spec: + priorityClassName: high + queue: queue diff --git a/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/high-priority-driver-template.yml b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/high-priority-driver-template.yml new file mode 100644 index 0000000000000..a7968bfcb2c1a --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/high-priority-driver-template.yml @@ -0,0 +1,26 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +apiVersion: v1 +Kind: Pod +metadata: + labels: + template-label-key: driver-template-label-value +spec: + priorityClassName: high + containers: + - name: test-driver-container + image: will-be-overwritten diff --git a/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/low-priority-driver-podgroup-template.yml b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/low-priority-driver-podgroup-template.yml new file mode 100644 index 0000000000000..5e89630c01705 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/low-priority-driver-podgroup-template.yml @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +apiVersion: scheduling.volcano.sh/v1beta1 +kind: PodGroup +spec: + priorityClassName: low + queue: queue diff --git a/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/low-priority-driver-template.yml b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/low-priority-driver-template.yml new file mode 100644 index 0000000000000..7f04b9e120c83 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/low-priority-driver-template.yml @@ -0,0 +1,26 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +apiVersion: v1 +Kind: Pod +metadata: + labels: + template-label-key: driver-template-label-value +spec: + priorityClassName: low + containers: + - name: test-driver-container + image: will-be-overwritten diff --git a/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/medium-priority-driver-podgroup-template.yml b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/medium-priority-driver-podgroup-template.yml new file mode 100644 index 0000000000000..5773e8b6b14be --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/medium-priority-driver-podgroup-template.yml @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +apiVersion: scheduling.volcano.sh/v1beta1 +kind: PodGroup +spec: + priorityClassName: medium + queue: queue diff --git a/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/medium-priority-driver-template.yml b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/medium-priority-driver-template.yml new file mode 100644 index 0000000000000..78d9295399c2e --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/medium-priority-driver-template.yml @@ -0,0 +1,26 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +apiVersion: v1 +Kind: Pod +metadata: + labels: + template-label-key: driver-template-label-value +spec: + priorityClassName: medium + containers: + - name: test-driver-container + image: will-be-overwritten diff --git a/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/priorityClasses.yml b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/priorityClasses.yml new file mode 100644 index 0000000000000..64e9b0d530363 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/priorityClasses.yml @@ -0,0 +1,33 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +apiVersion: scheduling.k8s.io/v1 +kind: PriorityClass +metadata: + name: high +value: 100 +--- +apiVersion: scheduling.k8s.io/v1 +kind: PriorityClass +metadata: + name: medium +value: 50 +--- +apiVersion: scheduling.k8s.io/v1 +kind: PriorityClass +metadata: + name: low +value: 0 diff --git a/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/queue-2u-3g.yml b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/queue-2u-3g.yml new file mode 100644 index 0000000000000..094ec233fd041 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/queue-2u-3g.yml @@ -0,0 +1,25 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +apiVersion: scheduling.volcano.sh/v1beta1 +kind: Queue +metadata: + name: queue-2u-3g +spec: + weight: 1 + capability: + cpu: "2" + memory: "3Gi" diff --git a/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/queue-driver-podgroup-template.yml b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/queue-driver-podgroup-template.yml new file mode 100644 index 0000000000000..591000a0d02d3 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/queue-driver-podgroup-template.yml @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +apiVersion: scheduling.volcano.sh/v1beta1 +kind: PodGroup +spec: + queue: queue diff --git a/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/queue0-driver-podgroup-template.yml b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/queue0-driver-podgroup-template.yml new file mode 100644 index 0000000000000..faba21abe1ec2 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/queue0-driver-podgroup-template.yml @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +apiVersion: scheduling.volcano.sh/v1beta1 +kind: PodGroup +spec: + queue: queue0 diff --git a/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/queue1-driver-podgroup-template.yml b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/queue1-driver-podgroup-template.yml new file mode 100644 index 0000000000000..280656450ea06 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/resources/volcano/queue1-driver-podgroup-template.yml @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +apiVersion: scheduling.volcano.sh/v1beta1 +kind: PodGroup +spec: + queue: queue1 diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/BasicTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/BasicTestsSuite.scala index 6db4beef6d221..a79442ac63581 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/BasicTestsSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/BasicTestsSuite.scala @@ -22,13 +22,13 @@ import io.fabric8.kubernetes.api.model.Pod import org.scalatest.concurrent.Eventually import org.scalatest.matchers.should.Matchers._ -import org.apache.spark.TestUtils +import org.apache.spark.{SparkFunSuite, TestUtils} import org.apache.spark.launcher.SparkLauncher private[spark] trait BasicTestsSuite { k8sSuite: KubernetesSuite => import BasicTestsSuite._ - import KubernetesSuite.k8sTestTag + import KubernetesSuite.{k8sTestTag, localTestTag} import KubernetesSuite.{TIMEOUT, INTERVAL} test("Run SparkPi with no resources", k8sTestTag) { @@ -82,12 +82,14 @@ private[spark] trait BasicTestsSuite { k8sSuite: KubernetesSuite => .set("spark.kubernetes.driver.label.label2", "label2-value") .set("spark.kubernetes.driver.annotation.annotation1", "annotation1-value") .set("spark.kubernetes.driver.annotation.annotation2", "annotation2-value") + .set("spark.kubernetes.driver.annotation.yunikorn.apache.org/app-id", "{{APP_ID}}") .set("spark.kubernetes.driverEnv.ENV1", "VALUE1") .set("spark.kubernetes.driverEnv.ENV2", "VALUE2") .set("spark.kubernetes.executor.label.label1", "label1-value") .set("spark.kubernetes.executor.label.label2", "label2-value") .set("spark.kubernetes.executor.annotation.annotation1", "annotation1-value") .set("spark.kubernetes.executor.annotation.annotation2", "annotation2-value") + .set("spark.kubernetes.executor.annotation.yunikorn.apache.org/app-id", "{{APP_ID}}") .set("spark.executorEnv.ENV1", "VALUE1") .set("spark.executorEnv.ENV2", "VALUE2") @@ -116,21 +118,21 @@ private[spark] trait BasicTestsSuite { k8sSuite: KubernetesSuite => expectedJVMValue = Seq("(spark.test.foo,spark.test.bar)")) } - test("Run SparkRemoteFileTest using a remote data file", k8sTestTag) { + test("Run SparkRemoteFileTest using a remote data file", k8sTestTag, localTestTag) { assert(sys.props.contains("spark.test.home"), "spark.test.home is not set!") TestUtils.withHttpServer(sys.props("spark.test.home")) { baseURL => - sparkAppConf - .set("spark.files", baseURL.toString + REMOTE_PAGE_RANK_DATA_FILE) + sparkAppConf.set("spark.files", baseURL.toString + + REMOTE_PAGE_RANK_DATA_FILE.replace(sys.props("spark.test.home"), "").substring(1)) runSparkRemoteCheckAndVerifyCompletion(appArgs = Array(REMOTE_PAGE_RANK_FILE_NAME)) } } } -private[spark] object BasicTestsSuite { +private[spark] object BasicTestsSuite extends SparkFunSuite { val SPARK_PAGE_RANK_MAIN_CLASS: String = "org.apache.spark.examples.SparkPageRank" val CONTAINER_LOCAL_FILE_DOWNLOAD_PATH = "/var/spark-data/spark-files" val CONTAINER_LOCAL_DOWNLOADED_PAGE_RANK_DATA_FILE = s"$CONTAINER_LOCAL_FILE_DOWNLOAD_PATH/pagerank_data.txt" - val REMOTE_PAGE_RANK_DATA_FILE = "data/mllib/pagerank_data.txt" + val REMOTE_PAGE_RANK_DATA_FILE = getTestResourcePath("pagerank_data.txt") val REMOTE_PAGE_RANK_FILE_NAME = "pagerank_data.txt" } diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala index 9605f6c42a45d..51ea1307236c8 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala @@ -27,6 +27,7 @@ import org.scalatest.concurrent.{Eventually, PatienceConfiguration} import org.scalatest.matchers.should.Matchers._ import org.scalatest.time.{Minutes, Seconds, Span} +import org.apache.spark.deploy.k8s.Config.{KUBERNETES_EXECUTOR_DECOMMISSION_LABEL, KUBERNETES_EXECUTOR_DECOMMISSION_LABEL_VALUE} import org.apache.spark.internal.config import org.apache.spark.internal.config.PLUGINS @@ -140,8 +141,8 @@ private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => // give enough time to validate the labels are set. .set("spark.storage.decommission.replicationReattemptInterval", "75") // Configure labels for decommissioning pods. - .set("spark.kubernetes.executor.decommmissionLabel", "solong") - .set("spark.kubernetes.executor.decommmissionLabelValue", "cruelworld") + .set(KUBERNETES_EXECUTOR_DECOMMISSION_LABEL.key, "solong") + .set(KUBERNETES_EXECUTOR_DECOMMISSION_LABEL_VALUE.key, "cruelworld") // This is called on all exec pods but we only care about exec 0 since it's the "first." // We only do this inside of this test since the other tests trigger k8s side deletes where we @@ -151,7 +152,7 @@ private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => val client = kubernetesTestComponents.kubernetesClient // The label will be added eventually, but k8s objects don't refresh. Eventually.eventually( - PatienceConfiguration.Timeout(Span(1200, Seconds)), + PatienceConfiguration.Timeout(Span(120, Seconds)), PatienceConfiguration.Interval(Span(1, Seconds))) { val currentPod = client.pods().withName(pod.getMetadata.getName).get diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index e608b3aa76ae9..3db51b2860023 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -26,6 +26,7 @@ import com.google.common.base.Charsets import com.google.common.io.Files import io.fabric8.kubernetes.api.model.Pod import io.fabric8.kubernetes.client.{Watcher, WatcherException} +import io.fabric8.kubernetes.client.KubernetesClientException import io.fabric8.kubernetes.client.Watcher.Action import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, Tag} import org.scalatest.concurrent.{Eventually, PatienceConfiguration} @@ -35,9 +36,9 @@ import org.scalatest.matchers.should.Matchers._ import org.scalatest.time.{Minutes, Seconds, Span} import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.k8s.Constants.ENV_APPLICATION_ID import org.apache.spark.deploy.k8s.integrationtest.TestConstants._ import org.apache.spark.deploy.k8s.integrationtest.backend.{IntegrationTestBackend, IntegrationTestBackendFactory} -import org.apache.spark.deploy.k8s.integrationtest.backend.minikube.Minikube import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ @@ -76,7 +77,7 @@ class KubernetesSuite extends SparkFunSuite protected override def logForFailedTest(): Unit = { logInfo("\n\n===== EXTRA LOGS FOR THE FAILED TEST\n") logInfo("BEGIN DESCRIBE PODS for application\n" + - Minikube.describePods(s"spark-app-locator=$appLocator").mkString("\n")) + testBackend.describePods(s"spark-app-locator=$appLocator").mkString("\n")) logInfo("END DESCRIBE PODS for the application") val driverPodOption = kubernetesTestComponents.kubernetesClient .pods() @@ -106,11 +107,10 @@ class KubernetesSuite extends SparkFunSuite .withName(execPod.getMetadata.getName) .getLog } catch { - case e: io.fabric8.kubernetes.client.KubernetesClientException => - "Error fetching log (pod is likely not ready) ${e}" + case e: KubernetesClientException => + s"Error fetching log (pod is likely not ready) $e" } - logInfo(s"\nBEGIN executor (${execPod.getMetadata.getName}) POD log:\n" + - podLog) + logInfo(s"\nBEGIN executor (${execPod.getMetadata.getName}) POD log:\n$podLog") logInfo(s"END executor (${execPod.getMetadata.getName}) POD log") } } @@ -182,9 +182,10 @@ class KubernetesSuite extends SparkFunSuite } } - before { + protected def setUpTest(): Unit = { appLocator = UUID.randomUUID().toString.replaceAll("-", "") - driverPodName = "spark-test-app-" + UUID.randomUUID().toString.replaceAll("-", "") + driverPodName = "spark-test-app-" + + UUID.randomUUID().toString.replaceAll("-", "") + "-driver" sparkAppConf = kubernetesTestComponents.newSparkAppConf() .set("spark.kubernetes.container.image", image) .set("spark.kubernetes.driver.pod.name", driverPodName) @@ -196,6 +197,10 @@ class KubernetesSuite extends SparkFunSuite } } + before { + setUpTest() + } + after { if (!kubernetesTestComponents.hasUserSpecifiedNamespace) { kubernetesTestComponents.deleteNamespace() @@ -209,7 +214,9 @@ class KubernetesSuite extends SparkFunSuite driverPodChecker: Pod => Unit = doBasicDriverPodCheck, executorPodChecker: Pod => Unit = doBasicExecutorPodCheck, appArgs: Array[String] = Array.empty[String], - isJVM: Boolean = true ): Unit = { + isJVM: Boolean = true, + customSparkConf: Option[SparkAppConf] = None, + customAppLocator: Option[String] = None): Unit = { runSparkApplicationAndVerifyCompletion( appResource, SPARK_PI_MAIN_CLASS, @@ -218,7 +225,10 @@ class KubernetesSuite extends SparkFunSuite appArgs, driverPodChecker, executorPodChecker, - isJVM) + isJVM, + customSparkConf = customSparkConf, + customAppLocator = customAppLocator + ) } protected def runDFSReadWriteAndVerifyCompletion( @@ -333,7 +343,9 @@ class KubernetesSuite extends SparkFunSuite pyFiles: Option[String] = None, executorPatience: Option[(Option[Interval], Option[Timeout])] = None, decommissioningTest: Boolean = false, - env: Map[String, String] = Map.empty[String, String]): Unit = { + env: Map[String, String] = Map.empty[String, String], + customSparkConf: Option[SparkAppConf] = None, + customAppLocator: Option[String] = None): Unit = { // scalastyle:on argcount val appArguments = SparkAppArguments( @@ -367,7 +379,7 @@ class KubernetesSuite extends SparkFunSuite val execWatcher = kubernetesTestComponents.kubernetesClient .pods() - .withLabel("spark-app-locator", appLocator) + .withLabel("spark-app-locator", customAppLocator.getOrElse(appLocator)) .withLabel("spark-role", "executor") .watch(new Watcher[Pod] { logDebug("Beginning watch of executors") @@ -431,7 +443,7 @@ class KubernetesSuite extends SparkFunSuite logDebug("Starting Spark K8s job") SparkAppLauncher.launch( appArguments, - sparkAppConf, + customSparkConf.getOrElse(sparkAppConf), TIMEOUT.value.toSeconds.toInt, sparkHomeDir, isJVM, @@ -440,7 +452,7 @@ class KubernetesSuite extends SparkFunSuite val driverPod = kubernetesTestComponents.kubernetesClient .pods() - .withLabel("spark-app-locator", appLocator) + .withLabel("spark-app-locator", customAppLocator.getOrElse(appLocator)) .withLabel("spark-role", "driver") .list() .getItems @@ -553,6 +565,7 @@ class KubernetesSuite extends SparkFunSuite assert(pod.getMetadata.getLabels.get("label2") === "label2-value") assert(pod.getMetadata.getAnnotations.get("annotation1") === "annotation1-value") assert(pod.getMetadata.getAnnotations.get("annotation2") === "annotation2-value") + val appId = pod.getMetadata.getAnnotations.get("yunikorn.apache.org/app-id") val container = pod.getSpec.getContainers.get(0) val envVars = container @@ -564,6 +577,7 @@ class KubernetesSuite extends SparkFunSuite .toMap assert(envVars("ENV1") === "VALUE1") assert(envVars("ENV2") === "VALUE2") + assert(appId === envVars(ENV_APPLICATION_ID)) } private def deleteDriverPod(): Unit = { @@ -596,6 +610,8 @@ class KubernetesSuite extends SparkFunSuite private[spark] object KubernetesSuite { val k8sTestTag = Tag("k8s") + val localTestTag = Tag("local") + val schedulingTestTag = Tag("schedule") val rTestTag = Tag("r") val MinikubeTag = Tag("minikube") val SPARK_PI_MAIN_CLASS: String = "org.apache.spark.examples.SparkPi" diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala index 411857f0229db..4fdb89eab6eb6 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala @@ -37,7 +37,8 @@ private[spark] class KubernetesTestComponents(defaultClient: DefaultKubernetesCl val namespaceOption = Option(System.getProperty(CONFIG_KEY_KUBE_NAMESPACE)) val hasUserSpecifiedNamespace = namespaceOption.isDefined - val namespace = namespaceOption.getOrElse(UUID.randomUUID().toString.replaceAll("-", "")) + val namespace = namespaceOption.getOrElse("spark-" + + UUID.randomUUID().toString.replaceAll("-", "")) val serviceAccountName = Option(System.getProperty(CONFIG_KEY_KUBE_SVC_ACCOUNT)) .getOrElse("default") diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PodTemplateSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PodTemplateSuite.scala index e5a847e7210cb..2cd3bb4c4a7d9 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PodTemplateSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PodTemplateSuite.scala @@ -20,7 +20,7 @@ import java.io.File import io.fabric8.kubernetes.api.model.Pod -import org.apache.spark.deploy.k8s.integrationtest.KubernetesSuite.k8sTestTag +import org.apache.spark.deploy.k8s.integrationtest.KubernetesSuite.{k8sTestTag, schedulingTestTag} private[spark] trait PodTemplateSuite { k8sSuite: KubernetesSuite => @@ -46,10 +46,36 @@ private[spark] trait PodTemplateSuite { k8sSuite: KubernetesSuite => } ) } + + test("SPARK-38398: Schedule pod creation from template", k8sTestTag, schedulingTestTag) { + sparkAppConf + .set("spark.kubernetes.driver.podTemplateFile", + DRIVER_SCHEDULE_TEMPLATE_FILE.getAbsolutePath) + .set("spark.kubernetes.executor.podTemplateFile", EXECUTOR_TEMPLATE_FILE.getAbsolutePath) + runSparkPiAndVerifyCompletion( + driverPodChecker = (driverPod: Pod) => { + assert(driverPod.getMetadata.getName === driverPodName) + assert(driverPod.getSpec.getContainers.get(0).getImage === image) + assert(driverPod.getSpec.getContainers.get(0).getName === "test-driver-container") + assert(driverPod.getMetadata.getLabels.containsKey(LABEL_KEY)) + assert(driverPod.getMetadata.getLabels.get(LABEL_KEY) === "driver-template-label-value") + assert(driverPod.getSpec.getPriority() === 2000001000) + }, + executorPodChecker = (executorPod: Pod) => { + assert(executorPod.getSpec.getContainers.get(0).getImage === image) + assert(executorPod.getSpec.getContainers.get(0).getName === "test-executor-container") + assert(executorPod.getMetadata.getLabels.containsKey(LABEL_KEY)) + assert(executorPod.getMetadata.getLabels.get(LABEL_KEY) === "executor-template-label-value") + assert(executorPod.getSpec.getPriority() === 0) // When there is no default, 0 is used. + } + ) + } } private[spark] object PodTemplateSuite { val LABEL_KEY = "template-label-key" val DRIVER_TEMPLATE_FILE = new File(getClass.getResource("/driver-template.yml").getFile) + val DRIVER_SCHEDULE_TEMPLATE_FILE = + new File(getClass.getResource("/driver-schedule-template.yml").getFile) val EXECUTOR_TEMPLATE_FILE = new File(getClass.getResource("/executor-template.yml").getFile) } diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConstants.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConstants.scala index 2b1fd08164616..c46839f1dffcc 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConstants.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConstants.scala @@ -19,6 +19,7 @@ package org.apache.spark.deploy.k8s.integrationtest object TestConstants { val BACKEND_MINIKUBE = "minikube" val BACKEND_DOCKER_FOR_DESKTOP = "docker-for-desktop" + val BACKEND_DOCKER_DESKTOP = "docker-desktop" val BACKEND_CLOUD = "cloud" val CONFIG_KEY_DEPLOY_MODE = "spark.kubernetes.test.deployMode" diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/Utils.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/Utils.scala index cc258533c2c8d..e0fd92617ba6d 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/Utils.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/Utils.scala @@ -24,7 +24,7 @@ import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConverters._ import io.fabric8.kubernetes.client.dsl.ExecListener -import okhttp3.Response +import io.fabric8.kubernetes.client.dsl.ExecListener.Response import org.apache.commons.compress.archivers.tar.{TarArchiveEntry, TarArchiveOutputStream} import org.apache.commons.compress.compressors.gzip.GzipCompressorOutputStream import org.apache.commons.compress.utils.IOUtils @@ -62,7 +62,7 @@ object Utils extends Logging { val openLatch: CountDownLatch = new CountDownLatch(1) val closeLatch: CountDownLatch = new CountDownLatch(1) - override def onOpen(response: Response): Unit = { + override def onOpen(): Unit = { openLatch.countDown() } diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/VolcanoSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/VolcanoSuite.scala new file mode 100644 index 0000000000000..ed7371718f9a5 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/VolcanoSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import org.scalatest.Tag + +class VolcanoSuite extends KubernetesSuite with VolcanoTestsSuite { + + override protected def setUpTest(): Unit = { + super.setUpTest() + sparkAppConf + .set("spark.kubernetes.driver.scheduler.name", "volcano") + .set("spark.kubernetes.executor.scheduler.name", "volcano") + } +} + +private[spark] object VolcanoSuite { + val volcanoTag = Tag("volcano") +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/VolcanoTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/VolcanoTestsSuite.scala new file mode 100644 index 0000000000000..8d5054465b9e5 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/VolcanoTestsSuite.scala @@ -0,0 +1,418 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import java.io.{File, FileInputStream} +import java.time.Instant +import java.util.UUID + +import scala.collection.JavaConverters._ +import scala.collection.mutable +// scalastyle:off executioncontextglobal +import scala.concurrent.ExecutionContext.Implicits.global +// scalastyle:on executioncontextglobal +import scala.concurrent.Future + +import io.fabric8.kubernetes.api.model.Pod +import io.fabric8.kubernetes.client.NamespacedKubernetesClient +import io.fabric8.volcano.client.VolcanoClient +import org.scalatest.BeforeAndAfterEach +import org.scalatest.concurrent.Eventually + +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.features.VolcanoFeatureStep +import org.apache.spark.internal.config.NETWORK_AUTH_ENABLED + +private[spark] trait VolcanoTestsSuite extends BeforeAndAfterEach { k8sSuite: KubernetesSuite => + import VolcanoTestsSuite._ + import org.apache.spark.deploy.k8s.integrationtest.VolcanoSuite.volcanoTag + import org.apache.spark.deploy.k8s.integrationtest.KubernetesSuite.{k8sTestTag, INTERVAL, TIMEOUT, + SPARK_DRIVER_MAIN_CLASS} + + lazy val volcanoClient: VolcanoClient + = kubernetesTestComponents.kubernetesClient.adapt(classOf[VolcanoClient]) + lazy val k8sClient: NamespacedKubernetesClient = kubernetesTestComponents.kubernetesClient + private val testGroups: mutable.Set[String] = mutable.Set.empty + private val testYAMLPaths: mutable.Set[String] = mutable.Set.empty + + private def deletePodInTestGroup(): Unit = { + testGroups.foreach { g => + k8sClient.pods().withLabel("spark-group-locator", g).delete() + Eventually.eventually(TIMEOUT, INTERVAL) { + assert(k8sClient.pods().withLabel("spark-group-locator", g).list().getItems.isEmpty) + } + } + testGroups.clear() + } + + private def deleteYamlResources(): Unit = { + testYAMLPaths.foreach { yaml => + deleteYAMLResource(yaml) + Eventually.eventually(TIMEOUT, INTERVAL) { + val resources = k8sClient.load(new FileInputStream(yaml)).fromServer.get.asScala + // Make sure all elements are null (no specific resources in cluster) + resources.foreach { r => assert(r === null) } + } + } + testYAMLPaths.clear() + } + + override protected def afterEach(): Unit = { + deletePodInTestGroup() + deleteYamlResources() + super.afterEach() + } + + protected def generateGroupName(name: String): String = { + val groupName = GROUP_PREFIX + name + // Append to testGroups + testGroups += groupName + groupName + } + + protected def checkScheduler(pod: Pod): Unit = { + assert(pod.getSpec.getSchedulerName === "volcano") + } + + protected def checkAnnotaion(pod: Pod): Unit = { + val appId = pod.getMetadata.getLabels.get("spark-app-selector") + val annotations = pod.getMetadata.getAnnotations + assert(annotations.get("scheduling.k8s.io/group-name") === s"$appId-podgroup") + } + + protected def checkPodGroup( + pod: Pod, + queue: Option[String] = None, + priorityClassName: Option[String] = None): Unit = { + val appId = pod.getMetadata.getLabels.get("spark-app-selector") + val podGroupName = s"$appId-podgroup" + val podGroup = volcanoClient.podGroups().withName(podGroupName).get() + assert(podGroup.getMetadata.getOwnerReferences.get(0).getName === pod.getMetadata.getName) + queue.foreach(q => assert(q === podGroup.getSpec.getQueue)) + priorityClassName.foreach(_ => + assert(pod.getSpec.getPriorityClassName === podGroup.getSpec.getPriorityClassName)) + } + + private def createOrReplaceYAMLResource(yamlPath: String): Unit = { + k8sClient.load(new FileInputStream(yamlPath)).createOrReplace() + testYAMLPaths += yamlPath + } + + private def deleteYAMLResource(yamlPath: String): Unit = { + k8sClient.load(new FileInputStream(yamlPath)).delete() + } + + private def getPods( + role: String, + groupLocator: String, + statusPhase: String): mutable.Buffer[Pod] = { + k8sClient + .pods() + .withLabel("spark-group-locator", groupLocator) + .withLabel("spark-role", role) + .withField("status.phase", statusPhase) + .list() + .getItems.asScala + } + + def runJobAndVerify( + batchSuffix: String, + groupLoc: Option[String] = None, + queue: Option[String] = None, + driverTemplate: Option[String] = None, + isDriverJob: Boolean = false, + driverPodGroupTemplate: Option[String] = None): Unit = { + val appLoc = s"${appLocator}${batchSuffix}" + val podName = s"${driverPodName}-${batchSuffix}" + // create new configuration for every job + val conf = createVolcanoSparkConf(podName, appLoc, groupLoc, queue, driverTemplate, + driverPodGroupTemplate) + if (isDriverJob) { + runSparkDriverSubmissionAndVerifyCompletion( + driverPodChecker = (driverPod: Pod) => { + checkScheduler(driverPod) + checkAnnotaion(driverPod) + checkPodGroup(driverPod, queue) + }, + customSparkConf = Option(conf), + customAppLocator = Option(appLoc) + ) + } else { + runSparkPiAndVerifyCompletion( + driverPodChecker = (driverPod: Pod) => { + checkScheduler(driverPod) + checkAnnotaion(driverPod) + checkPodGroup(driverPod, queue) + }, + executorPodChecker = (executorPod: Pod) => { + checkScheduler(executorPod) + checkAnnotaion(executorPod) + }, + customSparkConf = Option(conf), + customAppLocator = Option(appLoc) + ) + } + } + + protected def runSparkDriverSubmissionAndVerifyCompletion( + appResource: String = containerLocalSparkDistroExamplesJar, + mainClass: String = SPARK_DRIVER_MAIN_CLASS, + driverPodChecker: Pod => Unit = doBasicDriverPodCheck, + appArgs: Array[String] = Array("2"), + customSparkConf: Option[SparkAppConf] = None, + customAppLocator: Option[String] = None): Unit = { + val appArguments = SparkAppArguments( + mainAppResource = appResource, + mainClass = mainClass, + appArgs = appArgs) + SparkAppLauncher.launch( + appArguments, + customSparkConf.getOrElse(sparkAppConf), + TIMEOUT.value.toSeconds.toInt, + sparkHomeDir, + true) + val driverPod = kubernetesTestComponents.kubernetesClient + .pods() + .withLabel("spark-app-locator", customAppLocator.getOrElse(appLocator)) + .withLabel("spark-role", "driver") + .list() + .getItems + .get(0) + driverPodChecker(driverPod) + } + + private def createVolcanoSparkConf( + driverPodName: String = driverPodName, + appLoc: String = appLocator, + groupLoc: Option[String] = None, + queue: Option[String] = None, + driverTemplate: Option[String] = None, + driverPodGroupTemplate: Option[String] = None): SparkAppConf = { + val conf = kubernetesTestComponents.newSparkAppConf() + .set(CONTAINER_IMAGE.key, image) + .set(KUBERNETES_DRIVER_POD_NAME.key, driverPodName) + .set(s"${KUBERNETES_DRIVER_LABEL_PREFIX}spark-app-locator", appLoc) + .set(s"${KUBERNETES_EXECUTOR_LABEL_PREFIX}spark-app-locator", appLoc) + .set(NETWORK_AUTH_ENABLED.key, "true") + // below is volcano specific configuration + .set(KUBERNETES_SCHEDULER_NAME.key, "volcano") + .set(KUBERNETES_DRIVER_POD_FEATURE_STEPS.key, VOLCANO_FEATURE_STEP) + .set(KUBERNETES_EXECUTOR_POD_FEATURE_STEPS.key, VOLCANO_FEATURE_STEP) + queue.foreach { q => + conf.set(VolcanoFeatureStep.POD_GROUP_TEMPLATE_FILE_KEY, + new File( + getClass.getResource(s"/volcano/$q-driver-podgroup-template.yml").getFile + ).getAbsolutePath) + } + driverPodGroupTemplate.foreach(conf.set(VolcanoFeatureStep.POD_GROUP_TEMPLATE_FILE_KEY, _)) + groupLoc.foreach { locator => + conf.set(s"${KUBERNETES_DRIVER_LABEL_PREFIX}spark-group-locator", locator) + conf.set(s"${KUBERNETES_EXECUTOR_LABEL_PREFIX}spark-group-locator", locator) + } + driverTemplate.foreach(conf.set(KUBERNETES_DRIVER_PODTEMPLATE_FILE.key, _)) + conf + } + + test("Run SparkPi with volcano scheduler", k8sTestTag, volcanoTag) { + sparkAppConf + .set("spark.kubernetes.driver.pod.featureSteps", VOLCANO_FEATURE_STEP) + .set("spark.kubernetes.executor.pod.featureSteps", VOLCANO_FEATURE_STEP) + runSparkPiAndVerifyCompletion( + driverPodChecker = (driverPod: Pod) => { + doBasicDriverPodCheck(driverPod) + checkScheduler(driverPod) + checkAnnotaion(driverPod) + checkPodGroup(driverPod) + }, + executorPodChecker = (executorPod: Pod) => { + doBasicExecutorPodCheck(executorPod) + checkScheduler(executorPod) + checkAnnotaion(executorPod) + } + ) + } + + private def verifyJobsSucceededOneByOne(jobNum: Int, groupName: String): Unit = { + // Check Pending jobs completed one by one + (1 until jobNum).map { completedNum => + Eventually.eventually(TIMEOUT, INTERVAL) { + val pendingPods = getPods(role = "driver", groupName, statusPhase = "Pending") + assert(pendingPods.size === jobNum - completedNum) + } + } + // All jobs succeeded finally + Eventually.eventually(TIMEOUT, INTERVAL) { + val succeededPods = getPods(role = "driver", groupName, statusPhase = "Succeeded") + assert(succeededPods.size === jobNum) + } + } + + test("SPARK-38187: Run SparkPi Jobs with minCPU", k8sTestTag, volcanoTag) { + val groupName = generateGroupName("min-cpu") + // Create a queue with 2 CPU, 3G memory capacity + createOrReplaceYAMLResource(QUEUE_2U_3G_YAML) + // Submit 3 jobs with minCPU = 2 + val jobNum = 3 + (1 to jobNum).map { i => + Future { + runJobAndVerify( + i.toString, + groupLoc = Option(groupName), + driverPodGroupTemplate = Option(DRIVER_PG_TEMPLATE_CPU_2U)) + } + } + verifyJobsSucceededOneByOne(jobNum, groupName) + } + + test("SPARK-38187: Run SparkPi Jobs with minMemory", k8sTestTag, volcanoTag) { + val groupName = generateGroupName("min-mem") + // Create a queue with 2 CPU, 3G memory capacity + createOrReplaceYAMLResource(QUEUE_2U_3G_YAML) + // Submit 3 jobs with minMemory = 3g + val jobNum = 3 + (1 to jobNum).map { i => + Future { + runJobAndVerify( + i.toString, + groupLoc = Option(groupName), + driverPodGroupTemplate = Option(DRIVER_PG_TEMPLATE_MEMORY_3G)) + } + } + verifyJobsSucceededOneByOne(jobNum, groupName) + } + + test("SPARK-38188: Run SparkPi jobs with 2 queues (only 1 enabled)", k8sTestTag, volcanoTag) { + // Disabled queue0 and enabled queue1 + createOrReplaceYAMLResource(VOLCANO_Q0_DISABLE_Q1_ENABLE_YAML) + // Submit jobs into disabled queue0 and enabled queue1 + val jobNum = 4 + (1 to jobNum).foreach { i => + Future { + val queueName = s"queue${i % 2}" + val groupName = generateGroupName(queueName) + runJobAndVerify(i.toString, Option(groupName), Option(queueName)) + } + } + // There are two `Succeeded` jobs and two `Pending` jobs + Eventually.eventually(TIMEOUT, INTERVAL) { + val completedPods = getPods("driver", s"${GROUP_PREFIX}queue1", "Succeeded") + assert(completedPods.size === 2) + val pendingPods = getPods("driver", s"${GROUP_PREFIX}queue0", "Pending") + assert(pendingPods.size === 2) + } + } + + test("SPARK-38188: Run SparkPi jobs with 2 queues (all enabled)", k8sTestTag, volcanoTag) { + val groupName = generateGroupName("queue-enable") + // Enable all queues + createOrReplaceYAMLResource(VOLCANO_ENABLE_Q0_AND_Q1_YAML) + val jobNum = 4 + // Submit jobs into these two queues + (1 to jobNum).foreach { i => + Future { + val queueName = s"queue${i % 2}" + runJobAndVerify(i.toString, Option(groupName), Option(queueName)) + } + } + // All jobs "Succeeded" + Eventually.eventually(TIMEOUT, INTERVAL) { + val completedPods = getPods("driver", groupName, "Succeeded") + assert(completedPods.size === jobNum) + } + } + + test("SPARK-38423: Run driver job to validate priority order", k8sTestTag, volcanoTag) { + // Prepare the priority resource and queue + createOrReplaceYAMLResource(DISABLE_QUEUE) + createOrReplaceYAMLResource(VOLCANO_PRIORITY_YAML) + // Submit 3 jobs with different priority + val priorities = Seq("low", "medium", "high") + priorities.foreach { p => + Future { + val templatePath = new File( + getClass.getResource(s"/volcano/$p-priority-driver-template.yml").getFile + ).getAbsolutePath + val pgTemplatePath = new File( + getClass.getResource(s"/volcano/$p-priority-driver-podgroup-template.yml").getFile + ).getAbsolutePath + val groupName = generateGroupName(p) + runJobAndVerify( + p, groupLoc = Option(groupName), + queue = Option("queue"), + driverTemplate = Option(templatePath), + driverPodGroupTemplate = Option(pgTemplatePath), + isDriverJob = true + ) + } + } + // Make sure 3 jobs are pending + Eventually.eventually(TIMEOUT, INTERVAL) { + priorities.foreach { p => + val pods = getPods(role = "driver", s"$GROUP_PREFIX$p", statusPhase = "Pending") + assert(pods.size === 1) + } + } + + // Enable queue to let jobs running one by one + createOrReplaceYAMLResource(ENABLE_QUEUE) + + // Verify scheduling order follow the specified priority + Eventually.eventually(TIMEOUT, INTERVAL) { + var m = Map.empty[String, Instant] + priorities.foreach { p => + val pods = getPods(role = "driver", s"$GROUP_PREFIX$p", statusPhase = "Succeeded") + assert(pods.size === 1) + val conditions = pods.head.getStatus.getConditions.asScala + val scheduledTime + = conditions.filter(_.getType === "PodScheduled").head.getLastTransitionTime + m += (p -> Instant.parse(scheduledTime)) + } + // high --> medium --> low + assert(m("high").isBefore(m("medium"))) + assert(m("medium").isBefore(m("low"))) + } + } +} + +private[spark] object VolcanoTestsSuite extends SparkFunSuite { + val VOLCANO_FEATURE_STEP = classOf[VolcanoFeatureStep].getName + val VOLCANO_ENABLE_Q0_AND_Q1_YAML = new File( + getClass.getResource("/volcano/enable-queue0-enable-queue1.yml").getFile + ).getAbsolutePath + val VOLCANO_Q0_DISABLE_Q1_ENABLE_YAML = new File( + getClass.getResource("/volcano/disable-queue0-enable-queue1.yml").getFile + ).getAbsolutePath + val GROUP_PREFIX = "volcano-test" + UUID.randomUUID().toString.replaceAll("-", "") + "-" + val VOLCANO_PRIORITY_YAML + = new File(getClass.getResource("/volcano/priorityClasses.yml").getFile).getAbsolutePath + val ENABLE_QUEUE = new File( + getClass.getResource("/volcano/enable-queue.yml").getFile + ).getAbsolutePath + val DISABLE_QUEUE = new File( + getClass.getResource("/volcano/disable-queue.yml").getFile + ).getAbsolutePath + val QUEUE_2U_3G_YAML = new File( + getClass.getResource("/volcano/queue-2u-3g.yml").getFile + ).getAbsolutePath + val DRIVER_PG_TEMPLATE_CPU_2U = new File( + getClass.getResource("/volcano/driver-podgroup-template-cpu-2u.yml").getFile + ).getAbsolutePath + val DRIVER_PG_TEMPLATE_MEMORY_3G = new File( + getClass.getResource("/volcano/driver-podgroup-template-memory-3g.yml").getFile + ).getAbsolutePath +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/IntegrationTestBackend.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/IntegrationTestBackend.scala index 56ddae0c9c57c..ced8151b709b5 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/IntegrationTestBackend.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/IntegrationTestBackend.scala @@ -19,6 +19,7 @@ package org.apache.spark.deploy.k8s.integrationtest.backend import io.fabric8.kubernetes.client.DefaultKubernetesClient +import org.apache.spark.deploy.k8s.integrationtest.ProcessUtils import org.apache.spark.deploy.k8s.integrationtest.TestConstants._ import org.apache.spark.deploy.k8s.integrationtest.backend.cloud.KubeConfigBackend import org.apache.spark.deploy.k8s.integrationtest.backend.docker.DockerForDesktopBackend @@ -28,6 +29,10 @@ private[spark] trait IntegrationTestBackend { def initialize(): Unit def getKubernetesClient: DefaultKubernetesClient def cleanUp(): Unit = {} + def describePods(labels: String): Seq[String] = + ProcessUtils.executeProcess( + Array("bash", "-c", s"kubectl describe pods --all-namespaces -l $labels"), + timeout = 60, dumpOutput = false).filter { !_.contains("https://github.com/kubernetes") } } private[spark] object IntegrationTestBackendFactory { @@ -38,7 +43,7 @@ private[spark] object IntegrationTestBackendFactory { case BACKEND_MINIKUBE => MinikubeTestBackend case BACKEND_CLOUD => new KubeConfigBackend(System.getProperty(CONFIG_KEY_KUBE_CONFIG_CONTEXT)) - case BACKEND_DOCKER_FOR_DESKTOP => DockerForDesktopBackend + case BACKEND_DOCKER_FOR_DESKTOP | BACKEND_DOCKER_DESKTOP => DockerForDesktopBackend case _ => throw new IllegalArgumentException("Invalid " + CONFIG_KEY_DEPLOY_MODE + ": " + deployMode) } diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/cloud/KubeConfigBackend.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/cloud/KubeConfigBackend.scala index 0fbed4a220e68..83535488cc0ab 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/cloud/KubeConfigBackend.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/cloud/KubeConfigBackend.scala @@ -60,6 +60,9 @@ private[spark] class KubeConfigBackend(var context: String) } override def cleanUp(): Unit = { + if (defaultClient != null) { + defaultClient.close() + } super.cleanUp() } diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/docker/DockerForDesktopBackend.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/docker/DockerForDesktopBackend.scala index 81a11ae9dcdc6..f206befc64ff1 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/docker/DockerForDesktopBackend.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/docker/DockerForDesktopBackend.scala @@ -20,6 +20,6 @@ import org.apache.spark.deploy.k8s.integrationtest.TestConstants import org.apache.spark.deploy.k8s.integrationtest.backend.cloud.KubeConfigBackend private[spark] object DockerForDesktopBackend - extends KubeConfigBackend(TestConstants.BACKEND_DOCKER_FOR_DESKTOP) { + extends KubeConfigBackend(TestConstants.BACKEND_DOCKER_DESKTOP) { } diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala index 1ebc64445b717..755feb9aca9e6 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala @@ -48,9 +48,9 @@ private[spark] object Minikube extends Logging { versionArrayOpt match { case Some(Array(x, y, z)) => - if (Ordering.Tuple3[Int, Int, Int].lt((x, y, z), (1, 7, 3))) { + if (Ordering.Tuple3[Int, Int, Int].lt((x, y, z), (1, 18, 0))) { assert(false, s"Unsupported Minikube version is detected: $minikubeVersionString." + - "For integration testing Minikube version 1.7.3 or greater is expected.") + "For integration testing Minikube version 1.18.0 or greater is expected.") } case _ => assert(false, s"Unexpected version format detected in `$minikubeVersionString`." + @@ -111,10 +111,6 @@ private[spark] object Minikube extends Logging { def minikubeServiceAction(args: String*): String = { executeMinikube(true, "service", args: _*).head } - - def describePods(labels: String): Seq[String] = - Minikube.executeMinikube(false, "kubectl", "--", "describe", "pods", "--all-namespaces", - "-l", labels) } private[spark] object MinikubeStatus extends Enumeration { diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/MinikubeTestBackend.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/MinikubeTestBackend.scala index f92977ddacdf5..f2ca57f89d0aa 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/MinikubeTestBackend.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/MinikubeTestBackend.scala @@ -34,10 +34,17 @@ private[spark] object MinikubeTestBackend extends IntegrationTestBackend { } override def cleanUp(): Unit = { + if (defaultClient != null) { + defaultClient.close() + } super.cleanUp() } override def getKubernetesClient: DefaultKubernetesClient = { defaultClient } + + override def describePods(labels: String): Seq[String] = + Minikube.executeMinikube(false, "kubectl", "--", "describe", "pods", "--all-namespaces", + "-l", labels) } diff --git a/resource-managers/mesos/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager b/resource-managers/mesos/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager index 12b6d5b64d68c..f83bfa166bec8 100644 --- a/resource-managers/mesos/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager +++ b/resource-managers/mesos/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager @@ -1 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + org.apache.spark.scheduler.cluster.mesos.MesosClusterManager diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala index 2fd13a5903243..9e4187837b680 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -105,6 +105,7 @@ private[mesos] class MesosSubmitRequestServlet( val superviseDriver = sparkProperties.get(config.DRIVER_SUPERVISE.key) val driverMemory = sparkProperties.get(config.DRIVER_MEMORY.key) val driverMemoryOverhead = sparkProperties.get(config.DRIVER_MEMORY_OVERHEAD.key) + val driverMemoryOverheadFactor = sparkProperties.get(config.DRIVER_MEMORY_OVERHEAD_FACTOR.key) val driverCores = sparkProperties.get(config.DRIVER_CORES.key) val name = request.sparkProperties.getOrElse("spark.app.name", mainClass) @@ -121,8 +122,10 @@ private[mesos] class MesosSubmitRequestServlet( mainClass, appArgs, environmentVariables, extraClassPath, extraLibraryPath, javaOpts) val actualSuperviseDriver = superviseDriver.map(_.toBoolean).getOrElse(DEFAULT_SUPERVISE) val actualDriverMemory = driverMemory.map(Utils.memoryStringToMb).getOrElse(DEFAULT_MEMORY) + val actualDriverMemoryFactor = driverMemoryOverheadFactor.map(_.toDouble).getOrElse( + MEMORY_OVERHEAD_FACTOR) val actualDriverMemoryOverhead = driverMemoryOverhead.map(_.toInt).getOrElse( - math.max((MEMORY_OVERHEAD_FACTOR * actualDriverMemory).toInt, MEMORY_OVERHEAD_MIN)) + math.max((actualDriverMemoryFactor * actualDriverMemory).toInt, MEMORY_OVERHEAD_MIN)) val actualDriverCores = driverCores.map(_.toDouble).getOrElse(DEFAULT_CORES) val submitDate = new Date() val submissionId = newDriverId(submitDate) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 6fedce61d8208..e5a6a5f1ef166 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -525,7 +525,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( partitionTaskResources(resources, taskCPUs, taskMemory, taskGPUs) val taskBuilder = MesosTaskInfo.newBuilder() - .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) + .setTaskId(TaskID.newBuilder().setValue(taskId).build()) .setSlaveId(offer.getSlaveId) .setCommand(createCommand(offer, taskCPUs + extraCoresPerExecutor, taskId)) .setName(s"${sc.appName} $taskId") @@ -679,7 +679,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( "is Spark installed on it?") } } - executorTerminated(d, agentId, taskId, s"Executor finished with state $state") + executorTerminated(agentId, taskId, s"Executor finished with state $state") // In case we'd rejected everything before but have now lost a node d.reviveOffers() } @@ -740,7 +740,6 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( * what tasks are running. It also notifies the driver that an executor was removed. */ private def executorTerminated( - d: org.apache.mesos.SchedulerDriver, agentId: String, taskId: String, reason: String): Unit = { diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index 586c2bdd67cfa..cc67dad196880 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -419,8 +419,7 @@ private[spark] class MesosFineGrainedSchedulerBackend( } } - private def recordAgentLost( - d: org.apache.mesos.SchedulerDriver, agentId: AgentID, reason: ExecutorLossReason): Unit = { + private def recordAgentLost(agentId: AgentID, reason: ExecutorLossReason): Unit = { inClassLoader() { logInfo("Mesos agent lost: " + agentId.getValue) removeExecutor(agentId.getValue, reason.toString) @@ -429,7 +428,7 @@ private[spark] class MesosFineGrainedSchedulerBackend( } override def agentLost(d: org.apache.mesos.SchedulerDriver, agentId: AgentID): Unit = { - recordAgentLost(d, agentId, ExecutorProcessLost()) + recordAgentLost(agentId, ExecutorProcessLost()) } override def executorLost( @@ -439,7 +438,7 @@ private[spark] class MesosFineGrainedSchedulerBackend( status: Int): Unit = { logInfo("Executor lost: %s, marking agent %s as lost".format(executorId.getValue, agentId.getValue)) - recordAgentLost(d, agentId, ExecutorExited(status, exitCausedByApp = true)) + recordAgentLost(agentId, ExecutorExited(status, exitCausedByApp = true)) } override def killTask( diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 38f83df00e428..524b1d514fafe 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -387,8 +387,7 @@ trait MesosSchedulerUtils extends Logging { } } - // These defaults copied from YARN - private val MEMORY_OVERHEAD_FRACTION = 0.10 + // This default copied from YARN private val MEMORY_OVERHEAD_MINIMUM = 384 /** @@ -400,8 +399,9 @@ trait MesosSchedulerUtils extends Logging { * (whichever is larger) */ def executorMemory(sc: SparkContext): Int = { + val memoryOverheadFactor = sc.conf.get(EXECUTOR_MEMORY_OVERHEAD_FACTOR) sc.conf.get(mesosConfig.EXECUTOR_MEMORY_OVERHEAD).getOrElse( - math.max(MEMORY_OVERHEAD_FRACTION * sc.executorMemory, MEMORY_OVERHEAD_MINIMUM).toInt) + + math.max(memoryOverheadFactor * sc.executorMemory, MEMORY_OVERHEAD_MINIMUM).toInt) + sc.executorMemory } @@ -415,7 +415,8 @@ trait MesosSchedulerUtils extends Logging { * `MEMORY_OVERHEAD_FRACTION (=0.1) * driverMemory` */ def driverContainerMemory(driverDesc: MesosDriverDescription): Int = { - val defaultMem = math.max(MEMORY_OVERHEAD_FRACTION * driverDesc.mem, MEMORY_OVERHEAD_MINIMUM) + val memoryOverheadFactor = driverDesc.conf.get(DRIVER_MEMORY_OVERHEAD_FACTOR) + val defaultMem = math.max(memoryOverheadFactor * driverDesc.mem, MEMORY_OVERHEAD_MINIMUM) driverDesc.conf.get(mesosConfig.DRIVER_MEMORY_OVERHEAD).getOrElse(defaultMem.toInt) + driverDesc.mem } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/deploy/rest/mesos/MesosRestServerSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/deploy/rest/mesos/MesosRestServerSuite.scala index 344fc38c84fb1..8bed43a54d5d0 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/deploy/rest/mesos/MesosRestServerSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/deploy/rest/mesos/MesosRestServerSuite.scala @@ -35,10 +35,16 @@ class MesosRestServerSuite extends SparkFunSuite testOverheadMemory(new SparkConf(), "2000M", 2384) } - test("test driver overhead memory with overhead factor") { + test("test driver overhead memory with default overhead factor") { testOverheadMemory(new SparkConf(), "5000M", 5500) } + test("test driver overhead memory with overhead factor") { + val conf = new SparkConf() + conf.set(config.DRIVER_MEMORY_OVERHEAD_FACTOR.key, "0.2") + testOverheadMemory(conf, "5000M", 6000) + } + test("test configured driver overhead memory") { val conf = new SparkConf() conf.set(config.DRIVER_MEMORY_OVERHEAD.key, "1000") diff --git a/resource-managers/yarn/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager b/resource-managers/yarn/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager index 6e8a1ebfc61da..3759c3f197a9c 100644 --- a/resource-managers/yarn/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager +++ b/resource-managers/yarn/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager @@ -1 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + org.apache.spark.scheduler.cluster.YarnClusterManager diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index ca4fbbb97ad28..f364b79216098 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -54,6 +54,7 @@ import org.apache.spark.api.python.PythonUtils import org.apache.spark.deploy.{SparkApplication, SparkHadoopUtil} import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.deploy.yarn.ResourceRequestHelper._ +import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ @@ -70,7 +71,6 @@ private[spark] class Client( extends Logging { import Client._ - import YarnSparkHadoopUtil._ private val yarnClient = YarnClient.createYarnClient private val hadoopConf = new YarnConfiguration(SparkHadoopUtil.newConfiguration(sparkConf)) @@ -85,6 +85,12 @@ private[spark] class Client( private var appMaster: ApplicationMaster = _ private var stagingDirPath: Path = _ + private val amMemoryOverheadFactor = if (isClusterMode) { + sparkConf.get(DRIVER_MEMORY_OVERHEAD_FACTOR) + } else { + AM_MEMORY_OVERHEAD_FACTOR + } + // AM related configurations private val amMemory = if (isClusterMode) { sparkConf.get(DRIVER_MEMORY).toInt @@ -94,7 +100,7 @@ private[spark] class Client( private val amMemoryOverhead = { val amMemoryOverheadEntry = if (isClusterMode) DRIVER_MEMORY_OVERHEAD else AM_MEMORY_OVERHEAD sparkConf.get(amMemoryOverheadEntry).getOrElse( - math.max((MEMORY_OVERHEAD_FACTOR * amMemory).toLong, + math.max((amMemoryOverheadFactor * amMemory).toLong, ResourceProfile.MEMORY_OVERHEAD_MIN_MIB)).toInt } private val amCores = if (isClusterMode) { @@ -107,8 +113,10 @@ private[spark] class Client( private val executorMemory = sparkConf.get(EXECUTOR_MEMORY) // Executor offHeap memory in MiB. protected val executorOffHeapMemory = Utils.executorOffHeapMemorySizeAsMb(sparkConf) + + private val executorMemoryOvereadFactor = sparkConf.get(EXECUTOR_MEMORY_OVERHEAD_FACTOR) private val executorMemoryOverhead = sparkConf.get(EXECUTOR_MEMORY_OVERHEAD).getOrElse( - math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toLong, + math.max((executorMemoryOvereadFactor * executorMemory).toLong, ResourceProfile.MEMORY_OVERHEAD_MIN_MIB)).toInt private val isPython = sparkConf.get(IS_PYTHON_APP) @@ -183,8 +191,10 @@ private[spark] class Client( yarnClient.init(hadoopConf) yarnClient.start() - logInfo("Requesting a new application from cluster with %d NodeManagers" - .format(yarnClient.getYarnClusterMetrics.getNumNodeManagers)) + if (log.isDebugEnabled) { + logDebug("Requesting a new application from cluster with %d NodeManagers" + .format(yarnClient.getYarnClusterMetrics.getNumNodeManagers)) + } // Get a new application from our RM val newApp = yarnClient.createApplication() diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ResourceRequestHelper.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ResourceRequestHelper.scala index 50e822510fd3d..5a5334dc76321 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ResourceRequestHelper.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ResourceRequestHelper.scala @@ -51,14 +51,14 @@ private object ResourceRequestHelper extends Logging { if (splitIndex == -1) { val errorMessage = s"Missing suffix for ${componentName}${key}, you must specify" + s" a suffix - $AMOUNT is currently the only supported suffix." - throw new IllegalArgumentException(errorMessage.toString()) + throw new IllegalArgumentException(errorMessage) } val resourceName = key.substring(0, splitIndex) val resourceSuffix = key.substring(splitIndex + 1) if (!AMOUNT.equals(resourceSuffix)) { val errorMessage = s"Unsupported suffix: $resourceSuffix in: ${componentName}${key}, " + s"only .$AMOUNT is supported." - throw new IllegalArgumentException(errorMessage.toString()) + throw new IllegalArgumentException(errorMessage) } (resourceName, value) }.toMap diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 54ab643f2755b..a85b7174673af 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -163,6 +163,8 @@ private[yarn] class YarnAllocator( private val isPythonApp = sparkConf.get(IS_PYTHON_APP) + private val memoryOverheadFactor = sparkConf.get(EXECUTOR_MEMORY_OVERHEAD_FACTOR) + private val launcherPool = ThreadUtils.newDaemonCachedThreadPool( "ContainerLauncher", sparkConf.get(CONTAINER_LAUNCH_MAX_THREADS)) @@ -280,9 +282,10 @@ private[yarn] class YarnAllocator( // track the resource profile if not already there getOrUpdateRunningExecutorForRPId(rp.id) logInfo(s"Resource profile ${rp.id} doesn't exist, adding it") + val resourcesWithDefaults = ResourceProfile.getResourcesForClusterManager(rp.id, rp.executorResources, - MEMORY_OVERHEAD_FACTOR, sparkConf, isPythonApp, resourceNameMapping) + memoryOverheadFactor, sparkConf, isPythonApp, resourceNameMapping) val customSparkResources = resourcesWithDefaults.customResources.map { case (name, execReq) => (name, execReq.amount.toString) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index 2f272be60ba25..842611807db4d 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -124,7 +124,7 @@ private[spark] class YarnRMClient extends Logging { /** Returns the maximum number of attempts to register the AM. */ def getMaxRegAttempts(sparkConf: SparkConf, yarnConf: YarnConfiguration): Int = { - val sparkMaxAttempts = sparkConf.get(MAX_APP_ATTEMPTS).map(_.toInt) + val sparkMaxAttempts = sparkConf.get(MAX_APP_ATTEMPTS) val yarnMaxAttempts = yarnConf.getInt( YarnConfiguration.RM_AM_MAX_ATTEMPTS, YarnConfiguration.DEFAULT_RM_AM_MAX_ATTEMPTS) sparkMaxAttempts match { diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index f347e37ba24ab..1869c739e4844 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -34,11 +34,10 @@ import org.apache.spark.util.Utils object YarnSparkHadoopUtil { - // Additional memory overhead + // Additional memory overhead for application masters in client mode. // 10% was arrived at experimentally. In the interest of minimizing memory waste while covering // the common cases. Memory overhead tends to grow with container size. - - val MEMORY_OVERHEAD_FACTOR = 0.10 + val AM_MEMORY_OVERHEAD_FACTOR = 0.10 val ANY_HOST = "*" diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index db65d128b07f0..ae010f11503dd 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -706,4 +706,33 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter sparkConf.set(MEMORY_OFFHEAP_SIZE, originalOffHeapSize) } } + + test("SPARK-38194: Configurable memory overhead factor") { + val executorMemory = sparkConf.get(EXECUTOR_MEMORY).toLong + try { + sparkConf.set(EXECUTOR_MEMORY_OVERHEAD_FACTOR, 0.5) + val (handler, _) = createAllocator(maxExecutors = 1, + additionalConfigs = Map(EXECUTOR_MEMORY.key -> executorMemory.toString)) + val defaultResource = handler.rpIdToYarnResource.get(defaultRPId) + val memory = defaultResource.getMemory + assert(memory == (executorMemory * 1.5).toLong) + } finally { + sparkConf.set(EXECUTOR_MEMORY_OVERHEAD_FACTOR, 0.1) + } + } + + test("SPARK-38194: Memory overhead takes precedence over factor") { + val executorMemory = sparkConf.get(EXECUTOR_MEMORY) + try { + sparkConf.set(EXECUTOR_MEMORY_OVERHEAD_FACTOR, 0.5) + sparkConf.set(EXECUTOR_MEMORY_OVERHEAD, (executorMemory * 0.4).toLong) + val (handler, _) = createAllocator(maxExecutors = 1, + additionalConfigs = Map(EXECUTOR_MEMORY.key -> executorMemory.toString)) + val defaultResource = handler.rpIdToYarnResource.get(defaultRPId) + val memory = defaultResource.getMemory + assert(memory == (executorMemory * 1.4).toLong) + } finally { + sparkConf.set(EXECUTOR_MEMORY_OVERHEAD_FACTOR, 0.1) + } + } } diff --git a/sbin/spark-config.sh b/sbin/spark-config.sh index f27b6fe8d9a04..341eb053ed7b2 100755 --- a/sbin/spark-config.sh +++ b/sbin/spark-config.sh @@ -28,6 +28,6 @@ export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"${SPARK_HOME}/conf"}" # Add the PySpark classes to the PYTHONPATH: if [ -z "${PYSPARK_PYTHONPATH_SET}" ]; then export PYTHONPATH="${SPARK_HOME}/python:${PYTHONPATH}" - export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.9.3-src.zip:${PYTHONPATH}" + export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.9.4-src.zip:${PYTHONPATH}" export PYSPARK_PYTHONPATH_SET=1 fi diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh index e563f7bff1667..3cfd5acfe2b56 100755 --- a/sbin/spark-daemon.sh +++ b/sbin/spark-daemon.sh @@ -86,12 +86,12 @@ spark_rotate_log () fi if [ -f "$log" ]; then # rotate logs - while [ $num -gt 1 ]; do - prev=`expr $num - 1` - [ -f "$log.$prev" ] && mv "$log.$prev" "$log.$num" - num=$prev - done - mv "$log" "$log.$num"; + while [ $num -gt 1 ]; do + prev=`expr $num - 1` + [ -f "$log.$prev" ] && mv "$log.$prev" "$log.$num" + num=$prev + done + mv "$log" "$log.$num"; fi } diff --git a/sbin/start-master.sh b/sbin/start-master.sh index b6a566e4daf4b..36fe4b4abeb91 100755 --- a/sbin/start-master.sh +++ b/sbin/start-master.sh @@ -51,11 +51,11 @@ fi if [ "$SPARK_MASTER_HOST" = "" ]; then case `uname` in (SunOS) - SPARK_MASTER_HOST="`/usr/sbin/check-hostname | awk '{print $NF}'`" - ;; + SPARK_MASTER_HOST="`/usr/sbin/check-hostname | awk '{print $NF}'`" + ;; (*) - SPARK_MASTER_HOST="`hostname -f`" - ;; + SPARK_MASTER_HOST="`hostname -f`" + ;; esac fi diff --git a/sbin/start-mesos-dispatcher.sh b/sbin/start-mesos-dispatcher.sh index ecaad7ad09634..c2e30d8c0b080 100755 --- a/sbin/start-mesos-dispatcher.sh +++ b/sbin/start-mesos-dispatcher.sh @@ -36,11 +36,11 @@ fi if [ "$SPARK_MESOS_DISPATCHER_HOST" = "" ]; then case `uname` in (SunOS) - SPARK_MESOS_DISPATCHER_HOST="`/usr/sbin/check-hostname | awk '{print $NF}'`" - ;; + SPARK_MESOS_DISPATCHER_HOST="`/usr/sbin/check-hostname | awk '{print $NF}'`" + ;; (*) - SPARK_MESOS_DISPATCHER_HOST="`hostname -f`" - ;; + SPARK_MESOS_DISPATCHER_HOST="`hostname -f`" + ;; esac fi diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 791d91040c816..9585785835d62 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -264,6 +264,12 @@ This file is divided into 3 sections: of Commons Lang 2 (package org.apache.commons.lang.*) + + scala\.concurrent\.ExecutionContext\.Implicits\.global + User queries can use global thread pool, causing starvation and eventual OOM. + Thus, Spark-internal APIs should not use this thread pool + + FileSystem.get\([a-zA-Z_$][a-zA-Z_$0-9]*\) Omit braces in case clauses. + + new (java\.lang\.)?(Byte|Integer|Long|Short)\( + Use static factory 'valueOf' or 'parseXXX' instead of the deprecated constructors. + + diff --git a/sql/catalyst/benchmarks/EnumTypeSetBenchmark-jdk11-results.txt b/sql/catalyst/benchmarks/EnumTypeSetBenchmark-jdk11-results.txt index 956db2edfbc02..d26da81a2514e 100644 --- a/sql/catalyst/benchmarks/EnumTypeSetBenchmark-jdk11-results.txt +++ b/sql/catalyst/benchmarks/EnumTypeSetBenchmark-jdk11-results.txt @@ -1,105 +1,105 @@ -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.14+9-LTS on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz Test contains use empty Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 1 1 0 1722.9 0.6 1.0X -Use EnumSet 0 0 0 Infinity 0.0 InfinityX +Use HashSet 1 1 0 1120.1 0.9 1.0X +Use EnumSet 2 2 0 550.8 1.8 0.5X -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.14+9-LTS on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz Test contains use 1 item Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 10 11 1 97.5 10.3 1.0X -Use EnumSet 0 0 0 Infinity 0.0 InfinityX +Use HashSet 8 8 1 126.0 7.9 1.0X +Use EnumSet 2 2 0 590.4 1.7 4.7X -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.14+9-LTS on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz Test contains use 3 items Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 22 23 1 46.1 21.7 1.0X -Use EnumSet 0 0 0 10000000.0 0.0 216928.7X +Use HashSet 15 15 1 67.4 14.8 1.0X +Use EnumSet 2 2 0 652.3 1.5 9.7X -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.14+9-LTS on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz Test contains use 5 items Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 18 20 2 57.0 17.6 1.0X -Use EnumSet 0 0 0 10000000.0 0.0 175588.1X +Use HashSet 17 18 1 57.5 17.4 1.0X +Use EnumSet 2 2 0 591.2 1.7 10.3X -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.14+9-LTS on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz Test contains use 10 items Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 20 22 2 50.2 19.9 1.0X -Use EnumSet 0 0 0 10000000.0 0.0 199224.4X +Use HashSet 18 18 0 54.8 18.2 1.0X +Use EnumSet 2 2 0 591.4 1.7 10.8X -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.14+9-LTS on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz Test create empty Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 1 1 0 147.1 6.8 1.0X -Use EnumSet 2 2 0 57.9 17.3 0.4X +Use HashSet 1 1 0 95.0 10.5 1.0X +Use EnumSet 2 2 0 54.4 18.4 0.6X -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.14+9-LTS on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz Test create 1 item Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 15 16 2 6.7 149.6 1.0X -Use EnumSet 2 3 1 42.6 23.5 6.4X +Use HashSet 31 32 2 3.2 310.3 1.0X +Use EnumSet 3 3 0 38.0 26.3 11.8X -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.14+9-LTS on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz Test create 3 items Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 45 47 2 2.2 450.9 1.0X -Use EnumSet 2 3 1 41.2 24.3 18.6X +Use HashSet 75 75 0 1.3 751.6 1.0X +Use EnumSet 3 3 0 36.1 27.7 27.2X -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.14+9-LTS on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz Test create 5 items Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 104 108 5 1.0 1036.7 1.0X -Use EnumSet 2 3 1 44.3 22.6 46.0X +Use HashSet 122 123 0 0.8 1225.0 1.0X +Use EnumSet 2 2 0 41.8 23.9 51.2X -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.14+9-LTS on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz Test create 10 items Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 147 154 5 0.7 1474.0 1.0X -Use EnumSet 2 2 1 56.9 17.6 83.8X +Use HashSet 161 162 0 0.6 1614.9 1.0X +Use EnumSet 2 2 0 52.1 19.2 84.2X -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.14+9-LTS on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz Test create and contains use empty Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 1 1 0 798.4 1.3 1.0X -Use EnumSet 0 0 0 Infinity 0.0 InfinityX +Use HashSet 2 2 0 608.4 1.6 1.0X +Use EnumSet 3 4 0 295.5 3.4 0.5X -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.14+9-LTS on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz Test create and contains use 1 item Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 39 42 3 25.5 39.2 1.0X -Use EnumSet 0 0 0 Infinity 0.0 InfinityX +Use HashSet 57 58 2 17.6 56.8 1.0X +Use EnumSet 4 4 0 284.2 3.5 16.2X -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.14+9-LTS on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz Test create and contains use 3 items Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 73 75 2 13.7 73.2 1.0X -Use EnumSet 0 0 0 Infinity 0.0 InfinityX +Use HashSet 97 97 0 10.3 96.7 1.0X +Use EnumSet 4 4 0 263.3 3.8 25.5X -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.14+9-LTS on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz Test create and contains use 5 items Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 157 162 3 6.4 157.3 1.0X -Use EnumSet 0 0 0 Infinity 0.0 InfinityX +Use HashSet 174 175 2 5.8 173.6 1.0X +Use EnumSet 4 4 0 240.7 4.2 41.8X -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.14+9-LTS on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz Test create and contains use 10 items Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -Use HashSet 197 206 6 5.1 197.4 1.0X -Use EnumSet 0 0 0 Infinity 0.0 InfinityX +Use HashSet 211 214 5 4.7 211.4 1.0X +Use EnumSet 4 4 0 272.3 3.7 57.6X diff --git a/sql/catalyst/benchmarks/EnumTypeSetBenchmark-jdk17-results.txt b/sql/catalyst/benchmarks/EnumTypeSetBenchmark-jdk17-results.txt index 982ad076e100d..d110a292f8e66 100644 --- a/sql/catalyst/benchmarks/EnumTypeSetBenchmark-jdk17-results.txt +++ b/sql/catalyst/benchmarks/EnumTypeSetBenchmark-jdk17-results.txt @@ -1,105 +1,105 @@ -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 17.0.2+8-LTS on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz Test contains use empty Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 0 1 0 2155.1 0.5 1.0X -Use EnumSet 0 0 0 Infinity 0.0 InfinityX +Use HashSet 5 6 0 194.8 5.1 1.0X +Use EnumSet 1 1 0 879.0 1.1 4.5X -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 17.0.2+8-LTS on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz Test contains use 1 item Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 10 10 0 100.7 9.9 1.0X -Use EnumSet 0 0 0 Infinity 0.0 InfinityX +Use HashSet 8 10 1 117.8 8.5 1.0X +Use EnumSet 1 1 0 904.7 1.1 7.7X -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 17.0.2+8-LTS on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz Test contains use 3 items Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 18 18 0 57.0 17.5 1.0X -Use EnumSet 0 0 0 Infinity 0.0 InfinityX +Use HashSet 16 18 1 60.8 16.4 1.0X +Use EnumSet 1 1 0 965.2 1.0 15.9X -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 17.0.2+8-LTS on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz Test contains use 5 items Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 15 15 0 68.0 14.7 1.0X -Use EnumSet 0 0 0 Infinity 0.0 InfinityX +Use HashSet 16 17 1 63.7 15.7 1.0X +Use EnumSet 1 1 0 933.1 1.1 14.7X -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 17.0.2+8-LTS on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz Test contains use 10 items Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 16 17 1 61.0 16.4 1.0X -Use EnumSet 0 0 0 Infinity 0.0 InfinityX +Use HashSet 16 19 2 60.7 16.5 1.0X +Use EnumSet 1 1 0 831.7 1.2 13.7X -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 17.0.2+8-LTS on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz Test create empty Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 0 1 0 218.9 4.6 1.0X -Use EnumSet 1 1 0 83.2 12.0 0.4X +Use HashSet 1 1 0 99.7 10.0 1.0X +Use EnumSet 1 1 0 82.8 12.1 0.8X -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 17.0.2+8-LTS on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz Test create 1 item Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 17 17 1 5.9 168.4 1.0X -Use EnumSet 2 2 0 56.2 17.8 9.5X +Use HashSet 13 14 1 7.6 132.1 1.0X +Use EnumSet 2 2 0 46.9 21.3 6.2X -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 17.0.2+8-LTS on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz Test create 3 items Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 49 50 1 2.0 493.3 1.0X -Use EnumSet 1 1 0 89.7 11.1 44.2X +Use HashSet 45 46 1 2.2 446.6 1.0X +Use EnumSet 1 2 0 68.6 14.6 30.7X -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 17.0.2+8-LTS on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz Test create 5 items Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 116 118 1 0.9 1164.2 1.0X -Use EnumSet 1 1 0 89.7 11.2 104.4X +Use HashSet 127 128 1 0.8 1268.6 1.0X +Use EnumSet 1 2 0 80.3 12.5 101.9X -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 17.0.2+8-LTS on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz Test create 10 items Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 168 170 2 0.6 1681.4 1.0X -Use EnumSet 1 1 0 83.6 12.0 140.6X +Use HashSet 148 158 6 0.7 1479.8 1.0X +Use EnumSet 1 1 0 87.4 11.4 129.4X -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 17.0.2+8-LTS on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz Test create and contains use empty Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 1 1 0 904.5 1.1 1.0X -Use EnumSet 0 0 0 Infinity 0.0 InfinityX +Use HashSet 1 1 1 870.5 1.1 1.0X +Use EnumSet 2 2 0 497.6 2.0 0.6X -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 17.0.2+8-LTS on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz Test create and contains use 1 item Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 38 38 2 26.5 37.8 1.0X -Use EnumSet 0 0 0 Infinity 0.0 InfinityX +Use HashSet 27 30 2 36.9 27.1 1.0X +Use EnumSet 2 3 0 457.0 2.2 12.4X -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 17.0.2+8-LTS on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz Test create and contains use 3 items Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 67 68 2 14.9 67.2 1.0X -Use EnumSet 0 0 0 Infinity 0.0 InfinityX +Use HashSet 60 64 3 16.6 60.1 1.0X +Use EnumSet 2 2 0 460.9 2.2 27.7X -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 17.0.2+8-LTS on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz Test create and contains use 5 items Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 135 137 3 7.4 134.6 1.0X -Use EnumSet 0 0 0 Infinity 0.0 InfinityX +Use HashSet 146 151 4 6.9 145.6 1.0X +Use EnumSet 2 2 0 645.0 1.6 93.9X -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 17.0.2+8-LTS on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz Test create and contains use 10 items Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -Use HashSet 187 190 3 5.3 187.2 1.0X -Use EnumSet 0 0 0 Infinity 0.0 InfinityX +Use HashSet 193 200 5 5.2 192.6 1.0X +Use EnumSet 2 2 0 602.8 1.7 116.1X diff --git a/sql/catalyst/benchmarks/EnumTypeSetBenchmark-results.txt b/sql/catalyst/benchmarks/EnumTypeSetBenchmark-results.txt index e2c0e3d5dc22d..4d4eb0269ebf3 100644 --- a/sql/catalyst/benchmarks/EnumTypeSetBenchmark-results.txt +++ b/sql/catalyst/benchmarks/EnumTypeSetBenchmark-results.txt @@ -1,105 +1,105 @@ -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 1.8.0_322-b06 on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Test contains use empty Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 0 1 1 2192.0 0.5 1.0X -Use EnumSet 0 0 0 Infinity 0.0 InfinityX +Use HashSet 1 1 0 1709.1 0.6 1.0X +Use EnumSet 2 2 0 554.8 1.8 0.3X -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 1.8.0_322-b06 on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Test contains use 1 item Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 10 11 1 102.0 9.8 1.0X -Use EnumSet 0 0 0 Infinity 0.0 InfinityX +Use HashSet 8 8 0 124.2 8.1 1.0X +Use EnumSet 2 2 0 423.8 2.4 3.4X -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 1.8.0_322-b06 on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Test contains use 3 items Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 19 21 1 53.2 18.8 1.0X -Use EnumSet 0 0 0 Infinity 0.0 InfinityX +Use HashSet 16 16 0 62.6 16.0 1.0X +Use EnumSet 2 2 0 423.8 2.4 6.8X -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 1.8.0_322-b06 on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Test contains use 5 items Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 16 17 1 61.5 16.3 1.0X -Use EnumSet 0 0 0 Infinity 0.0 InfinityX +Use HashSet 15 15 0 66.3 15.1 1.0X +Use EnumSet 2 4 2 423.8 2.4 6.4X -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 1.8.0_322-b06 on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Test contains use 10 items Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 18 20 1 56.6 17.7 1.0X -Use EnumSet 0 0 0 Infinity 0.0 InfinityX +Use HashSet 15 16 0 65.3 15.3 1.0X +Use EnumSet 2 3 0 423.8 2.4 6.5X -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 1.8.0_322-b06 on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Test create empty Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 1 1 0 136.6 7.3 1.0X -Use EnumSet 2 2 0 65.5 15.3 0.5X +Use HashSet 1 1 0 132.0 7.6 1.0X +Use EnumSet 2 2 0 62.4 16.0 0.5X -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 1.8.0_322-b06 on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Test create 1 item Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 13 14 1 7.9 127.3 1.0X -Use EnumSet 2 2 0 54.9 18.2 7.0X +Use HashSet 16 17 1 6.4 156.6 1.0X +Use EnumSet 2 2 0 59.7 16.7 9.4X -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 1.8.0_322-b06 on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Test create 3 items Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 41 43 2 2.4 408.2 1.0X -Use EnumSet 2 2 0 66.0 15.1 27.0X +Use HashSet 51 51 1 2.0 510.7 1.0X +Use EnumSet 2 2 0 62.9 15.9 32.1X -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 1.8.0_322-b06 on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Test create 5 items Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 101 106 3 1.0 1010.4 1.0X -Use EnumSet 2 2 0 60.4 16.6 61.0X +Use HashSet 110 118 7 0.9 1099.9 1.0X +Use EnumSet 2 2 0 58.4 17.1 64.3X -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 1.8.0_322-b06 on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Test create 10 items Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 157 164 3 0.6 1571.3 1.0X -Use EnumSet 1 2 0 78.2 12.8 122.9X +Use HashSet 144 145 1 0.7 1442.6 1.0X +Use EnumSet 1 2 0 71.0 14.1 102.4X -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 1.8.0_322-b06 on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Test create and contains use empty Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 1 2 0 714.6 1.4 1.0X -Use EnumSet 0 0 0 Infinity 0.0 InfinityX +Use HashSet 1 1 0 816.8 1.2 1.0X +Use EnumSet 2 2 0 484.0 2.1 0.6X -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 1.8.0_322-b06 on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Test create and contains use 1 item Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 38 42 2 26.5 37.7 1.0X -Use EnumSet 0 0 0 Infinity 0.0 InfinityX +Use HashSet 33 33 0 30.7 32.6 1.0X +Use EnumSet 2 3 0 405.3 2.5 13.2X -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 1.8.0_322-b06 on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Test create and contains use 3 items Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 68 72 2 14.8 67.5 1.0X -Use EnumSet 0 0 0 Infinity 0.0 InfinityX +Use HashSet 76 76 1 13.2 75.6 1.0X +Use EnumSet 2 3 0 400.6 2.5 30.3X -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 1.8.0_322-b06 on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Test create and contains use 5 items Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Use HashSet 151 160 4 6.6 150.8 1.0X -Use EnumSet 0 0 0 Infinity 0.0 InfinityX +Use HashSet 170 170 1 5.9 169.6 1.0X +Use EnumSet 3 9 1 308.1 3.2 52.3X -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 1.8.0_322-b06 on Linux 5.11.0-1028-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Test create and contains use 10 items Set: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -Use HashSet 209 223 12 4.8 208.5 1.0X -Use EnumSet 0 0 0 Infinity 0.0 InfinityX +Use HashSet 156 157 1 6.4 155.8 1.0X +Use EnumSet 9 9 0 110.2 9.1 17.2X diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 new file mode 100644 index 0000000000000..e84d4fa45eb99 --- /dev/null +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 @@ -0,0 +1,487 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * This file is an adaptation of Presto's presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 grammar. + */ + +lexer grammar SqlBaseLexer; + +@members { + /** + * When true, parser should throw ParseExcetion for unclosed bracketed comment. + */ + public boolean has_unclosed_bracketed_comment = false; + + /** + * Verify whether current token is a valid decimal token (which contains dot). + * Returns true if the character that follows the token is not a digit or letter or underscore. + * + * For example: + * For char stream "2.3", "2." is not a valid decimal token, because it is followed by digit '3'. + * For char stream "2.3_", "2.3" is not a valid decimal token, because it is followed by '_'. + * For char stream "2.3W", "2.3" is not a valid decimal token, because it is followed by 'W'. + * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is followed + * by a space. 34.E2 is a valid decimal token because it is followed by symbol '+' + * which is not a digit or letter or underscore. + */ + public boolean isValidDecimal() { + int nextChar = _input.LA(1); + if (nextChar >= 'A' && nextChar <= 'Z' || nextChar >= '0' && nextChar <= '9' || + nextChar == '_') { + return false; + } else { + return true; + } + } + + /** + * This method will be called when we see '/*' and try to match it as a bracketed comment. + * If the next character is '+', it should be parsed as hint later, and we cannot match + * it as a bracketed comment. + * + * Returns true if the next character is '+'. + */ + public boolean isHint() { + int nextChar = _input.LA(1); + if (nextChar == '+') { + return true; + } else { + return false; + } + } + + /** + * This method will be called when the character stream ends and try to find out the + * unclosed bracketed comment. + * If the method be called, it means the end of the entire character stream match, + * and we set the flag and fail later. + */ + public void markUnclosedComment() { + has_unclosed_bracketed_comment = true; + } +} + +SEMICOLON: ';'; + +LEFT_PAREN: '('; +RIGHT_PAREN: ')'; +COMMA: ','; +DOT: '.'; +LEFT_BRACKET: '['; +RIGHT_BRACKET: ']'; + +// NOTE: If you add a new token in the list below, you should update the list of keywords +// and reserved tag in `docs/sql-ref-ansi-compliance.md#sql-keywords`. + +//============================ +// Start of the keywords list +//============================ +//--SPARK-KEYWORD-LIST-START +ADD: 'ADD'; +AFTER: 'AFTER'; +ALL: 'ALL'; +ALTER: 'ALTER'; +ANALYZE: 'ANALYZE'; +AND: 'AND'; +ANTI: 'ANTI'; +ANY: 'ANY'; +ARCHIVE: 'ARCHIVE'; +ARRAY: 'ARRAY'; +AS: 'AS'; +ASC: 'ASC'; +AT: 'AT'; +AUTHORIZATION: 'AUTHORIZATION'; +BETWEEN: 'BETWEEN'; +BOTH: 'BOTH'; +BUCKET: 'BUCKET'; +BUCKETS: 'BUCKETS'; +BY: 'BY'; +CACHE: 'CACHE'; +CASCADE: 'CASCADE'; +CASE: 'CASE'; +CAST: 'CAST'; +CATALOG: 'CATALOG'; +CATALOGS: 'CATALOGS'; +CHANGE: 'CHANGE'; +CHECK: 'CHECK'; +CLEAR: 'CLEAR'; +CLUSTER: 'CLUSTER'; +CLUSTERED: 'CLUSTERED'; +CODEGEN: 'CODEGEN'; +COLLATE: 'COLLATE'; +COLLECTION: 'COLLECTION'; +COLUMN: 'COLUMN'; +COLUMNS: 'COLUMNS'; +COMMENT: 'COMMENT'; +COMMIT: 'COMMIT'; +COMPACT: 'COMPACT'; +COMPACTIONS: 'COMPACTIONS'; +COMPUTE: 'COMPUTE'; +CONCATENATE: 'CONCATENATE'; +CONSTRAINT: 'CONSTRAINT'; +COST: 'COST'; +CREATE: 'CREATE'; +CROSS: 'CROSS'; +CUBE: 'CUBE'; +CURRENT: 'CURRENT'; +CURRENT_DATE: 'CURRENT_DATE'; +CURRENT_TIME: 'CURRENT_TIME'; +CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP'; +CURRENT_USER: 'CURRENT_USER'; +DAY: 'DAY'; +DAYOFYEAR: 'DAYOFYEAR'; +DATA: 'DATA'; +DATABASE: 'DATABASE'; +DATABASES: 'DATABASES'; +DATEADD: 'DATEADD'; +DATEDIFF: 'DATEDIFF'; +DBPROPERTIES: 'DBPROPERTIES'; +DEFAULT: 'DEFAULT'; +DEFINED: 'DEFINED'; +DELETE: 'DELETE'; +DELIMITED: 'DELIMITED'; +DESC: 'DESC'; +DESCRIBE: 'DESCRIBE'; +DFS: 'DFS'; +DIRECTORIES: 'DIRECTORIES'; +DIRECTORY: 'DIRECTORY'; +DISTINCT: 'DISTINCT'; +DISTRIBUTE: 'DISTRIBUTE'; +DIV: 'DIV'; +DROP: 'DROP'; +ELSE: 'ELSE'; +END: 'END'; +ESCAPE: 'ESCAPE'; +ESCAPED: 'ESCAPED'; +EXCEPT: 'EXCEPT'; +EXCHANGE: 'EXCHANGE'; +EXISTS: 'EXISTS'; +EXPLAIN: 'EXPLAIN'; +EXPORT: 'EXPORT'; +EXTENDED: 'EXTENDED'; +EXTERNAL: 'EXTERNAL'; +EXTRACT: 'EXTRACT'; +FALSE: 'FALSE'; +FETCH: 'FETCH'; +FIELDS: 'FIELDS'; +FILTER: 'FILTER'; +FILEFORMAT: 'FILEFORMAT'; +FIRST: 'FIRST'; +FOLLOWING: 'FOLLOWING'; +FOR: 'FOR'; +FOREIGN: 'FOREIGN'; +FORMAT: 'FORMAT'; +FORMATTED: 'FORMATTED'; +FROM: 'FROM'; +FULL: 'FULL'; +FUNCTION: 'FUNCTION'; +FUNCTIONS: 'FUNCTIONS'; +GLOBAL: 'GLOBAL'; +GRANT: 'GRANT'; +GROUP: 'GROUP'; +GROUPING: 'GROUPING'; +HAVING: 'HAVING'; +HOUR: 'HOUR'; +IF: 'IF'; +IGNORE: 'IGNORE'; +IMPORT: 'IMPORT'; +IN: 'IN'; +INDEX: 'INDEX'; +INDEXES: 'INDEXES'; +INNER: 'INNER'; +INPATH: 'INPATH'; +INPUTFORMAT: 'INPUTFORMAT'; +INSERT: 'INSERT'; +INTERSECT: 'INTERSECT'; +INTERVAL: 'INTERVAL'; +INTO: 'INTO'; +IS: 'IS'; +ITEMS: 'ITEMS'; +JOIN: 'JOIN'; +KEYS: 'KEYS'; +LAST: 'LAST'; +LATERAL: 'LATERAL'; +LAZY: 'LAZY'; +LEADING: 'LEADING'; +LEFT: 'LEFT'; +LIKE: 'LIKE'; +ILIKE: 'ILIKE'; +LIMIT: 'LIMIT'; +LINES: 'LINES'; +LIST: 'LIST'; +LOAD: 'LOAD'; +LOCAL: 'LOCAL'; +LOCATION: 'LOCATION'; +LOCK: 'LOCK'; +LOCKS: 'LOCKS'; +LOGICAL: 'LOGICAL'; +MACRO: 'MACRO'; +MAP: 'MAP'; +MATCHED: 'MATCHED'; +MERGE: 'MERGE'; +MICROSECOND: 'MICROSECOND'; +MILLISECOND: 'MILLISECOND'; +MINUTE: 'MINUTE'; +MONTH: 'MONTH'; +MSCK: 'MSCK'; +NAMESPACE: 'NAMESPACE'; +NAMESPACES: 'NAMESPACES'; +NATURAL: 'NATURAL'; +NO: 'NO'; +NOT: 'NOT' | '!'; +NULL: 'NULL'; +NULLS: 'NULLS'; +OF: 'OF'; +ON: 'ON'; +ONLY: 'ONLY'; +OPTION: 'OPTION'; +OPTIONS: 'OPTIONS'; +OR: 'OR'; +ORDER: 'ORDER'; +OUT: 'OUT'; +OUTER: 'OUTER'; +OUTPUTFORMAT: 'OUTPUTFORMAT'; +OVER: 'OVER'; +OVERLAPS: 'OVERLAPS'; +OVERLAY: 'OVERLAY'; +OVERWRITE: 'OVERWRITE'; +PARTITION: 'PARTITION'; +PARTITIONED: 'PARTITIONED'; +PARTITIONS: 'PARTITIONS'; +PERCENTILE_CONT: 'PERCENTILE_CONT'; +PERCENTLIT: 'PERCENT'; +PIVOT: 'PIVOT'; +PLACING: 'PLACING'; +POSITION: 'POSITION'; +PRECEDING: 'PRECEDING'; +PRIMARY: 'PRIMARY'; +PRINCIPALS: 'PRINCIPALS'; +PROPERTIES: 'PROPERTIES'; +PURGE: 'PURGE'; +QUARTER: 'QUARTER'; +QUERY: 'QUERY'; +RANGE: 'RANGE'; +RECORDREADER: 'RECORDREADER'; +RECORDWRITER: 'RECORDWRITER'; +RECOVER: 'RECOVER'; +REDUCE: 'REDUCE'; +REFERENCES: 'REFERENCES'; +REFRESH: 'REFRESH'; +RENAME: 'RENAME'; +REPAIR: 'REPAIR'; +REPEATABLE: 'REPEATABLE'; +REPLACE: 'REPLACE'; +RESET: 'RESET'; +RESPECT: 'RESPECT'; +RESTRICT: 'RESTRICT'; +REVOKE: 'REVOKE'; +RIGHT: 'RIGHT'; +RLIKE: 'RLIKE' | 'REGEXP'; +ROLE: 'ROLE'; +ROLES: 'ROLES'; +ROLLBACK: 'ROLLBACK'; +ROLLUP: 'ROLLUP'; +ROW: 'ROW'; +ROWS: 'ROWS'; +SECOND: 'SECOND'; +SCHEMA: 'SCHEMA'; +SCHEMAS: 'SCHEMAS'; +SELECT: 'SELECT'; +SEMI: 'SEMI'; +SEPARATED: 'SEPARATED'; +SERDE: 'SERDE'; +SERDEPROPERTIES: 'SERDEPROPERTIES'; +SESSION_USER: 'SESSION_USER'; +SET: 'SET'; +SETMINUS: 'MINUS'; +SETS: 'SETS'; +SHOW: 'SHOW'; +SKEWED: 'SKEWED'; +SOME: 'SOME'; +SORT: 'SORT'; +SORTED: 'SORTED'; +START: 'START'; +STATISTICS: 'STATISTICS'; +STORED: 'STORED'; +STRATIFY: 'STRATIFY'; +STRUCT: 'STRUCT'; +SUBSTR: 'SUBSTR'; +SUBSTRING: 'SUBSTRING'; +SYNC: 'SYNC'; +SYSTEM_TIME: 'SYSTEM_TIME'; +SYSTEM_VERSION: 'SYSTEM_VERSION'; +TABLE: 'TABLE'; +TABLES: 'TABLES'; +TABLESAMPLE: 'TABLESAMPLE'; +TBLPROPERTIES: 'TBLPROPERTIES'; +TEMPORARY: 'TEMPORARY' | 'TEMP'; +TERMINATED: 'TERMINATED'; +THEN: 'THEN'; +TIME: 'TIME'; +TIMESTAMP: 'TIMESTAMP'; +TIMESTAMPADD: 'TIMESTAMPADD'; +TIMESTAMPDIFF: 'TIMESTAMPDIFF'; +TO: 'TO'; +TOUCH: 'TOUCH'; +TRAILING: 'TRAILING'; +TRANSACTION: 'TRANSACTION'; +TRANSACTIONS: 'TRANSACTIONS'; +TRANSFORM: 'TRANSFORM'; +TRIM: 'TRIM'; +TRUE: 'TRUE'; +TRUNCATE: 'TRUNCATE'; +TRY_CAST: 'TRY_CAST'; +TYPE: 'TYPE'; +UNARCHIVE: 'UNARCHIVE'; +UNBOUNDED: 'UNBOUNDED'; +UNCACHE: 'UNCACHE'; +UNION: 'UNION'; +UNIQUE: 'UNIQUE'; +UNKNOWN: 'UNKNOWN'; +UNLOCK: 'UNLOCK'; +UNSET: 'UNSET'; +UPDATE: 'UPDATE'; +USE: 'USE'; +USER: 'USER'; +USING: 'USING'; +VALUES: 'VALUES'; +VERSION: 'VERSION'; +VIEW: 'VIEW'; +VIEWS: 'VIEWS'; +WEEK: 'WEEK'; +WHEN: 'WHEN'; +WHERE: 'WHERE'; +WINDOW: 'WINDOW'; +WITH: 'WITH'; +WITHIN: 'WITHIN'; +YEAR: 'YEAR'; +ZONE: 'ZONE'; +//--SPARK-KEYWORD-LIST-END +//============================ +// End of the keywords list +//============================ + +EQ : '=' | '=='; +NSEQ: '<=>'; +NEQ : '<>'; +NEQJ: '!='; +LT : '<'; +LTE : '<=' | '!>'; +GT : '>'; +GTE : '>=' | '!<'; + +PLUS: '+'; +MINUS: '-'; +ASTERISK: '*'; +SLASH: '/'; +PERCENT: '%'; +TILDE: '~'; +AMPERSAND: '&'; +PIPE: '|'; +CONCAT_PIPE: '||'; +HAT: '^'; +COLON: ':'; +ARROW: '->'; +HENT_START: '/*+'; +HENT_END: '*/'; + +STRING + : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' + | '"' ( ~('"'|'\\') | ('\\' .) )* '"' + | 'R\'' (~'\'')* '\'' + | 'R"'(~'"')* '"' + ; + +BIGINT_LITERAL + : DIGIT+ 'L' + ; + +SMALLINT_LITERAL + : DIGIT+ 'S' + ; + +TINYINT_LITERAL + : DIGIT+ 'Y' + ; + +INTEGER_VALUE + : DIGIT+ + ; + +EXPONENT_VALUE + : DIGIT+ EXPONENT + | DECIMAL_DIGITS EXPONENT {isValidDecimal()}? + ; + +DECIMAL_VALUE + : DECIMAL_DIGITS {isValidDecimal()}? + ; + +FLOAT_LITERAL + : DIGIT+ EXPONENT? 'F' + | DECIMAL_DIGITS EXPONENT? 'F' {isValidDecimal()}? + ; + +DOUBLE_LITERAL + : DIGIT+ EXPONENT? 'D' + | DECIMAL_DIGITS EXPONENT? 'D' {isValidDecimal()}? + ; + +BIGDECIMAL_LITERAL + : DIGIT+ EXPONENT? 'BD' + | DECIMAL_DIGITS EXPONENT? 'BD' {isValidDecimal()}? + ; + +IDENTIFIER + : (LETTER | DIGIT | '_')+ + ; + +BACKQUOTED_IDENTIFIER + : '`' ( ~'`' | '``' )* '`' + ; + +fragment DECIMAL_DIGITS + : DIGIT+ '.' DIGIT* + | '.' DIGIT+ + ; + +fragment EXPONENT + : 'E' [+-]? DIGIT+ + ; + +fragment DIGIT + : [0-9] + ; + +fragment LETTER + : [A-Z] + ; + +SIMPLE_COMMENT + : '--' ('\\\n' | ~[\r\n])* '\r'? '\n'? -> channel(HIDDEN) + ; + +BRACKETED_COMMENT + : '/*' {!isHint()}? ( BRACKETED_COMMENT | . )*? ('*/' | {markUnclosedComment();} EOF) -> channel(HIDDEN) + ; + +WS + : [ \r\n\t]+ -> channel(HIDDEN) + ; + +// Catch-all for anything we can't recognize. +// We use this to be able to ignore and recover all the text +// when splitting statements with DelimiterLexer +UNRECOGNIZED + : . + ; diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 similarity index 69% rename from sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 rename to sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 6331798ef5db0..fb3bccacaf94b 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -14,9 +14,11 @@ * This file is an adaptation of Presto's presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 grammar. */ -grammar SqlBase; +parser grammar SqlBaseParser; -@parser::members { +options { tokenVocab = SqlBaseLexer; } + +@members { /** * When false, INTERSECT is given the greater precedence over the other set * operations (UNION, EXCEPT and MINUS) as per the SQL standard. @@ -35,63 +37,8 @@ grammar SqlBase; public boolean SQL_standard_keyword_behavior = false; } -@lexer::members { - /** - * When true, parser should throw ParseExcetion for unclosed bracketed comment. - */ - public boolean has_unclosed_bracketed_comment = false; - - /** - * Verify whether current token is a valid decimal token (which contains dot). - * Returns true if the character that follows the token is not a digit or letter or underscore. - * - * For example: - * For char stream "2.3", "2." is not a valid decimal token, because it is followed by digit '3'. - * For char stream "2.3_", "2.3" is not a valid decimal token, because it is followed by '_'. - * For char stream "2.3W", "2.3" is not a valid decimal token, because it is followed by 'W'. - * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is followed - * by a space. 34.E2 is a valid decimal token because it is followed by symbol '+' - * which is not a digit or letter or underscore. - */ - public boolean isValidDecimal() { - int nextChar = _input.LA(1); - if (nextChar >= 'A' && nextChar <= 'Z' || nextChar >= '0' && nextChar <= '9' || - nextChar == '_') { - return false; - } else { - return true; - } - } - - /** - * This method will be called when we see '/*' and try to match it as a bracketed comment. - * If the next character is '+', it should be parsed as hint later, and we cannot match - * it as a bracketed comment. - * - * Returns true if the next character is '+'. - */ - public boolean isHint() { - int nextChar = _input.LA(1); - if (nextChar == '+') { - return true; - } else { - return false; - } - } - - /** - * This method will be called when the character stream ends and try to find out the - * unclosed bracketed comment. - * If the method be called, it means the end of the entire character stream match, - * and we set the flag and fail later. - */ - public void markUnclosedComment() { - has_unclosed_bracketed_comment = true; - } -} - singleStatement - : statement ';'* EOF + : statement SEMICOLON* EOF ; singleExpression @@ -136,7 +83,7 @@ statement (RESTRICT | CASCADE)? #dropNamespace | SHOW namespaces ((FROM | IN) multipartIdentifier)? (LIKE? pattern=STRING)? #showNamespaces - | createTableHeader ('(' colTypeList ')')? tableProvider? + | createTableHeader (LEFT_PAREN createOrReplaceTableColTypeList RIGHT_PAREN)? tableProvider? createTableClauses (AS? query)? #createTable | CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier @@ -146,7 +93,7 @@ statement createFileFormat | locationSpec | (TBLPROPERTIES tableProps=propertyList))* #createTableLike - | replaceTableHeader ('(' colTypeList ')')? tableProvider? + | replaceTableHeader (LEFT_PAREN createOrReplaceTableColTypeList RIGHT_PAREN)? tableProvider? createTableClauses (AS? query)? #replaceTable | ANALYZE TABLE multipartIdentifier partitionSpec? COMPUTE STATISTICS @@ -158,13 +105,13 @@ statement columns=qualifiedColTypeWithPositionList #addTableColumns | ALTER TABLE multipartIdentifier ADD (COLUMN | COLUMNS) - '(' columns=qualifiedColTypeWithPositionList ')' #addTableColumns + LEFT_PAREN columns=qualifiedColTypeWithPositionList RIGHT_PAREN #addTableColumns | ALTER TABLE table=multipartIdentifier RENAME COLUMN from=multipartIdentifier TO to=errorCapturingIdentifier #renameTableColumn | ALTER TABLE multipartIdentifier DROP (COLUMN | COLUMNS) - '(' columns=multipartIdentifierList ')' #dropTableColumns + LEFT_PAREN columns=multipartIdentifierList RIGHT_PAREN #dropTableColumns | ALTER TABLE multipartIdentifier DROP (COLUMN | COLUMNS) columns=multipartIdentifierList #dropTableColumns | ALTER (TABLE | VIEW) from=multipartIdentifier @@ -181,7 +128,8 @@ statement colName=multipartIdentifier colType colPosition? #hiveChangeColumn | ALTER TABLE table=multipartIdentifier partitionSpec? REPLACE COLUMNS - '(' columns=qualifiedColTypeWithPositionList ')' #hiveReplaceColumns + LEFT_PAREN columns=qualifiedColTypeWithPositionList + RIGHT_PAREN #hiveReplaceColumns | ALTER TABLE multipartIdentifier (partitionSpec)? SET SERDE STRING (WITH SERDEPROPERTIES propertyList)? #setTableSerDe | ALTER TABLE multipartIdentifier (partitionSpec)? @@ -191,7 +139,7 @@ statement | ALTER TABLE multipartIdentifier from=partitionSpec RENAME TO to=partitionSpec #renameTablePartition | ALTER (TABLE | VIEW) multipartIdentifier - DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? #dropTablePartitions + DROP (IF EXISTS)? partitionSpec (COMMA partitionSpec)* PURGE? #dropTablePartitions | ALTER TABLE multipartIdentifier (partitionSpec)? SET locationSpec #setTableLocation | ALTER TABLE multipartIdentifier RECOVER PARTITIONS #recoverPartitions @@ -205,12 +153,12 @@ statement (TBLPROPERTIES propertyList))* AS query #createView | CREATE (OR REPLACE)? GLOBAL? TEMPORARY VIEW - tableIdentifier ('(' colTypeList ')')? tableProvider + tableIdentifier (LEFT_PAREN colTypeList RIGHT_PAREN)? tableProvider (OPTIONS propertyList)? #createTempViewUsing | ALTER VIEW multipartIdentifier AS? query #alterViewQuery | CREATE (OR REPLACE)? TEMPORARY? FUNCTION (IF NOT EXISTS)? multipartIdentifier AS className=STRING - (USING resource (',' resource)*)? #createFunction + (USING resource (COMMA resource)*)? #createFunction | DROP TEMPORARY? FUNCTION (IF EXISTS)? multipartIdentifier #dropFunction | EXPLAIN (LOGICAL | FORMATTED | EXTENDED | CODEGEN | COST)? statement #explain @@ -219,7 +167,7 @@ statement | SHOW TABLE EXTENDED ((FROM | IN) ns=multipartIdentifier)? LIKE pattern=STRING partitionSpec? #showTableExtended | SHOW TBLPROPERTIES table=multipartIdentifier - ('(' key=propertyKey ')')? #showTblProperties + (LEFT_PAREN key=propertyKey RIGHT_PAREN)? #showTblProperties | SHOW COLUMNS (FROM | IN) table=multipartIdentifier ((FROM | IN) ns=multipartIdentifier)? #showColumns | SHOW VIEWS ((FROM | IN) multipartIdentifier)? @@ -264,7 +212,7 @@ statement | RESET .*? #resetConfiguration | CREATE INDEX (IF NOT EXISTS)? identifier ON TABLE? multipartIdentifier (USING indexType=identifier)? - '(' columns=multipartIdentifierPropertyList ')' + LEFT_PAREN columns=multipartIdentifierPropertyList RIGHT_PAREN (OPTIONS options=propertyList)? #createIndex | DROP INDEX (IF EXISTS)? identifier ON TABLE? multipartIdentifier #dropIndex | unsupportedHiveNativeCommands .*? #failNativeCommand @@ -369,7 +317,7 @@ partitionSpecLocation ; partitionSpec - : PARTITION '(' partitionVal (',' partitionVal)* ')' + : PARTITION LEFT_PAREN partitionVal (COMMA partitionVal)* RIGHT_PAREN ; partitionVal @@ -397,15 +345,15 @@ describeFuncName ; describeColName - : nameParts+=identifier ('.' nameParts+=identifier)* + : nameParts+=identifier (DOT nameParts+=identifier)* ; ctes - : WITH namedQuery (',' namedQuery)* + : WITH namedQuery (COMMA namedQuery)* ; namedQuery - : name=errorCapturingIdentifier (columnAliases=identifierList)? AS? '(' query ')' + : name=errorCapturingIdentifier (columnAliases=identifierList)? AS? LEFT_PAREN query RIGHT_PAREN ; tableProvider @@ -425,7 +373,7 @@ createTableClauses ; propertyList - : '(' property (',' property)* ')' + : LEFT_PAREN property (COMMA property)* RIGHT_PAREN ; property @@ -433,7 +381,7 @@ property ; propertyKey - : identifier ('.' identifier)* + : identifier (DOT identifier)* | STRING ; @@ -445,11 +393,11 @@ propertyValue ; constantList - : '(' constant (',' constant)* ')' + : LEFT_PAREN constant (COMMA constant)* RIGHT_PAREN ; nestedConstantList - : '(' constantList (',' constantList)* ')' + : LEFT_PAREN constantList (COMMA constantList)* RIGHT_PAREN ; createFileFormat @@ -477,17 +425,17 @@ dmlStatementNoWith | UPDATE multipartIdentifier tableAlias setClause whereClause? #updateTable | MERGE INTO target=multipartIdentifier targetAlias=tableAlias USING (source=multipartIdentifier | - '(' sourceQuery=query')') sourceAlias=tableAlias + LEFT_PAREN sourceQuery=query RIGHT_PAREN) sourceAlias=tableAlias ON mergeCondition=booleanExpression matchedClause* notMatchedClause* #mergeIntoTable ; queryOrganization - : (ORDER BY order+=sortItem (',' order+=sortItem)*)? - (CLUSTER BY clusterBy+=expression (',' clusterBy+=expression)*)? - (DISTRIBUTE BY distributeBy+=expression (',' distributeBy+=expression)*)? - (SORT BY sort+=sortItem (',' sort+=sortItem)*)? + : (ORDER BY order+=sortItem (COMMA order+=sortItem)*)? + (CLUSTER BY clusterBy+=expression (COMMA clusterBy+=expression)*)? + (DISTRIBUTE BY distributeBy+=expression (COMMA distributeBy+=expression)*)? + (SORT BY sort+=sortItem (COMMA sort+=sortItem)*)? windowClause? (LIMIT (ALL | limit=expression))? ; @@ -511,7 +459,7 @@ queryPrimary | fromStatement #fromStmt | TABLE multipartIdentifier #table | inlineTable #inlineTableDefault1 - | '(' query ')' #subquery + | LEFT_PAREN query RIGHT_PAREN #subquery ; sortItem @@ -553,13 +501,13 @@ querySpecification ; transformClause - : (SELECT kind=TRANSFORM '(' setQuantifier? expressionSeq ')' + : (SELECT kind=TRANSFORM LEFT_PAREN setQuantifier? expressionSeq RIGHT_PAREN | kind=MAP setQuantifier? expressionSeq | kind=REDUCE setQuantifier? expressionSeq) inRowFormat=rowFormat? (RECORDWRITER recordWriter=STRING)? USING script=STRING - (AS (identifierSeq | colTypeList | ('(' (identifierSeq | colTypeList) ')')))? + (AS (identifierSeq | colTypeList | (LEFT_PAREN (identifierSeq | colTypeList) RIGHT_PAREN)))? outRowFormat=rowFormat? (RECORDREADER recordReader=STRING)? ; @@ -587,12 +535,12 @@ matchedAction notMatchedAction : INSERT ASTERISK - | INSERT '(' columns=multipartIdentifierList ')' - VALUES '(' expression (',' expression)* ')' + | INSERT LEFT_PAREN columns=multipartIdentifierList RIGHT_PAREN + VALUES LEFT_PAREN expression (COMMA expression)* RIGHT_PAREN ; assignmentList - : assignment (',' assignment)* + : assignment (COMMA assignment)* ; assignment @@ -608,16 +556,16 @@ havingClause ; hint - : '/*+' hintStatements+=hintStatement (','? hintStatements+=hintStatement)* '*/' + : HENT_START hintStatements+=hintStatement (COMMA? hintStatements+=hintStatement)* HENT_END ; hintStatement : hintName=identifier - | hintName=identifier '(' parameters+=primaryExpression (',' parameters+=primaryExpression)* ')' + | hintName=identifier LEFT_PAREN parameters+=primaryExpression (COMMA parameters+=primaryExpression)* RIGHT_PAREN ; fromClause - : FROM relation (',' relation)* lateralView* pivotClause? + : FROM relation (COMMA relation)* lateralView* pivotClause? ; temporalClause @@ -627,11 +575,11 @@ temporalClause aggregationClause : GROUP BY groupingExpressionsWithGroupingAnalytics+=groupByClause - (',' groupingExpressionsWithGroupingAnalytics+=groupByClause)* - | GROUP BY groupingExpressions+=expression (',' groupingExpressions+=expression)* ( + (COMMA groupingExpressionsWithGroupingAnalytics+=groupByClause)* + | GROUP BY groupingExpressions+=expression (COMMA groupingExpressions+=expression)* ( WITH kind=ROLLUP | WITH kind=CUBE - | kind=GROUPING SETS '(' groupingSet (',' groupingSet)* ')')? + | kind=GROUPING SETS LEFT_PAREN groupingSet (COMMA groupingSet)* RIGHT_PAREN)? ; groupByClause @@ -640,8 +588,8 @@ groupByClause ; groupingAnalytics - : (ROLLUP | CUBE) '(' groupingSet (',' groupingSet)* ')' - | GROUPING SETS '(' groupingElement (',' groupingElement)* ')' + : (ROLLUP | CUBE) LEFT_PAREN groupingSet (COMMA groupingSet)* RIGHT_PAREN + | GROUPING SETS LEFT_PAREN groupingElement (COMMA groupingElement)* RIGHT_PAREN ; groupingElement @@ -650,17 +598,17 @@ groupingElement ; groupingSet - : '(' (expression (',' expression)*)? ')' + : LEFT_PAREN (expression (COMMA expression)*)? RIGHT_PAREN | expression ; pivotClause - : PIVOT '(' aggregates=namedExpressionSeq FOR pivotColumn IN '(' pivotValues+=pivotValue (',' pivotValues+=pivotValue)* ')' ')' + : PIVOT LEFT_PAREN aggregates=namedExpressionSeq FOR pivotColumn IN LEFT_PAREN pivotValues+=pivotValue (COMMA pivotValues+=pivotValue)* RIGHT_PAREN RIGHT_PAREN ; pivotColumn : identifiers+=identifier - | '(' identifiers+=identifier (',' identifiers+=identifier)* ')' + | LEFT_PAREN identifiers+=identifier (COMMA identifiers+=identifier)* RIGHT_PAREN ; pivotValue @@ -668,7 +616,7 @@ pivotValue ; lateralView - : LATERAL VIEW (OUTER)? qualifiedName '(' (expression (',' expression)*)? ')' tblName=identifier (AS? colName+=identifier (',' colName+=identifier)*)? + : LATERAL VIEW (OUTER)? qualifiedName LEFT_PAREN (expression (COMMA expression)*)? RIGHT_PAREN tblName=identifier (AS? colName+=identifier (COMMA colName+=identifier)*)? ; setQuantifier @@ -701,27 +649,27 @@ joinCriteria ; sample - : TABLESAMPLE '(' sampleMethod? ')' (REPEATABLE '('seed=INTEGER_VALUE')')? + : TABLESAMPLE LEFT_PAREN sampleMethod? RIGHT_PAREN (REPEATABLE LEFT_PAREN seed=INTEGER_VALUE RIGHT_PAREN)? ; sampleMethod : negativeSign=MINUS? percentage=(INTEGER_VALUE | DECIMAL_VALUE) PERCENTLIT #sampleByPercentile | expression ROWS #sampleByRows | sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE - (ON (identifier | qualifiedName '(' ')'))? #sampleByBucket + (ON (identifier | qualifiedName LEFT_PAREN RIGHT_PAREN))? #sampleByBucket | bytes=expression #sampleByBytes ; identifierList - : '(' identifierSeq ')' + : LEFT_PAREN identifierSeq RIGHT_PAREN ; identifierSeq - : ident+=errorCapturingIdentifier (',' ident+=errorCapturingIdentifier)* + : ident+=errorCapturingIdentifier (COMMA ident+=errorCapturingIdentifier)* ; orderedIdentifierList - : '(' orderedIdentifier (',' orderedIdentifier)* ')' + : LEFT_PAREN orderedIdentifier (COMMA orderedIdentifier)* RIGHT_PAREN ; orderedIdentifier @@ -729,7 +677,7 @@ orderedIdentifier ; identifierCommentList - : '(' identifierComment (',' identifierComment)* ')' + : LEFT_PAREN identifierComment (COMMA identifierComment)* RIGHT_PAREN ; identifierComment @@ -738,19 +686,19 @@ identifierComment relationPrimary : multipartIdentifier temporalClause? - sample? tableAlias #tableName - | '(' query ')' sample? tableAlias #aliasedQuery - | '(' relation ')' sample? tableAlias #aliasedRelation - | inlineTable #inlineTableDefault2 - | functionTable #tableValuedFunction + sample? tableAlias #tableName + | LEFT_PAREN query RIGHT_PAREN sample? tableAlias #aliasedQuery + | LEFT_PAREN relation RIGHT_PAREN sample? tableAlias #aliasedRelation + | inlineTable #inlineTableDefault2 + | functionTable #tableValuedFunction ; inlineTable - : VALUES expression (',' expression)* tableAlias + : VALUES expression (COMMA expression)* tableAlias ; functionTable - : funcName=functionName '(' (expression (',' expression)*)? ')' tableAlias + : funcName=functionName LEFT_PAREN (expression (COMMA expression)*)? RIGHT_PAREN tableAlias ; tableAlias @@ -768,15 +716,15 @@ rowFormat ; multipartIdentifierList - : multipartIdentifier (',' multipartIdentifier)* + : multipartIdentifier (COMMA multipartIdentifier)* ; multipartIdentifier - : parts+=errorCapturingIdentifier ('.' parts+=errorCapturingIdentifier)* + : parts+=errorCapturingIdentifier (DOT parts+=errorCapturingIdentifier)* ; multipartIdentifierPropertyList - : multipartIdentifierProperty (',' multipartIdentifierProperty)* + : multipartIdentifierProperty (COMMA multipartIdentifierProperty)* ; multipartIdentifierProperty @@ -784,11 +732,11 @@ multipartIdentifierProperty ; tableIdentifier - : (db=errorCapturingIdentifier '.')? table=errorCapturingIdentifier + : (db=errorCapturingIdentifier DOT)? table=errorCapturingIdentifier ; functionIdentifier - : (db=errorCapturingIdentifier '.')? function=errorCapturingIdentifier + : (db=errorCapturingIdentifier DOT)? function=errorCapturingIdentifier ; namedExpression @@ -796,11 +744,11 @@ namedExpression ; namedExpressionSeq - : namedExpression (',' namedExpression)* + : namedExpression (COMMA namedExpression)* ; partitionFieldList - : '(' fields+=partitionField (',' fields+=partitionField)* ')' + : LEFT_PAREN fields+=partitionField (COMMA fields+=partitionField)* RIGHT_PAREN ; partitionField @@ -809,9 +757,9 @@ partitionField ; transform - : qualifiedName #identityTransform + : qualifiedName #identityTransform | transformName=identifier - '(' argument+=transformArgument (',' argument+=transformArgument)* ')' #applyTransform + LEFT_PAREN argument+=transformArgument (COMMA argument+=transformArgument)* RIGHT_PAREN #applyTransform ; transformArgument @@ -824,12 +772,12 @@ expression ; expressionSeq - : expression (',' expression)* + : expression (COMMA expression)* ; booleanExpression : NOT booleanExpression #logicalNot - | EXISTS '(' query ')' #exists + | EXISTS LEFT_PAREN query RIGHT_PAREN #exists | valueExpression predicate? #predicated | left=booleanExpression operator=AND right=booleanExpression #logicalBinary | left=booleanExpression operator=OR right=booleanExpression #logicalBinary @@ -837,10 +785,10 @@ booleanExpression predicate : NOT? kind=BETWEEN lower=valueExpression AND upper=valueExpression - | NOT? kind=IN '(' expression (',' expression)* ')' - | NOT? kind=IN '(' query ')' + | NOT? kind=IN LEFT_PAREN expression (COMMA expression)* RIGHT_PAREN + | NOT? kind=IN LEFT_PAREN query RIGHT_PAREN | NOT? kind=RLIKE pattern=valueExpression - | NOT? kind=(LIKE | ILIKE) quantifier=(ANY | SOME | ALL) ('('')' | '(' expression (',' expression)* ')') + | NOT? kind=(LIKE | ILIKE) quantifier=(ANY | SOME | ALL) (LEFT_PAREN RIGHT_PAREN | LEFT_PAREN expression (COMMA expression)* RIGHT_PAREN) | NOT? kind=(LIKE | ILIKE) pattern=valueExpression (ESCAPE escapeChar=STRING)? | IS NOT? kind=NULL | IS NOT? kind=(TRUE | FALSE | UNKNOWN) @@ -858,38 +806,46 @@ valueExpression | left=valueExpression comparisonOperator right=valueExpression #comparison ; +datetimeUnit + : YEAR | QUARTER | MONTH + | WEEK | DAY | DAYOFYEAR + | HOUR | MINUTE | SECOND | MILLISECOND | MICROSECOND + ; + primaryExpression : name=(CURRENT_DATE | CURRENT_TIMESTAMP | CURRENT_USER) #currentLike + | name=(TIMESTAMPADD | DATEADD) LEFT_PAREN unit=datetimeUnit COMMA unitsAmount=valueExpression COMMA timestamp=valueExpression RIGHT_PAREN #timestampadd + | name=(TIMESTAMPDIFF | DATEDIFF) LEFT_PAREN unit=datetimeUnit COMMA startTimestamp=valueExpression COMMA endTimestamp=valueExpression RIGHT_PAREN #timestampdiff | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase - | name=(CAST | TRY_CAST) '(' expression AS dataType ')' #cast - | STRUCT '(' (argument+=namedExpression (',' argument+=namedExpression)*)? ')' #struct - | FIRST '(' expression (IGNORE NULLS)? ')' #first - | LAST '(' expression (IGNORE NULLS)? ')' #last - | POSITION '(' substr=valueExpression IN str=valueExpression ')' #position + | name=(CAST | TRY_CAST) LEFT_PAREN expression AS dataType RIGHT_PAREN #cast + | STRUCT LEFT_PAREN (argument+=namedExpression (COMMA argument+=namedExpression)*)? RIGHT_PAREN #struct + | FIRST LEFT_PAREN expression (IGNORE NULLS)? RIGHT_PAREN #first + | LAST LEFT_PAREN expression (IGNORE NULLS)? RIGHT_PAREN #last + | POSITION LEFT_PAREN substr=valueExpression IN str=valueExpression RIGHT_PAREN #position | constant #constantDefault | ASTERISK #star - | qualifiedName '.' ASTERISK #star - | '(' namedExpression (',' namedExpression)+ ')' #rowConstructor - | '(' query ')' #subqueryExpression - | functionName '(' (setQuantifier? argument+=expression (',' argument+=expression)*)? ')' - (FILTER '(' WHERE where=booleanExpression ')')? + | qualifiedName DOT ASTERISK #star + | LEFT_PAREN namedExpression (COMMA namedExpression)+ RIGHT_PAREN #rowConstructor + | LEFT_PAREN query RIGHT_PAREN #subqueryExpression + | functionName LEFT_PAREN (setQuantifier? argument+=expression (COMMA argument+=expression)*)? RIGHT_PAREN + (FILTER LEFT_PAREN WHERE where=booleanExpression RIGHT_PAREN)? (nullsOption=(IGNORE | RESPECT) NULLS)? ( OVER windowSpec)? #functionCall - | identifier '->' expression #lambda - | '(' identifier (',' identifier)+ ')' '->' expression #lambda - | value=primaryExpression '[' index=valueExpression ']' #subscript + | identifier ARROW expression #lambda + | LEFT_PAREN identifier (COMMA identifier)+ RIGHT_PAREN ARROW expression #lambda + | value=primaryExpression LEFT_BRACKET index=valueExpression RIGHT_BRACKET #subscript | identifier #columnReference - | base=primaryExpression '.' fieldName=identifier #dereference - | '(' expression ')' #parenthesizedExpression - | EXTRACT '(' field=identifier FROM source=valueExpression ')' #extract - | (SUBSTR | SUBSTRING) '(' str=valueExpression (FROM | ',') pos=valueExpression - ((FOR | ',') len=valueExpression)? ')' #substring - | TRIM '(' trimOption=(BOTH | LEADING | TRAILING)? (trimStr=valueExpression)? - FROM srcStr=valueExpression ')' #trim - | OVERLAY '(' input=valueExpression PLACING replace=valueExpression - FROM position=valueExpression (FOR length=valueExpression)? ')' #overlay - | PERCENTILE_CONT '(' percentage=valueExpression ')' - WITHIN GROUP '(' ORDER BY sortItem ')' #percentile + | base=primaryExpression DOT fieldName=identifier #dereference + | LEFT_PAREN expression RIGHT_PAREN #parenthesizedExpression + | EXTRACT LEFT_PAREN field=identifier FROM source=valueExpression RIGHT_PAREN #extract + | (SUBSTR | SUBSTRING) LEFT_PAREN str=valueExpression (FROM | COMMA) pos=valueExpression + ((FOR | COMMA) len=valueExpression)? RIGHT_PAREN #substring + | TRIM LEFT_PAREN trimOption=(BOTH | LEADING | TRAILING)? (trimStr=valueExpression)? + FROM srcStr=valueExpression RIGHT_PAREN #trim + | OVERLAY LEFT_PAREN input=valueExpression PLACING replace=valueExpression + FROM position=valueExpression (FOR length=valueExpression)? RIGHT_PAREN #overlay + | PERCENTILE_CONT LEFT_PAREN percentage=valueExpression RIGHT_PAREN + WITHIN GROUP LEFT_PAREN ORDER BY sortItem RIGHT_PAREN #percentile ; constant @@ -946,37 +902,50 @@ colPosition ; dataType - : complex=ARRAY '<' dataType '>' #complexDataType - | complex=MAP '<' dataType ',' dataType '>' #complexDataType - | complex=STRUCT ('<' complexColTypeList? '>' | NEQ) #complexDataType + : complex=ARRAY LT dataType GT #complexDataType + | complex=MAP LT dataType COMMA dataType GT #complexDataType + | complex=STRUCT (LT complexColTypeList? GT | NEQ) #complexDataType | INTERVAL from=(YEAR | MONTH) (TO to=MONTH)? #yearMonthIntervalDataType | INTERVAL from=(DAY | HOUR | MINUTE | SECOND) (TO to=(HOUR | MINUTE | SECOND))? #dayTimeIntervalDataType - | identifier ('(' INTEGER_VALUE (',' INTEGER_VALUE)* ')')? #primitiveDataType + | identifier (LEFT_PAREN INTEGER_VALUE + (COMMA INTEGER_VALUE)* RIGHT_PAREN)? #primitiveDataType ; qualifiedColTypeWithPositionList - : qualifiedColTypeWithPosition (',' qualifiedColTypeWithPosition)* + : qualifiedColTypeWithPosition (COMMA qualifiedColTypeWithPosition)* ; qualifiedColTypeWithPosition - : name=multipartIdentifier dataType (NOT NULL)? commentSpec? colPosition? + : name=multipartIdentifier dataType (NOT NULL)? defaultExpression? commentSpec? colPosition? + ; + +defaultExpression + : DEFAULT expression ; colTypeList - : colType (',' colType)* + : colType (COMMA colType)* ; colType : colName=errorCapturingIdentifier dataType (NOT NULL)? commentSpec? ; +createOrReplaceTableColTypeList + : createOrReplaceTableColType (COMMA createOrReplaceTableColType)* + ; + +createOrReplaceTableColType + : colName=errorCapturingIdentifier dataType (NOT NULL)? defaultExpression? commentSpec? + ; + complexColTypeList - : complexColType (',' complexColType)* + : complexColType (COMMA complexColType)* ; complexColType - : identifier ':'? dataType (NOT NULL)? commentSpec? + : identifier COLON? dataType (NOT NULL)? commentSpec? ; whenClause @@ -984,7 +953,7 @@ whenClause ; windowClause - : WINDOW namedWindow (',' namedWindow)* + : WINDOW namedWindow (COMMA namedWindow)* ; namedWindow @@ -992,14 +961,14 @@ namedWindow ; windowSpec - : name=errorCapturingIdentifier #windowRef - | '('name=errorCapturingIdentifier')' #windowRef - | '(' - ( CLUSTER BY partition+=expression (',' partition+=expression)* - | ((PARTITION | DISTRIBUTE) BY partition+=expression (',' partition+=expression)*)? - ((ORDER | SORT) BY sortItem (',' sortItem)*)?) + : name=errorCapturingIdentifier #windowRef + | LEFT_PAREN name=errorCapturingIdentifier RIGHT_PAREN #windowRef + | LEFT_PAREN + ( CLUSTER BY partition+=expression (COMMA partition+=expression)* + | ((PARTITION | DISTRIBUTE) BY partition+=expression (COMMA partition+=expression)*)? + ((ORDER | SORT) BY sortItem (COMMA sortItem)*)?) windowFrame? - ')' #windowDef + RIGHT_PAREN #windowDef ; windowFrame @@ -1016,7 +985,7 @@ frameBound ; qualifiedNameList - : qualifiedName (',' qualifiedName)* + : qualifiedName (COMMA qualifiedName)* ; functionName @@ -1027,7 +996,7 @@ functionName ; qualifiedName - : identifier ('.' identifier)* + : identifier (DOT identifier)* ; // this rule is used for explicitly capturing wrong identifiers such as test-table, which should actually be `test-table` @@ -1077,6 +1046,8 @@ alterColumnAction | commentSpec | colPosition | setOrDrop=(SET | DROP) NOT NULL + | SET defaultExpression + | dropDefault=DROP DEFAULT ; @@ -1129,8 +1100,12 @@ ansiNonReserved | DATA | DATABASE | DATABASES + | DATEADD + | DATEDIFF | DAY + | DAYOFYEAR | DBPROPERTIES + | DEFAULT | DEFINED | DELETE | DELIMITED @@ -1189,6 +1164,8 @@ ansiNonReserved | MAP | MATCHED | MERGE + | MICROSECOND + | MILLISECOND | MINUTE | MONTH | MSCK @@ -1215,6 +1192,7 @@ ansiNonReserved | PRINCIPALS | PROPERTIES | PURGE + | QUARTER | QUERY | RANGE | RECORDREADER @@ -1267,6 +1245,8 @@ ansiNonReserved | TEMPORARY | TERMINATED | TIMESTAMP + | TIMESTAMPADD + | TIMESTAMPDIFF | TOUCH | TRANSACTION | TRANSACTIONS @@ -1287,6 +1267,7 @@ ansiNonReserved | VERSION | VIEW | VIEWS + | WEEK | WINDOW | YEAR | ZONE @@ -1375,8 +1356,12 @@ nonReserved | DATA | DATABASE | DATABASES + | DATEADD + | DATEDIFF | DAY + | DAYOFYEAR | DBPROPERTIES + | DEFAULT | DEFINED | DELETE | DELIMITED @@ -1452,6 +1437,8 @@ nonReserved | MAP | MATCHED | MERGE + | MICROSECOND + | MILLISECOND | MINUTE | MONTH | MSCK @@ -1487,6 +1474,7 @@ nonReserved | PRINCIPALS | PROPERTIES | PURGE + | QUARTER | QUERY | RANGE | RECORDREADER @@ -1544,6 +1532,8 @@ nonReserved | THEN | TIME | TIMESTAMP + | TIMESTAMPADD + | TIMESTAMPDIFF | TO | TOUCH | TRAILING @@ -1569,6 +1559,7 @@ nonReserved | VERSION | VIEW | VIEWS + | WEEK | WHEN | WHERE | WINDOW @@ -1578,395 +1569,3 @@ nonReserved | ZONE //--DEFAULT-NON-RESERVED-END ; - -// NOTE: If you add a new token in the list below, you should update the list of keywords -// and reserved tag in `docs/sql-ref-ansi-compliance.md#sql-keywords`. - -//============================ -// Start of the keywords list -//============================ -//--SPARK-KEYWORD-LIST-START -ADD: 'ADD'; -AFTER: 'AFTER'; -ALL: 'ALL'; -ALTER: 'ALTER'; -ANALYZE: 'ANALYZE'; -AND: 'AND'; -ANTI: 'ANTI'; -ANY: 'ANY'; -ARCHIVE: 'ARCHIVE'; -ARRAY: 'ARRAY'; -AS: 'AS'; -ASC: 'ASC'; -AT: 'AT'; -AUTHORIZATION: 'AUTHORIZATION'; -BETWEEN: 'BETWEEN'; -BOTH: 'BOTH'; -BUCKET: 'BUCKET'; -BUCKETS: 'BUCKETS'; -BY: 'BY'; -CACHE: 'CACHE'; -CASCADE: 'CASCADE'; -CASE: 'CASE'; -CAST: 'CAST'; -CATALOG: 'CATALOG'; -CATALOGS: 'CATALOGS'; -CHANGE: 'CHANGE'; -CHECK: 'CHECK'; -CLEAR: 'CLEAR'; -CLUSTER: 'CLUSTER'; -CLUSTERED: 'CLUSTERED'; -CODEGEN: 'CODEGEN'; -COLLATE: 'COLLATE'; -COLLECTION: 'COLLECTION'; -COLUMN: 'COLUMN'; -COLUMNS: 'COLUMNS'; -COMMENT: 'COMMENT'; -COMMIT: 'COMMIT'; -COMPACT: 'COMPACT'; -COMPACTIONS: 'COMPACTIONS'; -COMPUTE: 'COMPUTE'; -CONCATENATE: 'CONCATENATE'; -CONSTRAINT: 'CONSTRAINT'; -COST: 'COST'; -CREATE: 'CREATE'; -CROSS: 'CROSS'; -CUBE: 'CUBE'; -CURRENT: 'CURRENT'; -CURRENT_DATE: 'CURRENT_DATE'; -CURRENT_TIME: 'CURRENT_TIME'; -CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP'; -CURRENT_USER: 'CURRENT_USER'; -DAY: 'DAY'; -DATA: 'DATA'; -DATABASE: 'DATABASE'; -DATABASES: 'DATABASES'; -DBPROPERTIES: 'DBPROPERTIES'; -DEFINED: 'DEFINED'; -DELETE: 'DELETE'; -DELIMITED: 'DELIMITED'; -DESC: 'DESC'; -DESCRIBE: 'DESCRIBE'; -DFS: 'DFS'; -DIRECTORIES: 'DIRECTORIES'; -DIRECTORY: 'DIRECTORY'; -DISTINCT: 'DISTINCT'; -DISTRIBUTE: 'DISTRIBUTE'; -DIV: 'DIV'; -DROP: 'DROP'; -ELSE: 'ELSE'; -END: 'END'; -ESCAPE: 'ESCAPE'; -ESCAPED: 'ESCAPED'; -EXCEPT: 'EXCEPT'; -EXCHANGE: 'EXCHANGE'; -EXISTS: 'EXISTS'; -EXPLAIN: 'EXPLAIN'; -EXPORT: 'EXPORT'; -EXTENDED: 'EXTENDED'; -EXTERNAL: 'EXTERNAL'; -EXTRACT: 'EXTRACT'; -FALSE: 'FALSE'; -FETCH: 'FETCH'; -FIELDS: 'FIELDS'; -FILTER: 'FILTER'; -FILEFORMAT: 'FILEFORMAT'; -FIRST: 'FIRST'; -FOLLOWING: 'FOLLOWING'; -FOR: 'FOR'; -FOREIGN: 'FOREIGN'; -FORMAT: 'FORMAT'; -FORMATTED: 'FORMATTED'; -FROM: 'FROM'; -FULL: 'FULL'; -FUNCTION: 'FUNCTION'; -FUNCTIONS: 'FUNCTIONS'; -GLOBAL: 'GLOBAL'; -GRANT: 'GRANT'; -GROUP: 'GROUP'; -GROUPING: 'GROUPING'; -HAVING: 'HAVING'; -HOUR: 'HOUR'; -IF: 'IF'; -IGNORE: 'IGNORE'; -IMPORT: 'IMPORT'; -IN: 'IN'; -INDEX: 'INDEX'; -INDEXES: 'INDEXES'; -INNER: 'INNER'; -INPATH: 'INPATH'; -INPUTFORMAT: 'INPUTFORMAT'; -INSERT: 'INSERT'; -INTERSECT: 'INTERSECT'; -INTERVAL: 'INTERVAL'; -INTO: 'INTO'; -IS: 'IS'; -ITEMS: 'ITEMS'; -JOIN: 'JOIN'; -KEYS: 'KEYS'; -LAST: 'LAST'; -LATERAL: 'LATERAL'; -LAZY: 'LAZY'; -LEADING: 'LEADING'; -LEFT: 'LEFT'; -LIKE: 'LIKE'; -ILIKE: 'ILIKE'; -LIMIT: 'LIMIT'; -LINES: 'LINES'; -LIST: 'LIST'; -LOAD: 'LOAD'; -LOCAL: 'LOCAL'; -LOCATION: 'LOCATION'; -LOCK: 'LOCK'; -LOCKS: 'LOCKS'; -LOGICAL: 'LOGICAL'; -MACRO: 'MACRO'; -MAP: 'MAP'; -MATCHED: 'MATCHED'; -MERGE: 'MERGE'; -MINUTE: 'MINUTE'; -MONTH: 'MONTH'; -MSCK: 'MSCK'; -NAMESPACE: 'NAMESPACE'; -NAMESPACES: 'NAMESPACES'; -NATURAL: 'NATURAL'; -NO: 'NO'; -NOT: 'NOT' | '!'; -NULL: 'NULL'; -NULLS: 'NULLS'; -OF: 'OF'; -ON: 'ON'; -ONLY: 'ONLY'; -OPTION: 'OPTION'; -OPTIONS: 'OPTIONS'; -OR: 'OR'; -ORDER: 'ORDER'; -OUT: 'OUT'; -OUTER: 'OUTER'; -OUTPUTFORMAT: 'OUTPUTFORMAT'; -OVER: 'OVER'; -OVERLAPS: 'OVERLAPS'; -OVERLAY: 'OVERLAY'; -OVERWRITE: 'OVERWRITE'; -PARTITION: 'PARTITION'; -PARTITIONED: 'PARTITIONED'; -PARTITIONS: 'PARTITIONS'; -PERCENTILE_CONT: 'PERCENTILE_CONT'; -PERCENTLIT: 'PERCENT'; -PIVOT: 'PIVOT'; -PLACING: 'PLACING'; -POSITION: 'POSITION'; -PRECEDING: 'PRECEDING'; -PRIMARY: 'PRIMARY'; -PRINCIPALS: 'PRINCIPALS'; -PROPERTIES: 'PROPERTIES'; -PURGE: 'PURGE'; -QUERY: 'QUERY'; -RANGE: 'RANGE'; -RECORDREADER: 'RECORDREADER'; -RECORDWRITER: 'RECORDWRITER'; -RECOVER: 'RECOVER'; -REDUCE: 'REDUCE'; -REFERENCES: 'REFERENCES'; -REFRESH: 'REFRESH'; -RENAME: 'RENAME'; -REPAIR: 'REPAIR'; -REPEATABLE: 'REPEATABLE'; -REPLACE: 'REPLACE'; -RESET: 'RESET'; -RESPECT: 'RESPECT'; -RESTRICT: 'RESTRICT'; -REVOKE: 'REVOKE'; -RIGHT: 'RIGHT'; -RLIKE: 'RLIKE' | 'REGEXP'; -ROLE: 'ROLE'; -ROLES: 'ROLES'; -ROLLBACK: 'ROLLBACK'; -ROLLUP: 'ROLLUP'; -ROW: 'ROW'; -ROWS: 'ROWS'; -SECOND: 'SECOND'; -SCHEMA: 'SCHEMA'; -SCHEMAS: 'SCHEMAS'; -SELECT: 'SELECT'; -SEMI: 'SEMI'; -SEPARATED: 'SEPARATED'; -SERDE: 'SERDE'; -SERDEPROPERTIES: 'SERDEPROPERTIES'; -SESSION_USER: 'SESSION_USER'; -SET: 'SET'; -SETMINUS: 'MINUS'; -SETS: 'SETS'; -SHOW: 'SHOW'; -SKEWED: 'SKEWED'; -SOME: 'SOME'; -SORT: 'SORT'; -SORTED: 'SORTED'; -START: 'START'; -STATISTICS: 'STATISTICS'; -STORED: 'STORED'; -STRATIFY: 'STRATIFY'; -STRUCT: 'STRUCT'; -SUBSTR: 'SUBSTR'; -SUBSTRING: 'SUBSTRING'; -SYNC: 'SYNC'; -SYSTEM_TIME: 'SYSTEM_TIME'; -SYSTEM_VERSION: 'SYSTEM_VERSION'; -TABLE: 'TABLE'; -TABLES: 'TABLES'; -TABLESAMPLE: 'TABLESAMPLE'; -TBLPROPERTIES: 'TBLPROPERTIES'; -TEMPORARY: 'TEMPORARY' | 'TEMP'; -TERMINATED: 'TERMINATED'; -THEN: 'THEN'; -TIME: 'TIME'; -TIMESTAMP: 'TIMESTAMP'; -TO: 'TO'; -TOUCH: 'TOUCH'; -TRAILING: 'TRAILING'; -TRANSACTION: 'TRANSACTION'; -TRANSACTIONS: 'TRANSACTIONS'; -TRANSFORM: 'TRANSFORM'; -TRIM: 'TRIM'; -TRUE: 'TRUE'; -TRUNCATE: 'TRUNCATE'; -TRY_CAST: 'TRY_CAST'; -TYPE: 'TYPE'; -UNARCHIVE: 'UNARCHIVE'; -UNBOUNDED: 'UNBOUNDED'; -UNCACHE: 'UNCACHE'; -UNION: 'UNION'; -UNIQUE: 'UNIQUE'; -UNKNOWN: 'UNKNOWN'; -UNLOCK: 'UNLOCK'; -UNSET: 'UNSET'; -UPDATE: 'UPDATE'; -USE: 'USE'; -USER: 'USER'; -USING: 'USING'; -VALUES: 'VALUES'; -VERSION: 'VERSION'; -VIEW: 'VIEW'; -VIEWS: 'VIEWS'; -WHEN: 'WHEN'; -WHERE: 'WHERE'; -WINDOW: 'WINDOW'; -WITH: 'WITH'; -WITHIN: 'WITHIN'; -YEAR: 'YEAR'; -ZONE: 'ZONE'; -//--SPARK-KEYWORD-LIST-END -//============================ -// End of the keywords list -//============================ - -EQ : '=' | '=='; -NSEQ: '<=>'; -NEQ : '<>'; -NEQJ: '!='; -LT : '<'; -LTE : '<=' | '!>'; -GT : '>'; -GTE : '>=' | '!<'; - -PLUS: '+'; -MINUS: '-'; -ASTERISK: '*'; -SLASH: '/'; -PERCENT: '%'; -TILDE: '~'; -AMPERSAND: '&'; -PIPE: '|'; -CONCAT_PIPE: '||'; -HAT: '^'; - -STRING - : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' - | '"' ( ~('"'|'\\') | ('\\' .) )* '"' - | 'R\'' (~'\'')* '\'' - | 'R"'(~'"')* '"' - ; - -BIGINT_LITERAL - : DIGIT+ 'L' - ; - -SMALLINT_LITERAL - : DIGIT+ 'S' - ; - -TINYINT_LITERAL - : DIGIT+ 'Y' - ; - -INTEGER_VALUE - : DIGIT+ - ; - -EXPONENT_VALUE - : DIGIT+ EXPONENT - | DECIMAL_DIGITS EXPONENT {isValidDecimal()}? - ; - -DECIMAL_VALUE - : DECIMAL_DIGITS {isValidDecimal()}? - ; - -FLOAT_LITERAL - : DIGIT+ EXPONENT? 'F' - | DECIMAL_DIGITS EXPONENT? 'F' {isValidDecimal()}? - ; - -DOUBLE_LITERAL - : DIGIT+ EXPONENT? 'D' - | DECIMAL_DIGITS EXPONENT? 'D' {isValidDecimal()}? - ; - -BIGDECIMAL_LITERAL - : DIGIT+ EXPONENT? 'BD' - | DECIMAL_DIGITS EXPONENT? 'BD' {isValidDecimal()}? - ; - -IDENTIFIER - : (LETTER | DIGIT | '_')+ - ; - -BACKQUOTED_IDENTIFIER - : '`' ( ~'`' | '``' )* '`' - ; - -fragment DECIMAL_DIGITS - : DIGIT+ '.' DIGIT* - | '.' DIGIT+ - ; - -fragment EXPONENT - : 'E' [+-]? DIGIT+ - ; - -fragment DIGIT - : [0-9] - ; - -fragment LETTER - : [A-Z] - ; - -SIMPLE_COMMENT - : '--' ('\\\n' | ~[\r\n])* '\r'? '\n'? -> channel(HIDDEN) - ; - -BRACKETED_COMMENT - : '/*' {!isHint()}? ( BRACKETED_COMMENT | . )*? ('*/' | {markUnclosedComment();} EOF) -> channel(HIDDEN) - ; - -WS - : [ \r\n\t]+ -> channel(HIDDEN) - ; - -// Catch-all for anything we can't recognize. -// We use this to be able to ignore and recover all the text -// when splitting statements with DelimiterLexer -UNRECOGNIZED - : . - ; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 5088d06de9b32..476201c9a8d8e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -80,7 +80,7 @@ public static int calculateBitSetWidthInBytes(int numFields) { static { mutableFieldTypes = Collections.unmodifiableSet( new HashSet<>( - Arrays.asList(new DataType[] { + Arrays.asList( NullType, BooleanType, ByteType, @@ -90,8 +90,9 @@ public static int calculateBitSetWidthInBytes(int numFields) { FloatType, DoubleType, DateType, - TimestampType - }))); + TimestampType, + TimestampNTZType + ))); } public static boolean isFixedLength(DataType dt) { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java index 48a859a4159fb..865ac553199aa 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java @@ -155,8 +155,10 @@ public void alterNamespace( } @Override - public boolean dropNamespace(String[] namespace) throws NoSuchNamespaceException { - return asNamespaceCatalog().dropNamespace(namespace); + public boolean dropNamespace( + String[] namespace, + boolean cascade) throws NoSuchNamespaceException, NonEmptyNamespaceException { + return asNamespaceCatalog().dropNamespace(namespace, cascade); } @Override diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsNamespaces.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsNamespaces.java index f70746b612e92..c1a4960068d24 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsNamespaces.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsNamespaces.java @@ -20,6 +20,7 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.catalyst.analysis.NamespaceAlreadyExistsException; import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; +import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException; import java.util.Map; @@ -136,15 +137,20 @@ void alterNamespace( NamespaceChange... changes) throws NoSuchNamespaceException; /** - * Drop a namespace from the catalog, recursively dropping all objects within the namespace. + * Drop a namespace from the catalog with cascade mode, recursively dropping all objects + * within the namespace if cascade is true. *

* If the catalog implementation does not support this operation, it may throw * {@link UnsupportedOperationException}. * * @param namespace a multi-part namespace + * @param cascade When true, deletes all objects under the namespace * @return true if the namespace was dropped * @throws NoSuchNamespaceException If the namespace does not exist (optional) + * @throws NonEmptyNamespaceException If the namespace is non-empty and cascade is false * @throws UnsupportedOperationException If drop is not a supported operation */ - boolean dropNamespace(String[] namespace) throws NoSuchNamespaceException; + boolean dropNamespace( + String[] namespace, + boolean cascade) throws NoSuchNamespaceException, NonEmptyNamespaceException; } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java new file mode 100644 index 0000000000000..b3dd2cbfe3d7d --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions; + +import java.io.Serializable; +import java.util.Arrays; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder; + +// scalastyle:off line.size.limit +/** + * The general representation of SQL scalar expressions, which contains the upper-cased + * expression name and all the children expressions. + *

+ * The currently supported SQL scalar expressions: + *

    + *
  1. Name: IS_NULL + *
      + *
    • SQL semantic: expr IS NULL
    • + *
    • Since version: 3.3.0
    • + *
    + *
  2. + *
  3. Name: IS_NOT_NULL + *
      + *
    • SQL semantic: expr IS NOT NULL
    • + *
    • Since version: 3.3.0
    • + *
    + *
  4. + *
  5. Name: = + *
      + *
    • SQL semantic: expr1 = expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  6. + *
  7. Name: != + *
      + *
    • SQL semantic: expr1 != expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  8. + *
  9. Name: <> + *
      + *
    • SQL semantic: expr1 <> expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  10. + *
  11. Name: <=> + *
      + *
    • SQL semantic: expr1 <=> expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  12. + *
  13. Name: < + *
      + *
    • SQL semantic: expr1 < expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  14. + *
  15. Name: <= + *
      + *
    • SQL semantic: expr1 <= expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  16. + *
  17. Name: > + *
      + *
    • SQL semantic: expr1 > expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  18. + *
  19. Name: >= + *
      + *
    • SQL semantic: expr1 >= expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  20. + *
  21. Name: + + *
      + *
    • SQL semantic: expr1 + expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  22. + *
  23. Name: - + *
      + *
    • SQL semantic: expr1 - expr2 or - expr
    • + *
    • Since version: 3.3.0
    • + *
    + *
  24. + *
  25. Name: * + *
      + *
    • SQL semantic: expr1 * expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  26. + *
  27. Name: / + *
      + *
    • SQL semantic: expr1 / expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  28. + *
  29. Name: % + *
      + *
    • SQL semantic: expr1 % expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  30. + *
  31. Name: & + *
      + *
    • SQL semantic: expr1 & expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  32. + *
  33. Name: | + *
      + *
    • SQL semantic: expr1 | expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  34. + *
  35. Name: ^ + *
      + *
    • SQL semantic: expr1 ^ expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  36. + *
  37. Name: AND + *
      + *
    • SQL semantic: expr1 AND expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  38. + *
  39. Name: OR + *
      + *
    • SQL semantic: expr1 OR expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  40. + *
  41. Name: NOT + *
      + *
    • SQL semantic: NOT expr
    • + *
    • Since version: 3.3.0
    • + *
    + *
  42. + *
  43. Name: ~ + *
      + *
    • SQL semantic: ~ expr
    • + *
    • Since version: 3.3.0
    • + *
    + *
  44. + *
  45. Name: CASE_WHEN + *
      + *
    • SQL semantic: + * CASE WHEN expr1 THEN expr2 [WHEN expr3 THEN expr4]* [ELSE expr5] END + *
    • + *
    • Since version: 3.3.0
    • + *
    + *
  46. + *
+ * Note: SQL semantic conforms ANSI standard, so some expressions are not supported when ANSI off, + * including: add, subtract, multiply, divide, remainder, pmod. + * + * @since 3.3.0 + */ +// scalastyle:on line.size.limit +@Evolving +public class GeneralScalarExpression implements Expression, Serializable { + private String name; + private Expression[] children; + + public GeneralScalarExpression(String name, Expression[] children) { + this.name = name; + this.children = children; + } + + public String name() { return name; } + public Expression[] children() { return children; } + + @Override + public String toString() { + V2ExpressionSQLBuilder builder = new V2ExpressionSQLBuilder(); + try { + return builder.build(this); + } catch (Throwable e) { + return name + "(" + + Arrays.stream(children).map(child -> child.toString()).reduce((a,b) -> a + "," + b) + ")"; + } + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java new file mode 100644 index 0000000000000..cc9d27ab8e59c --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.aggregate; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Expression; + +/** + * An aggregate function that returns the mean of all the values in a group. + * + * @since 3.3.0 + */ +@Evolving +public final class Avg implements AggregateFunc { + private final Expression input; + private final boolean isDistinct; + + public Avg(Expression column, boolean isDistinct) { + this.input = column; + this.isDistinct = isDistinct; + } + + public Expression column() { return input; } + public boolean isDistinct() { return isDistinct; } + + @Override + public String toString() { + if (isDistinct) { + return "AVG(DISTINCT " + input.describe() + ")"; + } else { + return "AVG(" + input.describe() + ")"; + } + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java index 1685770604a46..54c64b83c5d52 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.connector.expressions.aggregate; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Expression; /** * An aggregate function that returns the number of the specific row in a group. @@ -27,23 +27,23 @@ */ @Evolving public final class Count implements AggregateFunc { - private final NamedReference column; + private final Expression input; private final boolean isDistinct; - public Count(NamedReference column, boolean isDistinct) { - this.column = column; + public Count(Expression column, boolean isDistinct) { + this.input = column; this.isDistinct = isDistinct; } - public NamedReference column() { return column; } + public Expression column() { return input; } public boolean isDistinct() { return isDistinct; } @Override public String toString() { if (isDistinct) { - return "COUNT(DISTINCT " + column.describe() + ")"; + return "COUNT(DISTINCT " + input.describe() + ")"; } else { - return "COUNT(" + column.describe() + ")"; + return "COUNT(" + input.describe() + ")"; } } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java index 32615e201643b..0ff26c8875b7a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java @@ -31,7 +31,6 @@ *

* The currently supported SQL aggregate functions: *

    - *
  1. AVG(input1)
    Since 3.3.0
  2. *
  3. VAR_POP(input1)
    Since 3.3.0
  4. *
  5. VAR_SAMP(input1)
    Since 3.3.0
  6. *
  7. STDDEV_POP(input1)
    Since 3.3.0
  8. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java index 5acdf14bf7e2f..971aac279e09b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.connector.expressions.aggregate; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Expression; /** * An aggregate function that returns the maximum value in a group. @@ -27,12 +27,12 @@ */ @Evolving public final class Max implements AggregateFunc { - private final NamedReference column; + private final Expression input; - public Max(NamedReference column) { this.column = column; } + public Max(Expression column) { this.input = column; } - public NamedReference column() { return column; } + public Expression column() { return input; } @Override - public String toString() { return "MAX(" + column.describe() + ")"; } + public String toString() { return "MAX(" + input.describe() + ")"; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java index 824c607ea7df0..8d0644b0f0103 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.connector.expressions.aggregate; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Expression; /** * An aggregate function that returns the minimum value in a group. @@ -27,12 +27,12 @@ */ @Evolving public final class Min implements AggregateFunc { - private final NamedReference column; + private final Expression input; - public Min(NamedReference column) { this.column = column; } + public Min(Expression column) { this.input = column; } - public NamedReference column() { return column; } + public Expression column() { return input; } @Override - public String toString() { return "MIN(" + column.describe() + ")"; } + public String toString() { return "MIN(" + input.describe() + ")"; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java index 6b04dc38c2846..721ef31c9a817 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.connector.expressions.aggregate; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Expression; /** * An aggregate function that returns the summation of all the values in a group. @@ -27,23 +27,23 @@ */ @Evolving public final class Sum implements AggregateFunc { - private final NamedReference column; + private final Expression input; private final boolean isDistinct; - public Sum(NamedReference column, boolean isDistinct) { - this.column = column; + public Sum(Expression column, boolean isDistinct) { + this.input = column; this.isDistinct = isDistinct; } - public NamedReference column() { return column; } + public Expression column() { return input; } public boolean isDistinct() { return isDistinct; } @Override public String toString() { if (isDistinct) { - return "SUM(DISTINCT " + column.describe() + ")"; + return "SUM(DISTINCT " + input.describe() + ")"; } else { - return "SUM(" + column.describe() + ")"; + return "SUM(" + input.describe() + ")"; } } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/AcceptsLatestSeenOffset.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/AcceptsLatestSeenOffset.java new file mode 100644 index 0000000000000..e8515c063cffd --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/AcceptsLatestSeenOffset.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read.streaming; + +/** + * Indicates that the source accepts the latest seen offset, which requires streaming execution + * to provide the latest seen offset when restarting the streaming query from checkpoint. + * + * Note that this interface aims to only support DSv2 streaming sources. Spark may throw error + * if the interface is implemented along with DSv1 streaming sources. + * + * The callback method will be called once per run. + */ +public interface AcceptsLatestSeenOffset extends SparkDataStream { + /** + * Callback method to receive the latest seen offset information from streaming execution. + * The method will be called only when the streaming query is restarted from checkpoint. + * + * @param offset The offset which was latest seen in the previous run. + */ + void setLatestSeenOffset(Offset offset); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java new file mode 100644 index 0000000000000..0af0d88b0f622 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.util; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.spark.sql.connector.expressions.Expression; +import org.apache.spark.sql.connector.expressions.FieldReference; +import org.apache.spark.sql.connector.expressions.GeneralScalarExpression; +import org.apache.spark.sql.connector.expressions.LiteralValue; + +/** + * The builder to generate SQL from V2 expressions. + */ +public class V2ExpressionSQLBuilder { + public String build(Expression expr) { + if (expr instanceof LiteralValue) { + return visitLiteral((LiteralValue) expr); + } else if (expr instanceof FieldReference) { + return visitFieldReference((FieldReference) expr); + } else if (expr instanceof GeneralScalarExpression) { + GeneralScalarExpression e = (GeneralScalarExpression) expr; + String name = e.name(); + switch (name) { + case "IS_NULL": + return visitIsNull(build(e.children()[0])); + case "IS_NOT_NULL": + return visitIsNotNull(build(e.children()[0])); + case "=": + case "!=": + case "<=>": + case "<": + case "<=": + case ">": + case ">=": + return visitBinaryComparison(name, build(e.children()[0]), build(e.children()[1])); + case "+": + case "*": + case "/": + case "%": + case "&": + case "|": + case "^": + return visitBinaryArithmetic(name, build(e.children()[0]), build(e.children()[1])); + case "-": + if (e.children().length == 1) { + return visitUnaryArithmetic(name, build(e.children()[0])); + } else { + return visitBinaryArithmetic(name, build(e.children()[0]), build(e.children()[1])); + } + case "AND": + return visitAnd(name, build(e.children()[0]), build(e.children()[1])); + case "OR": + return visitOr(name, build(e.children()[0]), build(e.children()[1])); + case "NOT": + return visitNot(build(e.children()[0])); + case "~": + return visitUnaryArithmetic(name, build(e.children()[0])); + case "CASE_WHEN": + List children = new ArrayList<>(); + for (Expression child : e.children()) { + children.add(build(child)); + } + return visitCaseWhen(children.toArray(new String[e.children().length])); + // TODO supports other expressions + default: + return visitUnexpectedExpr(expr); + } + } else { + return visitUnexpectedExpr(expr); + } + } + + protected String visitLiteral(LiteralValue literalValue) { + return literalValue.toString(); + } + + protected String visitFieldReference(FieldReference fieldRef) { + return fieldRef.toString(); + } + + protected String visitIsNull(String v) { + return v + " IS NULL"; + } + + protected String visitIsNotNull(String v) { + return v + " IS NOT NULL"; + } + + protected String visitBinaryComparison(String name, String l, String r) { + return "(" + l + ") " + name + " (" + r + ")"; + } + + protected String visitBinaryArithmetic(String name, String l, String r) { + return "(" + l + ") " + name + " (" + r + ")"; + } + + protected String visitAnd(String name, String l, String r) { + return "(" + l + ") " + name + " (" + r + ")"; + } + + protected String visitOr(String name, String l, String r) { + return "(" + l + ") " + name + " (" + r + ")"; + } + + protected String visitNot(String v) { + return "NOT (" + v + ")"; + } + + protected String visitUnaryArithmetic(String name, String v) { return name +" (" + v + ")"; } + + protected String visitCaseWhen(String[] children) { + StringBuilder sb = new StringBuilder("CASE"); + for (int i = 0; i < children.length; i += 2) { + String c = children[i]; + int j = i + 1; + if (j < children.length) { + String v = children[j]; + sb.append(" WHEN "); + sb.append(c); + sb.append(" THEN "); + sb.append(v); + } else { + sb.append(" ELSE "); + sb.append(c); + } + } + sb.append(" END"); + return sb.toString(); + } + + protected String visitUnexpectedExpr(Expression expr) throws IllegalArgumentException { + throw new IllegalArgumentException("Unexpected V2 expression: " + expr); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RequiresDistributionAndOrdering.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RequiresDistributionAndOrdering.java index 2284086f99f6e..983e6b0fffb20 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RequiresDistributionAndOrdering.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RequiresDistributionAndOrdering.java @@ -35,6 +35,11 @@ public interface RequiresDistributionAndOrdering extends Write { * Spark will distribute incoming records across partitions to satisfy the required distribution * before passing the records to the data source table on write. *

    + * Batch and micro-batch writes can request a particular data distribution. + * If a distribution is requested in the micro-batch context, incoming records in each micro batch + * will satisfy the required distribution (but not across micro batches). The continuous execution + * mode continuously processes streaming data and does not support distribution requirements. + *

    * Implementations may return {@link UnspecifiedDistribution} if they don't require any specific * distribution of data on write. * @@ -61,6 +66,11 @@ public interface RequiresDistributionAndOrdering extends Write { * Spark will order incoming records within partitions to satisfy the required ordering * before passing those records to the data source table on write. *

    + * Batch and micro-batch writes can request a particular data ordering. + * If an ordering is requested in the micro-batch context, incoming records in each micro batch + * will satisfy the required ordering (but not across micro batches). The continuous execution + * mode continuously processes streaming data and does not support ordering requirements. + *

    * Implementations may return an empty array if they don't require any specific ordering of data * on write. * diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index 9aee1050370da..fe60605525ae4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -21,6 +21,7 @@ import org.apache.arrow.vector.complex.*; import org.apache.arrow.vector.holders.NullableVarCharHolder; +import org.apache.spark.annotation.DeveloperApi; import org.apache.spark.sql.util.ArrowUtils; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.UTF8String; @@ -28,10 +29,13 @@ /** * A column vector backed by Apache Arrow. */ -public final class ArrowColumnVector extends ColumnVector { +@DeveloperApi +public class ArrowColumnVector extends ColumnVector { - private final ArrowVectorAccessor accessor; - private ArrowColumnVector[] childColumns; + ArrowVectorAccessor accessor; + ArrowColumnVector[] childColumns; + + public ValueVector getValueVector() { return accessor.vector; } @Override public boolean hasNull() { @@ -128,9 +132,16 @@ public ColumnarMap getMap(int rowId) { @Override public ArrowColumnVector getChild(int ordinal) { return childColumns[ordinal]; } + ArrowColumnVector(DataType type) { + super(type); + } + public ArrowColumnVector(ValueVector vector) { - super(ArrowUtils.fromArrowField(vector.getField())); + this(ArrowUtils.fromArrowField(vector.getField())); + initAccessor(vector); + } + void initAccessor(ValueVector vector) { if (vector instanceof BitVector) { accessor = new BooleanAccessor((BitVector) vector); } else if (vector instanceof TinyIntVector) { @@ -182,9 +193,9 @@ public ArrowColumnVector(ValueVector vector) { } } - private abstract static class ArrowVectorAccessor { + abstract static class ArrowVectorAccessor { - private final ValueVector vector; + final ValueVector vector; ArrowVectorAccessor(ValueVector vector) { this.vector = vector; @@ -252,7 +263,7 @@ ColumnarMap getMap(int rowId) { } } - private static class BooleanAccessor extends ArrowVectorAccessor { + static class BooleanAccessor extends ArrowVectorAccessor { private final BitVector accessor; @@ -267,7 +278,7 @@ final boolean getBoolean(int rowId) { } } - private static class ByteAccessor extends ArrowVectorAccessor { + static class ByteAccessor extends ArrowVectorAccessor { private final TinyIntVector accessor; @@ -282,7 +293,7 @@ final byte getByte(int rowId) { } } - private static class ShortAccessor extends ArrowVectorAccessor { + static class ShortAccessor extends ArrowVectorAccessor { private final SmallIntVector accessor; @@ -297,7 +308,7 @@ final short getShort(int rowId) { } } - private static class IntAccessor extends ArrowVectorAccessor { + static class IntAccessor extends ArrowVectorAccessor { private final IntVector accessor; @@ -312,7 +323,7 @@ final int getInt(int rowId) { } } - private static class LongAccessor extends ArrowVectorAccessor { + static class LongAccessor extends ArrowVectorAccessor { private final BigIntVector accessor; @@ -327,7 +338,7 @@ final long getLong(int rowId) { } } - private static class FloatAccessor extends ArrowVectorAccessor { + static class FloatAccessor extends ArrowVectorAccessor { private final Float4Vector accessor; @@ -342,7 +353,7 @@ final float getFloat(int rowId) { } } - private static class DoubleAccessor extends ArrowVectorAccessor { + static class DoubleAccessor extends ArrowVectorAccessor { private final Float8Vector accessor; @@ -357,7 +368,7 @@ final double getDouble(int rowId) { } } - private static class DecimalAccessor extends ArrowVectorAccessor { + static class DecimalAccessor extends ArrowVectorAccessor { private final DecimalVector accessor; @@ -373,7 +384,7 @@ final Decimal getDecimal(int rowId, int precision, int scale) { } } - private static class StringAccessor extends ArrowVectorAccessor { + static class StringAccessor extends ArrowVectorAccessor { private final VarCharVector accessor; private final NullableVarCharHolder stringResult = new NullableVarCharHolder(); @@ -396,7 +407,7 @@ final UTF8String getUTF8String(int rowId) { } } - private static class BinaryAccessor extends ArrowVectorAccessor { + static class BinaryAccessor extends ArrowVectorAccessor { private final VarBinaryVector accessor; @@ -411,7 +422,7 @@ final byte[] getBinary(int rowId) { } } - private static class DateAccessor extends ArrowVectorAccessor { + static class DateAccessor extends ArrowVectorAccessor { private final DateDayVector accessor; @@ -426,7 +437,7 @@ final int getInt(int rowId) { } } - private static class TimestampAccessor extends ArrowVectorAccessor { + static class TimestampAccessor extends ArrowVectorAccessor { private final TimeStampMicroTZVector accessor; @@ -441,7 +452,7 @@ final long getLong(int rowId) { } } - private static class TimestampNTZAccessor extends ArrowVectorAccessor { + static class TimestampNTZAccessor extends ArrowVectorAccessor { private final TimeStampMicroVector accessor; @@ -456,7 +467,7 @@ final long getLong(int rowId) { } } - private static class ArrayAccessor extends ArrowVectorAccessor { + static class ArrayAccessor extends ArrowVectorAccessor { private final ListVector accessor; private final ArrowColumnVector arrayData; @@ -493,14 +504,14 @@ final ColumnarArray getArray(int rowId) { * bug in the code. * */ - private static class StructAccessor extends ArrowVectorAccessor { + static class StructAccessor extends ArrowVectorAccessor { StructAccessor(StructVector vector) { super(vector); } } - private static class MapAccessor extends ArrowVectorAccessor { + static class MapAccessor extends ArrowVectorAccessor { private final MapVector accessor; private final ArrowColumnVector keys; private final ArrowColumnVector values; @@ -522,14 +533,14 @@ final ColumnarMap getMap(int rowId) { } } - private static class NullAccessor extends ArrowVectorAccessor { + static class NullAccessor extends ArrowVectorAccessor { NullAccessor(NullVector vector) { super(vector); } } - private static class IntervalYearAccessor extends ArrowVectorAccessor { + static class IntervalYearAccessor extends ArrowVectorAccessor { private final IntervalYearVector accessor; @@ -544,7 +555,7 @@ int getInt(int rowId) { } } - private static class DurationAccessor extends ArrowVectorAccessor { + static class DurationAccessor extends ArrowVectorAccessor { private final DurationVector accessor; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 3fcbf2155f456..fced82c97b445 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -102,7 +102,7 @@ object ScalaReflection extends ScalaReflection { val className = getClassNameFromType(tpe) className match { case "scala.Array" => - val TypeRef(_, _, Seq(elementType)) = tpe + val TypeRef(_, _, Seq(elementType)) = tpe.dealias arrayClassFor(elementType) case other => val clazz = getClassFromType(tpe) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala index 3c17575860db3..8dec923649f1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala @@ -86,6 +86,15 @@ object SerializerBuildHelper { returnNullable = false) } + def createSerializerForAnyTimestamp(inputObject: Expression): Expression = { + StaticInvoke( + DateTimeUtils.getClass, + TimestampType, + "anyToMicros", + inputObject :: Nil, + returnNullable = false) + } + def createSerializerForLocalDateTime(inputObject: Expression): Expression = { StaticInvoke( DateTimeUtils.getClass, @@ -113,6 +122,15 @@ object SerializerBuildHelper { returnNullable = false) } + def createSerializerForAnyDate(inputObject: Expression): Expression = { + StaticInvoke( + DateTimeUtils.getClass, + DateType, + "anyToDays", + inputObject :: Nil, + returnNullable = false) + } + def createSerializerForJavaDuration(inputObject: Expression): Expression = { StaticInvoke( IntervalUtils.getClass, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 182e5997ec34b..528998398ddeb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -28,6 +28,7 @@ import scala.util.{Failure, Random, Success, Try} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer.{extraHintForAnsiTypeCoercionExpression, DATA_TYPE_MISMATCH_ERROR} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions.{Expression, FrameLessOffsetWindowFunction, _} @@ -508,7 +509,7 @@ class Analyzer(override val catalogManager: CatalogManager) } private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = - exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined) + exprs.exists(_.exists(_.isInstanceOf[UnresolvedAlias])) def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( _.containsPattern(UNRESOLVED_ALIAS), ruleId) { @@ -529,10 +530,7 @@ class Analyzer(override val catalogManager: CatalogManager) object ResolveGroupingAnalytics extends Rule[LogicalPlan] { private[analysis] def hasGroupingFunction(e: Expression): Boolean = { - e.collectFirst { - case g: Grouping => g - case g: GroupingID => g - }.isDefined + e.exists (g => g.isInstanceOf[Grouping] || g.isInstanceOf[GroupingID]) } private def replaceGroupingFunc( @@ -615,7 +613,7 @@ class Analyzer(override val catalogManager: CatalogManager) val aggsBuffer = ArrayBuffer[Expression]() // Returns whether the expression belongs to any expressions in `aggsBuffer` or not. def isPartOfAggregation(e: Expression): Boolean = { - aggsBuffer.exists(a => a.find(_ eq e).isDefined) + aggsBuffer.exists(a => a.exists(_ eq e)) } replaceGroupingFunc(agg, groupByExprs, gid).transformDown { // AggregateExpression should be computed on the unmodified value of its argument @@ -965,14 +963,14 @@ class Analyzer(override val catalogManager: CatalogManager) } private def hasMetadataCol(plan: LogicalPlan): Boolean = { - plan.expressions.exists(_.find { + plan.expressions.exists(_.exists { case a: Attribute => // If an attribute is resolved before being labeled as metadata // (i.e. from the originating Dataset), we check with expression ID a.isMetadataCol || plan.children.exists(c => c.metadataOutput.exists(_.exprId == a.exprId)) case _ => false - }.isDefined) + }) } private def addMetadataCol(plan: LogicalPlan): LogicalPlan = plan match { @@ -1379,6 +1377,31 @@ class Analyzer(override val catalogManager: CatalogManager) throw QueryCompilationErrors.invalidStarUsageError("explode/json_tuple/UDTF", extractStar(g.generator.children)) + case u @ Union(children, _, _) + // if there are duplicate output columns, give them unique expr ids + if children.exists(c => c.output.map(_.exprId).distinct.length < c.output.length) => + val newChildren = children.map { c => + if (c.output.map(_.exprId).distinct.length < c.output.length) { + val existingExprIds = mutable.HashSet[ExprId]() + val projectList = c.output.map { attr => + if (existingExprIds.contains(attr.exprId)) { + // replace non-first duplicates with aliases and tag them + val newMetadata = new MetadataBuilder().withMetadata(attr.metadata) + .putNull("__is_duplicate").build() + Alias(attr, attr.name)(explicitMetadata = Some(newMetadata)) + } else { + // leave first duplicate alone + existingExprIds.add(attr.exprId) + attr + } + } + Project(projectList, c) + } else { + c + } + } + u.withNewChildren(newChildren) + // When resolve `SortOrder`s in Sort based on child, don't report errors as // we still have chance to resolve it based on its descendants case s @ Sort(ordering, global, child) if child.resolved && !s.resolved => @@ -1592,7 +1615,7 @@ class Analyzer(override val catalogManager: CatalogManager) exprs.exists(_.collect { case _: Star => true }.nonEmpty) private def extractStar(exprs: Seq[Expression]): Seq[Star] = - exprs.map(_.collect { case s: Star => s }).flatten + exprs.flatMap(_.collect { case s: Star => s }) /** * Expands the matching attribute.*'s in `child`'s output. @@ -1648,7 +1671,7 @@ class Analyzer(override val catalogManager: CatalogManager) } private def containsDeserializer(exprs: Seq[Expression]): Boolean = { - exprs.exists(_.find(_.isInstanceOf[UnresolvedDeserializer]).isDefined) + exprs.exists(_.exists(_.isInstanceOf[UnresolvedDeserializer])) } // support CURRENT_DATE, CURRENT_TIMESTAMP, and grouping__id @@ -1843,7 +1866,7 @@ class Analyzer(override val catalogManager: CatalogManager) withPosition(ordinal) { if (index > 0 && index <= aggs.size) { val ordinalExpr = aggs(index - 1) - if (ordinalExpr.find(_.isInstanceOf[AggregateExpression]).nonEmpty) { + if (ordinalExpr.exists(_.isInstanceOf[AggregateExpression])) { throw QueryCompilationErrors.groupByPositionRefersToAggregateFunctionError( index, ordinalExpr) } else { @@ -1879,9 +1902,6 @@ class Analyzer(override val catalogManager: CatalogManager) }} } - // Group by alias is not allowed in ANSI mode. - private def allowGroupByAlias: Boolean = conf.groupByAliases && !conf.ansiEnabled - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( // mayResolveAttrByAggregateExprs requires the TreePattern UNRESOLVED_ATTRIBUTE. _.containsAllPatterns(AGGREGATE, UNRESOLVED_ATTRIBUTE), ruleId) { @@ -2048,7 +2068,8 @@ class Analyzer(override val catalogManager: CatalogManager) _.containsAnyPattern(UNRESOLVED_FUNC, UNRESOLVED_FUNCTION, GENERATOR), ruleId) { // Resolve functions with concrete relations from v2 catalog. case u @ UnresolvedFunc(nameParts, cmd, requirePersistentFunc, mismatchHint, _) => - lookupBuiltinOrTempFunction(nameParts).map { info => + lookupBuiltinOrTempFunction(nameParts) + .orElse(lookupBuiltinOrTempTableFunction(nameParts)).map { info => if (requirePersistentFunc) { throw QueryCompilationErrors.expectPersistentFuncError( nameParts.head, cmd, mismatchHint, u) @@ -2081,10 +2102,12 @@ class Analyzer(override val catalogManager: CatalogManager) case u if !u.childrenResolved => u // Skip until children are resolved. case u @ UnresolvedGenerator(name, arguments) => withPosition(u) { - resolveBuiltinOrTempFunction(name.asMultipart, arguments, None).getOrElse { - // For generator function, the parser only accepts v1 function name and creates - // `FunctionIdentifier`. - v1SessionCatalog.resolvePersistentFunction(name, arguments) + // For generator function, the parser only accepts v1 function name and creates + // `FunctionIdentifier`. + v1SessionCatalog.lookupFunction(name, arguments) match { + case generator: Generator => generator + case other => throw QueryCompilationErrors.generatorNotExpectedError( + name, other.getClass.getCanonicalName) } } @@ -2117,6 +2140,14 @@ class Analyzer(override val catalogManager: CatalogManager) } } + def lookupBuiltinOrTempTableFunction(name: Seq[String]): Option[ExpressionInfo] = { + if (name.length == 1) { + v1SessionCatalog.lookupBuiltinOrTempTableFunction(name.head) + } else { + None + } + } + private def resolveBuiltinOrTempFunction( name: Seq[String], arguments: Seq[Expression], @@ -2274,12 +2305,14 @@ class Analyzer(override val catalogManager: CatalogManager) case Some(m) if Modifier.isStatic(m.getModifiers) => StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(), MAGIC_METHOD_NAME, arguments, inputTypes = declaredInputTypes, - propagateNull = false, returnNullable = scalarFunc.isResultNullable) + propagateNull = false, returnNullable = scalarFunc.isResultNullable, + isDeterministic = scalarFunc.isDeterministic) case Some(_) => val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass)) Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(), arguments, methodInputTypes = declaredInputTypes, propagateNull = false, - returnNullable = scalarFunc.isResultNullable) + returnNullable = scalarFunc.isResultNullable, + isDeterministic = scalarFunc.isDeterministic) case _ => // TODO: handle functions defined in Scala too - in Scala, even if a // subclass do not override the default method in parent interface @@ -2487,11 +2520,11 @@ class Analyzer(override val catalogManager: CatalogManager) }.toSet // Find the first Aggregate Expression that is not Windowed. - exprs.exists(_.collectFirst { - case ae: AggregateExpression if !windowedAggExprs.contains(ae) => ae - case e: PythonUDF if PythonUDF.isGroupedAggPandasUDF(e) && - !windowedAggExprs.contains(e) => e - }.isDefined) + exprs.exists(_.exists { + case ae: AggregateExpression => !windowedAggExprs.contains(ae) + case e: PythonUDF => PythonUDF.isGroupedAggPandasUDF(e) && !windowedAggExprs.contains(e) + case _ => false + }) } } @@ -2651,7 +2684,7 @@ class Analyzer(override val catalogManager: CatalogManager) */ object ExtractGenerator extends Rule[LogicalPlan] { private def hasGenerator(expr: Expression): Boolean = { - expr.find(_.isInstanceOf[Generator]).isDefined + expr.exists(_.isInstanceOf[Generator]) } private def hasNestedGenerator(expr: NamedExpression): Boolean = { @@ -2661,10 +2694,10 @@ class Analyzer(override val catalogManager: CatalogManager) case go: GeneratorOuter => hasInnerGenerator(go.child) case _ => - g.children.exists { _.find { + g.children.exists { _.exists { case _: Generator => true case _ => false - }.isDefined } + } } } trimNonTopLevelAliases(expr) match { case UnresolvedAlias(g: Generator, _) => hasInnerGenerator(g) @@ -2675,12 +2708,12 @@ class Analyzer(override val catalogManager: CatalogManager) } private def hasAggFunctionInGenerator(ne: Seq[NamedExpression]): Boolean = { - ne.exists(_.find { + ne.exists(_.exists { case g: Generator => - g.children.exists(_.find(_.isInstanceOf[AggregateFunction]).isDefined) + g.children.exists(_.exists(_.isInstanceOf[AggregateFunction])) case _ => false - }.nonEmpty) + }) } private def trimAlias(expr: NamedExpression): Expression = expr match { @@ -2733,6 +2766,7 @@ class Analyzer(override val catalogManager: CatalogManager) val projectExprs = Array.ofDim[NamedExpression](aggList.length) val newAggList = aggList + .toIndexedSeq .map(trimNonTopLevelAliases) .zipWithIndex .flatMap { @@ -2801,6 +2835,9 @@ class Analyzer(override val catalogManager: CatalogManager) p } + case g @ Generate(GeneratorOuter(generator), _, _, _, _, _) => + g.copy(generator = generator, outer = true) + case g: Generate => g case p if p.expressions.exists(hasGenerator) => @@ -2878,10 +2915,10 @@ class Analyzer(override val catalogManager: CatalogManager) exprs.exists(hasWindowFunction) private def hasWindowFunction(expr: Expression): Boolean = { - expr.find { + expr.exists { case window: WindowExpression => true case _ => false - }.isDefined + } } /** @@ -3581,7 +3618,8 @@ class Analyzer(override val catalogManager: CatalogManager) case u @ UpCast(child, _, _) if !child.resolved => u case UpCast(_, target, _) if target != DecimalType && !target.isInstanceOf[DataType] => - throw QueryCompilationErrors.unsupportedAbstractDataTypeForUpCastError(target) + throw new IllegalStateException( + s"UpCast only supports DecimalType as AbstractDataType yet, but got: $target") case UpCast(child, target, walkedTypePath) if target == DecimalType && child.dataType.isInstanceOf[DecimalType] => @@ -3716,7 +3754,7 @@ class Analyzer(override val catalogManager: CatalogManager) } private def hasUnresolvedFieldName(a: AlterTableCommand): Boolean = { - a.expressions.exists(_.find(_.isInstanceOf[UnresolvedFieldName]).isDefined) + a.expressions.exists(_.exists(_.isInstanceOf[UnresolvedFieldName])) } } @@ -3833,8 +3871,8 @@ object TimeWindowing extends Rule[LogicalPlan] { * The windows are calculated as below: * maxNumOverlapping <- ceil(windowDuration / slideDuration) * for (i <- 0 until maxNumOverlapping) - * windowId <- ceil((timestamp - startTime) / slideDuration) - * windowStart <- windowId * slideDuration + (i - maxNumOverlapping) * slideDuration + startTime + * lastStart <- timestamp - (timestamp - startTime + slideDuration) % slideDuration + * windowStart <- lastStart - i * slideDuration * windowEnd <- windowStart + windowDuration * return windowStart, windowEnd * @@ -3874,21 +3912,20 @@ object TimeWindowing extends Rule[LogicalPlan] { case _ => Metadata.empty } - def getWindow(i: Int, overlappingWindows: Int, dataType: DataType): Expression = { - val division = (PreciseTimestampConversion( - window.timeColumn, dataType, LongType) - window.startTime) / window.slideDuration - val ceil = Ceil(division) - // if the division is equal to the ceiling, our record is the start of a window - val windowId = CaseWhen(Seq((ceil === division, ceil + 1)), Some(ceil)) - val windowStart = (windowId + i - overlappingWindows) * - window.slideDuration + window.startTime + def getWindow(i: Int, dataType: DataType): Expression = { + val timestamp = PreciseTimestampConversion(window.timeColumn, dataType, LongType) + val lastStart = timestamp - (timestamp - window.startTime + + window.slideDuration) % window.slideDuration + val windowStart = lastStart - i * window.slideDuration val windowEnd = windowStart + window.windowDuration + // We make sure value fields are nullable since the dataType of TimeWindow defines them + // as nullable. CreateNamedStruct( Literal(WINDOW_START) :: - PreciseTimestampConversion(windowStart, LongType, dataType) :: + PreciseTimestampConversion(windowStart, LongType, dataType).castNullable() :: Literal(WINDOW_END) :: - PreciseTimestampConversion(windowEnd, LongType, dataType) :: + PreciseTimestampConversion(windowEnd, LongType, dataType).castNullable() :: Nil) } @@ -3896,7 +3933,7 @@ object TimeWindowing extends Rule[LogicalPlan] { WINDOW_COL_NAME, window.dataType, metadata = metadata)() if (window.windowDuration == window.slideDuration) { - val windowStruct = Alias(getWindow(0, 1, window.timeColumn.dataType), WINDOW_COL_NAME)( + val windowStruct = Alias(getWindow(0, window.timeColumn.dataType), WINDOW_COL_NAME)( exprId = windowAttr.exprId, explicitMetadata = Some(metadata)) val replacedPlan = p transformExpressions { @@ -3914,13 +3951,20 @@ object TimeWindowing extends Rule[LogicalPlan] { math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt val windows = Seq.tabulate(overlappingWindows)(i => - getWindow(i, overlappingWindows, window.timeColumn.dataType)) + getWindow(i, window.timeColumn.dataType)) val projections = windows.map(_ +: child.output) + // When the condition windowDuration % slideDuration = 0 is fulfilled, + // the estimation of the number of windows becomes exact one, + // which means all produced windows are valid. val filterExpr = - window.timeColumn >= windowAttr.getField(WINDOW_START) && - window.timeColumn < windowAttr.getField(WINDOW_END) + if (window.windowDuration % window.slideDuration == 0) { + IsNotNull(window.timeColumn) + } else { + window.timeColumn >= windowAttr.getField(WINDOW_START) && + window.timeColumn < windowAttr.getField(WINDOW_END) + } val substitutedPlan = Filter(filterExpr, Expand(projections, windowAttr +: child.output, child)) @@ -3996,11 +4040,15 @@ object SessionWindowing extends Rule[LogicalPlan] { val sessionEnd = PreciseTimestampConversion(session.timeColumn + gapDuration, session.timeColumn.dataType, LongType) + // We make sure value fields are nullable since the dataType of SessionWindow defines them + // as nullable. val literalSessionStruct = CreateNamedStruct( Literal(SESSION_START) :: - PreciseTimestampConversion(sessionStart, LongType, session.timeColumn.dataType) :: + PreciseTimestampConversion(sessionStart, LongType, session.timeColumn.dataType) + .castNullable() :: Literal(SESSION_END) :: - PreciseTimestampConversion(sessionEnd, LongType, session.timeColumn.dataType) :: + PreciseTimestampConversion(sessionEnd, LongType, session.timeColumn.dataType) + .castNullable() :: Nil) val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)( @@ -4011,7 +4059,7 @@ object SessionWindowing extends Rule[LogicalPlan] { } // As same as tumbling window, we add a filter to filter out nulls. - // And we also filter out events with negative or zero gap duration. + // And we also filter out events with negative or zero or invalid gap duration. val filterExpr = IsNotNull(session.timeColumn) && (sessionAttr.getField(SESSION_END) > sessionAttr.getField(SESSION_START)) @@ -4239,7 +4287,30 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { * rule right after the main resolution batch. */ object RemoveTempResolvedColumn extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { - case t: TempResolvedColumn => UnresolvedAttribute(t.nameParts) + override def apply(plan: LogicalPlan): LogicalPlan = { + plan.foreachUp { + // HAVING clause will be resolved as a Filter. When having func(column with wrong data type), + // the column could be wrapped by a TempResolvedColumn, e.g. mean(tempresolvedcolumn(t.c)). + // Because TempResolvedColumn can still preserve column data type, here is a chance to check + // if the data type matches with the required data type of the function. We can throw an error + // when data types mismatches. + case operator: Filter => + operator.expressions.foreach(_.foreachUp { + case e: Expression if e.childrenResolved && e.checkInputDataTypes().isFailure => + e.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckFailure(message) => + e.setTagValue(DATA_TYPE_MISMATCH_ERROR, true) + e.failAnalysis( + s"cannot resolve '${e.sql}' due to data type mismatch: $message" + + extraHintForAnsiTypeCoercionExpression(plan)) + } + case _ => + }) + case _ => + } + + plan.resolveExpressions { + case t: TempResolvedColumn => UnresolvedAttribute(t.nameParts) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index 4ff2fbf3b3a9d..036efba34fab6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -68,7 +68,7 @@ import org.apache.spark.sql.types._ * * CreateMap * * For complex types (struct, array, map), Spark recursively looks into the element type and * applies the rules above. - * Note: this new type coercion system will allow implicit converting String type literals as other + * Note: this new type coercion system will allow implicit converting String type as other * primitive types, in case of breaking too many existing Spark SQL queries. This is a special * rule and it is not from the ANSI SQL standard. */ @@ -77,7 +77,7 @@ object AnsiTypeCoercion extends TypeCoercionBase { WidenSetOperationTypes :: new AnsiCombinedTypeCoercionRule( InConversion :: - PromoteStringLiterals :: + PromoteStrings :: DecimalPrecision :: FunctionArgumentConversion :: ConcatCoercion :: @@ -130,9 +130,27 @@ object AnsiTypeCoercion extends TypeCoercionBase { override def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = { findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) + .orElse(findWiderTypeForString(t1, t2)) .orElse(findTypeForComplex(t1, t2, findWiderTypeForTwo)) } + /** Promotes StringType to other data types. */ + @scala.annotation.tailrec + private def findWiderTypeForString(dt1: DataType, dt2: DataType): Option[DataType] = { + (dt1, dt2) match { + case (StringType, _: IntegralType) => Some(LongType) + case (StringType, _: FractionalType) => Some(DoubleType) + case (StringType, NullType) => Some(StringType) + // If a binary operation contains interval type and string, we can't decide which + // interval type the string should be promoted as. There are many possible interval + // types, such as year interval, month interval, day interval, hour interval, etc. + case (StringType, _: AnsiIntervalType) => None + case (StringType, a: AtomicType) => Some(a) + case (other, StringType) if other != StringType => findWiderTypeForString(StringType, other) + case _ => None + } + } + override def findWiderCommonType(types: Seq[DataType]): Option[DataType] = { types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { @@ -142,7 +160,7 @@ object AnsiTypeCoercion extends TypeCoercionBase { } override def implicitCast(e: Expression, expectedType: AbstractDataType): Option[Expression] = { - implicitCast(e.dataType, expectedType, e.foldable).map { dt => + implicitCast(e.dataType, expectedType).map { dt => if (dt == e.dataType) e else Cast(e, dt) } } @@ -153,8 +171,7 @@ object AnsiTypeCoercion extends TypeCoercionBase { */ private def implicitCast( inType: DataType, - expectedType: AbstractDataType, - isInputFoldable: Boolean): Option[DataType] = { + expectedType: AbstractDataType): Option[DataType] = { (inType, expectedType) match { // If the expected type equals the input type, no need to cast. case _ if expectedType.acceptsType(inType) => Some(inType) @@ -169,19 +186,28 @@ object AnsiTypeCoercion extends TypeCoercionBase { case (NullType, target) if !target.isInstanceOf[TypeCollection] => Some(target.defaultConcreteType) - // This type coercion system will allow implicit converting String type literals as other + // This type coercion system will allow implicit converting String type as other // primitive types, in case of breaking too many existing Spark SQL queries. - case (StringType, a: AtomicType) if isInputFoldable => + case (StringType, a: AtomicType) => Some(a) - // If the target type is any Numeric type, convert the String type literal as Double type. - case (StringType, NumericType) if isInputFoldable => + // If the target type is any Numeric type, convert the String type as Double type. + case (StringType, NumericType) => Some(DoubleType) - // If the target type is any Decimal type, convert the String type literal as Double type. - case (StringType, DecimalType) if isInputFoldable => + // If the target type is any Decimal type, convert the String type as the default + // Decimal type. + case (StringType, DecimalType) => Some(DecimalType.SYSTEM_DEFAULT) + // If the target type is any timestamp type, convert the String type as the default + // Timestamp type. + case (StringType, AnyTimestampType) => + Some(AnyTimestampType.defaultConcreteType) + + case (DateType, AnyTimestampType) => + Some(AnyTimestampType.defaultConcreteType) + case (_, target: DataType) => if (Cast.canANSIStoreAssign(inType, target)) { Some(target) @@ -192,7 +218,7 @@ object AnsiTypeCoercion extends TypeCoercionBase { // When we reach here, input type is not acceptable for any types in this type collection, // try to find the first one we can implicitly cast. case (_, TypeCollection(types)) => - types.flatMap(implicitCast(inType, _, isInputFoldable)).headOption + types.flatMap(implicitCast(inType, _)).headOption case _ => None } @@ -200,10 +226,7 @@ object AnsiTypeCoercion extends TypeCoercionBase { override def canCast(from: DataType, to: DataType): Boolean = AnsiCast.canCast(from, to) - /** - * Promotes string literals that appear in arithmetic, comparison, and datetime expressions. - */ - object PromoteStringLiterals extends TypeCoercionRule { + object PromoteStrings extends TypeCoercionRule { private def castExpr(expr: Expression, targetType: DataType): Expression = { expr.dataType match { case NullType => Literal.create(null, targetType) @@ -212,55 +235,37 @@ object AnsiTypeCoercion extends TypeCoercionBase { } } - // Return whether a string literal can be promoted as the give data type in a binary operation. - private def canPromoteAsInBinaryOperation(dt: DataType) = dt match { - // If a binary operation contains interval type and string literal, we can't decide which - // interval type the string literal should be promoted as. There are many possible interval - // types, such as year interval, month interval, day interval, hour interval, etc. - case _: AnsiIntervalType => false - case _: AtomicType => true - case _ => false - } - override def transform: PartialFunction[Expression, Expression] = { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case b @ BinaryOperator(left @ StringType(), right) - if left.foldable && canPromoteAsInBinaryOperation(right.dataType) => - b.makeCopy(Array(castExpr(left, right.dataType), right)) + case b @ BinaryOperator(left, right) + if findWiderTypeForString(left.dataType, right.dataType).isDefined => + val promoteType = findWiderTypeForString(left.dataType, right.dataType).get + b.withNewChildren(Seq(castExpr(left, promoteType), castExpr(right, promoteType))) - case b @ BinaryOperator(left, right @ StringType()) - if right.foldable && canPromoteAsInBinaryOperation(left.dataType) => - b.makeCopy(Array(left, castExpr(right, left.dataType))) + case Abs(e @ StringType(), failOnError) => Abs(Cast(e, DoubleType), failOnError) + case m @ UnaryMinus(e @ StringType(), _) => m.withNewChildren(Seq(Cast(e, DoubleType))) + case UnaryPositive(e @ StringType()) => UnaryPositive(Cast(e, DoubleType)) - // Promotes string literals in `In predicate`. - case p @ In(a, b) - if a.dataType != StringType && b.exists( e => e.foldable && e.dataType == StringType) => - val newList = b.map { - case e @ StringType() if e.foldable => Cast(e, a.dataType) - case other => other - } - p.makeCopy(Array(a, newList)) - - case d @ DateAdd(left @ StringType(), _) if left.foldable => + case d @ DateAdd(left @ StringType(), _) => d.copy(startDate = Cast(d.startDate, DateType)) - case d @ DateAdd(_, right @ StringType()) if right.foldable => + case d @ DateAdd(_, right @ StringType()) => d.copy(days = Cast(right, IntegerType)) - case d @ DateSub(left @ StringType(), _) if left.foldable => + case d @ DateSub(left @ StringType(), _) => d.copy(startDate = Cast(d.startDate, DateType)) - case d @ DateSub(_, right @ StringType()) if right.foldable => + case d @ DateSub(_, right @ StringType()) => d.copy(days = Cast(right, IntegerType)) - case s @ SubtractDates(left @ StringType(), _, _) if left.foldable => + case s @ SubtractDates(left @ StringType(), _, _) => s.copy(left = Cast(s.left, DateType)) - case s @ SubtractDates(_, right @ StringType(), _) if right.foldable => + case s @ SubtractDates(_, right @ StringType(), _) => s.copy(right = Cast(s.right, DateType)) - case t @ TimeAdd(left @ StringType(), _, _) if left.foldable => + case t @ TimeAdd(left @ StringType(), _, _) => t.copy(start = Cast(t.start, TimestampType)) - case t @ SubtractTimestamps(left @ StringType(), _, _, _) if left.foldable => + case t @ SubtractTimestamps(left @ StringType(), _, _, _) => t.copy(left = Cast(t.left, t.right.dataType)) - case t @ SubtractTimestamps(_, right @ StringType(), _, _) if right.foldable => + case t @ SubtractTimestamps(_, right @ StringType(), _, _) => t.copy(right = Cast(right, t.left.dataType)) } } @@ -275,6 +280,9 @@ object AnsiTypeCoercion extends TypeCoercionBase { */ object GetDateFieldOperations extends TypeCoercionRule { override def transform: PartialFunction[Expression, Expression] = { + // Skip nodes who's children have not been resolved yet. + case g if !g.childrenResolved => g + case g: GetDateField if AnyTimestampType.unapply(g.child) => g.withNewChildren(Seq(Cast(g.child, DateType))) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala index 2397527133f13..c0ba3598e4ba1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala @@ -51,10 +51,10 @@ object CTESubstitution extends Rule[LogicalPlan] { if (!plan.containsPattern(UNRESOLVED_WITH)) { return plan } - val isCommand = plan.find { + val isCommand = plan.exists { case _: Command | _: ParsedStatement | _: InsertIntoDir => true case _ => false - }.isDefined + } val cteDefs = mutable.ArrayBuffer.empty[CTERelationDef] val (substituted, lastSubstituted) = LegacyBehaviorPolicy.withName(conf.getConf(LEGACY_CTE_PRECEDENCE_POLICY)) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index d06996a09df02..c05b9326d2304 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification +import org.apache.spark.sql.catalyst.optimizer.{BooleanSimplification, DecorrelateInnerQuery} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.TreeNodeTag @@ -199,6 +199,10 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { failAnalysis(s"invalid cast from ${c.child.dataType.catalogString} to " + c.dataType.catalogString) + case e: RuntimeReplaceable if !e.replacement.resolved => + throw new IllegalStateException("Illegal RuntimeReplaceable: " + e + + "\nReplacement is unresolved: " + e.replacement) + case g: Grouping => failAnalysis("grouping() can only be used with GroupingSets/Cube/Rollup") case g: GroupingID => @@ -330,7 +334,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { } def checkValidGroupingExprs(expr: Expression): Unit = { - if (expr.find(_.isInstanceOf[AggregateExpression]).isDefined) { + if (expr.exists(_.isInstanceOf[AggregateExpression])) { failAnalysis( "aggregate functions are not allowed in GROUP BY, but found " + expr.sql) } @@ -431,7 +435,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { // Check if the data types match. dataTypes(child).zip(ref).zipWithIndex.foreach { case ((dt1, dt2), ci) => // SPARK-18058: we shall not care about the nullability of columns - if (dataTypesAreCompatibleFn(dt1, dt2)) { + if (!dataTypesAreCompatibleFn(dt1, dt2)) { val errorMessage = s""" |${operator.nodeName} can only be performed on tables with the compatible @@ -603,11 +607,11 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { val isUnion = plan.isInstanceOf[Union] if (isUnion) { (dt1: DataType, dt2: DataType) => - !DataType.equalsStructurally(dt1, dt2, true) + DataType.equalsStructurally(dt1, dt2, true) } else { // SPARK-18058: we shall not care about the nullability of columns (dt1: DataType, dt2: DataType) => - TypeCoercion.findWiderTypeForTwo(dt1.asNullable, dt2.asNullable).isEmpty + TypeCoercion.findWiderTypeForTwo(dt1.asNullable, dt2.asNullable).nonEmpty } } @@ -623,7 +627,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { } } - private def extraHintForAnsiTypeCoercionExpression(plan: LogicalPlan): String = { + private[analysis] def extraHintForAnsiTypeCoercionExpression(plan: LogicalPlan): String = { if (!SQLConf.get.ansiEnabled) { "" } else { @@ -658,7 +662,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { nonAnsiPlan.children.tail.zipWithIndex.foreach { case (child, ti) => // Check if the data types match. dataTypes(child).zip(ref).zipWithIndex.foreach { case ((dt1, dt2), ci) => - if (dataTypesAreCompatibleFn(dt1, dt2)) { + if (!dataTypesAreCompatibleFn(dt1, dt2)) { issueFixedIfAnsiOff = false } } @@ -714,7 +718,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { // Check whether the given expressions contains the subquery expression. def containsExpr(expressions: Seq[Expression]): Boolean = { - expressions.exists(_.find(_.semanticEquals(expr)).isDefined) + expressions.exists(_.exists(_.semanticEquals(expr))) } // Validate the subquery plan. @@ -755,7 +759,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { } // Validate to make sure the correlations appearing in the query are valid and // allowed by spark. - checkCorrelationsInSubquery(expr.plan, isScalarOrLateral = true) + checkCorrelationsInSubquery(expr.plan, isScalar = true) case _: LateralSubquery => assert(plan.isInstanceOf[LateralJoin]) @@ -774,7 +778,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { } // Validate to make sure the correlations appearing in the query are valid and // allowed by spark. - checkCorrelationsInSubquery(expr.plan, isScalarOrLateral = true) + checkCorrelationsInSubquery(expr.plan, isLateral = true) case inSubqueryOrExistsSubquery => plan match { @@ -827,7 +831,8 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { */ private def checkCorrelationsInSubquery( sub: LogicalPlan, - isScalarOrLateral: Boolean = false): Unit = { + isScalar: Boolean = false, + isLateral: Boolean = false): Unit = { // Validate that correlated aggregate expression do not contain a mixture // of outer and local references. def checkMixedReferencesInsideAggregateExpr(expr: Expression): Unit = { @@ -852,7 +857,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { // DecorrelateInnerQuery is enabled. Otherwise, only Filter can only outer references. def canHostOuter(plan: LogicalPlan): Boolean = plan match { case _: Filter => true - case _: Project => isScalarOrLateral && SQLConf.get.decorrelateInnerQueryEnabled + case _: Project => (isScalar || isLateral) && SQLConf.get.decorrelateInnerQueryEnabled case _ => false } @@ -932,31 +937,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { } } - def containsAttribute(e: Expression): Boolean = { - e.find(_.isInstanceOf[Attribute]).isDefined - } - - // Given a correlated predicate, check if it is either a non-equality predicate or - // equality predicate that does not guarantee one-on-one mapping between inner and - // outer attributes. When the correlated predicate does not contain any attribute - // (i.e. only has outer references), it is supported and should return false. E.G.: - // (a = outer(c)) -> false - // (outer(c) = outer(d)) -> false - // (a > outer(c)) -> true - // (a + b = outer(c)) -> true - // The last one is true because there can be multiple combinations of (a, b) that - // satisfy the equality condition. For example, if outer(c) = 0, then both (0, 0) - // and (-1, 1) can make the predicate evaluate to true. - def isUnsupportedPredicate(condition: Expression): Boolean = condition match { - // Only allow equality condition with one side being an attribute and another - // side being an expression without attributes from the inner query. Note - // OuterReference is a leaf node and will not be found here. - case Equality(_: Attribute, b) => containsAttribute(b) - case Equality(a, _: Attribute) => containsAttribute(a) - case e @ Equality(_, _) => containsAttribute(e) - case _ => true - } - val unsupportedPredicates = mutable.ArrayBuffer.empty[Expression] // Simplify the predicates before validating any unsupported correlation patterns in the plan. @@ -980,8 +960,8 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { // up to the operator producing the correlated values. // Category 1: - // ResolvedHint, Distinct, LeafNode, Repartition, and SubqueryAlias - case _: ResolvedHint | _: Distinct | _: LeafNode | _: Repartition | _: SubqueryAlias => + // ResolvedHint, LeafNode, Repartition, and SubqueryAlias + case _: ResolvedHint | _: LeafNode | _: Repartition | _: SubqueryAlias => // Category 2: // These operators can be anywhere in a correlated subquery. @@ -1003,7 +983,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { // The other operator is Join. Filter can be anywhere in a correlated subquery. case f: Filter => val (correlated, _) = splitConjunctivePredicates(f.condition).partition(containsOuter) - unsupportedPredicates ++= correlated.filter(isUnsupportedPredicate) + unsupportedPredicates ++= correlated.filterNot(DecorrelateInnerQuery.canPullUpOverAgg) failOnInvalidOuterReference(f) // Aggregate cannot host any correlated expressions @@ -1015,6 +995,16 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { failOnInvalidOuterReference(a) failOnUnsupportedCorrelatedPredicate(unsupportedPredicates.toSeq, a) + // Distinct does not host any correlated expressions, but during the optimization phase + // it will be rewritten as Aggregate, which can only be on a correlation path if the + // correlation contains only the supported correlated equality predicates. + // Only block it for lateral subqueries because scalar subqueries must be aggregated + // and it does not impact the results for IN/EXISTS subqueries. + case d: Distinct => + if (isLateral) { + failOnUnsupportedCorrelatedPredicate(unsupportedPredicates.toSeq, d) + } + // Join can host correlated expressions. case j @ Join(left, right, joinType, _, _) => joinType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala index 55b1c221c8378..4c351e3237df2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala @@ -40,7 +40,12 @@ case class ReferenceEqualPlanWrapper(plan: LogicalPlan) { object DeduplicateRelations extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { - renewDuplicatedRelations(mutable.HashSet.empty, plan)._1.resolveOperatorsUpWithPruning( + val newPlan = renewDuplicatedRelations(mutable.HashSet.empty, plan)._1 + if (newPlan.find(p => p.resolved && p.missingInput.nonEmpty).isDefined) { + // Wait for `ResolveMissingReferences` to resolve missing attributes first + return newPlan + } + newPlan.resolveOperatorsUpWithPruning( _.containsAnyPattern(JOIN, LATERAL_JOIN, AS_OF_JOIN, INTERSECT, EXCEPT, UNION, COMMAND), ruleId) { case p: LogicalPlan if !p.childrenResolved => p diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index c995ff8637529..e5954c8f26942 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -23,6 +23,7 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.mutable import scala.reflect.ClassTag +import org.apache.spark.SparkThrowable import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.FunctionIdentifier @@ -110,9 +111,11 @@ object FunctionRegistryBase { name: String, since: Option[String]): (ExpressionInfo, Seq[Expression] => T) = { val runtimeClass = scala.reflect.classTag[T].runtimeClass - // For `RuntimeReplaceable`, skip the constructor with most arguments, which is the main - // constructor and contains non-parameter `child` and should not be used as function builder. - val constructors = if (classOf[RuntimeReplaceable].isAssignableFrom(runtimeClass)) { + // For `InheritAnalysisRules`, skip the constructor with most arguments, which is the main + // constructor and contains non-parameter `replacement` and should not be used as + // function builder. + val isRuntime = classOf[InheritAnalysisRules].isAssignableFrom(runtimeClass) + val constructors = if (isRuntime) { val all = runtimeClass.getConstructors val maxNumArgs = all.map(_.getParameterCount).max all.filterNot(_.getParameterCount == maxNumArgs) @@ -129,7 +132,11 @@ object FunctionRegistryBase { } catch { // the exception is an invocation exception. To get a meaningful message, we need the // cause. - case e: Exception => throw new AnalysisException(e.getCause.getMessage) + case e: Exception => + throw e.getCause match { + case ae: SparkThrowable => ae + case _ => new AnalysisException(e.getCause.getMessage) + } } } else { // Otherwise, find a constructor method that matches the number of arguments, and use that. @@ -319,7 +326,37 @@ object FunctionRegistry { val FUNC_ALIAS = TreeNodeTag[String]("functionAliasName") - // Note: Whenever we add a new entry here, make sure we also update ExpressionToSQLSuite + // ============================================================================================== + // The guideline for adding SQL functions + // ============================================================================================== + // To add a SQL function, we usually need to create a new `Expression` for the function, and + // implement the function logic in both the interpretation code path and codegen code path of the + // `Expression`. We also need to define the type coercion behavior for the function inputs, by + // extending `ImplicitCastInputTypes` or updating type coercion rules directly. + // + // It's much simpler if the SQL function can be implemented with existing expression(s). There are + // a few cases: + // - The function is simply an alias of another function. We can just register the same + // expression with a different function name, e.g. `expression[Rand]("random", true)`. + // - The function is mostly the same with another function, but has a different parameter list. + // We can use `RuntimeReplaceable` to create a new expression, which can customize the + // parameter list and analysis behavior (type coercion). The `RuntimeReplaceable` expression + // will be replaced by the actual expression at the end of analysis. See `Left` as an example. + // - The function can be implemented by combining some existing expressions. We can use + // `RuntimeReplaceable` to define the combination. See `ParseToDate` as an example. + // To inherit the analysis behavior from the replacement expression + // mix-in `InheritAnalysisRules` with `RuntimeReplaceable`. See `TryAdd` as an example. + // - For `AggregateFunction`, `RuntimeReplaceableAggregate` should be mixed-in. See + // `CountIf` as an example. + // + // Sometimes, multiple functions share the same/similar expression replacement logic and it's + // tedious to create many similar `RuntimeReplaceable` expressions. We can use `ExpressionBuilder` + // to share the replacement logic. See `ParseToTimestampLTZExpressionBuilder` as an example. + // + // With these tools, we can even implement a new SQL function with a Java (static) method, and + // then create a `RuntimeReplaceable` expression to call the Java method with `Invoke` or + // `StaticInvoke` expression. By doing so we don't need to implement codegen for new functions + // anymore. See `AesEncrypt`/`AesDecrypt` as an example. val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] = Map( // misc non-aggregate functions expression[Abs]("abs"), @@ -331,7 +368,7 @@ object FunctionRegistry { expression[Inline]("inline"), expressionGeneratorOuter[Inline]("inline_outer"), expression[IsNaN]("isnan"), - expression[IfNull]("ifnull"), + expression[Nvl]("ifnull", setAlias = true), expression[IsNull]("isnull"), expression[IsNotNull]("isnotnull"), expression[Least]("least"), @@ -358,8 +395,8 @@ object FunctionRegistry { expression[Bin]("bin"), expression[BRound]("bround"), expression[Cbrt]("cbrt"), - expression[Ceil]("ceil"), - expression[Ceil]("ceiling", true), + expressionBuilder("ceil", CeilExpressionBuilder), + expressionBuilder("ceiling", CeilExpressionBuilder, true), expression[Cos]("cos"), expression[Sec]("sec"), expression[Cosh]("cosh"), @@ -368,7 +405,7 @@ object FunctionRegistry { expression[EulerNumber]("e"), expression[Exp]("exp"), expression[Expm1]("expm1"), - expression[Floor]("floor"), + expressionBuilder("floor", FloorExpressionBuilder), expression[Factorial]("factorial"), expression[Hex]("hex"), expression[Hypot]("hypot"), @@ -412,6 +449,8 @@ object FunctionRegistry { // "try_*" function which always return Null instead of runtime error. expression[TryAdd]("try_add"), expression[TryDivide]("try_divide"), + expression[TrySubtract]("try_subtract"), + expression[TryMultiply]("try_multiply"), expression[TryElementAt]("try_element_at"), // aggregate functions @@ -455,14 +494,16 @@ object FunctionRegistry { expression[BoolOr]("some", true), expression[BoolOr]("bool_or"), expression[RegrCount]("regr_count"), + expression[RegrAvgX]("regr_avgx"), + expression[RegrAvgY]("regr_avgy"), // string functions expression[Ascii]("ascii"), expression[Chr]("char", true), expression[Chr]("chr"), - expression[Contains]("contains"), - expression[StartsWith]("startswith"), - expression[EndsWith]("endswith"), + expressionBuilder("contains", ContainsExpressionBuilder), + expressionBuilder("startswith", StartsWithExpressionBuilder), + expressionBuilder("endswith", EndsWithExpressionBuilder), expression[Base64]("base64"), expression[BitLength]("bit_length"), expression[Length]("char_length", true), @@ -474,6 +515,7 @@ object FunctionRegistry { expression[FindInSet]("find_in_set"), expression[FormatNumber]("format_number"), expression[FormatString]("format_string"), + expression[ToNumber]("to_number"), expression[GetJsonObject]("get_json_object"), expression[InitCap]("initcap"), expression[StringInstr]("instr"), @@ -554,10 +596,12 @@ object FunctionRegistry { expression[Second]("second"), expression[ParseToTimestamp]("to_timestamp"), expression[ParseToDate]("to_date"), + expression[ToBinary]("to_binary"), expression[ToUnixTimestamp]("to_unix_timestamp"), expression[ToUTCTimestamp]("to_utc_timestamp"), - expression[ParseToTimestampNTZ]("to_timestamp_ntz"), - expression[ParseToTimestampLTZ]("to_timestamp_ltz"), + // We keep the 2 expression builders below to have different function docs. + expressionBuilder("to_timestamp_ntz", ParseToTimestampNTZExpressionBuilder, setAlias = true), + expressionBuilder("to_timestamp_ltz", ParseToTimestampLTZExpressionBuilder, setAlias = true), expression[TruncDate]("trunc"), expression[TruncTimestamp]("date_trunc"), expression[UnixTimestamp]("unix_timestamp"), @@ -569,13 +613,15 @@ object FunctionRegistry { expression[SessionWindow]("session_window"), expression[MakeDate]("make_date"), expression[MakeTimestamp]("make_timestamp"), - expression[MakeTimestampNTZ]("make_timestamp_ntz"), - expression[MakeTimestampLTZ]("make_timestamp_ltz"), + // We keep the 2 expression builders below to have different function docs. + expressionBuilder("make_timestamp_ntz", MakeTimestampNTZExpressionBuilder, setAlias = true), + expressionBuilder("make_timestamp_ltz", MakeTimestampLTZExpressionBuilder, setAlias = true), expression[MakeInterval]("make_interval"), expression[MakeDTInterval]("make_dt_interval"), expression[MakeYMInterval]("make_ym_interval"), - expression[DatePart]("date_part"), expression[Extract]("extract"), + // We keep the `DatePartExpressionBuilder` to have different function docs. + expressionBuilder("date_part", DatePartExpressionBuilder, setAlias = true), expression[DateFromUnixDate]("date_from_unix_date"), expression[UnixDate]("unix_date"), expression[SecondsToTimestamp]("timestamp_seconds"), @@ -593,6 +639,7 @@ object FunctionRegistry { expression[ArrayIntersect]("array_intersect"), expression[ArrayJoin]("array_join"), expression[ArrayPosition]("array_position"), + expression[ArraySize]("array_size"), expression[ArraySort]("array_sort"), expression[ArrayExcept]("array_except"), expression[ArrayUnion]("array_union"), @@ -796,11 +843,15 @@ object FunctionRegistry { } private def expressionBuilder[T <: ExpressionBuilder : ClassTag]( - name: String, builder: T): (String, (ExpressionInfo, FunctionBuilder)) = { + name: String, + builder: T, + setAlias: Boolean = false): (String, (ExpressionInfo, FunctionBuilder)) = { val info = FunctionRegistryBase.expressionInfo[T](name, None) val funcBuilder = (expressions: Seq[Expression]) => { assert(expressions.forall(_.resolved), "function arguments must be resolved.") - builder.build(expressions) + val expr = builder.build(name, expressions) + if (setAlias) expr.setTagValue(FUNC_ALIAS, name) + expr } (name, (info, funcBuilder)) } @@ -902,5 +953,5 @@ object TableFunctionRegistry { } trait ExpressionBuilder { - def build(expressions: Seq[Expression]): Expression + def build(funcName: String, expressions: Seq[Expression]): Expression } diff --git a/core/src/main/java/org/apache/hadoop/shaded/net/jpountz/lz4/LZ4Compressor.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NonEmptyException.scala similarity index 55% rename from core/src/main/java/org/apache/hadoop/shaded/net/jpountz/lz4/LZ4Compressor.java rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NonEmptyException.scala index 092ed59c6bb14..f3ff28f74fcc3 100644 --- a/core/src/main/java/org/apache/hadoop/shaded/net/jpountz/lz4/LZ4Compressor.java +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NonEmptyException.scala @@ -15,23 +15,22 @@ * limitations under the License. */ -package org.apache.hadoop.shaded.net.jpountz.lz4; +package org.apache.spark.sql.catalyst.analysis -/** - * TODO(SPARK-36679): A temporary workaround for SPARK-36669. We should remove this after - * Hadoop 3.3.2 release which fixes the LZ4 relocation in shaded Hadoop client libraries. - * This does not need implement all net.jpountz.lz4.LZ4Compressor API, just the ones used - * by Hadoop Lz4Compressor. - */ -public final class LZ4Compressor { +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ - private net.jpountz.lz4.LZ4Compressor lz4Compressor; - public LZ4Compressor(net.jpountz.lz4.LZ4Compressor lz4Compressor) { - this.lz4Compressor = lz4Compressor; - } +/** + * Thrown by a catalog when an item already exists. The analyzer will rethrow the exception + * as an [[org.apache.spark.sql.AnalysisException]] with the correct position information. + */ +case class NonEmptyNamespaceException( + override val message: String, + override val cause: Option[Throwable] = None) + extends AnalysisException(message, cause = cause) { - public void compress(java.nio.ByteBuffer src, java.nio.ByteBuffer dest) { - lz4Compressor.compress(src, dest); + def this(namespace: Array[String]) = { + this(s"Namespace '${namespace.quoted}' is non empty.") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index 27f2a5f416d56..46ebffea1aec5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -250,14 +250,26 @@ object ResolveHints { } private def createRebalance(hint: UnresolvedHint): LogicalPlan = { + def createRebalancePartitions( + partitionExprs: Seq[Any], initialNumPartitions: Option[Int]): RebalancePartitions = { + val invalidParams = partitionExprs.filter(!_.isInstanceOf[UnresolvedAttribute]) + if (invalidParams.nonEmpty) { + val hintName = hint.name.toUpperCase(Locale.ROOT) + throw QueryCompilationErrors.invalidHintParameterError(hintName, invalidParams) + } + RebalancePartitions( + partitionExprs.map(_.asInstanceOf[Expression]), + hint.child, + initialNumPartitions) + } + hint.parameters match { + case param @ Seq(IntegerLiteral(numPartitions), _*) => + createRebalancePartitions(param.tail, Some(numPartitions)) + case param @ Seq(numPartitions: Int, _*) => + createRebalancePartitions(param.tail, Some(numPartitions)) case partitionExprs @ Seq(_*) => - val invalidParams = partitionExprs.filter(!_.isInstanceOf[UnresolvedAttribute]) - if (invalidParams.nonEmpty) { - val hintName = hint.name.toUpperCase(Locale.ROOT) - throw QueryCompilationErrors.invalidHintParameterError(hintName, invalidParams) - } - RebalancePartitions(partitionExprs.map(_.asInstanceOf[Expression]), hint.child) + createRebalancePartitions(partitionExprs, None) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala index 9cdd77ee5a52d..3c5ab55a8a72a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala @@ -169,7 +169,7 @@ object StreamingJoinHelper extends PredicateHelper with Logging { return None } val constraintTerm = constraintTerms.head - if (constraintTerm.collectFirst { case u: UnaryMinus => u }.isEmpty) { + if (!constraintTerm.exists(_.isInstanceOf[UnaryMinus])) { // Incorrect condition. We want the constraint term in canonical form to be `-leftTime` // so that resolve for it as `-leftTime + watermark + c < 0` ==> `watermark + c < leftTime`. // Now, if the original conditions is `rightTime-with-watermark > leftTime` and watermark diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala index d471d754e7f8b..2cd069e5858da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala @@ -236,6 +236,9 @@ object TableOutputResolver { val casted = storeAssignmentPolicy match { case StoreAssignmentPolicy.ANSI => AnsiCast(queryExpr, tableAttr.dataType, Option(conf.sessionLocalTimeZone)) + case StoreAssignmentPolicy.LEGACY => + Cast(queryExpr, tableAttr.dataType, Option(conf.sessionLocalTimeZone), + ansiEnabled = false) case _ => Cast(queryExpr, tableAttr.dataType, Option(conf.sessionLocalTimeZone)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TimeTravelSpec.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TimeTravelSpec.scala index cbb6e8bb06a4c..7e79c03b5ff6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TimeTravelSpec.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TimeTravelSpec.scala @@ -41,7 +41,7 @@ object TimeTravelSpec { throw QueryCompilationErrors.invalidTimestampExprForTimeTravel(ts) } val tsToEval = ts.transform { - case r: RuntimeReplaceable => r.child + case r: RuntimeReplaceable => r.replacement case _: Unevaluable => throw QueryCompilationErrors.invalidTimestampExprForTimeTravel(ts) case e if !e.deterministic => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index b861e5df72c3a..9d24ae4a15950 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -591,7 +591,7 @@ case class GetViewColumnByNameAndOrdinal( override def dataType: DataType = throw new UnresolvedException("dataType") override def nullable: Boolean = throw new UnresolvedException("nullable") override lazy val resolved = false - override def stringArgs: Iterator[Any] = super.stringArgs.toSeq.dropRight(1).toIterator + override def stringArgs: Iterator[Any] = super.stringArgs.toSeq.dropRight(1).iterator } /** @@ -645,4 +645,5 @@ case class TempResolvedColumn(child: Expression, nameParts: Seq[String]) extends override def dataType: DataType = child.dataType override protected def withNewChildInternal(newChild: Expression): Expression = copy(child = newChild) + override def sql: String = child.sql } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index e3896c598eac9..5ca96f097b2f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -52,7 +52,7 @@ class InMemoryCatalog( import CatalogTypes.TablePartitionSpec private class TableDesc(var table: CatalogTable) { - val partitions = new mutable.HashMap[TablePartitionSpec, CatalogTablePartition] + var partitions = new mutable.HashMap[TablePartitionSpec, CatalogTablePartition] } private class DatabaseDesc(var db: CatalogDatabase) { @@ -298,8 +298,17 @@ class InMemoryCatalog( oldName, newName, oldDir, e) } oldDesc.table = oldDesc.table.withNewStorage(locationUri = Some(newDir.toUri)) - } + val newPartitions = oldDesc.partitions.map { case (spec, partition) => + val storage = partition.storage + val newLocationUri = storage.locationUri.map { uri => + new Path(uri.toString.replace(oldDir.toString, newDir.toString)).toUri + } + val newPartition = partition.copy(storage = storage.copy(locationUri = newLocationUri)) + (spec, newPartition) + } + oldDesc.partitions = newPartitions + } catalog(db).tables.put(newName, oldDesc) catalog(db).tables.remove(oldName) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index c712d2ccccade..3727bb3c101cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -365,10 +365,12 @@ class SessionCatalog( if (!ignoreIfExists) { throw new TableAlreadyExistsException(db = db, table = table) } - } else if (validateLocation) { - validateTableLocation(newTableDefinition) + } else { + if (validateLocation) { + validateTableLocation(newTableDefinition) + } + externalCatalog.createTable(newTableDefinition, ignoreIfExists) } - externalCatalog.createTable(newTableDefinition, ignoreIfExists) } def validateTableLocation(table: CatalogTable): Unit = { @@ -388,6 +390,8 @@ class SessionCatalog( private def makeQualifiedTablePath(locationUri: URI, database: String): URI = { if (locationUri.isAbsolute) { locationUri + } else if (new Path(locationUri).isAbsolute) { + makeQualifiedPath(locationUri) } else { val dbName = formatDatabaseName(database) val dbLocation = makeQualifiedDBPath(getDatabaseMetadata(dbName).locationUri) @@ -1553,18 +1557,24 @@ class SessionCatalog( /** * Look up the `ExpressionInfo` of the given function by name if it's a built-in or temp function. - * This supports both scalar and table functions. + * This only supports scalar functions. */ def lookupBuiltinOrTempFunction(name: String): Option[ExpressionInfo] = { FunctionRegistry.builtinOperators.get(name.toLowerCase(Locale.ROOT)).orElse { - def lookup(ident: FunctionIdentifier): Option[ExpressionInfo] = { - functionRegistry.lookupFunction(ident).orElse( - tableFunctionRegistry.lookupFunction(ident)) - } - synchronized(lookupTempFuncWithViewContext(name, isBuiltinFunction, lookup)) + synchronized(lookupTempFuncWithViewContext( + name, FunctionRegistry.builtin.functionExists, functionRegistry.lookupFunction)) } } + /** + * Look up the `ExpressionInfo` of the given function by name if it's a built-in or + * temp table function. + */ + def lookupBuiltinOrTempTableFunction(name: String): Option[ExpressionInfo] = synchronized { + lookupTempFuncWithViewContext( + name, TableFunctionRegistry.builtin.functionExists, tableFunctionRegistry.lookupFunction) + } + /** * Look up a built-in or temp scalar function by name and resolves it to an Expression if such * a function exists. @@ -1709,15 +1719,16 @@ class SessionCatalog( */ def lookupFunctionInfo(name: FunctionIdentifier): ExpressionInfo = synchronized { if (name.database.isEmpty) { - lookupBuiltinOrTempFunction(name.funcName).getOrElse(lookupPersistentFunction(name)) + lookupBuiltinOrTempFunction(name.funcName) + .orElse(lookupBuiltinOrTempTableFunction(name.funcName)) + .getOrElse(lookupPersistentFunction(name)) } else { lookupPersistentFunction(name) } } - // Test only. The actual function lookup logic looks up temp/built-in function first, then - // persistent function from either v1 or v2 catalog. This method only look up v1 catalog and is - // no longer valid. + // The actual function lookup logic looks up temp/built-in function first, then persistent + // function from either v1 or v2 catalog. This method only look up v1 catalog. def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = { if (name.database.isEmpty) { resolveBuiltinOrTempFunction(name.funcName, children) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 70ccb06c109fc..4ab14c3156294 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.catalog import java.net.URI -import java.time.ZoneOffset +import java.time.{ZoneId, ZoneOffset} import java.util.Date import scala.collection.mutable @@ -656,10 +656,13 @@ object CatalogColumnStat extends Logging { val VERSION = 2 - private def getTimestampFormatter(isParsing: Boolean): TimestampFormatter = { + def getTimestampFormatter( + isParsing: Boolean, + format: String = "yyyy-MM-dd HH:mm:ss.SSSSSS", + zoneId: ZoneId = ZoneOffset.UTC): TimestampFormatter = { TimestampFormatter( - format = "yyyy-MM-dd HH:mm:ss.SSSSSS", - zoneId = ZoneOffset.UTC, + format = format, + zoneId = zoneId, isParsing = isParsing) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala index 9d6582476b76b..5dd8c35e4c2e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala @@ -60,6 +60,7 @@ class UnivocityGenerator( legacyFormat = FAST_DATE_FORMAT, isParsing = false) + @scala.annotation.tailrec private def makeConverter(dataType: DataType): ValueConverter = dataType match { case DateType => (row: InternalRow, ordinal: Int) => dateFormatter.format(row.getInt(ordinal)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 1c95ec8d1a573..62b3ee7440745 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -138,6 +138,14 @@ package object dsl { } } + def castNullable(): Expression = { + if (expr.resolved && expr.nullable) { + expr + } else { + KnownNullable(expr) + } + } + def asc: SortOrder = SortOrder(expr, Ascending) def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, Seq.empty) def desc: SortOrder = SortOrder(expr, Descending) @@ -432,7 +440,7 @@ package object dsl { def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): LogicalPlan = { val aliasedExprs = aggregateExprs.map { case ne: NamedExpression => ne - case e => Alias(e, e.toString)() + case e => UnresolvedAlias(e) } Aggregate(groupingExprs, aliasedExprs, logicalPlan) } @@ -490,6 +498,9 @@ package object dsl { def distribute(exprs: Expression*)(n: Int): LogicalPlan = RepartitionByExpression(exprs, logicalPlan, numPartitions = n) + def rebalance(exprs: Expression*): LogicalPlan = + RebalancePartitions(exprs, logicalPlan) + def analyze: LogicalPlan = { val analyzed = analysis.SimpleAnalyzer.execute(logicalPlan) analysis.SimpleAnalyzer.checkAnalysis(analyzed) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index d34d9531c3f34..d7e497fafa86a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -66,23 +66,27 @@ import org.apache.spark.sql.types._ * }}} */ object RowEncoder { - def apply(schema: StructType): ExpressionEncoder[Row] = { + def apply(schema: StructType, lenient: Boolean): ExpressionEncoder[Row] = { val cls = classOf[Row] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - val serializer = serializerFor(inputObject, schema) + val serializer = serializerFor(inputObject, schema, lenient) val deserializer = deserializerFor(GetColumnByOrdinal(0, serializer.dataType), schema) new ExpressionEncoder[Row]( serializer, deserializer, ClassTag(cls)) } + def apply(schema: StructType): ExpressionEncoder[Row] = { + apply(schema, lenient = false) + } private def serializerFor( inputObject: Expression, - inputType: DataType): Expression = inputType match { + inputType: DataType, + lenient: Boolean): Expression = inputType match { case dt if ScalaReflection.isNativeType(dt) => inputObject - case p: PythonUserDefinedType => serializerFor(inputObject, p.sqlType) + case p: PythonUserDefinedType => serializerFor(inputObject, p.sqlType, lenient) case udt: UserDefinedType[_] => val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]) @@ -100,7 +104,9 @@ object RowEncoder { Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false) case TimestampType => - if (SQLConf.get.datetimeJava8ApiEnabled) { + if (lenient) { + createSerializerForAnyTimestamp(inputObject) + } else if (SQLConf.get.datetimeJava8ApiEnabled) { createSerializerForJavaInstant(inputObject) } else { createSerializerForSqlTimestamp(inputObject) @@ -109,7 +115,9 @@ object RowEncoder { case TimestampNTZType => createSerializerForLocalDateTime(inputObject) case DateType => - if (SQLConf.get.datetimeJava8ApiEnabled) { + if (lenient) { + createSerializerForAnyDate(inputObject) + } else if (SQLConf.get.datetimeJava8ApiEnabled) { createSerializerForJavaLocalDate(inputObject) } else { createSerializerForSqlDate(inputObject) @@ -144,7 +152,7 @@ object RowEncoder { inputObject, ObjectType(classOf[Object]), element => { - val value = serializerFor(ValidateExternalType(element, et), et) + val value = serializerFor(ValidateExternalType(element, et, lenient), et, lenient) expressionWithNullSafety(value, containsNull, WalkedTypePath()) }) } @@ -156,7 +164,7 @@ object RowEncoder { returnNullable = false), "toSeq", ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false) - val convertedKeys = serializerFor(keys, ArrayType(kt, false)) + val convertedKeys = serializerFor(keys, ArrayType(kt, false), lenient) val values = Invoke( @@ -164,7 +172,7 @@ object RowEncoder { returnNullable = false), "toSeq", ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false) - val convertedValues = serializerFor(values, ArrayType(vt, valueNullable)) + val convertedValues = serializerFor(values, ArrayType(vt, valueNullable), lenient) val nonNullOutput = NewInstance( classOf[ArrayBasedMapData], @@ -183,8 +191,10 @@ object RowEncoder { val fieldValue = serializerFor( ValidateExternalType( GetExternalRowField(inputObject, index, field.name), - field.dataType), - field.dataType) + field.dataType, + lenient), + field.dataType, + lenient) val convertedField = if (field.nullable) { If( Invoke(inputObject, "isNullAt", BooleanType, Literal(index) :: Nil), @@ -214,12 +224,13 @@ object RowEncoder { * can be `scala.math.BigDecimal`, `java.math.BigDecimal`, or * `org.apache.spark.sql.types.Decimal`. */ - def externalDataTypeForInput(dt: DataType): DataType = dt match { + def externalDataTypeForInput(dt: DataType, lenient: Boolean): DataType = dt match { // In order to support both Decimal and java/scala BigDecimal in external row, we make this // as java.lang.Object. case _: DecimalType => ObjectType(classOf[java.lang.Object]) // In order to support both Array and Seq in external row, we make this as java.lang.Object. case _: ArrayType => ObjectType(classOf[java.lang.Object]) + case _: DateType | _: TimestampType if lenient => ObjectType(classOf[java.lang.Object]) case _ => externalDataTypeFor(dt) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala index b33b9ed57f112..da4000f53e3e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala @@ -31,6 +31,8 @@ case class ApplyFunctionExpression( override def name: String = function.name() override def dataType: DataType = function.resultType() override def inputTypes: Seq[AbstractDataType] = function.inputTypes().toSeq + override lazy val deterministic: Boolean = function.isDeterministic && + children.forall(_.deterministic) private lazy val reusedRow = new SpecificInternalRow(function.inputTypes()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index e5fa433b78d64..39463ed122b6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -23,7 +23,7 @@ import java.util.concurrent.TimeUnit._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} -import org.apache.spark.sql.catalyst.expressions.Cast.{forceNullable, resolvableNullability} +import org.apache.spark.sql.catalyst.expressions.Cast.resolvableNullability import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TreeNodeTag @@ -293,10 +293,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit */ def typeCheckFailureMessage: String - override def toString: String = { - val ansi = if (ansiEnabled) "ansi_" else "" - s"${ansi}cast($child as ${dataType.simpleString})" - } + override def toString: String = s"cast($child as ${dataType.simpleString})" override def checkInputDataTypes(): TypeCheckResult = { if (canCast(child.dataType, dataType)) { @@ -671,7 +668,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit if (longValue == longValue.toInt) { longValue.toInt } else { - throw QueryExecutionErrors.castingCauseOverflowError(t, IntegerType.catalogString) + throw QueryExecutionErrors.castingCauseOverflowError(t, IntegerType) } }) case TimestampType => @@ -707,7 +704,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit if (longValue == longValue.toShort) { longValue.toShort } else { - throw QueryExecutionErrors.castingCauseOverflowError(t, ShortType.catalogString) + throw QueryExecutionErrors.castingCauseOverflowError(t, ShortType) } }) case TimestampType => @@ -718,12 +715,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit x.exactNumeric.asInstanceOf[Numeric[Any]].toInt(b) } catch { case _: ArithmeticException => - throw QueryExecutionErrors.castingCauseOverflowError(b, ShortType.catalogString) + throw QueryExecutionErrors.castingCauseOverflowError(b, ShortType) } if (intValue == intValue.toShort) { intValue.toShort } else { - throw QueryExecutionErrors.castingCauseOverflowError(b, ShortType.catalogString) + throw QueryExecutionErrors.castingCauseOverflowError(b, ShortType) } case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort @@ -754,7 +751,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit if (longValue == longValue.toByte) { longValue.toByte } else { - throw QueryExecutionErrors.castingCauseOverflowError(t, ByteType.catalogString) + throw QueryExecutionErrors.castingCauseOverflowError(t, ByteType) } }) case TimestampType => @@ -765,12 +762,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit x.exactNumeric.asInstanceOf[Numeric[Any]].toInt(b) } catch { case _: ArithmeticException => - throw QueryExecutionErrors.castingCauseOverflowError(b, ByteType.catalogString) + throw QueryExecutionErrors.castingCauseOverflowError(b, ByteType) } if (intValue == intValue.toByte) { intValue.toByte } else { - throw QueryExecutionErrors.castingCauseOverflowError(b, ByteType.catalogString) + throw QueryExecutionErrors.castingCauseOverflowError(b, ByteType) } case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte @@ -1639,20 +1636,21 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit private[this] def castTimestampToIntegralTypeCode( ctx: CodegenContext, integralType: String, - catalogType: String): CastFunction = { + dataType: DataType): CastFunction = { if (ansiEnabled) { val longValue = ctx.freshName("longValue") - (c, evPrim, evNull) => + val dt = ctx.addReferenceObj("dataType", dataType, dataType.getClass.getName) + (c, evPrim, _) => code""" long $longValue = ${timestampToLongCode(c)}; if ($longValue == ($integralType) $longValue) { $evPrim = ($integralType) $longValue; } else { - throw QueryExecutionErrors.castingCauseOverflowError($c, "$catalogType"); + throw QueryExecutionErrors.castingCauseOverflowError($c, $dt); } """ } else { - (c, evPrim, evNull) => code"$evPrim = ($integralType) ${timestampToLongCode(c)};" + (c, evPrim, _) => code"$evPrim = ($integralType) ${timestampToLongCode(c)};" } } @@ -1678,31 +1676,31 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit """ } - private[this] def castDecimalToIntegralTypeCode( - ctx: CodegenContext, - integralType: String, - catalogType: String): CastFunction = { + private[this] def castDecimalToIntegralTypeCode(integralType: String): CastFunction = { if (ansiEnabled) { - (c, evPrim, evNull) => code"$evPrim = $c.roundTo${integralType.capitalize}();" + (c, evPrim, _) => code"$evPrim = $c.roundTo${integralType.capitalize}();" } else { - (c, evPrim, evNull) => code"$evPrim = $c.to${integralType.capitalize}();" + (c, evPrim, _) => code"$evPrim = $c.to${integralType.capitalize}();" } } private[this] def castIntegralTypeToIntegralTypeExactCode( + ctx: CodegenContext, integralType: String, - catalogType: String): CastFunction = { + dataType: DataType): CastFunction = { assert(ansiEnabled) - (c, evPrim, evNull) => + val dt = ctx.addReferenceObj("dataType", dataType, dataType.getClass.getName) + (c, evPrim, _) => code""" if ($c == ($integralType) $c) { $evPrim = ($integralType) $c; } else { - throw QueryExecutionErrors.castingCauseOverflowError($c, "$catalogType"); + throw QueryExecutionErrors.castingCauseOverflowError($c, $dt); } """ } + private[this] def lowerAndUpperBound(integralType: String): (String, String) = { val (min, max, typeIndicator) = integralType.toLowerCase(Locale.ROOT) match { case "long" => (Long.MinValue, Long.MaxValue, "L") @@ -1714,22 +1712,24 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } private[this] def castFractionToIntegralTypeCode( + ctx: CodegenContext, integralType: String, - catalogType: String): CastFunction = { + dataType: DataType): CastFunction = { assert(ansiEnabled) val (min, max) = lowerAndUpperBound(integralType) val mathClass = classOf[Math].getName + val dt = ctx.addReferenceObj("dataType", dataType, dataType.getClass.getName) // When casting floating values to integral types, Spark uses the method `Numeric.toInt` // Or `Numeric.toLong` directly. For positive floating values, it is equivalent to `Math.floor`; // for negative floating values, it is equivalent to `Math.ceil`. // So, we can use the condition `Math.floor(x) <= upperBound && Math.ceil(x) >= lowerBound` // to check if the floating value x is in the range of an integral type after rounding. - (c, evPrim, evNull) => + (c, evPrim, _) => code""" if ($mathClass.floor($c) <= $max && $mathClass.ceil($c) >= $min) { $evPrim = ($integralType) $c; } else { - throw QueryExecutionErrors.castingCauseOverflowError($c, "$catalogType"); + throw QueryExecutionErrors.castingCauseOverflowError($c, $dt); } """ } @@ -1754,12 +1754,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit (c, evPrim, evNull) => code"$evPrim = $c ? (byte) 1 : (byte) 0;" case DateType => (c, evPrim, evNull) => code"$evNull = true;" - case TimestampType => castTimestampToIntegralTypeCode(ctx, "byte", ByteType.catalogString) - case DecimalType() => castDecimalToIntegralTypeCode(ctx, "byte", ByteType.catalogString) + case TimestampType => castTimestampToIntegralTypeCode(ctx, "byte", ByteType) + case DecimalType() => castDecimalToIntegralTypeCode("byte") case ShortType | IntegerType | LongType if ansiEnabled => - castIntegralTypeToIntegralTypeExactCode("byte", ByteType.catalogString) + castIntegralTypeToIntegralTypeExactCode(ctx, "byte", ByteType) case FloatType | DoubleType if ansiEnabled => - castFractionToIntegralTypeCode("byte", ByteType.catalogString) + castFractionToIntegralTypeCode(ctx, "byte", ByteType) case x: NumericType => (c, evPrim, evNull) => code"$evPrim = (byte) $c;" case x: DayTimeIntervalType => @@ -1790,12 +1790,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit (c, evPrim, evNull) => code"$evPrim = $c ? (short) 1 : (short) 0;" case DateType => (c, evPrim, evNull) => code"$evNull = true;" - case TimestampType => castTimestampToIntegralTypeCode(ctx, "short", ShortType.catalogString) - case DecimalType() => castDecimalToIntegralTypeCode(ctx, "short", ShortType.catalogString) + case TimestampType => castTimestampToIntegralTypeCode(ctx, "short", ShortType) + case DecimalType() => castDecimalToIntegralTypeCode("short") case IntegerType | LongType if ansiEnabled => - castIntegralTypeToIntegralTypeExactCode("short", ShortType.catalogString) + castIntegralTypeToIntegralTypeExactCode(ctx, "short", ShortType) case FloatType | DoubleType if ansiEnabled => - castFractionToIntegralTypeCode("short", ShortType.catalogString) + castFractionToIntegralTypeCode(ctx, "short", ShortType) case x: NumericType => (c, evPrim, evNull) => code"$evPrim = (short) $c;" case x: DayTimeIntervalType => @@ -1824,12 +1824,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit (c, evPrim, evNull) => code"$evPrim = $c ? 1 : 0;" case DateType => (c, evPrim, evNull) => code"$evNull = true;" - case TimestampType => castTimestampToIntegralTypeCode(ctx, "int", IntegerType.catalogString) - case DecimalType() => castDecimalToIntegralTypeCode(ctx, "int", IntegerType.catalogString) + case TimestampType => castTimestampToIntegralTypeCode(ctx, "int", IntegerType) + case DecimalType() => castDecimalToIntegralTypeCode("int") case LongType if ansiEnabled => - castIntegralTypeToIntegralTypeExactCode("int", IntegerType.catalogString) + castIntegralTypeToIntegralTypeExactCode(ctx, "int", IntegerType) case FloatType | DoubleType if ansiEnabled => - castFractionToIntegralTypeCode("int", IntegerType.catalogString) + castFractionToIntegralTypeCode(ctx, "int", IntegerType) case x: NumericType => (c, evPrim, evNull) => code"$evPrim = (int) $c;" case x: DayTimeIntervalType => @@ -1860,9 +1860,9 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => (c, evPrim, evNull) => code"$evPrim = (long) ${timestampToLongCode(c)};" - case DecimalType() => castDecimalToIntegralTypeCode(ctx, "long", LongType.catalogString) + case DecimalType() => castDecimalToIntegralTypeCode("long") case FloatType | DoubleType if ansiEnabled => - castFractionToIntegralTypeCode("long", LongType.catalogString) + castFractionToIntegralTypeCode(ctx, "long", LongType) case x: NumericType => (c, evPrim, evNull) => code"$evPrim = (long) $c;" case x: DayTimeIntervalType => @@ -2226,23 +2226,17 @@ object AnsiCast { case (TimestampType, _: NumericType) => true case (ArrayType(fromType, fn), ArrayType(toType, tn)) => - canCast(fromType, toType) && - resolvableNullability(fn || forceNullable(fromType, toType), tn) + canCast(fromType, toType) && resolvableNullability(fn, tn) case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => - canCast(fromKey, toKey) && - (!forceNullable(fromKey, toKey)) && - canCast(fromValue, toValue) && - resolvableNullability(fn || forceNullable(fromValue, toValue), tn) + canCast(fromKey, toKey) && canCast(fromValue, toValue) && resolvableNullability(fn, tn) case (StructType(fromFields), StructType(toFields)) => fromFields.length == toFields.length && fromFields.zip(toFields).forall { case (fromField, toField) => canCast(fromField.dataType, toField.dataType) && - resolvableNullability( - fromField.nullable || forceNullable(fromField.dataType, toField.dataType), - toField.nullable) + resolvableNullability(fromField.nullable, toField.nullable) } case (udt1: UserDefinedType[_], udt2: UserDefinedType[_]) if udt2.acceptsType(udt1) => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 59e2be4a6f5aa..903a6fd7bd014 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -194,7 +194,7 @@ class EquivalentExpressions { expr.isInstanceOf[LeafExpression] || // `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the // loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning. - expr.find(_.isInstanceOf[LambdaVariable]).isDefined || + expr.exists(_.isInstanceOf[LambdaVariable]) || // `PlanExpression` wraps query plan. To compare query plans of `PlanExpression` on executor, // can cause error like NPE. (expr.isInstanceOf[PlanExpression[_]] && Utils.isInRunningSparkTask) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 1d54efd7319e3..b5695e8c87268 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -21,7 +21,7 @@ import java.util.Locale import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, TypeCoercion} -import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, QuaternaryLike, TernaryLike, TreeNode, UnaryLike} @@ -315,7 +315,7 @@ abstract class Expression extends TreeNode[Expression] { } override def simpleStringWithNodeId(): String = { - throw QueryExecutionErrors.simpleStringWithNodeIdUnsupportedError(nodeName) + throw new IllegalStateException(s"$nodeName does not implement simpleStringWithNodeId") } protected def typeSuffix = @@ -352,65 +352,56 @@ trait Unevaluable extends Expression { * An expression that gets replaced at runtime (currently by the optimizer) into a different * expression for evaluation. This is mainly used to provide compatibility with other databases. * For example, we use this to support "nvl" by replacing it with "coalesce". - * - * A RuntimeReplaceable should have the original parameters along with a "child" expression in the - * case class constructor, and define a normal constructor that accepts only the original - * parameters. For an example, see [[Nvl]]. To make sure the explain plan and expression SQL - * works correctly, the implementation should also override flatArguments method and sql method. */ -trait RuntimeReplaceable extends UnaryExpression with Unevaluable { - override def nullable: Boolean = child.nullable - override def dataType: DataType = child.dataType +trait RuntimeReplaceable extends Expression { + def replacement: Expression + + override val nodePatterns: Seq[TreePattern] = Seq(RUNTIME_REPLACEABLE) + override def nullable: Boolean = replacement.nullable + override def dataType: DataType = replacement.dataType // As this expression gets replaced at optimization with its `child" expression, // two `RuntimeReplaceable` are considered to be semantically equal if their "child" expressions // are semantically equal. - override lazy val preCanonicalized: Expression = child.preCanonicalized + override lazy val preCanonicalized: Expression = replacement.preCanonicalized - /** - * Only used to generate SQL representation of this expression. - * - * Implementations should override this with original parameters - */ - def exprsReplaced: Seq[Expression] - - override def sql: String = mkString(exprsReplaced.map(_.sql)) - - final override val nodePatterns: Seq[TreePattern] = Seq(RUNTIME_REPLACEABLE) + final override def eval(input: InternalRow = null): Any = + throw QueryExecutionErrors.cannotEvaluateExpressionError(this) + final override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + throw QueryExecutionErrors.cannotGenerateCodeForExpressionError(this) +} - def mkString(childrenString: Seq[String]): String = { - prettyName + childrenString.mkString("(", ", ", ")") +/** + * An add-on of [[RuntimeReplaceable]]. It makes `replacement` the child of the expression, to + * inherit the analysis rules for it, such as type coercion. The implementation should put + * `replacement` in the case class constructor, and define a normal constructor that accepts only + * the original parameters. For an example, see [[TryAdd]]. To make sure the explain plan and + * expression SQL works correctly, the implementation should also implement the `parameters` method. + */ +trait InheritAnalysisRules extends UnaryLike[Expression] { self: RuntimeReplaceable => + override def child: Expression = replacement + def parameters: Seq[Expression] + override def flatArguments: Iterator[Any] = parameters.iterator + // This method is used to generate a SQL string with transformed inputs. This is necessary as + // the actual inputs are not the children of this expression. + def makeSQLString(childrenSQL: Seq[String]): String = { + prettyName + childrenSQL.mkString("(", ", ", ")") } + final override def sql: String = makeSQLString(parameters.map(_.sql)) } /** - * An aggregate expression that gets rewritten (currently by the optimizer) into a + * An add-on of [[AggregateFunction]]. This gets rewritten (currently by the optimizer) into a * different aggregate expression for evaluation. This is mainly used to provide compatibility * with other databases. For example, we use this to support every, any/some aggregates by rewriting * them with Min and Max respectively. */ -trait UnevaluableAggregate extends DeclarativeAggregate { - - override def nullable: Boolean = true - - override lazy val aggBufferAttributes = - throw QueryExecutionErrors.evaluateUnevaluableAggregateUnsupportedError( - "aggBufferAttributes", this) - - override lazy val initialValues: Seq[Expression] = - throw QueryExecutionErrors.evaluateUnevaluableAggregateUnsupportedError( - "initialValues", this) - - override lazy val updateExpressions: Seq[Expression] = - throw QueryExecutionErrors.evaluateUnevaluableAggregateUnsupportedError( - "updateExpressions", this) - - override lazy val mergeExpressions: Seq[Expression] = - throw QueryExecutionErrors.evaluateUnevaluableAggregateUnsupportedError( - "mergeExpressions", this) - - override lazy val evaluateExpression: Expression = - throw QueryExecutionErrors.evaluateUnevaluableAggregateUnsupportedError( - "evaluateExpression", this) +trait RuntimeReplaceableAggregate extends RuntimeReplaceable { self: AggregateFunction => + override def aggBufferSchema: StructType = throw new IllegalStateException( + "RuntimeReplaceableAggregate.aggBufferSchema should not be called") + override def aggBufferAttributes: Seq[AttributeReference] = throw new IllegalStateException( + "RuntimeReplaceableAggregate.aggBufferAttributes should not be called") + override def inputAggBufferAttributes: Seq[AttributeReference] = throw new IllegalStateException( + "RuntimeReplaceableAggregate.inputAggBufferAttributes should not be called") } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index d02d1e8b55b9d..731ad16cc7d9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter} import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{UserDefinedType, _} import org.apache.spark.unsafe.Platform @@ -254,7 +253,8 @@ object InterpretedUnsafeProjection { (_, _) => {} case _ => - throw QueryExecutionErrors.dataTypeUnsupportedError(dt) + throw new IllegalStateException(s"The data type '${dt.typeName}' is not supported in " + + "generating a writer function for a struct field, array element, map key or map value.") } // Always wrap the writer with a null safe version. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala index 2a182b6424db2..fd5b2db61f31e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala @@ -33,8 +33,8 @@ object SchemaPruning extends SQLConfHelper { * 1. The schema field ordering at original schema is still preserved in pruned schema. * 2. The top-level fields are not pruned here. */ - def pruneDataSchema( - dataSchema: StructType, + def pruneSchema( + schema: StructType, requestedRootFields: Seq[RootField]): StructType = { val resolver = conf.resolver // Merge the requested root fields into a single schema. Note the ordering of the fields @@ -44,10 +44,10 @@ object SchemaPruning extends SQLConfHelper { .map { root: RootField => StructType(Array(root.field)) } .reduceLeft(_ merge _) val mergedDataSchema = - StructType(dataSchema.map(d => mergedSchema.find(m => resolver(m.name, d.name)).getOrElse(d))) + StructType(schema.map(d => mergedSchema.find(m => resolver(m.name, d.name)).getOrElse(d))) // Sort the fields of mergedDataSchema according to their order in dataSchema, // recursively. This makes mergedDataSchema a pruned schema of dataSchema - sortLeftFieldsByRight(mergedDataSchema, dataSchema).asInstanceOf[StructType] + sortLeftFieldsByRight(mergedDataSchema, schema).asInstanceOf[StructType] } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 8e6f07611bfe8..974d4b5f86889 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -132,7 +132,8 @@ object SortOrder { case class SortPrefix(child: SortOrder) extends UnaryExpression { val nullValue = child.child.dataType match { - case BooleanType | DateType | TimestampType | _: IntegralType | _: AnsiIntervalType => + case BooleanType | DateType | TimestampType | TimestampNTZType | + _: IntegralType | _: AnsiIntervalType => if (nullAsSmallest) Long.MinValue else Long.MaxValue case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS => if (nullAsSmallest) Long.MinValue else Long.MaxValue @@ -154,7 +155,8 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { private lazy val calcPrefix: Any => Long = child.child.dataType match { case BooleanType => (raw) => if (raw.asInstanceOf[Boolean]) 1 else 0 - case DateType | TimestampType | _: IntegralType | _: AnsiIntervalType => (raw) => + case DateType | TimestampType | TimestampNTZType | + _: IntegralType | _: AnsiIntervalType => (raw) => raw.asInstanceOf[java.lang.Number].longValue() case FloatType | DoubleType => (raw) => { val dVal = raw.asInstanceOf[java.lang.Number].doubleValue() @@ -198,7 +200,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { s"$input ? 1L : 0L" case _: IntegralType => s"(long) $input" - case DateType | TimestampType | _: AnsiIntervalType => + case DateType | TimestampType | TimestampNTZType | _: AnsiIntervalType => s"(long) $input" case FloatType | DoubleType => s"$DoublePrefixCmp.computePrefix((double)$input)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryCast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryCast.scala index 0f63de1bf7e45..f43a80bf997a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryCast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryCast.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.Cast.{forceNullable, resolvableNullability} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} /** * A special version of [[AnsiCast]]. It performs the same operation (i.e. converts a value of @@ -56,7 +57,32 @@ case class TryCast(child: Expression, dataType: DataType, timeZoneId: Option[Str override def nullable: Boolean = true - override def canCast(from: DataType, to: DataType): Boolean = AnsiCast.canCast(from, to) + // If the target data type is a complex type which can't have Null values, we should guarantee + // that the casting between the element types won't produce Null results. + override def canCast(from: DataType, to: DataType): Boolean = (from, to) match { + case (ArrayType(fromType, fn), ArrayType(toType, tn)) => + canCast(fromType, toType) && + resolvableNullability(fn || forceNullable(fromType, toType), tn) + + case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => + canCast(fromKey, toKey) && + (!forceNullable(fromKey, toKey)) && + canCast(fromValue, toValue) && + resolvableNullability(fn || forceNullable(fromValue, toValue), tn) + + case (StructType(fromFields), StructType(toFields)) => + fromFields.length == toFields.length && + fromFields.zip(toFields).forall { + case (fromField, toField) => + canCast(fromField.dataType, toField.dataType) && + resolvableNullability( + fromField.nullable || forceNullable(fromField.dataType, toField.dataType), + toField.nullable) + } + + case _ => + AnsiCast.canCast(from, to) + } override def cast(from: DataType, to: DataType): Any => Any = (input: Any) => try { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala index bc2604a3447ed..7a8a689a1bd3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala @@ -75,19 +75,17 @@ case class TryEval(child: Expression) extends UnaryExpression with NullIntoleran since = "3.2.0", group = "math_funcs") // scalastyle:on line.size.limit -case class TryAdd(left: Expression, right: Expression, child: Expression) - extends RuntimeReplaceable { +case class TryAdd(left: Expression, right: Expression, replacement: Expression) + extends RuntimeReplaceable with InheritAnalysisRules { def this(left: Expression, right: Expression) = this(left, right, TryEval(Add(left, right, failOnError = true))) - override def flatArguments: Iterator[Any] = Iterator(left, right) - - override def exprsReplaced: Seq[Expression] = Seq(left, right) - override def prettyName: String = "try_add" + override def parameters: Seq[Expression] = Seq(left, right) + override protected def withNewChildInternal(newChild: Expression): Expression = - this.copy(child = newChild) + this.copy(replacement = newChild) } // scalastyle:off line.size.limit @@ -110,17 +108,76 @@ case class TryAdd(left: Expression, right: Expression, child: Expression) since = "3.2.0", group = "math_funcs") // scalastyle:on line.size.limit -case class TryDivide(left: Expression, right: Expression, child: Expression) - extends RuntimeReplaceable { +case class TryDivide(left: Expression, right: Expression, replacement: Expression) + extends RuntimeReplaceable with InheritAnalysisRules { def this(left: Expression, right: Expression) = this(left, right, TryEval(Divide(left, right, failOnError = true))) - override def flatArguments: Iterator[Any] = Iterator(left, right) + override def prettyName: String = "try_divide" + + override def parameters: Seq[Expression] = Seq(left, right) - override def exprsReplaced: Seq[Expression] = Seq(left, right) + override protected def withNewChildInternal(newChild: Expression): Expression = { + copy(replacement = newChild) + } +} - override def prettyName: String = "try_divide" +@ExpressionDescription( + usage = "expr1 _FUNC_ expr2 - Returns `expr1`-`expr2` and the result is null on overflow. " + + "The acceptable input types are the same with the `-` operator.", + examples = """ + Examples: + > SELECT _FUNC_(2, 1); + 1 + > SELECT _FUNC_(-2147483648, 1); + NULL + > SELECT _FUNC_(date'2021-01-02', 1); + 2021-01-01 + > SELECT _FUNC_(date'2021-01-01', interval 1 year); + 2020-01-01 + > SELECT _FUNC_(timestamp'2021-01-02 00:00:00', interval 1 day); + 2021-01-01 00:00:00 + > SELECT _FUNC_(interval 2 year, interval 1 year); + 1-0 + """, + since = "3.3.0", + group = "math_funcs") +case class TrySubtract(left: Expression, right: Expression, replacement: Expression) + extends RuntimeReplaceable with InheritAnalysisRules { + def this(left: Expression, right: Expression) = + this(left, right, TryEval(Subtract(left, right, failOnError = true))) + + override def prettyName: String = "try_subtract" + + override def parameters: Seq[Expression] = Seq(left, right) + + override protected def withNewChildInternal(newChild: Expression): Expression = + this.copy(replacement = newChild) +} + +@ExpressionDescription( + usage = "expr1 _FUNC_ expr2 - Returns `expr1`*`expr2` and the result is null on overflow. " + + "The acceptable input types are the same with the `*` operator.", + examples = """ + Examples: + > SELECT _FUNC_(2, 3); + 6 + > SELECT _FUNC_(-2147483648, 10); + NULL + > SELECT _FUNC_(interval 2 year, 3); + 6-0 + """, + since = "3.3.0", + group = "math_funcs") +case class TryMultiply(left: Expression, right: Expression, replacement: Expression) + extends RuntimeReplaceable with InheritAnalysisRules { + def this(left: Expression, right: Expression) = + this(left, right, TryEval(Multiply(left, right, failOnError = true))) + + override def prettyName: String = "try_multiply" + + override def parameters: Seq[Expression] = Seq(left, right) override protected def withNewChildInternal(newChild: Expression): Expression = - this.copy(child = newChild) + this.copy(replacement = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 9714a096a69a2..05f7edaeb5d48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -69,7 +69,7 @@ case class Average( case _ => DoubleType } - private lazy val sumDataType = child.dataType match { + lazy val sumDataType = child.dataType match { case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s) case _: YearMonthIntervalType => YearMonthIntervalType() case _: DayTimeIntervalType => DayTimeIntervalType() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala index 66800b277ffed..248ade05ab1d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala @@ -17,11 +17,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ImplicitCastInputTypes, UnevaluableAggregate} -import org.apache.spark.sql.catalyst.trees.TreePattern.{COUNT_IF, TreePattern} +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ImplicitCastInputTypes, Literal, NullIf, RuntimeReplaceableAggregate} import org.apache.spark.sql.catalyst.trees.UnaryLike -import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType, LongType} +import org.apache.spark.sql.types.{AbstractDataType, BooleanType} @ExpressionDescription( usage = """ @@ -36,30 +34,14 @@ import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType, Long """, group = "agg_funcs", since = "3.0.0") -case class CountIf(predicate: Expression) extends UnevaluableAggregate with ImplicitCastInputTypes +case class CountIf(child: Expression) + extends AggregateFunction + with RuntimeReplaceableAggregate + with ImplicitCastInputTypes with UnaryLike[Expression] { - - override def prettyName: String = "count_if" - - override def child: Expression = predicate - - override def nullable: Boolean = false - - override def dataType: DataType = LongType - + override lazy val replacement: Expression = Count(new NullIf(child, Literal.FalseLiteral)) + override def nodeName: String = "count_if" override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType) - - final override val nodePatterns: Seq[TreePattern] = Seq(COUNT_IF) - - override def checkInputDataTypes(): TypeCheckResult = predicate.dataType match { - case BooleanType => - TypeCheckResult.TypeCheckSuccess - case _ => - TypeCheckResult.TypeCheckFailure( - s"function $prettyName requires boolean type, not ${predicate.dataType.catalogString}" - ) - } - override protected def withNewChildInternal(newChild: Expression): CountIf = - copy(predicate = newChild) + copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumeric.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumeric.scala index 09408e6eff18a..23609faad9a76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumeric.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumeric.scala @@ -27,7 +27,8 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ImplicitCastInputTypes} import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.util.GenericArrayData -import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, DateType, DayTimeIntervalType, DoubleType, IntegerType, NumericType, StructField, StructType, TimestampNTZType, TimestampType, TypeCollection, YearMonthIntervalType} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ import org.apache.spark.sql.util.NumericHistogram /** @@ -46,12 +47,13 @@ import org.apache.spark.sql.util.NumericHistogram smaller datasets. Note that this function creates a histogram with non-uniform bin widths. It offers no guarantees in terms of the mean-squared-error of the histogram, but in practice is comparable to the histograms produced by the R/S-Plus - statistical computing packages. + statistical computing packages. Note: the output type of the 'x' field in the return value is + propagated from the input value consumed in the aggregate function. """, examples = """ Examples: > SELECT _FUNC_(col, 5) FROM VALUES (0), (1), (2), (10) AS tab(col); - [{"x":0.0,"y":1.0},{"x":1.0,"y":1.0},{"x":2.0,"y":1.0},{"x":10.0,"y":1.0}] + [{"x":0,"y":1.0},{"x":1,"y":1.0},{"x":2,"y":1.0},{"x":10,"y":1.0}] """, group = "agg_funcs", since = "3.3.0") @@ -72,6 +74,8 @@ case class HistogramNumeric( case n: Int => n } + private lazy val propagateInputType: Boolean = SQLConf.get.histogramNumericPropagateInputType + override def inputTypes: Seq[AbstractDataType] = { // Support NumericType, DateType, TimestampType and TimestampNTZType, YearMonthIntervalType, // DayTimeIntervalType since their internal types are all numeric, @@ -124,8 +128,33 @@ case class HistogramNumeric( null } else { val result = (0 until buffer.getUsedBins).map { index => + // Note that the 'coord.x' and 'coord.y' have double-precision floating point type here. val coord = buffer.getBin(index) - InternalRow.apply(coord.x, coord.y) + if (propagateInputType) { + // If the SQLConf.spark.sql.legacy.histogramNumericPropagateInputType is set to true, + // we need to internally convert the 'coord.x' value to the expected result type, for + // cases like integer types, timestamps, and intervals which are valid inputs to the + // numeric histogram aggregate function. For example, in this case: + // 'SELECT histogram_numeric(val, 3) FROM VALUES (0L), (1L), (2L), (10L) AS tab(col)' + // returns an array of structs where the first field has LongType. + val result: Any = left.dataType match { + case ByteType => coord.x.toByte + case IntegerType | DateType | _: YearMonthIntervalType => + coord.x.toInt + case FloatType => coord.x.toFloat + case ShortType => coord.x.toShort + case _: DayTimeIntervalType | LongType | TimestampType | TimestampNTZType => + coord.x.toLong + case _ => coord.x + } + InternalRow.apply(result, coord.y) + } else { + // Otherwise, just apply the double-precision values in 'coord.x' and 'coord.y' to the + // output row directly. In this case: 'SELECT histogram_numeric(val, 3) + // FROM VALUES (0L), (1L), (2L), (10L) AS tab(col)' returns an array of structs where the + // first field has DoubleType. + InternalRow.apply(coord.x, coord.y) + } } new GenericArrayData(result) } @@ -157,10 +186,17 @@ case class HistogramNumeric( override def nullable: Boolean = true - override def dataType: DataType = + override def dataType: DataType = { + // If the SQLConf.spark.sql.legacy.histogramNumericPropagateInputType is set to true, + // the output data type of this aggregate function is an array of structs, where each struct + // has two fields (x, y): one of the same data type as the left child and another of double + // type. Otherwise, the 'x' field always has double type. ArrayType(new StructType(Array( - StructField("x", DoubleType, true), + StructField(name = "x", + dataType = if (propagateInputType) left.dataType else DoubleType, + nullable = true), StructField("y", DoubleType, true))), true) + } override def prettyName: String = "histogram_numeric" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index 7d3dd0ae1c52e..a98585e0ff1e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -198,14 +198,12 @@ case class Percentile( return Seq.empty } - val ordering = - if (child.dataType.isInstanceOf[NumericType]) { - child.dataType.asInstanceOf[NumericType].ordering - } else if (child.dataType.isInstanceOf[YearMonthIntervalType]) { - child.dataType.asInstanceOf[YearMonthIntervalType].ordering - } else if (child.dataType.isInstanceOf[DayTimeIntervalType]) { - child.dataType.asInstanceOf[DayTimeIntervalType].ordering - } + val ordering = child.dataType match { + case numericType: NumericType => numericType.ordering + case intervalType: YearMonthIntervalType => intervalType.ordering + case intervalType: DayTimeIntervalType => intervalType.ordering + case otherType => QueryExecutionErrors.unsupportedTypeError(otherType) + } val sortedCounts = buffer.toSeq.sortBy(_._1)(ordering.asInstanceOf[Ordering[AnyRef]]) val accumulatedCounts = sortedCounts.scanLeft((sortedCounts.head._1, 0L)) { case ((key1, count1), (key2, count2)) => (key2, count1 + count2) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/RegrCount.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/RegrCount.scala deleted file mode 100644 index 57dbc14a1702d..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/RegrCount.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.aggregate - -import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ImplicitCastInputTypes, UnevaluableAggregate} -import org.apache.spark.sql.catalyst.trees.BinaryLike -import org.apache.spark.sql.catalyst.trees.TreePattern.{REGR_COUNT, TreePattern} -import org.apache.spark.sql.types.{AbstractDataType, DataType, LongType, NumericType} - -@ExpressionDescription( - usage = """ - _FUNC_(expr) - Returns the number of non-null number pairs in a group. - """, - examples = """ - Examples: - > SELECT _FUNC_(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x); - 4 - > SELECT _FUNC_(y, x) FROM VALUES (1, 2), (2, null), (2, 3), (2, 4) AS tab(y, x); - 3 - > SELECT _FUNC_(y, x) FROM VALUES (1, 2), (2, null), (null, 3), (2, 4) AS tab(y, x); - 2 - """, - group = "agg_funcs", - since = "3.3.0") -case class RegrCount(left: Expression, right: Expression) - extends UnevaluableAggregate with ImplicitCastInputTypes with BinaryLike[Expression] { - - override def prettyName: String = "regr_count" - - override def nullable: Boolean = false - - override def dataType: DataType = LongType - - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, NumericType) - - final override val nodePatterns: Seq[TreePattern] = Seq(REGR_COUNT) - - override protected def withNewChildrenInternal( - newLeft: Expression, newRight: Expression): RegrCount = - this.copy(left = newLeft, right = newRight) -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/boolAggregates.scala similarity index 63% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/boolAggregates.scala index 244e9d9755752..ae759abf8a4f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/boolAggregates.scala @@ -17,33 +17,10 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.trees.TreePattern.{BOOL_AGG, TreePattern} import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.types._ -abstract class UnevaluableBooleanAggBase(arg: Expression) - extends UnevaluableAggregate with ImplicitCastInputTypes with UnaryLike[Expression] { - - override def child: Expression = arg - - override def dataType: DataType = BooleanType - - override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType) - - final override val nodePatterns: Seq[TreePattern] = Seq(BOOL_AGG) - - override def checkInputDataTypes(): TypeCheckResult = { - arg.dataType match { - case dt if dt != BooleanType => - TypeCheckResult.TypeCheckFailure(s"Input to function '$prettyName' should have been " + - s"${BooleanType.simpleString}, but it's [${arg.dataType.catalogString}].") - case _ => TypeCheckResult.TypeCheckSuccess - } - } -} - @ExpressionDescription( usage = "_FUNC_(expr) - Returns true if all values of `expr` are true.", examples = """ @@ -57,10 +34,13 @@ abstract class UnevaluableBooleanAggBase(arg: Expression) """, group = "agg_funcs", since = "3.0.0") -case class BoolAnd(arg: Expression) extends UnevaluableBooleanAggBase(arg) { - override def nodeName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("bool_and") +case class BoolAnd(child: Expression) extends AggregateFunction with RuntimeReplaceableAggregate + with ImplicitCastInputTypes with UnaryLike[Expression] { + override lazy val replacement: Expression = Min(child) + override def nodeName: String = "bool_and" + override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType) override protected def withNewChildInternal(newChild: Expression): Expression = - copy(arg = newChild) + copy(child = newChild) } @ExpressionDescription( @@ -76,8 +56,11 @@ case class BoolAnd(arg: Expression) extends UnevaluableBooleanAggBase(arg) { """, group = "agg_funcs", since = "3.0.0") -case class BoolOr(arg: Expression) extends UnevaluableBooleanAggBase(arg) { - override def nodeName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("bool_or") +case class BoolOr(child: Expression) extends AggregateFunction with RuntimeReplaceableAggregate + with ImplicitCastInputTypes with UnaryLike[Expression] { + override lazy val replacement: Expression = Max(child) + override def nodeName: String = "bool_or" + override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType) override protected def withNewChildInternal(newChild: Expression): Expression = - copy(arg = newChild) + copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 3ba90659748e5..f97293dc9b464 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -83,7 +83,7 @@ object AggregateExpression { } def containsAggregate(expr: Expression): Boolean = { - expr.find(isAggregate).isDefined + expr.exists(isAggregate) } def isAggregate(expr: Expression): Boolean = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala new file mode 100644 index 0000000000000..8507069a7ac26 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.expressions.{And, Expression, ExpressionDescription, If, ImplicitCastInputTypes, IsNotNull, Literal, RuntimeReplaceableAggregate} +import org.apache.spark.sql.catalyst.trees.BinaryLike +import org.apache.spark.sql.types.{AbstractDataType, NumericType} + +@ExpressionDescription( + usage = """ + _FUNC_(expr) - Returns the number of non-null number pairs in a group. + """, + examples = """ + Examples: + > SELECT _FUNC_(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x); + 4 + > SELECT _FUNC_(y, x) FROM VALUES (1, 2), (2, null), (2, 3), (2, 4) AS tab(y, x); + 3 + > SELECT _FUNC_(y, x) FROM VALUES (1, 2), (2, null), (null, 3), (2, 4) AS tab(y, x); + 2 + """, + group = "agg_funcs", + since = "3.3.0") +case class RegrCount(left: Expression, right: Expression) + extends AggregateFunction + with RuntimeReplaceableAggregate + with ImplicitCastInputTypes + with BinaryLike[Expression] { + override lazy val replacement: Expression = Count(Seq(left, right)) + override def nodeName: String = "regr_count" + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, NumericType) + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): RegrCount = + this.copy(left = newLeft, right = newRight) +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the average of the independent variable for non-null pairs in a group, where `y` is the dependent variable and `x` is the independent variable.", + examples = """ + Examples: + > SELECT _FUNC_(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x); + 2.75 + > SELECT _FUNC_(y, x) FROM VALUES (1, null) AS tab(y, x); + NULL + > SELECT _FUNC_(y, x) FROM VALUES (null, 1) AS tab(y, x); + NULL + > SELECT _FUNC_(y, x) FROM VALUES (1, 2), (2, null), (2, 3), (2, 4) AS tab(y, x); + 3.0 + > SELECT _FUNC_(y, x) FROM VALUES (1, 2), (2, null), (null, 3), (2, 4) AS tab(y, x); + 3.0 + """, + group = "agg_funcs", + since = "3.3.0") +// scalastyle:on line.size.limit +case class RegrAvgX( + left: Expression, + right: Expression) + extends AggregateFunction + with RuntimeReplaceableAggregate + with ImplicitCastInputTypes + with BinaryLike[Expression] { + override lazy val replacement: Expression = + Average(If(And(IsNotNull(left), IsNotNull(right)), right, Literal.create(null, right.dataType))) + override def nodeName: String = "regr_avgx" + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, NumericType) + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): RegrAvgX = + this.copy(left = newLeft, right = newRight) +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the average of the dependent variable for non-null pairs in a group, where `y` is the dependent variable and `x` is the independent variable.", + examples = """ + Examples: + > SELECT _FUNC_(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x); + 1.75 + > SELECT _FUNC_(y, x) FROM VALUES (1, null) AS tab(y, x); + NULL + > SELECT _FUNC_(y, x) FROM VALUES (null, 1) AS tab(y, x); + NULL + > SELECT _FUNC_(y, x) FROM VALUES (1, 2), (2, null), (2, 3), (2, 4) AS tab(y, x); + 1.6666666666666667 + > SELECT _FUNC_(y, x) FROM VALUES (1, 2), (2, null), (null, 3), (2, 4) AS tab(y, x); + 1.5 + """, + group = "agg_funcs", + since = "3.3.0") +// scalastyle:on line.size.limit +case class RegrAvgY( + left: Expression, + right: Expression) + extends AggregateFunction + with RuntimeReplaceableAggregate + with ImplicitCastInputTypes + with BinaryLike[Expression] { + override lazy val replacement: Expression = + Average(If(And(IsNotNull(left), IsNotNull(right)), left, Literal.create(null, left.dataType))) + override def nodeName: String = "regr_avgy" + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, NumericType) + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): RegrAvgY = + this.copy(left = newLeft, right = newRight) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 2a906a69606cc..88a38612fc4f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -204,6 +204,8 @@ case class Abs(child: Expression, failOnError: Boolean = SQLConf.get.ansiEnabled protected override def nullSafeEval(input: Any): Any = numeric.abs(input) + override def flatArguments: Iterator[Any] = Iterator(child) + override protected def withNewChildInternal(newChild: Expression): Abs = copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index dbe9a810a493e..3651dc420fa21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -203,7 +203,7 @@ trait Block extends TreeNode[Block] with JavaCode { override def verboseString(maxFields: Int): String = toString override def simpleStringWithNodeId(): String = { - throw QueryExecutionErrors.simpleStringWithNodeIdUnsupportedError(nodeName) + throw new IllegalStateException(s"$nodeName does not implement simpleStringWithNodeId") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 65b6a05fbeb47..363c531b04272 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, Un import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{ARRAYS_ZIP, CONCAT, TreePattern} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.DateTimeConstants._ @@ -89,8 +90,6 @@ trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression 4 > SELECT _FUNC_(map('a', 1, 'b', 2)); 2 - > SELECT _FUNC_(NULL); - -1 """, since = "1.5.0", group = "collection_funcs") @@ -134,6 +133,31 @@ object Size { def apply(child: Expression): Size = new Size(child) } + +/** + * Given an array, returns total number of elements in it. + */ +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the size of an array. The function returns null for null input.", + examples = """ + Examples: + > SELECT _FUNC_(array('b', 'd', 'c', 'a')); + 4 + """, + since = "3.3.0", + group = "collection_funcs") +case class ArraySize(child: Expression) + extends RuntimeReplaceable with ImplicitCastInputTypes with UnaryLike[Expression] { + + override lazy val replacement: Expression = Size(child, legacySizeOfNull = false) + + override def prettyName: String = "array_size" + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + protected def withNewChildInternal(newChild: Expression): ArraySize = copy(child = newChild) +} + /** * Returns an unordered array containing the keys of the map. */ @@ -182,19 +206,41 @@ case class MapKeys(child: Expression) """, group = "map_funcs", since = "3.3.0") -case class MapContainsKey( - left: Expression, - right: Expression, - child: Expression) extends RuntimeReplaceable { - def this(left: Expression, right: Expression) = - this(left, right, ArrayContains(MapKeys(left), right)) +case class MapContainsKey(left: Expression, right: Expression) + extends RuntimeReplaceable with BinaryLike[Expression] with ImplicitCastInputTypes { - override def exprsReplaced: Seq[Expression] = Seq(left, right) + override lazy val replacement: Expression = ArrayContains(MapKeys(left), right) + + override def inputTypes: Seq[AbstractDataType] = { + (left.dataType, right.dataType) match { + case (_, NullType) => Seq.empty + case (MapType(kt, vt, valueContainsNull), dt) => + TypeCoercion.findWiderTypeWithoutStringPromotionForTwo(kt, dt) match { + case Some(widerType) => Seq(MapType(widerType, vt, valueContainsNull), widerType) + case _ => Seq.empty + } + case _ => Seq.empty + } + } + + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (_, NullType) => + TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as arguments") + case (MapType(kt, _, _), dt) if kt.sameType(dt) => + TypeUtils.checkForOrderingExpr(kt, s"function $prettyName") + case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + + s"been ${MapType.simpleString} followed by a value with same key type, but it's " + + s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].") + } + } override def prettyName: String = "map_contains_key" - override protected def withNewChildInternal(newChild: Expression): MapContainsKey = - copy(child = newChild) + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Expression = { + copy(newLeft, newRight) + } } @ExpressionDescription( @@ -2229,19 +2275,17 @@ case class ElementAt( """, since = "3.3.0", group = "map_funcs") -case class TryElementAt(left: Expression, right: Expression, child: Expression) - extends RuntimeReplaceable { +case class TryElementAt(left: Expression, right: Expression, replacement: Expression) + extends RuntimeReplaceable with InheritAnalysisRules { def this(left: Expression, right: Expression) = this(left, right, ElementAt(left, right, failOnError = false)) - override def flatArguments: Iterator[Any] = Iterator(left, right) - - override def exprsReplaced: Seq[Expression] = Seq(left, right) - override def prettyName: String = "try_element_at" + override def parameters: Seq[Expression] = Seq(left, right) + override protected def withNewChildInternal(newChild: Expression): Expression = - this.copy(child = newChild) + this.copy(replacement = newChild) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala index 8feaf52ecb134..75d912633a0fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala @@ -30,6 +30,17 @@ trait TaggingExpression extends UnaryExpression { override def eval(input: InternalRow): Any = child.eval(input) } +case class KnownNullable(child: Expression) extends TaggingExpression { + override def nullable: Boolean = true + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + child.genCode(ctx) + } + + override protected def withNewChildInternal(newChild: Expression): KnownNullable = + copy(child = newChild) +} + case class KnownNotNull(child: Expression) extends TaggingExpression { override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index 79bbc103c92d3..6e08ad346c853 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -91,7 +91,7 @@ case class CsvToStructs( assert(!rows.hasNext) result } else { - throw QueryExecutionErrors.rowFromCSVParserNotExpectedError + throw new IllegalStateException("Expected one row from CSV parser.") } } @@ -153,7 +153,7 @@ case class CsvToStructs( examples = """ Examples: > SELECT _FUNC_('1,abc'); - STRUCT<`_c0`: INT, `_c1`: STRING> + STRUCT<_c0: INT, _c1: STRING> """, since = "3.0.0", group = "csv_funcs") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index c679c1f5a5801..013f11ac29786 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -26,6 +26,7 @@ import org.apache.commons.text.StringEscapeUtils import org.apache.spark.SparkDateTimeException import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TreePattern._ @@ -1112,25 +1113,15 @@ case class GetTimestamp( group = "datetime_funcs", since = "3.3.0") // scalastyle:on line.size.limit -case class ParseToTimestampNTZ( - left: Expression, - format: Option[Expression], - child: Expression) extends RuntimeReplaceable { - - def this(left: Expression, format: Expression) = { - this(left, Option(format), GetTimestamp(left, format, TimestampNTZType)) +object ParseToTimestampNTZExpressionBuilder extends ExpressionBuilder { + override def build(funcName: String, expressions: Seq[Expression]): Expression = { + val numArgs = expressions.length + if (numArgs == 1 || numArgs == 2) { + ParseToTimestamp(expressions(0), expressions.drop(1).lastOption, TimestampNTZType) + } else { + throw QueryCompilationErrors.invalidFunctionArgumentNumberError(Seq(1, 2), funcName, numArgs) + } } - - def this(left: Expression) = this(left, None, Cast(left, TimestampNTZType)) - - override def flatArguments: Iterator[Any] = Iterator(left, format) - override def exprsReplaced: Seq[Expression] = left +: format.toSeq - - override def prettyName: String = "to_timestamp_ntz" - override def dataType: DataType = TimestampNTZType - - override protected def withNewChildInternal(newChild: Expression): ParseToTimestampNTZ = - copy(child = newChild) } /** @@ -1159,25 +1150,15 @@ case class ParseToTimestampNTZ( group = "datetime_funcs", since = "3.3.0") // scalastyle:on line.size.limit -case class ParseToTimestampLTZ( - left: Expression, - format: Option[Expression], - child: Expression) extends RuntimeReplaceable { - - def this(left: Expression, format: Expression) = { - this(left, Option(format), GetTimestamp(left, format, TimestampType)) +object ParseToTimestampLTZExpressionBuilder extends ExpressionBuilder { + override def build(funcName: String, expressions: Seq[Expression]): Expression = { + val numArgs = expressions.length + if (numArgs == 1 || numArgs == 2) { + ParseToTimestamp(expressions(0), expressions.drop(1).lastOption, TimestampType) + } else { + throw QueryCompilationErrors.invalidFunctionArgumentNumberError(Seq(1, 2), funcName, numArgs) + } } - - def this(left: Expression) = this(left, None, Cast(left, TimestampType)) - - override def flatArguments: Iterator[Any] = Iterator(left, format) - override def exprsReplaced: Seq[Expression] = left +: format.toSeq - - override def prettyName: String = "to_timestamp_ltz" - override def dataType: DataType = TimestampType - - override protected def withNewChildInternal(newChild: Expression): ParseToTimestampLTZ = - copy(child = newChild) } abstract class ToTimestamp @@ -1606,12 +1587,19 @@ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[S case class DatetimeSub( start: Expression, interval: Expression, - child: Expression) extends RuntimeReplaceable { - override def exprsReplaced: Seq[Expression] = Seq(start, interval) + replacement: Expression) extends RuntimeReplaceable with InheritAnalysisRules { + + override def parameters: Seq[Expression] = Seq(start, interval) + + override def makeSQLString(childrenSQL: Seq[String]): String = { + childrenSQL.mkString(" - ") + } + override def toString: String = s"$start - $interval" - override def mkString(childrenString: Seq[String]): String = childrenString.mkString(" - ") - override protected def withNewChildInternal(newChild: Expression): DatetimeSub = - copy(child = newChild) + + override protected def withNewChildInternal(newChild: Expression): Expression = { + copy(replacement = newChild) + } } /** @@ -1991,25 +1979,48 @@ case class MonthsBetween( group = "datetime_funcs", since = "1.5.0") // scalastyle:on line.size.limit -case class ParseToDate(left: Expression, format: Option[Expression], child: Expression) - extends RuntimeReplaceable { +case class ParseToDate( + left: Expression, + format: Option[Expression], + timeZoneId: Option[String] = None) + extends RuntimeReplaceable with ImplicitCastInputTypes with TimeZoneAwareExpression { + + override lazy val replacement: Expression = format.map { f => + Cast(GetTimestamp(left, f, TimestampType, timeZoneId), DateType, timeZoneId) + }.getOrElse(Cast(left, DateType, timeZoneId)) // backwards compatibility def this(left: Expression, format: Expression) = { - this(left, Option(format), Cast(GetTimestamp(left, format, TimestampType), DateType)) + this(left, Option(format)) } def this(left: Expression) = { - // backwards compatibility - this(left, None, Cast(left, DateType)) + this(left, None) } - override def exprsReplaced: Seq[Expression] = left +: format.toSeq - override def flatArguments: Iterator[Any] = Iterator(left, format) - override def prettyName: String = "to_date" - override protected def withNewChildInternal(newChild: Expression): ParseToDate = - copy(child = newChild) + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Some(timeZoneId)) + + override def nodePatternsInternal(): Seq[TreePattern] = Seq(RUNTIME_REPLACEABLE) + + override def children: Seq[Expression] = left +: format.toSeq + + override def inputTypes: Seq[AbstractDataType] = { + // Note: ideally this function should only take string input, but we allow more types here to + // be backward compatible. + TypeCollection(StringType, DateType, TimestampType, TimestampNTZType) +: + format.map(_ => StringType).toSeq + } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): Expression = { + if (format.isDefined) { + copy(left = newChildren.head, format = Some(newChildren.last)) + } else { + copy(left = newChildren.head) + } + } } /** @@ -2043,23 +2054,44 @@ case class ParseToTimestamp( left: Expression, format: Option[Expression], override val dataType: DataType, - child: Expression) extends RuntimeReplaceable { + timeZoneId: Option[String] = None) + extends RuntimeReplaceable with ImplicitCastInputTypes with TimeZoneAwareExpression { + + override lazy val replacement: Expression = format.map { f => + GetTimestamp(left, f, dataType, timeZoneId) + }.getOrElse(Cast(left, dataType, timeZoneId)) def this(left: Expression, format: Expression) = { - this(left, Option(format), SQLConf.get.timestampType, - GetTimestamp(left, format, SQLConf.get.timestampType)) + this(left, Option(format), SQLConf.get.timestampType) } def this(left: Expression) = - this(left, None, SQLConf.get.timestampType, Cast(left, SQLConf.get.timestampType)) + this(left, None, SQLConf.get.timestampType) - override def flatArguments: Iterator[Any] = Iterator(left, format) - override def exprsReplaced: Seq[Expression] = left +: format.toSeq + override def nodeName: String = "to_timestamp" - override def prettyName: String = "to_timestamp" + override def nodePatternsInternal(): Seq[TreePattern] = Seq(RUNTIME_REPLACEABLE) - override protected def withNewChildInternal(newChild: Expression): ParseToTimestamp = - copy(child = newChild) + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Some(timeZoneId)) + + override def children: Seq[Expression] = left +: format.toSeq + + override def inputTypes: Seq[AbstractDataType] = { + // Note: ideally this function should only take string input, but we allow more types here to + // be backward compatible. + TypeCollection(StringType, DateType, TimestampType, TimestampNTZType) +: + format.map(_ => StringType).toSeq + } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): Expression = { + if (format.isDefined) { + copy(left = newChildren.head, format = Some(newChildren.last)) + } else { + copy(left = newChildren.head) + } + } } trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes { @@ -2312,8 +2344,9 @@ case class DateDiff(endDate: Expression, startDate: Expression) copy(endDate = newLeft, startDate = newRight) } +// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(year, month, day) - Create date from year, month and day fields.", + usage = "_FUNC_(year, month, day) - Create date from year, month and day fields. If the configuration `spark.sql.ansi.enabled` is false, the function returns NULL on invalid inputs. Otherwise, it will throw an error instead.", arguments = """ Arguments: * year - the year to represent, from 1 to 9999 @@ -2324,15 +2357,12 @@ case class DateDiff(endDate: Expression, startDate: Expression) Examples: > SELECT _FUNC_(2013, 7, 15); 2013-07-15 - > SELECT _FUNC_(2019, 13, 1); - NULL > SELECT _FUNC_(2019, 7, NULL); NULL - > SELECT _FUNC_(2019, 2, 30); - NULL """, group = "datetime_funcs", since = "3.0.0") +// scalastyle:on line.size.limit case class MakeDate( year: Expression, month: Expression, @@ -2386,7 +2416,7 @@ case class MakeDate( // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(year, month, day, hour, min, sec) - Create local date-time from year, month, day, hour, min, sec fields. ", + usage = "_FUNC_(year, month, day, hour, min, sec) - Create local date-time from year, month, day, hour, min, sec fields. If the configuration `spark.sql.ansi.enabled` is false, the function returns NULL on invalid inputs. Otherwise, it will throw an error instead.", arguments = """ Arguments: * year - the year to represent, from 1 to 9999 @@ -2410,37 +2440,27 @@ case class MakeDate( group = "datetime_funcs", since = "3.3.0") // scalastyle:on line.size.limit -case class MakeTimestampNTZ( - year: Expression, - month: Expression, - day: Expression, - hour: Expression, - min: Expression, - sec: Expression, - failOnError: Boolean = SQLConf.get.ansiEnabled, - child: Expression) extends RuntimeReplaceable { - def this( - year: Expression, - month: Expression, - day: Expression, - hour: Expression, - min: Expression, - sec: Expression) = { - this(year, month, day, hour, min, sec, failOnError = SQLConf.get.ansiEnabled, - MakeTimestamp(year, month, day, hour, min, sec, dataType = TimestampNTZType)) +object MakeTimestampNTZExpressionBuilder extends ExpressionBuilder { + override def build(funcName: String, expressions: Seq[Expression]): Expression = { + val numArgs = expressions.length + if (numArgs == 6) { + MakeTimestamp( + expressions(0), + expressions(1), + expressions(2), + expressions(3), + expressions(4), + expressions(5), + dataType = TimestampNTZType) + } else { + throw QueryCompilationErrors.invalidFunctionArgumentNumberError(Seq(6), funcName, numArgs) + } } - - override def prettyName: String = "make_timestamp_ntz" - - override def exprsReplaced: Seq[Expression] = Seq(year, month, day, hour, min, sec) - - override protected def withNewChildInternal(newChild: Expression): Expression = - copy(child = newChild) } // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(year, month, day, hour, min, sec[, timezone]) - Create the current timestamp with local time zone from year, month, day, hour, min, sec and timezone fields. ", + usage = "_FUNC_(year, month, day, hour, min, sec[, timezone]) - Create the current timestamp with local time zone from year, month, day, hour, min, sec and timezone fields. If the configuration `spark.sql.ansi.enabled` is false, the function returns NULL on invalid inputs. Otherwise, it will throw an error instead.", arguments = """ Arguments: * year - the year to represent, from 1 to 9999 @@ -2461,59 +2481,34 @@ case class MakeTimestampNTZ( 2014-12-27 21:30:45.887 > SELECT _FUNC_(2019, 6, 30, 23, 59, 60); 2019-07-01 00:00:00 - > SELECT _FUNC_(2019, 13, 1, 10, 11, 12, 'PST'); - NULL > SELECT _FUNC_(null, 7, 22, 15, 30, 0); NULL """, group = "datetime_funcs", since = "3.3.0") // scalastyle:on line.size.limit -case class MakeTimestampLTZ( - year: Expression, - month: Expression, - day: Expression, - hour: Expression, - min: Expression, - sec: Expression, - timezone: Option[Expression], - failOnError: Boolean = SQLConf.get.ansiEnabled, - child: Expression) extends RuntimeReplaceable { - def this( - year: Expression, - month: Expression, - day: Expression, - hour: Expression, - min: Expression, - sec: Expression) = { - this(year, month, day, hour, min, sec, None, failOnError = SQLConf.get.ansiEnabled, - MakeTimestamp(year, month, day, hour, min, sec, dataType = TimestampType)) - } - - def this( - year: Expression, - month: Expression, - day: Expression, - hour: Expression, - min: Expression, - sec: Expression, - timezone: Expression) = { - this(year, month, day, hour, min, sec, Some(timezone), failOnError = SQLConf.get.ansiEnabled, - MakeTimestamp(year, month, day, hour, min, sec, Some(timezone), dataType = TimestampType)) +object MakeTimestampLTZExpressionBuilder extends ExpressionBuilder { + override def build(funcName: String, expressions: Seq[Expression]): Expression = { + val numArgs = expressions.length + if (numArgs == 6 || numArgs == 7) { + MakeTimestamp( + expressions(0), + expressions(1), + expressions(2), + expressions(3), + expressions(4), + expressions(5), + expressions.drop(6).lastOption, + dataType = TimestampType) + } else { + throw QueryCompilationErrors.invalidFunctionArgumentNumberError(Seq(6), funcName, numArgs) + } } - - override def prettyName: String = "make_timestamp_ltz" - - override def exprsReplaced: Seq[Expression] = Seq(year, month, day, hour, min, sec) - - override protected def withNewChildInternal(newChild: Expression): Expression = - copy(child = newChild) } // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(year, month, day, hour, min, sec[, timezone]) - Create timestamp from year, month, day, hour, min, sec and timezone fields. " + - "The result data type is consistent with the value of configuration `spark.sql.timestampType`", + usage = "_FUNC_(year, month, day, hour, min, sec[, timezone]) - Create timestamp from year, month, day, hour, min, sec and timezone fields. The result data type is consistent with the value of configuration `spark.sql.timestampType`. If the configuration `spark.sql.ansi.enabled` is false, the function returns NULL on invalid inputs. Otherwise, it will throw an error instead.", arguments = """ Arguments: * year - the year to represent, from 1 to 9999 @@ -2537,8 +2532,6 @@ case class MakeTimestampLTZ( 2019-07-01 00:00:00 > SELECT _FUNC_(2019, 6, 30, 23, 59, 1); 2019-06-30 23:59:01 - > SELECT _FUNC_(2019, 13, 1, 10, 11, 12, 'PST'); - NULL > SELECT _FUNC_(null, 7, 22, 15, 30, 0); NULL """, @@ -2699,7 +2692,7 @@ case class MakeTimestamp( }) } - override def prettyName: String = "make_timestamp" + override def nodeName: String = "make_timestamp" // override def children: Seq[Expression] = Seq(year, month, day, hour, min, sec) ++ timezone override protected def withNewChildrenInternal( @@ -2720,8 +2713,7 @@ object DatePart { def parseExtractField( extractField: String, - source: Expression, - errorHandleFunc: => Nothing): Expression = extractField.toUpperCase(Locale.ROOT) match { + source: Expression): Expression = extractField.toUpperCase(Locale.ROOT) match { case "YEAR" | "Y" | "YEARS" | "YR" | "YRS" => Year(source) case "YEAROFWEEK" => YearOfWeek(source) case "QUARTER" | "QTR" => Quarter(source) @@ -2734,29 +2726,8 @@ object DatePart { case "HOUR" | "H" | "HOURS" | "HR" | "HRS" => Hour(source) case "MINUTE" | "M" | "MIN" | "MINS" | "MINUTES" => Minute(source) case "SECOND" | "S" | "SEC" | "SECONDS" | "SECS" => SecondWithFraction(source) - case _ => errorHandleFunc - } - - def toEquivalentExpr(field: Expression, source: Expression): Expression = { - if (!field.foldable) { - throw QueryCompilationErrors.unfoldableFieldUnsupportedError - } - val fieldEval = field.eval() - if (fieldEval == null) { - Literal(null, DoubleType) - } else { - val fieldStr = fieldEval.asInstanceOf[UTF8String].toString - - def analysisException = - throw QueryCompilationErrors.literalTypeUnsupportedForSourceTypeError(fieldStr, source) - - source.dataType match { - case _: AnsiIntervalType | CalendarIntervalType => - ExtractIntervalPart.parseExtractField(fieldStr, source, analysisException) - case _ => - DatePart.parseExtractField(fieldStr, source, analysisException) - } - } + case _ => + throw QueryCompilationErrors.literalTypeUnsupportedForSourceTypeError(extractField, source) } } @@ -2793,20 +2764,17 @@ object DatePart { group = "datetime_funcs", since = "3.0.0") // scalastyle:on line.size.limit -case class DatePart(field: Expression, source: Expression, child: Expression) - extends RuntimeReplaceable { - - def this(field: Expression, source: Expression) = { - this(field, source, DatePart.toEquivalentExpr(field, source)) +object DatePartExpressionBuilder extends ExpressionBuilder { + override def build(funcName: String, expressions: Seq[Expression]): Expression = { + val numArgs = expressions.length + if (numArgs == 2) { + val field = expressions(0) + val source = expressions(1) + Extract(field, source, Extract.createExpr(funcName, field, source)) + } else { + throw QueryCompilationErrors.invalidFunctionArgumentNumberError(Seq(2), funcName, numArgs) + } } - - override def flatArguments: Iterator[Any] = Iterator(field, source) - override def exprsReplaced: Seq[Expression] = Seq(field, source) - - override def prettyName: String = "date_part" - - override protected def withNewChildInternal(newChild: Expression): DatePart = - copy(child = newChild) } // scalastyle:off line.size.limit @@ -2862,23 +2830,45 @@ case class DatePart(field: Expression, source: Expression, child: Expression) group = "datetime_funcs", since = "3.0.0") // scalastyle:on line.size.limit -case class Extract(field: Expression, source: Expression, child: Expression) - extends RuntimeReplaceable { +case class Extract(field: Expression, source: Expression, replacement: Expression) + extends RuntimeReplaceable with InheritAnalysisRules { - def this(field: Expression, source: Expression) = { - this(field, source, DatePart.toEquivalentExpr(field, source)) - } + def this(field: Expression, source: Expression) = + this(field, source, Extract.createExpr("extract", field, source)) - override def flatArguments: Iterator[Any] = Iterator(field, source) + override def parameters: Seq[Expression] = Seq(field, source) - override def exprsReplaced: Seq[Expression] = Seq(field, source) + override def makeSQLString(childrenSQL: Seq[String]): String = { + getTagValue(FunctionRegistry.FUNC_ALIAS) match { + case Some("date_part") => s"$prettyName(${childrenSQL.mkString(", ")})" + case _ => s"$prettyName(${childrenSQL.mkString(" FROM ")})" + } + } - override def mkString(childrenString: Seq[String]): String = { - prettyName + childrenString.mkString("(", " FROM ", ")") + override protected def withNewChildInternal(newChild: Expression): Expression = { + copy(replacement = newChild) } +} - override protected def withNewChildInternal(newChild: Expression): Extract = - copy(child = newChild) +object Extract { + def createExpr(funcName: String, field: Expression, source: Expression): Expression = { + // both string and null literals are allowed. + if ((field.dataType == StringType || field.dataType == NullType) && field.foldable) { + val fieldStr = field.eval().asInstanceOf[UTF8String] + if (fieldStr == null) { + Literal(null, DoubleType) + } else { + source.dataType match { + case _: AnsiIntervalType | CalendarIntervalType => + ExtractIntervalPart.parseExtractField(fieldStr.toString, source) + case _ => + DatePart.parseExtractField(fieldStr.toString, source) + } + } + } else { + throw QueryCompilationErrors.requireLiteralParameter(funcName, "field", "string") + } + } } /** @@ -2913,25 +2903,25 @@ case class SubtractTimestamps( @transient private lazy val zoneIdInEval: ZoneId = zoneIdForType(left.dataType) @transient - private lazy val evalFunc: (Long, Long) => Any = legacyInterval match { - case false => (leftMicros, rightMicros) => - subtractTimestamps(leftMicros, rightMicros, zoneIdInEval) - case true => (leftMicros, rightMicros) => + private lazy val evalFunc: (Long, Long) => Any = if (legacyInterval) { + (leftMicros, rightMicros) => new CalendarInterval(0, 0, leftMicros - rightMicros) + } else { + (leftMicros, rightMicros) => + subtractTimestamps(leftMicros, rightMicros, zoneIdInEval) } override def nullSafeEval(leftMicros: Any, rightMicros: Any): Any = { evalFunc(leftMicros.asInstanceOf[Long], rightMicros.asInstanceOf[Long]) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = legacyInterval match { - case false => - val zid = ctx.addReferenceObj("zoneId", zoneIdInEval, classOf[ZoneId].getName) - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (l, r) => s"""$dtu.subtractTimestamps($l, $r, $zid)""") - case true => - defineCodeGen(ctx, ev, (end, start) => - s"new org.apache.spark.unsafe.types.CalendarInterval(0, 0, $end - $start)") + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = if (legacyInterval) { + defineCodeGen(ctx, ev, (end, start) => + s"new org.apache.spark.unsafe.types.CalendarInterval(0, 0, $end - $start)") + } else { + val zid = ctx.addReferenceObj("zoneId", zoneIdInEval, classOf[ZoneId].getName) + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (l, r) => s"""$dtu.subtractTimestamps($l, $r, $zid)""") } override def toString: String = s"($left - $right)" @@ -2971,26 +2961,26 @@ case class SubtractDates( } @transient - private lazy val evalFunc: (Int, Int) => Any = legacyInterval match { - case false => (leftDays: Int, rightDays: Int) => + private lazy val evalFunc: (Int, Int) => Any = if (legacyInterval) { + (leftDays: Int, rightDays: Int) => subtractDates(leftDays, rightDays) + } else { + (leftDays: Int, rightDays: Int) => Math.multiplyExact(Math.subtractExact(leftDays, rightDays), MICROS_PER_DAY) - case true => (leftDays: Int, rightDays: Int) => subtractDates(leftDays, rightDays) } override def nullSafeEval(leftDays: Any, rightDays: Any): Any = { evalFunc(leftDays.asInstanceOf[Int], rightDays.asInstanceOf[Int]) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = legacyInterval match { - case false => - val m = classOf[Math].getName - defineCodeGen(ctx, ev, (leftDays, rightDays) => - s"$m.multiplyExact($m.subtractExact($leftDays, $rightDays), ${MICROS_PER_DAY}L)") - case true => - defineCodeGen(ctx, ev, (leftDays, rightDays) => { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - s"$dtu.subtractDates($leftDays, $rightDays)" - }) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = if (legacyInterval) { + defineCodeGen(ctx, ev, (leftDays, rightDays) => { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + s"$dtu.subtractDates($leftDays, $rightDays)" + }) + } else { + val m = classOf[Math].getName + defineCodeGen(ctx, ev, (leftDays, rightDays) => + s"$m.multiplyExact($m.subtractExact($leftDays, $rightDays), ${MICROS_PER_DAY}L)") } override def toString: String = s"($left - $right)" @@ -3057,3 +3047,163 @@ case class ConvertTimezone( copy(sourceTz = newFirst, targetTz = newSecond, sourceTs = newThird) } } + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(unit, quantity, timestamp) - Adds the specified number of units to the given timestamp.", + arguments = """ + Arguments: + * unit - this indicates the units of datetime that you want to add. + Supported string values of `unit` are (case insensitive): + - "YEAR" + - "QUARTER" - 3 months + - "MONTH" + - "WEEK" - 7 days + - "DAY", "DAYOFYEAR" + - "HOUR" + - "MINUTE" + - "SECOND" + - "MILLISECOND" + - "MICROSECOND" + * quantity - this is the number of units of time that you want to add. + * timestamp - this is a timestamp (w/ or w/o timezone) to which you want to add. + """, + examples = """ + Examples: + > SELECT _FUNC_(HOUR, 8, timestamp_ntz'2022-02-11 20:30:00'); + 2022-02-12 04:30:00 + > SELECT _FUNC_(MONTH, 1, timestamp_ltz'2022-01-31 00:00:00'); + 2022-02-28 00:00:00 + > SELECT _FUNC_(SECOND, -10, date'2022-01-01'); + 2021-12-31 23:59:50 + > SELECT _FUNC_(YEAR, 10, timestamp'2000-01-01 01:02:03.123456'); + 2010-01-01 01:02:03.123456 + """, + group = "datetime_funcs", + since = "3.3.0") +// scalastyle:on line.size.limit +case class TimestampAdd( + unit: String, + quantity: Expression, + timestamp: Expression, + timeZoneId: Option[String] = None) + extends BinaryExpression + with ImplicitCastInputTypes + with NullIntolerant + with TimeZoneAwareExpression { + + def this(unit: String, quantity: Expression, timestamp: Expression) = + this(unit, quantity, timestamp, None) + + override def left: Expression = quantity + override def right: Expression = timestamp + + override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType, AnyTimestampType) + override def dataType: DataType = timestamp.dataType + + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + + @transient private lazy val zoneIdInEval: ZoneId = zoneIdForType(timestamp.dataType) + + override def nullSafeEval(q: Any, micros: Any): Any = { + DateTimeUtils.timestampAdd(unit, q.asInstanceOf[Int], micros.asInstanceOf[Long], zoneIdInEval) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val zid = ctx.addReferenceObj("zoneId", zoneIdInEval, classOf[ZoneId].getName) + defineCodeGen(ctx, ev, (q, micros) => + s"""$dtu.timestampAdd("$unit", $q, $micros, $zid)""") + } + + override def prettyName: String = "timestampadd" + + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): TimestampAdd = { + copy(quantity = newLeft, timestamp = newRight) + } +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(unit, startTimestamp, endTimestamp) - Gets the difference between the timestamps `endTimestamp` and `startTimestamp` in the specified units by truncating the fraction part.", + arguments = """ + Arguments: + * unit - this indicates the units of the difference between the given timestamps. + Supported string values of `unit` are (case insensitive): + - "YEAR" + - "QUARTER" - 3 months + - "MONTH" + - "WEEK" - 7 days + - "DAY" + - "HOUR" + - "MINUTE" + - "SECOND" + - "MILLISECOND" + - "MICROSECOND" + * startTimestamp - A timestamp which the expression subtracts from `endTimestamp`. + * endTimestamp - A timestamp from which the expression subtracts `startTimestamp`. + """, + examples = """ + Examples: + > SELECT _FUNC_(HOUR, timestamp_ntz'2022-02-11 20:30:00', timestamp_ntz'2022-02-12 04:30:00'); + 8 + > SELECT _FUNC_(MONTH, timestamp_ltz'2022-01-01 00:00:00', timestamp_ltz'2022-02-28 00:00:00'); + 1 + > SELECT _FUNC_(SECOND, date'2022-01-01', timestamp'2021-12-31 23:59:50'); + -10 + > SELECT _FUNC_(YEAR, timestamp'2000-01-01 01:02:03.123456', timestamp'2010-01-01 01:02:03.123456'); + 10 + """, + group = "datetime_funcs", + since = "3.3.0") +// scalastyle:on line.size.limit +case class TimestampDiff( + unit: String, + startTimestamp: Expression, + endTimestamp: Expression, + timeZoneId: Option[String] = None) + extends BinaryExpression + with ImplicitCastInputTypes + with NullIntolerant + with TimeZoneAwareExpression { + + def this(unit: String, quantity: Expression, timestamp: Expression) = + this(unit, quantity, timestamp, None) + + override def left: Expression = startTimestamp + override def right: Expression = endTimestamp + + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType) + override def dataType: DataType = LongType + + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + + @transient private lazy val zoneIdInEval: ZoneId = zoneIdForType(endTimestamp.dataType) + + override def nullSafeEval(startMicros: Any, endMicros: Any): Any = { + DateTimeUtils.timestampDiff( + unit, + startMicros.asInstanceOf[Long], + endMicros.asInstanceOf[Long], + zoneIdInEval) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val zid = ctx.addReferenceObj("zoneId", zoneIdInEval, classOf[ZoneId].getName) + defineCodeGen(ctx, ev, (s, e) => + s"""$dtu.timestampDiff("$unit", $s, $e, $zid)""") + } + + override def prettyName: String = "timestampdiff" + + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): TimestampDiff = { + copy(startTimestamp = newLeft, endTimestamp = newRight) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 48ccc2e82b0ad..8116537d7b06d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -149,7 +149,7 @@ case class CheckOverflow( }) } - override def toString: String = s"CheckOverflow($child, $dataType, $nullOnOverflow)" + override def toString: String = s"CheckOverflow($child, $dataType)" override def sql: String = child.sql diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 0c45f495097aa..f9b2ade9a6029 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -388,18 +388,13 @@ case class ArraySort( checkArgumentDataTypes() match { case TypeCheckResult.TypeCheckSuccess => argument.dataType match { - case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => + case ArrayType(_, _) => if (function.dataType == IntegerType) { TypeCheckResult.TypeCheckSuccess } else { TypeCheckResult.TypeCheckFailure("Return type of the given function has to be " + "IntegerType") } - case ArrayType(dt, _) => - val dtSimple = dt.catalogString - TypeCheckResult.TypeCheckFailure( - s"$prettyName does not support sorting array of type $dtSimple which is not " + - "orderable") case _ => TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") } @@ -452,6 +447,8 @@ object ArraySort { If(LessThan(left, right), litm1, If(GreaterThan(left, right), lit1, lit0))))) } + // Default Comparator only works for orderable types. + // This is validated by the underlying LessTan and GreaterThan val defaultComparator: LambdaFunction = { val left = UnresolvedNamedLambdaVariable(Seq("left")) val right = UnresolvedNamedLambdaVariable(Seq("right")) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index 5568d7c4a6cba..c461b8f51eedc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGe import org.apache.spark.sql.catalyst.util.DateTimeConstants.MONTHS_PER_YEAR import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.catalyst.util.IntervalUtils._ -import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, MINUTE, SECOND} @@ -122,10 +122,7 @@ case class ExtractANSIIntervalSeconds(child: Expression) object ExtractIntervalPart { - def parseExtractField( - extractField: String, - source: Expression, - errorHandleFunc: => Nothing): Expression = { + def parseExtractField(extractField: String, source: Expression): Expression = { (extractField.toUpperCase(Locale.ROOT), source.dataType) match { case ("YEAR" | "Y" | "YEARS" | "YR" | "YRS", YearMonthIntervalType(start, end)) if isUnitInIntervalRange(YEAR, start, end) => @@ -157,7 +154,8 @@ object ExtractIntervalPart { ExtractANSIIntervalSeconds(source) case ("SECOND" | "S" | "SEC" | "SECONDS" | "SECS", CalendarIntervalType) => ExtractIntervalSeconds(source) - case _ => errorHandleFunc + case _ => + throw QueryCompilationErrors.literalTypeUnsupportedForSourceTypeError(extractField, source) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 5b058626e2227..9f00b7c8b7409 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -766,9 +766,9 @@ case class StructsToJson( examples = """ Examples: > SELECT _FUNC_('[{"col":0}]'); - ARRAY> + ARRAY> > SELECT _FUNC_('[{"col":01}]', map('allowNumericLeadingZeros', 'true')); - ARRAY> + ARRAY> """, group = "json_funcs", since = "2.4.0") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index cc207e51f85c4..af10a18e4d16d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -158,6 +158,7 @@ object Literal { Literal(CatalystTypeConverters.createToCatalystConverter(dataType)(v), dataType) case _: DayTimeIntervalType if v.isInstanceOf[Duration] => Literal(CatalystTypeConverters.createToCatalystConverter(dataType)(v), dataType) + case _: ObjectType => Literal(v, dataType) case _ => Literal(CatalystTypeConverters.convertToCatalyst(v), dataType) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 03f9da66cab48..f64b6ea078a46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -21,11 +21,12 @@ import java.{lang => jl} import java.util.Locale import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult} +import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{NumberConverter, TypeUtils} +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -238,17 +239,6 @@ case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT" override protected def withNewChildInternal(newChild: Expression): Cbrt = copy(child = newChild) } -@ExpressionDescription( - usage = "_FUNC_(expr) - Returns the smallest integer not smaller than `expr`.", - examples = """ - Examples: - > SELECT _FUNC_(-0.1); - 0 - > SELECT _FUNC_(5); - 5 - """, - since = "1.4.0", - group = "math_funcs") case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL") { override def dataType: DataType = child.dataType match { case dt @ DecimalType.Fixed(_, 0) => dt @@ -279,6 +269,76 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" override protected def withNewChildInternal(newChild: Expression): Ceil = copy(child = newChild) } +trait CeilFloorExpressionBuilderBase extends ExpressionBuilder { + protected def buildWithOneParam(param: Expression): Expression + protected def buildWithTwoParams(param1: Expression, param2: Expression): Expression + + override def build(funcName: String, expressions: Seq[Expression]): Expression = { + val numArgs = expressions.length + if (numArgs == 1) { + buildWithOneParam(expressions.head) + } else if (numArgs == 2) { + val scale = expressions(1) + if (!(scale.foldable && scale.dataType == IntegerType)) { + throw QueryCompilationErrors.requireLiteralParameter(funcName, "scale", "int") + } + if (scale.eval() == null) { + throw QueryCompilationErrors.requireLiteralParameter(funcName, "scale", "int") + } + buildWithTwoParams(expressions(0), scale) + } else { + throw QueryCompilationErrors.invalidFunctionArgumentNumberError(Seq(2), funcName, numArgs) + } + } +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(expr[, scale]) - Returns the smallest number after rounding up that is not smaller than `expr`. An optional `scale` parameter can be specified to control the rounding behavior.", + examples = """ + Examples: + > SELECT _FUNC_(-0.1); + 0 + > SELECT _FUNC_(5); + 5 + > SELECT _FUNC_(3.1411, 3); + 3.142 + > SELECT _FUNC_(3.1411, -3); + 1000 + """, + since = "3.3.0", + group = "math_funcs") +// scalastyle:on line.size.limit +object CeilExpressionBuilder extends CeilFloorExpressionBuilderBase { + override protected def buildWithOneParam(param: Expression): Expression = Ceil(param) + + override protected def buildWithTwoParams(param1: Expression, param2: Expression): Expression = + RoundCeil(param1, param2) +} + +case class RoundCeil(child: Expression, scale: Expression) + extends RoundBase(child, scale, BigDecimal.RoundingMode.CEILING, "ROUND_CEILING") + with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, IntegerType) + + override lazy val dataType: DataType = child.dataType match { + case DecimalType.Fixed(p, s) => + if (_scale < 0) { + DecimalType(math.max(p, 1 - _scale), 0) + } else { + DecimalType(p, math.min(s, _scale)) + } + case t => t + } + + override def nodeName: String = "ceil" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): RoundCeil = + copy(child = newLeft, scale = newRight) +} + @ExpressionDescription( usage = """ _FUNC_(expr) - Returns the cosine of `expr`, as if computed by @@ -448,17 +508,6 @@ case class Expm1(child: Expression) extends UnaryMathExpression(StrictMath.expm1 override protected def withNewChildInternal(newChild: Expression): Expm1 = copy(child = newChild) } -@ExpressionDescription( - usage = "_FUNC_(expr) - Returns the largest integer not greater than `expr`.", - examples = """ - Examples: - > SELECT _FUNC_(-0.1); - -1 - > SELECT _FUNC_(5); - 5 - """, - since = "1.4.0", - group = "math_funcs") case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") { override def dataType: DataType = child.dataType match { case dt @ DecimalType.Fixed(_, 0) => dt @@ -484,9 +533,56 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO case LongType => defineCodeGen(ctx, ev, c => s"$c") case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") } + } + override protected def withNewChildInternal(newChild: Expression): Floor = + copy(child = newChild) +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = " _FUNC_(expr[, scale]) - Returns the largest number after rounding down that is not greater than `expr`. An optional `scale` parameter can be specified to control the rounding behavior.", + examples = """ + Examples: + > SELECT _FUNC_(-0.1); + -1 + > SELECT _FUNC_(5); + 5 + > SELECT _FUNC_(3.1411, 3); + 3.141 + > SELECT _FUNC_(3.1411, -3); + 0 + """, + since = "3.3.0", + group = "math_funcs") +// scalastyle:on line.size.limit +object FloorExpressionBuilder extends CeilFloorExpressionBuilderBase { + override protected def buildWithOneParam(param: Expression): Expression = Floor(param) + + override protected def buildWithTwoParams(param1: Expression, param2: Expression): Expression = + RoundFloor(param1, param2) +} + +case class RoundFloor(child: Expression, scale: Expression) + extends RoundBase(child, scale, BigDecimal.RoundingMode.FLOOR, "ROUND_FLOOR") + with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, IntegerType) + + override lazy val dataType: DataType = child.dataType match { + case DecimalType.Fixed(p, s) => + if (_scale < 0) { + DecimalType(math.max(p, 1 - _scale), 0) + } else { + DecimalType(p, math.min(s, _scale)) + } + case t => t } - override protected def withNewChildInternal(newChild: Expression): Floor = copy(child = newChild) + override def nodeName: String = "floor" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): RoundFloor = + copy(child = newLeft, scale = newRight) } object Factorial { @@ -1375,7 +1471,7 @@ abstract class RoundBase(child: Expression, scale: Expression, // avoid unnecessary `child` evaluation in both codegen and non-codegen eval // by checking if scaleV == null as well. private lazy val scaleV: Any = scale.eval(EmptyRow) - private lazy val _scale: Int = scaleV.asInstanceOf[Int] + protected lazy val _scale: Int = scaleV.asInstanceOf[Int] override def eval(input: InternalRow): Any = { if (scaleV == null) { // if scale is null, no need to eval its child at all @@ -1393,10 +1489,14 @@ abstract class RoundBase(child: Expression, scale: Expression, // not overriding since _scale is a constant int at runtime def nullSafeEval(input1: Any): Any = { dataType match { - case DecimalType.Fixed(_, s) => + case DecimalType.Fixed(p, s) => val decimal = input1.asInstanceOf[Decimal] - // Overflow cannot happen, so no need to control nullOnOverflow - decimal.toPrecision(decimal.precision, s, mode) + if (_scale >= 0) { + // Overflow cannot happen, so no need to control nullOnOverflow + decimal.toPrecision(decimal.precision, s, mode) + } else { + Decimal(decimal.toBigDecimal.setScale(_scale, mode), p, s) + } case ByteType => BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte case ShortType => @@ -1426,12 +1526,18 @@ abstract class RoundBase(child: Expression, scale: Expression, val ce = child.genCode(ctx) val evaluationCode = dataType match { - case DecimalType.Fixed(_, s) => - s""" - |${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s, - | Decimal.$modeStr(), true); - |${ev.isNull} = ${ev.value} == null; - """.stripMargin + case DecimalType.Fixed(p, s) => + if (_scale >= 0) { + s""" + ${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s, + Decimal.$modeStr(), true); + ${ev.isNull} = ${ev.value} == null;""" + } else { + s""" + ${ev.value} = new Decimal().set(${ce.value}.toBigDecimal() + .setScale(${_scale}, Decimal.$modeStr()), $p, $s); + ${ev.isNull} = ${ev.value} == null;""" + } case ByteType => if (_scale < 0) { s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 941ccb7088393..eb21bd555db7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -126,8 +126,8 @@ object RaiseError { """, since = "2.0.0", group = "misc_funcs") -case class AssertTrue(left: Expression, right: Expression, child: Expression) - extends RuntimeReplaceable { +case class AssertTrue(left: Expression, right: Expression, replacement: Expression) + extends RuntimeReplaceable with InheritAnalysisRules { override def prettyName: String = "assert_true" @@ -139,11 +139,10 @@ case class AssertTrue(left: Expression, right: Expression, child: Expression) this(left, Literal(s"'${left.simpleString(SQLConf.get.maxToStringFields)}' is not true!")) } - override def flatArguments: Iterator[Any] = Iterator(left, right) - override def exprsReplaced: Seq[Expression] = Seq(left, right) + override def parameters: Seq[Expression] = Seq(left, right) override protected def withNewChildInternal(newChild: Expression): AssertTrue = - copy(child = newChild) + copy(replacement = newChild) } object AssertTrue { @@ -341,31 +340,31 @@ case class AesEncrypt( input: Expression, key: Expression, mode: Expression, - padding: Expression, - child: Expression) - extends RuntimeReplaceable { - - def this(input: Expression, key: Expression, mode: Expression, padding: Expression) = { - this( - input, - key, - mode, - padding, - StaticInvoke( - classOf[ExpressionImplUtils], - BinaryType, - "aesEncrypt", - Seq(input, key, mode, padding), - Seq(BinaryType, BinaryType, StringType, StringType))) - } + padding: Expression) + extends RuntimeReplaceable with ImplicitCastInputTypes { + + override lazy val replacement: Expression = StaticInvoke( + classOf[ExpressionImplUtils], + BinaryType, + "aesEncrypt", + Seq(input, key, mode, padding), + inputTypes) + def this(input: Expression, key: Expression, mode: Expression) = this(input, key, mode, Literal("DEFAULT")) def this(input: Expression, key: Expression) = this(input, key, Literal("GCM")) - def exprsReplaced: Seq[Expression] = Seq(input, key, mode, padding) - protected def withNewChildInternal(newChild: Expression): AesEncrypt = - copy(child = newChild) + override def prettyName: String = "aes_encrypt" + + override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType, BinaryType, StringType, StringType) + + override def children: Seq[Expression] = Seq(input, key, mode, padding) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): Expression = { + copy(newChildren(0), newChildren(1), newChildren(2), newChildren(3)) + } } /** @@ -405,30 +404,32 @@ case class AesDecrypt( input: Expression, key: Expression, mode: Expression, - padding: Expression, - child: Expression) - extends RuntimeReplaceable { - - def this(input: Expression, key: Expression, mode: Expression, padding: Expression) = { - this( - input, - key, - mode, - padding, - StaticInvoke( - classOf[ExpressionImplUtils], - BinaryType, - "aesDecrypt", - Seq(input, key, mode, padding), - Seq(BinaryType, BinaryType, StringType, StringType))) - } + padding: Expression) + extends RuntimeReplaceable with ImplicitCastInputTypes { + + override lazy val replacement: Expression = StaticInvoke( + classOf[ExpressionImplUtils], + BinaryType, + "aesDecrypt", + Seq(input, key, mode, padding), + inputTypes) + def this(input: Expression, key: Expression, mode: Expression) = this(input, key, mode, Literal("DEFAULT")) def this(input: Expression, key: Expression) = this(input, key, Literal("GCM")) - def exprsReplaced: Seq[Expression] = Seq(input, key) - protected def withNewChildInternal(newChild: Expression): AesDecrypt = - copy(child = newChild) + override def inputTypes: Seq[AbstractDataType] = { + Seq(BinaryType, BinaryType, StringType, StringType) + } + + override def prettyName: String = "aes_decrypt" + + override def children: Seq[Expression] = Seq(input, key, mode, padding) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): Expression = { + copy(newChildren(0), newChildren(1), newChildren(2), newChildren(3)) + } } // scalastyle:on line.size.limit diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index c51030fdd6405..d5df6a12aa45b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -294,7 +294,7 @@ case class AttributeReference( } override lazy val preCanonicalized: Expression = { - AttributeReference("none", dataType.asNullable)(exprId) + AttributeReference("none", dataType)(exprId) } override def newInstance(): AttributeReference = @@ -342,7 +342,7 @@ case class AttributeReference( AttributeReference(name, dataType, nullable, newMetadata)(exprId, qualifier) } - override def withDataType(newType: DataType): Attribute = { + override def withDataType(newType: DataType): AttributeReference = { AttributeReference(name, newType, nullable, metadata)(exprId, qualifier) } @@ -434,8 +434,8 @@ object VirtualColumn { } /** - * The internal representation of the hidden metadata struct: - * set `__metadata_col` to `true` in AttributeReference metadata + * The internal representation of the MetadataAttribute, + * it sets `__metadata_col` to `true` in AttributeReference metadata * - apply() will create a metadata attribute reference * - unapply() will check if an attribute reference is the metadata attribute reference */ @@ -451,3 +451,43 @@ object MetadataAttribute { } else None } } + +/** + * The internal representation of the FileSourceMetadataAttribute, it sets `__metadata_col` + * and `__file_source_metadata_col` to `true` in AttributeReference's metadata + * - apply() will create a file source metadata attribute reference + * - unapply() will check if an attribute reference is the file source metadata attribute reference + */ +object FileSourceMetadataAttribute { + + val FILE_SOURCE_METADATA_COL_ATTR_KEY = "__file_source_metadata_col" + + def apply(name: String, dataType: DataType, nullable: Boolean = true): AttributeReference = + AttributeReference(name, dataType, nullable, + new MetadataBuilder() + .putBoolean(METADATA_COL_ATTR_KEY, value = true) + .putBoolean(FILE_SOURCE_METADATA_COL_ATTR_KEY, value = true).build())() + + def unapply(attr: AttributeReference): Option[AttributeReference] = + attr match { + case MetadataAttribute(attr) + if attr.metadata.contains(FILE_SOURCE_METADATA_COL_ATTR_KEY) + && attr.metadata.getBoolean(FILE_SOURCE_METADATA_COL_ATTR_KEY) => Some(attr) + case _ => None + } + + /** + * Cleanup the internal metadata information of an attribute if it is + * a [[FileSourceMetadataAttribute]], it will remove both [[METADATA_COL_ATTR_KEY]] and + * [[FILE_SOURCE_METADATA_COL_ATTR_KEY]] from the attribute [[Metadata]] + */ + def cleanupFileSourceMetadataInformation(attr: Attribute): Attribute = attr match { + case FileSourceMetadataAttribute(attr) => attr.withMetadata( + new MetadataBuilder().withMetadata(attr.metadata) + .remove(METADATA_COL_ATTR_KEY) + .remove(FILE_SOURCE_METADATA_COL_ATTR_KEY) + .build() + ) + case attr => attr + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index a15126a3347a3..3c6a9b8e78041 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -129,29 +129,6 @@ case class Coalesce(children: Seq[Expression]) extends ComplexTypeMergingExpress } -@ExpressionDescription( - usage = "_FUNC_(expr1, expr2) - Returns `expr2` if `expr1` is null, or `expr1` otherwise.", - examples = """ - Examples: - > SELECT _FUNC_(NULL, array('2')); - ["2"] - """, - since = "2.0.0", - group = "conditional_funcs") -case class IfNull(left: Expression, right: Expression, child: Expression) - extends RuntimeReplaceable { - - def this(left: Expression, right: Expression) = { - this(left, right, Coalesce(Seq(left, right))) - } - - override def flatArguments: Iterator[Any] = Iterator(left, right) - override def exprsReplaced: Seq[Expression] = Seq(left, right) - - override protected def withNewChildInternal(newChild: Expression): IfNull = copy(child = newChild) -} - - @ExpressionDescription( usage = "_FUNC_(expr1, expr2) - Returns null if `expr1` equals to `expr2`, or `expr1` otherwise.", examples = """ @@ -161,17 +138,18 @@ case class IfNull(left: Expression, right: Expression, child: Expression) """, since = "2.0.0", group = "conditional_funcs") -case class NullIf(left: Expression, right: Expression, child: Expression) - extends RuntimeReplaceable { +case class NullIf(left: Expression, right: Expression, replacement: Expression) + extends RuntimeReplaceable with InheritAnalysisRules { def this(left: Expression, right: Expression) = { this(left, right, If(EqualTo(left, right), Literal.create(null, left.dataType), left)) } - override def flatArguments: Iterator[Any] = Iterator(left, right) - override def exprsReplaced: Seq[Expression] = Seq(left, right) + override def parameters: Seq[Expression] = Seq(left, right) - override protected def withNewChildInternal(newChild: Expression): NullIf = copy(child = newChild) + override protected def withNewChildInternal(newChild: Expression): NullIf = { + copy(replacement = newChild) + } } @@ -184,16 +162,17 @@ case class NullIf(left: Expression, right: Expression, child: Expression) """, since = "2.0.0", group = "conditional_funcs") -case class Nvl(left: Expression, right: Expression, child: Expression) extends RuntimeReplaceable { +case class Nvl(left: Expression, right: Expression, replacement: Expression) + extends RuntimeReplaceable with InheritAnalysisRules { def this(left: Expression, right: Expression) = { this(left, right, Coalesce(Seq(left, right))) } - override def flatArguments: Iterator[Any] = Iterator(left, right) - override def exprsReplaced: Seq[Expression] = Seq(left, right) + override def parameters: Seq[Expression] = Seq(left, right) - override protected def withNewChildInternal(newChild: Expression): Nvl = copy(child = newChild) + override protected def withNewChildInternal(newChild: Expression): Nvl = + copy(replacement = newChild) } @@ -208,17 +187,18 @@ case class Nvl(left: Expression, right: Expression, child: Expression) extends R since = "2.0.0", group = "conditional_funcs") // scalastyle:on line.size.limit -case class Nvl2(expr1: Expression, expr2: Expression, expr3: Expression, child: Expression) - extends RuntimeReplaceable { +case class Nvl2(expr1: Expression, expr2: Expression, expr3: Expression, replacement: Expression) + extends RuntimeReplaceable with InheritAnalysisRules { def this(expr1: Expression, expr2: Expression, expr3: Expression) = { this(expr1, expr2, expr3, If(IsNotNull(expr1), expr2, expr3)) } - override def flatArguments: Iterator[Any] = Iterator(expr1, expr2, expr3) - override def exprsReplaced: Seq[Expression] = Seq(expr1, expr2, expr3) + override def parameters: Seq[Expression] = Seq(expr1, expr2, expr3) - override protected def withNewChildInternal(newChild: Expression): Nvl2 = copy(child = newChild) + override protected def withNewChildInternal(newChild: Expression): Nvl2 = { + copy(replacement = newChild) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala new file mode 100644 index 0000000000000..e29a425eef199 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.util.Locale + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper +import org.apache.spark.sql.catalyst.util.NumberFormatter +import org.apache.spark.sql.types.{DataType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +/** + * A function that converts string to numeric. + */ +@ExpressionDescription( + usage = """ + _FUNC_(strExpr, formatExpr) - Convert `strExpr` to a number based on the `formatExpr`. + The format can consist of the following characters: + '0' or '9': digit position + '.' or 'D': decimal point (only allowed once) + ',' or 'G': group (thousands) separator + '-' or 'S': sign anchored to number (only allowed once) + '$': value with a leading dollar sign (only allowed once) + """, + examples = """ + Examples: + > SELECT _FUNC_('454', '999'); + 454 + > SELECT _FUNC_('454.00', '000D00'); + 454.00 + > SELECT _FUNC_('12,454', '99G999'); + 12454 + > SELECT _FUNC_('$78.12', '$99.99'); + 78.12 + > SELECT _FUNC_('12,454.8-', '99G999D9S'); + -12454.8 + """, + since = "3.3.0", + group = "string_funcs") +case class ToNumber(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { + + private lazy val numberFormat = right.eval().toString.toUpperCase(Locale.ROOT) + private lazy val numberFormatter = new NumberFormatter(numberFormat) + + override def dataType: DataType = numberFormatter.parsedDecimalType + + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + + override def checkInputDataTypes(): TypeCheckResult = { + val inputTypeCheck = super.checkInputDataTypes() + if (inputTypeCheck.isSuccess) { + if (right.foldable) { + numberFormatter.check() + } else { + TypeCheckResult.TypeCheckFailure(s"Format expression must be foldable, but got $right") + } + } else { + inputTypeCheck + } + } + + override def prettyName: String = "to_number" + + override def nullSafeEval(string: Any, format: Any): Any = { + val input = string.asInstanceOf[UTF8String] + numberFormatter.parse(input) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val builder = + ctx.addReferenceObj("builder", numberFormatter, classOf[NumberFormatter].getName) + val eval = left.genCode(ctx) + ev.copy(code = + code""" + |${eval.code} + |boolean ${ev.isNull} = ${eval.isNull}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |if (!${ev.isNull}) { + | ${ev.value} = $builder.parse(${eval.value}); + |} + """.stripMargin) + } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ToNumber = copy(left = newLeft, right = newRight) +} + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 50e214011b616..6974ada8735c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -50,6 +50,7 @@ trait InvokeLike extends Expression with NonSQLExpression with ImplicitCastInput def propagateNull: Boolean + override def foldable: Boolean = children.forall(_.foldable) && deterministic protected lazy val needNullCheck: Boolean = needNullCheckForIndex.contains(true) protected lazy val needNullCheckForIndex: Array[Boolean] = arguments.map(a => a.nullable && (propagateNull || @@ -240,6 +241,8 @@ object SerializerSupport { * without invoking the function. * @param returnNullable When false, indicating the invoked method will always return * non-null value. + * @param isDeterministic Whether the method invocation is deterministic or not. If false, Spark + * will not apply certain optimizations such as constant folding. */ case class StaticInvoke( staticObject: Class[_], @@ -248,7 +251,8 @@ case class StaticInvoke( arguments: Seq[Expression] = Nil, inputTypes: Seq[AbstractDataType] = Nil, propagateNull: Boolean = true, - returnNullable: Boolean = true) extends InvokeLike { + returnNullable: Boolean = true, + isDeterministic: Boolean = true) extends InvokeLike { val objectName = staticObject.getName.stripSuffix("$") val cls = if (staticObject.getName == objectName) { @@ -259,6 +263,7 @@ case class StaticInvoke( override def nullable: Boolean = needNullCheck || returnNullable override def children: Seq[Expression] = arguments + override lazy val deterministic: Boolean = isDeterministic && arguments.forall(_.deterministic) lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments) @transient lazy val method = findMethod(cls, functionName, argClasses) @@ -340,6 +345,8 @@ case class StaticInvoke( * without invoking the function. * @param returnNullable When false, indicating the invoked method will always return * non-null value. + * @param isDeterministic Whether the method invocation is deterministic or not. If false, Spark + * will not apply certain optimizations such as constant folding. */ case class Invoke( targetObject: Expression, @@ -348,12 +355,14 @@ case class Invoke( arguments: Seq[Expression] = Nil, methodInputTypes: Seq[AbstractDataType] = Nil, propagateNull: Boolean = true, - returnNullable : Boolean = true) extends InvokeLike { + returnNullable : Boolean = true, + isDeterministic: Boolean = true) extends InvokeLike { lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments) override def nullable: Boolean = targetObject.nullable || needNullCheck || returnNullable override def children: Seq[Expression] = targetObject +: arguments + override lazy val deterministic: Boolean = isDeterministic && arguments.forall(_.deterministic) override def inputTypes: Seq[AbstractDataType] = if (methodInputTypes.nonEmpty) { Seq(targetObject.dataType) ++ methodInputTypes @@ -819,7 +828,7 @@ case class MapObjects private( private def executeFuncOnCollection(inputCollection: Seq[_]): Iterator[_] = { val row = new GenericInternalRow(1) - inputCollection.toIterator.map { element => + inputCollection.iterator.map { element => row.update(0, element) lambdaFunction.eval(row) } @@ -1866,14 +1875,14 @@ case class GetExternalRowField( * Validates the actual data type of input expression at runtime. If it doesn't match the * expectation, throw an exception. */ -case class ValidateExternalType(child: Expression, expected: DataType) +case class ValidateExternalType(child: Expression, expected: DataType, lenient: Boolean) extends UnaryExpression with NonSQLExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(ObjectType(classOf[Object])) override def nullable: Boolean = child.nullable - override val dataType: DataType = RowEncoder.externalDataTypeForInput(expected) + override val dataType: DataType = RowEncoder.externalDataTypeForInput(expected, lenient) private lazy val errMsg = s" is not a valid external type for schema of ${expected.simpleString}" @@ -1887,6 +1896,14 @@ case class ValidateExternalType(child: Expression, expected: DataType) (value: Any) => { value.getClass.isArray || value.isInstanceOf[Seq[_]] } + case _: DateType => + (value: Any) => { + value.isInstanceOf[java.sql.Date] || value.isInstanceOf[java.time.LocalDate] + } + case _: TimestampType => + (value: Any) => { + value.isInstanceOf[java.sql.Timestamp] || value.isInstanceOf[java.time.Instant] + } case _ => val dataTypeClazz = ScalaReflection.javaBoxedType(dataType) (value: Any) => { @@ -1909,13 +1926,21 @@ case class ValidateExternalType(child: Expression, expected: DataType) val errMsgField = ctx.addReferenceObj("errMsg", errMsg) val input = child.genCode(ctx) val obj = input.value - + def genCheckTypes(classes: Seq[Class[_]]): String = { + classes.map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ") + } val typeCheck = expected match { case _: DecimalType => - Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], classOf[Decimal]) - .map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ") + genCheckTypes(Seq( + classOf[java.math.BigDecimal], + classOf[scala.math.BigDecimal], + classOf[Decimal])) case _: ArrayType => s"$obj.getClass().isArray() || $obj instanceof ${classOf[scala.collection.Seq[_]].getName}" + case _: DateType => + genCheckTypes(Seq(classOf[java.sql.Date], classOf[java.time.LocalDate])) + case _: TimestampType => + genCheckTypes(Seq(classOf[java.sql.Timestamp], classOf[java.time.Instant])) case _ => s"$obj instanceof ${CodeGenerator.boxedType(dataType)}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index d950fef3b26a5..6a4fb099c8b78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -335,8 +335,14 @@ package object expressions { matchWithFourOrMoreQualifierParts(nameParts, resolver) } + val prunedCandidates = if (candidates.size > 1) { + candidates.filter(c => !c.metadata.contains("__is_duplicate")) + } else { + candidates + } + def name = UnresolvedAttribute(nameParts).name - candidates match { + prunedCandidates match { case Seq(a) if nestedFields.nonEmpty => // One match, but we also need to extract the requested nested field. // The foldLeft adds ExtractValues for every remaining parts of the identifier, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index c09d3e47e460a..a2fd668f495e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -1134,7 +1134,7 @@ case class LessThanOrEqual(left: Expression, right: Expression) Examples: > SELECT 2 _FUNC_ 1; true - > SELECT 2 _FUNC_ '1.1'; + > SELECT 2 _FUNC_ 1.1; true > SELECT to_date('2009-07-30 04:17:52') _FUNC_ to_date('2009-07-30 04:17:52'); false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 889c53bc548bb..368cbfd6be641 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, TreePattern} import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils} import org.apache.spark.sql.errors.QueryExecutionErrors @@ -240,18 +241,20 @@ case class Like(left: Expression, right: Expression, escapeChar: Char) case class ILike( left: Expression, right: Expression, - escapeChar: Char, - child: Expression) extends RuntimeReplaceable { - def this(left: Expression, right: Expression, escapeChar: Char) = - this(left, right, escapeChar, Like(Lower(left), Lower(right), escapeChar)) + escapeChar: Char) extends RuntimeReplaceable + with ImplicitCastInputTypes with BinaryLike[Expression] { + + override lazy val replacement: Expression = Like(Lower(left), Lower(right), escapeChar) + def this(left: Expression, right: Expression) = this(left, right, '\\') - override def exprsReplaced: Seq[Expression] = Seq(left, right) - override def flatArguments: Iterator[Any] = Iterator(left, right, escapeChar) + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) - override protected def withNewChildInternal(newChild: Expression): ILike = - copy(child = newChild) + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Expression = { + copy(left = newLeft, right = newRight) + } } sealed abstract class MultiLikeBase diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index f1762f4eac767..fc73216b296af 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -30,12 +30,14 @@ import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegist import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke +import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LOWER} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.UTF8StringBuilder +import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.{ByteArray, UTF8String} //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -361,15 +363,19 @@ case class Elt( ev.copy( code""" |${index.code} - |final int $indexVal = ${index.value}; - |${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false; - |$inputVal = null; - |do { - | $codes - |} while (false); - |$indexOutOfBoundBranch - |final ${CodeGenerator.javaType(dataType)} ${ev.value} = $inputVal; - |final boolean ${ev.isNull} = ${ev.value} == null; + |boolean ${ev.isNull} = ${index.isNull}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = null; + |if (!${index.isNull}) { + | final int $indexVal = ${index.value}; + | ${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false; + | $inputVal = null; + | do { + | $codes + | } while (false); + | $indexOutOfBoundBranch + | ${ev.value} = $inputVal; + | ${ev.isNull} = ${ev.value} == null; + |} """.stripMargin) } @@ -463,13 +469,62 @@ abstract class StringPredicate extends BinaryExpression override def toString: String = s"$nodeName($left, $right)" } -/** - * A function that returns true if the string `left` contains the string `right`. - */ +trait StringBinaryPredicateExpressionBuilderBase extends ExpressionBuilder { + override def build(funcName: String, expressions: Seq[Expression]): Expression = { + val numArgs = expressions.length + if (numArgs == 2) { + if (expressions(0).dataType == BinaryType && expressions(1).dataType == BinaryType) { + BinaryPredicate(funcName, expressions(0), expressions(1)) + } else { + createStringPredicate(expressions(0), expressions(1)) + } + } else { + throw QueryCompilationErrors.invalidFunctionArgumentNumberError(Seq(2), funcName, numArgs) + } + } + + protected def createStringPredicate(left: Expression, right: Expression): Expression +} + +object BinaryPredicate { + def unapply(expr: Expression): Option[StaticInvoke] = expr match { + case s @ StaticInvoke(clz, _, "contains" | "startsWith" | "endsWith", Seq(_, _), _, _, _, _) + if clz == classOf[ByteArrayMethods] => Some(s) + case _ => None + } +} + +case class BinaryPredicate(override val prettyName: String, left: Expression, right: Expression) + extends RuntimeReplaceable with ImplicitCastInputTypes with BinaryLike[Expression] { + + private lazy val realFuncName = prettyName match { + case "startswith" => "startsWith" + case "endswith" => "endsWith" + case name => name + } + + override lazy val replacement = + StaticInvoke( + classOf[ByteArrayMethods], + BooleanType, + realFuncName, + Seq(left, right), + Seq(BinaryType, BinaryType)) + + override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType, BinaryType) + + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = { + copy(left = newLeft, right = newRight) + } +} + @ExpressionDescription( usage = """ _FUNC_(left, right) - Returns a boolean. The value is True if right is found inside left. Returns NULL if either input expression is NULL. Otherwise, returns False. + Both left or right must be of STRING or BINARY type. """, examples = """ Examples: @@ -479,10 +534,18 @@ abstract class StringPredicate extends BinaryExpression false > SELECT _FUNC_('Spark SQL', null); NULL + > SELECT _FUNC_(x'537061726b2053514c', x'537061726b'); + true """, since = "3.3.0", group = "string_funcs" ) +object ContainsExpressionBuilder extends StringBinaryPredicateExpressionBuilderBase { + override protected def createStringPredicate(left: Expression, right: Expression): Expression = { + Contains(left, right) + } +} + case class Contains(left: Expression, right: Expression) extends StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -494,8 +557,9 @@ case class Contains(left: Expression, right: Expression) extends StringPredicate @ExpressionDescription( usage = """ - _FUNC_(left, right) - Returns true if the string `left` starts with the string `right`. - Returns NULL if either input expression is NULL. + _FUNC_(left, right) - Returns a boolean. The value is True if left starts with right. + Returns NULL if either input expression is NULL. Otherwise, returns False. + Both left or right must be of STRING or BINARY type. """, examples = """ Examples: @@ -505,10 +569,20 @@ case class Contains(left: Expression, right: Expression) extends StringPredicate false > SELECT _FUNC_('Spark SQL', null); NULL + > SELECT _FUNC_(x'537061726b2053514c', x'537061726b'); + true + > SELECT _FUNC_(x'537061726b2053514c', x'53514c'); + false """, since = "3.3.0", group = "string_funcs" ) +object StartsWithExpressionBuilder extends StringBinaryPredicateExpressionBuilderBase { + override protected def createStringPredicate(left: Expression, right: Expression): Expression = { + StartsWith(left, right) + } +} + case class StartsWith(left: Expression, right: Expression) extends StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -520,8 +594,9 @@ case class StartsWith(left: Expression, right: Expression) extends StringPredica @ExpressionDescription( usage = """ - _FUNC_(left, right) - Returns true if the string `left` ends with the string `right`. - Returns NULL if either input expression is NULL. + _FUNC_(left, right) - Returns a boolean. The value is True if left ends with right. + Returns NULL if either input expression is NULL. Otherwise, returns False. + Both left or right must be of STRING or BINARY type. """, examples = """ Examples: @@ -531,10 +606,20 @@ case class StartsWith(left: Expression, right: Expression) extends StringPredica false > SELECT _FUNC_('Spark SQL', null); NULL + > SELECT _FUNC_(x'537061726b2053514c', x'537061726b'); + false + > SELECT _FUNC_(x'537061726b2053514c', x'53514c'); + true """, since = "3.3.0", group = "string_funcs" ) +object EndsWithExpressionBuilder extends StringBinaryPredicateExpressionBuilderBase { + override protected def createStringPredicate(left: Expression, right: Expression): Expression = { + EndsWith(left, right) + } +} + case class EndsWith(left: Expression, right: Expression) extends StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -1047,8 +1132,8 @@ case class StringTrim(srcStr: Expression, trimStr: Option[Expression] = None) """, since = "3.2.0", group = "string_funcs") -case class StringTrimBoth(srcStr: Expression, trimStr: Option[Expression], child: Expression) - extends RuntimeReplaceable { +case class StringTrimBoth(srcStr: Expression, trimStr: Option[Expression], replacement: Expression) + extends RuntimeReplaceable with InheritAnalysisRules { def this(srcStr: Expression, trimStr: Expression) = { this(srcStr, Option(trimStr), StringTrim(srcStr, trimStr)) @@ -1058,13 +1143,12 @@ case class StringTrimBoth(srcStr: Expression, trimStr: Option[Expression], child this(srcStr, None, StringTrim(srcStr)) } - override def exprsReplaced: Seq[Expression] = srcStr +: trimStr.toSeq - override def flatArguments: Iterator[Any] = Iterator(srcStr, trimStr) - override def prettyName: String = "btrim" + override def parameters: Seq[Expression] = srcStr +: trimStr.toSeq + override protected def withNewChildInternal(newChild: Expression): StringTrimBoth = - copy(child = newChild) + copy(replacement = newChild) } object StringTrimLeft { @@ -1376,17 +1460,17 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) } trait PadExpressionBuilderBase extends ExpressionBuilder { - override def build(expressions: Seq[Expression]): Expression = { + override def build(funcName: String, expressions: Seq[Expression]): Expression = { val numArgs = expressions.length if (numArgs == 2) { if (expressions(0).dataType == BinaryType) { - createBinaryPad(expressions(0), expressions(1), Literal(Array[Byte](0))) + BinaryPad(funcName, expressions(0), expressions(1), Literal(Array[Byte](0))) } else { createStringPad(expressions(0), expressions(1), Literal(" ")) } } else if (numArgs == 3) { if (expressions(0).dataType == BinaryType && expressions(2).dataType == BinaryType) { - createBinaryPad(expressions(0), expressions(1), expressions(2)) + BinaryPad(funcName, expressions(0), expressions(1), expressions(2)) } else { createStringPad(expressions(0), expressions(1), expressions(2)) } @@ -1395,8 +1479,6 @@ trait PadExpressionBuilderBase extends ExpressionBuilder { } } - protected def funcName: String - protected def createBinaryPad(str: Expression, len: Expression, pad: Expression): Expression protected def createStringPad(str: Expression, len: Expression, pad: Expression): Expression } @@ -1423,10 +1505,6 @@ trait PadExpressionBuilderBase extends ExpressionBuilder { since = "1.5.0", group = "string_funcs") object LPadExpressionBuilder extends PadExpressionBuilderBase { - override def funcName: String = "lpad" - override def createBinaryPad(str: Expression, len: Expression, pad: Expression): Expression = { - new BinaryLPad(str, len, pad) - } override def createStringPad(str: Expression, len: Expression, pad: Expression): Expression = { StringLPad(str, len, pad) } @@ -1459,21 +1537,28 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) copy(str = newFirst, len = newSecond, pad = newThird) } -case class BinaryLPad(str: Expression, len: Expression, pad: Expression, child: Expression) - extends RuntimeReplaceable { +case class BinaryPad(funcName: String, str: Expression, len: Expression, pad: Expression) + extends RuntimeReplaceable with ImplicitCastInputTypes { + assert(funcName == "lpad" || funcName == "rpad") - def this(str: Expression, len: Expression, pad: Expression) = this(str, len, pad, StaticInvoke( + override lazy val replacement: Expression = StaticInvoke( classOf[ByteArray], BinaryType, - "lpad", + funcName, Seq(str, len, pad), - Seq(BinaryType, IntegerType, BinaryType), + inputTypes, returnNullable = false) - ) - override def prettyName: String = "lpad" - def exprsReplaced: Seq[Expression] = Seq(str, len, pad) - protected def withNewChildInternal(newChild: Expression): BinaryLPad = copy(child = newChild) + override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType, IntegerType, BinaryType) + + override def nodeName: String = funcName + + override def children: Seq[Expression] = Seq(str, len, pad) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): Expression = { + copy(str = newChildren(0), len = newChildren(1), pad = newChildren(2)) + } } @ExpressionDescription( @@ -1499,10 +1584,6 @@ case class BinaryLPad(str: Expression, len: Expression, pad: Expression, child: since = "1.5.0", group = "string_funcs") object RPadExpressionBuilder extends PadExpressionBuilderBase { - override def funcName: String = "rpad" - override def createBinaryPad(str: Expression, len: Expression, pad: Expression): Expression = { - new BinaryRPad(str, len, pad) - } override def createStringPad(str: Expression, len: Expression, pad: Expression): Expression = { StringRPad(str, len, pad) } @@ -1535,23 +1616,6 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression = Litera copy(str = newFirst, len = newSecond, pad = newThird) } -case class BinaryRPad(str: Expression, len: Expression, pad: Expression, child: Expression) - extends RuntimeReplaceable { - - def this(str: Expression, len: Expression, pad: Expression) = this(str, len, pad, StaticInvoke( - classOf[ByteArray], - BinaryType, - "rpad", - Seq(str, len, pad), - Seq(BinaryType, IntegerType, BinaryType), - returnNullable = false) - ) - - override def prettyName: String = "rpad" - def exprsReplaced: Seq[Expression] = Seq(str, len, pad) - protected def withNewChildInternal(newChild: Expression): BinaryRPad = copy(child = newChild) -} - object ParseUrl { private val HOST = UTF8String.fromString("HOST") private val PATH = UTF8String.fromString("PATH") @@ -2025,16 +2089,26 @@ case class Substring(str: Expression, pos: Expression, len: Expression) since = "2.3.0", group = "string_funcs") // scalastyle:on line.size.limit -case class Right(str: Expression, len: Expression, child: Expression) extends RuntimeReplaceable { - def this(str: Expression, len: Expression) = { - this(str, len, If(IsNull(str), Literal(null, StringType), If(LessThanOrEqual(len, Literal(0)), - Literal(UTF8String.EMPTY_UTF8, StringType), new Substring(str, UnaryMinus(len))))) - } - - override def flatArguments: Iterator[Any] = Iterator(str, len) - override def exprsReplaced: Seq[Expression] = Seq(str, len) +case class Right(str: Expression, len: Expression) extends RuntimeReplaceable + with ImplicitCastInputTypes with BinaryLike[Expression] { + + override lazy val replacement: Expression = If( + IsNull(str), + Literal(null, StringType), + If( + LessThanOrEqual(len, Literal(0)), + Literal(UTF8String.EMPTY_UTF8, StringType), + new Substring(str, UnaryMinus(len)) + ) + ) - override protected def withNewChildInternal(newChild: Expression): Right = copy(child = newChild) + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType) + override def left: Expression = str + override def right: Expression = len + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Expression = { + copy(str = newLeft, len = newRight) + } } /** @@ -2051,14 +2125,21 @@ case class Right(str: Expression, len: Expression, child: Expression) extends Ru since = "2.3.0", group = "string_funcs") // scalastyle:on line.size.limit -case class Left(str: Expression, len: Expression, child: Expression) extends RuntimeReplaceable { - def this(str: Expression, len: Expression) = { - this(str, len, Substring(str, Literal(1), len)) +case class Left(str: Expression, len: Expression) extends RuntimeReplaceable + with ImplicitCastInputTypes with BinaryLike[Expression] { + + override lazy val replacement: Expression = Substring(str, Literal(1), len) + + override def inputTypes: Seq[AbstractDataType] = { + Seq(TypeCollection(StringType, BinaryType), IntegerType) } - override def flatArguments: Iterator[Any] = Iterator(str, len) - override def exprsReplaced: Seq[Expression] = Seq(str, len) - override protected def withNewChildInternal(newChild: Expression): Left = copy(child = newChild) + override def left: Expression = str + override def right: Expression = len + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Expression = { + copy(str = newLeft, len = newRight) + } } /** @@ -2418,12 +2499,12 @@ object Decode { // scalastyle:off line.size.limit @ExpressionDescription( usage = """ - |_FUNC_(bin, charset) - Decodes the first argument using the second argument character set. - | - |_FUNC_(expr, search, result [, search, result ] ... [, default]) - Decode compares expr - | to each search value one by one. If expr is equal to a search, returns the corresponding result. - | If no match is found, then Oracle returns default. If default is omitted, returns null. - """, + _FUNC_(bin, charset) - Decodes the first argument using the second argument character set. + + _FUNC_(expr, search, result [, search, result ] ... [, default]) - Decode compares expr + to each search value one by one. If expr is equal to a search, returns the corresponding result. + If no match is found, then Oracle returns default. If default is omitted, returns null. + """, examples = """ Examples: > SELECT _FUNC_(encode('abc', 'utf-8'), 'utf-8'); @@ -2438,16 +2519,16 @@ object Decode { since = "3.2.0", group = "string_funcs") // scalastyle:on line.size.limit -case class Decode(params: Seq[Expression], child: Expression) extends RuntimeReplaceable { +case class Decode(params: Seq[Expression], replacement: Expression) + extends RuntimeReplaceable with InheritAnalysisRules { - def this(params: Seq[Expression]) = { - this(params, Decode.createExpr(params)) - } + def this(params: Seq[Expression]) = this(params, Decode.createExpr(params)) - override def flatArguments: Iterator[Any] = Iterator(params) - override def exprsReplaced: Seq[Expression] = params + override def parameters: Seq[Expression] = params - override protected def withNewChildInternal(newChild: Expression): Decode = copy(child = newChild) + override protected def withNewChildInternal(newChild: Expression): Expression = { + copy(replacement = newChild) + } } /** @@ -2538,6 +2619,73 @@ case class Encode(value: Expression, charset: Expression) newLeft: Expression, newRight: Expression): Encode = copy(value = newLeft, charset = newRight) } +/** + * Converts the input expression to a binary value based on the supplied format. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(str[, fmt]) - Converts the input `str` to a binary value based on the supplied `fmt`. + `fmt` can be a case-insensitive string literal of "hex", "utf-8", or "base64". + By default, the binary format for conversion is "hex" if `fmt` is omitted. + The function returns NULL if at least one of the input parameters is NULL. + """, + examples = """ + Examples: + > SELECT _FUNC_('abc', 'utf-8'); + abc + """, + since = "3.3.0", + group = "string_funcs") +// scalastyle:on line.size.limit +case class ToBinary(expr: Expression, format: Option[Expression]) extends RuntimeReplaceable + with ImplicitCastInputTypes { + + override lazy val replacement: Expression = format.map { f => + assert(f.foldable && (f.dataType == StringType || f.dataType == NullType)) + val value = f.eval() + if (value == null) { + Literal(null, BinaryType) + } else { + value.asInstanceOf[UTF8String].toString.toLowerCase(Locale.ROOT) match { + case "hex" => Unhex(expr) + case "utf-8" => Encode(expr, Literal("UTF-8")) + case "base64" => UnBase64(expr) + case other => throw QueryCompilationErrors.invalidStringLiteralParameter( + "to_binary", "format", other, + Some("The value has to be a case-insensitive string literal of " + + "'hex', 'utf-8', or 'base64'.")) + } + } + }.getOrElse(Unhex(expr)) + + def this(expr: Expression) = this(expr, None) + + def this(expr: Expression, format: Expression) = this(expr, Some({ + // We perform this check in the constructor to make it eager and not go through type coercion. + if (format.foldable && (format.dataType == StringType || format.dataType == NullType)) { + format + } else { + throw QueryCompilationErrors.requireLiteralParameter("to_binary", "format", "string") + } + })) + + override def prettyName: String = "to_binary" + + override def children: Seq[Expression] = expr +: format.toSeq + + override def inputTypes: Seq[AbstractDataType] = children.map(_ => StringType) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): Expression = { + if (format.isDefined) { + copy(expr = newChildren.head, format = Some(newChildren.last)) + } else { + copy(expr = newChildren.head) + } + } +} + /** * Formats the number X to a format like '#,###,###.##', rounded to D decimal places, * and returns the result as a string. If D is 0, the result has no decimal point or diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index d7112a291f661..71b36fa8ef9ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -88,11 +88,11 @@ object SubqueryExpression { * and false otherwise. */ def hasInOrCorrelatedExistsSubquery(e: Expression): Boolean = { - e.find { + e.exists { case _: ListQuery => true case ex: Exists => ex.isCorrelated case _ => false - }.isDefined + } } /** @@ -101,20 +101,20 @@ object SubqueryExpression { * [[org.apache.spark.sql.catalyst.analysis.Analyzer.ResolveSubquery]] */ def hasCorrelatedSubquery(e: Expression): Boolean = { - e.find { + e.exists { case s: SubqueryExpression => s.isCorrelated case _ => false - }.isDefined + } } /** * Returns true when an expression contains a subquery */ def hasSubquery(e: Expression): Boolean = { - e.find { + e.exists { case _: SubqueryExpression => true case _ => false - }.isDefined + } } } @@ -124,7 +124,7 @@ object SubExprUtils extends PredicateHelper { * returns false otherwise. */ def containsOuter(e: Expression): Boolean = { - e.find(_.isInstanceOf[OuterReference]).isDefined + e.exists(_.isInstanceOf[OuterReference]) } /** @@ -161,7 +161,7 @@ object SubExprUtils extends PredicateHelper { * Given a logical plan, returns TRUE if it has an outer reference and false otherwise. */ def hasOuterReferences(plan: LogicalPlan): Boolean = { - plan.find(_.expressions.exists(containsOuter)).isDefined + plan.exists(_.expressions.exists(containsOuter)) } /** @@ -282,10 +282,10 @@ case class ScalarSubquery( object ScalarSubquery { def hasCorrelatedScalarSubquery(e: Expression): Boolean = { - e.find { + e.exists { case s: ScalarSubquery => s.isCorrelated case _ => false - }.isDefined + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 6396fde575b8f..c701d10b00b73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -825,7 +825,7 @@ case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindow zero, zero, zero, - (n.cast(DecimalType.IntDecimal) / buckets.cast(DecimalType.IntDecimal)).cast(IntegerType), + (n div buckets).cast(IntegerType), (n % buckets).cast(IntegerType) ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index a1f9487fe2e08..abcbdb83813b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -204,9 +204,12 @@ class JacksonParser( case VALUE_STRING if parser.getTextLength >= 1 => // Special case handling for NaN and Infinity. parser.getText match { - case "NaN" => Float.NaN - case "Infinity" => Float.PositiveInfinity - case "-Infinity" => Float.NegativeInfinity + case "NaN" if options.allowNonNumericNumbers => + Float.NaN + case "+INF" | "+Infinity" | "Infinity" if options.allowNonNumericNumbers => + Float.PositiveInfinity + case "-INF" | "-Infinity" if options.allowNonNumericNumbers => + Float.NegativeInfinity case _ => throw QueryExecutionErrors.cannotParseStringAsDataTypeError( parser, VALUE_STRING, FloatType) } @@ -220,9 +223,12 @@ class JacksonParser( case VALUE_STRING if parser.getTextLength >= 1 => // Special case handling for NaN and Infinity. parser.getText match { - case "NaN" => Double.NaN - case "Infinity" => Double.PositiveInfinity - case "-Infinity" => Double.NegativeInfinity + case "NaN" if options.allowNonNumericNumbers => + Double.NaN + case "+INF" | "+Infinity" | "Infinity" if options.allowNonNumericNumbers => + Double.PositiveInfinity + case "-INF" | "-Infinity" if options.allowNonNumericNumbers => + Double.NegativeInfinity case _ => throw QueryExecutionErrors.cannotParseStringAsDataTypeError( parser, VALUE_STRING, DoubleType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index 6a63118698106..d08773d846960 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -100,7 +100,7 @@ private[sql] class JsonInferSchema(options: JSONOptions) extends Serializable { wrappedCharException.initCause(e) handleJsonErrorsByParseMode(parseMode, columnNameOfCorruptRecord, wrappedCharException) } - }.reduceOption(typeMerger).toIterator + }.reduceOption(typeMerger).iterator } // Here we manually submit a fold-like Spark job, so that we can set the SQLConf when running diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala index 71f3897ccf50b..9a4d1a33e30bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala @@ -87,25 +87,35 @@ object DecorrelateInnerQuery extends PredicateHelper { * leaf node and will not be found here. */ private def containsAttribute(expression: Expression): Boolean = { - expression.find(_.isInstanceOf[Attribute]).isDefined + expression.exists(_.isInstanceOf[Attribute]) } /** * Check if an expression can be pulled up over an [[Aggregate]] without changing the * semantics of the plan. The expression must be an equality predicate that guarantees - * one-to-one mapping between inner and outer attributes. More specifically, one side - * of the predicate must be an attribute and another side of the predicate must not - * contain other attributes from the inner query. + * one-to-one mapping between inner and outer attributes. * For example: * (a = outer(c)) -> true * (a > outer(c)) -> false * (a + b = outer(c)) -> false * (a = outer(c) - b) -> false */ - private def canPullUpOverAgg(expression: Expression): Boolean = expression match { - case Equality(_: Attribute, b) => !containsAttribute(b) - case Equality(a, _: Attribute) => !containsAttribute(a) - case o => !containsAttribute(o) + def canPullUpOverAgg(expression: Expression): Boolean = { + def isSupported(e: Expression): Boolean = e match { + case _: Attribute => true + // Allow Cast expressions that guarantee 1:1 mapping. + case Cast(a: Attribute, dataType, _, _) => Cast.canUpCast(a.dataType, dataType) + case _ => false + } + + // Only allow equality condition with one side being an attribute or an expression that + // guarantees 1:1 mapping and another side being an expression without attributes from + // the inner query. + expression match { + case Equality(a, b) if isSupported(a) => !containsAttribute(b) + case Equality(a, b) if isSupported(b) => !containsAttribute(a) + case o => !containsAttribute(o) + } } /** @@ -258,7 +268,7 @@ object DecorrelateInnerQuery extends PredicateHelper { // The decorrelation framework adds domain inner joins by traversing down the plan tree // recursively until it reaches a node that is not correlated with the outer query. // So the child node of a domain inner join shouldn't contain another domain join. - assert(child.find(_.isInstanceOf[DomainJoin]).isEmpty, + assert(!child.exists(_.isInstanceOf[DomainJoin]), s"Child of a domain inner join shouldn't contain another domain join.\n$child") child case o => @@ -599,6 +609,11 @@ object DecorrelateInnerQuery extends PredicateHelper { (newAggregate, joinCond, outerReferenceMap) } + case d: Distinct => + val (newChild, joinCond, outerReferenceMap) = + decorrelate(d.child, parentOuterReferences, aggregated = true) + (d.copy(child = newChild), joinCond, outerReferenceMap) + case j @ Join(left, right, joinType, condition, _) => val outerReferences = collectOuterReferences(j.expressions) // Join condition containing outer references is not supported. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala index 1de300ef9c09d..61577b1d21ea4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala @@ -54,7 +54,7 @@ object InlineCTE extends Rule[LogicalPlan] { // 2) Any `CTERelationRef` that contains `OuterReference` would have been inlined first. refCount == 1 || cteDef.deterministic || - cteDef.child.find(_.expressions.exists(_.isInstanceOf[OuterReference])).isDefined + cteDef.child.exists(_.expressions.exists(_.isInstanceOf[OuterReference])) } private def buildCTEMap( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala index 9d63f4e94647c..4c7130e51e0b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala @@ -210,6 +210,7 @@ object NestedColumnAliasing { case _: Repartition => true case _: Sample => true case _: RepartitionByExpression => true + case _: RebalancePartitions => true case _: Join => true case _: Window => true case _: Sort => true @@ -246,7 +247,7 @@ object NestedColumnAliasing { exprList.foreach { e => collectRootReferenceAndExtractValue(e).foreach { // we can not alias the attr from lambda variable whose expr id is not available - case ev: ExtractValue if ev.find(_.isInstanceOf[NamedLambdaVariable]).isEmpty => + case ev: ExtractValue if !ev.exists(_.isInstanceOf[NamedLambdaVariable]) => if (ev.references.size == 1) { nestedFieldReferences.append(ev) } @@ -266,7 +267,7 @@ object NestedColumnAliasing { // that do should not have an alias generated as it can lead to pushing the aggregate down // into a projection. def containsAggregateFunction(ev: ExtractValue): Boolean = - ev.find(_.isInstanceOf[AggregateFunction]).isDefined + ev.exists(_.isInstanceOf[AggregateFunction]) // Remove redundant [[ExtractValue]]s if they share the same parent nest field. // For example, when `a.b` and `a.b.c` are in project list, we only need to alias `a.b`. @@ -276,7 +277,7 @@ object NestedColumnAliasing { // [[GetStructField]] case e @ (_: GetStructField | _: GetArrayStructFields) => val child = e.children.head - nestedFields.forall(f => child.find(_.semanticEquals(f)).isEmpty) + nestedFields.forall(f => !child.exists(_.semanticEquals(f))) case _ => true } .distinct @@ -371,6 +372,17 @@ object GeneratorNestedColumnAliasing { e.withNewChildren(Seq(extractor)) } + // If after replacing generator expression with nested extractor, there + // is invalid extractor pattern like + // `GetArrayStructFields(GetArrayStructFields(...), ...), we cannot do + // pruning but fallback to original query plan. + val invalidExtractor = rewrittenG.generator.children.head.collect { + case GetArrayStructFields(_: GetArrayStructFields, _, _, _, _) => true + } + if (invalidExtractor.nonEmpty) { + return Some(pushedThrough) + } + // As we change the child of the generator, its output data type must be updated. val updatedGeneratorOutput = rewrittenG.generatorOutput .zip(rewrittenG.generator.elementSchema.toAttributes) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlan.scala new file mode 100644 index 0000000000000..83646611578cb --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlan.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.trees.TreePattern._ + +/** + * The rule is applied both normal and AQE Optimizer. It optimizes plan using max rows: + * - if the max rows of the child of sort is less than or equal to 1, remove the sort + * - if the max rows per partition of the child of local sort is less than or equal to 1, + * remove the local sort + * - if the max rows of the child of aggregate is less than or equal to 1 and its child and + * it's grouping only(include the rewritten distinct plan), convert aggregate to project + * - if the max rows of the child of aggregate is less than or equal to 1, + * set distinct to false in all aggregate expression + */ +object OptimizeOneRowPlan extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = { + plan.transformUpWithPruning(_.containsAnyPattern(SORT, AGGREGATE), ruleId) { + case Sort(_, _, child) if child.maxRows.exists(_ <= 1L) => child + case Sort(_, false, child) if child.maxRowsPerPartition.exists(_ <= 1L) => child + case agg @ Aggregate(_, _, child) if agg.groupOnly && child.maxRows.exists(_ <= 1L) => + Project(agg.aggregateExpressions, child) + case agg: Aggregate if agg.child.maxRows.exists(_ <= 1L) => + agg.transformExpressions { + case aggExpr: AggregateExpression if aggExpr.isDistinct => + aggExpr.copy(isDistinct = false) + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 3d41953ebfb58..debd5a66adb23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -52,7 +52,7 @@ abstract class Optimizer(catalogManager: CatalogManager) previousPlan: LogicalPlan, currentPlan: LogicalPlan): Boolean = { !Utils.isTesting || (currentPlan.resolved && - currentPlan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty && + !currentPlan.exists(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty) && LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(currentPlan) && DataType.equalsIgnoreNullability(previousPlan.schema, currentPlan.schema)) } @@ -108,7 +108,6 @@ abstract class Optimizer(catalogManager: CatalogManager) EliminateAggregateFilter, ReorderAssociativeOperator, LikeSimplification, - NotPropagation, BooleanSimplification, SimplifyConditionals, PushFoldableIntoBranches, @@ -239,6 +238,7 @@ abstract class Optimizer(catalogManager: CatalogManager) // PropagateEmptyRelation can change the nullability of an attribute from nullable to // non-nullable when an empty relation child of a Union is removed UpdateAttributeNullability) :+ + Batch("Optimize One Row Plan", fixedPoint, OptimizeOneRowPlan) :+ // The following batch should be executed after batch "Join Reorder" and "LocalRelation". Batch("Check Cartesian Products", Once, CheckCartesianProducts) :+ @@ -415,7 +415,8 @@ abstract class Optimizer(catalogManager: CatalogManager) * This rule should be applied before RewriteDistinctAggregates. */ object EliminateDistinct extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformExpressions { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning( + _.containsPattern(AGGREGATE_EXPRESSION)) { case ae: AggregateExpression if ae.isDistinct && isDuplicateAgnostic(ae.aggregateFunction) => ae.copy(isDistinct = false) } @@ -437,8 +438,8 @@ object EliminateDistinct extends Rule[LogicalPlan] { * This rule should be applied before RewriteDistinctAggregates. */ object EliminateAggregateFilter extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformExpressionsWithPruning( - _.containsAllPatterns(TRUE_OR_FALSE_LITERAL), ruleId) { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning( + _.containsAllPatterns(AGGREGATE_EXPRESSION, TRUE_OR_FALSE_LITERAL), ruleId) { case ae @ AggregateExpression(_, _, _, Some(Literal.TrueLiteral), _) => ae.copy(filter = None) case AggregateExpression(af: DeclarativeAggregate, _, _, Some(Literal.FalseLiteral), _) => @@ -684,7 +685,9 @@ object LimitPushDown extends Rule[LogicalPlan] { left = maybePushLocalLimit(limitExpr, join.left), right = maybePushLocalLimit(limitExpr, join.right)) case LeftSemi | LeftAnti if join.condition.isEmpty => - join.copy(left = maybePushLocalLimit(limitExpr, join.left)) + join.copy( + left = maybePushLocalLimit(limitExpr, join.left), + right = maybePushLocalLimit(Literal(1, IntegerType), join.right)) case _ => join } } @@ -719,9 +722,9 @@ object LimitPushDown extends Rule[LogicalPlan] { LocalLimit(exp, project.copy(child = pushLocalLimitThroughJoin(exp, join))) // Push down limit 1 through Aggregate and turn Aggregate into Project if it is group only. case Limit(le @ IntegerLiteral(1), a: Aggregate) if a.groupOnly => - Limit(le, Project(a.output, LocalLimit(le, a.child))) + Limit(le, Project(a.aggregateExpressions, LocalLimit(le, a.child))) case Limit(le @ IntegerLiteral(1), p @ Project(_, a: Aggregate)) if a.groupOnly => - Limit(le, p.copy(child = Project(a.output, LocalLimit(le, a.child)))) + Limit(le, p.copy(child = Project(a.aggregateExpressions, LocalLimit(le, a.child)))) } } @@ -762,22 +765,22 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] with PredicateHelper result.asInstanceOf[A] } + def pushProjectionThroughUnion(projectList: Seq[NamedExpression], u: Union): Seq[LogicalPlan] = { + val newFirstChild = Project(projectList, u.children.head) + val newOtherChildren = u.children.tail.map { child => + val rewrites = buildRewrites(u.children.head, child) + Project(projectList.map(pushToRight(_, rewrites)), child) + } + newFirstChild +: newOtherChildren + } + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( _.containsAllPatterns(UNION, PROJECT)) { // Push down deterministic projection through UNION ALL - case p @ Project(projectList, u: Union) => - assert(u.children.nonEmpty) - if (projectList.forall(_.deterministic)) { - val newFirstChild = Project(projectList, u.children.head) - val newOtherChildren = u.children.tail.map { child => - val rewrites = buildRewrites(u.children.head, child) - Project(projectList.map(pushToRight(_, rewrites)), child) - } - u.copy(children = newFirstChild +: newOtherChildren) - } else { - p - } + case Project(projectList, u: Union) + if projectList.forall(_.deterministic) && u.children.nonEmpty => + u.copy(children = pushProjectionThroughUnion(projectList, u)) } } @@ -1004,7 +1007,7 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper { }.isEmpty) } - private def buildCleanedProjectList( + def buildCleanedProjectList( upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Seq[NamedExpression] = { val aliases = getAliasMap(lower) @@ -1049,11 +1052,11 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper { } /** - * Combines adjacent [[RepartitionOperation]] operators + * Combines adjacent [[RepartitionOperation]] and [[RebalancePartitions]] operators */ object CollapseRepartition extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( - _.containsPattern(REPARTITION_OPERATION), ruleId) { + _.containsAnyPattern(REPARTITION_OPERATION, REBALANCE_PARTITIONS), ruleId) { // Case 1: When a Repartition has a child of Repartition or RepartitionByExpression, // 1) When the top node does not enable the shuffle (i.e., coalesce API), but the child // enables the shuffle. Returns the child node if the last numPartitions is bigger; @@ -1067,6 +1070,14 @@ object CollapseRepartition extends Rule[LogicalPlan] { // RepartitionByExpression we can remove the child. case r @ RepartitionByExpression(_, child @ (Sort(_, true, _) | _: RepartitionOperation), _) => r.withNewChildren(child.children) + // Case 3: When a RebalancePartitions has a child of local or global Sort, Repartition or + // RepartitionByExpression we can remove the child. + case r @ RebalancePartitions(_, child @ (_: Sort | _: RepartitionOperation), _) => + r.withNewChildren(child.children) + // Case 4: When a RebalancePartitions has a child of RebalancePartitions we can remove the + // child. + case r @ RebalancePartitions(_, child: RebalancePartitions, _) => + r.withNewChildren(child.children) } } @@ -1290,6 +1301,9 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] * Combines all adjacent [[Union]] operators into a single [[Union]]. */ object CombineUnions extends Rule[LogicalPlan] { + import CollapseProject.{buildCleanedProjectList, canCollapseExpressions} + import PushProjectionThroughUnion.pushProjectionThroughUnion + def apply(plan: LogicalPlan): LogicalPlan = plan.transformDownWithPruning( _.containsAnyPattern(UNION, DISTINCT_LIKE), ruleId) { case u: Union => flattenUnion(u, false) @@ -1311,6 +1325,10 @@ object CombineUnions extends Rule[LogicalPlan] { // rules (by position and by name) could cause incorrect results. while (stack.nonEmpty) { stack.pop() match { + case p1 @ Project(_, p2: Project) + if canCollapseExpressions(p1.projectList, p2.projectList, alwaysInline = false) => + val newProjectList = buildCleanedProjectList(p1.projectList, p2.projectList) + stack.pushAll(Seq(p2.copy(projectList = newProjectList))) case Distinct(Union(children, byName, allowMissingCol)) if flattenDistinct && byName == topByName && allowMissingCol == topAllowMissingCol => stack.pushAll(children.reverse) @@ -1322,6 +1340,20 @@ object CombineUnions extends Rule[LogicalPlan] { case Union(children, byName, allowMissingCol) if byName == topByName && allowMissingCol == topAllowMissingCol => stack.pushAll(children.reverse) + // Push down projection through Union and then push pushed plan to Stack if + // there is a Project. + case Project(projectList, Distinct(u @ Union(children, byName, allowMissingCol))) + if projectList.forall(_.deterministic) && children.nonEmpty && + flattenDistinct && byName == topByName && allowMissingCol == topAllowMissingCol => + stack.pushAll(pushProjectionThroughUnion(projectList, u).reverse) + case Project(projectList, Deduplicate(keys: Seq[Attribute], u: Union)) + if projectList.forall(_.deterministic) && flattenDistinct && u.byName == topByName && + u.allowMissingCol == topAllowMissingCol && AttributeSet(keys) == u.outputSet => + stack.pushAll(pushProjectionThroughUnion(projectList, u).reverse) + case Project(projectList, u @ Union(children, byName, allowMissingCol)) + if projectList.forall(_.deterministic) && children.nonEmpty && + byName == topByName && allowMissingCol == topAllowMissingCol => + stack.pushAll(pushProjectionThroughUnion(projectList, u).reverse) case child => flattened += child } @@ -1359,24 +1391,22 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper { * Removes Sort operations if they don't affect the final output ordering. * Note that changes in the final output ordering may affect the file size (SPARK-32318). * This rule handles the following cases: - * 1) if the child maximum number of rows less than or equal to 1 - * 2) if the sort order is empty or the sort order does not have any reference - * 3) if the Sort operator is a local sort and the child is already sorted - * 4) if there is another Sort operator separated by 0...n Project, Filter, Repartition or - * RepartitionByExpression (with deterministic expressions) operators - * 5) if the Sort operator is within Join separated by 0...n Project, Filter, Repartition or - * RepartitionByExpression (with deterministic expressions) operators only and the Join condition - * is deterministic - * 6) if the Sort operator is within GroupBy separated by 0...n Project, Filter, Repartition or - * RepartitionByExpression (with deterministic expressions) operators only and the aggregate - * function is order irrelevant + * 1) if the sort order is empty or the sort order does not have any reference + * 2) if the Sort operator is a local sort and the child is already sorted + * 3) if there is another Sort operator separated by 0...n Project, Filter, Repartition or + * RepartitionByExpression, RebalancePartitions (with deterministic expressions) operators + * 4) if the Sort operator is within Join separated by 0...n Project, Filter, Repartition or + * RepartitionByExpression, RebalancePartitions (with deterministic expressions) operators only + * and the Join condition is deterministic + * 5) if the Sort operator is within GroupBy separated by 0...n Project, Filter, Repartition or + * RepartitionByExpression, RebalancePartitions (with deterministic expressions) operators only + * and the aggregate function is order irrelevant */ object EliminateSorts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( _.containsPattern(SORT))(applyLocally) private val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = { - case Sort(_, _, child) if child.maxRows.exists(_ <= 1L) => recursiveRemoveSort(child) case s @ Sort(orders, _, child) if orders.isEmpty || orders.exists(_.child.foldable) => val newOrders = orders.filterNot(_.child.foldable) if (newOrders.isEmpty) { @@ -1409,6 +1439,7 @@ object EliminateSorts extends Rule[LogicalPlan] { case p: Project => p.projectList.forall(_.deterministic) case f: Filter => f.condition.deterministic case r: RepartitionByExpression => r.partitionExpressions.forall(_.deterministic) + case r: RebalancePartitions => r.partitionExpressions.forall(_.deterministic) case _: Repartition => true case _ => false } @@ -1659,11 +1690,10 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe */ private def canPushThroughCondition(plan: LogicalPlan, condition: Expression): Boolean = { val attributes = plan.outputSet - val matched = condition.find { + !condition.exists { case s: SubqueryExpression => s.plan.outputSet.intersect(attributes).nonEmpty case _ => false } - matched.isEmpty } } @@ -1925,7 +1955,7 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] { } private def hasUnevaluableExpr(expr: Expression): Boolean = { - expr.find(e => e.isInstanceOf[Unevaluable] && !e.isInstanceOf[AttributeReference]).isDefined + expr.exists(e => e.isInstanceOf[Unevaluable] && !e.isInstanceOf[AttributeReference]) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala index d02f12d67e19f..2c964fa6da3db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, TRUE_OR_ * [[LocalRelation]]. * 3. Unary-node Logical Plans * - Project/Filter/Sample with all empty children. - * - Limit/Repartition with all empty children. + * - Limit/Repartition/RepartitionByExpression/Rebalance with all empty children. * - Aggregate with all empty children and at least one grouping expression. * - Generate(Explode) with all empty children. Others like Hive UDTF may return results. */ @@ -138,6 +138,7 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup case _: LocalLimit if !p.isStreaming => empty(p) case _: Repartition => empty(p) case _: RepartitionByExpression => empty(p) + case _: RebalancePartitions => empty(p) // An aggregate with non-empty group expression will return one output row per group when the // input to the aggregate is not empty. If the input to the aggregate is empty then all groups // will be empty and thus the output will be empty. If we're working on batch data, we can diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutGroupingExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutGroupingExpressions.scala index 859a73a4842f0..1bd186d89a07d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutGroupingExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutGroupingExpressions.scala @@ -50,7 +50,7 @@ object PullOutGroupingExpressions extends Rule[LogicalPlan] { plan.transformWithPruning(_.containsPattern(AGGREGATE)) { case a: Aggregate if a.resolved => val complexGroupingExpressionMap = mutable.LinkedHashMap.empty[Expression, NamedExpression] - val newGroupingExpressions = a.groupingExpressions.map { + val newGroupingExpressions = a.groupingExpressions.toIndexedSeq.map { case e if !e.foldable && e.children.nonEmpty => complexGroupingExpressionMap .getOrElseUpdate(e.canonicalized, Alias(e, s"_groupingexpression")()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala index bf17791fdd0a0..2104bce3711f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.PullOutNondeterministic -import org.apache.spark.sql.catalyst.expressions.{AliasHelper, AttributeSet} +import org.apache.spark.sql.catalyst.expressions.{AliasHelper, AttributeSet, ExpressionSet} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE @@ -47,15 +47,19 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper { } else { newAggregate } + + case agg @ Aggregate(groupingExps, _, child) + if agg.groupOnly && child.distinctKeys.exists(_.subsetOf(ExpressionSet(groupingExps))) => + Project(agg.aggregateExpressions, child) } private def isLowerRedundant(upper: Aggregate, lower: Aggregate): Boolean = { val upperHasNoDuplicateSensitiveAgg = upper .aggregateExpressions - .forall(expr => expr.find { + .forall(expr => !expr.exists { case ae: AggregateExpression => isDuplicateSensitive(ae) case e => AggregateExpression.isAggregate(e) - }.isEmpty) + }) lazy val upperRefsOnlyDeterministicNonAgg = upper.references.subsetOf(AttributeSet( lower diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala index 8218051c584b3..f66128dcbc3fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala @@ -87,8 +87,8 @@ object ReplaceExceptWithFilter extends Rule[LogicalPlan] { val rightProjectList = projectList(right) left.output.size == left.output.map(_.name).distinct.size && - left.find(_.expressions.exists(SubqueryExpression.hasSubquery)).isEmpty && - right.find(_.expressions.exists(SubqueryExpression.hasSubquery)).isEmpty && + !left.exists(_.expressions.exists(SubqueryExpression.hasSubquery)) && + !right.exists(_.expressions.exists(SubqueryExpression.hasSubquery)) && Project(leftProjectList, nonFilterChild(skipProject(left))).sameResult( Project(rightProjectList, nonFilterChild(skipProject(right)))) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index b002930391222..5aa134a0c1109 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -21,7 +21,7 @@ import scala.collection.immutable.HashSet import scala.collection.mutable.{ArrayBuffer, Stack} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, MultiLikeBase, _} +import org.apache.spark.sql.catalyst.expressions.{MultiLikeBase, _} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull @@ -447,53 +447,6 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { } -/** - * Move/Push `Not` operator if it's beneficial. - */ -object NotPropagation extends Rule[LogicalPlan] { - // Given argument x, return true if expression Not(x) can be simplified - // E.g. let x == Not(y), then canSimplifyNot(x) == true because Not(x) == Not(Not(y)) == y - // For the case of x = EqualTo(a, b), recursively check each child expression - // Extra nullable check is required for EqualNullSafe because - // Not(EqualNullSafe(e, null)) is different from EqualNullSafe(e, Not(null)) - private def canSimplifyNot(x: Expression): Boolean = x match { - case Literal(_, BooleanType) | Literal(_, NullType) => true - case _: Not | _: IsNull | _: IsNotNull | _: And | _: Or => true - case _: GreaterThan | _: GreaterThanOrEqual | _: LessThan | _: LessThanOrEqual => true - case EqualTo(a, b) if canSimplifyNot(a) || canSimplifyNot(b) => true - case EqualNullSafe(a, b) - if !a.nullable && !b.nullable && (canSimplifyNot(a) || canSimplifyNot(b)) => true - case _ => false - } - - def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( - _.containsPattern(NOT), ruleId) { - case q: LogicalPlan => q.transformExpressionsDownWithPruning(_.containsPattern(NOT), ruleId) { - // Move `Not` from one side of `EqualTo`/`EqualNullSafe` to the other side if it's beneficial. - // E.g. `EqualTo(Not(a), b)` where `b = Not(c)`, it will become - // `EqualTo(a, Not(b))` => `EqualTo(a, Not(Not(c)))` => `EqualTo(a, c)` - // In addition, `if canSimplifyNot(b)` checks if the optimization can converge - // that avoids the situation two conditions are returning to each other. - case EqualTo(Not(a), b) if !canSimplifyNot(a) && canSimplifyNot(b) => EqualTo(a, Not(b)) - case EqualTo(a, Not(b)) if canSimplifyNot(a) && !canSimplifyNot(b) => EqualTo(Not(a), b) - case EqualNullSafe(Not(a), b) if !canSimplifyNot(a) && canSimplifyNot(b) => - EqualNullSafe(a, Not(b)) - case EqualNullSafe(a, Not(b)) if canSimplifyNot(a) && !canSimplifyNot(b) => - EqualNullSafe(Not(a), b) - - // Push `Not` to one side of `EqualTo`/`EqualNullSafe` if it's beneficial. - // E.g. Not(EqualTo(x, false)) => EqualTo(x, true) - case Not(EqualTo(a, b)) if canSimplifyNot(b) => EqualTo(a, Not(b)) - case Not(EqualTo(a, b)) if canSimplifyNot(a) => EqualTo(Not(a), b) - case Not(EqualNullSafe(a, b)) if !a.nullable && !b.nullable && canSimplifyNot(b) => - EqualNullSafe(a, Not(b)) - case Not(EqualNullSafe(a, b)) if !a.nullable && !b.nullable && canSimplifyNot(a) => - EqualNullSafe(Not(a), b) - } - } -} - - /** * Simplifies binary comparisons with semantically-equal expressions: * 1) Replace '<=>' with 'true' literal. @@ -662,17 +615,6 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper { case _ => false } - // Not all BinaryExpression can be pushed into (if / case) branches. - private def supportedBinaryExpression(e: BinaryExpression): Boolean = e match { - case _: BinaryComparison | _: StringPredicate | _: StringRegexExpression => true - case _: BinaryArithmetic => true - case _: BinaryMathExpression => true - case _: AddMonths | _: DateAdd | _: DateAddInterval | _: DateDiff | _: DateSub | - _: DateAddYMInterval | _: TimestampAddYMInterval | _: TimeAdd => true - case _: FindInSet | _: RoundBase => true - case _ => false - } - def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( _.containsAnyPattern(CASE_WHEN, IF), ruleId) { case q: LogicalPlan => q.transformExpressionsUpWithPruning( @@ -689,30 +631,26 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper { branches.map(e => e.copy(_2 = u.withNewChildren(Array(e._2)))), Some(u.withNewChildren(Array(elseValue.getOrElse(Literal(null, c.dataType)))))) - case b @ BinaryExpression(i @ If(_, trueValue, falseValue), right) - if supportedBinaryExpression(b) && right.foldable && - atMostOneUnfoldable(Seq(trueValue, falseValue)) => + case SupportedBinaryExpr(b, i @ If(_, trueValue, falseValue), right) + if right.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue)) => i.copy( trueValue = b.withNewChildren(Array(trueValue, right)), falseValue = b.withNewChildren(Array(falseValue, right))) - case b @ BinaryExpression(left, i @ If(_, trueValue, falseValue)) - if supportedBinaryExpression(b) && left.foldable && - atMostOneUnfoldable(Seq(trueValue, falseValue)) => + case SupportedBinaryExpr(b, left, i @ If(_, trueValue, falseValue)) + if left.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue)) => i.copy( trueValue = b.withNewChildren(Array(left, trueValue)), falseValue = b.withNewChildren(Array(left, falseValue))) - case b @ BinaryExpression(c @ CaseWhen(branches, elseValue), right) - if supportedBinaryExpression(b) && right.foldable && - atMostOneUnfoldable(branches.map(_._2) ++ elseValue) => + case SupportedBinaryExpr(b, c @ CaseWhen(branches, elseValue), right) + if right.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) => c.copy( branches.map(e => e.copy(_2 = b.withNewChildren(Array(e._2, right)))), Some(b.withNewChildren(Array(elseValue.getOrElse(Literal(null, c.dataType)), right)))) - case b @ BinaryExpression(left, c @ CaseWhen(branches, elseValue)) - if supportedBinaryExpression(b) && left.foldable && - atMostOneUnfoldable(branches.map(_._2) ++ elseValue) => + case SupportedBinaryExpr(b, left, c @ CaseWhen(branches, elseValue)) + if left.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) => c.copy( branches.map(e => e.copy(_2 = b.withNewChildren(Array(left, e._2)))), Some(b.withNewChildren(Array(left, elseValue.getOrElse(Literal(null, c.dataType)))))) @@ -720,6 +658,21 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper { } } +object SupportedBinaryExpr { + def unapply(expr: Expression): Option[(Expression, Expression, Expression)] = expr match { + case _: BinaryComparison | _: StringPredicate | _: StringRegexExpression => + Some(expr, expr.children.head, expr.children.last) + case _: BinaryArithmetic => Some(expr, expr.children.head, expr.children.last) + case _: BinaryMathExpression => Some(expr, expr.children.head, expr.children.last) + case _: AddMonths | _: DateAdd | _: DateAddInterval | _: DateDiff | _: DateSub | + _: DateAddYMInterval | _: TimestampAddYMInterval | _: TimeAdd => + Some(expr, expr.children.head, expr.children.last) + case _: FindInSet | _: RoundBase => Some(expr, expr.children.head, expr.children.last) + case BinaryPredicate(expr) => + Some(expr, expr.arguments.head, expr.arguments.last) + case _ => None + } +} /** * Simplifies LIKE expressions that do not need full regular expressions to evaluate the condition. @@ -1023,6 +976,7 @@ object FoldablePropagation extends Rule[LogicalPlan] { case _: AppendColumnsWithObject => true case _: RepartitionByExpression => true case _: Repartition => true + case _: RebalancePartitions => true case _: Sort => true case _: TypedFilter => true case _ => false @@ -1037,6 +991,9 @@ object SimplifyCasts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning( _.containsPattern(CAST), ruleId) { case Cast(e, dataType, _, _) if e.dataType == dataType => e + case c @ Cast(Cast(e, dt1: NumericType, _, _), dt2: NumericType, _, _) + if isWiderCast(e.dataType, dt1) && isWiderCast(dt1, dt2) => + c.copy(child = e) case c @ Cast(e, dataType, _, _) => (e.dataType, dataType) match { case (ArrayType(from, false), ArrayType(to, true)) if from == to => e case (MapType(fromKey, fromValue, false), MapType(toKey, toValue, true)) @@ -1044,6 +1001,15 @@ object SimplifyCasts extends Rule[LogicalPlan] { case _ => c } } + + // Returns whether the from DataType can be safely casted to the to DataType without losing + // any precision or range. + private def isWiderCast(from: DataType, to: NumericType): Boolean = (from, to) match { + case (from: NumericType, to: DecimalType) if to.isWiderThan(from) => true + case (from: DecimalType, to: NumericType) if from.isTighterThan(to) => true + case (from: IntegralType, to: IntegralType) => Cast.canUpCast(from, to) + case _ => from == to + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index 645ff6bdee975..ef9c4b9af40d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -21,7 +21,6 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreePattern._ @@ -32,26 +31,23 @@ import org.apache.spark.util.Utils /** - * Finds all the expressions that are unevaluable and replace/rewrite them with semantically - * equivalent expressions that can be evaluated. Currently we replace two kinds of expressions: - * 1) [[RuntimeReplaceable]] expressions - * 2) [[UnevaluableAggregate]] expressions such as Every, Some, Any, CountIf + * Finds all the [[RuntimeReplaceable]] expressions that are unevaluable and replace them + * with semantically equivalent expressions that can be evaluated. + * * This is mainly used to provide compatibility with other databases. * Few examples are: - * we use this to support "nvl" by replacing it with "coalesce". + * we use this to support "left" by replacing it with "substring". * we use this to replace Every and Any with Min and Max respectively. - * - * TODO: In future, explore an option to replace aggregate functions similar to - * how RuntimeReplaceable does. */ object ReplaceExpressions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning( - _.containsAnyPattern(RUNTIME_REPLACEABLE, COUNT_IF, BOOL_AGG, REGR_COUNT)) { - case e: RuntimeReplaceable => e.child - case CountIf(predicate) => Count(new NullIf(predicate, Literal.FalseLiteral)) - case BoolOr(arg) => Max(arg) - case BoolAnd(arg) => Min(arg) - case RegrCount(left, right) => Count(Seq(left, right)) + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsAnyPattern(RUNTIME_REPLACEABLE)) { + case p => p.mapExpressions(replace) + } + + private def replace(e: Expression): Expression = e match { + case r: RuntimeReplaceable => replace(r.replacement) + case _ => e.mapChildren(replace) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index e03360d3d44d6..6d683a7a11384 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -143,7 +143,7 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { val attributes = e.references.toSeq val emptyRow = new GenericInternalRow(attributes.length) val boundE = BindReferences.bindReference(e, attributes) - if (boundE.find(_.isInstanceOf[Unevaluable]).isDefined) return false + if (boundE.exists(_.isInstanceOf[Unevaluable])) return false val v = boundE.eval(emptyRow) v == null || v == false } @@ -195,9 +195,9 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { object ExtractPythonUDFFromJoinCondition extends Rule[LogicalPlan] with PredicateHelper { private def hasUnevaluablePythonUDF(expr: Expression, j: Join): Boolean = { - expr.find { e => + expr.exists { e => PythonUDF.isScalarPythonUDF(e) && !canEvaluate(e, j.left) && !canEvaluate(e, j.right) - }.isDefined + } } override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index 52544ff3e241d..82aef32c5a22f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -186,7 +186,7 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] { serializer: NamedExpression, prunedDataType: DataType): NamedExpression = { val prunedStructTypes = collectStructType(prunedDataType, ArrayBuffer.empty[StructType]) - .toIterator + .iterator def transformer: PartialFunction[Expression, Expression] = { case m: ExternalMapToCatalyst => @@ -222,7 +222,7 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] { if (conf.serializerNestedSchemaPruningEnabled && rootFields.nonEmpty) { // Prunes nested fields in serializers. - val prunedSchema = SchemaPruning.pruneDataSchema( + val prunedSchema = SchemaPruning.pruneSchema( StructType.fromAttributes(prunedSerializer.map(_.toAttribute)), rootFields) val nestedPrunedSerializer = prunedSerializer.zipWithIndex.map { case (serializer, idx) => pruneSerializer(serializer, prunedSchema(idx).dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 6d6b8b7d8aca8..7ef5ef55fabda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -728,7 +728,7 @@ object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] { } private def hasCorrelatedSubquery(plan: LogicalPlan): Boolean = { - plan.find(_.expressions.exists(SubqueryExpression.hasCorrelatedSubquery)).isDefined + plan.exists(_.expressions.exists(SubqueryExpression.hasCorrelatedSubquery)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 6a509db73718c..5eb72af6b2f09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -54,7 +54,8 @@ import org.apache.spark.util.random.RandomSampler * The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or * TableIdentifier. */ -class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logging { +class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper with Logging { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import ParserUtils._ protected def typedVisit[T](ctx: ParseTree): T = { @@ -1145,7 +1146,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg case Some(c) if c.booleanExpression != null => (baseJoinType, Option(expression(c.booleanExpression))) case Some(c) => - throw QueryParsingErrors.joinCriteriaUnimplementedError(c, ctx) + throw new IllegalStateException(s"Unimplemented joinCriteria: $c") case None if join.NATURAL != null => if (join.LATERAL != null) { throw QueryParsingErrors.lateralJoinWithNaturalJoinUnsupportedError(ctx) @@ -1669,7 +1670,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg str.charAt(0) }.getOrElse('\\') val likeExpr = ctx.kind.getType match { - case SqlBaseParser.ILIKE => new ILike(e, expression(ctx.pattern), escapeChar) + case SqlBaseParser.ILIKE => ILike(e, expression(ctx.pattern), escapeChar) case _ => Like(e, expression(ctx.pattern), escapeChar) } invertIfNotDefined(likeExpr) @@ -2090,6 +2091,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg return false } + /** + * Returns whether the pattern is a regex expression (instead of a normal + * string). Normal string is a string with all alphabets/digits and "_". + */ + private def isRegex(pattern: String): Boolean = { + pattern.exists(p => !Character.isLetterOrDigit(p) && p != '_') + } + /** * Create a dereference expression. The return type depends on the type of the parent. * If the parent is an [[UnresolvedAttribute]], it can be a [[UnresolvedAttribute]] or @@ -2102,7 +2111,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg case unresolved_attr @ UnresolvedAttribute(nameParts) => ctx.fieldName.getStart.getText match { case escapedIdentifier(columnNameRegex) - if conf.supportQuotedRegexColumnName && canApplyRegex(ctx) => + if conf.supportQuotedRegexColumnName && + isRegex(columnNameRegex) && canApplyRegex(ctx) => UnresolvedRegex(columnNameRegex, Some(unresolved_attr.name), conf.caseSensitiveAnalysis) case _ => @@ -2120,7 +2130,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg override def visitColumnReference(ctx: ColumnReferenceContext): Expression = withOrigin(ctx) { ctx.getStart.getText match { case escapedIdentifier(columnNameRegex) - if conf.supportQuotedRegexColumnName && canApplyRegex(ctx) => + if conf.supportQuotedRegexColumnName && + isRegex(columnNameRegex) && canApplyRegex(ctx) => UnresolvedRegex(columnNameRegex, None, conf.caseSensitiveAnalysis) case _ => UnresolvedAttribute.quoted(ctx.getText) @@ -2569,11 +2580,10 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } if (values(i).MINUS() == null) { value + } else if (value.startsWith("-")) { + value.replaceFirst("-", "") } else { - value.startsWith("-") match { - case true => value.replaceFirst("-", "") - case false => s"-$value" - } + s"-$value" } } else { values(i).getText @@ -2598,11 +2608,10 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg val value = Option(ctx.intervalValue.STRING).map(string).map { interval => if (ctx.intervalValue().MINUS() == null) { interval + } else if (interval.startsWith("-")) { + interval.replaceFirst("-", "") } else { - interval.startsWith("-") match { - case true => interval.replaceFirst("-", "") - case false => s"-$interval" - } + s"-$interval" } }.getOrElse { throw QueryParsingErrors.invalidFromToUnitValueError(ctx.intervalValue) @@ -2719,6 +2728,13 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg StructType(Option(ctx).toSeq.flatMap(visitColTypeList)) } + /** + * Create top level table schema. + */ + protected def createSchema(ctx: CreateOrReplaceTableColTypeListContext): StructType = { + StructType(Option(ctx).toSeq.flatMap(visitCreateOrReplaceTableColTypeList)) + } + /** * Create a [[StructType]] from a number of column definitions. */ @@ -2745,6 +2761,41 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg metadata = builder.build()) } + /** + * Create a [[StructType]] from a number of CREATE TABLE column definitions. + */ + override def visitCreateOrReplaceTableColTypeList( + ctx: CreateOrReplaceTableColTypeListContext): Seq[StructField] = withOrigin(ctx) { + ctx.createOrReplaceTableColType().asScala.map(visitCreateOrReplaceTableColType).toSeq + } + + /** + * Create a top level [[StructField]] from a CREATE TABLE column definition. + */ + override def visitCreateOrReplaceTableColType( + ctx: CreateOrReplaceTableColTypeContext): StructField = withOrigin(ctx) { + import ctx._ + + val builder = new MetadataBuilder + // Add comment to metadata + Option(commentSpec()).map(visitCommentSpec).foreach { + builder.putString("comment", _) + } + + // Process the 'DEFAULT expression' clause in the column definition, if any. + val name: String = colName.getText + val defaultExpr = Option(ctx.defaultExpression()).map(visitDefaultExpression) + if (defaultExpr.isDefined) { + throw QueryParsingErrors.defaultColumnNotImplementedYetError(ctx) + } + + StructField( + name = name, + dataType = typedVisit[DataType](ctx.dataType), + nullable = NULL == null, + metadata = builder.build()) + } + /** * Create a [[StructType]] from a sequence of [[StructField]]s. */ @@ -2930,15 +2981,6 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg (multipartIdentifier, temporary, ifNotExists, ctx.EXTERNAL != null) } - /** - * Validate a replace table statement and return the [[TableIdentifier]]. - */ - override def visitReplaceTableHeader( - ctx: ReplaceTableHeaderContext): TableHeader = withOrigin(ctx) { - val multipartIdentifier = ctx.multipartIdentifier.parts.asScala.map(_.getText).toSeq - (multipartIdentifier, false, false, false) - } - /** * Parse a qualified name to a multipart name. */ @@ -2982,7 +3024,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg if (arguments.size > 1) { throw QueryParsingErrors.tooManyArgumentsForTransformError(name, ctx) } else if (arguments.isEmpty) { - throw QueryParsingErrors.notEnoughArgumentsForTransformError(name, ctx) + throw new IllegalStateException(s"Not enough arguments for transform $name") } else { getFieldReference(ctx, arguments.head) } @@ -3043,7 +3085,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg .map(typedVisit[Literal]) .map(lit => LiteralValue(lit.value, lit.dataType)) reference.orElse(literal) - .getOrElse(throw QueryParsingErrors.invalidTransformArgumentError(ctx)) + .getOrElse(throw new IllegalStateException("Invalid transform argument")) } } @@ -3203,6 +3245,10 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg throw QueryParsingErrors.cannotCleanReservedTablePropertyError( PROP_OWNER, ctx, "it will be set to the current user") case (PROP_OWNER, _) => false + case (PROP_EXTERNAL, _) if !legacyOn => + throw QueryParsingErrors.cannotCleanReservedTablePropertyError( + PROP_EXTERNAL, ctx, "please use CREATE EXTERNAL TABLE") + case (PROP_EXTERNAL, _) => false case _ => true } } @@ -3453,7 +3499,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) { val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) - val columns = Option(ctx.colTypeList()).map(visitColTypeList).getOrElse(Nil) + val columns = Option(ctx.createOrReplaceTableColTypeList()) + .map(visitCreateOrReplaceTableColTypeList).getOrElse(Nil) val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText) val (partTransforms, partCols, bucketSpec, properties, options, location, comment, serdeInfo) = visitCreateTableClauses(ctx.createTableClauses()) @@ -3468,8 +3515,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg s"CREATE TEMPORARY TABLE ...$asSelect, use CREATE TEMPORARY VIEW instead", ctx) } - val partitioning = partitionExpressions(partTransforms, partCols, ctx) - val tableSpec = TableSpec(bucketSpec, properties, provider, options, location, comment, + val partitioning = + partitionExpressions(partTransforms, partCols, ctx) ++ bucketSpec.map(_.asTransform) + val tableSpec = TableSpec(properties, provider, options, location, comment, serdeInfo, external) Option(ctx.query).map(plan) match { @@ -3527,33 +3575,21 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg * }}} */ override def visitReplaceTable(ctx: ReplaceTableContext): LogicalPlan = withOrigin(ctx) { - val (table, temp, ifNotExists, external) = visitReplaceTableHeader(ctx.replaceTableHeader) + val table = visitMultipartIdentifier(ctx.replaceTableHeader.multipartIdentifier()) val orCreate = ctx.replaceTableHeader().CREATE() != null - - if (temp) { - val action = if (orCreate) "CREATE OR REPLACE" else "REPLACE" - operationNotAllowed(s"$action TEMPORARY TABLE ..., use $action TEMPORARY VIEW instead.", ctx) - } - - if (external) { - operationNotAllowed("REPLACE EXTERNAL TABLE ...", ctx) - } - - if (ifNotExists) { - operationNotAllowed("REPLACE ... IF NOT EXISTS, use CREATE IF NOT EXISTS instead", ctx) - } - val (partTransforms, partCols, bucketSpec, properties, options, location, comment, serdeInfo) = visitCreateTableClauses(ctx.createTableClauses()) - val columns = Option(ctx.colTypeList()).map(visitColTypeList).getOrElse(Nil) + val columns = Option(ctx.createOrReplaceTableColTypeList()) + .map(visitCreateOrReplaceTableColTypeList).getOrElse(Nil) val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText) if (provider.isDefined && serdeInfo.isDefined) { operationNotAllowed(s"CREATE TABLE ... USING ... ${serdeInfo.get.describe}", ctx) } - val partitioning = partitionExpressions(partTransforms, partCols, ctx) - val tableSpec = TableSpec(bucketSpec, properties, provider, options, location, comment, + val partitioning = + partitionExpressions(partTransforms, partCols, ctx) ++ bucketSpec.map(_.asTransform) + val tableSpec = TableSpec(properties, provider, options, location, comment, serdeInfo, false) Option(ctx.query).map(plan) match { @@ -3663,6 +3699,10 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg override def visitQualifiedColTypeWithPosition( ctx: QualifiedColTypeWithPositionContext): QualifiedColType = withOrigin(ctx) { val name = typedVisit[Seq[String]](ctx.name) + val defaultExpr = Option(ctx.defaultExpression()).map(visitDefaultExpression) + if (defaultExpr.isDefined) { + throw QueryParsingErrors.defaultColumnNotImplementedYetError(ctx) + } QualifiedColType( path = if (name.length > 1) Some(UnresolvedFieldName(name.init)) else None, colName = name.last, @@ -3751,6 +3791,12 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } else { None } + if (action.defaultExpression != null) { + throw QueryParsingErrors.defaultColumnNotImplementedYetError(ctx) + } + if (action.dropDefault != null) { + throw QueryParsingErrors.defaultColumnNotImplementedYetError(ctx) + } assert(Seq(dataType, nullable, comment, position).count(_.nonEmpty) == 1) @@ -3819,6 +3865,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg throw QueryParsingErrors.operationInHiveStyleCommandUnsupportedError( "Replacing with a nested column", "REPLACE COLUMNS", ctx) } + if (Option(colType.defaultExpression()).map(visitDefaultExpression).isDefined) { + throw QueryParsingErrors.defaultColumnNotImplementedYetError(ctx) + } col }.toSeq ) @@ -4503,4 +4552,18 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg private def alterViewTypeMismatchHint: Option[String] = Some("Please use ALTER TABLE instead.") private def alterTableTypeMismatchHint: Option[String] = Some("Please use ALTER VIEW instead.") + + /** + * Create a TimestampAdd expression. + */ + override def visitTimestampadd(ctx: TimestampaddContext): Expression = withOrigin(ctx) { + TimestampAdd(ctx.unit.getText, expression(ctx.unitsAmount), expression(ctx.timestamp)) + } + + /** + * Create a TimestampDiff expression. + */ + override def visitTimestampdiff(ctx: TimestampdiffContext): Expression = withOrigin(ctx) { + TimestampDiff(ctx.unit.getText, expression(ctx.startTimestamp), expression(ctx.endTimestamp)) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index 1057c78f3c282..5c9c382d08d04 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -104,6 +104,7 @@ abstract class AbstractSqlParser extends ParserInterface with SQLConfHelper with parser.addParseListener(UnclosedCommentProcessor(command, tokenStream)) parser.removeErrorListeners() parser.addErrorListener(ParseErrorListener) + parser.setErrorHandler(new SparkParserErrorStrategy()) parser.legacy_setops_precedence_enabled = conf.setOpsPrecedenceEnforced parser.legacy_exponent_literal_as_decimal_enabled = conf.exponentLiteralAsDecimalEnabled parser.SQL_standard_keyword_behavior = conf.enforceReservedKeywords @@ -207,7 +208,12 @@ case object ParseErrorListener extends BaseErrorListener { val start = Origin(Some(line), Some(charPositionInLine)) (start, start) } - throw new ParseException(None, msg, start, stop) + e match { + case sre: SparkRecognitionException if sre.errorClass.isDefined => + throw new ParseException(None, start, stop, sre.errorClass.get, sre.messageParameters) + case _ => + throw new ParseException(None, msg, start, stop) + } } } @@ -246,6 +252,21 @@ class ParseException( Some(errorClass), messageParameters) + /** Compose the message through SparkThrowableHelper given errorClass and messageParameters. */ + def this( + command: Option[String], + start: Origin, + stop: Origin, + errorClass: String, + messageParameters: Array[String]) = + this( + command, + SparkThrowableHelper.getMessage(errorClass, messageParameters), + start, + stop, + Some(errorClass), + messageParameters) + override def getMessage: String = { val builder = new StringBuilder builder ++= "\n" ++= message @@ -268,14 +289,19 @@ class ParseException( } def withCommand(cmd: String): ParseException = { - new ParseException(Option(cmd), message, start, stop, errorClass, messageParameters) + // PARSE_EMPTY_STATEMENT error class overrides the PARSE_INPUT_MISMATCHED when cmd is empty + if (cmd.trim().isEmpty && errorClass.isDefined && errorClass.get == "PARSE_INPUT_MISMATCHED") { + new ParseException(Option(cmd), start, stop, "PARSE_EMPTY_STATEMENT", Array[String]()) + } else { + new ParseException(Option(cmd), message, start, stop, errorClass, messageParameters) + } } } /** * The post-processor validates & cleans-up the parse tree during the parse process. */ -case object PostProcessor extends SqlBaseBaseListener { +case object PostProcessor extends SqlBaseParserBaseListener { /** Throws error message when exiting a explicitly captured wrong identifier rule */ override def exitErrorIdent(ctx: SqlBaseParser.ErrorIdentContext): Unit = { @@ -319,7 +345,7 @@ case object PostProcessor extends SqlBaseBaseListener { * The post-processor checks the unclosed bracketed comment. */ case class UnclosedCommentProcessor( - command: String, tokenStream: CommonTokenStream) extends SqlBaseBaseListener { + command: String, tokenStream: CommonTokenStream) extends SqlBaseParserBaseListener { override def exitSingleDataType(ctx: SqlBaseParser.SingleDataTypeContext): Unit = { checkUnclosedComment(tokenStream, command) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SparkParserErrorStrategy.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SparkParserErrorStrategy.scala new file mode 100644 index 0000000000000..0ce514c4d2298 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SparkParserErrorStrategy.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.parser + +import org.antlr.v4.runtime.{DefaultErrorStrategy, InputMismatchException, IntStream, Parser, + ParserRuleContext, RecognitionException, Recognizer} + +/** + * A [[SparkRecognitionException]] extends the [[RecognitionException]] with more information + * including the error class and parameters for the error message, which align with the interface + * of [[SparkThrowableHelper]]. + */ +class SparkRecognitionException( + message: String, + recognizer: Recognizer[_, _], + input: IntStream, + ctx: ParserRuleContext, + val errorClass: Option[String] = None, + val messageParameters: Array[String] = Array.empty) + extends RecognitionException(message, recognizer, input, ctx) { + + /** Construct from a given [[RecognitionException]], with additional error information. */ + def this( + recognitionException: RecognitionException, + errorClass: String, + messageParameters: Array[String]) = + this( + recognitionException.getMessage, + recognitionException.getRecognizer, + recognitionException.getInputStream, + recognitionException.getCtx match { + case p: ParserRuleContext => p + case _ => null + }, + Some(errorClass), + messageParameters) +} + +/** + * A [[SparkParserErrorStrategy]] extends the [[DefaultErrorStrategy]], that does special handling + * on errors. + * + * The intention of this class is to provide more information of these errors encountered in + * ANTLR parser to the downstream consumers, to be able to apply the [[SparkThrowable]] error + * message framework to these exceptions. + */ +class SparkParserErrorStrategy() extends DefaultErrorStrategy { + private val userWordDict : Map[String, String] = Map("''" -> "end of input") + private def getUserFacingLanguage(input: String) = { + userWordDict.getOrElse(input, input) + } + + override def reportInputMismatch(recognizer: Parser, e: InputMismatchException): Unit = { + // Keep the original error message in ANTLR + val msg = "mismatched input " + + this.getTokenErrorDisplay(e.getOffendingToken) + + " expecting " + + e.getExpectedTokens.toString(recognizer.getVocabulary) + + val exceptionWithErrorClass = new SparkRecognitionException( + e, + "PARSE_INPUT_MISMATCHED", + Array(getUserFacingLanguage(getTokenErrorDisplay(e.getOffendingToken)))) + recognizer.notifyErrorListeners(e.getOffendingToken, msg, exceptionWithErrorClass) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 2417ff904570b..5d749b8fc4b53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.rules.RuleId import org.apache.spark.sql.catalyst.rules.UnknownRuleId import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin, TreeNode, TreeNodeTag} -import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE +import org.apache.spark.sql.catalyst.trees.TreePattern.{OUTER_REFERENCE, PLAN_EXPRESSION} import org.apache.spark.sql.catalyst.trees.TreePatternBits import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} @@ -354,7 +354,13 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] private def updateAttr(a: Attribute, attrMap: AttributeMap[Attribute]): Attribute = { attrMap.get(a) match { case Some(b) => - AttributeReference(a.name, b.dataType, b.nullable, a.metadata)(b.exprId, a.qualifier) + // The new Attribute has to + // - use a.nullable, because nullability cannot be propagated bottom-up without considering + // enclosed operators, e.g., operators such as Filters and Outer Joins can change + // nullability; + // - use b.dataType because transformUpWithNewOutput is used in the Analyzer for resolution, + // e.g., WidenSetOperationTypes uses it to propagate types bottom-up. + AttributeReference(a.name, b.dataType, a.nullable, a.metadata)(b.exprId, a.qualifier) case None => a } } @@ -427,8 +433,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] /** * All the top-level subqueries of the current plan node. Nested subqueries are not included. */ - def subqueries: Seq[PlanType] = { - expressions.flatMap(_.collect { + @transient lazy val subqueries: Seq[PlanType] = { + expressions.filter(_.containsPattern(PLAN_EXPRESSION)).flatMap(_.collect { case e: PlanExpression[_] => e.plan.asInstanceOf[PlanType] }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala new file mode 100644 index 0000000000000..bb2bc4e3d2f93 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, ExpressionSet, NamedExpression} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, LeftSemiOrAnti, RightOuter} + +/** + * A visitor pattern for traversing a [[LogicalPlan]] tree and propagate the distinct attributes. + */ +object DistinctKeyVisitor extends LogicalPlanVisitor[Set[ExpressionSet]] { + + private def projectDistinctKeys( + keys: Set[ExpressionSet], projectList: Seq[NamedExpression]): Set[ExpressionSet] = { + val outputSet = ExpressionSet(projectList.map(_.toAttribute)) + val aliases = projectList.filter(_.isInstanceOf[Alias]) + if (aliases.isEmpty) { + keys.filter(_.subsetOf(outputSet)) + } else { + val aliasedDistinctKeys = keys.map { expressionSet => + expressionSet.map { expression => + expression transform { + case expr: Expression => + // TODO: Expand distinctKeys for redundant aliases on the same expression + aliases + .collectFirst { case a: Alias if a.child.semanticEquals(expr) => a.toAttribute } + .getOrElse(expr) + } + } + } + aliasedDistinctKeys.collect { + case es: ExpressionSet if es.subsetOf(outputSet) => ExpressionSet(es) + } ++ keys.filter(_.subsetOf(outputSet)) + }.filter(_.nonEmpty) + } + + override def default(p: LogicalPlan): Set[ExpressionSet] = Set.empty[ExpressionSet] + + override def visitAggregate(p: Aggregate): Set[ExpressionSet] = { + val groupingExps = ExpressionSet(p.groupingExpressions) // handle group by a, a + projectDistinctKeys(Set(groupingExps), p.aggregateExpressions) + } + + override def visitDistinct(p: Distinct): Set[ExpressionSet] = Set(ExpressionSet(p.output)) + + override def visitExcept(p: Except): Set[ExpressionSet] = + if (!p.isAll) Set(ExpressionSet(p.output)) else default(p) + + override def visitExpand(p: Expand): Set[ExpressionSet] = default(p) + + override def visitFilter(p: Filter): Set[ExpressionSet] = p.child.distinctKeys + + override def visitGenerate(p: Generate): Set[ExpressionSet] = default(p) + + override def visitGlobalLimit(p: GlobalLimit): Set[ExpressionSet] = { + p.maxRows match { + case Some(value) if value <= 1 => Set(ExpressionSet(p.output)) + case _ => p.child.distinctKeys + } + } + + override def visitIntersect(p: Intersect): Set[ExpressionSet] = { + if (!p.isAll) Set(ExpressionSet(p.output)) else default(p) + } + + override def visitJoin(p: Join): Set[ExpressionSet] = { + p match { + case Join(_, _, LeftSemiOrAnti(_), _, _) => + p.left.distinctKeys + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, left, right, _) + if left.distinctKeys.nonEmpty || right.distinctKeys.nonEmpty => + val rightJoinKeySet = ExpressionSet(rightKeys) + val leftJoinKeySet = ExpressionSet(leftKeys) + joinType match { + case Inner if left.distinctKeys.exists(_.subsetOf(leftJoinKeySet)) && + right.distinctKeys.exists(_.subsetOf(rightJoinKeySet)) => + left.distinctKeys ++ right.distinctKeys + case Inner | LeftOuter if right.distinctKeys.exists(_.subsetOf(rightJoinKeySet)) => + p.left.distinctKeys + case Inner | RightOuter if left.distinctKeys.exists(_.subsetOf(leftJoinKeySet)) => + p.right.distinctKeys + case _ => + default(p) + } + case _ => default(p) + } + } + + override def visitLocalLimit(p: LocalLimit): Set[ExpressionSet] = p.child.distinctKeys + + override def visitPivot(p: Pivot): Set[ExpressionSet] = default(p) + + override def visitProject(p: Project): Set[ExpressionSet] = { + if (p.child.distinctKeys.nonEmpty) { + projectDistinctKeys(p.child.distinctKeys, p.projectList) + } else { + default(p) + } + } + + override def visitRepartition(p: Repartition): Set[ExpressionSet] = p.child.distinctKeys + + override def visitRepartitionByExpr(p: RepartitionByExpression): Set[ExpressionSet] = + p.child.distinctKeys + + override def visitSample(p: Sample): Set[ExpressionSet] = { + if (!p.withReplacement) p.child.distinctKeys else default(p) + } + + override def visitScriptTransform(p: ScriptTransformation): Set[ExpressionSet] = default(p) + + override def visitUnion(p: Union): Set[ExpressionSet] = default(p) + + override def visitWindow(p: Window): Set[ExpressionSet] = p.child.distinctKeys + + override def visitTail(p: Tail): Set[ExpressionSet] = p.child.distinctKeys + + override def visitSort(p: Sort): Set[ExpressionSet] = p.child.distinctKeys + + override def visitRebalancePartitions(p: RebalancePartitions): Set[ExpressionSet] = + p.child.distinctKeys + + override def visitWithCTE(p: WithCTE): Set[ExpressionSet] = p.plan.distinctKeys +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 49634a2a0eb89..7640d9234c71f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -31,6 +31,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with AnalysisHelper with LogicalPlanStats + with LogicalPlanDistinctKeys with QueryPlanConstraints with Logging { @@ -183,7 +184,7 @@ trait UnaryNode extends LogicalPlan with UnaryLike[LogicalPlan] { projectList.foreach { case a @ Alias(l: Literal, _) => allConstraints += EqualNullSafe(a.toAttribute, l) - case a @ Alias(e, _) => + case a @ Alias(e, _) if e.deterministic => // For every alias in `projectList`, replace the reference in constraints by its attribute. allConstraints ++= allConstraints.map(_ transform { case expr: Expression if expr.semanticEquals(e) => @@ -212,11 +213,12 @@ object LogicalPlanIntegrity { private def canGetOutputAttrs(p: LogicalPlan): Boolean = { p.resolved && !p.expressions.exists { e => - e.collectFirst { + e.exists { // We cannot call `output` in plans with a `ScalarSubquery` expr having no column, // so, we filter out them in advance. - case s: ScalarSubquery if s.plan.schema.fields.isEmpty => true - }.isDefined + case s: ScalarSubquery => s.plan.schema.fields.isEmpty + case _ => false + } } } diff --git a/core/src/main/java/org/apache/hadoop/shaded/net/jpountz/lz4/LZ4SafeDecompressor.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanDistinctKeys.scala similarity index 53% rename from core/src/main/java/org/apache/hadoop/shaded/net/jpountz/lz4/LZ4SafeDecompressor.java rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanDistinctKeys.scala index cd3dd6f060f52..1843c2da478ef 100644 --- a/core/src/main/java/org/apache/hadoop/shaded/net/jpountz/lz4/LZ4SafeDecompressor.java +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanDistinctKeys.scala @@ -15,22 +15,20 @@ * limitations under the License. */ -package org.apache.hadoop.shaded.net.jpountz.lz4; +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.ExpressionSet +import org.apache.spark.sql.internal.SQLConf.PROPAGATE_DISTINCT_KEYS_ENABLED /** - * TODO(SPARK-36679): A temporary workaround for SPARK-36669. We should remove this after - * Hadoop 3.3.2 release which fixes the LZ4 relocation in shaded Hadoop client libraries. - * This does not need implement all net.jpountz.lz4.LZ4SafeDecompressor API, just the ones - * used by Hadoop Lz4Decompressor. + * A trait to add distinct attributes to [[LogicalPlan]]. For example: + * {{{ + * SELECT a, b, SUM(c) FROM Tab1 GROUP BY a, b + * // returns a, b + * }}} */ -public final class LZ4SafeDecompressor { - private net.jpountz.lz4.LZ4SafeDecompressor lz4Decompressor; - - public LZ4SafeDecompressor(net.jpountz.lz4.LZ4SafeDecompressor lz4Decompressor) { - this.lz4Decompressor = lz4Decompressor; - } - - public void decompress(java.nio.ByteBuffer src, java.nio.ByteBuffer dest) { - lz4Decompressor.decompress(src, dest); +trait LogicalPlanDistinctKeys { self: LogicalPlan => + lazy val distinctKeys: Set[ExpressionSet] = { + if (conf.getConf(PROPAGATE_DISTINCT_KEYS_ENABLED)) DistinctKeyVisitor.visit(self) else Set.empty } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala index ba927746bbf6a..fd5f9051719dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala @@ -37,6 +37,7 @@ trait LogicalPlanVisitor[T] { case p: Project => visitProject(p) case p: Repartition => visitRepartition(p) case p: RepartitionByExpression => visitRepartitionByExpr(p) + case p: RebalancePartitions => visitRebalancePartitions(p) case p: Sample => visitSample(p) case p: ScriptTransformation => visitScriptTransform(p) case p: Union => visitUnion(p) @@ -77,6 +78,8 @@ trait LogicalPlanVisitor[T] { def visitRepartitionByExpr(p: RepartitionByExpression): T + def visitRebalancePartitions(p: RebalancePartitions): T + def visitSample(p: Sample): T def visitScriptTransform(p: ScriptTransformation): T diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index e8a632d01598f..895eeb772075d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -277,11 +277,18 @@ case class Union( assert(!allowMissingCol || byName, "`allowMissingCol` can be true only if `byName` is true.") override def maxRows: Option[Long] = { - if (children.exists(_.maxRows.isEmpty)) { - None - } else { - Some(children.flatMap(_.maxRows).sum) + var sum = BigInt(0) + children.foreach { child => + if (child.maxRows.isDefined) { + sum += child.maxRows.get + if (!sum.isValidLong) { + return None + } + } else { + return None + } } + Some(sum.toLong) } final override val nodePatterns: Seq[TreePattern] = Seq(UNION) @@ -290,11 +297,18 @@ case class Union( * Note the definition has assumption about how union is implemented physically. */ override def maxRowsPerPartition: Option[Long] = { - if (children.exists(_.maxRowsPerPartition.isEmpty)) { - None - } else { - Some(children.flatMap(_.maxRowsPerPartition).sum) + var sum = BigInt(0) + children.foreach { child => + if (child.maxRowsPerPartition.isDefined) { + sum += child.maxRowsPerPartition.get + if (!sum.isValidLong) { + return None + } + } else { + return None + } } + Some(sum.toLong) } def duplicateResolved: Boolean = { @@ -975,7 +989,7 @@ case class Aggregate( final override val nodePatterns : Seq[TreePattern] = Seq(AGGREGATE) override lazy val validConstraints: ExpressionSet = { - val nonAgg = aggregateExpressions.filter(_.find(_.isInstanceOf[AggregateExpression]).isEmpty) + val nonAgg = aggregateExpressions.filter(!_.exists(_.isInstanceOf[AggregateExpression])) getAllValidConstraints(nonAgg) } @@ -984,10 +998,12 @@ case class Aggregate( // Whether this Aggregate operator is group only. For example: SELECT a, a FROM t GROUP BY a private[sql] def groupOnly: Boolean = { - aggregateExpressions.map { + // aggregateExpressions can be empty through Dateset.agg, + // so we should also check groupingExpressions is non empty + groupingExpressions.nonEmpty && aggregateExpressions.map { case Alias(child, _) => child case e => e - }.forall(a => groupingExpressions.exists(g => a.semanticEquals(g))) + }.forall(a => a.foldable || groupingExpressions.exists(g => a.semanticEquals(g))) } } @@ -1344,7 +1360,11 @@ case class Sample( s"Sampling fraction ($fraction) must be on interval [0, 1] without replacement") } - override def maxRows: Option[Long] = child.maxRows + override def maxRows: Option[Long] = { + // when withReplacement is true, PoissonSampler is applied in SampleExec, + // which may output more rows than child.maxRows. + if (withReplacement) None else child.maxRows + } override def output: Seq[Attribute] = child.output override protected def withNewChildInternal(newChild: LogicalPlan): Sample = @@ -1459,14 +1479,19 @@ object RepartitionByExpression { */ case class RebalancePartitions( partitionExpressions: Seq[Expression], - child: LogicalPlan) extends UnaryNode { + child: LogicalPlan, + initialNumPartitionOpt: Option[Int] = None) extends UnaryNode { override def maxRows: Option[Long] = child.maxRows override def output: Seq[Attribute] = child.output + override val nodePatterns: Seq[TreePattern] = Seq(REBALANCE_PARTITIONS) - def partitioning: Partitioning = if (partitionExpressions.isEmpty) { - RoundRobinPartitioning(conf.numShufflePartitions) - } else { - HashPartitioning(partitionExpressions, conf.numShufflePartitions) + def partitioning: Partitioning = { + val initialNumPartitions = initialNumPartitionOpt.getOrElse(conf.numShufflePartitions) + if (partitionExpressions.isEmpty) { + RoundRobinPartitioning(initialNumPartitions) + } else { + HashPartitioning(partitionExpressions, initialNumPartitions) + } } override protected def withNewChildInternal(newChild: LogicalPlan): RebalancePartitions = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala index 3f702724cca53..0f09022fb9c2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala @@ -88,6 +88,8 @@ object BasicStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { override def visitRepartitionByExpr(p: RepartitionByExpression): Statistics = fallback(p) + override def visitRebalancePartitions(p: RebalancePartitions): Statistics = fallback(p) + override def visitSample(p: Sample): Statistics = fallback(p) override def visitScriptTransform(p: ScriptTransformation): Statistics = default(p) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala index 73c1b9445f693..67a045fe5ec1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala @@ -132,6 +132,8 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { override def visitRepartitionByExpr(p: RepartitionByExpression): Statistics = p.child.stats + override def visitRebalancePartitions(p: RebalancePartitions): Statistics = p.child.stats + override def visitSample(p: Sample): Statistics = { val ratio = p.upperBound - p.lowerBound var sizeInBytes = EstimationUtils.ceil(BigDecimal(p.child.stats.sizeInBytes) * ratio) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index edf3abfacbb72..45465b0f99d3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, FieldName, NamedRelation, PartitionSpec, ResolvedDBObjectName, UnresolvedException} -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, FunctionResource} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec +import org.apache.spark.sql.catalyst.catalog.FunctionResource import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, Unevaluable} import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema import org.apache.spark.sql.catalyst.trees.BinaryLike @@ -1129,7 +1129,6 @@ case class DropIndex( } case class TableSpec( - bucketSpec: Option[BucketSpec], properties: Map[String, String], provider: Option[String], options: Map[String, String], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 7d30ecd97c3ca..78d153c5a0e83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.plans.physical import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -71,9 +72,14 @@ case object AllTuples extends Distribution { /** * Represents data where tuples that share the same values for the `clustering` * [[Expression Expressions]] will be co-located in the same partition. + * + * @param requireAllClusterKeys When true, `Partitioning` which satisfies this distribution, + * must match all `clustering` expressions in the same ordering. */ case class ClusteredDistribution( clustering: Seq[Expression], + requireAllClusterKeys: Boolean = SQLConf.get.getConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION), requiredNumPartitions: Option[Int] = None) extends Distribution { require( clustering != Nil, @@ -87,6 +93,60 @@ case class ClusteredDistribution( s"the actual number of partitions is $numPartitions.") HashPartitioning(clustering, numPartitions) } + + /** + * Checks if `expressions` match all `clustering` expressions in the same ordering. + * + * `Partitioning` should call this to check its expressions when `requireAllClusterKeys` + * is set to true. + */ + def areAllClusterKeysMatched(expressions: Seq[Expression]): Boolean = { + expressions.length == clustering.length && + expressions.zip(clustering).forall { + case (l, r) => l.semanticEquals(r) + } + } +} + +/** + * Represents the requirement of distribution on the stateful operator in Structured Streaming. + * + * Each partition in stateful operator initializes state store(s), which are independent with state + * store(s) in other partitions. Since it is not possible to repartition the data in state store, + * Spark should make sure the physical partitioning of the stateful operator is unchanged across + * Spark versions. Violation of this requirement may bring silent correctness issue. + * + * Since this distribution relies on [[HashPartitioning]] on the physical partitioning of the + * stateful operator, only [[HashPartitioning]] (and HashPartitioning in + * [[PartitioningCollection]]) can satisfy this distribution. + * When `_requiredNumPartitions` is 1, [[SinglePartition]] is essentially same as + * [[HashPartitioning]], so it can satisfy this distribution as well. + * + * NOTE: This is applied only to stream-stream join as of now. For other stateful operators, we + * have been using ClusteredDistribution, which could construct the physical partitioning of the + * state in different way (ClusteredDistribution requires relaxed condition and multiple + * partitionings can satisfy the requirement.) We need to construct the way to fix this with + * minimizing possibility to break the existing checkpoints. + * + * TODO(SPARK-38204): address the issue explained in above note. + */ +case class StatefulOpClusteredDistribution( + expressions: Seq[Expression], + _requiredNumPartitions: Int) extends Distribution { + require( + expressions != Nil, + "The expressions for hash of a StatefulOpClusteredDistribution should not be Nil. " + + "An AllTuples should be used to represent a distribution that only has " + + "a single partition.") + + override val requiredNumPartitions: Option[Int] = Some(_requiredNumPartitions) + + override def createPartitioning(numPartitions: Int): Partitioning = { + assert(_requiredNumPartitions == numPartitions, + s"This StatefulOpClusteredDistribution requires ${_requiredNumPartitions} " + + s"partitions, but the actual number of partitions is $numPartitions.") + HashPartitioning(expressions, numPartitions) + } } /** @@ -199,6 +259,11 @@ case object SinglePartition extends Partitioning { * Represents a partitioning where rows are split up across partitions based on the hash * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be * in the same partition. + * + * Since [[StatefulOpClusteredDistribution]] relies on this partitioning and Spark requires + * stateful operators to retain the same physical partitioning during the lifetime of the query + * (including restart), the result of evaluation on `partitionIdExpression` must be unchanged + * across Spark versions. Violation of this requirement may bring silent correctness issue. */ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) extends Expression with Partitioning with Unevaluable { @@ -210,8 +275,18 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def satisfies0(required: Distribution): Boolean = { super.satisfies0(required) || { required match { - case ClusteredDistribution(requiredClustering, _) => - expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) + case h: StatefulOpClusteredDistribution => + expressions.length == h.expressions.length && expressions.zip(h.expressions).forall { + case (l, r) => l.semanticEquals(r) + } + case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) => + if (requireAllClusterKeys) { + // Checks `HashPartitioning` is partitioned on exactly same clustering keys of + // `ClusteredDistribution`. + c.areAllClusterKeysMatched(expressions) + } else { + expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) + } case _ => false } } @@ -271,8 +346,15 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) // `RangePartitioning(a, b, c)` satisfies `OrderedDistribution(a, b)`. val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) - case ClusteredDistribution(requiredClustering, _) => - ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) + case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) => + val expressions = ordering.map(_.child) + if (requireAllClusterKeys) { + // Checks `RangePartitioning` is partitioned on exactly same clustering keys of + // `ClusteredDistribution`. + c.areAllClusterKeysMatched(expressions) + } else { + expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) + } case _ => false } } @@ -380,7 +462,7 @@ trait ShuffleSpec { /** * Whether this shuffle spec can be used to create partitionings for the other children. */ - def canCreatePartitioning: Boolean = false + def canCreatePartitioning: Boolean /** * Creates a partitioning that can be used to re-partition the other side with the given @@ -412,6 +494,11 @@ case class RangeShuffleSpec( numPartitions: Int, distribution: ClusteredDistribution) extends ShuffleSpec { + // `RangePartitioning` is not compatible with any other partitioning since it can't guarantee + // data are co-partitioned for all the children, as range boundaries are randomly sampled. We + // can't let `RangeShuffleSpec` to create a partitioning. + override def canCreatePartitioning: Boolean = false + override def isCompatibleWith(other: ShuffleSpec): Boolean = other match { case SinglePartitionShuffleSpec => numPartitions == 1 case ShuffleSpecCollection(specs) => specs.exists(isCompatibleWith) @@ -424,8 +511,19 @@ case class RangeShuffleSpec( case class HashShuffleSpec( partitioning: HashPartitioning, distribution: ClusteredDistribution) extends ShuffleSpec { - lazy val hashKeyPositions: Seq[mutable.BitSet] = - createHashKeyPositions(distribution.clustering, partitioning.expressions) + + /** + * A sequence where each element is a set of positions of the hash partition key to the cluster + * keys. For instance, if cluster keys are [a, b, b] and hash partition keys are [a, b], the + * result will be [(0), (1, 2)]. + */ + lazy val hashKeyPositions: Seq[mutable.BitSet] = { + val distKeyToPos = mutable.Map.empty[Expression, mutable.BitSet] + distribution.clustering.zipWithIndex.foreach { case (distKey, distKeyPos) => + distKeyToPos.getOrElseUpdate(distKey.canonicalized, mutable.BitSet.empty).add(distKeyPos) + } + partitioning.expressions.map(k => distKeyToPos.getOrElse(k.canonicalized, mutable.BitSet.empty)) + } override def isCompatibleWith(other: ShuffleSpec): Boolean = other match { case SinglePartitionShuffleSpec => @@ -451,7 +549,17 @@ case class HashShuffleSpec( false } - override def canCreatePartitioning: Boolean = true + override def canCreatePartitioning: Boolean = { + // To avoid potential data skew, we don't allow `HashShuffleSpec` to create partitioning if + // the hash partition keys are not the full join keys (the cluster keys). Then the planner + // will add shuffles with the default partitioning of `ClusteredDistribution`, which uses all + // the join keys. + if (SQLConf.get.getConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION)) { + distribution.areAllClusterKeysMatched(partitioning.expressions) + } else { + true + } + } override def createPartitioning(clustering: Seq[Expression]): Partitioning = { val exprs = hashKeyPositions.map(v => clustering(v.head)) @@ -459,22 +567,6 @@ case class HashShuffleSpec( } override def numPartitions: Int = partitioning.numPartitions - - /** - * Returns a sequence where each element is a set of positions of the key in `hashKeys` to its - * positions in `requiredClusterKeys`. For instance, if `requiredClusterKeys` is [a, b, b] and - * `hashKeys` is [a, b], the result will be [(0), (1, 2)]. - */ - private def createHashKeyPositions( - requiredClusterKeys: Seq[Expression], - hashKeys: Seq[Expression]): Seq[mutable.BitSet] = { - val distKeyToPos = mutable.Map.empty[Expression, mutable.BitSet] - requiredClusterKeys.zipWithIndex.foreach { case (distKey, distKeyPos) => - distKeyToPos.getOrElseUpdate(distKey.canonicalized, mutable.BitSet.empty).add(distKeyPos) - } - - hashKeys.map(k => distKeyToPos(k.canonicalized)) - } } case class ShuffleSpecCollection(specs: Seq[ShuffleSpec]) extends ShuffleSpec { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/QueryExecutionMetering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/QueryExecutionMetering.scala index 8efc3593d72f5..b5a5e239b68ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/QueryExecutionMetering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/QueryExecutionMetering.scala @@ -79,7 +79,7 @@ case class QueryExecutionMetering() { val maxLengthRuleNames = if (map.isEmpty) { 0 } else { - map.keys.map(_.toString.length).max + map.keys.map(_.length).max } val colRuleName = "Rule".padTo(maxLengthRuleNames, " ").mkString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index 66a6a890022ac..e36a76b0b26cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -116,12 +116,12 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.optimizer.LikeSimplification" :: "org.apache.spark.sql.catalyst.optimizer.LimitPushDown" :: "org.apache.spark.sql.catalyst.optimizer.LimitPushDownThroughWindow" :: - "org.apache.spark.sql.catalyst.optimizer.NotPropagation" :: "org.apache.spark.sql.catalyst.optimizer.NullDownPropagation" :: "org.apache.spark.sql.catalyst.optimizer.NullPropagation" :: "org.apache.spark.sql.catalyst.optimizer.ObjectSerializerPruning" :: "org.apache.spark.sql.catalyst.optimizer.OptimizeCsvJsonExprs" :: "org.apache.spark.sql.catalyst.optimizer.OptimizeIn" :: + "org.apache.spark.sql.catalyst.optimizer.OptimizeOneRowPlan" :: "org.apache.spark.sql.catalyst.optimizer.Optimizer$OptimizeSubqueries" :: "org.apache.spark.sql.catalyst.optimizer.OptimizeRepartition" :: "org.apache.spark.sql.catalyst.optimizer.OptimizeWindowFunctions" :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index f78bbbf6c7516..ac60e18b2c1bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -246,6 +246,16 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre children.foldLeft(Option.empty[BaseType]) { (l, r) => l.orElse(r.find(f)) } } + /** + * Test whether there is [[TreeNode]] satisfies the conditions specified in `f`. + * The condition is recursively applied to this node and all of its children (pre-order). + */ + def exists(f: BaseType => Boolean): Boolean = if (f(this)) { + true + } else { + children.exists(_.exists(f)) + } + /** * Runs the given function on this node and then recursively on [[children]]. * @param f the function to be applied to each node in the tree. @@ -341,10 +351,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre // This is a temporary solution, we will change the type of children to IndexedSeq in a // followup PR private def asIndexedSeq(seq: Seq[BaseType]): IndexedSeq[BaseType] = { - if (seq.isInstanceOf[IndexedSeq[BaseType]]) { - seq.asInstanceOf[IndexedSeq[BaseType]] - } else { - seq.toIndexedSeq + seq match { + case types: IndexedSeq[BaseType] => types + case other => other.toIndexedSeq } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index e02bc475cfee0..b595966bcc235 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -33,13 +33,11 @@ object TreePattern extends Enumeration { val GROUPING_ANALYTICS: Value = Value val BINARY_ARITHMETIC: Value = Value val BINARY_COMPARISON: Value = Value - val BOOL_AGG: Value = Value val CASE_WHEN: Value = Value val CAST: Value = Value val COALESCE: Value = Value val CONCAT: Value = Value val COUNT: Value = Value - val COUNT_IF: Value = Value val CREATE_NAMED_STRUCT: Value = Value val CURRENT_LIKE: Value = Value val DESERIALIZE_TO_OBJECT: Value = Value @@ -74,7 +72,6 @@ object TreePattern extends Enumeration { val PIVOT: Value = Value val PLAN_EXPRESSION: Value = Value val PYTHON_UDF: Value = Value - val REGR_COUNT: Value = Value val RUNTIME_REPLACEABLE: Value = Value val SCALAR_SUBQUERY: Value = Value val SCALA_UDF: Value = Value @@ -111,6 +108,7 @@ object TreePattern extends Enumeration { val PROJECT: Value = Value val RELATION_TIME_TRAVEL: Value = Value val REPARTITION_OPERATION: Value = Value + val REBALANCE_PARTITIONS: Value = Value val UNION: Value = Value val UNRESOLVED_RELATION: Value = Value val UNRESOLVED_WITH: Value = Value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 445ec8444a915..65da5e9cb4251 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.time._ -import java.time.temporal.{ChronoField, ChronoUnit, IsoFields} +import java.time.temporal.{ChronoField, ChronoUnit, IsoFields, Temporal} import java.util.{Locale, TimeZone} import java.util.concurrent.TimeUnit._ @@ -107,6 +107,17 @@ object DateTimeUtils { rebaseJulianToGregorianDays(julianDays) } + /** + * Converts an Java object to days. + * + * @param obj Either an object of `java.sql.Date` or `java.time.LocalDate`. + * @return The number of days since 1970-01-01. + */ + def anyToDays(obj: Any): Int = obj match { + case d: Date => fromJavaDate(d) + case ld: LocalDate => localDateToDays(ld) + } + /** * Converts days since the epoch 1970-01-01 in Proleptic Gregorian calendar to a local date * at the default JVM time zone in the hybrid calendar (Julian + Gregorian). It rebases the given @@ -180,6 +191,17 @@ object DateTimeUtils { rebaseJulianToGregorianMicros(micros) } + /** + * Converts an Java object to microseconds. + * + * @param obj Either an object of `java.sql.Timestamp` or `java.time.Instant`. + * @return The number of micros since the epoch. + */ + def anyToMicros(obj: Any): Long = obj match { + case t: Timestamp => fromJavaTimestamp(t) + case i: Instant => instantToMicros(i) + } + /** * Returns the number of microseconds since epoch from Julian day and nanoseconds in a day. */ @@ -1163,4 +1185,81 @@ object DateTimeUtils { val localStartTs = getLocalDateTime(startMicros, zoneId) ChronoUnit.MICROS.between(localStartTs, localEndTs) } + + /** + * Adds the specified number of units to a timestamp. + * + * @param unit A keyword that specifies the interval units to add to the input timestamp. + * @param quantity The amount of `unit`s to add. It can be positive or negative. + * @param micros The input timestamp value, expressed in microseconds since 1970-01-01 00:00:00Z. + * @param zoneId The time zone ID at which the operation is performed. + * @return A timestamp value, expressed in microseconds since 1970-01-01 00:00:00Z. + */ + def timestampAdd(unit: String, quantity: Int, micros: Long, zoneId: ZoneId): Long = { + try { + unit.toUpperCase(Locale.ROOT) match { + case "MICROSECOND" => + timestampAddDayTime(micros, quantity, zoneId) + case "MILLISECOND" => + timestampAddDayTime(micros, quantity * MICROS_PER_MILLIS, zoneId) + case "SECOND" => + timestampAddDayTime(micros, quantity * MICROS_PER_SECOND, zoneId) + case "MINUTE" => + timestampAddDayTime(micros, quantity * MICROS_PER_MINUTE, zoneId) + case "HOUR" => + timestampAddDayTime(micros, quantity * MICROS_PER_HOUR, zoneId) + case "DAY" | "DAYOFYEAR" => + timestampAddDayTime(micros, quantity * MICROS_PER_DAY, zoneId) + case "WEEK" => + timestampAddDayTime(micros, quantity * MICROS_PER_DAY * DAYS_PER_WEEK, zoneId) + case "MONTH" => + timestampAddMonths(micros, quantity, zoneId) + case "QUARTER" => + timestampAddMonths(micros, quantity * 3, zoneId) + case "YEAR" => + timestampAddMonths(micros, quantity * MONTHS_PER_YEAR, zoneId) + } + } catch { + case _: scala.MatchError => + throw new IllegalStateException(s"Got the unexpected unit '$unit'.") + case _: ArithmeticException | _: DateTimeException => + throw QueryExecutionErrors.timestampAddOverflowError(micros, quantity, unit) + case e: Throwable => + throw new IllegalStateException(s"Failure of 'timestampAdd': ${e.getMessage}") + } + } + + private val timestampDiffMap = Map[String, (Temporal, Temporal) => Long]( + "MICROSECOND" -> ChronoUnit.MICROS.between, + "MILLISECOND" -> ChronoUnit.MILLIS.between, + "SECOND" -> ChronoUnit.SECONDS.between, + "MINUTE" -> ChronoUnit.MINUTES.between, + "HOUR" -> ChronoUnit.HOURS.between, + "DAY" -> ChronoUnit.DAYS.between, + "WEEK" -> ChronoUnit.WEEKS.between, + "MONTH" -> ChronoUnit.MONTHS.between, + "QUARTER" -> ((startTs: Temporal, endTs: Temporal) => + ChronoUnit.MONTHS.between(startTs, endTs) / 3), + "YEAR" -> ChronoUnit.YEARS.between) + + /** + * Gets the difference between two timestamps. + * + * @param unit Specifies the interval units in which to express the difference between + * the two timestamp parameters. + * @param startTs A timestamp which the function subtracts from `endTs`. + * @param endTs A timestamp from which the function subtracts `startTs`. + * @param zoneId The time zone ID at which the operation is performed. + * @return The time span between two timestamp values, in the units specified. + */ + def timestampDiff(unit: String, startTs: Long, endTs: Long, zoneId: ZoneId): Long = { + val unitInUpperCase = unit.toUpperCase(Locale.ROOT) + if (timestampDiffMap.contains(unitInUpperCase)) { + val startLocalTs = getLocalDateTime(startTs, zoneId) + val endLocalTs = getLocalDateTime(endTs, zoneId) + timestampDiffMap(unitInUpperCase)(startLocalTs, endLocalTs) + } else { + throw new IllegalStateException(s"Got the unexpected unit '$unit'.") + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala index ab7c9310bf844..5a9e52a51a27f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala @@ -57,7 +57,7 @@ class FailureSafeParser[IN]( def parse(input: IN): Iterator[InternalRow] = { try { - rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null)) + rawParser.apply(input).iterator.map(row => toResultRow(Some(row), () => null)) } catch { case e: BadRecordException => mode match { case PermissiveMode => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index fc927ba054f82..ceed8df5026d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -1263,7 +1263,7 @@ object IntervalUtils { Math.multiplyExact(v, MONTHS_PER_YEAR) } catch { case _: ArithmeticException => - throw QueryExecutionErrors.castingCauseOverflowError(v, YM(endField).catalogString) + throw QueryExecutionErrors.castingCauseOverflowError(v, YM(endField)) } case MONTH => v } @@ -1272,7 +1272,7 @@ object IntervalUtils { def longToYearMonthInterval(v: Long, endField: Byte): Int = { val vInt = v.toInt if (v != vInt) { - throw QueryExecutionErrors.castingCauseOverflowError(v, YM(endField).catalogString) + throw QueryExecutionErrors.castingCauseOverflowError(v, YM(endField)) } intToYearMonthInterval(vInt, endField) } @@ -1289,7 +1289,7 @@ object IntervalUtils { val vShort = vInt.toShort if (vInt != vShort) { throw QueryExecutionErrors.castingCauseOverflowError( - toYearMonthIntervalString(v, ANSI_STYLE, startField, endField), ShortType.catalogString) + toYearMonthIntervalString(v, ANSI_STYLE, startField, endField), ShortType) } vShort } @@ -1299,7 +1299,7 @@ object IntervalUtils { val vByte = vInt.toByte if (vInt != vByte) { throw QueryExecutionErrors.castingCauseOverflowError( - toYearMonthIntervalString(v, ANSI_STYLE, startField, endField), ByteType.catalogString) + toYearMonthIntervalString(v, ANSI_STYLE, startField, endField), ByteType) } vByte } @@ -1311,7 +1311,7 @@ object IntervalUtils { Math.multiplyExact(v, MICROS_PER_DAY) } catch { case _: ArithmeticException => - throw QueryExecutionErrors.castingCauseOverflowError(v, DT(endField).catalogString) + throw QueryExecutionErrors.castingCauseOverflowError(v, DT(endField)) } case HOUR => v * MICROS_PER_HOUR case MINUTE => v * MICROS_PER_MINUTE @@ -1329,7 +1329,7 @@ object IntervalUtils { } } catch { case _: ArithmeticException => - throw QueryExecutionErrors.castingCauseOverflowError(v, DT(endField).catalogString) + throw QueryExecutionErrors.castingCauseOverflowError(v, DT(endField)) } } @@ -1347,7 +1347,7 @@ object IntervalUtils { val vInt = vLong.toInt if (vLong != vInt) { throw QueryExecutionErrors.castingCauseOverflowError( - toDayTimeIntervalString(v, ANSI_STYLE, startField, endField), IntegerType.catalogString) + toDayTimeIntervalString(v, ANSI_STYLE, startField, endField), IntegerType) } vInt } @@ -1357,7 +1357,7 @@ object IntervalUtils { val vShort = vLong.toShort if (vLong != vShort) { throw QueryExecutionErrors.castingCauseOverflowError( - toDayTimeIntervalString(v, ANSI_STYLE, startField, endField), ShortType.catalogString) + toDayTimeIntervalString(v, ANSI_STYLE, startField, endField), ShortType) } vShort } @@ -1367,7 +1367,7 @@ object IntervalUtils { val vByte = vLong.toByte if (vLong != vByte) { throw QueryExecutionErrors.castingCauseOverflowError( - toDayTimeIntervalString(v, ANSI_STYLE, startField, endField), ByteType.catalogString) + toDayTimeIntervalString(v, ANSI_STYLE, startField, endField), ByteType) } vByte } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberFormatter.scala new file mode 100644 index 0000000000000..a14aceb692291 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberFormatter.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import java.math.BigDecimal +import java.text.{DecimalFormat, ParsePosition} +import java.util.Locale + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.types.{Decimal, DecimalType} +import org.apache.spark.unsafe.types.UTF8String + +object NumberFormatter { + final val POINT_SIGN = '.' + final val POINT_LETTER = 'D' + final val COMMA_SIGN = ',' + final val COMMA_LETTER = 'G' + final val MINUS_SIGN = '-' + final val MINUS_LETTER = 'S' + final val DOLLAR_SIGN = '$' + final val NINE_DIGIT = '9' + final val ZERO_DIGIT = '0' + final val POUND_SIGN = '#' + + final val COMMA_SIGN_STRING = COMMA_SIGN.toString + final val POUND_SIGN_STRING = POUND_SIGN.toString + + final val SIGN_SET = Set(POINT_SIGN, COMMA_SIGN, MINUS_SIGN, DOLLAR_SIGN) +} + +class NumberFormatter(originNumberFormat: String, isParse: Boolean = true) extends Serializable { + import NumberFormatter._ + + protected val normalizedNumberFormat = normalize(originNumberFormat) + + private val transformedFormat = transform(normalizedNumberFormat) + + private lazy val numberDecimalFormat = { + val decimalFormat = new DecimalFormat(transformedFormat) + decimalFormat.setParseBigDecimal(true) + decimalFormat + } + + private lazy val (precision, scale) = { + val formatSplits = normalizedNumberFormat.split(POINT_SIGN).map(_.filterNot(isSign)) + assert(formatSplits.length <= 2) + val precision = formatSplits.map(_.length).sum + val scale = if (formatSplits.length == 2) formatSplits.last.length else 0 + (precision, scale) + } + + def parsedDecimalType: DecimalType = DecimalType(precision, scale) + + /** + * DecimalFormat provides '#' and '0' as placeholder of digit, ',' as grouping separator, + * '.' as decimal separator, '-' as minus, '$' as dollar, but not '9', 'G', 'D', 'S'. So we need + * replace them show below: + * 1. '9' -> '#' + * 2. 'G' -> ',' + * 3. 'D' -> '.' + * 4. 'S' -> '-' + * + * Note: When calling format, we must preserve the digits after decimal point, so the digits + * after decimal point should be replaced as '0'. For example: '999.9' will be normalized as + * '###.0' and '999.99' will be normalized as '###.00', so if the input is 454, the format + * output will be 454.0 and 454.00 respectively. + * + * @param format number format string + * @return normalized number format string + */ + private def normalize(format: String): String = { + var notFindDecimalPoint = true + val normalizedFormat = format.toUpperCase(Locale.ROOT).map { + case NINE_DIGIT if notFindDecimalPoint => POUND_SIGN + case ZERO_DIGIT if isParse && notFindDecimalPoint => POUND_SIGN + case NINE_DIGIT if !notFindDecimalPoint => ZERO_DIGIT + case COMMA_LETTER => COMMA_SIGN + case POINT_LETTER | POINT_SIGN => + notFindDecimalPoint = false + POINT_SIGN + case MINUS_LETTER => MINUS_SIGN + case other => other + } + // If the comma is at the beginning or end of number format, then DecimalFormat will be + // invalid. For example, "##,###," or ",###,###" for DecimalFormat is invalid, so we must use + // "##,###" or "###,###". + normalizedFormat.stripPrefix(COMMA_SIGN_STRING).stripSuffix(COMMA_SIGN_STRING) + } + + private def isSign(c: Char): Boolean = { + SIGN_SET.contains(c) + } + + private def transform(format: String): String = { + if (format.contains(MINUS_SIGN)) { + // For example: '#.######' represents a positive number, + // but '#.######;#.######-' represents a negative number. + val positiveFormatString = format.replaceAll("-", "") + s"$positiveFormatString;$format" + } else { + format + } + } + + def check(): TypeCheckResult = { + def invalidSignPosition(c: Char): Boolean = { + val signIndex = normalizedNumberFormat.indexOf(c) + signIndex > 0 && signIndex < normalizedNumberFormat.length - 1 + } + + def multipleSignInNumberFormatError(message: String): String = { + s"At most one $message is allowed in the number format: '$originNumberFormat'" + } + + def nonFistOrLastCharInNumberFormatError(message: String): String = { + s"$message must be the first or last char in the number format: '$originNumberFormat'" + } + + if (normalizedNumberFormat.length == 0) { + TypeCheckResult.TypeCheckFailure("Number format cannot be empty") + } else if (normalizedNumberFormat.count(_ == POINT_SIGN) > 1) { + TypeCheckResult.TypeCheckFailure( + multipleSignInNumberFormatError(s"'$POINT_LETTER' or '$POINT_SIGN'")) + } else if (normalizedNumberFormat.count(_ == MINUS_SIGN) > 1) { + TypeCheckResult.TypeCheckFailure( + multipleSignInNumberFormatError(s"'$MINUS_LETTER' or '$MINUS_SIGN'")) + } else if (normalizedNumberFormat.count(_ == DOLLAR_SIGN) > 1) { + TypeCheckResult.TypeCheckFailure(multipleSignInNumberFormatError(s"'$DOLLAR_SIGN'")) + } else if (invalidSignPosition(MINUS_SIGN)) { + TypeCheckResult.TypeCheckFailure( + nonFistOrLastCharInNumberFormatError(s"'$MINUS_LETTER' or '$MINUS_SIGN'")) + } else if (invalidSignPosition(DOLLAR_SIGN)) { + TypeCheckResult.TypeCheckFailure( + nonFistOrLastCharInNumberFormatError(s"'$DOLLAR_SIGN'")) + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + /** + * Convert string to numeric based on the given number format. + * The format can consist of the following characters: + * '0' or '9': digit position + * '.' or 'D': decimal point (only allowed once) + * ',' or 'G': group (thousands) separator + * '-' or 'S': sign anchored to number (only allowed once) + * '$': value with a leading dollar sign (only allowed once) + * + * @param input the string need to converted + * @return decimal obtained from string parsing + */ + def parse(input: UTF8String): Decimal = { + val inputStr = input.toString.trim + val inputSplits = inputStr.split(POINT_SIGN) + assert(inputSplits.length <= 2) + if (inputSplits.length == 1) { + if (inputStr.filterNot(isSign).length > precision - scale) { + throw QueryExecutionErrors.invalidNumberFormatError(input, originNumberFormat) + } + } else if (inputSplits(0).filterNot(isSign).length > precision - scale || + inputSplits(1).filterNot(isSign).length > scale) { + throw QueryExecutionErrors.invalidNumberFormatError(input, originNumberFormat) + } + + try { + val number = numberDecimalFormat.parse(inputStr, new ParsePosition(0)) + assert(number.isInstanceOf[BigDecimal]) + Decimal(number.asInstanceOf[BigDecimal]) + } catch { + case _: IllegalArgumentException => + throw QueryExecutionErrors.invalidNumberFormatError(input, originNumberFormat) + } + } + + /** + * Convert numeric to string based on the given number format. + * The format can consist of the following characters: + * '9': digit position (can be dropped if insignificant) + * '0': digit position (will not be dropped, even if insignificant) + * '.' or 'D': decimal point (only allowed once) + * ',' or 'G': group (thousands) separator + * '-' or 'S': sign anchored to number (only allowed once) + * '$': value with a leading dollar sign (only allowed once) + * + * @param input the decimal to format + * @param numberFormat the format string + * @return The string after formatting input decimal + */ + def format(input: Decimal): String = { + val bigDecimal = input.toJavaBigDecimal + val decimalPlainStr = bigDecimal.toPlainString + if (decimalPlainStr.length > transformedFormat.length) { + transformedFormat.replaceAll("0", POUND_SIGN_STRING) + } else { + var resultStr = numberDecimalFormat.format(bigDecimal) + // Since we trimmed the comma at the beginning or end of number format in function + // `normalize`, we restore the comma to the result here. + // For example, if the specified number format is "99,999," or ",999,999", function + // `normalize` normalize them to "##,###" or "###,###". + // new DecimalFormat("##,###").parse(12454) and new DecimalFormat("###,###").parse(124546) + // will return "12,454" and "124,546" respectively. So we add ',' at the end and head of + // the result, then the final output are "12,454," or ",124,546". + if (originNumberFormat.last == COMMA_SIGN || originNumberFormat.last == COMMA_LETTER) { + resultStr = resultStr + COMMA_SIGN + } + if (originNumberFormat.charAt(0) == COMMA_SIGN || + originNumberFormat.charAt(0) == COMMA_LETTER) { + resultStr = COMMA_SIGN + resultStr + } + + resultStr + } + } +} + +// Visible for testing +class TestNumberFormatter(originNumberFormat: String, isParse: Boolean = true) + extends NumberFormatter(originNumberFormat, isParse) { + def checkWithException(): Unit = { + check() match { + case TypeCheckResult.TypeCheckFailure(message) => + throw new AnalysisException(message) + case _ => + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberUtils.scala deleted file mode 100644 index 6efde2aa657b9..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberUtils.scala +++ /dev/null @@ -1,189 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.util - -import java.math.BigDecimal -import java.text.{DecimalFormat, NumberFormat, ParsePosition} -import java.util.Locale - -import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.types.Decimal -import org.apache.spark.unsafe.types.UTF8String - -object NumberUtils { - - private val pointSign = '.' - private val letterPointSign = 'D' - private val commaSign = ',' - private val letterCommaSign = 'G' - private val minusSign = '-' - private val letterMinusSign = 'S' - private val dollarSign = '$' - - private val commaSignStr = commaSign.toString - - private def normalize(format: String): String = { - var notFindDecimalPoint = true - val normalizedFormat = format.toUpperCase(Locale.ROOT).map { - case '9' if notFindDecimalPoint => '#' - case '9' if !notFindDecimalPoint => '0' - case `letterPointSign` => - notFindDecimalPoint = false - pointSign - case `letterCommaSign` => commaSign - case `letterMinusSign` => minusSign - case `pointSign` => - notFindDecimalPoint = false - pointSign - case other => other - } - // If the comma is at the beginning or end of number format, then DecimalFormat will be invalid. - // For example, "##,###," or ",###,###" for DecimalFormat is invalid, so we must use "##,###" - // or "###,###". - normalizedFormat.stripPrefix(commaSignStr).stripSuffix(commaSignStr) - } - - private def isSign(c: Char): Boolean = { - Set(pointSign, commaSign, minusSign, dollarSign).contains(c) - } - - private def transform(format: String): String = { - if (format.contains(minusSign)) { - val positiveFormatString = format.replaceAll("-", "") - s"$positiveFormatString;$format" - } else { - format - } - } - - private def check(normalizedFormat: String, numberFormat: String) = { - def invalidSignPosition(format: String, c: Char): Boolean = { - val signIndex = format.indexOf(c) - signIndex > 0 && signIndex < format.length - 1 - } - - if (normalizedFormat.count(_ == pointSign) > 1) { - throw QueryCompilationErrors.multipleSignInNumberFormatError( - s"'$letterPointSign' or '$pointSign'", numberFormat) - } else if (normalizedFormat.count(_ == minusSign) > 1) { - throw QueryCompilationErrors.multipleSignInNumberFormatError( - s"'$letterMinusSign' or '$minusSign'", numberFormat) - } else if (normalizedFormat.count(_ == dollarSign) > 1) { - throw QueryCompilationErrors.multipleSignInNumberFormatError(s"'$dollarSign'", numberFormat) - } else if (invalidSignPosition(normalizedFormat, minusSign)) { - throw QueryCompilationErrors.nonFistOrLastCharInNumberFormatError( - s"'$letterMinusSign' or '$minusSign'", numberFormat) - } else if (invalidSignPosition(normalizedFormat, dollarSign)) { - throw QueryCompilationErrors.nonFistOrLastCharInNumberFormatError( - s"'$dollarSign'", numberFormat) - } - } - - /** - * Convert string to numeric based on the given number format. - * The format can consist of the following characters: - * '9': digit position (can be dropped if insignificant) - * '0': digit position (will not be dropped, even if insignificant) - * '.': decimal point (only allowed once) - * ',': group (thousands) separator - * 'S': sign anchored to number (uses locale) - * 'D': decimal point (uses locale) - * 'G': group separator (uses locale) - * '$': specifies that the input value has a leading $ (Dollar) sign. - * - * @param input the string need to converted - * @param numberFormat the given number format - * @return decimal obtained from string parsing - */ - def parse(input: UTF8String, numberFormat: String): Decimal = { - val normalizedFormat = normalize(numberFormat) - check(normalizedFormat, numberFormat) - - val precision = normalizedFormat.filterNot(isSign).length - val formatSplits = normalizedFormat.split(pointSign) - val scale = if (formatSplits.length == 1) { - 0 - } else { - formatSplits(1).filterNot(isSign).length - } - val transformedFormat = transform(normalizedFormat) - val numberFormatInstance = NumberFormat.getInstance() - val numberDecimalFormat = numberFormatInstance.asInstanceOf[DecimalFormat] - numberDecimalFormat.setParseBigDecimal(true) - numberDecimalFormat.applyPattern(transformedFormat) - val inputStr = input.toString.trim - val inputSplits = inputStr.split(pointSign) - if (inputSplits.length == 1) { - if (inputStr.filterNot(isSign).length > precision - scale) { - throw QueryExecutionErrors.invalidNumberFormatError(numberFormat) - } - } else if (inputSplits(0).filterNot(isSign).length > precision - scale || - inputSplits(1).filterNot(isSign).length > scale) { - throw QueryExecutionErrors.invalidNumberFormatError(numberFormat) - } - val number = numberDecimalFormat.parse(inputStr, new ParsePosition(0)) - Decimal(number.asInstanceOf[BigDecimal]) - } - - /** - * Convert numeric to string based on the given number format. - * The format can consist of the following characters: - * '9': digit position (can be dropped if insignificant) - * '0': digit position (will not be dropped, even if insignificant) - * '.': decimal point (only allowed once) - * ',': group (thousands) separator - * 'S': sign anchored to number (uses locale) - * 'D': decimal point (uses locale) - * 'G': group separator (uses locale) - * '$': specifies that the input value has a leading $ (Dollar) sign. - * - * @param input the decimal to format - * @param numberFormat the format string - * @return The string after formatting input decimal - */ - def format(input: Decimal, numberFormat: String): String = { - val normalizedFormat = normalize(numberFormat) - check(normalizedFormat, numberFormat) - - val transformedFormat = transform(normalizedFormat) - val bigDecimal = input.toJavaBigDecimal - val decimalPlainStr = bigDecimal.toPlainString - if (decimalPlainStr.length > transformedFormat.length) { - transformedFormat.replaceAll("0", "#") - } else { - val decimalFormat = new DecimalFormat(transformedFormat) - var resultStr = decimalFormat.format(bigDecimal) - // Since we trimmed the comma at the beginning or end of number format in function - // `normalize`, we restore the comma to the result here. - // For example, if the specified number format is "99,999," or ",999,999", function - // `normalize` normalize them to "##,###" or "###,###". - // new DecimalFormat("##,###").parse(12454) and new DecimalFormat("###,###").parse(124546) - // will return "12,454" and "124,546" respectively. So we add ',' at the end and head of - // the result, then the final output are "12,454," or ",124,546". - if (numberFormat.last == commaSign || numberFormat.last == letterCommaSign) { - resultStr = resultStr + commaSign - } - if (numberFormat.charAt(0) == commaSign || numberFormat.charAt(0) == letterCommaSign) { - resultStr = commaSign + resultStr - } - - resultStr - } - } - -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala index 812d5ded4bf0f..57fecb774bd20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala @@ -43,7 +43,7 @@ class StringKeyHashMap[T](normalizer: (String) => String) { def remove(key: String): Option[T] = base.remove(normalizer(key)) - def iterator: Iterator[(String, T)] = base.toIterator + def iterator: Iterator[(String, T)] = base.iterator def clear(): Unit = base.clear() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala index 0d0f7a07bb478..4ad0337abc45e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -44,7 +44,7 @@ object StringUtils extends Logging { * @return the equivalent Java regular expression of the pattern */ def escapeLikeRegex(pattern: String, escapeChar: Char): String = { - val in = pattern.toIterator + val in = pattern.iterator val out = new StringBuilder() def fail(message: String) = throw QueryCompilationErrors.invalidPatternError(pattern, message) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index e26f397bb0b52..e06072cbed282 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -22,6 +22,8 @@ import java.nio.charset.Charset import java.nio.charset.StandardCharsets.UTF_8 import java.util.concurrent.atomic.AtomicBoolean +import com.google.common.io.ByteStreams + import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.internal.SQLConf @@ -48,42 +50,22 @@ package object util extends Logging { def fileToString(file: File, encoding: Charset = UTF_8): String = { val inStream = new FileInputStream(file) - val outStream = new ByteArrayOutputStream try { - var reading = true - while ( reading ) { - inStream.read() match { - case -1 => reading = false - case c => outStream.write(c) - } - } - outStream.flush() - } - finally { + new String(ByteStreams.toByteArray(inStream), encoding) + } finally { inStream.close() } - new String(outStream.toByteArray, encoding) } def resourceToBytes( resource: String, classLoader: ClassLoader = Utils.getSparkClassLoader): Array[Byte] = { val inStream = classLoader.getResourceAsStream(resource) - val outStream = new ByteArrayOutputStream try { - var reading = true - while ( reading ) { - inStream.read() match { - case -1 => reading = false - case c => outStream.write(c) - } - } - outStream.flush() - } - finally { + ByteStreams.toByteArray(inStream) + } finally { inStream.close() } - outStream.toByteArray } def resourceToString( @@ -135,8 +117,8 @@ package object util extends Logging { PrettyAttribute(usePrettyExpression(e.child).sql + "." + name, e.dataType) case e: GetArrayStructFields => PrettyAttribute(usePrettyExpression(e.child) + "." + e.field.name, e.dataType) - case r: RuntimeReplaceable => - PrettyAttribute(r.mkString(r.exprsReplaced.map(toPrettySQL)), r.dataType) + case r: InheritAnalysisRules => + PrettyAttribute(r.makeSQLString(r.parameters.map(toPrettySQL)), r.dataType) case c: CastBase if !c.getTagValue(Cast.USER_SPECIFIED_CAST).getOrElse(false) => PrettyAttribute(usePrettyExpression(c.child).sql, c.dataType) case p: PythonUDF => PrettyPythonUDF(p.name, p.dataType, p.children) @@ -159,7 +141,7 @@ package object util extends Logging { def toPrettySQL(e: Expression): String = usePrettyExpression(e).sql def escapeSingleQuotedString(str: String): String = { - val builder = StringBuilder.newBuilder + val builder = new StringBuilder str.foreach { case '\'' => builder ++= s"\\\'" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala index dbc4bd373751f..04af7eda6aaa9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.connector.catalog +import scala.collection.mutable + import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.quoteIfNeeded -import org.apache.spark.sql.connector.expressions.{BucketTransform, IdentityTransform, LogicalExpressions, Transform} -import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, IdentityTransform, LogicalExpressions, Transform} +import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} /** * Conversion helpers for working with v2 [[CatalogPlugin]]. @@ -37,7 +39,7 @@ private[sql] object CatalogV2Implicits { } implicit class BucketSpecHelper(spec: BucketSpec) { - def asTransform: BucketTransform = { + def asTransform: Transform = { val references = spec.bucketColumnNames.map(col => reference(Seq(col))) if (spec.sortColumnNames.nonEmpty) { val sortedCol = spec.sortColumnNames.map(col => reference(Seq(col))) @@ -49,21 +51,28 @@ private[sql] object CatalogV2Implicits { } implicit class TransformHelper(transforms: Seq[Transform]) { - def asPartitionColumns: Seq[String] = { - val (idTransforms, nonIdTransforms) = transforms.partition(_.isInstanceOf[IdentityTransform]) - - if (nonIdTransforms.nonEmpty) { - throw QueryCompilationErrors.cannotConvertTransformsToPartitionColumnsError(nonIdTransforms) + def convertTransforms: (Seq[String], Option[BucketSpec]) = { + val identityCols = new mutable.ArrayBuffer[String] + var bucketSpec = Option.empty[BucketSpec] + + transforms.map { + case IdentityTransform(FieldReference(Seq(col))) => + identityCols += col + + case BucketTransform(numBuckets, col, sortCol) => + if (bucketSpec.nonEmpty) throw QueryExecutionErrors.multipleBucketTransformsError + if (sortCol.isEmpty) { + bucketSpec = Some(BucketSpec(numBuckets, col.map(_.fieldNames.mkString(".")), Nil)) + } else { + bucketSpec = Some(BucketSpec(numBuckets, col.map(_.fieldNames.mkString(".")), + sortCol.map(_.fieldNames.mkString(".")))) + } + + case transform => + throw QueryExecutionErrors.unsupportedPartitionTransformError(transform) } - idTransforms.map(_.asInstanceOf[IdentityTransform]).map(_.reference).map { ref => - val parts = ref.fieldNames - if (parts.size > 1) { - throw QueryCompilationErrors.cannotPartitionByNestedColumnError(ref) - } else { - parts(0) - } - } + (identityCols.toSeq, bucketSpec) } } @@ -164,6 +173,18 @@ private[sql] object CatalogV2Implicits { def quoted: String = parts.map(quoteIfNeeded).mkString(".") } + implicit class TableIdentifierHelper(identifier: TableIdentifier) { + def quoted: String = { + identifier.database match { + case Some(db) => + Seq(db, identifier.table).map(quoteIfNeeded).mkString(".") + case _ => + quoteIfNeeded(identifier.table) + + } + } + } + def parseColumnPath(name: String): Seq[String] = { CatalystSqlParser.parseMultipartIdentifier(name) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index 597b3c3884c62..4092674046eca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -47,7 +47,8 @@ private[sql] object CatalogV2Util { Seq(TableCatalog.PROP_COMMENT, TableCatalog.PROP_LOCATION, TableCatalog.PROP_PROVIDER, - TableCatalog.PROP_OWNER) + TableCatalog.PROP_OWNER, + TableCatalog.PROP_EXTERNAL) /** * The list of reserved namespace properties, which can not be removed or changed directly by diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala index 07f66a614b2ad..bf92107f6ae2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala @@ -22,9 +22,8 @@ import java.util import scala.collection.JavaConverters._ import scala.collection.mutable -import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.util.quoteIfNeeded +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.TableIdentifierHelper import org.apache.spark.sql.connector.catalog.V1Table.addV2TableProperties import org.apache.spark.sql.connector.expressions.{LogicalExpressions, Transform} import org.apache.spark.sql.types.StructType @@ -33,17 +32,6 @@ import org.apache.spark.sql.types.StructType * An implementation of catalog v2 `Table` to expose v1 table metadata. */ private[sql] case class V1Table(v1Table: CatalogTable) extends Table { - implicit class IdentifierHelper(identifier: TableIdentifier) { - def quoted: String = { - identifier.database match { - case Some(db) => - Seq(db, identifier.table).map(quoteIfNeeded).mkString(".") - case _ => - quoteIfNeeded(identifier.table) - - } - } - } def catalogTable: CatalogTable = v1Table @@ -92,7 +80,9 @@ private[sql] object V1Table { TableCatalog.OPTION_PREFIX + key -> value } ++ v1Table.provider.map(TableCatalog.PROP_PROVIDER -> _) ++ v1Table.comment.map(TableCatalog.PROP_COMMENT -> _) ++ - v1Table.storage.locationUri.map(TableCatalog.PROP_LOCATION -> _.toString) ++ + (if (external) { + v1Table.storage.locationUri.map(TableCatalog.PROP_LOCATION -> _.toString) + } else None) ++ (if (external) Some(TableCatalog.PROP_EXTERNAL -> "true") else None) ++ Some(TableCatalog.PROP_OWNER -> v1Table.owner) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala index 996b2566eeb7b..fbd2520e2a774 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.connector.expressions +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.types.{DataType, IntegerType, StringType} @@ -48,8 +49,8 @@ private[sql] object LogicalExpressions { def bucket( numBuckets: Int, references: Array[NamedReference], - sortedCols: Array[NamedReference]): BucketTransform = - BucketTransform(literal(numBuckets, IntegerType), references, sortedCols) + sortedCols: Array[NamedReference]): SortedBucketTransform = + SortedBucketTransform(literal(numBuckets, IntegerType), references, sortedCols) def identity(reference: NamedReference): IdentityTransform = IdentityTransform(reference) @@ -101,8 +102,7 @@ private[sql] abstract class SingleColumnTransform(ref: NamedReference) extends R private[sql] final case class BucketTransform( numBuckets: Literal[Int], - columns: Seq[NamedReference], - sortedColumns: Seq[NamedReference] = Seq.empty[NamedReference]) extends RewritableTransform { + columns: Seq[NamedReference]) extends RewritableTransform { override val name: String = "bucket" @@ -112,13 +112,9 @@ private[sql] final case class BucketTransform( override def arguments: Array[Expression] = numBuckets +: columns.toArray - override def toString: String = - if (sortedColumns.nonEmpty) { - s"bucket(${arguments.map(_.describe).mkString(", ")}," + - s" ${sortedColumns.map(_.describe).mkString(", ")})" - } else { - s"bucket(${arguments.map(_.describe).mkString(", ")})" - } + override def describe: String = s"bucket(${arguments.map(_.describe).mkString(", ")})" + + override def toString: String = describe override def withReferences(newReferences: Seq[NamedReference]): Transform = { this.copy(columns = newReferences) @@ -126,32 +122,52 @@ private[sql] final case class BucketTransform( } private[sql] object BucketTransform { - def unapply(expr: Expression): Option[(Int, FieldReference, FieldReference)] = - expr match { - case transform: Transform => + def unapply(transform: Transform): Option[(Int, Seq[NamedReference], Seq[NamedReference])] = transform match { - case BucketTransform(n, FieldReference(parts), FieldReference(sortCols)) => - Some((n, FieldReference(parts), FieldReference(sortCols))) + case NamedTransform("sorted_bucket", arguments) => + var posOfLit: Int = -1 + var numOfBucket: Int = -1 + arguments.zipWithIndex.foreach { + case (Lit(value: Int, IntegerType), i) => + numOfBucket = value + posOfLit = i case _ => - None } + Some(numOfBucket, arguments.take(posOfLit).map(_.asInstanceOf[NamedReference]), + arguments.drop(posOfLit + 1).map(_.asInstanceOf[NamedReference])) + case NamedTransform("bucket", arguments) => + var numOfBucket: Int = -1 + arguments(0) match { + case Lit(value: Int, IntegerType) => + numOfBucket = value + case _ => throw new SparkException("The first element in BucketTransform arguments " + + "should be an Integer Literal.") + } + Some(numOfBucket, arguments.drop(1).map(_.asInstanceOf[NamedReference]), + Seq.empty[FieldReference]) case _ => None } +} - def unapply(transform: Transform): Option[(Int, NamedReference, NamedReference)] = - transform match { - case NamedTransform("bucket", Seq( - Lit(value: Int, IntegerType), - Ref(partCols: Seq[String]), - Ref(sortCols: Seq[String]))) => - Some((value, FieldReference(partCols), FieldReference(sortCols))) - case NamedTransform("bucket", Seq( - Lit(value: Int, IntegerType), - Ref(partCols: Seq[String]))) => - Some((value, FieldReference(partCols), FieldReference(Seq.empty[String]))) - case _ => - None +private[sql] final case class SortedBucketTransform( + numBuckets: Literal[Int], + columns: Seq[NamedReference], + sortedColumns: Seq[NamedReference] = Seq.empty[NamedReference]) extends RewritableTransform { + + override val name: String = "sorted_bucket" + + override def references: Array[NamedReference] = { + arguments.collect { case named: NamedReference => named } + } + + override def arguments: Array[Expression] = (columns.toArray :+ numBuckets) ++ sortedColumns + + override def toString: String = s"$name(${arguments.map(_.describe).mkString(", ")})" + + override def withReferences(newReferences: Seq[NamedReference]): Transform = { + this.copy(columns = newReferences.take(columns.length), + sortedColumns = newReferences.drop(columns.length)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index fcbcb5491587e..6bf0ec8eb8c40 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.util.{toPrettySQL, FailFastMode, ParseMode, import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, UnboundFunction} -import org.apache.spark.sql.connector.expressions.{NamedReference, Transform} +import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.{LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED, LEGACY_CTE_PRECEDENCE_POLICY} import org.apache.spark.sql.sources.Filter @@ -93,8 +93,8 @@ object QueryCompilationErrors { def unsupportedIfNotExistsError(tableName: String): Throwable = { new AnalysisException( - errorClass = "IF_PARTITION_NOT_EXISTS_UNSUPPORTED", - messageParameters = Array(tableName)) + errorClass = "UNSUPPORTED_FEATURE", + messageParameters = Array(s"IF NOT EXISTS for the table '$tableName' by INSERT INTO.")) } def nonPartitionColError(partitionName: String): Throwable = { @@ -158,16 +158,16 @@ object QueryCompilationErrors { def upCastFailureError( fromStr: String, from: Expression, to: DataType, walkedTypePath: Seq[String]): Throwable = { new AnalysisException( - s"Cannot up cast $fromStr from " + - s"${from.dataType.catalogString} to ${to.catalogString}.\n" + + errorClass = "CANNOT_UP_CAST_DATATYPE", + messageParameters = Array( + fromStr, + from.dataType.catalogString, + to.catalogString, s"The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + - "You can either add an explicit cast to the input data or choose a higher precision " + - "type of the field in the target object") - } - - def unsupportedAbstractDataTypeForUpCastError(gotType: AbstractDataType): Throwable = { - new AnalysisException( - s"UpCast only support DecimalType as AbstractDataType yet, but got: $gotType") + "You can either add an explicit cast to the input data or choose a higher precision " + + "type of the field in the target object" + ) + ) } def outerScopeFailureForNewInstanceError(className: String): Throwable = { @@ -192,11 +192,15 @@ object QueryCompilationErrors { } def groupingMustWithGroupingSetsOrCubeOrRollupError(): Throwable = { - new AnalysisException("grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") + new AnalysisException( + errorClass = "UNSUPPORTED_GROUPING_EXPRESSION", + messageParameters = Array.empty) } def pandasUDFAggregateNotSupportedInPivotError(): Throwable = { - new AnalysisException("Pandas UDF aggregate expressions are currently not supported in pivot.") + new AnalysisException( + errorClass = "UNSUPPORTED_FEATURE", + messageParameters = Array("Pandas UDF aggregate expressions don't support pivot.")) } def aggregateExpressionRequiredForPivotError(sql: String): Throwable = { @@ -415,13 +419,6 @@ object QueryCompilationErrors { s"'${child.output.map(_.name).mkString("(", ",", ")")}'") } - def cannotUpCastAsAttributeError( - fromAttr: Attribute, toAttr: Attribute): Throwable = { - new AnalysisException(s"Cannot up cast ${fromAttr.sql} from " + - s"${fromAttr.dataType.catalogString} to ${toAttr.dataType.catalogString} " + - "as it may truncate") - } - def functionUndefinedError(name: FunctionIdentifier): Throwable = { new AnalysisException(s"undefined function $name") } @@ -726,8 +723,20 @@ object QueryCompilationErrors { s"Acceptable modes are ${PermissiveMode.name} and ${FailFastMode.name}.") } - def unfoldableFieldUnsupportedError(): Throwable = { - new AnalysisException("The field parameter needs to be a foldable string value.") + def requireLiteralParameter( + funcName: String, argName: String, requiredType: String): Throwable = { + new AnalysisException( + s"The '$argName' parameter of function '$funcName' needs to be a $requiredType literal.") + } + + def invalidStringLiteralParameter( + funcName: String, + argName: String, + invalidValue: String, + allowedValues: Option[String] = None): Throwable = { + val endingMsg = allowedValues.map(" " + _).getOrElse("") + new AnalysisException(s"Invalid value for the '$argName' parameter of function '$funcName': " + + s"$invalidValue.$endingMsg") } def literalTypeUnsupportedForSourceTypeError(field: String, source: Expression): Throwable = { @@ -1323,10 +1332,6 @@ object QueryCompilationErrors { s"Expected: ${dataType.typeName}; Found: ${expression.dataType.typeName}") } - def groupAggPandasUDFUnsupportedByStreamingAggError(): Throwable = { - new AnalysisException("Streaming aggregation doesn't support group aggregate pandas UDF") - } - def streamJoinStreamWithoutEqualityPredicateUnsupportedError(plan: LogicalPlan): Throwable = { new AnalysisException( "Stream-stream join without equality predicate is not supported", plan = Some(plan)) @@ -1334,7 +1339,8 @@ object QueryCompilationErrors { def cannotUseMixtureOfAggFunctionAndGroupAggPandasUDFError(): Throwable = { new AnalysisException( - "Cannot use a mixture of aggregate function and group aggregate pandas UDF") + errorClass = "CANNOT_USE_MIXTURE", + messageParameters = Array.empty) } def ambiguousAttributesInSelfJoinError( @@ -1369,11 +1375,6 @@ object QueryCompilationErrors { new AnalysisException("Cannot use interval type in the table schema.") } - def cannotConvertTransformsToPartitionColumnsError(nonIdTransforms: Seq[Transform]): Throwable = { - new AnalysisException("Transforms cannot be converted to partition columns: " + - nonIdTransforms.map(_.describe).mkString(", ")) - } - def cannotPartitionByNestedColumnError(reference: NamedReference): Throwable = { new AnalysisException(s"Cannot partition by nested column: $reference") } @@ -1568,8 +1569,10 @@ object QueryCompilationErrors { } def usePythonUDFInJoinConditionUnsupportedError(joinType: JoinType): Throwable = { - new AnalysisException("Using PythonUDF in join condition of join type" + - s" $joinType is not supported.") + new AnalysisException( + errorClass = "UNSUPPORTED_FEATURE", + messageParameters = Array( + s"Using PythonUDF in join condition of join type $joinType is not supported")) } def conflictingAttributesInJoinConditionError( @@ -2381,11 +2384,8 @@ object QueryCompilationErrors { new UnsupportedOperationException(s"Table $tableName does not support time travel.") } - def multipleSignInNumberFormatError(message: String, numberFormat: String): Throwable = { - new AnalysisException(s"Multiple $message in '$numberFormat'") - } - - def nonFistOrLastCharInNumberFormatError(message: String, numberFormat: String): Throwable = { - new AnalysisException(s"$message must be the first or last char in '$numberFormat'") + def writeDistributionAndOrderingNotSupportedInContinuousExecution(): Throwable = { + new AnalysisException( + "Sinks cannot request distribution and ordering in continuous execution mode") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index ede4c393b1308..c6a69e4ce5d6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -42,13 +42,13 @@ import org.apache.spark.sql.catalyst.ScalaReflection.Schema import org.apache.spark.sql.catalyst.WalkedTypePath import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, UnevaluableAggregate} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical.{DomainJoin, LogicalPlan} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.ValueInterval import org.apache.spark.sql.catalyst.trees.TreeNode -import org.apache.spark.sql.catalyst.util.{sideBySide, BadRecordException, FailFastMode} +import org.apache.spark.sql.catalyst.util.{sideBySide, BadRecordException, DateTimeUtils, FailFastMode} import org.apache.spark.sql.connector.catalog.{CatalogNotFoundException, Identifier, Table, TableProvider} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.expressions.Transform @@ -68,11 +68,6 @@ import org.apache.spark.util.CircularBuffer */ object QueryExecutionErrors { - def columnChangeUnsupportedError(): Throwable = { - new SparkUnsupportedOperationException(errorClass = "UNSUPPORTED_CHANGE_COLUMN", - messageParameters = Array.empty) - } - def logicalHintOperatorNotRemovedDuringAnalysisError(): Throwable = { new SparkIllegalStateException(errorClass = "INTERNAL_ERROR", messageParameters = Array( @@ -94,9 +89,9 @@ object QueryExecutionErrors { messageParameters = Array(s"Cannot terminate expression: $generator")) } - def castingCauseOverflowError(t: Any, targetType: String): ArithmeticException = { + def castingCauseOverflowError(t: Any, dataType: DataType): ArithmeticException = { new SparkArithmeticException(errorClass = "CAST_CAUSES_OVERFLOW", - messageParameters = Array(t.toString, targetType, SQLConf.ANSI_ENABLED.key)) + messageParameters = Array(t.toString, dataType.catalogString, SQLConf.ANSI_ENABLED.key)) } def cannotChangeDecimalPrecisionError( @@ -131,22 +126,6 @@ object QueryExecutionErrors { messageParameters = Array.empty) } - def simpleStringWithNodeIdUnsupportedError(nodeName: String): Throwable = { - new SparkUnsupportedOperationException(errorClass = "UNSUPPORTED_SIMPLE_STRING_WITH_NODE_ID", - messageParameters = Array(nodeName)) - } - - def evaluateUnevaluableAggregateUnsupportedError( - methodName: String, unEvaluable: UnevaluableAggregate): Throwable = { - new SparkUnsupportedOperationException(errorClass = "INTERNAL_ERROR", - messageParameters = Array(s"Cannot evaluate expression: $methodName: $unEvaluable")) - } - - def dataTypeUnsupportedError(dt: DataType): Throwable = { - new SparkException(errorClass = "UNSUPPORTED_DATATYPE", - messageParameters = Array(dt.typeName), null) - } - def dataTypeUnsupportedError(dataType: String, failure: String): Throwable = { new SparkIllegalArgumentException(errorClass = "UNSUPPORTED_DATATYPE", messageParameters = Array(dataType + failure)) @@ -196,11 +175,6 @@ object QueryExecutionErrors { } } - def rowFromCSVParserNotExpectedError(): Throwable = { - new SparkIllegalArgumentException(errorClass = "ROW_FROM_CSV_PARSER_NOT_EXPECTED", - messageParameters = Array.empty) - } - def inputTypeUnsupportedError(dataType: DataType): Throwable = { new IllegalArgumentException(s"Unsupported input type ${dataType.catalogString}") } @@ -257,8 +231,17 @@ object QueryExecutionErrors { } def literalTypeUnsupportedError(v: Any): RuntimeException = { - new SparkRuntimeException("UNSUPPORTED_LITERAL_TYPE", - Array(v.getClass.toString, v.toString)) + new SparkRuntimeException( + errorClass = "UNSUPPORTED_FEATURE", + messageParameters = Array(s"literal for '${v.toString}' of ${v.getClass.toString}.")) + } + + def pivotColumnUnsupportedError(v: Any, dataType: DataType): RuntimeException = { + new SparkRuntimeException( + errorClass = "UNSUPPORTED_FEATURE", + messageParameters = Array( + s"pivoting by the value '${v.toString}' of the column data type" + + s" '${dataType.catalogString}'.")) } def noDefaultForDataTypeError(dataType: DataType): RuntimeException = { @@ -550,30 +533,43 @@ object QueryExecutionErrors { def sparkUpgradeInReadingDatesError( format: String, config: String, option: String): SparkUpgradeException = { - new SparkUpgradeException("3.0", - s""" - |reading dates before 1582-10-15 or timestamps before 1900-01-01T00:00:00Z from $format - |files can be ambiguous, as the files may be written by Spark 2.x or legacy versions of - |Hive, which uses a legacy hybrid calendar that is different from Spark 3.0+'s Proleptic - |Gregorian calendar. See more details in SPARK-31404. You can set the SQL config - |'$config' or the datasource option '$option' to 'LEGACY' to rebase the datetime values - |w.r.t. the calendar difference during reading. To read the datetime values as it is, - |set the SQL config '$config' or the datasource option '$option' to 'CORRECTED'. - """.stripMargin, null) + new SparkUpgradeException( + errorClass = "INCONSISTENT_BEHAVIOR_CROSS_VERSION", + messageParameters = Array( + "3.0", + s""" + |reading dates before 1582-10-15 or timestamps before 1900-01-01T00:00:00Z + |from $format files can be ambiguous, as the files may be written by + |Spark 2.x or legacy versions of Hive, which uses a legacy hybrid calendar + |that is different from Spark 3.0+'s Proleptic Gregorian calendar. + |See more details in SPARK-31404. You can set the SQL config '$config' or + |the datasource option '$option' to 'LEGACY' to rebase the datetime values + |w.r.t. the calendar difference during reading. To read the datetime values + |as it is, set the SQL config '$config' or the datasource option '$option' + |to 'CORRECTED'. + |""".stripMargin), + cause = null + ) } def sparkUpgradeInWritingDatesError(format: String, config: String): SparkUpgradeException = { - new SparkUpgradeException("3.0", - s""" - |writing dates before 1582-10-15 or timestamps before 1900-01-01T00:00:00Z into $format - |files can be dangerous, as the files may be read by Spark 2.x or legacy versions of Hive - |later, which uses a legacy hybrid calendar that is different from Spark 3.0+'s Proleptic - |Gregorian calendar. See more details in SPARK-31404. You can set $config to 'LEGACY' to - |rebase the datetime values w.r.t. the calendar difference during writing, to get maximum - |interoperability. Or set $config to 'CORRECTED' to write the datetime values as it is, - |if you are 100% sure that the written files will only be read by Spark 3.0+ or other - |systems that use Proleptic Gregorian calendar. - """.stripMargin, null) + new SparkUpgradeException( + errorClass = "INCONSISTENT_BEHAVIOR_CROSS_VERSION", + messageParameters = Array( + "3.0", + s""" + |writing dates before 1582-10-15 or timestamps before 1900-01-01T00:00:00Z + |into $format files can be dangerous, as the files may be read by Spark 2.x + |or legacy versions of Hive later, which uses a legacy hybrid calendar that + |is different from Spark 3.0+'s Proleptic Gregorian calendar. See more + |details in SPARK-31404. You can set $config to 'LEGACY' to rebase the + |datetime values w.r.t. the calendar difference during writing, to get maximum + |interoperability. Or set $config to 'CORRECTED' to write the datetime values + |as it is, if you are 100% sure that the written files will only be read by + |Spark 3.0+ or other systems that use Proleptic Gregorian calendar. + |""".stripMargin), + cause = null + ) } def buildReaderUnsupportedForFileFormatError(format: String): Throwable = { @@ -669,7 +665,7 @@ object QueryExecutionErrors { def unsupportedPartitionTransformError(transform: Transform): Throwable = { new UnsupportedOperationException( - s"SessionCatalog does not support partition transform: $transform") + s"Unsupported partition transform: $transform") } def missingDatabaseLocationError(): Throwable = { @@ -784,8 +780,10 @@ object QueryExecutionErrors { } def transactionUnsupportedByJdbcServerError(): Throwable = { - new SparkSQLFeatureNotSupportedException(errorClass = "UNSUPPORTED_TRANSACTION_BY_JDBC_SERVER", - Array.empty) + new SparkSQLFeatureNotSupportedException( + errorClass = "UNSUPPORTED_FEATURE", + messageParameters = Array("the target JDBC server does not support transaction and " + + "can only support ALTER TABLE with a single action.")) } def dataTypeUnsupportedYetError(dataType: DataType): Throwable = { @@ -817,6 +815,15 @@ object QueryExecutionErrors { """.stripMargin.replaceAll("\n", " ")) } + def foundDuplicateFieldInFieldIdLookupModeError( + requiredId: Int, matchedFields: String): Throwable = { + new RuntimeException( + s""" + |Found duplicate field(s) "$requiredId": $matchedFields + |in id mapping mode + """.stripMargin.replaceAll("\n", " ")) + } + def failedToMergeIncompatibleSchemasError( left: StructType, right: StructType, e: Throwable): Throwable = { new SparkException(s"Failed to merge incompatible schemas $left and $right", e) @@ -1623,8 +1630,12 @@ object QueryExecutionErrors { } def timeZoneIdNotSpecifiedForTimestampTypeError(): Throwable = { - new UnsupportedOperationException( - s"${TimestampType.catalogString} must supply timeZoneId parameter") + new SparkUnsupportedOperationException( + errorClass = "UNSUPPORTED_OPERATION", + messageParameters = Array( + s"${TimestampType.catalogString} must supply timeZoneId parameter " + + s"while converting to ArrowType") + ) } def notPublicClassError(name: String): Throwable = { @@ -1897,23 +1908,40 @@ object QueryExecutionErrors { } def repeatedPivotsUnsupportedError(): Throwable = { - new UnsupportedOperationException("repeated pivots are not supported") + new SparkUnsupportedOperationException( + errorClass = "UNSUPPORTED_FEATURE", + messageParameters = Array("Repeated pivots.")) } def pivotNotAfterGroupByUnsupportedError(): Throwable = { - new UnsupportedOperationException("pivot is only supported after a groupBy") + new SparkUnsupportedOperationException( + errorClass = "UNSUPPORTED_FEATURE", + messageParameters = Array("Pivot not after a groupBy.")) } def invalidAesKeyLengthError(actualLength: Int): RuntimeException = { - new SparkRuntimeException("INVALID_AES_KEY_LENGTH", Array(actualLength.toString)) + new SparkRuntimeException( + errorClass = "INVALID_PARAMETER_VALUE", + messageParameters = Array( + "key", + "the aes_encrypt/aes_decrypt function", + s"expects a binary value with 16, 24 or 32 bytes, but got ${actualLength.toString} bytes.")) } def aesModeUnsupportedError(mode: String, padding: String): RuntimeException = { - new SparkRuntimeException("UNSUPPORTED_AES_MODE", Array(mode, padding)) + new SparkRuntimeException( + errorClass = "UNSUPPORTED_FEATURE", + messageParameters = Array( + s"AES-$mode with the padding $padding by the aes_encrypt/aes_decrypt function.")) } def aesCryptoError(detailMessage: String): RuntimeException = { - new SparkRuntimeException("AES_CRYPTO_ERROR", Array(detailMessage)) + new SparkRuntimeException( + errorClass = "INVALID_PARAMETER_VALUE", + messageParameters = Array( + "expr, key", + "the aes_encrypt/aes_decrypt function", + s"Detail message: $detailMessage")) } def hiveTableWithAnsiIntervalsError(tableName: String): Throwable = { @@ -1921,7 +1949,16 @@ object QueryExecutionErrors { } def cannotConvertOrcTimestampToTimestampNTZError(): Throwable = { - new RuntimeException("Unable to convert timestamp of Orc to data type 'timestamp_ntz'") + new SparkUnsupportedOperationException( + errorClass = "UNSUPPORTED_OPERATION", + messageParameters = Array("Unable to convert timestamp of Orc to data type 'timestamp_ntz'")) + } + + def cannotConvertOrcTimestampNTZToTimestampLTZError(): Throwable = { + new SparkUnsupportedOperationException( + errorClass = "UNSUPPORTED_OPERATION", + messageParameters = + Array("Unable to convert timestamp ntz of Orc to data type 'timestamp_ltz'")) } def writePartitionExceedConfigSizeWhenDynamicPartitionError( @@ -1935,9 +1972,30 @@ object QueryExecutionErrors { s" to at least $numWrittenParts.") } - def invalidNumberFormatError(format: String): Throwable = { + def invalidNumberFormatError(input: UTF8String, format: String): Throwable = { new IllegalArgumentException( - s"Format '$format' used for parsing string to number or " + - "formatting number to string is invalid") + s"The input string '$input' does not match the given number format: '$format'") + } + + def multipleBucketTransformsError(): Throwable = { + new UnsupportedOperationException("Multiple bucket transforms are not supported.") + } + + def unsupportedCreateNamespaceCommentError(): Throwable = { + new SQLFeatureNotSupportedException("Create namespace comment is not supported") + } + + def unsupportedRemoveNamespaceCommentError(): Throwable = { + new SQLFeatureNotSupportedException("Remove namespace comment is not supported") + } + + def unsupportedDropNamespaceRestrictError(): Throwable = { + new SQLFeatureNotSupportedException("Drop namespace restrict is not supported") + } + + def timestampAddOverflowError(micros: Long, amount: Int, unit: String): ArithmeticException = { + new SparkArithmeticException( + errorClass = "DATETIME_OVERFLOW", + messageParameters = Array(s"add $amount $unit to '${DateTimeUtils.microsToInstant(micros)}'")) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala index 938bbfdb49c33..d055299b39396 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala @@ -90,11 +90,13 @@ object QueryParsingErrors { } def transformNotSupportQuantifierError(ctx: ParserRuleContext): Throwable = { - new ParseException("TRANSFORM does not support DISTINCT/ALL in inputs", ctx) + new ParseException("UNSUPPORTED_FEATURE", + Array("TRANSFORM does not support DISTINCT/ALL in inputs"), ctx) } def transformWithSerdeUnsupportedError(ctx: ParserRuleContext): Throwable = { - new ParseException("TRANSFORM with serde is only supported in hive mode", ctx) + new ParseException("UNSUPPORTED_FEATURE", + Array("TRANSFORM with serde is only supported in hive mode"), ctx) } def lateralWithPivotInFromClauseNotAllowedError(ctx: FromClauseContext): Throwable = { @@ -102,39 +104,38 @@ object QueryParsingErrors { } def lateralJoinWithNaturalJoinUnsupportedError(ctx: ParserRuleContext): Throwable = { - new ParseException("LATERAL join with NATURAL join is not supported", ctx) + new ParseException("UNSUPPORTED_FEATURE", Array("LATERAL join with NATURAL join."), ctx) } def lateralJoinWithUsingJoinUnsupportedError(ctx: ParserRuleContext): Throwable = { - new ParseException("LATERAL join with USING join is not supported", ctx) + new ParseException("UNSUPPORTED_FEATURE", Array("LATERAL join with USING join."), ctx) } def unsupportedLateralJoinTypeError(ctx: ParserRuleContext, joinType: String): Throwable = { - new ParseException(s"Unsupported LATERAL join type $joinType", ctx) + new ParseException("UNSUPPORTED_FEATURE", Array(s"LATERAL join type '$joinType'."), ctx) } def invalidLateralJoinRelationError(ctx: RelationPrimaryContext): Throwable = { - new ParseException(s"LATERAL can only be used with subquery", ctx) + new ParseException("INVALID_SQL_SYNTAX", Array("LATERAL can only be used with subquery."), ctx) } def repetitiveWindowDefinitionError(name: String, ctx: WindowClauseContext): Throwable = { - new ParseException(s"The definition of window '$name' is repetitive", ctx) + new ParseException("INVALID_SQL_SYNTAX", + Array(s"The definition of window '$name' is repetitive."), ctx) } def invalidWindowReferenceError(name: String, ctx: WindowClauseContext): Throwable = { - new ParseException(s"Window reference '$name' is not a window specification", ctx) + new ParseException("INVALID_SQL_SYNTAX", + Array(s"Window reference '$name' is not a window specification."), ctx) } def cannotResolveWindowReferenceError(name: String, ctx: WindowClauseContext): Throwable = { - new ParseException(s"Cannot resolve window reference '$name'", ctx) - } - - def joinCriteriaUnimplementedError(join: JoinCriteriaContext, ctx: RelationContext): Throwable = { - new ParseException(s"Unimplemented joinCriteria: $join", ctx) + new ParseException("INVALID_SQL_SYNTAX", + Array(s"Cannot resolve window reference '$name'."), ctx) } def naturalCrossJoinUnsupportedError(ctx: RelationContext): Throwable = { - new ParseException("NATURAL CROSS JOIN is not supported", ctx) + new ParseException("UNSUPPORTED_FEATURE", Array("NATURAL CROSS JOIN."), ctx) } def emptyInputForTableSampleError(ctx: ParserRuleContext): Throwable = { @@ -160,7 +161,8 @@ object QueryParsingErrors { } def functionNameUnsupportedError(functionName: String, ctx: ParserRuleContext): Throwable = { - new ParseException(s"Unsupported function name '$functionName'", ctx) + new ParseException("INVALID_SQL_SYNTAX", + Array(s"Unsupported function name '$functionName'"), ctx) } def cannotParseValueTypeError( @@ -225,21 +227,13 @@ object QueryParsingErrors { } def tooManyArgumentsForTransformError(name: String, ctx: ApplyTransformContext): Throwable = { - new ParseException(s"Too many arguments for transform $name", ctx) - } - - def notEnoughArgumentsForTransformError(name: String, ctx: ApplyTransformContext): Throwable = { - new ParseException(s"Not enough arguments for transform $name", ctx) + new ParseException("INVALID_SQL_SYNTAX", Array(s"Too many arguments for transform $name"), ctx) } def invalidBucketsNumberError(describe: String, ctx: ApplyTransformContext): Throwable = { new ParseException(s"Invalid number of buckets: $describe", ctx) } - def invalidTransformArgumentError(ctx: TransformArgumentContext): Throwable = { - new ParseException("Invalid transform argument", ctx) - } - def cannotCleanReservedNamespacePropertyError( property: String, ctx: ParserRuleContext, msg: String): Throwable = { new ParseException(s"$property is a reserved namespace property, $msg.", ctx) @@ -300,12 +294,13 @@ object QueryParsingErrors { } def showFunctionsUnsupportedError(identifier: String, ctx: IdentifierContext): Throwable = { - new ParseException(s"SHOW $identifier FUNCTIONS not supported", ctx) + new ParseException("INVALID_SQL_SYNTAX", + Array(s"SHOW $identifier FUNCTIONS not supported"), ctx) } def showFunctionsInvalidPatternError(pattern: String, ctx: ParserRuleContext): Throwable = { - new ParseException(s"Invalid pattern in SHOW FUNCTIONS: $pattern. It must be " + - "a string literal.", ctx) + new ParseException("INVALID_SQL_SYNTAX", + Array(s"Invalid pattern in SHOW FUNCTIONS: $pattern. It must be a string literal."), ctx) } def duplicateCteDefinitionNamesError(duplicateNames: String, ctx: CtesContext): Throwable = { @@ -410,22 +405,27 @@ object QueryParsingErrors { } def createFuncWithBothIfNotExistsAndReplaceError(ctx: CreateFunctionContext): Throwable = { - new ParseException("CREATE FUNCTION with both IF NOT EXISTS and REPLACE is not allowed.", ctx) + new ParseException("INVALID_SQL_SYNTAX", + Array("CREATE FUNCTION with both IF NOT EXISTS and REPLACE is not allowed."), ctx) } def defineTempFuncWithIfNotExistsError(ctx: CreateFunctionContext): Throwable = { - new ParseException("It is not allowed to define a TEMPORARY function with IF NOT EXISTS.", ctx) + new ParseException("INVALID_SQL_SYNTAX", + Array("It is not allowed to define a TEMPORARY function with IF NOT EXISTS."), ctx) } def unsupportedFunctionNameError(quoted: String, ctx: CreateFunctionContext): Throwable = { - new ParseException(s"Unsupported function name '$quoted'", ctx) + new ParseException("INVALID_SQL_SYNTAX", + Array(s"Unsupported function name '$quoted'"), ctx) } def specifyingDBInCreateTempFuncError( databaseName: String, ctx: CreateFunctionContext): Throwable = { new ParseException( - s"Specifying a database in CREATE TEMPORARY FUNCTION is not allowed: '$databaseName'", ctx) + "INVALID_SQL_SYNTAX", + Array(s"Specifying a database in CREATE TEMPORARY FUNCTION is not allowed: '$databaseName'"), + ctx) } def unclosedBracketedCommentError(command: String, position: Origin): Throwable = { @@ -437,7 +437,11 @@ object QueryParsingErrors { } def invalidNameForDropTempFunc(name: Seq[String], ctx: ParserRuleContext): Throwable = { - new ParseException( - s"DROP TEMPORARY FUNCTION requires a single part name but got: ${name.quoted}", ctx) + new ParseException("INVALID_SQL_SYNTAX", + Array(s"DROP TEMPORARY FUNCTION requires a single part name but got: ${name.quoted}"), ctx) + } + + def defaultColumnNotImplementedYetError(ctx: ParserRuleContext): Throwable = { + new ParseException("Support for DEFAULT column values is not implemented yet", ctx) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 252dd5bad30b4..3314dd1916498 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -396,6 +396,29 @@ object SQLConf { .booleanConf .createWithDefault(true) + val REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION = + buildConf("spark.sql.requireAllClusterKeysForCoPartition") + .internal() + .doc("When true, the planner requires all the clustering keys as the hash partition keys " + + "of the children, to eliminate the shuffles for the operator that needs its children to " + + "be co-partitioned, such as JOIN node. This is to avoid data skews which can lead to " + + "significant performance regression if shuffles are eliminated.") + .version("3.3.0") + .booleanConf + .createWithDefault(true) + + val REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION = + buildConf("spark.sql.requireAllClusterKeysForDistribution") + .internal() + .doc("When true, the planner requires all the clustering keys as the partition keys " + + "(with same ordering) of the children, to eliminate the shuffle for the operator that " + + "requires its children be clustered distributed, such as AGGREGATE and WINDOW node. " + + "This is to avoid data skews which can lead to significant performance regression if " + + "shuffle is eliminated.") + .version("3.3.0") + .booleanConf + .createWithDefault(false) + val RADIX_SORT_ENABLED = buildConf("spark.sql.sort.enableRadixSort") .internal() .doc("When true, enable use of radix sort when possible. Radix sort is much faster but " + @@ -721,6 +744,15 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PROPAGATE_DISTINCT_KEYS_ENABLED = + buildConf("spark.sql.optimizer.propagateDistinctKeys.enabled") + .internal() + .doc("When true, the query optimizer will propagate a set of distinct attributes from the " + + "current node and use it to optimize query.") + .version("3.3.0") + .booleanConf + .createWithDefault(true) + val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals") .internal() .doc("When true, string literals (including regex patterns) remain escaped in our SQL " + @@ -923,6 +955,33 @@ object SQLConf { .intConf .createWithDefault(4096) + val PARQUET_FIELD_ID_WRITE_ENABLED = + buildConf("spark.sql.parquet.fieldId.write.enabled") + .doc("Field ID is a native field of the Parquet schema spec. When enabled, " + + "Parquet writers will populate the field Id " + + "metadata (if present) in the Spark schema to the Parquet schema.") + .version("3.3.0") + .booleanConf + .createWithDefault(true) + + val PARQUET_FIELD_ID_READ_ENABLED = + buildConf("spark.sql.parquet.fieldId.read.enabled") + .doc("Field ID is a native field of the Parquet schema spec. When enabled, Parquet readers " + + "will use field IDs (if present) in the requested Spark schema to look up Parquet " + + "fields instead of using column names") + .version("3.3.0") + .booleanConf + .createWithDefault(false) + + val IGNORE_MISSING_PARQUET_FIELD_ID = + buildConf("spark.sql.parquet.fieldId.read.ignoreMissing") + .doc("When the Parquet file doesn't have any field IDs but the " + + "Spark read schema is using field IDs to read, we will silently return nulls " + + "when this flag is enabled, or error otherwise.") + .version("3.3.0") + .booleanConf + .createWithDefault(false) + val ORC_COMPRESSION = buildConf("spark.sql.orc.compression.codec") .doc("Sets the compression codec used when writing ORC files. If either `compression` or " + "`orc.compress` is specified in the table-specific options/properties, the precedence " + @@ -1670,7 +1729,6 @@ object SQLConf { val STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION = buildConf("spark.sql.streaming.sessionWindow.merge.sessions.in.local.partition") - .internal() .doc("When true, streaming session window sorts and merge sessions in local partition " + "prior to shuffle. This is to reduce the rows to shuffle, but only beneficial when " + "there're lots of rows in a batch being assigned to same sessions.") @@ -1724,6 +1782,23 @@ object SQLConf { .booleanConf .createWithDefault(true) + val STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION = + buildConf("spark.sql.streaming.statefulOperator.useStrictDistribution") + .internal() + .doc("The purpose of this config is only compatibility; DO NOT MANUALLY CHANGE THIS!!! " + + "When true, the stateful operator for streaming query will use " + + "StatefulOpClusteredDistribution which guarantees stable state partitioning as long as " + + "the operator provides consistent grouping keys across the lifetime of query. " + + "When false, the stateful operator for streaming query will use ClusteredDistribution " + + "which is not sufficient to guarantee stable state partitioning despite the operator " + + "provides consistent grouping keys across the lifetime of query. " + + "This config will be set to true for new streaming queries to guarantee stable state " + + "partitioning, and set to false for existing streaming queries to not break queries " + + "which are restored from existing checkpoints. Please refer SPARK-38204 for details.") + .version("3.3.0") + .booleanConf + .createWithDefault(true) + val FILESTREAM_SINK_METADATA_IGNORED = buildConf("spark.sql.streaming.fileStreamSink.ignoreMetadata") .internal() @@ -2372,7 +2447,8 @@ object SQLConf { "and shows a Python-friendly exception only.") .version("3.0.0") .booleanConf - .createWithDefault(false) + // show full stacktrace in tests but hide in production by default. + .createWithDefault(Utils.isTesting) val ARROW_SPARKR_EXECUTION_ENABLED = buildConf("spark.sql.execution.arrow.sparkr.enabled") @@ -2429,7 +2505,8 @@ object SQLConf { "shows the exception messages from UDFs. Note that this works only with CPython 3.7+.") .version("3.1.0") .booleanConf - .createWithDefault(true) + // show full stacktrace in tests but hide in production by default. + .createWithDefault(!Utils.isTesting) val PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME = buildConf("spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName") @@ -2655,7 +2732,7 @@ object SQLConf { "standard directly, but their behaviors align with ANSI SQL's style") .version("3.0.0") .booleanConf - .createWithDefault(false) + .createWithDefault(sys.env.get("SPARK_ANSI_SQL_MODE").contains("true")) val ENFORCE_RESERVED_KEYWORDS = buildConf("spark.sql.ansi.enforceReservedKeywords") .doc(s"When true and '${ANSI_ENABLED.key}' is true, the Spark SQL parser enforces the ANSI " + @@ -3510,6 +3587,18 @@ object SQLConf { .booleanConf .createWithDefault(false) + val HISTOGRAM_NUMERIC_PROPAGATE_INPUT_TYPE = + buildConf("spark.sql.legacy.histogramNumericPropagateInputType") + .internal() + .doc("The histogram_numeric function computes a histogram on numeric 'expr' using nb bins. " + + "The return value is an array of (x,y) pairs representing the centers of the histogram's " + + "bins. If this config is set to true, the output type of the 'x' field in the return " + + "value is propagated from the input value consumed in the aggregate function. Otherwise, " + + "'x' always has double type.") + .version("3.3.0") + .booleanConf + .createWithDefault(true) + /** * Holds information about keys that have been deprecated. * @@ -4240,8 +4329,17 @@ class SQLConf extends Serializable with Logging { def inferDictAsStruct: Boolean = getConf(SQLConf.INFER_NESTED_DICT_AS_STRUCT) + def parquetFieldIdReadEnabled: Boolean = getConf(SQLConf.PARQUET_FIELD_ID_READ_ENABLED) + + def parquetFieldIdWriteEnabled: Boolean = getConf(SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED) + + def ignoreMissingParquetFieldId: Boolean = getConf(SQLConf.IGNORE_MISSING_PARQUET_FIELD_ID) + def useV1Command: Boolean = getConf(SQLConf.LEGACY_USE_V1_COMMAND) + def histogramNumericPropagateInputType: Boolean = + getConf(SQLConf.HISTOGRAM_NUMERIC_PROPAGATE_INPUT_TYPE) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index cb468c523f36c..bbf902849e7f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -251,9 +251,6 @@ final class Decimal extends Ordered[Decimal] with Serializable { def toByte: Byte = toLong.toByte - private def overflowException(dataType: String) = - throw QueryExecutionErrors.castingCauseOverflowError(this, dataType) - /** * @return the Byte value that is equal to the rounded decimal. * @throws ArithmeticException if the decimal is too big to fit in Byte type. @@ -264,14 +261,14 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (actualLongVal == actualLongVal.toByte) { actualLongVal.toByte } else { - throw QueryExecutionErrors.castingCauseOverflowError(this, "byte") + throw QueryExecutionErrors.castingCauseOverflowError(this, ByteType) } } else { val doubleVal = decimalVal.toDouble if (Math.floor(doubleVal) <= Byte.MaxValue && Math.ceil(doubleVal) >= Byte.MinValue) { doubleVal.toByte } else { - throw QueryExecutionErrors.castingCauseOverflowError(this, "byte") + throw QueryExecutionErrors.castingCauseOverflowError(this, ByteType) } } } @@ -286,14 +283,14 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (actualLongVal == actualLongVal.toShort) { actualLongVal.toShort } else { - throw QueryExecutionErrors.castingCauseOverflowError(this, "short") + throw QueryExecutionErrors.castingCauseOverflowError(this, ShortType) } } else { val doubleVal = decimalVal.toDouble if (Math.floor(doubleVal) <= Short.MaxValue && Math.ceil(doubleVal) >= Short.MinValue) { doubleVal.toShort } else { - throw QueryExecutionErrors.castingCauseOverflowError(this, "short") + throw QueryExecutionErrors.castingCauseOverflowError(this, ShortType) } } } @@ -308,14 +305,14 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (actualLongVal == actualLongVal.toInt) { actualLongVal.toInt } else { - throw QueryExecutionErrors.castingCauseOverflowError(this, "int") + throw QueryExecutionErrors.castingCauseOverflowError(this, IntegerType) } } else { val doubleVal = decimalVal.toDouble if (Math.floor(doubleVal) <= Int.MaxValue && Math.ceil(doubleVal) >= Int.MinValue) { doubleVal.toInt } else { - throw QueryExecutionErrors.castingCauseOverflowError(this, "int") + throw QueryExecutionErrors.castingCauseOverflowError(this, IntegerType) } } } @@ -335,7 +332,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { decimalVal.bigDecimal.toBigInteger.longValueExact() } catch { case _: ArithmeticException => - throw QueryExecutionErrors.castingCauseOverflowError(this, "long") + throw QueryExecutionErrors.castingCauseOverflowError(this, LongType) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala index 93d57a7fe6f3b..f490f8318ef84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala @@ -21,7 +21,7 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.apache.spark.annotation.Stable -import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdentifier} +import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIfNeeded} import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat import org.apache.spark.sql.util.SchemaUtils @@ -93,7 +93,7 @@ case class StructField( * Returns a string containing a schema in SQL format. For example the following value: * `StructField("eventId", IntegerType)` will be converted to `eventId`: INT. */ - private[sql] def sql = s"${quoteIdentifier(name)}: ${dataType.sql}$getDDLComment" + private[sql] def sql = s"${quoteIfNeeded(name)}: ${dataType.sql}$getDDLComment" /** * Returns a string containing a schema in DDL format. For example, the following value: @@ -103,6 +103,6 @@ case class StructField( */ def toDDL: String = { val nullString = if (nullable) "" else " NOT NULL" - s"${quoteIdentifier(name)} ${dataType.sql}${nullString}$getDDLComment" + s"${quoteIfNeeded(name)} ${dataType.sql}${nullString}$getDDLComment" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala index 6811e50ccdf94..16adec71bc84f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala @@ -115,7 +115,7 @@ private[sql] object LongExactNumeric extends LongIsIntegral with Ordering.LongOr if (x == x.toInt) { x.toInt } else { - throw QueryExecutionErrors.castingCauseOverflowError(x, "int") + throw QueryExecutionErrors.castingCauseOverflowError(x, IntegerType) } } @@ -135,7 +135,7 @@ private[sql] object FloatExactNumeric extends FloatIsFractional { if (Math.floor(x) <= intUpperBound && Math.ceil(x) >= intLowerBound) { x.toInt } else { - throw QueryExecutionErrors.castingCauseOverflowError(x, "int") + throw QueryExecutionErrors.castingCauseOverflowError(x, IntegerType) } } @@ -143,7 +143,7 @@ private[sql] object FloatExactNumeric extends FloatIsFractional { if (Math.floor(x) <= longUpperBound && Math.ceil(x) >= longLowerBound) { x.toLong } else { - throw QueryExecutionErrors.castingCauseOverflowError(x, "int") + throw QueryExecutionErrors.castingCauseOverflowError(x, LongType) } } @@ -160,7 +160,7 @@ private[sql] object DoubleExactNumeric extends DoubleIsFractional { if (Math.floor(x) <= intUpperBound && Math.ceil(x) >= intLowerBound) { x.toInt } else { - throw QueryExecutionErrors.castingCauseOverflowError(x, "int") + throw QueryExecutionErrors.castingCauseOverflowError(x, IntegerType) } } @@ -168,7 +168,7 @@ private[sql] object DoubleExactNumeric extends DoubleIsFractional { if (Math.floor(x) <= longUpperBound && Math.ceil(x) >= longLowerBound) { x.toLong } else { - throw QueryExecutionErrors.castingCauseOverflowError(x, "long") + throw QueryExecutionErrors.castingCauseOverflowError(x, LongType) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index 5d3f960c3bfac..a924a9ed02e5d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.catalyst import org.apache.spark.SparkFunSuite /* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Literal, Murmur3Hash, Pmod} import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.types.IntegerType class DistributionSuite extends SparkFunSuite { @@ -167,6 +169,24 @@ class DistributionSuite extends SparkFunSuite { ClusteredDistribution(Seq($"d", $"e")), false) + // When ClusteredDistribution.requireAllClusterKeys is set to true, + // HashPartitioning can only satisfy ClusteredDistribution iff its hash expressions are + // exactly same as the required clustering expressions. + checkSatisfied( + HashPartitioning(Seq($"a", $"b", $"c"), 10), + ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = true), + true) + + checkSatisfied( + HashPartitioning(Seq($"b", $"c"), 10), + ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = true), + false) + + checkSatisfied( + HashPartitioning(Seq($"b", $"a", $"c"), 10), + ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = true), + false) + // HashPartitioning cannot satisfy OrderedDistribution checkSatisfied( HashPartitioning(Seq($"a", $"b", $"c"), 10), @@ -247,22 +267,116 @@ class DistributionSuite extends SparkFunSuite { RangePartitioning(Seq($"a".asc, $"b".asc, $"c".asc), 10), ClusteredDistribution(Seq($"c", $"d")), false) + + // When ClusteredDistribution.requireAllClusterKeys is set to true, + // RangePartitioning can only satisfy ClusteredDistribution iff its ordering expressions are + // exactly same as the required clustering expressions. + checkSatisfied( + RangePartitioning(Seq($"a".asc, $"b".asc, $"c".asc), 10), + ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = true), + true) + + checkSatisfied( + RangePartitioning(Seq($"a".asc, $"b".asc), 10), + ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = true), + false) + + checkSatisfied( + RangePartitioning(Seq($"b".asc, $"a".asc, $"c".asc), 10), + ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = true), + false) } test("Partitioning.numPartitions must match Distribution.requiredNumPartitions to satisfy it") { checkSatisfied( SinglePartition, - ClusteredDistribution(Seq($"a", $"b", $"c"), Some(10)), + ClusteredDistribution(Seq($"a", $"b", $"c"), requiredNumPartitions = Some(10)), + false) + + checkSatisfied( + HashPartitioning(Seq($"a", $"b", $"c"), 10), + ClusteredDistribution(Seq($"a", $"b", $"c"), requiredNumPartitions = Some(5)), + false) + + checkSatisfied( + RangePartitioning(Seq($"a".asc, $"b".asc, $"c".asc), 10), + ClusteredDistribution(Seq($"a", $"b", $"c"), requiredNumPartitions = Some(5)), false) + } + + test("Structured Streaming output partitioning and distribution") { + // Validate HashPartitioning.partitionIdExpression to be exactly expected format, because + // Structured Streaming state store requires it to be consistent across Spark versions. + val expressions = Seq($"a", $"b", $"c") + val hashPartitioning = HashPartitioning(expressions, 10) + hashPartitioning.partitionIdExpression match { + case Pmod(Murmur3Hash(es, 42), Literal(10, IntegerType), _) => + assert(es.length == expressions.length && es.zip(expressions).forall { + case (l, r) => l.semanticEquals(r) + }) + case x => fail(s"Unexpected partitionIdExpression $x for $hashPartitioning") + } + // Validate only HashPartitioning (and HashPartitioning in PartitioningCollection) can satisfy + // StatefulOpClusteredDistribution. SinglePartition can also satisfy this distribution when + // `_requiredNumPartitions` is 1. checkSatisfied( HashPartitioning(Seq($"a", $"b", $"c"), 10), - ClusteredDistribution(Seq($"a", $"b", $"c"), Some(5)), + StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 10), + true) + + checkSatisfied( + PartitioningCollection(Seq( + HashPartitioning(Seq($"a", $"b", $"c"), 10), + RangePartitioning(Seq($"a".asc, $"b".asc, $"c".asc), 10))), + StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 10), + true) + + checkSatisfied( + SinglePartition, + StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 1), + true) + + checkSatisfied( + PartitioningCollection(Seq( + HashPartitioning(Seq($"a", $"b"), 1), + SinglePartition)), + StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 1), + true) + + checkSatisfied( + HashPartitioning(Seq($"a", $"b"), 10), + StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 10), + false) + + checkSatisfied( + HashPartitioning(Seq($"a", $"b", $"c"), 5), + StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 10), false) checkSatisfied( RangePartitioning(Seq($"a".asc, $"b".asc, $"c".asc), 10), - ClusteredDistribution(Seq($"a", $"b", $"c"), Some(5)), + StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 10), + false) + + checkSatisfied( + SinglePartition, + StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 10), + false) + + checkSatisfied( + BroadcastPartitioning(IdentityBroadcastMode), + StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 10), + false) + + checkSatisfied( + RoundRobinPartitioning(10), + StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 10), + false) + + checkSatisfied( + UnknownPartitioning(10), + StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 10), false) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SQLKeywordSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SQLKeywordSuite.scala index 45f88628f3ab3..0c1c9d5bfeeaf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SQLKeywordSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SQLKeywordSuite.scala @@ -30,11 +30,15 @@ import org.apache.spark.sql.catalyst.util.fileToString trait SQLKeywordUtils extends SparkFunSuite with SQLHelper { val sqlSyntaxDefs = { - val sqlBasePath = { + val sqlBaseParserPath = getWorkspaceFilePath("sql", "catalyst", "src", "main", "antlr4", "org", - "apache", "spark", "sql", "catalyst", "parser", "SqlBase.g4").toFile - } - fileToString(sqlBasePath).split("\n") + "apache", "spark", "sql", "catalyst", "parser", "SqlBaseParser.g4").toFile + + val sqlBaseLexerPath = + getWorkspaceFilePath("sql", "catalyst", "src", "main", "antlr4", "org", + "apache", "spark", "sql", "catalyst", "parser", "SqlBaseLexer.g4").toFile + + (fileToString(sqlBaseParserPath) + fileToString(sqlBaseLexerPath)).split("\n") } // each element is an array of 4 string: the keyword name, reserve or not in Spark ANSI mode, @@ -54,7 +58,7 @@ trait SQLKeywordUtils extends SparkFunSuite with SQLHelper { val default = (_: String) => Nil var startTagFound = false var parseFinished = false - val lineIter = sqlSyntaxDefs.toIterator + val lineIter = sqlSyntaxDefs.iterator while (!parseFinished && lineIter.hasNext) { val line = lineIter.next() if (line.trim.startsWith(startTag)) { @@ -67,8 +71,9 @@ trait SQLKeywordUtils extends SparkFunSuite with SQLHelper { } } } - assert(keywords.nonEmpty && startTagFound && parseFinished, "cannot extract keywords from " + - s"the `SqlBase.g4` file, so please check if the start/end tags (`$startTag` and `$endTag`) " + + assert(keywords.nonEmpty && startTagFound && parseFinished, + "cannot extract keywords from the `SqlBaseParser.g4` or `SqlBaseLexer.g4` file, " + + s"so please check if the start/end tags (`$startTag` and `$endTag`) " + "are placed correctly in the file.") keywords.toSet } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala index d4d73b363e23d..74ec949fe4470 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala @@ -18,11 +18,12 @@ package org.apache.spark.sql.catalyst import org.apache.spark.SparkFunSuite -/* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.internal.SQLConf -class ShuffleSpecSuite extends SparkFunSuite { +class ShuffleSpecSuite extends SparkFunSuite with SQLHelper { protected def checkCompatible( left: ShuffleSpec, right: ShuffleSpec, @@ -349,12 +350,22 @@ class ShuffleSpecSuite extends SparkFunSuite { test("canCreatePartitioning") { val distribution = ClusteredDistribution(Seq($"a", $"b")) - assert(HashShuffleSpec(HashPartitioning(Seq($"a"), 10), distribution).canCreatePartitioning) + withSQLConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false") { + assert(HashShuffleSpec(HashPartitioning(Seq($"a"), 10), distribution).canCreatePartitioning) + } + withSQLConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "true") { + assert(!HashShuffleSpec(HashPartitioning(Seq($"a"), 10), distribution) + .canCreatePartitioning) + assert(HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), distribution) + .canCreatePartitioning) + } assert(SinglePartitionShuffleSpec.canCreatePartitioning) - assert(ShuffleSpecCollection(Seq( + withSQLConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false") { + assert(ShuffleSpecCollection(Seq( HashShuffleSpec(HashPartitioning(Seq($"a"), 10), distribution), HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), distribution))) - .canCreatePartitioning) + .canCreatePartitioning) + } assert(!RangeShuffleSpec(10, distribution).canCreatePartitioning) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 8f690e2021602..c69d51938aef0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -305,7 +305,7 @@ class AnalysisErrorSuite extends AnalysisTest { .where(sum($"b") > 0) .orderBy($"havingCondition".asc), "MISSING_COLUMN", - Array("havingCondition", "max('b)")) + Array("havingCondition", "max(b)")) errorTest( "unresolved star expansion in max", @@ -793,7 +793,8 @@ class AnalysisErrorSuite extends AnalysisTest { val a = AttributeReference("a", IntegerType)() val b = AttributeReference("b", IntegerType)() val c = AttributeReference("c", IntegerType)() - val t1 = LocalRelation(a, b) + val d = AttributeReference("d", DoubleType)() + val t1 = LocalRelation(a, b, d) val t2 = LocalRelation(c) val conditions = Seq( (abs($"a") === $"c", "abs(a) = outer(c)"), @@ -801,7 +802,7 @@ class AnalysisErrorSuite extends AnalysisTest { ($"a" + 1 === $"c", "(a + 1) = outer(c)"), ($"a" + $"b" === $"c", "(a + b) = outer(c)"), ($"a" + $"c" === $"b", "(a + outer(c)) = b"), - (And($"a" === $"c", Cast($"a", IntegerType) === $"c"), "CAST(a AS INT) = outer(c)")) + (And($"a" === $"c", Cast($"d", IntegerType) === $"c"), "CAST(d AS INT) = outer(c)")) conditions.foreach { case (cond, msg) => val plan = Project( ScalarSubquery( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 63f90a8d6b886..fff25b59eff98 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -221,7 +221,9 @@ class AnalysisSuite extends AnalysisTest with Matchers { val pl = plan.asInstanceOf[Project].projectList assert(pl(0).dataType == DoubleType) - assert(pl(1).dataType == DoubleType) + if (!SQLConf.get.ansiEnabled) { + assert(pl(1).dataType == DoubleType) + } assert(pl(2).dataType == DoubleType) assert(pl(3).dataType == DoubleType) assert(pl(4).dataType == DoubleType) @@ -1150,4 +1152,28 @@ class AnalysisSuite extends AnalysisTest with Matchers { "MISSING_COLUMN", Array("c.y", "x")) } + + test("SPARK-38118: Func(wrong_type) in the HAVING clause should throw data mismatch error") { + Seq("mean", "abs").foreach { func => + assertAnalysisError(parsePlan( + s""" + |WITH t as (SELECT true c) + |SELECT t.c + |FROM t + |GROUP BY t.c + |HAVING ${func}(t.c) > 0d""".stripMargin), + Seq(s"cannot resolve '$func(t.c)' due to data type mismatch"), + false) + + assertAnalysisError(parsePlan( + s""" + |WITH t as (SELECT true c, false d) + |SELECT (t.c AND t.d) c + |FROM t + |GROUP BY t.c + |HAVING ${func}(c) > 0d""".stripMargin), + Seq(s"cannot resolve '$func(t.c)' due to data type mismatch"), + false) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 53dc9be6c69b7..804f1edbe06fd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -200,11 +200,15 @@ trait AnalysisTest extends PlanTest { } } - protected def interceptParseException( - parser: String => Any)(sqlCommand: String, messages: String*): Unit = { + protected def interceptParseException(parser: String => Any)( + sqlCommand: String, messages: String*)( + errorClass: Option[String] = None): Unit = { val e = intercept[ParseException](parser(sqlCommand)) messages.foreach { message => assert(e.message.contains(message)) } + if (errorClass.isDefined) { + assert(e.getErrorClass == errorClass.get) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala index 809cbb2cebdbf..1f23aeb61e1f4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala @@ -99,24 +99,15 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { } } - test("implicit type cast - unfoldable StringType") { - val nonCastableTypes = allTypes.filterNot(_ == StringType) - nonCastableTypes.foreach { dt => - shouldNotCastStringInput(dt) - } - shouldNotCastStringInput(DecimalType) - shouldNotCastStringInput(NumericType) - } - - test("implicit type cast - foldable StringType") { - atomicTypes.foreach { dt => - shouldCastStringLiteral(dt, dt) - } - allTypes.filterNot(atomicTypes.contains).foreach { dt => - shouldNotCastStringLiteral(dt) - } - shouldCastStringLiteral(DecimalType, DecimalType.defaultConcreteType) - shouldCastStringLiteral(NumericType, DoubleType) + test("implicit type cast - StringType") { + val checkedType = StringType + val nonCastableTypes = + complexTypes ++ Seq(NullType, CalendarIntervalType) + checkTypeCasting(checkedType, castableTypes = allTypes.filterNot(nonCastableTypes.contains)) + shouldCast(checkedType, DecimalType, DecimalType.SYSTEM_DEFAULT) + shouldCast(checkedType, NumericType, NumericType.defaultConcreteType) + shouldCast(checkedType, AnyTimestampType, AnyTimestampType.defaultConcreteType) + shouldNotCast(checkedType, IntegralType) } test("implicit type cast - unfoldable ArrayType(StringType)") { @@ -153,6 +144,26 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { shouldNotCast(checkedType, IntegralType) } + test("wider data type of two for string") { + def widenTest(t1: DataType, t2: DataType, expected: Option[DataType]): Unit = { + checkWidenType(AnsiTypeCoercion.findWiderTypeForTwo, t1, t2, expected) + checkWidenType(AnsiTypeCoercion.findWiderTypeForTwo, t2, t1, expected) + } + + widenTest(NullType, StringType, Some(StringType)) + widenTest(StringType, StringType, Some(StringType)) + Seq(ByteType, ShortType, IntegerType, LongType).foreach { dt => + widenTest(dt, StringType, Some(LongType)) + } + Seq(FloatType, DecimalType(20, 10), DoubleType).foreach { dt => + widenTest(dt, StringType, Some(DoubleType)) + } + + Seq(DateType, TimestampType, BinaryType, BooleanType).foreach { dt => + widenTest(dt, StringType, Some(dt)) + } + } + test("tightest common bound for types") { def widenTest(t1: DataType, t2: DataType, expected: Option[DataType]): Unit = checkWidenType(AnsiTypeCoercion.findTightestCommonType, t1, t2, expected) @@ -408,7 +419,7 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { ruleTest(rule, Coalesce(Seq(timestampLit, stringLit)), - Coalesce(Seq(timestampLit, stringLit))) + Coalesce(Seq(timestampLit, Cast(stringLit, TimestampType)))) ruleTest(rule, Coalesce(Seq(nullLit, floatNullLit, intLit)), @@ -422,7 +433,8 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { // There is no a common type among Float/Double/String ruleTest(rule, Coalesce(Seq(nullLit, floatNullLit, doubleLit, stringLit)), - Coalesce(Seq(nullLit, floatNullLit, doubleLit, stringLit))) + Coalesce(Seq(Cast(nullLit, DoubleType), Cast(floatNullLit, DoubleType), + doubleLit, Cast(stringLit, DoubleType)))) // There is no a common type among Timestamp/Int/String ruleTest(rule, @@ -451,8 +463,8 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { :: Literal("a") :: Nil), CreateArray(Literal(1.0) - :: Literal(1) - :: Literal("a") + :: Cast(Literal(1), DoubleType) + :: Cast(Literal("a"), DoubleType) :: Nil)) ruleTest(AnsiTypeCoercion.FunctionArgumentConversion, @@ -506,7 +518,7 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { :: Literal(3.0) :: Nil), CreateMap(Literal(1) - :: Literal("a") + :: Cast(Literal("a"), DoubleType) :: Literal(2) :: Literal(3.0) :: Nil)) @@ -523,13 +535,13 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { :: Nil)) // type coercion for both map keys and values ruleTest(AnsiTypeCoercion.FunctionArgumentConversion, - CreateMap(Literal(1) - :: Literal("a") + CreateMap(Cast(Literal(1), DoubleType) + :: Cast(Literal("a"), DoubleType) :: Literal(2.0) :: Literal(3.0) :: Nil), CreateMap(Cast(Literal(1), DoubleType) - :: Literal("a") + :: Cast(Literal("a"), DoubleType) :: Literal(2.0) :: Literal(3.0) :: Nil)) @@ -644,11 +656,11 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { ruleTest(rule, If(falseLit, stringLit, doubleLit), - If(falseLit, stringLit, doubleLit)) + If(falseLit, Cast(stringLit, DoubleType), doubleLit)) ruleTest(rule, If(trueLit, timestampLit, stringLit), - If(trueLit, timestampLit, stringLit)) + If(trueLit, timestampLit, Cast(stringLit, TimestampType))) } test("type coercion for CaseKeyWhen") { @@ -878,7 +890,7 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { val wp1 = widenSetOperationTypes(union.select(p1.output.head, $"p2.v")) assert(wp1.isInstanceOf[Project]) // The attribute `p1.output.head` should be replaced in the root `Project`. - assert(wp1.expressions.forall(_.find(_ == p1.output.head).isEmpty)) + assert(wp1.expressions.forall(!_.exists(_ == p1.output.head))) val wp2 = widenSetOperationTypes(Aggregate(Nil, sum(p1.output.head).as("v") :: Nil, union)) assert(wp2.isInstanceOf[Aggregate]) assert(wp2.missingInput.isEmpty) @@ -901,7 +913,8 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { ) ruleTest(inConversion, In(Literal("a"), Seq(Literal(1), Literal("b"))), - In(Literal("a"), Seq(Literal(1), Literal("b"))) + In(Cast(Literal("a"), LongType), + Seq(Cast(Literal(1), LongType), Cast(Literal("b"), LongType))) ) } @@ -1024,55 +1037,6 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { IntegralDivide(Cast(2, LongType), 1L)) } - test("Promote string literals") { - val rule = AnsiTypeCoercion.PromoteStringLiterals - val stringLiteral = Literal("123") - val castStringLiteralAsInt = Cast(stringLiteral, IntegerType) - val castStringLiteralAsDouble = Cast(stringLiteral, DoubleType) - val castStringLiteralAsDate = Cast(stringLiteral, DateType) - val castStringLiteralAsTimestamp = Cast(stringLiteral, TimestampType) - ruleTest(rule, - GreaterThan(stringLiteral, Literal(1)), - GreaterThan(castStringLiteralAsInt, Literal(1))) - ruleTest(rule, - LessThan(Literal(true), stringLiteral), - LessThan(Literal(true), Cast(stringLiteral, BooleanType))) - ruleTest(rule, - EqualTo(Literal(Array(1, 2)), stringLiteral), - EqualTo(Literal(Array(1, 2)), stringLiteral)) - ruleTest(rule, - GreaterThan(stringLiteral, Literal(0.5)), - GreaterThan(castStringLiteralAsDouble, Literal(0.5))) - - val dateLiteral = Literal(java.sql.Date.valueOf("2021-01-01")) - ruleTest(rule, - EqualTo(stringLiteral, dateLiteral), - EqualTo(castStringLiteralAsDate, dateLiteral)) - - val timestampLiteral = Literal(Timestamp.valueOf("2021-01-01 00:00:00")) - ruleTest(rule, - EqualTo(stringLiteral, timestampLiteral), - EqualTo(castStringLiteralAsTimestamp, timestampLiteral)) - - ruleTest(rule, Add(stringLiteral, Literal(1)), - Add(castStringLiteralAsInt, Literal(1))) - ruleTest(rule, Divide(stringLiteral, Literal(1)), - Divide(castStringLiteralAsInt, Literal(1))) - - ruleTest(rule, - In(Literal(1), Seq(stringLiteral, Literal(2))), - In(Literal(1), Seq(castStringLiteralAsInt, Literal(2)))) - ruleTest(rule, - In(Literal(1.0), Seq(stringLiteral, Literal(2.2))), - In(Literal(1.0), Seq(castStringLiteralAsDouble, Literal(2.2)))) - ruleTest(rule, - In(dateLiteral, Seq(stringLiteral)), - In(dateLiteral, Seq(castStringLiteralAsDate))) - ruleTest(rule, - In(timestampLiteral, Seq(stringLiteral)), - In(timestampLiteral, Seq(castStringLiteralAsTimestamp))) - } - test("SPARK-35937: GetDateFieldOperations") { val ts = Literal(Timestamp.valueOf("2021-01-01 01:30:00")) Seq( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala index 41b22bc019014..ced83b31c7f04 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap class CreateTablePartitioningValidationSuite extends AnalysisTest { test("CreateTableAsSelect: fail missing top-level column") { - val tableSpec = TableSpec(None, Map.empty, None, Map.empty, + val tableSpec = TableSpec(Map.empty, None, Map.empty, None, None, None, false) val plan = CreateTableAsSelect( UnresolvedDBObjectName(Array("table_name"), isNamespace = false), @@ -46,7 +46,7 @@ class CreateTablePartitioningValidationSuite extends AnalysisTest { } test("CreateTableAsSelect: fail missing top-level column nested reference") { - val tableSpec = TableSpec(None, Map.empty, None, Map.empty, + val tableSpec = TableSpec(Map.empty, None, Map.empty, None, None, None, false) val plan = CreateTableAsSelect( UnresolvedDBObjectName(Array("table_name"), isNamespace = false), @@ -63,7 +63,7 @@ class CreateTablePartitioningValidationSuite extends AnalysisTest { } test("CreateTableAsSelect: fail missing nested column") { - val tableSpec = TableSpec(None, Map.empty, None, Map.empty, + val tableSpec = TableSpec(Map.empty, None, Map.empty, None, None, None, false) val plan = CreateTableAsSelect( UnresolvedDBObjectName(Array("table_name"), isNamespace = false), @@ -80,7 +80,7 @@ class CreateTablePartitioningValidationSuite extends AnalysisTest { } test("CreateTableAsSelect: fail with multiple errors") { - val tableSpec = TableSpec(None, Map.empty, None, Map.empty, + val tableSpec = TableSpec(Map.empty, None, Map.empty, None, None, None, false) val plan = CreateTableAsSelect( UnresolvedDBObjectName(Array("table_name"), isNamespace = false), @@ -98,7 +98,7 @@ class CreateTablePartitioningValidationSuite extends AnalysisTest { } test("CreateTableAsSelect: success with top-level column") { - val tableSpec = TableSpec(None, Map.empty, None, Map.empty, + val tableSpec = TableSpec(Map.empty, None, Map.empty, None, None, None, false) val plan = CreateTableAsSelect( UnresolvedDBObjectName(Array("table_name"), isNamespace = false), @@ -112,7 +112,7 @@ class CreateTablePartitioningValidationSuite extends AnalysisTest { } test("CreateTableAsSelect: success using nested column") { - val tableSpec = TableSpec(None, Map.empty, None, Map.empty, + val tableSpec = TableSpec(Map.empty, None, Map.empty, None, None, None, false) val plan = CreateTableAsSelect( UnresolvedDBObjectName(Array("table_name"), isNamespace = false), @@ -126,7 +126,7 @@ class CreateTablePartitioningValidationSuite extends AnalysisTest { } test("CreateTableAsSelect: success using complex column") { - val tableSpec = TableSpec(None, Map.empty, None, Map.empty, + val tableSpec = TableSpec(Map.empty, None, Map.empty, None, None, None, false) val plan = CreateTableAsSelect( UnresolvedDBObjectName(Array("table_name"), isNamespace = false), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 239d886303a02..da6b981fb4bf6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -23,10 +23,12 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -class ExpressionTypeCheckingSuite extends SparkFunSuite { +class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper { val testRelation = LocalRelation( Symbol("intField").int, @@ -103,8 +105,14 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(GreaterThanOrEqual(Symbol("intField"), Symbol("stringField"))) // We will transform EqualTo with numeric and boolean types to CaseKeyWhen - assertSuccess(EqualTo(Symbol("intField"), Symbol("booleanField"))) - assertSuccess(EqualNullSafe(Symbol("intField"), Symbol("booleanField"))) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + assertSuccess(EqualTo(Symbol("intField"), Symbol("booleanField"))) + assertSuccess(EqualNullSafe(Symbol("intField"), Symbol("booleanField"))) + } + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + assertError(EqualTo(Symbol("intField"), Symbol("booleanField")), "differing types") + assertError(EqualNullSafe(Symbol("intField"), Symbol("booleanField")), "differing types") + } assertErrorForDifferingTypes(EqualTo(Symbol("intField"), Symbol("mapField"))) assertErrorForDifferingTypes(EqualNullSafe(Symbol("intField"), Symbol("mapField"))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala index 77dc5b4ccedc4..ab8bcee121232 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala @@ -299,11 +299,19 @@ class ResolveHintsSuite extends AnalysisTest { } } - test("SPARK-35786: Support optimize repartition by expression in AQE") { + test("SPARK-35786: Support optimize rebalance by expression in AQE") { checkAnalysisWithoutViewWrapper( UnresolvedHint("REBALANCE", Seq(UnresolvedAttribute("a")), table("TaBlE")), RebalancePartitions(Seq(AttributeReference("a", IntegerType)()), testRelation)) + checkAnalysisWithoutViewWrapper( + UnresolvedHint("REBALANCE", Seq(1, UnresolvedAttribute("a")), table("TaBlE")), + RebalancePartitions(Seq(AttributeReference("a", IntegerType)()), testRelation, Some(1))) + + checkAnalysisWithoutViewWrapper( + UnresolvedHint("REBALANCE", Seq(Literal(1), UnresolvedAttribute("a")), table("TaBlE")), + RebalancePartitions(Seq(AttributeReference("a", IntegerType)()), testRelation, Some(1))) + checkAnalysisWithoutViewWrapper( UnresolvedHint("REBALANCE", Seq.empty, table("TaBlE")), RebalancePartitions(Seq.empty, testRelation)) @@ -313,13 +321,42 @@ class ResolveHintsSuite extends AnalysisTest { UnresolvedHint("REBALANCE", Seq(UnresolvedAttribute("a")), table("TaBlE")), testRelation) + checkAnalysisWithoutViewWrapper( + UnresolvedHint("REBALANCE", Seq(1, UnresolvedAttribute("a")), table("TaBlE")), + testRelation) + + checkAnalysisWithoutViewWrapper( + UnresolvedHint("REBALANCE", Seq(Literal(1), UnresolvedAttribute("a")), table("TaBlE")), + testRelation) + checkAnalysisWithoutViewWrapper( UnresolvedHint("REBALANCE", Seq.empty, table("TaBlE")), testRelation) + + checkAnalysisWithoutViewWrapper( + UnresolvedHint("REBALANCE", 1 :: Nil, table("TaBlE")), + testRelation) } assertAnalysisError( - UnresolvedHint("REBALANCE", Seq(Literal(1)), table("TaBlE")), + UnresolvedHint("REBALANCE", Seq(Literal(1), Literal(1)), table("TaBlE")), Seq("Hint parameter should include columns")) + + assertAnalysisError( + UnresolvedHint("REBALANCE", Seq(1, Literal(1)), table("TaBlE")), + Seq("Hint parameter should include columns")) + } + + test("SPARK-38410: Support specify initial partition number for rebalance") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3") { + Seq( + Nil -> 3, + Seq(1) -> 1, + Seq(UnresolvedAttribute("a")) -> 3, + Seq(1, UnresolvedAttribute("a")) -> 1).foreach { case (param, initialNumPartitions) => + assert(UnresolvedHint("REBALANCE", param, testRelation).analyze + .asInstanceOf[RebalancePartitions].partitioning.numPartitions == initialNumPartitions) + } + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 1f3d1c4516778..782f3e41f42c7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -190,6 +190,7 @@ abstract class TypeCoercionSuiteBase extends AnalysisTest { test("implicit type cast - DateType") { val checkedType = DateType checkTypeCasting(checkedType, castableTypes = Seq(checkedType, StringType) ++ datetimeTypes) + shouldCast(checkedType, AnyTimestampType, AnyTimestampType.defaultConcreteType) shouldNotCast(checkedType, DecimalType) shouldNotCast(checkedType, NumericType) shouldNotCast(checkedType, IntegralType) @@ -198,6 +199,16 @@ abstract class TypeCoercionSuiteBase extends AnalysisTest { test("implicit type cast - TimestampType") { val checkedType = TimestampType checkTypeCasting(checkedType, castableTypes = Seq(checkedType, StringType) ++ datetimeTypes) + shouldCast(checkedType, AnyTimestampType, checkedType) + shouldNotCast(checkedType, DecimalType) + shouldNotCast(checkedType, NumericType) + shouldNotCast(checkedType, IntegralType) + } + + test("implicit type cast - TimestampNTZType") { + val checkedType = TimestampNTZType + checkTypeCasting(checkedType, castableTypes = Seq(checkedType, StringType) ++ datetimeTypes) + shouldCast(checkedType, AnyTimestampType, checkedType) shouldNotCast(checkedType, DecimalType) shouldNotCast(checkedType, NumericType) shouldNotCast(checkedType, IntegralType) @@ -476,6 +487,7 @@ class TypeCoercionSuite extends TypeCoercionSuiteBase { checkTypeCasting(checkedType, castableTypes = allTypes.filterNot(nonCastableTypes.contains)) shouldCast(checkedType, DecimalType, DecimalType.SYSTEM_DEFAULT) shouldCast(checkedType, NumericType, NumericType.defaultConcreteType) + shouldCast(checkedType, AnyTimestampType, AnyTimestampType.defaultConcreteType) shouldNotCast(checkedType, IntegralType) } @@ -1478,7 +1490,7 @@ class TypeCoercionSuite extends TypeCoercionSuiteBase { val wp1 = widenSetOperationTypes(union.select(p1.output.head, $"p2.v")) assert(wp1.isInstanceOf[Project]) // The attribute `p1.output.head` should be replaced in the root `Project`. - assert(wp1.expressions.forall(_.find(_ == p1.output.head).isEmpty)) + assert(wp1.expressions.forall(!_.exists(_ == p1.output.head))) val wp2 = widenSetOperationTypes(Aggregate(Nil, sum(p1.output.head).as("v") :: Nil, union)) assert(wp2.isInstanceOf[Aggregate]) assert(wp2.missingInput.isEmpty) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index d310538e302de..f791f778ecdc6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -481,6 +481,29 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac assert(catalog.listPartitions("db2", "tbl1", Some(part2.spec)).map(_.spec) == Seq(part2.spec)) } + test("SPARK-38120: list partitions with special chars and mixed case column name") { + val catalog = newBasicCatalog() + val table = CatalogTable( + identifier = TableIdentifier("tbl", Some("db1")), + tableType = CatalogTableType.EXTERNAL, + storage = storageFormat.copy(locationUri = Some(Utils.createTempDir().toURI)), + schema = new StructType() + .add("col1", "int") + .add("col2", "string") + .add("partCol1", "int") + .add("partCol2", "string"), + provider = Some(defaultProvider), + partitionColumnNames = Seq("partCol1", "partCol2")) + catalog.createTable(table, ignoreIfExists = false) + + val part1 = CatalogTablePartition(Map("partCol1" -> "1", "partCol2" -> "i+j"), storageFormat) + val part2 = CatalogTablePartition(Map("partCol1" -> "1", "partCol2" -> "i.j"), storageFormat) + catalog.createPartitions("db1", "tbl", Seq(part1, part2), ignoreIfExists = false) + + assert(catalog.listPartitions("db1", "tbl", Some(part1.spec)).map(_.spec) == Seq(part1.spec)) + assert(catalog.listPartitions("db1", "tbl", Some(part2.spec)).map(_.spec) == Seq(part2.spec)) + } + test("list partitions by filter") { val tz = TimeZone.getDefault.getID val catalog = newBasicCatalog() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 1a427848fa11b..c6bddfa5eee1f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -31,12 +31,11 @@ import org.apache.spark.sql.types.DataTypeTestUtils.{dayTimeIntervalTypes, yearM class ExamplePoint(val x: Double, val y: Double) extends Serializable { override def hashCode: Int = 41 * (41 + x.toInt) + y.toInt override def equals(that: Any): Boolean = { - if (that.isInstanceOf[ExamplePoint]) { - val e = that.asInstanceOf[ExamplePoint] - (this.x == e.x || (this.x.isNaN && e.x.isNaN) || (this.x.isInfinity && e.x.isInfinity)) && - (this.y == e.y || (this.y.isNaN && e.y.isNaN) || (this.y.isInfinity && e.y.isInfinity)) - } else { - false + that match { + case e: ExamplePoint => + (this.x == e.x || (this.x.isNaN && e.x.isNaN) || (this.x.isInfinity && e.x.isInfinity)) && + (this.y == e.y || (this.y.isNaN && e.y.isNaN) || (this.y.isInfinity && e.y.isInfinity)) + case _ => false } } } @@ -436,4 +435,27 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { } } } + + test("SPARK-38437: encoding TimestampType/DateType from any supported datetime Java types") { + Seq(true, false).foreach { java8Api => + withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> java8Api.toString) { + val schema = new StructType() + .add("t0", TimestampType) + .add("t1", TimestampType) + .add("d0", DateType) + .add("d1", DateType) + val encoder = RowEncoder(schema, lenient = true).resolveAndBind() + val instant = java.time.Instant.parse("2019-02-26T16:56:00Z") + val ld = java.time.LocalDate.parse("2022-03-08") + val row = encoder.createSerializer().apply( + Row(instant, java.sql.Timestamp.from(instant), ld, java.sql.Date.valueOf(ld))) + val expectedMicros = DateTimeUtils.instantToMicros(instant) + assert(row.getLong(0) === expectedMicros) + assert(row.getLong(1) === expectedMicros) + val expectedDays = DateTimeUtils.localDateToDays(ld) + assert(row.getInt(2) === expectedDays) + assert(row.getInt(3) === expectedDays) + } + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AnsiCastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AnsiCastSuiteBase.scala index 6338be1a2eb54..7fb04fe8b7f76 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AnsiCastSuiteBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AnsiCastSuiteBase.scala @@ -21,6 +21,7 @@ import java.sql.Timestamp import java.time.DateTimeException import org.apache.spark.SparkArithmeticException +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_SECOND import org.apache.spark.sql.catalyst.util.DateTimeTestUtils @@ -315,6 +316,28 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase { assert(ret.resolved) checkCastToBooleanError(array_notNull, to, Seq(null, true, false)) } + + { + val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false)) + assert(ret.resolved == !isTryCast) + if (!isTryCast) { + checkExceptionInExpression[UnsupportedOperationException]( + ret, "invalid input syntax for type boolean") + } + } + } + + test("cast from array III") { + if (!isTryCast) { + val from: DataType = ArrayType(DoubleType, containsNull = false) + val array = Literal.create(Seq(1.0, 2.0), from) + val to: DataType = ArrayType(IntegerType, containsNull = false) + val answer = Literal.create(Seq(1, 2), to).value + checkEvaluation(cast(array, to), answer) + + val overflowArray = Literal.create(Seq(Int.MaxValue + 1.0D), from) + checkExceptionInExpression[ArithmeticException](cast(overflowArray, to), "overflow") + } } test("cast from map II") { @@ -340,6 +363,49 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase { assert(ret.resolved) checkCastToBooleanError(map_notNull, to, Map("a" -> null, "b" -> true, "c" -> false)) } + + { + val ret = cast(map, MapType(IntegerType, StringType, valueContainsNull = true)) + assert(ret.resolved == !isTryCast) + if (!isTryCast) { + checkExceptionInExpression[NumberFormatException]( + ret, "invalid input syntax for type numeric") + } + } + + { + val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false)) + assert(ret.resolved == !isTryCast) + if (!isTryCast) { + checkExceptionInExpression[UnsupportedOperationException]( + ret, "invalid input syntax for type boolean") + } + } + + { + val ret = cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true)) + assert(ret.resolved == !isTryCast) + if (!isTryCast) { + checkExceptionInExpression[NumberFormatException]( + ret, "invalid input syntax for type numeric") + } + } + } + + test("cast from map III") { + if (!isTryCast) { + val from: DataType = MapType(DoubleType, DoubleType, valueContainsNull = false) + val map = Literal.create(Map(1.0 -> 2.0), from) + val to: DataType = MapType(IntegerType, IntegerType, valueContainsNull = false) + val answer = Literal.create(Map(1 -> 2), to).value + checkEvaluation(cast(map, to), answer) + + Seq( + Literal.create(Map((Int.MaxValue + 1.0) -> 2.0), from), + Literal.create(Map(1.0 -> (Int.MinValue - 1.0)), from)).foreach { overflowMap => + checkExceptionInExpression[ArithmeticException](cast(overflowMap, to), "overflow") + } + } } test("cast from struct II") { @@ -392,6 +458,62 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase { assert(ret.resolved) checkCastToBooleanError(struct_notNull, to, InternalRow(null, true, false)) } + + { + val ret = cast(struct_notNull, StructType(Seq( + StructField("a", BooleanType, nullable = true), + StructField("b", BooleanType, nullable = true), + StructField("c", BooleanType, nullable = false)))) + assert(ret.resolved == !isTryCast) + if (!isTryCast) { + checkExceptionInExpression[UnsupportedOperationException]( + ret, "invalid input syntax for type boolean") + } + } + } + + test("cast from struct III") { + if (!isTryCast) { + val from: DataType = StructType(Seq(StructField("a", DoubleType, nullable = false))) + val struct = Literal.create(InternalRow(1.0), from) + val to: DataType = StructType(Seq(StructField("a", IntegerType, nullable = false))) + val answer = Literal.create(InternalRow(1), to).value + checkEvaluation(cast(struct, to), answer) + + val overflowStruct = Literal.create(InternalRow(Int.MaxValue + 1.0), from) + checkExceptionInExpression[ArithmeticException](cast(overflowStruct, to), "overflow") + } + } + + test("complex casting") { + val complex = Literal.create( + Row( + Seq("123", "true", "f"), + Map("a" -> "123", "b" -> "true", "c" -> "f"), + Row(0)), + StructType(Seq( + StructField("a", + ArrayType(StringType, containsNull = false), nullable = true), + StructField("m", + MapType(StringType, StringType, valueContainsNull = false), nullable = true), + StructField("s", + StructType(Seq( + StructField("i", IntegerType, nullable = true))))))) + + val ret = cast(complex, StructType(Seq( + StructField("a", + ArrayType(IntegerType, containsNull = true), nullable = true), + StructField("m", + MapType(StringType, BooleanType, valueContainsNull = false), nullable = true), + StructField("s", + StructType(Seq( + StructField("l", LongType, nullable = true))))))) + + assert(ret.resolved === !isTryCast) + if (!isTryCast) { + checkExceptionInExpression[NumberFormatException]( + ret, "invalid input syntax for type numeric") + } } test("ANSI mode: cast string to timestamp with parse error") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 31d7a4b0a87e0..522313ffeb184 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -78,10 +78,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(UnaryMinus(input), convert(-1)) checkEvaluation(UnaryMinus(Literal.create(null, dataType)), null) } - checkEvaluation(UnaryMinus(Literal(Long.MinValue)), Long.MinValue) - checkEvaluation(UnaryMinus(Literal(Int.MinValue)), Int.MinValue) - checkEvaluation(UnaryMinus(Literal(Short.MinValue)), Short.MinValue) - checkEvaluation(UnaryMinus(Literal(Byte.MinValue)), Byte.MinValue) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + checkEvaluation(UnaryMinus(Literal(Long.MinValue)), Long.MinValue) + checkEvaluation(UnaryMinus(Literal(Int.MinValue)), Int.MinValue) + checkEvaluation(UnaryMinus(Literal(Short.MinValue)), Short.MinValue) + checkEvaluation(UnaryMinus(Literal(Byte.MinValue)), Byte.MinValue) + } withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { checkExceptionInExpression[ArithmeticException]( UnaryMinus(Literal(Long.MinValue)), "overflow") @@ -170,7 +172,13 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Divide(left, right), convert(2)) checkEvaluation(Divide(Literal.create(null, left.dataType), right), null) checkEvaluation(Divide(left, Literal.create(null, right.dataType)), null) - checkEvaluation(Divide(left, Literal(convert(0))), null) // divide by zero + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + checkEvaluation(Divide(left, Literal(convert(0))), null) // divide by zero + } + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + checkExceptionInExpression[ArithmeticException]( + Divide(left, Literal(convert(0))), "divide by zero") + } } Seq("true", "false").foreach { failOnError => @@ -194,7 +202,13 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(IntegralDivide(left, right), 0L) checkEvaluation(IntegralDivide(Literal.create(null, left.dataType), right), null) checkEvaluation(IntegralDivide(left, Literal.create(null, right.dataType)), null) - checkEvaluation(IntegralDivide(left, Literal(convert(0))), null) // divide by zero + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + checkEvaluation(IntegralDivide(left, Literal(convert(0))), null) // divide by zero + } + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + checkExceptionInExpression[ArithmeticException]( + IntegralDivide(left, Literal(convert(0))), "divide by zero") + } } checkEvaluation(IntegralDivide(positiveLongLit, negativeLongLit), 0L) @@ -222,7 +236,13 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Remainder(left, right), convert(1)) checkEvaluation(Remainder(Literal.create(null, left.dataType), right), null) checkEvaluation(Remainder(left, Literal.create(null, right.dataType)), null) - checkEvaluation(Remainder(left, Literal(convert(0))), null) // mod by 0 + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + checkEvaluation(Remainder(left, Literal(convert(0))), null) // mod by 0 + } + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + checkExceptionInExpression[ArithmeticException]( + Remainder(left, Literal(convert(0))), "divide by zero") + } } checkEvaluation(Remainder(positiveShortLit, positiveShortLit), 0.toShort) checkEvaluation(Remainder(negativeShortLit, negativeShortLit), 0.toShort) @@ -304,7 +324,13 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Pmod(left, right), convert(1)) checkEvaluation(Pmod(Literal.create(null, left.dataType), right), null) checkEvaluation(Pmod(left, Literal.create(null, right.dataType)), null) - checkEvaluation(Pmod(left, Literal(convert(0))), null) // mod by 0 + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + checkEvaluation(Pmod(left, Literal(convert(0))), null) // mod by 0 + } + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + checkExceptionInExpression[ArithmeticException]( + Pmod(left, Literal(convert(0))), "divide by zero") + } } checkEvaluation(Pmod(Literal(-7), Literal(3)), 2) checkEvaluation(Pmod(Literal(7.2D), Literal(4.1D)), 3.1000000000000005) @@ -461,15 +487,24 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(IntegralDivide(Literal(Decimal(1)), Literal(Decimal(2))), 0L) checkEvaluation(IntegralDivide(Literal(Decimal(2.4)), Literal(Decimal(1.1))), 2L) checkEvaluation(IntegralDivide(Literal(Decimal(1.2)), Literal(Decimal(1.1))), 1L) - checkEvaluation(IntegralDivide(Literal(Decimal(0.2)), Literal(Decimal(0.0))), null) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + checkEvaluation( + IntegralDivide(Literal(Decimal(0.2)), Literal(Decimal(0.0))), null) // mod by 0 + } + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + checkExceptionInExpression[ArithmeticException]( + IntegralDivide(Literal(Decimal(0.2)), Literal(Decimal(0.0))), "divide by zero") + } // overflows long and so returns a wrong result checkEvaluation(DecimalPrecision.decimalAndDecimal.apply(IntegralDivide( Literal(Decimal("99999999999999999999999999999999999")), Literal(Decimal(0.001)))), 687399551400672280L) // overflow during promote precision - checkEvaluation(DecimalPrecision.decimalAndDecimal.apply(IntegralDivide( - Literal(Decimal("99999999999999999999999999999999999999")), Literal(Decimal(0.00001)))), - null) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + checkEvaluation(DecimalPrecision.decimalAndDecimal.apply(IntegralDivide( + Literal(Decimal("99999999999999999999999999999999999999")), Literal(Decimal(0.00001)))), + null) + } } test("SPARK-24598: overflow on long returns wrong result") { @@ -701,13 +736,25 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } test("SPARK-36921: Support YearMonthIntervalType by div") { - checkEvaluation(IntegralDivide(Literal(Period.ZERO), Literal(Period.ZERO)), null) - checkEvaluation(IntegralDivide(Literal(Period.ofYears(1)), - Literal(Period.ZERO)), null) - checkEvaluation(IntegralDivide(Period.ofMonths(Int.MinValue), - Literal(Period.ZERO)), null) - checkEvaluation(IntegralDivide(Period.ofMonths(Int.MaxValue), - Literal(Period.ZERO)), null) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + checkEvaluation(IntegralDivide(Literal(Period.ZERO), Literal(Period.ZERO)), null) + checkEvaluation(IntegralDivide(Literal(Period.ofYears(1)), + Literal(Period.ZERO)), null) + checkEvaluation(IntegralDivide(Period.ofMonths(Int.MinValue), + Literal(Period.ZERO)), null) + checkEvaluation(IntegralDivide(Period.ofMonths(Int.MaxValue), + Literal(Period.ZERO)), null) + } + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + checkExceptionInExpression[ArithmeticException]( + IntegralDivide(Literal(Period.ZERO), Literal(Period.ZERO)), "divide by zero") + checkExceptionInExpression[ArithmeticException]( + IntegralDivide(Literal(Period.ofYears(1)), Literal(Period.ZERO)), "divide by zero") + checkExceptionInExpression[ArithmeticException]( + IntegralDivide(Period.ofMonths(Int.MinValue), Literal(Period.ZERO)), "divide by zero") + checkExceptionInExpression[ArithmeticException]( + IntegralDivide(Period.ofMonths(Int.MaxValue), Literal(Period.ZERO)), "divide by zero") + } checkEvaluation(IntegralDivide(Literal.create(null, YearMonthIntervalType()), Literal.create(null, YearMonthIntervalType())), null) @@ -741,13 +788,28 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper Literal(Period.ofMonths(-5))), -2L) } test("SPARK-36921: Support DayTimeIntervalType by div") { - checkEvaluation(IntegralDivide(Literal(Duration.ZERO), Literal(Duration.ZERO)), null) - checkEvaluation(IntegralDivide(Literal(Duration.ofDays(1)), - Literal(Duration.ZERO)), null) - checkEvaluation(IntegralDivide(Literal(Duration.of(Long.MaxValue, ChronoUnit.MICROS)), - Literal(Duration.ZERO)), null) - checkEvaluation(IntegralDivide(Literal(Duration.of(Long.MinValue, ChronoUnit.MICROS)), - Literal(Duration.ZERO)), null) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + checkEvaluation(IntegralDivide(Literal(Duration.ZERO), Literal(Duration.ZERO)), null) + checkEvaluation(IntegralDivide(Literal(Duration.ofDays(1)), + Literal(Duration.ZERO)), null) + checkEvaluation(IntegralDivide(Literal(Duration.of(Long.MaxValue, ChronoUnit.MICROS)), + Literal(Duration.ZERO)), null) + checkEvaluation(IntegralDivide(Literal(Duration.of(Long.MinValue, ChronoUnit.MICROS)), + Literal(Duration.ZERO)), null) + } + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + checkExceptionInExpression[ArithmeticException]( + IntegralDivide(Literal(Duration.ZERO), Literal(Duration.ZERO)), "divide by zero") + checkExceptionInExpression[ArithmeticException]( + IntegralDivide(Literal(Duration.ofDays(1)), + Literal(Duration.ZERO)), "divide by zero") + checkExceptionInExpression[ArithmeticException]( + IntegralDivide(Literal(Duration.of(Long.MaxValue, ChronoUnit.MICROS)), + Literal(Duration.ZERO)), "divide by zero") + checkExceptionInExpression[ArithmeticException]( + IntegralDivide(Literal(Duration.of(Long.MinValue, ChronoUnit.MICROS)), + Literal(Duration.ZERO)), "divide by zero") + } checkEvaluation(IntegralDivide(Literal.create(null, DayTimeIntervalType()), Literal.create(null, DayTimeIntervalType())), null) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala index 1805189b268db..83307c9022dd2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.logical.Range -import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType} +import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType} class CanonicalizeSuite extends SparkFunSuite { @@ -177,4 +177,17 @@ class CanonicalizeSuite extends SparkFunSuite { assert(expr.semanticEquals(attr)) assert(attr.semanticEquals(expr)) } + + test("SPARK-38030: Canonicalization should not remove nullability of AttributeReference" + + " dataType") { + val structType = StructType(Seq(StructField("name", StringType, nullable = false))) + val attr = AttributeReference("col", structType)() + // AttributeReference dataType should not be converted to nullable + assert(attr.canonicalized.dataType === structType) + + val cast = Cast(attr, structType) + assert(cast.resolved) + // canonicalization should not converted resolved cast to unresolved + assert(cast.canonicalized.resolved) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index ba36fa0314cb8..ca110502c6b3a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -39,6 +39,15 @@ import org.apache.spark.unsafe.types.UTF8String * in `CastSuiteBase` instead of this file to ensure the test coverage. */ class CastSuite extends CastSuiteBase { + override def beforeAll(): Unit = { + super.beforeAll() + SQLConf.get.setConf(SQLConf.ANSI_ENABLED, false) + } + + override def afterAll(): Unit = { + super.afterAll() + SQLConf.get.unsetConf(SQLConf.ANSI_ENABLED) + } override def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): CastBase = { v match { @@ -233,6 +242,11 @@ class CastSuite extends CastSuiteBase { assert(ret.resolved) checkEvaluation(ret, Seq(null, true, false)) } + + { + val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false)) + assert(ret.resolved === false) + } } test("cast from map II") { @@ -254,6 +268,21 @@ class CastSuite extends CastSuiteBase { assert(ret.resolved) checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false)) } + + { + val ret = cast(map, MapType(IntegerType, StringType, valueContainsNull = true)) + assert(ret.resolved === false) + } + + { + val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false)) + assert(ret.resolved === false) + } + + { + val ret = cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true)) + assert(ret.resolved === false) + } } test("cast from struct II") { @@ -304,6 +333,41 @@ class CastSuite extends CastSuiteBase { assert(ret.resolved) checkEvaluation(ret, InternalRow(null, true, false)) } + + { + val ret = cast(struct_notNull, StructType(Seq( + StructField("a", BooleanType, nullable = true), + StructField("b", BooleanType, nullable = true), + StructField("c", BooleanType, nullable = false)))) + assert(ret.resolved === false) + } + } + + test("complex casting") { + val complex = Literal.create( + Row( + Seq("123", "true", "f"), + Map("a" -> "123", "b" -> "true", "c" -> "f"), + Row(0)), + StructType(Seq( + StructField("a", + ArrayType(StringType, containsNull = false), nullable = true), + StructField("m", + MapType(StringType, StringType, valueContainsNull = false), nullable = true), + StructField("s", + StructType(Seq( + StructField("i", IntegerType, nullable = true))))))) + + val ret = cast(complex, StructType(Seq( + StructField("a", + ArrayType(IntegerType, containsNull = true), nullable = true), + StructField("m", + MapType(StringType, BooleanType, valueContainsNull = false), nullable = true), + StructField("s", + StructType(Seq( + StructField("l", LongType, nullable = true))))))) + + assert(ret.resolved === false) } test("SPARK-31227: Non-nullable null type should not coerce to nullable type") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala index 54497f1b21edb..ba8ab708046d1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala @@ -427,11 +427,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { assert(ret.resolved === false) } - { - val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false)) - assert(ret.resolved === false) - } - { val ret = cast(array, IntegerType) assert(ret.resolved === false) @@ -452,18 +447,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = false)) assert(ret.resolved === false) } - { - val ret = cast(map, MapType(IntegerType, StringType, valueContainsNull = true)) - assert(ret.resolved === false) - } - { - val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false)) - assert(ret.resolved === false) - } - { - val ret = cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true)) - assert(ret.resolved === false) - } { val ret = cast(map, IntegerType) @@ -510,14 +493,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { assert(ret.resolved === false) } - { - val ret = cast(struct_notNull, StructType(Seq( - StructField("a", BooleanType, nullable = true), - StructField("b", BooleanType, nullable = true), - StructField("c", BooleanType, nullable = false)))) - assert(ret.resolved === false) - } - { val ret = cast(struct, StructType(Seq( StructField("a", StringType, nullable = true), @@ -541,33 +516,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(inp, targetSchema), expected) } - test("complex casting") { - val complex = Literal.create( - Row( - Seq("123", "true", "f"), - Map("a" -> "123", "b" -> "true", "c" -> "f"), - Row(0)), - StructType(Seq( - StructField("a", - ArrayType(StringType, containsNull = false), nullable = true), - StructField("m", - MapType(StringType, StringType, valueContainsNull = false), nullable = true), - StructField("s", - StructType(Seq( - StructField("i", IntegerType, nullable = true))))))) - - val ret = cast(complex, StructType(Seq( - StructField("a", - ArrayType(IntegerType, containsNull = true), nullable = true), - StructField("m", - MapType(StringType, BooleanType, valueContainsNull = false), nullable = true), - StructField("s", - StructType(Seq( - StructField("l", LongType, nullable = true))))))) - - assert(ret.resolved === false) - } - test("cast between string and interval") { import org.apache.spark.unsafe.types.CalendarInterval diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 2b59d723ab66b..1e4499a0ee3fe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -330,7 +330,9 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) GenerateUnsafeProjection.generate( ValidateExternalType( - GetExternalRowField(inputObject, index = 0, fieldName = "\"quote"), IntegerType) :: Nil) + GetExternalRowField(inputObject, index = 0, fieldName = "\"quote"), + IntegerType, + lenient = false) :: Nil) } test("SPARK-17160: field names are properly escaped by AssertTrue") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 7dfa3ea6f5a15..3cf3b4469a4d2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -66,7 +66,9 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("Array and Map Size - legacy") { - withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") { + withSQLConf( + SQLConf.LEGACY_SIZE_OF_NULL.key -> "true", + SQLConf.ANSI_ENABLED.key -> "false") { testSize(sizeOfNull = -1) } } @@ -1437,8 +1439,10 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ElementAt(a0, Literal(0)), null) }.getMessage.contains("SQL array indices start at 1") intercept[Exception] { checkEvaluation(ElementAt(a0, Literal(1.1)), null) } - checkEvaluation(ElementAt(a0, Literal(4)), null) - checkEvaluation(ElementAt(a0, Literal(-4)), null) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + checkEvaluation(ElementAt(a0, Literal(4)), null) + checkEvaluation(ElementAt(a0, Literal(-4)), null) + } checkEvaluation(ElementAt(a0, Literal(1)), 1) checkEvaluation(ElementAt(a0, Literal(2)), 2) @@ -1464,9 +1468,10 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper assert(ElementAt(m0, Literal(1.0)).checkInputDataTypes().isFailure) - checkEvaluation(ElementAt(m0, Literal("d")), null) - - checkEvaluation(ElementAt(m1, Literal("a")), null) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + checkEvaluation(ElementAt(m0, Literal("d")), null) + checkEvaluation(ElementAt(m1, Literal("a")), null) + } checkEvaluation(ElementAt(m0, Literal("a")), "1") checkEvaluation(ElementAt(m0, Literal("b")), "2") @@ -1480,9 +1485,10 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper MapType(BinaryType, StringType)) val mb1 = Literal.create(Map[Array[Byte], String](), MapType(BinaryType, StringType)) - checkEvaluation(ElementAt(mb0, Literal(Array[Byte](1, 2, 3))), null) - - checkEvaluation(ElementAt(mb1, Literal(Array[Byte](1, 2))), null) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + checkEvaluation(ElementAt(mb0, Literal(Array[Byte](1, 2, 3))), null) + checkEvaluation(ElementAt(mb1, Literal(Array[Byte](1, 2))), null) + } checkEvaluation(ElementAt(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2") checkEvaluation(ElementAt(mb0, Literal(Array[Byte](3, 4))), null) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala index 7945974a1f3dc..1d174ed214523 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala @@ -158,13 +158,13 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P } test("infer schema of CSV strings") { - checkEvaluation(new SchemaOfCsv(Literal.create("1,abc")), "STRUCT<`_c0`: INT, `_c1`: STRING>") + checkEvaluation(new SchemaOfCsv(Literal.create("1,abc")), "STRUCT<_c0: INT, _c1: STRING>") } test("infer schema of CSV strings by using options") { checkEvaluation( new SchemaOfCsv(Literal.create("1|abc"), Map("delimiter" -> "|")), - "STRUCT<`_c0`: INT, `_c1`: STRING>") + "STRUCT<_c0: INT, _c1: STRING>") } test("to_csv - struct") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index d0c0a1948b442..c5d559c4501af 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -1462,7 +1462,9 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Literal("yyyy-MM-dd'T'HH:mm:ss.SSSz"), TimestampType), 1580184371847000L) } - withSQLConf(SQLConf.LEGACY_TIME_PARSER_POLICY.key -> "corrected") { + withSQLConf( + SQLConf.LEGACY_TIME_PARSER_POLICY.key -> "corrected", + SQLConf.ANSI_ENABLED.key -> "false") { checkEvaluation( GetTimestamp( Literal("2020-01-27T20:06:11.847-0800"), @@ -1481,8 +1483,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Consistent error handling for datetime formatting and parsing functions") { def checkException[T <: Exception : ClassTag](c: String): Unit = { - checkExceptionInExpression[T](new ParseToTimestamp(Literal("1"), Literal(c)).child, c) - checkExceptionInExpression[T](new ParseToDate(Literal("1"), Literal(c)).child, c) + checkExceptionInExpression[T](new ParseToTimestamp(Literal("1"), Literal(c)).replacement, c) + checkExceptionInExpression[T](new ParseToDate(Literal("1"), Literal(c)).replacement, c) checkExceptionInExpression[T](ToUnixTimestamp(Literal("1"), Literal(c)), c) checkExceptionInExpression[T](UnixTimestamp(Literal("1"), Literal(c)), c) if (!Set("E", "F", "q", "Q").contains(c)) { @@ -1502,10 +1504,10 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-31896: Handle am-pm timestamp parsing when hour is missing") { checkEvaluation( - new ParseToTimestamp(Literal("PM"), Literal("a")).child, + new ParseToTimestamp(Literal("PM"), Literal("a")).replacement, Timestamp.valueOf("1970-01-01 12:00:00.0")) checkEvaluation( - new ParseToTimestamp(Literal("11:11 PM"), Literal("mm:ss a")).child, + new ParseToTimestamp(Literal("11:11 PM"), Literal("mm:ss a")).replacement, Timestamp.valueOf("1970-01-01 12:11:11.0")) } @@ -1735,7 +1737,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { exprSeq2.foreach(pair => checkExceptionInExpression[SparkUpgradeException]( pair._1, - "You may get a different result due to the upgrading of Spark 3.0")) + "You may get a different result due to the upgrading to Spark >= 3.0")) } else { if (ansiEnabled) { exprSeq2.foreach(pair => @@ -1885,4 +1887,117 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + + test("SPARK-38195: add a quantity of interval units to a timestamp") { + // Check case-insensitivity + checkEvaluation( + TimestampAdd("Hour", Literal(1), Literal(LocalDateTime.of(2022, 2, 15, 12, 57, 0))), + LocalDateTime.of(2022, 2, 15, 13, 57, 0)) + // Check nulls as input values + checkEvaluation( + TimestampAdd( + "MINUTE", + Literal.create(null, IntegerType), + Literal(LocalDateTime.of(2022, 2, 15, 12, 57, 0))), + null) + checkEvaluation( + TimestampAdd( + "MINUTE", + Literal(1), + Literal.create(null, TimestampType)), + null) + // Check crossing the daylight saving time + checkEvaluation( + TimestampAdd( + "HOUR", + Literal(6), + Literal(Instant.parse("2022-03-12T23:30:00Z")), + Some("America/Los_Angeles")), + Instant.parse("2022-03-13T05:30:00Z")) + // Check the leap year + checkEvaluation( + TimestampAdd( + "DAY", + Literal(2), + Literal(LocalDateTime.of(2020, 2, 28, 10, 11, 12)), + Some("America/Los_Angeles")), + LocalDateTime.of(2020, 3, 1, 10, 11, 12)) + + Seq( + "YEAR", "QUARTER", "MONTH", + "WEEK", "DAY", + "HOUR", "MINUTE", "SECOND", + "MILLISECOND", "MICROSECOND" + ).foreach { unit => + outstandingTimezonesIds.foreach { tz => + Seq(TimestampNTZType, TimestampType).foreach { tsType => + checkConsistencyBetweenInterpretedAndCodegenAllowingException( + (quantity: Expression, timestamp: Expression) => + TimestampAdd( + unit, + quantity, + timestamp, + Some(tz)), + IntegerType, tsType) + } + } + } + } + + test("SPARK-38284: difference between two timestamps in units") { + // Check case-insensitivity + checkEvaluation( + TimestampDiff( + "Hour", + Literal(Instant.parse("2022-02-15T12:57:00Z")), + Literal(Instant.parse("2022-02-15T13:57:00Z"))), + 1L) + // Check nulls as input values + checkEvaluation( + TimestampDiff( + "MINUTE", + Literal.create(null, TimestampType), + Literal(Instant.parse("2022-02-15T12:57:00Z"))), + null) + checkEvaluation( + TimestampDiff( + "MINUTE", + Literal(Instant.parse("2021-02-15T12:57:00Z")), + Literal.create(null, TimestampType)), + null) + // Check crossing the daylight saving time + checkEvaluation( + TimestampDiff( + "HOUR", + Literal(Instant.parse("2022-03-12T23:30:00Z")), + Literal(Instant.parse("2022-03-13T05:30:00Z")), + Some("America/Los_Angeles")), + 6L) + // Check the leap year + checkEvaluation( + TimestampDiff( + "DAY", + Literal(Instant.parse("2020-02-28T10:11:12Z")), + Literal(Instant.parse("2020-03-01T10:21:12Z")), + Some("America/Los_Angeles")), + 2L) + + Seq( + "YEAR", "QUARTER", "MONTH", + "WEEK", "DAY", + "HOUR", "MINUTE", "SECOND", + "MILLISECOND", "MICROSECOND" + ).foreach { unit => + outstandingTimezonesIds.foreach { tz => + checkConsistencyBetweenInterpretedAndCodegenAllowingException( + (startTs: Expression, endTs: Expression) => + TimestampDiff( + unit, + startTs, + endTs, + Some(tz)), + TimestampType, TimestampType) + } + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 2ae7c76599e5a..af071727b10dc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -736,17 +736,17 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("SPARK-24709: infer schema of json strings") { checkEvaluation(new SchemaOfJson(Literal.create("""{"col":0}""")), - "STRUCT<`col`: BIGINT>") + "STRUCT") checkEvaluation( new SchemaOfJson(Literal.create("""{"col0":["a"], "col1": {"col2": "b"}}""")), - "STRUCT<`col0`: ARRAY, `col1`: STRUCT<`col2`: STRING>>") + "STRUCT, col1: STRUCT>") } test("infer schema of JSON strings by using options") { checkEvaluation( new SchemaOfJson(Literal.create("""{"col":01}"""), CreateMap(Seq(Literal.create("allowNumericLeadingZeros"), Literal.create("true")))), - "STRUCT<`col`: BIGINT>") + "STRUCT") } test("parse date with locale") { @@ -811,7 +811,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with } Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach { - checkDecimalInfer(_, """STRUCT<`d`: DECIMAL(7,3)>""") + checkDecimalInfer(_, """STRUCT""") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index 4081e138d2b62..6ce51f1eec8ca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.types.DayTimeIntervalType._ import org.apache.spark.sql.types.YearMonthIntervalType._ -import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -231,17 +231,6 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkStructLiteral((Period.ZERO, ("abc", Duration.ofDays(1)))) } - test("unsupported types (map and struct) in Literal.apply") { - def checkUnsupportedTypeInLiteral(v: Any): Unit = { - val errMsgMap = intercept[RuntimeException] { - Literal(v) - } - assert(errMsgMap.getMessage.startsWith("Unsupported literal type")) - } - checkUnsupportedTypeInLiteral(Map("key1" -> 1, "key2" -> 2)) - checkUnsupportedTypeInLiteral(("mike", 29, 1.0)) - } - test("SPARK-24571: char literals") { checkEvaluation(Literal('X'), "X") checkEvaluation(Literal.create('0'), "0") @@ -465,4 +454,10 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(duration, dt), result) } } + + test("SPARK-37967: Literal.create support ObjectType") { + checkEvaluation( + Literal.create(UTF8String.fromString("Spark SQL"), ObjectType(classOf[UTF8String])), + UTF8String.fromString("Spark SQL")) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index ea0d619ad4c15..5281643b7b107 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -321,11 +321,21 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Cbrt, DoubleType) } - def checkDataTypeAndCast(expression: UnaryMathExpression): Expression = { + def checkDataTypeAndCast(expression: Expression): Expression = expression match { + case e: UnaryMathExpression => checkDataTypeAndCastUnaryMathExpression(e) + case e: RoundBase => checkDataTypeAndCastRoundBase(e) + } + + def checkDataTypeAndCastUnaryMathExpression(expression: UnaryMathExpression): Expression = { val expNew = implicitCast(expression.child, expression.inputTypes(0)).getOrElse(expression) expression.withNewChildren(Seq(expNew)) } + def checkDataTypeAndCastRoundBase(expression: RoundBase): Expression = { + val expNewLeft = implicitCast(expression.left, expression.inputTypes(0)).getOrElse(expression) + expression.withNewChildren(Seq(expNewLeft, expression.right)) + } + test("ceil") { testUnary(Ceil, (d: Double) => math.ceil(d).toLong) checkConsistencyBetweenInterpretedAndCodegen(Ceil, DoubleType) @@ -630,7 +640,7 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Logarithm, DoubleType, DoubleType) } - test("round/bround") { + test("round/bround/floor/ceil") { val scales = -6 to 6 val doublePi: Double = math.Pi val shortPi: Short = 31415 @@ -658,6 +668,66 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val intResultsB: Seq[Int] = Seq(314000000, 314200000, 314160000, 314159000, 314159300, 314159260) ++ Seq.fill(7)(314159265) + def doubleResultsFloor(i: Int): Decimal = { + val results = Seq(0, 0, 0, 0, 0, 0, 3, + 3.1, 3.14, 3.141, 3.1415, 3.14159, 3.141592) + Decimal(results(i)) + } + + def doubleResultsCeil(i: Int): Any = { + val results = Seq(1000000, 100000, 10000, 1000, 100, 10, + 4, 3.2, 3.15, 3.142, 3.1416, 3.1416, 3.141593) + Decimal(results(i)) + } + + def floatResultsFloor(i: Int): Any = { + val results = Seq(0, 0, 0, 0, 0, 0, 3, + 3.1, 3.14, 3.141, 3.1415, 3.1415, 3.1415) + Decimal(results(i)) + } + + def floatResultsCeil(i: Int): Any = { + val results = Seq(1000000, 100000, 10000, 1000, 100, 10, 4, + 3.2, 3.15, 3.142, 3.1415, 3.1415, 3.1415) + Decimal(results(i)) + } + + def shortResultsFloor(i: Int): Decimal = { + val results = Seq(0, 0, 30000, 31000, 31400, 31410) ++ Seq.fill(7)(31415) + Decimal(results(i)) + } + + def shortResultsCeil(i: Int): Decimal = { + val results = Seq(1000000, 100000, 40000, 32000, 31500, 31420) ++ Seq.fill(7)(31415) + Decimal(results(i)) + } + + def longResultsFloor(i: Int): Decimal = { + val results = Seq(31415926535000000L, 31415926535800000L, 31415926535890000L, + 31415926535897000L, 31415926535897900L, 31415926535897930L, 31415926535897932L) ++ + Seq.fill(6)(31415926535897932L) + Decimal(results(i)) + } + + def longResultsCeil(i: Int): Decimal = { + val results = Seq(31415926536000000L, 31415926535900000L, 31415926535900000L, + 31415926535898000L, 31415926535898000L, 31415926535897940L) ++ + Seq.fill(7)(31415926535897932L) + Decimal(results(i)) + } + + def intResultsFloor(i: Int): Decimal = { + val results = Seq(314000000, 314100000, 314150000, 314159000, + 314159200, 314159260) ++ Seq.fill(7)(314159265) + Decimal(results(i)) + } + + def intResultsCeil(i: Int): Decimal = { + val results = Seq(315000000, 314200000, 314160000, 314160000, + 314159300, 314159270) ++ Seq.fill(7)(314159265) + Decimal(results(i)) + } + scales.zipWithIndex.foreach { case (scale, i) => checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow) checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow) @@ -669,19 +739,52 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(BRound(intPi, scale), intResultsB(i), EmptyRow) checkEvaluation(BRound(longPi, scale), longResults(i), EmptyRow) checkEvaluation(BRound(floatPi, scale), floatResults(i), EmptyRow) + checkEvaluation(checkDataTypeAndCast( + RoundFloor(Literal(doublePi), Literal(scale))), doubleResultsFloor(i), EmptyRow) + checkEvaluation(checkDataTypeAndCast( + RoundFloor(Literal(shortPi), Literal(scale))), shortResultsFloor(i), EmptyRow) + checkEvaluation(checkDataTypeAndCast( + RoundFloor(Literal(intPi), Literal(scale))), intResultsFloor(i), EmptyRow) + checkEvaluation(checkDataTypeAndCast( + RoundFloor(Literal(longPi), Literal(scale))), longResultsFloor(i), EmptyRow) + checkEvaluation(checkDataTypeAndCast( + RoundFloor(Literal(floatPi), Literal(scale))), floatResultsFloor(i), EmptyRow) + checkEvaluation(checkDataTypeAndCast( + RoundCeil(Literal(doublePi), Literal(scale))), doubleResultsCeil(i), EmptyRow) + checkEvaluation(checkDataTypeAndCast( + RoundCeil(Literal(shortPi), Literal(scale))), shortResultsCeil(i), EmptyRow) + checkEvaluation(checkDataTypeAndCast( + RoundCeil(Literal(intPi), Literal(scale))), intResultsCeil(i), EmptyRow) + checkEvaluation(checkDataTypeAndCast( + RoundCeil(Literal(longPi), Literal(scale))), longResultsCeil(i), EmptyRow) + checkEvaluation(checkDataTypeAndCast( + RoundCeil(Literal(floatPi), Literal(scale))), floatResultsCeil(i), EmptyRow) } val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3), BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142"), BigDecimal("3.1416"), BigDecimal("3.14159"), BigDecimal("3.141593"), BigDecimal("3.1415927")) + val bdResultsFloor: Seq[BigDecimal] = + Seq(BigDecimal(3), BigDecimal("3.1"), BigDecimal("3.14"), + BigDecimal("3.141"), BigDecimal("3.1415"), BigDecimal("3.14159"), + BigDecimal("3.141592"), BigDecimal("3.1415927")) + + val bdResultsCeil: Seq[BigDecimal] = Seq(BigDecimal(4), BigDecimal("3.2"), BigDecimal("3.15"), + BigDecimal("3.142"), BigDecimal("3.1416"), BigDecimal("3.14160"), + BigDecimal("3.141593"), BigDecimal("3.1415927")) + (0 to 7).foreach { i => checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow) checkEvaluation(BRound(bdPi, i), bdResults(i), EmptyRow) + checkEvaluation(RoundFloor(bdPi, i), bdResultsFloor(i), EmptyRow) + checkEvaluation(RoundCeil(bdPi, i), bdResultsCeil(i), EmptyRow) } (8 to 10).foreach { scale => checkEvaluation(Round(bdPi, scale), bdPi, EmptyRow) checkEvaluation(BRound(bdPi, scale), bdPi, EmptyRow) + checkEvaluation(RoundFloor(bdPi, scale), bdPi, EmptyRow) + checkEvaluation(RoundCeil(bdPi, scale), bdPi, EmptyRow) } DataTypeTestUtils.numericTypes.foreach { dataType => @@ -691,6 +794,10 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(BRound(Literal.create(null, dataType), Literal(2)), null) checkEvaluation(BRound(Literal.create(null, dataType), Literal.create(null, IntegerType)), null) + checkEvaluation(checkDataTypeAndCast( + RoundFloor(Literal.create(null, dataType), Literal(2))), null) + checkEvaluation(checkDataTypeAndCast( + RoundCeil(Literal.create(null, dataType), Literal(2))), null) } checkEvaluation(Round(2.5, 0), 3.0) @@ -705,6 +812,26 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(BRound(-3.5, 0), -4.0) checkEvaluation(BRound(-0.35, 1), -0.4) checkEvaluation(BRound(-35, -1), -40) + checkEvaluation(checkDataTypeAndCast(RoundFloor(Literal(2.5), Literal(0))), Decimal(2)) + checkEvaluation(checkDataTypeAndCast(RoundFloor(Literal(3.5), Literal(0))), Decimal(3)) + checkEvaluation(checkDataTypeAndCast(RoundFloor(Literal(-2.5), Literal(0))), Decimal(-3L)) + checkEvaluation(checkDataTypeAndCast(RoundFloor(Literal(-3.5), Literal(0))), Decimal(-4L)) + checkEvaluation(checkDataTypeAndCast(RoundFloor(Literal(-0.35), Literal(1))), Decimal(-0.4)) + checkEvaluation(checkDataTypeAndCast(RoundFloor(Literal(-35), Literal(-1))), Decimal(-40)) + checkEvaluation(checkDataTypeAndCast(RoundFloor(Literal(-0.1), Literal(0))), Decimal(-1)) + checkEvaluation(checkDataTypeAndCast(RoundFloor(Literal(5), Literal(0))), Decimal(5)) + checkEvaluation(checkDataTypeAndCast(RoundFloor(Literal(3.1411), Literal(-3))), Decimal(0)) + checkEvaluation(checkDataTypeAndCast(RoundFloor(Literal(135.135), Literal(-2))), Decimal(100)) + checkEvaluation(checkDataTypeAndCast(RoundCeil(Literal(2.5), Literal(0))), Decimal(3)) + checkEvaluation(checkDataTypeAndCast(RoundCeil(Literal(3.5), Literal(0))), Decimal(4L)) + checkEvaluation(checkDataTypeAndCast(RoundCeil(Literal(-2.5), Literal(0))), Decimal(-2L)) + checkEvaluation(checkDataTypeAndCast(RoundCeil(Literal(-3.5), Literal(0))), Decimal(-3L)) + checkEvaluation(checkDataTypeAndCast(RoundCeil(Literal(-0.35), Literal(1))), Decimal(-0.3)) + checkEvaluation(checkDataTypeAndCast(RoundCeil(Literal(-35), Literal(-1))), Decimal(-30)) + checkEvaluation(checkDataTypeAndCast(RoundCeil(Literal(-0.1), Literal(0))), Decimal(0)) + checkEvaluation(checkDataTypeAndCast(RoundCeil(Literal(5), Literal(0))), Decimal(5)) + checkEvaluation(checkDataTypeAndCast(RoundCeil(Literal(3.1411), Literal(-3))), Decimal(1000)) + checkEvaluation(checkDataTypeAndCast(RoundCeil(Literal(135.135), Literal(-2))), Decimal(200)) } test("SPARK-36922: Support ANSI intervals for SIGN/SIGNUM") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index 7abea96915d2f..da8e11c0433eb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -112,25 +113,31 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val timestampLit = Literal.create(Timestamp.valueOf("2017-04-12 00:00:00"), TimestampType) val decimalLit = Literal.create(BigDecimal.valueOf(10.2), DecimalType(20, 2)) - assert(analyze(new Nvl(decimalLit, stringLit)).dataType == StringType) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + assert(analyze(new Nvl(decimalLit, stringLit)).dataType == StringType) + } assert(analyze(new Nvl(doubleLit, decimalLit)).dataType == DoubleType) assert(analyze(new Nvl(decimalLit, doubleLit)).dataType == DoubleType) assert(analyze(new Nvl(decimalLit, floatLit)).dataType == DoubleType) assert(analyze(new Nvl(floatLit, decimalLit)).dataType == DoubleType) - assert(analyze(new Nvl(timestampLit, stringLit)).dataType == StringType) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + assert(analyze(new Nvl(timestampLit, stringLit)).dataType == StringType) + assert(analyze(new Nvl(intLit, stringLit)).dataType == StringType) + assert(analyze(new Nvl(stringLit, doubleLit)).dataType == StringType) + assert(analyze(new Nvl(doubleLit, stringLit)).dataType == StringType) + } assert(analyze(new Nvl(intLit, doubleLit)).dataType == DoubleType) - assert(analyze(new Nvl(intLit, stringLit)).dataType == StringType) - assert(analyze(new Nvl(stringLit, doubleLit)).dataType == StringType) - assert(analyze(new Nvl(doubleLit, stringLit)).dataType == StringType) assert(analyze(new Nvl(nullLit, intLit)).dataType == IntegerType) assert(analyze(new Nvl(doubleLit, nullLit)).dataType == DoubleType) assert(analyze(new Nvl(nullLit, stringLit)).dataType == StringType) - assert(analyze(new Nvl(floatLit, stringLit)).dataType == StringType) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + assert(analyze(new Nvl(floatLit, stringLit)).dataType == StringType) + assert(analyze(new Nvl(floatNullLit, intLit)).dataType == FloatType) + } assert(analyze(new Nvl(floatLit, doubleLit)).dataType == DoubleType) - assert(analyze(new Nvl(floatNullLit, intLit)).dataType == FloatType) } test("AtLeastNNonNulls") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 8d98965f2be81..585191faf18bc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -498,13 +498,17 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { (Array(3, 2, 1), ArrayType(IntegerType)) ).foreach { case (input, dt) => val validateType = ValidateExternalType( - GetExternalRowField(inputObject, index = 0, fieldName = "c0"), dt) + GetExternalRowField(inputObject, index = 0, fieldName = "c0"), + dt, + lenient = false) checkObjectExprEvaluation(validateType, input, InternalRow.fromSeq(Seq(Row(input)))) } checkExceptionInExpression[RuntimeException]( ValidateExternalType( - GetExternalRowField(inputObject, index = 0, fieldName = "c0"), DoubleType), + GetExternalRowField(inputObject, index = 0, fieldName = "c0"), + DoubleType, + lenient = false), InternalRow.fromSeq(Seq(Row(1))), "java.lang.Integer is not a valid external type for schema of double") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala index c67a9622b61fd..b64bc49f95446 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala @@ -31,7 +31,7 @@ class SchemaPruningSuite extends SparkFunSuite with SQLHelper { // `derivedFromAtt` doesn't affect the result of pruned schema. SchemaPruning.RootField(field = f, derivedFromAtt = true) } - val prunedSchema = SchemaPruning.pruneDataSchema(schema, requestedRootFields) + val prunedSchema = SchemaPruning.pruneSchema(schema, requestedRootFields) assert(prunedSchema === expectedSchema) } @@ -140,7 +140,7 @@ class SchemaPruningSuite extends SparkFunSuite with SQLHelper { assert(field.metadata.getString("foo") == "bar") val schema = StructType(Seq(field)) - val prunedSchema = SchemaPruning.pruneDataSchema(schema, rootFields) + val prunedSchema = SchemaPruning.pruneSchema(schema, rootFields) assert(prunedSchema.head.metadata.getString("foo") == "bar") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 443a94b2ee08c..b05142add0bab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.internal.SQLConf @@ -119,9 +120,12 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { testElt(null, null, "hello", "world") // Invalid ranges - testElt(null, 3, "hello", "world") - testElt(null, 0, "hello", "world") - testElt(null, -1, "hello", "world") + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + // ANSI will throw SparkArrayIndexOutOfBoundsException with invalid index + testElt(null, 3, "hello", "world") + testElt(null, 0, "hello", "world") + testElt(null, -1, "hello", "world") + } // type checking assert(Elt(Seq.empty).checkInputDataTypes().isFailure) @@ -888,6 +892,172 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) } + test("ToNumber") { + ToNumber(Literal("454"), Literal("")).checkInputDataTypes() match { + case TypeCheckResult.TypeCheckFailure(msg) => + assert(msg.contains("Number format cannot be empty")) + } + ToNumber(Literal("454"), NonFoldableLiteral.create("999", StringType)) + .checkInputDataTypes() match { + case TypeCheckResult.TypeCheckFailure(msg) => + assert(msg.contains("Format expression must be foldable")) + } + + // Test '0' and '9' + + Seq("454", "054", "54", "450").foreach { input => + val invalidFormat1 = 0.until(input.length - 1).map(_ => '0').mkString + val invalidFormat2 = 0.until(input.length - 2).map(_ => '0').mkString + val invalidFormat3 = 0.until(input.length - 1).map(_ => '9').mkString + val invalidFormat4 = 0.until(input.length - 2).map(_ => '9').mkString + Seq(invalidFormat1, invalidFormat2, invalidFormat3, invalidFormat4) + .filter(_.nonEmpty).foreach { format => + checkExceptionInExpression[IllegalArgumentException]( + ToNumber(Literal(input), Literal(format)), + s"The input string '$input' does not match the given number format: '$format'") + } + + val format1 = 0.until(input.length).map(_ => '0').mkString + val format2 = 0.until(input.length).map(_ => '9').mkString + val format3 = 0.until(input.length).map(i => i % 2 * 9).mkString + val format4 = 0.until(input.length + 1).map(_ => '0').mkString + val format5 = 0.until(input.length + 1).map(_ => '9').mkString + val format6 = 0.until(input.length + 1).map(i => i % 2 * 9).mkString + Seq(format1, format2, format3, format4, format5, format6).foreach { format => + checkEvaluation(ToNumber(Literal(input), Literal(format)), Decimal(input)) + } + } + + // Test '.' and 'D' + checkExceptionInExpression[IllegalArgumentException]( + ToNumber(Literal("454.2"), Literal("999")), + "The input string '454.2' does not match the given number format: '999'") + Seq("999.9", "000.0", "99.99", "00.00", "0000.0", "9999.9", "00.000", "99.999") + .foreach { format => + checkExceptionInExpression[IllegalArgumentException]( + ToNumber(Literal("454.23"), Literal(format)), + s"The input string '454.23' does not match the given number format: '$format'") + val format2 = format.replace('.', 'D') + checkExceptionInExpression[IllegalArgumentException]( + ToNumber(Literal("454.23"), Literal(format2)), + s"The input string '454.23' does not match the given number format: '$format2'") + } + + Seq( + ("454.2", "000.0") -> Decimal(454.2), + ("454.23", "000.00") -> Decimal(454.23), + ("454.2", "000.00") -> Decimal(454.2), + ("454.0", "000.0") -> Decimal(454), + ("454.00", "000.00") -> Decimal(454), + (".4542", ".0000") -> Decimal(0.4542), + ("4542.", "0000.") -> Decimal(4542) + ).foreach { case ((str, format), expected) => + checkEvaluation(ToNumber(Literal(str), Literal(format)), expected) + val format2 = format.replace('.', 'D') + checkEvaluation(ToNumber(Literal(str), Literal(format2)), expected) + val format3 = format.replace('0', '9') + checkEvaluation(ToNumber(Literal(str), Literal(format3)), expected) + val format4 = format3.replace('.', 'D') + checkEvaluation(ToNumber(Literal(str), Literal(format4)), expected) + } + + Seq("999.9.9", "999D9D9", "999.9D9", "999D9.9").foreach { str => + ToNumber(Literal("454.3.2"), Literal(str)).checkInputDataTypes() match { + case TypeCheckResult.TypeCheckFailure(msg) => + assert(msg.contains(s"At most one 'D' or '.' is allowed in the number format: '$str'")) + } + } + + // Test ',' and 'G' + checkExceptionInExpression[IllegalArgumentException]( + ToNumber(Literal("123,456"), Literal("9G9")), + "The input string '123,456' does not match the given number format: '9G9'") + checkExceptionInExpression[IllegalArgumentException]( + ToNumber(Literal("123,456,789"), Literal("999,999")), + "The input string '123,456,789' does not match the given number format: '999,999'") + + Seq( + ("12,454", "99,999") -> Decimal(12454), + ("12,454", "99,999,999") -> Decimal(12454), + ("12,454,367", "99,999,999") -> Decimal(12454367), + ("12,454,", "99,999,") -> Decimal(12454), + (",454,367", ",999,999") -> Decimal(454367), + (",454,367", "999,999") -> Decimal(454367) + ).foreach { case ((str, format), expected) => + checkEvaluation(ToNumber(Literal(str), Literal(format)), expected) + val format2 = format.replace(',', 'G') + checkEvaluation(ToNumber(Literal(str), Literal(format2)), expected) + val format3 = format.replace('9', '0') + checkEvaluation(ToNumber(Literal(str), Literal(format3)), expected) + val format4 = format3.replace(',', 'G') + checkEvaluation(ToNumber(Literal(str), Literal(format4)), expected) + val format5 = s"${format}9" + checkEvaluation(ToNumber(Literal(str), Literal(format5)), expected) + val format6 = s"${format}0" + checkEvaluation(ToNumber(Literal(str), Literal(format6)), expected) + val format7 = s"9${format}9" + checkEvaluation(ToNumber(Literal(str), Literal(format7)), expected) + val format8 = s"0${format}0" + checkEvaluation(ToNumber(Literal(str), Literal(format8)), expected) + val format9 = s"${format3}9" + checkEvaluation(ToNumber(Literal(str), Literal(format9)), expected) + val format10 = s"${format3}0" + checkEvaluation(ToNumber(Literal(str), Literal(format10)), expected) + val format11 = s"9${format3}9" + checkEvaluation(ToNumber(Literal(str), Literal(format11)), expected) + val format12 = s"0${format3}0" + checkEvaluation(ToNumber(Literal(str), Literal(format12)), expected) + } + + // Test '$' + Seq( + ("$78.12", "$99.99") -> Decimal(78.12), + ("$78.12", "$00.00") -> Decimal(78.12), + ("78.12$", "99.99$") -> Decimal(78.12), + ("78.12$", "00.00$") -> Decimal(78.12) + ).foreach { case ((str, format), expected) => + checkEvaluation(ToNumber(Literal(str), Literal(format)), expected) + } + + ToNumber(Literal("$78$.12"), Literal("$99$.99")).checkInputDataTypes() match { + case TypeCheckResult.TypeCheckFailure(msg) => + assert(msg.contains("At most one '$' is allowed in the number format: '$99$.99'")) + } + ToNumber(Literal("78$.12"), Literal("99$.99")).checkInputDataTypes() match { + case TypeCheckResult.TypeCheckFailure(msg) => + assert(msg.contains("'$' must be the first or last char in the number format: '99$.99'")) + } + + // Test '-' and 'S' + Seq( + ("454-", "999-") -> Decimal(-454), + ("-454", "-999") -> Decimal(-454), + ("12,454.8-", "99G999D9-") -> Decimal(-12454.8), + ("00,454.8-", "99G999.9-") -> Decimal(-454.8) + ).foreach { case ((str, format), expected) => + checkEvaluation(ToNumber(Literal(str), Literal(format)), expected) + val format2 = format.replace('9', '0') + checkEvaluation(ToNumber(Literal(str), Literal(format2)), expected) + val format3 = format.replace('-', 'S') + checkEvaluation(ToNumber(Literal(str), Literal(format3)), expected) + val format4 = format2.replace('-', 'S') + checkEvaluation(ToNumber(Literal(str), Literal(format4)), expected) + } + + ToNumber(Literal("454.3--"), Literal("999D9SS")).checkInputDataTypes() match { + case TypeCheckResult.TypeCheckFailure(msg) => + assert(msg.contains("At most one 'S' or '-' is allowed in the number format: '999D9SS'")) + } + + Seq("9S99", "9-99").foreach { str => + ToNumber(Literal("-454"), Literal(str)).checkInputDataTypes() match { + case TypeCheckResult.TypeCheckFailure(msg) => + assert(msg.contains( + s"'S' or '-' must be the first or last char in the number format: '$str'")) + } + } + } + test("find in set") { checkEvaluation( FindInSet(Literal.create(null, StringType), Literal.create(null, StringType)), null) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryEvalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryEvalSuite.scala index 4633b63cab7f0..1eccd46d960f7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryEvalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryEvalSuite.scala @@ -46,26 +46,28 @@ class TryEvalSuite extends SparkFunSuite with ExpressionEvalHelper { } } - test("try_element_at: array") { - val left = Literal(Array(1, 2, 3)) + test("try_subtract") { Seq( - (0, null), - (1, 1), - (4, null) - ).foreach { case (index, expected) => - val input = TryEval(ElementAt(left, Literal(index), failOnError = false)) + (1, 1, 0), + (Int.MaxValue, -1, null), + (Int.MinValue, 1, null) + ).foreach { case (a, b, expected) => + val left = Literal(a) + val right = Literal(b) + val input = TryEval(Subtract(left, right, failOnError = true)) checkEvaluation(input, expected) } } - test("try_element_at: map") { - val left = Literal.create(Map(1 -> 1)) + test("try_multiply") { Seq( - (0, null), - (1, 1), - (4, null) - ).foreach { case (index, expected) => - val input = TryEval(ElementAt(left, Literal(index), failOnError = false)) + (2, 3, 6), + (Int.MaxValue, -10, null), + (Int.MinValue, 10, null) + ).foreach { case (a, b, expected) => + val left = Literal(a) + val right = Literal(b) + val input = TryEval(Multiply(left, right, failOnError = true)) checkEvaluation(input, expected) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumericSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumericSuite.scala index 60b53c660f6ef..f603563ee3d0f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumericSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumericSuite.scala @@ -17,18 +17,25 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import java.sql.Timestamp +import java.time.{Duration, Period} + import org.apache.spark.SparkFunSuite +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.dsl.expressions.{DslString, DslSymbol} import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, BoundReference, Cast, GenericInternalRow, Literal} +import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.types.{DoubleType, IntegerType} +import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ import org.apache.spark.sql.util.NumericHistogram -class HistogramNumericSuite extends SparkFunSuite { +class HistogramNumericSuite extends SparkFunSuite with SQLHelper with Logging { private val random = new java.util.Random() @@ -76,7 +83,6 @@ class HistogramNumericSuite extends SparkFunSuite { } test("class HistogramNumeric, sql string") { - val defaultAccuracy = ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY assertEqual(s"histogram_numeric(a, 3)", new HistogramNumeric("a".attr, Literal(3)).sql: String) @@ -106,23 +112,47 @@ class HistogramNumericSuite extends SparkFunSuite { } test("class HistogramNumeric, automatically add type casting for parameters") { - val testRelation = LocalRelation('a.int) + // These are the types of input relations under test. We exercise the unit test with several + // input column types to inspect the behavior of query analysis for the aggregate function. + val relations = Seq(LocalRelation('a.double), + LocalRelation('a.int), + LocalRelation('a.timestamp), + LocalRelation('a.dayTimeInterval()), + LocalRelation('a.yearMonthInterval())) - // accuracy types must be integral, no type casting + // These are the types of the second 'nbins' argument to the aggregate function. + // These accuracy types must be integral, no type casting is allowed. val nBinsExpressions = Seq( Literal(2.toByte), Literal(100.toShort), Literal(100), Literal(1000L)) - nBinsExpressions.foreach { nBins => + // Iterate through each of the input relation column types and 'nbins' expression types under + // test. + for { + relation <- relations + nBins <- nBinsExpressions + } { + // We expect each relation under test to have exactly one output attribute. + assert(relation.output.length == 1) + val relationAttributeType = relation.output(0).dataType val agg = new HistogramNumeric(UnresolvedAttribute("a"), nBins) - val analyzed = testRelation.select(agg).analyze.expressions.head + val analyzed = relation.select(agg).analyze.expressions.head analyzed match { case Alias(agg: HistogramNumeric, _) => assert(agg.resolved) - assert(agg.child.dataType == IntegerType) + assert(agg.child.dataType == relationAttributeType) assert(agg.nBins.dataType == IntegerType) + // We expect the output type of the histogram aggregate function to be an array of structs + // where the first element of each struct has the same type as the original input + // attribute. + val expectedType = + ArrayType( + StructType(Seq( + StructField("x", relationAttributeType, nullable = true), + StructField("y", DoubleType, nullable = true)))) + assert(agg.dataType == expectedType) case _ => fail() } } @@ -151,6 +181,84 @@ class HistogramNumericSuite extends SparkFunSuite { assert(agg.eval(buffer) != null) } + test("class HistogramNumeric, exercise many different numeric input types") { + val inputs = Seq( + (Literal(null), + Literal(null), + Literal(null)), + (Literal(0), + Literal(1), + Literal(2)), + (Literal(0L), + Literal(1L), + Literal(2L)), + (Literal(0.toShort), + Literal(1.toShort), + Literal(2.toShort)), + (Literal(0F), + Literal(1F), + Literal(2F)), + (Literal(0D), + Literal(1D), + Literal(2D)), + (Literal(Timestamp.valueOf("2017-03-01 00:00:00")), + Literal(Timestamp.valueOf("2017-03-02 00:00:00")), + Literal(Timestamp.valueOf("2017-03-03 00:00:00"))), + (Literal(Duration.ofSeconds(1111)), + Literal(Duration.ofSeconds(1211)), + Literal(Duration.ofSeconds(1311))), + (Literal(Period.ofMonths(10)), + Literal(Period.ofMonths(11)), + Literal(Period.ofMonths(12)))) + for ((left, middle, right) <- inputs) { + // Check that the 'propagateInputType' bit correctly toggles the output type. + withSQLConf(SQLConf.HISTOGRAM_NUMERIC_PROPAGATE_INPUT_TYPE.key -> "false") { + val aggDoubleOutputType = new HistogramNumeric( + BoundReference(0, left.dataType, nullable = true), Literal(5)) + assert(aggDoubleOutputType.dataType match { + case ArrayType(StructType(Array( + StructField("x", DoubleType, _, _), + StructField("y", _, _, _))), true) => true + }) + } + val aggPropagateOutputType = new HistogramNumeric( + BoundReference(0, left.dataType, nullable = true), Literal(5)) + assert(aggPropagateOutputType.left.dataType == + (aggPropagateOutputType.dataType match { + case + ArrayType(StructType(Array( + StructField("x", lhs@_, true, _), + StructField("y", _, true, _))), true) => lhs + })) + // Now consume some input values and check the result. + val buffer = new GenericInternalRow(new Array[Any](1)) + aggPropagateOutputType.initialize(buffer) + // Consume three non-empty rows in the aggregation. + aggPropagateOutputType.update(buffer, InternalRow(left.value)) + aggPropagateOutputType.update(buffer, InternalRow(middle.value)) + aggPropagateOutputType.update(buffer, InternalRow(right.value)) + // Evaluate the aggregate function. + val result = aggPropagateOutputType.eval(buffer) + if (left.dataType != NullType) { + assert(result != null) + // Sanity-check the sum of the heights. + var ys = 0.0 + result match { + case v: GenericArrayData => + for (row <- v.array) { + row match { + case r: GenericInternalRow => + assert(r.values.length == 2) + ys += r.values(1).asInstanceOf[Double] + } + } + } + assert(ys > 1) + } + // As a basic sanity check, the sum of the heights of the bins should be greater than one. + } + } + private def compareEquals(left: NumericHistogram, right: NumericHistogram): Boolean = { left.getNumBins == right.getNumBins && left.getUsedBins == right.getUsedBins && (0 until left.getUsedBins).forall { i => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala index dd67a61015e72..c13cb33201ad7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala @@ -206,7 +206,7 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite { if (actualFixedLength !== expectedFixedLength) { actualFixedLength.grouped(8) .zip(expectedFixedLength.grouped(8)) - .zip(mergedSchema.fields.toIterator) + .zip(mergedSchema.fields.iterator) .foreach { case ((actual, expected), field) => assert(actual === expected, s"Fixed length sections are not equal for field $field") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index 7981dda495de4..1db04d2f5a7ce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -122,8 +123,7 @@ class AggregateOptimizeSuite extends AnalysisTest { Optimize.execute( x.join(y, LeftOuter, Some("x.a".attr === "y.a".attr)) .groupBy("x.a".attr)("x.a".attr, Literal(1)).analyze), - x.join(y, LeftOuter, Some("x.a".attr === "y.a".attr)) - .groupBy("x.a".attr)("x.a".attr, Literal(1)).analyze) + x.groupBy("x.a".attr)("x.a".attr, Literal(1)).analyze) } test("SPARK-37292: Removes outer join if it only has DISTINCT on streamed side with alias") { @@ -148,4 +148,17 @@ class AggregateOptimizeSuite extends AnalysisTest { x.select("x.b".attr.as("newAlias1"), "x.b".attr.as("newAlias2")) .groupBy("newAlias1".attr, "newAlias2".attr)("newAlias1".attr, "newAlias2".attr).analyze) } + + test("SPARK-38489: Aggregate.groupOnly support foldable expressions") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + comparePlans( + Optimize.execute( + Distinct(x.join(y, LeftOuter, Some("x.a".attr === "y.a".attr)) + .select("x.b".attr, TrueLiteral, FalseLiteral.as("newAlias"))) + .analyze), + x.select("x.b".attr, TrueLiteral, FalseLiteral.as("newAlias")) + .groupBy("x.b".attr)("x.b".attr, TrueLiteral, FalseLiteral.as("newAlias")) + .analyze) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 07f16f438cc56..41fc6e93cab4f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -138,8 +138,8 @@ class BooleanSimplificationSuite extends PlanTest with ExpressionEvalHelper with 'a > 1 && 'b > 3 && 'c > 1) checkCondition( - ('a > 1 || 'b > 3) && (('a > 1 || 'b > 3) && 'd > 0 && (('a > 1 || 'b > 3) && 'c > 1)), - ('a > 1 || 'b > 3) && 'd > 0 && 'c > 1) + ('a > 1 || 'b > 3) && (('a > 1 || 'b > 3) && 'd > 0L && (('a > 1 || 'b > 3) && 'c > 1)), + ('a > 1 || 'b > 3) && 'd > 0L && 'c > 1) checkCondition( 'a > 1 && 'b > 2 && 'a > 1 && 'c > 3, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala index 177545faa212f..dd5d6d48bcd3e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala @@ -207,4 +207,18 @@ class CollapseRepartitionSuite extends PlanTest { .distribute('a)(20) comparePlans(Optimize.execute(originalQuery2.analyze), originalQuery2.analyze) } + + test("SPARK-37904: Improve rebalance in CollapseRepartition") { + Seq(testRelation.sortBy($"a".asc), + testRelation.orderBy($"a".asc), + testRelation.coalesce(1), + testRelation.repartition(1), + testRelation.distribute($"a")(1), + testRelation.rebalance($"a")).foreach { prefix => + val plan = prefix.rebalance($"a").analyze + val optimized = Optimize.execute(plan) + val expected = testRelation.rebalance($"a").analyze + comparePlans(optimized, expected) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index 46e9dea730eb7..d3cbaa8c41e2d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -159,6 +159,19 @@ class CombiningLimitsSuite extends PlanTest { ) } + test("SPARK-38271: PoissonSampler may output more rows than child.maxRows") { + val query = testRelation.select().sample(0, 0.2, true, 1) + assert(query.maxRows.isEmpty) + val optimized = Optimize.execute(query.analyze) + assert(optimized.maxRows.isEmpty) + // can not eliminate Limit since Sample.maxRows is None + checkPlanAndMaxRow( + query.limit(10), + query.limit(10), + 10 + ) + } + test("SPARK-33497: Eliminate Limit if Deduplicate max rows not larger than Limit") { checkPlanAndMaxRow( testRelation.deduplicate("a".attr).limit(10), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index ae644c1110740..a2ee2a2fb6813 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -21,10 +21,13 @@ import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, Unresol import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, NewInstance, StaticInvoke} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.ByteArray class ConstantFoldingSuite extends PlanTest { @@ -141,7 +144,7 @@ class ConstantFoldingSuite extends PlanTest { testRelation .select( Cast(Literal("2"), IntegerType) + Literal(3) + 'a as Symbol("c1"), - Coalesce(Seq(Cast(Literal("abc"), IntegerType), Literal(3))) as Symbol("c2")) + Coalesce(Seq(TryCast(Literal("abc"), IntegerType), Literal(3))) as Symbol("c2")) val optimized = Optimize.execute(originalQuery.analyze) @@ -299,4 +302,41 @@ class ConstantFoldingSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("SPARK-37907: InvokeLike support ConstantFolding") { + val originalQuery = + testRelation + .select( + StaticInvoke( + classOf[ByteArray], + BinaryType, + "lpad", + Seq(Literal("Spark".getBytes), Literal(7), Literal("W".getBytes)), + Seq(BinaryType, IntegerType, BinaryType), + returnNullable = false).as("c1"), + Invoke( + Literal.create("a", StringType), + "substring", + StringType, + Seq(Literal(0), Literal(1))).as("c2"), + NewInstance( + cls = classOf[GenericArrayData], + arguments = Literal.fromObject(List(1, 2, 3)) :: Nil, + inputTypes = Nil, + propagateNull = false, + dataType = ArrayType(IntegerType), + outerPointer = None).as("c3")) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = + testRelation + .select( + Literal("WWSpark".getBytes()).as("c1"), + Literal.create("a", StringType).as("c2"), + Literal.create(new GenericArrayData(List(1, 2, 3)), ArrayType(IntegerType)).as("c3")) + .analyze + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala index b8886a5c0b2fe..c74eeea349b2c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala @@ -37,7 +37,7 @@ class DecorrelateInnerQuerySuite extends PlanTest { val testRelation2 = LocalRelation(x, y, z) private def hasOuterReferences(plan: LogicalPlan): Boolean = { - plan.find(_.expressions.exists(SubExprUtils.containsOuter)).isDefined + plan.exists(_.expressions.exists(SubExprUtils.containsOuter)) } private def check( @@ -282,4 +282,18 @@ class DecorrelateInnerQuerySuite extends PlanTest { ).analyze check(innerPlan, outerPlan, correctAnswer, Seq(y <=> y, x === a, y === z)) } + + test("SPARK-38155: distinct with non-equality correlated predicates") { + val outerPlan = testRelation2 + val innerPlan = + Distinct( + Project(Seq(b), + Filter(OuterReference(x) > a, testRelation))) + val correctAnswer = + Distinct( + Project(Seq(b, x), + Filter(x > a, + DomainJoin(Seq(x), testRelation)))) + check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateAggregateFilterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateAggregateFilterSuite.scala index ec9b876f78e1d..1bd4550e2c077 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateAggregateFilterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateAggregateFilterSuite.scala @@ -72,4 +72,15 @@ class EliminateAggregateFilterSuite extends PlanTest { comparePlans(Optimize.execute(query), answer) } + test("SPARK-38177: Eliminate Filter in non-root node") { + val query = testRelation + .select(countDistinctWithFilter(GreaterThan(Literal(1), Literal(2)), 'a).as('result)) + .limit(1) + .analyze + val answer = testRelation + .groupBy()(Literal.create(0L, LongType).as('result)) + .limit(1) + .analyze + comparePlans(Optimize.execute(query), answer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala index 08773720d717b..cf4761d561162 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala @@ -57,5 +57,18 @@ class EliminateDistinctSuite extends PlanTest { assert(query != answer) comparePlans(Optimize.execute(query), answer) } + + test(s"SPARK-38177: Eliminate Distinct in non-root $agg") { + val query = testRelation + .select(agg.toAggregateExpression(isDistinct = true).as('result)) + .limit(1) + .analyze + val answer = testRelation + .select(agg.toAggregateExpression(isDistinct = false).as('result)) + .limit(1) + .analyze + assert(query != answer) + comparePlans(Optimize.execute(query), answer) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsBeforeRepartitionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsBeforeRepartitionSuite.scala index bbb860086557a..5927cc2dfff6d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsBeforeRepartitionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsBeforeRepartitionSuite.scala @@ -196,3 +196,19 @@ class EliminateSortsBeforeRepartitionByExprsSuite extends EliminateSortsBeforeRe class EliminateSortsBeforeCoalesceSuite extends EliminateSortsBeforeRepartitionSuite { override def repartition(plan: LogicalPlan): LogicalPlan = plan.coalesce(1) } + +class EliminateSortsBeforeRebalanceSuite extends EliminateSortsBeforeRepartitionSuite { + override def repartition(plan: LogicalPlan): LogicalPlan = plan.rebalance($"a") + + test("sortBy before rebalance with non-deterministic expressions") { + val plan = testRelation.sortBy($"a".asc, $"b".asc).limit(10) + val planWithRepartition = plan.rebalance(rand(1).asc, $"a".asc) + checkRepartitionCases(plan = planWithRepartition, optimizedPlan = planWithRepartition) + } + + test("orderBy before rebalance with non-deterministic expressions") { + val plan = testRelation.orderBy($"a".asc, $"b".asc).limit(10) + val planWithRebalance = plan.rebalance(rand(1).asc, $"a".asc) + checkRepartitionCases(plan = planWithRebalance, optimizedPlan = planWithRebalance) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala index 6dc464c1cd582..01ecbd808c251 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -422,14 +422,4 @@ class EliminateSortsSuite extends AnalysisTest { comparePlans(optimized, correctAnswer) } } - - test("SPARK-35906: Remove order by if the maximum number of rows less than or equal to 1") { - comparePlans( - Optimize.execute(testRelation.groupBy()(count(1).as("cnt")).orderBy('cnt.asc)).analyze, - testRelation.groupBy()(count(1).as("cnt")).analyze) - - comparePlans( - Optimize.execute(testRelation.limit(Literal(1)).orderBy('a.asc).orderBy('a.asc)).analyze, - testRelation.limit(Literal(1)).analyze) - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExtractPythonUDFFromJoinConditionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExtractPythonUDFFromJoinConditionSuite.scala index 77bfc0b3682a3..65c8f5d300c62 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExtractPythonUDFFromJoinConditionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExtractPythonUDFFromJoinConditionSuite.scala @@ -188,7 +188,8 @@ class ExtractPythonUDFFromJoinConditionSuite extends PlanTest { Optimize.execute(query.analyze) } assert(e.message.contentEquals( - s"Using PythonUDF in join condition of join type $joinType is not supported.")) + "The feature is not supported: " + + s"Using PythonUDF in join condition of join type $joinType is not supported")) val query2 = testRelationLeft.join( testRelationRight, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala index 92e4fa345e2ad..732c50e225550 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala @@ -204,4 +204,12 @@ class FoldablePropagationSuite extends PlanTest { .select('a, 'b, Literal(1).as('c)).analyze comparePlans(optimized, correctAnswer) } + + test("SPARK-37904: Improve rebalance in FoldablePropagation") { + val foldableAttr = Literal(1).as("x") + val plan = testRelation.select(foldableAttr, $"a").rebalance($"x", $"a").analyze + val optimized = Optimize.execute(plan) + val expected = testRelation.select(foldableAttr, $"a").rebalance(foldableAttr, $"a").analyze + comparePlans(optimized, expected) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index 848416b09813e..4cfc90a7d32fd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -216,9 +216,9 @@ class LimitPushdownSuite extends PlanTest { test("SPARK-34514: Push down limit through LEFT SEMI and LEFT ANTI join") { // Push down when condition is empty Seq(LeftSemi, LeftAnti).foreach { joinType => - val originalQuery = x.join(y, joinType).limit(1) + val originalQuery = x.join(y, joinType).limit(5) val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = Limit(1, LocalLimit(1, x).join(y, joinType)).analyze + val correctAnswer = Limit(5, LocalLimit(5, x).join(LocalLimit(1, y), joinType)).analyze comparePlans(optimized, correctAnswer) } @@ -254,6 +254,13 @@ class LimitPushdownSuite extends PlanTest { Optimize.execute(x.union(y).groupBy("x.a".attr)("x.a".attr).limit(1).analyze), LocalLimit(1, LocalLimit(1, x).union(LocalLimit(1, y))).select("x.a".attr).limit(1).analyze) + comparePlans( + Optimize.execute( + x.groupBy("x.a".attr)("x.a".attr) + .select("x.a".attr.as("a1"), "x.a".attr.as("a2")).limit(1).analyze), + LocalLimit(1, x).select("x.a".attr) + .select("x.a".attr.as("a1"), "x.a".attr.as("a2")).limit(1).analyze) + // No push down comparePlans( Optimize.execute(x.groupBy("x.a".attr)("x.a".attr).limit(2).analyze), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala index 40ab72c89f3bf..ff3414d901208 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala @@ -790,6 +790,28 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { comparePlans(optimized, expected) } + + test("SPARK-37904: Improve rebalance in NestedColumnAliasing") { + // alias nested columns through rebalance + val plan1 = contact.rebalance($"id").select($"name.first").analyze + val optimized1 = Optimize.execute(plan1) + val expected1 = contact.select($"id", $"name.first".as("_extract_first")) + .rebalance($"id").select($"_extract_first".as("first")).analyze + comparePlans(optimized1, expected1) + + // also alias rebalance nested columns + val plan2 = contact.rebalance($"name.first").select($"name.first").analyze + val optimized2 = Optimize.execute(plan2) + val expected2 = contact.select($"name.first".as("_extract_first")) + .rebalance($"_extract_first".as("first")).select($"_extract_first".as("first")).analyze + comparePlans(optimized2, expected2) + + // do not alias nested columns if its child contains root reference + val plan3 = contact.rebalance($"name").select($"name.first").analyze + val optimized3 = Optimize.execute(plan3) + val expected3 = contact.select($"name").rebalance($"name").select($"name.first").analyze + comparePlans(optimized3, expected3) + } } object NestedColumnAliasingSuite { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NotPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NotPropagationSuite.scala deleted file mode 100644 index d9506098b1d00..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NotPropagationSuite.scala +++ /dev/null @@ -1,176 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.types.BooleanType - -class NotPropagationSuite extends PlanTest with ExpressionEvalHelper { - - object Optimize extends RuleExecutor[LogicalPlan] { - val batches = - Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: - Batch("Not Propagation", FixedPoint(50), - NullPropagation, - NullDownPropagation, - ConstantFolding, - SimplifyConditionals, - BooleanSimplification, - NotPropagation, - PruneFilters) :: Nil - } - - val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string, - 'e.boolean, 'f.boolean, 'g.boolean, 'h.boolean) - - val testRelationWithData = LocalRelation.fromExternalRows( - testRelation.output, Seq(Row(1, 2, 3, "abc")) - ) - - private def checkCondition(input: Expression, expected: LogicalPlan): Unit = { - val plan = testRelationWithData.where(input).analyze - val actual = Optimize.execute(plan) - comparePlans(actual, expected) - } - - private def checkCondition(input: Expression, expected: Expression): Unit = { - val plan = testRelation.where(input).analyze - val actual = Optimize.execute(plan) - val correctAnswer = testRelation.where(expected).analyze - comparePlans(actual, correctAnswer) - } - - test("Using (Not(a) === b) == (a === Not(b)), (Not(a) <=> b) == (a <=> Not(b)) rules") { - checkCondition(Not('e) === Literal(true), 'e === Literal(false)) - checkCondition(Not('e) === Literal(false), 'e === Literal(true)) - checkCondition(Not('e) === Literal(null, BooleanType), testRelation) - checkCondition(Literal(true) === Not('e), Literal(false) === 'e) - checkCondition(Literal(false) === Not('e), Literal(true) === 'e) - checkCondition(Literal(null, BooleanType) === Not('e), testRelation) - checkCondition(Not('e) <=> Literal(true), 'e <=> Literal(false)) - checkCondition(Not('e) <=> Literal(false), 'e <=> Literal(true)) - checkCondition(Not('e) <=> Literal(null, BooleanType), IsNull('e)) - checkCondition(Literal(true) <=> Not('e), Literal(false) <=> 'e) - checkCondition(Literal(false) <=> Not('e), Literal(true) <=> 'e) - checkCondition(Literal(null, BooleanType) <=> Not('e), IsNull('e)) - - checkCondition(Not('e) === Not('f), 'e === 'f) - checkCondition(Not('e) <=> Not('f), 'e <=> 'f) - - checkCondition(IsNull('e) === Not('f), IsNotNull('e) === 'f) - checkCondition(Not('e) === IsNull('f), 'e === IsNotNull('f)) - checkCondition(IsNull('e) <=> Not('f), IsNotNull('e) <=> 'f) - checkCondition(Not('e) <=> IsNull('f), 'e <=> IsNotNull('f)) - - checkCondition(IsNotNull('e) === Not('f), IsNull('e) === 'f) - checkCondition(Not('e) === IsNotNull('f), 'e === IsNull('f)) - checkCondition(IsNotNull('e) <=> Not('f), IsNull('e) <=> 'f) - checkCondition(Not('e) <=> IsNotNull('f), 'e <=> IsNull('f)) - - checkCondition(Not('e) === Not(And('f, 'g)), 'e === And('f, 'g)) - checkCondition(Not(And('e, 'f)) === Not('g), And('e, 'f) === 'g) - checkCondition(Not('e) <=> Not(And('f, 'g)), 'e <=> And('f, 'g)) - checkCondition(Not(And('e, 'f)) <=> Not('g), And('e, 'f) <=> 'g) - - checkCondition(Not('e) === Not(Or('f, 'g)), 'e === Or('f, 'g)) - checkCondition(Not(Or('e, 'f)) === Not('g), Or('e, 'f) === 'g) - checkCondition(Not('e) <=> Not(Or('f, 'g)), 'e <=> Or('f, 'g)) - checkCondition(Not(Or('e, 'f)) <=> Not('g), Or('e, 'f) <=> 'g) - - checkCondition(('a > 'b) === Not('f), ('a <= 'b) === 'f) - checkCondition(Not('e) === ('a > 'b), 'e === ('a <= 'b)) - checkCondition(('a > 'b) <=> Not('f), ('a <= 'b) <=> 'f) - checkCondition(Not('e) <=> ('a > 'b), 'e <=> ('a <= 'b)) - - checkCondition(('a >= 'b) === Not('f), ('a < 'b) === 'f) - checkCondition(Not('e) === ('a >= 'b), 'e === ('a < 'b)) - checkCondition(('a >= 'b) <=> Not('f), ('a < 'b) <=> 'f) - checkCondition(Not('e) <=> ('a >= 'b), 'e <=> ('a < 'b)) - - checkCondition(('a < 'b) === Not('f), ('a >= 'b) === 'f) - checkCondition(Not('e) === ('a < 'b), 'e === ('a >= 'b)) - checkCondition(('a < 'b) <=> Not('f), ('a >= 'b) <=> 'f) - checkCondition(Not('e) <=> ('a < 'b), 'e <=> ('a >= 'b)) - - checkCondition(('a <= 'b) === Not('f), ('a > 'b) === 'f) - checkCondition(Not('e) === ('a <= 'b), 'e === ('a > 'b)) - checkCondition(('a <= 'b) <=> Not('f), ('a > 'b) <=> 'f) - checkCondition(Not('e) <=> ('a <= 'b), 'e <=> ('a > 'b)) - } - - test("Using (a =!= b) == (a === Not(b)), Not(a <=> b) == (a <=> Not(b)) rules") { - checkCondition('e =!= Literal(true), 'e === Literal(false)) - checkCondition('e =!= Literal(false), 'e === Literal(true)) - checkCondition('e =!= Literal(null, BooleanType), testRelation) - checkCondition(Literal(true) =!= 'e, Literal(false) === 'e) - checkCondition(Literal(false) =!= 'e, Literal(true) === 'e) - checkCondition(Literal(null, BooleanType) =!= 'e, testRelation) - checkCondition(Not(('a <=> 'b) <=> Literal(true)), ('a <=> 'b) <=> Literal(false)) - checkCondition(Not(('a <=> 'b) <=> Literal(false)), ('a <=> 'b) <=> Literal(true)) - checkCondition(Not(('a <=> 'b) <=> Literal(null, BooleanType)), testRelationWithData) - checkCondition(Not(Literal(true) <=> ('a <=> 'b)), Literal(false) <=> ('a <=> 'b)) - checkCondition(Not(Literal(false) <=> ('a <=> 'b)), Literal(true) <=> ('a <=> 'b)) - checkCondition(Not(Literal(null, BooleanType) <=> IsNull('e)), testRelationWithData) - - checkCondition('e =!= Not('f), 'e === 'f) - checkCondition(Not('e) =!= 'f, 'e === 'f) - checkCondition(Not(('a <=> 'b) <=> Not(('b <=> 'c))), ('a <=> 'b) <=> ('b <=> 'c)) - checkCondition(Not(Not(('a <=> 'b)) <=> ('b <=> 'c)), ('a <=> 'b) <=> ('b <=> 'c)) - - checkCondition('e =!= IsNull('f), 'e === IsNotNull('f)) - checkCondition(IsNull('e) =!= 'f, IsNotNull('e) === 'f) - checkCondition(Not(('a <=> 'b) <=> IsNull('f)), ('a <=> 'b) <=> IsNotNull('f)) - checkCondition(Not(IsNull('e) <=> ('b <=> 'c)), IsNotNull('e) <=> ('b <=> 'c)) - - checkCondition('e =!= IsNotNull('f), 'e === IsNull('f)) - checkCondition(IsNotNull('e) =!= 'f, IsNull('e) === 'f) - checkCondition(Not(('a <=> 'b) <=> IsNotNull('f)), ('a <=> 'b) <=> IsNull('f)) - checkCondition(Not(IsNotNull('e) <=> ('b <=> 'c)), IsNull('e) <=> ('b <=> 'c)) - - checkCondition('e =!= Not(And('f, 'g)), 'e === And('f, 'g)) - checkCondition(Not(And('e, 'f)) =!= 'g, And('e, 'f) === 'g) - checkCondition('e =!= Not(Or('f, 'g)), 'e === Or('f, 'g)) - checkCondition(Not(Or('e, 'f)) =!= 'g, Or('e, 'f) === 'g) - - checkCondition(('a > 'b) =!= 'f, ('a <= 'b) === 'f) - checkCondition('e =!= ('a > 'b), 'e === ('a <= 'b)) - checkCondition(('a >= 'b) =!= 'f, ('a < 'b) === 'f) - checkCondition('e =!= ('a >= 'b), 'e === ('a < 'b)) - checkCondition(('a < 'b) =!= 'f, ('a >= 'b) === 'f) - checkCondition('e =!= ('a < 'b), 'e === ('a >= 'b)) - checkCondition(('a <= 'b) =!= 'f, ('a > 'b) === 'f) - checkCondition('e =!= ('a <= 'b), 'e === ('a > 'b)) - - checkCondition('e =!= ('f === ('g === Not('h))), 'e === ('f === ('g === 'h))) - - } - - test("Properly avoid non optimize-able cases") { - checkCondition(Not(('a > 'b) <=> 'f), Not(('a > 'b) <=> 'f)) - checkCondition(Not('e <=> ('a > 'b)), Not('e <=> ('a > 'b))) - checkCondition(('a === 'b) =!= ('a === 'c), ('a === 'b) =!= ('a === 'c)) - checkCondition(('a === 'b) =!= ('c in(1, 2, 3)), ('a === 'b) =!= ('c in(1, 2, 3))) - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NullDownPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NullDownPropagationSuite.scala index c9d1f3357dc8a..7097ebd4c0c63 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NullDownPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NullDownPropagationSuite.scala @@ -36,7 +36,6 @@ class NullDownPropagationSuite extends PlanTest with ExpressionEvalHelper { ConstantFolding, SimplifyConditionals, BooleanSimplification, - NotPropagation, PruneFilters) :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlanSuite.scala new file mode 100644 index 0000000000000..3266febb9ed69 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlanSuite.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class OptimizeOneRowPlanSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Replace Operators", Once, ReplaceDistinctWithAggregate) :: + Batch("Eliminate Sorts", Once, EliminateSorts) :: + Batch("Optimize One Row Plan", FixedPoint(10), OptimizeOneRowPlan) :: Nil + } + + private val t1 = LocalRelation.fromExternalRows(Seq($"a".int), data = Seq(Row(1))) + private val t2 = LocalRelation.fromExternalRows(Seq($"a".int), data = Seq(Row(1), Row(2))) + + test("SPARK-35906: Remove order by if the maximum number of rows less than or equal to 1") { + comparePlans( + Optimize.execute(t2.groupBy()(count(1).as("cnt")).orderBy('cnt.asc)).analyze, + t2.groupBy()(count(1).as("cnt")).analyze) + + comparePlans( + Optimize.execute(t2.limit(Literal(1)).orderBy('a.asc).orderBy('a.asc)).analyze, + t2.limit(Literal(1)).analyze) + } + + test("Remove sort") { + // remove local sort + val plan1 = LocalLimit(0, t1).union(LocalLimit(0, t2)).sortBy($"a".desc).analyze + val expected = LocalLimit(0, t1).union(LocalLimit(0, t2)).analyze + comparePlans(Optimize.execute(plan1), expected) + + // do not remove + val plan2 = t2.orderBy($"a".desc).analyze + comparePlans(Optimize.execute(plan2), plan2) + + val plan3 = t2.sortBy($"a".desc).analyze + comparePlans(Optimize.execute(plan3), plan3) + } + + test("Convert group only aggregate to project") { + val plan1 = t1.groupBy($"a")($"a").analyze + comparePlans(Optimize.execute(plan1), t1.select($"a").analyze) + + val plan2 = t1.groupBy($"a" + 1)($"a" + 1).analyze + comparePlans(Optimize.execute(plan2), t1.select($"a" + 1).analyze) + + // do not remove + val plan3 = t2.groupBy($"a")($"a").analyze + comparePlans(Optimize.execute(plan3), plan3) + + val plan4 = t1.groupBy($"a")(sum($"a")).analyze + comparePlans(Optimize.execute(plan4), plan4) + + val plan5 = t1.groupBy()(sum($"a")).analyze + comparePlans(Optimize.execute(plan5), plan5) + } + + test("Remove distinct in aggregate expression") { + val plan1 = t1.groupBy($"a")(sumDistinct($"a").as("s")).analyze + val expected1 = t1.groupBy($"a")(sum($"a").as("s")).analyze + comparePlans(Optimize.execute(plan1), expected1) + + val plan2 = t1.groupBy()(sumDistinct($"a").as("s")).analyze + val expected2 = t1.groupBy()(sum($"a").as("s")).analyze + comparePlans(Optimize.execute(plan2), expected2) + + // do not remove + val plan3 = t2.groupBy($"a")(sumDistinct($"a").as("s")).analyze + comparePlans(Optimize.execute(plan3), plan3) + } + + test("Remove in complex case") { + val plan1 = t1.groupBy($"a")($"a").orderBy($"a".asc).analyze + val expected1 = t1.select($"a").analyze + comparePlans(Optimize.execute(plan1), expected1) + + val plan2 = t1.groupBy($"a")(sumDistinct($"a").as("s")).orderBy($"s".asc).analyze + val expected2 = t1.groupBy($"a")(sum($"a").as("s")).analyze + comparePlans(Optimize.execute(plan2), expected2) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index 1aa4f4cbceae8..8277e44458bb1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -294,4 +294,19 @@ class PropagateEmptyRelationSuite extends PlanTest { val expected = LocalRelation.fromExternalRows(Seq('a.int, 'b.int, 'c.int), Nil) comparePlans(optimized, expected) } + + test("SPARK-37904: Improve rebalance in PropagateEmptyRelation") { + val emptyRelation = LocalRelation($"a".int) + val expected = emptyRelation.analyze + + // test root node + val plan1 = emptyRelation.rebalance($"a").analyze + val optimized1 = Optimize.execute(plan1) + comparePlans(optimized1, expected) + + // test non-root node + val plan2 = emptyRelation.rebalance($"a").where($"a" > 0).select($"a").analyze + val optimized2 = Optimize.execute(plan2) + comparePlans(optimized2, expected) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala index 250a62d5eeb0b..7b9041a904a60 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala @@ -342,10 +342,12 @@ class PushFoldableIntoBranchesSuite assertEquivalent( EqualTo(CaseWhen(Seq((a, Literal.create(null, IntegerType)))), Literal(2)), Literal.create(null, BooleanType)) - assertEquivalent( - EqualTo(CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal("str")))).cast(IntegerType), - Literal(2)), - CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal.create(null, BooleanType))))) + if (!conf.ansiEnabled) { + assertEquivalent( + EqualTo(CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal("str")))).cast(IntegerType), + Literal(2)), + CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal.create(null, BooleanType))))) + } } test("SPARK-33884: simplify CaseWhen clauses with (true and false) and (false and true)") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala index d11ff16229b14..963332103b6cb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala @@ -21,8 +21,9 @@ import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Expression, PythonUDF} -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral +import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{Distinct, LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.types.IntegerType @@ -33,6 +34,10 @@ class RemoveRedundantAggregatesSuite extends PlanTest { RemoveRedundantAggregates) :: Nil } + private val relation = LocalRelation('a.int, 'b.int) + private val x = relation.subquery('x) + private val y = relation.subquery('y) + private def aggregates(e: Expression): Seq[Expression] = { Seq( count(e), @@ -42,7 +47,6 @@ class RemoveRedundantAggregatesSuite extends PlanTest { } test("Remove redundant aggregate") { - val relation = LocalRelation('a.int, 'b.int) for (agg <- aggregates('b)) { val query = relation .groupBy('a)('a, agg) @@ -57,7 +61,6 @@ class RemoveRedundantAggregatesSuite extends PlanTest { } test("Remove 2 redundant aggregates") { - val relation = LocalRelation('a.int, 'b.int) for (agg <- aggregates('b)) { val query = relation .groupBy('a)('a, agg) @@ -73,7 +76,6 @@ class RemoveRedundantAggregatesSuite extends PlanTest { } test("Remove redundant aggregate with different grouping") { - val relation = LocalRelation('a.int, 'b.int) val query = relation .groupBy('a, 'b)('a) .groupBy('a)('a) @@ -86,7 +88,6 @@ class RemoveRedundantAggregatesSuite extends PlanTest { } test("Remove redundant aggregate with aliases") { - val relation = LocalRelation('a.int, 'b.int) for (agg <- aggregates('b)) { val query = relation .groupBy('a + 'b)(('a + 'b) as 'c, agg) @@ -101,7 +102,6 @@ class RemoveRedundantAggregatesSuite extends PlanTest { } test("Remove redundant aggregate with non-deterministic upper") { - val relation = LocalRelation('a.int, 'b.int) val query = relation .groupBy('a)('a) .groupBy('a)('a, rand(0) as 'c) @@ -114,7 +114,6 @@ class RemoveRedundantAggregatesSuite extends PlanTest { } test("Remove redundant aggregate with non-deterministic lower") { - val relation = LocalRelation('a.int, 'b.int) val query = relation .groupBy('a, 'c)('a, rand(0) as 'c) .groupBy('a, 'c)('a, 'c) @@ -127,7 +126,6 @@ class RemoveRedundantAggregatesSuite extends PlanTest { } test("Keep non-redundant aggregate - upper has duplicate sensitive agg expression") { - val relation = LocalRelation('a.int, 'b.int) for (agg <- aggregates('b)) { val query = relation .groupBy('a, 'b)('a, 'b) @@ -140,7 +138,6 @@ class RemoveRedundantAggregatesSuite extends PlanTest { } test("Remove redundant aggregate - upper has duplicate agnostic agg expression") { - val relation = LocalRelation('a.int, 'b.int) val query = relation .groupBy('a, 'b)('a, 'b) // The max and countDistinct does not change if there are duplicate values @@ -153,8 +150,14 @@ class RemoveRedundantAggregatesSuite extends PlanTest { comparePlans(optimized, expected) } + test("Remove redundant aggregate - upper has contains foldable expressions") { + val originalQuery = x.groupBy('a, 'b)('a, 'b).groupBy('a)('a, TrueLiteral).analyze + val correctAnswer = x.groupBy('a)('a, TrueLiteral).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + test("Keep non-redundant aggregate - upper references agg expression") { - val relation = LocalRelation('a.int, 'b.int) for (agg <- aggregates('b)) { val query = relation .groupBy('a)('a, agg as 'c) @@ -165,13 +168,123 @@ class RemoveRedundantAggregatesSuite extends PlanTest { } } - test("Keep non-redundant aggregate - upper references non-deterministic non-grouping") { - val relation = LocalRelation('a.int, 'b.int) + test("Remove non-redundant aggregate - upper references non-deterministic non-grouping") { val query = relation .groupBy('a)('a, ('a + rand(0)) as 'c) .groupBy('a, 'c)('a, 'c) .analyze + val expected = relation + .groupBy('a)('a, ('a + rand(0)) as 'c) + .select('a, 'c) + .analyze val optimized = Optimize.execute(query) - comparePlans(optimized, query) + comparePlans(optimized, expected) + } + + test("SPARK-36194: Remove aggregation from left semi/anti join if aggregation the same") { + Seq(LeftSemi, LeftAnti).foreach { joinType => + val originalQuery = x.groupBy('a, 'b)('a, 'b) + .join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)) + .groupBy("x.a".attr, "x.b".attr)("x.a".attr, "x.b".attr) + val correctAnswer = x.groupBy('a, 'b)('a, 'b) + .join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)) + .select("x.a".attr, "x.b".attr) + + val optimized = Optimize.execute(originalQuery.analyze) + comparePlans(optimized, correctAnswer.analyze) + } + } + + test("SPARK-36194: Remove aggregation from left semi/anti join with alias") { + Seq(LeftSemi, LeftAnti).foreach { joinType => + val originalQuery = x.groupBy('a, 'b)('a, 'b.as("d")) + .join(y, joinType, Some("x.a".attr === "y.a".attr && "d".attr === "y.b".attr)) + .groupBy("x.a".attr, "d".attr)("x.a".attr, "d".attr) + val correctAnswer = x.groupBy('a, 'b)('a, 'b.as("d")) + .join(y, joinType, Some("x.a".attr === "y.a".attr && "d".attr === "y.b".attr)) + .select("x.a".attr, "d".attr) + + val optimized = Optimize.execute(originalQuery.analyze) + comparePlans(optimized, correctAnswer.analyze) + } + } + + test("SPARK-36194: Remove aggregation from left semi/anti join if it is the sub aggregateExprs") { + Seq(LeftSemi, LeftAnti).foreach { joinType => + val originalQuery = x.groupBy('a, 'b)('a, 'b) + .join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)) + .groupBy("x.a".attr, "x.b".attr)("x.a".attr) + val correctAnswer = x.groupBy('a, 'b)('a, 'b) + .join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)) + .select("x.a".attr) + + val optimized = Optimize.execute(originalQuery.analyze) + comparePlans(optimized, correctAnswer.analyze) + } + } + + test("SPARK-36194: Transform down to remove more aggregates") { + Seq(LeftSemi, LeftAnti).foreach { joinType => + val originalQuery = x.groupBy('a, 'b)('a, 'b) + .join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)) + .groupBy("x.a".attr, "x.b".attr)("x.a".attr, "x.b".attr) + .join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)) + .groupBy("x.a".attr, "x.b".attr)("x.a".attr) + val correctAnswer = x.groupBy('a, 'b)('a, 'b) + .join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)) + .select("x.a".attr, "x.b".attr) + .join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)) + .select("x.a".attr) + + val optimized = Optimize.execute(originalQuery.analyze) + comparePlans(optimized, correctAnswer.analyze) + } + } + + test("SPARK-36194: Child distinct keys is the subset of required keys") { + val originalQuery = relation + .groupBy('a)('a, count('b).as("cnt")) + .groupBy('a, 'cnt)('a, 'cnt) + .analyze + val correctAnswer = relation + .groupBy('a)('a, count('b).as("cnt")) + .select('a, 'cnt) + .analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("SPARK-36194: Child distinct keys are subsets and aggregateExpressions are foldable") { + val originalQuery = x.groupBy('a, 'b)('a, 'b) + .join(y, LeftSemi, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)) + .groupBy("x.a".attr, "x.b".attr)(TrueLiteral) + .analyze + val correctAnswer = x.groupBy('a, 'b)('a, 'b) + .join(y, LeftSemi, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)) + .select(TrueLiteral) + .analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("SPARK-36194: Negative case: child distinct keys is not the subset of required keys") { + Seq(LeftSemi, LeftAnti).foreach { joinType => + val originalQuery1 = x.groupBy('a, 'b)('a, 'b) + .join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)) + .groupBy("x.a".attr)("x.a".attr) + .analyze + comparePlans(Optimize.execute(originalQuery1), originalQuery1) + + val originalQuery2 = x.groupBy('a, 'b)('a, 'b) + .join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)) + .groupBy("x.a".attr)(count("x.b".attr)) + .analyze + comparePlans(Optimize.execute(originalQuery2), originalQuery2) + } + } + + test("SPARK-36194: Negative case: child distinct keys is empty") { + val originalQuery = Distinct(x.groupBy('a, 'b)('a, TrueLiteral)).analyze + comparePlans(Optimize.execute(originalQuery), originalQuery) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index 3fa7df3c94949..c4113e734c704 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, GreaterThan, GreaterThanO import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.types.BooleanType +import org.apache.spark.sql.types.{BooleanType, DecimalType} class SetOperationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -328,4 +328,66 @@ class SetOperationSuite extends PlanTest { Union(testRelation :: testRelation :: testRelation :: testRelation :: Nil, true, false) comparePlans(unionOptimized2, unionCorrectAnswer2, false) } + + test("SPARK-37915: combine unions if there is a project between them") { + val relation1 = LocalRelation('a.decimal(18, 1), 'b.int) + val relation2 = LocalRelation('a.decimal(18, 2), 'b.int) + val relation3 = LocalRelation('a.decimal(18, 3), 'b.int) + val relation4 = LocalRelation('a.decimal(18, 4), 'b.int) + val relation5 = LocalRelation('a.decimal(18, 5), 'b.int) + + val optimizedRelation1 = relation1.select('a.cast(DecimalType(19, 2)).cast(DecimalType(20, 3)) + .cast(DecimalType(21, 4)).cast(DecimalType(22, 5)).as("a"), 'b) + val optimizedRelation2 = relation2.select('a.cast(DecimalType(19, 2)).cast(DecimalType(20, 3)) + .cast(DecimalType(21, 4)).cast(DecimalType(22, 5)).as("a"), 'b) + val optimizedRelation3 = relation3.select('a.cast(DecimalType(20, 3)) + .cast(DecimalType(21, 4)).cast(DecimalType(22, 5)).as("a"), 'b) + val optimizedRelation4 = relation4 + .select('a.cast(DecimalType(21, 4)).cast(DecimalType(22, 5)).as("a"), 'b) + val optimizedRelation5 = relation5.select('a.cast(DecimalType(22, 5)).as("a"), 'b) + + // SQL UNION ALL + comparePlans( + Optimize.execute(relation1.union(relation2) + .union(relation3).union(relation4).union(relation5).analyze), + Union(Seq(optimizedRelation1, optimizedRelation2, optimizedRelation3, + optimizedRelation4, optimizedRelation5)).analyze) + + // SQL UNION + comparePlans( + Optimize.execute(Distinct(Distinct(Distinct(Distinct(relation1.union(relation2)) + .union(relation3)).union(relation4)).union(relation5)).analyze), + Distinct(Union(Seq(optimizedRelation1, optimizedRelation2, optimizedRelation3, + optimizedRelation4, optimizedRelation5))).analyze) + + // Deduplicate + comparePlans( + Optimize.execute(relation1.union(relation2).deduplicate('a, 'b).union(relation3) + .deduplicate('a, 'b).union(relation4).deduplicate('a, 'b).union(relation5) + .deduplicate('a, 'b).analyze), + Deduplicate( + Seq('a, 'b), + Union(Seq(optimizedRelation1, optimizedRelation2, optimizedRelation3, + optimizedRelation4, optimizedRelation5))).analyze) + + // Other cases + comparePlans( + Optimize.execute(Distinct(Distinct(Distinct(Distinct(relation1.union(relation2)) + .union(relation3)).union(relation4)).union(relation5)).select('a % 2).analyze), + Distinct(Union(Seq(optimizedRelation1, optimizedRelation2, optimizedRelation3, + optimizedRelation4, optimizedRelation5))).select('a % 2).analyze) + + comparePlans( + Optimize.execute(Distinct(Distinct(Distinct(Distinct(relation1.union(relation2)) + .union(relation3)).union(relation4)).union(relation5)).select('a + 'b).analyze), + Distinct(Union(Seq(optimizedRelation1, optimizedRelation2, optimizedRelation3, + optimizedRelation4, optimizedRelation5))).select('a + 'b).analyze) + + comparePlans( + Optimize.execute(Distinct(Distinct(Distinct(Distinct(relation1.union(relation2)) + .union(relation3)).union(relation4)).union(relation5)).select('a).analyze), + Distinct(Union(Seq(optimizedRelation1, optimizedRelation2, optimizedRelation3, + optimizedRelation4, optimizedRelation5))).select('a).analyze) + + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala index c981cee55d0fa..3c1815043df7f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala @@ -68,4 +68,41 @@ class SimplifyCastsSuite extends PlanTest { // `SimplifyCasts` rule respect the plan. comparePlans(optimized, plan, checkAnalysis = false) } + + test("SPARK-37922: Combine to one cast if we can safely up-cast two casts") { + val input = LocalRelation('a.int, 'b.decimal(18, 2), 'c.date, 'd.timestamp) + + // Combine casts + comparePlans( + Optimize.execute( + input.select('a.cast(DecimalType(18, 1)).cast(DecimalType(19, 1)).as("casted")).analyze), + input.select('a.cast(DecimalType(19, 1)).as("casted")).analyze) + comparePlans( + Optimize.execute( + input.select('a.cast(LongType).cast(DecimalType(22, 1)).as("casted")).analyze), + input.select('a.cast(DecimalType(22, 1)).as("casted")).analyze) + comparePlans( + Optimize.execute( + input.select('b.cast(DecimalType(20, 2)).cast(DecimalType(24, 2)).as("casted")).analyze), + input.select('b.cast(DecimalType(24, 2)).as("casted")).analyze) + + // Can not combine casts + comparePlans( + Optimize.execute( + input.select('a.cast(DecimalType(2, 1)).cast(DecimalType(3, 1)).as("casted")).analyze), + input.select('a.cast(DecimalType(2, 1)).cast(DecimalType(3, 1)).as("casted")).analyze) + comparePlans( + Optimize.execute( + input.select('b.cast(DecimalType(10, 2)).cast(DecimalType(24, 2)).as("casted")).analyze), + input.select('b.cast(DecimalType(10, 2)).cast(DecimalType(24, 2)).as("casted")).analyze) + + comparePlans( + Optimize.execute( + input.select('c.cast(TimestampType).cast(StringType).as("casted")).analyze), + input.select('c.cast(TimestampType).cast(StringType).as("casted")).analyze) + comparePlans( + Optimize.execute( + input.select('d.cast(LongType).cast(StringType).as("casted")).analyze), + input.select('d.cast(LongType).cast(StringType).as("casted")).analyze) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala index 4fd681d4cedc8..a53e04da19d41 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.Rand +import org.apache.spark.sql.catalyst.expressions.{Concat, Rand} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -91,7 +91,7 @@ class TransposeWindowSuite extends PlanTest { test("don't transpose two adjacent windows with intersection of partition and output set") { val query = testRelation - .window(Seq(('a + 'b).as('e), sum(c).as('sum_a_2)), partitionSpec3, Seq.empty) + .window(Seq(Concat(Seq('a, 'b)).as('e), sum(c).as('sum_a_2)), partitionSpec3, Seq.empty) .window(Seq(sum(c).as('sum_a_1)), Seq(a, 'e), Seq.empty) val analyzed = query.analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala index 8da6d373eb3bc..a51be57db6fa7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala @@ -212,7 +212,9 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp } test("unwrap cast should skip if cannot coerce type") { - assertEquivalent(Cast(f, ByteType) > 100.toByte, Cast(f, ByteType) > 100.toByte) + if (!conf.ansiEnabled) { + assertEquivalent(Cast(f, ByteType) > 100.toByte, Cast(f, ByteType) > 100.toByte) + } } test("test getRange()") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index 00a4212f661d9..11d1b30b4f8cc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -441,9 +442,12 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { MapType(BinaryType, StringType)) val mb1 = Literal.create(Map[Array[Byte], String](), MapType(BinaryType, StringType)) - checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](1, 2, 3))), null) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + // ANSI will throw exception + checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](1, 2, 3))), null) + checkEvaluation(GetMapValue(mb1, Literal(Array[Byte](1, 2))), null) + } - checkEvaluation(GetMapValue(mb1, Literal(Array[Byte](1, 2))), null) checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2") checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](3, 4))), null) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 04309297bb4fc..a339e6d33f5f3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -21,11 +21,11 @@ import java.util.Locale import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions.{EqualTo, Hex, Literal} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition.{after, first} import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} +import org.apache.spark.sql.connector.expressions.LogicalExpressions.bucket import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType, TimestampType} import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -44,7 +44,10 @@ class DDLParserSuite extends AnalysisTest { } private def intercept(sqlCommand: String, messages: String*): Unit = - interceptParseException(parsePlan)(sqlCommand, messages: _*) + interceptParseException(parsePlan)(sqlCommand, messages: _*)() + + private def intercept(sqlCommand: String, errorClass: Option[String], messages: String*): Unit = + interceptParseException(parsePlan)(sqlCommand, messages: _*)(errorClass) private def parseCompare(sql: String, expected: LogicalPlan): Unit = { comparePlans(parsePlan(sql), expected, checkAnalysis = false) @@ -59,7 +62,6 @@ class DDLParserSuite extends AnalysisTest { .add("a", IntegerType, nullable = true, "test") .add("b", StringType, nullable = false)), Seq.empty[Transform], - None, Map.empty[String, String], Some("parquet"), Map.empty[String, String], @@ -83,7 +85,6 @@ class DDLParserSuite extends AnalysisTest { Seq("my_tab"), Some(new StructType().add("a", IntegerType).add("b", StringType)), Seq.empty[Transform], - None, Map.empty[String, String], Some("parquet"), Map.empty[String, String], @@ -104,7 +105,6 @@ class DDLParserSuite extends AnalysisTest { .add("a", IntegerType, nullable = true, "test") .add("b", StringType)), Seq(IdentityTransform(FieldReference("a"))), - None, Map.empty[String, String], Some("parquet"), Map.empty[String, String], @@ -159,7 +159,6 @@ class DDLParserSuite extends AnalysisTest { FieldReference("a"), LiteralValue(UTF8String.fromString("bar"), StringType), LiteralValue(34, IntegerType)))), - None, Map.empty[String, String], Some("parquet"), Map.empty[String, String], @@ -181,14 +180,14 @@ class DDLParserSuite extends AnalysisTest { val expectedTableSpec = TableSpec( Seq("my_tab"), Some(new StructType().add("a", IntegerType).add("b", StringType)), - Seq.empty[Transform], - Some(BucketSpec(5, Seq("a"), Seq("b"))), + List(bucket(5, Array(FieldReference.column("a")), Array(FieldReference.column("b")))), Map.empty[String, String], Some("parquet"), Map.empty[String, String], None, None, None) + Seq(createSql, replaceSql).foreach { sql => testCreateOrReplaceDdl(sql, expectedTableSpec, expectedIfNotExists = false) } @@ -201,7 +200,6 @@ class DDLParserSuite extends AnalysisTest { Seq("my_tab"), Some(new StructType().add("a", IntegerType).add("b", StringType)), Seq.empty[Transform], - None, Map.empty[String, String], Some("parquet"), Map.empty[String, String], @@ -222,7 +220,6 @@ class DDLParserSuite extends AnalysisTest { Seq("my_tab"), Some(new StructType().add("a", IntegerType).add("b", StringType)), Seq.empty[Transform], - None, Map("test" -> "test"), Some("parquet"), Map.empty[String, String], @@ -241,7 +238,6 @@ class DDLParserSuite extends AnalysisTest { Seq("my_tab"), Some(new StructType().add("a", IntegerType).add("b", StringType)), Seq.empty[Transform], - None, Map.empty[String, String], Some("parquet"), Map.empty[String, String], @@ -260,7 +256,6 @@ class DDLParserSuite extends AnalysisTest { Seq("1m", "2g"), Some(new StructType().add("a", IntegerType)), Seq.empty[Transform], - None, Map.empty[String, String], Some("parquet"), Map.empty[String, String], @@ -279,7 +274,6 @@ class DDLParserSuite extends AnalysisTest { Seq("my_tab"), Some(new StructType().add("id", LongType).add("part", StringType)), Seq(IdentityTransform(FieldReference("part"))), - None, Map.empty[String, String], None, Map.empty[String, String], @@ -298,7 +292,6 @@ class DDLParserSuite extends AnalysisTest { Seq("my_tab"), Some(new StructType().add("part", StringType)), Seq(IdentityTransform(FieldReference("part"))), - None, Map.empty[String, String], None, Map.empty[String, String], @@ -317,7 +310,6 @@ class DDLParserSuite extends AnalysisTest { Seq("my_tab"), Some(new StructType().add("id", LongType).add("part", StringType)), Seq(IdentityTransform(FieldReference("part"))), - None, Map.empty[String, String], Some("parquet"), Map.empty[String, String], @@ -361,7 +353,6 @@ class DDLParserSuite extends AnalysisTest { Seq("my_tab"), Some(new StructType().add("id", LongType).add("part", StringType)), Seq(IdentityTransform(FieldReference("part"))), - None, Map.empty[String, String], None, Map.empty[String, String], @@ -387,7 +378,6 @@ class DDLParserSuite extends AnalysisTest { Seq("my_tab"), Some(new StructType().add("id", LongType).add("part", StringType)), Seq(IdentityTransform(FieldReference("part"))), - None, Map.empty[String, String], None, Map.empty[String, String], @@ -430,7 +420,6 @@ class DDLParserSuite extends AnalysisTest { Seq("my_tab"), Some(new StructType().add("id", LongType).add("part", StringType)), Seq(IdentityTransform(FieldReference("part"))), - None, Map.empty[String, String], None, Map.empty[String, String], @@ -469,7 +458,6 @@ class DDLParserSuite extends AnalysisTest { Seq("my_tab"), Some(new StructType().add("id", LongType).add("part", StringType)), Seq(IdentityTransform(FieldReference("part"))), - None, Map.empty[String, String], None, Map.empty[String, String], @@ -493,7 +481,6 @@ class DDLParserSuite extends AnalysisTest { Seq("my_tab"), Some(new StructType().add("id", LongType).add("part", StringType)), Seq(IdentityTransform(FieldReference("part"))), - None, Map.empty[String, String], None, Map.empty[String, String], @@ -627,7 +614,6 @@ class DDLParserSuite extends AnalysisTest { Seq("table_name"), Some(new StructType), Seq.empty[Transform], - Option.empty[BucketSpec], Map.empty[String, String], Some("json"), Map("a" -> "1", "b" -> "0.1", "c" -> "true"), @@ -683,7 +669,6 @@ class DDLParserSuite extends AnalysisTest { Seq("mydb", "page_view"), None, Seq.empty[Transform], - None, Map("p1" -> "v1", "p2" -> "v2"), Some("parquet"), Map.empty[String, String], @@ -1792,7 +1777,7 @@ class DDLParserSuite extends AnalysisTest { allColumns = true)) intercept("ANALYZE TABLE a.b.c COMPUTE STATISTICS FOR ALL COLUMNS key, value", - "mismatched input 'key' expecting {, ';'}") + Some("PARSE_INPUT_MISMATCHED"), "Syntax error at or near 'key'") // expecting {, ';'} intercept("ANALYZE TABLE a.b.c COMPUTE STATISTICS FOR ALL", "missing 'COLUMNS' at ''") } @@ -1839,19 +1824,6 @@ class DDLParserSuite extends AnalysisTest { Some(Map("ds" -> "2017-06-10")))) } - test("SHOW CREATE table") { - comparePlans( - parsePlan("SHOW CREATE TABLE a.b.c"), - ShowCreateTable( - UnresolvedTableOrView(Seq("a", "b", "c"), "SHOW CREATE TABLE", allowTempView = false))) - - comparePlans( - parsePlan("SHOW CREATE TABLE a.b.c AS SERDE"), - ShowCreateTable( - UnresolvedTableOrView(Seq("a", "b", "c"), "SHOW CREATE TABLE", allowTempView = false), - asSerde = true)) - } - test("CACHE TABLE") { comparePlans( parsePlan("CACHE TABLE a.b.c"), @@ -2137,7 +2109,6 @@ class DDLParserSuite extends AnalysisTest { name: Seq[String], schema: Option[StructType], partitioning: Seq[Transform], - bucketSpec: Option[BucketSpec], properties: Map[String, String], provider: Option[String], options: Map[String, String], @@ -2154,7 +2125,6 @@ class DDLParserSuite extends AnalysisTest { create.name.asInstanceOf[UnresolvedDBObjectName].nameParts, Some(create.tableSchema), create.partitioning, - create.tableSpec.bucketSpec, create.tableSpec.properties, create.tableSpec.provider, create.tableSpec.options, @@ -2167,7 +2137,6 @@ class DDLParserSuite extends AnalysisTest { replace.name.asInstanceOf[UnresolvedDBObjectName].nameParts, Some(replace.tableSchema), replace.partitioning, - replace.tableSpec.bucketSpec, replace.tableSpec.properties, replace.tableSpec.provider, replace.tableSpec.options, @@ -2179,7 +2148,6 @@ class DDLParserSuite extends AnalysisTest { ctas.name.asInstanceOf[UnresolvedDBObjectName].nameParts, Some(ctas.query).filter(_.resolved).map(_.schema), ctas.partitioning, - ctas.tableSpec.bucketSpec, ctas.tableSpec.properties, ctas.tableSpec.provider, ctas.tableSpec.options, @@ -2192,7 +2160,6 @@ class DDLParserSuite extends AnalysisTest { rtas.name.asInstanceOf[UnresolvedDBObjectName].nameParts, Some(rtas.query).filter(_.resolved).map(_.schema), rtas.partitioning, - rtas.tableSpec.bucketSpec, rtas.tableSpec.properties, rtas.tableSpec.provider, rtas.tableSpec.options, @@ -2230,7 +2197,6 @@ class DDLParserSuite extends AnalysisTest { Seq("1m", "2g"), Some(new StructType().add("a", IntegerType)), Seq.empty[Transform], - None, Map.empty[String, String], None, Map.empty[String, String], @@ -2272,4 +2238,57 @@ class DDLParserSuite extends AnalysisTest { comparePlans(parsePlan(timestampTypeSql), insertPartitionPlan(timestamp)) comparePlans(parsePlan(binaryTypeSql), insertPartitionPlan(binaryStr)) } + + test("SPARK-38335: Implement parser support for DEFAULT values for columns in tables") { + // The following commands will support DEFAULT columns, but this has not been implemented yet. + for (sql <- Seq( + "ALTER TABLE t1 ADD COLUMN x int NOT NULL DEFAULT 42", + "ALTER TABLE t1 ALTER COLUMN a.b.c SET DEFAULT 42", + "ALTER TABLE t1 ALTER COLUMN a.b.c DROP DEFAULT", + "ALTER TABLE t1 REPLACE COLUMNS (x STRING DEFAULT 42)", + "CREATE TABLE my_tab(a INT COMMENT 'test', b STRING NOT NULL DEFAULT \"abc\") USING parquet", + "REPLACE TABLE my_tab(a INT COMMENT 'test', b STRING NOT NULL DEFAULT \"xyz\") USING parquet" + )) { + val exc = intercept[ParseException] { + parsePlan(sql); + } + assert(exc.getMessage.contains("Support for DEFAULT column values is not implemented yet")); + } + // In each of the following cases, the DEFAULT reference parses as an unresolved attribute + // reference. We can handle these cases after the parsing stage, at later phases of analysis. + comparePlans(parsePlan("VALUES (1, 2, DEFAULT) AS val"), + SubqueryAlias("val", + UnresolvedInlineTable(Seq("col1", "col2", "col3"), Seq(Seq(Literal(1), Literal(2), + UnresolvedAttribute("DEFAULT")))))) + comparePlans(parsePlan( + "INSERT INTO t PARTITION(part = date'2019-01-02') VALUES ('a', DEFAULT)"), + InsertIntoStatement( + UnresolvedRelation(Seq("t")), + Map("part" -> Some("2019-01-02")), + userSpecifiedCols = Seq.empty[String], + query = UnresolvedInlineTable(Seq("col1", "col2"), Seq(Seq(Literal("a"), + UnresolvedAttribute("DEFAULT")))), + overwrite = false, ifPartitionNotExists = false)) + parseCompare( + """ + |MERGE INTO testcat1.ns1.ns2.tbl AS target + |USING testcat2.ns1.ns2.tbl AS source + |ON target.col1 = source.col1 + |WHEN MATCHED AND (target.col2='delete') THEN DELETE + |WHEN MATCHED AND (target.col2='update') THEN UPDATE SET target.col2 = DEFAULT + |WHEN NOT MATCHED AND (target.col2='insert') + |THEN INSERT (target.col1, target.col2) VALUES (source.col1, DEFAULT) + """.stripMargin, + MergeIntoTable( + SubqueryAlias("target", UnresolvedRelation(Seq("testcat1", "ns1", "ns2", "tbl"))), + SubqueryAlias("source", UnresolvedRelation(Seq("testcat2", "ns1", "ns2", "tbl"))), + EqualTo(UnresolvedAttribute("target.col1"), UnresolvedAttribute("source.col1")), + Seq(DeleteAction(Some(EqualTo(UnresolvedAttribute("target.col2"), Literal("delete")))), + UpdateAction(Some(EqualTo(UnresolvedAttribute("target.col2"), Literal("update"))), + Seq(Assignment(UnresolvedAttribute("target.col2"), + UnresolvedAttribute("DEFAULT"))))), + Seq(InsertAction(Some(EqualTo(UnresolvedAttribute("target.col2"), Literal("insert"))), + Seq(Assignment(UnresolvedAttribute("target.col1"), UnresolvedAttribute("source.col1")), + Assignment(UnresolvedAttribute("target.col2"), UnresolvedAttribute("DEFAULT"))))))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala index dfc5edc82ef5b..71296f0a26e4a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.catalyst.parser +import org.apache.spark.SparkThrowableHelper import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -31,26 +32,49 @@ class ErrorParserSuite extends AnalysisTest { assert(parsePlan(sqlCommand) == plan) } - def intercept(sqlCommand: String, messages: String*): Unit = - interceptParseException(CatalystSqlParser.parsePlan)(sqlCommand, messages: _*) - - def intercept(sql: String, line: Int, startPosition: Int, stopPosition: Int, - messages: String*): Unit = { + private def interceptImpl(sql: String, messages: String*)( + line: Option[Int] = None, + startPosition: Option[Int] = None, + stopPosition: Option[Int] = None, + errorClass: Option[String] = None): Unit = { val e = intercept[ParseException](CatalystSqlParser.parsePlan(sql)) - // Check position. - assert(e.line.isDefined) - assert(e.line.get === line) - assert(e.startPosition.isDefined) - assert(e.startPosition.get === startPosition) - assert(e.stop.startPosition.isDefined) - assert(e.stop.startPosition.get === stopPosition) - // Check messages. val error = e.getMessage messages.foreach { message => assert(error.contains(message)) } + + // Check position. + if (line.isDefined) { + assert(line.isDefined && startPosition.isDefined && stopPosition.isDefined) + assert(e.line.isDefined) + assert(e.line.get === line.get) + assert(e.startPosition.isDefined) + assert(e.startPosition.get === startPosition.get) + assert(e.stop.startPosition.isDefined) + assert(e.stop.startPosition.get === stopPosition.get) + } + + // Check error class. + if (errorClass.isDefined) { + assert(e.getErrorClass == errorClass.get) + } + } + + def intercept(sqlCommand: String, errorClass: Option[String], messages: String*): Unit = { + interceptImpl(sqlCommand, messages: _*)(errorClass = errorClass) + } + + def intercept( + sql: String, line: Int, startPosition: Int, stopPosition: Int, messages: String*): Unit = { + interceptImpl(sql, messages: _*)(Some(line), Some(startPosition), Some(stopPosition)) + } + + def intercept(sql: String, errorClass: String, line: Int, startPosition: Int, stopPosition: Int, + messages: String*): Unit = { + interceptImpl(sql, messages: _*)( + Some(line), Some(startPosition), Some(stopPosition), Some(errorClass)) } test("no viable input") { @@ -64,10 +88,29 @@ class ErrorParserSuite extends AnalysisTest { } test("mismatched input") { - intercept("select * from r order by q from t", 1, 27, 31, - "mismatched input", - "---------------------------^^^") - intercept("select *\nfrom r\norder by q\nfrom t", 4, 0, 4, "mismatched input", "^^^") + intercept("select * from r order by q from t", "PARSE_INPUT_MISMATCHED", + 1, 27, 31, + "Syntax error at or near", + "---------------------------^^^" + ) + intercept("select *\nfrom r\norder by q\nfrom t", "PARSE_INPUT_MISMATCHED", + 4, 0, 4, + "Syntax error at or near", "^^^") + } + + test("empty input") { + val expectedErrMsg = SparkThrowableHelper.getMessage("PARSE_EMPTY_STATEMENT", Array[String]()) + intercept("", Some("PARSE_EMPTY_STATEMENT"), expectedErrMsg) + intercept(" ", Some("PARSE_EMPTY_STATEMENT"), expectedErrMsg) + intercept(" \n", Some("PARSE_EMPTY_STATEMENT"), expectedErrMsg) + } + + test("jargon token substitute to user-facing language") { + // '' -> end of input + intercept("select count(*", "PARSE_INPUT_MISMATCHED", + 1, 14, 14, "Syntax error at or near end of input") + intercept("select 1 as a from", "PARSE_INPUT_MISMATCHED", + 1, 18, 18, "Syntax error at or near end of input") } test("semantic errors") { @@ -77,9 +120,11 @@ class ErrorParserSuite extends AnalysisTest { } test("SPARK-21136: misleading error message due to problematic antlr grammar") { - intercept("select * from a left join_ b on a.id = b.id", "missing 'JOIN' at 'join_'") - intercept("select * from test where test.t is like 'test'", "mismatched input 'is' expecting") - intercept("SELECT * FROM test WHERE x NOT NULL", "mismatched input 'NOT' expecting") + intercept("select * from a left join_ b on a.id = b.id", None, "missing 'JOIN' at 'join_'") + intercept("select * from test where test.t is like 'test'", Some("PARSE_INPUT_MISMATCHED"), + SparkThrowableHelper.getMessage("PARSE_INPUT_MISMATCHED", Array("'is'"))) + intercept("SELECT * FROM test WHERE x NOT NULL", Some("PARSE_INPUT_MISMATCHED"), + SparkThrowableHelper.getMessage("PARSE_INPUT_MISMATCHED", Array("'NOT'"))) } test("hyphen in identifier - DDL tests") { @@ -208,14 +253,4 @@ class ErrorParserSuite extends AnalysisTest { |SELECT b """.stripMargin, 2, 9, 10, msg + " test-table") } - - test("SPARK-35789: lateral join with non-subquery relations") { - val msg = "LATERAL can only be used with subquery" - intercept("SELECT * FROM t1, LATERAL t2", msg) - intercept("SELECT * FROM t1 JOIN LATERAL t2", msg) - intercept("SELECT * FROM t1, LATERAL (t2 JOIN t3)", msg) - intercept("SELECT * FROM t1, LATERAL (LATERAL t2)", msg) - intercept("SELECT * FROM t1, LATERAL VALUES (0, 1)", msg) - intercept("SELECT * FROM t1, LATERAL RANGE(0, 1)", msg) - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 93b6ca64ca2db..754ac8b91f738 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -58,7 +58,10 @@ class ExpressionParserSuite extends AnalysisTest { } private def intercept(sqlCommand: String, messages: String*): Unit = - interceptParseException(defaultParser.parseExpression)(sqlCommand, messages: _*) + interceptParseException(defaultParser.parseExpression)(sqlCommand, messages: _*)() + + private def intercept(sqlCommand: String, errorClass: Option[String], messages: String*): Unit = + interceptParseException(defaultParser.parseExpression)(sqlCommand, messages: _*)(errorClass) def assertEval( sqlCommand: String, @@ -863,7 +866,8 @@ class ExpressionParserSuite extends AnalysisTest { test("composed expressions") { assertEqual("1 + r.r As q", (Literal(1) + UnresolvedAttribute("r.r")).as("q")) assertEqual("1 - f('o', o(bar))", Literal(1) - 'f.function("o", 'o.function('bar))) - intercept("1 - f('o', o(bar)) hello * world", "mismatched input '*'") + intercept("1 - f('o', o(bar)) hello * world", Some("PARSE_INPUT_MISMATCHED"), + "Syntax error at or near '*'") } test("SPARK-17364, fully qualified column name which starts with number") { @@ -882,7 +886,8 @@ class ExpressionParserSuite extends AnalysisTest { test("SPARK-17832 function identifier contains backtick") { val complexName = FunctionIdentifier("`ba`r", Some("`fo`o")) assertEqual(complexName.quotedString, UnresolvedAttribute(Seq("`fo`o", "`ba`r"))) - intercept(complexName.unquotedString, "mismatched input") + intercept(complexName.unquotedString, Some("PARSE_INPUT_MISMATCHED"), + "Syntax error at or near") // Function identifier contains continuous backticks should be treated correctly. val complexName2 = FunctionIdentifier("ba``r", Some("fo``o")) assertEqual(complexName2.quotedString, UnresolvedAttribute(Seq("fo``o", "ba``r"))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index e8088a62ecdde..70138a3e688c7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -41,7 +41,10 @@ class PlanParserSuite extends AnalysisTest { } private def intercept(sqlCommand: String, messages: String*): Unit = - interceptParseException(parsePlan)(sqlCommand, messages: _*) + interceptParseException(parsePlan)(sqlCommand, messages: _*)() + + private def intercept(sqlCommand: String, errorClass: Option[String], messages: String*): Unit = + interceptParseException(parsePlan)(sqlCommand, messages: _*)(errorClass) private def cte( plan: LogicalPlan, @@ -289,11 +292,11 @@ class PlanParserSuite extends AnalysisTest { "from a select * select * where s < 10", table("a").select(star()).union(table("a").where('s < 10).select(star()))) intercept( - "from a select * select * from x where a.s < 10", - "mismatched input 'from' expecting") + "from a select * select * from x where a.s < 10", Some("PARSE_INPUT_MISMATCHED"), + "Syntax error at or near 'from'") intercept( - "from a select * from b", - "mismatched input 'from' expecting") + "from a select * from b", Some("PARSE_INPUT_MISMATCHED"), + "Syntax error at or near 'from'") assertEqual( "from a insert into tbl1 select * insert into tbl2 select * where s < 10", table("a").select(star()).insertInto("tbl1").union( @@ -775,16 +778,12 @@ class PlanParserSuite extends AnalysisTest { test("select hint syntax") { // Hive compatibility: Missing parameter raises ParseException. - val m = intercept[ParseException] { - parsePlan("SELECT /*+ HINT() */ * FROM t") - }.getMessage - assert(m.contains("mismatched input")) + intercept("SELECT /*+ HINT() */ * FROM t", Some("PARSE_INPUT_MISMATCHED"), + "Syntax error at or near") // Disallow space as the delimiter. - val m3 = intercept[ParseException] { - parsePlan("SELECT /*+ INDEX(a b c) */ * from default.t") - }.getMessage - assert(m3.contains("mismatched input 'b' expecting")) + intercept("SELECT /*+ INDEX(a b c) */ * from default.t", Some("PARSE_INPUT_MISMATCHED"), + "Syntax error at or near 'b'") comparePlans( parsePlan("SELECT /*+ HINT */ * FROM t"), @@ -841,7 +840,8 @@ class PlanParserSuite extends AnalysisTest { UnresolvedHint("REPARTITION", Seq(Literal(100)), table("t").select(star())))) - intercept("SELECT /*+ COALESCE(30 + 50) */ * FROM t", "mismatched input") + intercept("SELECT /*+ COALESCE(30 + 50) */ * FROM t", Some("PARSE_INPUT_MISMATCHED"), + "Syntax error at or near") comparePlans( parsePlan("SELECT /*+ REPARTITION(c) */ * FROM t"), @@ -965,8 +965,10 @@ class PlanParserSuite extends AnalysisTest { ) } - intercept("select ltrim(both 'S' from 'SS abc S'", "mismatched input 'from' expecting {')'") - intercept("select rtrim(trailing 'S' from 'SS abc S'", "mismatched input 'from' expecting {')'") + intercept("select ltrim(both 'S' from 'SS abc S'", Some("PARSE_INPUT_MISMATCHED"), + "Syntax error at or near 'from'") // expecting {')' + intercept("select rtrim(trailing 'S' from 'SS abc S'", Some("PARSE_INPUT_MISMATCHED"), + "Syntax error at or near 'from'") // expecting {')' assertTrimPlans( "SELECT TRIM(BOTH '@$%&( )abc' FROM '@ $ % & ()abc ' )", @@ -1079,7 +1081,7 @@ class PlanParserSuite extends AnalysisTest { val m1 = intercept[ParseException] { parsePlan("CREATE VIEW testView AS INSERT INTO jt VALUES(1, 1)") }.getMessage - assert(m1.contains("mismatched input 'INSERT' expecting")) + assert(m1.contains("Syntax error at or near 'INSERT'")) // Multi insert query val m2 = intercept[ParseException] { parsePlan( @@ -1089,11 +1091,11 @@ class PlanParserSuite extends AnalysisTest { |INSERT INTO tbl2 SELECT * WHERE jt.id > 4 """.stripMargin) }.getMessage - assert(m2.contains("mismatched input 'INSERT' expecting")) + assert(m2.contains("Syntax error at or near 'INSERT'")) val m3 = intercept[ParseException] { parsePlan("ALTER VIEW testView AS INSERT INTO jt VALUES(1, 1)") }.getMessage - assert(m3.contains("mismatched input 'INSERT' expecting")) + assert(m3.contains("Syntax error at or near 'INSERT'")) // Multi insert query val m4 = intercept[ParseException] { parsePlan( @@ -1104,7 +1106,7 @@ class PlanParserSuite extends AnalysisTest { """.stripMargin ) }.getMessage - assert(m4.contains("mismatched input 'INSERT' expecting")) + assert(m4.contains("Syntax error at or near 'INSERT'")) } test("Invalid insert constructs in the query") { @@ -1115,7 +1117,7 @@ class PlanParserSuite extends AnalysisTest { val m2 = intercept[ParseException] { parsePlan("SELECT * FROM S WHERE C1 IN (INSERT INTO T VALUES (2))") }.getMessage - assert(m2.contains("mismatched input 'IN' expecting")) + assert(m2.contains("Syntax error at or near 'IN'")) } test("relation in v2 catalog") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index 0cd6d8164fe8d..acb41b097efbb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -105,4 +105,12 @@ class LogicalPlanSuite extends SparkFunSuite { assert(Range(0, 100, 1, 3).select('id).maxRowsPerPartition === Some(34)) assert(Range(0, 100, 1, 3).where('id % 2 === 1).maxRowsPerPartition === Some(34)) } + + test("SPARK-38286: Union's maxRows and maxRowsPerPartition may overflow") { + val query1 = Range(0, Long.MaxValue, 1, 1) + val query2 = Range(0, 100, 1, 10) + val query = query1.union(query2) + assert(query.maxRows.isEmpty) + assert(query.maxRowsPerPartition.isEmpty) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala index fb014bb8391f3..0839092119da3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala @@ -129,4 +129,35 @@ class QueryPlanSuite extends SparkFunSuite { ) assert(!nonDeterministicPlan.deterministic) } + + test("SPARK-38347: Nullability propagation in transformUpWithNewOutput") { + // A test rule that replaces Attributes in Project's project list. + val testRule = new Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithNewOutput { + case p @ Project(projectList, _) => + val newProjectList = projectList.map { + case a: AttributeReference => a.newInstance() + case ne => ne + } + val newProject = p.copy(projectList = newProjectList) + newProject -> p.output.zip(newProject.output) + } + } + + // Test a Left Outer Join plan in which right-hand-side input attributes are not nullable. + // Those attributes should be nullable after join even with a `transformUpWithNewOutput` + // started below the Left Outer join. + val t1 = LocalRelation('a.int.withNullability(false), + 'b.int.withNullability(false), 'c.int.withNullability(false)) + val t2 = LocalRelation('c.int.withNullability(false), + 'd.int.withNullability(false), 'e.int.withNullability(false)) + val plan = t1.select($"a", $"b") + .join(t2.select($"c", $"d"), LeftOuter, Some($"a" === $"c")) + .select($"a" + $"d").analyze + // The output Attribute of `plan` is nullable even though `d` is not nullable before the join. + assert(plan.output(0).nullable) + // The test rule with `transformUpWithNewOutput` should not change the nullability. + val planAfterTestRule = testRule(plan) + assert(planAfterTestRule.output(0).nullable) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelperSuite.scala index 0a3f86ebf6808..4a426458e5bfe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelperSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelperSuite.scala @@ -41,7 +41,7 @@ class AnalysisHelperSuite extends SparkFunSuite { test("setAnalyze is recursive") { val plan = Project(Nil, LocalRelation()) plan.setAnalyzed() - assert(plan.find(!_.analyzed).isEmpty) + assert(!plan.exists(!_.analyzed)) } test("resolveOperator runs on operators recursively") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitorSuite.scala new file mode 100644 index 0000000000000..131155f8c04d1 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitorSuite.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import scala.collection.mutable +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, ExpressionSet, UnspecifiedFrame} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.types.IntegerType + +class DistinctKeyVisitorSuite extends PlanTest { + + private val a = AttributeReference("a", IntegerType)() + private val b = AttributeReference("b", IntegerType)() + private val c = AttributeReference("c", IntegerType)() + private val d = a.as("aliased_a") + private val e = b.as("aliased_b") + private val f = Alias(a + 1, (a + 1).toString)() + private val x = AttributeReference("x", IntegerType)() + private val y = AttributeReference("y", IntegerType)() + private val z = AttributeReference("z", IntegerType)() + + + private val t1 = LocalRelation(a, b, c).as("t1") + private val t2 = LocalRelation(x, y, z).as("t2") + + private def checkDistinctAttributes(plan: LogicalPlan, distinctKeys: Set[ExpressionSet]) = { + assert(plan.analyze.distinctKeys === distinctKeys) + } + + implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]() + + test("Aggregate's distinct attributes") { + checkDistinctAttributes(t1.groupBy('a, 'b)('a, 'b, 1), Set(ExpressionSet(Seq(a, b)))) + checkDistinctAttributes(t1.groupBy('a)('a), Set(ExpressionSet(Seq(a)))) + checkDistinctAttributes(t1.groupBy('a, 'b)('a, 'b), Set(ExpressionSet(Seq(a, b)))) + checkDistinctAttributes(t1.groupBy('a, 'b, 1)('a, 'b), Set(ExpressionSet(Seq(a, b)))) + checkDistinctAttributes(t1.groupBy('a, 'b)('a, 'b, 1), Set(ExpressionSet(Seq(a, b)))) + checkDistinctAttributes(t1.groupBy('a, 'b, 1)('a, 'b, 1), Set(ExpressionSet(Seq(a, b)))) + checkDistinctAttributes(t1.groupBy('a, 'b)('a, 'a), Set.empty) + checkDistinctAttributes(t1.groupBy('a, 'b)('a), Set.empty) + checkDistinctAttributes(t1.groupBy('a)('a, max('b)), Set(ExpressionSet(Seq(a)))) + checkDistinctAttributes(t1.groupBy('a, 'b)('a, 'b, d, e), + Set(ExpressionSet(Seq(a, b)), ExpressionSet(Seq(d.toAttribute, e.toAttribute)))) + checkDistinctAttributes(t1.groupBy()(sum('c)), Set.empty) + checkDistinctAttributes(t1.groupBy('a)('a, 'a % 10, d, sum('b)), + Set(ExpressionSet(Seq(a)), ExpressionSet(Seq(d.toAttribute)))) + checkDistinctAttributes(t1.groupBy(f.child, 'b)(f, 'b, sum('c)), + Set(ExpressionSet(Seq(f.toAttribute, b)))) + } + + test("Distinct's distinct attributes") { + checkDistinctAttributes(Distinct(t1), Set(ExpressionSet(Seq(a, b, c)))) + checkDistinctAttributes(Distinct(t1.select('a, 'c)), Set(ExpressionSet(Seq(a, c)))) + } + + test("Except's distinct attributes") { + checkDistinctAttributes(Except(t1, t2, false), Set(ExpressionSet(Seq(a, b, c)))) + checkDistinctAttributes(Except(t1, t2, true), Set.empty) + } + + test("Filter's distinct attributes") { + checkDistinctAttributes(Filter('a > 1, t1), Set.empty) + checkDistinctAttributes(Filter('a > 1, Distinct(t1)), Set(ExpressionSet(Seq(a, b, c)))) + } + + test("Limit's distinct attributes") { + checkDistinctAttributes(Distinct(t1).limit(10), Set(ExpressionSet(Seq(a, b, c)))) + checkDistinctAttributes(LocalLimit(10, Distinct(t1)), Set(ExpressionSet(Seq(a, b, c)))) + checkDistinctAttributes(t1.limit(1), Set(ExpressionSet(Seq(a, b, c)))) + } + + test("Intersect's distinct attributes") { + checkDistinctAttributes(Intersect(t1, t2, false), Set(ExpressionSet(Seq(a, b, c)))) + checkDistinctAttributes(Intersect(t1, t2, true), Set.empty) + } + + test("Join's distinct attributes") { + Seq(LeftSemi, LeftAnti).foreach { joinType => + checkDistinctAttributes( + Distinct(t1).join(t2, joinType, Some('a === 'x)), Set(ExpressionSet(Seq(a, b, c)))) + } + + checkDistinctAttributes( + Distinct(t1).join(Distinct(t2), Inner, Some('a === 'x && 'b === 'y && 'c === 'z)), + Set(ExpressionSet(Seq(a, b, c)), ExpressionSet(Seq(x, y, z)))) + + checkDistinctAttributes( + Distinct(t1).join(Distinct(t2), LeftOuter, Some('a === 'x && 'b === 'y && 'c === 'z)), + Set(ExpressionSet(Seq(a, b, c)))) + + checkDistinctAttributes( + Distinct(t1).join(Distinct(t2), RightOuter, Some('a === 'x && 'b === 'y && 'c === 'z)), + Set(ExpressionSet(Seq(x, y, z)))) + + Seq(Inner, Cross, LeftOuter, RightOuter).foreach { joinType => + checkDistinctAttributes(t1.join(t2, joinType, Some('a === 'x)), + Set.empty) + checkDistinctAttributes( + Distinct(t1).join(Distinct(t2), joinType, Some('a === 'x && 'b === 'y)), + Set.empty) + checkDistinctAttributes( + Distinct(t1).join(Distinct(t2), joinType, + Some('a === 'x && 'b === 'y && 'c % 5 === 'z % 5)), + Set.empty) + } + + checkDistinctAttributes( + Distinct(t1).join(Distinct(t2), Cross, Some('a === 'x && 'b === 'y && 'c === 'z)), + Set.empty) + } + + test("Project's distinct attributes") { + checkDistinctAttributes(t1.select('a, 'b), Set.empty) + checkDistinctAttributes(Distinct(t1).select('a), Set.empty) + checkDistinctAttributes(Distinct(t1).select('a, 'b, d, e), Set.empty) + checkDistinctAttributes(Distinct(t1).select('a, 'b, 'c, 1), Set(ExpressionSet(Seq(a, b, c)))) + checkDistinctAttributes(Distinct(t1).select('a, 'b, c, d), + Set(ExpressionSet(Seq(a, b, c)), ExpressionSet(Seq(b, c, d.toAttribute)))) + checkDistinctAttributes(t1.groupBy('a, 'b)('a, 'b, d).select('a, 'b, e), + Set(ExpressionSet(Seq(a, b)), ExpressionSet(Seq(a, e.toAttribute)))) + } + + test("Repartition's distinct attributes") { + checkDistinctAttributes(t1.repartition(8), Set.empty) + checkDistinctAttributes(Distinct(t1).repartition(8), Set(ExpressionSet(Seq(a, b, c)))) + checkDistinctAttributes(RepartitionByExpression(Seq(a), Distinct(t1), None), + Set(ExpressionSet(Seq(a, b, c)))) + } + + test("Sample's distinct attributes") { + checkDistinctAttributes(t1.sample(0, 0.2, false, 1), Set.empty) + checkDistinctAttributes(Distinct(t1).sample(0, 0.2, false, 1), Set(ExpressionSet(Seq(a, b, c)))) + } + + test("Window's distinct attributes") { + val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame)) + + checkDistinctAttributes( + Distinct(t1).select('a, 'b, 'c, winExpr.as('window)), Set(ExpressionSet(Seq(a, b, c)))) + checkDistinctAttributes( + Distinct(t1).select('a, 'b, winExpr.as('window)), Set()) + } + + test("Tail's distinct attributes") { + checkDistinctAttributes(Tail(10, Distinct(t1)), Set(ExpressionSet(Seq(a, b, c)))) + } + + test("Sort's distinct attributes") { + checkDistinctAttributes(t1.sortBy('a.asc), Set.empty) + checkDistinctAttributes(Distinct(t1).sortBy('a.asc), Set(ExpressionSet(Seq(a, b, c)))) + } + + test("RebalancePartitions's distinct attributes") { + checkDistinctAttributes(RebalancePartitions(Seq(a), Distinct(t1)), + Set(ExpressionSet(Seq(a, b, c)))) + } + + test("WithCTE's distinct attributes") { + checkDistinctAttributes(WithCTE(Distinct(t1), mutable.ArrayBuffer.empty[CTERelationDef].toSeq), + Set(ExpressionSet(Seq(a, b, c)))) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 31e289e052586..bc61a76ecfc22 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -259,12 +259,16 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { expectedStatsCboOff = Statistics.DUMMY) } - test("SPARK-35203: Improve Repartition statistics estimation") { + test("Improve Repartition statistics estimation") { + // SPARK-35203 for repartition and repartitionByExpr + // SPARK-37949 for rebalance Seq( RepartitionByExpression(plan.output, plan, 10), RepartitionByExpression(Nil, plan, None), plan.repartition(2), - plan.coalesce(3)).foreach { rep => + plan.coalesce(3), + plan.rebalance(), + plan.rebalance(plan.output: _*)).foreach { rep => val expectedStats = Statistics(plan.size.get, Some(plan.rowCount), plan.attributeStats) checkStats( rep, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index b52ecb56ad995..b6087c54e664b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -248,6 +248,44 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { assert(expected === actual) } + test("exists") { + val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) + // Check the top node. + var exists = expression.exists { + case _: Add => true + case _ => false + } + assert(exists) + + // Check the first children. + exists = expression.exists { + case Literal(1, IntegerType) => true + case _ => false + } + assert(exists) + + // Check an internal node (Subtract). + exists = expression.exists { + case _: Subtract => true + case _ => false + } + assert(exists) + + // Check a leaf node. + exists = expression.exists { + case Literal(3, IntegerType) => true + case _ => false + } + assert(exists) + + // Check not exists. + exists = expression.exists { + case Literal(100, IntegerType) => true + case _ => false + } + assert(!exists) + } + test("collectFirst") { val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index af0c26e39a7c8..41da5409feb06 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -766,12 +766,15 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper { assert(daysToMicros(16800, UTC) === expected) // There are some days are skipped entirely in some timezone, skip them here. + // JDK-8274407 and its backport commits renamed 'Pacific/Enderbury' to 'Pacific/Kanton' + // in Java 8u311, 11.0.14, and 17.0.2 val skipped_days = Map[String, Set[Int]]( "Kwajalein" -> Set(8632, 8633, 8634), "Pacific/Apia" -> Set(15338), "Pacific/Enderbury" -> Set(9130, 9131), "Pacific/Fakaofo" -> Set(15338), "Pacific/Kiritimati" -> Set(9130, 9131), + "Pacific/Kanton" -> Set(9130, 9131), "Pacific/Kwajalein" -> Set(8632, 8633, 8634), MIT.getId -> Set(15338)) for (zid <- ALL_TIMEZONES) { @@ -952,4 +955,92 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper { s"The difference is ${(result - expectedMicros) / MICROS_PER_HOUR} hours") } } + + test("SPARK-38195: add a quantity of interval units to a timestamp") { + outstandingZoneIds.foreach { zid => + assert(timestampAdd("MICROSECOND", 1, date(2022, 2, 14, 11, 27, 0, 0, zid), zid) === + date(2022, 2, 14, 11, 27, 0, 1, zid)) + assert(timestampAdd("MILLISECOND", -1, date(2022, 2, 14, 11, 27, 0, 1000, zid), zid) === + date(2022, 2, 14, 11, 27, 0, 0, zid)) + assert(timestampAdd("SECOND", 0, date(2022, 2, 14, 11, 27, 0, 1001, zid), zid) === + date(2022, 2, 14, 11, 27, 0, 1001, zid)) + assert(timestampAdd("MINUTE", -90, date(2022, 2, 14, 11, 0, 1, 1, zid), zid) === + date(2022, 2, 14, 9, 30, 1, 1, zid)) + assert(timestampAdd("HOUR", 24, date(2022, 2, 14, 11, 0, 1, 0, zid), zid) === + date(2022, 2, 15, 11, 0, 1, 0, zid)) + assert(timestampAdd("DAY", 1, date(2022, 2, 28, 11, 1, 0, 0, zid), zid) === + date(2022, 3, 1, 11, 1, 0, 0, zid)) + assert(timestampAdd("DAYOFYEAR", 364, date(2022, 1, 1, 0, 0, 0, 0, zid), zid) === + date(2022, 12, 31, 0, 0, 0, 0, zid)) + assert(timestampAdd("WEEK", 1, date(2022, 2, 14, 11, 43, 0, 1, zid), zid) === + date(2022, 2, 21, 11, 43, 0, 1, zid)) + assert(timestampAdd("MONTH", 10, date(2022, 2, 14, 11, 43, 0, 1, zid), zid) === + date(2022, 12, 14, 11, 43, 0, 1, zid)) + assert(timestampAdd("QUARTER", 1, date(1900, 2, 1, 0, 0, 0, 1, zid), zid) === + date(1900, 5, 1, 0, 0, 0, 1, zid)) + assert(timestampAdd("YEAR", 1, date(9998, 1, 1, 0, 0, 0, 1, zid), zid) === + date(9999, 1, 1, 0, 0, 0, 1, zid)) + assert(timestampAdd("YEAR", -9998, date(9999, 1, 1, 0, 0, 0, 1, zid), zid) === + date(1, 1, 1, 0, 0, 0, 1, zid)) + } + + val e = intercept[IllegalStateException] { + timestampAdd("SECS", 1, date(1969, 1, 1, 0, 0, 0, 1, getZoneId("UTC")), getZoneId("UTC")) + } + assert(e.getMessage === "Got the unexpected unit 'SECS'.") + } + + test("SPARK-38284: difference between two timestamps in units") { + outstandingZoneIds.foreach { zid => + assert(timestampDiff("MICROSECOND", + date(2022, 2, 14, 11, 27, 0, 0, zid), date(2022, 2, 14, 11, 27, 0, 1, zid), zid) === 1) + assert(timestampDiff("MILLISECOND", + date(2022, 2, 14, 11, 27, 0, 1000, zid), date(2022, 2, 14, 11, 27, 0, 0, zid), zid) === -1) + assert(timestampDiff( + "SECOND", + date(2022, 2, 14, 11, 27, 0, 1001, zid), + date(2022, 2, 14, 11, 27, 0, 1002, zid), + zid) === 0) + assert(timestampDiff( + "MINUTE", + date(2022, 2, 14, 11, 0, 1, 1, zid), + date(2022, 2, 14, 9, 30, 1, 1, zid), + zid) === -90) + assert(timestampDiff( + "HOUR", + date(2022, 2, 14, 11, 0, 1, 0, zid), + date(2022, 2, 15, 11, 0, 1, 2, zid), + zid) === 24) + assert(timestampDiff( + "DAY", + date(2022, 2, 28, 11, 1, 0, 0, zid), + date(2022, 3, 1, 11, 1, 0, 0, zid), + zid) === 1) + assert(timestampDiff("WEEK", + date(2022, 2, 14, 11, 43, 0, 1, zid), date(2022, 2, 21, 11, 42, 59, 1, zid), zid) === 0) + assert(timestampDiff("MONTH", + date(2022, 2, 14, 11, 43, 0, 1, zid), date(2022, 12, 14, 11, 43, 0, 1, zid), zid) === 10) + assert(timestampDiff("QUARTER", + date(1900, 2, 1, 0, 0, 0, 1, zid), date(1900, 5, 1, 2, 0, 0, 1, zid), zid) === 1) + assert(timestampDiff( + "YEAR", + date(9998, 1, 1, 0, 0, 0, 1, zid), + date(9999, 1, 1, 0, 0, 1, 2, zid), + zid) === 1) + assert(timestampDiff( + "YEAR", + date(9999, 1, 1, 0, 0, 0, 1, zid), + date(1, 1, 1, 0, 0, 0, 1, zid), + zid) === -9998) + } + + val e = intercept[IllegalStateException] { + timestampDiff( + "SECS", + date(1969, 1, 1, 0, 0, 0, 1, getZoneId("UTC")), + date(2022, 1, 1, 0, 0, 0, 1, getZoneId("UTC")), + getZoneId("UTC")) + } + assert(e.getMessage === "Got the unexpected unit 'SECS'.") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberFormatterSuite.scala similarity index 65% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberUtilsSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberFormatterSuite.scala index 66a17dceed745..81264f4e85080 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberFormatterSuite.scala @@ -19,43 +19,37 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.util.NumberUtils.{format, parse} import org.apache.spark.sql.types.Decimal import org.apache.spark.unsafe.types.UTF8String -class NumberUtilsSuite extends SparkFunSuite { +class NumberFormatterSuite extends SparkFunSuite { - private def failParseWithInvalidInput( - input: UTF8String, numberFormat: String, errorMsg: String): Unit = { - val e = intercept[IllegalArgumentException](parse(input, numberFormat)) + private def invalidNumberFormat(numberFormat: String, errorMsg: String): Unit = { + val testNumberFormatter = new TestNumberFormatter(numberFormat) + val e = intercept[AnalysisException](testNumberFormatter.checkWithException()) assert(e.getMessage.contains(errorMsg)) } - private def failParseWithAnalysisException( + private def failParseWithInvalidInput( input: UTF8String, numberFormat: String, errorMsg: String): Unit = { - val e = intercept[AnalysisException](parse(input, numberFormat)) - assert(e.getMessage.contains(errorMsg)) - } - - private def failFormatWithAnalysisException( - input: Decimal, numberFormat: String, errorMsg: String): Unit = { - val e = intercept[AnalysisException](format(input, numberFormat)) + val testNumberFormatter = new TestNumberFormatter(numberFormat) + val e = intercept[IllegalArgumentException](testNumberFormatter.parse(input)) assert(e.getMessage.contains(errorMsg)) } test("parse") { - failParseWithInvalidInput(UTF8String.fromString("454"), "", - "Format '' used for parsing string to number or formatting number to string is invalid") + invalidNumberFormat("", "Number format cannot be empty") // Test '9' and '0' failParseWithInvalidInput(UTF8String.fromString("454"), "9", - "Format '9' used for parsing string to number or formatting number to string is invalid") + "The input string '454' does not match the given number format: '9'") failParseWithInvalidInput(UTF8String.fromString("454"), "99", - "Format '99' used for parsing string to number or formatting number to string is invalid") + "The input string '454' does not match the given number format: '99'") Seq( ("454", "999") -> Decimal(454), ("054", "999") -> Decimal(54), + ("54", "999") -> Decimal(54), ("404", "999") -> Decimal(404), ("450", "999") -> Decimal(450), ("454", "9999") -> Decimal(454), @@ -63,17 +57,20 @@ class NumberUtilsSuite extends SparkFunSuite { ("404", "9999") -> Decimal(404), ("450", "9999") -> Decimal(450) ).foreach { case ((str, format), expected) => - assert(parse(UTF8String.fromString(str), format) === expected) + val builder = new TestNumberFormatter(format) + builder.check() + assert(builder.parse(UTF8String.fromString(str)) === expected) } failParseWithInvalidInput(UTF8String.fromString("454"), "0", - "Format '0' used for parsing string to number or formatting number to string is invalid") + "The input string '454' does not match the given number format: '0'") failParseWithInvalidInput(UTF8String.fromString("454"), "00", - "Format '00' used for parsing string to number or formatting number to string is invalid") + "The input string '454' does not match the given number format: '00'") Seq( ("454", "000") -> Decimal(454), ("054", "000") -> Decimal(54), + ("54", "000") -> Decimal(54), ("404", "000") -> Decimal(404), ("450", "000") -> Decimal(450), ("454", "0000") -> Decimal(454), @@ -81,14 +78,16 @@ class NumberUtilsSuite extends SparkFunSuite { ("404", "0000") -> Decimal(404), ("450", "0000") -> Decimal(450) ).foreach { case ((str, format), expected) => - assert(parse(UTF8String.fromString(str), format) === expected) + val builder = new TestNumberFormatter(format) + builder.check() + assert(builder.parse(UTF8String.fromString(str)) === expected) } // Test '.' and 'D' failParseWithInvalidInput(UTF8String.fromString("454.2"), "999", - "Format '999' used for parsing string to number or formatting number to string is invalid") + "The input string '454.2' does not match the given number format: '999'") failParseWithInvalidInput(UTF8String.fromString("454.23"), "999.9", - "Format '999.9' used for parsing string to number or formatting number to string is invalid") + "The input string '454.23' does not match the given number format: '999.9'") Seq( ("454.2", "999.9") -> Decimal(454.2), @@ -116,17 +115,19 @@ class NumberUtilsSuite extends SparkFunSuite { ("4542.", "9999D") -> Decimal(4542), ("4542.", "0000D") -> Decimal(4542) ).foreach { case ((str, format), expected) => - assert(parse(UTF8String.fromString(str), format) === expected) + val builder = new TestNumberFormatter(format) + builder.check() + assert(builder.parse(UTF8String.fromString(str)) === expected) } - failParseWithAnalysisException(UTF8String.fromString("454.3.2"), "999.9.9", - "Multiple 'D' or '.' in '999.9.9'") - failParseWithAnalysisException(UTF8String.fromString("454.3.2"), "999D9D9", - "Multiple 'D' or '.' in '999D9D9'") - failParseWithAnalysisException(UTF8String.fromString("454.3.2"), "999.9D9", - "Multiple 'D' or '.' in '999.9D9'") - failParseWithAnalysisException(UTF8String.fromString("454.3.2"), "999D9.9", - "Multiple 'D' or '.' in '999D9.9'") + invalidNumberFormat( + "999.9.9", "At most one 'D' or '.' is allowed in the number format: '999.9.9'") + invalidNumberFormat( + "999D9D9", "At most one 'D' or '.' is allowed in the number format: '999D9D9'") + invalidNumberFormat( + "999.9D9", "At most one 'D' or '.' is allowed in the number format: '999.9D9'") + invalidNumberFormat( + "999D9.9", "At most one 'D' or '.' is allowed in the number format: '999D9.9'") // Test ',' and 'G' Seq( @@ -145,9 +146,15 @@ class NumberUtilsSuite extends SparkFunSuite { (",454,367", ",999,999") -> Decimal(454367), (",454,367", ",000,000") -> Decimal(454367), (",454,367", "G999G999") -> Decimal(454367), - (",454,367", "G000G000") -> Decimal(454367) + (",454,367", "G000G000") -> Decimal(454367), + (",454,367", "999,999") -> Decimal(454367), + (",454,367", "000,000") -> Decimal(454367), + (",454,367", "999G999") -> Decimal(454367), + (",454,367", "000G000") -> Decimal(454367) ).foreach { case ((str, format), expected) => - assert(parse(UTF8String.fromString(str), format) === expected) + val builder = new TestNumberFormatter(format) + builder.check() + assert(builder.parse(UTF8String.fromString(str)) === expected) } // Test '$' @@ -157,13 +164,14 @@ class NumberUtilsSuite extends SparkFunSuite { ("78.12$", "99.99$") -> Decimal(78.12), ("78.12$", "00.00$") -> Decimal(78.12) ).foreach { case ((str, format), expected) => - assert(parse(UTF8String.fromString(str), format) === expected) + val builder = new TestNumberFormatter(format) + builder.check() + assert(builder.parse(UTF8String.fromString(str)) === expected) } - failParseWithAnalysisException(UTF8String.fromString("78$.12"), "99$.99", - "'$' must be the first or last char in '99$.99'") - failParseWithAnalysisException(UTF8String.fromString("$78.12$"), "$99.99$", - "Multiple '$' in '$99.99$'") + invalidNumberFormat( + "99$.99", "'$' must be the first or last char in the number format: '99$.99'") + invalidNumberFormat("$99.99$", "At most one '$' is allowed in the number format: '$99.99$'") // Test '-' and 'S' Seq( @@ -178,19 +186,20 @@ class NumberUtilsSuite extends SparkFunSuite { ("12,454.8-", "99G999D9S") -> Decimal(-12454.8), ("00,454.8-", "99G999.9S") -> Decimal(-454.8) ).foreach { case ((str, format), expected) => - assert(parse(UTF8String.fromString(str), format) === expected) + val builder = new TestNumberFormatter(format) + builder.check() + assert(builder.parse(UTF8String.fromString(str)) === expected) } - failParseWithAnalysisException(UTF8String.fromString("4-54"), "9S99", - "'S' or '-' must be the first or last char in '9S99'") - failParseWithAnalysisException(UTF8String.fromString("4-54"), "9-99", - "'S' or '-' must be the first or last char in '9-99'") - failParseWithAnalysisException(UTF8String.fromString("454.3--"), "999D9SS", - "Multiple 'S' or '-' in '999D9SS'") + invalidNumberFormat( + "9S99", "'S' or '-' must be the first or last char in the number format: '9S99'") + invalidNumberFormat( + "9-99", "'S' or '-' must be the first or last char in the number format: '9-99'") + invalidNumberFormat( + "999D9SS", "At most one 'S' or '-' is allowed in the number format: '999D9SS'") } test("format") { - assert(format(Decimal(454), "") === "") // Test '9' and '0' Seq( @@ -214,8 +223,10 @@ class NumberUtilsSuite extends SparkFunSuite { (Decimal(54), "0000") -> "0054", (Decimal(404), "0000") -> "0404", (Decimal(450), "0000") -> "0450" - ).foreach { case ((decimal, str), expected) => - assert(format(decimal, str) === expected) + ).foreach { case ((decimal, format), expected) => + val builder = new TestNumberFormatter(format, false) + builder.check() + assert(builder.format(decimal) === expected) } // Test '.' and 'D' @@ -240,19 +251,12 @@ class NumberUtilsSuite extends SparkFunSuite { (Decimal(4542), "0000.") -> "4542.", (Decimal(4542), "9999D") -> "4542.", (Decimal(4542), "0000D") -> "4542." - ).foreach { case ((decimal, str), expected) => - assert(format(decimal, str) === expected) + ).foreach { case ((decimal, format), expected) => + val builder = new TestNumberFormatter(format, false) + builder.check() + assert(builder.format(decimal) === expected) } - failFormatWithAnalysisException(Decimal(454.32), "999.9.9", - "Multiple 'D' or '.' in '999.9.9'") - failFormatWithAnalysisException(Decimal(454.32), "999D9D9", - "Multiple 'D' or '.' in '999D9D9'") - failFormatWithAnalysisException(Decimal(454.32), "999.9D9", - "Multiple 'D' or '.' in '999.9D9'") - failFormatWithAnalysisException(Decimal(454.32), "999D9.9", - "Multiple 'D' or '.' in '999D9.9'") - // Test ',' and 'G' Seq( (Decimal(12454), "99,999") -> "12,454", @@ -271,8 +275,10 @@ class NumberUtilsSuite extends SparkFunSuite { (Decimal(454367), ",000,000") -> ",454,367", (Decimal(454367), "G999G999") -> ",454,367", (Decimal(454367), "G000G000") -> ",454,367" - ).foreach { case ((decimal, str), expected) => - assert(format(decimal, str) === expected) + ).foreach { case ((decimal, format), expected) => + val builder = new TestNumberFormatter(format, false) + builder.check() + assert(builder.format(decimal) === expected) } // Test '$' @@ -281,15 +287,12 @@ class NumberUtilsSuite extends SparkFunSuite { (Decimal(78.12), "$00.00") -> "$78.12", (Decimal(78.12), "99.99$") -> "78.12$", (Decimal(78.12), "00.00$") -> "78.12$" - ).foreach { case ((decimal, str), expected) => - assert(format(decimal, str) === expected) + ).foreach { case ((decimal, format), expected) => + val builder = new TestNumberFormatter(format, false) + builder.check() + assert(builder.format(decimal) === expected) } - failFormatWithAnalysisException(Decimal(78.12), "99$.99", - "'$' must be the first or last char in '99$.99'") - failFormatWithAnalysisException(Decimal(78.12), "$99.99$", - "Multiple '$' in '$99.99$'") - // Test '-' and 'S' Seq( (Decimal(-454), "999-") -> "454-", @@ -302,16 +305,11 @@ class NumberUtilsSuite extends SparkFunSuite { (Decimal(-454), "S000") -> "-454", (Decimal(-12454.8), "99G999D9S") -> "12,454.8-", (Decimal(-454.8), "99G999.9S") -> "454.8-" - ).foreach { case ((decimal, str), expected) => - assert(format(decimal, str) === expected) + ).foreach { case ((decimal, format), expected) => + val builder = new TestNumberFormatter(format, false) + builder.check() + assert(builder.format(decimal) === expected) } - - failFormatWithAnalysisException(Decimal(-454), "9S99", - "'S' or '-' must be the first or last char in '9S99'") - failFormatWithAnalysisException(Decimal(-454), "9-99", - "'S' or '-' must be the first or last char in '9-99'") - failFormatWithAnalysisException(Decimal(-454.3), "999D9SS", - "Multiple 'S' or '-' in '999D9SS'") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala index 0cca1cc9bebf2..d00bc31e07f19 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala @@ -820,7 +820,7 @@ class CatalogSuite extends SparkFunSuite { assert(catalog.namespaceExists(testNs) === false) - val ret = catalog.dropNamespace(testNs) + val ret = catalog.dropNamespace(testNs, cascade = false) assert(ret === false) } @@ -833,7 +833,7 @@ class CatalogSuite extends SparkFunSuite { assert(catalog.namespaceExists(testNs) === true) assert(catalog.loadNamespaceMetadata(testNs).asScala === Map("property" -> "value")) - val ret = catalog.dropNamespace(testNs) + val ret = catalog.dropNamespace(testNs, cascade = false) assert(ret === true) assert(catalog.namespaceExists(testNs) === false) @@ -845,7 +845,7 @@ class CatalogSuite extends SparkFunSuite { catalog.createNamespace(testNs, Map("property" -> "value").asJava) catalog.createTable(testIdent, schema, Array.empty, emptyProps) - assert(catalog.dropNamespace(testNs)) + assert(catalog.dropNamespace(testNs, cascade = true)) assert(!catalog.namespaceExists(testNs)) intercept[NoSuchNamespaceException](catalog.listTables(testNs)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/EnumTypeSetBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/EnumTypeSetBenchmark.scala index a23ff6eaa2a55..a918bae4a8402 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/EnumTypeSetBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/EnumTypeSetBenchmark.scala @@ -106,7 +106,9 @@ object EnumTypeSetBenchmark extends BenchmarkBase { } benchmark.addCase("Use EnumSet") { _: Int => - capabilities.foreach(enumSet.contains) + for (_ <- 0L until valuesPerIteration) { + capabilities.foreach(enumSet.contains) + } } benchmark.run() } @@ -131,7 +133,9 @@ object EnumTypeSetBenchmark extends BenchmarkBase { } benchmark.addCase("Use EnumSet") { _: Int => - capabilities.foreach(creatEnumSetFunctions.apply().contains) + for (_ <- 0L until valuesPerIteration) { + capabilities.foreach(creatEnumSetFunctions.apply().contains) + } } benchmark.run() } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryPartitionTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryPartitionTable.scala index 58dc4847111e2..671d22040e169 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryPartitionTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryPartitionTable.scala @@ -43,7 +43,7 @@ class InMemoryPartitionTable( new ConcurrentHashMap[InternalRow, util.Map[String, String]]() def partitionSchema: StructType = { - val partitionColumnNames = partitioning.toSeq.asPartitionColumns + val partitionColumnNames = partitioning.toSeq.convertTransforms._1 new StructType(schema.filter(p => partitionColumnNames.contains(p.name)).toArray) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index fa8be1b8fa3c0..5d72b2060bfd8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric} import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.connector.write._ import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} +import org.apache.spark.sql.internal.connector.SupportsStreamingUpdateAsAppend import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -80,6 +81,7 @@ class InMemoryTable( case _: DaysTransform => case _: HoursTransform => case _: BucketTransform => + case _: SortedBucketTransform => case t if !allowUnsupportedTransforms => throw new IllegalArgumentException(s"Transform $t is not a supported transform") } @@ -161,10 +163,15 @@ class InMemoryTable( case (v, t) => throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") } - case BucketTransform(numBuckets, ref, _) => - val (value, dataType) = extractor(ref.fieldNames, cleanedSchema, row) - val valueHashCode = if (value == null) 0 else value.hashCode - ((valueHashCode + 31 * dataType.hashCode()) & Integer.MAX_VALUE) % numBuckets + case BucketTransform(numBuckets, cols, _) => + val valueTypePairs = cols.map(col => extractor(col.fieldNames, cleanedSchema, row)) + var valueHashCode = 0 + valueTypePairs.foreach( pair => + if ( pair._1 != null) valueHashCode += pair._1.hashCode() + ) + var dataTypeHashCode = 0 + valueTypePairs.foreach(dataTypeHashCode += _._2.hashCode()) + ((valueHashCode + 31 * dataTypeHashCode) & Integer.MAX_VALUE) % numBuckets } } @@ -305,7 +312,9 @@ class InMemoryTable( InMemoryTable.maybeSimulateFailedTableWrite(new CaseInsensitiveStringMap(properties)) InMemoryTable.maybeSimulateFailedTableWrite(info.options) - new WriteBuilder with SupportsTruncate with SupportsOverwrite with SupportsDynamicOverwrite { + new WriteBuilder with SupportsTruncate with SupportsOverwrite + with SupportsDynamicOverwrite with SupportsStreamingUpdateAsAppend { + private var writer: BatchWrite = Append private var streamingWriter: StreamingWrite = StreamingAppend diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala index d8e6bc4149d98..428aec703674d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala @@ -22,7 +22,7 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ -import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NonEmptyNamespaceException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} import org.apache.spark.sql.connector.expressions.{SortOrder, Transform} import org.apache.spark.sql.types.StructType @@ -213,10 +213,16 @@ class InMemoryTableCatalog extends BasicInMemoryTableCatalog with SupportsNamesp namespaces.put(namespace.toList, CatalogV2Util.applyNamespaceChanges(metadata, changes)) } - override def dropNamespace(namespace: Array[String]): Boolean = { - listNamespaces(namespace).foreach(dropNamespace) + override def dropNamespace(namespace: Array[String], cascade: Boolean): Boolean = { try { - listTables(namespace).foreach(dropTable) + if (!cascade) { + if (listTables(namespace).nonEmpty || listNamespaces(namespace).nonEmpty) { + throw new NonEmptyNamespaceException(namespace) + } + } else { + listNamespaces(namespace).foreach(namespace => dropNamespace(namespace, cascade)) + listTables(namespace).foreach(dropTable) + } } catch { case _: NoSuchNamespaceException => } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala index b2371ce667ffc..54ab1df3fa8f8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.connector.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst +import org.apache.spark.sql.connector.expressions.LogicalExpressions.bucket import org.apache.spark.sql.types.DataType class TransformExtractorSuite extends SparkFunSuite { @@ -139,9 +140,9 @@ class TransformExtractorSuite extends SparkFunSuite { } bucketTransform match { - case BucketTransform(numBuckets, FieldReference(seq), _) => + case BucketTransform(numBuckets, cols, _) => assert(numBuckets === 16) - assert(seq === Seq("a", "b")) + assert(cols(0).fieldNames === Seq("a", "b")) case _ => fail("Did not match BucketTransform extractor") } @@ -153,4 +154,61 @@ class TransformExtractorSuite extends SparkFunSuite { // expected } } + + test("Sorted Bucket extractor") { + val col = Array(ref("a"), ref("b")) + val sortedCol = Array(ref("c"), ref("d")) + + val sortedBucketTransform = new Transform { + override def name: String = "sorted_bucket" + override def references: Array[NamedReference] = col ++ sortedCol + override def arguments: Array[Expression] = (col :+ lit(16)) ++ sortedCol + override def describe: String = s"bucket(16, ${col(0).describe}, ${col(1).describe} " + + s"${sortedCol(0).describe} ${sortedCol(1).describe})" + } + + sortedBucketTransform match { + case BucketTransform(numBuckets, cols, sortCols) => + assert(numBuckets === 16) + assert(cols.flatMap(c => c.fieldNames()) === Seq("a", "b")) + assert(sortCols.flatMap(c => c.fieldNames()) === Seq("c", "d")) + case _ => + fail("Did not match BucketTransform extractor") + } + } + + test("test bucket") { + val col = Array(ref("a"), ref("b")) + val sortedCol = Array(ref("c"), ref("d")) + + val bucketTransform = bucket(16, col) + val reference1 = bucketTransform.references + assert(reference1.length == 2) + assert(reference1(0).fieldNames() === Seq("a")) + assert(reference1(1).fieldNames() === Seq("b")) + val arguments1 = bucketTransform.arguments + assert(arguments1.length == 3) + assert(arguments1(0).asInstanceOf[LiteralValue[Integer]].value === 16) + assert(arguments1(1).asInstanceOf[NamedReference].fieldNames() === Seq("a")) + assert(arguments1(2).asInstanceOf[NamedReference].fieldNames() === Seq("b")) + val copied1 = bucketTransform.withReferences(reference1) + assert(copied1.equals(bucketTransform)) + + val sortedBucketTransform = bucket(16, col, sortedCol) + val reference2 = sortedBucketTransform.references + assert(reference2.length == 4) + assert(reference2(0).fieldNames() === Seq("a")) + assert(reference2(1).fieldNames() === Seq("b")) + assert(reference2(2).fieldNames() === Seq("c")) + assert(reference2(3).fieldNames() === Seq("d")) + val arguments2 = sortedBucketTransform.arguments + assert(arguments2.length == 5) + assert(arguments2(0).asInstanceOf[NamedReference].fieldNames() === Seq("a")) + assert(arguments2(1).asInstanceOf[NamedReference].fieldNames() === Seq("b")) + assert(arguments2(2).asInstanceOf[LiteralValue[Integer]].value === 16) + assert(arguments2(3).asInstanceOf[NamedReference].fieldNames() === Seq("c")) + assert(arguments2(4).asInstanceOf[NamedReference].fieldNames() === Seq("d")) + val copied2 = sortedBucketTransform.withReferences(reference2) + assert(copied2.equals(sortedBucketTransform)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala index a7e22e9403275..16f122334f370 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala @@ -51,7 +51,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { test("SPARK-24849: toDDL - simple struct") { val struct = StructType(Seq(StructField("a", IntegerType))) - assert(struct.toDDL == "`a` INT") + assert(struct.toDDL == "a INT") } test("SPARK-24849: round trip toDDL - fromDDL") { @@ -61,7 +61,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { } test("SPARK-24849: round trip fromDDL - toDDL") { - val struct = "`a` MAP,`b` INT" + val struct = "a MAP,b INT" assert(fromDDL(struct).toDDL === struct) } @@ -70,14 +70,14 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { val struct = new StructType() .add("metaData", new StructType().add("eventId", StringType)) - assert(struct.toDDL == "`metaData` STRUCT<`eventId`: STRING>") + assert(struct.toDDL == "metaData STRUCT") } test("SPARK-24849: toDDL should output field's comment") { val struct = StructType(Seq( StructField("b", BooleanType).withComment("Field's comment"))) - assert(struct.toDDL == """`b` BOOLEAN COMMENT 'Field\'s comment'""") + assert(struct.toDDL == """b BOOLEAN COMMENT 'Field\'s comment'""") } private val nestedStruct = new StructType() @@ -89,7 +89,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { ).withComment("comment")) test("SPARK-33846: toDDL should output nested field's comment") { - val ddl = "`a` STRUCT<`b`: STRUCT<`c`: STRING COMMENT 'Deep Nested comment'> " + + val ddl = "a STRUCT " + "COMMENT 'Nested comment'> COMMENT 'comment'" assert(nestedStruct.toDDL == ddl) } @@ -153,7 +153,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { } test("interval keyword in schema string") { - val interval = "`a` INTERVAL" + val interval = "a INTERVAL" assert(fromDDL(interval).toDDL === interval) } @@ -250,10 +250,10 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { } test("SPARK-35285: ANSI interval types in schema") { - val yearMonthInterval = "`ymi` INTERVAL YEAR TO MONTH" + val yearMonthInterval = "ymi INTERVAL YEAR TO MONTH" assert(fromDDL(yearMonthInterval).toDDL === yearMonthInterval) - val dayTimeInterval = "`dti` INTERVAL DAY TO SECOND" + val dayTimeInterval = "dti INTERVAL DAY TO SECOND" assert(fromDDL(dayTimeInterval).toDDL === dayTimeInterval) } diff --git a/sql/core/benchmarks/BuiltInDataSourceWriteBenchmark-jdk11-results.txt b/sql/core/benchmarks/BuiltInDataSourceWriteBenchmark-jdk11-results.txt index d1395ef07eb0d..8ed23d4ba5c31 100644 --- a/sql/core/benchmarks/BuiltInDataSourceWriteBenchmark-jdk11-results.txt +++ b/sql/core/benchmarks/BuiltInDataSourceWriteBenchmark-jdk11-results.txt @@ -2,59 +2,69 @@ Parquet writer benchmark ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.14+9-LTS on Linux 5.11.0-1027-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz -Parquet writer benchmark: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +Parquet(PARQUET_1_0) writer benchmark: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Output Single Int Column 2617 2756 197 6.0 166.4 1.0X -Output Single Double Column 2753 2782 41 5.7 175.0 1.0X -Output Int and String Column 7625 7664 54 2.1 484.8 0.3X -Output Partitions 4964 5023 84 3.2 315.6 0.5X -Output Buckets 6988 7051 88 2.3 444.3 0.4X +Output Single Int Column 2199 2291 130 7.2 139.8 1.0X +Output Single Double Column 2724 2753 40 5.8 173.2 0.8X +Output Int and String Column 6836 6998 229 2.3 434.6 0.3X +Output Partitions 4936 4970 49 3.2 313.8 0.4X +Output Buckets 6672 6708 50 2.4 424.2 0.3X + +OpenJDK 64-Bit Server VM 11.0.14+9-LTS on Linux 5.11.0-1027-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +Parquet(PARQUET_2_0) writer benchmark: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Output Single Int Column 2610 2622 17 6.0 166.0 1.0X +Output Single Double Column 2389 2425 51 6.6 151.9 1.1X +Output Int and String Column 7516 7540 35 2.1 477.9 0.3X +Output Partitions 5190 5195 8 3.0 329.9 0.5X +Output Buckets 6444 6446 1 2.4 409.7 0.4X ================================================================================================ ORC writer benchmark ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.14+9-LTS on Linux 5.11.0-1027-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz ORC writer benchmark: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Output Single Int Column 1972 1988 23 8.0 125.4 1.0X -Output Single Double Column 2230 2312 116 7.1 141.8 0.9X -Output Int and String Column 5748 5858 156 2.7 365.4 0.3X -Output Partitions 4083 4104 30 3.9 259.6 0.5X -Output Buckets 6062 6083 29 2.6 385.4 0.3X +Output Single Int Column 1589 1624 49 9.9 101.0 1.0X +Output Single Double Column 2221 2243 32 7.1 141.2 0.7X +Output Int and String Column 5543 5640 138 2.8 352.4 0.3X +Output Partitions 4135 4284 212 3.8 262.9 0.4X +Output Buckets 6100 6234 190 2.6 387.8 0.3X ================================================================================================ JSON writer benchmark ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.14+9-LTS on Linux 5.11.0-1027-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz JSON writer benchmark: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Output Single Int Column 2444 2495 72 6.4 155.4 1.0X -Output Single Double Column 3384 3388 5 4.6 215.1 0.7X -Output Int and String Column 5762 5771 13 2.7 366.4 0.4X -Output Partitions 4727 4777 70 3.3 300.6 0.5X -Output Buckets 6420 6541 171 2.4 408.2 0.4X +Output Single Int Column 2475 2492 24 6.4 157.3 1.0X +Output Single Double Column 3524 3525 3 4.5 224.0 0.7X +Output Int and String Column 5480 5533 74 2.9 348.4 0.5X +Output Partitions 4735 4748 19 3.3 301.0 0.5X +Output Buckets 6251 6264 19 2.5 397.4 0.4X ================================================================================================ CSV writer benchmark ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.14+9-LTS on Linux 5.11.0-1027-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz CSV writer benchmark: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Output Single Int Column 3301 3325 34 4.8 209.8 1.0X -Output Single Double Column 3897 3923 37 4.0 247.8 0.8X -Output Int and String Column 6484 6487 4 2.4 412.3 0.5X -Output Partitions 5896 5899 5 2.7 374.8 0.6X -Output Buckets 7919 7927 12 2.0 503.5 0.4X +Output Single Int Column 3293 3301 11 4.8 209.4 1.0X +Output Single Double Column 4085 4095 14 3.9 259.7 0.8X +Output Int and String Column 6369 6375 8 2.5 404.9 0.5X +Output Partitions 6067 6090 32 2.6 385.7 0.5X +Output Buckets 7736 7863 180 2.0 491.8 0.4X diff --git a/sql/core/benchmarks/BuiltInDataSourceWriteBenchmark-jdk17-results.txt b/sql/core/benchmarks/BuiltInDataSourceWriteBenchmark-jdk17-results.txt index b4e0345fa5644..5f64bf7b624cb 100644 --- a/sql/core/benchmarks/BuiltInDataSourceWriteBenchmark-jdk17-results.txt +++ b/sql/core/benchmarks/BuiltInDataSourceWriteBenchmark-jdk17-results.txt @@ -2,59 +2,69 @@ Parquet writer benchmark ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz -Parquet writer benchmark: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +OpenJDK 64-Bit Server VM 17.0.2+8-LTS on Linux 5.11.0-1027-azure +Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +Parquet(PARQUET_1_0) writer benchmark: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Output Single Int Column 2948 2954 8 5.3 187.4 1.0X -Output Single Double Column 2978 3012 48 5.3 189.3 1.0X -Output Int and String Column 8568 8651 117 1.8 544.8 0.3X -Output Partitions 5196 5273 110 3.0 330.3 0.6X -Output Buckets 6761 6800 55 2.3 429.8 0.4X +Output Single Int Column 3119 3167 68 5.0 198.3 1.0X +Output Single Double Column 3156 3298 201 5.0 200.7 1.0X +Output Int and String Column 8070 8207 193 1.9 513.1 0.4X +Output Partitions 5636 5887 355 2.8 358.3 0.6X +Output Buckets 7523 7541 25 2.1 478.3 0.4X + +OpenJDK 64-Bit Server VM 17.0.2+8-LTS on Linux 5.11.0-1027-azure +Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +Parquet(PARQUET_2_0) writer benchmark: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Output Single Int Column 3678 3787 154 4.3 233.9 1.0X +Output Single Double Column 3201 3229 39 4.9 203.5 1.1X +Output Int and String Column 8322 8333 15 1.9 529.1 0.4X +Output Partitions 6184 6202 26 2.5 393.1 0.6X +Output Buckets 7341 7406 93 2.1 466.7 0.5X ================================================================================================ ORC writer benchmark ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 17.0.2+8-LTS on Linux 5.11.0-1027-azure +Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz ORC writer benchmark: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Output Single Int Column 2234 2244 14 7.0 142.1 1.0X -Output Single Double Column 2824 2876 73 5.6 179.6 0.8X -Output Int and String Column 7665 7753 124 2.1 487.3 0.3X -Output Partitions 4985 5004 28 3.2 316.9 0.4X -Output Buckets 6765 6814 69 2.3 430.1 0.3X +Output Single Int Column 2264 2301 53 6.9 143.9 1.0X +Output Single Double Column 2929 3092 230 5.4 186.2 0.8X +Output Int and String Column 7562 7713 212 2.1 480.8 0.3X +Output Partitions 5265 5318 74 3.0 334.8 0.4X +Output Buckets 7117 7160 61 2.2 452.5 0.3X ================================================================================================ JSON writer benchmark ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 17.0.2+8-LTS on Linux 5.11.0-1027-azure +Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz JSON writer benchmark: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Output Single Int Column 2783 2826 61 5.7 177.0 1.0X -Output Single Double Column 3983 4009 37 3.9 253.3 0.7X -Output Int and String Column 6656 6679 32 2.4 423.2 0.4X -Output Partitions 5289 5305 22 3.0 336.3 0.5X -Output Buckets 6584 6695 156 2.4 418.6 0.4X +Output Single Int Column 2881 2964 118 5.5 183.2 1.0X +Output Single Double Column 4568 4578 14 3.4 290.4 0.6X +Output Int and String Column 6943 7078 192 2.3 441.4 0.4X +Output Partitions 5862 5883 30 2.7 372.7 0.5X +Output Buckets 7176 7297 170 2.2 456.3 0.4X ================================================================================================ CSV writer benchmark ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 17.0.2+8-LTS on Linux 5.11.0-1027-azure +Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz CSV writer benchmark: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Output Single Int Column 4271 4338 95 3.7 271.5 1.0X -Output Single Double Column 5145 5207 87 3.1 327.1 0.8X -Output Int and String Column 7573 7682 154 2.1 481.5 0.6X -Output Partitions 6644 6675 44 2.4 422.4 0.6X -Output Buckets 8497 8539 59 1.9 540.2 0.5X +Output Single Int Column 4571 4577 8 3.4 290.6 1.0X +Output Single Double Column 5769 5794 34 2.7 366.8 0.8X +Output Int and String Column 8372 8414 59 1.9 532.3 0.5X +Output Partitions 7186 7215 41 2.2 456.9 0.6X +Output Buckets 9297 9319 31 1.7 591.1 0.5X diff --git a/sql/core/benchmarks/BuiltInDataSourceWriteBenchmark-results.txt b/sql/core/benchmarks/BuiltInDataSourceWriteBenchmark-results.txt index 442b0cce429dc..88b82991c2d16 100644 --- a/sql/core/benchmarks/BuiltInDataSourceWriteBenchmark-results.txt +++ b/sql/core/benchmarks/BuiltInDataSourceWriteBenchmark-results.txt @@ -2,59 +2,69 @@ Parquet writer benchmark ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 1.8.0_322-b06 on Linux 5.11.0-1027-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz -Parquet writer benchmark: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +Parquet(PARQUET_1_0) writer benchmark: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Output Single Int Column 1968 2146 251 8.0 125.1 1.0X -Output Single Double Column 1921 2073 215 8.2 122.1 1.0X -Output Int and String Column 5630 6171 766 2.8 357.9 0.3X -Output Partitions 3699 3733 48 4.3 235.2 0.5X -Output Buckets 4705 4746 59 3.3 299.1 0.4X +Output Single Int Column 2089 2185 135 7.5 132.8 1.0X +Output Single Double Column 2156 2212 80 7.3 137.1 1.0X +Output Int and String Column 5673 5705 46 2.8 360.7 0.4X +Output Partitions 3917 4052 192 4.0 249.0 0.5X +Output Buckets 4782 5108 461 3.3 304.0 0.4X + +OpenJDK 64-Bit Server VM 1.8.0_322-b06 on Linux 5.11.0-1027-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +Parquet(PARQUET_2_0) writer benchmark: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Output Single Int Column 2201 2208 10 7.1 139.9 1.0X +Output Single Double Column 2057 2066 13 7.6 130.8 1.1X +Output Int and String Column 5969 6011 60 2.6 379.5 0.4X +Output Partitions 3777 3823 65 4.2 240.1 0.6X +Output Buckets 4889 4895 8 3.2 310.8 0.5X ================================================================================================ ORC writer benchmark ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 1.8.0_322-b06 on Linux 5.11.0-1027-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz ORC writer benchmark: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Output Single Int Column 1507 1546 54 10.4 95.8 1.0X -Output Single Double Column 1641 1650 12 9.6 104.4 0.9X -Output Int and String Column 5671 5738 95 2.8 360.6 0.3X -Output Partitions 3068 3112 63 5.1 195.0 0.5X -Output Buckets 4635 4894 366 3.4 294.7 0.3X +Output Single Int Column 1634 1645 16 9.6 103.9 1.0X +Output Single Double Column 1680 1691 15 9.4 106.8 1.0X +Output Int and String Column 5603 5611 11 2.8 356.3 0.3X +Output Partitions 3091 3116 36 5.1 196.5 0.5X +Output Buckets 4472 4734 372 3.5 284.3 0.4X ================================================================================================ JSON writer benchmark ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 1.8.0_322-b06 on Linux 5.11.0-1027-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz JSON writer benchmark: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Output Single Int Column 2206 2243 51 7.1 140.3 1.0X -Output Single Double Column 2868 2876 11 5.5 182.3 0.8X -Output Int and String Column 6017 6140 175 2.6 382.5 0.4X -Output Partitions 3602 3602 0 4.4 229.0 0.6X -Output Buckets 5308 5340 46 3.0 337.5 0.4X +Output Single Int Column 2359 2380 29 6.7 150.0 1.0X +Output Single Double Column 2971 2991 29 5.3 188.9 0.8X +Output Int and String Column 6070 6244 246 2.6 385.9 0.4X +Output Partitions 3635 3686 73 4.3 231.1 0.6X +Output Buckets 5066 5082 22 3.1 322.1 0.5X ================================================================================================ CSV writer benchmark ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 1.8.0_322-b06 on Linux 5.11.0-1027-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz CSV writer benchmark: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Output Single Int Column 3136 3137 2 5.0 199.4 1.0X -Output Single Double Column 3504 3505 2 4.5 222.8 0.9X -Output Int and String Column 7075 7473 562 2.2 449.8 0.4X -Output Partitions 5067 5228 227 3.1 322.2 0.6X -Output Buckets 6695 6718 33 2.3 425.7 0.5X +Output Single Int Column 3116 3117 2 5.0 198.1 1.0X +Output Single Double Column 3575 3695 170 4.4 227.3 0.9X +Output Int and String Column 7040 7482 626 2.2 447.6 0.4X +Output Partitions 4819 4995 249 3.3 306.4 0.6X +Output Buckets 6638 6656 25 2.4 422.0 0.5X diff --git a/sql/core/benchmarks/DataSourceReadBenchmark-jdk11-results.txt b/sql/core/benchmarks/DataSourceReadBenchmark-jdk11-results.txt index fb152e20c9449..25c43d8273df8 100644 --- a/sql/core/benchmarks/DataSourceReadBenchmark-jdk11-results.txt +++ b/sql/core/benchmarks/DataSourceReadBenchmark-jdk11-results.txt @@ -2,269 +2,322 @@ SQL Single Numeric Column Scan ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single BOOLEAN Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 11834 11929 134 1.3 752.4 1.0X -SQL Json 8574 8597 32 1.8 545.1 1.4X -SQL Parquet Vectorized 116 136 17 135.5 7.4 102.0X -SQL Parquet MR 1703 1715 17 9.2 108.2 7.0X -SQL ORC Vectorized 172 215 48 91.2 11.0 68.6X -SQL ORC MR 1819 1825 8 8.6 115.7 6.5X - -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +SQL CSV 9636 9771 191 1.6 612.6 1.0X +SQL Json 7960 8227 378 2.0 506.1 1.2X +SQL Parquet Vectorized: DataPageV1 113 129 12 139.7 7.2 85.6X +SQL Parquet Vectorized: DataPageV2 84 93 12 186.6 5.4 114.3X +SQL Parquet MR: DataPageV1 1466 1470 6 10.7 93.2 6.6X +SQL Parquet MR: DataPageV2 1334 1347 18 11.8 84.8 7.2X +SQL ORC Vectorized 163 197 27 96.3 10.4 59.0X +SQL ORC MR 1554 1558 6 10.1 98.8 6.2X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz -Parquet Reader Single BOOLEAN Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------- -ParquetReader Vectorized 117 126 17 134.9 7.4 1.0X -ParquetReader Vectorized -> Row 47 49 3 336.5 3.0 2.5X - -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +Parquet Reader Single BOOLEAN Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------- +ParquetReader Vectorized: DataPageV1 94 103 13 167.1 6.0 1.0X +ParquetReader Vectorized: DataPageV2 77 86 11 204.3 4.9 1.2X +ParquetReader Vectorized -> Row: DataPageV1 44 47 4 357.0 2.8 2.1X +ParquetReader Vectorized -> Row: DataPageV2 35 37 3 445.2 2.2 2.7X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 13434 13590 220 1.2 854.1 1.0X -SQL Json 10056 10073 24 1.6 639.3 1.3X -SQL Parquet Vectorized 212 229 12 74.3 13.5 63.4X -SQL Parquet MR 1883 1916 47 8.4 119.7 7.1X -SQL ORC Vectorized 200 241 30 78.8 12.7 67.3X -SQL ORC MR 1529 1549 28 10.3 97.2 8.8X - -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +SQL CSV 11479 11919 622 1.4 729.8 1.0X +SQL Json 9894 9922 39 1.6 629.1 1.2X +SQL Parquet Vectorized: DataPageV1 123 156 30 128.3 7.8 93.6X +SQL Parquet Vectorized: DataPageV2 126 138 19 125.2 8.0 91.4X +SQL Parquet MR: DataPageV1 1986 2500 726 7.9 126.3 5.8X +SQL Parquet MR: DataPageV2 1810 1898 126 8.7 115.1 6.3X +SQL ORC Vectorized 174 210 30 90.5 11.0 66.1X +SQL ORC MR 1645 1652 9 9.6 104.6 7.0X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz -Parquet Reader Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------- -ParquetReader Vectorized 229 254 13 68.6 14.6 1.0X -ParquetReader Vectorized -> Row 162 171 14 96.9 10.3 1.4X - -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +Parquet Reader Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------- +ParquetReader Vectorized: DataPageV1 166 177 14 94.9 10.5 1.0X +ParquetReader Vectorized: DataPageV2 165 172 11 95.3 10.5 1.0X +ParquetReader Vectorized -> Row: DataPageV1 95 100 5 165.7 6.0 1.7X +ParquetReader Vectorized -> Row: DataPageV2 85 89 6 186.0 5.4 2.0X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single SMALLINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 14320 14476 221 1.1 910.4 1.0X -SQL Json 9769 10067 423 1.6 621.1 1.5X -SQL Parquet Vectorized 187 228 28 84.3 11.9 76.8X -SQL Parquet MR 2230 2240 14 7.1 141.8 6.4X -SQL ORC Vectorized 221 265 36 71.1 14.1 64.8X -SQL ORC MR 1763 1779 23 8.9 112.1 8.1X - -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +SQL CSV 12176 12646 664 1.3 774.1 1.0X +SQL Json 9696 9729 46 1.6 616.5 1.3X +SQL Parquet Vectorized: DataPageV1 151 201 33 103.9 9.6 80.4X +SQL Parquet Vectorized: DataPageV2 216 235 15 72.7 13.8 56.3X +SQL Parquet MR: DataPageV1 1915 2017 145 8.2 121.8 6.4X +SQL Parquet MR: DataPageV2 1954 1978 33 8.0 124.3 6.2X +SQL ORC Vectorized 197 235 25 79.7 12.6 61.7X +SQL ORC MR 1769 1829 85 8.9 112.5 6.9X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Parquet Reader Single SMALLINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative --------------------------------------------------------------------------------------------------------------------------- -ParquetReader Vectorized 246 255 12 64.1 15.6 1.0X -ParquetReader Vectorized -> Row 249 294 21 63.1 15.8 1.0X +--------------------------------------------------------------------------------------------------------------------------- +ParquetReader Vectorized: DataPageV1 230 237 12 68.5 14.6 1.0X +ParquetReader Vectorized: DataPageV2 293 298 9 53.6 18.7 0.8X +ParquetReader Vectorized -> Row: DataPageV1 215 265 23 73.2 13.7 1.1X +ParquetReader Vectorized -> Row: DataPageV2 279 301 32 56.3 17.8 0.8X -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single INT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 15460 15543 116 1.0 982.9 1.0X -SQL Json 10199 10393 274 1.5 648.4 1.5X -SQL Parquet Vectorized 163 203 30 96.5 10.4 94.8X -SQL Parquet MR 1914 2025 157 8.2 121.7 8.1X -SQL ORC Vectorized 324 355 23 48.5 20.6 47.7X -SQL ORC MR 1673 1701 39 9.4 106.4 9.2X - -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +SQL CSV 13069 13409 482 1.2 830.9 1.0X +SQL Json 10599 10621 32 1.5 673.9 1.2X +SQL Parquet Vectorized: DataPageV1 142 177 34 110.6 9.0 91.9X +SQL Parquet Vectorized: DataPageV2 313 359 28 50.2 19.9 41.7X +SQL Parquet MR: DataPageV1 1979 2044 92 7.9 125.8 6.6X +SQL Parquet MR: DataPageV2 1958 2030 101 8.0 124.5 6.7X +SQL ORC Vectorized 277 303 21 56.7 17.6 47.1X +SQL ORC MR 1692 1782 128 9.3 107.6 7.7X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz -Parquet Reader Single INT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -ParquetReader Vectorized 209 223 17 75.2 13.3 1.0X -ParquetReader Vectorized -> Row 303 307 6 51.9 19.3 0.7X - -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +Parquet Reader Single INT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------- +ParquetReader Vectorized: DataPageV1 253 269 18 62.1 16.1 1.0X +ParquetReader Vectorized: DataPageV2 1197 1199 3 13.1 76.1 0.2X +ParquetReader Vectorized -> Row: DataPageV1 273 361 110 57.7 17.3 0.9X +ParquetReader Vectorized -> Row: DataPageV2 379 438 37 41.5 24.1 0.7X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single BIGINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 19075 19147 101 0.8 1212.8 1.0X -SQL Json 12181 12369 265 1.3 774.5 1.6X -SQL Parquet Vectorized 230 268 25 68.5 14.6 83.1X -SQL Parquet MR 2160 2244 120 7.3 137.3 8.8X -SQL ORC Vectorized 396 444 41 39.7 25.2 48.2X -SQL ORC MR 1924 1939 21 8.2 122.3 9.9X - -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +SQL CSV 17143 17467 458 0.9 1089.9 1.0X +SQL Json 11507 12198 977 1.4 731.6 1.5X +SQL Parquet Vectorized: DataPageV1 238 253 19 66.0 15.2 71.9X +SQL Parquet Vectorized: DataPageV2 502 567 48 31.3 31.9 34.1X +SQL Parquet MR: DataPageV1 2333 2335 3 6.7 148.4 7.3X +SQL Parquet MR: DataPageV2 1948 1972 34 8.1 123.8 8.8X +SQL ORC Vectorized 389 408 20 40.5 24.7 44.1X +SQL ORC MR 1726 1817 128 9.1 109.7 9.9X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz -Parquet Reader Single BIGINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -ParquetReader Vectorized 273 311 43 57.5 17.4 1.0X -ParquetReader Vectorized -> Row 316 322 8 49.8 20.1 0.9X - -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +Parquet Reader Single BIGINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------- +ParquetReader Vectorized: DataPageV1 289 340 43 54.4 18.4 1.0X +ParquetReader Vectorized: DataPageV2 572 609 27 27.5 36.4 0.5X +ParquetReader Vectorized -> Row: DataPageV1 329 353 48 47.8 20.9 0.9X +ParquetReader Vectorized -> Row: DataPageV2 639 654 18 24.6 40.6 0.5X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single FLOAT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 15439 15605 235 1.0 981.6 1.0X -SQL Json 11709 11852 201 1.3 744.5 1.3X -SQL Parquet Vectorized 157 199 33 99.9 10.0 98.0X -SQL Parquet MR 1996 2120 176 7.9 126.9 7.7X -SQL ORC Vectorized 439 466 28 35.8 27.9 35.1X -SQL ORC MR 1965 1991 36 8.0 124.9 7.9X - -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +SQL CSV 13721 13812 129 1.1 872.4 1.0X +SQL Json 12147 17632 2196 1.3 772.3 1.1X +SQL Parquet Vectorized: DataPageV1 138 164 25 113.9 8.8 99.4X +SQL Parquet Vectorized: DataPageV2 151 180 26 104.4 9.6 91.1X +SQL Parquet MR: DataPageV1 2006 2078 101 7.8 127.6 6.8X +SQL Parquet MR: DataPageV2 2038 2040 2 7.7 129.6 6.7X +SQL ORC Vectorized 465 475 10 33.8 29.6 29.5X +SQL ORC MR 1814 1860 64 8.7 115.4 7.6X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz -Parquet Reader Single FLOAT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -ParquetReader Vectorized 206 212 8 76.4 13.1 1.0X -ParquetReader Vectorized -> Row 220 266 29 71.4 14.0 0.9X - -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +Parquet Reader Single FLOAT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------- +ParquetReader Vectorized: DataPageV1 275 404 187 57.2 17.5 1.0X +ParquetReader Vectorized: DataPageV2 275 287 12 57.2 17.5 1.0X +ParquetReader Vectorized -> Row: DataPageV1 227 265 24 69.2 14.4 1.2X +ParquetReader Vectorized -> Row: DataPageV2 228 259 28 69.1 14.5 1.2X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single DOUBLE Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 20048 20816 1086 0.8 1274.6 1.0X -SQL Json 16265 16314 69 1.0 1034.1 1.2X -SQL Parquet Vectorized 238 296 29 66.1 15.1 84.3X -SQL Parquet MR 2414 2418 7 6.5 153.5 8.3X -SQL ORC Vectorized 555 604 38 28.4 35.3 36.2X -SQL ORC MR 2225 2242 24 7.1 141.5 9.0X - -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +SQL CSV 17269 17620 496 0.9 1097.9 1.0X +SQL Json 15636 15952 447 1.0 994.1 1.1X +SQL Parquet Vectorized: DataPageV1 238 267 18 66.0 15.1 72.5X +SQL Parquet Vectorized: DataPageV2 222 260 21 70.9 14.1 77.9X +SQL Parquet MR: DataPageV1 2418 2457 56 6.5 153.7 7.1X +SQL Parquet MR: DataPageV2 2194 2207 18 7.2 139.5 7.9X +SQL ORC Vectorized 519 528 14 30.3 33.0 33.3X +SQL ORC MR 1760 1770 14 8.9 111.9 9.8X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz -Parquet Reader Single DOUBLE Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -ParquetReader Vectorized 317 352 35 49.6 20.2 1.0X -ParquetReader Vectorized -> Row 346 356 9 45.4 22.0 0.9X +Parquet Reader Single DOUBLE Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------- +ParquetReader Vectorized: DataPageV1 284 305 30 55.3 18.1 1.0X +ParquetReader Vectorized: DataPageV2 286 286 1 55.1 18.2 1.0X +ParquetReader Vectorized -> Row: DataPageV1 325 337 16 48.4 20.6 0.9X +ParquetReader Vectorized -> Row: DataPageV2 346 361 16 45.5 22.0 0.8X ================================================================================================ Int and String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Int and String Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 13981 14223 342 0.7 1333.4 1.0X -SQL Json 11241 11293 74 0.9 1072.0 1.2X -SQL Parquet Vectorized 2060 2076 23 5.1 196.4 6.8X -SQL Parquet MR 3779 3931 216 2.8 360.4 3.7X -SQL ORC Vectorized 2085 2088 4 5.0 198.8 6.7X -SQL ORC MR 3739 3767 39 2.8 356.6 3.7X +SQL CSV 12428 12714 405 0.8 1185.2 1.0X +SQL Json 11088 11251 231 0.9 1057.4 1.1X +SQL Parquet Vectorized: DataPageV1 1990 1997 10 5.3 189.8 6.2X +SQL Parquet Vectorized: DataPageV2 2551 2618 95 4.1 243.3 4.9X +SQL Parquet MR: DataPageV1 3903 3913 15 2.7 372.2 3.2X +SQL Parquet MR: DataPageV2 3734 3920 263 2.8 356.1 3.3X +SQL ORC Vectorized 2153 2155 3 4.9 205.3 5.8X +SQL ORC MR 3485 3549 91 3.0 332.4 3.6X ================================================================================================ Repeated String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Repeated String: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 8544 8579 50 1.2 814.8 1.0X -SQL Json 6705 6952 348 1.6 639.5 1.3X -SQL Parquet Vectorized 603 615 9 17.4 57.5 14.2X -SQL Parquet MR 1722 1725 4 6.1 164.2 5.0X -SQL ORC Vectorized 515 547 24 20.4 49.1 16.6X -SQL ORC MR 1827 1845 25 5.7 174.2 4.7X +SQL CSV 7116 7167 72 1.5 678.7 1.0X +SQL Json 6700 6741 58 1.6 639.0 1.1X +SQL Parquet Vectorized: DataPageV1 526 556 36 19.9 50.1 13.5X +SQL Parquet Vectorized: DataPageV2 518 533 15 20.2 49.4 13.7X +SQL Parquet MR: DataPageV1 1504 1656 216 7.0 143.4 4.7X +SQL Parquet MR: DataPageV2 1676 1676 1 6.3 159.8 4.2X +SQL ORC Vectorized 497 518 20 21.1 47.4 14.3X +SQL ORC MR 1657 1787 183 6.3 158.1 4.3X ================================================================================================ Partitioned Table Scan ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz -Partitioned Table: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -Data column - CSV 18854 19521 943 0.8 1198.7 1.0X -Data column - Json 12579 12688 154 1.3 799.8 1.5X -Data column - Parquet Vectorized 246 298 28 63.9 15.7 76.5X -Data column - Parquet MR 2693 2699 9 5.8 171.2 7.0X -Data column - ORC Vectorized 434 463 25 36.2 27.6 43.4X -Data column - ORC MR 2249 2303 77 7.0 143.0 8.4X -Partition column - CSV 6045 6199 217 2.6 384.3 3.1X -Partition column - Json 9463 9679 305 1.7 601.7 2.0X -Partition column - Parquet Vectorized 64 92 36 244.3 4.1 292.9X -Partition column - Parquet MR 1238 1252 20 12.7 78.7 15.2X -Partition column - ORC Vectorized 60 85 25 263.7 3.8 316.1X -Partition column - ORC MR 1440 1458 26 10.9 91.5 13.1X -Both columns - CSV 19647 20381 1038 0.8 1249.1 1.0X -Both columns - Json 12615 12654 55 1.2 802.0 1.5X -Both columns - Parquet Vectorized 337 345 9 46.7 21.4 56.0X -Both columns - Parquet MR 2461 2573 158 6.4 156.5 7.7X -Both columns - ORC Vectorized 432 470 54 36.4 27.5 43.6X -Both columns - ORC MR 2507 2536 40 6.3 159.4 7.5X +Partitioned Table: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------------- +Data column - CSV 18247 18411 232 0.9 1160.1 1.0X +Data column - Json 10860 11264 571 1.4 690.5 1.7X +Data column - Parquet Vectorized: DataPageV1 223 274 26 70.6 14.2 81.9X +Data column - Parquet Vectorized: DataPageV2 537 559 23 29.3 34.1 34.0X +Data column - Parquet MR: DataPageV1 2411 2517 150 6.5 153.3 7.6X +Data column - Parquet MR: DataPageV2 2299 2356 81 6.8 146.2 7.9X +Data column - ORC Vectorized 417 433 11 37.7 26.5 43.8X +Data column - ORC MR 2107 2178 101 7.5 134.0 8.7X +Partition column - CSV 6090 6186 136 2.6 387.2 3.0X +Partition column - Json 9479 9603 176 1.7 602.7 1.9X +Partition column - Parquet Vectorized: DataPageV1 49 69 28 322.0 3.1 373.6X +Partition column - Parquet Vectorized: DataPageV2 49 63 23 322.1 3.1 373.7X +Partition column - Parquet MR: DataPageV1 1200 1225 36 13.1 76.3 15.2X +Partition column - Parquet MR: DataPageV2 1199 1240 57 13.1 76.3 15.2X +Partition column - ORC Vectorized 53 77 26 295.0 3.4 342.2X +Partition column - ORC MR 1287 1346 83 12.2 81.8 14.2X +Both columns - CSV 17671 18140 663 0.9 1123.5 1.0X +Both columns - Json 11675 12167 696 1.3 742.3 1.6X +Both columns - Parquet Vectorized: DataPageV1 298 303 9 52.9 18.9 61.3X +Both columns - Parquet Vectorized: DataPageV2 541 580 36 29.1 34.4 33.7X +Both columns - Parquet MR: DataPageV1 2448 2491 60 6.4 155.6 7.5X +Both columns - Parquet MR: DataPageV2 2303 2352 69 6.8 146.4 7.9X +Both columns - ORC Vectorized 385 406 25 40.9 24.5 47.4X +Both columns - ORC MR 2118 2202 120 7.4 134.6 8.6X ================================================================================================ String with Nulls Scan ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz String with Nulls Scan (0.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 10199 10226 38 1.0 972.6 1.0X -SQL Json 10744 10925 256 1.0 1024.6 0.9X -SQL Parquet Vectorized 1251 1261 15 8.4 119.3 8.2X -SQL Parquet MR 3306 3315 13 3.2 315.3 3.1X -ParquetReader Vectorized 849 904 48 12.4 80.9 12.0X -SQL ORC Vectorized 1184 1204 28 8.9 112.9 8.6X -SQL ORC MR 2895 2945 71 3.6 276.1 3.5X - -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +SQL CSV 7966 12723 2892 1.3 759.7 1.0X +SQL Json 9897 10008 157 1.1 943.9 0.8X +SQL Parquet Vectorized: DataPageV1 1176 1264 125 8.9 112.1 6.8X +SQL Parquet Vectorized: DataPageV2 2224 2326 144 4.7 212.1 3.6X +SQL Parquet MR: DataPageV1 3431 3483 73 3.1 327.2 2.3X +SQL Parquet MR: DataPageV2 3845 4043 280 2.7 366.7 2.1X +ParquetReader Vectorized: DataPageV1 1055 1056 2 9.9 100.6 7.6X +ParquetReader Vectorized: DataPageV2 2093 2119 37 5.0 199.6 3.8X +SQL ORC Vectorized 1129 1217 125 9.3 107.7 7.1X +SQL ORC MR 2931 2982 72 3.6 279.5 2.7X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz String with Nulls Scan (50.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 7949 8052 145 1.3 758.1 1.0X -SQL Json 7750 7868 167 1.4 739.1 1.0X -SQL Parquet Vectorized 949 976 24 11.0 90.5 8.4X -SQL Parquet MR 2700 2722 31 3.9 257.5 2.9X -ParquetReader Vectorized 916 940 31 11.4 87.3 8.7X -SQL ORC Vectorized 1240 1249 13 8.5 118.2 6.4X -SQL ORC MR 2856 2929 103 3.7 272.4 2.8X - -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +SQL CSV 6338 6508 240 1.7 604.4 1.0X +SQL Json 7149 7247 138 1.5 681.8 0.9X +SQL Parquet Vectorized: DataPageV1 937 984 45 11.2 89.3 6.8X +SQL Parquet Vectorized: DataPageV2 1582 1608 37 6.6 150.9 4.0X +SQL Parquet MR: DataPageV1 2525 2721 277 4.2 240.8 2.5X +SQL Parquet MR: DataPageV2 2969 2974 7 3.5 283.1 2.1X +ParquetReader Vectorized: DataPageV1 933 940 12 11.2 88.9 6.8X +ParquetReader Vectorized: DataPageV2 1535 1549 20 6.8 146.4 4.1X +SQL ORC Vectorized 1144 1204 86 9.2 109.1 5.5X +SQL ORC MR 2816 2822 8 3.7 268.6 2.3X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz String with Nulls Scan (95.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 5416 5542 179 1.9 516.5 1.0X -SQL Json 4760 4980 311 2.2 454.0 1.1X -SQL Parquet Vectorized 222 236 8 47.2 21.2 24.4X -SQL Parquet MR 1669 1685 22 6.3 159.2 3.2X -ParquetReader Vectorized 248 252 3 42.3 23.6 21.9X -SQL ORC Vectorized 409 472 81 25.6 39.0 13.2X -SQL ORC MR 1686 1687 0 6.2 160.8 3.2X +SQL CSV 4443 4504 86 2.4 423.7 1.0X +SQL Json 4528 4563 49 2.3 431.8 1.0X +SQL Parquet Vectorized: DataPageV1 213 233 15 49.2 20.3 20.8X +SQL Parquet Vectorized: DataPageV2 267 294 22 39.3 25.4 16.7X +SQL Parquet MR: DataPageV1 1691 1700 13 6.2 161.2 2.6X +SQL Parquet MR: DataPageV2 1515 1565 70 6.9 144.5 2.9X +ParquetReader Vectorized: DataPageV1 228 231 2 46.0 21.7 19.5X +ParquetReader Vectorized: DataPageV2 285 296 9 36.8 27.1 15.6X +SQL ORC Vectorized 369 425 82 28.4 35.2 12.1X +SQL ORC MR 1457 1463 9 7.2 138.9 3.0X ================================================================================================ Single Column Scan From Wide Columns ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Column Scan from 10 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 2244 2282 53 0.5 2140.4 1.0X -SQL Json 3015 3099 119 0.3 2875.6 0.7X -SQL Parquet Vectorized 50 77 29 20.9 47.9 44.7X -SQL Parquet MR 190 209 27 5.5 180.8 11.8X -SQL ORC Vectorized 57 76 20 18.5 54.0 39.6X -SQL ORC MR 158 195 40 6.6 151.0 14.2X - -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +SQL CSV 2374 2377 5 0.4 2264.2 1.0X +SQL Json 2693 2726 46 0.4 2568.5 0.9X +SQL Parquet Vectorized: DataPageV1 44 62 16 23.8 42.0 54.0X +SQL Parquet Vectorized: DataPageV2 63 81 21 16.5 60.5 37.5X +SQL Parquet MR: DataPageV1 173 198 27 6.1 164.6 13.8X +SQL Parquet MR: DataPageV2 161 193 30 6.5 153.5 14.8X +SQL ORC Vectorized 53 71 18 19.9 50.2 45.1X +SQL ORC MR 149 182 34 7.0 142.3 15.9X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Column Scan from 50 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 5114 5296 257 0.2 4876.7 1.0X -SQL Json 11564 11828 373 0.1 11028.4 0.4X -SQL Parquet Vectorized 60 93 26 17.3 57.6 84.6X -SQL Parquet MR 198 232 31 5.3 188.9 25.8X -SQL ORC Vectorized 69 103 35 15.2 65.9 74.0X -SQL ORC MR 175 212 36 6.0 166.9 29.2X - -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +SQL CSV 5149 5193 62 0.2 4910.9 1.0X +SQL Json 10556 10891 475 0.1 10066.5 0.5X +SQL Parquet Vectorized: DataPageV1 64 96 28 16.3 61.3 80.1X +SQL Parquet Vectorized: DataPageV2 83 106 22 12.6 79.1 62.0X +SQL Parquet MR: DataPageV1 196 232 25 5.3 187.4 26.2X +SQL Parquet MR: DataPageV2 184 221 28 5.7 175.1 28.0X +SQL ORC Vectorized 74 98 31 14.1 70.8 69.3X +SQL ORC MR 182 214 38 5.8 173.9 28.2X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Column Scan from 100 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 9072 9324 357 0.1 8651.4 1.0X -SQL Json 23444 23735 411 0.0 22358.1 0.4X -SQL Parquet Vectorized 91 129 28 11.5 86.7 99.8X -SQL Parquet MR 220 270 56 4.8 209.6 41.3X -SQL ORC Vectorized 96 110 20 10.9 91.8 94.2X -SQL ORC MR 216 240 33 4.8 206.2 41.9X +SQL CSV 9077 9107 43 0.1 8656.2 1.0X +SQL Json 20131 20886 1067 0.1 19198.5 0.5X +SQL Parquet Vectorized: DataPageV1 93 124 26 11.3 88.8 97.5X +SQL Parquet Vectorized: DataPageV2 103 128 29 10.2 98.5 87.9X +SQL Parquet MR: DataPageV1 218 257 35 4.8 207.6 41.7X +SQL Parquet MR: DataPageV2 213 255 29 4.9 202.7 42.7X +SQL ORC Vectorized 80 95 20 13.0 76.6 112.9X +SQL ORC MR 187 207 20 5.6 178.0 48.6X diff --git a/sql/core/benchmarks/DataSourceReadBenchmark-jdk17-results.txt b/sql/core/benchmarks/DataSourceReadBenchmark-jdk17-results.txt index 85d506ec3454e..ecba57c0c3cc3 100644 --- a/sql/core/benchmarks/DataSourceReadBenchmark-jdk17-results.txt +++ b/sql/core/benchmarks/DataSourceReadBenchmark-jdk17-results.txt @@ -2,269 +2,322 @@ SQL Single Numeric Column Scan ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz SQL Single BOOLEAN Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 11737 11812 106 1.3 746.2 1.0X -SQL Json 7827 7904 109 2.0 497.6 1.5X -SQL Parquet Vectorized 98 116 12 160.6 6.2 119.8X -SQL Parquet MR 1529 1541 18 10.3 97.2 7.7X -SQL ORC Vectorized 165 185 14 95.5 10.5 71.2X -SQL ORC MR 1433 1440 9 11.0 91.1 8.2X - -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz -Parquet Reader Single BOOLEAN Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------- -ParquetReader Vectorized 38 40 3 416.2 2.4 1.0X -ParquetReader Vectorized -> Row 38 39 3 419.1 2.4 1.0X - -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +SQL CSV 15972 16369 561 1.0 1015.5 1.0X +SQL Json 9543 9580 54 1.6 606.7 1.7X +SQL Parquet Vectorized: DataPageV1 115 144 19 136.3 7.3 138.4X +SQL Parquet Vectorized: DataPageV2 95 109 15 165.1 6.1 167.6X +SQL Parquet MR: DataPageV1 2098 2119 30 7.5 133.4 7.6X +SQL Parquet MR: DataPageV2 2007 2012 6 7.8 127.6 8.0X +SQL ORC Vectorized 211 225 16 74.5 13.4 75.7X +SQL ORC MR 2077 2103 36 7.6 132.1 7.7X + +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz +Parquet Reader Single BOOLEAN Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------- +ParquetReader Vectorized: DataPageV1 43 47 2 369.4 2.7 1.0X +ParquetReader Vectorized: DataPageV2 30 34 2 518.5 1.9 1.4X +ParquetReader Vectorized -> Row: DataPageV1 47 50 2 333.6 3.0 0.9X +ParquetReader Vectorized -> Row: DataPageV2 31 35 2 504.8 2.0 1.4X + +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz SQL Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 13156 13192 51 1.2 836.4 1.0X -SQL Json 8690 8784 133 1.8 552.5 1.5X -SQL Parquet Vectorized 196 207 8 80.4 12.4 67.2X -SQL Parquet MR 1831 1834 4 8.6 116.4 7.2X -SQL ORC Vectorized 157 167 7 100.2 10.0 83.8X -SQL ORC MR 1381 1387 8 11.4 87.8 9.5X - -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz -Parquet Reader Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------- -ParquetReader Vectorized 147 153 6 107.0 9.3 1.0X -ParquetReader Vectorized -> Row 149 162 24 105.7 9.5 1.0X - -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +SQL CSV 17468 17543 105 0.9 1110.6 1.0X +SQL Json 11059 11065 8 1.4 703.1 1.6X +SQL Parquet Vectorized: DataPageV1 128 142 15 123.1 8.1 136.7X +SQL Parquet Vectorized: DataPageV2 126 141 8 125.2 8.0 139.1X +SQL Parquet MR: DataPageV1 2305 2331 36 6.8 146.5 7.6X +SQL Parquet MR: DataPageV2 2075 2095 28 7.6 131.9 8.4X +SQL ORC Vectorized 172 191 16 91.5 10.9 101.6X +SQL ORC MR 1777 1796 26 8.8 113.0 9.8X + +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz +Parquet Reader Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------- +ParquetReader Vectorized: DataPageV1 72 77 5 219.4 4.6 1.0X +ParquetReader Vectorized: DataPageV2 72 77 3 217.9 4.6 1.0X +ParquetReader Vectorized -> Row: DataPageV1 76 83 6 206.6 4.8 0.9X +ParquetReader Vectorized -> Row: DataPageV2 75 80 3 210.3 4.8 1.0X + +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz SQL Single SMALLINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 14024 14291 378 1.1 891.6 1.0X -SQL Json 9777 9849 102 1.6 621.6 1.4X -SQL Parquet Vectorized 153 175 18 102.9 9.7 91.8X -SQL Parquet MR 1971 1979 11 8.0 125.3 7.1X -SQL ORC Vectorized 193 211 15 81.4 12.3 72.5X -SQL ORC MR 1665 1693 39 9.4 105.9 8.4X - -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +SQL CSV 18330 18332 3 0.9 1165.4 1.0X +SQL Json 11383 11429 66 1.4 723.7 1.6X +SQL Parquet Vectorized: DataPageV1 179 197 13 88.0 11.4 102.5X +SQL Parquet Vectorized: DataPageV2 239 263 18 65.7 15.2 76.6X +SQL Parquet MR: DataPageV1 2552 2567 21 6.2 162.3 7.2X +SQL Parquet MR: DataPageV2 2389 2436 67 6.6 151.9 7.7X +SQL ORC Vectorized 246 263 14 64.0 15.6 74.6X +SQL ORC MR 1965 2002 52 8.0 124.9 9.3X + +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz Parquet Reader Single SMALLINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative --------------------------------------------------------------------------------------------------------------------------- -ParquetReader Vectorized 217 227 7 72.6 13.8 1.0X -ParquetReader Vectorized -> Row 214 216 2 73.5 13.6 1.0X - -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +--------------------------------------------------------------------------------------------------------------------------- +ParquetReader Vectorized: DataPageV1 253 263 11 62.2 16.1 1.0X +ParquetReader Vectorized: DataPageV2 306 317 7 51.4 19.4 0.8X +ParquetReader Vectorized -> Row: DataPageV1 246 250 4 64.0 15.6 1.0X +ParquetReader Vectorized -> Row: DataPageV2 316 321 4 49.8 20.1 0.8X + +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz SQL Single INT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 15107 15205 139 1.0 960.5 1.0X -SQL Json 9699 9773 104 1.6 616.7 1.6X -SQL Parquet Vectorized 144 160 24 109.6 9.1 105.2X -SQL Parquet MR 1903 1906 4 8.3 121.0 7.9X -SQL ORC Vectorized 227 234 6 69.4 14.4 66.6X -SQL ORC MR 1566 1578 17 10.0 99.5 9.6X - -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz -Parquet Reader Single INT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -ParquetReader Vectorized 209 214 4 75.2 13.3 1.0X -ParquetReader Vectorized -> Row 192 194 2 81.9 12.2 1.1X - -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +SQL CSV 19573 19822 352 0.8 1244.4 1.0X +SQL Json 12141 12217 107 1.3 771.9 1.6X +SQL Parquet Vectorized: DataPageV1 192 222 28 81.8 12.2 101.8X +SQL Parquet Vectorized: DataPageV2 345 373 24 45.6 21.9 56.7X +SQL Parquet MR: DataPageV1 2736 2741 7 5.7 173.9 7.2X +SQL Parquet MR: DataPageV2 2467 2536 97 6.4 156.9 7.9X +SQL ORC Vectorized 332 356 20 47.4 21.1 59.0X +SQL ORC MR 2188 2193 7 7.2 139.1 8.9X + +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz +Parquet Reader Single INT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------- +ParquetReader Vectorized: DataPageV1 291 295 4 54.1 18.5 1.0X +ParquetReader Vectorized: DataPageV2 493 518 39 31.9 31.3 0.6X +ParquetReader Vectorized -> Row: DataPageV1 300 306 8 52.5 19.1 1.0X +ParquetReader Vectorized -> Row: DataPageV2 471 483 11 33.4 30.0 0.6X + +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz SQL Single BIGINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 19711 19743 44 0.8 1253.2 1.0X -SQL Json 11459 11500 59 1.4 728.5 1.7X -SQL Parquet Vectorized 202 210 7 77.9 12.8 97.6X -SQL Parquet MR 2093 2120 37 7.5 133.1 9.4X -SQL ORC Vectorized 356 384 22 44.2 22.6 55.4X -SQL ORC MR 1832 1844 17 8.6 116.4 10.8X - -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz -Parquet Reader Single BIGINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -ParquetReader Vectorized 290 290 0 54.3 18.4 1.0X -ParquetReader Vectorized -> Row 308 314 8 51.1 19.6 0.9X - -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +SQL CSV 24692 24718 37 0.6 1569.9 1.0X +SQL Json 14839 14875 50 1.1 943.5 1.7X +SQL Parquet Vectorized: DataPageV1 295 316 29 53.3 18.7 83.7X +SQL Parquet Vectorized: DataPageV2 477 505 24 32.9 30.4 51.7X +SQL Parquet MR: DataPageV1 2841 2981 197 5.5 180.6 8.7X +SQL Parquet MR: DataPageV2 2616 2632 23 6.0 166.3 9.4X +SQL ORC Vectorized 388 403 11 40.5 24.7 63.6X +SQL ORC MR 2274 2372 138 6.9 144.6 10.9X + +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz +Parquet Reader Single BIGINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------- +ParquetReader Vectorized: DataPageV1 376 387 9 41.9 23.9 1.0X +ParquetReader Vectorized: DataPageV2 585 591 6 26.9 37.2 0.6X +ParquetReader Vectorized -> Row: DataPageV1 377 387 9 41.8 23.9 1.0X +ParquetReader Vectorized -> Row: DataPageV2 576 586 10 27.3 36.6 0.7X + +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz SQL Single FLOAT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 16396 16602 292 1.0 1042.4 1.0X -SQL Json 11284 11591 433 1.4 717.4 1.5X -SQL Parquet Vectorized 137 168 14 114.7 8.7 119.6X -SQL Parquet MR 1901 1907 8 8.3 120.9 8.6X -SQL ORC Vectorized 429 447 12 36.6 27.3 38.2X -SQL ORC MR 1769 1841 102 8.9 112.4 9.3X - -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz -Parquet Reader Single FLOAT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -ParquetReader Vectorized 234 253 10 67.2 14.9 1.0X -ParquetReader Vectorized -> Row 214 238 15 73.5 13.6 1.1X - -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +SQL CSV 20566 20651 119 0.8 1307.6 1.0X +SQL Json 14337 14409 101 1.1 911.5 1.4X +SQL Parquet Vectorized: DataPageV1 154 167 8 101.9 9.8 133.2X +SQL Parquet Vectorized: DataPageV2 157 178 14 99.9 10.0 130.6X +SQL Parquet MR: DataPageV1 2730 2730 1 5.8 173.5 7.5X +SQL Parquet MR: DataPageV2 2459 2491 45 6.4 156.3 8.4X +SQL ORC Vectorized 479 501 15 32.9 30.4 43.0X +SQL ORC MR 2293 2343 71 6.9 145.8 9.0X + +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz +Parquet Reader Single FLOAT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------- +ParquetReader Vectorized: DataPageV1 272 283 9 57.9 17.3 1.0X +ParquetReader Vectorized: DataPageV2 250 288 27 62.9 15.9 1.1X +ParquetReader Vectorized -> Row: DataPageV1 291 301 6 54.1 18.5 0.9X +ParquetReader Vectorized -> Row: DataPageV2 293 305 14 53.6 18.6 0.9X + +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz SQL Single DOUBLE Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 20303 20621 449 0.8 1290.9 1.0X -SQL Json 14630 14734 147 1.1 930.1 1.4X -SQL Parquet Vectorized 212 246 23 74.0 13.5 95.6X -SQL Parquet MR 2073 2212 198 7.6 131.8 9.8X -SQL ORC Vectorized 445 455 9 35.4 28.3 45.6X -SQL ORC MR 1835 1902 95 8.6 116.7 11.1X - -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz -Parquet Reader Single DOUBLE Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -ParquetReader Vectorized 279 297 12 56.3 17.8 1.0X -ParquetReader Vectorized -> Row 280 292 12 56.1 17.8 1.0X +SQL CSV 25753 25874 171 0.6 1637.3 1.0X +SQL Json 19097 19391 416 0.8 1214.2 1.3X +SQL Parquet Vectorized: DataPageV1 273 288 11 57.6 17.4 94.3X +SQL Parquet Vectorized: DataPageV2 240 277 25 65.5 15.3 107.3X +SQL Parquet MR: DataPageV1 2969 3042 103 5.3 188.8 8.7X +SQL Parquet MR: DataPageV2 2692 2747 78 5.8 171.1 9.6X +SQL ORC Vectorized 601 626 20 26.2 38.2 42.8X +SQL ORC MR 2458 2467 13 6.4 156.3 10.5X + +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz +Parquet Reader Single DOUBLE Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------- +ParquetReader Vectorized: DataPageV1 354 363 7 44.4 22.5 1.0X +ParquetReader Vectorized: DataPageV2 345 359 12 45.5 22.0 1.0X +ParquetReader Vectorized -> Row: DataPageV1 337 345 8 46.7 21.4 1.1X +ParquetReader Vectorized -> Row: DataPageV2 335 364 21 46.9 21.3 1.1X ================================================================================================ Int and String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz Int and String Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 14027 14143 164 0.7 1337.7 1.0X -SQL Json 10476 10606 183 1.0 999.1 1.3X -SQL Parquet Vectorized 1969 2040 100 5.3 187.8 7.1X -SQL Parquet MR 3743 3834 128 2.8 357.0 3.7X -SQL ORC Vectorized 1926 1936 14 5.4 183.6 7.3X -SQL ORC MR 3383 3403 28 3.1 322.6 4.1X +SQL CSV 18074 18101 37 0.6 1723.7 1.0X +SQL Json 13211 13214 5 0.8 1259.9 1.4X +SQL Parquet Vectorized: DataPageV1 2249 2286 53 4.7 214.5 8.0X +SQL Parquet Vectorized: DataPageV2 2804 2818 20 3.7 267.4 6.4X +SQL Parquet MR: DataPageV1 4708 4779 100 2.2 449.0 3.8X +SQL Parquet MR: DataPageV2 4868 5046 251 2.2 464.3 3.7X +SQL ORC Vectorized 2145 2160 20 4.9 204.6 8.4X +SQL ORC MR 4180 4308 182 2.5 398.6 4.3X ================================================================================================ Repeated String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz Repeated String: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 8672 8905 330 1.2 827.0 1.0X -SQL Json 6369 6374 7 1.6 607.4 1.4X -SQL Parquet Vectorized 556 579 25 18.9 53.0 15.6X -SQL Parquet MR 1574 1585 14 6.7 150.2 5.5X -SQL ORC Vectorized 420 427 4 25.0 40.1 20.6X -SQL ORC MR 1711 1733 31 6.1 163.2 5.1X +SQL CSV 11320 11376 78 0.9 1079.6 1.0X +SQL Json 7593 7664 101 1.4 724.1 1.5X +SQL Parquet Vectorized: DataPageV1 633 639 9 16.6 60.3 17.9X +SQL Parquet Vectorized: DataPageV2 621 644 20 16.9 59.2 18.2X +SQL Parquet MR: DataPageV1 2111 2157 65 5.0 201.3 5.4X +SQL Parquet MR: DataPageV2 2018 2064 65 5.2 192.4 5.6X +SQL ORC Vectorized 505 540 36 20.8 48.2 22.4X +SQL ORC MR 2302 2360 82 4.6 219.5 4.9X ================================================================================================ Partitioned Table Scan ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz -Partitioned Table: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -Data column - CSV 21008 21367 508 0.7 1335.7 1.0X -Data column - Json 12091 12412 455 1.3 768.7 1.7X -Data column - Parquet Vectorized 210 217 6 75.0 13.3 100.1X -Data column - Parquet MR 2434 2450 22 6.5 154.8 8.6X -Data column - ORC Vectorized 323 347 26 48.7 20.5 65.1X -Data column - ORC MR 2223 2231 11 7.1 141.3 9.5X -Partition column - CSV 5889 5992 146 2.7 374.4 3.6X -Partition column - Json 9706 9870 233 1.6 617.1 2.2X -Partition column - Parquet Vectorized 51 58 8 306.3 3.3 409.2X -Partition column - Parquet MR 1237 1241 5 12.7 78.7 17.0X -Partition column - ORC Vectorized 53 61 8 294.1 3.4 392.9X -Partition column - ORC MR 1322 1336 20 11.9 84.1 15.9X -Both columns - CSV 20362 20389 39 0.8 1294.6 1.0X -Both columns - Json 12267 12512 346 1.3 779.9 1.7X -Both columns - Parquet Vectorized 254 262 9 61.9 16.2 82.6X -Both columns - Parquet MR 2649 2745 136 5.9 168.4 7.9X -Both columns - ORC Vectorized 348 379 32 45.2 22.1 60.4X -Both columns - ORC MR 2339 2343 6 6.7 148.7 9.0X +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz +Partitioned Table: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------------- +Data column - CSV 24867 25261 556 0.6 1581.0 1.0X +Data column - Json 13937 13987 70 1.1 886.1 1.8X +Data column - Parquet Vectorized: DataPageV1 252 264 8 62.3 16.0 98.5X +Data column - Parquet Vectorized: DataPageV2 547 560 13 28.8 34.7 45.5X +Data column - Parquet MR: DataPageV1 3492 3509 25 4.5 222.0 7.1X +Data column - Parquet MR: DataPageV2 3148 3208 84 5.0 200.2 7.9X +Data column - ORC Vectorized 493 512 21 31.9 31.3 50.5X +Data column - ORC MR 2925 2943 26 5.4 185.9 8.5X +Partition column - CSV 7847 7851 5 2.0 498.9 3.2X +Partition column - Json 11759 11908 210 1.3 747.6 2.1X +Partition column - Parquet Vectorized: DataPageV1 60 67 7 262.3 3.8 414.7X +Partition column - Parquet Vectorized: DataPageV2 57 65 9 274.2 3.6 433.5X +Partition column - Parquet MR: DataPageV1 1762 1768 8 8.9 112.1 14.1X +Partition column - Parquet MR: DataPageV2 1742 1783 59 9.0 110.7 14.3X +Partition column - ORC Vectorized 59 71 7 265.6 3.8 419.9X +Partition column - ORC MR 1743 1764 29 9.0 110.8 14.3X +Both columns - CSV 25859 25924 92 0.6 1644.1 1.0X +Both columns - Json 14693 14764 101 1.1 934.2 1.7X +Both columns - Parquet Vectorized: DataPageV1 341 395 66 46.2 21.7 73.0X +Both columns - Parquet Vectorized: DataPageV2 624 643 13 25.2 39.7 39.9X +Both columns - Parquet MR: DataPageV1 3541 3611 99 4.4 225.2 7.0X +Both columns - Parquet MR: DataPageV2 3279 3301 32 4.8 208.4 7.6X +Both columns - ORC Vectorized 434 483 40 36.2 27.6 57.3X +Both columns - ORC MR 2946 2964 26 5.3 187.3 8.4X ================================================================================================ String with Nulls Scan ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz String with Nulls Scan (0.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 9872 9917 64 1.1 941.4 1.0X -SQL Json 8698 8793 134 1.2 829.5 1.1X -SQL Parquet Vectorized 1277 1281 6 8.2 121.8 7.7X -SQL Parquet MR 3649 3679 42 2.9 348.0 2.7X -ParquetReader Vectorized 969 1015 66 10.8 92.4 10.2X -SQL ORC Vectorized 1022 1038 23 10.3 97.4 9.7X -SQL ORC MR 3103 3122 27 3.4 295.9 3.2X - -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +SQL CSV 13698 13783 121 0.8 1306.3 1.0X +SQL Json 11030 11144 161 1.0 1051.9 1.2X +SQL Parquet Vectorized: DataPageV1 1695 1699 7 6.2 161.6 8.1X +SQL Parquet Vectorized: DataPageV2 2740 2744 5 3.8 261.3 5.0X +SQL Parquet MR: DataPageV1 4547 4594 66 2.3 433.7 3.0X +SQL Parquet MR: DataPageV2 5382 5455 103 1.9 513.3 2.5X +ParquetReader Vectorized: DataPageV1 1238 1238 0 8.5 118.0 11.1X +ParquetReader Vectorized: DataPageV2 2312 2325 19 4.5 220.5 5.9X +SQL ORC Vectorized 1134 1147 18 9.2 108.2 12.1X +SQL ORC MR 3966 4015 69 2.6 378.2 3.5X + +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz String with Nulls Scan (50.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 7321 7550 324 1.4 698.2 1.0X -SQL Json 6939 6962 32 1.5 661.8 1.1X -SQL Parquet Vectorized 906 917 17 11.6 86.4 8.1X -SQL Parquet MR 2617 2655 54 4.0 249.6 2.8X -ParquetReader Vectorized 832 837 5 12.6 79.4 8.8X -SQL ORC Vectorized 1101 1109 11 9.5 105.0 6.6X -SQL ORC MR 2777 2778 2 3.8 264.8 2.6X - -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +SQL CSV 10613 10658 64 1.0 1012.1 1.0X +SQL Json 8973 8996 33 1.2 855.7 1.2X +SQL Parquet Vectorized: DataPageV1 1208 1221 18 8.7 115.2 8.8X +SQL Parquet Vectorized: DataPageV2 1949 1950 1 5.4 185.9 5.4X +SQL Parquet MR: DataPageV1 3701 3716 21 2.8 353.0 2.9X +SQL Parquet MR: DataPageV2 4150 4192 60 2.5 395.8 2.6X +ParquetReader Vectorized: DataPageV1 1191 1192 1 8.8 113.6 8.9X +ParquetReader Vectorized: DataPageV2 1874 1917 61 5.6 178.7 5.7X +SQL ORC Vectorized 1338 1365 38 7.8 127.6 7.9X +SQL ORC MR 3659 3674 21 2.9 349.0 2.9X + +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz String with Nulls Scan (95.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 5670 5691 30 1.8 540.7 1.0X -SQL Json 4309 4327 27 2.4 410.9 1.3X -SQL Parquet Vectorized 212 217 5 49.5 20.2 26.8X -SQL Parquet MR 1634 1672 53 6.4 155.9 3.5X -ParquetReader Vectorized 212 214 3 49.5 20.2 26.8X -SQL ORC Vectorized 356 359 4 29.5 33.9 15.9X -SQL ORC MR 1519 1561 59 6.9 144.9 3.7X +SQL CSV 8714 8809 134 1.2 831.0 1.0X +SQL Json 5801 5819 25 1.8 553.2 1.5X +SQL Parquet Vectorized: DataPageV1 297 316 11 35.3 28.3 29.3X +SQL Parquet Vectorized: DataPageV2 363 382 12 28.9 34.6 24.0X +SQL Parquet MR: DataPageV1 2350 2366 22 4.5 224.1 3.7X +SQL Parquet MR: DataPageV2 2132 2183 73 4.9 203.3 4.1X +ParquetReader Vectorized: DataPageV1 296 310 13 35.4 28.2 29.4X +ParquetReader Vectorized: DataPageV2 368 372 3 28.5 35.1 23.7X +SQL ORC Vectorized 474 487 10 22.1 45.2 18.4X +SQL ORC MR 2025 2031 9 5.2 193.1 4.3X ================================================================================================ Single Column Scan From Wide Columns ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz Single Column Scan from 10 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 2172 2213 58 0.5 2071.4 1.0X -SQL Json 2916 2934 26 0.4 2780.7 0.7X -SQL Parquet Vectorized 43 48 6 24.5 40.7 50.8X -SQL Parquet MR 175 182 9 6.0 167.1 12.4X -SQL ORC Vectorized 51 56 6 20.5 48.9 42.4X -SQL ORC MR 152 157 5 6.9 144.9 14.3X - -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +SQL CSV 2677 2687 14 0.4 2553.2 1.0X +SQL Json 3581 3588 10 0.3 3414.8 0.7X +SQL Parquet Vectorized: DataPageV1 52 59 7 20.2 49.6 51.5X +SQL Parquet Vectorized: DataPageV2 68 75 7 15.4 65.0 39.3X +SQL Parquet MR: DataPageV1 245 257 9 4.3 233.6 10.9X +SQL Parquet MR: DataPageV2 224 237 8 4.7 213.7 11.9X +SQL ORC Vectorized 64 70 5 16.3 61.3 41.7X +SQL ORC MR 208 216 8 5.0 198.2 12.9X + +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz Single Column Scan from 50 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 4658 4737 112 0.2 4442.6 1.0X -SQL Json 12114 12242 181 0.1 11552.8 0.4X -SQL Parquet Vectorized 59 66 9 17.8 56.3 78.9X -SQL Parquet MR 196 206 10 5.3 187.3 23.7X -SQL ORC Vectorized 68 77 6 15.3 65.2 68.1X -SQL ORC MR 171 183 9 6.1 163.4 27.2X - -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +SQL CSV 5753 5771 25 0.2 5486.7 1.0X +SQL Json 13801 13851 71 0.1 13161.9 0.4X +SQL Parquet Vectorized: DataPageV1 75 83 9 14.1 71.1 77.2X +SQL Parquet Vectorized: DataPageV2 84 93 7 12.4 80.6 68.1X +SQL Parquet MR: DataPageV1 269 280 7 3.9 256.5 21.4X +SQL Parquet MR: DataPageV2 251 258 8 4.2 238.9 23.0X +SQL ORC Vectorized 82 88 6 12.8 78.3 70.1X +SQL ORC MR 223 239 8 4.7 213.0 25.8X + +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz Single Column Scan from 100 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 8008 8070 88 0.1 7636.6 1.0X -SQL Json 22795 23224 607 0.0 21739.5 0.4X -SQL Parquet Vectorized 81 88 7 13.0 77.2 99.0X -SQL Parquet MR 225 244 16 4.7 214.9 35.5X -SQL ORC Vectorized 77 82 5 13.6 73.3 104.2X -SQL ORC MR 185 190 6 5.7 176.2 43.3X +SQL CSV 9487 9503 24 0.1 9047.1 1.0X +SQL Json 26109 26240 186 0.0 24899.2 0.4X +SQL Parquet Vectorized: DataPageV1 100 110 10 10.4 95.8 94.5X +SQL Parquet Vectorized: DataPageV2 113 119 6 9.3 107.3 84.3X +SQL Parquet MR: DataPageV1 280 296 11 3.7 267.2 33.9X +SQL Parquet MR: DataPageV2 281 321 68 3.7 268.0 33.8X +SQL ORC Vectorized 92 101 8 11.4 87.5 103.4X +SQL ORC MR 228 245 10 4.6 217.7 41.6X diff --git a/sql/core/benchmarks/DataSourceReadBenchmark-results.txt b/sql/core/benchmarks/DataSourceReadBenchmark-results.txt index 1dd99011ba273..6a2b6bfb4a0a8 100644 --- a/sql/core/benchmarks/DataSourceReadBenchmark-results.txt +++ b/sql/core/benchmarks/DataSourceReadBenchmark-results.txt @@ -2,269 +2,322 @@ SQL Single Numeric Column Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single BOOLEAN Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 13046 13274 322 1.2 829.5 1.0X -SQL Json 10585 10610 37 1.5 672.9 1.2X -SQL Parquet Vectorized 147 168 27 106.7 9.4 88.5X -SQL Parquet MR 1891 1897 7 8.3 120.3 6.9X -SQL ORC Vectorized 200 213 15 78.8 12.7 65.4X -SQL ORC MR 1939 1944 7 8.1 123.3 6.7X - -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz -Parquet Reader Single BOOLEAN Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------- -ParquetReader Vectorized 164 165 3 96.2 10.4 1.0X -ParquetReader Vectorized -> Row 71 72 2 220.6 4.5 2.3X - -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +SQL CSV 11570 12144 812 1.4 735.6 1.0X +SQL Json 7542 7568 37 2.1 479.5 1.5X +SQL Parquet Vectorized: DataPageV1 129 144 16 121.9 8.2 89.7X +SQL Parquet Vectorized: DataPageV2 92 106 20 170.3 5.9 125.2X +SQL Parquet MR: DataPageV1 1416 1419 3 11.1 90.0 8.2X +SQL Parquet MR: DataPageV2 1281 1359 110 12.3 81.4 9.0X +SQL ORC Vectorized 161 176 10 97.4 10.3 71.6X +SQL ORC MR 1525 1545 29 10.3 96.9 7.6X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +Parquet Reader Single BOOLEAN Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------- +ParquetReader Vectorized: DataPageV1 111 118 6 142.3 7.0 1.0X +ParquetReader Vectorized: DataPageV2 116 117 2 135.7 7.4 1.0X +ParquetReader Vectorized -> Row: DataPageV1 48 49 1 324.9 3.1 2.3X +ParquetReader Vectorized -> Row: DataPageV2 39 39 1 405.8 2.5 2.9X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 16466 16494 40 1.0 1046.9 1.0X -SQL Json 12509 12528 28 1.3 795.3 1.3X -SQL Parquet Vectorized 170 179 11 92.7 10.8 97.1X -SQL Parquet MR 2154 2167 19 7.3 136.9 7.6X -SQL ORC Vectorized 203 213 9 77.4 12.9 81.1X -SQL ORC MR 1977 1980 4 8.0 125.7 8.3X - -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz -Parquet Reader Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------- -ParquetReader Vectorized 216 218 3 72.8 13.7 1.0X -ParquetReader Vectorized -> Row 123 124 2 127.6 7.8 1.8X - -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +SQL CSV 13807 14535 1030 1.1 877.8 1.0X +SQL Json 8079 8094 21 1.9 513.6 1.7X +SQL Parquet Vectorized: DataPageV1 139 152 12 113.0 8.9 99.2X +SQL Parquet Vectorized: DataPageV2 140 147 5 112.5 8.9 98.7X +SQL Parquet MR: DataPageV1 1637 1741 148 9.6 104.1 8.4X +SQL Parquet MR: DataPageV2 1522 1636 161 10.3 96.8 9.1X +SQL ORC Vectorized 147 160 10 106.9 9.4 93.8X +SQL ORC MR 1542 1545 4 10.2 98.1 9.0X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +Parquet Reader Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------- +ParquetReader Vectorized: DataPageV1 166 171 8 94.7 10.6 1.0X +ParquetReader Vectorized: DataPageV2 166 169 4 94.7 10.6 1.0X +ParquetReader Vectorized -> Row: DataPageV1 156 157 2 100.7 9.9 1.1X +ParquetReader Vectorized -> Row: DataPageV2 156 157 2 100.7 9.9 1.1X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single SMALLINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 17321 17358 53 0.9 1101.2 1.0X -SQL Json 12964 13001 52 1.2 824.2 1.3X -SQL Parquet Vectorized 243 251 7 64.8 15.4 71.3X -SQL Parquet MR 2491 2499 12 6.3 158.4 7.0X -SQL ORC Vectorized 214 217 3 73.4 13.6 80.9X -SQL ORC MR 1960 1963 3 8.0 124.6 8.8X - -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +SQL CSV 15327 15421 133 1.0 974.5 1.0X +SQL Json 8564 8799 332 1.8 544.5 1.8X +SQL Parquet Vectorized: DataPageV1 202 219 11 77.8 12.8 75.8X +SQL Parquet Vectorized: DataPageV2 203 210 8 77.7 12.9 75.7X +SQL Parquet MR: DataPageV1 1874 2004 183 8.4 119.2 8.2X +SQL Parquet MR: DataPageV2 1606 1709 146 9.8 102.1 9.5X +SQL ORC Vectorized 167 179 10 94.1 10.6 91.7X +SQL ORC MR 1404 1408 6 11.2 89.3 10.9X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Parquet Reader Single SMALLINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative --------------------------------------------------------------------------------------------------------------------------- -ParquetReader Vectorized 361 365 6 43.6 22.9 1.0X -ParquetReader Vectorized -> Row 323 329 10 48.7 20.5 1.1X - -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +--------------------------------------------------------------------------------------------------------------------------- +ParquetReader Vectorized: DataPageV1 222 236 13 70.7 14.1 1.0X +ParquetReader Vectorized: DataPageV2 259 268 14 60.8 16.5 0.9X +ParquetReader Vectorized -> Row: DataPageV1 228 248 11 68.9 14.5 1.0X +ParquetReader Vectorized -> Row: DataPageV2 264 293 13 59.5 16.8 0.8X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single INT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 19098 19123 36 0.8 1214.2 1.0X -SQL Json 13719 13736 23 1.1 872.3 1.4X -SQL Parquet Vectorized 188 192 5 83.5 12.0 101.4X -SQL Parquet MR 2515 2536 30 6.3 159.9 7.6X -SQL ORC Vectorized 287 295 5 54.8 18.3 66.5X -SQL ORC MR 2034 2036 2 7.7 129.3 9.4X - -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz -Parquet Reader Single INT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -ParquetReader Vectorized 309 311 3 50.9 19.7 1.0X -ParquetReader Vectorized -> Row 270 272 5 58.4 17.1 1.1X - -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +SQL CSV 17479 17651 243 0.9 1111.3 1.0X +SQL Json 9565 9582 25 1.6 608.1 1.8X +SQL Parquet Vectorized: DataPageV1 152 159 8 103.2 9.7 114.7X +SQL Parquet Vectorized: DataPageV2 290 308 18 54.2 18.4 60.3X +SQL Parquet MR: DataPageV1 1861 1980 169 8.5 118.3 9.4X +SQL Parquet MR: DataPageV2 1647 1748 142 9.5 104.7 10.6X +SQL ORC Vectorized 230 251 12 68.3 14.6 75.9X +SQL ORC MR 1645 1648 3 9.6 104.6 10.6X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +Parquet Reader Single INT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------- +ParquetReader Vectorized: DataPageV1 208 213 9 75.7 13.2 1.0X +ParquetReader Vectorized: DataPageV2 355 382 14 44.3 22.6 0.6X +ParquetReader Vectorized -> Row: DataPageV1 212 233 8 74.1 13.5 1.0X +ParquetReader Vectorized -> Row: DataPageV2 350 353 7 45.0 22.2 0.6X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single BIGINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 25565 25574 13 0.6 1625.4 1.0X -SQL Json 17510 17518 11 0.9 1113.3 1.5X -SQL Parquet Vectorized 259 266 9 60.7 16.5 98.6X -SQL Parquet MR 2628 2647 28 6.0 167.1 9.7X -SQL ORC Vectorized 357 365 6 44.1 22.7 71.6X -SQL ORC MR 2144 2151 10 7.3 136.3 11.9X - -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz -Parquet Reader Single BIGINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -ParquetReader Vectorized 385 390 8 40.8 24.5 1.0X -ParquetReader Vectorized -> Row 345 350 6 45.6 21.9 1.1X - -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +SQL CSV 21825 21944 169 0.7 1387.6 1.0X +SQL Json 11877 11927 71 1.3 755.1 1.8X +SQL Parquet Vectorized: DataPageV1 229 242 18 68.8 14.5 95.5X +SQL Parquet Vectorized: DataPageV2 435 452 23 36.1 27.7 50.1X +SQL Parquet MR: DataPageV1 2050 2184 190 7.7 130.3 10.6X +SQL Parquet MR: DataPageV2 1829 1927 138 8.6 116.3 11.9X +SQL ORC Vectorized 287 308 14 54.8 18.3 76.0X +SQL ORC MR 1579 1603 34 10.0 100.4 13.8X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +Parquet Reader Single BIGINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------- +ParquetReader Vectorized: DataPageV1 299 341 86 52.6 19.0 1.0X +ParquetReader Vectorized: DataPageV2 551 607 110 28.5 35.1 0.5X +ParquetReader Vectorized -> Row: DataPageV1 341 344 4 46.2 21.7 0.9X +ParquetReader Vectorized -> Row: DataPageV2 508 557 33 31.0 32.3 0.6X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single FLOAT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 19931 19941 13 0.8 1267.2 1.0X -SQL Json 17274 17302 40 0.9 1098.2 1.2X -SQL Parquet Vectorized 175 182 10 90.0 11.1 114.1X -SQL Parquet MR 2496 2502 9 6.3 158.7 8.0X -SQL ORC Vectorized 432 436 4 36.4 27.5 46.1X -SQL ORC MR 2184 2187 5 7.2 138.8 9.1X - -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz -Parquet Reader Single FLOAT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -ParquetReader Vectorized 287 289 5 54.9 18.2 1.0X -ParquetReader Vectorized -> Row 281 283 3 55.9 17.9 1.0X - -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +SQL CSV 17585 17926 482 0.9 1118.0 1.0X +SQL Json 11927 12180 357 1.3 758.3 1.5X +SQL Parquet Vectorized: DataPageV1 150 161 11 104.6 9.6 116.9X +SQL Parquet Vectorized: DataPageV2 150 160 8 104.7 9.5 117.1X +SQL Parquet MR: DataPageV1 1830 1867 52 8.6 116.4 9.6X +SQL Parquet MR: DataPageV2 1715 1828 160 9.2 109.1 10.3X +SQL ORC Vectorized 328 358 15 48.0 20.8 53.6X +SQL ORC MR 1584 1687 145 9.9 100.7 11.1X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +Parquet Reader Single FLOAT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------- +ParquetReader Vectorized: DataPageV1 207 211 8 76.0 13.2 1.0X +ParquetReader Vectorized: DataPageV2 207 220 11 75.8 13.2 1.0X +ParquetReader Vectorized -> Row: DataPageV1 208 214 9 75.7 13.2 1.0X +ParquetReader Vectorized -> Row: DataPageV2 208 213 9 75.6 13.2 1.0X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single DOUBLE Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 26664 26695 44 0.6 1695.3 1.0X -SQL Json 22655 22657 3 0.7 1440.4 1.2X -SQL Parquet Vectorized 249 254 8 63.2 15.8 107.1X -SQL Parquet MR 2689 2750 86 5.8 171.0 9.9X -SQL ORC Vectorized 517 523 7 30.4 32.9 51.6X -SQL ORC MR 2269 2270 1 6.9 144.3 11.8X - -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz -Parquet Reader Single DOUBLE Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -ParquetReader Vectorized 359 404 100 43.8 22.8 1.0X -ParquetReader Vectorized -> Row 325 329 5 48.4 20.7 1.1X +SQL CSV 22569 22614 63 0.7 1434.9 1.0X +SQL Json 15590 15600 15 1.0 991.2 1.4X +SQL Parquet Vectorized: DataPageV1 225 241 17 69.9 14.3 100.3X +SQL Parquet Vectorized: DataPageV2 219 236 13 72.0 13.9 103.3X +SQL Parquet MR: DataPageV1 2013 2109 136 7.8 128.0 11.2X +SQL Parquet MR: DataPageV2 1850 1967 165 8.5 117.6 12.2X +SQL ORC Vectorized 396 416 25 39.7 25.2 56.9X +SQL ORC MR 1707 1763 79 9.2 108.5 13.2X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +Parquet Reader Single DOUBLE Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------- +ParquetReader Vectorized: DataPageV1 280 298 13 56.2 17.8 1.0X +ParquetReader Vectorized: DataPageV2 278 300 21 56.6 17.7 1.0X +ParquetReader Vectorized -> Row: DataPageV1 280 299 13 56.2 17.8 1.0X +ParquetReader Vectorized -> Row: DataPageV2 304 307 4 51.8 19.3 0.9X ================================================================================================ Int and String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Int and String Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 18336 18703 519 0.6 1748.7 1.0X -SQL Json 15924 16092 238 0.7 1518.6 1.2X -SQL Parquet Vectorized 2534 2540 9 4.1 241.6 7.2X -SQL Parquet MR 4768 4772 5 2.2 454.7 3.8X -SQL ORC Vectorized 2477 2513 51 4.2 236.3 7.4X -SQL ORC MR 4451 4470 27 2.4 424.5 4.1X +SQL CSV 15548 16002 641 0.7 1482.8 1.0X +SQL Json 10801 11108 434 1.0 1030.1 1.4X +SQL Parquet Vectorized: DataPageV1 1858 1966 152 5.6 177.2 8.4X +SQL Parquet Vectorized: DataPageV2 2342 2466 175 4.5 223.4 6.6X +SQL Parquet MR: DataPageV1 3873 3908 49 2.7 369.4 4.0X +SQL Parquet MR: DataPageV2 3764 3869 148 2.8 358.9 4.1X +SQL ORC Vectorized 2018 2020 3 5.2 192.5 7.7X +SQL ORC MR 3247 3450 287 3.2 309.7 4.8X ================================================================================================ Repeated String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Repeated String: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 9701 9753 74 1.1 925.1 1.0X -SQL Json 9562 9566 6 1.1 911.9 1.0X -SQL Parquet Vectorized 907 916 8 11.6 86.5 10.7X -SQL Parquet MR 2020 2021 2 5.2 192.6 4.8X -SQL ORC Vectorized 536 539 3 19.6 51.1 18.1X -SQL ORC MR 2211 2218 9 4.7 210.9 4.4X +SQL CSV 8028 8337 436 1.3 765.6 1.0X +SQL Json 6362 6488 178 1.6 606.7 1.3X +SQL Parquet Vectorized: DataPageV1 642 673 51 16.3 61.3 12.5X +SQL Parquet Vectorized: DataPageV2 646 678 40 16.2 61.6 12.4X +SQL Parquet MR: DataPageV1 1504 1604 141 7.0 143.5 5.3X +SQL Parquet MR: DataPageV2 1645 1646 1 6.4 156.9 4.9X +SQL ORC Vectorized 386 415 25 27.2 36.8 20.8X +SQL ORC MR 1704 1730 37 6.2 162.5 4.7X ================================================================================================ Partitioned Table Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz -Partitioned Table: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -Data column - CSV 25664 25733 97 0.6 1631.7 1.0X -Data column - Json 17014 17023 13 0.9 1081.7 1.5X -Data column - Parquet Vectorized 261 268 8 60.2 16.6 98.2X -Data column - Parquet MR 3173 3182 14 5.0 201.7 8.1X -Data column - ORC Vectorized 363 365 1 43.3 23.1 70.7X -Data column - ORC MR 2672 2675 4 5.9 169.9 9.6X -Partition column - CSV 8197 8202 7 1.9 521.2 3.1X -Partition column - Json 12495 12501 9 1.3 794.4 2.1X -Partition column - Parquet Vectorized 67 69 2 236.1 4.2 385.3X -Partition column - Parquet MR 1465 1466 1 10.7 93.2 17.5X -Partition column - ORC Vectorized 68 71 4 232.7 4.3 379.7X -Partition column - ORC MR 1625 1625 0 9.7 103.3 15.8X -Both columns - CSV 26284 26309 36 0.6 1671.1 1.0X -Both columns - Json 19343 19369 37 0.8 1229.8 1.3X -Both columns - Parquet Vectorized 311 321 10 50.5 19.8 82.5X -Both columns - Parquet MR 3355 3356 2 4.7 213.3 7.6X -Both columns - ORC Vectorized 415 418 5 37.9 26.4 61.9X -Both columns - ORC MR 2739 2743 6 5.7 174.1 9.4X +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +Partitioned Table: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------------- +Data column - CSV 21472 21514 59 0.7 1365.2 1.0X +Data column - Json 11537 11606 97 1.4 733.5 1.9X +Data column - Parquet Vectorized: DataPageV1 238 256 11 66.1 15.1 90.2X +Data column - Parquet Vectorized: DataPageV2 482 507 17 32.6 30.6 44.6X +Data column - Parquet MR: DataPageV1 2213 2355 200 7.1 140.7 9.7X +Data column - Parquet MR: DataPageV2 2036 2163 179 7.7 129.4 10.5X +Data column - ORC Vectorized 289 310 20 54.4 18.4 74.3X +Data column - ORC MR 1898 1936 54 8.3 120.7 11.3X +Partition column - CSV 6307 6364 80 2.5 401.0 3.4X +Partition column - Json 9167 9253 121 1.7 582.8 2.3X +Partition column - Parquet Vectorized: DataPageV1 62 66 3 253.5 3.9 346.1X +Partition column - Parquet Vectorized: DataPageV2 61 65 2 259.2 3.9 353.8X +Partition column - Parquet MR: DataPageV1 1086 1088 3 14.5 69.0 19.8X +Partition column - Parquet MR: DataPageV2 1091 1146 78 14.4 69.4 19.7X +Partition column - ORC Vectorized 63 67 2 251.1 4.0 342.9X +Partition column - ORC MR 1173 1175 3 13.4 74.6 18.3X +Both columns - CSV 21458 22038 820 0.7 1364.3 1.0X +Both columns - Json 12697 12712 22 1.2 807.2 1.7X +Both columns - Parquet Vectorized: DataPageV1 275 288 10 57.2 17.5 78.0X +Both columns - Parquet Vectorized: DataPageV2 505 525 24 31.2 32.1 42.5X +Both columns - Parquet MR: DataPageV1 2541 2547 9 6.2 161.5 8.5X +Both columns - Parquet MR: DataPageV2 2059 2060 2 7.6 130.9 10.4X +Both columns - ORC Vectorized 326 349 16 48.3 20.7 66.0X +Both columns - ORC MR 2116 2151 50 7.4 134.5 10.1X ================================================================================================ String with Nulls Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz String with Nulls Scan (0.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 12006 12014 11 0.9 1145.0 1.0X -SQL Json 19062 19074 16 0.6 1817.9 0.6X -SQL Parquet Vectorized 1608 1612 6 6.5 153.3 7.5X -SQL Parquet MR 3986 4005 27 2.6 380.1 3.0X -ParquetReader Vectorized 1199 1203 7 8.7 114.3 10.0X -SQL ORC Vectorized 1114 1114 0 9.4 106.2 10.8X -SQL ORC MR 3806 3806 1 2.8 362.9 3.2X - -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +SQL CSV 10074 10372 422 1.0 960.7 1.0X +SQL Json 10037 10147 156 1.0 957.2 1.0X +SQL Parquet Vectorized: DataPageV1 1192 1226 47 8.8 113.7 8.4X +SQL Parquet Vectorized: DataPageV2 2349 2423 105 4.5 224.0 4.3X +SQL Parquet MR: DataPageV1 2995 3114 168 3.5 285.6 3.4X +SQL Parquet MR: DataPageV2 3847 3900 75 2.7 366.9 2.6X +ParquetReader Vectorized: DataPageV1 888 918 51 11.8 84.7 11.3X +ParquetReader Vectorized: DataPageV2 2128 2159 43 4.9 203.0 4.7X +SQL ORC Vectorized 837 908 61 12.5 79.8 12.0X +SQL ORC MR 2792 2882 127 3.8 266.3 3.6X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz String with Nulls Scan (50.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 8707 8791 118 1.2 830.4 1.0X -SQL Json 14505 14532 39 0.7 1383.3 0.6X -SQL Parquet Vectorized 1245 1265 27 8.4 118.8 7.0X -SQL Parquet MR 3019 3028 12 3.5 287.9 2.9X -ParquetReader Vectorized 1143 1156 20 9.2 109.0 7.6X -SQL ORC Vectorized 1543 1549 8 6.8 147.1 5.6X -SQL ORC MR 3672 3685 18 2.9 350.2 2.4X - -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +SQL CSV 7808 7810 3 1.3 744.6 1.0X +SQL Json 7434 7491 82 1.4 708.9 1.1X +SQL Parquet Vectorized: DataPageV1 1037 1044 10 10.1 98.9 7.5X +SQL Parquet Vectorized: DataPageV2 1528 1529 3 6.9 145.7 5.1X +SQL Parquet MR: DataPageV1 2300 2411 156 4.6 219.4 3.4X +SQL Parquet MR: DataPageV2 2637 2639 4 4.0 251.5 3.0X +ParquetReader Vectorized: DataPageV1 843 907 56 12.4 80.4 9.3X +ParquetReader Vectorized: DataPageV2 1424 1446 30 7.4 135.8 5.5X +SQL ORC Vectorized 1131 1132 1 9.3 107.8 6.9X +SQL ORC MR 2781 2856 106 3.8 265.3 2.8X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz String with Nulls Scan (95.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 5845 5848 4 1.8 557.4 1.0X -SQL Json 8854 8858 5 1.2 844.4 0.7X -SQL Parquet Vectorized 272 278 8 38.6 25.9 21.5X -SQL Parquet MR 1916 1936 27 5.5 182.7 3.1X -ParquetReader Vectorized 283 285 3 37.0 27.0 20.6X -SQL ORC Vectorized 548 551 3 19.1 52.3 10.7X -SQL ORC MR 1942 1944 2 5.4 185.2 3.0X +SQL CSV 5357 5538 255 2.0 510.9 1.0X +SQL Json 4354 4387 47 2.4 415.2 1.2X +SQL Parquet Vectorized: DataPageV1 212 226 15 49.5 20.2 25.3X +SQL Parquet Vectorized: DataPageV2 265 276 16 39.6 25.2 20.2X +SQL Parquet MR: DataPageV1 1575 1578 4 6.7 150.2 3.4X +SQL Parquet MR: DataPageV2 1624 1638 21 6.5 154.8 3.3X +ParquetReader Vectorized: DataPageV1 219 234 14 47.8 20.9 24.4X +ParquetReader Vectorized: DataPageV2 274 294 17 38.2 26.2 19.5X +SQL ORC Vectorized 370 393 12 28.4 35.3 14.5X +SQL ORC MR 1540 1545 7 6.8 146.9 3.5X ================================================================================================ Single Column Scan From Wide Columns ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Column Scan from 10 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 3388 3395 10 0.3 3231.0 1.0X -SQL Json 4079 4087 11 0.3 3889.6 0.8X -SQL Parquet Vectorized 55 59 7 19.2 52.1 62.0X -SQL Parquet MR 226 229 2 4.6 215.2 15.0X -SQL ORC Vectorized 62 67 13 17.0 58.7 55.0X -SQL ORC MR 194 198 5 5.4 185.0 17.5X - -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +SQL CSV 2159 2212 74 0.5 2059.3 1.0X +SQL Json 2836 2896 84 0.4 2704.5 0.8X +SQL Parquet Vectorized: DataPageV1 54 59 9 19.5 51.4 40.1X +SQL Parquet Vectorized: DataPageV2 66 72 8 15.9 63.1 32.7X +SQL Parquet MR: DataPageV1 173 186 10 6.1 164.5 12.5X +SQL Parquet MR: DataPageV2 159 172 8 6.6 151.8 13.6X +SQL ORC Vectorized 54 60 10 19.2 52.0 39.6X +SQL ORC MR 150 161 7 7.0 143.3 14.4X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Column Scan from 50 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 8141 8142 1 0.1 7764.3 1.0X -SQL Json 15614 15694 113 0.1 14890.4 0.5X -SQL Parquet Vectorized 70 78 12 14.9 67.0 115.8X -SQL Parquet MR 245 250 4 4.3 234.0 33.2X -SQL ORC Vectorized 77 83 9 13.5 73.8 105.2X -SQL ORC MR 212 215 2 4.9 202.1 38.4X - -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +SQL CSV 5877 5883 8 0.2 5605.0 1.0X +SQL Json 11474 11587 159 0.1 10942.9 0.5X +SQL Parquet Vectorized: DataPageV1 66 72 7 15.9 63.1 88.9X +SQL Parquet Vectorized: DataPageV2 83 90 8 12.6 79.4 70.6X +SQL Parquet MR: DataPageV1 191 201 9 5.5 182.6 30.7X +SQL Parquet MR: DataPageV2 179 187 9 5.9 170.3 32.9X +SQL ORC Vectorized 70 76 12 14.9 67.1 83.5X +SQL ORC MR 167 175 7 6.3 159.2 35.2X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Column Scan from 100 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 14087 14102 20 0.1 13434.7 1.0X -SQL Json 30069 30223 218 0.0 28676.2 0.5X -SQL Parquet Vectorized 107 113 8 9.8 101.9 131.9X -SQL Parquet MR 289 295 4 3.6 275.9 48.7X -SQL ORC Vectorized 99 105 14 10.6 94.4 142.3X -SQL ORC MR 236 239 3 4.4 225.5 59.6X +SQL CSV 9695 9965 382 0.1 9245.8 1.0X +SQL Json 22119 23566 2045 0.0 21094.6 0.4X +SQL Parquet Vectorized: DataPageV1 96 104 7 10.9 91.6 100.9X +SQL Parquet Vectorized: DataPageV2 113 121 8 9.3 107.8 85.8X +SQL Parquet MR: DataPageV1 227 243 9 4.6 216.2 42.8X +SQL Parquet MR: DataPageV2 210 225 12 5.0 200.2 46.2X +SQL ORC Vectorized 90 96 10 11.7 85.7 107.9X +SQL ORC MR 188 199 9 5.6 178.9 51.7X diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 7842ab36bb1b7..3002a3b4a876d 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -153,7 +153,7 @@ com.h2database h2 - 2.0.204 + 2.1.210 test diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java b/sql/core/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java index 1f243406c77e0..e91873a008860 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java @@ -24,6 +24,7 @@ public final class RecordBinaryComparator extends RecordComparator { + private static final boolean UNALIGNED = Platform.unaligned(); private static final boolean LITTLE_ENDIAN = ByteOrder.nativeOrder().equals(ByteOrder.LITTLE_ENDIAN); @@ -41,7 +42,7 @@ public int compare( // we have guaranteed `leftLen` == `rightLen`. // check if stars align and we can get both offsets to be aligned - if ((leftOff % 8) == (rightOff % 8)) { + if (!UNALIGNED && ((leftOff % 8) == (rightOff % 8))) { while ((leftOff + i) % 8 != 0 && i < leftLen) { final int v1 = Platform.getByte(leftObj, leftOff + i); final int v2 = Platform.getByte(rightObj, rightOff + i); @@ -52,7 +53,7 @@ public int compare( } } // for architectures that support unaligned accesses, chew it up 8 bytes at a time - if (Platform.unaligned() || (((leftOff + i) % 8 == 0) && ((rightOff + i) % 8 == 0))) { + if (UNALIGNED || (((leftOff + i) % 8 == 0) && ((rightOff + i) % 8 == 0))) { while (i <= leftLen - 8) { long v1 = Platform.getLong(leftObj, leftOff + i); long v2 = Platform.getLong(rightObj, rightOff + i); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 117e98f33a0ec..31e10af38a42b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -226,10 +226,10 @@ public void free() { } /** - * Gets the average bucket list iterations per lookup in the underlying `BytesToBytesMap`. + * Gets the average number of hash probes per key lookup in the underlying `BytesToBytesMap`. */ - public double getAvgHashProbeBucketListIterations() { - return map.getAvgHashProbeBucketListIterations(); + public double getAvgHashProbesPerKey() { + return map.getAvgHashProbesPerKey(); } /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index e1a0607d37c2c..07e35c158c8cb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -19,10 +19,8 @@ package org.apache.spark.sql.execution.datasources.parquet; import java.io.Closeable; -import java.io.File; import java.io.IOException; import java.lang.reflect.InvocationTargetException; -import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -121,25 +119,6 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont } } - /** - * Returns the list of files at 'path' recursively. This skips files that are ignored normally - * by MapReduce. - */ - public static List listDirectory(File path) { - List result = new ArrayList<>(); - if (path.isDirectory()) { - for (File f: path.listFiles()) { - result.addAll(listDirectory(f)); - } - } else { - char c = path.getName().charAt(0); - if (c != '.' && c != '_') { - result.add(path.getAbsolutePath()); - } - } - return result; - } - /** * Initializes the reader to read the file at `path` with `columns` projected. If columns is * null, all the columns are projected. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 0a7b929dafea3..57a307b1b7b6b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -39,6 +39,7 @@ import org.apache.spark.sql.types.Decimal; import static org.apache.parquet.column.ValuesType.REPETITION_LEVEL; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BOOLEAN; import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT64; /** @@ -292,6 +293,16 @@ private ValuesReader getValuesReader(Encoding encoding) { return new VectorizedDeltaByteArrayReader(); case DELTA_BINARY_PACKED: return new VectorizedDeltaBinaryPackedReader(); + case RLE: + PrimitiveType.PrimitiveTypeName typeName = + this.descriptor.getPrimitiveType().getPrimitiveTypeName(); + // RLE encoding only supports boolean type `Values`, and `bitwidth` is always 1. + if (typeName == BOOLEAN) { + return new VectorizedRleValuesReader(1); + } else { + throw new UnsupportedOperationException( + "RLE encoding is not supported for values of type: " + typeName); + } default: throw new UnsupportedOperationException("Unsupported encoding: " + encoding); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedDeltaBinaryPackedReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedDeltaBinaryPackedReader.java index 62fb5f8c96bbf..7b2aac3118e5f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedDeltaBinaryPackedReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedDeltaBinaryPackedReader.java @@ -73,10 +73,10 @@ public class VectorizedDeltaBinaryPackedReader extends VectorizedReaderBase { private ByteBufferInputStream in; // temporary buffers used by readByte, readShort, readInteger, and readLong - byte byteVal; - short shortVal; - int intVal; - long longVal; + private byte byteVal; + private short shortVal; + private int intVal; + private long longVal; @Override public void initFromPage(int valueCount, ByteBufferInputStream in) throws IOException { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 50056bf4073e9..0e976be2f652e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -305,7 +305,7 @@ public boolean nextBatch() throws IOException { if (rowsReturned >= totalRowCount) return false; checkEndOfRowGroup(); - int num = (int) Math.min((long) capacity, totalCountLoadedSoFar - rowsReturned); + int num = (int) Math.min(capacity, totalCountLoadedSoFar - rowsReturned); for (int i = 0; i < columnReaders.length; ++i) { if (columnReaders[i] == null) continue; columnReaders[i].readBatch(num, columnVectors[i]); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index cd97fb6c3cd55..bd7cbc7e17188 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -40,6 +40,7 @@ * This encoding is used in multiple places: * - Definition/Repetition levels * - Dictionary ids. + * - Boolean type values of Parquet DataPageV2 */ public final class VectorizedRleValuesReader extends ValuesReader implements VectorizedValuesReader { @@ -369,7 +370,25 @@ public void readBinary(int total, WritableColumnVector c, int rowId) { @Override public void readBooleans(int total, WritableColumnVector c, int rowId) { - throw new UnsupportedOperationException("only readInts is valid."); + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + c.putBooleans(rowId, n, currentValue != 0); + break; + case PACKED: + for (int i = 0; i < n; ++i) { + // For Boolean types, `currentBuffer[currentBufferIdx++]` can only be 0 or 1 + c.putByte(rowId + i, (byte) currentBuffer[currentBufferIdx++]); + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } } @Override @@ -389,25 +408,12 @@ public Binary readBinary(int len) { @Override public void skipIntegers(int total) { - int left = total; - while (left > 0) { - if (this.currentCount == 0) this.readNextGroup(); - int n = Math.min(left, this.currentCount); - switch (mode) { - case RLE: - break; - case PACKED: - currentBufferIdx += n; - break; - } - currentCount -= n; - left -= n; - } + skipValues(total); } @Override public void skipBooleans(int total) { - throw new UnsupportedOperationException("only skipIntegers is valid"); + skipValues(total); } @Override @@ -533,4 +539,24 @@ private void readNextGroup() { throw new ParquetDecodingException("Failed to read from input stream", e); } } + + /** + * Skip `n` values from the current reader. + */ + private void skipValues(int n) { + int left = n; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int num = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + break; + case PACKED: + currentBufferIdx += num; + break; + } + currentCount -= num; + left -= num; + } + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index 37c348cf4ed66..353a128254412 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -91,7 +91,8 @@ public static void populate(WritableColumnVector col, InternalRow row, int field } else if (t instanceof CalendarIntervalType) { CalendarInterval c = (CalendarInterval)row.get(fieldIdx, t); col.getChild(0).putInts(0, capacity, c.months); - col.getChild(1).putLongs(0, capacity, c.microseconds); + col.getChild(1).putInts(0, capacity, c.days); + col.getChild(2).putLongs(0, capacity, c.microseconds); } else if (t instanceof DateType || t instanceof YearMonthIntervalType) { col.putInts(0, capacity, row.getInt(fieldIdx)); } else if (t instanceof TimestampType || t instanceof TimestampNTZType || diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ConstantColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ConstantColumnVector.java new file mode 100644 index 0000000000000..3a5dea479cab5 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ConstantColumnVector.java @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import java.math.BigDecimal; +import java.math.BigInteger; + +import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * This class adds the constant support to ColumnVector. + * It supports all the types and contains `set` APIs, + * which will set the exact same value to all rows. + * + * Capacity: The vector stores only one copy of the data. + */ +public class ConstantColumnVector extends ColumnVector { + + // The data stored in this ConstantColumnVector, the vector stores only one copy of the data. + private byte nullData; + private byte byteData; + private short shortData; + private int intData; + private long longData; + private float floatData; + private double doubleData; + private UTF8String stringData; + private byte[] byteArrayData; + private ConstantColumnVector[] childData; + private ColumnarArray arrayData; + private ColumnarMap mapData; + + private final int numRows; + + /** + * @param numRows: The number of rows for this ConstantColumnVector + * @param type: The data type of this ConstantColumnVector + */ + public ConstantColumnVector(int numRows, DataType type) { + super(type); + this.numRows = numRows; + + if (type instanceof StructType) { + this.childData = new ConstantColumnVector[((StructType) type).fields().length]; + } else if (type instanceof CalendarIntervalType) { + // Three columns. Months as int. Days as Int. Microseconds as Long. + this.childData = new ConstantColumnVector[3]; + } else { + this.childData = null; + } + } + + @Override + public void close() { + stringData = null; + byteArrayData = null; + if (childData != null) { + for (int i = 0; i < childData.length; i++) { + if (childData[i] != null) { + childData[i].close(); + childData[i] = null; + } + } + childData = null; + } + arrayData = null; + mapData = null; + } + + @Override + public boolean hasNull() { + return nullData == 1; + } + + @Override + public int numNulls() { + return hasNull() ? numRows : 0; + } + + @Override + public boolean isNullAt(int rowId) { + return nullData == 1; + } + + /** + * Sets all rows as `null` + */ + public void setNull() { + nullData = (byte) 1; + } + + /** + * Sets all rows as not `null` + */ + public void setNotNull() { + nullData = (byte) 0; + } + + @Override + public boolean getBoolean(int rowId) { + return byteData == 1; + } + + /** + * Sets the boolean `value` for all rows + */ + public void setBoolean(boolean value) { + byteData = (byte) ((value) ? 1 : 0); + } + + @Override + public byte getByte(int rowId) { + return byteData; + } + + /** + * Sets the byte `value` for all rows + */ + public void setByte(byte value) { + byteData = value; + } + + @Override + public short getShort(int rowId) { + return shortData; + } + + /** + * Sets the short `value` for all rows + */ + public void setShort(short value) { + shortData = value; + } + + @Override + public int getInt(int rowId) { + return intData; + } + + /** + * Sets the int `value` for all rows + */ + public void setInt(int value) { + intData = value; + } + + @Override + public long getLong(int rowId) { + return longData; + } + + /** + * Sets the long `value` for all rows + */ + public void setLong(long value) { + longData = value; + } + + @Override + public float getFloat(int rowId) { + return floatData; + } + + /** + * Sets the float `value` for all rows + */ + public void setFloat(float value) { + floatData = value; + } + + @Override + public double getDouble(int rowId) { + return doubleData; + } + + /** + * Sets the double `value` for all rows + */ + public void setDouble(double value) { + doubleData = value; + } + + @Override + public ColumnarArray getArray(int rowId) { + return arrayData; + } + + /** + * Sets the `ColumnarArray` `value` for all rows + */ + public void setArray(ColumnarArray value) { + arrayData = value; + } + + @Override + public ColumnarMap getMap(int ordinal) { + return mapData; + } + + /** + * Sets the `ColumnarMap` `value` for all rows + */ + public void setMap(ColumnarMap value) { + mapData = value; + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + // copy and modify from WritableColumnVector + if (precision <= Decimal.MAX_INT_DIGITS()) { + return Decimal.createUnsafe(getInt(rowId), precision, scale); + } else if (precision <= Decimal.MAX_LONG_DIGITS()) { + return Decimal.createUnsafe(getLong(rowId), precision, scale); + } else { + byte[] bytes = getBinary(rowId); + BigInteger bigInteger = new BigInteger(bytes); + BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); + return Decimal.apply(javaDecimal, precision, scale); + } + } + + /** + * Sets the `Decimal` `value` with the precision for all rows + */ + public void setDecimal(Decimal value, int precision) { + // copy and modify from WritableColumnVector + if (precision <= Decimal.MAX_INT_DIGITS()) { + setInt((int) value.toUnscaledLong()); + } else if (precision <= Decimal.MAX_LONG_DIGITS()) { + setLong(value.toUnscaledLong()); + } else { + BigInteger bigInteger = value.toJavaBigDecimal().unscaledValue(); + setByteArray(bigInteger.toByteArray()); + } + } + + @Override + public UTF8String getUTF8String(int rowId) { + return stringData; + } + + /** + * Sets the `UTF8String` `value` for all rows + */ + public void setUtf8String(UTF8String value) { + stringData = value; + } + + /** + * Sets the byte array `value` for all rows + */ + private void setByteArray(byte[] value) { + byteArrayData = value; + } + + @Override + public byte[] getBinary(int rowId) { + return byteArrayData; + } + + /** + * Sets the binary `value` for all rows + */ + public void setBinary(byte[] value) { + setByteArray(value); + } + + @Override + public ColumnVector getChild(int ordinal) { + return childData[ordinal]; + } + + /** + * Sets the child `ConstantColumnVector` `value` at the given ordinal for all rows + */ + public void setChild(int ordinal, ConstantColumnVector value) { + childData[ordinal] = value; + } +} diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.deploy.history.EventFilterBuilder b/sql/core/src/main/resources/META-INF/services/org.apache.spark.deploy.history.EventFilterBuilder index 5025616b752d1..03cddd94645d6 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.deploy.history.EventFilterBuilder +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.deploy.history.EventFilterBuilder @@ -1 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + org.apache.spark.sql.execution.history.SQLEventFilterBuilder \ No newline at end of file diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.jdbc.JdbcConnectionProvider b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.jdbc.JdbcConnectionProvider index 6e42517a6d40c..b3f30a6650017 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.jdbc.JdbcConnectionProvider +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.jdbc.JdbcConnectionProvider @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + org.apache.spark.sql.execution.datasources.jdbc.connection.BasicConnectionProvider org.apache.spark.sql.execution.datasources.jdbc.connection.DB2ConnectionProvider org.apache.spark.sql.execution.datasources.jdbc.connection.MariaDBConnectionProvider diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index fe4554a9c50b3..1365134641758 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + org.apache.spark.sql.execution.datasources.v2.csv.CSVDataSourceV2 org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider org.apache.spark.sql.execution.datasources.v2.json.JsonDataSourceV2 diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.status.AppHistoryServerPlugin b/sql/core/src/main/resources/META-INF/services/org.apache.spark.status.AppHistoryServerPlugin index 6771eef525307..2fca64c565d16 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.status.AppHistoryServerPlugin +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.status.AppHistoryServerPlugin @@ -1,2 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + org.apache.spark.sql.execution.ui.SQLHistoryServerPlugin org.apache.spark.sql.execution.ui.StreamingQueryHistoryServerPlugin diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 396cb26259f68..e4d6dd2297f9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -327,7 +327,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { supportsExtract, catalogManager, dsOptions) val tableSpec = TableSpec( - bucketSpec = None, properties = Map.empty, provider = Some(source), options = Map.empty, @@ -596,7 +595,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { case (SaveMode.Overwrite, _) => val tableSpec = TableSpec( - bucketSpec = None, properties = Map.empty, provider = Some(source), options = Map.empty, @@ -617,7 +615,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { // created between our existence check and physical execution, but this can't be helped // in any case. val tableSpec = TableSpec( - bucketSpec = None, properties = Map.empty, provider = Some(source), options = Map.empty, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala index 22b2eb978d917..93127e6288a3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala @@ -108,7 +108,6 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) override def create(): Unit = { val tableSpec = TableSpec( - bucketSpec = None, properties = properties.toMap, provider = provider, options = Map.empty, @@ -198,7 +197,6 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) private def internalReplace(orCreate: Boolean): Unit = { val tableSpec = TableSpec( - bucketSpec = None, properties = properties.toMap, provider = provider, options = Map.empty, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 9dd38d850e329..62dea96614a5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1503,10 +1503,10 @@ class Dataset[T] private[sql]( case typedCol: TypedColumn[_, _] => // Checks if a `TypedColumn` has been inserted with // specific input type and schema by `withInputType`. - val needInputType = typedCol.expr.find { + val needInputType = typedCol.expr.exists { case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty => true case _ => false - }.isDefined + } if (!needInputType) { typedCol @@ -2478,6 +2478,35 @@ class Dataset[T] private[sql]( */ def withColumn(colName: String, col: Column): DataFrame = withColumns(Seq(colName), Seq(col)) + /** + * (Scala-specific) Returns a new Dataset by adding columns or replacing the existing columns + * that has the same names. + * + * `colsMap` is a map of column name and column, the column must only refer to attributes + * supplied by this Dataset. It is an error to add columns that refers to some other Dataset. + * + * @group untypedrel + * @since 3.3.0 + */ + def withColumns(colsMap: Map[String, Column]): DataFrame = { + val (colNames, newCols) = colsMap.toSeq.unzip + withColumns(colNames, newCols) + } + + /** + * (Java-specific) Returns a new Dataset by adding columns or replacing the existing columns + * that has the same names. + * + * `colsMap` is a map of column name and column, the column must only refer to attribute + * supplied by this Dataset. It is an error to add columns that refers to some other Dataset. + * + * @group untypedrel + * @since 3.3.0 + */ + def withColumns(colsMap: java.util.Map[String, Column]): DataFrame = withColumns( + colsMap.asScala.toMap + ) + /** * Returns a new Dataset by adding columns or replacing the existing columns that has * the same names. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 96bb1b3027f15..7e3c622196173 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -21,6 +21,7 @@ import java.util.Locale import scala.collection.JavaConverters._ +import org.apache.spark.SparkRuntimeException import org.apache.spark.annotation.Stable import org.apache.spark.api.python.PythonEvalType import org.apache.spark.broadcast.Broadcast @@ -452,7 +453,13 @@ class RelationalGroupedDataset protected[sql]( case RelationalGroupedDataset.GroupByType => val valueExprs = values.map(_ match { case c: Column => c.expr - case v => Literal.apply(v) + case v => + try { + Literal.apply(v) + } catch { + case _: SparkRuntimeException => + throw QueryExecutionErrors.pivotColumnUnsupportedError(v, pivotColumn.expr.dataType) + } }) new RelationalGroupedDataset( df, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 490ab9f8956cb..ab43aa49944c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -24,7 +24,7 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.PythonRDDServer import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Column, DataFrame, SQLContext} +import org.apache.spark.sql.{Column, DataFrame, SparkSession} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.expressions.{CastTimestampNTZToLong, ExpressionInfo} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -59,8 +59,8 @@ private[sql] object PythonSQLUtils extends Logging { * Python callable function to read a file in Arrow stream format and create a [[RDD]] * using each serialized ArrowRecordBatch as a partition. */ - def readArrowStreamFromFile(sqlContext: SQLContext, filename: String): JavaRDD[Array[Byte]] = { - ArrowConverters.readArrowStreamFromFile(sqlContext, filename) + def readArrowStreamFromFile(session: SparkSession, filename: String): JavaRDD[Array[Byte]] = { + ArrowConverters.readArrowStreamFromFile(session, filename) } /** @@ -70,8 +70,8 @@ private[sql] object PythonSQLUtils extends Logging { def toDataFrame( arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, - sqlContext: SQLContext): DataFrame = { - ArrowConverters.toDataFrame(arrowBatchRDD, schemaString, sqlContext) + session: SparkSession): DataFrame = { + ArrowConverters.toDataFrame(arrowBatchRDD, schemaString, session) } def explainString(queryExecution: QueryExecution, mode: String): String = { @@ -85,13 +85,13 @@ private[sql] object PythonSQLUtils extends Logging { * Helper for making a dataframe from arrow data from data sent from python over a socket. This is * used when encryption is enabled, and we don't want to write data to a file. */ -private[sql] class ArrowRDDServer(sqlContext: SQLContext) extends PythonRDDServer { +private[sql] class ArrowRDDServer(session: SparkSession) extends PythonRDDServer { override protected def streamToRDD(input: InputStream): RDD[Array[Byte]] = { // Create array to consume iterator so that we can safely close the inputStream val batches = ArrowConverters.getBatchesFromStream(Channels.newChannel(input)).toArray // Parallelize the record batches to create an RDD - JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) + JavaRDD.fromRDD(session.sparkContext.parallelize(batches, batches.length)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index befaea24e0002..7831ddee4f9b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -230,7 +230,7 @@ private[sql] object SQLUtils extends Logging { def readArrowStreamFromFile( sparkSession: SparkSession, filename: String): JavaRDD[Array[Byte]] = { - ArrowConverters.readArrowStreamFromFile(sparkSession.sqlContext, filename) + ArrowConverters.readArrowStreamFromFile(sparkSession, filename) } /** @@ -241,6 +241,6 @@ private[sql] object SQLUtils extends Logging { arrowBatchRDD: JavaRDD[Array[Byte]], schema: StructType, sparkSession: SparkSession): DataFrame = { - ArrowConverters.toDataFrame(arrowBatchRDD, schema.json, sparkSession.sqlContext) + ArrowConverters.toDataFrame(arrowBatchRDD, schema.json, sparkSession) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index aaf2ead592c98..13237eb75c9a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, CatalogUtils} +import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType, CatalogUtils} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule @@ -146,30 +146,28 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) // For CREATE TABLE [AS SELECT], we should use the v1 command if the catalog is resolved to the // session catalog and the table provider is not v2. - case c @ CreateTable(ResolvedDBObjectName(catalog, name), _, _, _, _) => + case c @ CreateTable(ResolvedDBObjectName(catalog, name), _, _, _, _) + if isSessionCatalog(catalog) => val (storageFormat, provider) = getStorageFormatAndProvider( c.tableSpec.provider, c.tableSpec.options, c.tableSpec.location, c.tableSpec.serde, ctas = false) - if (isSessionCatalog(catalog) && !isV2Provider(provider)) { + if (!isV2Provider(provider)) { constructV1TableCmd(None, c.tableSpec, name, c.tableSchema, c.partitioning, c.ignoreIfExists, storageFormat, provider) } else { - val newTableSpec = c.tableSpec.copy(bucketSpec = None) - c.copy(partitioning = c.partitioning ++ c.tableSpec.bucketSpec.map(_.asTransform), - tableSpec = newTableSpec) + c } - case c @ CreateTableAsSelect(ResolvedDBObjectName(catalog, name), _, _, _, _, _) => + case c @ CreateTableAsSelect(ResolvedDBObjectName(catalog, name), _, _, _, _, _) + if isSessionCatalog(catalog) => val (storageFormat, provider) = getStorageFormatAndProvider( c.tableSpec.provider, c.tableSpec.options, c.tableSpec.location, c.tableSpec.serde, ctas = true) - if (isSessionCatalog(catalog) && !isV2Provider(provider)) { + if (!isV2Provider(provider)) { constructV1TableCmd(Some(c.query), c.tableSpec, name, new StructType, c.partitioning, c.ignoreIfExists, storageFormat, provider) } else { - val newTableSpec = c.tableSpec.copy(bucketSpec = None) - c.copy(partitioning = c.partitioning ++ c.tableSpec.bucketSpec.map(_.asTransform), - tableSpec = newTableSpec) + c } case RefreshTable(ResolvedV1TableIdentifier(ident)) => @@ -180,26 +178,23 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) // For REPLACE TABLE [AS SELECT], we should fail if the catalog is resolved to the // session catalog and the table provider is not v2. - case c @ ReplaceTable( - ResolvedDBObjectName(catalog, _), _, _, _, _) => + case c @ ReplaceTable(ResolvedDBObjectName(catalog, _), _, _, _, _) + if isSessionCatalog(catalog) => val provider = c.tableSpec.provider.getOrElse(conf.defaultDataSourceName) - if (isSessionCatalog(catalog) && !isV2Provider(provider)) { + if (!isV2Provider(provider)) { throw QueryCompilationErrors.operationOnlySupportedWithV2TableError("REPLACE TABLE") } else { - val newTableSpec = c.tableSpec.copy(bucketSpec = None) - c.copy(partitioning = c.partitioning ++ c.tableSpec.bucketSpec.map(_.asTransform), - tableSpec = newTableSpec) + c } - case c @ ReplaceTableAsSelect(ResolvedDBObjectName(catalog, _), _, _, _, _, _) => + case c @ ReplaceTableAsSelect(ResolvedDBObjectName(catalog, _), _, _, _, _, _) + if isSessionCatalog(catalog) => val provider = c.tableSpec.provider.getOrElse(conf.defaultDataSourceName) - if (isSessionCatalog(catalog) && !isV2Provider(provider)) { + if (!isV2Provider(provider)) { throw QueryCompilationErrors .operationOnlySupportedWithV2TableError("REPLACE TABLE AS SELECT") } else { - val newTableSpec = c.tableSpec.copy(bucketSpec = None) - c.copy(partitioning = c.partitioning ++ c.tableSpec.bucketSpec.map(_.asTransform), - tableSpec = newTableSpec) + c } case DropTable(ResolvedV1TableIdentifier(ident), ifExists, purge) => @@ -221,7 +216,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) val newProperties = c.properties -- CatalogV2Util.NAMESPACE_RESERVED_PROPERTIES CreateDatabaseCommand(name, c.ifNotExists, location, comment, newProperties) - case d @ DropNamespace(DatabaseInSessionCatalog(db), _, _) => + case d @ DropNamespace(DatabaseInSessionCatalog(db), _, _) if conf.useV1Command => DropDatabaseCommand(db, d.ifExists, d.cascade) case ShowTables(DatabaseInSessionCatalog(db), pattern, output) if conf.useV1Command => @@ -266,12 +261,19 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) isOverwrite, partition) - case ShowCreateTable(ResolvedV1TableOrViewIdentifier(ident), asSerde, output) => - if (asSerde) { - ShowCreateTableAsSerdeCommand(ident.asTableIdentifier, output) - } else { + case ShowCreateTable(ResolvedV1TableOrViewIdentifier(ident), asSerde, output) if asSerde => + ShowCreateTableAsSerdeCommand(ident.asTableIdentifier, output) + + // If target is view, force use v1 command + case ShowCreateTable(ResolvedViewIdentifier(ident), _, output) => + ShowCreateTableCommand(ident.asTableIdentifier, output) + + case ShowCreateTable(ResolvedV1TableIdentifier(ident), _, output) + if conf.useV1Command => ShowCreateTableCommand(ident.asTableIdentifier, output) + + case ShowCreateTable(ResolvedTable(catalog, ident, table: V1Table, _), _, output) + if isSessionCatalog(catalog) && DDLUtils.isHiveTable(table.catalogTable) => ShowCreateTableCommand(ident.asTableIdentifier, output) - } case TruncateTable(ResolvedV1TableIdentifier(ident)) => TruncateTableCommand(ident.asTableIdentifier, None) @@ -446,7 +448,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) storageFormat: CatalogStorageFormat, provider: String): CreateTableV1 = { val tableDesc = buildCatalogTable(name.asTableIdentifier, tableSchema, - partitioning, tableSpec.bucketSpec, tableSpec.properties, provider, + partitioning, tableSpec.properties, provider, tableSpec.location, tableSpec.comment, storageFormat, tableSpec.external) val mode = if (ignoreIfExists) SaveMode.Ignore else SaveMode.ErrorIfExists CreateTableV1(tableDesc, mode, query) @@ -518,7 +520,6 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) table: TableIdentifier, schema: StructType, partitioning: Seq[Transform], - bucketSpec: Option[BucketSpec], properties: Map[String, String], provider: String, location: Option[String], @@ -530,6 +531,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) } else { CatalogTableType.MANAGED } + val (partitionColumns, maybeBucketSpec) = partitioning.toSeq.convertTransforms CatalogTable( identifier = table, @@ -537,8 +539,8 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) storage = storageFormat, schema = schema, provider = Some(provider), - partitionColumnNames = partitioning.asPartitionColumns, - bucketSpec = bucketSpec, + partitionColumnNames = partitionColumns, + bucketSpec = maybeBucketSpec, properties = properties, comment = comment) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala new file mode 100644 index 0000000000000..1e361695056a7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.catalyst.expressions.{Add, And, Attribute, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Divide, EqualTo, Expression, IsNotNull, IsNull, Literal, Multiply, Not, Or, Remainder, Subtract, UnaryMinus} +import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue} + +/** + * The builder to generate V2 expressions from catalyst expressions. + */ +class V2ExpressionBuilder(e: Expression) { + + def build(): Option[V2Expression] = generateExpression(e) + + private def canTranslate(b: BinaryOperator) = b match { + case _: And | _: Or => true + case _: BinaryComparison => true + case _: BitwiseAnd | _: BitwiseOr | _: BitwiseXor => true + case add: Add => add.failOnError + case sub: Subtract => sub.failOnError + case mul: Multiply => mul.failOnError + case div: Divide => div.failOnError + case r: Remainder => r.failOnError + case _ => false + } + + private def generateExpression(expr: Expression): Option[V2Expression] = expr match { + case Literal(value, dataType) => Some(LiteralValue(value, dataType)) + case attr: Attribute => Some(FieldReference.column(attr.name)) + case IsNull(col) => generateExpression(col) + .map(c => new GeneralScalarExpression("IS_NULL", Array[V2Expression](c))) + case IsNotNull(col) => generateExpression(col) + .map(c => new GeneralScalarExpression("IS_NOT_NULL", Array[V2Expression](c))) + case b: BinaryOperator if canTranslate(b) => + val left = generateExpression(b.left) + val right = generateExpression(b.right) + if (left.isDefined && right.isDefined) { + Some(new GeneralScalarExpression(b.sqlOperator, Array[V2Expression](left.get, right.get))) + } else { + None + } + case Not(eq: EqualTo) => + val left = generateExpression(eq.left) + val right = generateExpression(eq.right) + if (left.isDefined && right.isDefined) { + Some(new GeneralScalarExpression("!=", Array[V2Expression](left.get, right.get))) + } else { + None + } + case Not(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("NOT", Array[V2Expression](v))) + case UnaryMinus(child, true) => generateExpression(child) + .map(v => new GeneralScalarExpression("-", Array[V2Expression](v))) + case BitwiseNot(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("~", Array[V2Expression](v))) + case CaseWhen(branches, elseValue) => + val conditions = branches.map(_._1).flatMap(generateExpression) + val values = branches.map(_._2).flatMap(generateExpression) + if (conditions.length == branches.length && values.length == branches.length) { + val branchExpressions = conditions.zip(values).flatMap { case (c, v) => + Seq[V2Expression](c, v) + } + if (elseValue.isDefined) { + elseValue.flatMap(generateExpression).map { v => + val children = (branchExpressions :+ v).toArray[V2Expression] + // The children looks like [condition1, value1, ..., conditionN, valueN, elseValue] + new GeneralScalarExpression("CASE_WHEN", children) + } + } else { + // The children looks like [condition1, value1, ..., conditionN, valueN] + Some(new GeneralScalarExpression("CASE_WHEN", branchExpressions.toArray[V2Expression])) + } + } else { + None + } + // TODO supports other expressions + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 426b2337f76a3..27d6bedad47c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -161,7 +161,7 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { blocking: Boolean = false): Unit = { val shouldRemove: LogicalPlan => Boolean = if (cascade) { - _.find(_.sameResult(plan)).isDefined + _.exists(_.sameResult(plan)) } else { _.sameResult(plan) } @@ -187,7 +187,7 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { // will keep it as it is. It means the physical plan has been re-compiled already in the // other thread. val cacheAlreadyLoaded = cd.cachedRepresentation.cacheBuilder.isCachedColumnBuffersLoaded - cd.plan.find(_.sameResult(plan)).isDefined && !cacheAlreadyLoaded + cd.plan.exists(_.sameResult(plan)) && !cacheAlreadyLoaded }) } } @@ -207,7 +207,7 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { * Tries to re-cache all the cache entries that refer to the given plan. */ def recacheByPlan(spark: SparkSession, plan: LogicalPlan): Unit = { - recacheByCondition(spark, _.plan.find(_.sameResult(plan)).isDefined) + recacheByCondition(spark, _.plan.exists(_.sameResult(plan))) } /** @@ -288,7 +288,7 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { */ def recacheByPath(spark: SparkSession, resourcePath: Path, fs: FileSystem): Unit = { val qualifiedPath = fs.makeQualified(resourcePath) - recacheByCondition(spark, _.plan.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined) + recacheByCondition(spark, _.plan.exists(lookupAndRefresh(_, fs, qualifiedPath))) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala index 70a508e6b7ec9..1971b8b1baf09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala @@ -264,7 +264,7 @@ private object RowToColumnConverter { case ShortType => ShortConverter case IntegerType | DateType | _: YearMonthIntervalType => IntConverter case FloatType => FloatConverter - case LongType | TimestampType | _: DayTimeIntervalType => LongConverter + case LongType | TimestampType | TimestampNTZType | _: DayTimeIntervalType => LongConverter case DoubleType => DoubleConverter case StringType => StringConverter case CalendarIntervalType => CalendarConverter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 4bd6c239a3367..1e2fa41ef0f49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource} import org.apache.spark.sql.execution.datasources.v2.PushedDownOperators import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector +import org.apache.spark.sql.execution.vectorized.ConstantColumnVector import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{BaseRelation, Filter} import org.apache.spark.sql.types.StructType @@ -200,7 +200,7 @@ case class FileSourceScanExec( extends DataSourceScanExec { lazy val metadataColumns: Seq[AttributeReference] = - output.collect { case MetadataAttribute(attr) => attr } + output.collect { case FileSourceMetadataAttribute(attr) => attr } // Note that some vals referring the file-based relation are lazy intentionally // so that this plan can be canonicalized on executor side too. See SPARK-23731. @@ -221,8 +221,8 @@ case class FileSourceScanExec( requiredSchema = requiredSchema, partitionSchema = relation.partitionSchema, relation.sparkSession.sessionState.conf).map { vectorTypes => - // for column-based file format, append metadata struct column's vector type classes if any - vectorTypes ++ Seq.fill(metadataColumns.size)(classOf[OnHeapColumnVector].getName) + // for column-based file format, append metadata column's vector type classes if any + vectorTypes ++ Seq.fill(metadataColumns.size)(classOf[ConstantColumnVector].getName) } private lazy val driverMetrics: HashMap[String, Long] = HashMap.empty @@ -239,7 +239,7 @@ case class FileSourceScanExec( } private def isDynamicPruningFilter(e: Expression): Boolean = - e.find(_.isInstanceOf[PlanExpression[_]]).isDefined + e.exists(_.isInstanceOf[PlanExpression[_]]) @transient lazy val selectedPartitions: Array[PartitionDirectory] = { val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) @@ -366,9 +366,11 @@ case class FileSourceScanExec( @transient private lazy val pushedDownFilters = { val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation) - // TODO: should be able to push filters containing metadata columns down to skip files + // `dataFilters` should not include any metadata col filters + // because the metadata struct has been flatted in FileSourceStrategy + // and thus metadata col filters are invalid to be pushed down dataFilters.filterNot(_.references.exists { - case MetadataAttribute(_) => true + case FileSourceMetadataAttribute(_) => true case _ => false }).flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExplainUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExplainUtils.scala index 1eea0cd777ed9..12ffbc8554e63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExplainUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExplainUtils.scala @@ -247,6 +247,8 @@ object ExplainUtils extends AdaptiveSparkPlanHelper { plan.foreach { case a: AdaptiveSparkPlanExec => getSubqueries(a.executedPlan, subqueries) + case q: QueryStageExec => + getSubqueries(q.plan, subqueries) case p: SparkPlan => p.expressions.foreach (_.collect { case e: PlanExpression[_] => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 6c7929437ffdd..f6dbf5fda1816 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -142,8 +142,8 @@ case class GenerateExec( case (attr, _) => requiredAttrSet.contains(attr) }.map(_._2) boundGenerator match { - case e: CollectionGenerator => codeGenCollection(ctx, e, requiredInput, row) - case g => codeGenTraversableOnce(ctx, g, requiredInput, row) + case e: CollectionGenerator => codeGenCollection(ctx, e, requiredInput) + case g => codeGenTraversableOnce(ctx, g, requiredInput) } } @@ -153,8 +153,7 @@ case class GenerateExec( private def codeGenCollection( ctx: CodegenContext, e: CollectionGenerator, - input: Seq[ExprCode], - row: ExprCode): String = { + input: Seq[ExprCode]): String = { // Generate code for the generator. val data = e.genCode(ctx) @@ -241,8 +240,7 @@ case class GenerateExec( private def codeGenTraversableOnce( ctx: CodegenContext, e: Expression, - requiredInput: Seq[ExprCode], - row: ExprCode): String = { + requiredInput: Seq[ExprCode]): String = { // Generate the code for the generator val data = e.genCode(ctx) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 26c6904a896a5..1b089943a680e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -304,8 +304,6 @@ class QueryExecution( } private def stringWithStats(maxFields: Int, append: String => Unit): Unit = { - val maxFields = SQLConf.get.maxToStringFields - // trigger to compute stats for logical plans try { // This will trigger to compute stats for all the nodes in the plan, including subqueries, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index a1b093f88f862..4b561b813067e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -43,7 +43,7 @@ object SortPrefixUtils { case StringType => stringPrefixComparator(sortOrder) case BinaryType => binaryPrefixComparator(sortOrder) case BooleanType | ByteType | ShortType | IntegerType | LongType | DateType | TimestampType | - _: AnsiIntervalType => + TimestampNTZType | _: AnsiIntervalType => longPrefixComparator(sortOrder) case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS => longPrefixComparator(sortOrder) @@ -123,7 +123,7 @@ object SortPrefixUtils { def canSortFullyWithPrefix(sortOrder: SortOrder): Boolean = { sortOrder.dataType match { case BooleanType | ByteType | ShortType | IntegerType | LongType | DateType | - TimestampType | FloatType | DoubleType | _: AnsiIntervalType => + TimestampType | TimestampNTZType | FloatType | DoubleType | _: AnsiIntervalType => true case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index dc3ceb5c595d0..7e8fb4a157262 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -47,6 +47,8 @@ class SparkOptimizer( PushDownPredicates) :+ Batch("Cleanup filters that cannot be pushed down", Once, CleanupDynamicPruningFilters, + // cleanup the unnecessary TrueLiteral predicates + BooleanSimplification, PruneFilters)) ++ postHocOptimizationBatches :+ Batch("Extract Python UDFs", Once, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index f56beeb79db72..bb1c5c3873cd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.NextIterator object SparkPlan { /** The original [[LogicalPlan]] from which this [[SparkPlan]] is converted. */ @@ -384,10 +385,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ val bis = new ByteArrayInputStream(bytes) val ins = new DataInputStream(codec.compressedInputStream(bis)) - new Iterator[InternalRow] { + new NextIterator[InternalRow] { private var sizeOfNextRow = ins.readInt() - override def hasNext: Boolean = sizeOfNextRow >= 0 - override def next(): InternalRow = { + private def _next(): InternalRow = { val bs = new Array[Byte](sizeOfNextRow) ins.readFully(bs) val row = new UnsafeRow(nFields) @@ -395,6 +395,22 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ sizeOfNextRow = ins.readInt() row } + + override def getNext(): InternalRow = { + if (sizeOfNextRow >= 0) { + try { + _next() + } catch { + case t: Throwable if ins != null => + ins.close() + throw t + } + } else { + finished = true + null + } + } + override def close(): Unit = ins.close() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index fed02dddecf78..a4e72e04507b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -318,7 +318,7 @@ class SparkSqlAstBuilder extends AstBuilder { val (_, _, _, _, options, location, _, _) = visitCreateTableClauses(ctx.createTableClauses()) val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText).getOrElse( throw QueryParsingErrors.createTempTableNotSpecifyProviderError(ctx)) - val schema = Option(ctx.colTypeList()).map(createSchema) + val schema = Option(ctx.createOrReplaceTableColTypeList()).map(createSchema) logWarning(s"CREATE TEMPORARY TABLE ... USING ... is deprecated, please use " + "CREATE TEMPORARY VIEW ... USING ... instead") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 9c2195d42786c..675b158100394 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import java.util.Locale import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{execution, Strategy} +import org.apache.spark.sql.{execution, AnalysisException, Strategy} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ @@ -373,7 +373,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) => if (aggregateExpressions.exists(PythonUDF.isGroupedAggPandasUDF)) { - throw QueryCompilationErrors.groupAggPandasUDFUnsupportedByStreamingAggError() + throw new AnalysisException( + "Streaming aggregation doesn't support group aggregate pandas UDF") } val sessionWindowOption = namedGroupingExpressions.find { p => @@ -440,7 +441,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { /** Ensures that this plan does not have a streaming aggregate in it. */ def hasNoStreamingAgg: Boolean = { - plan.collectFirst { case a: Aggregate if a.isStreaming => a }.isEmpty + !plan.exists { + case a: Aggregate => a.isStreaming + case _ => false + } } // The following cases of limits on a streaming plan has to be executed with a stateful diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index dde976c951718..7d36fd5d412a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -892,7 +892,7 @@ case class CollapseCodegenStages( private def supportCodegen(plan: SparkPlan): Boolean = plan match { case plan: CodegenSupport if plan.supportCodegen => - val willFallback = plan.expressions.exists(_.find(e => !supportCodegen(e)).isDefined) + val willFallback = plan.expressions.exists(_.exists(e => !supportCodegen(e))) // the generated code will be huge if there are too many columns val hasTooManyOutputFields = WholeStageCodegenExec.isTooManyFields(conf, plan.schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala index ea1ab8e5755a2..5533bb1cd7916 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.sql.catalyst.analysis.UpdateAttributeNullability -import org.apache.spark.sql.catalyst.optimizer.{ConvertToLocalRelation, EliminateLimits} +import org.apache.spark.sql.catalyst.optimizer.{ConvertToLocalRelation, EliminateLimits, OptimizeOneRowPlan} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LogicalPlanIntegrity, PlanHelper} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.internal.SQLConf @@ -40,7 +40,8 @@ class AQEOptimizer(conf: SQLConf) extends RuleExecutor[LogicalPlan] { ConvertToLocalRelation, UpdateAttributeNullability), Batch("Dynamic Join Selection", Once, DynamicJoinSelection), - Batch("Eliminate Limits", Once, EliminateLimits) + Batch("Eliminate Limits", fixedPoint, EliminateLimits), + Batch("Optimize One Row Plan", fixedPoint, OptimizeOneRowPlan) ) final override protected def batches: Seq[Batch] = { @@ -70,7 +71,7 @@ class AQEOptimizer(conf: SQLConf) extends RuleExecutor[LogicalPlan] { previousPlan: LogicalPlan, currentPlan: LogicalPlan): Boolean = { !Utils.isTesting || (currentPlan.resolved && - currentPlan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty && + !currentPlan.exists(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty) && LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(currentPlan) && DataType.equalsIgnoreNullability(previousPlan.schema, currentPlan.schema)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEUtils.scala index cbd4ee698df28..51833012a128e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEUtils.scala @@ -37,7 +37,7 @@ object AQEUtils { } else { None } - Some(ClusteredDistribution(h.expressions, numPartitions)) + Some(ClusteredDistribution(h.expressions, requiredNumPartitions = numPartitions)) case f: FilterExec => getRequiredDistribution(f.child) case s: SortExec if !s.global => getRequiredDistribution(s.child) case c: CollectMetricsExec => getRequiredDistribution(c.child) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 2b42804e784ed..c6505a0ea5f73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.execution.exchange._ import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLAdaptiveSQLMetricUpdates, SQLPlanMetric} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.{SparkFatalException, ThreadUtils} /** * A root node to execute the query plan adaptively. It splits the query plan into independent @@ -332,7 +332,7 @@ case class AdaptiveSparkPlanExec( // Subqueries that don't belong to any query stage of the main query will execute after the // last UI update in `getFinalPhysicalPlan`, so we need to update UI here again to make sure // the newly generated nodes of those subqueries are updated. - if (!isSubquery && currentPhysicalPlan.find(_.subqueries.nonEmpty).isDefined) { + if (!isSubquery && currentPhysicalPlan.exists(_.subqueries.nonEmpty)) { getExecutionId.foreach(onUpdatePlan(_, Seq.empty)) } logOnLevel(s"Final plan: $currentPhysicalPlan") @@ -410,7 +410,6 @@ case class AdaptiveSparkPlanExec( if (isFinalPlan) "Final Plan" else "Current Plan", currentPhysicalPlan, depth, - lastChildren, append, verbose, maxFields, @@ -419,7 +418,6 @@ case class AdaptiveSparkPlanExec( "Initial Plan", initialPlan, depth, - lastChildren, append, verbose, maxFields, @@ -432,7 +430,6 @@ case class AdaptiveSparkPlanExec( header: String, plan: SparkPlan, depth: Int, - lastChildren: Seq[Boolean], append: String => Unit, verbose: Boolean, maxFields: Int, @@ -614,7 +611,7 @@ case class AdaptiveSparkPlanExec( stagesToReplace: Seq[QueryStageExec]): LogicalPlan = { var logicalPlan = plan stagesToReplace.foreach { - case stage if currentPhysicalPlan.find(_.eq(stage)).isDefined => + case stage if currentPhysicalPlan.exists(_.eq(stage)) => val logicalNodeOpt = stage.getTagValue(TEMP_LOGICAL_PLAN_TAG).orElse(stage.logicalLink) assert(logicalNodeOpt.isDefined) val logicalNode = logicalNodeOpt.get @@ -702,7 +699,7 @@ case class AdaptiveSparkPlanExec( p.flatMap(_.metrics.values.map(m => SQLPlanMetric(m.name.get, m.id, m.metricType))) } context.session.sparkContext.listenerBus.post(SparkListenerSQLAdaptiveSQLMetricUpdates( - executionId.toLong, newMetrics)) + executionId, newMetrics)) } else { val planDescriptionMode = ExplainMode.fromString(conf.uiExplainMode) context.session.sparkContext.listenerBus.post(SparkListenerSQLAdaptiveExecutionUpdate( @@ -731,11 +728,16 @@ case class AdaptiveSparkPlanExec( } case _ => } - val e = if (errors.size == 1) { - errors.head + // Respect SparkFatalException which can be thrown by BroadcastExchangeExec + val originalErrors = errors.map { + case fatal: SparkFatalException => fatal.throwable + case other => other + } + val e = if (originalErrors.size == 1) { + originalErrors.head } else { - val se = QueryExecutionErrors.multiFailuresInStageMaterializationError(errors.head) - errors.tail.foreach(se.addSuppressed) + val se = QueryExecutionErrors.multiFailuresInStageMaterializationError(originalErrors.head) + originalErrors.tail.foreach(se.addSuppressed) se } throw e diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/DynamicJoinSelection.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/DynamicJoinSelection.scala index 6106dff99b2ac..217569ae645c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/DynamicJoinSelection.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/DynamicJoinSelection.scala @@ -18,8 +18,10 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.MapOutputStatistics +import org.apache.spark.sql.catalyst.optimizer.JoinSelectionHelper import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys -import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, JoinStrategyHint, LogicalPlan, NO_BROADCAST_HASH, PREFER_SHUFFLE_HASH, SHUFFLE_HASH} +import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, Join, JoinStrategyHint, LogicalPlan, NO_BROADCAST_HASH, PREFER_SHUFFLE_HASH, SHUFFLE_HASH} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf @@ -33,9 +35,9 @@ import org.apache.spark.sql.internal.SQLConf * 3. if a join satisfies both NO_BROADCAST_HASH and PREFER_SHUFFLE_HASH, * then add a SHUFFLE_HASH hint. */ -object DynamicJoinSelection extends Rule[LogicalPlan] { +object DynamicJoinSelection extends Rule[LogicalPlan] with JoinSelectionHelper { - private def shouldDemoteBroadcastHashJoin(mapStats: MapOutputStatistics): Boolean = { + private def hasManyEmptyPartitions(mapStats: MapOutputStatistics): Boolean = { val partitionCnt = mapStats.bytesByPartitionId.length val nonZeroCnt = mapStats.bytesByPartitionId.count(_ > 0) partitionCnt > 0 && nonZeroCnt > 0 && @@ -50,35 +52,69 @@ object DynamicJoinSelection extends Rule[LogicalPlan] { mapStats.bytesByPartitionId.forall(_ <= maxShuffledHashJoinLocalMapThreshold) } - private def selectJoinStrategy(plan: LogicalPlan): Option[JoinStrategyHint] = plan match { - case LogicalQueryStage(_, stage: ShuffleQueryStageExec) if stage.isMaterialized - && stage.mapStats.isDefined => - val demoteBroadcastHash = shouldDemoteBroadcastHashJoin(stage.mapStats.get) - val preferShuffleHash = preferShuffledHashJoin(stage.mapStats.get) - if (demoteBroadcastHash && preferShuffleHash) { - Some(SHUFFLE_HASH) - } else if (demoteBroadcastHash) { - Some(NO_BROADCAST_HASH) - } else if (preferShuffleHash) { - Some(PREFER_SHUFFLE_HASH) - } else { - None - } + private def selectJoinStrategy( + join: Join, + isLeft: Boolean): Option[JoinStrategyHint] = { + val plan = if (isLeft) join.left else join.right + plan match { + case LogicalQueryStage(_, stage: ShuffleQueryStageExec) if stage.isMaterialized + && stage.mapStats.isDefined => + + val manyEmptyInPlan = hasManyEmptyPartitions(stage.mapStats.get) + val canBroadcastPlan = (isLeft && canBuildBroadcastLeft(join.joinType)) || + (!isLeft && canBuildBroadcastRight(join.joinType)) + val manyEmptyInOther = (if (isLeft) join.right else join.left) match { + case LogicalQueryStage(_, stage: ShuffleQueryStageExec) if stage.isMaterialized + && stage.mapStats.isDefined => hasManyEmptyPartitions(stage.mapStats.get) + case _ => false + } + + val demoteBroadcastHash = if (manyEmptyInPlan && canBroadcastPlan) { + join.joinType match { + // don't demote BHJ since you cannot short circuit local join if inner (null-filled) + // side is empty + case LeftOuter | RightOuter | LeftAnti => false + case _ => true + } + } else if (manyEmptyInOther && canBroadcastPlan) { + // for example, LOJ, !isLeft but it's the LHS that has many empty partitions if we + // proceed with shuffle. But if we proceed with BHJ, the OptimizeShuffleWithLocalRead + // will assemble partitions as they were before the shuffle and that may no longer have + // many empty partitions and thus cannot short-circuit local join + join.joinType match { + case LeftOuter | RightOuter | LeftAnti => true + case _ => false + } + } else { + false + } + + val preferShuffleHash = preferShuffledHashJoin(stage.mapStats.get) + if (demoteBroadcastHash && preferShuffleHash) { + Some(SHUFFLE_HASH) + } else if (demoteBroadcastHash) { + Some(NO_BROADCAST_HASH) + } else if (preferShuffleHash) { + Some(PREFER_SHUFFLE_HASH) + } else { + None + } - case _ => None + case _ => None + } } def apply(plan: LogicalPlan): LogicalPlan = plan.transformDown { - case j @ ExtractEquiJoinKeys(_, _, _, _, _, left, right, hint) => + case j @ ExtractEquiJoinKeys(_, _, _, _, _, _, _, hint) => var newHint = hint if (!hint.leftHint.exists(_.strategy.isDefined)) { - selectJoinStrategy(left).foreach { strategy => + selectJoinStrategy(j, true).foreach { strategy => newHint = newHint.copy(leftHint = Some(hint.leftHint.getOrElse(HintInfo()).copy(strategy = Some(strategy)))) } } if (!hint.rightHint.exists(_.strategy.isDefined)) { - selectJoinStrategy(right).foreach { strategy => + selectJoinStrategy(j, false).foreach { strategy => newHint = newHint.copy(rightHint = Some(hint.rightHint.getOrElse(HintInfo()).copy(strategy = Some(strategy)))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala index 68042d8384102..4410f7fea81af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{ListQuery, SubqueryExpression} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.{DYNAMIC_PRUNING_SUBQUERY, IN_SUBQUERY, SCALAR_SUBQUERY} +import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec} import org.apache.spark.sql.execution.datasources.v2.V2CommandExec @@ -88,14 +88,14 @@ case class InsertAdaptiveSparkPlan( // - The query contains sub-query. private def shouldApplyAQE(plan: SparkPlan, isSubquery: Boolean): Boolean = { conf.getConf(SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY) || isSubquery || { - plan.find { + plan.exists { case _: Exchange => true case p if !p.requiredChildDistribution.forall(_ == UnspecifiedDistribution) => true - case p => p.expressions.exists(_.find { + case p => p.expressions.exists(_.exists { case _: SubqueryExpression => true case _ => false - }.isDefined) - }.isDefined + }) + } } } @@ -118,7 +118,7 @@ case class InsertAdaptiveSparkPlan( if (!plan.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) { return subqueryMap.toMap } - plan.foreach(_.expressions.foreach(_.foreach { + plan.foreach(_.expressions.filter(_.containsPattern(PLAN_EXPRESSION)).foreach(_.foreach { case expressions.ScalarSubquery(p, _, exprId, _) if !subqueryMap.contains(exprId.id) => val executedPlan = compileSubquery(p) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index 1c1ee7d03a4df..d4a173bb9cceb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, EnsureRequirements, ValidateRequirements} import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Utils /** * A rule to optimize skewed joins to avoid straggler tasks whose share of data are significantly @@ -66,16 +67,6 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) medianSize * conf.getConf(SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR)) } - private def medianSize(sizes: Array[Long]): Long = { - val numPartitions = sizes.length - val bytes = sizes.sorted - numPartitions match { - case _ if (numPartitions % 2 == 0) => - math.max((bytes(numPartitions / 2) + bytes(numPartitions / 2 - 1)) / 2, 1) - case _ => math.max(bytes(numPartitions / 2), 1) - } - } - /** * The goal of skew join optimization is to make the data distribution more even. The target size * to split skewed partitions is the average size of non-skewed partition, or the @@ -130,8 +121,8 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) assert(leftSizes.length == rightSizes.length) val numPartitions = leftSizes.length // We use the median size of the original shuffle partitions to detect skewed partitions. - val leftMedSize = medianSize(leftSizes) - val rightMedSize = medianSize(rightSizes) + val leftMedSize = Utils.median(leftSizes, false) + val rightMedSize = Utils.median(rightSizes, false) logDebug( s""" |Optimizing skewed join. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala index e2f763eb71502..ac1968dab6998 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala @@ -124,6 +124,10 @@ abstract class QueryStageExec extends LeafExecNode { protected override def stringArgs: Iterator[Any] = Iterator.single(id) + override def simpleStringWithNodeId(): String = { + super.simpleStringWithNodeId() + computeStats().map(", " + _.toString).getOrElse("") + } + override def generateTreeString( depth: Int, lastChildren: Seq[Boolean], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala index 0251f803786c8..af689db337987 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala @@ -317,7 +317,7 @@ object ShufflePartitionsUtil extends Logging { */ // Visible for testing private[sql] def splitSizeListByTargetSize( - sizes: Seq[Long], + sizes: Array[Long], targetSize: Long, smallPartitionFactor: Double): Array[Int] = { val partitionStartIndices = ArrayBuffer[Int]() @@ -394,7 +394,12 @@ object ShufflePartitionsUtil extends Logging { } else { mapStartIndices(i + 1) } - val dataSize = startMapIndex.until(endMapIndex).map(mapPartitionSizes(_)).sum + var dataSize = 0L + var mapIndex = startMapIndex + while (mapIndex < endMapIndex) { + dataSize += mapPartitionSizes(mapIndex) + mapIndex += 1 + } PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) }) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 32db622c9f931..26161acae30b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -45,8 +45,28 @@ object AggUtils { } } + private def createStreamingAggregate( + requiredChildDistributionExpressions: Option[Seq[Expression]] = None, + groupingExpressions: Seq[NamedExpression] = Nil, + aggregateExpressions: Seq[AggregateExpression] = Nil, + aggregateAttributes: Seq[Attribute] = Nil, + initialInputBufferOffset: Int = 0, + resultExpressions: Seq[NamedExpression] = Nil, + child: SparkPlan): SparkPlan = { + createAggregate( + requiredChildDistributionExpressions, + isStreaming = true, + groupingExpressions = groupingExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = initialInputBufferOffset, + resultExpressions = resultExpressions, + child = child) + } + private def createAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]] = None, + isStreaming: Boolean = false, groupingExpressions: Seq[NamedExpression] = Nil, aggregateExpressions: Seq[AggregateExpression] = Nil, aggregateAttributes: Seq[Attribute] = Nil, @@ -60,6 +80,8 @@ object AggUtils { if (useHash && !forceSortAggregate) { HashAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, + isStreaming = isStreaming, + numShufflePartitions = None, groupingExpressions = groupingExpressions, aggregateExpressions = mayRemoveAggFilters(aggregateExpressions), aggregateAttributes = aggregateAttributes, @@ -73,6 +95,8 @@ object AggUtils { if (objectHashEnabled && useObjectHash && !forceSortAggregate) { ObjectHashAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, + isStreaming = isStreaming, + numShufflePartitions = None, groupingExpressions = groupingExpressions, aggregateExpressions = mayRemoveAggFilters(aggregateExpressions), aggregateAttributes = aggregateAttributes, @@ -82,6 +106,8 @@ object AggUtils { } else { SortAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, + isStreaming = isStreaming, + numShufflePartitions = None, groupingExpressions = groupingExpressions, aggregateExpressions = mayRemoveAggFilters(aggregateExpressions), aggregateAttributes = aggregateAttributes, @@ -290,7 +316,7 @@ object AggUtils { val partialAggregate: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) - createAggregate( + createStreamingAggregate( groupingExpressions = groupingExpressions, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, @@ -302,7 +328,7 @@ object AggUtils { val partialMerged1: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) - createAggregate( + createStreamingAggregate( requiredChildDistributionExpressions = Some(groupingAttributes), groupingExpressions = groupingAttributes, @@ -320,7 +346,7 @@ object AggUtils { val partialMerged2: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) - createAggregate( + createStreamingAggregate( requiredChildDistributionExpressions = Some(groupingAttributes), groupingExpressions = groupingAttributes, @@ -348,7 +374,7 @@ object AggUtils { // projection: val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) - createAggregate( + createStreamingAggregate( requiredChildDistributionExpressions = Some(groupingAttributes), groupingExpressions = groupingAttributes, aggregateExpressions = finalAggregateExpressions, @@ -407,7 +433,7 @@ object AggUtils { val partialAggregate: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) - createAggregate( + createStreamingAggregate( groupingExpressions = groupingExpressions, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, @@ -424,7 +450,8 @@ object AggUtils { // this is to reduce amount of rows to shuffle MergingSessionsExec( requiredChildDistributionExpressions = None, - requiredChildDistributionOption = None, + isStreaming = true, + numShufflePartitions = None, groupingExpressions = groupingAttributes, sessionExpression = sessionExpression, aggregateExpressions = aggregateExpressions, @@ -447,8 +474,10 @@ object AggUtils { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) MergingSessionsExec( - requiredChildDistributionExpressions = None, - requiredChildDistributionOption = Some(restored.requiredChildDistribution), + requiredChildDistributionExpressions = Some(groupingWithoutSessionAttributes), + isStreaming = true, + // This will be replaced with actual value in state rule. + numShufflePartitions = None, groupingExpressions = groupingAttributes, sessionExpression = sessionExpression, aggregateExpressions = aggregateExpressions, @@ -476,8 +505,8 @@ object AggUtils { // projection: val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) - createAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), + createStreamingAggregate( + requiredChildDistributionExpressions = Some(groupingWithoutSessionAttributes), groupingExpressions = groupingAttributes, aggregateExpressions = finalAggregateExpressions, aggregateAttributes = finalAggregateAttributes, @@ -491,10 +520,15 @@ object AggUtils { private def mayAppendUpdatingSessionExec( groupingExpressions: Seq[NamedExpression], - maybeChildPlan: SparkPlan): SparkPlan = { + maybeChildPlan: SparkPlan, + isStreaming: Boolean = false): SparkPlan = { groupingExpressions.find(_.metadata.contains(SessionWindow.marker)) match { case Some(sessionExpression) => UpdatingSessionsExec( + isStreaming = isStreaming, + // numShufflePartitions will be set to None, and replaced to the actual value in the + // state rule if the query is streaming. + numShufflePartitions = None, groupingExpressions.map(_.toAttribute), sessionExpression.toAttribute, maybeChildPlan) @@ -506,7 +540,8 @@ object AggUtils { private def mayAppendMergingSessionExec( groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], - partialAggregate: SparkPlan): SparkPlan = { + partialAggregate: SparkPlan, + isStreaming: Boolean = false): SparkPlan = { groupingExpressions.find(_.metadata.contains(SessionWindow.marker)) match { case Some(sessionExpression) => val aggExpressions = aggregateExpressions.map(_.copy(mode = PartialMerge)) @@ -519,7 +554,10 @@ object AggUtils { MergingSessionsExec( requiredChildDistributionExpressions = Some(groupingWithoutSessionsAttributes), - requiredChildDistributionOption = None, + isStreaming = isStreaming, + // numShufflePartitions will be set to None, and replaced to the actual value in the + // state rule if the query is streaming. + numShufflePartitions = None, groupingExpressions = groupingAttributes, sessionExpression = sessionExpression, aggregateExpressions = aggExpressions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala index 6304363d7888e..1377a98422317 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala @@ -47,6 +47,11 @@ trait AggregateCodegenSupport */ private var bufVars: Seq[Seq[ExprCode]] = _ + /** + * Whether this operator needs to build hash table. + */ + protected def needHashTable: Boolean + /** * The generated code for `doProduce` call when aggregate has grouping keys. */ @@ -154,14 +159,23 @@ trait AggregateCodegenSupport """.stripMargin) val numOutput = metricTerm(ctx, "numOutputRows") - val aggTime = metricTerm(ctx, "aggTime") - val beforeAgg = ctx.freshName("beforeAgg") + val doAggWithRecordMetric = + if (needHashTable) { + val aggTime = metricTerm(ctx, "aggTime") + val beforeAgg = ctx.freshName("beforeAgg") + s""" + |long $beforeAgg = System.nanoTime(); + |$doAggFuncName(); + |$aggTime.add((System.nanoTime() - $beforeAgg) / $NANOS_PER_MILLIS); + """.stripMargin + } else { + s"$doAggFuncName();" + } + s""" |while (!$initAgg) { | $initAgg = true; - | long $beforeAgg = System.nanoTime(); - | $doAggFuncName(); - | $aggTime.add((System.nanoTime() - $beforeAgg) / $NANOS_PER_MILLIS); + | $doAggWithRecordMetric | | // output the result | ${genResult.trim} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala index b709c8092e46d..756b5eb09d0b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala @@ -21,12 +21,15 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Final, PartialMerge} import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, ExplainUtils, UnaryExecNode} +import org.apache.spark.sql.execution.streaming.StatefulOperatorPartitioning /** * Holds common logic for aggregate operators */ trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning { def requiredChildDistributionExpressions: Option[Seq[Expression]] + def isStreaming: Boolean + def numShufflePartitions: Option[Int] def groupingExpressions: Seq[NamedExpression] def aggregateExpressions: Seq[AggregateExpression] def aggregateAttributes: Seq[Attribute] @@ -92,7 +95,20 @@ trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning override def requiredChildDistribution: List[Distribution] = { requiredChildDistributionExpressions match { case Some(exprs) if exprs.isEmpty => AllTuples :: Nil - case Some(exprs) => ClusteredDistribution(exprs) :: Nil + case Some(exprs) => + if (isStreaming) { + numShufflePartitions match { + case Some(parts) => + StatefulOperatorPartitioning.getCompatibleDistribution( + exprs, parts, conf) :: Nil + + case _ => + throw new IllegalStateException("Expected to set the number of partitions before " + + "constructing required child distribution!") + } + } else { + ClusteredDistribution(exprs) :: Nil + } case None => UnspecifiedDistribution :: Nil } } @@ -102,7 +118,8 @@ trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning */ def toSortAggregate: SortAggregateExec = { SortAggregateExec( - requiredChildDistributionExpressions, groupingExpressions, aggregateExpressions, - aggregateAttributes, initialInputBufferOffset, resultExpressions, child) + requiredChildDistributionExpressions, isStreaming, numShufflePartitions, groupingExpressions, + aggregateExpressions, aggregateAttributes, initialInputBufferOffset, resultExpressions, + child) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index d4a4502badd09..8be3a018cee58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -45,6 +45,8 @@ import org.apache.spark.util.Utils */ case class HashAggregateExec( requiredChildDistributionExpressions: Option[Seq[Expression]], + isStreaming: Boolean, + numShufflePartitions: Option[Int], groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], aggregateAttributes: Seq[Attribute], @@ -65,7 +67,7 @@ case class HashAggregateExec( "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"), "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in aggregation build"), "avgHashProbe" -> - SQLMetrics.createAverageMetric(sparkContext, "avg hash probe bucket list iters"), + SQLMetrics.createAverageMetric(sparkContext, "avg hash probes per key"), "numTasksFallBacked" -> SQLMetrics.createMetric(sparkContext, "number of sort fallback tasks")) // This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash @@ -204,7 +206,7 @@ case class HashAggregateExec( metrics.incPeakExecutionMemory(maxMemory) // Update average hashmap probe - avgHashProbe.set(hashMap.getAvgHashProbeBucketListIterations) + avgHashProbe.set(hashMap.getAvgHashProbesPerKey) if (sorter == null) { // not spilled @@ -376,7 +378,7 @@ case class HashAggregateExec( * Currently fast hash map is supported for primitive data types during partial aggregation. * This list of supported use-cases should be expanded over time. */ - private def checkIfFastHashMapSupported(ctx: CodegenContext): Boolean = { + private def checkIfFastHashMapSupported(): Boolean = { val isSupported = (groupingKeySchema ++ bufferSchema).forall(f => CodeGenerator.isPrimitiveType(f.dataType) || f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType] || @@ -402,8 +404,8 @@ case class HashAggregateExec( isSupported && isNotByteArrayDecimalType && isEnabledForAggModes } - private def enableTwoLevelHashMap(ctx: CodegenContext): Unit = { - if (!checkIfFastHashMapSupported(ctx)) { + private def enableTwoLevelHashMap(): Unit = { + if (!checkIfFastHashMapSupported()) { if (!Utils.isTesting) { logInfo(s"${SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key} is set to true, but" + " current version of codegened fast hashmap does not support this aggregate.") @@ -417,10 +419,12 @@ case class HashAggregateExec( } } + protected override def needHashTable: Boolean = true + protected override def doProduceWithKeys(ctx: CodegenContext): String = { val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") if (conf.enableTwoLevelAggMap) { - enableTwoLevelHashMap(ctx) + enableTwoLevelHashMap() } else if (conf.enableVectorizedHashMap) { logWarning("Two level hashmap is disabled but vectorized hashmap is enabled.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsExec.scala index 08e8b59a17828..31245c5451857 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsExec.scala @@ -21,7 +21,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, MutableProjection, NamedExpression, SortOrder, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.SQLMetrics @@ -41,7 +40,8 @@ import org.apache.spark.sql.execution.metric.SQLMetrics */ case class MergingSessionsExec( requiredChildDistributionExpressions: Option[Seq[Expression]], - requiredChildDistributionOption: Option[Seq[Distribution]], + isStreaming: Boolean, + numShufflePartitions: Option[Int], groupingExpressions: Seq[NamedExpression], sessionExpression: NamedExpression, aggregateExpressions: Seq[AggregateExpression], @@ -59,17 +59,6 @@ case class MergingSessionsExec( override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def requiredChildDistribution: List[Distribution] = { - requiredChildDistributionExpressions match { - case Some(exprs) if exprs.isEmpty => AllTuples :: Nil - case Some(exprs) => ClusteredDistribution(exprs) :: Nil - case None => requiredChildDistributionOption match { - case Some(distributions) => distributions.toList - case None => UnspecifiedDistribution :: Nil - } - } - } - override def requiredChildOrdering: Seq[Seq[SortOrder]] = { Seq((keyWithoutSessionExpressions ++ Seq(sessionExpression)).map(SortOrder(_, Ascending))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index c98c9f42e69da..9da0ca93c1819 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -59,6 +59,8 @@ import org.apache.spark.sql.execution.metric.SQLMetrics */ case class ObjectHashAggregateExec( requiredChildDistributionExpressions: Option[Seq[Expression]], + isStreaming: Boolean, + numShufflePartitions: Option[Int], groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], aggregateAttributes: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index f5462d226c3ae..3cf63a5318dcf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.aggregate -import java.util.concurrent.TimeUnit.NANOSECONDS - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -34,6 +32,8 @@ import org.apache.spark.sql.internal.SQLConf */ case class SortAggregateExec( requiredChildDistributionExpressions: Option[Seq[Expression]], + isStreaming: Boolean, + numShufflePartitions: Option[Int], groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], aggregateAttributes: Seq[Attribute], @@ -44,8 +44,7 @@ case class SortAggregateExec( with AliasAwareOutputOrdering { override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in aggregation build")) + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) override def requiredChildOrdering: Seq[Seq[SortOrder]] = { groupingExpressions.map(SortOrder(_, Ascending)) :: Nil @@ -57,14 +56,11 @@ case class SortAggregateExec( protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - val aggTime = longMetric("aggTime") - child.execute().mapPartitionsWithIndexInternal { (partIndex, iter) => - val beforeAgg = System.nanoTime() // Because the constructor of an aggregation iterator will read at least the first row, // we need to get the value of iter.hasNext first. val hasInput = iter.hasNext - val res = if (!hasInput && groupingExpressions.nonEmpty) { + if (!hasInput && groupingExpressions.nonEmpty) { // This is a grouped aggregate and the input iterator is empty, // so return an empty iterator. Iterator[UnsafeRow]() @@ -90,8 +86,6 @@ case class SortAggregateExec( outputIter } } - aggTime += NANOSECONDS.toMillis(System.nanoTime() - beforeAgg) - res } } @@ -101,11 +95,13 @@ case class SortAggregateExec( groupingExpressions.isEmpty } - protected def doProduceWithKeys(ctx: CodegenContext): String = { + protected override def needHashTable: Boolean = false + + protected override def doProduceWithKeys(ctx: CodegenContext): String = { throw new UnsupportedOperationException("SortAggregate code-gen does not support grouping keys") } - protected def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { + protected override def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { throw new UnsupportedOperationException("SortAggregate code-gen does not support grouping keys") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 0a5e8838e1531..36405fe927273 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -389,7 +389,7 @@ class TungstenAggregationIterator( metrics.incPeakExecutionMemory(maxMemory) // Updating average hashmap probe - avgHashProbe.set(hashMap.getAvgHashProbeBucketListIterations) + avgHashProbe.set(hashMap.getAvgHashProbesPerKey) }) /////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsExec.scala index f15a22403cfb4..fee7e29f8add1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsExec.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, SortOrder} import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.streaming.StatefulOperatorPartitioning /** * This node updates the session window spec of each input rows via analyzing neighbor rows and @@ -35,6 +36,8 @@ import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} * Refer [[UpdatingSessionsIterator]] for more details. */ case class UpdatingSessionsExec( + isStreaming: Boolean, + numShufflePartitions: Option[Int], groupingExpression: Seq[Attribute], sessionExpression: Attribute, child: SparkPlan) extends UnaryExecNode { @@ -63,7 +66,20 @@ case class UpdatingSessionsExec( if (groupingWithoutSessionExpression.isEmpty) { AllTuples :: Nil } else { - ClusteredDistribution(groupingWithoutSessionExpression) :: Nil + if (isStreaming) { + numShufflePartitions match { + case Some(parts) => + StatefulOperatorPartitioning.getCompatibleDistribution( + groupingWithoutSessionExpression, parts, conf) :: Nil + + case _ => + throw new IllegalStateException("Expected to set the number of partitions before " + + "constructing required child distribution!") + } + + } else { + ClusteredDistribution(groupingWithoutSessionExpression) :: Nil + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 8879e1499f930..c1e225200f7b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -83,7 +83,7 @@ sealed trait BufferSetterGetterUtils { (row: InternalRow, ordinal: Int) => if (row.isNullAt(ordinal)) null else row.getInt(ordinal) - case TimestampType => + case TimestampType | TimestampNTZType => (row: InternalRow, ordinal: Int) => if (row.isNullAt(ordinal)) null else row.getLong(ordinal) @@ -187,7 +187,7 @@ sealed trait BufferSetterGetterUtils { row.setNullAt(ordinal) } - case TimestampType => + case TimestampType | TimestampNTZType => (row: InternalRow, ordinal: Int, value: Any) => if (value != null) { row.setLong(ordinal, value.asInstanceOf[Long]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala index 17ea93e5ffede..7e9628c385130 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala @@ -78,7 +78,7 @@ object DetectAmbiguousSelfJoin extends Rule[LogicalPlan] { // We always remove the special metadata from `AttributeReference` at the end of this rule, so // Dataset column reference only exists in the root node via Dataset transformations like // `Dataset#select`. - if (plan.find(_.isInstanceOf[Join]).isEmpty) return stripColumnReferenceMetadataInPlan(plan) + if (!plan.exists(_.isInstanceOf[Join])) return stripColumnReferenceMetadataInPlan(plan) val colRefAttrs = plan.expressions.flatMap(_.collect { case a: AttributeReference if isColumnReference(a) => a diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 8e22c429c24e4..93ff276529dad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -31,7 +31,7 @@ import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, IpcOption, Message import org.apache.spark.TaskContext import org.apache.spark.api.java.JavaRDD import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ArrowUtils @@ -195,27 +195,27 @@ private[sql] object ArrowConverters { private[sql] def toDataFrame( arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, - sqlContext: SQLContext): DataFrame = { + session: SparkSession): DataFrame = { val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] - val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone + val timeZoneId = session.sessionState.conf.sessionLocalTimeZone val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } - sqlContext.internalCreateDataFrame(rdd.setName("arrow"), schema) + session.internalCreateDataFrame(rdd.setName("arrow"), schema) } /** * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches. */ private[sql] def readArrowStreamFromFile( - sqlContext: SQLContext, + session: SparkSession, filename: String): JavaRDD[Array[Byte]] = { Utils.tryWithResource(new FileInputStream(filename)) { fileStream => // Create array to consume iterator so that we can safely close the file val batches = getBatchesFromStream(fileStream.getChannel).toArray // Parallelize the record batches to create an RDD - JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) + JavaRDD.fromRDD(session.sparkContext.parallelize(batches, batches.length)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/DisableUnnecessaryBucketedScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/DisableUnnecessaryBucketedScan.scala index 479bc21e5e6c8..1eb1082402972 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/DisableUnnecessaryBucketedScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/DisableUnnecessaryBucketedScan.scala @@ -141,10 +141,10 @@ object DisableUnnecessaryBucketedScan extends Rule[SparkPlan] { } def apply(plan: SparkPlan): SparkPlan = { - lazy val hasBucketedScan = plan.find { + lazy val hasBucketedScan = plan.exists { case scan: FileSourceScanExec => scan.bucketedScan case _ => false - }.isDefined + } if (!conf.bucketingEnabled || !conf.autoBucketedScanEnabled || !hasBucketedScan) { plan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala index 2f68e89d9c1f9..770b2442e403c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala @@ -106,7 +106,7 @@ private[columnar] class BinaryColumnAccessor(buffer: ByteBuffer) extends BasicColumnAccessor[Array[Byte]](buffer, BINARY) with NullableColumnAccessor -private[columnar] class IntervalColumnAccessor(buffer: ByteBuffer, dataType: CalendarIntervalType) +private[columnar] class IntervalColumnAccessor(buffer: ByteBuffer) extends BasicColumnAccessor[CalendarInterval](buffer, CALENDAR_INTERVAL) with NullableColumnAccessor @@ -158,11 +158,11 @@ private[sql] object ColumnAccessor { def decompress(columnAccessor: ColumnAccessor, columnVector: WritableColumnVector, numRows: Int): Unit = { - if (columnAccessor.isInstanceOf[NativeColumnAccessor[_]]) { - val nativeAccessor = columnAccessor.asInstanceOf[NativeColumnAccessor[_]] - nativeAccessor.decompress(columnVector, numRows) - } else { - throw QueryExecutionErrors.notSupportNonPrimitiveTypeError() + columnAccessor match { + case nativeAccessor: NativeColumnAccessor[_] => + nativeAccessor.decompress(columnVector, numRows) + case _ => + throw QueryExecutionErrors.notSupportNonPrimitiveTypeError() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index 419dcc6cdeca7..c029786637687 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -473,23 +473,25 @@ private[columnar] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType // copy the bytes from ByteBuffer to UnsafeRow override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { - if (row.isInstanceOf[MutableUnsafeRow]) { - val numBytes = buffer.getInt - val cursor = buffer.position() - buffer.position(cursor + numBytes) - row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, buffer.array(), - buffer.arrayOffset() + cursor, numBytes) - } else { - setField(row, ordinal, extract(buffer)) + row match { + case mutable: MutableUnsafeRow => + val numBytes = buffer.getInt + val cursor = buffer.position() + buffer.position(cursor + numBytes) + mutable.writer.write(ordinal, buffer.array(), + buffer.arrayOffset() + cursor, numBytes) + case _ => + setField(row, ordinal, extract(buffer)) } } // copy the bytes from UnsafeRow to ByteBuffer override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { - if (row.isInstanceOf[UnsafeRow]) { - row.asInstanceOf[UnsafeRow].writeFieldTo(ordinal, buffer) - } else { - super.append(row, ordinal, buffer) + row match { + case unsafe: UnsafeRow => + unsafe.writeFieldTo(ordinal, buffer) + case _ => + super.append(row, ordinal, buffer) } } } @@ -514,10 +516,11 @@ private[columnar] object STRING } override def setField(row: InternalRow, ordinal: Int, value: UTF8String): Unit = { - if (row.isInstanceOf[MutableUnsafeRow]) { - row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, value) - } else { - row.update(ordinal, value.clone()) + row match { + case mutable: MutableUnsafeRow => + mutable.writer.write(ordinal, value) + case _ => + row.update(ordinal, value.clone()) } } @@ -792,13 +795,14 @@ private[columnar] object CALENDAR_INTERVAL extends ColumnType[CalendarInterval] // copy the bytes from ByteBuffer to UnsafeRow override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { - if (row.isInstanceOf[MutableUnsafeRow]) { - val cursor = buffer.position() - buffer.position(cursor + defaultSize) - row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, buffer.array(), - buffer.arrayOffset() + cursor, defaultSize) - } else { - setField(row, ordinal, extract(buffer)) + row match { + case mutable: MutableUnsafeRow => + val cursor = buffer.position() + buffer.position(cursor + defaultSize) + mutable.writer.write(ordinal, buffer.array(), + buffer.arrayOffset() + cursor, defaultSize) + case _ => + setField(row, ordinal, extract(buffer)) } } @@ -829,7 +833,7 @@ private[columnar] object ColumnType { case arr: ArrayType => ARRAY(arr) case map: MapType => MAP(map) case struct: StructType => STRUCT(struct) - case udt: UserDefinedType[_] => apply(udt.sqlType) + case udt: UserDefinedType[_] => ColumnType(udt.sqlType) case other => throw QueryExecutionErrors.unsupportedTypeError(other) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 6e666d4e1f9fc..33918bcee738b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -100,7 +100,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera val createCode = dt match { case t if CodeGenerator.isPrimitiveType(dt) => s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" - case NullType | StringType | BinaryType => + case NullType | StringType | BinaryType | CalendarIntervalType => s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" case other => s"""$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index eed8e039eddd0..c21f330be0647 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -83,7 +83,7 @@ case class ExecutedCommandExec(cmd: RunnableCommand) extends LeafExecNode { override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray - override def executeToIterator(): Iterator[InternalRow] = sideEffectResult.toIterator + override def executeToIterator(): Iterator[InternalRow] = sideEffectResult.iterator override def executeTake(limit: Int): Array[InternalRow] = sideEffectResult.take(limit).toArray @@ -124,7 +124,7 @@ case class DataWritingCommandExec(cmd: DataWritingCommand, child: SparkPlan) override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray - override def executeToIterator(): Iterator[InternalRow] = sideEffectResult.toIterator + override def executeToIterator(): Iterator[InternalRow] = sideEffectResult.iterator override def executeTake(limit: Int): Array[InternalRow] = sideEffectResult.take(limit).toArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 295838eda5a72..14d0e9753f2b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.command import java.util.Locale import java.util.concurrent.TimeUnit._ -import scala.collection.{GenMap, GenSeq} import scala.collection.parallel.ForkJoinTaskSupport import scala.collection.parallel.immutable.ParVector import scala.util.control.NonFatal @@ -469,7 +468,7 @@ case class AlterTableAddPartitionCommand( // Also the request to metastore times out when adding lot of partitions in one shot. // we should split them into smaller batches val batchSize = conf.getConf(SQLConf.ADD_PARTITION_BATCH_SIZE) - parts.toIterator.grouped(batchSize).foreach { batch => + parts.iterator.grouped(batchSize).foreach { batch => catalog.createPartitions(table.identifier, batch, ignoreIfExists = ifNotExists) } @@ -643,7 +642,7 @@ case class RepairTableCommand( val pathFilter = getPathFilter(hadoopConf) val evalPool = ThreadUtils.newForkJoinPool("RepairTableCommand", 8) - val partitionSpecsAndLocs: GenSeq[(TablePartitionSpec, Path)] = + val partitionSpecsAndLocs: Seq[(TablePartitionSpec, Path)] = try { scanPartitions(spark, fs, pathFilter, root, Map(), table.partitionColumnNames, threshold, spark.sessionState.conf.resolver, new ForkJoinTaskSupport(evalPool)).seq @@ -656,7 +655,7 @@ case class RepairTableCommand( val partitionStats = if (spark.sqlContext.conf.gatherFastStats) { gatherPartitionStats(spark, partitionSpecsAndLocs, fs, pathFilter, threshold) } else { - GenMap.empty[String, PartitionStatistics] + Map.empty[String, PartitionStatistics] } logInfo(s"Finished to gather the fast stats for all $total partitions.") @@ -689,13 +688,13 @@ case class RepairTableCommand( partitionNames: Seq[String], threshold: Int, resolver: Resolver, - evalTaskSupport: ForkJoinTaskSupport): GenSeq[(TablePartitionSpec, Path)] = { + evalTaskSupport: ForkJoinTaskSupport): Seq[(TablePartitionSpec, Path)] = { if (partitionNames.isEmpty) { return Seq(spec -> path) } val statuses = fs.listStatus(path, filter) - val statusPar: GenSeq[FileStatus] = + val statusPar: Seq[FileStatus] = if (partitionNames.length > 1 && statuses.length > threshold || partitionNames.length > 2) { // parallelize the list of partitions here, then we can have better parallelism later. val parArray = new ParVector(statuses.toVector) @@ -728,10 +727,10 @@ case class RepairTableCommand( private def gatherPartitionStats( spark: SparkSession, - partitionSpecsAndLocs: GenSeq[(TablePartitionSpec, Path)], + partitionSpecsAndLocs: Seq[(TablePartitionSpec, Path)], fs: FileSystem, pathFilter: PathFilter, - threshold: Int): GenMap[String, PartitionStatistics] = { + threshold: Int): Map[String, PartitionStatistics] = { if (partitionSpecsAndLocs.length > threshold) { val hadoopConf = spark.sessionState.newHadoopConf() val serializableConfiguration = new SerializableConfiguration(hadoopConf) @@ -752,7 +751,7 @@ case class RepairTableCommand( val statuses = fs.listStatus(path, pathFilter) (path.toString, PartitionStatistics(statuses.length, statuses.map(_.getLen).sum)) } - }.collectAsMap() + }.collectAsMap().toMap } else { partitionSpecsAndLocs.map { case (_, location) => val statuses = fs.listStatus(location, pathFilter) @@ -764,15 +763,15 @@ case class RepairTableCommand( private def addPartitions( spark: SparkSession, table: CatalogTable, - partitionSpecsAndLocs: GenSeq[(TablePartitionSpec, Path)], - partitionStats: GenMap[String, PartitionStatistics]): Unit = { + partitionSpecsAndLocs: Seq[(TablePartitionSpec, Path)], + partitionStats: Map[String, PartitionStatistics]): Unit = { val total = partitionSpecsAndLocs.length var done = 0L // Hive metastore may not have enough memory to handle millions of partitions in single RPC, // we should split them into smaller batches. Since Hive client is not thread safe, we cannot // do this in parallel. val batchSize = spark.conf.get(SQLConf.ADD_PARTITION_BATCH_SIZE) - partitionSpecsAndLocs.toIterator.grouped(batchSize).foreach { batch => + partitionSpecsAndLocs.iterator.grouped(batchSize).foreach { batch => val now = MILLISECONDS.toSeconds(System.currentTimeMillis()) val parts = batch.map { case (spec, location) => val params = partitionStats.get(location.toString).map { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 761a0d508e877..ac4bb8395a3b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -35,7 +35,8 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdentifier, CaseInsensitiveMap, CharVarcharUtils} +import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIfNeeded, CaseInsensitiveMap, CharVarcharUtils, DateTimeUtils} +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.TableIdentifierHelper import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat @@ -773,8 +774,10 @@ case class DescribeColumnCommand( ) if (isExtended) { // Show column stats when EXTENDED or FORMATTED is specified. - buffer += Row("min", cs.flatMap(_.min.map(_.toString)).getOrElse("NULL")) - buffer += Row("max", cs.flatMap(_.max.map(_.toString)).getOrElse("NULL")) + buffer += Row("min", cs.flatMap(_.min.map( + toZoneAwareExternalString(_, field.name, field.dataType))).getOrElse("NULL")) + buffer += Row("max", cs.flatMap(_.max.map( + toZoneAwareExternalString(_, field.name, field.dataType))).getOrElse("NULL")) buffer += Row("num_nulls", cs.flatMap(_.nullCount.map(_.toString)).getOrElse("NULL")) buffer += Row("distinct_count", cs.flatMap(_.distinctCount.map(_.toString)).getOrElse("NULL")) @@ -789,6 +792,27 @@ case class DescribeColumnCommand( buffer.toSeq } + private def toZoneAwareExternalString( + valueStr: String, + name: String, + dataType: DataType): String = { + dataType match { + case TimestampType => + // When writing to metastore, we always format timestamp value in the default UTC time zone. + // So here we need to first convert to internal value, then format it using the current + // time zone. + val internalValue = + CatalogColumnStat.fromExternalString(valueStr, name, dataType, CatalogColumnStat.VERSION) + val curZoneId = DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone) + CatalogColumnStat + .getTimestampFormatter( + isParsing = false, format = "yyyy-MM-dd HH:mm:ss.SSSSSS Z", zoneId = curZoneId) + .format(internalValue.asInstanceOf[Long]) + case _ => + valueStr + } + } + private def histogramDescription(histogram: Histogram): Seq[Row] = { val header = Row("histogram", s"height: ${histogram.height}, num_of_bins: ${histogram.bins.length}") @@ -1034,7 +1058,7 @@ trait ShowCreateTableCommandBase { .map(" COMMENT '" + _ + "'") // view columns shouldn't have data type info - s"${quoteIdentifier(f.name)}${comment.getOrElse("")}" + s"${quoteIfNeeded(f.name)}${comment.getOrElse("")}" } builder ++= concatByMultiLines(viewColumns) } @@ -1043,7 +1067,7 @@ trait ShowCreateTableCommandBase { private def showViewProperties(metadata: CatalogTable, builder: StringBuilder): Unit = { val viewProps = metadata.properties.filterKeys(!_.startsWith(CatalogTable.VIEW_PREFIX)) if (viewProps.nonEmpty) { - val props = viewProps.map { case (key, value) => + val props = viewProps.toSeq.sortBy(_._1).map { case (key, value) => s"'${escapeSingleQuotedString(key)}' = '${escapeSingleQuotedString(value)}'" } @@ -1101,15 +1125,15 @@ case class ShowCreateTableCommand( } } - val builder = StringBuilder.newBuilder + val builder = new StringBuilder val stmt = if (tableMetadata.tableType == VIEW) { - builder ++= s"CREATE VIEW ${table.quotedString} " + builder ++= s"CREATE VIEW ${table.quoted} " showCreateView(metadata, builder) builder.toString() } else { - builder ++= s"CREATE TABLE ${table.quotedString} " + builder ++= s"CREATE TABLE ${table.quoted} " showCreateDataSourceTable(metadata, builder) builder.toString() @@ -1129,7 +1153,7 @@ case class ShowCreateTableCommand( // TODO: some Hive fileformat + row serde might be mapped to Spark data source, e.g. CSV. val source = HiveSerDe.serdeToSource(hiveSerde) if (source.isEmpty) { - val builder = StringBuilder.newBuilder + val builder = new StringBuilder hiveSerde.serde.foreach { serde => builder ++= s" SERDE: $serde" } @@ -1236,7 +1260,7 @@ case class ShowCreateTableAsSerdeCommand( reportUnsupportedError(metadata.unsupportedFeatures) } - val builder = StringBuilder.newBuilder + val builder = new StringBuilder val tableTypeString = metadata.tableType match { case EXTERNAL => " EXTERNAL TABLE" @@ -1247,7 +1271,7 @@ case class ShowCreateTableAsSerdeCommand( s"Unknown table type is found at showCreateHiveTable: $t") } - builder ++= s"CREATE$tableTypeString ${table.quotedString}" + builder ++= s"CREATE$tableTypeString ${table.quoted} " if (metadata.tableType == VIEW) { showCreateView(metadata, builder) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 145287158a58c..eca48a6992433 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -612,12 +612,17 @@ object ViewHelper extends SQLConfHelper with Logging { val uncache = getRawTempView(name.table).map { r => needsToUncache(r, aliasedPlan) }.getOrElse(false) + val storeAnalyzedPlanForView = conf.storeAnalyzedPlanForView || originalText.isEmpty if (replace && uncache) { logDebug(s"Try to uncache ${name.quotedString} before replacing.") - checkCyclicViewReference(analyzedPlan, Seq(name), name) + if (!storeAnalyzedPlanForView) { + // Skip cyclic check because when stored analyzed plan for view, the depended + // view is already converted to the underlying tables. So no cyclic views. + checkCyclicViewReference(analyzedPlan, Seq(name), name) + } CommandUtils.uncacheTableOrView(session, name.quotedString) } - if (!conf.storeAnalyzedPlanForView && originalText.nonEmpty) { + if (!storeAnalyzedPlanForView) { TemporaryViewRelation( prepareTemporaryView( name, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala index e7069137f31cb..4779a3eaf2531 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow} -import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min} import org.apache.spark.sql.execution.RowToColumnConverter +import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector} import org.apache.spark.sql.types.{BooleanType, ByteType, DateType, DoubleType, FloatType, IntegerType, LongType, ShortType, StructField, StructType} import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -42,27 +42,28 @@ object AggregatePushDownUtils { var finalSchema = new StructType() - def getStructFieldForCol(col: NamedReference): StructField = { - schema.apply(col.fieldNames.head) + def getStructFieldForCol(colName: String): StructField = { + schema.apply(colName) } - def isPartitionCol(col: NamedReference) = { - partitionNames.contains(col.fieldNames.head) + def isPartitionCol(colName: String) = { + partitionNames.contains(colName) } def processMinOrMax(agg: AggregateFunc): Boolean = { - val (column, aggType) = agg match { - case max: Max => (max.column, "max") - case min: Min => (min.column, "min") - case _ => - throw new IllegalArgumentException(s"Unexpected type of AggregateFunc ${agg.describe}") + val (columnName, aggType) = agg match { + case max: Max if V2ColumnUtils.extractV2Column(max.column).isDefined => + (V2ColumnUtils.extractV2Column(max.column).get, "max") + case min: Min if V2ColumnUtils.extractV2Column(min.column).isDefined => + (V2ColumnUtils.extractV2Column(min.column).get, "min") + case _ => return false } - if (isPartitionCol(column)) { + if (isPartitionCol(columnName)) { // don't push down partition column, footer doesn't have max/min for partition column return false } - val structField = getStructFieldForCol(column) + val structField = getStructFieldForCol(columnName) structField.dataType match { // not push down complex type @@ -108,8 +109,8 @@ object AggregatePushDownUtils { aggregation.groupByColumns.foreach { col => // don't push down if the group by columns are not the same as the partition columns (orders // doesn't matter because reorder can be done at data source layer) - if (col.fieldNames.length != 1 || !isPartitionCol(col)) return None - finalSchema = finalSchema.add(getStructFieldForCol(col)) + if (col.fieldNames.length != 1 || !isPartitionCol(col.fieldNames.head)) return None + finalSchema = finalSchema.add(getStructFieldForCol(col.fieldNames.head)) } aggregation.aggregateExpressions.foreach { @@ -117,10 +118,10 @@ object AggregatePushDownUtils { if (!processMinOrMax(max)) return None case min: Min => if (!processMinOrMax(min)) return None - case count: Count => - if (count.column.fieldNames.length != 1 || count.isDistinct) return None - finalSchema = - finalSchema.add(StructField(s"count(" + count.column.fieldNames.head + ")", LongType)) + case count: Count + if V2ColumnUtils.extractV2Column(count.column).isDefined && !count.isDistinct => + val columnName = V2ColumnUtils.extractV2Column(count.column).get + finalSchema = finalSchema.add(StructField(s"count($columnName)", LongType)) case _: CountStar => finalSchema = finalSchema.add(StructField("count(*)", LongType)) case _ => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index a7e505ebd93da..2bb3d48c1458c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -353,7 +353,7 @@ case class DataSource( case (dataSource: RelationProvider, Some(schema)) => val baseRelation = dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions) - if (baseRelation.schema != schema) { + if (!DataType.equalsIgnoreCompatibleNullability(baseRelation.schema, schema)) { throw QueryCompilationErrors.userSpecifiedSchemaMismatchActualSchemaError( schema, baseRelation.schema) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index e734de32d232f..4e5014cc83e13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -38,10 +38,11 @@ import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoStatement, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 +import org.apache.spark.sql.catalyst.util.V2ExpressionBuilder import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ -import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortOrder => SortOrderV2, SortValue} -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum} +import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, NullOrdering, SortDirection, SortOrder => V2SortOrder, SortValue} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ @@ -60,7 +61,7 @@ import org.apache.spark.unsafe.types.UTF8String * Note that, this rule must be run after `PreprocessTableCreation` and * `PreprocessTableInsertion`. */ -object DataSourceAnalysis extends Rule[LogicalPlan] with CastSupport { +object DataSourceAnalysis extends Rule[LogicalPlan] { def resolver: Resolver = conf.resolver @@ -114,7 +115,10 @@ object DataSourceAnalysis extends Rule[LogicalPlan] with CastSupport { Some(Alias(AnsiCast(Literal(partValue), field.dataType, Option(conf.sessionLocalTimeZone)), field.name)()) case _ => - Some(Alias(cast(Literal(partValue), field.dataType), field.name)()) + val castExpression = + Cast(Literal(partValue), field.dataType, Option(conf.sessionLocalTimeZone), + ansiEnabled = false) + Some(Alias(castExpression, field.name)()) } } else { throw QueryCompilationErrors.multiplePartitionColumnValuesSpecifiedError( @@ -705,22 +709,17 @@ object DataSourceStrategy protected[sql] def translateAggregate(agg: AggregateExpression): Option[AggregateFunc] = { if (agg.filter.isEmpty) { agg.aggregateFunction match { - case aggregate.Min(PushableColumnWithoutNestedColumn(name)) => - Some(new Min(FieldReference.column(name))) - case aggregate.Max(PushableColumnWithoutNestedColumn(name)) => - Some(new Max(FieldReference.column(name))) + case aggregate.Min(PushableExpression(expr)) => Some(new Min(expr)) + case aggregate.Max(PushableExpression(expr)) => Some(new Max(expr)) case count: aggregate.Count if count.children.length == 1 => count.children.head match { // COUNT(any literal) is the same as COUNT(*) case Literal(_, _) => Some(new CountStar()) - case PushableColumnWithoutNestedColumn(name) => - Some(new Count(FieldReference.column(name), agg.isDistinct)) + case PushableExpression(expr) => Some(new Count(expr, agg.isDistinct)) case _ => None } - case aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) => - Some(new Sum(FieldReference.column(name), agg.isDistinct)) - case aggregate.Average(PushableColumnWithoutNestedColumn(name), _) => - Some(new GeneralAggregateFunc("AVG", agg.isDistinct, Array(FieldReference.column(name)))) + case aggregate.Sum(PushableExpression(expr), _) => Some(new Sum(expr, agg.isDistinct)) + case aggregate.Average(PushableExpression(expr), _) => Some(new Avg(expr, agg.isDistinct)) case aggregate.VariancePop(PushableColumnWithoutNestedColumn(name), _) => Some(new GeneralAggregateFunc( "VAR_POP", agg.isDistinct, Array(FieldReference.column(name)))) @@ -752,8 +751,33 @@ object DataSourceStrategy } } - protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[SortOrderV2] = { - def translateOortOrder(sortOrder: SortOrder): Option[SortOrderV2] = sortOrder match { + /** + * Translate aggregate expressions and group by expressions. + * + * @return translated aggregation. + */ + protected[sql] def translateAggregation( + aggregates: Seq[AggregateExpression], groupBy: Seq[Expression]): Option[Aggregation] = { + + def columnAsString(e: Expression): Option[FieldReference] = e match { + case PushableColumnWithoutNestedColumn(name) => + Some(FieldReference.column(name).asInstanceOf[FieldReference]) + case _ => None + } + + val translatedAggregates = aggregates.flatMap(translateAggregate) + val translatedGroupBys = groupBy.flatMap(columnAsString) + + if (translatedAggregates.length != aggregates.length || + translatedGroupBys.length != groupBy.length) { + return None + } + + Some(new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray)) + } + + protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[V2SortOrder] = { + def translateOortOrder(sortOrder: SortOrder): Option[V2SortOrder] = sortOrder match { case SortOrder(PushableColumnWithoutNestedColumn(name), directionV1, nullOrderingV1, _) => val directionV2 = directionV1 match { case Ascending => SortDirection.ASCENDING @@ -778,7 +802,7 @@ object DataSourceStrategy output: Seq[Attribute], rdd: RDD[Row]): RDD[InternalRow] = { if (relation.needConversion) { - val toRow = RowEncoder(StructType.fromAttributes(output)).createSerializer() + val toRow = RowEncoder(StructType.fromAttributes(output), lenient = true).createSerializer() rdd.mapPartitions { iterator => iterator.map(toRow) } @@ -835,3 +859,13 @@ object PushableColumnAndNestedColumn extends PushableColumnBase { object PushableColumnWithoutNestedColumn extends PushableColumnBase { override val nestedPredicatePushdownEnabled = false } + +/** + * Get the expression of DS V2 to represent catalyst expression that can be pushed down. + */ +object PushableExpression { + def unapply(e: Expression): Option[V2Expression] = e match { + case PushableColumnWithoutNestedColumn(name) => Some(FieldReference.column(name)) + case _ => new V2ExpressionBuilder(e).build() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala index 6ceb44ab15020..15d40a78f2346 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -47,6 +47,12 @@ object DataSourceUtils extends PredicateHelper { */ val PARTITIONING_COLUMNS_KEY = "__partition_columns" + /** + * The key to use for specifying partition overwrite mode when + * INSERT OVERWRITE a partitioned data source table. + */ + val PARTITION_OVERWRITE_MODE = "partitionOverwriteMode" + /** * Utility methods for converting partitionBy columns to options and back. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index c3bcf06b6e5fb..f9b37fb5d9fcc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.{DataType, LongType, StringType, StructField, StructType, TimestampType} +import org.apache.spark.unsafe.types.UTF8String /** @@ -191,7 +192,38 @@ object FileFormat { .add(StructField(FILE_MODIFICATION_TIME, TimestampType)) // create a file metadata struct col - def createFileMetadataCol: AttributeReference = MetadataAttribute(METADATA_NAME, METADATA_STRUCT) + def createFileMetadataCol: AttributeReference = + FileSourceMetadataAttribute(METADATA_NAME, METADATA_STRUCT) + + // create an internal row given required metadata fields and file information + def createMetadataInternalRow( + fieldNames: Seq[String], + filePath: Path, + fileSize: Long, + fileModificationTime: Long): InternalRow = + updateMetadataInternalRow(new GenericInternalRow(fieldNames.length), fieldNames, + filePath, fileSize, fileModificationTime) + + // update an internal row given required metadata fields and file information + def updateMetadataInternalRow( + row: InternalRow, + fieldNames: Seq[String], + filePath: Path, + fileSize: Long, + fileModificationTime: Long): InternalRow = { + fieldNames.zipWithIndex.foreach { case (name, i) => + name match { + case FILE_PATH => row.update(i, UTF8String.fromString(filePath.toString)) + case FILE_NAME => row.update(i, UTF8String.fromString(filePath.getName)) + case FILE_SIZE => row.update(i, fileSize) + case FILE_MODIFICATION_TIME => + // the modificationTime from the file is in millisecond, + // while internally, the TimestampType `file_modification_time` is stored in microsecond + row.update(i, fileModificationTime * 1000L) + } + } + row + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 409e33448acf8..643902e7cbcb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -111,7 +111,11 @@ object FileFormatWriter extends Logging { FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath)) val partitionSet = AttributeSet(partitionColumns) - val dataColumns = outputSpec.outputColumns.filterNot(partitionSet.contains) + // cleanup the internal metadata information of + // the file source metadata attribute if any before write out + val finalOutputSpec = outputSpec.copy(outputColumns = outputSpec.outputColumns + .map(FileSourceMetadataAttribute.cleanupFileSourceMetadataInformation)) + val dataColumns = finalOutputSpec.outputColumns.filterNot(partitionSet.contains) var needConvert = false val projectList: Seq[NamedExpression] = plan.output.map { @@ -167,12 +171,12 @@ object FileFormatWriter extends Logging { uuid = UUID.randomUUID.toString, serializableHadoopConf = new SerializableConfiguration(job.getConfiguration), outputWriterFactory = outputWriterFactory, - allColumns = outputSpec.outputColumns, + allColumns = finalOutputSpec.outputColumns, dataColumns = dataColumns, partitionColumns = partitionColumns, bucketSpec = writerBucketSpec, - path = outputSpec.outputPath, - customPartitionLocations = outputSpec.customPartitionLocations, + path = finalOutputSpec.outputPath, + customPartitionLocations = finalOutputSpec.customPartitionLocations, maxRecordsPerFile = caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong) .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile), timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION) @@ -212,7 +216,7 @@ object FileFormatWriter extends Logging { // the physical plan may have different attribute ids due to optimizer removing some // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch. val orderingExpr = bindReferences( - requiredOrdering.map(SortOrder(_, Ascending)), outputSpec.outputColumns) + requiredOrdering.map(SortOrder(_, Ascending)), finalOutputSpec.outputColumns) val sortPlan = SortExec( orderingExpr, global = false, @@ -324,7 +328,6 @@ object FileFormatWriter extends Logging { try { Utils.tryWithSafeFinallyAndFailureCallbacks(block = { // Execute the task to write rows out and commit the task. - val taskAttemptID = taskAttemptContext.getTaskAttemptID dataWriter.writeWithIterator(iterator) dataWriter.commit() })(catchBlock = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 47f279babef58..20c393a5c0e60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericInternalRow, JoinedRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.FileFormat._ -import org.apache.spark.sql.execution.vectorized.{OnHeapColumnVector, WritableColumnVector} +import org.apache.spark.sql.execution.vectorized.ConstantColumnVector import org.apache.spark.sql.types.{LongType, StringType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.unsafe.types.UTF8String @@ -93,7 +93,7 @@ class FileScanRDD( inputMetrics.setBytesRead(existingBytesRead + getBytesReadCallback()) } - private[this] val files = split.asInstanceOf[FilePartition].files.toIterator + private[this] val files = split.asInstanceOf[FilePartition].files.iterator private[this] var currentFile: PartitionedFile = null private[this] var currentIterator: Iterator[Object] = null @@ -134,59 +134,35 @@ class FileScanRDD( * For each partitioned file, metadata columns for each record in the file are exactly same. * Only update metadata row when `currentFile` is changed. */ - private def updateMetadataRow(): Unit = { + private def updateMetadataRow(): Unit = if (metadataColumns.nonEmpty && currentFile != null) { - val path = new Path(currentFile.filePath) - metadataColumns.zipWithIndex.foreach { case (attr, i) => - attr.name match { - case FILE_PATH => metadataRow.update(i, UTF8String.fromString(path.toString)) - case FILE_NAME => metadataRow.update(i, UTF8String.fromString(path.getName)) - case FILE_SIZE => metadataRow.update(i, currentFile.fileSize) - case FILE_MODIFICATION_TIME => - // the modificationTime from the file is in millisecond, - // while internally, the TimestampType is stored in microsecond - metadataRow.update(i, currentFile.modificationTime * 1000L) - } - } + updateMetadataInternalRow(metadataRow, metadataColumns.map(_.name), + new Path(currentFile.filePath), currentFile.fileSize, currentFile.modificationTime) } - } /** - * Create a writable column vector containing all required metadata columns + * Create an array of constant column vectors containing all required metadata columns */ - private def createMetadataColumnVector(c: ColumnarBatch): Array[WritableColumnVector] = { + private def createMetadataColumnVector(c: ColumnarBatch): Array[ConstantColumnVector] = { val path = new Path(currentFile.filePath) - val filePathBytes = path.toString.getBytes - val fileNameBytes = path.getName.getBytes - var rowId = 0 metadataColumns.map(_.name).map { case FILE_PATH => - val columnVector = new OnHeapColumnVector(c.numRows(), StringType) - rowId = 0 - // use a tight-loop for better performance - while (rowId < c.numRows()) { - columnVector.putByteArray(rowId, filePathBytes) - rowId += 1 - } + val columnVector = new ConstantColumnVector(c.numRows(), StringType) + columnVector.setUtf8String(UTF8String.fromString(path.toString)) columnVector case FILE_NAME => - val columnVector = new OnHeapColumnVector(c.numRows(), StringType) - rowId = 0 - // use a tight-loop for better performance - while (rowId < c.numRows()) { - columnVector.putByteArray(rowId, fileNameBytes) - rowId += 1 - } + val columnVector = new ConstantColumnVector(c.numRows(), StringType) + columnVector.setUtf8String(UTF8String.fromString(path.getName)) columnVector case FILE_SIZE => - val columnVector = new OnHeapColumnVector(c.numRows(), LongType) - columnVector.putLongs(0, c.numRows(), currentFile.fileSize) + val columnVector = new ConstantColumnVector(c.numRows(), LongType) + columnVector.setLong(currentFile.fileSize) columnVector case FILE_MODIFICATION_TIME => - val columnVector = new OnHeapColumnVector(c.numRows(), LongType) + val columnVector = new ConstantColumnVector(c.numRows(), LongType) // the modificationTime from the file is in millisecond, // while internally, the TimestampType is stored in microsecond - columnVector.putLongs(0, c.numRows(), currentFile.modificationTime * 1000L) + columnVector.setLong(currentFile.modificationTime * 1000L) columnVector }.toArray } @@ -198,10 +174,9 @@ class FileScanRDD( private def addMetadataColumnsIfNeeded(nextElement: Object): Object = { if (metadataColumns.nonEmpty) { nextElement match { - case c: ColumnarBatch => - new ColumnarBatch( - Array.tabulate(c.numCols())(c.column) ++ createMetadataColumnVector(c), - c.numRows()) + case c: ColumnarBatch => new ColumnarBatch( + Array.tabulate(c.numCols())(c.column) ++ createMetadataColumnVector(c), + c.numRows()) case u: UnsafeRow => projection.apply(new JoinedRow(u, metadataRow)) case i: InternalRow => new JoinedRow(i, metadataRow) } @@ -214,16 +189,17 @@ class FileScanRDD( val nextElement = currentIterator.next() // TODO: we should have a better separation of row based and batch based scan, so that we // don't need to run this `if` for every record. - if (nextElement.isInstanceOf[ColumnarBatch]) { - incTaskInputMetricsBytesRead() - inputMetrics.incRecordsRead(nextElement.asInstanceOf[ColumnarBatch].numRows()) - } else { - // too costly to update every record - if (inputMetrics.recordsRead % - SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { + nextElement match { + case batch: ColumnarBatch => incTaskInputMetricsBytesRead() - } - inputMetrics.incRecordsRead(1) + inputMetrics.incRecordsRead(batch.numRows()) + case _ => + // too costly to update every record + if (inputMetrics.recordsRead % + SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { + incTaskInputMetricsBytesRead() + } + inputMetrics.incRecordsRead(1) } addMetadataColumnsIfNeeded(nextElement) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index c1282fa69ca80..9356e46a69187 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -213,14 +213,13 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging { val outputSchema = readDataColumns.toStructType logInfo(s"Output Data Schema: ${outputSchema.simpleString(5)}") - val metadataStructOpt = requiredAttributes.collectFirst { - case MetadataAttribute(attr) => attr + val metadataStructOpt = l.output.collectFirst { + case FileSourceMetadataAttribute(attr) => attr } - // TODO (yaohua): should be able to prune the metadata struct only containing what needed val metadataColumns = metadataStructOpt.map { metadataStruct => metadataStruct.dataType.asInstanceOf[StructType].fields.map { field => - MetadataAttribute(field.name, field.dataType) + FileSourceMetadataAttribute(field.name, field.dataType) }.toSeq }.getOrElse(Seq.empty) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 267b360b474ca..74be483cd7c37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -62,7 +62,7 @@ case class InsertIntoHadoopFsRelationCommand( private lazy val parameters = CaseInsensitiveMap(options) private[sql] lazy val dynamicPartitionOverwrite: Boolean = { - val partitionOverwriteMode = parameters.get("partitionOverwriteMode") + val partitionOverwriteMode = parameters.get(DataSourceUtils.PARTITION_OVERWRITE_MODE) // scalastyle:off caselocale .map(mode => PartitionOverwriteMode.withName(mode.toUpperCase)) // scalastyle:on caselocale diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index 5b0d0606da093..d70c4b11bc0d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.execution.datasources.FileFormat.createMetadataInternalRow import org.apache.spark.sql.types.StructType /** @@ -71,8 +72,39 @@ abstract class PartitioningAwareFileIndex( def isNonEmptyFile(f: FileStatus): Boolean = { isDataPath(f.getPath) && f.getLen > 0 } + + // retrieve the file metadata filters and reduce to a final filter expression + val fileMetadataFilterOpt = dataFilters.filter { f => + f.references.nonEmpty && f.references.forall { + case FileSourceMetadataAttribute(_) => true + case _ => false + } + }.reduceOption(expressions.And) + + // - create a bound references for filters: put the metadata struct at 0 position for each file + // - retrieve the final metadata struct (could be pruned) from filters + val boundedFilterMetadataStructOpt = fileMetadataFilterOpt.map { fileMetadataFilter => + val metadataStruct = fileMetadataFilter.references.head.dataType + val boundedFilter = Predicate.createInterpreted(fileMetadataFilter.transform { + case _: AttributeReference => BoundReference(0, metadataStruct, nullable = true) + }) + (boundedFilter, metadataStruct) + } + + def matchFileMetadataPredicate(f: FileStatus): Boolean = { + // use option.forall, so if there is no filter no metadata struct, return true + boundedFilterMetadataStructOpt.forall { case (boundedFilter, metadataStruct) => + val row = InternalRow.fromSeq(Seq( + createMetadataInternalRow(metadataStruct.asInstanceOf[StructType].names, + f.getPath, f.getLen, f.getModificationTime) + )) + boundedFilter.eval(row) + } + } + val selectedPartitions = if (partitionSpec().partitionColumns.isEmpty) { - PartitionDirectory(InternalRow.empty, allFiles().filter(isNonEmptyFile)) :: Nil + PartitionDirectory(InternalRow.empty, allFiles() + .filter(f => isNonEmptyFile(f) && matchFileMetadataPredicate(f))) :: Nil } else { if (recursiveFileLookup) { throw new IllegalArgumentException( @@ -83,7 +115,8 @@ abstract class PartitioningAwareFileIndex( val files: Seq[FileStatus] = leafDirToChildrenFiles.get(path) match { case Some(existingDir) => // Directory has children files in it, return them - existingDir.filter(f => matchPathPattern(f) && isNonEmptyFile(f)) + existingDir.filter(f => matchPathPattern(f) && isNonEmptyFile(f) && + matchFileMetadataPredicate(f)) case None => // Directory does not exist, or has no children files diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 88543bd19bb4f..8d71cf65807c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -262,7 +262,7 @@ object PartitioningUtils extends SQLConfHelper{ // Once we get the string, we try to parse it and find the partition column and value. val maybeColumn = parsePartitionColumn(currentPath.getName, typeInference, userSpecifiedDataTypes, - validatePartitionColumns, zoneId, dateFormatter, timestampFormatter) + zoneId, dateFormatter, timestampFormatter) maybeColumn.foreach(columns += _) // Now, we determine if we should stop. @@ -296,7 +296,6 @@ object PartitioningUtils extends SQLConfHelper{ columnSpec: String, typeInference: Boolean, userSpecifiedDataTypes: Map[String, DataType], - validatePartitionColumns: Boolean, zoneId: ZoneId, dateFormatter: DateFormatter, timestampFormatter: TimestampFormatter): Option[(String, TypedPartValue)] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala index 93bd1acc7377d..a49c10c852b08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala @@ -31,58 +31,68 @@ import org.apache.spark.sql.util.SchemaUtils._ * By "physical column", we mean a column as defined in the data source format like Parquet format * or ORC format. For example, in Spark SQL, a root-level Parquet column corresponds to a SQL * column, and a nested Parquet column corresponds to a [[StructField]]. + * + * Also prunes the unnecessary metadata columns if any for all file formats. */ object SchemaPruning extends Rule[LogicalPlan] { import org.apache.spark.sql.catalyst.expressions.SchemaPruning._ override def apply(plan: LogicalPlan): LogicalPlan = - if (conf.nestedSchemaPruningEnabled) { - apply0(plan) - } else { - plan - } - - private def apply0(plan: LogicalPlan): LogicalPlan = plan transformDown { case op @ PhysicalOperation(projects, filters, - l @ LogicalRelation(hadoopFsRelation: HadoopFsRelation, _, _, _)) - if canPruneRelation(hadoopFsRelation) => - - prunePhysicalColumns(l.output, projects, filters, hadoopFsRelation.dataSchema, - prunedDataSchema => { + l @ LogicalRelation(hadoopFsRelation: HadoopFsRelation, _, _, _)) => + prunePhysicalColumns(l, projects, filters, hadoopFsRelation, + (prunedDataSchema, prunedMetadataSchema) => { val prunedHadoopRelation = hadoopFsRelation.copy(dataSchema = prunedDataSchema)(hadoopFsRelation.sparkSession) - buildPrunedRelation(l, prunedHadoopRelation) + buildPrunedRelation(l, prunedHadoopRelation, prunedMetadataSchema) }).getOrElse(op) } /** * This method returns optional logical plan. `None` is returned if no nested field is required or * all nested fields are required. + * + * This method will prune both the data schema and the metadata schema */ private def prunePhysicalColumns( - output: Seq[AttributeReference], + relation: LogicalRelation, projects: Seq[NamedExpression], filters: Seq[Expression], - dataSchema: StructType, - leafNodeBuilder: StructType => LeafNode): Option[LogicalPlan] = { + hadoopFsRelation: HadoopFsRelation, + leafNodeBuilder: (StructType, StructType) => LeafNode): Option[LogicalPlan] = { + val (normalizedProjects, normalizedFilters) = - normalizeAttributeRefNames(output, projects, filters) + normalizeAttributeRefNames(relation.output, projects, filters) val requestedRootFields = identifyRootFields(normalizedProjects, normalizedFilters) // If requestedRootFields includes a nested field, continue. Otherwise, // return op if (requestedRootFields.exists { root: RootField => !root.derivedFromAtt }) { - val prunedDataSchema = pruneDataSchema(dataSchema, requestedRootFields) - // If the data schema is different from the pruned data schema, continue. Otherwise, - // return op. We effect this comparison by counting the number of "leaf" fields in - // each schemata, assuming the fields in prunedDataSchema are a subset of the fields - // in dataSchema. - if (countLeaves(dataSchema) > countLeaves(prunedDataSchema)) { - val prunedRelation = leafNodeBuilder(prunedDataSchema) - val projectionOverSchema = ProjectionOverSchema(prunedDataSchema) + val prunedDataSchema = if (canPruneDataSchema(hadoopFsRelation)) { + pruneSchema(hadoopFsRelation.dataSchema, requestedRootFields) + } else { + hadoopFsRelation.dataSchema + } + + val metadataSchema = + relation.output.collect { case FileSourceMetadataAttribute(attr) => attr }.toStructType + val prunedMetadataSchema = if (metadataSchema.nonEmpty) { + pruneSchema(metadataSchema, requestedRootFields) + } else { + metadataSchema + } + // If the data schema is different from the pruned data schema + // OR + // the metadata schema is different from the pruned metadata schema, continue. + // Otherwise, return None. + if (countLeaves(hadoopFsRelation.dataSchema) > countLeaves(prunedDataSchema) || + countLeaves(metadataSchema) > countLeaves(prunedMetadataSchema)) { + val prunedRelation = leafNodeBuilder(prunedDataSchema, prunedMetadataSchema) + val projectionOverSchema = + ProjectionOverSchema(prunedDataSchema.merge(prunedMetadataSchema)) Some(buildNewProjection(projects, normalizedProjects, normalizedFilters, prunedRelation, projectionOverSchema)) } else { @@ -96,9 +106,10 @@ object SchemaPruning extends Rule[LogicalPlan] { /** * Checks to see if the given relation can be pruned. Currently we support Parquet and ORC v1. */ - private def canPruneRelation(fsRelation: HadoopFsRelation) = - fsRelation.fileFormat.isInstanceOf[ParquetFileFormat] || - fsRelation.fileFormat.isInstanceOf[OrcFileFormat] + private def canPruneDataSchema(fsRelation: HadoopFsRelation): Boolean = + conf.nestedSchemaPruningEnabled && ( + fsRelation.fileFormat.isInstanceOf[ParquetFileFormat] || + fsRelation.fileFormat.isInstanceOf[OrcFileFormat]) /** * Normalizes the names of the attribute references in the given projects and filters to reflect @@ -162,29 +173,25 @@ object SchemaPruning extends Rule[LogicalPlan] { */ private def buildPrunedRelation( outputRelation: LogicalRelation, - prunedBaseRelation: HadoopFsRelation) = { - val prunedOutput = getPrunedOutput(outputRelation.output, prunedBaseRelation.schema) - // also add the metadata output if any - // TODO: should be able to prune the metadata schema - val metaOutput = outputRelation.output.collect { - case MetadataAttribute(attr) => attr - } - outputRelation.copy(relation = prunedBaseRelation, output = prunedOutput ++ metaOutput) + prunedBaseRelation: HadoopFsRelation, + prunedMetadataSchema: StructType) = { + val finalSchema = prunedBaseRelation.schema.merge(prunedMetadataSchema) + val prunedOutput = getPrunedOutput(outputRelation.output, finalSchema) + outputRelation.copy(relation = prunedBaseRelation, output = prunedOutput) } // Prune the given output to make it consistent with `requiredSchema`. private def getPrunedOutput( output: Seq[AttributeReference], requiredSchema: StructType): Seq[AttributeReference] = { - // We need to replace the expression ids of the pruned relation output attributes - // with the expression ids of the original relation output attributes so that - // references to the original relation's output are not broken - val outputIdMap = output.map(att => (att.name, att.exprId)).toMap + // We need to update the data type of the output attributes to use the pruned ones. + // so that references to the original relation's output are not broken + val nameAttributeMap = output.map(att => (att.name, att)).toMap requiredSchema .toAttributes .map { - case att if outputIdMap.contains(att.name) => - att.withExprId(outputIdMap(att.name)) + case att if nameAttributeMap.contains(att.name) => + nameAttributeMap(att.name).withDataType(att.dataType) case att => att } } @@ -203,6 +210,4 @@ object SchemaPruning extends Rule[LogicalPlan] { case _ => 1 } } - - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index d40ad9d1bf0e9..8d9525078402e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser} import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.util.CompressionCodecs -import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -101,21 +100,20 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - + val columnPruning = sparkSession.sessionState.conf.csvColumnPruning && + !requiredSchema.exists(_.name == sparkSession.sessionState.conf.columnNameOfCorruptRecord) val parsedOptions = new CSVOptions( options, - sparkSession.sessionState.conf.csvColumnPruning, + columnPruning, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) // Check a field requirement for corrupt records here to throw an exception in a driver side ExprUtils.verifyColumnNameOfCorruptRecord(dataSchema, parsedOptions.columnNameOfCorruptRecord) - - if (requiredSchema.length == 1 && - requiredSchema.head.name == parsedOptions.columnNameOfCorruptRecord) { - throw QueryCompilationErrors.queryFromRawFilesIncludeCorruptRecordColumnError() - } - val columnPruning = sparkSession.sessionState.conf.csvColumnPruning + // Don't push any filter which refers to the "virtual" column which cannot present in the input. + // Such filters will be applied later on the upper layer. + val actualFilters = + filters.filterNot(_.references.contains(parsedOptions.columnNameOfCorruptRecord)) (file: PartitionedFile) => { val conf = broadcastedHadoopConf.value.value @@ -127,7 +125,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { actualDataSchema, actualRequiredSchema, parsedOptions, - filters) + actualFilters) val schema = if (columnPruning) actualRequiredSchema else actualDataSchema val isStartOfFile = file.start == 0 val headerChecker = new CSVHeaderChecker( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index d081e0ace0e44..ad44048ce9c6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -216,7 +216,7 @@ class JDBCOptions( // The principal name of user's keytab file val principal = parameters.getOrElse(JDBC_PRINCIPAL, null) - val tableComment = parameters.getOrElse(JDBC_TABLE_COMMENT, "").toString + val tableComment = parameters.getOrElse(JDBC_TABLE_COMMENT, "") val refreshKrb5Config = parameters.getOrElse(JDBC_REFRESH_KRB5_CONFIG, "false").toBoolean diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index baee53847a5a4..3cd2e03828212 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -24,8 +24,10 @@ import scala.util.control.NonFatal import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.expressions.SortOrder +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects} import org.apache.spark.sql.sources._ @@ -60,7 +62,7 @@ object JDBCRDD extends Logging { def getQueryOutputSchema( query: String, options: JDBCOptions, dialect: JdbcDialect): StructType = { - val conn: Connection = JdbcUtils.createConnectionFactory(options)() + val conn: Connection = dialect.createConnectionFactory(options)(-1) try { val statement = conn.prepareStatement(query) try { @@ -97,7 +99,14 @@ object JDBCRDD extends Logging { * Returns None for an unhandled filter. */ def compileFilter(f: Filter, dialect: JdbcDialect): Option[String] = { - def quote(colName: String): String = dialect.quoteIdentifier(colName) + + def quote(colName: String): String = { + val nameParts = SparkSession.active.sessionState.sqlParser.parseMultipartIdentifier(colName) + if(nameParts.length > 1) { + throw QueryCompilationErrors.commandNotSupportNestedColumnError("Filter push down", colName) + } + dialect.quoteIdentifier(nameParts.head) + } Option(f match { case EqualTo(attr, value) => s"${quote(attr)} = ${dialect.compileValue(value)}" @@ -182,7 +191,7 @@ object JDBCRDD extends Logging { } new JDBCRDD( sc, - JdbcUtils.createConnectionFactory(options), + dialect.createConnectionFactory(options), outputSchema.getOrElse(pruneSchema(schema, requiredColumns)), quotedColumns, filters, @@ -204,7 +213,7 @@ object JDBCRDD extends Logging { */ private[jdbc] class JDBCRDD( sc: SparkContext, - getConnection: () => Connection, + getConnection: Int => Connection, schema: StructType, columns: Array[String], filters: Array[Filter], @@ -318,7 +327,7 @@ private[jdbc] class JDBCRDD( val inputMetrics = context.taskMetrics().inputMetrics val part = thePart.asInstanceOf[JDBCPartition] - conn = getConnection() + conn = getConnection(part.idx) val dialect = JdbcDialects.get(url) import scala.collection.JavaConverters._ dialect.beforeFetch(conn, options.asProperties.asScala.toMap) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index d953ba45cc2fb..2760c7ac3019c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.jdbc import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils._ +import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider} class JdbcRelationProvider extends CreatableRelationProvider @@ -45,8 +46,8 @@ class JdbcRelationProvider extends CreatableRelationProvider df: DataFrame): BaseRelation = { val options = new JdbcOptionsInWrite(parameters) val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis - - val conn = JdbcUtils.createConnectionFactory(options)() + val dialect = JdbcDialects.get(options.url) + val conn = dialect.createConnectionFactory(options)(-1) try { val tableExists = JdbcUtils.tableExists(conn, options) if (tableExists) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 7f68a73f8950a..6c67a22b8e3ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.jdbc -import java.sql.{Connection, Driver, JDBCType, PreparedStatement, ResultSet, ResultSetMetaData, SQLException} +import java.sql.{Connection, JDBCType, PreparedStatement, ResultSet, ResultSetMetaData, SQLException} import java.time.{Instant, LocalDate} import java.util import java.util.Locale @@ -43,7 +43,6 @@ import org.apache.spark.sql.connector.catalog.TableChange import org.apache.spark.sql.connector.catalog.index.{SupportsIndex, TableIndex} import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.execution.datasources.jdbc.connection.ConnectionProvider import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType} import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils @@ -54,24 +53,6 @@ import org.apache.spark.util.NextIterator * Util functions for JDBC tables. */ object JdbcUtils extends Logging with SQLConfHelper { - /** - * Returns a factory for creating connections to the given JDBC URL. - * - * @param options - JDBC options that contains url, table and other information. - * @throws IllegalArgumentException if the driver could not open a JDBC connection. - */ - def createConnectionFactory(options: JDBCOptions): () => Connection = { - val driverClass: String = options.driverClass - () => { - DriverRegistry.register(driverClass) - val driver: Driver = DriverRegistry.get(driverClass) - val connection = - ConnectionProvider.create(driver, options.parameters, options.connectionProviderName) - require(connection != null, - s"The driver could not open a JDBC connection. Check the URL: ${options.url}") - connection - } - } /** * Returns true if the table already exists in the JDBC database. @@ -651,7 +632,6 @@ object JdbcUtils extends Logging with SQLConfHelper { * updated even with error if it doesn't support transaction, as there're dirty outputs. */ def savePartition( - getConnection: () => Connection, table: String, iterator: Iterator[Row], rddSchema: StructType, @@ -667,7 +647,7 @@ object JdbcUtils extends Logging with SQLConfHelper { val outMetrics = TaskContext.get().taskMetrics().outputMetrics - val conn = getConnection() + val conn = dialect.createConnectionFactory(options)(-1) var committed = false var finalIsolationLevel = Connection.TRANSACTION_NONE @@ -874,7 +854,6 @@ object JdbcUtils extends Logging with SQLConfHelper { val table = options.table val dialect = JdbcDialects.get(url) val rddSchema = df.schema - val getConnection: () => Connection = createConnectionFactory(options) val batchSize = options.batchSize val isolationLevel = options.isolationLevel @@ -886,8 +865,7 @@ object JdbcUtils extends Logging with SQLConfHelper { case _ => df } repartitionedDF.rdd.foreachPartition { iterator => savePartition( - getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel, - options) + table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel, options) } } @@ -971,52 +949,57 @@ object JdbcUtils extends Logging with SQLConfHelper { } /** - * Creates a namespace. + * Creates a schema. */ - def createNamespace( + def createSchema( conn: Connection, options: JDBCOptions, - namespace: String, + schema: String, comment: String): Unit = { + val statement = conn.createStatement + try { + statement.setQueryTimeout(options.queryTimeout) + val dialect = JdbcDialects.get(options.url) + dialect.createSchema(statement, schema, comment) + } finally { + statement.close() + } + } + + def schemaExists(conn: Connection, options: JDBCOptions, schema: String): Boolean = { val dialect = JdbcDialects.get(options.url) - executeStatement(conn, options, s"CREATE SCHEMA ${dialect.quoteIdentifier(namespace)}") - if (!comment.isEmpty) createNamespaceComment(conn, options, namespace, comment) + dialect.schemasExists(conn, options, schema) } - def createNamespaceComment( + def listSchemas(conn: Connection, options: JDBCOptions): Array[Array[String]] = { + val dialect = JdbcDialects.get(options.url) + dialect.listSchemas(conn, options) + } + + def alterSchemaComment( conn: Connection, options: JDBCOptions, - namespace: String, + schema: String, comment: String): Unit = { val dialect = JdbcDialects.get(options.url) - try { - executeStatement( - conn, options, dialect.getSchemaCommentQuery(namespace, comment)) - } catch { - case e: Exception => - logWarning("Cannot create JDBC catalog comment. The catalog comment will be ignored.") - } + executeStatement(conn, options, dialect.getSchemaCommentQuery(schema, comment)) } - def removeNamespaceComment( + def removeSchemaComment( conn: Connection, options: JDBCOptions, - namespace: String): Unit = { + schema: String): Unit = { val dialect = JdbcDialects.get(options.url) - try { - executeStatement(conn, options, dialect.removeSchemaCommentQuery(namespace)) - } catch { - case e: Exception => - logWarning("Cannot drop JDBC catalog comment.") - } + executeStatement(conn, options, dialect.removeSchemaCommentQuery(schema)) } /** - * Drops a namespace from the JDBC database. + * Drops a schema from the JDBC database. */ - def dropNamespace(conn: Connection, options: JDBCOptions, namespace: String): Unit = { + def dropSchema( + conn: Connection, options: JDBCOptions, schema: String, cascade: Boolean): Unit = { val dialect = JdbcDialects.get(options.url) - executeStatement(conn, options, s"DROP SCHEMA ${dialect.quoteIdentifier(namespace)}") + executeStatement(conn, options, dialect.dropSchema(schema, cascade)) } /** @@ -1147,11 +1130,17 @@ object JdbcUtils extends Logging with SQLConfHelper { } } - def executeQuery(conn: Connection, options: JDBCOptions, sql: String): ResultSet = { + def executeQuery(conn: Connection, options: JDBCOptions, sql: String)( + f: ResultSet => Unit): Unit = { val statement = conn.createStatement try { statement.setQueryTimeout(options.queryTimeout) - statement.executeQuery(sql) + val rs = statement.executeQuery(sql) + try { + f(rs) + } finally { + rs.close() + } } finally { statement.close() } @@ -1166,7 +1155,8 @@ object JdbcUtils extends Logging with SQLConfHelper { } def withConnection[T](options: JDBCOptions)(f: Connection => T): T = { - val conn = createConnectionFactory(options)() + val dialect = JdbcDialects.get(options.url) + val conn = dialect.createConnectionFactory(options)(-1) try { f(conn) } finally { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala index 84a62693a6e7d..0d8c80c9fc15c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala @@ -104,4 +104,4 @@ protected abstract class ConnectionProviderBase extends Logging { } } -private[jdbc] object ConnectionProvider extends ConnectionProviderBase +private[sql] object ConnectionProvider extends ConnectionProviderBase diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index ce851c58cc4fa..39a8763160530 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -228,13 +228,4 @@ class OrcFileFormat case _ => false } - - override def supportFieldName(name: String): Boolean = { - try { - TypeDescription.fromString(s"struct<`$name`:int>") - true - } catch { - case _: IllegalArgumentException => false - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index d1b7e8db619b1..1f05117462db8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -44,6 +44,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, SchemaMergeUtils} +import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils import org.apache.spark.sql.types._ import org.apache.spark.util.{ThreadUtils, Utils} @@ -146,7 +147,7 @@ object OrcUtils extends Logging { : Option[StructType] = { val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles val conf = sparkSession.sessionState.newHadoopConfWithOptions(options) - files.toIterator.map(file => readSchema(file.getPath, conf, ignoreCorruptFiles)).collectFirst { + files.iterator.map(file => readSchema(file.getPath, conf, ignoreCorruptFiles)).collectFirst { case Some(schema) => logDebug(s"Reading schema from file $files, got Hive schema string: $schema") toCatalystSchema(schema) @@ -204,6 +205,8 @@ object OrcUtils extends Logging { orcCatalystSchema.fields.map(_.dataType).zip(dataSchema.fields.map(_.dataType)).foreach { case (TimestampType, TimestampNTZType) => throw QueryExecutionErrors.cannotConvertOrcTimestampToTimestampNTZError() + case (TimestampNTZType, TimestampType) => + throw QueryExecutionErrors.cannotConvertOrcTimestampNTZToTimestampLTZError() case (t1: StructType, t2: StructType) => checkTimestampCompatibility(t1, t2) case _ => } @@ -487,18 +490,18 @@ object OrcUtils extends Logging { val aggORCValues: Seq[WritableComparable[_]] = aggregation.aggregateExpressions.zipWithIndex.map { - case (max: Max, index) => - val columnName = max.column.fieldNames.head + case (max: Max, index) if V2ColumnUtils.extractV2Column(max.column).isDefined => + val columnName = V2ColumnUtils.extractV2Column(max.column).get val statistics = getColumnStatistics(columnName) val dataType = schemaWithoutGroupBy(index).dataType getMinMaxFromColumnStatistics(statistics, dataType, isMax = true) - case (min: Min, index) => - val columnName = min.column.fieldNames.head + case (min: Min, index) if V2ColumnUtils.extractV2Column(min.column).isDefined => + val columnName = V2ColumnUtils.extractV2Column(min.column).get val statistics = getColumnStatistics(columnName) val dataType = schemaWithoutGroupBy.apply(index).dataType getMinMaxFromColumnStatistics(statistics, dataType, isMax = false) - case (count: Count, _) => - val columnName = count.column.fieldNames.head + case (count: Count, _) if V2ColumnUtils.extractV2Column(count.column).isDefined => + val columnName = V2ColumnUtils.extractV2Column(count.column).get val isPartitionColumn = partitionSchema.fields.map(_.name).contains(columnName) // NOTE: Count(columnName) doesn't include null values. // org.apache.orc.ColumnStatistics.getNumberOfValues() returns number of non-null values diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 4515387bdaa90..18876dedb951e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -119,6 +119,10 @@ class ParquetFileFormat SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key, sparkSession.sessionState.conf.parquetOutputTimestampType.toString) + conf.set( + SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.key, + sparkSession.sessionState.conf.parquetFieldIdWriteEnabled.toString) + // Sets compression scheme conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName) @@ -218,8 +222,6 @@ class ParquetFileFormat SQLConf.CASE_SENSITIVE.key, sparkSession.sessionState.conf.caseSensitiveAnalysis) - ParquetWriteSupport.setSchema(requiredSchema, hadoopConf) - // Sets flags for `ParquetToSparkSchemaConverter` hadoopConf.setBoolean( SQLConf.PARQUET_BINARY_AS_STRING.key, @@ -407,10 +409,6 @@ class ParquetFileFormat case _ => false } - - override def supportFieldName(name: String): Boolean = { - !name.matches(".*[ ,;{}()\n\t=].*") - } } object ParquetFileFormat extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala index bdab0f7892f00..69684f9466f98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.execution.datasources.parquet import java.time.ZoneId -import java.util.{Locale, Map => JMap} +import java.util +import java.util.{Locale, Map => JMap, UUID} import scala.collection.JavaConverters._ @@ -85,13 +86,71 @@ class ParquetReadSupport( StructType.fromString(schemaString) } + val parquetRequestedSchema = ParquetReadSupport.getRequestedSchema( + context.getFileSchema, catalystRequestedSchema, conf, enableVectorizedReader) + new ReadContext(parquetRequestedSchema, new util.HashMap[String, String]()) + } + + /** + * Called on executor side after [[init()]], before instantiating actual Parquet record readers. + * Responsible for instantiating [[RecordMaterializer]], which is used for converting Parquet + * records to Catalyst [[InternalRow]]s. + */ + override def prepareForRead( + conf: Configuration, + keyValueMetaData: JMap[String, String], + fileSchema: MessageType, + readContext: ReadContext): RecordMaterializer[InternalRow] = { + val parquetRequestedSchema = readContext.getRequestedSchema + new ParquetRecordMaterializer( + parquetRequestedSchema, + ParquetReadSupport.expandUDT(catalystRequestedSchema), + new ParquetToSparkSchemaConverter(conf), + convertTz, + datetimeRebaseSpec, + int96RebaseSpec) + } +} + +object ParquetReadSupport extends Logging { + val SPARK_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema" + + val SPARK_METADATA_KEY = "org.apache.spark.sql.parquet.row.metadata" + + def generateFakeColumnName: String = s"_fake_name_${UUID.randomUUID()}" + + def getRequestedSchema( + parquetFileSchema: MessageType, + catalystRequestedSchema: StructType, + conf: Configuration, + enableVectorizedReader: Boolean): MessageType = { val caseSensitive = conf.getBoolean(SQLConf.CASE_SENSITIVE.key, SQLConf.CASE_SENSITIVE.defaultValue.get) val schemaPruningEnabled = conf.getBoolean(SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.defaultValue.get) - val parquetFileSchema = context.getFileSchema + val useFieldId = conf.getBoolean(SQLConf.PARQUET_FIELD_ID_READ_ENABLED.key, + SQLConf.PARQUET_FIELD_ID_READ_ENABLED.defaultValue.get) + val ignoreMissingIds = conf.getBoolean(SQLConf.IGNORE_MISSING_PARQUET_FIELD_ID.key, + SQLConf.IGNORE_MISSING_PARQUET_FIELD_ID.defaultValue.get) + + if (!ignoreMissingIds && + !containsFieldIds(parquetFileSchema) && + ParquetUtils.hasFieldIds(catalystRequestedSchema)) { + throw new RuntimeException( + "Spark read schema expects field Ids, " + + "but Parquet file schema doesn't contain any field Ids.\n" + + "Please remove the field ids from Spark schema or ignore missing ids by " + + s"setting `${SQLConf.IGNORE_MISSING_PARQUET_FIELD_ID.key} = true`\n" + + s""" + |Spark read schema: + |${catalystRequestedSchema.prettyJson} + | + |Parquet file schema: + |${parquetFileSchema.toString} + |""".stripMargin) + } val parquetClippedSchema = ParquetReadSupport.clipParquetSchema(parquetFileSchema, - catalystRequestedSchema, caseSensitive) + catalystRequestedSchema, caseSensitive, useFieldId) // We pass two schema to ParquetRecordMaterializer: // - parquetRequestedSchema: the schema of the file data we want to read @@ -109,6 +168,7 @@ class ParquetReadSupport( // in parquetRequestedSchema which are not present in the file. parquetClippedSchema } + logDebug( s"""Going to read the following fields from the Parquet file with the following schema: |Parquet file schema: @@ -120,34 +180,20 @@ class ParquetReadSupport( |Catalyst requested schema: |${catalystRequestedSchema.treeString} """.stripMargin) - new ReadContext(parquetRequestedSchema, Map.empty[String, String].asJava) + + parquetRequestedSchema } /** - * Called on executor side after [[init()]], before instantiating actual Parquet record readers. - * Responsible for instantiating [[RecordMaterializer]], which is used for converting Parquet - * records to Catalyst [[InternalRow]]s. + * Overloaded method for backward compatibility with + * `caseSensitive` default to `true` and `useFieldId` default to `false` */ - override def prepareForRead( - conf: Configuration, - keyValueMetaData: JMap[String, String], - fileSchema: MessageType, - readContext: ReadContext): RecordMaterializer[InternalRow] = { - val parquetRequestedSchema = readContext.getRequestedSchema - new ParquetRecordMaterializer( - parquetRequestedSchema, - ParquetReadSupport.expandUDT(catalystRequestedSchema), - new ParquetToSparkSchemaConverter(conf), - convertTz, - datetimeRebaseSpec, - int96RebaseSpec) + def clipParquetSchema( + parquetSchema: MessageType, + catalystSchema: StructType, + caseSensitive: Boolean = true): MessageType = { + clipParquetSchema(parquetSchema, catalystSchema, caseSensitive, useFieldId = false) } -} - -object ParquetReadSupport { - val SPARK_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema" - - val SPARK_METADATA_KEY = "org.apache.spark.sql.parquet.row.metadata" /** * Tailors `parquetSchema` according to `catalystSchema` by removing column paths don't exist @@ -156,9 +202,10 @@ object ParquetReadSupport { def clipParquetSchema( parquetSchema: MessageType, catalystSchema: StructType, - caseSensitive: Boolean = true): MessageType = { + caseSensitive: Boolean, + useFieldId: Boolean): MessageType = { val clippedParquetFields = clipParquetGroupFields( - parquetSchema.asGroupType(), catalystSchema, caseSensitive) + parquetSchema.asGroupType(), catalystSchema, caseSensitive, useFieldId) if (clippedParquetFields.isEmpty) { ParquetSchemaConverter.EMPTY_MESSAGE } else { @@ -170,26 +217,36 @@ object ParquetReadSupport { } private def clipParquetType( - parquetType: Type, catalystType: DataType, caseSensitive: Boolean): Type = { - catalystType match { + parquetType: Type, + catalystType: DataType, + caseSensitive: Boolean, + useFieldId: Boolean): Type = { + val newParquetType = catalystType match { case t: ArrayType if !isPrimitiveCatalystType(t.elementType) => // Only clips array types with nested type as element type. - clipParquetListType(parquetType.asGroupType(), t.elementType, caseSensitive) + clipParquetListType(parquetType.asGroupType(), t.elementType, caseSensitive, useFieldId) case t: MapType if !isPrimitiveCatalystType(t.keyType) || !isPrimitiveCatalystType(t.valueType) => // Only clips map types with nested key type or value type - clipParquetMapType(parquetType.asGroupType(), t.keyType, t.valueType, caseSensitive) + clipParquetMapType( + parquetType.asGroupType(), t.keyType, t.valueType, caseSensitive, useFieldId) case t: StructType => - clipParquetGroup(parquetType.asGroupType(), t, caseSensitive) + clipParquetGroup(parquetType.asGroupType(), t, caseSensitive, useFieldId) case _ => // UDTs and primitive types are not clipped. For UDTs, a clipped version might not be able // to be mapped to desired user-space types. So UDTs shouldn't participate schema merging. parquetType } + + if (useFieldId && parquetType.getId != null) { + newParquetType.withId(parquetType.getId.intValue()) + } else { + newParquetType + } } /** @@ -210,7 +267,10 @@ object ParquetReadSupport { * [[StructType]]. */ private def clipParquetListType( - parquetList: GroupType, elementType: DataType, caseSensitive: Boolean): Type = { + parquetList: GroupType, + elementType: DataType, + caseSensitive: Boolean, + useFieldId: Boolean): Type = { // Precondition of this method, should only be called for lists with nested element types. assert(!isPrimitiveCatalystType(elementType)) @@ -218,7 +278,7 @@ object ParquetReadSupport { // list element type is just the group itself. Clip it. if (parquetList.getLogicalTypeAnnotation == null && parquetList.isRepetition(Repetition.REPEATED)) { - clipParquetType(parquetList, elementType, caseSensitive) + clipParquetType(parquetList, elementType, caseSensitive, useFieldId) } else { assert( parquetList.getLogicalTypeAnnotation.isInstanceOf[ListLogicalTypeAnnotation], @@ -250,19 +310,28 @@ object ParquetReadSupport { Types .buildGroup(parquetList.getRepetition) .as(LogicalTypeAnnotation.listType()) - .addField(clipParquetType(repeatedGroup, elementType, caseSensitive)) + .addField(clipParquetType(repeatedGroup, elementType, caseSensitive, useFieldId)) .named(parquetList.getName) } else { + val newRepeatedGroup = Types + .repeatedGroup() + .addField( + clipParquetType( + repeatedGroup.getType(0), elementType, caseSensitive, useFieldId)) + .named(repeatedGroup.getName) + + val newElementType = if (useFieldId && repeatedGroup.getId != null) { + newRepeatedGroup.withId(repeatedGroup.getId.intValue()) + } else { + newRepeatedGroup + } + // Otherwise, the repeated field's type is the element type with the repeated field's // repetition. Types .buildGroup(parquetList.getRepetition) .as(LogicalTypeAnnotation.listType()) - .addField( - Types - .repeatedGroup() - .addField(clipParquetType(repeatedGroup.getType(0), elementType, caseSensitive)) - .named(repeatedGroup.getName)) + .addField(newElementType) .named(parquetList.getName) } } @@ -277,7 +346,8 @@ object ParquetReadSupport { parquetMap: GroupType, keyType: DataType, valueType: DataType, - caseSensitive: Boolean): GroupType = { + caseSensitive: Boolean, + useFieldId: Boolean): GroupType = { // Precondition of this method, only handles maps with nested key types or value types. assert(!isPrimitiveCatalystType(keyType) || !isPrimitiveCatalystType(valueType)) @@ -285,13 +355,19 @@ object ParquetReadSupport { val parquetKeyType = repeatedGroup.getType(0) val parquetValueType = repeatedGroup.getType(1) - val clippedRepeatedGroup = - Types + val clippedRepeatedGroup = { + val newRepeatedGroup = Types .repeatedGroup() .as(repeatedGroup.getLogicalTypeAnnotation) - .addField(clipParquetType(parquetKeyType, keyType, caseSensitive)) - .addField(clipParquetType(parquetValueType, valueType, caseSensitive)) + .addField(clipParquetType(parquetKeyType, keyType, caseSensitive, useFieldId)) + .addField(clipParquetType(parquetValueType, valueType, caseSensitive, useFieldId)) .named(repeatedGroup.getName) + if (useFieldId && repeatedGroup.getId != null) { + newRepeatedGroup.withId(repeatedGroup.getId.intValue()) + } else { + newRepeatedGroup + } + } Types .buildGroup(parquetMap.getRepetition) @@ -309,8 +385,12 @@ object ParquetReadSupport { * pruning. */ private def clipParquetGroup( - parquetRecord: GroupType, structType: StructType, caseSensitive: Boolean): GroupType = { - val clippedParquetFields = clipParquetGroupFields(parquetRecord, structType, caseSensitive) + parquetRecord: GroupType, + structType: StructType, + caseSensitive: Boolean, + useFieldId: Boolean): GroupType = { + val clippedParquetFields = + clipParquetGroupFields(parquetRecord, structType, caseSensitive, useFieldId) Types .buildGroup(parquetRecord.getRepetition) .as(parquetRecord.getLogicalTypeAnnotation) @@ -324,23 +404,29 @@ object ParquetReadSupport { * @return A list of clipped [[GroupType]] fields, which can be empty. */ private def clipParquetGroupFields( - parquetRecord: GroupType, structType: StructType, caseSensitive: Boolean): Seq[Type] = { - val toParquet = new SparkToParquetSchemaConverter(writeLegacyParquetFormat = false) - if (caseSensitive) { - val caseSensitiveParquetFieldMap = + parquetRecord: GroupType, + structType: StructType, + caseSensitive: Boolean, + useFieldId: Boolean): Seq[Type] = { + val toParquet = new SparkToParquetSchemaConverter( + writeLegacyParquetFormat = false, useFieldId = useFieldId) + lazy val caseSensitiveParquetFieldMap = parquetRecord.getFields.asScala.map(f => f.getName -> f).toMap - structType.map { f => - caseSensitiveParquetFieldMap + lazy val caseInsensitiveParquetFieldMap = + parquetRecord.getFields.asScala.groupBy(_.getName.toLowerCase(Locale.ROOT)) + lazy val idToParquetFieldMap = + parquetRecord.getFields.asScala.filter(_.getId != null).groupBy(f => f.getId.intValue()) + + def matchCaseSensitiveField(f: StructField): Type = { + caseSensitiveParquetFieldMap .get(f.name) - .map(clipParquetType(_, f.dataType, caseSensitive)) + .map(clipParquetType(_, f.dataType, caseSensitive, useFieldId)) .getOrElse(toParquet.convertField(f)) - } - } else { + } + + def matchCaseInsensitiveField(f: StructField): Type = { // Do case-insensitive resolution only if in case-insensitive mode - val caseInsensitiveParquetFieldMap = - parquetRecord.getFields.asScala.groupBy(_.getName.toLowerCase(Locale.ROOT)) - structType.map { f => - caseInsensitiveParquetFieldMap + caseInsensitiveParquetFieldMap .get(f.name.toLowerCase(Locale.ROOT)) .map { parquetTypes => if (parquetTypes.size > 1) { @@ -349,9 +435,39 @@ object ParquetReadSupport { throw QueryExecutionErrors.foundDuplicateFieldInCaseInsensitiveModeError( f.name, parquetTypesString) } else { - clipParquetType(parquetTypes.head, f.dataType, caseSensitive) + clipParquetType(parquetTypes.head, f.dataType, caseSensitive, useFieldId) } }.getOrElse(toParquet.convertField(f)) + } + + def matchIdField(f: StructField): Type = { + val fieldId = ParquetUtils.getFieldId(f) + idToParquetFieldMap + .get(fieldId) + .map { parquetTypes => + if (parquetTypes.size > 1) { + // Need to fail if there is ambiguity, i.e. more than one field is matched + val parquetTypesString = parquetTypes.map(_.getName).mkString("[", ", ", "]") + throw QueryExecutionErrors.foundDuplicateFieldInFieldIdLookupModeError( + fieldId, parquetTypesString) + } else { + clipParquetType(parquetTypes.head, f.dataType, caseSensitive, useFieldId) + } + }.getOrElse { + // When there is no ID match, we use a fake name to avoid a name match by accident + // We need this name to be unique as well, otherwise there will be type conflicts + toParquet.convertField(f.copy(name = generateFakeColumnName)) + } + } + + val shouldMatchById = useFieldId && ParquetUtils.hasFieldIds(structType) + structType.map { f => + if (shouldMatchById && ParquetUtils.hasFieldId(f)) { + matchIdField(f) + } else if (caseSensitive) { + matchCaseSensitiveField(f) + } else { + matchCaseInsensitiveField(f) } } } @@ -410,4 +526,13 @@ object ParquetReadSupport { expand(schema).asInstanceOf[StructType] } + + /** + * Whether the parquet schema contains any field IDs. + */ + def containsFieldIds(schema: Type): Boolean = schema match { + case p: PrimitiveType => p.getId != null + // We don't require all fields to have IDs, so we use `exists` here. + case g: GroupType => g.getId != null || g.getFields.asScala.exists(containsFieldIds) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index b12898360dcf4..63ad5ed6db82e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -203,16 +203,38 @@ private[parquet] class ParquetRowConverter( private[this] val fieldConverters: Array[Converter with HasParentContainerUpdater] = { // (SPARK-31116) Use case insensitive map if spark.sql.caseSensitive is false // to prevent throwing IllegalArgumentException when searching catalyst type's field index - val catalystFieldNameToIndex = if (SQLConf.get.caseSensitiveAnalysis) { - catalystType.fieldNames.zipWithIndex.toMap + def nameToIndex: Map[String, Int] = catalystType.fieldNames.zipWithIndex.toMap + + val catalystFieldIdxByName = if (SQLConf.get.caseSensitiveAnalysis) { + nameToIndex } else { - CaseInsensitiveMap(catalystType.fieldNames.zipWithIndex.toMap) + CaseInsensitiveMap(nameToIndex) } + + // (SPARK-38094) parquet field ids, if exist, should be prioritized for matching + val catalystFieldIdxByFieldId = + if (SQLConf.get.parquetFieldIdReadEnabled && ParquetUtils.hasFieldIds(catalystType)) { + catalystType.fields + .zipWithIndex + .filter { case (f, _) => ParquetUtils.hasFieldId(f) } + .map { case (f, idx) => (ParquetUtils.getFieldId(f), idx) } + .toMap + } else { + Map.empty[Int, Int] + } + parquetType.getFields.asScala.map { parquetField => - val fieldIndex = catalystFieldNameToIndex(parquetField.getName) - val catalystField = catalystType(fieldIndex) + val catalystFieldIndex = Option(parquetField.getId).flatMap { fieldId => + // field has id, try to match by id first before falling back to match by name + catalystFieldIdxByFieldId.get(fieldId.intValue()) + }.getOrElse { + // field doesn't have id, just match by name + catalystFieldIdxByName(parquetField.getName) + } + val catalystField = catalystType(catalystFieldIndex) // Converted field value should be set to the `fieldIndex`-th cell of `currentRow` - newConverter(parquetField, catalystField.dataType, new RowUpdater(currentRow, fieldIndex)) + newConverter(parquetField, + catalystField.dataType, new RowUpdater(currentRow, catalystFieldIndex)) }.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index 352e5f01172f2..34a4eb8c002d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -434,20 +434,25 @@ class ParquetToSparkSchemaConverter( * When set to false, use standard format defined in parquet-format spec. This argument only * affects Parquet write path. * @param outputTimestampType which parquet timestamp type to use when writing. + * @param useFieldId whether we should include write field id to Parquet schema. Set this to false + * via `spark.sql.parquet.fieldId.write.enabled = false` to disable writing field ids. */ class SparkToParquetSchemaConverter( writeLegacyParquetFormat: Boolean = SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get, outputTimestampType: SQLConf.ParquetOutputTimestampType.Value = - SQLConf.ParquetOutputTimestampType.INT96) { + SQLConf.ParquetOutputTimestampType.INT96, + useFieldId: Boolean = SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.defaultValue.get) { def this(conf: SQLConf) = this( writeLegacyParquetFormat = conf.writeLegacyParquetFormat, - outputTimestampType = conf.parquetOutputTimestampType) + outputTimestampType = conf.parquetOutputTimestampType, + useFieldId = conf.parquetFieldIdWriteEnabled) def this(conf: Configuration) = this( writeLegacyParquetFormat = conf.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key).toBoolean, outputTimestampType = SQLConf.ParquetOutputTimestampType.withName( - conf.get(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key))) + conf.get(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key)), + useFieldId = conf.get(SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.key).toBoolean) /** * Converts a Spark SQL [[StructType]] to a Parquet [[MessageType]]. @@ -463,11 +468,15 @@ class SparkToParquetSchemaConverter( * Converts a Spark SQL [[StructField]] to a Parquet [[Type]]. */ def convertField(field: StructField): Type = { - convertField(field, if (field.nullable) OPTIONAL else REQUIRED) + val converted = convertField(field, if (field.nullable) OPTIONAL else REQUIRED) + if (useFieldId && ParquetUtils.hasFieldId(field)) { + converted.withId(ParquetUtils.getFieldId(field)) + } else { + converted + } } private def convertField(field: StructField, repetition: Type.Repetition): Type = { - ParquetSchemaConverter.checkFieldName(field.name) field.dataType match { // =================== @@ -698,23 +707,6 @@ private[sql] object ParquetSchemaConverter { val EMPTY_MESSAGE: MessageType = Types.buildMessage().named(ParquetSchemaConverter.SPARK_PARQUET_SCHEMA_NAME) - def checkFieldName(name: String): Unit = { - // ,;{}()\n\t= and space are special characters in Parquet schema - if (name.matches(".*[ ,;{}()\n\t=].*")) { - throw QueryCompilationErrors.columnNameContainsInvalidCharactersError(name) - } - } - - def checkFieldNames(schema: StructType): Unit = { - schema.foreach { field => - checkFieldName(field.name) - field.dataType match { - case s: StructType => checkFieldNames(s) - case _ => - } - } - } - def checkConversionRequirement(f: => Boolean, message: String): Unit = { if (!f) { throw new AnalysisException(message) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index 87a0d9c860f31..2c565c8890e70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -35,8 +35,9 @@ import org.apache.spark.sql.catalyst.expressions.JoinedRow import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min} import org.apache.spark.sql.execution.datasources.AggregatePushDownUtils +import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils import org.apache.spark.sql.internal.SQLConf.{LegacyBehaviorPolicy, PARQUET_AGGREGATE_PUSHDOWN_ENABLED} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType} object ParquetUtils { def inferSchema( @@ -144,6 +145,48 @@ object ParquetUtils { file.getName == ParquetFileWriter.PARQUET_METADATA_FILE } + /** + * A StructField metadata key used to set the field id of a column in the Parquet schema. + */ + val FIELD_ID_METADATA_KEY = "parquet.field.id" + + /** + * Whether there exists a field in the schema, whether inner or leaf, has the parquet field + * ID metadata. + */ + def hasFieldIds(schema: StructType): Boolean = { + def recursiveCheck(schema: DataType): Boolean = { + schema match { + case st: StructType => + st.exists(field => hasFieldId(field) || recursiveCheck(field.dataType)) + + case at: ArrayType => recursiveCheck(at.elementType) + + case mt: MapType => recursiveCheck(mt.keyType) || recursiveCheck(mt.valueType) + + case _ => + // No need to really check primitive types, just to terminate the recursion + false + } + } + if (schema.isEmpty) false else recursiveCheck(schema) + } + + def hasFieldId(field: StructField): Boolean = + field.metadata.contains(FIELD_ID_METADATA_KEY) + + def getFieldId(field: StructField): Int = { + require(hasFieldId(field), + s"The key `$FIELD_ID_METADATA_KEY` doesn't exist in the metadata of " + field) + try { + Math.toIntExact(field.metadata.getLong(FIELD_ID_METADATA_KEY)) + } catch { + case _: ArithmeticException | _: ClassCastException => + throw new IllegalArgumentException( + s"The key `$FIELD_ID_METADATA_KEY` must be a 32-bit integer") + } + } + /** * When the partial aggregates (Max/Min/Count) are pushed down to Parquet, we don't need to * createRowBaseReader to read data from Parquet and aggregate at Spark layer. Instead we want @@ -248,32 +291,33 @@ object ParquetUtils { blocks.forEach { block => val blockMetaData = block.getColumns agg match { - case max: Max => - val colName = max.column.fieldNames.head + case max: Max if V2ColumnUtils.extractV2Column(max.column).isDefined => + val colName = V2ColumnUtils.extractV2Column(max.column).get index = dataSchema.fieldNames.toList.indexOf(colName) schemaName = "max(" + colName + ")" val currentMax = getCurrentBlockMaxOrMin(filePath, blockMetaData, index, true) if (value == None || currentMax.asInstanceOf[Comparable[Any]].compareTo(value) > 0) { value = currentMax } - case min: Min => - val colName = min.column.fieldNames.head + case min: Min if V2ColumnUtils.extractV2Column(min.column).isDefined => + val colName = V2ColumnUtils.extractV2Column(min.column).get index = dataSchema.fieldNames.toList.indexOf(colName) schemaName = "min(" + colName + ")" val currentMin = getCurrentBlockMaxOrMin(filePath, blockMetaData, index, false) if (value == None || currentMin.asInstanceOf[Comparable[Any]].compareTo(value) < 0) { value = currentMin } - case count: Count => - schemaName = "count(" + count.column.fieldNames.head + ")" + case count: Count if V2ColumnUtils.extractV2Column(count.column).isDefined => + val colName = V2ColumnUtils.extractV2Column(count.column).get + schemaName = "count(" + colName + ")" rowCount += block.getRowCount var isPartitionCol = false - if (partitionSchema.fields.map(_.name).toSet.contains(count.column.fieldNames.head)) { + if (partitionSchema.fields.map(_.name).toSet.contains(colName)) { isPartitionCol = true } isCount = true if (!isPartitionCol) { - index = dataSchema.fieldNames.toList.indexOf(count.column.fieldNames.head) + index = dataSchema.fieldNames.toList.indexOf(colName) // Count(*) includes the null values, but Count(colName) doesn't. rowCount -= getNumNulls(filePath, blockMetaData, index) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 605f7c17fed30..af43f8d1c1bd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -319,15 +319,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi conf.resolver) if (schema.nonEmpty && normalizedPartitionCols.length == schema.length) { - if (DDLUtils.isHiveTable(table)) { - // When we hit this branch, it means users didn't specify schema for the table to be - // created, as we always include partition columns in table schema for hive serde tables. - // The real schema will be inferred at hive metastore by hive serde, plus the given - // partition columns, so we should not fail the analysis here. - } else { - failAnalysis("Cannot use all columns for partition columns") - } - + failAnalysis("Cannot use all columns for partition columns") } schema.filter(f => normalizedPartitionCols.contains(f.name)).map(_.dataType).foreach { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 217a1d5750d42..a1eb857c4ed41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -70,6 +70,7 @@ class DataSourceRDD( // In case of early stopping before consuming the entire iterator, // we need to do one more metric update at the end of the task. CustomMetrics.updateMetrics(reader.currentMetricsValues, customMetrics) + iter.forceUpdateMetrics() reader.close() } // TODO: SPARK-25083 remove the type erasure hack in data source scan @@ -130,10 +131,12 @@ private abstract class MetricsIterator[I](iter: Iterator[I]) extends Iterator[I] if (iter.hasNext) { true } else { - metricsHandler.updateMetrics(0, force = true) + forceUpdateMetrics() false } } + + def forceUpdateMetrics(): Unit = metricsHandler.updateMetrics(0, force = true) } private class MetricsRowIterator( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropNamespaceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropNamespaceExec.scala index 9a9d8e1d4d57d..5d302055e7d91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropNamespaceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropNamespaceExec.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.connector.catalog.CatalogPlugin import org.apache.spark.sql.errors.QueryCompilationErrors @@ -37,17 +38,11 @@ case class DropNamespaceExec( val nsCatalog = catalog.asNamespaceCatalog val ns = namespace.toArray if (nsCatalog.namespaceExists(ns)) { - // The default behavior of `SupportsNamespace.dropNamespace()` is cascading, - // so make sure the namespace to drop is empty. - if (!cascade) { - if (catalog.asTableCatalog.listTables(ns).nonEmpty - || nsCatalog.listNamespaces(ns).nonEmpty) { + try { + nsCatalog.dropNamespace(ns, cascade) + } catch { + case _: NonEmptyNamespaceException => throw QueryCompilationErrors.cannotDropNonemptyNamespaceError(namespace) - } - } - - if (!nsCatalog.dropNamespace(ns)) { - throw QueryCompilationErrors.cannotDropNonemptyNamespaceError(namespace) } } else if (!ifExists) { throw QueryCompilationErrors.noSuchNamespaceError(ns) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FilePartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FilePartitionReaderFactory.scala index 5e160228c60e3..da4f9e89fde8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FilePartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FilePartitionReaderFactory.scala @@ -26,7 +26,7 @@ abstract class FilePartitionReaderFactory extends PartitionReaderFactory { override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { assert(partition.isInstanceOf[FilePartition]) val filePartition = partition.asInstanceOf[FilePartition] - val iter = filePartition.files.toIterator.map { file => + val iter = filePartition.files.iterator.map { file => PartitionedFileReader(file, buildReader(file)) } new FilePartitionReader[InternalRow](iter) @@ -35,7 +35,7 @@ abstract class FilePartitionReaderFactory extends PartitionReaderFactory { override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] = { assert(partition.isInstanceOf[FilePartition]) val filePartition = partition.asInstanceOf[FilePartition] - val iter = filePartition.files.toIterator.map { file => + val iter = filePartition.files.iterator.map { file => PartitionedFileReader(file, buildColumnarReader(file)) } new FilePartitionReader[ColumnarBatch](iter) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index db7b3dc7248f3..9953658b65488 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -20,13 +20,11 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning} -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.connector.expressions.{FieldReference, SortOrder} -import org.apache.spark.sql.connector.expressions.aggregate.Aggregation +import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters} -import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnWithoutNestedColumn} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters} +import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources import org.apache.spark.sql.types.StructType @@ -106,34 +104,6 @@ object PushDownUtils extends PredicateHelper { } } - /** - * Pushes down aggregates to the data source reader - * - * @return pushed aggregation. - */ - def pushAggregates( - scanBuilder: SupportsPushDownAggregates, - aggregates: Seq[AggregateExpression], - groupBy: Seq[Expression]): Option[Aggregation] = { - - def columnAsString(e: Expression): Option[FieldReference] = e match { - case PushableColumnWithoutNestedColumn(name) => - Some(FieldReference.column(name).asInstanceOf[FieldReference]) - case _ => None - } - - val translatedAggregates = aggregates.flatMap(DataSourceStrategy.translateAggregate) - val translatedGroupBys = groupBy.flatMap(columnAsString) - - if (translatedAggregates.length != aggregates.length || - translatedGroupBys.length != groupBy.length) { - return None - } - - val agg = new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray) - Some(agg).filter(scanBuilder.pushAggregation) - } - /** * Pushes down TableSample to the data source Scan */ @@ -187,7 +157,7 @@ object PushDownUtils extends PredicateHelper { case r: SupportsPushDownRequiredColumns if SQLConf.get.nestedSchemaPruningEnabled => val rootFields = SchemaPruning.identifyRootFields(projects, filters) val prunedSchema = if (rootFields.nonEmpty) { - SchemaPruning.pruneDataSchema(relation.schema, rootFields) + SchemaPruning.pruneSchema(relation.schema, rootFields) } else { new StructType() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowCreateTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowCreateTableExec.scala index 5eaa16961886b..06f5a08ffd9c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowCreateTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowCreateTableExec.scala @@ -21,9 +21,11 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.util.escapeSingleQuotedString +import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, CharVarcharUtils} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Table, TableCatalog} +import org.apache.spark.sql.connector.expressions.BucketTransform import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.unsafe.types.UTF8String @@ -34,7 +36,7 @@ case class ShowCreateTableExec( output: Seq[Attribute], table: Table) extends V2CommandExec with LeafExecNode { override protected def run(): Seq[InternalRow] = { - val builder = StringBuilder.newBuilder + val builder = new StringBuilder showCreateTable(table, builder) Seq(InternalRow(UTF8String.fromString(builder.toString))) } @@ -57,7 +59,7 @@ case class ShowCreateTableExec( } private def showTableDataColumns(table: Table, builder: StringBuilder): Unit = { - val columns = table.schema().fields.map(_.toDDL) + val columns = CharVarcharUtils.getRawSchema(table.schema(), conf).fields.map(_.toDDL) builder ++= concatByMultiLines(columns) } @@ -71,10 +73,11 @@ case class ShowCreateTableExec( builder: StringBuilder, tableOptions: Map[String, String]): Unit = { if (tableOptions.nonEmpty) { - val props = tableOptions.toSeq.sortBy(_._1).map { case (key, value) => - s"'${escapeSingleQuotedString(key)}' = '${escapeSingleQuotedString(value)}'" + val props = conf.redactOptions(tableOptions).toSeq.sortBy(_._1).map { + case (key, value) => + s"'${escapeSingleQuotedString(key)}' = '${escapeSingleQuotedString(value)}'" } - builder ++= "OPTIONS" + builder ++= "OPTIONS " builder ++= concatByMultiLines(props) } } @@ -82,8 +85,31 @@ case class ShowCreateTableExec( private def showTablePartitioning(table: Table, builder: StringBuilder): Unit = { if (!table.partitioning.isEmpty) { val transforms = new ArrayBuffer[String] - table.partitioning.foreach(t => transforms += t.describe()) - builder ++= s"PARTITIONED BY ${transforms.mkString("(", ", ", ")")}\n" + var bucketSpec = Option.empty[BucketSpec] + table.partitioning.map { + case BucketTransform(numBuckets, col, sortCol) => + if (sortCol.isEmpty) { + bucketSpec = Some(BucketSpec(numBuckets, col.map(_.fieldNames.mkString(".")), Nil)) + } else { + bucketSpec = Some(BucketSpec(numBuckets, col.map(_.fieldNames.mkString(".")), + sortCol.map(_.fieldNames.mkString(".")))) + } + case t => + transforms += t.describe() + } + if (transforms.nonEmpty) { + builder ++= s"PARTITIONED BY ${transforms.mkString("(", ", ", ")")}\n" + } + + // compatible with v1 + bucketSpec.map { bucket => + assert(bucket.bucketColumnNames.nonEmpty) + builder ++= s"CLUSTERED BY ${bucket.bucketColumnNames.mkString("(", ", ", ")")}\n" + if (bucket.sortColumnNames.nonEmpty) { + builder ++= s"SORTED BY ${bucket.sortColumnNames.mkString("(", ", ", ")")}\n" + } + builder ++= s"INTO ${bucket.numBuckets} BUCKETS\n" + } } } @@ -98,7 +124,6 @@ case class ShowCreateTableExec( builder: StringBuilder, tableOptions: Map[String, String]): Unit = { - val showProps = table.properties.asScala .filterKeys(key => !CatalogV2Util.TABLE_RESERVED_PROPERTIES.contains(key) && !key.startsWith(TableCatalog.OPTION_PREFIX) @@ -123,5 +148,4 @@ case class ShowCreateTableExec( private def concatByMultiLines(iter: Iterable[String]): String = { iter.mkString("(\n ", ",\n ", ")\n") } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ColumnUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ColumnUtils.scala new file mode 100644 index 0000000000000..9fc220f440bc1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ColumnUtils.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.connector.expressions.{Expression, NamedReference} + +object V2ColumnUtils { + def extractV2Column(expr: Expression): Option[String] = expr match { + case r: NamedReference if r. fieldNames.length == 1 => Some(r.fieldNames.head) + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2CommandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2CommandExec.scala index fee9137c6ba1d..31e4a772dc1a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2CommandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2CommandExec.scala @@ -48,7 +48,7 @@ abstract class V2CommandExec extends SparkPlan { */ override def executeCollect(): Array[InternalRow] = result.toArray - override def executeToIterator(): Iterator[InternalRow] = result.toIterator + override def executeToIterator(): Iterator[InternalRow] = result.iterator override def executeTake(limit: Int): Array[InternalRow] = result.take(limit).toArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index dec7189ac698d..3ff917664b486 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -19,26 +19,35 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Cast, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Cast, Divide, DivideDTInterval, DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.expressions.SortOrder -import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, GeneralAggregateFunc} +import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, GeneralAggregateFunc} import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources -import org.apache.spark.sql.types.{DataType, LongType, StructType} +import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, LongType, StructType, YearMonthIntervalType} import org.apache.spark.sql.util.SchemaUtils._ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { import DataSourceV2Implicits._ def apply(plan: LogicalPlan): LogicalPlan = { - applyColumnPruning( - applyLimit(pushDownAggregates(pushDownFilters(pushDownSample(createScanBuilder(plan)))))) + val pushdownRules = Seq[LogicalPlan => LogicalPlan] ( + createScanBuilder, + pushDownSample, + pushDownFilters, + pushDownAggregates, + pushDownLimits, + pruneColumns) + + pushdownRules.foldLeft(plan) { (newPlan, pushDownRule) => + pushDownRule(newPlan) + } } private def createScanBuilder(plan: LogicalPlan) = plan.transform { @@ -88,25 +97,66 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { sHolder.builder match { case r: SupportsPushDownAggregates => val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] - var ordinal = 0 - val aggregates = resultExpressions.flatMap { expr => - expr.collect { - // Do not push down duplicated aggregate expressions. For example, - // `SELECT max(a) + 1, max(a) + 2 FROM ...`, we should only push down one - // `max(a)` to the data source. - case agg: AggregateExpression - if !aggExprToOutputOrdinal.contains(agg.canonicalized) => - aggExprToOutputOrdinal(agg.canonicalized) = ordinal - ordinal += 1 - agg - } - } + val aggregates = collectAggregates(resultExpressions, aggExprToOutputOrdinal) val normalizedAggregates = DataSourceStrategy.normalizeExprs( aggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]] val normalizedGroupingExpressions = DataSourceStrategy.normalizeExprs( groupingExpressions, sHolder.relation.output) - val pushedAggregates = PushDownUtils.pushAggregates( - r, normalizedAggregates, normalizedGroupingExpressions) + val translatedAggregates = DataSourceStrategy.translateAggregation( + normalizedAggregates, normalizedGroupingExpressions) + val (finalResultExpressions, finalAggregates, finalTranslatedAggregates) = { + if (translatedAggregates.isEmpty || + r.supportCompletePushDown(translatedAggregates.get) || + translatedAggregates.get.aggregateExpressions().forall(!_.isInstanceOf[Avg])) { + (resultExpressions, aggregates, translatedAggregates) + } else { + // scalastyle:off + // The data source doesn't support the complete push-down of this aggregation. + // Here we translate `AVG` to `SUM / COUNT`, so that it's more likely to be + // pushed, completely or partially. + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT avg(c1) FROM t GROUP BY c2; + // The original logical plan is + // Aggregate [c2#10],[avg(c1#9) AS avg(c1)#19] + // +- ScanOperation[...] + // + // After convert avg(c1#9) to sum(c1#9)/count(c1#9) + // we have the following + // Aggregate [c2#10],[sum(c1#9)/count(c1#9) AS avg(c1)#19] + // +- ScanOperation[...] + // scalastyle:on + val newResultExpressions = resultExpressions.map { expr => + expr.transform { + case AggregateExpression(avg: aggregate.Average, _, isDistinct, _, _) => + val sum = aggregate.Sum(avg.child).toAggregateExpression(isDistinct) + val count = aggregate.Count(avg.child).toAggregateExpression(isDistinct) + // Closely follow `Average.evaluateExpression` + avg.dataType match { + case _: YearMonthIntervalType => + If(EqualTo(count, Literal(0L)), + Literal(null, YearMonthIntervalType()), DivideYMInterval(sum, count)) + case _: DayTimeIntervalType => + If(EqualTo(count, Literal(0L)), + Literal(null, DayTimeIntervalType()), DivideDTInterval(sum, count)) + case _ => + // TODO deal with the overflow issue + Divide(addCastIfNeeded(sum, avg.dataType), + addCastIfNeeded(count, avg.dataType), false) + } + } + }.asInstanceOf[Seq[NamedExpression]] + // Because aggregate expressions changed, translate them again. + aggExprToOutputOrdinal.clear() + val newAggregates = + collectAggregates(newResultExpressions, aggExprToOutputOrdinal) + val newNormalizedAggregates = DataSourceStrategy.normalizeExprs( + newAggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]] + (newResultExpressions, newAggregates, DataSourceStrategy.translateAggregation( + newNormalizedAggregates, normalizedGroupingExpressions)) + } + } + + val pushedAggregates = finalTranslatedAggregates.filter(r.pushAggregation) if (pushedAggregates.isEmpty) { aggNode // return original plan node } else if (!supportPartialAggPushDown(pushedAggregates.get) && @@ -129,7 +179,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] // scalastyle:on val newOutput = scan.readSchema().toAttributes - assert(newOutput.length == groupingExpressions.length + aggregates.length) + assert(newOutput.length == groupingExpressions.length + finalAggregates.length) val groupAttrs = normalizedGroupingExpressions.zip(newOutput).map { case (a: Attribute, b: Attribute) => b.withExprId(a.exprId) case (_, b) => b @@ -164,7 +214,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { Project(projectExpressions, scanRelation) } else { val plan = Aggregate( - output.take(groupingExpressions.length), resultExpressions, scanRelation) + output.take(groupingExpressions.length), finalResultExpressions, scanRelation) // scalastyle:off // Change the optimized logical plan to reflect the pushed down aggregate @@ -210,19 +260,36 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } } + private def collectAggregates(resultExpressions: Seq[NamedExpression], + aggExprToOutputOrdinal: mutable.HashMap[Expression, Int]): Seq[AggregateExpression] = { + var ordinal = 0 + resultExpressions.flatMap { expr => + expr.collect { + // Do not push down duplicated aggregate expressions. For example, + // `SELECT max(a) + 1, max(a) + 2 FROM ...`, we should only push down one + // `max(a)` to the data source. + case agg: AggregateExpression + if !aggExprToOutputOrdinal.contains(agg.canonicalized) => + aggExprToOutputOrdinal(agg.canonicalized) = ordinal + ordinal += 1 + agg + } + } + } + private def supportPartialAggPushDown(agg: Aggregation): Boolean = { // We don't know the agg buffer of `GeneralAggregateFunc`, so can't do partial agg push down. agg.aggregateExpressions().forall(!_.isInstanceOf[GeneralAggregateFunc]) } - private def addCastIfNeeded(aggAttribute: AttributeReference, aggDataType: DataType) = - if (aggAttribute.dataType == aggDataType) { - aggAttribute + private def addCastIfNeeded(expression: Expression, expectedDataType: DataType) = + if (expression.dataType == expectedDataType) { + expression } else { - Cast(aggAttribute, aggDataType) + Cast(expression, expectedDataType) } - def applyColumnPruning(plan: LogicalPlan): LogicalPlan = plan.transform { + def pruneColumns(plan: LogicalPlan): LogicalPlan = plan.transform { case ScanOperation(project, filters, sHolder: ScanBuilderHolder) => // column pruning val normalizedProjects = DataSourceStrategy @@ -308,7 +375,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { case other => other } - def applyLimit(plan: LogicalPlan): LogicalPlan = plan.transform { + def pushDownLimits(plan: LogicalPlan): LogicalPlan = plan.transform { case globalLimit @ Limit(IntegerLiteral(limitValue), child) => val newChild = pushDownLimit(child, limitValue) val newLocalLimit = globalLimit.child.asInstanceOf[LocalLimit].withNewChildren(Seq(newChild)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala index 906107a1227f8..b9a4e0e6ba30b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala @@ -25,11 +25,11 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.{FunctionIdentifier, SQLConfHelper, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, TableAlreadyExistsException} -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogDatabase, CatalogTable, CatalogTableType, CatalogUtils, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTableType, CatalogUtils, SessionCatalog} import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogV2Util, FunctionCatalog, Identifier, NamespaceChange, SupportsNamespaces, Table, TableCatalog, TableChange, V1Table} import org.apache.spark.sql.connector.catalog.NamespaceChange.RemoveProperty import org.apache.spark.sql.connector.catalog.functions.UnboundFunction -import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, IdentityTransform, Transform} +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.internal.connector.V1Function @@ -96,8 +96,8 @@ class V2SessionCatalog(catalog: SessionCatalog) schema: StructType, partitions: Array[Transform], properties: util.Map[String, String]): Table = { - - val (partitionColumns, maybeBucketSpec) = V2SessionCatalog.convertTransforms(partitions) + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.TransformHelper + val (partitionColumns, maybeBucketSpec) = partitions.toSeq.convertTransforms val provider = properties.getOrDefault(TableCatalog.PROP_PROVIDER, conf.defaultDataSourceName) val tableProperties = properties.asScala val location = Option(properties.get(TableCatalog.PROP_LOCATION)) @@ -286,12 +286,11 @@ class V2SessionCatalog(catalog: SessionCatalog) } } - override def dropNamespace(namespace: Array[String]): Boolean = namespace match { + override def dropNamespace( + namespace: Array[String], + cascade: Boolean): Boolean = namespace match { case Array(db) if catalog.databaseExists(db) => - if (catalog.listTables(db).nonEmpty) { - throw QueryExecutionErrors.namespaceNotEmptyError(namespace) - } - catalog.dropDatabase(db, ignoreIfNotExists = false, cascade = false) + catalog.dropDatabase(db, ignoreIfNotExists = false, cascade) true case Array(_) => @@ -331,27 +330,6 @@ class V2SessionCatalog(catalog: SessionCatalog) private[sql] object V2SessionCatalog { - /** - * Convert v2 Transforms to v1 partition columns and an optional bucket spec. - */ - private def convertTransforms(partitions: Seq[Transform]): (Seq[String], Option[BucketSpec]) = { - val identityCols = new mutable.ArrayBuffer[String] - var bucketSpec = Option.empty[BucketSpec] - - partitions.map { - case IdentityTransform(FieldReference(Seq(col))) => - identityCols += col - - case BucketTransform(numBuckets, FieldReference(Seq(col)), FieldReference(Seq(sortCol))) => - bucketSpec = Some(BucketSpec(numBuckets, col :: Nil, sortCol :: Nil)) - - case transform => - throw QueryExecutionErrors.unsupportedPartitionTransformError(transform) - } - - (identityCols.toSeq, bucketSpec) - } - private def toCatalogDatabase( db: String, metadata: util.Map[String, String], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala index 8494bba078552..38f741532d786 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala @@ -22,11 +22,15 @@ import java.util.UUID import org.apache.spark.sql.catalyst.expressions.PredicateHelper import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.catalog.Table -import org.apache.spark.sql.connector.write.{LogicalWriteInfoImpl, SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, WriteBuilder} +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table} +import org.apache.spark.sql.connector.write.{LogicalWriteInfoImpl, SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, Write, WriteBuilder} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWrite, WriteToMicroBatchDataSource} +import org.apache.spark.sql.internal.connector.SupportsStreamingUpdateAsAppend import org.apache.spark.sql.sources.{AlwaysTrue, Filter} +import org.apache.spark.sql.streaming.OutputMode /** * A rule that constructs logical writes. @@ -77,6 +81,36 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { } val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, conf) o.copy(write = Some(write), query = newQuery) + + case WriteToMicroBatchDataSource( + relation, table, query, queryId, writeOptions, outputMode, Some(batchId)) => + + val writeBuilder = newWriteBuilder(table, query, writeOptions, queryId) + val write = buildWriteForMicroBatch(table, writeBuilder, outputMode) + val microBatchWrite = new MicroBatchWrite(batchId, write.toStreaming) + val customMetrics = write.supportedCustomMetrics.toSeq + val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, conf) + WriteToDataSourceV2(relation, microBatchWrite, newQuery, customMetrics) + } + + private def buildWriteForMicroBatch( + table: SupportsWrite, + writeBuilder: WriteBuilder, + outputMode: OutputMode): Write = { + + outputMode match { + case Append => + writeBuilder.build() + case Complete => + // TODO: we should do this check earlier when we have capability API. + require(writeBuilder.isInstanceOf[SupportsTruncate], + table.name + " does not support Complete mode.") + writeBuilder.asInstanceOf[SupportsTruncate].truncate().build() + case Update => + require(writeBuilder.isInstanceOf[SupportsStreamingUpdateAsAppend], + table.name + " does not support Update mode.") + writeBuilder.asInstanceOf[SupportsStreamingUpdateAsAppend].build() + } } private def isTruncate(filters: Array[Filter]): Boolean = { @@ -86,12 +120,10 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { private def newWriteBuilder( table: Table, query: LogicalPlan, - writeOptions: Map[String, String]): WriteBuilder = { + writeOptions: Map[String, String], + queryId: String = UUID.randomUUID().toString): WriteBuilder = { - val info = LogicalWriteInfoImpl( - queryId = UUID.randomUUID().toString, - query.schema, - writeOptions.asOptions) + val info = LogicalWriteInfoImpl(queryId, query.schema, writeOptions.asOptions) table.asWritable.newWriteBuilder(info) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala index 31d31bd43f453..bf996ab1b3111 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala @@ -46,7 +46,6 @@ case class CSVPartitionReaderFactory( partitionSchema: StructType, parsedOptions: CSVOptions, filters: Seq[Filter]) extends FilePartitionReaderFactory { - private val columnPruning = sqlConf.csvColumnPruning override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = { val conf = broadcastedConf.value.value @@ -59,7 +58,7 @@ case class CSVPartitionReaderFactory( actualReadDataSchema, parsedOptions, filters) - val schema = if (columnPruning) actualReadDataSchema else actualDataSchema + val schema = if (parsedOptions.columnPruning) actualReadDataSchema else actualDataSchema val isStartOfFile = file.start == 0 val headerChecker = new CSVHeaderChecker( schema, parsedOptions, source = s"CSV file: ${file.filePath}", isStartOfFile) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala index cc3c146106670..5c33a1047a12f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.catalyst.expressions.{Expression, ExprUtils} import org.apache.spark.sql.connector.read.PartitionReaderFactory -import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.csv.CSVDataSource import org.apache.spark.sql.execution.datasources.v2.TextBasedFileScan @@ -45,9 +44,11 @@ case class CSVScan( dataFilters: Seq[Expression] = Seq.empty) extends TextBasedFileScan(sparkSession, options) { + val columnPruning = sparkSession.sessionState.conf.csvColumnPruning && + !readDataSchema.exists(_.name == sparkSession.sessionState.conf.columnNameOfCorruptRecord) private lazy val parsedOptions: CSVOptions = new CSVOptions( options.asScala.toMap, - columnPruning = sparkSession.sessionState.conf.csvColumnPruning, + columnPruning = columnPruning, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) @@ -67,11 +68,10 @@ case class CSVScan( override def createReaderFactory(): PartitionReaderFactory = { // Check a field requirement for corrupt records here to throw an exception in a driver side ExprUtils.verifyColumnNameOfCorruptRecord(dataSchema, parsedOptions.columnNameOfCorruptRecord) - - if (readDataSchema.length == 1 && - readDataSchema.head.name == parsedOptions.columnNameOfCorruptRecord) { - throw QueryCompilationErrors.queryFromRawFilesIncludeCorruptRecordColumnError() - } + // Don't push any filter which refers to the "virtual" column which cannot present in the input. + // Such filters will be applied later on the upper layer. + val actualFilters = + pushedFilters.filterNot(_.references.contains(parsedOptions.columnNameOfCorruptRecord)) val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap // Hadoop Configurations are case sensitive. @@ -81,7 +81,7 @@ case class CSVScan( // The partition values are already truncated in `FileScan.partitions`. // We should use `readPartitionSchema` as the partition schema here. CSVPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, - dataSchema, readDataSchema, readPartitionSchema, parsedOptions, pushedFilters) + dataSchema, readDataSchema, readPartitionSchema, parsedOptions, actualFilters) } override def equals(obj: Any): Boolean = obj match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala index 566706486d3f0..03200d5a6f371 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala @@ -21,7 +21,6 @@ import java.util import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.ArrayBuilder import org.apache.spark.internal.Logging import org.apache.spark.sql.connector.catalog.{Identifier, NamespaceChange, SupportsNamespaces, Table, TableCatalog, TableChange} @@ -173,23 +172,14 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging override def namespaceExists(namespace: Array[String]): Boolean = namespace match { case Array(db) => JdbcUtils.withConnection(options) { conn => - val rs = conn.getMetaData.getSchemas(null, db) - while (rs.next()) { - if (rs.getString(1) == db) return true; - } - false + JdbcUtils.schemaExists(conn, options, db) } case _ => false } override def listNamespaces(): Array[Array[String]] = { JdbcUtils.withConnection(options) { conn => - val schemaBuilder = ArrayBuilder.make[Array[String]] - val rs = conn.getMetaData.getSchemas() - while (rs.next()) { - schemaBuilder += Array(rs.getString(1)) - } - schemaBuilder.result + JdbcUtils.listSchemas(conn, options) } } @@ -236,7 +226,7 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging } JdbcUtils.withConnection(options) { conn => JdbcUtils.classifyException(s"Failed create name space: $db", dialect) { - JdbcUtils.createNamespace(conn, options, db, comment) + JdbcUtils.createSchema(conn, options, db, comment) } } @@ -254,7 +244,9 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging case set: NamespaceChange.SetProperty => if (set.property() == SupportsNamespaces.PROP_COMMENT) { JdbcUtils.withConnection(options) { conn => - JdbcUtils.createNamespaceComment(conn, options, db, set.value) + JdbcUtils.classifyException(s"Failed create comment on name space: $db", dialect) { + JdbcUtils.alterSchemaComment(conn, options, db, set.value) + } } } else { throw QueryCompilationErrors.cannotSetJDBCNamespaceWithPropertyError(set.property) @@ -263,7 +255,9 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging case unset: NamespaceChange.RemoveProperty => if (unset.property() == SupportsNamespaces.PROP_COMMENT) { JdbcUtils.withConnection(options) { conn => - JdbcUtils.removeNamespaceComment(conn, options, db) + JdbcUtils.classifyException(s"Failed remove comment on name space: $db", dialect) { + JdbcUtils.removeSchemaComment(conn, options, db) + } } } else { throw QueryCompilationErrors.cannotUnsetJDBCNamespaceWithPropertyError(unset.property) @@ -278,14 +272,13 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging } } - override def dropNamespace(namespace: Array[String]): Boolean = namespace match { + override def dropNamespace( + namespace: Array[String], + cascade: Boolean): Boolean = namespace match { case Array(db) if namespaceExists(namespace) => - if (listTables(Array(db)).nonEmpty) { - throw QueryExecutionErrors.namespaceNotEmptyError(namespace) - } JdbcUtils.withConnection(options) { conn => JdbcUtils.classifyException(s"Failed drop name space: $db", dialect) { - JdbcUtils.dropNamespace(conn, options, db) + JdbcUtils.dropSchema(conn, options, db, cascade) true } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala index 0e6c72c2cc331..7449f66ee020f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala @@ -20,6 +20,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.connector.write._ import org.apache.spark.sql.execution.datasources.jdbc.{JdbcOptionsInWrite, JdbcUtils} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources.InsertableRelation import org.apache.spark.sql.types.StructType @@ -37,7 +38,8 @@ case class JDBCWriteBuilder(schema: StructType, options: JdbcOptionsInWrite) ext override def toInsertableRelation: InsertableRelation = (data: DataFrame, _: Boolean) => { // TODO (SPARK-32595): do truncate and append atomically. if (isTruncate) { - val conn = JdbcUtils.createConnectionFactory(options)() + val dialect = JdbcDialects.get(options.url) + val conn = dialect.createConnectionFactory(options)(-1) JdbcUtils.truncateTable(conn, options) } JdbcUtils.saveTable(data, Some(schema), SQLConf.get.caseSensitiveAnalysis, options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index 41ee98a4f47b8..12b8a631196ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -202,7 +202,7 @@ case class ParquetPartitionReaderFactory( private def buildReaderBase[T]( file: PartitionedFile, buildReaderFunc: ( - FileSplit, InternalRow, TaskAttemptContextImpl, + InternalRow, Option[FilterPredicate], Option[ZoneId], RebaseSpec, RebaseSpec) => RecordReader[Void, T]): RecordReader[Void, T] = { @@ -261,9 +261,7 @@ case class ParquetPartitionReaderFactory( footerFileMetaData.getKeyValueMetaData.get, int96RebaseModeInRead) val reader = buildReaderFunc( - split, file.partitionValues, - hadoopAttemptContext, pushed, convertTz, datetimeRebaseSpec, @@ -277,9 +275,7 @@ case class ParquetPartitionReaderFactory( } private def createRowBaseParquetReader( - split: FileSplit, partitionValues: InternalRow, - hadoopAttemptContext: TaskAttemptContextImpl, pushed: Option[FilterPredicate], convertTz: Option[ZoneId], datetimeRebaseSpec: RebaseSpec, @@ -312,9 +308,7 @@ case class ParquetPartitionReaderFactory( } private def createParquetVectorizedReader( - split: FileSplit, partitionValues: InternalRow, - hadoopAttemptContext: TaskAttemptContextImpl, pushed: Option[FilterPredicate], convertTz: Option[ZoneId], datetimeRebaseSpec: RebaseSpec, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index 617faad8ab6d7..6b35f2406a82f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -78,8 +78,6 @@ case class ParquetScan( SQLConf.CASE_SENSITIVE.key, sparkSession.sessionState.conf.caseSensitiveAnalysis) - ParquetWriteSupport.setSchema(readDataSchema, hadoopConf) - // Sets flags for `ParquetToSparkSchemaConverter` hadoopConf.setBoolean( SQLConf.PARQUET_BINARY_AS_STRING.key, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetWrite.scala index b2b6d313e1bcd..d84acedb962e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetWrite.scala @@ -72,7 +72,6 @@ case class ParquetWrite( ParquetOutputFormat.setWriteSupportClass(job, classOf[ParquetWriteSupport]) - ParquetSchemaConverter.checkFieldNames(dataSchema) // This metadata is useful for keeping UDTs like Vector/Matrix. ParquetWriteSupport.setSchema(dataSchema, conf) @@ -82,6 +81,10 @@ case class ParquetWrite( conf.set(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key, sqlConf.parquetOutputTimestampType.toString) + conf.set( + SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.key, + sqlConf.parquetFieldIdWriteEnabled.toString) + // Sets compression scheme conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/CleanupDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/CleanupDynamicPruningFilters.scala index abf0cf63a0bb0..65621fb1860e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/CleanupDynamicPruningFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/CleanupDynamicPruningFilters.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.dynamicpruning import org.apache.spark.sql.catalyst.catalog.HiveTableRelation -import org.apache.spark.sql.catalyst.expressions.{DynamicPruning, PredicateHelper} +import org.apache.spark.sql.catalyst.expressions.{DynamicPruning, DynamicPruningSubquery, EqualNullSafe, EqualTo, Expression, ExpressionSet, PredicateHelper} import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} @@ -34,6 +34,33 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation */ object CleanupDynamicPruningFilters extends Rule[LogicalPlan] with PredicateHelper { + private def collectEqualityConditionExpressions(condition: Expression): Seq[Expression] = { + splitConjunctivePredicates(condition).flatMap(_.collect { + case EqualTo(l, r) if l.deterministic && r.foldable => l + case EqualTo(l, r) if r.deterministic && l.foldable => r + case EqualNullSafe(l, r) if l.deterministic && r.foldable => l + case EqualNullSafe(l, r) if r.deterministic && l.foldable => r + }) + } + + /** + * If a partition key already has equality conditions, then its DPP filter is useless and + * can't prune anything. So we should remove it. + */ + private def removeUnnecessaryDynamicPruningSubquery(plan: LogicalPlan): LogicalPlan = { + plan.transformWithPruning(_.containsPattern(DYNAMIC_PRUNING_SUBQUERY)) { + case f @ Filter(condition, _) => + val unnecessaryPruningKeys = ExpressionSet(collectEqualityConditionExpressions(condition)) + val newCondition = condition.transformWithPruning( + _.containsPattern(DYNAMIC_PRUNING_SUBQUERY)) { + case dynamicPruning: DynamicPruningSubquery + if unnecessaryPruningKeys.contains(dynamicPruning.pruningKey) => + TrueLiteral + } + f.copy(condition = newCondition) + } + } + override def apply(plan: LogicalPlan): LogicalPlan = { if (!conf.dynamicPartitionPruningEnabled) { return plan @@ -43,10 +70,13 @@ object CleanupDynamicPruningFilters extends Rule[LogicalPlan] with PredicateHelp // No-op for trees that do not contain dynamic pruning. _.containsAnyPattern(DYNAMIC_PRUNING_EXPRESSION, DYNAMIC_PRUNING_SUBQUERY)) { // pass through anything that is pushed down into PhysicalOperation - case p @ PhysicalOperation(_, _, LogicalRelation(_: HadoopFsRelation, _, _, _)) => p + case p @ PhysicalOperation(_, _, LogicalRelation(_: HadoopFsRelation, _, _, _)) => + removeUnnecessaryDynamicPruningSubquery(p) // pass through anything that is pushed down into PhysicalOperation - case p @ PhysicalOperation(_, _, HiveTableRelation(_, _, _, _, _)) => p - case p @ PhysicalOperation(_, _, _: DataSourceV2ScanRelation) => p + case p @ PhysicalOperation(_, _, HiveTableRelation(_, _, _, _, _)) => + removeUnnecessaryDynamicPruningSubquery(p) + case p @ PhysicalOperation(_, _, _: DataSourceV2ScanRelation) => + removeUnnecessaryDynamicPruningSubquery(p) // remove any Filters with DynamicPruning that didn't get pushed down to PhysicalOperation. case f @ Filter(condition, _) => val newCondition = condition.transformWithPruning( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala index 4b5f724ba6f85..3b5fc4aea5d8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala @@ -205,6 +205,7 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper { case _: BinaryComparison => true case _: In | _: InSet => true case _: StringPredicate => true + case BinaryPredicate(_) => true case _: MultiLikeBase => true case _ => false } @@ -213,10 +214,10 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper { * Search a filtering predicate in a given logical plan */ private def hasSelectivePredicate(plan: LogicalPlan): Boolean = { - plan.find { + plan.exists { case f: Filter => isLikelySelective(f.condition) case _ => false - }.isDefined + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala index 9a05e396d4a70..252565fd9077b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala @@ -58,13 +58,13 @@ case class PlanDynamicPruningFilters(sparkSession: SparkSession) // Using `sparkPlan` is a little hacky as it is based on the assumption that this rule is // the first to be applied (apart from `InsertAdaptiveSparkPlan`). val canReuseExchange = conf.exchangeReuseEnabled && buildKeys.nonEmpty && - plan.find { + plan.exists { case BroadcastHashJoinExec(_, _, _, BuildLeft, _, left, _, _) => left.sameResult(sparkPlan) case BroadcastHashJoinExec(_, _, _, BuildRight, _, _, right, _) => right.sameResult(sparkPlan) case _ => false - }.isDefined + } if (canReuseExchange) { val executedPlan = QueryExecution.prepareExecutedPlan(sparkSession, sparkPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ValidateRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ValidateRequirements.scala index 9538199590477..1ac6b809fd250 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ValidateRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ValidateRequirements.scala @@ -45,17 +45,7 @@ object ValidateRequirements extends Logging { assert(requiredChildDistributions.length == children.length) assert(requiredChildOrderings.length == children.length) - // Verify partition number. For (hash) clustered distribution, the corresponding children must - // have the same number of partitions. - val numPartitions = requiredChildDistributions.zipWithIndex.collect { - case (_: ClusteredDistribution, i) => i - }.map(i => children(i).outputPartitioning.numPartitions) - if (numPartitions.length > 1 && !numPartitions.tail.forall(_ == numPartitions.head)) { - logDebug(s"ValidateRequirements failed: different partition num in\n$plan") - return false - } - - children.zip(requiredChildDistributions.zip(requiredChildOrderings)).forall { + val satisfied = children.zip(requiredChildDistributions.zip(requiredChildOrderings)).forall { case (child, (distribution, ordering)) if !child.outputPartitioning.satisfies(distribution) || !SortOrder.orderingSatisfies(child.outputOrdering, ordering) => @@ -63,5 +53,21 @@ object ValidateRequirements extends Logging { false case _ => true } + + if (satisfied && children.length > 1 && + requiredChildDistributions.forall(_.isInstanceOf[ClusteredDistribution])) { + // Check the co-partitioning requirement. + val specs = children.map(_.outputPartitioning).zip(requiredChildDistributions).map { + case (p, d) => p.createShuffleSpec(d.asInstanceOf[ClusteredDistribution]) + } + if (specs.tail.forall(_.isCompatibleWith(specs.head))) { + true + } else { + logDebug(s"ValidateRequirements failed: children not co-partitioned in\n$plan") + false + } + } else { + satisfied + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index 4de35b9e06c5d..23b5b614369fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -356,7 +356,7 @@ case class BroadcastNestedLoopJoinExec( i += 1 } } - Seq(matched).toIterator + Seq(matched).iterator } matchedBuildRows.fold(new BitSet(relation.value.length))(_ | _) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 0e8bb84ee5d81..4595ea049ef70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -705,6 +705,13 @@ trait HashJoin extends JoinCodegenSupport { } object HashJoin extends CastSupport with SQLConfHelper { + + private def canRewriteAsLongType(keys: Seq[Expression]): Boolean = { + // TODO: support BooleanType, DateType and TimestampType + keys.forall(_.dataType.isInstanceOf[IntegralType]) && + keys.map(_.dataType.defaultSize).sum <= 8 + } + /** * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long. * @@ -712,9 +719,7 @@ object HashJoin extends CastSupport with SQLConfHelper { */ def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = { assert(keys.nonEmpty) - // TODO: support BooleanType, DateType and TimestampType - if (keys.exists(!_.dataType.isInstanceOf[IntegralType]) - || keys.map(_.dataType.defaultSize).sum > 8) { + if (!canRewriteAsLongType(keys)) { return keys } @@ -736,18 +741,28 @@ object HashJoin extends CastSupport with SQLConfHelper { * determine the number of bits to shift */ def extractKeyExprAt(keys: Seq[Expression], index: Int): Expression = { + assert(canRewriteAsLongType(keys)) // jump over keys that have a higher index value than the required key if (keys.size == 1) { assert(index == 0) - cast(BoundReference(0, LongType, nullable = false), keys(index).dataType) + Cast( + child = BoundReference(0, LongType, nullable = false), + dataType = keys(index).dataType, + timeZoneId = Option(conf.sessionLocalTimeZone), + ansiEnabled = false) } else { val shiftedBits = keys.slice(index + 1, keys.size).map(_.dataType.defaultSize * 8).sum val mask = (1L << (keys(index).dataType.defaultSize * 8)) - 1 // build the schema for unpacking the required key - cast(BitwiseAnd( + val castChild = BitwiseAnd( ShiftRightUnsigned(BoundReference(0, LongType, nullable = false), Literal(shiftedBits)), - Literal(mask)), keys(index).dataType) + Literal(mask)) + Cast( + child = castChild, + dataType = keys(index).dataType, + timeZoneId = Option(conf.sessionLocalTimeZone), + ansiEnabled = false) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 0ea245093e3ed..253f16e39d352 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -110,6 +110,11 @@ private[execution] sealed trait HashedRelation extends KnownSizeEstimation { */ def keys(): Iterator[InternalRow] + /** + * Returns the average number of hash probes per key lookup. + */ + def getAvgHashProbesPerKey(): Double + /** * Returns a read-only copy of this, to be safely used in current thread. */ @@ -202,7 +207,7 @@ private[execution] class ValueRowWithKeyIndex { * A HashedRelation for UnsafeRow, which is backed BytesToBytesMap. * * It's serialized in the following format: - * [number of keys] + * [number of keys] [number of fields] * [size of key] [size of value] [key bytes] [bytes for value] */ private[joins] class UnsafeHashedRelation( @@ -221,6 +226,8 @@ private[joins] class UnsafeHashedRelation( override def estimatedSize: Long = binaryMap.getTotalMemoryConsumption + override def getAvgHashProbesPerKey(): Double = binaryMap.getAvgHashProbesPerKey + // re-used in get()/getValue()/getWithKeyIndex()/getValueWithKeyIndex()/valuesWithKeyIndex() var resultRow = new UnsafeRow(numFields) @@ -357,6 +364,7 @@ private[joins] class UnsafeHashedRelation( writeInt: (Int) => Unit, writeLong: (Long) => Unit, writeBuffer: (Array[Byte], Int, Int) => Unit) : Unit = { + writeInt(numKeys) writeInt(numFields) // TODO: move these into BytesToBytesMap writeLong(binaryMap.numKeys()) @@ -390,6 +398,7 @@ private[joins] class UnsafeHashedRelation( readInt: () => Int, readLong: () => Long, readBuffer: (Array[Byte], Int, Int) => Unit): Unit = { + numKeys = readInt() numFields = readInt() resultRow = new UnsafeRow(numFields) val nKeys = readLong() @@ -566,6 +575,12 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap // The number of unique keys. private var numKeys = 0L + // The number of hash probes for keys. + private var numProbes = 0L + + // The number of keys lookups. + private var numKeyLookups = 0L + // needed by serializer def this() = { this( @@ -614,6 +629,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap */ def getTotalMemoryConsumption: Long = array.length * 8L + page.length * 8L + /** + * Returns the average number of hash probes per key lookup. + */ + def getAvgHashProbesPerKey: Double = (1.0 * numProbes) / numKeyLookups + /** * Returns the first slot of array that store the keys (sparse mode). */ @@ -648,7 +668,9 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap * Returns the single UnsafeRow for given key, or null if not found. */ def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = { + numKeyLookups += 1 if (isDense) { + numProbes += 1 if (key >= minKey && key <= maxKey) { val value = array((key - minKey).toInt) if (value > 0) { @@ -656,12 +678,14 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap } } } else { + numProbes += 1 var pos = firstSlot(key) while (array(pos + 1) != 0) { if (array(pos) == key) { return getRow(array(pos + 1), resultRow) } pos = nextSlot(pos) + numProbes += 1 } } null @@ -688,7 +712,9 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap * Returns an iterator for all the values for the given key, or null if no value found. */ def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { + numKeyLookups += 1 if (isDense) { + numProbes += 1 if (key >= minKey && key <= maxKey) { val value = array((key - minKey).toInt) if (value > 0) { @@ -696,12 +722,14 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap } } } else { + numProbes += 1 var pos = firstSlot(key) while (array(pos + 1) != 0) { if (array(pos) == key) { return valueIter(array(pos + 1), resultRow) } pos = nextSlot(pos) + numProbes += 1 } } null @@ -780,10 +808,13 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap * Update the address in array for given key. */ private def updateIndex(key: Long, address: Long): Unit = { + numKeyLookups += 1 + numProbes += 1 var pos = firstSlot(key) assert(numKeys < array.length / 2) while (array(pos) != key && array(pos + 1) != 0) { pos = nextSlot(pos) + numProbes += 1 } if (array(pos + 1) == 0) { // this is the first value for this key, put the address in array. @@ -986,6 +1017,8 @@ class LongHashedRelation( override def estimatedSize: Long = map.getTotalMemoryConsumption + override def getAvgHashProbesPerKey(): Double = map.getAvgHashProbesPerKey + override def get(key: InternalRow): Iterator[InternalRow] = { if (key.isNullAt(0)) { null @@ -1103,6 +1136,8 @@ case object EmptyHashedRelation extends HashedRelation { override def close(): Unit = {} override def estimatedSize: Long = 0 + + override def getAvgHashProbesPerKey(): Double = 0 } /** @@ -1129,6 +1164,8 @@ case object HashedRelationWithAllNullKeys extends HashedRelation { override def close(): Unit = {} override def estimatedSize: Long = 0 + + override def getAvgHashProbesPerKey(): Double = 0 } /** The HashedRelationBroadcastMode requires that rows are broadcasted as a HashedRelation. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index cfe35d04778fb..38c9c82f77e07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -49,7 +49,8 @@ case class ShuffledHashJoinExec( override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of build side"), - "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map")) + "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"), + "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probes per key")) override def output: Seq[Attribute] = super[ShuffledJoin].output @@ -77,6 +78,7 @@ case class ShuffledHashJoinExec( def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = { val buildDataSize = longMetric("buildDataSize") val buildTime = longMetric("buildTime") + val avgHashProbe = longMetric("avgHashProbe") val start = System.nanoTime() val context = TaskContext.get() val relation = HashedRelation( @@ -89,7 +91,11 @@ case class ShuffledHashJoinExec( buildTime += NANOSECONDS.toMillis(System.nanoTime() - start) buildDataSize += relation.estimatedSize // This relation is usually used until the end of task. - context.addTaskCompletionListener[Unit](_ => relation.close()) + context.addTaskCompletionListener[Unit](_ => { + // Update average hashmap probe + avgHashProbe.set(relation.getAvgHashProbesPerKey()) + relation.close() + }) relation } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 69802b143c113..a7f63aafc9f1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -88,7 +88,7 @@ case class AggregateInPandasExec( (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) case children => // There should not be any other UDFs, or the children can't be evaluated directly. - assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) + assert(children.forall(!_.exists(_.isInstanceOf[PythonUDF]))) (ChainedPythonFunctions(Seq(udf.func)), udf.children) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala index fca43e454bff5..c567a70e1d3cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala @@ -72,7 +72,7 @@ trait EvalPythonExec extends UnaryExecNode { (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) case children => // There should not be any other UDFs, or the children can't be evaluated directly. - assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) + assert(children.forall(!_.exists(_.isInstanceOf[PythonUDF]))) (ChainedPythonFunctions(Seq(udf.func)), udf.children) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 407c498c81759..a809ea07d0ec6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -45,10 +45,10 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { } private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate): Boolean = { - expr.find { + expr.exists { e => PythonUDF.isScalarPythonUDF(e) && - (e.references.isEmpty || e.find(belongAggregate(_, agg)).isDefined) - }.isDefined + (e.references.isEmpty || e.exists(belongAggregate(_, agg))) + } } private def extract(agg: Aggregate): LogicalPlan = { @@ -90,7 +90,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { */ object ExtractGroupingPythonUDFFromAggregate extends Rule[LogicalPlan] { private def hasScalarPythonUDF(e: Expression): Boolean = { - e.find(PythonUDF.isScalarPythonUDF).isDefined + e.exists(PythonUDF.isScalarPythonUDF) } private def extract(agg: Aggregate): LogicalPlan = { @@ -164,7 +164,7 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper { private type EvalTypeChecker = EvalType => Boolean private def hasScalarPythonUDF(e: Expression): Boolean = { - e.find(PythonUDF.isScalarPythonUDF).isDefined + e.exists(PythonUDF.isScalarPythonUDF) } @scala.annotation.tailrec diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala index 07c0aab1b6b74..e73da99786ceb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -27,7 +27,6 @@ import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan} import org.apache.spark.sql.execution.window._ @@ -87,27 +86,6 @@ case class WindowInPandasExec( child: SparkPlan) extends WindowExecBase { - override def output: Seq[Attribute] = - child.output ++ windowExpression.map(_.toAttribute) - - override def requiredChildDistribution: Seq[Distribution] = { - if (partitionSpec.isEmpty) { - // Only show warning when the number of bytes is larger than 100 MiB? - logWarning("No Partition Defined for Window operation! Moving all data to a single " - + "partition, this can cause serious performance degradation.") - AllTuples :: Nil - } else { - ClusteredDistribution(partitionSpec) :: Nil - } - } - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec) - - override def outputOrdering: Seq[SortOrder] = child.outputOrdering - - override def outputPartitioning: Partitioning = child.outputPartitioning - /** * Helper functions and data structures for window bounds * @@ -135,7 +113,7 @@ case class WindowInPandasExec( (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) case children => // There should not be any other UDFs, or the children can't be evaluated directly. - assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) + assert(children.forall(!_.exists(_.isInstanceOf[PythonUDF]))) (ChainedPythonFunctions(Seq(udf.func)), udf.children) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 5dc0ff0ac4d1d..9155c1cb6e7ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -22,7 +22,7 @@ import java.util.Locale import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, Expression, GenericInternalRow, GetArrayItem, Literal} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, Expression, GenericInternalRow, GetArrayItem, Literal, TryCast} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.util.{GenericArrayData, QuantileSummaries} @@ -246,6 +246,11 @@ object StatFunctions extends Logging { } require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]") + def castAsDoubleIfNecessary(e: Expression): Expression = if (e.dataType == StringType) { + TryCast(e, DoubleType) + } else { + e + } var percentileIndex = 0 val statisticFns = selectedStatistics.map { stats => if (stats.endsWith("%")) { @@ -253,7 +258,7 @@ object StatFunctions extends Logging { percentileIndex += 1 (child: Expression) => GetArrayItem( - new ApproximatePercentile(child, + new ApproximatePercentile(castAsDoubleIfNecessary(child), Literal(new GenericArrayData(percentiles), ArrayType(DoubleType, false))) .toAggregateExpression(), Literal(index)) @@ -264,8 +269,10 @@ object StatFunctions extends Logging { Count(child).toAggregateExpression(isDistinct = true) case "approx_count_distinct" => (child: Expression) => HyperLogLogPlusPlus(child).toAggregateExpression() - case "mean" => (child: Expression) => Average(child).toAggregateExpression() - case "stddev" => (child: Expression) => StddevSamp(child).toAggregateExpression() + case "mean" => (child: Expression) => + Average(castAsDoubleIfNecessary(child)).toAggregateExpression() + case "stddev" => (child: Expression) => + StddevSamp(castAsDoubleIfNecessary(child)).toAggregateExpression() case "min" => (child: Expression) => Min(child).toAggregateExpression() case "max" => (child: Expression) => Max(child).toAggregateExpression() case _ => throw QueryExecutionErrors.statisticNotRecognizedError(stats) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AcceptsLatestSeenOffsetHandler.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AcceptsLatestSeenOffsetHandler.scala new file mode 100644 index 0000000000000..69795cc82c477 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AcceptsLatestSeenOffsetHandler.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.connector.read.streaming.{AcceptsLatestSeenOffset, SparkDataStream} + +/** + * This feeds "latest seen offset" to the sources that implement AcceptsLatestSeenOffset. + */ +object AcceptsLatestSeenOffsetHandler { + def setLatestSeenOffsetOnSources( + offsets: Option[OffsetSeq], + sources: Seq[SparkDataStream]): Unit = { + assertNoAcceptsLatestSeenOffsetWithDataSourceV1(sources) + + offsets.map(_.toStreamProgress(sources)) match { + case Some(streamProgress) => + streamProgress.foreach { + case (src: AcceptsLatestSeenOffset, offset) => + src.setLatestSeenOffset(offset) + + case _ => // no-op + } + case _ => // no-op + } + } + + private def assertNoAcceptsLatestSeenOffsetWithDataSourceV1( + sources: Seq[SparkDataStream]): Unit = { + val unsupportedSources = sources + .filter(_.isInstanceOf[AcceptsLatestSeenOffset]) + .filter(_.isInstanceOf[Source]) + + if (unsupportedSources.nonEmpty) { + throw new UnsupportedOperationException( + "AcceptsLatestSeenOffset is not supported with DSv1 streaming source: " + + unsupportedSources) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala index 6f43542fd6595..a5c1c735cbd7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala @@ -33,9 +33,9 @@ class FileStreamOptions(parameters: CaseInsensitiveMap[String]) extends Logging def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters)) - checkDisallowedOptions(parameters) + checkDisallowedOptions() - private def checkDisallowedOptions(options: Map[String, String]): Unit = { + private def checkDisallowedOptions(): Unit = { Seq(ModifiedBeforeFilter.PARAM_NAME, ModifiedAfterFilter.PARAM_NAME).foreach { param => if (parameters.contains(param)) { throw new IllegalArgumentException(s"option '$param' is not allowed in file stream sources") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index a00a62216f3dc..3ff539b9ef32b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution} +import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._ import org.apache.spark.sql.execution.streaming.state._ @@ -93,8 +93,10 @@ case class FlatMapGroupsWithStateExec( * to have the same grouping so that the data are co-lacated on the same task. */ override def requiredChildDistribution: Seq[Distribution] = { - ClusteredDistribution(groupingAttributes, stateInfo.map(_.numPartitions)) :: - ClusteredDistribution(initialStateGroupAttrs, stateInfo.map(_.numPartitions)) :: + StatefulOperatorPartitioning.getCompatibleDistribution( + groupingAttributes, getStateInfo, conf) :: + StatefulOperatorPartitioning.getCompatibleDistribution( + initialStateGroupAttrs, getStateInfo, conf) :: Nil } @@ -167,12 +169,20 @@ case class FlatMapGroupsWithStateExec( timeoutProcessingStartTimeNs = System.nanoTime }) - val timeoutProcessorIter = - CompletionIterator[InternalRow, Iterator[InternalRow]](processor.processTimedOutState(), { - // Note: `timeoutLatencyMs` also includes the time the parent operator took for - // processing output returned through iterator. - timeoutLatencyMs += NANOSECONDS.toMillis(System.nanoTime - timeoutProcessingStartTimeNs) - }) + // SPARK-38320: Late-bind the timeout processing iterator so it is created *after* the input is + // processed (the input iterator is exhausted) and the state updates are written into the + // state store. Otherwise the iterator may not see the updates (e.g. with RocksDB state store). + val timeoutProcessorIter = new Iterator[InternalRow] { + private lazy val itr = getIterator() + override def hasNext = itr.hasNext + override def next() = itr.next() + private def getIterator(): Iterator[InternalRow] = + CompletionIterator[InternalRow, Iterator[InternalRow]](processor.processTimedOutState(), { + // Note: `timeoutLatencyMs` also includes the time the parent operator took for + // processing output returned through iterator. + timeoutLatencyMs += NANOSECONDS.toMillis(System.nanoTime - timeoutProcessingStartTimeNs) + }) + } // Generate a iterator that returns the rows grouped by the grouping function // Note that this code ensures that the filtering for timeout occurs only after diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 3e772e104648b..9670c774a74c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.execution.{LocalLimitExec, QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode} +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, MergingSessionsExec, ObjectHashAggregateExec, SortAggregateExec, UpdatingSessionsExec} import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode @@ -132,6 +133,22 @@ class IncrementalExecution( } override def apply(plan: SparkPlan): SparkPlan = plan transform { + // NOTE: we should include all aggregate execs here which are used in streaming aggregations + case a: SortAggregateExec if a.isStreaming => + a.copy(numShufflePartitions = Some(numStateStores)) + + case a: HashAggregateExec if a.isStreaming => + a.copy(numShufflePartitions = Some(numStateStores)) + + case a: ObjectHashAggregateExec if a.isStreaming => + a.copy(numShufflePartitions = Some(numStateStores)) + + case a: MergingSessionsExec if a.isStreaming => + a.copy(numShufflePartitions = Some(numStateStores)) + + case a: UpdatingSessionsExec if a.isStreaming => + a.copy(numShufflePartitions = Some(numStateStores)) + case StateStoreSaveExec(keys, None, None, None, stateFormatVersion, UnaryExecNode(agg, StateStoreRestoreExec(_, None, _, child))) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index e9e4be90a0449..3b409fa2f6a72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.{Map => MutableMap} import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.encoders.RowEncoder -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp, LocalTimestamp} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp, FileSourceMetadataAttribute, LocalTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.streaming.{StreamingRelationV2, WriteToStream} import org.apache.spark.sql.catalyst.trees.TreePattern.CURRENT_LIKE @@ -30,6 +30,7 @@ import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Tabl import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset => OffsetV2, ReadLimit, SparkDataStream, SupportsAdmissionControl, SupportsTriggerAvailableNow} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, StreamWriterCommitProgress, WriteToDataSourceV2Exec} import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSource import org.apache.spark.sql.internal.SQLConf @@ -156,11 +157,16 @@ class MicroBatchExecution( // TODO (SPARK-27484): we should add the writing node before the plan is analyzed. sink match { case s: SupportsWrite => - val (streamingWrite, customMetrics) = createStreamingWrite(s, extraOptions, _logicalPlan) val relationOpt = plan.catalogAndIdent.map { case (catalog, ident) => DataSourceV2Relation.create(s, Some(catalog), Some(ident)) } - WriteToMicroBatchDataSource(relationOpt, streamingWrite, _logicalPlan, customMetrics) + WriteToMicroBatchDataSource( + relationOpt, + table = s, + query = _logicalPlan, + queryId = id.toString, + extraOptions, + outputMode) case _ => _logicalPlan } @@ -212,6 +218,8 @@ class MicroBatchExecution( reportTimeTaken("triggerExecution") { // We'll do this initialization only once every start / restart if (currentBatchId < 0) { + AcceptsLatestSeenOffsetHandler.setLatestSeenOffsetOnSources( + offsetLog.getLatest().map(_._2), sources) populateStartOffsets(sparkSessionForStream) logInfo(s"Stream started from $committedOffsets") } @@ -308,6 +316,12 @@ class MicroBatchExecution( * is the second latest batch id in the offset log. */ if (latestBatchId != 0) { val secondLatestOffsets = offsetLog.get(latestBatchId - 1).getOrElse { + logError(s"The offset log for batch ${latestBatchId - 1} doesn't exist, " + + s"which is required to restart the query from the latest batch $latestBatchId " + + "from the offset log. Please ensure there are two subsequent offset logs " + + "available for the latest batch via manually deleting the offset file(s). " + + "Please also ensure the latest batch for commit log is equal or one batch " + + "earlier than the latest batch for offset log.") throw new IllegalStateException(s"batch ${latestBatchId - 1} doesn't exist") } committedOffsets = secondLatestOffsets.toStreamProgress(sources) @@ -564,15 +578,23 @@ class MicroBatchExecution( // For v1 sources. case StreamingExecutionRelation(source, output) => newData.get(source).map { dataPlan => + val hasFileMetadata = output.exists { + case FileSourceMetadataAttribute(_) => true + case _ => false + } + val finalDataPlan = dataPlan match { + case l: LogicalRelation if hasFileMetadata => l.withMetadataColumns() + case _ => dataPlan + } val maxFields = SQLConf.get.maxToStringFields - assert(output.size == dataPlan.output.size, + assert(output.size == finalDataPlan.output.size, s"Invalid batch: ${truncatedString(output, ",", maxFields)} != " + - s"${truncatedString(dataPlan.output, ",", maxFields)}") + s"${truncatedString(finalDataPlan.output, ",", maxFields)}") - val aliases = output.zip(dataPlan.output).map { case (to, from) => + val aliases = output.zip(finalDataPlan.output).map { case (to, from) => Alias(from, to.name)(exprId = to.exprId, explicitMetadata = Some(from.metadata)) } - Project(aliases, dataPlan) + Project(aliases, finalDataPlan) }.getOrElse { LocalRelation(output, isStreaming = true) } @@ -607,7 +629,7 @@ class MicroBatchExecution( val triggerLogicalPlan = sink match { case _: Sink => newAttributePlan case _: SupportsWrite => - newAttributePlan.asInstanceOf[WriteToMicroBatchDataSource].createPlan(currentBatchId) + newAttributePlan.asInstanceOf[WriteToMicroBatchDataSource].withNewBatchId(currentBatchId) case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index c08a14c65b772..913805d1a074d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -98,7 +98,7 @@ object OffsetSeqMetadata extends Logging { SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY, FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, STREAMING_AGGREGATION_STATE_FORMAT_VERSION, STREAMING_JOIN_STATE_FORMAT_VERSION, STATE_STORE_COMPRESSION_CODEC, - STATE_STORE_ROCKSDB_FORMAT_VERSION) + STATE_STORE_ROCKSDB_FORMAT_VERSION, STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION) /** * Default values of relevant configurations that are used for backward compatibility. @@ -118,7 +118,8 @@ object OffsetSeqMetadata extends Logging { StreamingAggregationStateManager.legacyVersion.toString, STREAMING_JOIN_STATE_FORMAT_VERSION.key -> SymmetricHashJoinStateManager.legacyVersion.toString, - STATE_STORE_COMPRESSION_CODEC.key -> "lz4" + STATE_STORE_COMPRESSION_CODEC.key -> "lz4", + STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION.key -> "false" ) def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulOperatorPartitioning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulOperatorPartitioning.scala new file mode 100644 index 0000000000000..527349201574e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulOperatorPartitioning.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, StatefulOpClusteredDistribution} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION + +/** + * This object is to provide clustered distribution for stateful operator with ensuring backward + * compatibility. Please read through the NOTE on the classdoc of + * [[StatefulOpClusteredDistribution]] before making any changes. Please refer SPARK-38204 + * for details. + * + * Do not use methods in this object for stateful operators which already uses + * [[StatefulOpClusteredDistribution]] as its required child distribution. + */ +object StatefulOperatorPartitioning { + + def getCompatibleDistribution( + expressions: Seq[Expression], + stateInfo: StatefulOperatorStateInfo, + conf: SQLConf): Distribution = { + getCompatibleDistribution(expressions, stateInfo.numPartitions, conf) + } + + def getCompatibleDistribution( + expressions: Seq[Expression], + numPartitions: Int, + conf: SQLConf): Distribution = { + if (conf.getConf(STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION)) { + StatefulOpClusteredDistribution(expressions, numPartitions) + } else { + ClusteredDistribution(expressions, requiredNumPartitions = Some(numPartitions)) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index d1dfcdc514a10..f9ae65cdc47d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -37,10 +37,8 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table} -import org.apache.spark.sql.connector.metric.CustomMetric import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, ReadLimit, SparkDataStream} -import org.apache.spark.sql.connector.write.{LogicalWriteInfoImpl, SupportsTruncate} -import org.apache.spark.sql.connector.write.streaming.StreamingWrite +import org.apache.spark.sql.connector.write.{LogicalWriteInfoImpl, SupportsTruncate, Write} import org.apache.spark.sql.execution.command.StreamingExplainCommand import org.apache.spark.sql.execution.datasources.v2.StreamWriterCommitProgress import org.apache.spark.sql.internal.SQLConf @@ -289,6 +287,11 @@ abstract class StreamExecution( // Disable cost-based join optimization as we do not want stateful operations // to be rearranged sparkSessionForStream.conf.set(SQLConf.CBO_ENABLED.key, "false") + // Disable any config affecting the required child distribution of stateful operators. + // Please read through the NOTE on the classdoc of StatefulOpClusteredDistribution for + // details. + sparkSessionForStream.conf.set(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION.key, + "false") updateStatusMessage("Initializing sources") // force initialization of the logical plan so that the sources can be created @@ -579,16 +582,16 @@ abstract class StreamExecution( |batch = $batchDescription""".stripMargin } - protected def createStreamingWrite( + protected def createWrite( table: SupportsWrite, options: Map[String, String], - inputPlan: LogicalPlan): (StreamingWrite, Seq[CustomMetric]) = { + inputPlan: LogicalPlan): Write = { val info = LogicalWriteInfoImpl( queryId = id.toString, inputPlan.schema, new CaseInsensitiveStringMap(options.asJava)) val writeBuilder = table.newWriteBuilder(info) - val write = outputMode match { + outputMode match { case Append => writeBuilder.build() @@ -603,8 +606,6 @@ abstract class StreamExecution( table.name + " does not support Update mode.") writeBuilder.asInstanceOf[SupportsStreamingUpdateAsAppend].build() } - - (write.toStreaming, write.supportedCustomMetrics().toSeq) } protected def purge(threshold: Long): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index 5d4b811defeeb..00962a4f4cdf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -21,12 +21,12 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.plans.logical.{ExposesMetadataColumns, LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.connector.read.streaming.SparkDataStream import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.LeafExecNode -import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.datasources.{DataSource, FileFormat} object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { @@ -43,7 +43,7 @@ object StreamingRelation { * passing to [[StreamExecution]] to run a query. */ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: Seq[Attribute]) - extends LeafNode with MultiInstanceRelation { + extends LeafNode with MultiInstanceRelation with ExposesMetadataColumns { override def isStreaming: Boolean = true override def toString: String = sourceName @@ -56,6 +56,31 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: ) override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance())) + + override lazy val metadataOutput: Seq[AttributeReference] = { + dataSource.providingClass match { + // If the dataSource provided class is a same or subclass of FileFormat class + case f if classOf[FileFormat].isAssignableFrom(f) => + val resolve = conf.resolver + val outputNames = outputSet.map(_.name) + def isOutputColumn(col: AttributeReference): Boolean = { + outputNames.exists(name => resolve(col.name, name)) + } + // filter out the metadata struct column if it has the name conflicting with output columns. + // if the file has a column "_metadata", + // then the data column should be returned not the metadata struct column + Seq(FileFormat.createFileMetadataCol).filterNot(isOutputColumn) + case _ => Nil + } + } + + override def withMetadataColumns(): LogicalPlan = { + if (metadataOutput.nonEmpty) { + this.copy(output = output ++ metadataOutput) + } else { + this + } + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 74b82451e029f..81888e0f7e189 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -174,7 +174,13 @@ case class StreamingSymmetricHashJoinExec( joinType == Inner || joinType == LeftOuter || joinType == RightOuter || joinType == FullOuter || joinType == LeftSemi, errorMessageForJoinType) - require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType)) + + // The assertion against join keys is same as hash join for batch query. + require(leftKeys.length == rightKeys.length && + leftKeys.map(_.dataType) + .zip(rightKeys.map(_.dataType)) + .forall(types => types._1.sameType(types._2)), + "Join keys from two sides should have same length and types") private val storeConf = new StateStoreConf(conf) private val hadoopConfBcast = sparkContext.broadcast( @@ -185,8 +191,8 @@ case class StreamingSymmetricHashJoinExec( val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length) override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) :: - ClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil + StatefulOpClusteredDistribution(leftKeys, getStateInfo.numPartitions) :: + StatefulOpClusteredDistribution(rightKeys, getStateInfo.numPartitions) :: Nil override def output: Seq[Attribute] = joinType match { case _: InnerLike => left.output ++ right.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 5101bdf46ed1f..665ed77007bb8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -31,8 +31,10 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.streaming.{StreamingRelationV2, WriteToStream} import org.apache.spark.sql.catalyst.trees.TreePattern.CURRENT_LIKE import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, TableCapability} +import org.apache.spark.sql.connector.distributions.UnspecifiedDistribution import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, PartitionOffset, ReadLimit} -import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.connector.write.{RequiresDistributionAndOrdering, Write} +import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ @@ -85,11 +87,36 @@ class ContinuousExecution( uniqueSources = sources.distinct.map(s => s -> ReadLimit.allAvailable()).toMap // TODO (SPARK-27484): we should add the writing node before the plan is analyzed. - val (streamingWrite, customMetrics) = createStreamingWrite( - plan.sink.asInstanceOf[SupportsWrite], extraOptions, _logicalPlan) + val write = createWrite(plan.sink.asInstanceOf[SupportsWrite], extraOptions, _logicalPlan) + + if (hasDistributionRequirements(write) || hasOrderingRequirements(write)) { + throw QueryCompilationErrors.writeDistributionAndOrderingNotSupportedInContinuousExecution() + } + + val streamingWrite = write.toStreaming + val customMetrics = write.supportedCustomMetrics.toSeq WriteToContinuousDataSource(streamingWrite, _logicalPlan, customMetrics) } + private def hasDistributionRequirements(write: Write): Boolean = write match { + case w: RequiresDistributionAndOrdering if w.requiredNumPartitions == 0 => + w.requiredDistribution match { + case _: UnspecifiedDistribution => + false + case _ => + true + } + case _ => + false + } + + private def hasOrderingRequirements(write: Write): Boolean = write match { + case w: RequiresDistributionAndOrdering if w.requiredOrdering.nonEmpty => + true + case _ => + false + } + private val triggerExecutor = trigger match { case ContinuousTrigger(t) => ProcessingTimeExecutor(ProcessingTimeTrigger(t), triggerClock) case _ => throw new IllegalStateException(s"Unsupported type of trigger: $trigger") @@ -130,7 +157,7 @@ class ContinuousExecution( * Start a new query log * DONE */ - private def getStartOffsets(sparkSessionToRunBatches: SparkSession): OffsetSeq = { + private def getStartOffsets(): OffsetSeq = { // Note that this will need a slight modification for exactly once. If ending offsets were // reported but not committed for any epochs, we must replay exactly to those offsets. // For at least once, we can just ignore those reports and risk duplicates. @@ -161,7 +188,11 @@ class ContinuousExecution( * @param sparkSessionForQuery Isolated [[SparkSession]] to run the continuous query with. */ private def runContinuous(sparkSessionForQuery: SparkSession): Unit = { - val offsets = getStartOffsets(sparkSessionForQuery) + val offsets = getStartOffsets() + + if (currentBatchId > 0) { + AcceptsLatestSeenOffsetHandler.setLatestSeenOffsetOnSources(Some(offsets), sources) + } val withNewSources: LogicalPlan = logicalPlan transform { case relation: StreamingDataSourceV2Relation => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala index b8b85a7ded877..0a33093dcbcea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala @@ -19,27 +19,31 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} -import org.apache.spark.sql.connector.metric.CustomMetric -import org.apache.spark.sql.connector.write.streaming.StreamingWrite -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.connector.catalog.SupportsWrite +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.streaming.OutputMode /** * The logical plan for writing data to a micro-batch stream. * * Note that this logical plan does not have a corresponding physical plan, as it will be converted - * to [[WriteToDataSourceV2]] with [[MicroBatchWrite]] before execution. + * to [[org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2 WriteToDataSourceV2]] + * with [[MicroBatchWrite]] before execution. */ case class WriteToMicroBatchDataSource( relation: Option[DataSourceV2Relation], - write: StreamingWrite, + table: SupportsWrite, query: LogicalPlan, - customMetrics: Seq[CustomMetric]) + queryId: String, + writeOptions: Map[String, String], + outputMode: OutputMode, + batchId: Option[Long] = None) extends UnaryNode { override def child: LogicalPlan = query override def output: Seq[Attribute] = Nil - def createPlan(batchId: Long): WriteToDataSourceV2 = { - WriteToDataSourceV2(relation, new MicroBatchWrite(batchId, write), query, customMetrics) + def withNewBatchId(batchId: Long): WriteToMicroBatchDataSource = { + copy(batchId = Some(batchId)) } override protected def withNewChildInternal(newChild: LogicalPlan): WriteToMicroBatchDataSource = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index ea25342cc8a1c..a5bd489e04fda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -370,6 +370,9 @@ class RocksDB( val totalSSTFilesBytes = getDBProperty("rocksdb.total-sst-files-size") val readerMemUsage = getDBProperty("rocksdb.estimate-table-readers-mem") val memTableMemUsage = getDBProperty("rocksdb.size-all-mem-tables") + val blockCacheUsage = getDBProperty("rocksdb.block-cache-usage") + // Get the approximate memory usage of this writeBatchWithIndex + val writeBatchMemUsage = writeBatch.getWriteBatch.getDataSize val nativeOpsHistograms = Seq( "get" -> DB_GET, "put" -> DB_WRITE, @@ -403,7 +406,8 @@ class RocksDB( RocksDBMetrics( numKeysOnLoadedVersion, numKeysOnWritingVersion, - readerMemUsage + memTableMemUsage, + readerMemUsage + memTableMemUsage + blockCacheUsage + writeBatchMemUsage, + writeBatchMemUsage, totalSSTFilesBytes, nativeOpsLatencyMicros.toMap, commitLatencyMs, @@ -616,7 +620,8 @@ object RocksDBConf { case class RocksDBMetrics( numCommittedKeys: Long, numUncommittedKeys: Long, - memUsageBytes: Long, + totalMemUsageBytes: Long, + writeBatchMemUsageBytes: Long, totalSSTFilesBytes: Long, nativeOpsHistograms: Map[String, RocksDBNativeHistogram], lastCommitLatencyMs: Map[String, Long], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index c88e6ae3f477c..79614df629927 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -148,7 +148,7 @@ private[sql] class RocksDBStateStoreProvider StateStoreMetrics( rocksDBMetrics.numUncommittedKeys, - rocksDBMetrics.memUsageBytes, + rocksDBMetrics.totalMemUsageBytes, stateStoreCustomMetrics) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala index 20625e10f321e..0c8cabb75ed65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala @@ -72,7 +72,7 @@ class StateSchemaCompatibilityChecker( } private def schemasCompatible(storedSchema: StructType, schema: StructType): Boolean = - DataType.equalsIgnoreNameAndCompatibleNullability(storedSchema, schema) + DataType.equalsIgnoreNameAndCompatibleNullability(schema, storedSchema) // Visible for testing private[sql] def readSchemaFile(): (StructType, StructType) = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 3431823765c1b..e367637671cc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark -import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, Partitioning} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution._ @@ -337,7 +337,8 @@ case class StateStoreRestoreExec( if (keyExpressions.isEmpty) { AllTuples :: Nil } else { - ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil + StatefulOperatorPartitioning.getCompatibleDistribution( + keyExpressions, getStateInfo, conf) :: Nil } } @@ -496,7 +497,8 @@ case class StateStoreSaveExec( if (keyExpressions.isEmpty) { AllTuples :: Nil } else { - ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil + StatefulOperatorPartitioning.getCompatibleDistribution( + keyExpressions, getStateInfo, conf) :: Nil } } @@ -527,6 +529,12 @@ case class SessionWindowStateStoreRestoreExec( child: SparkPlan) extends UnaryExecNode with StateStoreReader with WatermarkSupport { + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numRowsDroppedByWatermark" -> SQLMetrics.createMetric(sparkContext, + "number of rows which are dropped by watermark") + ) + override def keyExpressions: Seq[Attribute] = keyWithoutSessionExpressions assert(keyExpressions.nonEmpty, "Grouping key must be specified when using sessionWindow") @@ -547,7 +555,11 @@ case class SessionWindowStateStoreRestoreExec( // We need to filter out outdated inputs val filteredIterator = watermarkPredicateForData match { - case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) + case Some(predicate) => iter.filter((row: InternalRow) => { + val shouldKeep = !predicate.eval(row) + if (!shouldKeep) longMetric("numRowsDroppedByWatermark") += 1 + shouldKeep + }) case None => iter } @@ -573,7 +585,8 @@ case class SessionWindowStateStoreRestoreExec( } override def requiredChildDistribution: Seq[Distribution] = { - ClusteredDistribution(keyWithoutSessionExpressions, stateInfo.map(_.numPartitions)) :: Nil + StatefulOperatorPartitioning.getCompatibleDistribution( + keyWithoutSessionExpressions, getStateInfo, conf) :: Nil } override def requiredChildOrdering: Seq[Seq[SortOrder]] = { @@ -684,7 +697,8 @@ case class SessionWindowStateStoreSaveExec( override def outputPartitioning: Partitioning = child.outputPartitioning override def requiredChildDistribution: Seq[Distribution] = { - ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil + StatefulOperatorPartitioning.getCompatibleDistribution( + keyWithoutSessionExpressions, getStateInfo, conf) :: Nil } override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { @@ -741,8 +755,10 @@ case class StreamingDeduplicateExec( extends UnaryExecNode with StateStoreWriter with WatermarkSupport { /** Distribute by grouping attributes */ - override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil + override def requiredChildDistribution: Seq[Distribution] = { + StatefulOperatorPartitioning.getCompatibleDistribution( + keyExpressions, getStateInfo, conf) :: Nil + } override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 887867766ea92..afd0aba00680e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -46,10 +46,10 @@ object ExecSubqueryExpression { * Returns true when an expression contains a subquery */ def hasSubquery(e: Expression): Boolean = { - e.find { + e.exists { case _: ExecSubqueryExpression => true case _ => false - }.isDefined + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/StreamingQueryStatusStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/StreamingQueryStatusStore.scala index 9eb14a6a63063..6a3b4eeb67275 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/StreamingQueryStatusStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/StreamingQueryStatusStore.scala @@ -43,7 +43,7 @@ class StreamingQueryStatusStore(store: KVStore) { } private def makeUIData(summary: StreamingQueryData): StreamingQueryUIData = { - val runId = summary.runId.toString + val runId = summary.runId val view = store.view(classOf[StreamingQueryProgressWrapper]) .index("runId").first(runId).last(runId) val recentProgress = KVUtils.viewToSeq(view, Int.MaxValue)(_ => true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index 374659e03a3fd..33c37e871e385 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.window import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan} /** @@ -91,25 +90,6 @@ case class WindowExec( child: SparkPlan) extends WindowExecBase { - override def output: Seq[Attribute] = - child.output ++ windowExpression.map(_.toAttribute) - - override def requiredChildDistribution: Seq[Distribution] = { - if (partitionSpec.isEmpty) { - // Only show warning when the number of bytes is larger than 100 MiB? - logWarning("No Partition Defined for Window operation! Moving all data to a single " - + "partition, this can cause serious performance degradation.") - AllTuples :: Nil - } else ClusteredDistribution(partitionSpec) :: Nil - } - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec) - - override def outputOrdering: Seq[SortOrder] = child.outputOrdering - - override def outputPartitioning: Partitioning = child.outputPartitioning - protected override def doExecute(): RDD[InternalRow] = { // Unwrap the window expressions and window frame factories from the map. val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala index f3b3b3494f2cc..5f1758d12fd5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala @@ -23,14 +23,39 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.UnaryExecNode import org.apache.spark.sql.types._ +/** + * Holds common logic for window operators + */ trait WindowExecBase extends UnaryExecNode { def windowExpression: Seq[NamedExpression] def partitionSpec: Seq[Expression] def orderSpec: Seq[SortOrder] + override def output: Seq[Attribute] = + child.output ++ windowExpression.map(_.toAttribute) + + override def requiredChildDistribution: Seq[Distribution] = { + if (partitionSpec.isEmpty) { + // Only show warning when the number of bytes is larger than 100 MiB? + logWarning("No Partition Defined for Window operation! Moving all data to a single " + + "partition, this can cause serious performance degradation.") + AllTuples :: Nil + } else { + ClusteredDistribution(partitionSpec) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec) + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputPartitioning: Partitioning = child.outputPartitioning + /** * Create the resulting projection. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index ec28d8dde38e3..58e855e2314ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1768,15 +1768,25 @@ object functions { def cbrt(columnName: String): Column = cbrt(Column(columnName)) /** - * Computes the ceiling of the given value. + * Computes the ceiling of the given value of `e` to `scale` decimal places. + * + * @group math_funcs + * @since 3.3.0 + */ + def ceil(e: Column, scale: Column): Column = withExpr { + UnresolvedFunction(Seq("ceil"), Seq(e.expr, scale.expr), isDistinct = false) + } + + /** + * Computes the ceiling of the given value of `e` to 0 decimal places. * * @group math_funcs * @since 1.4.0 */ - def ceil(e: Column): Column = withExpr { Ceil(e.expr) } + def ceil(e: Column): Column = ceil(e, lit(0)) /** - * Computes the ceiling of the given column. + * Computes the ceiling of the given value of `e` to 0 decimal places. * * @group math_funcs * @since 1.4.0 @@ -1888,15 +1898,25 @@ object functions { def factorial(e: Column): Column = withExpr { Factorial(e.expr) } /** - * Computes the floor of the given value. + * Computes the floor of the given value of `e` to `scale` decimal places. + * + * @group math_funcs + * @since 3.3.0 + */ + def floor(e: Column, scale: Column): Column = withExpr { + UnresolvedFunction(Seq("floor"), Seq(e.expr, scale.expr), isDistinct = false) + } + + /** + * Computes the floor of the given value of `e` to 0 decimal places. * * @group math_funcs * @since 1.4.0 */ - def floor(e: Column): Column = withExpr { Floor(e.expr) } + def floor(e: Column): Column = floor(e, lit(0)) /** - * Computes the floor of the given column. + * Computes the floor of the given column value to 0 decimal places. * * @group math_funcs * @since 1.4.0 @@ -2752,7 +2772,7 @@ object functions { * @since 3.3.0 */ def lpad(str: Column, len: Int, pad: Array[Byte]): Column = withExpr { - new BinaryLPad(str.expr, lit(len).expr, lit(pad).expr) + BinaryPad("lpad", str.expr, lit(len).expr, lit(pad).expr) } /** @@ -2841,7 +2861,7 @@ object functions { * @since 3.3.0 */ def rpad(str: Column, len: Int, pad: Array[Byte]): Column = withExpr { - new BinaryRPad(str.expr, lit(len).expr, lit(pad).expr) + BinaryPad("rpad", str.expr, lit(len).expr, lit(pad).expr) } /** @@ -3621,7 +3641,7 @@ object functions { * processing time. * * @param timeColumn The column or the expression to use as the timestamp for windowing by time. - * The time column must be of TimestampType. + * The time column must be of TimestampType or TimestampNTZType. * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`, * `1 second`. Check `org.apache.spark.unsafe.types.CalendarInterval` for * valid duration identifiers. Note that the duration is a fixed length of @@ -3677,7 +3697,7 @@ object functions { * processing time. * * @param timeColumn The column or the expression to use as the timestamp for windowing by time. - * The time column must be of TimestampType. + * The time column must be of TimestampType or TimestampNTZType. * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`, * `1 second`. Check `org.apache.spark.unsafe.types.CalendarInterval` for * valid duration identifiers. Note that the duration is a fixed length of @@ -3722,7 +3742,7 @@ object functions { * processing time. * * @param timeColumn The column or the expression to use as the timestamp for windowing by time. - * The time column must be of TimestampType. + * The time column must be of TimestampType or TimestampNTZType. * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`, * `1 second`. Check `org.apache.spark.unsafe.types.CalendarInterval` for * valid duration identifiers. @@ -3750,7 +3770,7 @@ object functions { * processing time. * * @param timeColumn The column or the expression to use as the timestamp for windowing by time. - * The time column must be of TimestampType. + * The time column must be of TimestampType or TimestampNTZType. * @param gapDuration A string specifying the timeout of the session, e.g. `10 minutes`, * `1 second`. Check `org.apache.spark.unsafe.types.CalendarInterval` for * valid duration identifiers. @@ -3787,7 +3807,7 @@ object functions { * processing time. * * @param timeColumn The column or the expression to use as the timestamp for windowing by time. - * The time column must be of TimestampType. + * The time column must be of TimestampType or TimestampNTZType. * @param gapDuration A column specifying the timeout of the session. It could be static value, * e.g. `10 minutes`, `1 second`, or an expression/UDF that specifies gap * duration dynamically based on the input row. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala index 0b394db5c8932..6af5cc00ef5db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.jdbc -import java.sql.Types +import java.sql.{SQLException, Types} import java.util.Locale +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.types._ private object DB2Dialect extends JdbcDialect { @@ -27,6 +30,37 @@ private object DB2Dialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:db2") + // See https://www.ibm.com/docs/en/db2/11.5?topic=functions-aggregate + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VARIANCE($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VARIANCE_SAMP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_SAMP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false => + assert(f.inputs().length == 2) + Some(s"COVARIANCE(${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false => + assert(f.inputs().length == 2) + Some(s"COVARIANCE_SAMP(${f.inputs().head}, ${f.inputs().last})") + case _ => None + } + ) + } + override def getCatalystType( sqlType: Int, typeName: String, @@ -52,6 +86,15 @@ private object DB2Dialect extends JdbcDialect { override def isCascadingTruncateTable(): Option[Boolean] = Some(false) + // scalastyle:off line.size.limit + // See https://www.ibm.com/support/knowledgecenter/en/SSEPGG_11.5.0/com.ibm.db2.luw.sql.ref.doc/doc/r0053474.html + // scalastyle:on line.size.limit + override def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { + s"TRUNCATE TABLE $table IMMEDIATE" + } + // scalastyle:off line.size.limit // See https://www.ibm.com/support/knowledgecenter/en/SSEPGG_11.5.0/com.ibm.db2.luw.sql.ref.doc/doc/r0000980.html // scalastyle:on line.size.limit @@ -79,4 +122,28 @@ private object DB2Dialect extends JdbcDialect { val nullable = if (isNullable) "DROP NOT NULL" else "SET NOT NULL" s"ALTER TABLE $tableName ALTER COLUMN ${quoteIdentifier(columnName)} $nullable" } + + override def removeSchemaCommentQuery(schema: String): String = { + s"COMMENT ON SCHEMA ${quoteIdentifier(schema)} IS ''" + } + + override def classifyException(message: String, e: Throwable): AnalysisException = { + e match { + case sqlException: SQLException => + sqlException.getSQLState match { + // https://www.ibm.com/docs/en/db2/11.5?topic=messages-sqlstate + case "42893" => throw NonEmptyNamespaceException(message, cause = Some(e)) + case _ => super.classifyException(message, e) + } + case _ => super.classifyException(message, e) + } + } + + override def dropSchema(schema: String, cascade: Boolean): String = { + if (cascade) { + s"DROP SCHEMA ${quoteIdentifier(schema)} CASCADE" + } else { + s"DROP SCHEMA ${quoteIdentifier(schema)} RESTRICT" + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala index f19ef7ead5f8e..bf838b8ed66eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.jdbc import java.sql.Types import java.util.Locale +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ @@ -29,6 +30,27 @@ private object DerbyDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:derby") + // See https://db.apache.org/derby/docs/10.15/ref/index.html + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false => + assert(f.inputs().length == 1) + Some(s"VAR_POP(${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false => + assert(f.inputs().length == 1) + Some(s"VAR_SAMP(${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false => + assert(f.inputs().length == 1) + Some(s"STDDEV_POP(${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false => + assert(f.inputs().length == 1) + Some(s"STDDEV_SAMP(${f.inputs().head})") + case _ => None + } + ) + } + override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.REAL) Option(FloatType) else None @@ -47,7 +69,7 @@ private object DerbyDialect extends JdbcDialect { override def isCascadingTruncateTable(): Option[Boolean] = Some(false) - // See https://db.apache.org/derby/docs/10.5/ref/rrefsqljrenametablestatement.html + // See https://db.apache.org/derby/docs/10.15/ref/rrefsqljrenametablestatement.html override def renameTable(oldTable: String, newTable: String): String = { s"RENAME TABLE $oldTable TO $newTable" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index 1f422e5a59cf8..7bd51f809cd04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -65,20 +65,22 @@ private object H2Dialect extends JdbcDialect { } override def classifyException(message: String, e: Throwable): AnalysisException = { - if (e.isInstanceOf[SQLException]) { - // Error codes are from https://www.h2database.com/javadoc/org/h2/api/ErrorCode.html - e.asInstanceOf[SQLException].getErrorCode match { - // TABLE_OR_VIEW_ALREADY_EXISTS_1 - case 42101 => - throw new TableAlreadyExistsException(message, cause = Some(e)) - // TABLE_OR_VIEW_NOT_FOUND_1 - case 42102 => - throw new NoSuchTableException(message, cause = Some(e)) - // SCHEMA_NOT_FOUND_1 - case 90079 => - throw new NoSuchNamespaceException(message, cause = Some(e)) - case _ => - } + e match { + case exception: SQLException => + // Error codes are from https://www.h2database.com/javadoc/org/h2/api/ErrorCode.html + exception.getErrorCode match { + // TABLE_OR_VIEW_ALREADY_EXISTS_1 + case 42101 => + throw new TableAlreadyExistsException(message, cause = Some(e)) + // TABLE_OR_VIEW_NOT_FOUND_1 + case 42102 => + throw NoSuchTableException(message, cause = Some(e)) + // SCHEMA_NOT_FOUND_1 + case 90079 => + throw NoSuchNamespaceException(message, cause = Some(e)) + case _ => // do nothing + } + case _ => // do nothing } super.classifyException(message, e) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 344842d30b232..c9dcbb2706cd4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.jdbc -import java.sql.{Connection, Date, Timestamp} +import java.sql.{Connection, Date, Driver, Statement, Timestamp} import java.time.{Instant, LocalDate} import java.util @@ -32,10 +32,12 @@ import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, Timesta import org.apache.spark.sql.connector.catalog.TableChange import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.connector.catalog.index.TableIndex -import org.apache.spark.sql.connector.expressions.NamedReference -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum} +import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, NamedReference} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, Max, Min, Sum} +import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} +import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, JDBCOptions, JdbcUtils} +import org.apache.spark.sql.execution.datasources.jdbc.connection.ConnectionProvider import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -99,6 +101,29 @@ abstract class JdbcDialect extends Serializable with Logging{ */ def getJDBCType(dt: DataType): Option[JdbcType] = None + /** + * Returns a factory for creating connections to the given JDBC URL. + * In general, creating a connection has nothing to do with JDBC partition id. + * But sometimes it is needed, such as a database with multiple shard nodes. + * @param options - JDBC options that contains url, table and other information. + * @return The factory method for creating JDBC connections with the RDD partition ID. -1 means + the connection is being created at the driver side. + * @throws IllegalArgumentException if the driver could not open a JDBC connection. + */ + @Since("3.3.0") + def createConnectionFactory(options: JDBCOptions): Int => Connection = { + val driverClass: String = options.driverClass + (partitionId: Int) => { + DriverRegistry.register(driverClass) + val driver: Driver = DriverRegistry.get(driverClass) + val connection = + ConnectionProvider.create(driver, options.parameters, options.connectionProviderName) + require(connection != null, + s"The driver could not open a JDBC connection. Check the URL: ${options.url}") + connection + } + } + /** * Quotes the identifier. This is used to put quotes around the identifier in case the column * name is a reserved keyword, or in case it contains characters that require quotes (e.g. space). @@ -194,6 +219,31 @@ abstract class JdbcDialect extends Serializable with Logging{ case _ => value } + class JDBCSQLBuilder extends V2ExpressionSQLBuilder { + override def visitFieldReference(fieldRef: FieldReference): String = { + if (fieldRef.fieldNames().length != 1) { + throw new IllegalArgumentException( + "FieldReference with field name has multiple or zero parts unsupported: " + fieldRef); + } + quoteIdentifier(fieldRef.fieldNames.head) + } + } + + /** + * Converts V2 expression to String representing a SQL expression. + * @param expr The V2 expression to be converted. + * @return Converted value. + */ + @Since("3.3.0") + def compileExpression(expr: Expression): Option[String] = { + val jdbcSQLBuilder = new JDBCSQLBuilder() + try { + Some(jdbcSQLBuilder.build(expr)) + } catch { + case _: IllegalArgumentException => None + } + } + /** * Converts aggregate function to String representing a SQL expression. * @param aggFunction The aggregate function to be converted. @@ -203,31 +253,63 @@ abstract class JdbcDialect extends Serializable with Logging{ def compileAggregate(aggFunction: AggregateFunc): Option[String] = { aggFunction match { case min: Min => - if (min.column.fieldNames.length != 1) return None - Some(s"MIN(${quoteIdentifier(min.column.fieldNames.head)})") + compileExpression(min.column).map(v => s"MIN($v)") case max: Max => - if (max.column.fieldNames.length != 1) return None - Some(s"MAX(${quoteIdentifier(max.column.fieldNames.head)})") + compileExpression(max.column).map(v => s"MAX($v)") case count: Count => - if (count.column.fieldNames.length != 1) return None val distinct = if (count.isDistinct) "DISTINCT " else "" - val column = quoteIdentifier(count.column.fieldNames.head) - Some(s"COUNT($distinct$column)") + compileExpression(count.column).map(v => s"COUNT($distinct$v)") case sum: Sum => - if (sum.column.fieldNames.length != 1) return None val distinct = if (sum.isDistinct) "DISTINCT " else "" - val column = quoteIdentifier(sum.column.fieldNames.head) - Some(s"SUM($distinct$column)") + compileExpression(sum.column).map(v => s"SUM($distinct$v)") case _: CountStar => Some("COUNT(*)") - case f: GeneralAggregateFunc if f.name() == "AVG" => - assert(f.inputs().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"AVG($distinct${f.inputs().head})") + case avg: Avg => + val distinct = if (avg.isDistinct) "DISTINCT " else "" + compileExpression(avg.column).map(v => s"AVG($distinct$v)") case _ => None } } + /** + * Create schema with an optional comment. Empty string means no comment. + */ + def createSchema(statement: Statement, schema: String, comment: String): Unit = { + val schemaCommentQuery = if (comment.nonEmpty) { + // We generate comment query here so that it can fail earlier without creating the schema. + getSchemaCommentQuery(schema, comment) + } else { + comment + } + statement.executeUpdate(s"CREATE SCHEMA ${quoteIdentifier(schema)}") + if (comment.nonEmpty) { + statement.executeUpdate(schemaCommentQuery) + } + } + + /** + * Check schema exists or not. + */ + def schemasExists(conn: Connection, options: JDBCOptions, schema: String): Boolean = { + val rs = conn.getMetaData.getSchemas(null, schema) + while (rs.next()) { + if (rs.getString(1) == schema) return true; + } + false + } + + /** + * Lists all the schemas in this table. + */ + def listSchemas(conn: Connection, options: JDBCOptions): Array[Array[String]] = { + val schemaBuilder = ArrayBuilder.make[Array[String]] + val rs = conn.getMetaData.getSchemas() + while (rs.next()) { + schemaBuilder += Array(rs.getString(1)) + } + schemaBuilder.result + } + /** * Return Some[true] iff `TRUNCATE TABLE` causes cascading default. * Some[true] : TRUNCATE TABLE causes cascading. @@ -326,6 +408,14 @@ abstract class JdbcDialect extends Serializable with Logging{ s"COMMENT ON SCHEMA ${quoteIdentifier(schema)} IS NULL" } + def dropSchema(schema: String, cascade: Boolean): String = { + if (cascade) { + s"DROP SCHEMA ${quoteIdentifier(schema)} CASCADE" + } else { + s"DROP SCHEMA ${quoteIdentifier(schema)}" + } + } + /** * Build a create index SQL statement. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 8e5674a181e7a..841f1c87319b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -17,8 +17,12 @@ package org.apache.spark.sql.jdbc +import java.sql.SQLException import java.util.Locale +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -36,6 +40,33 @@ private object MsSqlServerDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:sqlserver") + // scalastyle:off line.size.limit + // See https://docs.microsoft.com/en-us/sql/t-sql/functions/aggregate-functions-transact-sql?view=sql-server-ver15 + // scalastyle:on line.size.limit + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VARP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDEVP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDEV($distinct${f.inputs().head})") + case _ => None + } + ) + } + override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (typeName.contains("datetimeoffset")) { @@ -122,4 +153,15 @@ private object MsSqlServerDialect extends JdbcDialect { override def getLimitClause(limit: Integer): String = { "" } + + override def classifyException(message: String, e: Throwable): AnalysisException = { + e match { + case sqlException: SQLException => + sqlException.getErrorCode match { + case 3729 => throw NonEmptyNamespaceException(message, cause = Some(e)) + case _ => super.classifyException(message, e) + } + case _ => super.classifyException(message, e) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index fb98996e6bf8b..d73721de962d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -21,11 +21,14 @@ import java.sql.{Connection, SQLException, Types} import java.util import java.util.Locale +import scala.collection.mutable.ArrayBuilder + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException} import org.apache.spark.sql.connector.catalog.index.TableIndex import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.types.{BooleanType, DataType, FloatType, LongType, MetadataBuilder} @@ -35,6 +38,27 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper { override def canHandle(url : String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:mysql") + // See https://dev.mysql.com/doc/refman/8.0/en/aggregate-functions.html + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false => + assert(f.inputs().length == 1) + Some(s"VAR_POP(${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false => + assert(f.inputs().length == 1) + Some(s"VAR_SAMP(${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false => + assert(f.inputs().length == 1) + Some(s"STDDEV_POP(${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false => + assert(f.inputs().length == 1) + Some(s"STDDEV_SAMP(${f.inputs().head})") + case _ => None + } + ) + } + override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { @@ -51,6 +75,25 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper { s"`$colName`" } + override def schemasExists(conn: Connection, options: JDBCOptions, schema: String): Boolean = { + listSchemas(conn, options).exists(_.head == schema) + } + + override def listSchemas(conn: Connection, options: JDBCOptions): Array[Array[String]] = { + val schemaBuilder = ArrayBuilder.make[Array[String]] + try { + JdbcUtils.executeQuery(conn, options, "SHOW SCHEMAS") { rs => + while (rs.next()) { + schemaBuilder += Array(rs.getString("Database")) + } + } + } catch { + case _: Exception => + logWarning("Cannot show schemas.") + } + schemaBuilder.result + } + override def getTableExistsQuery(table: String): String = { s"SELECT 1 FROM $table LIMIT 1" } @@ -109,6 +152,14 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper { case _ => JdbcUtils.getCommonJDBCType(dt) } + override def getSchemaCommentQuery(schema: String, comment: String): String = { + throw QueryExecutionErrors.unsupportedCreateNamespaceCommentError() + } + + override def removeSchemaCommentQuery(schema: String): String = { + throw QueryExecutionErrors.unsupportedRemoveNamespaceCommentError() + } + // CREATE INDEX syntax // https://dev.mysql.com/doc/refman/8.0/en/create-index.html override def createIndex( @@ -150,26 +201,27 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper { val sql = s"SHOW INDEXES FROM $tableName" var indexMap: Map[String, TableIndex] = Map() try { - val rs = JdbcUtils.executeQuery(conn, options, sql) - while (rs.next()) { - val indexName = rs.getString("key_name") - val colName = rs.getString("column_name") - val indexType = rs.getString("index_type") - val indexComment = rs.getString("Index_comment") - if (indexMap.contains(indexName)) { - val index = indexMap.get(indexName).get - val newIndex = new TableIndex(indexName, indexType, - index.columns() :+ FieldReference(colName), - index.columnProperties, index.properties) - indexMap += (indexName -> newIndex) - } else { - // The only property we are building here is `COMMENT` because it's the only one - // we can get from `SHOW INDEXES`. - val properties = new util.Properties(); - if (indexComment.nonEmpty) properties.put("COMMENT", indexComment) - val index = new TableIndex(indexName, indexType, Array(FieldReference(colName)), - new util.HashMap[NamedReference, util.Properties](), properties) - indexMap += (indexName -> index) + JdbcUtils.executeQuery(conn, options, sql) { rs => + while (rs.next()) { + val indexName = rs.getString("key_name") + val colName = rs.getString("column_name") + val indexType = rs.getString("index_type") + val indexComment = rs.getString("Index_comment") + if (indexMap.contains(indexName)) { + val index = indexMap.get(indexName).get + val newIndex = new TableIndex(indexName, indexType, + index.columns() :+ FieldReference(colName), + index.columnProperties, index.properties) + indexMap += (indexName -> newIndex) + } else { + // The only property we are building here is `COMMENT` because it's the only one + // we can get from `SHOW INDEXES`. + val properties = new util.Properties(); + if (indexComment.nonEmpty) properties.put("COMMENT", indexComment) + val index = new TableIndex(indexName, indexType, Array(FieldReference(colName)), + new util.HashMap[NamedReference, util.Properties](), properties) + indexMap += (indexName -> index) + } } } } catch { @@ -194,4 +246,12 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper { case _ => super.classifyException(message, e) } } + + override def dropSchema(schema: String, cascade: Boolean): String = { + if (cascade) { + s"DROP SCHEMA ${quoteIdentifier(schema)}" + } else { + throw QueryExecutionErrors.unsupportedDropNamespaceRestrictError() + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index b741ece8dda9b..71db7e9285f5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp, Types} import java.util.{Locale, TimeZone} import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -33,6 +34,38 @@ private case object OracleDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:oracle") + // scalastyle:off line.size.limit + // https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/Aggregate-Functions.html#GUID-62BE676B-AF18-4E63-BD14-25206FEA0848 + // scalastyle:on line.size.limit + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false => + assert(f.inputs().length == 1) + Some(s"VAR_POP(${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false => + assert(f.inputs().length == 1) + Some(s"VAR_SAMP(${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false => + assert(f.inputs().length == 1) + Some(s"STDDEV_POP(${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false => + assert(f.inputs().length == 1) + Some(s"STDDEV_SAMP(${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false => + assert(f.inputs().length == 2) + Some(s"COVAR_POP(${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false => + assert(f.inputs().length == 2) + Some(s"COVAR_SAMP(${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "CORR" && f.isDistinct == false => + assert(f.inputs().length == 2) + Some(s"CORR(${f.inputs().head}, ${f.inputs().last})") + case _ => None + } + ) + } + private def supportTimeZoneTypes: Boolean = { val timeZone = DateTimeUtils.getTimeZone(SQLConf.get.sessionLocalTimeZone) // TODO: support timezone types when users are not using the JVM timezone, which diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 356cb4ddbd008..e2023d110ae4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -23,8 +23,9 @@ import java.util.Locale import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.SQLConfHelper -import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException} +import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NonEmptyNamespaceException, NoSuchIndexException} import org.apache.spark.sql.connector.expressions.NamedReference +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.types._ @@ -35,6 +36,43 @@ private object PostgresDialect extends JdbcDialect with SQLConfHelper { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:postgresql") + // See https://www.postgresql.org/docs/8.4/functions-aggregate.html + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_POP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_SAMP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_POP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_SAMP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "COVAR_POP" => + assert(f.inputs().length == 2) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"COVAR_POP($distinct${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" => + assert(f.inputs().length == 2) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"COVAR_SAMP($distinct${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "CORR" => + assert(f.inputs().length == 2) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"CORR($distinct${f.inputs().head}, ${f.inputs().last})") + case _ => None + } + ) + } + override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.REAL) { @@ -215,6 +253,7 @@ private object PostgresDialect extends JdbcDialect with SQLConfHelper { // https://www.postgresql.org/docs/14/errcodes-appendix.html case "42P07" => throw new IndexAlreadyExistsException(message, cause = Some(e)) case "42704" => throw new NoSuchIndexException(message, cause = Some(e)) + case "2BP01" => throw NonEmptyNamespaceException(message, cause = Some(e)) case _ => super.classifyException(message, e) } case unsupported: UnsupportedOperationException => throw unsupported diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala index 13f4c5fe9c926..13e16d24d048d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.jdbc import java.util.Locale +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.types._ @@ -27,6 +28,42 @@ private case object TeradataDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:teradata") + // scalastyle:off line.size.limit + // See https://docs.teradata.com/r/Teradata-VantageTM-SQL-Functions-Expressions-and-Predicates/March-2019/Aggregate-Functions + // scalastyle:on line.size.limit + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_POP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_SAMP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_POP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_SAMP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false => + assert(f.inputs().length == 2) + Some(s"COVAR_POP(${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false => + assert(f.inputs().length == 2) + Some(s"COVAR_SAMP(${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "CORR" && f.isDistinct == false => + assert(f.inputs().length == 2) + Some(s"CORR(${f.inputs().head}, ${f.inputs().last})") + case _ => None + } + ) + } + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case StringType => Some(JdbcType("VARCHAR(255)", java.sql.Types.VARCHAR)) case BooleanType => Option(JdbcType("CHAR(1)", java.sql.Types.CHAR)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index f72d03ecc62be..af058315f7caf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -290,7 +290,6 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * TODO (SPARK-33638): Full support of v2 table creation */ val tableSpec = TableSpec( - None, Map.empty[String, String], Some(source), Map.empty[String, String], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala index 6ca9aacab7247..fe187917ec021 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala @@ -61,7 +61,7 @@ class StreamingQueryStatus protected[sql]( } private[sql] def jsonValue: JValue = { - ("message" -> JString(message.toString)) ~ + ("message" -> JString(message)) ~ ("isDataAvailable" -> JBool(isDataAvailable)) ~ ("isTriggerActive" -> JBool(isTriggerActive)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala index 97691d9d7e827..e13ac4e487c95 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.streaming.ui import java.{util => ju} import java.lang.{Long => JLong} -import java.util.{Locale, UUID} +import java.util.Locale import javax.servlet.http.HttpServletRequest import scala.collection.JavaConverters._ @@ -59,7 +59,7 @@ private[ui] class StreamingQueryStatisticsPage(parent: StreamingQueryTab) require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") val query = parent.store.allQueryUIData.find { uiData => - uiData.summary.runId.equals(UUID.fromString(parameterId)) + uiData.summary.runId.equals(parameterId) }.getOrElse(throw new IllegalArgumentException(s"Failed to find streaming query $parameterId")) val resources = generateLoadResources(request) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatusListener.scala index fdd3754344108..55ceab245a968 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatusListener.scala @@ -64,7 +64,7 @@ private[sql] class StreamingQueryStatusListener( .take(numInactiveQueries - inactiveQueryStatusRetention) val runIds = toDelete.map { e => store.delete(e.getClass, e.runId) - e.runId.toString + e.runId } // Delete wrappers in one pass, as deleting them for each summary is slow store.removeAllByIndexValues(classOf[StreamingQueryProgressWrapper], "runId", runIds) @@ -75,7 +75,7 @@ private[sql] class StreamingQueryStatusListener( store.write(new StreamingQueryData( event.name, event.id, - event.runId, + event.runId.toString, isActive = true, None, startTimestamp @@ -100,7 +100,7 @@ private[sql] class StreamingQueryStatusListener( override def onQueryTerminated( event: StreamingQueryListener.QueryTerminatedEvent): Unit = { - val querySummary = store.read(classOf[StreamingQueryData], event.runId) + val querySummary = store.read(classOf[StreamingQueryData], event.runId.toString) val curTime = System.currentTimeMillis() store.write(new StreamingQueryData( querySummary.name, @@ -118,7 +118,7 @@ private[sql] class StreamingQueryStatusListener( private[sql] class StreamingQueryData( val name: String, val id: UUID, - @KVIndexParam val runId: UUID, + @KVIndexParam val runId: String, @KVIndexParam("active") val isActive: Boolean, val exception: Option[String], @KVIndexParam("startTimestamp") val startTimestamp: Long, diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index da7c62251b385..c0b4690dd6260 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -33,6 +33,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.Column; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; @@ -318,6 +319,21 @@ public void testSampleBy() { Assert.assertTrue(2 <= actual.get(1).getLong(1) && actual.get(1).getLong(1) <= 13); } + @Test + public void testwithColumns() { + Dataset df = spark.table("testData2"); + Map colMaps = new HashMap<>(); + colMaps.put("a1", col("a")); + colMaps.put("b1", col("b")); + + StructType expected = df.withColumn("a1", col("a")).withColumn("b1", col("b")).schema(); + StructType actual = df.withColumns(colMaps).schema(); + // Validate geting same result with withColumn loop call + Assert.assertEquals(expected, actual); + // Validate the col names + Assert.assertArrayEquals(actual.fieldNames(), new String[] {"a", "b", "a1", "b1"}); + } + @Test public void testSampleByColumn() { Dataset df = spark.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 6ff53a84f328c..22978fb8c286e 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -279,7 +279,7 @@ public void testMappingFunctionWithTestGroupState() throws Exception { Assert.assertTrue(prevState.isUpdated()); Assert.assertFalse(prevState.isRemoved()); Assert.assertTrue(prevState.exists()); - Assert.assertEquals(new Integer(9), prevState.get()); + Assert.assertEquals(Integer.valueOf(9), prevState.get()); Assert.assertEquals(0L, prevState.getCurrentProcessingTimeMs()); Assert.assertEquals(1000L, prevState.getCurrentWatermarkMs()); Assert.assertEquals(Optional.of(1500L), prevState.getTimeoutTimestampMs()); @@ -289,7 +289,7 @@ public void testMappingFunctionWithTestGroupState() throws Exception { Assert.assertTrue(prevState.isUpdated()); Assert.assertFalse(prevState.isRemoved()); Assert.assertTrue(prevState.exists()); - Assert.assertEquals(new Integer(18), prevState.get()); + Assert.assertEquals(Integer.valueOf(18), prevState.get()); prevState = TestGroupState.create( Optional.of(9), GroupStateTimeout.EventTimeTimeout(), 0L, Optional.of(1000L), true); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java index ca78d6489ef5c..37f49ce5705de 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java @@ -145,7 +145,6 @@ public void constructComplexRow() { doubleValue, stringValue, timestampValue, null); // Complex array - @SuppressWarnings("unchecked") List> arrayOfMaps = Arrays.asList(simpleMap); List arrayOfRows = Arrays.asList(simpleStruct); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java index 08dc129f27a0c..1da5fb4b64cbb 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java @@ -43,7 +43,6 @@ public void tearDown() { spark = null; } - @SuppressWarnings("unchecked") @Test public void udf1Test() { spark.range(1, 10).toDF("value").createOrReplaceTempView("df"); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index 7e938ca88d8b9..cd64f858b1473 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -55,7 +55,6 @@ public void tearDown() { spark = null; } - @SuppressWarnings("unchecked") @Test public void udf1Test() { spark.udf().register("stringLengthTest", (String str) -> str.length(), DataTypes.IntegerType); @@ -64,7 +63,6 @@ public void udf1Test() { Assert.assertEquals(4, result.getInt(0)); } - @SuppressWarnings("unchecked") @Test public void udf2Test() { spark.udf().register("stringLengthTest", @@ -81,7 +79,6 @@ public Integer call(String str1, String str2) { } } - @SuppressWarnings("unchecked") @Test public void udf3Test() { spark.udf().registerJava("stringLengthTest", StringLengthTest.class.getName(), @@ -95,7 +92,6 @@ public void udf3Test() { Assert.assertEquals(9, result.getInt(0)); } - @SuppressWarnings("unchecked") @Test public void udf4Test() { spark.udf().register("inc", (Long i) -> i + 1, DataTypes.LongType); @@ -111,14 +107,12 @@ public void udf4Test() { Assert.assertEquals(55, sum); } - @SuppressWarnings("unchecked") @Test(expected = AnalysisException.class) public void udf5Test() { spark.udf().register("inc", (Long i) -> i + 1, DataTypes.LongType); List results = spark.sql("SELECT inc(1, 5)").collectAsList(); } - @SuppressWarnings("unchecked") @Test public void udf6Test() { spark.udf().register("returnOne", () -> 1, DataTypes.IntegerType); @@ -126,7 +120,6 @@ public void udf6Test() { Assert.assertEquals(1, result.getInt(0)); } - @SuppressWarnings("unchecked") @Test public void udf7Test() { String originConf = spark.conf().get(SQLConf.DATETIME_JAVA8API_ENABLED().key()); @@ -142,7 +135,6 @@ public void udf7Test() { } } - @SuppressWarnings("unchecked") @Test public void sourceTest() { spark.udf().register("stringLengthTest", (String str) -> str.length(), DataTypes.IntegerType); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaLongAdd.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaLongAdd.java index e5b9c7f5bafaa..75ef5275684d6 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaLongAdd.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaLongAdd.java @@ -66,7 +66,7 @@ public String description() { return "long_add"; } - private abstract static class JavaLongAddBase implements ScalarFunction { + public abstract static class JavaLongAddBase implements ScalarFunction { private final boolean isResultNullable; JavaLongAddBase(boolean isResultNullable) { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaRandomAdd.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaRandomAdd.java new file mode 100644 index 0000000000000..b315fafd8ece8 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaRandomAdd.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.connector.catalog.functions; + +import java.util.Random; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.functions.BoundFunction; +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.StructType; + +/** + * Test V2 function which add a random number to the input integer. + */ +public class JavaRandomAdd implements UnboundFunction { + private final BoundFunction fn; + + public JavaRandomAdd(BoundFunction fn) { + this.fn = fn; + } + + @Override + public String name() { + return "rand"; + } + + @Override + public BoundFunction bind(StructType inputType) { + if (inputType.fields().length != 1) { + throw new UnsupportedOperationException("Expect exactly one argument"); + } + if (inputType.fields()[0].dataType() instanceof IntegerType) { + return fn; + } + throw new UnsupportedOperationException("Expect IntegerType"); + } + + @Override + public String description() { + return "rand_add: add a random integer to the input\n" + + "rand_add(int) -> int"; + } + + public abstract static class JavaRandomAddBase implements ScalarFunction { + @Override + public DataType[] inputTypes() { + return new DataType[] { DataTypes.IntegerType }; + } + + @Override + public DataType resultType() { + return DataTypes.IntegerType; + } + + @Override + public String name() { + return "rand_add"; + } + + @Override + public boolean isDeterministic() { + return false; + } + } + + public static class JavaRandomAddDefault extends JavaRandomAddBase { + private final Random rand = new Random(); + + @Override + public Integer produceResult(InternalRow input) { + return input.getInt(0) + rand.nextInt(); + } + } + + public static class JavaRandomAddMagic extends JavaRandomAddBase { + private final Random rand = new Random(); + + public int invoke(int input) { + return input + rand.nextInt(); + } + } + + public static class JavaRandomAddStaticMagic extends JavaRandomAddBase { + private static final Random rand = new Random(); + + public static int invoke(int input) { + return input + rand.nextInt(); + } + } +} + diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java index 1b1689668e1f6..dade2a113ef45 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java @@ -49,7 +49,7 @@ public BoundFunction bind(StructType inputType) { return fn; } - throw new UnsupportedOperationException("Except StringType"); + throw new UnsupportedOperationException("Expect StringType"); } @Override diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.hadoop.crypto.key.KeyProviderFactory b/sql/core/src/test/resources/META-INF/services/org.apache.hadoop.crypto.key.KeyProviderFactory index f436622b5fb42..246058e0bed70 100644 --- a/sql/core/src/test/resources/META-INF/services/org.apache.hadoop.crypto.key.KeyProviderFactory +++ b/sql/core/src/test/resources/META-INF/services/org.apache.hadoop.crypto.key.KeyProviderFactory @@ -1,3 +1,4 @@ +# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. @@ -5,12 +6,13 @@ # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# test.org.apache.spark.sql.execution.datasources.orc.FakeKeyProvider$Factory diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.SparkSessionExtensionsProvider b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.SparkSessionExtensionsProvider index b5b01a09e6995..0584b8c8b4f0d 100644 --- a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.SparkSessionExtensionsProvider +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.SparkSessionExtensionsProvider @@ -1 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + org.apache.spark.sql.YourExtensions diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.jdbc.JdbcConnectionProvider b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.jdbc.JdbcConnectionProvider index afb48e1a3511f..bf8d78edef4ce 100644 --- a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.jdbc.JdbcConnectionProvider +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.jdbc.JdbcConnectionProvider @@ -1 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + org.apache.spark.sql.execution.datasources.jdbc.connection.IntentionallyFaultyConnectionProvider \ No newline at end of file diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index dd22970203b3c..c1fc7234d7c19 100644 --- a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + org.apache.spark.sql.sources.FakeSourceOne org.apache.spark.sql.sources.FakeSourceTwo org.apache.spark.sql.sources.FakeSourceThree diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 07e1d00ca545d..386dd1fe0ae17 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -1,6 +1,6 @@ ## Summary - - Number of queries: 375 + - Number of queries: 382 - Number of expressions that missing example: 12 - Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint ## Schema of Built-in Functions @@ -11,8 +11,8 @@ | org.apache.spark.sql.catalyst.expressions.Acosh | acosh | SELECT acosh(1) | struct | | org.apache.spark.sql.catalyst.expressions.Add | + | SELECT 1 + 2 | struct<(1 + 2):int> | | org.apache.spark.sql.catalyst.expressions.AddMonths | add_months | SELECT add_months('2016-08-31', 1) | struct | -| org.apache.spark.sql.catalyst.expressions.AesDecrypt | aes_decrypt | SELECT aes_decrypt(unhex('83F16B2AA704794132802D248E6BFD4E380078182D1544813898AC97E709B28A94'), '0000111122223333') | struct | -| org.apache.spark.sql.catalyst.expressions.AesEncrypt | aes_encrypt | SELECT hex(aes_encrypt('Spark', '0000111122223333')) | struct | +| org.apache.spark.sql.catalyst.expressions.AesDecrypt | aes_decrypt | SELECT aes_decrypt(unhex('83F16B2AA704794132802D248E6BFD4E380078182D1544813898AC97E709B28A94'), '0000111122223333') | struct | +| org.apache.spark.sql.catalyst.expressions.AesEncrypt | aes_encrypt | SELECT hex(aes_encrypt('Spark', '0000111122223333')) | struct | | org.apache.spark.sql.catalyst.expressions.And | and | SELECT true and true | struct<(true AND true):boolean> | | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | aggregate | SELECT aggregate(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayContains | array_contains | SELECT array_contains(array(1, 2, 3), 2) | struct | @@ -28,6 +28,7 @@ | org.apache.spark.sql.catalyst.expressions.ArrayPosition | array_position | SELECT array_position(array(3, 2, 1), 1) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayRemove | array_remove | SELECT array_remove(array(1, 2, 3, null, 3), 3) | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayRepeat | array_repeat | SELECT array_repeat('123', 2) | struct> | +| org.apache.spark.sql.catalyst.expressions.ArraySize | array_size | SELECT array_size(array('b', 'd', 'c', 'a')) | struct | | org.apache.spark.sql.catalyst.expressions.ArraySort | array_sort | SELECT array_sort(array(5, 6, 1), (left, right) -> case when left < right then -1 when left > right then 1 else 0 end) | struct namedlambdavariable()) THEN 1 ELSE 0 END, namedlambdavariable(), namedlambdavariable())):array> | | org.apache.spark.sql.catalyst.expressions.ArrayTransform | transform | SELECT transform(array(1, 2, 3), x -> x + 1) | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayUnion | array_union | SELECT array_union(array(1, 2, 3), array(1, 3, 5)) | struct> | @@ -68,14 +69,14 @@ | org.apache.spark.sql.catalyst.expressions.Cast | timestamp | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.Cast | tinyint | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.Cbrt | cbrt | SELECT cbrt(27.0) | struct | -| org.apache.spark.sql.catalyst.expressions.Ceil | ceil | SELECT ceil(-0.1) | struct | -| org.apache.spark.sql.catalyst.expressions.Ceil | ceiling | SELECT ceiling(-0.1) | struct | +| org.apache.spark.sql.catalyst.expressions.CeilExpressionBuilder$ | ceil | SELECT ceil(-0.1) | struct | +| org.apache.spark.sql.catalyst.expressions.CeilExpressionBuilder$ | ceiling | SELECT ceiling(-0.1) | struct | | org.apache.spark.sql.catalyst.expressions.Chr | char | SELECT char(65) | struct | | org.apache.spark.sql.catalyst.expressions.Chr | chr | SELECT chr(65) | struct | | org.apache.spark.sql.catalyst.expressions.Coalesce | coalesce | SELECT coalesce(NULL, 1, NULL) | struct | | org.apache.spark.sql.catalyst.expressions.Concat | concat | SELECT concat('Spark', 'SQL') | struct | | org.apache.spark.sql.catalyst.expressions.ConcatWs | concat_ws | SELECT concat_ws(' ', 'Spark', 'SQL') | struct | -| org.apache.spark.sql.catalyst.expressions.Contains | contains | SELECT contains('Spark SQL', 'Spark') | struct | +| org.apache.spark.sql.catalyst.expressions.ContainsExpressionBuilder$ | contains | SELECT contains('Spark SQL', 'Spark') | struct | | org.apache.spark.sql.catalyst.expressions.Conv | conv | SELECT conv('100', 2, 10) | struct | | org.apache.spark.sql.catalyst.expressions.ConvertTimezone | convert_timezone | SELECT convert_timezone('Europe/Amsterdam', 'America/Los_Angeles', timestamp_ntz'2021-12-06 00:00:00') | struct | | org.apache.spark.sql.catalyst.expressions.Cos | cos | SELECT cos(0) | struct | @@ -99,7 +100,7 @@ | org.apache.spark.sql.catalyst.expressions.DateDiff | datediff | SELECT datediff('2009-07-31', '2009-07-30') | struct | | org.apache.spark.sql.catalyst.expressions.DateFormatClass | date_format | SELECT date_format('2016-04-08', 'y') | struct | | org.apache.spark.sql.catalyst.expressions.DateFromUnixDate | date_from_unix_date | SELECT date_from_unix_date(1) | struct | -| org.apache.spark.sql.catalyst.expressions.DatePart | date_part | SELECT date_part('YEAR', TIMESTAMP '2019-08-12 01:00:00.123456') | struct | +| org.apache.spark.sql.catalyst.expressions.DatePartExpressionBuilder$ | date_part | SELECT date_part('YEAR', TIMESTAMP '2019-08-12 01:00:00.123456') | struct | | org.apache.spark.sql.catalyst.expressions.DateSub | date_sub | SELECT date_sub('2016-07-30', 1) | struct | | org.apache.spark.sql.catalyst.expressions.DayOfMonth | day | SELECT day('2009-07-30') | struct | | org.apache.spark.sql.catalyst.expressions.DayOfMonth | dayofmonth | SELECT dayofmonth('2009-07-30') | struct | @@ -111,7 +112,7 @@ | org.apache.spark.sql.catalyst.expressions.ElementAt | element_at | SELECT element_at(array(1, 2, 3), 2) | struct | | org.apache.spark.sql.catalyst.expressions.Elt | elt | SELECT elt(1, 'scala', 'java') | struct | | org.apache.spark.sql.catalyst.expressions.Encode | encode | SELECT encode('abc', 'utf-8') | struct | -| org.apache.spark.sql.catalyst.expressions.EndsWith | endswith | SELECT endswith('Spark SQL', 'SQL') | struct | +| org.apache.spark.sql.catalyst.expressions.EndsWithExpressionBuilder$ | endswith | SELECT endswith('Spark SQL', 'SQL') | struct | | org.apache.spark.sql.catalyst.expressions.EqualNullSafe | <=> | SELECT 2 <=> 2 | struct<(2 <=> 2):boolean> | | org.apache.spark.sql.catalyst.expressions.EqualTo | = | SELECT 2 = 2 | struct<(2 = 2):boolean> | | org.apache.spark.sql.catalyst.expressions.EqualTo | == | SELECT 2 == 2 | struct<(2 = 2):boolean> | @@ -124,7 +125,7 @@ | org.apache.spark.sql.catalyst.expressions.Factorial | factorial | SELECT factorial(5) | struct | | org.apache.spark.sql.catalyst.expressions.FindInSet | find_in_set | SELECT find_in_set('ab','abc,b,ab,c,def') | struct | | org.apache.spark.sql.catalyst.expressions.Flatten | flatten | SELECT flatten(array(array(1, 2), array(3, 4))) | struct> | -| org.apache.spark.sql.catalyst.expressions.Floor | floor | SELECT floor(-0.1) | struct | +| org.apache.spark.sql.catalyst.expressions.FloorExpressionBuilder$ | floor | SELECT floor(-0.1) | struct | | org.apache.spark.sql.catalyst.expressions.FormatNumber | format_number | SELECT format_number(12332.123456, 4) | struct | | org.apache.spark.sql.catalyst.expressions.FormatString | format_string | SELECT format_string("Hello World %d %s", 100, "days") | struct | | org.apache.spark.sql.catalyst.expressions.FormatString | printf | SELECT printf("Hello World %d %s", 100, "days") | struct | @@ -141,7 +142,6 @@ | org.apache.spark.sql.catalyst.expressions.Hypot | hypot | SELECT hypot(3, 4) | struct | | org.apache.spark.sql.catalyst.expressions.ILike | ilike | SELECT ilike('Spark', '_Park') | struct | | org.apache.spark.sql.catalyst.expressions.If | if | SELECT if(1 < 2, 'a', 'b') | struct<(IF((1 < 2), a, b)):string> | -| org.apache.spark.sql.catalyst.expressions.IfNull | ifnull | SELECT ifnull(NULL, array('2')) | struct> | | org.apache.spark.sql.catalyst.expressions.In | in | SELECT 1 in(1, 2, 3) | struct<(1 IN (1, 2, 3)):boolean> | | org.apache.spark.sql.catalyst.expressions.InitCap | initcap | SELECT initcap('sPark sql') | struct | | org.apache.spark.sql.catalyst.expressions.Inline | inline | SELECT inline(array(struct(1, 'a'), struct(2, 'b'))) | struct | @@ -182,8 +182,8 @@ | org.apache.spark.sql.catalyst.expressions.MakeDate | make_date | SELECT make_date(2013, 7, 15) | struct | | org.apache.spark.sql.catalyst.expressions.MakeInterval | make_interval | SELECT make_interval(100, 11, 1, 1, 12, 30, 01.001001) | struct | | org.apache.spark.sql.catalyst.expressions.MakeTimestamp | make_timestamp | SELECT make_timestamp(2014, 12, 28, 6, 30, 45.887) | struct | -| org.apache.spark.sql.catalyst.expressions.MakeTimestampLTZ | make_timestamp_ltz | SELECT make_timestamp_ltz(2014, 12, 28, 6, 30, 45.887) | struct | -| org.apache.spark.sql.catalyst.expressions.MakeTimestampNTZ | make_timestamp_ntz | SELECT make_timestamp_ntz(2014, 12, 28, 6, 30, 45.887) | struct | +| org.apache.spark.sql.catalyst.expressions.MakeTimestampLTZExpressionBuilder$ | make_timestamp_ltz | SELECT make_timestamp_ltz(2014, 12, 28, 6, 30, 45.887) | struct | +| org.apache.spark.sql.catalyst.expressions.MakeTimestampNTZExpressionBuilder$ | make_timestamp_ntz | SELECT make_timestamp_ntz(2014, 12, 28, 6, 30, 45.887) | struct | | org.apache.spark.sql.catalyst.expressions.MakeYMInterval | make_ym_interval | SELECT make_ym_interval(1, 2) | struct | | org.apache.spark.sql.catalyst.expressions.MapConcat | map_concat | SELECT map_concat(map(1, 'a', 2, 'b'), map(3, 'c')) | struct> | | org.apache.spark.sql.catalyst.expressions.MapContainsKey | map_contains_key | SELECT map_contains_key(map(1, 'a', 2, 'b'), 1) | struct | @@ -211,6 +211,7 @@ | org.apache.spark.sql.catalyst.expressions.Now | now | SELECT now() | struct | | org.apache.spark.sql.catalyst.expressions.NthValue | nth_value | SELECT a, b, nth_value(b, 2) OVER (PARTITION BY a ORDER BY b) FROM VALUES ('A1', 2), ('A1', 1), ('A2', 3), ('A1', 1) tab(a, b) | struct | | org.apache.spark.sql.catalyst.expressions.NullIf | nullif | SELECT nullif(2, 2) | struct | +| org.apache.spark.sql.catalyst.expressions.Nvl | ifnull | SELECT ifnull(NULL, array('2')) | struct> | | org.apache.spark.sql.catalyst.expressions.Nvl | nvl | SELECT nvl(NULL, array('2')) | struct> | | org.apache.spark.sql.catalyst.expressions.Nvl2 | nvl2 | SELECT nvl2(NULL, 2, 1) | struct | | org.apache.spark.sql.catalyst.expressions.OctetLength | octet_length | SELECT octet_length('Spark SQL') | struct | @@ -218,8 +219,8 @@ | org.apache.spark.sql.catalyst.expressions.Overlay | overlay | SELECT overlay('Spark SQL' PLACING '_' FROM 6) | struct | | org.apache.spark.sql.catalyst.expressions.ParseToDate | to_date | SELECT to_date('2009-07-30 04:17:52') | struct | | org.apache.spark.sql.catalyst.expressions.ParseToTimestamp | to_timestamp | SELECT to_timestamp('2016-12-31 00:12:00') | struct | -| org.apache.spark.sql.catalyst.expressions.ParseToTimestampLTZ | to_timestamp_ltz | SELECT to_timestamp_ltz('2016-12-31 00:12:00') | struct | -| org.apache.spark.sql.catalyst.expressions.ParseToTimestampNTZ | to_timestamp_ntz | SELECT to_timestamp_ntz('2016-12-31 00:12:00') | struct | +| org.apache.spark.sql.catalyst.expressions.ParseToTimestampLTZExpressionBuilder$ | to_timestamp_ltz | SELECT to_timestamp_ltz('2016-12-31 00:12:00') | struct | +| org.apache.spark.sql.catalyst.expressions.ParseToTimestampNTZExpressionBuilder$ | to_timestamp_ntz | SELECT to_timestamp_ntz('2016-12-31 00:12:00') | struct | | org.apache.spark.sql.catalyst.expressions.ParseUrl | parse_url | SELECT parse_url('http://spark.apache.org/path?query=1', 'HOST') | struct | | org.apache.spark.sql.catalyst.expressions.PercentRank | percent_rank | SELECT a, b, percent_rank(b) OVER (PARTITION BY a ORDER BY b) FROM VALUES ('A1', 2), ('A1', 1), ('A2', 3), ('A1', 1) tab(a, b) | struct | | org.apache.spark.sql.catalyst.expressions.Pi | pi | SELECT pi() | struct | @@ -276,7 +277,7 @@ | org.apache.spark.sql.catalyst.expressions.SparkVersion | version | SELECT version() | struct | | org.apache.spark.sql.catalyst.expressions.Sqrt | sqrt | SELECT sqrt(4) | struct | | org.apache.spark.sql.catalyst.expressions.Stack | stack | SELECT stack(2, 1, 2, 3) | struct | -| org.apache.spark.sql.catalyst.expressions.StartsWith | startswith | SELECT startswith('Spark SQL', 'Spark') | struct | +| org.apache.spark.sql.catalyst.expressions.StartsWithExpressionBuilder$ | startswith | SELECT startswith('Spark SQL', 'Spark') | struct | | org.apache.spark.sql.catalyst.expressions.StringInstr | instr | SELECT instr('SparkSQL', 'SQL') | struct | | org.apache.spark.sql.catalyst.expressions.StringLocate | locate | SELECT locate('bar', 'foobarbar') | struct | | org.apache.spark.sql.catalyst.expressions.StringLocate | position | SELECT position('bar', 'foobarbar') | struct | @@ -299,7 +300,9 @@ | org.apache.spark.sql.catalyst.expressions.Tan | tan | SELECT tan(0) | struct | | org.apache.spark.sql.catalyst.expressions.Tanh | tanh | SELECT tanh(0) | struct | | org.apache.spark.sql.catalyst.expressions.TimeWindow | window | SELECT a, window.start, window.end, count(*) as cnt FROM VALUES ('A1', '2021-01-01 00:00:00'), ('A1', '2021-01-01 00:04:30'), ('A1', '2021-01-01 00:06:00'), ('A2', '2021-01-01 00:01:00') AS tab(a, b) GROUP by a, window(b, '5 minutes') ORDER BY a, start | struct | +| org.apache.spark.sql.catalyst.expressions.ToBinary | to_binary | SELECT to_binary('abc', 'utf-8') | struct | | org.apache.spark.sql.catalyst.expressions.ToDegrees | degrees | SELECT degrees(3.141592653589793) | struct | +| org.apache.spark.sql.catalyst.expressions.ToNumber | to_number | SELECT to_number('454', '999') | struct | | org.apache.spark.sql.catalyst.expressions.ToRadians | radians | SELECT radians(180) | struct | | org.apache.spark.sql.catalyst.expressions.ToUTCTimestamp | to_utc_timestamp | SELECT to_utc_timestamp('2016-08-31', 'Asia/Seoul') | struct | | org.apache.spark.sql.catalyst.expressions.ToUnixTimestamp | to_unix_timestamp | SELECT to_unix_timestamp('2016-04-08', 'yyyy-MM-dd') | struct | @@ -310,6 +313,8 @@ | org.apache.spark.sql.catalyst.expressions.TryAdd | try_add | SELECT try_add(1, 2) | struct | | org.apache.spark.sql.catalyst.expressions.TryDivide | try_divide | SELECT try_divide(3, 2) | struct | | org.apache.spark.sql.catalyst.expressions.TryElementAt | try_element_at | SELECT try_element_at(array(1, 2, 3), 2) | struct | +| org.apache.spark.sql.catalyst.expressions.TryMultiply | try_multiply | SELECT try_multiply(2, 3) | struct | +| org.apache.spark.sql.catalyst.expressions.TrySubtract | try_subtract | SELECT try_subtract(2, 1) | struct | | org.apache.spark.sql.catalyst.expressions.TypeOf | typeof | SELECT typeof(1) | struct | | org.apache.spark.sql.catalyst.expressions.UnBase64 | unbase64 | SELECT unbase64('U3BhcmsgU1FM') | struct | | org.apache.spark.sql.catalyst.expressions.UnaryMinus | negative | SELECT negative(1) | struct | @@ -352,7 +357,7 @@ | org.apache.spark.sql.catalyst.expressions.aggregate.CovSample | covar_samp | SELECT covar_samp(c1, c2) FROM VALUES (1,1), (2,2), (3,3) AS tab(c1, c2) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.First | first | SELECT first(col) FROM VALUES (10), (5), (20) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.First | first_value | SELECT first_value(col) FROM VALUES (10), (5), (20) AS tab(col) | struct | -| org.apache.spark.sql.catalyst.expressions.aggregate.HistogramNumeric | histogram_numeric | SELECT histogram_numeric(col, 5) FROM VALUES (0), (1), (2), (10) AS tab(col) | struct>> | +| org.apache.spark.sql.catalyst.expressions.aggregate.HistogramNumeric | histogram_numeric | SELECT histogram_numeric(col, 5) FROM VALUES (0), (1), (2), (10) AS tab(col) | struct>> | | org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus | approx_count_distinct | SELECT approx_count_distinct(col1) FROM VALUES (1), (1), (2), (2), (3) tab(col1) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.Kurtosis | kurtosis | SELECT kurtosis(col) FROM VALUES (-10), (-20), (100), (1000) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.Last | last | SELECT last(col) FROM VALUES (10), (5), (20) AS tab(col) | struct | @@ -362,6 +367,8 @@ | org.apache.spark.sql.catalyst.expressions.aggregate.Min | min | SELECT min(col) FROM VALUES (10), (-1), (20) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.MinBy | min_by | SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.Percentile | percentile | SELECT percentile(col, 0.3) FROM VALUES (0), (10) AS tab(col) | struct | +| org.apache.spark.sql.catalyst.expressions.aggregate.RegrAvgX | regr_avgx | SELECT regr_avgx(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x) | struct | +| org.apache.spark.sql.catalyst.expressions.aggregate.RegrAvgY | regr_avgy | SELECT regr_avgy(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.RegrCount | regr_count | SELECT regr_count(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.Skewness | skewness | SELECT skewness(col) FROM VALUES (-10), (-20), (100), (1000) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.StddevPop | stddev_pop | SELECT stddev_pop(col) FROM VALUES (1), (2), (3) AS tab(col) | struct | diff --git a/sql/core/src/test/resources/sql-tests/inputs/ansi/datetime-parsing-invalid.sql b/sql/core/src/test/resources/sql-tests/inputs/ansi/datetime-parsing-invalid.sql new file mode 100644 index 0000000000000..70022f33337d4 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/ansi/datetime-parsing-invalid.sql @@ -0,0 +1,2 @@ +--IMPORT datetime-parsing-invalid.sql + diff --git a/sql/core/src/test/resources/sql-tests/inputs/array.sql b/sql/core/src/test/resources/sql-tests/inputs/array.sql index f73b653659eb4..dfcf1742feb6f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/array.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/array.sql @@ -99,6 +99,17 @@ select element_at(array(1, 2, 3), 0); select elt(4, '123', '456'); select elt(0, '123', '456'); select elt(-1, '123', '456'); +select elt(null, '123', '456'); +select elt(null, '123', null); +select elt(1, '123', null); +select elt(2, '123', null); select array(1, 2, 3)[5]; select array(1, 2, 3)[-1]; + +-- array_size +select array_size(array()); +select array_size(array(true)); +select array_size(array(2, 1)); +select array_size(NULL); +select array_size(map('a', 1, 'b', 2)); diff --git a/sql/core/src/test/resources/sql-tests/inputs/ceil-floor-with-scale-param.sql b/sql/core/src/test/resources/sql-tests/inputs/ceil-floor-with-scale-param.sql new file mode 100644 index 0000000000000..1baee30a8cf9a --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/ceil-floor-with-scale-param.sql @@ -0,0 +1,27 @@ +-- Tests different scenarios of ceil and floor functions with scale parameters +SELECT CEIL(2.5, 0); +SELECT CEIL(3.5, 0); +SELECT CEIL(-2.5, 0); +SELECT CEIL(-3.5, 0); +SELECT CEIL(-0.35, 1); +SELECT CEIL(-35, -1); +SELECT CEIL(-0.1, 0); +SELECT CEIL(5, 0); +SELECT CEIL(3.14115, -3); +SELECT CEIL(2.5, null); +SELECT CEIL(2.5, 'a'); +SELECT CEIL(2.5, 0, 0); + +-- Same inputs with floor function +SELECT FLOOR(2.5, 0); +SELECT FLOOR(3.5, 0); +SELECT FLOOR(-2.5, 0); +SELECT FLOOR(-3.5, 0); +SELECT FLOOR(-0.35, 1); +SELECT FLOOR(-35, -1); +SELECT FLOOR(-0.1, 0); +SELECT FLOOR(5, 0); +SELECT FLOOR(3.14115, -3); +SELECT FLOOR(2.5, null); +SELECT FLOOR(2.5, 'a'); +SELECT FLOOR(2.5, 0, 0); \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/inputs/date.sql b/sql/core/src/test/resources/sql-tests/inputs/date.sql index 57049eb461325..ab57c7c754c67 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/date.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/date.sql @@ -140,3 +140,27 @@ select date '2012-01-01' - interval '2-2' year to month, select to_date('26/October/2015', 'dd/MMMMM/yyyy'); select from_json('{"d":"26/October/2015"}', 'd Date', map('dateFormat', 'dd/MMMMM/yyyy')); select from_csv('26/October/2015', 'd Date', map('dateFormat', 'dd/MMMMM/yyyy')); + +-- Add a number of units to a timestamp or a date +select dateadd(MICROSECOND, 1001, timestamp'2022-02-25 01:02:03.123'); +select dateadd(MILLISECOND, -1, timestamp'2022-02-25 01:02:03.456'); +select dateadd(SECOND, 58, timestamp'2022-02-25 01:02:03'); +select dateadd(MINUTE, -100, date'2022-02-25'); +select dateadd(HOUR, -1, timestamp'2022-02-25 01:02:03'); +select dateadd(DAY, 367, date'2022-02-25'); +select dateadd(WEEK, -4, timestamp'2022-02-25 01:02:03'); +select dateadd(MONTH, -1, timestamp'2022-02-25 01:02:03'); +select dateadd(QUARTER, 5, date'2022-02-25'); +select dateadd(YEAR, 1, date'2022-02-25'); + +-- Get the difference between timestamps or dates in the specified units +select datediff(MICROSECOND, timestamp'2022-02-25 01:02:03.123', timestamp'2022-02-25 01:02:03.124001'); +select datediff(MILLISECOND, timestamp'2022-02-25 01:02:03.456', timestamp'2022-02-25 01:02:03.455'); +select datediff(SECOND, timestamp'2022-02-25 01:02:03', timestamp'2022-02-25 01:03:01'); +select datediff(MINUTE, date'2022-02-25', timestamp'2022-02-24 22:20:00'); +select datediff(HOUR, timestamp'2022-02-25 01:02:03', timestamp'2022-02-25 00:02:03'); +select datediff(DAY, date'2022-02-25', timestamp'2023-02-27 00:00:00'); +select datediff(WEEK, timestamp'2022-02-25 01:02:03', timestamp'2022-01-28 01:02:03'); +select datediff(MONTH, timestamp'2022-02-25 01:02:03', timestamp'2022-01-25 01:02:03'); +select datediff(QUARTER, date'2022-02-25', date'2023-05-25'); +select datediff(YEAR, date'2022-02-25', date'2023-02-25'); diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime-parsing-invalid.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime-parsing-invalid.sql index a6d743cab5480..1d1e2a5282c81 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/datetime-parsing-invalid.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime-parsing-invalid.sql @@ -14,7 +14,8 @@ select to_timestamp('366', 'D'); select to_timestamp('9', 'DD'); -- in java 8 this case is invalid, but valid in java 11, disabled for jenkins -- select to_timestamp('100', 'DD'); -select to_timestamp('366', 'DD'); +-- The error message is changed since Java 11+ +-- select to_timestamp('366', 'DD'); select to_timestamp('9', 'DDD'); select to_timestamp('99', 'DDD'); select to_timestamp('30-365', 'dd-DDD'); diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index cb82bfa310122..ef3d523a23d84 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -199,6 +199,7 @@ FROM testData GROUP BY a IS NULL; +-- Histogram aggregates with different numeric input types SELECT histogram_numeric(col, 2) as histogram_2, histogram_numeric(col, 3) as histogram_3, @@ -210,6 +211,32 @@ FROM VALUES (21), (22), (23), (24), (25), (26), (27), (28), (29), (30), (31), (32), (33), (34), (35), (3), (37), (38), (39), (40), (41), (42), (43), (44), (45), (46), (47), (48), (49), (50) AS tab(col); +SELECT histogram_numeric(col, 3) FROM VALUES (1), (2), (3) AS tab(col); +SELECT histogram_numeric(col, 3) FROM VALUES (1L), (2L), (3L) AS tab(col); +SELECT histogram_numeric(col, 3) FROM VALUES (1F), (2F), (3F) AS tab(col); +SELECT histogram_numeric(col, 3) FROM VALUES (1D), (2D), (3D) AS tab(col); +SELECT histogram_numeric(col, 3) FROM VALUES (1S), (2S), (3S) AS tab(col); +SELECT histogram_numeric(col, 3) FROM VALUES + (CAST(1 AS BYTE)), (CAST(2 AS BYTE)), (CAST(3 AS BYTE)) AS tab(col); +SELECT histogram_numeric(col, 3) FROM VALUES + (CAST(1 AS TINYINT)), (CAST(2 AS TINYINT)), (CAST(3 AS TINYINT)) AS tab(col); +SELECT histogram_numeric(col, 3) FROM VALUES + (CAST(1 AS SMALLINT)), (CAST(2 AS SMALLINT)), (CAST(3 AS SMALLINT)) AS tab(col); +SELECT histogram_numeric(col, 3) FROM VALUES + (CAST(1 AS BIGINT)), (CAST(2 AS BIGINT)), (CAST(3 AS BIGINT)) AS tab(col); +SELECT histogram_numeric(col, 3) FROM VALUES (TIMESTAMP '2017-03-01 00:00:00'), + (TIMESTAMP '2017-04-01 00:00:00'), (TIMESTAMP '2017-05-01 00:00:00') AS tab(col); +SELECT histogram_numeric(col, 3) FROM VALUES (INTERVAL '100-00' YEAR TO MONTH), + (INTERVAL '110-00' YEAR TO MONTH), (INTERVAL '120-00' YEAR TO MONTH) AS tab(col); +SELECT histogram_numeric(col, 3) FROM VALUES (INTERVAL '12 20:4:0' DAY TO SECOND), + (INTERVAL '12 21:4:0' DAY TO SECOND), (INTERVAL '12 22:4:0' DAY TO SECOND) AS tab(col); +SELECT histogram_numeric(col, 3) +FROM VALUES (NULL), (NULL), (NULL) AS tab(col); +SELECT histogram_numeric(col, 3) +FROM VALUES (CAST(NULL AS DOUBLE)), (CAST(NULL AS DOUBLE)), (CAST(NULL AS DOUBLE)) AS tab(col); +SELECT histogram_numeric(col, 3) +FROM VALUES (CAST(NULL AS INT)), (CAST(NULL AS INT)), (CAST(NULL AS INT)) AS tab(col); + -- SPARK-37613: Support ANSI Aggregate Function: regr_count SELECT regr_count(y, x) FROM testRegression; @@ -231,6 +258,12 @@ FROM VALUES (1,4),(2,3),(1,4),(2,4) AS v(a,b) GROUP BY a; +-- SPARK-37614: Support ANSI Aggregate Function: regr_avgx & regr_avgy +SELECT regr_avgx(y, x), regr_avgy(y, x) FROM testRegression; +SELECT regr_avgx(y, x), regr_avgy(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL; +SELECT k, avg(x), avg(y), regr_avgx(y, x), regr_avgy(y, x) FROM testRegression GROUP BY k; +SELECT k, avg(x) FILTER (WHERE x IS NOT NULL AND y IS NOT NULL), avg(y) FILTER (WHERE x IS NOT NULL AND y IS NOT NULL), regr_avgx(y, x), regr_avgy(y, x) FROM testRegression GROUP BY k; + -- SPARK-37676: Support ANSI Aggregation Function: percentile_cont SELECT percentile_cont(0.25) WITHIN GROUP (ORDER BY v), diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part1.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part1.sql index 4cc24c00435cc..d08037268c95c 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part1.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part1.sql @@ -84,7 +84,7 @@ SELECT regr_count(b, a) FROM aggtest; -- SELECT regr_sxx(b, a) FROM aggtest; -- SELECT regr_syy(b, a) FROM aggtest; -- SELECT regr_sxy(b, a) FROM aggtest; --- SELECT regr_avgx(b, a), regr_avgy(b, a) FROM aggtest; +SELECT regr_avgx(b, a), regr_avgy(b, a) FROM aggtest; -- SELECT regr_r2(b, a) FROM aggtest; -- SELECT regr_slope(b, a), regr_intercept(b, a) FROM aggtest; SELECT covar_pop(b, a), covar_samp(b, a) FROM aggtest; diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/numeric.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/numeric.sql index 53f2aa41ae3fa..14a89d526b512 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/numeric.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/numeric.sql @@ -895,22 +895,22 @@ DROP TABLE width_bucket_test; -- TO_NUMBER() -- -- SET lc_numeric = 'C'; --- SELECT '' AS to_number_1, to_number('-34,338,492', '99G999G999'); --- SELECT '' AS to_number_2, to_number('-34,338,492.654,878', '99G999G999D999G999'); +SELECT '' AS to_number_1, to_number('-34,338,492', '99G999G999'); +SELECT '' AS to_number_2, to_number('-34,338,492.654,878', '99G999G999D999G999'); -- SELECT '' AS to_number_3, to_number('<564646.654564>', '999999.999999PR'); --- SELECT '' AS to_number_4, to_number('0.00001-', '9.999999S'); +SELECT '' AS to_number_4, to_number('0.00001-', '9.999999S'); -- SELECT '' AS to_number_5, to_number('5.01-', 'FM9.999999S'); -- SELECT '' AS to_number_5, to_number('5.01-', 'FM9.999999MI'); -- SELECT '' AS to_number_7, to_number('5 4 4 4 4 8 . 7 8', '9 9 9 9 9 9 . 9 9'); -- SELECT '' AS to_number_8, to_number('.01', 'FM9.99'); --- SELECT '' AS to_number_9, to_number('.0', '99999999.99999999'); --- SELECT '' AS to_number_10, to_number('0', '99.99'); +SELECT '' AS to_number_9, to_number('.0', '99999999.99999999'); +SELECT '' AS to_number_10, to_number('0', '99.99'); -- SELECT '' AS to_number_11, to_number('.-01', 'S99.99'); --- SELECT '' AS to_number_12, to_number('.01-', '99.99S'); +SELECT '' AS to_number_12, to_number('.01-', '99.99S'); -- SELECT '' AS to_number_13, to_number(' . 0 1-', ' 9 9 . 9 9 S'); --- SELECT '' AS to_number_14, to_number('34,50','999,99'); --- SELECT '' AS to_number_15, to_number('123,000','999G'); --- SELECT '' AS to_number_16, to_number('123456','999G999'); +SELECT '' AS to_number_14, to_number('34,50','999,99'); +SELECT '' AS to_number_15, to_number('123,000','999G'); +SELECT '' AS to_number_16, to_number('123456','999G999'); -- SELECT '' AS to_number_17, to_number('$1234.56','L9,999.99'); -- SELECT '' AS to_number_18, to_number('$1234.56','L99,999.99'); -- SELECT '' AS to_number_19, to_number('$1,234.56','L99,999.99'); diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index 4b5f1204b15e9..e7c01a69bc838 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -124,4 +124,53 @@ SELECT endswith('Spark SQL', 'QL'); SELECT endswith('Spark SQL', 'Spa'); SELECT endswith(null, 'Spark'); SELECT endswith('Spark', null); -SELECT endswith(null, null); \ No newline at end of file +SELECT endswith(null, null); + +SELECT contains(x'537061726b2053514c', x'537061726b'); +SELECT contains(x'', x''); +SELECT contains(x'537061726b2053514c', null); +SELECT contains(12, '1'); +SELECT contains(true, 'ru'); +SELECT contains(x'12', 12); +SELECT contains(true, false); + +SELECT startswith(x'537061726b2053514c', x'537061726b'); +SELECT startswith(x'537061726b2053514c', x''); +SELECT startswith(x'', x''); +SELECT startswith(x'537061726b2053514c', null); + +SELECT endswith(x'537061726b2053514c', x'53516c'); +SELECT endsWith(x'537061726b2053514c', x'537061726b'); +SELECT endsWith(x'537061726b2053514c', x''); +SELECT endsWith(x'', x''); +SELECT endsWith(x'537061726b2053514c', null); + +-- to_number +select to_number('454', '000'); +select to_number('454.2', '000.0'); +select to_number('12,454', '00,000'); +select to_number('$78.12', '$00.00'); +select to_number('-454', '-000'); +select to_number('-454', 'S000'); +select to_number('12,454.8-', '00,000.9-'); +select to_number('00,454.8-', '00,000.9-'); + +-- to_binary +select to_binary('abc'); +select to_binary('abc', 'utf-8'); +select to_binary('abc', 'base64'); +select to_binary('abc', 'hex'); +-- 'format' parameter can be any foldable string value, not just literal. +select to_binary('abc', concat('utf', '-8')); +-- 'format' parameter is case insensitive. +select to_binary('abc', 'Hex'); +-- null inputs lead to null result. +select to_binary('abc', null); +select to_binary(null, 'utf-8'); +select to_binary(null, null); +select to_binary(null, cast(null as string)); +-- 'format' parameter must be string type or void type. +select to_binary(null, cast(null as int)); +select to_binary('abc', 1); +-- invalid inputs. +select to_binary('abc', 'invalidFormat'); diff --git a/sql/core/src/test/resources/sql-tests/inputs/timestamp-ntz.sql b/sql/core/src/test/resources/sql-tests/inputs/timestamp-ntz.sql index 14266db65a971..b7dc2872e50d3 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/timestamp-ntz.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/timestamp-ntz.sql @@ -17,3 +17,9 @@ SELECT make_timestamp_ntz(2021, 07, 11, 6, 30, 45.678, 'CET'); SELECT make_timestamp_ntz(2021, 07, 11, 6, 30, 60.007); SELECT convert_timezone('Europe/Moscow', 'America/Los_Angeles', timestamp_ntz'2022-01-01 00:00:00'); + +-- Get the difference between timestamps w/o time zone in the specified units +select timestampdiff(QUARTER, timestamp_ntz'2022-01-01 01:02:03', timestamp_ntz'2022-05-02 05:06:07'); +select timestampdiff(HOUR, timestamp_ntz'2022-02-14 01:02:03', timestamp_ltz'2022-02-14 02:03:04'); +select timestampdiff(YEAR, date'2022-02-15', timestamp_ntz'2023-02-15 10:11:12'); +select timestampdiff(MILLISECOND, timestamp_ntz'2022-02-14 23:59:59.123', date'2022-02-15'); diff --git a/sql/core/src/test/resources/sql-tests/inputs/timestamp.sql b/sql/core/src/test/resources/sql-tests/inputs/timestamp.sql index 0bc77a8c971f8..21d27e98ab440 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/timestamp.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/timestamp.sql @@ -142,3 +142,15 @@ select to_timestamp('22 05 2020 Friday', 'dd MM yyyy EEEEE'); select unix_timestamp('22 05 2020 Friday', 'dd MM yyyy EEEEE'); select from_json('{"t":"26/October/2015"}', 't Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')); select from_csv('26/October/2015', 't Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')); + +-- Add a number of units to a timestamp or a date +select timestampadd(MONTH, -1, timestamp'2022-02-14 01:02:03'); +select timestampadd(MINUTE, 58, timestamp'2022-02-14 01:02:03'); +select timestampadd(YEAR, 1, date'2022-02-15'); +select timestampadd(SECOND, -1, date'2022-02-15'); + +-- Get the difference between timestamps in the specified units +select timestampdiff(MONTH, timestamp'2022-02-14 01:02:03', timestamp'2022-01-14 01:02:03'); +select timestampdiff(MINUTE, timestamp'2022-02-14 01:02:03', timestamp'2022-02-14 02:00:03'); +select timestampdiff(YEAR, date'2022-02-15', date'2023-02-15'); +select timestampdiff(SECOND, date'2022-02-15', timestamp'2022-02-14 23:59:59'); diff --git a/sql/core/src/test/resources/sql-tests/inputs/timestampNTZ/timestamp.sql b/sql/core/src/test/resources/sql-tests/inputs/timestampNTZ/timestamp.sql index 79193c900d046..47988ee65fb7c 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/timestampNTZ/timestamp.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/timestampNTZ/timestamp.sql @@ -1 +1,2 @@ +--SET spark.sql.ansi.enabled = false --IMPORT timestamp.sql diff --git a/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql b/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql index 5962a5d55bb89..586680f550761 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql @@ -40,3 +40,31 @@ SELECT try_divide(interval 2 year, 0); SELECT try_divide(interval 2 second, 0); SELECT try_divide(interval 2147483647 month, 0.5); SELECT try_divide(interval 106751991 day, 0.5); + +-- Numeric - Numeric +SELECT try_subtract(1, 1); +SELECT try_subtract(2147483647, -1); +SELECT try_subtract(-2147483648, 1); +SELECT try_subtract(9223372036854775807L, -1); +SELECT try_subtract(-9223372036854775808L, 1); + +-- Interval - Interval +SELECT try_subtract(interval 2 year, interval 3 year); +SELECT try_subtract(interval 3 second, interval 2 second); +SELECT try_subtract(interval 2147483647 month, interval -2 month); +SELECT try_subtract(interval 106751991 day, interval -3 day); + +-- Numeric * Numeric +SELECT try_multiply(2, 3); +SELECT try_multiply(2147483647, -2); +SELECT try_multiply(-2147483648, 2); +SELECT try_multiply(9223372036854775807L, 2); +SELECT try_multiply(-9223372036854775808L, -2); + +-- Interval * Numeric +SELECT try_multiply(interval 2 year, 2); +SELECT try_multiply(interval 2 second, 2); +SELECT try_multiply(interval 2 year, 0); +SELECT try_multiply(interval 2 second, 0); +SELECT try_multiply(interval 2147483647 month, 2); +SELECT try_multiply(interval 106751991 day, 2); diff --git a/sql/core/src/test/resources/sql-tests/inputs/udf/postgreSQL/udf-aggregates_part1.sql b/sql/core/src/test/resources/sql-tests/inputs/udf/postgreSQL/udf-aggregates_part1.sql index 2b00815bba2e3..6e2ffae48a6ca 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/udf/postgreSQL/udf-aggregates_part1.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/udf/postgreSQL/udf-aggregates_part1.sql @@ -84,7 +84,7 @@ SELECT regr_count(b, a) FROM aggtest; -- SELECT regr_sxx(b, a) FROM aggtest; -- SELECT regr_syy(b, a) FROM aggtest; -- SELECT regr_sxy(b, a) FROM aggtest; --- SELECT regr_avgx(b, a), regr_avgy(b, a) FROM aggtest; +SELECT regr_avgx(b, a), regr_avgy(b, a) FROM aggtest; -- SELECT regr_r2(b, a) FROM aggtest; -- SELECT regr_slope(b, a), regr_intercept(b, a) FROM aggtest; SELECT udf(covar_pop(b, udf(a))), covar_samp(udf(b), a) FROM aggtest; diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out index b412493b60aa3..00ac2eeba7ffd 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 29 +-- Number of queries: 38 -- !query @@ -216,6 +216,38 @@ org.apache.spark.SparkArrayIndexOutOfBoundsException Invalid index: -1, numElements: 2. If necessary set spark.sql.ansi.enabled to false to bypass this error. +-- !query +select elt(null, '123', '456') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select elt(null, '123', null) +-- !query schema +struct +-- !query output +NULL + + +-- !query +select elt(1, '123', null) +-- !query schema +struct +-- !query output +123 + + +-- !query +select elt(2, '123', null) +-- !query schema +struct +-- !query output +NULL + + -- !query select array(1, 2, 3)[5] -- !query schema @@ -234,6 +266,47 @@ org.apache.spark.SparkArrayIndexOutOfBoundsException Invalid index: -1, numElements: 3. If necessary set spark.sql.ansi.strictIndexOperator to false to bypass this error. +-- !query +select array_size(array()) +-- !query schema +struct +-- !query output +0 + + +-- !query +select array_size(array(true)) +-- !query schema +struct +-- !query output +1 + + +-- !query +select array_size(array(2, 1)) +-- !query schema +struct +-- !query output +2 + + +-- !query +select array_size(NULL) +-- !query schema +struct +-- !query output +NULL + + +-- !query +select array_size(map('a', 1, 'b', 2)) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +cannot resolve 'array_size(map('a', 1, 'b', 2))' due to data type mismatch: argument 1 requires array type, however, 'map('a', 1, 'b', 2)' is of map type.; line 1 pos 7 + + -- !query set spark.sql.ansi.strictIndexOperator=false -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out index 75cc31856d56c..437b56e2ffa3e 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 77 +-- Number of queries: 97 -- !query @@ -470,39 +470,33 @@ struct -- !query select date_add('2011-11-11', int_str) from date_view -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'date_add(CAST('2011-11-11' AS DATE), date_view.int_str)' due to data type mismatch: argument 2 requires (int or smallint or tinyint) type, however, 'date_view.int_str' is of string type.; line 1 pos 7 +2011-11-12 -- !query select date_sub('2011-11-11', int_str) from date_view -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'date_sub(CAST('2011-11-11' AS DATE), date_view.int_str)' due to data type mismatch: argument 2 requires (int or smallint or tinyint) type, however, 'date_view.int_str' is of string type.; line 1 pos 7 +2011-11-10 -- !query select date_add(date_str, 1) from date_view -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'date_add(date_view.date_str, 1)' due to data type mismatch: argument 1 requires date type, however, 'date_view.date_str' is of string type. -To fix the error, you might need to add explicit type casts. If necessary set spark.sql.ansi.enabled to false to bypass this error.; line 1 pos 7 +2011-11-12 -- !query select date_sub(date_str, 1) from date_view -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'date_sub(date_view.date_str, 1)' due to data type mismatch: argument 1 requires date type, however, 'date_view.date_str' is of string type. -To fix the error, you might need to add explicit type casts. If necessary set spark.sql.ansi.enabled to false to bypass this error.; line 1 pos 7 +2011-11-10 -- !query @@ -581,20 +575,17 @@ NULL -- !query select date_str - date '2001-09-28' from date_view -- !query schema -struct<> +struct<(date_str - DATE '2001-09-28'):interval day> -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve '(date_view.date_str - DATE '2001-09-28')' due to data type mismatch: argument 1 requires date type, however, 'date_view.date_str' is of string type. -To fix the error, you might need to add explicit type casts. If necessary set spark.sql.ansi.enabled to false to bypass this error.; line 1 pos 7 +3696 00:00:00.000000000 -- !query select date '2001-09-28' - date_str from date_view -- !query schema -struct<> +struct<(DATE '2001-09-28' - date_str):interval day> -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve '(DATE '2001-09-28' - date_view.date_str)' due to data type mismatch: differing types in '(DATE '2001-09-28' - date_view.date_str)' (date and string).; line 1 pos 7 +-3696 00:00:00.000000000 -- !query @@ -650,7 +641,7 @@ select to_date('26/October/2015', 'dd/MMMMM/yyyy') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd/MMMMM/yyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'dd/MMMMM/yyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -659,7 +650,7 @@ select from_json('{"d":"26/October/2015"}', 'd Date', map('dateFormat', 'dd/MMMM struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd/MMMMM/yyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'dd/MMMMM/yyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -668,4 +659,164 @@ select from_csv('26/October/2015', 'd Date', map('dateFormat', 'dd/MMMMM/yyyy')) struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd/MMMMM/yyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'dd/MMMMM/yyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select dateadd(MICROSECOND, 1001, timestamp'2022-02-25 01:02:03.123') +-- !query schema +struct +-- !query output +2022-02-25 01:02:03.124001 + + +-- !query +select dateadd(MILLISECOND, -1, timestamp'2022-02-25 01:02:03.456') +-- !query schema +struct +-- !query output +2022-02-25 01:02:03.455 + + +-- !query +select dateadd(SECOND, 58, timestamp'2022-02-25 01:02:03') +-- !query schema +struct +-- !query output +2022-02-25 01:03:01 + + +-- !query +select dateadd(MINUTE, -100, date'2022-02-25') +-- !query schema +struct +-- !query output +2022-02-24 22:20:00 + + +-- !query +select dateadd(HOUR, -1, timestamp'2022-02-25 01:02:03') +-- !query schema +struct +-- !query output +2022-02-25 00:02:03 + + +-- !query +select dateadd(DAY, 367, date'2022-02-25') +-- !query schema +struct +-- !query output +2023-02-27 00:00:00 + + +-- !query +select dateadd(WEEK, -4, timestamp'2022-02-25 01:02:03') +-- !query schema +struct +-- !query output +2022-01-28 01:02:03 + + +-- !query +select dateadd(MONTH, -1, timestamp'2022-02-25 01:02:03') +-- !query schema +struct +-- !query output +2022-01-25 01:02:03 + + +-- !query +select dateadd(QUARTER, 5, date'2022-02-25') +-- !query schema +struct +-- !query output +2023-05-25 00:00:00 + + +-- !query +select dateadd(YEAR, 1, date'2022-02-25') +-- !query schema +struct +-- !query output +2023-02-25 00:00:00 + + +-- !query +select datediff(MICROSECOND, timestamp'2022-02-25 01:02:03.123', timestamp'2022-02-25 01:02:03.124001') +-- !query schema +struct +-- !query output +1001 + + +-- !query +select datediff(MILLISECOND, timestamp'2022-02-25 01:02:03.456', timestamp'2022-02-25 01:02:03.455') +-- !query schema +struct +-- !query output +-1 + + +-- !query +select datediff(SECOND, timestamp'2022-02-25 01:02:03', timestamp'2022-02-25 01:03:01') +-- !query schema +struct +-- !query output +58 + + +-- !query +select datediff(MINUTE, date'2022-02-25', timestamp'2022-02-24 22:20:00') +-- !query schema +struct +-- !query output +-100 + + +-- !query +select datediff(HOUR, timestamp'2022-02-25 01:02:03', timestamp'2022-02-25 00:02:03') +-- !query schema +struct +-- !query output +-1 + + +-- !query +select datediff(DAY, date'2022-02-25', timestamp'2023-02-27 00:00:00') +-- !query schema +struct +-- !query output +367 + + +-- !query +select datediff(WEEK, timestamp'2022-02-25 01:02:03', timestamp'2022-01-28 01:02:03') +-- !query schema +struct +-- !query output +-4 + + +-- !query +select datediff(MONTH, timestamp'2022-02-25 01:02:03', timestamp'2022-01-25 01:02:03') +-- !query schema +struct +-- !query output +-1 + + +-- !query +select datediff(QUARTER, date'2022-02-25', date'2023-05-25') +-- !query schema +struct +-- !query output +5 + + +-- !query +select datediff(YEAR, date'2022-02-25', date'2023-02-25') +-- !query schema +struct +-- !query output +1 diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/datetime-parsing-invalid.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/datetime-parsing-invalid.sql.out new file mode 100644 index 0000000000000..59761d5ac53f0 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/ansi/datetime-parsing-invalid.sql.out @@ -0,0 +1,254 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 28 + + +-- !query +select to_timestamp('294248', 'y') +-- !query schema +struct<> +-- !query output +java.lang.ArithmeticException +long overflow + + +-- !query +select to_timestamp('1', 'yy') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading to Spark >= 3.0: Fail to parse '1' in the new parser. You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0, or set to CORRECTED and treat it as an invalid datetime string. + + +-- !query +select to_timestamp('-12', 'yy') +-- !query schema +struct<> +-- !query output +java.time.format.DateTimeParseException +Text '-12' could not be parsed at index 0. If necessary set spark.sql.ansi.enabled to false to bypass this error. + + +-- !query +select to_timestamp('123', 'yy') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading to Spark >= 3.0: Fail to parse '123' in the new parser. You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0, or set to CORRECTED and treat it as an invalid datetime string. + + +-- !query +select to_timestamp('1', 'yyy') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading to Spark >= 3.0: Fail to parse '1' in the new parser. You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0, or set to CORRECTED and treat it as an invalid datetime string. + + +-- !query +select to_timestamp('1234567', 'yyyyyyy') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'yyyyyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select to_timestamp('366', 'D') +-- !query schema +struct<> +-- !query output +java.time.DateTimeException +Invalid date 'DayOfYear 366' as '1970' is not a leap year. If necessary set spark.sql.ansi.enabled to false to bypass this error. + + +-- !query +select to_timestamp('9', 'DD') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading to Spark >= 3.0: Fail to parse '9' in the new parser. You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0, or set to CORRECTED and treat it as an invalid datetime string. + + +-- !query +select to_timestamp('9', 'DDD') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading to Spark >= 3.0: Fail to parse '9' in the new parser. You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0, or set to CORRECTED and treat it as an invalid datetime string. + + +-- !query +select to_timestamp('99', 'DDD') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading to Spark >= 3.0: Fail to parse '99' in the new parser. You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0, or set to CORRECTED and treat it as an invalid datetime string. + + +-- !query +select to_timestamp('30-365', 'dd-DDD') +-- !query schema +struct<> +-- !query output +java.time.DateTimeException +Conflict found: Field DayOfMonth 30 differs from DayOfMonth 31 derived from 1970-12-31. If necessary set spark.sql.ansi.enabled to false to bypass this error. + + +-- !query +select to_timestamp('11-365', 'MM-DDD') +-- !query schema +struct<> +-- !query output +java.time.DateTimeException +Conflict found: Field MonthOfYear 11 differs from MonthOfYear 12 derived from 1970-12-31. If necessary set spark.sql.ansi.enabled to false to bypass this error. + + +-- !query +select to_timestamp('2019-366', 'yyyy-DDD') +-- !query schema +struct<> +-- !query output +java.time.format.DateTimeParseException +Text '2019-366' could not be parsed: Invalid date 'DayOfYear 366' as '2019' is not a leap year. If necessary set spark.sql.ansi.enabled to false to bypass this error. + + +-- !query +select to_timestamp('12-30-365', 'MM-dd-DDD') +-- !query schema +struct<> +-- !query output +java.time.DateTimeException +Conflict found: Field DayOfMonth 30 differs from DayOfMonth 31 derived from 1970-12-31. If necessary set spark.sql.ansi.enabled to false to bypass this error. + + +-- !query +select to_timestamp('2020-01-365', 'yyyy-dd-DDD') +-- !query schema +struct<> +-- !query output +java.time.format.DateTimeParseException +Text '2020-01-365' could not be parsed: Conflict found: Field DayOfMonth 30 differs from DayOfMonth 1 derived from 2020-12-30. If necessary set spark.sql.ansi.enabled to false to bypass this error. + + +-- !query +select to_timestamp('2020-10-350', 'yyyy-MM-DDD') +-- !query schema +struct<> +-- !query output +java.time.format.DateTimeParseException +Text '2020-10-350' could not be parsed: Conflict found: Field MonthOfYear 12 differs from MonthOfYear 10 derived from 2020-12-15. If necessary set spark.sql.ansi.enabled to false to bypass this error. + + +-- !query +select to_timestamp('2020-11-31-366', 'yyyy-MM-dd-DDD') +-- !query schema +struct<> +-- !query output +java.time.format.DateTimeParseException +Text '2020-11-31-366' could not be parsed: Invalid date 'NOVEMBER 31'. If necessary set spark.sql.ansi.enabled to false to bypass this error. + + +-- !query +select from_csv('2018-366', 'date Date', map('dateFormat', 'yyyy-DDD')) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading to Spark >= 3.0: Fail to parse '2018-366' in the new parser. You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0, or set to CORRECTED and treat it as an invalid datetime string. + + +-- !query +select to_date("2020-01-27T20:06:11.847", "yyyy-MM-dd HH:mm:ss.SSS") +-- !query schema +struct<> +-- !query output +java.time.format.DateTimeParseException +Text '2020-01-27T20:06:11.847' could not be parsed at index 10. If necessary set spark.sql.ansi.enabled to false to bypass this error. + + +-- !query +select to_date("Unparseable", "yyyy-MM-dd HH:mm:ss.SSS") +-- !query schema +struct<> +-- !query output +java.time.format.DateTimeParseException +Text 'Unparseable' could not be parsed at index 0. If necessary set spark.sql.ansi.enabled to false to bypass this error. + + +-- !query +select to_timestamp("2020-01-27T20:06:11.847", "yyyy-MM-dd HH:mm:ss.SSS") +-- !query schema +struct<> +-- !query output +java.time.format.DateTimeParseException +Text '2020-01-27T20:06:11.847' could not be parsed at index 10. If necessary set spark.sql.ansi.enabled to false to bypass this error. + + +-- !query +select to_timestamp("Unparseable", "yyyy-MM-dd HH:mm:ss.SSS") +-- !query schema +struct<> +-- !query output +java.time.format.DateTimeParseException +Text 'Unparseable' could not be parsed at index 0. If necessary set spark.sql.ansi.enabled to false to bypass this error. + + +-- !query +select unix_timestamp("2020-01-27T20:06:11.847", "yyyy-MM-dd HH:mm:ss.SSS") +-- !query schema +struct<> +-- !query output +java.time.format.DateTimeParseException +Text '2020-01-27T20:06:11.847' could not be parsed at index 10. If necessary set spark.sql.ansi.enabled to false to bypass this error. + + +-- !query +select unix_timestamp("Unparseable", "yyyy-MM-dd HH:mm:ss.SSS") +-- !query schema +struct<> +-- !query output +java.time.format.DateTimeParseException +Text 'Unparseable' could not be parsed at index 0. If necessary set spark.sql.ansi.enabled to false to bypass this error. + + +-- !query +select to_unix_timestamp("2020-01-27T20:06:11.847", "yyyy-MM-dd HH:mm:ss.SSS") +-- !query schema +struct<> +-- !query output +java.time.format.DateTimeParseException +Text '2020-01-27T20:06:11.847' could not be parsed at index 10. If necessary set spark.sql.ansi.enabled to false to bypass this error. + + +-- !query +select to_unix_timestamp("Unparseable", "yyyy-MM-dd HH:mm:ss.SSS") +-- !query schema +struct<> +-- !query output +java.time.format.DateTimeParseException +Text 'Unparseable' could not be parsed at index 0. If necessary set spark.sql.ansi.enabled to false to bypass this error. + + +-- !query +select cast("Unparseable" as timestamp) +-- !query schema +struct<> +-- !query output +java.time.DateTimeException +Cannot cast Unparseable to TimestampType. To return NULL instead, use 'try_cast'. If necessary set spark.sql.ansi.enabled to false to bypass this error. + + +-- !query +select cast("Unparseable" as date) +-- !query schema +struct<> +-- !query output +java.time.DateTimeException +Cannot cast Unparseable to DateType. To return NULL instead, use 'try_cast'. If necessary set spark.sql.ansi.enabled to false to bypass this error. diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out index 12450fa6679dc..cfc77aa45fdeb 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out @@ -1532,9 +1532,8 @@ select str - interval '4 22:12' day to minute from interval_view -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'interval_view.str + (- INTERVAL '4 22:12' DAY TO MINUTE)' due to data type mismatch: argument 1 requires (timestamp or timestamp without time zone) type, however, 'interval_view.str' is of string type. -To fix the error, you might need to add explicit type casts. If necessary set spark.sql.ansi.enabled to false to bypass this error.; line 1 pos 7 +java.time.DateTimeException +Cannot cast 1 to TimestampType. To return NULL instead, use 'try_cast'. If necessary set spark.sql.ansi.enabled to false to bypass this error. -- !query @@ -1542,9 +1541,8 @@ select str + interval '4 22:12' day to minute from interval_view -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'interval_view.str + INTERVAL '4 22:12' DAY TO MINUTE' due to data type mismatch: argument 1 requires (timestamp or timestamp without time zone) type, however, 'interval_view.str' is of string type. -To fix the error, you might need to add explicit type casts. If necessary set spark.sql.ansi.enabled to false to bypass this error.; line 1 pos 7 +java.time.DateTimeException +Cannot cast 1 to TimestampType. To return NULL instead, use 'try_cast'. If necessary set spark.sql.ansi.enabled to false to bypass this error. -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/map.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/map.sql.out index 7a27a89e5bb95..5f7bd9faa79e9 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/map.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/map.sql.out @@ -74,7 +74,7 @@ select map_contains_key(map('1', 'a', '2', 'b'), 1) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'array_contains(map_keys(map('1', 'a', '2', 'b')), 1)' due to data type mismatch: Input to function array_contains should have been array followed by a value with same element type, but it's [array, int].; line 1 pos 7 +cannot resolve 'map_contains_key(map('1', 'a', '2', 'b'), 1)' due to data type mismatch: Input to function map_contains_key should have been map followed by a value with same key type, but it's [map, int].; line 1 pos 7 -- !query @@ -83,7 +83,7 @@ select map_contains_key(map(1, 'a', 2, 'b'), '1') struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'array_contains(map_keys(map(1, 'a', 2, 'b')), '1')' due to data type mismatch: Input to function array_contains should have been array followed by a value with same element type, but it's [array, string].; line 1 pos 7 +cannot resolve 'map_contains_key(map(1, 'a', 2, 'b'), '1')' due to data type mismatch: Input to function map_contains_key should have been map followed by a value with same key type, but it's [map, string].; line 1 pos 7 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out index 6fb9a6d5a47ab..b182b5cb6b390 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 94 +-- Number of queries: 131 -- !query @@ -760,3 +760,302 @@ SELECT endswith(null, null) struct -- !query output NULL + + +-- !query +SELECT contains(x'537061726b2053514c', x'537061726b') +-- !query schema +struct +-- !query output +true + + +-- !query +SELECT contains(x'', x'') +-- !query schema +struct +-- !query output +true + + +-- !query +SELECT contains(x'537061726b2053514c', null) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT contains(12, '1') +-- !query schema +struct +-- !query output +true + + +-- !query +SELECT contains(true, 'ru') +-- !query schema +struct +-- !query output +true + + +-- !query +SELECT contains(x'12', 12) +-- !query schema +struct +-- !query output +false + + +-- !query +SELECT contains(true, false) +-- !query schema +struct +-- !query output +false + + +-- !query +SELECT startswith(x'537061726b2053514c', x'537061726b') +-- !query schema +struct +-- !query output +true + + +-- !query +SELECT startswith(x'537061726b2053514c', x'') +-- !query schema +struct +-- !query output +true + + +-- !query +SELECT startswith(x'', x'') +-- !query schema +struct +-- !query output +true + + +-- !query +SELECT startswith(x'537061726b2053514c', null) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT endswith(x'537061726b2053514c', x'53516c') +-- !query schema +struct +-- !query output +false + + +-- !query +SELECT endsWith(x'537061726b2053514c', x'537061726b') +-- !query schema +struct +-- !query output +false + + +-- !query +SELECT endsWith(x'537061726b2053514c', x'') +-- !query schema +struct +-- !query output +true + + +-- !query +SELECT endsWith(x'', x'') +-- !query schema +struct +-- !query output +true + + +-- !query +SELECT endsWith(x'537061726b2053514c', null) +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_number('454', '000') +-- !query schema +struct +-- !query output +454 + + +-- !query +select to_number('454.2', '000.0') +-- !query schema +struct +-- !query output +454.2 + + +-- !query +select to_number('12,454', '00,000') +-- !query schema +struct +-- !query output +12454 + + +-- !query +select to_number('$78.12', '$00.00') +-- !query schema +struct +-- !query output +78.12 + + +-- !query +select to_number('-454', '-000') +-- !query schema +struct +-- !query output +-454 + + +-- !query +select to_number('-454', 'S000') +-- !query schema +struct +-- !query output +-454 + + +-- !query +select to_number('12,454.8-', '00,000.9-') +-- !query schema +struct +-- !query output +-12454.8 + + +-- !query +select to_number('00,454.8-', '00,000.9-') +-- !query schema +struct +-- !query output +-454.8 + + +-- !query +select to_binary('abc') +-- !query schema +struct +-- !query output +� + + +-- !query +select to_binary('abc', 'utf-8') +-- !query schema +struct +-- !query output +abc + + +-- !query +select to_binary('abc', 'base64') +-- !query schema +struct +-- !query output +i� + + +-- !query +select to_binary('abc', 'hex') +-- !query schema +struct +-- !query output +� + + +-- !query +select to_binary('abc', concat('utf', '-8')) +-- !query schema +struct +-- !query output +abc + + +-- !query +select to_binary('abc', 'Hex') +-- !query schema +struct +-- !query output +� + + +-- !query +select to_binary('abc', null) +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_binary(null, 'utf-8') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_binary(null, null) +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_binary(null, cast(null as string)) +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_binary(null, cast(null as int)) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +The 'format' parameter of function 'to_binary' needs to be a string literal.; line 1 pos 7 + + +-- !query +select to_binary('abc', 1) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +The 'format' parameter of function 'to_binary' needs to be a string literal.; line 1 pos 7 + + +-- !query +select to_binary('abc', 'invalidFormat') +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Invalid value for the 'format' parameter of function 'to_binary': invalidformat. The value has to be a case-insensitive string literal of 'hex', 'utf-8', or 'base64'. diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/timestamp.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/timestamp.sql.out index 6aa70bd599b95..2946842e3f6e4 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/timestamp.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/timestamp.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 89 +-- Number of queries: 97 -- !query @@ -647,19 +647,17 @@ struct<> -- !query select str - timestamp'2011-11-11 11:11:11' from ts_view -- !query schema -struct<> +struct<(str - TIMESTAMP '2011-11-11 11:11:11'):interval day to second> -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve '(ts_view.str - TIMESTAMP '2011-11-11 11:11:11')' due to data type mismatch: argument 1 requires (timestamp or timestamp without time zone) type, however, 'ts_view.str' is of string type.; line 1 pos 7 +0 00:00:00.000000000 -- !query select timestamp'2011-11-11 11:11:11' - str from ts_view -- !query schema -struct<> +struct<(TIMESTAMP '2011-11-11 11:11:11' - str):interval day to second> -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve '(TIMESTAMP '2011-11-11 11:11:11' - ts_view.str)' due to data type mismatch: argument 2 requires (timestamp or timestamp without time zone) type, however, 'ts_view.str' is of string type.; line 1 pos 7 +0 00:00:00.000000000 -- !query @@ -727,7 +725,7 @@ select to_timestamp('2019-10-06 A', 'yyyy-MM-dd GGGGG') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'yyyy-MM-dd GGGGG' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'yyyy-MM-dd GGGGG' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -736,7 +734,7 @@ select to_timestamp('22 05 2020 Friday', 'dd MM yyyy EEEEEE') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd MM yyyy EEEEEE' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'dd MM yyyy EEEEEE' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -745,7 +743,7 @@ select to_timestamp('22 05 2020 Friday', 'dd MM yyyy EEEEE') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd MM yyyy EEEEE' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'dd MM yyyy EEEEE' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -754,7 +752,7 @@ select unix_timestamp('22 05 2020 Friday', 'dd MM yyyy EEEEE') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd MM yyyy EEEEE' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'dd MM yyyy EEEEE' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -763,7 +761,7 @@ select from_json('{"t":"26/October/2015"}', 't Timestamp', map('timestampFormat' struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd/MMMMM/yyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'dd/MMMMM/yyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -772,4 +770,68 @@ select from_csv('26/October/2015', 't Timestamp', map('timestampFormat', 'dd/MMM struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd/MMMMM/yyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'dd/MMMMM/yyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select timestampadd(MONTH, -1, timestamp'2022-02-14 01:02:03') +-- !query schema +struct +-- !query output +2022-01-14 01:02:03 + + +-- !query +select timestampadd(MINUTE, 58, timestamp'2022-02-14 01:02:03') +-- !query schema +struct +-- !query output +2022-02-14 02:00:03 + + +-- !query +select timestampadd(YEAR, 1, date'2022-02-15') +-- !query schema +struct +-- !query output +2023-02-15 00:00:00 + + +-- !query +select timestampadd(SECOND, -1, date'2022-02-15') +-- !query schema +struct +-- !query output +2022-02-14 23:59:59 + + +-- !query +select timestampdiff(MONTH, timestamp'2022-02-14 01:02:03', timestamp'2022-01-14 01:02:03') +-- !query schema +struct +-- !query output +-1 + + +-- !query +select timestampdiff(MINUTE, timestamp'2022-02-14 01:02:03', timestamp'2022-02-14 02:00:03') +-- !query schema +struct +-- !query output +58 + + +-- !query +select timestampdiff(YEAR, date'2022-02-15', date'2023-02-15') +-- !query schema +struct +-- !query output +1 + + +-- !query +select timestampdiff(SECOND, date'2022-02-15', timestamp'2022-02-14 23:59:59') +-- !query schema +struct +-- !query output +-1 diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out index 47faeb3ce9ea4..f3c483cfafea8 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 29 +-- Number of queries: 49 -- !query @@ -233,3 +233,163 @@ SELECT try_divide(interval 106751991 day, 0.5) struct -- !query output NULL + + +-- !query +SELECT try_subtract(1, 1) +-- !query schema +struct +-- !query output +0 + + +-- !query +SELECT try_subtract(2147483647, -1) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT try_subtract(-2147483648, 1) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT try_subtract(9223372036854775807L, -1) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT try_subtract(-9223372036854775808L, 1) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT try_subtract(interval 2 year, interval 3 year) +-- !query schema +struct +-- !query output +-1-0 + + +-- !query +SELECT try_subtract(interval 3 second, interval 2 second) +-- !query schema +struct +-- !query output +0 00:00:01.000000000 + + +-- !query +SELECT try_subtract(interval 2147483647 month, interval -2 month) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT try_subtract(interval 106751991 day, interval -3 day) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT try_multiply(2, 3) +-- !query schema +struct +-- !query output +6 + + +-- !query +SELECT try_multiply(2147483647, -2) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT try_multiply(-2147483648, 2) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT try_multiply(9223372036854775807L, 2) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT try_multiply(-9223372036854775808L, -2) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT try_multiply(interval 2 year, 2) +-- !query schema +struct +-- !query output +4-0 + + +-- !query +SELECT try_multiply(interval 2 second, 2) +-- !query schema +struct +-- !query output +0 00:00:04.000000000 + + +-- !query +SELECT try_multiply(interval 2 year, 0) +-- !query schema +struct +-- !query output +0-0 + + +-- !query +SELECT try_multiply(interval 2 second, 0) +-- !query schema +struct +-- !query output +0 00:00:00.000000000 + + +-- !query +SELECT try_multiply(interval 2147483647 month, 2) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT try_multiply(interval 106751991 day, 2) +-- !query schema +struct +-- !query output +NULL diff --git a/sql/core/src/test/resources/sql-tests/results/array.sql.out b/sql/core/src/test/resources/sql-tests/results/array.sql.out index 76fdf035ad4ec..1ff2a1790ceee 100644 --- a/sql/core/src/test/resources/sql-tests/results/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/array.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 20 +-- Number of queries: 29 -- !query @@ -211,6 +211,38 @@ struct NULL +-- !query +select elt(null, '123', '456') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select elt(null, '123', null) +-- !query schema +struct +-- !query output +NULL + + +-- !query +select elt(1, '123', null) +-- !query schema +struct +-- !query output +123 + + +-- !query +select elt(2, '123', null) +-- !query schema +struct +-- !query output +NULL + + -- !query select array(1, 2, 3)[5] -- !query schema @@ -225,3 +257,44 @@ select array(1, 2, 3)[-1] struct -- !query output NULL + + +-- !query +select array_size(array()) +-- !query schema +struct +-- !query output +0 + + +-- !query +select array_size(array(true)) +-- !query schema +struct +-- !query output +1 + + +-- !query +select array_size(array(2, 1)) +-- !query schema +struct +-- !query output +2 + + +-- !query +select array_size(NULL) +-- !query schema +struct +-- !query output +NULL + + +-- !query +select array_size(map('a', 1, 'b', 2)) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +cannot resolve 'array_size(map('a', 1, 'b', 2))' due to data type mismatch: argument 1 requires array type, however, 'map('a', 1, 'b', 2)' is of map type.; line 1 pos 7 diff --git a/sql/core/src/test/resources/sql-tests/results/ceil-floor-with-scale-param.sql.out b/sql/core/src/test/resources/sql-tests/results/ceil-floor-with-scale-param.sql.out new file mode 100644 index 0000000000000..132bd96350fb1 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/ceil-floor-with-scale-param.sql.out @@ -0,0 +1,200 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 24 + + +-- !query +SELECT CEIL(2.5, 0) +-- !query schema +struct +-- !query output +3 + + +-- !query +SELECT CEIL(3.5, 0) +-- !query schema +struct +-- !query output +4 + + +-- !query +SELECT CEIL(-2.5, 0) +-- !query schema +struct +-- !query output +-2 + + +-- !query +SELECT CEIL(-3.5, 0) +-- !query schema +struct +-- !query output +-3 + + +-- !query +SELECT CEIL(-0.35, 1) +-- !query schema +struct +-- !query output +-0.3 + + +-- !query +SELECT CEIL(-35, -1) +-- !query schema +struct +-- !query output +-30 + + +-- !query +SELECT CEIL(-0.1, 0) +-- !query schema +struct +-- !query output +0 + + +-- !query +SELECT CEIL(5, 0) +-- !query schema +struct +-- !query output +5 + + +-- !query +SELECT CEIL(3.14115, -3) +-- !query schema +struct +-- !query output +1000 + + +-- !query +SELECT CEIL(2.5, null) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +The 'scale' parameter of function 'ceil' needs to be a int literal.; line 1 pos 7 + + +-- !query +SELECT CEIL(2.5, 'a') +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +The 'scale' parameter of function 'ceil' needs to be a int literal.; line 1 pos 7 + + +-- !query +SELECT CEIL(2.5, 0, 0) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Invalid number of arguments for function ceil. Expected: 2; Found: 3; line 1 pos 7 + + +-- !query +SELECT FLOOR(2.5, 0) +-- !query schema +struct +-- !query output +2 + + +-- !query +SELECT FLOOR(3.5, 0) +-- !query schema +struct +-- !query output +3 + + +-- !query +SELECT FLOOR(-2.5, 0) +-- !query schema +struct +-- !query output +-3 + + +-- !query +SELECT FLOOR(-3.5, 0) +-- !query schema +struct +-- !query output +-4 + + +-- !query +SELECT FLOOR(-0.35, 1) +-- !query schema +struct +-- !query output +-0.4 + + +-- !query +SELECT FLOOR(-35, -1) +-- !query schema +struct +-- !query output +-40 + + +-- !query +SELECT FLOOR(-0.1, 0) +-- !query schema +struct +-- !query output +-1 + + +-- !query +SELECT FLOOR(5, 0) +-- !query schema +struct +-- !query output +5 + + +-- !query +SELECT FLOOR(3.14115, -3) +-- !query schema +struct +-- !query output +0 + + +-- !query +SELECT FLOOR(2.5, null) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +The 'scale' parameter of function 'floor' needs to be a int literal.; line 1 pos 7 + + +-- !query +SELECT FLOOR(2.5, 'a') +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +The 'scale' parameter of function 'floor' needs to be a int literal.; line 1 pos 7 + + +-- !query +SELECT FLOOR(2.5, 0, 0) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Invalid number of arguments for function floor. Expected: 2; Found: 3; line 1 pos 7 diff --git a/sql/core/src/test/resources/sql-tests/results/charvarchar.sql.out b/sql/core/src/test/resources/sql-tests/results/charvarchar.sql.out index fcd207cd15001..6345702e00ea2 100644 --- a/sql/core/src/test/resources/sql-tests/results/charvarchar.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/charvarchar.sql.out @@ -51,9 +51,9 @@ show create table char_tbl -- !query schema struct -- !query output -CREATE TABLE `default`.`char_tbl` ( - `c` CHAR(5), - `v` VARCHAR(6)) +CREATE TABLE default.char_tbl ( + c CHAR(5), + v VARCHAR(6)) USING parquet @@ -70,9 +70,9 @@ show create table char_tbl2 -- !query schema struct -- !query output -CREATE TABLE `default`.`char_tbl2` ( - `c` CHAR(5), - `v` VARCHAR(6)) +CREATE TABLE default.char_tbl2 ( + c CHAR(5), + v VARCHAR(6)) USING parquet @@ -161,9 +161,9 @@ show create table char_tbl3 -- !query schema struct -- !query output -CREATE TABLE `default`.`char_tbl3` ( - `c` CHAR(5), - `v` VARCHAR(6)) +CREATE TABLE default.char_tbl3 ( + c CHAR(5), + v VARCHAR(6)) USING parquet @@ -218,9 +218,9 @@ show create table char_view -- !query schema struct -- !query output -CREATE VIEW `default`.`char_view` ( - `c`, - `v`) +CREATE VIEW default.char_view ( + c, + v) AS select * from char_tbl diff --git a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out index 2ca44d51244a5..53cae3f935568 100644 --- a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out @@ -89,7 +89,7 @@ select schema_of_csv('1|abc', map('delimiter', '|')) -- !query schema struct -- !query output -STRUCT<`_c0`: INT, `_c1`: STRING> +STRUCT<_c0: INT, _c1: STRING> -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/date.sql.out b/sql/core/src/test/resources/sql-tests/results/date.sql.out index 562028945103e..91c89ef5a93d7 100644 --- a/sql/core/src/test/resources/sql-tests/results/date.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/date.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 77 +-- Number of queries: 97 -- !query @@ -640,7 +640,7 @@ select to_date('26/October/2015', 'dd/MMMMM/yyyy') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd/MMMMM/yyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'dd/MMMMM/yyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -649,7 +649,7 @@ select from_json('{"d":"26/October/2015"}', 'd Date', map('dateFormat', 'dd/MMMM struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd/MMMMM/yyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'dd/MMMMM/yyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -658,4 +658,164 @@ select from_csv('26/October/2015', 'd Date', map('dateFormat', 'dd/MMMMM/yyyy')) struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd/MMMMM/yyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'dd/MMMMM/yyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select dateadd(MICROSECOND, 1001, timestamp'2022-02-25 01:02:03.123') +-- !query schema +struct +-- !query output +2022-02-25 01:02:03.124001 + + +-- !query +select dateadd(MILLISECOND, -1, timestamp'2022-02-25 01:02:03.456') +-- !query schema +struct +-- !query output +2022-02-25 01:02:03.455 + + +-- !query +select dateadd(SECOND, 58, timestamp'2022-02-25 01:02:03') +-- !query schema +struct +-- !query output +2022-02-25 01:03:01 + + +-- !query +select dateadd(MINUTE, -100, date'2022-02-25') +-- !query schema +struct +-- !query output +2022-02-24 22:20:00 + + +-- !query +select dateadd(HOUR, -1, timestamp'2022-02-25 01:02:03') +-- !query schema +struct +-- !query output +2022-02-25 00:02:03 + + +-- !query +select dateadd(DAY, 367, date'2022-02-25') +-- !query schema +struct +-- !query output +2023-02-27 00:00:00 + + +-- !query +select dateadd(WEEK, -4, timestamp'2022-02-25 01:02:03') +-- !query schema +struct +-- !query output +2022-01-28 01:02:03 + + +-- !query +select dateadd(MONTH, -1, timestamp'2022-02-25 01:02:03') +-- !query schema +struct +-- !query output +2022-01-25 01:02:03 + + +-- !query +select dateadd(QUARTER, 5, date'2022-02-25') +-- !query schema +struct +-- !query output +2023-05-25 00:00:00 + + +-- !query +select dateadd(YEAR, 1, date'2022-02-25') +-- !query schema +struct +-- !query output +2023-02-25 00:00:00 + + +-- !query +select datediff(MICROSECOND, timestamp'2022-02-25 01:02:03.123', timestamp'2022-02-25 01:02:03.124001') +-- !query schema +struct +-- !query output +1001 + + +-- !query +select datediff(MILLISECOND, timestamp'2022-02-25 01:02:03.456', timestamp'2022-02-25 01:02:03.455') +-- !query schema +struct +-- !query output +-1 + + +-- !query +select datediff(SECOND, timestamp'2022-02-25 01:02:03', timestamp'2022-02-25 01:03:01') +-- !query schema +struct +-- !query output +58 + + +-- !query +select datediff(MINUTE, date'2022-02-25', timestamp'2022-02-24 22:20:00') +-- !query schema +struct +-- !query output +-100 + + +-- !query +select datediff(HOUR, timestamp'2022-02-25 01:02:03', timestamp'2022-02-25 00:02:03') +-- !query schema +struct +-- !query output +-1 + + +-- !query +select datediff(DAY, date'2022-02-25', timestamp'2023-02-27 00:00:00') +-- !query schema +struct +-- !query output +367 + + +-- !query +select datediff(WEEK, timestamp'2022-02-25 01:02:03', timestamp'2022-01-28 01:02:03') +-- !query schema +struct +-- !query output +-4 + + +-- !query +select datediff(MONTH, timestamp'2022-02-25 01:02:03', timestamp'2022-01-25 01:02:03') +-- !query schema +struct +-- !query output +-1 + + +-- !query +select datediff(QUARTER, date'2022-02-25', date'2023-05-25') +-- !query schema +struct +-- !query output +5 + + +-- !query +select datediff(YEAR, date'2022-02-25', date'2023-02-25') +-- !query schema +struct +-- !query output +1 diff --git a/sql/core/src/test/resources/sql-tests/results/datetime-formatting-invalid.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime-formatting-invalid.sql.out index 9c8553dc0f01f..6649ae3dbaf1c 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime-formatting-invalid.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime-formatting-invalid.sql.out @@ -8,7 +8,7 @@ select date_format('2018-11-17 13:33:33.333', 'GGGGG') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'GGGGG' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'GGGGG' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -17,7 +17,7 @@ select date_format('2018-11-17 13:33:33.333', 'yyyyyyy') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'yyyyyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'yyyyyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -44,7 +44,7 @@ select date_format('2018-11-17 13:33:33.333', 'MMMMM') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'MMMMM' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'MMMMM' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -53,7 +53,7 @@ select date_format('2018-11-17 13:33:33.333', 'LLLLL') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'LLLLL' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'LLLLL' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -62,7 +62,7 @@ select date_format('2018-11-17 13:33:33.333', 'EEEEE') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'EEEEE' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'EEEEE' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -71,7 +71,7 @@ select date_format('2018-11-17 13:33:33.333', 'FF') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'FF' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'FF' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -80,7 +80,7 @@ select date_format('2018-11-17 13:33:33.333', 'ddd') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'ddd' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'ddd' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -89,7 +89,7 @@ select date_format('2018-11-17 13:33:33.333', 'DDDD') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'DDDD' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'DDDD' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -98,7 +98,7 @@ select date_format('2018-11-17 13:33:33.333', 'HHH') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'HHH' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'HHH' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -107,7 +107,7 @@ select date_format('2018-11-17 13:33:33.333', 'hhh') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'hhh' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'hhh' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -116,7 +116,7 @@ select date_format('2018-11-17 13:33:33.333', 'kkk') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'kkk' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'kkk' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -125,7 +125,7 @@ select date_format('2018-11-17 13:33:33.333', 'KKK') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'KKK' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'KKK' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -134,7 +134,7 @@ select date_format('2018-11-17 13:33:33.333', 'mmm') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'mmm' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'mmm' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -143,7 +143,7 @@ select date_format('2018-11-17 13:33:33.333', 'sss') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'sss' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'sss' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -152,7 +152,7 @@ select date_format('2018-11-17 13:33:33.333', 'SSSSSSSSSS') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'SSSSSSSSSS' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'SSSSSSSSSS' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -161,7 +161,7 @@ select date_format('2018-11-17 13:33:33.333', 'aa') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'aa' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'aa' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -179,7 +179,7 @@ select date_format('2018-11-17 13:33:33.333', 'zzzzz') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'zzzzz' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'zzzzz' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -197,7 +197,7 @@ select date_format('2018-11-17 13:33:33.333', 'ZZZZZZ') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'ZZZZZZ' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'ZZZZZZ' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -260,7 +260,7 @@ select date_format('2018-11-17 13:33:33.333', 'Y') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'Y' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'Y' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -269,7 +269,7 @@ select date_format('2018-11-17 13:33:33.333', 'w') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'w' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'w' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -278,7 +278,7 @@ select date_format('2018-11-17 13:33:33.333', 'W') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'W' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'W' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -287,7 +287,7 @@ select date_format('2018-11-17 13:33:33.333', 'u') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'u' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'u' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/datetime-legacy.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime-legacy.sql.out index 74480ab6cc2b4..ebfdf60effdae 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime-legacy.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime-legacy.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 166 +-- Number of queries: 194 -- !query @@ -658,6 +658,166 @@ struct> {"d":2015-10-26} +-- !query +select dateadd(MICROSECOND, 1001, timestamp'2022-02-25 01:02:03.123') +-- !query schema +struct +-- !query output +2022-02-25 01:02:03.124001 + + +-- !query +select dateadd(MILLISECOND, -1, timestamp'2022-02-25 01:02:03.456') +-- !query schema +struct +-- !query output +2022-02-25 01:02:03.455 + + +-- !query +select dateadd(SECOND, 58, timestamp'2022-02-25 01:02:03') +-- !query schema +struct +-- !query output +2022-02-25 01:03:01 + + +-- !query +select dateadd(MINUTE, -100, date'2022-02-25') +-- !query schema +struct +-- !query output +2022-02-24 22:20:00 + + +-- !query +select dateadd(HOUR, -1, timestamp'2022-02-25 01:02:03') +-- !query schema +struct +-- !query output +2022-02-25 00:02:03 + + +-- !query +select dateadd(DAY, 367, date'2022-02-25') +-- !query schema +struct +-- !query output +2023-02-27 00:00:00 + + +-- !query +select dateadd(WEEK, -4, timestamp'2022-02-25 01:02:03') +-- !query schema +struct +-- !query output +2022-01-28 01:02:03 + + +-- !query +select dateadd(MONTH, -1, timestamp'2022-02-25 01:02:03') +-- !query schema +struct +-- !query output +2022-01-25 01:02:03 + + +-- !query +select dateadd(QUARTER, 5, date'2022-02-25') +-- !query schema +struct +-- !query output +2023-05-25 00:00:00 + + +-- !query +select dateadd(YEAR, 1, date'2022-02-25') +-- !query schema +struct +-- !query output +2023-02-25 00:00:00 + + +-- !query +select datediff(MICROSECOND, timestamp'2022-02-25 01:02:03.123', timestamp'2022-02-25 01:02:03.124001') +-- !query schema +struct +-- !query output +1001 + + +-- !query +select datediff(MILLISECOND, timestamp'2022-02-25 01:02:03.456', timestamp'2022-02-25 01:02:03.455') +-- !query schema +struct +-- !query output +-1 + + +-- !query +select datediff(SECOND, timestamp'2022-02-25 01:02:03', timestamp'2022-02-25 01:03:01') +-- !query schema +struct +-- !query output +58 + + +-- !query +select datediff(MINUTE, date'2022-02-25', timestamp'2022-02-24 22:20:00') +-- !query schema +struct +-- !query output +-100 + + +-- !query +select datediff(HOUR, timestamp'2022-02-25 01:02:03', timestamp'2022-02-25 00:02:03') +-- !query schema +struct +-- !query output +-1 + + +-- !query +select datediff(DAY, date'2022-02-25', timestamp'2023-02-27 00:00:00') +-- !query schema +struct +-- !query output +367 + + +-- !query +select datediff(WEEK, timestamp'2022-02-25 01:02:03', timestamp'2022-01-28 01:02:03') +-- !query schema +struct +-- !query output +-4 + + +-- !query +select datediff(MONTH, timestamp'2022-02-25 01:02:03', timestamp'2022-01-25 01:02:03') +-- !query schema +struct +-- !query output +-1 + + +-- !query +select datediff(QUARTER, date'2022-02-25', date'2023-05-25') +-- !query schema +struct +-- !query output +5 + + +-- !query +select datediff(YEAR, date'2022-02-25', date'2023-02-25') +-- !query schema +struct +-- !query output +1 + + -- !query select timestamp '2019-01-01\t' -- !query schema @@ -1415,3 +1575,67 @@ select from_csv('26/October/2015', 't Timestamp', map('timestampFormat', 'dd/MMM struct> -- !query output {"t":2015-10-26 00:00:00} + + +-- !query +select timestampadd(MONTH, -1, timestamp'2022-02-14 01:02:03') +-- !query schema +struct +-- !query output +2022-01-14 01:02:03 + + +-- !query +select timestampadd(MINUTE, 58, timestamp'2022-02-14 01:02:03') +-- !query schema +struct +-- !query output +2022-02-14 02:00:03 + + +-- !query +select timestampadd(YEAR, 1, date'2022-02-15') +-- !query schema +struct +-- !query output +2023-02-15 00:00:00 + + +-- !query +select timestampadd(SECOND, -1, date'2022-02-15') +-- !query schema +struct +-- !query output +2022-02-14 23:59:59 + + +-- !query +select timestampdiff(MONTH, timestamp'2022-02-14 01:02:03', timestamp'2022-01-14 01:02:03') +-- !query schema +struct +-- !query output +-1 + + +-- !query +select timestampdiff(MINUTE, timestamp'2022-02-14 01:02:03', timestamp'2022-02-14 02:00:03') +-- !query schema +struct +-- !query output +58 + + +-- !query +select timestampdiff(YEAR, date'2022-02-15', date'2023-02-15') +-- !query schema +struct +-- !query output +1 + + +-- !query +select timestampdiff(SECOND, date'2022-02-15', timestamp'2022-02-14 23:59:59') +-- !query schema +struct +-- !query output +-1 diff --git a/sql/core/src/test/resources/sql-tests/results/datetime-parsing-invalid.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime-parsing-invalid.sql.out index c1e1a2c4b2143..9fc28876a5b2a 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime-parsing-invalid.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime-parsing-invalid.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 29 +-- Number of queries: 28 -- !query @@ -17,7 +17,7 @@ select to_timestamp('1', 'yy') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to parse '1' in the new parser. You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0, or set to CORRECTED and treat it as an invalid datetime string. +You may get a different result due to the upgrading to Spark >= 3.0: Fail to parse '1' in the new parser. You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0, or set to CORRECTED and treat it as an invalid datetime string. -- !query @@ -34,7 +34,7 @@ select to_timestamp('123', 'yy') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to parse '123' in the new parser. You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0, or set to CORRECTED and treat it as an invalid datetime string. +You may get a different result due to the upgrading to Spark >= 3.0: Fail to parse '123' in the new parser. You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0, or set to CORRECTED and treat it as an invalid datetime string. -- !query @@ -43,7 +43,7 @@ select to_timestamp('1', 'yyy') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to parse '1' in the new parser. You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0, or set to CORRECTED and treat it as an invalid datetime string. +You may get a different result due to the upgrading to Spark >= 3.0: Fail to parse '1' in the new parser. You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0, or set to CORRECTED and treat it as an invalid datetime string. -- !query @@ -52,7 +52,7 @@ select to_timestamp('1234567', 'yyyyyyy') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'yyyyyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'yyyyyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -69,15 +69,7 @@ select to_timestamp('9', 'DD') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to parse '9' in the new parser. You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0, or set to CORRECTED and treat it as an invalid datetime string. - - --- !query -select to_timestamp('366', 'DD') --- !query schema -struct --- !query output -NULL +You may get a different result due to the upgrading to Spark >= 3.0: Fail to parse '9' in the new parser. You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0, or set to CORRECTED and treat it as an invalid datetime string. -- !query @@ -86,7 +78,7 @@ select to_timestamp('9', 'DDD') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to parse '9' in the new parser. You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0, or set to CORRECTED and treat it as an invalid datetime string. +You may get a different result due to the upgrading to Spark >= 3.0: Fail to parse '9' in the new parser. You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0, or set to CORRECTED and treat it as an invalid datetime string. -- !query @@ -95,7 +87,7 @@ select to_timestamp('99', 'DDD') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to parse '99' in the new parser. You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0, or set to CORRECTED and treat it as an invalid datetime string. +You may get a different result due to the upgrading to Spark >= 3.0: Fail to parse '99' in the new parser. You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0, or set to CORRECTED and treat it as an invalid datetime string. -- !query @@ -160,7 +152,7 @@ select from_csv('2018-366', 'date Date', map('dateFormat', 'yyyy-DDD')) struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to parse '2018-366' in the new parser. You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0, or set to CORRECTED and treat it as an invalid datetime string. +You may get a different result due to the upgrading to Spark >= 3.0: Fail to parse '2018-366' in the new parser. You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0, or set to CORRECTED and treat it as an invalid datetime string. -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/describe-query.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-query.sql.out index 2199fc0312d25..322b24877a57e 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe-query.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe-query.sql.out @@ -112,7 +112,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -mismatched input 'desc_temp1' expecting {, ';'}(line 1, pos 21) +Syntax error at or near 'desc_temp1'(line 1, pos 21) == SQL == DESCRIBE INSERT INTO desc_temp1 values (1, 'val1') @@ -126,7 +126,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -mismatched input 'desc_temp1' expecting {, ';'}(line 1, pos 21) +Syntax error at or near 'desc_temp1'(line 1, pos 21) == SQL == DESCRIBE INSERT INTO desc_temp1 SELECT * FROM desc_temp2 @@ -143,7 +143,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -mismatched input 'insert' expecting {'MAP', 'REDUCE', 'SELECT'}(line 3, pos 5) +Syntax error at or near 'insert'(line 3, pos 5) == SQL == DESCRIBE diff --git a/sql/core/src/test/resources/sql-tests/results/extract.sql.out b/sql/core/src/test/resources/sql-tests/results/extract.sql.out index e3f676dfd1f5e..55776d3243689 100644 --- a/sql/core/src/test/resources/sql-tests/results/extract.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/extract.sql.out @@ -660,7 +660,7 @@ select date_part(c, c) from t struct<> -- !query output org.apache.spark.sql.AnalysisException -The field parameter needs to be a foldable string value.; line 1 pos 7 +The 'field' parameter of function 'date_part' needs to be a string literal.; line 1 pos 7 -- !query @@ -677,7 +677,7 @@ select date_part(i, i) from t struct<> -- !query output org.apache.spark.sql.AnalysisException -The field parameter needs to be a foldable string value.; line 1 pos 7 +The 'field' parameter of function 'date_part' needs to be a string literal.; line 1 pos 7 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index cd0fa486cdb6f..7ae9199701536 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 76 +-- Number of queries: 95 -- !query @@ -470,7 +470,7 @@ SELECT every(1) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'every(1)' due to data type mismatch: Input to function 'every' should have been boolean, but it's [int].; line 1 pos 7 +cannot resolve 'every(1)' due to data type mismatch: argument 1 requires boolean type, however, '1' is of int type.; line 1 pos 7 -- !query @@ -479,7 +479,7 @@ SELECT some(1S) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'some(1S)' due to data type mismatch: Input to function 'some' should have been boolean, but it's [smallint].; line 1 pos 7 +cannot resolve 'some(1S)' due to data type mismatch: argument 1 requires boolean type, however, '1S' is of smallint type.; line 1 pos 7 -- !query @@ -488,7 +488,7 @@ SELECT any(1L) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'any(1L)' due to data type mismatch: Input to function 'any' should have been boolean, but it's [bigint].; line 1 pos 7 +cannot resolve 'any(1L)' due to data type mismatch: argument 1 requires boolean type, however, '1L' is of bigint type.; line 1 pos 7 -- !query @@ -497,7 +497,7 @@ SELECT every("true") struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'every('true')' due to data type mismatch: Input to function 'every' should have been boolean, but it's [string].; line 1 pos 7 +cannot resolve 'every('true')' due to data type mismatch: argument 1 requires boolean type, however, ''true'' is of string type.; line 1 pos 7 -- !query @@ -506,7 +506,7 @@ SELECT bool_and(1.0) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'bool_and(1.0BD)' due to data type mismatch: Input to function 'bool_and' should have been boolean, but it's [decimal(2,1)].; line 1 pos 7 +cannot resolve 'bool_and(1.0BD)' due to data type mismatch: argument 1 requires boolean type, however, '1.0BD' is of decimal(2,1) type.; line 1 pos 7 -- !query @@ -515,7 +515,7 @@ SELECT bool_or(1.0D) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'bool_or(1.0D)' due to data type mismatch: Input to function 'bool_or' should have been boolean, but it's [double].; line 1 pos 7 +cannot resolve 'bool_or(1.0D)' due to data type mismatch: argument 1 requires boolean type, however, '1.0D' is of double type.; line 1 pos 7 -- !query @@ -708,9 +708,139 @@ FROM VALUES (31), (32), (33), (34), (35), (3), (37), (38), (39), (40), (41), (42), (43), (44), (45), (46), (47), (48), (49), (50) AS tab(col) -- !query schema -struct>,histogram_3:array>,histogram_5:array>,histogram_10:array>> +struct>,histogram_3:array>,histogram_5:array>,histogram_10:array>> -- !query output -[{"x":12.615384615384613,"y":26.0},{"x":38.083333333333336,"y":24.0}] [{"x":9.649999999999999,"y":20.0},{"x":25.0,"y":11.0},{"x":40.736842105263165,"y":19.0}] [{"x":5.272727272727273,"y":11.0},{"x":14.5,"y":8.0},{"x":22.0,"y":7.0},{"x":30.499999999999996,"y":10.0},{"x":43.5,"y":14.0}] [{"x":3.0,"y":6.0},{"x":8.5,"y":6.0},{"x":13.5,"y":4.0},{"x":17.0,"y":3.0},{"x":20.5,"y":4.0},{"x":25.5,"y":6.0},{"x":31.999999999999996,"y":7.0},{"x":39.0,"y":5.0},{"x":43.5,"y":4.0},{"x":48.0,"y":5.0}] +[{"x":12,"y":26.0},{"x":38,"y":24.0}] [{"x":9,"y":20.0},{"x":25,"y":11.0},{"x":40,"y":19.0}] [{"x":5,"y":11.0},{"x":14,"y":8.0},{"x":22,"y":7.0},{"x":30,"y":10.0},{"x":43,"y":14.0}] [{"x":3,"y":6.0},{"x":8,"y":6.0},{"x":13,"y":4.0},{"x":17,"y":3.0},{"x":20,"y":4.0},{"x":25,"y":6.0},{"x":31,"y":7.0},{"x":39,"y":5.0},{"x":43,"y":4.0},{"x":48,"y":5.0}] + + +-- !query +SELECT histogram_numeric(col, 3) FROM VALUES (1), (2), (3) AS tab(col) +-- !query schema +struct>> +-- !query output +[{"x":1,"y":1.0},{"x":2,"y":1.0},{"x":3,"y":1.0}] + + +-- !query +SELECT histogram_numeric(col, 3) FROM VALUES (1L), (2L), (3L) AS tab(col) +-- !query schema +struct>> +-- !query output +[{"x":1,"y":1.0},{"x":2,"y":1.0},{"x":3,"y":1.0}] + + +-- !query +SELECT histogram_numeric(col, 3) FROM VALUES (1F), (2F), (3F) AS tab(col) +-- !query schema +struct>> +-- !query output +[{"x":1.0,"y":1.0},{"x":2.0,"y":1.0},{"x":3.0,"y":1.0}] + + +-- !query +SELECT histogram_numeric(col, 3) FROM VALUES (1D), (2D), (3D) AS tab(col) +-- !query schema +struct>> +-- !query output +[{"x":1.0,"y":1.0},{"x":2.0,"y":1.0},{"x":3.0,"y":1.0}] + + +-- !query +SELECT histogram_numeric(col, 3) FROM VALUES (1S), (2S), (3S) AS tab(col) +-- !query schema +struct>> +-- !query output +[{"x":1,"y":1.0},{"x":2,"y":1.0},{"x":3,"y":1.0}] + + +-- !query +SELECT histogram_numeric(col, 3) FROM VALUES + (CAST(1 AS BYTE)), (CAST(2 AS BYTE)), (CAST(3 AS BYTE)) AS tab(col) +-- !query schema +struct>> +-- !query output +[{"x":1,"y":1.0},{"x":2,"y":1.0},{"x":3,"y":1.0}] + + +-- !query +SELECT histogram_numeric(col, 3) FROM VALUES + (CAST(1 AS TINYINT)), (CAST(2 AS TINYINT)), (CAST(3 AS TINYINT)) AS tab(col) +-- !query schema +struct>> +-- !query output +[{"x":1,"y":1.0},{"x":2,"y":1.0},{"x":3,"y":1.0}] + + +-- !query +SELECT histogram_numeric(col, 3) FROM VALUES + (CAST(1 AS SMALLINT)), (CAST(2 AS SMALLINT)), (CAST(3 AS SMALLINT)) AS tab(col) +-- !query schema +struct>> +-- !query output +[{"x":1,"y":1.0},{"x":2,"y":1.0},{"x":3,"y":1.0}] + + +-- !query +SELECT histogram_numeric(col, 3) FROM VALUES + (CAST(1 AS BIGINT)), (CAST(2 AS BIGINT)), (CAST(3 AS BIGINT)) AS tab(col) +-- !query schema +struct>> +-- !query output +[{"x":1,"y":1.0},{"x":2,"y":1.0},{"x":3,"y":1.0}] + + +-- !query +SELECT histogram_numeric(col, 3) FROM VALUES (TIMESTAMP '2017-03-01 00:00:00'), + (TIMESTAMP '2017-04-01 00:00:00'), (TIMESTAMP '2017-05-01 00:00:00') AS tab(col) +-- !query schema +struct>> +-- !query output +[{"x":2017-03-01 00:00:00,"y":1.0},{"x":2017-04-01 00:00:00,"y":1.0},{"x":2017-05-01 00:00:00,"y":1.0}] + + +-- !query +SELECT histogram_numeric(col, 3) FROM VALUES (INTERVAL '100-00' YEAR TO MONTH), + (INTERVAL '110-00' YEAR TO MONTH), (INTERVAL '120-00' YEAR TO MONTH) AS tab(col) +-- !query schema +struct>> +-- !query output +[{"x":100-0,"y":1.0},{"x":110-0,"y":1.0},{"x":120-0,"y":1.0}] + + +-- !query +SELECT histogram_numeric(col, 3) FROM VALUES (INTERVAL '12 20:4:0' DAY TO SECOND), + (INTERVAL '12 21:4:0' DAY TO SECOND), (INTERVAL '12 22:4:0' DAY TO SECOND) AS tab(col) +-- !query schema +struct>> +-- !query output +[{"x":12 20:04:00.000000000,"y":1.0},{"x":12 21:04:00.000000000,"y":1.0},{"x":12 22:04:00.000000000,"y":1.0}] + + +-- !query +SELECT histogram_numeric(col, 3) +FROM VALUES (NULL), (NULL), (NULL) AS tab(col) +-- !query schema +struct>> +-- !query output +NULL + + +-- !query +SELECT histogram_numeric(col, 3) +FROM VALUES (CAST(NULL AS DOUBLE)), (CAST(NULL AS DOUBLE)), (CAST(NULL AS DOUBLE)) AS tab(col) +-- !query schema +struct>> +-- !query output +NULL + + +-- !query +SELECT histogram_numeric(col, 3) +FROM VALUES (CAST(NULL AS INT)), (CAST(NULL AS INT)), (CAST(NULL AS INT)) AS tab(col) +-- !query schema +struct>> +-- !query output +NULL -- !query @@ -774,6 +904,40 @@ struct,collect_list(b):array> 2 [3,4] [3,4] +-- !query +SELECT regr_avgx(y, x), regr_avgy(y, x) FROM testRegression +-- !query schema +struct +-- !query output +22.666666666666668 20.0 + + +-- !query +SELECT regr_avgx(y, x), regr_avgy(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL +-- !query schema +struct +-- !query output +22.666666666666668 20.0 + + +-- !query +SELECT k, avg(x), avg(y), regr_avgx(y, x), regr_avgy(y, x) FROM testRegression GROUP BY k +-- !query schema +struct +-- !query output +1 NULL 10.0 NULL NULL +2 22.666666666666668 21.25 22.666666666666668 20.0 + + +-- !query +SELECT k, avg(x) FILTER (WHERE x IS NOT NULL AND y IS NOT NULL), avg(y) FILTER (WHERE x IS NOT NULL AND y IS NOT NULL), regr_avgx(y, x), regr_avgy(y, x) FROM testRegression GROUP BY k +-- !query schema +struct +-- !query output +1 NULL NULL NULL NULL +2 22.666666666666668 20.0 22.666666666666668 20.0 + + -- !query SELECT percentile_cont(0.25) WITHIN GROUP (ORDER BY v), diff --git a/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out b/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out index c1b595ec4fe61..cc1619813dd55 100644 --- a/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out @@ -153,7 +153,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -LATERAL join with NATURAL join is not supported(line 1, pos 14) +The feature is not supported: LATERAL join with NATURAL join.(line 1, pos 14) == SQL == SELECT * FROM t1 NATURAL JOIN LATERAL (SELECT c1 + c2 AS c2) @@ -167,7 +167,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -LATERAL join with USING join is not supported(line 1, pos 14) +The feature is not supported: LATERAL join with USING join.(line 1, pos 14) == SQL == SELECT * FROM t1 JOIN LATERAL (SELECT c1 + c2 AS c2) USING (c2) diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index ff59553e4e9d9..84610834fa7e7 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -236,7 +236,7 @@ select schema_of_json('{"c1":0, "c2":[1]}') -- !query schema struct -- !query output -STRUCT<`c1`: BIGINT, `c2`: ARRAY> +STRUCT> -- !query @@ -339,7 +339,7 @@ select from_json( struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to parse '02-29' in the new parser. You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0, or set to CORRECTED and treat it as an invalid datetime string. +You may get a different result due to the upgrading to Spark >= 3.0: Fail to parse '02-29' in the new parser. You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0, or set to CORRECTED and treat it as an invalid datetime string. -- !query @@ -351,7 +351,7 @@ select from_json( struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to parse '02-29' in the new parser. You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0, or set to CORRECTED and treat it as an invalid datetime string. +You may get a different result due to the upgrading to Spark >= 3.0: Fail to parse '02-29' in the new parser. You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0, or set to CORRECTED and treat it as an invalid datetime string. -- !query @@ -375,7 +375,7 @@ select schema_of_json('{"c1":1}', map('primitivesAsString', 'true')) -- !query schema struct -- !query output -STRUCT<`c1`: STRING> +STRUCT -- !query @@ -383,7 +383,7 @@ select schema_of_json('{"c1":01, "c2":0.1}', map('allowNumericLeadingZeros', 'tr -- !query schema struct -- !query output -STRUCT<`c1`: BIGINT, `c2`: DECIMAL(1,1)> +STRUCT -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/map.sql.out b/sql/core/src/test/resources/sql-tests/results/map.sql.out index aa13fee451d11..b615a62581108 100644 --- a/sql/core/src/test/resources/sql-tests/results/map.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/map.sql.out @@ -72,7 +72,7 @@ select map_contains_key(map('1', 'a', '2', 'b'), 1) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'array_contains(map_keys(map('1', 'a', '2', 'b')), 1)' due to data type mismatch: Input to function array_contains should have been array followed by a value with same element type, but it's [array, int].; line 1 pos 7 +cannot resolve 'map_contains_key(map('1', 'a', '2', 'b'), 1)' due to data type mismatch: Input to function map_contains_key should have been map followed by a value with same key type, but it's [map, int].; line 1 pos 7 -- !query @@ -81,4 +81,4 @@ select map_contains_key(map(1, 'a', 2, 'b'), '1') struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'array_contains(map_keys(map(1, 'a', 2, 'b')), '1')' due to data type mismatch: Input to function array_contains should have been array followed by a value with same element type, but it's [array, string].; line 1 pos 7 +cannot resolve 'map_contains_key(map(1, 'a', 2, 'b'), '1')' due to data type mismatch: Input to function map_contains_key should have been map followed by a value with same key type, but it's [map, string].; line 1 pos 7 diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part1.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part1.sql.out index 91f8185ff9055..f2c20bced6e1f 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part1.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part1.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 45 +-- Number of queries: 46 -- !query @@ -296,6 +296,14 @@ struct 4 +-- !query +SELECT regr_avgx(b, a), regr_avgy(b, a) FROM aggtest +-- !query schema +struct +-- !query output +49.5 107.94315227307379 + + -- !query SELECT covar_pop(b, a), covar_samp(b, a) FROM aggtest -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/float4.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/float4.sql.out index 6aa890efe5c95..690fd7cd2cbbc 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/float4.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/float4.sql.out @@ -171,6 +171,7 @@ struct 0.0 1.2345679E-20 1.2345679E20 + 1004.3 -- !query @@ -178,7 +179,7 @@ SELECT '' AS one, f.* FROM FLOAT4_TBL f WHERE f.f1 = '1004.3' -- !query schema struct -- !query output - 1004.3 + -- !query @@ -189,6 +190,7 @@ struct -34.84 0.0 1.2345679E-20 + 1004.3 -- !query @@ -199,6 +201,7 @@ struct -34.84 0.0 1.2345679E-20 + 1004.3 -- !query @@ -227,22 +230,22 @@ struct SELECT '' AS three, f.f1, f.f1 * '-10' AS x FROM FLOAT4_TBL f WHERE f.f1 > '0.0' -- !query schema -struct +struct -- !query output - 1.2345679E-20 -1.2345678E-19 - 1.2345679E20 -1.2345678E21 - 1004.3 -10043.0 + 1.2345679E-20 -1.2345678720289608E-19 + 1.2345679E20 -1.2345678955701443E21 + 1004.3 -10042.999877929688 -- !query SELECT '' AS three, f.f1, f.f1 + '-10' AS x FROM FLOAT4_TBL f WHERE f.f1 > '0.0' -- !query schema -struct +struct -- !query output 1.2345679E-20 -10.0 - 1.2345679E20 1.2345679E20 - 1004.3 994.3 + 1.2345679E20 1.2345678955701443E20 + 1004.3 994.2999877929688 -- !query @@ -260,11 +263,11 @@ struct SELECT '' AS three, f.f1, f.f1 - '-10' AS x FROM FLOAT4_TBL f WHERE f.f1 > '0.0' -- !query schema -struct +struct -- !query output 1.2345679E-20 10.0 - 1.2345679E20 1.2345679E20 - 1004.3 1014.3 + 1.2345679E20 1.2345678955701443E20 + 1004.3 1014.2999877929688 -- !query @@ -375,7 +378,7 @@ SELECT bigint(float('-9223380000000000000')) struct<> -- !query output org.apache.spark.SparkArithmeticException -Casting -9.22338E18 to int causes overflow. To return NULL instead, use 'try_cast'. If necessary set spark.sql.ansi.enabled to false to bypass this error. +Casting -9.22338E18 to bigint causes overflow. To return NULL instead, use 'try_cast'. If necessary set spark.sql.ansi.enabled to false to bypass this error. -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/float8.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/float8.sql.out index 2e4fbc2dfa537..2b71be5a5d96c 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/float8.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/float8.sql.out @@ -833,7 +833,7 @@ SELECT bigint(double('-9223372036854780000')) struct<> -- !query output org.apache.spark.SparkArithmeticException -Casting -9.22337203685478E18 to long causes overflow. To return NULL instead, use 'try_cast'. If necessary set spark.sql.ansi.enabled to false to bypass this error. +Casting -9.22337203685478E18 to bigint causes overflow. To return NULL instead, use 'try_cast'. If necessary set spark.sql.ansi.enabled to false to bypass this error. -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/int8.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/int8.sql.out index 24f0b3c5ed3bf..427e89a8d1b41 100755 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/int8.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/int8.sql.out @@ -661,7 +661,7 @@ SELECT CAST(double('922337203685477580700.0') AS bigint) struct<> -- !query output org.apache.spark.SparkArithmeticException -Casting 9.223372036854776E20 to long causes overflow. To return NULL instead, use 'try_cast'. If necessary set spark.sql.ansi.enabled to false to bypass this error. +Casting 9.223372036854776E20 to bigint causes overflow. To return NULL instead, use 'try_cast'. If necessary set spark.sql.ansi.enabled to false to bypass this error. -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out index bc13bb893b118..41fc9908d0c2b 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 592 +-- Number of queries: 601 -- !query @@ -4594,6 +4594,80 @@ struct<> +-- !query +SELECT '' AS to_number_1, to_number('-34,338,492', '99G999G999') +-- !query schema +struct +-- !query output + -34338492 + + +-- !query +SELECT '' AS to_number_2, to_number('-34,338,492.654,878', '99G999G999D999G999') +-- !query schema +struct<> +-- !query output +java.lang.IllegalArgumentException +The input string '-34,338,492.654,878' does not match the given number format: '99G999G999D999G999' + + +-- !query +SELECT '' AS to_number_4, to_number('0.00001-', '9.999999S') +-- !query schema +struct +-- !query output + -0.000010 + + +-- !query +SELECT '' AS to_number_9, to_number('.0', '99999999.99999999') +-- !query schema +struct +-- !query output + 0.00000000 + + +-- !query +SELECT '' AS to_number_10, to_number('0', '99.99') +-- !query schema +struct +-- !query output + 0.00 + + +-- !query +SELECT '' AS to_number_12, to_number('.01-', '99.99S') +-- !query schema +struct +-- !query output + -0.01 + + +-- !query +SELECT '' AS to_number_14, to_number('34,50','999,99') +-- !query schema +struct +-- !query output + 3450 + + +-- !query +SELECT '' AS to_number_15, to_number('123,000','999G') +-- !query schema +struct<> +-- !query output +java.lang.IllegalArgumentException +The input string '123,000' does not match the given number format: '999G' + + +-- !query +SELECT '' AS to_number_16, to_number('123456','999G999') +-- !query schema +struct +-- !query output + 123456 + + -- !query CREATE TABLE num_input_test (n1 decimal(38, 18)) USING parquet -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/union.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/union.sql.out index a3f7b35fa27ed..99b6ea78cace1 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/union.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/union.sql.out @@ -80,7 +80,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -mismatched input 'SELECT' expecting {, ';'}(line 1, pos 39) +Syntax error at or near 'SELECT'(line 1, pos 39) == SQL == SELECT 1 AS three UNION SELECT 2 UNION SELECT 3 ORDER BY 1 @@ -94,7 +94,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -mismatched input 'SELECT' expecting {, ';'}(line 1, pos 37) +Syntax error at or near 'SELECT'(line 1, pos 37) == SQL == SELECT 1 AS two UNION SELECT 2 UNION SELECT 2 ORDER BY 1 @@ -171,7 +171,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -mismatched input 'SELECT' expecting {, ';'}(line 1, pos 41) +Syntax error at or near 'SELECT'(line 1, pos 41) == SQL == SELECT 1.1 AS three UNION SELECT 2 UNION SELECT 3 ORDER BY 1 @@ -185,7 +185,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -mismatched input 'SELECT' expecting {, ';'}(line 1, pos 47) +Syntax error at or near 'SELECT'(line 1, pos 47) == SQL == SELECT double(1.1) AS two UNION SELECT 2 UNION SELECT double(2.0) ORDER BY 1 @@ -381,7 +381,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -mismatched input 'SELECT' expecting {')', ',', 'CLUSTER', 'DISTRIBUTE', 'EXCEPT', 'FROM', 'GROUP', 'HAVING', 'INTERSECT', 'LATERAL', 'LIMIT', 'ORDER', 'MINUS', 'SORT', 'UNION', 'WHERE', 'WINDOW', '-'}(line 1, pos 20) +Syntax error at or near 'SELECT'(line 1, pos 20) == SQL == (SELECT 1,2,3 UNION SELECT 4,5,6) INTERSECT SELECT 4,5,6 @@ -395,7 +395,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -mismatched input 'SELECT' expecting {')', ',', 'CLUSTER', 'DISTRIBUTE', 'EXCEPT', 'FROM', 'GROUP', 'HAVING', 'INTERSECT', 'LATERAL', 'LIMIT', 'ORDER', 'MINUS', 'SORT', 'UNION', 'WHERE', 'WINDOW', '-'}(line 1, pos 20) +Syntax error at or near 'SELECT'(line 1, pos 20) == SQL == (SELECT 1,2,3 UNION SELECT 4,5,6 ORDER BY 1,2) INTERSECT SELECT 4,5,6 @@ -409,7 +409,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -mismatched input 'SELECT' expecting {')', ',', 'CLUSTER', 'DISTRIBUTE', 'EXCEPT', 'FROM', 'GROUP', 'HAVING', 'INTERSECT', 'LATERAL', 'LIMIT', 'ORDER', 'MINUS', 'SORT', 'UNION', 'WHERE', 'WINDOW', '-'}(line 1, pos 20) +Syntax error at or near 'SELECT'(line 1, pos 20) == SQL == (SELECT 1,2,3 UNION SELECT 4,5,6) EXCEPT SELECT 4,5,6 @@ -423,7 +423,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -mismatched input 'SELECT' expecting {')', ',', 'CLUSTER', 'DISTRIBUTE', 'EXCEPT', 'FROM', 'GROUP', 'HAVING', 'INTERSECT', 'LATERAL', 'LIMIT', 'ORDER', 'MINUS', 'SORT', 'UNION', 'WHERE', 'WINDOW', '-'}(line 1, pos 20) +Syntax error at or near 'SELECT'(line 1, pos 20) == SQL == (SELECT 1,2,3 UNION SELECT 4,5,6 ORDER BY 1,2) EXCEPT SELECT 4,5,6 @@ -728,7 +728,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -mismatched input 'SELECT' expecting {, ';'}(line 1, pos 44) +Syntax error at or near 'SELECT'(line 1, pos 44) == SQL == SELECT cast('3.4' as decimal(38, 18)) UNION SELECT 'foo' diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part3.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part3.sql.out index a76b4088fb818..fc19471bb5b32 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part3.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part3.sql.out @@ -329,7 +329,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -mismatched input 'BY' expecting {')', ',', '-'}(line 1, pos 33) +Syntax error at or near 'BY'(line 1, pos 33) == SQL == SELECT * FROM rank() OVER (ORDER BY random()) @@ -374,7 +374,7 @@ SELECT range(1, 100) OVER () FROM empsalary struct<> -- !query output org.apache.spark.sql.AnalysisException -Undefined function: 'range'. This function is neither a registered temporary function nor a permanent function registered in the database 'default'.; line 1 pos 7 +Undefined function: range. This function is neither a built-in/temporary function, nor a persistent function that is qualified as spark_catalog.default.range.; line 1 pos 7 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/show-create-table.sql.out b/sql/core/src/test/resources/sql-tests/results/show-create-table.sql.out index e7399e45c3579..ded27abc4c14d 100644 --- a/sql/core/src/test/resources/sql-tests/results/show-create-table.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/show-create-table.sql.out @@ -15,10 +15,10 @@ SHOW CREATE TABLE tbl -- !query schema struct -- !query output -CREATE TABLE `default`.`tbl` ( - `a` INT, - `b` STRING, - `c` INT) +CREATE TABLE default.tbl ( + a INT, + b STRING, + c INT) USING parquet @@ -44,10 +44,10 @@ SHOW CREATE TABLE tbl -- !query schema struct -- !query output -CREATE TABLE `default`.`tbl` ( - `a` INT, - `b` STRING, - `c` INT) +CREATE TABLE default.tbl ( + a INT, + b STRING, + c INT) USING parquet OPTIONS ( 'a' = '1') @@ -75,12 +75,12 @@ SHOW CREATE TABLE tbl -- !query schema struct -- !query output -CREATE TABLE `default`.`tbl` ( - `a` INT, - `b` STRING, - `c` INT) +CREATE TABLE default.tbl ( + a INT, + b STRING, + c INT) USING parquet -LOCATION 'file:/path/to/table' +LOCATION 'file:///path/to/table' -- !query @@ -105,12 +105,12 @@ SHOW CREATE TABLE tbl -- !query schema struct -- !query output -CREATE TABLE `default`.`tbl` ( - `a` INT, - `b` STRING, - `c` INT) +CREATE TABLE default.tbl ( + a INT, + b STRING, + c INT) USING parquet -LOCATION 'file:/path/to/table' +LOCATION 'file:///path/to/table' -- !query @@ -135,10 +135,10 @@ SHOW CREATE TABLE tbl -- !query schema struct -- !query output -CREATE TABLE `default`.`tbl` ( - `b` STRING, - `c` INT, - `a` INT) +CREATE TABLE default.tbl ( + b STRING, + c INT, + a INT) USING parquet PARTITIONED BY (a) @@ -165,10 +165,10 @@ SHOW CREATE TABLE tbl -- !query schema struct -- !query output -CREATE TABLE `default`.`tbl` ( - `a` INT, - `b` STRING, - `c` INT) +CREATE TABLE default.tbl ( + a INT, + b STRING, + c INT) USING parquet CLUSTERED BY (a) SORTED BY (b) @@ -197,10 +197,10 @@ SHOW CREATE TABLE tbl -- !query schema struct -- !query output -CREATE TABLE `default`.`tbl` ( - `a` INT, - `b` STRING, - `c` INT) +CREATE TABLE default.tbl ( + a INT, + b STRING, + c INT) USING parquet COMMENT 'This is a comment' @@ -227,10 +227,10 @@ SHOW CREATE TABLE tbl -- !query schema struct -- !query output -CREATE TABLE `default`.`tbl` ( - `a` INT, - `b` STRING, - `c` INT) +CREATE TABLE default.tbl ( + a INT, + b STRING, + c INT) USING parquet TBLPROPERTIES ( 'a' = '1') @@ -257,11 +257,11 @@ SHOW CREATE TABLE tbl -- !query schema struct -- !query output -CREATE TABLE `default`.`tbl` ( - `a` FLOAT, - `b` DECIMAL(10,0), - `c` DECIMAL(10,0), - `d` DECIMAL(10,1)) +CREATE TABLE default.tbl ( + a FLOAT, + b DECIMAL(10,0), + c DECIMAL(10,0), + d DECIMAL(10,1)) USING parquet @@ -295,9 +295,9 @@ SHOW CREATE TABLE view_SPARK_30302 AS SERDE -- !query schema struct -- !query output -CREATE VIEW `default`.`view_SPARK_30302`( - `aaa`, - `bbb`) +CREATE VIEW default.view_SPARK_30302 ( + aaa, + bbb) AS SELECT a, b FROM tbl @@ -306,9 +306,9 @@ SHOW CREATE TABLE view_SPARK_30302 -- !query schema struct -- !query output -CREATE VIEW `default`.`view_SPARK_30302` ( - `aaa`, - `bbb`) +CREATE VIEW default.view_SPARK_30302 ( + aaa, + bbb) AS SELECT a, b FROM tbl @@ -335,9 +335,9 @@ SHOW CREATE TABLE view_SPARK_30302 AS SERDE -- !query schema struct -- !query output -CREATE VIEW `default`.`view_SPARK_30302`( - `aaa` COMMENT 'comment with \'quoted text\' for aaa', - `bbb`) +CREATE VIEW default.view_SPARK_30302 ( + aaa COMMENT 'comment with \'quoted text\' for aaa', + bbb) COMMENT 'This is a comment with \'quoted text\' for view' AS SELECT a, b FROM tbl @@ -347,9 +347,9 @@ SHOW CREATE TABLE view_SPARK_30302 -- !query schema struct -- !query output -CREATE VIEW `default`.`view_SPARK_30302` ( - `aaa` COMMENT 'comment with \'quoted text\' for aaa', - `bbb`) +CREATE VIEW default.view_SPARK_30302 ( + aaa COMMENT 'comment with \'quoted text\' for aaa', + bbb) COMMENT 'This is a comment with \'quoted text\' for view' AS SELECT a, b FROM tbl @@ -377,9 +377,9 @@ SHOW CREATE TABLE view_SPARK_30302 AS SERDE -- !query schema struct -- !query output -CREATE VIEW `default`.`view_SPARK_30302`( - `aaa`, - `bbb`) +CREATE VIEW default.view_SPARK_30302 ( + aaa, + bbb) TBLPROPERTIES ( 'a' = '1', 'b' = '2') @@ -391,9 +391,9 @@ SHOW CREATE TABLE view_SPARK_30302 -- !query schema struct -- !query output -CREATE VIEW `default`.`view_SPARK_30302` ( - `aaa`, - `bbb`) +CREATE VIEW default.view_SPARK_30302 ( + aaa, + bbb) TBLPROPERTIES ( 'a' = '1', 'b' = '2') diff --git a/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out b/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out index 139004345accb..70a4822ff916d 100644 --- a/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out @@ -168,7 +168,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -mismatched input '' expecting {'FROM', 'IN', 'LIKE'}(line 1, pos 19) +Syntax error at or near end of input(line 1, pos 19) == SQL == SHOW TABLE EXTENDED @@ -193,7 +193,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -mismatched input 'PARTITION' expecting {'FROM', 'IN', 'LIKE'}(line 1, pos 20) +Syntax error at or near 'PARTITION'(line 1, pos 20) == SQL == SHOW TABLE EXTENDED PARTITION(c='Us', d=1) diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 2aa2e80a1244e..4307df7e61683 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 94 +-- Number of queries: 131 -- !query @@ -756,3 +756,302 @@ SELECT endswith(null, null) struct -- !query output NULL + + +-- !query +SELECT contains(x'537061726b2053514c', x'537061726b') +-- !query schema +struct +-- !query output +true + + +-- !query +SELECT contains(x'', x'') +-- !query schema +struct +-- !query output +true + + +-- !query +SELECT contains(x'537061726b2053514c', null) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT contains(12, '1') +-- !query schema +struct +-- !query output +true + + +-- !query +SELECT contains(true, 'ru') +-- !query schema +struct +-- !query output +true + + +-- !query +SELECT contains(x'12', 12) +-- !query schema +struct +-- !query output +false + + +-- !query +SELECT contains(true, false) +-- !query schema +struct +-- !query output +false + + +-- !query +SELECT startswith(x'537061726b2053514c', x'537061726b') +-- !query schema +struct +-- !query output +true + + +-- !query +SELECT startswith(x'537061726b2053514c', x'') +-- !query schema +struct +-- !query output +true + + +-- !query +SELECT startswith(x'', x'') +-- !query schema +struct +-- !query output +true + + +-- !query +SELECT startswith(x'537061726b2053514c', null) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT endswith(x'537061726b2053514c', x'53516c') +-- !query schema +struct +-- !query output +false + + +-- !query +SELECT endsWith(x'537061726b2053514c', x'537061726b') +-- !query schema +struct +-- !query output +false + + +-- !query +SELECT endsWith(x'537061726b2053514c', x'') +-- !query schema +struct +-- !query output +true + + +-- !query +SELECT endsWith(x'', x'') +-- !query schema +struct +-- !query output +true + + +-- !query +SELECT endsWith(x'537061726b2053514c', null) +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_number('454', '000') +-- !query schema +struct +-- !query output +454 + + +-- !query +select to_number('454.2', '000.0') +-- !query schema +struct +-- !query output +454.2 + + +-- !query +select to_number('12,454', '00,000') +-- !query schema +struct +-- !query output +12454 + + +-- !query +select to_number('$78.12', '$00.00') +-- !query schema +struct +-- !query output +78.12 + + +-- !query +select to_number('-454', '-000') +-- !query schema +struct +-- !query output +-454 + + +-- !query +select to_number('-454', 'S000') +-- !query schema +struct +-- !query output +-454 + + +-- !query +select to_number('12,454.8-', '00,000.9-') +-- !query schema +struct +-- !query output +-12454.8 + + +-- !query +select to_number('00,454.8-', '00,000.9-') +-- !query schema +struct +-- !query output +-454.8 + + +-- !query +select to_binary('abc') +-- !query schema +struct +-- !query output +� + + +-- !query +select to_binary('abc', 'utf-8') +-- !query schema +struct +-- !query output +abc + + +-- !query +select to_binary('abc', 'base64') +-- !query schema +struct +-- !query output +i� + + +-- !query +select to_binary('abc', 'hex') +-- !query schema +struct +-- !query output +� + + +-- !query +select to_binary('abc', concat('utf', '-8')) +-- !query schema +struct +-- !query output +abc + + +-- !query +select to_binary('abc', 'Hex') +-- !query schema +struct +-- !query output +� + + +-- !query +select to_binary('abc', null) +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_binary(null, 'utf-8') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_binary(null, null) +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_binary(null, cast(null as string)) +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_binary(null, cast(null as int)) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +The 'format' parameter of function 'to_binary' needs to be a string literal.; line 1 pos 7 + + +-- !query +select to_binary('abc', 1) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +The 'format' parameter of function 'to_binary' needs to be a string literal.; line 1 pos 7 + + +-- !query +select to_binary('abc', 'invalidFormat') +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Invalid value for the 'format' parameter of function 'to_binary': invalidformat. The value has to be a case-insensitive string literal of 'hex', 'utf-8', or 'base64'. diff --git a/sql/core/src/test/resources/sql-tests/results/timestamp-ltz.sql.out b/sql/core/src/test/resources/sql-tests/results/timestamp-ltz.sql.out index 48036c6a34808..057cdf1db845c 100644 --- a/sql/core/src/test/resources/sql-tests/results/timestamp-ltz.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/timestamp-ltz.sql.out @@ -45,7 +45,7 @@ struct -- !query SELECT make_timestamp_ltz(2021, 07, 11, 6, 30, 45.678, 'CET') -- !query schema -struct +struct -- !query output 2021-07-10 21:30:45.678 diff --git a/sql/core/src/test/resources/sql-tests/results/timestamp-ntz.sql.out b/sql/core/src/test/resources/sql-tests/results/timestamp-ntz.sql.out index 0ed5beeaddf72..c4fcff4c2b81b 100644 --- a/sql/core/src/test/resources/sql-tests/results/timestamp-ntz.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/timestamp-ntz.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 8 +-- Number of queries: 12 -- !query @@ -65,3 +65,35 @@ SELECT convert_timezone('Europe/Moscow', 'America/Los_Angeles', timestamp_ntz'20 struct -- !query output 2021-12-31 13:00:00 + + +-- !query +select timestampdiff(QUARTER, timestamp_ntz'2022-01-01 01:02:03', timestamp_ntz'2022-05-02 05:06:07') +-- !query schema +struct +-- !query output +1 + + +-- !query +select timestampdiff(HOUR, timestamp_ntz'2022-02-14 01:02:03', timestamp_ltz'2022-02-14 02:03:04') +-- !query schema +struct +-- !query output +1 + + +-- !query +select timestampdiff(YEAR, date'2022-02-15', timestamp_ntz'2023-02-15 10:11:12') +-- !query schema +struct +-- !query output +1 + + +-- !query +select timestampdiff(MILLISECOND, timestamp_ntz'2022-02-14 23:59:59.123', date'2022-02-15') +-- !query schema +struct +-- !query output +877 diff --git a/sql/core/src/test/resources/sql-tests/results/timestamp.sql.out b/sql/core/src/test/resources/sql-tests/results/timestamp.sql.out index 77b4f73d179f0..0ebdf4cc01615 100644 --- a/sql/core/src/test/resources/sql-tests/results/timestamp.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/timestamp.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 89 +-- Number of queries: 97 -- !query @@ -719,7 +719,7 @@ select to_timestamp('2019-10-06 A', 'yyyy-MM-dd GGGGG') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'yyyy-MM-dd GGGGG' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'yyyy-MM-dd GGGGG' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -728,7 +728,7 @@ select to_timestamp('22 05 2020 Friday', 'dd MM yyyy EEEEEE') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd MM yyyy EEEEEE' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'dd MM yyyy EEEEEE' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -737,7 +737,7 @@ select to_timestamp('22 05 2020 Friday', 'dd MM yyyy EEEEE') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd MM yyyy EEEEE' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'dd MM yyyy EEEEE' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -746,7 +746,7 @@ select unix_timestamp('22 05 2020 Friday', 'dd MM yyyy EEEEE') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd MM yyyy EEEEE' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'dd MM yyyy EEEEE' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -755,7 +755,7 @@ select from_json('{"t":"26/October/2015"}', 't Timestamp', map('timestampFormat' struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd/MMMMM/yyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'dd/MMMMM/yyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -764,4 +764,68 @@ select from_csv('26/October/2015', 't Timestamp', map('timestampFormat', 'dd/MMM struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd/MMMMM/yyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'dd/MMMMM/yyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select timestampadd(MONTH, -1, timestamp'2022-02-14 01:02:03') +-- !query schema +struct +-- !query output +2022-01-14 01:02:03 + + +-- !query +select timestampadd(MINUTE, 58, timestamp'2022-02-14 01:02:03') +-- !query schema +struct +-- !query output +2022-02-14 02:00:03 + + +-- !query +select timestampadd(YEAR, 1, date'2022-02-15') +-- !query schema +struct +-- !query output +2023-02-15 00:00:00 + + +-- !query +select timestampadd(SECOND, -1, date'2022-02-15') +-- !query schema +struct +-- !query output +2022-02-14 23:59:59 + + +-- !query +select timestampdiff(MONTH, timestamp'2022-02-14 01:02:03', timestamp'2022-01-14 01:02:03') +-- !query schema +struct +-- !query output +-1 + + +-- !query +select timestampdiff(MINUTE, timestamp'2022-02-14 01:02:03', timestamp'2022-02-14 02:00:03') +-- !query schema +struct +-- !query output +58 + + +-- !query +select timestampdiff(YEAR, date'2022-02-15', date'2023-02-15') +-- !query schema +struct +-- !query output +1 + + +-- !query +select timestampdiff(SECOND, date'2022-02-15', timestamp'2022-02-14 23:59:59') +-- !query schema +struct +-- !query output +-1 diff --git a/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out b/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out index 371b0e00f532e..f7552ed4f62cc 100644 --- a/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 89 +-- Number of queries: 97 -- !query @@ -647,19 +647,17 @@ struct<> -- !query select str - timestamp'2011-11-11 11:11:11' from ts_view -- !query schema -struct<> +struct<(str - TIMESTAMP_NTZ '2011-11-11 11:11:11'):interval day to second> -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve '(ts_view.str - TIMESTAMP_NTZ '2011-11-11 11:11:11')' due to data type mismatch: argument 1 requires (timestamp or timestamp without time zone) type, however, 'ts_view.str' is of string type.; line 1 pos 7 +0 00:00:00.000000000 -- !query select timestamp'2011-11-11 11:11:11' - str from ts_view -- !query schema -struct<> +struct<(TIMESTAMP_NTZ '2011-11-11 11:11:11' - str):interval day to second> -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve '(TIMESTAMP_NTZ '2011-11-11 11:11:11' - ts_view.str)' due to data type mismatch: argument 2 requires (timestamp or timestamp without time zone) type, however, 'ts_view.str' is of string type.; line 1 pos 7 +0 00:00:00.000000000 -- !query @@ -754,7 +752,7 @@ select unix_timestamp('22 05 2020 Friday', 'dd MM yyyy EEEEE') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd MM yyyy EEEEE' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'dd MM yyyy EEEEE' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -771,3 +769,67 @@ select from_csv('26/October/2015', 't Timestamp', map('timestampFormat', 'dd/MMM struct> -- !query output {"t":null} + + +-- !query +select timestampadd(MONTH, -1, timestamp'2022-02-14 01:02:03') +-- !query schema +struct +-- !query output +2022-01-14 01:02:03 + + +-- !query +select timestampadd(MINUTE, 58, timestamp'2022-02-14 01:02:03') +-- !query schema +struct +-- !query output +2022-02-14 02:00:03 + + +-- !query +select timestampadd(YEAR, 1, date'2022-02-15') +-- !query schema +struct +-- !query output +2023-02-15 00:00:00 + + +-- !query +select timestampadd(SECOND, -1, date'2022-02-15') +-- !query schema +struct +-- !query output +2022-02-14 23:59:59 + + +-- !query +select timestampdiff(MONTH, timestamp'2022-02-14 01:02:03', timestamp'2022-01-14 01:02:03') +-- !query schema +struct +-- !query output +-1 + + +-- !query +select timestampdiff(MINUTE, timestamp'2022-02-14 01:02:03', timestamp'2022-02-14 02:00:03') +-- !query schema +struct +-- !query output +58 + + +-- !query +select timestampdiff(YEAR, date'2022-02-15', date'2023-02-15') +-- !query schema +struct +-- !query output +1 + + +-- !query +select timestampdiff(SECOND, date'2022-02-15', timestamp'2022-02-14 23:59:59') +-- !query schema +struct +-- !query output +-1 diff --git a/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp.sql.out b/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp.sql.out index d8958d66cef4b..06e255a09c3e3 100644 --- a/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 89 +-- Number of queries: 97 -- !query @@ -746,7 +746,7 @@ select unix_timestamp('22 05 2020 Friday', 'dd MM yyyy EEEEE') struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd MM yyyy EEEEE' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'dd MM yyyy EEEEE' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -763,3 +763,67 @@ select from_csv('26/October/2015', 't Timestamp', map('timestampFormat', 'dd/MMM struct> -- !query output {"t":null} + + +-- !query +select timestampadd(MONTH, -1, timestamp'2022-02-14 01:02:03') +-- !query schema +struct +-- !query output +2022-01-14 01:02:03 + + +-- !query +select timestampadd(MINUTE, 58, timestamp'2022-02-14 01:02:03') +-- !query schema +struct +-- !query output +2022-02-14 02:00:03 + + +-- !query +select timestampadd(YEAR, 1, date'2022-02-15') +-- !query schema +struct +-- !query output +2023-02-15 00:00:00 + + +-- !query +select timestampadd(SECOND, -1, date'2022-02-15') +-- !query schema +struct +-- !query output +2022-02-14 23:59:59 + + +-- !query +select timestampdiff(MONTH, timestamp'2022-02-14 01:02:03', timestamp'2022-01-14 01:02:03') +-- !query schema +struct +-- !query output +-1 + + +-- !query +select timestampdiff(MINUTE, timestamp'2022-02-14 01:02:03', timestamp'2022-02-14 02:00:03') +-- !query schema +struct +-- !query output +58 + + +-- !query +select timestampdiff(YEAR, date'2022-02-15', date'2023-02-15') +-- !query schema +struct +-- !query output +1 + + +-- !query +select timestampdiff(SECOND, date'2022-02-15', timestamp'2022-02-14 23:59:59') +-- !query schema +struct +-- !query output +-1 diff --git a/sql/core/src/test/resources/sql-tests/results/transform.sql.out b/sql/core/src/test/resources/sql-tests/results/transform.sql.out index c1c13cdf276c0..c9a04c99b9fb2 100644 --- a/sql/core/src/test/resources/sql-tests/results/transform.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/transform.sql.out @@ -719,7 +719,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -TRANSFORM does not support DISTINCT/ALL in inputs(line 1, pos 17) +The feature is not supported: TRANSFORM does not support DISTINCT/ALL in inputs(line 1, pos 17) == SQL == SELECT TRANSFORM(DISTINCT b, a, c) @@ -739,7 +739,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -TRANSFORM does not support DISTINCT/ALL in inputs(line 1, pos 17) +The feature is not supported: TRANSFORM does not support DISTINCT/ALL in inputs(line 1, pos 17) == SQL == SELECT TRANSFORM(ALL b, a, c) diff --git a/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out b/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out index 47faeb3ce9ea4..f3c483cfafea8 100644 --- a/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 29 +-- Number of queries: 49 -- !query @@ -233,3 +233,163 @@ SELECT try_divide(interval 106751991 day, 0.5) struct -- !query output NULL + + +-- !query +SELECT try_subtract(1, 1) +-- !query schema +struct +-- !query output +0 + + +-- !query +SELECT try_subtract(2147483647, -1) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT try_subtract(-2147483648, 1) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT try_subtract(9223372036854775807L, -1) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT try_subtract(-9223372036854775808L, 1) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT try_subtract(interval 2 year, interval 3 year) +-- !query schema +struct +-- !query output +-1-0 + + +-- !query +SELECT try_subtract(interval 3 second, interval 2 second) +-- !query schema +struct +-- !query output +0 00:00:01.000000000 + + +-- !query +SELECT try_subtract(interval 2147483647 month, interval -2 month) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT try_subtract(interval 106751991 day, interval -3 day) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT try_multiply(2, 3) +-- !query schema +struct +-- !query output +6 + + +-- !query +SELECT try_multiply(2147483647, -2) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT try_multiply(-2147483648, 2) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT try_multiply(9223372036854775807L, 2) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT try_multiply(-9223372036854775808L, -2) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT try_multiply(interval 2 year, 2) +-- !query schema +struct +-- !query output +4-0 + + +-- !query +SELECT try_multiply(interval 2 second, 2) +-- !query schema +struct +-- !query output +0 00:00:04.000000000 + + +-- !query +SELECT try_multiply(interval 2 year, 0) +-- !query schema +struct +-- !query output +0-0 + + +-- !query +SELECT try_multiply(interval 2 second, 0) +-- !query schema +struct +-- !query output +0 00:00:00.000000000 + + +-- !query +SELECT try_multiply(interval 2147483647 month, 2) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT try_multiply(interval 106751991 day, 2) +-- !query schema +struct +-- !query output +NULL diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out index 14e941c074041..fd4f8b2c7a0e3 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out @@ -139,7 +139,7 @@ select to_timestamp('2018-01-01', a) from t struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'aa' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'aa' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -156,7 +156,7 @@ select to_unix_timestamp('2018-01-01', a) from t struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'aa' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'aa' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -173,7 +173,7 @@ select unix_timestamp('2018-01-01', a) from t struct<> -- !query output org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'aa' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +You may get a different result due to the upgrading to Spark >= 3.0: Fail to recognize 'aa' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/udf/postgreSQL/udf-aggregates_part1.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/postgreSQL/udf-aggregates_part1.sql.out index b75bd58d93c9f..09cf6ee218969 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/postgreSQL/udf-aggregates_part1.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/postgreSQL/udf-aggregates_part1.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 44 +-- Number of queries: 45 -- !query @@ -287,6 +287,14 @@ struct 4 +-- !query +SELECT regr_avgx(b, a), regr_avgy(b, a) FROM aggtest +-- !query schema +struct +-- !query output +49.5 107.94315227307379 + + -- !query SELECT udf(covar_pop(b, udf(a))), covar_samp(udf(b), a) FROM aggtest -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out index 5db0f4dac54a7..d543c6a1bb742 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out @@ -380,7 +380,7 @@ SELECT every(udf(1)) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'every(CAST(udf(cast(1 as string)) AS INT))' due to data type mismatch: Input to function 'every' should have been boolean, but it's [int].; line 1 pos 7 +cannot resolve 'every(CAST(udf(cast(1 as string)) AS INT))' due to data type mismatch: argument 1 requires boolean type, however, 'CAST(udf(cast(1 as string)) AS INT)' is of int type.; line 1 pos 7 -- !query @@ -389,7 +389,7 @@ SELECT some(udf(1S)) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'some(CAST(udf(cast(1 as string)) AS SMALLINT))' due to data type mismatch: Input to function 'some' should have been boolean, but it's [smallint].; line 1 pos 7 +cannot resolve 'some(CAST(udf(cast(1 as string)) AS SMALLINT))' due to data type mismatch: argument 1 requires boolean type, however, 'CAST(udf(cast(1 as string)) AS SMALLINT)' is of smallint type.; line 1 pos 7 -- !query @@ -398,7 +398,7 @@ SELECT any(udf(1L)) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'any(CAST(udf(cast(1 as string)) AS BIGINT))' due to data type mismatch: Input to function 'any' should have been boolean, but it's [bigint].; line 1 pos 7 +cannot resolve 'any(CAST(udf(cast(1 as string)) AS BIGINT))' due to data type mismatch: argument 1 requires boolean type, however, 'CAST(udf(cast(1 as string)) AS BIGINT)' is of bigint type.; line 1 pos 7 -- !query @@ -407,7 +407,7 @@ SELECT udf(every("true")) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'every('true')' due to data type mismatch: Input to function 'every' should have been boolean, but it's [string].; line 1 pos 11 +cannot resolve 'every('true')' due to data type mismatch: argument 1 requires boolean type, however, ''true'' is of string type.; line 1 pos 11 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/window.sql.out b/sql/core/src/test/resources/sql-tests/results/window.sql.out index d781245227ec4..d13411e333371 100644 --- a/sql/core/src/test/resources/sql-tests/results/window.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/window.sql.out @@ -898,7 +898,7 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -The definition of window 'w' is repetitive(line 8, pos 0) +Invalid SQL syntax: The definition of window 'w' is repetitive.(line 8, pos 0) == SQL == SELECT diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-test-offsetId-commitId-inconsistent/commits/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-test-offsetId-commitId-inconsistent/commits/0 new file mode 100644 index 0000000000000..9c1e3021c3ead --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-test-offsetId-commitId-inconsistent/commits/0 @@ -0,0 +1,2 @@ +v1 +{"nextBatchWatermarkMs":0} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-test-offsetId-commitId-inconsistent/commits/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-test-offsetId-commitId-inconsistent/commits/1 new file mode 100644 index 0000000000000..9c1e3021c3ead --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-test-offsetId-commitId-inconsistent/commits/1 @@ -0,0 +1,2 @@ +v1 +{"nextBatchWatermarkMs":0} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-test-offsetId-commitId-inconsistent/commits/2 b/sql/core/src/test/resources/structured-streaming/checkpoint-test-offsetId-commitId-inconsistent/commits/2 new file mode 100644 index 0000000000000..9c1e3021c3ead --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-test-offsetId-commitId-inconsistent/commits/2 @@ -0,0 +1,2 @@ +v1 +{"nextBatchWatermarkMs":0} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-test-offsetId-commitId-inconsistent/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-test-offsetId-commitId-inconsistent/metadata new file mode 100644 index 0000000000000..4691bccd0a792 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-test-offsetId-commitId-inconsistent/metadata @@ -0,0 +1 @@ +{"id":"d4358946-170c-49a7-823b-d8e4e9126616"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-test-offsetId-commitId-inconsistent/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-test-offsetId-commitId-inconsistent/offsets/0 new file mode 100644 index 0000000000000..807d7b0063b96 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-test-offsetId-commitId-inconsistent/offsets/0 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1531292029003,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-test-offsetId-commitId-inconsistent/offsets/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-test-offsetId-commitId-inconsistent/offsets/1 new file mode 100644 index 0000000000000..cce541073fb4b --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-test-offsetId-commitId-inconsistent/offsets/1 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":5000,"batchTimestampMs":1531292030005,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +1 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-test-offsetId-commitId-inconsistent/offsets/2 b/sql/core/src/test/resources/structured-streaming/checkpoint-test-offsetId-commitId-inconsistent/offsets/2 new file mode 100644 index 0000000000000..dd9a1936aba55 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-test-offsetId-commitId-inconsistent/offsets/2 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":5000,"batchTimestampMs":1531292030005,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +2 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-test-offsetId-commitId-inconsistent/offsets/4 b/sql/core/src/test/resources/structured-streaming/checkpoint-test-offsetId-commitId-inconsistent/offsets/4 new file mode 100644 index 0000000000000..54a6fecef7d52 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-test-offsetId-commitId-inconsistent/offsets/4 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":5000,"batchTimestampMs":1531292030005,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +4 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.1.0-flatmapgroupswithstate-repartition/commits/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.1.0-flatmapgroupswithstate-repartition/commits/0 new file mode 100644 index 0000000000000..9c1e3021c3ead --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.1.0-flatmapgroupswithstate-repartition/commits/0 @@ -0,0 +1,2 @@ +v1 +{"nextBatchWatermarkMs":0} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.1.0-flatmapgroupswithstate-repartition/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.1.0-flatmapgroupswithstate-repartition/metadata new file mode 100644 index 0000000000000..019111c307024 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.1.0-flatmapgroupswithstate-repartition/metadata @@ -0,0 +1 @@ +{"id":"dc9af96e-870c-4dc6-ad09-1b84b62caac3"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.1.0-flatmapgroupswithstate-repartition/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.1.0-flatmapgroupswithstate-repartition/offsets/0 new file mode 100644 index 0000000000000..d00e8a5a4134a --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.1.0-flatmapgroupswithstate-repartition/offsets/0 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1000,"conf":{"spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider","spark.sql.streaming.join.stateFormatVersion":"2","spark.sql.streaming.stateStore.compression.codec":"lz4","spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion":"2","spark.sql.streaming.multipleWatermarkPolicy":"min","spark.sql.streaming.aggregation.stateFormatVersion":"2","spark.sql.shuffle.partitions":"5"}} +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.1.0-flatmapgroupswithstate-repartition/state/0/0/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.1.0-flatmapgroupswithstate-repartition/state/0/0/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.1.0-flatmapgroupswithstate-repartition/state/0/0/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.1.0-flatmapgroupswithstate-repartition/state/0/0/_metadata/schema b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.1.0-flatmapgroupswithstate-repartition/state/0/0/_metadata/schema new file mode 100644 index 0000000000000..d3948722c3258 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.1.0-flatmapgroupswithstate-repartition/state/0/0/_metadata/schema differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.1.0-flatmapgroupswithstate-repartition/state/0/1/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.1.0-flatmapgroupswithstate-repartition/state/0/1/1.delta new file mode 100644 index 0000000000000..2639d3211decf Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.1.0-flatmapgroupswithstate-repartition/state/0/1/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.1.0-flatmapgroupswithstate-repartition/state/0/2/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.1.0-flatmapgroupswithstate-repartition/state/0/2/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.1.0-flatmapgroupswithstate-repartition/state/0/2/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.1.0-flatmapgroupswithstate-repartition/state/0/3/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.1.0-flatmapgroupswithstate-repartition/state/0/3/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.1.0-flatmapgroupswithstate-repartition/state/0/3/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.1.0-flatmapgroupswithstate-repartition/state/0/4/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.1.0-flatmapgroupswithstate-repartition/state/0/4/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.1.0-flatmapgroupswithstate-repartition/state/0/4/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/commits/.0.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/commits/.0.crc new file mode 100644 index 0000000000000..1aee7033161ec Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/commits/.0.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/commits/.1.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/commits/.1.crc new file mode 100644 index 0000000000000..1aee7033161ec Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/commits/.1.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/commits/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/commits/0 new file mode 100644 index 0000000000000..9c1e3021c3ead --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/commits/0 @@ -0,0 +1,2 @@ +v1 +{"nextBatchWatermarkMs":0} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/commits/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/commits/1 new file mode 100644 index 0000000000000..9c1e3021c3ead --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/commits/1 @@ -0,0 +1,2 @@ +v1 +{"nextBatchWatermarkMs":0} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/metadata new file mode 100644 index 0000000000000..81acb4439e8f5 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/metadata @@ -0,0 +1 @@ +{"id":"9538ada3-a233-4697-8b02-cc66250189a3"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/offsets/.0.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/offsets/.0.crc new file mode 100644 index 0000000000000..b8a9976585811 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/offsets/.0.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/offsets/.1.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/offsets/.1.crc new file mode 100644 index 0000000000000..81716485cf023 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/offsets/.1.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/offsets/0 new file mode 100644 index 0000000000000..852130a526e08 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/offsets/0 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1645693797622,"conf":{"spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider","spark.sql.streaming.join.stateFormatVersion":"2","spark.sql.streaming.stateStore.compression.codec":"lz4","spark.sql.streaming.stateStore.rocksdb.formatVersion":"5","spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion":"2","spark.sql.streaming.multipleWatermarkPolicy":"min","spark.sql.streaming.aggregation.stateFormatVersion":"2","spark.sql.shuffle.partitions":"5"}} +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/offsets/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/offsets/1 new file mode 100644 index 0000000000000..2d894644897bf --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/offsets/1 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1645693802625,"conf":{"spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider","spark.sql.streaming.join.stateFormatVersion":"2","spark.sql.streaming.stateStore.compression.codec":"lz4","spark.sql.streaming.stateStore.rocksdb.formatVersion":"5","spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion":"2","spark.sql.streaming.multipleWatermarkPolicy":"min","spark.sql.streaming.aggregation.stateFormatVersion":"2","spark.sql.shuffle.partitions":"5"}} +1 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/0/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/0/.1.delta.crc new file mode 100644 index 0000000000000..cf1d68e2acee3 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/0/.1.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/0/.2.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/0/.2.delta.crc new file mode 100644 index 0000000000000..cf1d68e2acee3 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/0/.2.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/0/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/0/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/0/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/0/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/0/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/0/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/0/_metadata/.schema.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/0/_metadata/.schema.crc new file mode 100644 index 0000000000000..f03866c573c15 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/0/_metadata/.schema.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/0/_metadata/schema b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/0/_metadata/schema new file mode 100644 index 0000000000000..e4695f58d7de9 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/0/_metadata/schema differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/1/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/1/.1.delta.crc new file mode 100644 index 0000000000000..cf1d68e2acee3 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/1/.1.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/1/.2.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/1/.2.delta.crc new file mode 100644 index 0000000000000..dc5c3a4905a5b Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/1/.2.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/1/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/1/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/1/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/1/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/1/2.delta new file mode 100644 index 0000000000000..00c03b0f2aaa5 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/1/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/2/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/2/.1.delta.crc new file mode 100644 index 0000000000000..cf1d68e2acee3 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/2/.1.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/2/.2.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/2/.2.delta.crc new file mode 100644 index 0000000000000..0df89359466b4 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/2/.2.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/2/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/2/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/2/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/2/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/2/2.delta new file mode 100644 index 0000000000000..0a0f74c944036 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/2/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/3/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/3/.1.delta.crc new file mode 100644 index 0000000000000..fcb13666a0ad8 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/3/.1.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/3/.2.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/3/.2.delta.crc new file mode 100644 index 0000000000000..cf1d68e2acee3 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/3/.2.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/3/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/3/1.delta new file mode 100644 index 0000000000000..4e033f8786aa8 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/3/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/3/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/3/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/3/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/4/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/4/.1.delta.crc new file mode 100644 index 0000000000000..eb2b6be4e5e55 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/4/.1.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/4/.2.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/4/.2.delta.crc new file mode 100644 index 0000000000000..cf1d68e2acee3 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/4/.2.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/4/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/4/1.delta new file mode 100644 index 0000000000000..7b6e9c175b8cf Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/4/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/4/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/4/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/state/0/4/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/commits/.0.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/commits/.0.crc new file mode 100644 index 0000000000000..1aee7033161ec Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/commits/.0.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/commits/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/commits/0 new file mode 100644 index 0000000000000..9c1e3021c3ead --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/commits/0 @@ -0,0 +1,2 @@ +v1 +{"nextBatchWatermarkMs":0} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/metadata new file mode 100644 index 0000000000000..54698e5f8afa9 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/metadata @@ -0,0 +1 @@ +{"id":"b36205c7-696a-4fe9-86d4-a4efdf05795b"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/offsets/.0.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/offsets/.0.crc new file mode 100644 index 0000000000000..04523a6882fdb Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/offsets/.0.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/offsets/0 new file mode 100644 index 0000000000000..321a56f4d3707 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/offsets/0 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1000,"conf":{"spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider","spark.sql.streaming.join.stateFormatVersion":"2","spark.sql.streaming.stateStore.compression.codec":"lz4","spark.sql.streaming.stateStore.rocksdb.formatVersion":"5","spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion":"2","spark.sql.streaming.multipleWatermarkPolicy":"min","spark.sql.streaming.aggregation.stateFormatVersion":"2","spark.sql.shuffle.partitions":"5"}} +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/0/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/0/.1.delta.crc new file mode 100644 index 0000000000000..cf1d68e2acee3 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/0/.1.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/0/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/0/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/0/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/0/_metadata/.schema.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/0/_metadata/.schema.crc new file mode 100644 index 0000000000000..4d339e472ac25 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/0/_metadata/.schema.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/0/_metadata/schema b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/0/_metadata/schema new file mode 100644 index 0000000000000..bf902e50cf260 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/0/_metadata/schema differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/1/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/1/.1.delta.crc new file mode 100644 index 0000000000000..7029bc3ccdf2f Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/1/.1.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/1/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/1/1.delta new file mode 100644 index 0000000000000..610e2c0250d4e Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/1/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/2/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/2/.1.delta.crc new file mode 100644 index 0000000000000..cf1d68e2acee3 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/2/.1.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/2/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/2/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/2/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/3/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/3/.1.delta.crc new file mode 100644 index 0000000000000..cf1d68e2acee3 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/3/.1.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/3/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/3/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/3/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/4/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/4/.1.delta.crc new file mode 100644 index 0000000000000..cf1d68e2acee3 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/4/.1.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/4/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/4/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/state/0/4/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/commits/.0.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/commits/.0.crc new file mode 100644 index 0000000000000..1aee7033161ec Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/commits/.0.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/commits/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/commits/0 new file mode 100644 index 0000000000000..9c1e3021c3ead --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/commits/0 @@ -0,0 +1,2 @@ +v1 +{"nextBatchWatermarkMs":0} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/metadata new file mode 100644 index 0000000000000..fa78985cb8778 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/metadata @@ -0,0 +1 @@ +{"id":"f4795695-2b3e-4864-983a-f7bf52c0e29d"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/offsets/.0.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/offsets/.0.crc new file mode 100644 index 0000000000000..04523a6882fdb Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/offsets/.0.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/offsets/0 new file mode 100644 index 0000000000000..321a56f4d3707 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/offsets/0 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1000,"conf":{"spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider","spark.sql.streaming.join.stateFormatVersion":"2","spark.sql.streaming.stateStore.compression.codec":"lz4","spark.sql.streaming.stateStore.rocksdb.formatVersion":"5","spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion":"2","spark.sql.streaming.multipleWatermarkPolicy":"min","spark.sql.streaming.aggregation.stateFormatVersion":"2","spark.sql.shuffle.partitions":"5"}} +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/0/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/0/.1.delta.crc new file mode 100644 index 0000000000000..cf1d68e2acee3 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/0/.1.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/0/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/0/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/0/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/0/_metadata/.schema.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/0/_metadata/.schema.crc new file mode 100644 index 0000000000000..4d339e472ac25 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/0/_metadata/.schema.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/0/_metadata/schema b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/0/_metadata/schema new file mode 100644 index 0000000000000..bf902e50cf260 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/0/_metadata/schema differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/1/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/1/.1.delta.crc new file mode 100644 index 0000000000000..421f95ae9dd8b Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/1/.1.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/1/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/1/1.delta new file mode 100644 index 0000000000000..2639d3211decf Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/1/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/2/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/2/.1.delta.crc new file mode 100644 index 0000000000000..cf1d68e2acee3 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/2/.1.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/2/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/2/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/2/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/3/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/3/.1.delta.crc new file mode 100644 index 0000000000000..cf1d68e2acee3 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/3/.1.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/3/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/3/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/3/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/4/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/4/.1.delta.crc new file mode 100644 index 0000000000000..cf1d68e2acee3 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/4/.1.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/4/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/4/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/state/0/4/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/commits/.0.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/commits/.0.crc new file mode 100644 index 0000000000000..ba56986ebd219 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/commits/.0.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/commits/.1.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/commits/.1.crc new file mode 100644 index 0000000000000..ba56986ebd219 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/commits/.1.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/commits/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/commits/0 new file mode 100644 index 0000000000000..00b8a64995dde --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/commits/0 @@ -0,0 +1,2 @@ +v1 +{"nextBatchWatermarkMs":11000} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/commits/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/commits/1 new file mode 100644 index 0000000000000..00b8a64995dde --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/commits/1 @@ -0,0 +1,2 @@ +v1 +{"nextBatchWatermarkMs":11000} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/metadata new file mode 100644 index 0000000000000..879dac88e351a --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/metadata @@ -0,0 +1 @@ +{"id":"c3d27d93-536b-49ce-a62f-f2777855a1fb"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/offsets/.0.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/offsets/.0.crc new file mode 100644 index 0000000000000..0d6e0a4778504 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/offsets/.0.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/offsets/.1.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/offsets/.1.crc new file mode 100644 index 0000000000000..24dcb52ef6098 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/offsets/.1.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/offsets/0 new file mode 100644 index 0000000000000..6f149ed4ec45c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/offsets/0 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1645760172709,"conf":{"spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider","spark.sql.streaming.join.stateFormatVersion":"2","spark.sql.streaming.stateStore.compression.codec":"lz4","spark.sql.streaming.stateStore.rocksdb.formatVersion":"5","spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion":"2","spark.sql.streaming.multipleWatermarkPolicy":"min","spark.sql.streaming.aggregation.stateFormatVersion":"2","spark.sql.shuffle.partitions":"5"}} +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/offsets/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/offsets/1 new file mode 100644 index 0000000000000..4a6194c2002bd --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/offsets/1 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":11000,"batchTimestampMs":1645760174214,"conf":{"spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider","spark.sql.streaming.join.stateFormatVersion":"2","spark.sql.streaming.stateStore.compression.codec":"lz4","spark.sql.streaming.stateStore.rocksdb.formatVersion":"5","spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion":"2","spark.sql.streaming.multipleWatermarkPolicy":"min","spark.sql.streaming.aggregation.stateFormatVersion":"2","spark.sql.shuffle.partitions":"5"}} +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/0/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/0/.1.delta.crc new file mode 100644 index 0000000000000..cf1d68e2acee3 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/0/.1.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/0/.2.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/0/.2.delta.crc new file mode 100644 index 0000000000000..cf1d68e2acee3 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/0/.2.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/0/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/0/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/0/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/0/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/0/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/0/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/0/_metadata/.schema.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/0/_metadata/.schema.crc new file mode 100644 index 0000000000000..3f3804f1999c0 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/0/_metadata/.schema.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/0/_metadata/schema b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/0/_metadata/schema new file mode 100644 index 0000000000000..871586884066b Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/0/_metadata/schema differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/1/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/1/.1.delta.crc new file mode 100644 index 0000000000000..9e684a6792e54 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/1/.1.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/1/.2.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/1/.2.delta.crc new file mode 100644 index 0000000000000..cf1d68e2acee3 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/1/.2.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/1/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/1/1.delta new file mode 100644 index 0000000000000..73c35f68c2f28 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/1/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/1/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/1/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/1/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/2/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/2/.1.delta.crc new file mode 100644 index 0000000000000..cf1d68e2acee3 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/2/.1.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/2/.2.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/2/.2.delta.crc new file mode 100644 index 0000000000000..cf1d68e2acee3 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/2/.2.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/2/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/2/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/2/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/2/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/2/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/2/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/3/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/3/.1.delta.crc new file mode 100644 index 0000000000000..816cff99cd156 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/3/.1.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/3/.2.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/3/.2.delta.crc new file mode 100644 index 0000000000000..cf1d68e2acee3 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/3/.2.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/3/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/3/1.delta new file mode 100644 index 0000000000000..3c6d389f04264 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/3/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/3/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/3/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/3/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/4/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/4/.1.delta.crc new file mode 100644 index 0000000000000..cf1d68e2acee3 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/4/.1.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/4/.2.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/4/.2.delta.crc new file mode 100644 index 0000000000000..cf1d68e2acee3 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/4/.2.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/4/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/4/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/4/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/4/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/4/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/state/0/4/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/commits/.0.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/commits/.0.crc new file mode 100644 index 0000000000000..1aee7033161ec Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/commits/.0.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/commits/.1.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/commits/.1.crc new file mode 100644 index 0000000000000..1aee7033161ec Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/commits/.1.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/commits/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/commits/0 new file mode 100644 index 0000000000000..9c1e3021c3ead --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/commits/0 @@ -0,0 +1,2 @@ +v1 +{"nextBatchWatermarkMs":0} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/commits/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/commits/1 new file mode 100644 index 0000000000000..9c1e3021c3ead --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/commits/1 @@ -0,0 +1,2 @@ +v1 +{"nextBatchWatermarkMs":0} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/metadata new file mode 100644 index 0000000000000..0831489d9d02d --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/metadata @@ -0,0 +1 @@ +{"id":"cd462130-c8fb-4212-8b08-4e1b9e10dbcf"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/offsets/.0.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/offsets/.0.crc new file mode 100644 index 0000000000000..b1cf4d310b245 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/offsets/.0.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/offsets/.1.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/offsets/.1.crc new file mode 100644 index 0000000000000..cf958a5259df3 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/offsets/.1.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/offsets/0 new file mode 100644 index 0000000000000..523d0ce69165e --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/offsets/0 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1645692626085,"conf":{"spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider","spark.sql.streaming.join.stateFormatVersion":"2","spark.sql.streaming.stateStore.compression.codec":"lz4","spark.sql.streaming.stateStore.rocksdb.formatVersion":"5","spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion":"2","spark.sql.streaming.multipleWatermarkPolicy":"min","spark.sql.streaming.aggregation.stateFormatVersion":"2","spark.sql.shuffle.partitions":"5"}} +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/offsets/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/offsets/1 new file mode 100644 index 0000000000000..f69d320e37d2f --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/offsets/1 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1645692630152,"conf":{"spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider","spark.sql.streaming.join.stateFormatVersion":"2","spark.sql.streaming.stateStore.compression.codec":"lz4","spark.sql.streaming.stateStore.rocksdb.formatVersion":"5","spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion":"2","spark.sql.streaming.multipleWatermarkPolicy":"min","spark.sql.streaming.aggregation.stateFormatVersion":"2","spark.sql.shuffle.partitions":"5"}} +1 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/0/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/0/.1.delta.crc new file mode 100644 index 0000000000000..cf1d68e2acee3 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/0/.1.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/0/.2.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/0/.2.delta.crc new file mode 100644 index 0000000000000..cf1d68e2acee3 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/0/.2.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/0/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/0/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/0/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/0/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/0/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/0/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/0/_metadata/.schema.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/0/_metadata/.schema.crc new file mode 100644 index 0000000000000..701a0a87ad48a Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/0/_metadata/.schema.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/0/_metadata/schema b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/0/_metadata/schema new file mode 100644 index 0000000000000..08ee320ef2421 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/0/_metadata/schema differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/1/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/1/.1.delta.crc new file mode 100644 index 0000000000000..f712e4290ad37 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/1/.1.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/1/.2.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/1/.2.delta.crc new file mode 100644 index 0000000000000..2a9f3595f24b5 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/1/.2.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/1/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/1/1.delta new file mode 100644 index 0000000000000..f5faf01f4dc5c Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/1/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/1/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/1/2.delta new file mode 100644 index 0000000000000..ec3f1af46bd49 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/1/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/2/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/2/.1.delta.crc new file mode 100644 index 0000000000000..cf1d68e2acee3 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/2/.1.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/2/.2.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/2/.2.delta.crc new file mode 100644 index 0000000000000..cf1d68e2acee3 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/2/.2.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/2/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/2/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/2/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/2/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/2/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/2/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/3/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/3/.1.delta.crc new file mode 100644 index 0000000000000..cf1d68e2acee3 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/3/.1.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/3/.2.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/3/.2.delta.crc new file mode 100644 index 0000000000000..cf1d68e2acee3 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/3/.2.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/3/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/3/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/3/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/3/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/3/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/3/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/4/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/4/.1.delta.crc new file mode 100644 index 0000000000000..cf1d68e2acee3 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/4/.1.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/4/.2.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/4/.2.delta.crc new file mode 100644 index 0000000000000..3ffbb7a9133b5 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/4/.2.delta.crc differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/4/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/4/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/4/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/4/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/4/2.delta new file mode 100644 index 0000000000000..7c8834f659bd9 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/state/0/4/2.delta differ diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q53.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q53.sf100/explain.txt index d100e73a4de24..42b83c9c7d830 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q53.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q53.sf100/explain.txt @@ -146,7 +146,7 @@ Arguments: [avg(_w0#25) windowspecdefinition(i_manufact_id#5, specifiedwindowfra (26) Filter [codegen id : 7] Input [4]: [i_manufact_id#5, sum_sales#24, _w0#25, avg_quarterly_sales#27] -Condition : (isnotnull(avg_quarterly_sales#27) AND ((avg_quarterly_sales#27 > 0.000000) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#24 as decimal(22,6))) - promote_precision(cast(avg_quarterly_sales#27 as decimal(22,6)))), DecimalType(22,6), true), false)) / promote_precision(cast(avg_quarterly_sales#27 as decimal(22,6)))), DecimalType(38,16), true) > 0.1000000000000000))) +Condition : (isnotnull(avg_quarterly_sales#27) AND ((avg_quarterly_sales#27 > 0.000000) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#24 as decimal(22,6))) - promote_precision(cast(avg_quarterly_sales#27 as decimal(22,6)))), DecimalType(22,6)))) / promote_precision(cast(avg_quarterly_sales#27 as decimal(22,6)))), DecimalType(38,16)) > 0.1000000000000000))) (27) Project [codegen id : 7] Output [3]: [i_manufact_id#5, sum_sales#24, avg_quarterly_sales#27] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q53/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q53/explain.txt index 2b7ace43773b6..e7ae5ce6dcfb7 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q53/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q53/explain.txt @@ -146,7 +146,7 @@ Arguments: [avg(_w0#25) windowspecdefinition(i_manufact_id#5, specifiedwindowfra (26) Filter [codegen id : 7] Input [4]: [i_manufact_id#5, sum_sales#24, _w0#25, avg_quarterly_sales#27] -Condition : (isnotnull(avg_quarterly_sales#27) AND ((avg_quarterly_sales#27 > 0.000000) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#24 as decimal(22,6))) - promote_precision(cast(avg_quarterly_sales#27 as decimal(22,6)))), DecimalType(22,6), true), false)) / promote_precision(cast(avg_quarterly_sales#27 as decimal(22,6)))), DecimalType(38,16), true) > 0.1000000000000000))) +Condition : (isnotnull(avg_quarterly_sales#27) AND ((avg_quarterly_sales#27 > 0.000000) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#24 as decimal(22,6))) - promote_precision(cast(avg_quarterly_sales#27 as decimal(22,6)))), DecimalType(22,6)))) / promote_precision(cast(avg_quarterly_sales#27 as decimal(22,6)))), DecimalType(38,16)) > 0.1000000000000000))) (27) Project [codegen id : 7] Output [3]: [i_manufact_id#5, sum_sales#24, avg_quarterly_sales#27] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q59.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q59.sf100/explain.txt index 8f71448cb76b2..f260becf18e26 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q59.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q59.sf100/explain.txt @@ -283,7 +283,7 @@ Right keys [2]: [s_store_id2#74, (d_week_seq2#73 - 52)] Join condition: None (50) Project [codegen id : 10] -Output [10]: [s_store_name1#44, s_store_id1#46, d_week_seq1#45, CheckOverflow((promote_precision(sun_sales1#47) / promote_precision(sun_sales2#75)), DecimalType(37,20), true) AS (sun_sales1 / sun_sales2)#82, CheckOverflow((promote_precision(mon_sales1#48) / promote_precision(mon_sales2#76)), DecimalType(37,20), true) AS (mon_sales1 / mon_sales2)#83, CheckOverflow((promote_precision(tue_sales1#49) / promote_precision(tue_sales1#49)), DecimalType(37,20), true) AS (tue_sales1 / tue_sales1)#84, CheckOverflow((promote_precision(wed_sales1#50) / promote_precision(wed_sales2#77)), DecimalType(37,20), true) AS (wed_sales1 / wed_sales2)#85, CheckOverflow((promote_precision(thu_sales1#51) / promote_precision(thu_sales2#78)), DecimalType(37,20), true) AS (thu_sales1 / thu_sales2)#86, CheckOverflow((promote_precision(fri_sales1#52) / promote_precision(fri_sales2#79)), DecimalType(37,20), true) AS (fri_sales1 / fri_sales2)#87, CheckOverflow((promote_precision(sat_sales1#53) / promote_precision(sat_sales2#80)), DecimalType(37,20), true) AS (sat_sales1 / sat_sales2)#88] +Output [10]: [s_store_name1#44, s_store_id1#46, d_week_seq1#45, CheckOverflow((promote_precision(sun_sales1#47) / promote_precision(sun_sales2#75)), DecimalType(37,20)) AS (sun_sales1 / sun_sales2)#82, CheckOverflow((promote_precision(mon_sales1#48) / promote_precision(mon_sales2#76)), DecimalType(37,20)) AS (mon_sales1 / mon_sales2)#83, CheckOverflow((promote_precision(tue_sales1#49) / promote_precision(tue_sales1#49)), DecimalType(37,20)) AS (tue_sales1 / tue_sales1)#84, CheckOverflow((promote_precision(wed_sales1#50) / promote_precision(wed_sales2#77)), DecimalType(37,20)) AS (wed_sales1 / wed_sales2)#85, CheckOverflow((promote_precision(thu_sales1#51) / promote_precision(thu_sales2#78)), DecimalType(37,20)) AS (thu_sales1 / thu_sales2)#86, CheckOverflow((promote_precision(fri_sales1#52) / promote_precision(fri_sales2#79)), DecimalType(37,20)) AS (fri_sales1 / fri_sales2)#87, CheckOverflow((promote_precision(sat_sales1#53) / promote_precision(sat_sales2#80)), DecimalType(37,20)) AS (sat_sales1 / sat_sales2)#88] Input [18]: [s_store_name1#44, d_week_seq1#45, s_store_id1#46, sun_sales1#47, mon_sales1#48, tue_sales1#49, wed_sales1#50, thu_sales1#51, fri_sales1#52, sat_sales1#53, d_week_seq2#73, s_store_id2#74, sun_sales2#75, mon_sales2#76, wed_sales2#77, thu_sales2#78, fri_sales2#79, sat_sales2#80] (51) TakeOrderedAndProject diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q59/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q59/explain.txt index 8f71448cb76b2..f260becf18e26 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q59/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q59/explain.txt @@ -283,7 +283,7 @@ Right keys [2]: [s_store_id2#74, (d_week_seq2#73 - 52)] Join condition: None (50) Project [codegen id : 10] -Output [10]: [s_store_name1#44, s_store_id1#46, d_week_seq1#45, CheckOverflow((promote_precision(sun_sales1#47) / promote_precision(sun_sales2#75)), DecimalType(37,20), true) AS (sun_sales1 / sun_sales2)#82, CheckOverflow((promote_precision(mon_sales1#48) / promote_precision(mon_sales2#76)), DecimalType(37,20), true) AS (mon_sales1 / mon_sales2)#83, CheckOverflow((promote_precision(tue_sales1#49) / promote_precision(tue_sales1#49)), DecimalType(37,20), true) AS (tue_sales1 / tue_sales1)#84, CheckOverflow((promote_precision(wed_sales1#50) / promote_precision(wed_sales2#77)), DecimalType(37,20), true) AS (wed_sales1 / wed_sales2)#85, CheckOverflow((promote_precision(thu_sales1#51) / promote_precision(thu_sales2#78)), DecimalType(37,20), true) AS (thu_sales1 / thu_sales2)#86, CheckOverflow((promote_precision(fri_sales1#52) / promote_precision(fri_sales2#79)), DecimalType(37,20), true) AS (fri_sales1 / fri_sales2)#87, CheckOverflow((promote_precision(sat_sales1#53) / promote_precision(sat_sales2#80)), DecimalType(37,20), true) AS (sat_sales1 / sat_sales2)#88] +Output [10]: [s_store_name1#44, s_store_id1#46, d_week_seq1#45, CheckOverflow((promote_precision(sun_sales1#47) / promote_precision(sun_sales2#75)), DecimalType(37,20)) AS (sun_sales1 / sun_sales2)#82, CheckOverflow((promote_precision(mon_sales1#48) / promote_precision(mon_sales2#76)), DecimalType(37,20)) AS (mon_sales1 / mon_sales2)#83, CheckOverflow((promote_precision(tue_sales1#49) / promote_precision(tue_sales1#49)), DecimalType(37,20)) AS (tue_sales1 / tue_sales1)#84, CheckOverflow((promote_precision(wed_sales1#50) / promote_precision(wed_sales2#77)), DecimalType(37,20)) AS (wed_sales1 / wed_sales2)#85, CheckOverflow((promote_precision(thu_sales1#51) / promote_precision(thu_sales2#78)), DecimalType(37,20)) AS (thu_sales1 / thu_sales2)#86, CheckOverflow((promote_precision(fri_sales1#52) / promote_precision(fri_sales2#79)), DecimalType(37,20)) AS (fri_sales1 / fri_sales2)#87, CheckOverflow((promote_precision(sat_sales1#53) / promote_precision(sat_sales2#80)), DecimalType(37,20)) AS (sat_sales1 / sat_sales2)#88] Input [18]: [s_store_name1#44, d_week_seq1#45, s_store_id1#46, sun_sales1#47, mon_sales1#48, tue_sales1#49, wed_sales1#50, thu_sales1#51, fri_sales1#52, sat_sales1#53, d_week_seq2#73, s_store_id2#74, sun_sales2#75, mon_sales2#76, wed_sales2#77, thu_sales2#78, fri_sales2#79, sat_sales2#80] (51) TakeOrderedAndProject diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q63.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q63.sf100/explain.txt index 1e722cf779dab..698d6f41f8871 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q63.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q63.sf100/explain.txt @@ -146,7 +146,7 @@ Arguments: [avg(_w0#25) windowspecdefinition(i_manager_id#5, specifiedwindowfram (26) Filter [codegen id : 7] Input [4]: [i_manager_id#5, sum_sales#24, _w0#25, avg_monthly_sales#27] -Condition : (isnotnull(avg_monthly_sales#27) AND ((avg_monthly_sales#27 > 0.000000) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#24 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#27 as decimal(22,6)))), DecimalType(22,6), true), false)) / promote_precision(cast(avg_monthly_sales#27 as decimal(22,6)))), DecimalType(38,16), true) > 0.1000000000000000))) +Condition : (isnotnull(avg_monthly_sales#27) AND ((avg_monthly_sales#27 > 0.000000) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#24 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#27 as decimal(22,6)))), DecimalType(22,6)))) / promote_precision(cast(avg_monthly_sales#27 as decimal(22,6)))), DecimalType(38,16)) > 0.1000000000000000))) (27) Project [codegen id : 7] Output [3]: [i_manager_id#5, sum_sales#24, avg_monthly_sales#27] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q63/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q63/explain.txt index 35eaebb171a51..99146cf1d2829 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q63/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q63/explain.txt @@ -146,7 +146,7 @@ Arguments: [avg(_w0#25) windowspecdefinition(i_manager_id#5, specifiedwindowfram (26) Filter [codegen id : 7] Input [4]: [i_manager_id#5, sum_sales#24, _w0#25, avg_monthly_sales#27] -Condition : (isnotnull(avg_monthly_sales#27) AND ((avg_monthly_sales#27 > 0.000000) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#24 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#27 as decimal(22,6)))), DecimalType(22,6), true), false)) / promote_precision(cast(avg_monthly_sales#27 as decimal(22,6)))), DecimalType(38,16), true) > 0.1000000000000000))) +Condition : (isnotnull(avg_monthly_sales#27) AND ((avg_monthly_sales#27 > 0.000000) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#24 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#27 as decimal(22,6)))), DecimalType(22,6)))) / promote_precision(cast(avg_monthly_sales#27 as decimal(22,6)))), DecimalType(38,16)) > 0.1000000000000000))) (27) Project [codegen id : 7] Output [3]: [i_manager_id#5, sum_sales#24, avg_monthly_sales#27] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q65.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q65.sf100/explain.txt index 7066bd1ed142e..aabb4fe67f387 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q65.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q65.sf100/explain.txt @@ -158,7 +158,7 @@ Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)) (24) BroadcastHashJoin [codegen id : 7] Left keys [1]: [ss_store_sk#2] Right keys [1]: [ss_store_sk#13] -Join condition: (cast(revenue#11 as decimal(23,7)) <= CheckOverflow((0.100000 * promote_precision(ave#28)), DecimalType(23,7), true)) +Join condition: (cast(revenue#11 as decimal(23,7)) <= CheckOverflow((0.100000 * promote_precision(ave#28)), DecimalType(23,7))) (25) Project [codegen id : 7] Output [3]: [ss_store_sk#2, ss_item_sk#1, revenue#11] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q65/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q65/explain.txt index 02c9fdd520c10..019f4fa4c7076 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q65/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q65/explain.txt @@ -212,7 +212,7 @@ Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)) (36) BroadcastHashJoin [codegen id : 9] Left keys [1]: [ss_store_sk#4] Right keys [1]: [ss_store_sk#22] -Join condition: (cast(revenue#13 as decimal(23,7)) <= CheckOverflow((0.100000 * promote_precision(ave#37)), DecimalType(23,7), true)) +Join condition: (cast(revenue#13 as decimal(23,7)) <= CheckOverflow((0.100000 * promote_precision(ave#37)), DecimalType(23,7))) (37) Project [codegen id : 9] Output [6]: [s_store_name#2, i_item_desc#16, revenue#13, i_current_price#17, i_wholesale_cost#18, i_brand#19] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q89.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q89.sf100/explain.txt index e1b716bd2186e..8b19320021538 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q89.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q89.sf100/explain.txt @@ -141,7 +141,7 @@ Arguments: [avg(_w0#22) windowspecdefinition(i_category#15, i_brand#13, s_store_ (25) Filter [codegen id : 7] Input [9]: [i_category#15, i_class#14, i_brand#13, s_store_name#9, s_company_name#10, d_moy#7, sum_sales#21, _w0#22, avg_monthly_sales#24] -Condition : (isnotnull(avg_monthly_sales#24) AND (NOT (avg_monthly_sales#24 = 0.000000) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(22,6), true), false)) / promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(38,16), true) > 0.1000000000000000))) +Condition : (isnotnull(avg_monthly_sales#24) AND (NOT (avg_monthly_sales#24 = 0.000000) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(22,6)))) / promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(38,16)) > 0.1000000000000000))) (26) Project [codegen id : 7] Output [8]: [i_category#15, i_class#14, i_brand#13, s_store_name#9, s_company_name#10, d_moy#7, sum_sales#21, avg_monthly_sales#24] @@ -149,7 +149,7 @@ Input [9]: [i_category#15, i_class#14, i_brand#13, s_store_name#9, s_company_nam (27) TakeOrderedAndProject Input [8]: [i_category#15, i_class#14, i_brand#13, s_store_name#9, s_company_name#10, d_moy#7, sum_sales#21, avg_monthly_sales#24] -Arguments: 100, [CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(22,6), true) ASC NULLS FIRST, s_store_name#9 ASC NULLS FIRST], [i_category#15, i_class#14, i_brand#13, s_store_name#9, s_company_name#10, d_moy#7, sum_sales#21, avg_monthly_sales#24] +Arguments: 100, [CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(22,6)) ASC NULLS FIRST, s_store_name#9 ASC NULLS FIRST], [i_category#15, i_class#14, i_brand#13, s_store_name#9, s_company_name#10, d_moy#7, sum_sales#21, avg_monthly_sales#24] ===== Subqueries ===== diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q89/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q89/explain.txt index fe910f9157d15..5d3ea6d0cb7be 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q89/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q89/explain.txt @@ -141,7 +141,7 @@ Arguments: [avg(_w0#22) windowspecdefinition(i_category#4, i_brand#2, s_store_na (25) Filter [codegen id : 7] Input [9]: [i_category#4, i_class#3, i_brand#2, s_store_name#14, s_company_name#15, d_moy#12, sum_sales#21, _w0#22, avg_monthly_sales#24] -Condition : (isnotnull(avg_monthly_sales#24) AND (NOT (avg_monthly_sales#24 = 0.000000) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(22,6), true), false)) / promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(38,16), true) > 0.1000000000000000))) +Condition : (isnotnull(avg_monthly_sales#24) AND (NOT (avg_monthly_sales#24 = 0.000000) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(22,6)))) / promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(38,16)) > 0.1000000000000000))) (26) Project [codegen id : 7] Output [8]: [i_category#4, i_class#3, i_brand#2, s_store_name#14, s_company_name#15, d_moy#12, sum_sales#21, avg_monthly_sales#24] @@ -149,7 +149,7 @@ Input [9]: [i_category#4, i_class#3, i_brand#2, s_store_name#14, s_company_name# (27) TakeOrderedAndProject Input [8]: [i_category#4, i_class#3, i_brand#2, s_store_name#14, s_company_name#15, d_moy#12, sum_sales#21, avg_monthly_sales#24] -Arguments: 100, [CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(22,6), true) ASC NULLS FIRST, s_store_name#14 ASC NULLS FIRST], [i_category#4, i_class#3, i_brand#2, s_store_name#14, s_company_name#15, d_moy#12, sum_sales#21, avg_monthly_sales#24] +Arguments: 100, [CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(22,6)) ASC NULLS FIRST, s_store_name#14 ASC NULLS FIRST], [i_category#4, i_class#3, i_brand#2, s_store_name#14, s_company_name#15, d_moy#12, sum_sales#21, avg_monthly_sales#24] ===== Subqueries ===== diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q98.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q98.sf100/explain.txt index 554005d706d3d..e630982cc606b 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q98.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q98.sf100/explain.txt @@ -123,7 +123,7 @@ Input [8]: [i_item_desc#9, i_category#12, i_class#11, i_current_price#10, itemre Arguments: [sum(_w1#20) windowspecdefinition(i_class#11, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS _we0#22], [i_class#11] (22) Project [codegen id : 9] -Output [7]: [i_item_desc#9, i_category#12, i_class#11, i_current_price#10, itemrevenue#18, CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(_w0#19) * 100.00), DecimalType(21,2), true) as decimal(27,2))) / promote_precision(_we0#22)), DecimalType(38,17), true) AS revenueratio#23, i_item_id#8] +Output [7]: [i_item_desc#9, i_category#12, i_class#11, i_current_price#10, itemrevenue#18, CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(_w0#19) * 100.00), DecimalType(21,2)) as decimal(27,2))) / promote_precision(_we0#22)), DecimalType(38,17)) AS revenueratio#23, i_item_id#8] Input [9]: [i_item_desc#9, i_category#12, i_class#11, i_current_price#10, itemrevenue#18, _w0#19, _w1#20, i_item_id#8, _we0#22] (23) Exchange diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q98/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q98/explain.txt index 66206ac265399..fc2390f392247 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q98/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-modified/q98/explain.txt @@ -108,7 +108,7 @@ Input [8]: [i_item_desc#7, i_category#10, i_class#9, i_current_price#8, itemreve Arguments: [sum(_w1#19) windowspecdefinition(i_class#9, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS _we0#21], [i_class#9] (19) Project [codegen id : 6] -Output [7]: [i_item_desc#7, i_category#10, i_class#9, i_current_price#8, itemrevenue#17, CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(_w0#18) * 100.00), DecimalType(21,2), true) as decimal(27,2))) / promote_precision(_we0#21)), DecimalType(38,17), true) AS revenueratio#22, i_item_id#6] +Output [7]: [i_item_desc#7, i_category#10, i_class#9, i_current_price#8, itemrevenue#17, CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(_w0#18) * 100.00), DecimalType(21,2)) as decimal(27,2))) / promote_precision(_we0#21)), DecimalType(38,17)) AS revenueratio#22, i_item_id#6] Input [9]: [i_item_desc#7, i_category#10, i_class#9, i_current_price#8, itemrevenue#17, _w0#18, _w1#19, i_item_id#6, _we0#21] (20) Exchange diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q1.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q1.sf100/explain.txt index f071af103792d..0ac812675e8f5 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q1.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q1.sf100/explain.txt @@ -154,7 +154,7 @@ Input [3]: [ctr_store_sk#12, sum#19, count#20] Keys [1]: [ctr_store_sk#12] Functions [1]: [avg(ctr_total_return#13)] Aggregate Attributes [1]: [avg(ctr_total_return#13)#22] -Results [2]: [CheckOverflow((promote_precision(avg(ctr_total_return#13)#22) * 1.200000), DecimalType(24,7), true) AS (avg(ctr_total_return) * 1.2)#23, ctr_store_sk#12 AS ctr_store_sk#12#24] +Results [2]: [CheckOverflow((promote_precision(avg(ctr_total_return#13)#22) * 1.200000), DecimalType(24,7)) AS (avg(ctr_total_return) * 1.2)#23, ctr_store_sk#12 AS ctr_store_sk#12#24] (23) Filter [codegen id : 6] Input [2]: [(avg(ctr_total_return) * 1.2)#23, ctr_store_sk#12#24] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q1/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q1/explain.txt index 33d072fb94143..bfdc1e926597b 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q1/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q1/explain.txt @@ -151,7 +151,7 @@ Input [3]: [ctr_store_sk#12, sum#19, count#20] Keys [1]: [ctr_store_sk#12] Functions [1]: [avg(ctr_total_return#13)] Aggregate Attributes [1]: [avg(ctr_total_return#13)#22] -Results [2]: [CheckOverflow((promote_precision(avg(ctr_total_return#13)#22) * 1.200000), DecimalType(24,7), true) AS (avg(ctr_total_return) * 1.2)#23, ctr_store_sk#12 AS ctr_store_sk#12#24] +Results [2]: [CheckOverflow((promote_precision(avg(ctr_total_return#13)#22) * 1.200000), DecimalType(24,7)) AS (avg(ctr_total_return) * 1.2)#23, ctr_store_sk#12 AS ctr_store_sk#12#24] (23) Filter [codegen id : 6] Input [2]: [(avg(ctr_total_return) * 1.2)#23, ctr_store_sk#12#24] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q11.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q11.sf100/explain.txt index 025e881f1bf9e..4d8179a75c6ea 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q11.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q11.sf100/explain.txt @@ -150,7 +150,7 @@ Input [12]: [ss_customer_sk#1, ss_ext_discount_amt#2, ss_ext_list_price#3, d_yea (16) HashAggregate [codegen id : 6] Input [10]: [c_customer_id#10, c_first_name#11, c_last_name#12, c_preferred_cust_flag#13, c_birth_country#14, c_login#15, c_email_address#16, ss_ext_discount_amt#2, ss_ext_list_price#3, d_year#7] Keys [8]: [c_customer_id#10, c_first_name#11, c_last_name#12, d_year#7, c_preferred_cust_flag#13, c_birth_country#14, c_login#15, c_email_address#16] -Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#3 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#2 as decimal(8,2)))), DecimalType(8,2), true)))] +Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#3 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#2 as decimal(8,2)))), DecimalType(8,2))))] Aggregate Attributes [1]: [sum#18] Results [9]: [c_customer_id#10, c_first_name#11, c_last_name#12, d_year#7, c_preferred_cust_flag#13, c_birth_country#14, c_login#15, c_email_address#16, sum#19] @@ -161,9 +161,9 @@ Arguments: hashpartitioning(c_customer_id#10, c_first_name#11, c_last_name#12, d (18) HashAggregate [codegen id : 7] Input [9]: [c_customer_id#10, c_first_name#11, c_last_name#12, d_year#7, c_preferred_cust_flag#13, c_birth_country#14, c_login#15, c_email_address#16, sum#19] Keys [8]: [c_customer_id#10, c_first_name#11, c_last_name#12, d_year#7, c_preferred_cust_flag#13, c_birth_country#14, c_login#15, c_email_address#16] -Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#3 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#2 as decimal(8,2)))), DecimalType(8,2), true)))] -Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#3 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#2 as decimal(8,2)))), DecimalType(8,2), true)))#21] -Results [2]: [c_customer_id#10 AS customer_id#22, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#3 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#2 as decimal(8,2)))), DecimalType(8,2), true)))#21,18,2) AS year_total#23] +Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#3 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#2 as decimal(8,2)))), DecimalType(8,2))))] +Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#3 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#2 as decimal(8,2)))), DecimalType(8,2))))#21] +Results [2]: [c_customer_id#10 AS customer_id#22, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#3 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#2 as decimal(8,2)))), DecimalType(8,2))))#21,18,2) AS year_total#23] (19) Filter [codegen id : 7] Input [2]: [customer_id#22, year_total#23] @@ -231,7 +231,7 @@ Input [12]: [ss_customer_sk#25, ss_ext_discount_amt#26, ss_ext_list_price#27, d_ (34) HashAggregate [codegen id : 14] Input [10]: [c_customer_id#34, c_first_name#35, c_last_name#36, c_preferred_cust_flag#37, c_birth_country#38, c_login#39, c_email_address#40, ss_ext_discount_amt#26, ss_ext_list_price#27, d_year#31] Keys [8]: [c_customer_id#34, c_first_name#35, c_last_name#36, d_year#31, c_preferred_cust_flag#37, c_birth_country#38, c_login#39, c_email_address#40] -Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#27 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#26 as decimal(8,2)))), DecimalType(8,2), true)))] +Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#27 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#26 as decimal(8,2)))), DecimalType(8,2))))] Aggregate Attributes [1]: [sum#41] Results [9]: [c_customer_id#34, c_first_name#35, c_last_name#36, d_year#31, c_preferred_cust_flag#37, c_birth_country#38, c_login#39, c_email_address#40, sum#42] @@ -242,9 +242,9 @@ Arguments: hashpartitioning(c_customer_id#34, c_first_name#35, c_last_name#36, d (36) HashAggregate [codegen id : 15] Input [9]: [c_customer_id#34, c_first_name#35, c_last_name#36, d_year#31, c_preferred_cust_flag#37, c_birth_country#38, c_login#39, c_email_address#40, sum#42] Keys [8]: [c_customer_id#34, c_first_name#35, c_last_name#36, d_year#31, c_preferred_cust_flag#37, c_birth_country#38, c_login#39, c_email_address#40] -Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#27 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#26 as decimal(8,2)))), DecimalType(8,2), true)))] -Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#27 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#26 as decimal(8,2)))), DecimalType(8,2), true)))#21] -Results [3]: [c_customer_id#34 AS customer_id#44, c_preferred_cust_flag#37 AS customer_preferred_cust_flag#45, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#27 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#26 as decimal(8,2)))), DecimalType(8,2), true)))#21,18,2) AS year_total#46] +Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#27 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#26 as decimal(8,2)))), DecimalType(8,2))))] +Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#27 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#26 as decimal(8,2)))), DecimalType(8,2))))#21] +Results [3]: [c_customer_id#34 AS customer_id#44, c_preferred_cust_flag#37 AS customer_preferred_cust_flag#45, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#27 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#26 as decimal(8,2)))), DecimalType(8,2))))#21,18,2) AS year_total#46] (37) Exchange Input [3]: [customer_id#44, customer_preferred_cust_flag#45, year_total#46] @@ -317,7 +317,7 @@ Input [12]: [ws_bill_customer_sk#48, ws_ext_discount_amt#49, ws_ext_list_price#5 (53) HashAggregate [codegen id : 23] Input [10]: [c_customer_id#56, c_first_name#57, c_last_name#58, c_preferred_cust_flag#59, c_birth_country#60, c_login#61, c_email_address#62, ws_ext_discount_amt#49, ws_ext_list_price#50, d_year#53] Keys [8]: [c_customer_id#56, c_first_name#57, c_last_name#58, c_preferred_cust_flag#59, c_birth_country#60, c_login#61, c_email_address#62, d_year#53] -Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#50 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#49 as decimal(8,2)))), DecimalType(8,2), true)))] +Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#50 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#49 as decimal(8,2)))), DecimalType(8,2))))] Aggregate Attributes [1]: [sum#63] Results [9]: [c_customer_id#56, c_first_name#57, c_last_name#58, c_preferred_cust_flag#59, c_birth_country#60, c_login#61, c_email_address#62, d_year#53, sum#64] @@ -328,9 +328,9 @@ Arguments: hashpartitioning(c_customer_id#56, c_first_name#57, c_last_name#58, c (55) HashAggregate [codegen id : 24] Input [9]: [c_customer_id#56, c_first_name#57, c_last_name#58, c_preferred_cust_flag#59, c_birth_country#60, c_login#61, c_email_address#62, d_year#53, sum#64] Keys [8]: [c_customer_id#56, c_first_name#57, c_last_name#58, c_preferred_cust_flag#59, c_birth_country#60, c_login#61, c_email_address#62, d_year#53] -Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#50 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#49 as decimal(8,2)))), DecimalType(8,2), true)))] -Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#50 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#49 as decimal(8,2)))), DecimalType(8,2), true)))#66] -Results [2]: [c_customer_id#56 AS customer_id#67, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#50 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#49 as decimal(8,2)))), DecimalType(8,2), true)))#66,18,2) AS year_total#68] +Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#50 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#49 as decimal(8,2)))), DecimalType(8,2))))] +Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#50 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#49 as decimal(8,2)))), DecimalType(8,2))))#66] +Results [2]: [c_customer_id#56 AS customer_id#67, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#50 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#49 as decimal(8,2)))), DecimalType(8,2))))#66,18,2) AS year_total#68] (56) Filter [codegen id : 24] Input [2]: [customer_id#67, year_total#68] @@ -407,7 +407,7 @@ Input [12]: [ws_bill_customer_sk#70, ws_ext_discount_amt#71, ws_ext_list_price#7 (73) HashAggregate [codegen id : 32] Input [10]: [c_customer_id#78, c_first_name#79, c_last_name#80, c_preferred_cust_flag#81, c_birth_country#82, c_login#83, c_email_address#84, ws_ext_discount_amt#71, ws_ext_list_price#72, d_year#75] Keys [8]: [c_customer_id#78, c_first_name#79, c_last_name#80, c_preferred_cust_flag#81, c_birth_country#82, c_login#83, c_email_address#84, d_year#75] -Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#72 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#71 as decimal(8,2)))), DecimalType(8,2), true)))] +Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#72 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#71 as decimal(8,2)))), DecimalType(8,2))))] Aggregate Attributes [1]: [sum#85] Results [9]: [c_customer_id#78, c_first_name#79, c_last_name#80, c_preferred_cust_flag#81, c_birth_country#82, c_login#83, c_email_address#84, d_year#75, sum#86] @@ -418,9 +418,9 @@ Arguments: hashpartitioning(c_customer_id#78, c_first_name#79, c_last_name#80, c (75) HashAggregate [codegen id : 33] Input [9]: [c_customer_id#78, c_first_name#79, c_last_name#80, c_preferred_cust_flag#81, c_birth_country#82, c_login#83, c_email_address#84, d_year#75, sum#86] Keys [8]: [c_customer_id#78, c_first_name#79, c_last_name#80, c_preferred_cust_flag#81, c_birth_country#82, c_login#83, c_email_address#84, d_year#75] -Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#72 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#71 as decimal(8,2)))), DecimalType(8,2), true)))] -Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#72 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#71 as decimal(8,2)))), DecimalType(8,2), true)))#66] -Results [2]: [c_customer_id#78 AS customer_id#88, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#72 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#71 as decimal(8,2)))), DecimalType(8,2), true)))#66,18,2) AS year_total#89] +Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#72 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#71 as decimal(8,2)))), DecimalType(8,2))))] +Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#72 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#71 as decimal(8,2)))), DecimalType(8,2))))#66] +Results [2]: [c_customer_id#78 AS customer_id#88, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#72 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#71 as decimal(8,2)))), DecimalType(8,2))))#66,18,2) AS year_total#89] (76) Exchange Input [2]: [customer_id#88, year_total#89] @@ -433,7 +433,7 @@ Arguments: [customer_id#88 ASC NULLS FIRST], false, 0 (78) SortMergeJoin [codegen id : 35] Left keys [1]: [customer_id#22] Right keys [1]: [customer_id#88] -Join condition: (CASE WHEN (year_total#68 > 0.00) THEN CheckOverflow((promote_precision(year_total#89) / promote_precision(year_total#68)), DecimalType(38,20), true) END > CASE WHEN (year_total#23 > 0.00) THEN CheckOverflow((promote_precision(year_total#46) / promote_precision(year_total#23)), DecimalType(38,20), true) END) +Join condition: (CASE WHEN (year_total#68 > 0.00) THEN CheckOverflow((promote_precision(year_total#89) / promote_precision(year_total#68)), DecimalType(38,20)) END > CASE WHEN (year_total#23 > 0.00) THEN CheckOverflow((promote_precision(year_total#46) / promote_precision(year_total#23)), DecimalType(38,20)) END) (79) Project [codegen id : 35] Output [1]: [customer_preferred_cust_flag#45] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q11.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q11.sf100/simplified.txt index eed9d7158c108..ff149df17d8f4 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q11.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q11.sf100/simplified.txt @@ -17,7 +17,7 @@ TakeOrderedAndProject [customer_preferred_cust_flag] Exchange [customer_id] #1 WholeStageCodegen (7) Filter [year_total] - HashAggregate [c_customer_id,c_first_name,c_last_name,d_year,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt as decimal(8,2)))), DecimalType(8,2), true))),customer_id,year_total,sum] + HashAggregate [c_customer_id,c_first_name,c_last_name,d_year,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt as decimal(8,2)))), DecimalType(8,2)))),customer_id,year_total,sum] InputAdapter Exchange [c_customer_id,c_first_name,c_last_name,d_year,c_preferred_cust_flag,c_birth_country,c_login,c_email_address] #2 WholeStageCodegen (6) @@ -61,7 +61,7 @@ TakeOrderedAndProject [customer_preferred_cust_flag] InputAdapter Exchange [customer_id] #6 WholeStageCodegen (15) - HashAggregate [c_customer_id,c_first_name,c_last_name,d_year,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt as decimal(8,2)))), DecimalType(8,2), true))),customer_id,customer_preferred_cust_flag,year_total,sum] + HashAggregate [c_customer_id,c_first_name,c_last_name,d_year,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt as decimal(8,2)))), DecimalType(8,2)))),customer_id,customer_preferred_cust_flag,year_total,sum] InputAdapter Exchange [c_customer_id,c_first_name,c_last_name,d_year,c_preferred_cust_flag,c_birth_country,c_login,c_email_address] #7 WholeStageCodegen (14) @@ -101,7 +101,7 @@ TakeOrderedAndProject [customer_preferred_cust_flag] Exchange [customer_id] #10 WholeStageCodegen (24) Filter [year_total] - HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt as decimal(8,2)))), DecimalType(8,2), true))),customer_id,year_total,sum] + HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt as decimal(8,2)))), DecimalType(8,2)))),customer_id,year_total,sum] InputAdapter Exchange [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year] #11 WholeStageCodegen (23) @@ -134,7 +134,7 @@ TakeOrderedAndProject [customer_preferred_cust_flag] InputAdapter Exchange [customer_id] #13 WholeStageCodegen (33) - HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt as decimal(8,2)))), DecimalType(8,2), true))),customer_id,year_total,sum] + HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt as decimal(8,2)))), DecimalType(8,2)))),customer_id,year_total,sum] InputAdapter Exchange [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year] #14 WholeStageCodegen (32) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q11/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q11/explain.txt index 87c6b6f7123fe..8cb7c021fb3ea 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q11/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q11/explain.txt @@ -130,7 +130,7 @@ Input [12]: [c_customer_id#2, c_first_name#3, c_last_name#4, c_preferred_cust_fl (13) HashAggregate [codegen id : 3] Input [10]: [c_customer_id#2, c_first_name#3, c_last_name#4, c_preferred_cust_flag#5, c_birth_country#6, c_login#7, c_email_address#8, ss_ext_discount_amt#10, ss_ext_list_price#11, d_year#16] Keys [8]: [c_customer_id#2, c_first_name#3, c_last_name#4, d_year#16, c_preferred_cust_flag#5, c_birth_country#6, c_login#7, c_email_address#8] -Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#11 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#10 as decimal(8,2)))), DecimalType(8,2), true)))] +Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#11 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#10 as decimal(8,2)))), DecimalType(8,2))))] Aggregate Attributes [1]: [sum#17] Results [9]: [c_customer_id#2, c_first_name#3, c_last_name#4, d_year#16, c_preferred_cust_flag#5, c_birth_country#6, c_login#7, c_email_address#8, sum#18] @@ -141,9 +141,9 @@ Arguments: hashpartitioning(c_customer_id#2, c_first_name#3, c_last_name#4, d_ye (15) HashAggregate [codegen id : 16] Input [9]: [c_customer_id#2, c_first_name#3, c_last_name#4, d_year#16, c_preferred_cust_flag#5, c_birth_country#6, c_login#7, c_email_address#8, sum#18] Keys [8]: [c_customer_id#2, c_first_name#3, c_last_name#4, d_year#16, c_preferred_cust_flag#5, c_birth_country#6, c_login#7, c_email_address#8] -Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#11 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#10 as decimal(8,2)))), DecimalType(8,2), true)))] -Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#11 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#10 as decimal(8,2)))), DecimalType(8,2), true)))#20] -Results [2]: [c_customer_id#2 AS customer_id#21, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#11 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#10 as decimal(8,2)))), DecimalType(8,2), true)))#20,18,2) AS year_total#22] +Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#11 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#10 as decimal(8,2)))), DecimalType(8,2))))] +Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#11 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#10 as decimal(8,2)))), DecimalType(8,2))))#20] +Results [2]: [c_customer_id#2 AS customer_id#21, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#11 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#10 as decimal(8,2)))), DecimalType(8,2))))#20,18,2) AS year_total#22] (16) Filter [codegen id : 16] Input [2]: [customer_id#21, year_total#22] @@ -206,7 +206,7 @@ Input [12]: [c_customer_id#24, c_first_name#25, c_last_name#26, c_preferred_cust (29) HashAggregate [codegen id : 6] Input [10]: [c_customer_id#24, c_first_name#25, c_last_name#26, c_preferred_cust_flag#27, c_birth_country#28, c_login#29, c_email_address#30, ss_ext_discount_amt#32, ss_ext_list_price#33, d_year#38] Keys [8]: [c_customer_id#24, c_first_name#25, c_last_name#26, d_year#38, c_preferred_cust_flag#27, c_birth_country#28, c_login#29, c_email_address#30] -Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#33 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#32 as decimal(8,2)))), DecimalType(8,2), true)))] +Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#33 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#32 as decimal(8,2)))), DecimalType(8,2))))] Aggregate Attributes [1]: [sum#39] Results [9]: [c_customer_id#24, c_first_name#25, c_last_name#26, d_year#38, c_preferred_cust_flag#27, c_birth_country#28, c_login#29, c_email_address#30, sum#40] @@ -217,9 +217,9 @@ Arguments: hashpartitioning(c_customer_id#24, c_first_name#25, c_last_name#26, d (31) HashAggregate [codegen id : 7] Input [9]: [c_customer_id#24, c_first_name#25, c_last_name#26, d_year#38, c_preferred_cust_flag#27, c_birth_country#28, c_login#29, c_email_address#30, sum#40] Keys [8]: [c_customer_id#24, c_first_name#25, c_last_name#26, d_year#38, c_preferred_cust_flag#27, c_birth_country#28, c_login#29, c_email_address#30] -Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#33 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#32 as decimal(8,2)))), DecimalType(8,2), true)))] -Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#33 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#32 as decimal(8,2)))), DecimalType(8,2), true)))#20] -Results [3]: [c_customer_id#24 AS customer_id#42, c_preferred_cust_flag#27 AS customer_preferred_cust_flag#43, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#33 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#32 as decimal(8,2)))), DecimalType(8,2), true)))#20,18,2) AS year_total#44] +Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#33 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#32 as decimal(8,2)))), DecimalType(8,2))))] +Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#33 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#32 as decimal(8,2)))), DecimalType(8,2))))#20] +Results [3]: [c_customer_id#24 AS customer_id#42, c_preferred_cust_flag#27 AS customer_preferred_cust_flag#43, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#33 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#32 as decimal(8,2)))), DecimalType(8,2))))#20,18,2) AS year_total#44] (32) BroadcastExchange Input [3]: [customer_id#42, customer_preferred_cust_flag#43, year_total#44] @@ -291,7 +291,7 @@ Input [12]: [c_customer_id#47, c_first_name#48, c_last_name#49, c_preferred_cust (47) HashAggregate [codegen id : 10] Input [10]: [c_customer_id#47, c_first_name#48, c_last_name#49, c_preferred_cust_flag#50, c_birth_country#51, c_login#52, c_email_address#53, ws_ext_discount_amt#55, ws_ext_list_price#56, d_year#60] Keys [8]: [c_customer_id#47, c_first_name#48, c_last_name#49, c_preferred_cust_flag#50, c_birth_country#51, c_login#52, c_email_address#53, d_year#60] -Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#56 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#55 as decimal(8,2)))), DecimalType(8,2), true)))] +Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#56 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#55 as decimal(8,2)))), DecimalType(8,2))))] Aggregate Attributes [1]: [sum#61] Results [9]: [c_customer_id#47, c_first_name#48, c_last_name#49, c_preferred_cust_flag#50, c_birth_country#51, c_login#52, c_email_address#53, d_year#60, sum#62] @@ -302,9 +302,9 @@ Arguments: hashpartitioning(c_customer_id#47, c_first_name#48, c_last_name#49, c (49) HashAggregate [codegen id : 11] Input [9]: [c_customer_id#47, c_first_name#48, c_last_name#49, c_preferred_cust_flag#50, c_birth_country#51, c_login#52, c_email_address#53, d_year#60, sum#62] Keys [8]: [c_customer_id#47, c_first_name#48, c_last_name#49, c_preferred_cust_flag#50, c_birth_country#51, c_login#52, c_email_address#53, d_year#60] -Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#56 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#55 as decimal(8,2)))), DecimalType(8,2), true)))] -Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#56 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#55 as decimal(8,2)))), DecimalType(8,2), true)))#64] -Results [2]: [c_customer_id#47 AS customer_id#65, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#56 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#55 as decimal(8,2)))), DecimalType(8,2), true)))#64,18,2) AS year_total#66] +Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#56 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#55 as decimal(8,2)))), DecimalType(8,2))))] +Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#56 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#55 as decimal(8,2)))), DecimalType(8,2))))#64] +Results [2]: [c_customer_id#47 AS customer_id#65, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#56 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#55 as decimal(8,2)))), DecimalType(8,2))))#64,18,2) AS year_total#66] (50) Filter [codegen id : 11] Input [2]: [customer_id#65, year_total#66] @@ -380,7 +380,7 @@ Input [12]: [c_customer_id#69, c_first_name#70, c_last_name#71, c_preferred_cust (66) HashAggregate [codegen id : 14] Input [10]: [c_customer_id#69, c_first_name#70, c_last_name#71, c_preferred_cust_flag#72, c_birth_country#73, c_login#74, c_email_address#75, ws_ext_discount_amt#77, ws_ext_list_price#78, d_year#82] Keys [8]: [c_customer_id#69, c_first_name#70, c_last_name#71, c_preferred_cust_flag#72, c_birth_country#73, c_login#74, c_email_address#75, d_year#82] -Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#78 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#77 as decimal(8,2)))), DecimalType(8,2), true)))] +Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#78 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#77 as decimal(8,2)))), DecimalType(8,2))))] Aggregate Attributes [1]: [sum#83] Results [9]: [c_customer_id#69, c_first_name#70, c_last_name#71, c_preferred_cust_flag#72, c_birth_country#73, c_login#74, c_email_address#75, d_year#82, sum#84] @@ -391,9 +391,9 @@ Arguments: hashpartitioning(c_customer_id#69, c_first_name#70, c_last_name#71, c (68) HashAggregate [codegen id : 15] Input [9]: [c_customer_id#69, c_first_name#70, c_last_name#71, c_preferred_cust_flag#72, c_birth_country#73, c_login#74, c_email_address#75, d_year#82, sum#84] Keys [8]: [c_customer_id#69, c_first_name#70, c_last_name#71, c_preferred_cust_flag#72, c_birth_country#73, c_login#74, c_email_address#75, d_year#82] -Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#78 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#77 as decimal(8,2)))), DecimalType(8,2), true)))] -Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#78 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#77 as decimal(8,2)))), DecimalType(8,2), true)))#64] -Results [2]: [c_customer_id#69 AS customer_id#86, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#78 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#77 as decimal(8,2)))), DecimalType(8,2), true)))#64,18,2) AS year_total#87] +Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#78 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#77 as decimal(8,2)))), DecimalType(8,2))))] +Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#78 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#77 as decimal(8,2)))), DecimalType(8,2))))#64] +Results [2]: [c_customer_id#69 AS customer_id#86, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#78 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#77 as decimal(8,2)))), DecimalType(8,2))))#64,18,2) AS year_total#87] (69) BroadcastExchange Input [2]: [customer_id#86, year_total#87] @@ -402,7 +402,7 @@ Arguments: HashedRelationBroadcastMode(List(input[0, string, true]),false), [id= (70) BroadcastHashJoin [codegen id : 16] Left keys [1]: [customer_id#21] Right keys [1]: [customer_id#86] -Join condition: (CASE WHEN (year_total#66 > 0.00) THEN CheckOverflow((promote_precision(year_total#87) / promote_precision(year_total#66)), DecimalType(38,20), true) END > CASE WHEN (year_total#22 > 0.00) THEN CheckOverflow((promote_precision(year_total#44) / promote_precision(year_total#22)), DecimalType(38,20), true) END) +Join condition: (CASE WHEN (year_total#66 > 0.00) THEN CheckOverflow((promote_precision(year_total#87) / promote_precision(year_total#66)), DecimalType(38,20)) END > CASE WHEN (year_total#22 > 0.00) THEN CheckOverflow((promote_precision(year_total#44) / promote_precision(year_total#22)), DecimalType(38,20)) END) (71) Project [codegen id : 16] Output [1]: [customer_preferred_cust_flag#43] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q11/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q11/simplified.txt index e9c0faa7491a0..6e80ebc5a038d 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q11/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q11/simplified.txt @@ -7,7 +7,7 @@ TakeOrderedAndProject [customer_preferred_cust_flag] Project [customer_id,year_total,customer_preferred_cust_flag,year_total] BroadcastHashJoin [customer_id,customer_id] Filter [year_total] - HashAggregate [c_customer_id,c_first_name,c_last_name,d_year,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt as decimal(8,2)))), DecimalType(8,2), true))),customer_id,year_total,sum] + HashAggregate [c_customer_id,c_first_name,c_last_name,d_year,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt as decimal(8,2)))), DecimalType(8,2)))),customer_id,year_total,sum] InputAdapter Exchange [c_customer_id,c_first_name,c_last_name,d_year,c_preferred_cust_flag,c_birth_country,c_login,c_email_address] #1 WholeStageCodegen (3) @@ -39,7 +39,7 @@ TakeOrderedAndProject [customer_preferred_cust_flag] InputAdapter BroadcastExchange #4 WholeStageCodegen (7) - HashAggregate [c_customer_id,c_first_name,c_last_name,d_year,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt as decimal(8,2)))), DecimalType(8,2), true))),customer_id,customer_preferred_cust_flag,year_total,sum] + HashAggregate [c_customer_id,c_first_name,c_last_name,d_year,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt as decimal(8,2)))), DecimalType(8,2)))),customer_id,customer_preferred_cust_flag,year_total,sum] InputAdapter Exchange [c_customer_id,c_first_name,c_last_name,d_year,c_preferred_cust_flag,c_birth_country,c_login,c_email_address] #5 WholeStageCodegen (6) @@ -72,7 +72,7 @@ TakeOrderedAndProject [customer_preferred_cust_flag] BroadcastExchange #8 WholeStageCodegen (11) Filter [year_total] - HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt as decimal(8,2)))), DecimalType(8,2), true))),customer_id,year_total,sum] + HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt as decimal(8,2)))), DecimalType(8,2)))),customer_id,year_total,sum] InputAdapter Exchange [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year] #9 WholeStageCodegen (10) @@ -98,7 +98,7 @@ TakeOrderedAndProject [customer_preferred_cust_flag] InputAdapter BroadcastExchange #11 WholeStageCodegen (15) - HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt as decimal(8,2)))), DecimalType(8,2), true))),customer_id,year_total,sum] + HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt as decimal(8,2)))), DecimalType(8,2)))),customer_id,year_total,sum] InputAdapter Exchange [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year] #12 WholeStageCodegen (14) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q12.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q12.sf100/explain.txt index 64ee24cf9435c..0f0b678bb7074 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q12.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q12.sf100/explain.txt @@ -121,7 +121,7 @@ Input [8]: [i_item_desc#8, i_category#11, i_class#10, i_current_price#9, itemrev Arguments: [sum(_w1#20) windowspecdefinition(i_class#10, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS _we0#22], [i_class#10] (22) Project [codegen id : 9] -Output [7]: [i_item_desc#8, i_category#11, i_class#10, i_current_price#9, itemrevenue#18, CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(_w0#19) * 100.00), DecimalType(21,2), true) as decimal(27,2))) / promote_precision(_we0#22)), DecimalType(38,17), true) AS revenueratio#23, i_item_id#7] +Output [7]: [i_item_desc#8, i_category#11, i_class#10, i_current_price#9, itemrevenue#18, CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(_w0#19) * 100.00), DecimalType(21,2)) as decimal(27,2))) / promote_precision(_we0#22)), DecimalType(38,17)) AS revenueratio#23, i_item_id#7] Input [9]: [i_item_desc#8, i_category#11, i_class#10, i_current_price#9, itemrevenue#18, _w0#19, _w1#20, i_item_id#7, _we0#22] (23) TakeOrderedAndProject diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q12/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q12/explain.txt index 306ecd52c1a3b..0b4dfea762918 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q12/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q12/explain.txt @@ -106,7 +106,7 @@ Input [8]: [i_item_desc#7, i_category#10, i_class#9, i_current_price#8, itemreve Arguments: [sum(_w1#19) windowspecdefinition(i_class#9, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS _we0#21], [i_class#9] (19) Project [codegen id : 6] -Output [7]: [i_item_desc#7, i_category#10, i_class#9, i_current_price#8, itemrevenue#17, CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(_w0#18) * 100.00), DecimalType(21,2), true) as decimal(27,2))) / promote_precision(_we0#21)), DecimalType(38,17), true) AS revenueratio#22, i_item_id#6] +Output [7]: [i_item_desc#7, i_category#10, i_class#9, i_current_price#8, itemrevenue#17, CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(_w0#18) * 100.00), DecimalType(21,2)) as decimal(27,2))) / promote_precision(_we0#21)), DecimalType(38,17)) AS revenueratio#22, i_item_id#6] Input [9]: [i_item_desc#7, i_category#10, i_class#9, i_current_price#8, itemrevenue#17, _w0#18, _w1#19, i_item_id#6, _we0#21] (20) TakeOrderedAndProject diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q13.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q13.sf100/explain.txt index 7c4e7222a52e7..9d6b17e613ef1 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q13.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q13.sf100/explain.txt @@ -81,7 +81,7 @@ Input [13]: [ss_cdemo_sk#1, ss_hdemo_sk#2, ss_addr_sk#3, ss_store_sk#4, ss_quant Output [2]: [hd_demo_sk#16, hd_dep_count#17] Batched: true Location [not included in comparison]/{warehouse_dir}/household_demographics] -PushedFilters: [IsNotNull(hd_demo_sk), Or(Or(EqualTo(hd_dep_count,3),EqualTo(hd_dep_count,1)),EqualTo(hd_dep_count,1))] +PushedFilters: [IsNotNull(hd_demo_sk), Or(EqualTo(hd_dep_count,3),EqualTo(hd_dep_count,1))] ReadSchema: struct (11) ColumnarToRow [codegen id : 2] @@ -89,7 +89,7 @@ Input [2]: [hd_demo_sk#16, hd_dep_count#17] (12) Filter [codegen id : 2] Input [2]: [hd_demo_sk#16, hd_dep_count#17] -Condition : (isnotnull(hd_demo_sk#16) AND (((hd_dep_count#17 = 3) OR (hd_dep_count#17 = 1)) OR (hd_dep_count#17 = 1))) +Condition : (isnotnull(hd_demo_sk#16) AND ((hd_dep_count#17 = 3) OR (hd_dep_count#17 = 1))) (13) BroadcastExchange Input [2]: [hd_demo_sk#16, hd_dep_count#17] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q13/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q13/explain.txt index 31142b18a09fe..59e8cf7c4d063 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q13/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q13/explain.txt @@ -151,7 +151,7 @@ Input [9]: [ss_cdemo_sk#1, ss_hdemo_sk#2, ss_quantity#5, ss_sales_price#6, ss_ex Output [2]: [hd_demo_sk#23, hd_dep_count#24] Batched: true Location [not included in comparison]/{warehouse_dir}/household_demographics] -PushedFilters: [IsNotNull(hd_demo_sk), Or(Or(EqualTo(hd_dep_count,3),EqualTo(hd_dep_count,1)),EqualTo(hd_dep_count,1))] +PushedFilters: [IsNotNull(hd_demo_sk), Or(EqualTo(hd_dep_count,3),EqualTo(hd_dep_count,1))] ReadSchema: struct (27) ColumnarToRow [codegen id : 5] @@ -159,7 +159,7 @@ Input [2]: [hd_demo_sk#23, hd_dep_count#24] (28) Filter [codegen id : 5] Input [2]: [hd_demo_sk#23, hd_dep_count#24] -Condition : (isnotnull(hd_demo_sk#23) AND (((hd_dep_count#24 = 3) OR (hd_dep_count#24 = 1)) OR (hd_dep_count#24 = 1))) +Condition : (isnotnull(hd_demo_sk#23) AND ((hd_dep_count#24 = 3) OR (hd_dep_count#24 = 1))) (29) BroadcastExchange Input [2]: [hd_demo_sk#23, hd_dep_count#24] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a.sf100/explain.txt index 536a1cc04222f..4105a94131dda 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a.sf100/explain.txt @@ -1,130 +1,127 @@ == Physical Plan == -TakeOrderedAndProject (126) -+- * HashAggregate (125) - +- Exchange (124) - +- * HashAggregate (123) - +- * Expand (122) - +- Union (121) - :- * Project (82) - : +- * Filter (81) - : +- * HashAggregate (80) - : +- Exchange (79) - : +- * HashAggregate (78) - : +- * Project (77) - : +- * BroadcastHashJoin Inner BuildRight (76) - : :- * Project (66) - : : +- * BroadcastHashJoin Inner BuildRight (65) - : : :- * SortMergeJoin LeftSemi (63) +TakeOrderedAndProject (123) ++- * HashAggregate (122) + +- Exchange (121) + +- * HashAggregate (120) + +- * Expand (119) + +- Union (118) + :- * Project (79) + : +- * Filter (78) + : +- * HashAggregate (77) + : +- Exchange (76) + : +- * HashAggregate (75) + : +- * Project (74) + : +- * BroadcastHashJoin Inner BuildRight (73) + : :- * Project (63) + : : +- * BroadcastHashJoin Inner BuildRight (62) + : : :- * SortMergeJoin LeftSemi (60) : : : :- * Sort (5) : : : : +- Exchange (4) : : : : +- * Filter (3) : : : : +- * ColumnarToRow (2) : : : : +- Scan parquet default.store_sales (1) - : : : +- * Sort (62) - : : : +- Exchange (61) - : : : +- * Project (60) - : : : +- * BroadcastHashJoin Inner BuildRight (59) + : : : +- * Sort (59) + : : : +- Exchange (58) + : : : +- * Project (57) + : : : +- * BroadcastHashJoin Inner BuildRight (56) : : : :- * Filter (8) : : : : +- * ColumnarToRow (7) : : : : +- Scan parquet default.item (6) - : : : +- BroadcastExchange (58) - : : : +- * HashAggregate (57) - : : : +- Exchange (56) - : : : +- * HashAggregate (55) - : : : +- * SortMergeJoin LeftSemi (54) - : : : :- * Sort (42) - : : : : +- Exchange (41) - : : : : +- * HashAggregate (40) - : : : : +- Exchange (39) - : : : : +- * HashAggregate (38) - : : : : +- * Project (37) - : : : : +- * BroadcastHashJoin Inner BuildRight (36) - : : : : :- * Project (14) - : : : : : +- * BroadcastHashJoin Inner BuildRight (13) - : : : : : :- * Filter (11) - : : : : : : +- * ColumnarToRow (10) - : : : : : : +- Scan parquet default.store_sales (9) - : : : : : +- ReusedExchange (12) - : : : : +- BroadcastExchange (35) - : : : : +- * SortMergeJoin LeftSemi (34) - : : : : :- * Sort (19) - : : : : : +- Exchange (18) - : : : : : +- * Filter (17) - : : : : : +- * ColumnarToRow (16) - : : : : : +- Scan parquet default.item (15) - : : : : +- * Sort (33) - : : : : +- Exchange (32) - : : : : +- * Project (31) - : : : : +- * BroadcastHashJoin Inner BuildRight (30) - : : : : :- * Project (25) - : : : : : +- * BroadcastHashJoin Inner BuildRight (24) - : : : : : :- * Filter (22) - : : : : : : +- * ColumnarToRow (21) - : : : : : : +- Scan parquet default.catalog_sales (20) - : : : : : +- ReusedExchange (23) - : : : : +- BroadcastExchange (29) - : : : : +- * Filter (28) - : : : : +- * ColumnarToRow (27) - : : : : +- Scan parquet default.item (26) - : : : +- * Sort (53) - : : : +- Exchange (52) - : : : +- * Project (51) - : : : +- * BroadcastHashJoin Inner BuildRight (50) - : : : :- * Project (48) - : : : : +- * BroadcastHashJoin Inner BuildRight (47) - : : : : :- * Filter (45) - : : : : : +- * ColumnarToRow (44) - : : : : : +- Scan parquet default.web_sales (43) - : : : : +- ReusedExchange (46) - : : : +- ReusedExchange (49) - : : +- ReusedExchange (64) - : +- BroadcastExchange (75) - : +- * SortMergeJoin LeftSemi (74) - : :- * Sort (71) - : : +- Exchange (70) - : : +- * Filter (69) - : : +- * ColumnarToRow (68) - : : +- Scan parquet default.item (67) - : +- * Sort (73) - : +- ReusedExchange (72) - :- * Project (101) - : +- * Filter (100) - : +- * HashAggregate (99) - : +- Exchange (98) - : +- * HashAggregate (97) - : +- * Project (96) - : +- * BroadcastHashJoin Inner BuildRight (95) - : :- * Project (93) - : : +- * BroadcastHashJoin Inner BuildRight (92) - : : :- * SortMergeJoin LeftSemi (90) - : : : :- * Sort (87) - : : : : +- Exchange (86) - : : : : +- * Filter (85) - : : : : +- * ColumnarToRow (84) - : : : : +- Scan parquet default.catalog_sales (83) - : : : +- * Sort (89) - : : : +- ReusedExchange (88) - : : +- ReusedExchange (91) - : +- ReusedExchange (94) - +- * Project (120) - +- * Filter (119) - +- * HashAggregate (118) - +- Exchange (117) - +- * HashAggregate (116) - +- * Project (115) - +- * BroadcastHashJoin Inner BuildRight (114) - :- * Project (112) - : +- * BroadcastHashJoin Inner BuildRight (111) - : :- * SortMergeJoin LeftSemi (109) - : : :- * Sort (106) - : : : +- Exchange (105) - : : : +- * Filter (104) - : : : +- * ColumnarToRow (103) - : : : +- Scan parquet default.web_sales (102) - : : +- * Sort (108) - : : +- ReusedExchange (107) - : +- ReusedExchange (110) - +- ReusedExchange (113) + : : : +- BroadcastExchange (55) + : : : +- * SortMergeJoin LeftSemi (54) + : : : :- * Sort (42) + : : : : +- Exchange (41) + : : : : +- * HashAggregate (40) + : : : : +- Exchange (39) + : : : : +- * HashAggregate (38) + : : : : +- * Project (37) + : : : : +- * BroadcastHashJoin Inner BuildRight (36) + : : : : :- * Project (14) + : : : : : +- * BroadcastHashJoin Inner BuildRight (13) + : : : : : :- * Filter (11) + : : : : : : +- * ColumnarToRow (10) + : : : : : : +- Scan parquet default.store_sales (9) + : : : : : +- ReusedExchange (12) + : : : : +- BroadcastExchange (35) + : : : : +- * SortMergeJoin LeftSemi (34) + : : : : :- * Sort (19) + : : : : : +- Exchange (18) + : : : : : +- * Filter (17) + : : : : : +- * ColumnarToRow (16) + : : : : : +- Scan parquet default.item (15) + : : : : +- * Sort (33) + : : : : +- Exchange (32) + : : : : +- * Project (31) + : : : : +- * BroadcastHashJoin Inner BuildRight (30) + : : : : :- * Project (25) + : : : : : +- * BroadcastHashJoin Inner BuildRight (24) + : : : : : :- * Filter (22) + : : : : : : +- * ColumnarToRow (21) + : : : : : : +- Scan parquet default.catalog_sales (20) + : : : : : +- ReusedExchange (23) + : : : : +- BroadcastExchange (29) + : : : : +- * Filter (28) + : : : : +- * ColumnarToRow (27) + : : : : +- Scan parquet default.item (26) + : : : +- * Sort (53) + : : : +- Exchange (52) + : : : +- * Project (51) + : : : +- * BroadcastHashJoin Inner BuildRight (50) + : : : :- * Project (48) + : : : : +- * BroadcastHashJoin Inner BuildRight (47) + : : : : :- * Filter (45) + : : : : : +- * ColumnarToRow (44) + : : : : : +- Scan parquet default.web_sales (43) + : : : : +- ReusedExchange (46) + : : : +- ReusedExchange (49) + : : +- ReusedExchange (61) + : +- BroadcastExchange (72) + : +- * SortMergeJoin LeftSemi (71) + : :- * Sort (68) + : : +- Exchange (67) + : : +- * Filter (66) + : : +- * ColumnarToRow (65) + : : +- Scan parquet default.item (64) + : +- * Sort (70) + : +- ReusedExchange (69) + :- * Project (98) + : +- * Filter (97) + : +- * HashAggregate (96) + : +- Exchange (95) + : +- * HashAggregate (94) + : +- * Project (93) + : +- * BroadcastHashJoin Inner BuildRight (92) + : :- * Project (90) + : : +- * BroadcastHashJoin Inner BuildRight (89) + : : :- * SortMergeJoin LeftSemi (87) + : : : :- * Sort (84) + : : : : +- Exchange (83) + : : : : +- * Filter (82) + : : : : +- * ColumnarToRow (81) + : : : : +- Scan parquet default.catalog_sales (80) + : : : +- * Sort (86) + : : : +- ReusedExchange (85) + : : +- ReusedExchange (88) + : +- ReusedExchange (91) + +- * Project (117) + +- * Filter (116) + +- * HashAggregate (115) + +- Exchange (114) + +- * HashAggregate (113) + +- * Project (112) + +- * BroadcastHashJoin Inner BuildRight (111) + :- * Project (109) + : +- * BroadcastHashJoin Inner BuildRight (108) + : :- * SortMergeJoin LeftSemi (106) + : : :- * Sort (103) + : : : +- Exchange (102) + : : : +- * Filter (101) + : : : +- * ColumnarToRow (100) + : : : +- Scan parquet default.web_sales (99) + : : +- * Sort (105) + : : +- ReusedExchange (104) + : +- ReusedExchange (107) + +- ReusedExchange (110) (1) Scan parquet default.store_sales @@ -157,10 +154,10 @@ Location [not included in comparison]/{warehouse_dir}/item] PushedFilters: [IsNotNull(i_brand_id), IsNotNull(i_class_id), IsNotNull(i_category_id)] ReadSchema: struct -(7) ColumnarToRow [codegen id : 20] +(7) ColumnarToRow [codegen id : 19] Input [4]: [i_item_sk#7, i_brand_id#8, i_class_id#9, i_category_id#10] -(8) Filter [codegen id : 20] +(8) Filter [codegen id : 19] Input [4]: [i_item_sk#7, i_brand_id#8, i_class_id#9, i_category_id#10] Condition : ((isnotnull(i_brand_id#8) AND isnotnull(i_class_id#9)) AND isnotnull(i_category_id#10)) @@ -179,7 +176,7 @@ Input [2]: [ss_item_sk#11, ss_sold_date_sk#12] Input [2]: [ss_item_sk#11, ss_sold_date_sk#12] Condition : isnotnull(ss_item_sk#11) -(12) ReusedExchange [Reuses operator id: 155] +(12) ReusedExchange [Reuses operator id: 152] Output [1]: [d_date_sk#14] (13) BroadcastHashJoin [codegen id : 11] @@ -228,7 +225,7 @@ Input [2]: [cs_item_sk#20, cs_sold_date_sk#21] Input [2]: [cs_item_sk#20, cs_sold_date_sk#21] Condition : isnotnull(cs_item_sk#20) -(23) ReusedExchange [Reuses operator id: 155] +(23) ReusedExchange [Reuses operator id: 152] Output [1]: [d_date_sk#22] (24) BroadcastHashJoin [codegen id : 8] @@ -334,7 +331,7 @@ Input [2]: [ws_item_sk#35, ws_sold_date_sk#36] Input [2]: [ws_item_sk#35, ws_sold_date_sk#36] Condition : isnotnull(ws_item_sk#35) -(46) ReusedExchange [Reuses operator id: 155] +(46) ReusedExchange [Reuses operator id: 152] Output [1]: [d_date_sk#37] (47) BroadcastHashJoin [codegen id : 16] @@ -371,519 +368,501 @@ Left keys [6]: [coalesce(brand_id#30, 0), isnull(brand_id#30), coalesce(class_id Right keys [6]: [coalesce(i_brand_id#39, 0), isnull(i_brand_id#39), coalesce(i_class_id#40, 0), isnull(i_class_id#40), coalesce(i_category_id#41, 0), isnull(i_category_id#41)] Join condition: None -(55) HashAggregate [codegen id : 18] +(55) BroadcastExchange Input [3]: [brand_id#30, class_id#31, category_id#32] -Keys [3]: [brand_id#30, class_id#31, category_id#32] -Functions: [] -Aggregate Attributes: [] -Results [3]: [brand_id#30, class_id#31, category_id#32] - -(56) Exchange -Input [3]: [brand_id#30, class_id#31, category_id#32] -Arguments: hashpartitioning(brand_id#30, class_id#31, category_id#32, 5), ENSURE_REQUIREMENTS, [id=#43] - -(57) HashAggregate [codegen id : 19] -Input [3]: [brand_id#30, class_id#31, category_id#32] -Keys [3]: [brand_id#30, class_id#31, category_id#32] -Functions: [] -Aggregate Attributes: [] -Results [3]: [brand_id#30, class_id#31, category_id#32] - -(58) BroadcastExchange -Input [3]: [brand_id#30, class_id#31, category_id#32] -Arguments: HashedRelationBroadcastMode(List(input[0, int, true], input[1, int, true], input[2, int, true]),false), [id=#44] +Arguments: HashedRelationBroadcastMode(List(input[0, int, true], input[1, int, true], input[2, int, true]),false), [id=#43] -(59) BroadcastHashJoin [codegen id : 20] +(56) BroadcastHashJoin [codegen id : 19] Left keys [3]: [i_brand_id#8, i_class_id#9, i_category_id#10] Right keys [3]: [brand_id#30, class_id#31, category_id#32] Join condition: None -(60) Project [codegen id : 20] -Output [1]: [i_item_sk#7 AS ss_item_sk#45] +(57) Project [codegen id : 19] +Output [1]: [i_item_sk#7 AS ss_item_sk#44] Input [7]: [i_item_sk#7, i_brand_id#8, i_class_id#9, i_category_id#10, brand_id#30, class_id#31, category_id#32] -(61) Exchange -Input [1]: [ss_item_sk#45] -Arguments: hashpartitioning(ss_item_sk#45, 5), ENSURE_REQUIREMENTS, [id=#46] +(58) Exchange +Input [1]: [ss_item_sk#44] +Arguments: hashpartitioning(ss_item_sk#44, 5), ENSURE_REQUIREMENTS, [id=#45] -(62) Sort [codegen id : 21] -Input [1]: [ss_item_sk#45] -Arguments: [ss_item_sk#45 ASC NULLS FIRST], false, 0 +(59) Sort [codegen id : 20] +Input [1]: [ss_item_sk#44] +Arguments: [ss_item_sk#44 ASC NULLS FIRST], false, 0 -(63) SortMergeJoin [codegen id : 45] +(60) SortMergeJoin [codegen id : 43] Left keys [1]: [ss_item_sk#1] -Right keys [1]: [ss_item_sk#45] +Right keys [1]: [ss_item_sk#44] Join condition: None -(64) ReusedExchange [Reuses operator id: 150] -Output [1]: [d_date_sk#47] +(61) ReusedExchange [Reuses operator id: 147] +Output [1]: [d_date_sk#46] -(65) BroadcastHashJoin [codegen id : 45] +(62) BroadcastHashJoin [codegen id : 43] Left keys [1]: [ss_sold_date_sk#4] -Right keys [1]: [d_date_sk#47] +Right keys [1]: [d_date_sk#46] Join condition: None -(66) Project [codegen id : 45] +(63) Project [codegen id : 43] Output [3]: [ss_item_sk#1, ss_quantity#2, ss_list_price#3] -Input [5]: [ss_item_sk#1, ss_quantity#2, ss_list_price#3, ss_sold_date_sk#4, d_date_sk#47] +Input [5]: [ss_item_sk#1, ss_quantity#2, ss_list_price#3, ss_sold_date_sk#4, d_date_sk#46] -(67) Scan parquet default.item -Output [4]: [i_item_sk#48, i_brand_id#49, i_class_id#50, i_category_id#51] +(64) Scan parquet default.item +Output [4]: [i_item_sk#47, i_brand_id#48, i_class_id#49, i_category_id#50] Batched: true Location [not included in comparison]/{warehouse_dir}/item] PushedFilters: [IsNotNull(i_item_sk)] ReadSchema: struct -(68) ColumnarToRow [codegen id : 23] -Input [4]: [i_item_sk#48, i_brand_id#49, i_class_id#50, i_category_id#51] +(65) ColumnarToRow [codegen id : 22] +Input [4]: [i_item_sk#47, i_brand_id#48, i_class_id#49, i_category_id#50] -(69) Filter [codegen id : 23] -Input [4]: [i_item_sk#48, i_brand_id#49, i_class_id#50, i_category_id#51] -Condition : isnotnull(i_item_sk#48) +(66) Filter [codegen id : 22] +Input [4]: [i_item_sk#47, i_brand_id#48, i_class_id#49, i_category_id#50] +Condition : isnotnull(i_item_sk#47) -(70) Exchange -Input [4]: [i_item_sk#48, i_brand_id#49, i_class_id#50, i_category_id#51] -Arguments: hashpartitioning(i_item_sk#48, 5), ENSURE_REQUIREMENTS, [id=#52] +(67) Exchange +Input [4]: [i_item_sk#47, i_brand_id#48, i_class_id#49, i_category_id#50] +Arguments: hashpartitioning(i_item_sk#47, 5), ENSURE_REQUIREMENTS, [id=#51] -(71) Sort [codegen id : 24] -Input [4]: [i_item_sk#48, i_brand_id#49, i_class_id#50, i_category_id#51] -Arguments: [i_item_sk#48 ASC NULLS FIRST], false, 0 +(68) Sort [codegen id : 23] +Input [4]: [i_item_sk#47, i_brand_id#48, i_class_id#49, i_category_id#50] +Arguments: [i_item_sk#47 ASC NULLS FIRST], false, 0 -(72) ReusedExchange [Reuses operator id: 61] -Output [1]: [ss_item_sk#45] +(69) ReusedExchange [Reuses operator id: 58] +Output [1]: [ss_item_sk#44] -(73) Sort [codegen id : 43] -Input [1]: [ss_item_sk#45] -Arguments: [ss_item_sk#45 ASC NULLS FIRST], false, 0 +(70) Sort [codegen id : 41] +Input [1]: [ss_item_sk#44] +Arguments: [ss_item_sk#44 ASC NULLS FIRST], false, 0 -(74) SortMergeJoin [codegen id : 44] -Left keys [1]: [i_item_sk#48] -Right keys [1]: [ss_item_sk#45] +(71) SortMergeJoin [codegen id : 42] +Left keys [1]: [i_item_sk#47] +Right keys [1]: [ss_item_sk#44] Join condition: None -(75) BroadcastExchange -Input [4]: [i_item_sk#48, i_brand_id#49, i_class_id#50, i_category_id#51] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#53] +(72) BroadcastExchange +Input [4]: [i_item_sk#47, i_brand_id#48, i_class_id#49, i_category_id#50] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#52] -(76) BroadcastHashJoin [codegen id : 45] +(73) BroadcastHashJoin [codegen id : 43] Left keys [1]: [ss_item_sk#1] -Right keys [1]: [i_item_sk#48] +Right keys [1]: [i_item_sk#47] Join condition: None -(77) Project [codegen id : 45] -Output [5]: [ss_quantity#2, ss_list_price#3, i_brand_id#49, i_class_id#50, i_category_id#51] -Input [7]: [ss_item_sk#1, ss_quantity#2, ss_list_price#3, i_item_sk#48, i_brand_id#49, i_class_id#50, i_category_id#51] - -(78) HashAggregate [codegen id : 45] -Input [5]: [ss_quantity#2, ss_list_price#3, i_brand_id#49, i_class_id#50, i_category_id#51] -Keys [3]: [i_brand_id#49, i_class_id#50, i_category_id#51] -Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true)), partial_count(1)] -Aggregate Attributes [3]: [sum#54, isEmpty#55, count#56] -Results [6]: [i_brand_id#49, i_class_id#50, i_category_id#51, sum#57, isEmpty#58, count#59] - -(79) Exchange -Input [6]: [i_brand_id#49, i_class_id#50, i_category_id#51, sum#57, isEmpty#58, count#59] -Arguments: hashpartitioning(i_brand_id#49, i_class_id#50, i_category_id#51, 5), ENSURE_REQUIREMENTS, [id=#60] - -(80) HashAggregate [codegen id : 46] -Input [6]: [i_brand_id#49, i_class_id#50, i_category_id#51, sum#57, isEmpty#58, count#59] -Keys [3]: [i_brand_id#49, i_class_id#50, i_category_id#51] -Functions [2]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true)), count(1)] -Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true))#61, count(1)#62] -Results [5]: [i_brand_id#49, i_class_id#50, i_category_id#51, sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true))#61 AS sales#63, count(1)#62 AS number_sales#64] - -(81) Filter [codegen id : 46] -Input [5]: [i_brand_id#49, i_class_id#50, i_category_id#51, sales#63, number_sales#64] -Condition : (isnotnull(sales#63) AND (cast(sales#63 as decimal(32,6)) > cast(Subquery scalar-subquery#65, [id=#66] as decimal(32,6)))) - -(82) Project [codegen id : 46] -Output [6]: [sales#63, number_sales#64, store AS channel#67, i_brand_id#49, i_class_id#50, i_category_id#51] -Input [5]: [i_brand_id#49, i_class_id#50, i_category_id#51, sales#63, number_sales#64] - -(83) Scan parquet default.catalog_sales -Output [4]: [cs_item_sk#68, cs_quantity#69, cs_list_price#70, cs_sold_date_sk#71] +(74) Project [codegen id : 43] +Output [5]: [ss_quantity#2, ss_list_price#3, i_brand_id#48, i_class_id#49, i_category_id#50] +Input [7]: [ss_item_sk#1, ss_quantity#2, ss_list_price#3, i_item_sk#47, i_brand_id#48, i_class_id#49, i_category_id#50] + +(75) HashAggregate [codegen id : 43] +Input [5]: [ss_quantity#2, ss_list_price#3, i_brand_id#48, i_class_id#49, i_category_id#50] +Keys [3]: [i_brand_id#48, i_class_id#49, i_category_id#50] +Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2))), partial_count(1)] +Aggregate Attributes [3]: [sum#53, isEmpty#54, count#55] +Results [6]: [i_brand_id#48, i_class_id#49, i_category_id#50, sum#56, isEmpty#57, count#58] + +(76) Exchange +Input [6]: [i_brand_id#48, i_class_id#49, i_category_id#50, sum#56, isEmpty#57, count#58] +Arguments: hashpartitioning(i_brand_id#48, i_class_id#49, i_category_id#50, 5), ENSURE_REQUIREMENTS, [id=#59] + +(77) HashAggregate [codegen id : 44] +Input [6]: [i_brand_id#48, i_class_id#49, i_category_id#50, sum#56, isEmpty#57, count#58] +Keys [3]: [i_brand_id#48, i_class_id#49, i_category_id#50] +Functions [2]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2))), count(1)] +Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2)))#60, count(1)#61] +Results [5]: [i_brand_id#48, i_class_id#49, i_category_id#50, sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2)))#60 AS sales#62, count(1)#61 AS number_sales#63] + +(78) Filter [codegen id : 44] +Input [5]: [i_brand_id#48, i_class_id#49, i_category_id#50, sales#62, number_sales#63] +Condition : (isnotnull(sales#62) AND (cast(sales#62 as decimal(32,6)) > cast(Subquery scalar-subquery#64, [id=#65] as decimal(32,6)))) + +(79) Project [codegen id : 44] +Output [6]: [sales#62, number_sales#63, store AS channel#66, i_brand_id#48, i_class_id#49, i_category_id#50] +Input [5]: [i_brand_id#48, i_class_id#49, i_category_id#50, sales#62, number_sales#63] + +(80) Scan parquet default.catalog_sales +Output [4]: [cs_item_sk#67, cs_quantity#68, cs_list_price#69, cs_sold_date_sk#70] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(cs_sold_date_sk#71), dynamicpruningexpression(cs_sold_date_sk#71 IN dynamicpruning#5)] +PartitionFilters: [isnotnull(cs_sold_date_sk#70), dynamicpruningexpression(cs_sold_date_sk#70 IN dynamicpruning#5)] PushedFilters: [IsNotNull(cs_item_sk)] ReadSchema: struct -(84) ColumnarToRow [codegen id : 47] -Input [4]: [cs_item_sk#68, cs_quantity#69, cs_list_price#70, cs_sold_date_sk#71] +(81) ColumnarToRow [codegen id : 45] +Input [4]: [cs_item_sk#67, cs_quantity#68, cs_list_price#69, cs_sold_date_sk#70] -(85) Filter [codegen id : 47] -Input [4]: [cs_item_sk#68, cs_quantity#69, cs_list_price#70, cs_sold_date_sk#71] -Condition : isnotnull(cs_item_sk#68) +(82) Filter [codegen id : 45] +Input [4]: [cs_item_sk#67, cs_quantity#68, cs_list_price#69, cs_sold_date_sk#70] +Condition : isnotnull(cs_item_sk#67) -(86) Exchange -Input [4]: [cs_item_sk#68, cs_quantity#69, cs_list_price#70, cs_sold_date_sk#71] -Arguments: hashpartitioning(cs_item_sk#68, 5), ENSURE_REQUIREMENTS, [id=#72] +(83) Exchange +Input [4]: [cs_item_sk#67, cs_quantity#68, cs_list_price#69, cs_sold_date_sk#70] +Arguments: hashpartitioning(cs_item_sk#67, 5), ENSURE_REQUIREMENTS, [id=#71] -(87) Sort [codegen id : 48] -Input [4]: [cs_item_sk#68, cs_quantity#69, cs_list_price#70, cs_sold_date_sk#71] -Arguments: [cs_item_sk#68 ASC NULLS FIRST], false, 0 +(84) Sort [codegen id : 46] +Input [4]: [cs_item_sk#67, cs_quantity#68, cs_list_price#69, cs_sold_date_sk#70] +Arguments: [cs_item_sk#67 ASC NULLS FIRST], false, 0 -(88) ReusedExchange [Reuses operator id: 61] -Output [1]: [ss_item_sk#45] +(85) ReusedExchange [Reuses operator id: 58] +Output [1]: [ss_item_sk#44] -(89) Sort [codegen id : 67] -Input [1]: [ss_item_sk#45] -Arguments: [ss_item_sk#45 ASC NULLS FIRST], false, 0 +(86) Sort [codegen id : 64] +Input [1]: [ss_item_sk#44] +Arguments: [ss_item_sk#44 ASC NULLS FIRST], false, 0 -(90) SortMergeJoin [codegen id : 91] -Left keys [1]: [cs_item_sk#68] -Right keys [1]: [ss_item_sk#45] +(87) SortMergeJoin [codegen id : 87] +Left keys [1]: [cs_item_sk#67] +Right keys [1]: [ss_item_sk#44] Join condition: None -(91) ReusedExchange [Reuses operator id: 150] -Output [1]: [d_date_sk#73] +(88) ReusedExchange [Reuses operator id: 147] +Output [1]: [d_date_sk#72] -(92) BroadcastHashJoin [codegen id : 91] -Left keys [1]: [cs_sold_date_sk#71] -Right keys [1]: [d_date_sk#73] +(89) BroadcastHashJoin [codegen id : 87] +Left keys [1]: [cs_sold_date_sk#70] +Right keys [1]: [d_date_sk#72] Join condition: None -(93) Project [codegen id : 91] -Output [3]: [cs_item_sk#68, cs_quantity#69, cs_list_price#70] -Input [5]: [cs_item_sk#68, cs_quantity#69, cs_list_price#70, cs_sold_date_sk#71, d_date_sk#73] +(90) Project [codegen id : 87] +Output [3]: [cs_item_sk#67, cs_quantity#68, cs_list_price#69] +Input [5]: [cs_item_sk#67, cs_quantity#68, cs_list_price#69, cs_sold_date_sk#70, d_date_sk#72] -(94) ReusedExchange [Reuses operator id: 75] -Output [4]: [i_item_sk#74, i_brand_id#75, i_class_id#76, i_category_id#77] +(91) ReusedExchange [Reuses operator id: 72] +Output [4]: [i_item_sk#73, i_brand_id#74, i_class_id#75, i_category_id#76] -(95) BroadcastHashJoin [codegen id : 91] -Left keys [1]: [cs_item_sk#68] -Right keys [1]: [i_item_sk#74] +(92) BroadcastHashJoin [codegen id : 87] +Left keys [1]: [cs_item_sk#67] +Right keys [1]: [i_item_sk#73] Join condition: None -(96) Project [codegen id : 91] -Output [5]: [cs_quantity#69, cs_list_price#70, i_brand_id#75, i_class_id#76, i_category_id#77] -Input [7]: [cs_item_sk#68, cs_quantity#69, cs_list_price#70, i_item_sk#74, i_brand_id#75, i_class_id#76, i_category_id#77] - -(97) HashAggregate [codegen id : 91] -Input [5]: [cs_quantity#69, cs_list_price#70, i_brand_id#75, i_class_id#76, i_category_id#77] -Keys [3]: [i_brand_id#75, i_class_id#76, i_category_id#77] -Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(cast(cs_quantity#69 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price#70 as decimal(12,2)))), DecimalType(18,2), true)), partial_count(1)] -Aggregate Attributes [3]: [sum#78, isEmpty#79, count#80] -Results [6]: [i_brand_id#75, i_class_id#76, i_category_id#77, sum#81, isEmpty#82, count#83] - -(98) Exchange -Input [6]: [i_brand_id#75, i_class_id#76, i_category_id#77, sum#81, isEmpty#82, count#83] -Arguments: hashpartitioning(i_brand_id#75, i_class_id#76, i_category_id#77, 5), ENSURE_REQUIREMENTS, [id=#84] - -(99) HashAggregate [codegen id : 92] -Input [6]: [i_brand_id#75, i_class_id#76, i_category_id#77, sum#81, isEmpty#82, count#83] -Keys [3]: [i_brand_id#75, i_class_id#76, i_category_id#77] -Functions [2]: [sum(CheckOverflow((promote_precision(cast(cast(cs_quantity#69 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price#70 as decimal(12,2)))), DecimalType(18,2), true)), count(1)] -Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(cast(cs_quantity#69 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price#70 as decimal(12,2)))), DecimalType(18,2), true))#85, count(1)#86] -Results [5]: [i_brand_id#75, i_class_id#76, i_category_id#77, sum(CheckOverflow((promote_precision(cast(cast(cs_quantity#69 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price#70 as decimal(12,2)))), DecimalType(18,2), true))#85 AS sales#87, count(1)#86 AS number_sales#88] - -(100) Filter [codegen id : 92] -Input [5]: [i_brand_id#75, i_class_id#76, i_category_id#77, sales#87, number_sales#88] -Condition : (isnotnull(sales#87) AND (cast(sales#87 as decimal(32,6)) > cast(ReusedSubquery Subquery scalar-subquery#65, [id=#66] as decimal(32,6)))) - -(101) Project [codegen id : 92] -Output [6]: [sales#87, number_sales#88, catalog AS channel#89, i_brand_id#75, i_class_id#76, i_category_id#77] -Input [5]: [i_brand_id#75, i_class_id#76, i_category_id#77, sales#87, number_sales#88] - -(102) Scan parquet default.web_sales -Output [4]: [ws_item_sk#90, ws_quantity#91, ws_list_price#92, ws_sold_date_sk#93] +(93) Project [codegen id : 87] +Output [5]: [cs_quantity#68, cs_list_price#69, i_brand_id#74, i_class_id#75, i_category_id#76] +Input [7]: [cs_item_sk#67, cs_quantity#68, cs_list_price#69, i_item_sk#73, i_brand_id#74, i_class_id#75, i_category_id#76] + +(94) HashAggregate [codegen id : 87] +Input [5]: [cs_quantity#68, cs_list_price#69, i_brand_id#74, i_class_id#75, i_category_id#76] +Keys [3]: [i_brand_id#74, i_class_id#75, i_category_id#76] +Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(cs_quantity#68 as decimal(12,2))) * promote_precision(cast(cs_list_price#69 as decimal(12,2)))), DecimalType(18,2))), partial_count(1)] +Aggregate Attributes [3]: [sum#77, isEmpty#78, count#79] +Results [6]: [i_brand_id#74, i_class_id#75, i_category_id#76, sum#80, isEmpty#81, count#82] + +(95) Exchange +Input [6]: [i_brand_id#74, i_class_id#75, i_category_id#76, sum#80, isEmpty#81, count#82] +Arguments: hashpartitioning(i_brand_id#74, i_class_id#75, i_category_id#76, 5), ENSURE_REQUIREMENTS, [id=#83] + +(96) HashAggregate [codegen id : 88] +Input [6]: [i_brand_id#74, i_class_id#75, i_category_id#76, sum#80, isEmpty#81, count#82] +Keys [3]: [i_brand_id#74, i_class_id#75, i_category_id#76] +Functions [2]: [sum(CheckOverflow((promote_precision(cast(cs_quantity#68 as decimal(12,2))) * promote_precision(cast(cs_list_price#69 as decimal(12,2)))), DecimalType(18,2))), count(1)] +Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(cs_quantity#68 as decimal(12,2))) * promote_precision(cast(cs_list_price#69 as decimal(12,2)))), DecimalType(18,2)))#84, count(1)#85] +Results [5]: [i_brand_id#74, i_class_id#75, i_category_id#76, sum(CheckOverflow((promote_precision(cast(cs_quantity#68 as decimal(12,2))) * promote_precision(cast(cs_list_price#69 as decimal(12,2)))), DecimalType(18,2)))#84 AS sales#86, count(1)#85 AS number_sales#87] + +(97) Filter [codegen id : 88] +Input [5]: [i_brand_id#74, i_class_id#75, i_category_id#76, sales#86, number_sales#87] +Condition : (isnotnull(sales#86) AND (cast(sales#86 as decimal(32,6)) > cast(ReusedSubquery Subquery scalar-subquery#64, [id=#65] as decimal(32,6)))) + +(98) Project [codegen id : 88] +Output [6]: [sales#86, number_sales#87, catalog AS channel#88, i_brand_id#74, i_class_id#75, i_category_id#76] +Input [5]: [i_brand_id#74, i_class_id#75, i_category_id#76, sales#86, number_sales#87] + +(99) Scan parquet default.web_sales +Output [4]: [ws_item_sk#89, ws_quantity#90, ws_list_price#91, ws_sold_date_sk#92] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(ws_sold_date_sk#93), dynamicpruningexpression(ws_sold_date_sk#93 IN dynamicpruning#5)] +PartitionFilters: [isnotnull(ws_sold_date_sk#92), dynamicpruningexpression(ws_sold_date_sk#92 IN dynamicpruning#5)] PushedFilters: [IsNotNull(ws_item_sk)] ReadSchema: struct -(103) ColumnarToRow [codegen id : 93] -Input [4]: [ws_item_sk#90, ws_quantity#91, ws_list_price#92, ws_sold_date_sk#93] +(100) ColumnarToRow [codegen id : 89] +Input [4]: [ws_item_sk#89, ws_quantity#90, ws_list_price#91, ws_sold_date_sk#92] -(104) Filter [codegen id : 93] -Input [4]: [ws_item_sk#90, ws_quantity#91, ws_list_price#92, ws_sold_date_sk#93] -Condition : isnotnull(ws_item_sk#90) +(101) Filter [codegen id : 89] +Input [4]: [ws_item_sk#89, ws_quantity#90, ws_list_price#91, ws_sold_date_sk#92] +Condition : isnotnull(ws_item_sk#89) -(105) Exchange -Input [4]: [ws_item_sk#90, ws_quantity#91, ws_list_price#92, ws_sold_date_sk#93] -Arguments: hashpartitioning(ws_item_sk#90, 5), ENSURE_REQUIREMENTS, [id=#94] +(102) Exchange +Input [4]: [ws_item_sk#89, ws_quantity#90, ws_list_price#91, ws_sold_date_sk#92] +Arguments: hashpartitioning(ws_item_sk#89, 5), ENSURE_REQUIREMENTS, [id=#93] -(106) Sort [codegen id : 94] -Input [4]: [ws_item_sk#90, ws_quantity#91, ws_list_price#92, ws_sold_date_sk#93] -Arguments: [ws_item_sk#90 ASC NULLS FIRST], false, 0 +(103) Sort [codegen id : 90] +Input [4]: [ws_item_sk#89, ws_quantity#90, ws_list_price#91, ws_sold_date_sk#92] +Arguments: [ws_item_sk#89 ASC NULLS FIRST], false, 0 -(107) ReusedExchange [Reuses operator id: 61] -Output [1]: [ss_item_sk#45] +(104) ReusedExchange [Reuses operator id: 58] +Output [1]: [ss_item_sk#44] -(108) Sort [codegen id : 113] -Input [1]: [ss_item_sk#45] -Arguments: [ss_item_sk#45 ASC NULLS FIRST], false, 0 +(105) Sort [codegen id : 108] +Input [1]: [ss_item_sk#44] +Arguments: [ss_item_sk#44 ASC NULLS FIRST], false, 0 -(109) SortMergeJoin [codegen id : 137] -Left keys [1]: [ws_item_sk#90] -Right keys [1]: [ss_item_sk#45] +(106) SortMergeJoin [codegen id : 131] +Left keys [1]: [ws_item_sk#89] +Right keys [1]: [ss_item_sk#44] Join condition: None -(110) ReusedExchange [Reuses operator id: 150] -Output [1]: [d_date_sk#95] +(107) ReusedExchange [Reuses operator id: 147] +Output [1]: [d_date_sk#94] -(111) BroadcastHashJoin [codegen id : 137] -Left keys [1]: [ws_sold_date_sk#93] -Right keys [1]: [d_date_sk#95] +(108) BroadcastHashJoin [codegen id : 131] +Left keys [1]: [ws_sold_date_sk#92] +Right keys [1]: [d_date_sk#94] Join condition: None -(112) Project [codegen id : 137] -Output [3]: [ws_item_sk#90, ws_quantity#91, ws_list_price#92] -Input [5]: [ws_item_sk#90, ws_quantity#91, ws_list_price#92, ws_sold_date_sk#93, d_date_sk#95] +(109) Project [codegen id : 131] +Output [3]: [ws_item_sk#89, ws_quantity#90, ws_list_price#91] +Input [5]: [ws_item_sk#89, ws_quantity#90, ws_list_price#91, ws_sold_date_sk#92, d_date_sk#94] -(113) ReusedExchange [Reuses operator id: 75] -Output [4]: [i_item_sk#96, i_brand_id#97, i_class_id#98, i_category_id#99] +(110) ReusedExchange [Reuses operator id: 72] +Output [4]: [i_item_sk#95, i_brand_id#96, i_class_id#97, i_category_id#98] -(114) BroadcastHashJoin [codegen id : 137] -Left keys [1]: [ws_item_sk#90] -Right keys [1]: [i_item_sk#96] +(111) BroadcastHashJoin [codegen id : 131] +Left keys [1]: [ws_item_sk#89] +Right keys [1]: [i_item_sk#95] Join condition: None -(115) Project [codegen id : 137] -Output [5]: [ws_quantity#91, ws_list_price#92, i_brand_id#97, i_class_id#98, i_category_id#99] -Input [7]: [ws_item_sk#90, ws_quantity#91, ws_list_price#92, i_item_sk#96, i_brand_id#97, i_class_id#98, i_category_id#99] - -(116) HashAggregate [codegen id : 137] -Input [5]: [ws_quantity#91, ws_list_price#92, i_brand_id#97, i_class_id#98, i_category_id#99] -Keys [3]: [i_brand_id#97, i_class_id#98, i_category_id#99] -Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(cast(ws_quantity#91 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price#92 as decimal(12,2)))), DecimalType(18,2), true)), partial_count(1)] -Aggregate Attributes [3]: [sum#100, isEmpty#101, count#102] -Results [6]: [i_brand_id#97, i_class_id#98, i_category_id#99, sum#103, isEmpty#104, count#105] - -(117) Exchange -Input [6]: [i_brand_id#97, i_class_id#98, i_category_id#99, sum#103, isEmpty#104, count#105] -Arguments: hashpartitioning(i_brand_id#97, i_class_id#98, i_category_id#99, 5), ENSURE_REQUIREMENTS, [id=#106] - -(118) HashAggregate [codegen id : 138] -Input [6]: [i_brand_id#97, i_class_id#98, i_category_id#99, sum#103, isEmpty#104, count#105] -Keys [3]: [i_brand_id#97, i_class_id#98, i_category_id#99] -Functions [2]: [sum(CheckOverflow((promote_precision(cast(cast(ws_quantity#91 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price#92 as decimal(12,2)))), DecimalType(18,2), true)), count(1)] -Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(cast(ws_quantity#91 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price#92 as decimal(12,2)))), DecimalType(18,2), true))#107, count(1)#108] -Results [5]: [i_brand_id#97, i_class_id#98, i_category_id#99, sum(CheckOverflow((promote_precision(cast(cast(ws_quantity#91 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price#92 as decimal(12,2)))), DecimalType(18,2), true))#107 AS sales#109, count(1)#108 AS number_sales#110] - -(119) Filter [codegen id : 138] -Input [5]: [i_brand_id#97, i_class_id#98, i_category_id#99, sales#109, number_sales#110] -Condition : (isnotnull(sales#109) AND (cast(sales#109 as decimal(32,6)) > cast(ReusedSubquery Subquery scalar-subquery#65, [id=#66] as decimal(32,6)))) - -(120) Project [codegen id : 138] -Output [6]: [sales#109, number_sales#110, web AS channel#111, i_brand_id#97, i_class_id#98, i_category_id#99] -Input [5]: [i_brand_id#97, i_class_id#98, i_category_id#99, sales#109, number_sales#110] - -(121) Union - -(122) Expand [codegen id : 139] -Input [6]: [sales#63, number_sales#64, channel#67, i_brand_id#49, i_class_id#50, i_category_id#51] -Arguments: [[sales#63, number_sales#64, channel#67, i_brand_id#49, i_class_id#50, i_category_id#51, 0], [sales#63, number_sales#64, channel#67, i_brand_id#49, i_class_id#50, null, 1], [sales#63, number_sales#64, channel#67, i_brand_id#49, null, null, 3], [sales#63, number_sales#64, channel#67, null, null, null, 7], [sales#63, number_sales#64, null, null, null, null, 15]], [sales#63, number_sales#64, channel#112, i_brand_id#113, i_class_id#114, i_category_id#115, spark_grouping_id#116] - -(123) HashAggregate [codegen id : 139] -Input [7]: [sales#63, number_sales#64, channel#112, i_brand_id#113, i_class_id#114, i_category_id#115, spark_grouping_id#116] -Keys [5]: [channel#112, i_brand_id#113, i_class_id#114, i_category_id#115, spark_grouping_id#116] -Functions [2]: [partial_sum(sales#63), partial_sum(number_sales#64)] -Aggregate Attributes [3]: [sum#117, isEmpty#118, sum#119] -Results [8]: [channel#112, i_brand_id#113, i_class_id#114, i_category_id#115, spark_grouping_id#116, sum#120, isEmpty#121, sum#122] - -(124) Exchange -Input [8]: [channel#112, i_brand_id#113, i_class_id#114, i_category_id#115, spark_grouping_id#116, sum#120, isEmpty#121, sum#122] -Arguments: hashpartitioning(channel#112, i_brand_id#113, i_class_id#114, i_category_id#115, spark_grouping_id#116, 5), ENSURE_REQUIREMENTS, [id=#123] - -(125) HashAggregate [codegen id : 140] -Input [8]: [channel#112, i_brand_id#113, i_class_id#114, i_category_id#115, spark_grouping_id#116, sum#120, isEmpty#121, sum#122] -Keys [5]: [channel#112, i_brand_id#113, i_class_id#114, i_category_id#115, spark_grouping_id#116] -Functions [2]: [sum(sales#63), sum(number_sales#64)] -Aggregate Attributes [2]: [sum(sales#63)#124, sum(number_sales#64)#125] -Results [6]: [channel#112, i_brand_id#113, i_class_id#114, i_category_id#115, sum(sales#63)#124 AS sum(sales)#126, sum(number_sales#64)#125 AS sum(number_sales)#127] - -(126) TakeOrderedAndProject -Input [6]: [channel#112, i_brand_id#113, i_class_id#114, i_category_id#115, sum(sales)#126, sum(number_sales)#127] -Arguments: 100, [channel#112 ASC NULLS FIRST, i_brand_id#113 ASC NULLS FIRST, i_class_id#114 ASC NULLS FIRST, i_category_id#115 ASC NULLS FIRST], [channel#112, i_brand_id#113, i_class_id#114, i_category_id#115, sum(sales)#126, sum(number_sales)#127] +(112) Project [codegen id : 131] +Output [5]: [ws_quantity#90, ws_list_price#91, i_brand_id#96, i_class_id#97, i_category_id#98] +Input [7]: [ws_item_sk#89, ws_quantity#90, ws_list_price#91, i_item_sk#95, i_brand_id#96, i_class_id#97, i_category_id#98] + +(113) HashAggregate [codegen id : 131] +Input [5]: [ws_quantity#90, ws_list_price#91, i_brand_id#96, i_class_id#97, i_category_id#98] +Keys [3]: [i_brand_id#96, i_class_id#97, i_category_id#98] +Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(ws_quantity#90 as decimal(12,2))) * promote_precision(cast(ws_list_price#91 as decimal(12,2)))), DecimalType(18,2))), partial_count(1)] +Aggregate Attributes [3]: [sum#99, isEmpty#100, count#101] +Results [6]: [i_brand_id#96, i_class_id#97, i_category_id#98, sum#102, isEmpty#103, count#104] + +(114) Exchange +Input [6]: [i_brand_id#96, i_class_id#97, i_category_id#98, sum#102, isEmpty#103, count#104] +Arguments: hashpartitioning(i_brand_id#96, i_class_id#97, i_category_id#98, 5), ENSURE_REQUIREMENTS, [id=#105] + +(115) HashAggregate [codegen id : 132] +Input [6]: [i_brand_id#96, i_class_id#97, i_category_id#98, sum#102, isEmpty#103, count#104] +Keys [3]: [i_brand_id#96, i_class_id#97, i_category_id#98] +Functions [2]: [sum(CheckOverflow((promote_precision(cast(ws_quantity#90 as decimal(12,2))) * promote_precision(cast(ws_list_price#91 as decimal(12,2)))), DecimalType(18,2))), count(1)] +Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(ws_quantity#90 as decimal(12,2))) * promote_precision(cast(ws_list_price#91 as decimal(12,2)))), DecimalType(18,2)))#106, count(1)#107] +Results [5]: [i_brand_id#96, i_class_id#97, i_category_id#98, sum(CheckOverflow((promote_precision(cast(ws_quantity#90 as decimal(12,2))) * promote_precision(cast(ws_list_price#91 as decimal(12,2)))), DecimalType(18,2)))#106 AS sales#108, count(1)#107 AS number_sales#109] + +(116) Filter [codegen id : 132] +Input [5]: [i_brand_id#96, i_class_id#97, i_category_id#98, sales#108, number_sales#109] +Condition : (isnotnull(sales#108) AND (cast(sales#108 as decimal(32,6)) > cast(ReusedSubquery Subquery scalar-subquery#64, [id=#65] as decimal(32,6)))) + +(117) Project [codegen id : 132] +Output [6]: [sales#108, number_sales#109, web AS channel#110, i_brand_id#96, i_class_id#97, i_category_id#98] +Input [5]: [i_brand_id#96, i_class_id#97, i_category_id#98, sales#108, number_sales#109] + +(118) Union + +(119) Expand [codegen id : 133] +Input [6]: [sales#62, number_sales#63, channel#66, i_brand_id#48, i_class_id#49, i_category_id#50] +Arguments: [[sales#62, number_sales#63, channel#66, i_brand_id#48, i_class_id#49, i_category_id#50, 0], [sales#62, number_sales#63, channel#66, i_brand_id#48, i_class_id#49, null, 1], [sales#62, number_sales#63, channel#66, i_brand_id#48, null, null, 3], [sales#62, number_sales#63, channel#66, null, null, null, 7], [sales#62, number_sales#63, null, null, null, null, 15]], [sales#62, number_sales#63, channel#111, i_brand_id#112, i_class_id#113, i_category_id#114, spark_grouping_id#115] + +(120) HashAggregate [codegen id : 133] +Input [7]: [sales#62, number_sales#63, channel#111, i_brand_id#112, i_class_id#113, i_category_id#114, spark_grouping_id#115] +Keys [5]: [channel#111, i_brand_id#112, i_class_id#113, i_category_id#114, spark_grouping_id#115] +Functions [2]: [partial_sum(sales#62), partial_sum(number_sales#63)] +Aggregate Attributes [3]: [sum#116, isEmpty#117, sum#118] +Results [8]: [channel#111, i_brand_id#112, i_class_id#113, i_category_id#114, spark_grouping_id#115, sum#119, isEmpty#120, sum#121] + +(121) Exchange +Input [8]: [channel#111, i_brand_id#112, i_class_id#113, i_category_id#114, spark_grouping_id#115, sum#119, isEmpty#120, sum#121] +Arguments: hashpartitioning(channel#111, i_brand_id#112, i_class_id#113, i_category_id#114, spark_grouping_id#115, 5), ENSURE_REQUIREMENTS, [id=#122] + +(122) HashAggregate [codegen id : 134] +Input [8]: [channel#111, i_brand_id#112, i_class_id#113, i_category_id#114, spark_grouping_id#115, sum#119, isEmpty#120, sum#121] +Keys [5]: [channel#111, i_brand_id#112, i_class_id#113, i_category_id#114, spark_grouping_id#115] +Functions [2]: [sum(sales#62), sum(number_sales#63)] +Aggregate Attributes [2]: [sum(sales#62)#123, sum(number_sales#63)#124] +Results [6]: [channel#111, i_brand_id#112, i_class_id#113, i_category_id#114, sum(sales#62)#123 AS sum(sales)#125, sum(number_sales#63)#124 AS sum(number_sales)#126] + +(123) TakeOrderedAndProject +Input [6]: [channel#111, i_brand_id#112, i_class_id#113, i_category_id#114, sum(sales)#125, sum(number_sales)#126] +Arguments: 100, [channel#111 ASC NULLS FIRST, i_brand_id#112 ASC NULLS FIRST, i_class_id#113 ASC NULLS FIRST, i_category_id#114 ASC NULLS FIRST], [channel#111, i_brand_id#112, i_class_id#113, i_category_id#114, sum(sales)#125, sum(number_sales)#126] ===== Subqueries ===== -Subquery:1 Hosting operator id = 81 Hosting Expression = Subquery scalar-subquery#65, [id=#66] -* HashAggregate (145) -+- Exchange (144) - +- * HashAggregate (143) - +- Union (142) - :- * Project (131) - : +- * BroadcastHashJoin Inner BuildRight (130) - : :- * ColumnarToRow (128) - : : +- Scan parquet default.store_sales (127) - : +- ReusedExchange (129) - :- * Project (136) - : +- * BroadcastHashJoin Inner BuildRight (135) - : :- * ColumnarToRow (133) - : : +- Scan parquet default.catalog_sales (132) - : +- ReusedExchange (134) - +- * Project (141) - +- * BroadcastHashJoin Inner BuildRight (140) - :- * ColumnarToRow (138) - : +- Scan parquet default.web_sales (137) - +- ReusedExchange (139) - - -(127) Scan parquet default.store_sales -Output [3]: [ss_quantity#128, ss_list_price#129, ss_sold_date_sk#130] +Subquery:1 Hosting operator id = 78 Hosting Expression = Subquery scalar-subquery#64, [id=#65] +* HashAggregate (142) ++- Exchange (141) + +- * HashAggregate (140) + +- Union (139) + :- * Project (128) + : +- * BroadcastHashJoin Inner BuildRight (127) + : :- * ColumnarToRow (125) + : : +- Scan parquet default.store_sales (124) + : +- ReusedExchange (126) + :- * Project (133) + : +- * BroadcastHashJoin Inner BuildRight (132) + : :- * ColumnarToRow (130) + : : +- Scan parquet default.catalog_sales (129) + : +- ReusedExchange (131) + +- * Project (138) + +- * BroadcastHashJoin Inner BuildRight (137) + :- * ColumnarToRow (135) + : +- Scan parquet default.web_sales (134) + +- ReusedExchange (136) + + +(124) Scan parquet default.store_sales +Output [3]: [ss_quantity#127, ss_list_price#128, ss_sold_date_sk#129] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(ss_sold_date_sk#130), dynamicpruningexpression(ss_sold_date_sk#130 IN dynamicpruning#13)] +PartitionFilters: [isnotnull(ss_sold_date_sk#129), dynamicpruningexpression(ss_sold_date_sk#129 IN dynamicpruning#13)] ReadSchema: struct -(128) ColumnarToRow [codegen id : 2] -Input [3]: [ss_quantity#128, ss_list_price#129, ss_sold_date_sk#130] +(125) ColumnarToRow [codegen id : 2] +Input [3]: [ss_quantity#127, ss_list_price#128, ss_sold_date_sk#129] -(129) ReusedExchange [Reuses operator id: 155] -Output [1]: [d_date_sk#131] +(126) ReusedExchange [Reuses operator id: 152] +Output [1]: [d_date_sk#130] -(130) BroadcastHashJoin [codegen id : 2] -Left keys [1]: [ss_sold_date_sk#130] -Right keys [1]: [d_date_sk#131] +(127) BroadcastHashJoin [codegen id : 2] +Left keys [1]: [ss_sold_date_sk#129] +Right keys [1]: [d_date_sk#130] Join condition: None -(131) Project [codegen id : 2] -Output [2]: [ss_quantity#128 AS quantity#132, ss_list_price#129 AS list_price#133] -Input [4]: [ss_quantity#128, ss_list_price#129, ss_sold_date_sk#130, d_date_sk#131] +(128) Project [codegen id : 2] +Output [2]: [ss_quantity#127 AS quantity#131, ss_list_price#128 AS list_price#132] +Input [4]: [ss_quantity#127, ss_list_price#128, ss_sold_date_sk#129, d_date_sk#130] -(132) Scan parquet default.catalog_sales -Output [3]: [cs_quantity#134, cs_list_price#135, cs_sold_date_sk#136] +(129) Scan parquet default.catalog_sales +Output [3]: [cs_quantity#133, cs_list_price#134, cs_sold_date_sk#135] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(cs_sold_date_sk#136), dynamicpruningexpression(cs_sold_date_sk#136 IN dynamicpruning#13)] +PartitionFilters: [isnotnull(cs_sold_date_sk#135), dynamicpruningexpression(cs_sold_date_sk#135 IN dynamicpruning#13)] ReadSchema: struct -(133) ColumnarToRow [codegen id : 4] -Input [3]: [cs_quantity#134, cs_list_price#135, cs_sold_date_sk#136] +(130) ColumnarToRow [codegen id : 4] +Input [3]: [cs_quantity#133, cs_list_price#134, cs_sold_date_sk#135] -(134) ReusedExchange [Reuses operator id: 155] -Output [1]: [d_date_sk#137] +(131) ReusedExchange [Reuses operator id: 152] +Output [1]: [d_date_sk#136] -(135) BroadcastHashJoin [codegen id : 4] -Left keys [1]: [cs_sold_date_sk#136] -Right keys [1]: [d_date_sk#137] +(132) BroadcastHashJoin [codegen id : 4] +Left keys [1]: [cs_sold_date_sk#135] +Right keys [1]: [d_date_sk#136] Join condition: None -(136) Project [codegen id : 4] -Output [2]: [cs_quantity#134 AS quantity#138, cs_list_price#135 AS list_price#139] -Input [4]: [cs_quantity#134, cs_list_price#135, cs_sold_date_sk#136, d_date_sk#137] +(133) Project [codegen id : 4] +Output [2]: [cs_quantity#133 AS quantity#137, cs_list_price#134 AS list_price#138] +Input [4]: [cs_quantity#133, cs_list_price#134, cs_sold_date_sk#135, d_date_sk#136] -(137) Scan parquet default.web_sales -Output [3]: [ws_quantity#140, ws_list_price#141, ws_sold_date_sk#142] +(134) Scan parquet default.web_sales +Output [3]: [ws_quantity#139, ws_list_price#140, ws_sold_date_sk#141] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(ws_sold_date_sk#142), dynamicpruningexpression(ws_sold_date_sk#142 IN dynamicpruning#13)] +PartitionFilters: [isnotnull(ws_sold_date_sk#141), dynamicpruningexpression(ws_sold_date_sk#141 IN dynamicpruning#13)] ReadSchema: struct -(138) ColumnarToRow [codegen id : 6] -Input [3]: [ws_quantity#140, ws_list_price#141, ws_sold_date_sk#142] +(135) ColumnarToRow [codegen id : 6] +Input [3]: [ws_quantity#139, ws_list_price#140, ws_sold_date_sk#141] -(139) ReusedExchange [Reuses operator id: 155] -Output [1]: [d_date_sk#143] +(136) ReusedExchange [Reuses operator id: 152] +Output [1]: [d_date_sk#142] -(140) BroadcastHashJoin [codegen id : 6] -Left keys [1]: [ws_sold_date_sk#142] -Right keys [1]: [d_date_sk#143] +(137) BroadcastHashJoin [codegen id : 6] +Left keys [1]: [ws_sold_date_sk#141] +Right keys [1]: [d_date_sk#142] Join condition: None -(141) Project [codegen id : 6] -Output [2]: [ws_quantity#140 AS quantity#144, ws_list_price#141 AS list_price#145] -Input [4]: [ws_quantity#140, ws_list_price#141, ws_sold_date_sk#142, d_date_sk#143] +(138) Project [codegen id : 6] +Output [2]: [ws_quantity#139 AS quantity#143, ws_list_price#140 AS list_price#144] +Input [4]: [ws_quantity#139, ws_list_price#140, ws_sold_date_sk#141, d_date_sk#142] -(142) Union +(139) Union -(143) HashAggregate [codegen id : 7] -Input [2]: [quantity#132, list_price#133] +(140) HashAggregate [codegen id : 7] +Input [2]: [quantity#131, list_price#132] Keys: [] -Functions [1]: [partial_avg(CheckOverflow((promote_precision(cast(cast(quantity#132 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#133 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [2]: [sum#146, count#147] -Results [2]: [sum#148, count#149] +Functions [1]: [partial_avg(CheckOverflow((promote_precision(cast(quantity#131 as decimal(12,2))) * promote_precision(cast(list_price#132 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [2]: [sum#145, count#146] +Results [2]: [sum#147, count#148] -(144) Exchange -Input [2]: [sum#148, count#149] -Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#150] +(141) Exchange +Input [2]: [sum#147, count#148] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#149] -(145) HashAggregate [codegen id : 8] -Input [2]: [sum#148, count#149] +(142) HashAggregate [codegen id : 8] +Input [2]: [sum#147, count#148] Keys: [] -Functions [1]: [avg(CheckOverflow((promote_precision(cast(cast(quantity#132 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#133 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [1]: [avg(CheckOverflow((promote_precision(cast(cast(quantity#132 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#133 as decimal(12,2)))), DecimalType(18,2), true))#151] -Results [1]: [avg(CheckOverflow((promote_precision(cast(cast(quantity#132 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#133 as decimal(12,2)))), DecimalType(18,2), true))#151 AS average_sales#152] +Functions [1]: [avg(CheckOverflow((promote_precision(cast(quantity#131 as decimal(12,2))) * promote_precision(cast(list_price#132 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [1]: [avg(CheckOverflow((promote_precision(cast(quantity#131 as decimal(12,2))) * promote_precision(cast(list_price#132 as decimal(12,2)))), DecimalType(18,2)))#150] +Results [1]: [avg(CheckOverflow((promote_precision(cast(quantity#131 as decimal(12,2))) * promote_precision(cast(list_price#132 as decimal(12,2)))), DecimalType(18,2)))#150 AS average_sales#151] -Subquery:2 Hosting operator id = 127 Hosting Expression = ss_sold_date_sk#130 IN dynamicpruning#13 +Subquery:2 Hosting operator id = 124 Hosting Expression = ss_sold_date_sk#129 IN dynamicpruning#13 -Subquery:3 Hosting operator id = 132 Hosting Expression = cs_sold_date_sk#136 IN dynamicpruning#13 +Subquery:3 Hosting operator id = 129 Hosting Expression = cs_sold_date_sk#135 IN dynamicpruning#13 -Subquery:4 Hosting operator id = 137 Hosting Expression = ws_sold_date_sk#142 IN dynamicpruning#13 +Subquery:4 Hosting operator id = 134 Hosting Expression = ws_sold_date_sk#141 IN dynamicpruning#13 Subquery:5 Hosting operator id = 1 Hosting Expression = ss_sold_date_sk#4 IN dynamicpruning#5 -BroadcastExchange (150) -+- * Project (149) - +- * Filter (148) - +- * ColumnarToRow (147) - +- Scan parquet default.date_dim (146) +BroadcastExchange (147) ++- * Project (146) + +- * Filter (145) + +- * ColumnarToRow (144) + +- Scan parquet default.date_dim (143) -(146) Scan parquet default.date_dim -Output [3]: [d_date_sk#47, d_year#153, d_moy#154] +(143) Scan parquet default.date_dim +Output [3]: [d_date_sk#46, d_year#152, d_moy#153] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), IsNotNull(d_moy), EqualTo(d_year,2001), EqualTo(d_moy,11), IsNotNull(d_date_sk)] ReadSchema: struct -(147) ColumnarToRow [codegen id : 1] -Input [3]: [d_date_sk#47, d_year#153, d_moy#154] +(144) ColumnarToRow [codegen id : 1] +Input [3]: [d_date_sk#46, d_year#152, d_moy#153] -(148) Filter [codegen id : 1] -Input [3]: [d_date_sk#47, d_year#153, d_moy#154] -Condition : ((((isnotnull(d_year#153) AND isnotnull(d_moy#154)) AND (d_year#153 = 2001)) AND (d_moy#154 = 11)) AND isnotnull(d_date_sk#47)) +(145) Filter [codegen id : 1] +Input [3]: [d_date_sk#46, d_year#152, d_moy#153] +Condition : ((((isnotnull(d_year#152) AND isnotnull(d_moy#153)) AND (d_year#152 = 2001)) AND (d_moy#153 = 11)) AND isnotnull(d_date_sk#46)) -(149) Project [codegen id : 1] -Output [1]: [d_date_sk#47] -Input [3]: [d_date_sk#47, d_year#153, d_moy#154] +(146) Project [codegen id : 1] +Output [1]: [d_date_sk#46] +Input [3]: [d_date_sk#46, d_year#152, d_moy#153] -(150) BroadcastExchange -Input [1]: [d_date_sk#47] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#155] +(147) BroadcastExchange +Input [1]: [d_date_sk#46] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#154] Subquery:6 Hosting operator id = 9 Hosting Expression = ss_sold_date_sk#12 IN dynamicpruning#13 -BroadcastExchange (155) -+- * Project (154) - +- * Filter (153) - +- * ColumnarToRow (152) - +- Scan parquet default.date_dim (151) +BroadcastExchange (152) ++- * Project (151) + +- * Filter (150) + +- * ColumnarToRow (149) + +- Scan parquet default.date_dim (148) -(151) Scan parquet default.date_dim -Output [2]: [d_date_sk#14, d_year#156] +(148) Scan parquet default.date_dim +Output [2]: [d_date_sk#14, d_year#155] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), GreaterThanOrEqual(d_year,1999), LessThanOrEqual(d_year,2001), IsNotNull(d_date_sk)] ReadSchema: struct -(152) ColumnarToRow [codegen id : 1] -Input [2]: [d_date_sk#14, d_year#156] +(149) ColumnarToRow [codegen id : 1] +Input [2]: [d_date_sk#14, d_year#155] -(153) Filter [codegen id : 1] -Input [2]: [d_date_sk#14, d_year#156] -Condition : (((isnotnull(d_year#156) AND (d_year#156 >= 1999)) AND (d_year#156 <= 2001)) AND isnotnull(d_date_sk#14)) +(150) Filter [codegen id : 1] +Input [2]: [d_date_sk#14, d_year#155] +Condition : (((isnotnull(d_year#155) AND (d_year#155 >= 1999)) AND (d_year#155 <= 2001)) AND isnotnull(d_date_sk#14)) -(154) Project [codegen id : 1] +(151) Project [codegen id : 1] Output [1]: [d_date_sk#14] -Input [2]: [d_date_sk#14, d_year#156] +Input [2]: [d_date_sk#14, d_year#155] -(155) BroadcastExchange +(152) BroadcastExchange Input [1]: [d_date_sk#14] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#157] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#156] Subquery:7 Hosting operator id = 20 Hosting Expression = cs_sold_date_sk#21 IN dynamicpruning#13 Subquery:8 Hosting operator id = 43 Hosting Expression = ws_sold_date_sk#36 IN dynamicpruning#13 -Subquery:9 Hosting operator id = 100 Hosting Expression = ReusedSubquery Subquery scalar-subquery#65, [id=#66] +Subquery:9 Hosting operator id = 97 Hosting Expression = ReusedSubquery Subquery scalar-subquery#64, [id=#65] -Subquery:10 Hosting operator id = 83 Hosting Expression = cs_sold_date_sk#71 IN dynamicpruning#5 +Subquery:10 Hosting operator id = 80 Hosting Expression = cs_sold_date_sk#70 IN dynamicpruning#5 -Subquery:11 Hosting operator id = 119 Hosting Expression = ReusedSubquery Subquery scalar-subquery#65, [id=#66] +Subquery:11 Hosting operator id = 116 Hosting Expression = ReusedSubquery Subquery scalar-subquery#64, [id=#65] -Subquery:12 Hosting operator id = 102 Hosting Expression = ws_sold_date_sk#93 IN dynamicpruning#5 +Subquery:12 Hosting operator id = 99 Hosting Expression = ws_sold_date_sk#92 IN dynamicpruning#5 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a.sf100/simplified.txt index 35a8f0d31afc7..f445a370581af 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a.sf100/simplified.txt @@ -1,21 +1,21 @@ TakeOrderedAndProject [channel,i_brand_id,i_class_id,i_category_id,sum(sales),sum(number_sales)] - WholeStageCodegen (140) + WholeStageCodegen (134) HashAggregate [channel,i_brand_id,i_class_id,i_category_id,spark_grouping_id,sum,isEmpty,sum] [sum(sales),sum(number_salesL),sum(sales),sum(number_sales),sum,isEmpty,sum] InputAdapter Exchange [channel,i_brand_id,i_class_id,i_category_id,spark_grouping_id] #1 - WholeStageCodegen (139) + WholeStageCodegen (133) HashAggregate [channel,i_brand_id,i_class_id,i_category_id,spark_grouping_id,sales,number_sales] [sum,isEmpty,sum,sum,isEmpty,sum] Expand [sales,number_sales,channel,i_brand_id,i_class_id,i_category_id] InputAdapter Union - WholeStageCodegen (46) + WholeStageCodegen (44) Project [sales,number_sales,i_brand_id,i_class_id,i_category_id] Filter [sales] Subquery #3 WholeStageCodegen (8) - HashAggregate [sum,count] [avg(CheckOverflow((promote_precision(cast(cast(quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price as decimal(12,2)))), DecimalType(18,2), true)),average_sales,sum,count] + HashAggregate [sum,count] [avg(CheckOverflow((promote_precision(cast(quantity as decimal(12,2))) * promote_precision(cast(list_price as decimal(12,2)))), DecimalType(18,2))),average_sales,sum,count] InputAdapter - Exchange #18 + Exchange #17 WholeStageCodegen (7) HashAggregate [quantity,list_price] [sum,count,sum,count] InputAdapter @@ -28,7 +28,7 @@ TakeOrderedAndProject [channel,i_brand_id,i_class_id,i_category_id,sum(sales),su Scan parquet default.store_sales [ss_quantity,ss_list_price,ss_sold_date_sk] ReusedSubquery [d_date_sk] #2 InputAdapter - ReusedExchange [d_date_sk] #10 + ReusedExchange [d_date_sk] #9 WholeStageCodegen (4) Project [cs_quantity,cs_list_price] BroadcastHashJoin [cs_sold_date_sk,d_date_sk] @@ -37,7 +37,7 @@ TakeOrderedAndProject [channel,i_brand_id,i_class_id,i_category_id,sum(sales),su Scan parquet default.catalog_sales [cs_quantity,cs_list_price,cs_sold_date_sk] ReusedSubquery [d_date_sk] #2 InputAdapter - ReusedExchange [d_date_sk] #10 + ReusedExchange [d_date_sk] #9 WholeStageCodegen (6) Project [ws_quantity,ws_list_price] BroadcastHashJoin [ws_sold_date_sk,d_date_sk] @@ -46,11 +46,11 @@ TakeOrderedAndProject [channel,i_brand_id,i_class_id,i_category_id,sum(sales),su Scan parquet default.web_sales [ws_quantity,ws_list_price,ws_sold_date_sk] ReusedSubquery [d_date_sk] #2 InputAdapter - ReusedExchange [d_date_sk] #10 - HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price as decimal(12,2)))), DecimalType(18,2), true)),count(1),sales,number_sales,sum,isEmpty,count] + ReusedExchange [d_date_sk] #9 + HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(ss_quantity as decimal(12,2))) * promote_precision(cast(ss_list_price as decimal(12,2)))), DecimalType(18,2))),count(1),sales,number_sales,sum,isEmpty,count] InputAdapter Exchange [i_brand_id,i_class_id,i_category_id] #2 - WholeStageCodegen (45) + WholeStageCodegen (43) HashAggregate [i_brand_id,i_class_id,i_category_id,ss_quantity,ss_list_price] [sum,isEmpty,count,sum,isEmpty,count] Project [ss_quantity,ss_list_price,i_brand_id,i_class_id,i_category_id] BroadcastHashJoin [ss_item_sk,i_item_sk] @@ -76,11 +76,11 @@ TakeOrderedAndProject [channel,i_brand_id,i_class_id,i_category_id,sum(sales),su InputAdapter Scan parquet default.date_dim [d_date_sk,d_year,d_moy] InputAdapter - WholeStageCodegen (21) + WholeStageCodegen (20) Sort [ss_item_sk] InputAdapter Exchange [ss_item_sk] #5 - WholeStageCodegen (20) + WholeStageCodegen (19) Project [i_item_sk] BroadcastHashJoin [i_brand_id,i_class_id,i_category_id,brand_id,class_id,category_id] Filter [i_brand_id,i_class_id,i_category_id] @@ -89,128 +89,123 @@ TakeOrderedAndProject [channel,i_brand_id,i_class_id,i_category_id,sum(sales),su Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] InputAdapter BroadcastExchange #6 - WholeStageCodegen (19) - HashAggregate [brand_id,class_id,category_id] + WholeStageCodegen (18) + SortMergeJoin [brand_id,class_id,category_id,i_brand_id,i_class_id,i_category_id] InputAdapter - Exchange [brand_id,class_id,category_id] #7 - WholeStageCodegen (18) - HashAggregate [brand_id,class_id,category_id] - SortMergeJoin [brand_id,class_id,category_id,i_brand_id,i_class_id,i_category_id] - InputAdapter - WholeStageCodegen (13) - Sort [brand_id,class_id,category_id] - InputAdapter - Exchange [brand_id,class_id,category_id] #8 - WholeStageCodegen (12) - HashAggregate [brand_id,class_id,category_id] - InputAdapter - Exchange [brand_id,class_id,category_id] #9 - WholeStageCodegen (11) - HashAggregate [brand_id,class_id,category_id] - Project [i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [ss_item_sk,i_item_sk] - Project [ss_item_sk] - BroadcastHashJoin [ss_sold_date_sk,d_date_sk] - Filter [ss_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.store_sales [ss_item_sk,ss_sold_date_sk] - SubqueryBroadcast [d_date_sk] #2 - BroadcastExchange #10 - WholeStageCodegen (1) - Project [d_date_sk] - Filter [d_year,d_date_sk] - ColumnarToRow - InputAdapter - Scan parquet default.date_dim [d_date_sk,d_year] + WholeStageCodegen (13) + Sort [brand_id,class_id,category_id] + InputAdapter + Exchange [brand_id,class_id,category_id] #7 + WholeStageCodegen (12) + HashAggregate [brand_id,class_id,category_id] + InputAdapter + Exchange [brand_id,class_id,category_id] #8 + WholeStageCodegen (11) + HashAggregate [brand_id,class_id,category_id] + Project [i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [ss_item_sk,i_item_sk] + Project [ss_item_sk] + BroadcastHashJoin [ss_sold_date_sk,d_date_sk] + Filter [ss_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.store_sales [ss_item_sk,ss_sold_date_sk] + SubqueryBroadcast [d_date_sk] #2 + BroadcastExchange #9 + WholeStageCodegen (1) + Project [d_date_sk] + Filter [d_year,d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_year] + InputAdapter + ReusedExchange [d_date_sk] #9 + InputAdapter + BroadcastExchange #10 + WholeStageCodegen (10) + SortMergeJoin [i_brand_id,i_class_id,i_category_id,i_brand_id,i_class_id,i_category_id] + InputAdapter + WholeStageCodegen (5) + Sort [i_brand_id,i_class_id,i_category_id] InputAdapter - ReusedExchange [d_date_sk] #10 - InputAdapter - BroadcastExchange #11 - WholeStageCodegen (10) - SortMergeJoin [i_brand_id,i_class_id,i_category_id,i_brand_id,i_class_id,i_category_id] - InputAdapter - WholeStageCodegen (5) - Sort [i_brand_id,i_class_id,i_category_id] + Exchange [i_brand_id,i_class_id,i_category_id] #11 + WholeStageCodegen (4) + Filter [i_item_sk,i_brand_id,i_class_id,i_category_id] + ColumnarToRow InputAdapter - Exchange [i_brand_id,i_class_id,i_category_id] #12 - WholeStageCodegen (4) - Filter [i_item_sk,i_brand_id,i_class_id,i_category_id] + Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] + InputAdapter + WholeStageCodegen (9) + Sort [i_brand_id,i_class_id,i_category_id] + InputAdapter + Exchange [i_brand_id,i_class_id,i_category_id] #12 + WholeStageCodegen (8) + Project [i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [cs_item_sk,i_item_sk] + Project [cs_item_sk] + BroadcastHashJoin [cs_sold_date_sk,d_date_sk] + Filter [cs_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.catalog_sales [cs_item_sk,cs_sold_date_sk] + ReusedSubquery [d_date_sk] #2 + InputAdapter + ReusedExchange [d_date_sk] #9 + InputAdapter + BroadcastExchange #13 + WholeStageCodegen (7) + Filter [i_item_sk] ColumnarToRow InputAdapter Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] - InputAdapter - WholeStageCodegen (9) - Sort [i_brand_id,i_class_id,i_category_id] - InputAdapter - Exchange [i_brand_id,i_class_id,i_category_id] #13 - WholeStageCodegen (8) - Project [i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [cs_item_sk,i_item_sk] - Project [cs_item_sk] - BroadcastHashJoin [cs_sold_date_sk,d_date_sk] - Filter [cs_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.catalog_sales [cs_item_sk,cs_sold_date_sk] - ReusedSubquery [d_date_sk] #2 - InputAdapter - ReusedExchange [d_date_sk] #10 - InputAdapter - BroadcastExchange #14 - WholeStageCodegen (7) - Filter [i_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] - InputAdapter - WholeStageCodegen (17) - Sort [i_brand_id,i_class_id,i_category_id] + InputAdapter + WholeStageCodegen (17) + Sort [i_brand_id,i_class_id,i_category_id] + InputAdapter + Exchange [i_brand_id,i_class_id,i_category_id] #14 + WholeStageCodegen (16) + Project [i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [ws_item_sk,i_item_sk] + Project [ws_item_sk] + BroadcastHashJoin [ws_sold_date_sk,d_date_sk] + Filter [ws_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.web_sales [ws_item_sk,ws_sold_date_sk] + ReusedSubquery [d_date_sk] #2 + InputAdapter + ReusedExchange [d_date_sk] #9 InputAdapter - Exchange [i_brand_id,i_class_id,i_category_id] #15 - WholeStageCodegen (16) - Project [i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [ws_item_sk,i_item_sk] - Project [ws_item_sk] - BroadcastHashJoin [ws_sold_date_sk,d_date_sk] - Filter [ws_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.web_sales [ws_item_sk,ws_sold_date_sk] - ReusedSubquery [d_date_sk] #2 - InputAdapter - ReusedExchange [d_date_sk] #10 - InputAdapter - ReusedExchange [i_item_sk,i_brand_id,i_class_id,i_category_id] #14 + ReusedExchange [i_item_sk,i_brand_id,i_class_id,i_category_id] #13 InputAdapter ReusedExchange [d_date_sk] #4 InputAdapter - BroadcastExchange #16 - WholeStageCodegen (44) + BroadcastExchange #15 + WholeStageCodegen (42) SortMergeJoin [i_item_sk,ss_item_sk] InputAdapter - WholeStageCodegen (24) + WholeStageCodegen (23) Sort [i_item_sk] InputAdapter - Exchange [i_item_sk] #17 - WholeStageCodegen (23) + Exchange [i_item_sk] #16 + WholeStageCodegen (22) Filter [i_item_sk] ColumnarToRow InputAdapter Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] InputAdapter - WholeStageCodegen (43) + WholeStageCodegen (41) Sort [ss_item_sk] InputAdapter ReusedExchange [ss_item_sk] #5 - WholeStageCodegen (92) + WholeStageCodegen (88) Project [sales,number_sales,i_brand_id,i_class_id,i_category_id] Filter [sales] ReusedSubquery [average_sales] #3 - HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price as decimal(12,2)))), DecimalType(18,2), true)),count(1),sales,number_sales,sum,isEmpty,count] + HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(cs_quantity as decimal(12,2))) * promote_precision(cast(cs_list_price as decimal(12,2)))), DecimalType(18,2))),count(1),sales,number_sales,sum,isEmpty,count] InputAdapter - Exchange [i_brand_id,i_class_id,i_category_id] #19 - WholeStageCodegen (91) + Exchange [i_brand_id,i_class_id,i_category_id] #18 + WholeStageCodegen (87) HashAggregate [i_brand_id,i_class_id,i_category_id,cs_quantity,cs_list_price] [sum,isEmpty,count,sum,isEmpty,count] Project [cs_quantity,cs_list_price,i_brand_id,i_class_id,i_category_id] BroadcastHashJoin [cs_item_sk,i_item_sk] @@ -218,33 +213,33 @@ TakeOrderedAndProject [channel,i_brand_id,i_class_id,i_category_id,sum(sales),su BroadcastHashJoin [cs_sold_date_sk,d_date_sk] SortMergeJoin [cs_item_sk,ss_item_sk] InputAdapter - WholeStageCodegen (48) + WholeStageCodegen (46) Sort [cs_item_sk] InputAdapter - Exchange [cs_item_sk] #20 - WholeStageCodegen (47) + Exchange [cs_item_sk] #19 + WholeStageCodegen (45) Filter [cs_item_sk] ColumnarToRow InputAdapter Scan parquet default.catalog_sales [cs_item_sk,cs_quantity,cs_list_price,cs_sold_date_sk] ReusedSubquery [d_date_sk] #1 InputAdapter - WholeStageCodegen (67) + WholeStageCodegen (64) Sort [ss_item_sk] InputAdapter ReusedExchange [ss_item_sk] #5 InputAdapter ReusedExchange [d_date_sk] #4 InputAdapter - ReusedExchange [i_item_sk,i_brand_id,i_class_id,i_category_id] #16 - WholeStageCodegen (138) + ReusedExchange [i_item_sk,i_brand_id,i_class_id,i_category_id] #15 + WholeStageCodegen (132) Project [sales,number_sales,i_brand_id,i_class_id,i_category_id] Filter [sales] ReusedSubquery [average_sales] #3 - HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price as decimal(12,2)))), DecimalType(18,2), true)),count(1),sales,number_sales,sum,isEmpty,count] + HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(ws_quantity as decimal(12,2))) * promote_precision(cast(ws_list_price as decimal(12,2)))), DecimalType(18,2))),count(1),sales,number_sales,sum,isEmpty,count] InputAdapter - Exchange [i_brand_id,i_class_id,i_category_id] #21 - WholeStageCodegen (137) + Exchange [i_brand_id,i_class_id,i_category_id] #20 + WholeStageCodegen (131) HashAggregate [i_brand_id,i_class_id,i_category_id,ws_quantity,ws_list_price] [sum,isEmpty,count,sum,isEmpty,count] Project [ws_quantity,ws_list_price,i_brand_id,i_class_id,i_category_id] BroadcastHashJoin [ws_item_sk,i_item_sk] @@ -252,22 +247,22 @@ TakeOrderedAndProject [channel,i_brand_id,i_class_id,i_category_id,sum(sales),su BroadcastHashJoin [ws_sold_date_sk,d_date_sk] SortMergeJoin [ws_item_sk,ss_item_sk] InputAdapter - WholeStageCodegen (94) + WholeStageCodegen (90) Sort [ws_item_sk] InputAdapter - Exchange [ws_item_sk] #22 - WholeStageCodegen (93) + Exchange [ws_item_sk] #21 + WholeStageCodegen (89) Filter [ws_item_sk] ColumnarToRow InputAdapter Scan parquet default.web_sales [ws_item_sk,ws_quantity,ws_list_price,ws_sold_date_sk] ReusedSubquery [d_date_sk] #1 InputAdapter - WholeStageCodegen (113) + WholeStageCodegen (108) Sort [ss_item_sk] InputAdapter ReusedExchange [ss_item_sk] #5 InputAdapter ReusedExchange [d_date_sk] #4 InputAdapter - ReusedExchange [i_item_sk,i_brand_id,i_class_id,i_category_id] #16 + ReusedExchange [i_item_sk,i_brand_id,i_class_id,i_category_id] #15 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a/explain.txt index cf4bb6501bd92..300cfd7ccbb21 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a/explain.txt @@ -1,111 +1,109 @@ == Physical Plan == -TakeOrderedAndProject (107) -+- * HashAggregate (106) - +- Exchange (105) - +- * HashAggregate (104) - +- * Expand (103) - +- Union (102) - :- * Project (69) - : +- * Filter (68) - : +- * HashAggregate (67) - : +- Exchange (66) - : +- * HashAggregate (65) - : +- * Project (64) - : +- * BroadcastHashJoin Inner BuildRight (63) - : :- * Project (61) - : : +- * BroadcastHashJoin Inner BuildRight (60) - : : :- * BroadcastHashJoin LeftSemi BuildRight (53) +TakeOrderedAndProject (105) ++- * HashAggregate (104) + +- Exchange (103) + +- * HashAggregate (102) + +- * Expand (101) + +- Union (100) + :- * Project (67) + : +- * Filter (66) + : +- * HashAggregate (65) + : +- Exchange (64) + : +- * HashAggregate (63) + : +- * Project (62) + : +- * BroadcastHashJoin Inner BuildRight (61) + : :- * Project (59) + : : +- * BroadcastHashJoin Inner BuildRight (58) + : : :- * BroadcastHashJoin LeftSemi BuildRight (51) : : : :- * Filter (3) : : : : +- * ColumnarToRow (2) : : : : +- Scan parquet default.store_sales (1) - : : : +- BroadcastExchange (52) - : : : +- * Project (51) - : : : +- * BroadcastHashJoin Inner BuildRight (50) + : : : +- BroadcastExchange (50) + : : : +- * Project (49) + : : : +- * BroadcastHashJoin Inner BuildRight (48) : : : :- * Filter (6) : : : : +- * ColumnarToRow (5) : : : : +- Scan parquet default.item (4) - : : : +- BroadcastExchange (49) - : : : +- * HashAggregate (48) - : : : +- * HashAggregate (47) - : : : +- * BroadcastHashJoin LeftSemi BuildRight (46) - : : : :- * HashAggregate (35) - : : : : +- Exchange (34) - : : : : +- * HashAggregate (33) - : : : : +- * Project (32) - : : : : +- * BroadcastHashJoin Inner BuildRight (31) - : : : : :- * Project (29) - : : : : : +- * BroadcastHashJoin Inner BuildRight (28) - : : : : : :- * Filter (9) - : : : : : : +- * ColumnarToRow (8) - : : : : : : +- Scan parquet default.store_sales (7) - : : : : : +- BroadcastExchange (27) - : : : : : +- * BroadcastHashJoin LeftSemi BuildRight (26) - : : : : : :- * Filter (12) - : : : : : : +- * ColumnarToRow (11) - : : : : : : +- Scan parquet default.item (10) - : : : : : +- BroadcastExchange (25) - : : : : : +- * Project (24) - : : : : : +- * BroadcastHashJoin Inner BuildRight (23) - : : : : : :- * Project (21) - : : : : : : +- * BroadcastHashJoin Inner BuildRight (20) - : : : : : : :- * Filter (15) - : : : : : : : +- * ColumnarToRow (14) - : : : : : : : +- Scan parquet default.catalog_sales (13) - : : : : : : +- BroadcastExchange (19) - : : : : : : +- * Filter (18) - : : : : : : +- * ColumnarToRow (17) - : : : : : : +- Scan parquet default.item (16) - : : : : : +- ReusedExchange (22) - : : : : +- ReusedExchange (30) - : : : +- BroadcastExchange (45) - : : : +- * Project (44) - : : : +- * BroadcastHashJoin Inner BuildRight (43) - : : : :- * Project (41) - : : : : +- * BroadcastHashJoin Inner BuildRight (40) - : : : : :- * Filter (38) - : : : : : +- * ColumnarToRow (37) - : : : : : +- Scan parquet default.web_sales (36) - : : : : +- ReusedExchange (39) - : : : +- ReusedExchange (42) - : : +- BroadcastExchange (59) - : : +- * BroadcastHashJoin LeftSemi BuildRight (58) - : : :- * Filter (56) - : : : +- * ColumnarToRow (55) - : : : +- Scan parquet default.item (54) - : : +- ReusedExchange (57) - : +- ReusedExchange (62) - :- * Project (85) - : +- * Filter (84) - : +- * HashAggregate (83) - : +- Exchange (82) - : +- * HashAggregate (81) - : +- * Project (80) - : +- * BroadcastHashJoin Inner BuildRight (79) - : :- * Project (77) - : : +- * BroadcastHashJoin Inner BuildRight (76) - : : :- * BroadcastHashJoin LeftSemi BuildRight (74) - : : : :- * Filter (72) - : : : : +- * ColumnarToRow (71) - : : : : +- Scan parquet default.catalog_sales (70) - : : : +- ReusedExchange (73) - : : +- ReusedExchange (75) - : +- ReusedExchange (78) - +- * Project (101) - +- * Filter (100) - +- * HashAggregate (99) - +- Exchange (98) - +- * HashAggregate (97) - +- * Project (96) - +- * BroadcastHashJoin Inner BuildRight (95) - :- * Project (93) - : +- * BroadcastHashJoin Inner BuildRight (92) - : :- * BroadcastHashJoin LeftSemi BuildRight (90) - : : :- * Filter (88) - : : : +- * ColumnarToRow (87) - : : : +- Scan parquet default.web_sales (86) - : : +- ReusedExchange (89) - : +- ReusedExchange (91) - +- ReusedExchange (94) + : : : +- BroadcastExchange (47) + : : : +- * BroadcastHashJoin LeftSemi BuildRight (46) + : : : :- * HashAggregate (35) + : : : : +- Exchange (34) + : : : : +- * HashAggregate (33) + : : : : +- * Project (32) + : : : : +- * BroadcastHashJoin Inner BuildRight (31) + : : : : :- * Project (29) + : : : : : +- * BroadcastHashJoin Inner BuildRight (28) + : : : : : :- * Filter (9) + : : : : : : +- * ColumnarToRow (8) + : : : : : : +- Scan parquet default.store_sales (7) + : : : : : +- BroadcastExchange (27) + : : : : : +- * BroadcastHashJoin LeftSemi BuildRight (26) + : : : : : :- * Filter (12) + : : : : : : +- * ColumnarToRow (11) + : : : : : : +- Scan parquet default.item (10) + : : : : : +- BroadcastExchange (25) + : : : : : +- * Project (24) + : : : : : +- * BroadcastHashJoin Inner BuildRight (23) + : : : : : :- * Project (21) + : : : : : : +- * BroadcastHashJoin Inner BuildRight (20) + : : : : : : :- * Filter (15) + : : : : : : : +- * ColumnarToRow (14) + : : : : : : : +- Scan parquet default.catalog_sales (13) + : : : : : : +- BroadcastExchange (19) + : : : : : : +- * Filter (18) + : : : : : : +- * ColumnarToRow (17) + : : : : : : +- Scan parquet default.item (16) + : : : : : +- ReusedExchange (22) + : : : : +- ReusedExchange (30) + : : : +- BroadcastExchange (45) + : : : +- * Project (44) + : : : +- * BroadcastHashJoin Inner BuildRight (43) + : : : :- * Project (41) + : : : : +- * BroadcastHashJoin Inner BuildRight (40) + : : : : :- * Filter (38) + : : : : : +- * ColumnarToRow (37) + : : : : : +- Scan parquet default.web_sales (36) + : : : : +- ReusedExchange (39) + : : : +- ReusedExchange (42) + : : +- BroadcastExchange (57) + : : +- * BroadcastHashJoin LeftSemi BuildRight (56) + : : :- * Filter (54) + : : : +- * ColumnarToRow (53) + : : : +- Scan parquet default.item (52) + : : +- ReusedExchange (55) + : +- ReusedExchange (60) + :- * Project (83) + : +- * Filter (82) + : +- * HashAggregate (81) + : +- Exchange (80) + : +- * HashAggregate (79) + : +- * Project (78) + : +- * BroadcastHashJoin Inner BuildRight (77) + : :- * Project (75) + : : +- * BroadcastHashJoin Inner BuildRight (74) + : : :- * BroadcastHashJoin LeftSemi BuildRight (72) + : : : :- * Filter (70) + : : : : +- * ColumnarToRow (69) + : : : : +- Scan parquet default.catalog_sales (68) + : : : +- ReusedExchange (71) + : : +- ReusedExchange (73) + : +- ReusedExchange (76) + +- * Project (99) + +- * Filter (98) + +- * HashAggregate (97) + +- Exchange (96) + +- * HashAggregate (95) + +- * Project (94) + +- * BroadcastHashJoin Inner BuildRight (93) + :- * Project (91) + : +- * BroadcastHashJoin Inner BuildRight (90) + : :- * BroadcastHashJoin LeftSemi BuildRight (88) + : : :- * Filter (86) + : : : +- * ColumnarToRow (85) + : : : +- Scan parquet default.web_sales (84) + : : +- ReusedExchange (87) + : +- ReusedExchange (89) + +- ReusedExchange (92) (1) Scan parquet default.store_sales @@ -208,7 +206,7 @@ Join condition: None Output [4]: [cs_sold_date_sk#18, i_brand_id#20, i_class_id#21, i_category_id#22] Input [6]: [cs_item_sk#17, cs_sold_date_sk#18, i_item_sk#19, i_brand_id#20, i_class_id#21, i_category_id#22] -(22) ReusedExchange [Reuses operator id: 136] +(22) ReusedExchange [Reuses operator id: 134] Output [1]: [d_date_sk#24] (23) BroadcastHashJoin [codegen id : 3] @@ -242,7 +240,7 @@ Join condition: None Output [4]: [ss_sold_date_sk#11, i_brand_id#14, i_class_id#15, i_category_id#16] Input [6]: [ss_item_sk#10, ss_sold_date_sk#11, i_item_sk#13, i_brand_id#14, i_class_id#15, i_category_id#16] -(30) ReusedExchange [Reuses operator id: 136] +(30) ReusedExchange [Reuses operator id: 134] Output [1]: [d_date_sk#27] (31) BroadcastHashJoin [codegen id : 6] @@ -299,7 +297,7 @@ Join condition: None Output [4]: [ws_sold_date_sk#33, i_brand_id#35, i_class_id#36, i_category_id#37] Input [6]: [ws_item_sk#32, ws_sold_date_sk#33, i_item_sk#34, i_brand_id#35, i_class_id#36, i_category_id#37] -(42) ReusedExchange [Reuses operator id: 136] +(42) ReusedExchange [Reuses operator id: 134] Output [1]: [d_date_sk#38] (43) BroadcastHashJoin [codegen id : 9] @@ -320,116 +318,102 @@ Left keys [6]: [coalesce(brand_id#28, 0), isnull(brand_id#28), coalesce(class_id Right keys [6]: [coalesce(i_brand_id#35, 0), isnull(i_brand_id#35), coalesce(i_class_id#36, 0), isnull(i_class_id#36), coalesce(i_category_id#37, 0), isnull(i_category_id#37)] Join condition: None -(47) HashAggregate [codegen id : 10] -Input [3]: [brand_id#28, class_id#29, category_id#30] -Keys [3]: [brand_id#28, class_id#29, category_id#30] -Functions: [] -Aggregate Attributes: [] -Results [3]: [brand_id#28, class_id#29, category_id#30] - -(48) HashAggregate [codegen id : 10] -Input [3]: [brand_id#28, class_id#29, category_id#30] -Keys [3]: [brand_id#28, class_id#29, category_id#30] -Functions: [] -Aggregate Attributes: [] -Results [3]: [brand_id#28, class_id#29, category_id#30] - -(49) BroadcastExchange +(47) BroadcastExchange Input [3]: [brand_id#28, class_id#29, category_id#30] Arguments: HashedRelationBroadcastMode(List(input[0, int, true], input[1, int, true], input[2, int, true]),false), [id=#40] -(50) BroadcastHashJoin [codegen id : 11] +(48) BroadcastHashJoin [codegen id : 11] Left keys [3]: [i_brand_id#7, i_class_id#8, i_category_id#9] Right keys [3]: [brand_id#28, class_id#29, category_id#30] Join condition: None -(51) Project [codegen id : 11] +(49) Project [codegen id : 11] Output [1]: [i_item_sk#6 AS ss_item_sk#41] Input [7]: [i_item_sk#6, i_brand_id#7, i_class_id#8, i_category_id#9, brand_id#28, class_id#29, category_id#30] -(52) BroadcastExchange +(50) BroadcastExchange Input [1]: [ss_item_sk#41] Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#42] -(53) BroadcastHashJoin [codegen id : 25] +(51) BroadcastHashJoin [codegen id : 25] Left keys [1]: [ss_item_sk#1] Right keys [1]: [ss_item_sk#41] Join condition: None -(54) Scan parquet default.item +(52) Scan parquet default.item Output [4]: [i_item_sk#43, i_brand_id#44, i_class_id#45, i_category_id#46] Batched: true Location [not included in comparison]/{warehouse_dir}/item] PushedFilters: [IsNotNull(i_item_sk)] ReadSchema: struct -(55) ColumnarToRow [codegen id : 23] +(53) ColumnarToRow [codegen id : 23] Input [4]: [i_item_sk#43, i_brand_id#44, i_class_id#45, i_category_id#46] -(56) Filter [codegen id : 23] +(54) Filter [codegen id : 23] Input [4]: [i_item_sk#43, i_brand_id#44, i_class_id#45, i_category_id#46] Condition : isnotnull(i_item_sk#43) -(57) ReusedExchange [Reuses operator id: 52] +(55) ReusedExchange [Reuses operator id: 50] Output [1]: [ss_item_sk#41] -(58) BroadcastHashJoin [codegen id : 23] +(56) BroadcastHashJoin [codegen id : 23] Left keys [1]: [i_item_sk#43] Right keys [1]: [ss_item_sk#41] Join condition: None -(59) BroadcastExchange +(57) BroadcastExchange Input [4]: [i_item_sk#43, i_brand_id#44, i_class_id#45, i_category_id#46] Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#47] -(60) BroadcastHashJoin [codegen id : 25] +(58) BroadcastHashJoin [codegen id : 25] Left keys [1]: [ss_item_sk#1] Right keys [1]: [i_item_sk#43] Join condition: None -(61) Project [codegen id : 25] +(59) Project [codegen id : 25] Output [6]: [ss_quantity#2, ss_list_price#3, ss_sold_date_sk#4, i_brand_id#44, i_class_id#45, i_category_id#46] Input [8]: [ss_item_sk#1, ss_quantity#2, ss_list_price#3, ss_sold_date_sk#4, i_item_sk#43, i_brand_id#44, i_class_id#45, i_category_id#46] -(62) ReusedExchange [Reuses operator id: 131] +(60) ReusedExchange [Reuses operator id: 129] Output [1]: [d_date_sk#48] -(63) BroadcastHashJoin [codegen id : 25] +(61) BroadcastHashJoin [codegen id : 25] Left keys [1]: [ss_sold_date_sk#4] Right keys [1]: [d_date_sk#48] Join condition: None -(64) Project [codegen id : 25] +(62) Project [codegen id : 25] Output [5]: [ss_quantity#2, ss_list_price#3, i_brand_id#44, i_class_id#45, i_category_id#46] Input [7]: [ss_quantity#2, ss_list_price#3, ss_sold_date_sk#4, i_brand_id#44, i_class_id#45, i_category_id#46, d_date_sk#48] -(65) HashAggregate [codegen id : 25] +(63) HashAggregate [codegen id : 25] Input [5]: [ss_quantity#2, ss_list_price#3, i_brand_id#44, i_class_id#45, i_category_id#46] Keys [3]: [i_brand_id#44, i_class_id#45, i_category_id#46] -Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true)), partial_count(1)] +Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2))), partial_count(1)] Aggregate Attributes [3]: [sum#49, isEmpty#50, count#51] Results [6]: [i_brand_id#44, i_class_id#45, i_category_id#46, sum#52, isEmpty#53, count#54] -(66) Exchange +(64) Exchange Input [6]: [i_brand_id#44, i_class_id#45, i_category_id#46, sum#52, isEmpty#53, count#54] Arguments: hashpartitioning(i_brand_id#44, i_class_id#45, i_category_id#46, 5), ENSURE_REQUIREMENTS, [id=#55] -(67) HashAggregate [codegen id : 26] +(65) HashAggregate [codegen id : 26] Input [6]: [i_brand_id#44, i_class_id#45, i_category_id#46, sum#52, isEmpty#53, count#54] Keys [3]: [i_brand_id#44, i_class_id#45, i_category_id#46] -Functions [2]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true)), count(1)] -Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true))#56, count(1)#57] -Results [5]: [i_brand_id#44, i_class_id#45, i_category_id#46, sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true))#56 AS sales#58, count(1)#57 AS number_sales#59] +Functions [2]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2))), count(1)] +Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2)))#56, count(1)#57] +Results [5]: [i_brand_id#44, i_class_id#45, i_category_id#46, sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2)))#56 AS sales#58, count(1)#57 AS number_sales#59] -(68) Filter [codegen id : 26] +(66) Filter [codegen id : 26] Input [5]: [i_brand_id#44, i_class_id#45, i_category_id#46, sales#58, number_sales#59] Condition : (isnotnull(sales#58) AND (cast(sales#58 as decimal(32,6)) > cast(Subquery scalar-subquery#60, [id=#61] as decimal(32,6)))) -(69) Project [codegen id : 26] +(67) Project [codegen id : 26] Output [6]: [sales#58, number_sales#59, store AS channel#62, i_brand_id#44, i_class_id#45, i_category_id#46] Input [5]: [i_brand_id#44, i_class_id#45, i_category_id#46, sales#58, number_sales#59] -(70) Scan parquet default.catalog_sales +(68) Scan parquet default.catalog_sales Output [4]: [cs_item_sk#63, cs_quantity#64, cs_list_price#65, cs_sold_date_sk#66] Batched: true Location: InMemoryFileIndex [] @@ -437,72 +421,72 @@ PartitionFilters: [isnotnull(cs_sold_date_sk#66), dynamicpruningexpression(cs_so PushedFilters: [IsNotNull(cs_item_sk)] ReadSchema: struct -(71) ColumnarToRow [codegen id : 51] +(69) ColumnarToRow [codegen id : 51] Input [4]: [cs_item_sk#63, cs_quantity#64, cs_list_price#65, cs_sold_date_sk#66] -(72) Filter [codegen id : 51] +(70) Filter [codegen id : 51] Input [4]: [cs_item_sk#63, cs_quantity#64, cs_list_price#65, cs_sold_date_sk#66] Condition : isnotnull(cs_item_sk#63) -(73) ReusedExchange [Reuses operator id: 52] +(71) ReusedExchange [Reuses operator id: 50] Output [1]: [ss_item_sk#41] -(74) BroadcastHashJoin [codegen id : 51] +(72) BroadcastHashJoin [codegen id : 51] Left keys [1]: [cs_item_sk#63] Right keys [1]: [ss_item_sk#41] Join condition: None -(75) ReusedExchange [Reuses operator id: 59] +(73) ReusedExchange [Reuses operator id: 57] Output [4]: [i_item_sk#67, i_brand_id#68, i_class_id#69, i_category_id#70] -(76) BroadcastHashJoin [codegen id : 51] +(74) BroadcastHashJoin [codegen id : 51] Left keys [1]: [cs_item_sk#63] Right keys [1]: [i_item_sk#67] Join condition: None -(77) Project [codegen id : 51] +(75) Project [codegen id : 51] Output [6]: [cs_quantity#64, cs_list_price#65, cs_sold_date_sk#66, i_brand_id#68, i_class_id#69, i_category_id#70] Input [8]: [cs_item_sk#63, cs_quantity#64, cs_list_price#65, cs_sold_date_sk#66, i_item_sk#67, i_brand_id#68, i_class_id#69, i_category_id#70] -(78) ReusedExchange [Reuses operator id: 131] +(76) ReusedExchange [Reuses operator id: 129] Output [1]: [d_date_sk#71] -(79) BroadcastHashJoin [codegen id : 51] +(77) BroadcastHashJoin [codegen id : 51] Left keys [1]: [cs_sold_date_sk#66] Right keys [1]: [d_date_sk#71] Join condition: None -(80) Project [codegen id : 51] +(78) Project [codegen id : 51] Output [5]: [cs_quantity#64, cs_list_price#65, i_brand_id#68, i_class_id#69, i_category_id#70] Input [7]: [cs_quantity#64, cs_list_price#65, cs_sold_date_sk#66, i_brand_id#68, i_class_id#69, i_category_id#70, d_date_sk#71] -(81) HashAggregate [codegen id : 51] +(79) HashAggregate [codegen id : 51] Input [5]: [cs_quantity#64, cs_list_price#65, i_brand_id#68, i_class_id#69, i_category_id#70] Keys [3]: [i_brand_id#68, i_class_id#69, i_category_id#70] -Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(cast(cs_quantity#64 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price#65 as decimal(12,2)))), DecimalType(18,2), true)), partial_count(1)] +Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(cs_quantity#64 as decimal(12,2))) * promote_precision(cast(cs_list_price#65 as decimal(12,2)))), DecimalType(18,2))), partial_count(1)] Aggregate Attributes [3]: [sum#72, isEmpty#73, count#74] Results [6]: [i_brand_id#68, i_class_id#69, i_category_id#70, sum#75, isEmpty#76, count#77] -(82) Exchange +(80) Exchange Input [6]: [i_brand_id#68, i_class_id#69, i_category_id#70, sum#75, isEmpty#76, count#77] Arguments: hashpartitioning(i_brand_id#68, i_class_id#69, i_category_id#70, 5), ENSURE_REQUIREMENTS, [id=#78] -(83) HashAggregate [codegen id : 52] +(81) HashAggregate [codegen id : 52] Input [6]: [i_brand_id#68, i_class_id#69, i_category_id#70, sum#75, isEmpty#76, count#77] Keys [3]: [i_brand_id#68, i_class_id#69, i_category_id#70] -Functions [2]: [sum(CheckOverflow((promote_precision(cast(cast(cs_quantity#64 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price#65 as decimal(12,2)))), DecimalType(18,2), true)), count(1)] -Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(cast(cs_quantity#64 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price#65 as decimal(12,2)))), DecimalType(18,2), true))#79, count(1)#80] -Results [5]: [i_brand_id#68, i_class_id#69, i_category_id#70, sum(CheckOverflow((promote_precision(cast(cast(cs_quantity#64 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price#65 as decimal(12,2)))), DecimalType(18,2), true))#79 AS sales#81, count(1)#80 AS number_sales#82] +Functions [2]: [sum(CheckOverflow((promote_precision(cast(cs_quantity#64 as decimal(12,2))) * promote_precision(cast(cs_list_price#65 as decimal(12,2)))), DecimalType(18,2))), count(1)] +Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(cs_quantity#64 as decimal(12,2))) * promote_precision(cast(cs_list_price#65 as decimal(12,2)))), DecimalType(18,2)))#79, count(1)#80] +Results [5]: [i_brand_id#68, i_class_id#69, i_category_id#70, sum(CheckOverflow((promote_precision(cast(cs_quantity#64 as decimal(12,2))) * promote_precision(cast(cs_list_price#65 as decimal(12,2)))), DecimalType(18,2)))#79 AS sales#81, count(1)#80 AS number_sales#82] -(84) Filter [codegen id : 52] +(82) Filter [codegen id : 52] Input [5]: [i_brand_id#68, i_class_id#69, i_category_id#70, sales#81, number_sales#82] Condition : (isnotnull(sales#81) AND (cast(sales#81 as decimal(32,6)) > cast(ReusedSubquery Subquery scalar-subquery#60, [id=#61] as decimal(32,6)))) -(85) Project [codegen id : 52] +(83) Project [codegen id : 52] Output [6]: [sales#81, number_sales#82, catalog AS channel#83, i_brand_id#68, i_class_id#69, i_category_id#70] Input [5]: [i_brand_id#68, i_class_id#69, i_category_id#70, sales#81, number_sales#82] -(86) Scan parquet default.web_sales +(84) Scan parquet default.web_sales Output [4]: [ws_item_sk#84, ws_quantity#85, ws_list_price#86, ws_sold_date_sk#87] Batched: true Location: InMemoryFileIndex [] @@ -510,272 +494,272 @@ PartitionFilters: [isnotnull(ws_sold_date_sk#87), dynamicpruningexpression(ws_so PushedFilters: [IsNotNull(ws_item_sk)] ReadSchema: struct -(87) ColumnarToRow [codegen id : 77] +(85) ColumnarToRow [codegen id : 77] Input [4]: [ws_item_sk#84, ws_quantity#85, ws_list_price#86, ws_sold_date_sk#87] -(88) Filter [codegen id : 77] +(86) Filter [codegen id : 77] Input [4]: [ws_item_sk#84, ws_quantity#85, ws_list_price#86, ws_sold_date_sk#87] Condition : isnotnull(ws_item_sk#84) -(89) ReusedExchange [Reuses operator id: 52] +(87) ReusedExchange [Reuses operator id: 50] Output [1]: [ss_item_sk#41] -(90) BroadcastHashJoin [codegen id : 77] +(88) BroadcastHashJoin [codegen id : 77] Left keys [1]: [ws_item_sk#84] Right keys [1]: [ss_item_sk#41] Join condition: None -(91) ReusedExchange [Reuses operator id: 59] +(89) ReusedExchange [Reuses operator id: 57] Output [4]: [i_item_sk#88, i_brand_id#89, i_class_id#90, i_category_id#91] -(92) BroadcastHashJoin [codegen id : 77] +(90) BroadcastHashJoin [codegen id : 77] Left keys [1]: [ws_item_sk#84] Right keys [1]: [i_item_sk#88] Join condition: None -(93) Project [codegen id : 77] +(91) Project [codegen id : 77] Output [6]: [ws_quantity#85, ws_list_price#86, ws_sold_date_sk#87, i_brand_id#89, i_class_id#90, i_category_id#91] Input [8]: [ws_item_sk#84, ws_quantity#85, ws_list_price#86, ws_sold_date_sk#87, i_item_sk#88, i_brand_id#89, i_class_id#90, i_category_id#91] -(94) ReusedExchange [Reuses operator id: 131] +(92) ReusedExchange [Reuses operator id: 129] Output [1]: [d_date_sk#92] -(95) BroadcastHashJoin [codegen id : 77] +(93) BroadcastHashJoin [codegen id : 77] Left keys [1]: [ws_sold_date_sk#87] Right keys [1]: [d_date_sk#92] Join condition: None -(96) Project [codegen id : 77] +(94) Project [codegen id : 77] Output [5]: [ws_quantity#85, ws_list_price#86, i_brand_id#89, i_class_id#90, i_category_id#91] Input [7]: [ws_quantity#85, ws_list_price#86, ws_sold_date_sk#87, i_brand_id#89, i_class_id#90, i_category_id#91, d_date_sk#92] -(97) HashAggregate [codegen id : 77] +(95) HashAggregate [codegen id : 77] Input [5]: [ws_quantity#85, ws_list_price#86, i_brand_id#89, i_class_id#90, i_category_id#91] Keys [3]: [i_brand_id#89, i_class_id#90, i_category_id#91] -Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(cast(ws_quantity#85 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price#86 as decimal(12,2)))), DecimalType(18,2), true)), partial_count(1)] +Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(ws_quantity#85 as decimal(12,2))) * promote_precision(cast(ws_list_price#86 as decimal(12,2)))), DecimalType(18,2))), partial_count(1)] Aggregate Attributes [3]: [sum#93, isEmpty#94, count#95] Results [6]: [i_brand_id#89, i_class_id#90, i_category_id#91, sum#96, isEmpty#97, count#98] -(98) Exchange +(96) Exchange Input [6]: [i_brand_id#89, i_class_id#90, i_category_id#91, sum#96, isEmpty#97, count#98] Arguments: hashpartitioning(i_brand_id#89, i_class_id#90, i_category_id#91, 5), ENSURE_REQUIREMENTS, [id=#99] -(99) HashAggregate [codegen id : 78] +(97) HashAggregate [codegen id : 78] Input [6]: [i_brand_id#89, i_class_id#90, i_category_id#91, sum#96, isEmpty#97, count#98] Keys [3]: [i_brand_id#89, i_class_id#90, i_category_id#91] -Functions [2]: [sum(CheckOverflow((promote_precision(cast(cast(ws_quantity#85 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price#86 as decimal(12,2)))), DecimalType(18,2), true)), count(1)] -Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(cast(ws_quantity#85 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price#86 as decimal(12,2)))), DecimalType(18,2), true))#100, count(1)#101] -Results [5]: [i_brand_id#89, i_class_id#90, i_category_id#91, sum(CheckOverflow((promote_precision(cast(cast(ws_quantity#85 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price#86 as decimal(12,2)))), DecimalType(18,2), true))#100 AS sales#102, count(1)#101 AS number_sales#103] +Functions [2]: [sum(CheckOverflow((promote_precision(cast(ws_quantity#85 as decimal(12,2))) * promote_precision(cast(ws_list_price#86 as decimal(12,2)))), DecimalType(18,2))), count(1)] +Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(ws_quantity#85 as decimal(12,2))) * promote_precision(cast(ws_list_price#86 as decimal(12,2)))), DecimalType(18,2)))#100, count(1)#101] +Results [5]: [i_brand_id#89, i_class_id#90, i_category_id#91, sum(CheckOverflow((promote_precision(cast(ws_quantity#85 as decimal(12,2))) * promote_precision(cast(ws_list_price#86 as decimal(12,2)))), DecimalType(18,2)))#100 AS sales#102, count(1)#101 AS number_sales#103] -(100) Filter [codegen id : 78] +(98) Filter [codegen id : 78] Input [5]: [i_brand_id#89, i_class_id#90, i_category_id#91, sales#102, number_sales#103] Condition : (isnotnull(sales#102) AND (cast(sales#102 as decimal(32,6)) > cast(ReusedSubquery Subquery scalar-subquery#60, [id=#61] as decimal(32,6)))) -(101) Project [codegen id : 78] +(99) Project [codegen id : 78] Output [6]: [sales#102, number_sales#103, web AS channel#104, i_brand_id#89, i_class_id#90, i_category_id#91] Input [5]: [i_brand_id#89, i_class_id#90, i_category_id#91, sales#102, number_sales#103] -(102) Union +(100) Union -(103) Expand [codegen id : 79] +(101) Expand [codegen id : 79] Input [6]: [sales#58, number_sales#59, channel#62, i_brand_id#44, i_class_id#45, i_category_id#46] Arguments: [[sales#58, number_sales#59, channel#62, i_brand_id#44, i_class_id#45, i_category_id#46, 0], [sales#58, number_sales#59, channel#62, i_brand_id#44, i_class_id#45, null, 1], [sales#58, number_sales#59, channel#62, i_brand_id#44, null, null, 3], [sales#58, number_sales#59, channel#62, null, null, null, 7], [sales#58, number_sales#59, null, null, null, null, 15]], [sales#58, number_sales#59, channel#105, i_brand_id#106, i_class_id#107, i_category_id#108, spark_grouping_id#109] -(104) HashAggregate [codegen id : 79] +(102) HashAggregate [codegen id : 79] Input [7]: [sales#58, number_sales#59, channel#105, i_brand_id#106, i_class_id#107, i_category_id#108, spark_grouping_id#109] Keys [5]: [channel#105, i_brand_id#106, i_class_id#107, i_category_id#108, spark_grouping_id#109] Functions [2]: [partial_sum(sales#58), partial_sum(number_sales#59)] Aggregate Attributes [3]: [sum#110, isEmpty#111, sum#112] Results [8]: [channel#105, i_brand_id#106, i_class_id#107, i_category_id#108, spark_grouping_id#109, sum#113, isEmpty#114, sum#115] -(105) Exchange +(103) Exchange Input [8]: [channel#105, i_brand_id#106, i_class_id#107, i_category_id#108, spark_grouping_id#109, sum#113, isEmpty#114, sum#115] Arguments: hashpartitioning(channel#105, i_brand_id#106, i_class_id#107, i_category_id#108, spark_grouping_id#109, 5), ENSURE_REQUIREMENTS, [id=#116] -(106) HashAggregate [codegen id : 80] +(104) HashAggregate [codegen id : 80] Input [8]: [channel#105, i_brand_id#106, i_class_id#107, i_category_id#108, spark_grouping_id#109, sum#113, isEmpty#114, sum#115] Keys [5]: [channel#105, i_brand_id#106, i_class_id#107, i_category_id#108, spark_grouping_id#109] Functions [2]: [sum(sales#58), sum(number_sales#59)] Aggregate Attributes [2]: [sum(sales#58)#117, sum(number_sales#59)#118] Results [6]: [channel#105, i_brand_id#106, i_class_id#107, i_category_id#108, sum(sales#58)#117 AS sum(sales)#119, sum(number_sales#59)#118 AS sum(number_sales)#120] -(107) TakeOrderedAndProject +(105) TakeOrderedAndProject Input [6]: [channel#105, i_brand_id#106, i_class_id#107, i_category_id#108, sum(sales)#119, sum(number_sales)#120] Arguments: 100, [channel#105 ASC NULLS FIRST, i_brand_id#106 ASC NULLS FIRST, i_class_id#107 ASC NULLS FIRST, i_category_id#108 ASC NULLS FIRST], [channel#105, i_brand_id#106, i_class_id#107, i_category_id#108, sum(sales)#119, sum(number_sales)#120] ===== Subqueries ===== -Subquery:1 Hosting operator id = 68 Hosting Expression = Subquery scalar-subquery#60, [id=#61] -* HashAggregate (126) -+- Exchange (125) - +- * HashAggregate (124) - +- Union (123) - :- * Project (112) - : +- * BroadcastHashJoin Inner BuildRight (111) - : :- * ColumnarToRow (109) - : : +- Scan parquet default.store_sales (108) - : +- ReusedExchange (110) - :- * Project (117) - : +- * BroadcastHashJoin Inner BuildRight (116) - : :- * ColumnarToRow (114) - : : +- Scan parquet default.catalog_sales (113) - : +- ReusedExchange (115) - +- * Project (122) - +- * BroadcastHashJoin Inner BuildRight (121) - :- * ColumnarToRow (119) - : +- Scan parquet default.web_sales (118) - +- ReusedExchange (120) - - -(108) Scan parquet default.store_sales +Subquery:1 Hosting operator id = 66 Hosting Expression = Subquery scalar-subquery#60, [id=#61] +* HashAggregate (124) ++- Exchange (123) + +- * HashAggregate (122) + +- Union (121) + :- * Project (110) + : +- * BroadcastHashJoin Inner BuildRight (109) + : :- * ColumnarToRow (107) + : : +- Scan parquet default.store_sales (106) + : +- ReusedExchange (108) + :- * Project (115) + : +- * BroadcastHashJoin Inner BuildRight (114) + : :- * ColumnarToRow (112) + : : +- Scan parquet default.catalog_sales (111) + : +- ReusedExchange (113) + +- * Project (120) + +- * BroadcastHashJoin Inner BuildRight (119) + :- * ColumnarToRow (117) + : +- Scan parquet default.web_sales (116) + +- ReusedExchange (118) + + +(106) Scan parquet default.store_sales Output [3]: [ss_quantity#121, ss_list_price#122, ss_sold_date_sk#123] Batched: true Location: InMemoryFileIndex [] PartitionFilters: [isnotnull(ss_sold_date_sk#123), dynamicpruningexpression(ss_sold_date_sk#123 IN dynamicpruning#12)] ReadSchema: struct -(109) ColumnarToRow [codegen id : 2] +(107) ColumnarToRow [codegen id : 2] Input [3]: [ss_quantity#121, ss_list_price#122, ss_sold_date_sk#123] -(110) ReusedExchange [Reuses operator id: 136] +(108) ReusedExchange [Reuses operator id: 134] Output [1]: [d_date_sk#124] -(111) BroadcastHashJoin [codegen id : 2] +(109) BroadcastHashJoin [codegen id : 2] Left keys [1]: [ss_sold_date_sk#123] Right keys [1]: [d_date_sk#124] Join condition: None -(112) Project [codegen id : 2] +(110) Project [codegen id : 2] Output [2]: [ss_quantity#121 AS quantity#125, ss_list_price#122 AS list_price#126] Input [4]: [ss_quantity#121, ss_list_price#122, ss_sold_date_sk#123, d_date_sk#124] -(113) Scan parquet default.catalog_sales +(111) Scan parquet default.catalog_sales Output [3]: [cs_quantity#127, cs_list_price#128, cs_sold_date_sk#129] Batched: true Location: InMemoryFileIndex [] PartitionFilters: [isnotnull(cs_sold_date_sk#129), dynamicpruningexpression(cs_sold_date_sk#129 IN dynamicpruning#12)] ReadSchema: struct -(114) ColumnarToRow [codegen id : 4] +(112) ColumnarToRow [codegen id : 4] Input [3]: [cs_quantity#127, cs_list_price#128, cs_sold_date_sk#129] -(115) ReusedExchange [Reuses operator id: 136] +(113) ReusedExchange [Reuses operator id: 134] Output [1]: [d_date_sk#130] -(116) BroadcastHashJoin [codegen id : 4] +(114) BroadcastHashJoin [codegen id : 4] Left keys [1]: [cs_sold_date_sk#129] Right keys [1]: [d_date_sk#130] Join condition: None -(117) Project [codegen id : 4] +(115) Project [codegen id : 4] Output [2]: [cs_quantity#127 AS quantity#131, cs_list_price#128 AS list_price#132] Input [4]: [cs_quantity#127, cs_list_price#128, cs_sold_date_sk#129, d_date_sk#130] -(118) Scan parquet default.web_sales +(116) Scan parquet default.web_sales Output [3]: [ws_quantity#133, ws_list_price#134, ws_sold_date_sk#135] Batched: true Location: InMemoryFileIndex [] PartitionFilters: [isnotnull(ws_sold_date_sk#135), dynamicpruningexpression(ws_sold_date_sk#135 IN dynamicpruning#12)] ReadSchema: struct -(119) ColumnarToRow [codegen id : 6] +(117) ColumnarToRow [codegen id : 6] Input [3]: [ws_quantity#133, ws_list_price#134, ws_sold_date_sk#135] -(120) ReusedExchange [Reuses operator id: 136] +(118) ReusedExchange [Reuses operator id: 134] Output [1]: [d_date_sk#136] -(121) BroadcastHashJoin [codegen id : 6] +(119) BroadcastHashJoin [codegen id : 6] Left keys [1]: [ws_sold_date_sk#135] Right keys [1]: [d_date_sk#136] Join condition: None -(122) Project [codegen id : 6] +(120) Project [codegen id : 6] Output [2]: [ws_quantity#133 AS quantity#137, ws_list_price#134 AS list_price#138] Input [4]: [ws_quantity#133, ws_list_price#134, ws_sold_date_sk#135, d_date_sk#136] -(123) Union +(121) Union -(124) HashAggregate [codegen id : 7] +(122) HashAggregate [codegen id : 7] Input [2]: [quantity#125, list_price#126] Keys: [] -Functions [1]: [partial_avg(CheckOverflow((promote_precision(cast(cast(quantity#125 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#126 as decimal(12,2)))), DecimalType(18,2), true))] +Functions [1]: [partial_avg(CheckOverflow((promote_precision(cast(quantity#125 as decimal(12,2))) * promote_precision(cast(list_price#126 as decimal(12,2)))), DecimalType(18,2)))] Aggregate Attributes [2]: [sum#139, count#140] Results [2]: [sum#141, count#142] -(125) Exchange +(123) Exchange Input [2]: [sum#141, count#142] Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#143] -(126) HashAggregate [codegen id : 8] +(124) HashAggregate [codegen id : 8] Input [2]: [sum#141, count#142] Keys: [] -Functions [1]: [avg(CheckOverflow((promote_precision(cast(cast(quantity#125 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#126 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [1]: [avg(CheckOverflow((promote_precision(cast(cast(quantity#125 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#126 as decimal(12,2)))), DecimalType(18,2), true))#144] -Results [1]: [avg(CheckOverflow((promote_precision(cast(cast(quantity#125 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#126 as decimal(12,2)))), DecimalType(18,2), true))#144 AS average_sales#145] +Functions [1]: [avg(CheckOverflow((promote_precision(cast(quantity#125 as decimal(12,2))) * promote_precision(cast(list_price#126 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [1]: [avg(CheckOverflow((promote_precision(cast(quantity#125 as decimal(12,2))) * promote_precision(cast(list_price#126 as decimal(12,2)))), DecimalType(18,2)))#144] +Results [1]: [avg(CheckOverflow((promote_precision(cast(quantity#125 as decimal(12,2))) * promote_precision(cast(list_price#126 as decimal(12,2)))), DecimalType(18,2)))#144 AS average_sales#145] -Subquery:2 Hosting operator id = 108 Hosting Expression = ss_sold_date_sk#123 IN dynamicpruning#12 +Subquery:2 Hosting operator id = 106 Hosting Expression = ss_sold_date_sk#123 IN dynamicpruning#12 -Subquery:3 Hosting operator id = 113 Hosting Expression = cs_sold_date_sk#129 IN dynamicpruning#12 +Subquery:3 Hosting operator id = 111 Hosting Expression = cs_sold_date_sk#129 IN dynamicpruning#12 -Subquery:4 Hosting operator id = 118 Hosting Expression = ws_sold_date_sk#135 IN dynamicpruning#12 +Subquery:4 Hosting operator id = 116 Hosting Expression = ws_sold_date_sk#135 IN dynamicpruning#12 Subquery:5 Hosting operator id = 1 Hosting Expression = ss_sold_date_sk#4 IN dynamicpruning#5 -BroadcastExchange (131) -+- * Project (130) - +- * Filter (129) - +- * ColumnarToRow (128) - +- Scan parquet default.date_dim (127) +BroadcastExchange (129) ++- * Project (128) + +- * Filter (127) + +- * ColumnarToRow (126) + +- Scan parquet default.date_dim (125) -(127) Scan parquet default.date_dim +(125) Scan parquet default.date_dim Output [3]: [d_date_sk#48, d_year#146, d_moy#147] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), IsNotNull(d_moy), EqualTo(d_year,2001), EqualTo(d_moy,11), IsNotNull(d_date_sk)] ReadSchema: struct -(128) ColumnarToRow [codegen id : 1] +(126) ColumnarToRow [codegen id : 1] Input [3]: [d_date_sk#48, d_year#146, d_moy#147] -(129) Filter [codegen id : 1] +(127) Filter [codegen id : 1] Input [3]: [d_date_sk#48, d_year#146, d_moy#147] Condition : ((((isnotnull(d_year#146) AND isnotnull(d_moy#147)) AND (d_year#146 = 2001)) AND (d_moy#147 = 11)) AND isnotnull(d_date_sk#48)) -(130) Project [codegen id : 1] +(128) Project [codegen id : 1] Output [1]: [d_date_sk#48] Input [3]: [d_date_sk#48, d_year#146, d_moy#147] -(131) BroadcastExchange +(129) BroadcastExchange Input [1]: [d_date_sk#48] Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#148] Subquery:6 Hosting operator id = 7 Hosting Expression = ss_sold_date_sk#11 IN dynamicpruning#12 -BroadcastExchange (136) -+- * Project (135) - +- * Filter (134) - +- * ColumnarToRow (133) - +- Scan parquet default.date_dim (132) +BroadcastExchange (134) ++- * Project (133) + +- * Filter (132) + +- * ColumnarToRow (131) + +- Scan parquet default.date_dim (130) -(132) Scan parquet default.date_dim +(130) Scan parquet default.date_dim Output [2]: [d_date_sk#27, d_year#149] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), GreaterThanOrEqual(d_year,1999), LessThanOrEqual(d_year,2001), IsNotNull(d_date_sk)] ReadSchema: struct -(133) ColumnarToRow [codegen id : 1] +(131) ColumnarToRow [codegen id : 1] Input [2]: [d_date_sk#27, d_year#149] -(134) Filter [codegen id : 1] +(132) Filter [codegen id : 1] Input [2]: [d_date_sk#27, d_year#149] Condition : (((isnotnull(d_year#149) AND (d_year#149 >= 1999)) AND (d_year#149 <= 2001)) AND isnotnull(d_date_sk#27)) -(135) Project [codegen id : 1] +(133) Project [codegen id : 1] Output [1]: [d_date_sk#27] Input [2]: [d_date_sk#27, d_year#149] -(136) BroadcastExchange +(134) BroadcastExchange Input [1]: [d_date_sk#27] Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#150] @@ -783,12 +767,12 @@ Subquery:7 Hosting operator id = 13 Hosting Expression = cs_sold_date_sk#18 IN d Subquery:8 Hosting operator id = 36 Hosting Expression = ws_sold_date_sk#33 IN dynamicpruning#12 -Subquery:9 Hosting operator id = 84 Hosting Expression = ReusedSubquery Subquery scalar-subquery#60, [id=#61] +Subquery:9 Hosting operator id = 82 Hosting Expression = ReusedSubquery Subquery scalar-subquery#60, [id=#61] -Subquery:10 Hosting operator id = 70 Hosting Expression = cs_sold_date_sk#66 IN dynamicpruning#5 +Subquery:10 Hosting operator id = 68 Hosting Expression = cs_sold_date_sk#66 IN dynamicpruning#5 -Subquery:11 Hosting operator id = 100 Hosting Expression = ReusedSubquery Subquery scalar-subquery#60, [id=#61] +Subquery:11 Hosting operator id = 98 Hosting Expression = ReusedSubquery Subquery scalar-subquery#60, [id=#61] -Subquery:12 Hosting operator id = 86 Hosting Expression = ws_sold_date_sk#87 IN dynamicpruning#5 +Subquery:12 Hosting operator id = 84 Hosting Expression = ws_sold_date_sk#87 IN dynamicpruning#5 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a/simplified.txt index 34d892c264062..b8125b2af8e92 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a/simplified.txt @@ -13,7 +13,7 @@ TakeOrderedAndProject [channel,i_brand_id,i_class_id,i_category_id,sum(sales),su Filter [sales] Subquery #3 WholeStageCodegen (8) - HashAggregate [sum,count] [avg(CheckOverflow((promote_precision(cast(cast(quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price as decimal(12,2)))), DecimalType(18,2), true)),average_sales,sum,count] + HashAggregate [sum,count] [avg(CheckOverflow((promote_precision(cast(quantity as decimal(12,2))) * promote_precision(cast(list_price as decimal(12,2)))), DecimalType(18,2))),average_sales,sum,count] InputAdapter Exchange #13 WholeStageCodegen (7) @@ -47,7 +47,7 @@ TakeOrderedAndProject [channel,i_brand_id,i_class_id,i_category_id,sum(sales),su ReusedSubquery [d_date_sk] #2 InputAdapter ReusedExchange [d_date_sk] #7 - HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price as decimal(12,2)))), DecimalType(18,2), true)),count(1),sales,number_sales,sum,isEmpty,count] + HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(ss_quantity as decimal(12,2))) * promote_precision(cast(ss_list_price as decimal(12,2)))), DecimalType(18,2))),count(1),sales,number_sales,sum,isEmpty,count] InputAdapter Exchange [i_brand_id,i_class_id,i_category_id] #2 WholeStageCodegen (25) @@ -81,77 +81,75 @@ TakeOrderedAndProject [channel,i_brand_id,i_class_id,i_category_id,sum(sales),su InputAdapter BroadcastExchange #5 WholeStageCodegen (10) - HashAggregate [brand_id,class_id,category_id] + BroadcastHashJoin [brand_id,class_id,category_id,i_brand_id,i_class_id,i_category_id] HashAggregate [brand_id,class_id,category_id] - BroadcastHashJoin [brand_id,class_id,category_id,i_brand_id,i_class_id,i_category_id] - HashAggregate [brand_id,class_id,category_id] - InputAdapter - Exchange [brand_id,class_id,category_id] #6 - WholeStageCodegen (6) - HashAggregate [brand_id,class_id,category_id] - Project [i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [ss_sold_date_sk,d_date_sk] - Project [ss_sold_date_sk,i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [ss_item_sk,i_item_sk] - Filter [ss_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.store_sales [ss_item_sk,ss_sold_date_sk] - SubqueryBroadcast [d_date_sk] #2 - BroadcastExchange #7 - WholeStageCodegen (1) - Project [d_date_sk] - Filter [d_year,d_date_sk] - ColumnarToRow - InputAdapter - Scan parquet default.date_dim [d_date_sk,d_year] - InputAdapter - BroadcastExchange #8 - WholeStageCodegen (4) - BroadcastHashJoin [i_brand_id,i_class_id,i_category_id,i_brand_id,i_class_id,i_category_id] - Filter [i_item_sk,i_brand_id,i_class_id,i_category_id] - ColumnarToRow - InputAdapter - Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] - InputAdapter - BroadcastExchange #9 - WholeStageCodegen (3) - Project [i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [cs_sold_date_sk,d_date_sk] - Project [cs_sold_date_sk,i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [cs_item_sk,i_item_sk] - Filter [cs_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.catalog_sales [cs_item_sk,cs_sold_date_sk] - ReusedSubquery [d_date_sk] #2 - InputAdapter - BroadcastExchange #10 - WholeStageCodegen (1) - Filter [i_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] - InputAdapter - ReusedExchange [d_date_sk] #7 - InputAdapter - ReusedExchange [d_date_sk] #7 - InputAdapter - BroadcastExchange #11 - WholeStageCodegen (9) + InputAdapter + Exchange [brand_id,class_id,category_id] #6 + WholeStageCodegen (6) + HashAggregate [brand_id,class_id,category_id] Project [i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [ws_sold_date_sk,d_date_sk] - Project [ws_sold_date_sk,i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [ws_item_sk,i_item_sk] - Filter [ws_item_sk] + BroadcastHashJoin [ss_sold_date_sk,d_date_sk] + Project [ss_sold_date_sk,i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [ss_item_sk,i_item_sk] + Filter [ss_item_sk] ColumnarToRow InputAdapter - Scan parquet default.web_sales [ws_item_sk,ws_sold_date_sk] - ReusedSubquery [d_date_sk] #2 + Scan parquet default.store_sales [ss_item_sk,ss_sold_date_sk] + SubqueryBroadcast [d_date_sk] #2 + BroadcastExchange #7 + WholeStageCodegen (1) + Project [d_date_sk] + Filter [d_year,d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_year] InputAdapter - ReusedExchange [i_item_sk,i_brand_id,i_class_id,i_category_id] #10 + BroadcastExchange #8 + WholeStageCodegen (4) + BroadcastHashJoin [i_brand_id,i_class_id,i_category_id,i_brand_id,i_class_id,i_category_id] + Filter [i_item_sk,i_brand_id,i_class_id,i_category_id] + ColumnarToRow + InputAdapter + Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] + InputAdapter + BroadcastExchange #9 + WholeStageCodegen (3) + Project [i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [cs_sold_date_sk,d_date_sk] + Project [cs_sold_date_sk,i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [cs_item_sk,i_item_sk] + Filter [cs_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.catalog_sales [cs_item_sk,cs_sold_date_sk] + ReusedSubquery [d_date_sk] #2 + InputAdapter + BroadcastExchange #10 + WholeStageCodegen (1) + Filter [i_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] + InputAdapter + ReusedExchange [d_date_sk] #7 InputAdapter ReusedExchange [d_date_sk] #7 + InputAdapter + BroadcastExchange #11 + WholeStageCodegen (9) + Project [i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [ws_sold_date_sk,d_date_sk] + Project [ws_sold_date_sk,i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [ws_item_sk,i_item_sk] + Filter [ws_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.web_sales [ws_item_sk,ws_sold_date_sk] + ReusedSubquery [d_date_sk] #2 + InputAdapter + ReusedExchange [i_item_sk,i_brand_id,i_class_id,i_category_id] #10 + InputAdapter + ReusedExchange [d_date_sk] #7 InputAdapter BroadcastExchange #12 WholeStageCodegen (23) @@ -168,7 +166,7 @@ TakeOrderedAndProject [channel,i_brand_id,i_class_id,i_category_id,sum(sales),su Project [sales,number_sales,i_brand_id,i_class_id,i_category_id] Filter [sales] ReusedSubquery [average_sales] #3 - HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price as decimal(12,2)))), DecimalType(18,2), true)),count(1),sales,number_sales,sum,isEmpty,count] + HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(cs_quantity as decimal(12,2))) * promote_precision(cast(cs_list_price as decimal(12,2)))), DecimalType(18,2))),count(1),sales,number_sales,sum,isEmpty,count] InputAdapter Exchange [i_brand_id,i_class_id,i_category_id] #14 WholeStageCodegen (51) @@ -193,7 +191,7 @@ TakeOrderedAndProject [channel,i_brand_id,i_class_id,i_category_id,sum(sales),su Project [sales,number_sales,i_brand_id,i_class_id,i_category_id] Filter [sales] ReusedSubquery [average_sales] #3 - HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price as decimal(12,2)))), DecimalType(18,2), true)),count(1),sales,number_sales,sum,isEmpty,count] + HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(ws_quantity as decimal(12,2))) * promote_precision(cast(ws_list_price as decimal(12,2)))), DecimalType(18,2))),count(1),sales,number_sales,sum,isEmpty,count] InputAdapter Exchange [i_brand_id,i_class_id,i_category_id] #15 WholeStageCodegen (77) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b.sf100/explain.txt index 3a62afcce3e31..3f0acc0ea73be 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b.sf100/explain.txt @@ -1,106 +1,103 @@ == Physical Plan == -TakeOrderedAndProject (102) -+- * BroadcastHashJoin Inner BuildRight (101) - :- * Filter (81) - : +- * HashAggregate (80) - : +- Exchange (79) - : +- * HashAggregate (78) - : +- * Project (77) - : +- * BroadcastHashJoin Inner BuildRight (76) - : :- * Project (66) - : : +- * BroadcastHashJoin Inner BuildRight (65) - : : :- * SortMergeJoin LeftSemi (63) +TakeOrderedAndProject (99) ++- * BroadcastHashJoin Inner BuildRight (98) + :- * Filter (78) + : +- * HashAggregate (77) + : +- Exchange (76) + : +- * HashAggregate (75) + : +- * Project (74) + : +- * BroadcastHashJoin Inner BuildRight (73) + : :- * Project (63) + : : +- * BroadcastHashJoin Inner BuildRight (62) + : : :- * SortMergeJoin LeftSemi (60) : : : :- * Sort (5) : : : : +- Exchange (4) : : : : +- * Filter (3) : : : : +- * ColumnarToRow (2) : : : : +- Scan parquet default.store_sales (1) - : : : +- * Sort (62) - : : : +- Exchange (61) - : : : +- * Project (60) - : : : +- * BroadcastHashJoin Inner BuildRight (59) + : : : +- * Sort (59) + : : : +- Exchange (58) + : : : +- * Project (57) + : : : +- * BroadcastHashJoin Inner BuildRight (56) : : : :- * Filter (8) : : : : +- * ColumnarToRow (7) : : : : +- Scan parquet default.item (6) - : : : +- BroadcastExchange (58) - : : : +- * HashAggregate (57) - : : : +- Exchange (56) - : : : +- * HashAggregate (55) - : : : +- * SortMergeJoin LeftSemi (54) - : : : :- * Sort (42) - : : : : +- Exchange (41) - : : : : +- * HashAggregate (40) - : : : : +- Exchange (39) - : : : : +- * HashAggregate (38) - : : : : +- * Project (37) - : : : : +- * BroadcastHashJoin Inner BuildRight (36) - : : : : :- * Project (14) - : : : : : +- * BroadcastHashJoin Inner BuildRight (13) - : : : : : :- * Filter (11) - : : : : : : +- * ColumnarToRow (10) - : : : : : : +- Scan parquet default.store_sales (9) - : : : : : +- ReusedExchange (12) - : : : : +- BroadcastExchange (35) - : : : : +- * SortMergeJoin LeftSemi (34) - : : : : :- * Sort (19) - : : : : : +- Exchange (18) - : : : : : +- * Filter (17) - : : : : : +- * ColumnarToRow (16) - : : : : : +- Scan parquet default.item (15) - : : : : +- * Sort (33) - : : : : +- Exchange (32) - : : : : +- * Project (31) - : : : : +- * BroadcastHashJoin Inner BuildRight (30) - : : : : :- * Project (25) - : : : : : +- * BroadcastHashJoin Inner BuildRight (24) - : : : : : :- * Filter (22) - : : : : : : +- * ColumnarToRow (21) - : : : : : : +- Scan parquet default.catalog_sales (20) - : : : : : +- ReusedExchange (23) - : : : : +- BroadcastExchange (29) - : : : : +- * Filter (28) - : : : : +- * ColumnarToRow (27) - : : : : +- Scan parquet default.item (26) - : : : +- * Sort (53) - : : : +- Exchange (52) - : : : +- * Project (51) - : : : +- * BroadcastHashJoin Inner BuildRight (50) - : : : :- * Project (48) - : : : : +- * BroadcastHashJoin Inner BuildRight (47) - : : : : :- * Filter (45) - : : : : : +- * ColumnarToRow (44) - : : : : : +- Scan parquet default.web_sales (43) - : : : : +- ReusedExchange (46) - : : : +- ReusedExchange (49) - : : +- ReusedExchange (64) - : +- BroadcastExchange (75) - : +- * SortMergeJoin LeftSemi (74) - : :- * Sort (71) - : : +- Exchange (70) - : : +- * Filter (69) - : : +- * ColumnarToRow (68) - : : +- Scan parquet default.item (67) - : +- * Sort (73) - : +- ReusedExchange (72) - +- BroadcastExchange (100) - +- * Filter (99) - +- * HashAggregate (98) - +- Exchange (97) - +- * HashAggregate (96) - +- * Project (95) - +- * BroadcastHashJoin Inner BuildRight (94) - :- * Project (92) - : +- * BroadcastHashJoin Inner BuildRight (91) - : :- * SortMergeJoin LeftSemi (89) - : : :- * Sort (86) - : : : +- Exchange (85) - : : : +- * Filter (84) - : : : +- * ColumnarToRow (83) - : : : +- Scan parquet default.store_sales (82) - : : +- * Sort (88) - : : +- ReusedExchange (87) - : +- ReusedExchange (90) - +- ReusedExchange (93) + : : : +- BroadcastExchange (55) + : : : +- * SortMergeJoin LeftSemi (54) + : : : :- * Sort (42) + : : : : +- Exchange (41) + : : : : +- * HashAggregate (40) + : : : : +- Exchange (39) + : : : : +- * HashAggregate (38) + : : : : +- * Project (37) + : : : : +- * BroadcastHashJoin Inner BuildRight (36) + : : : : :- * Project (14) + : : : : : +- * BroadcastHashJoin Inner BuildRight (13) + : : : : : :- * Filter (11) + : : : : : : +- * ColumnarToRow (10) + : : : : : : +- Scan parquet default.store_sales (9) + : : : : : +- ReusedExchange (12) + : : : : +- BroadcastExchange (35) + : : : : +- * SortMergeJoin LeftSemi (34) + : : : : :- * Sort (19) + : : : : : +- Exchange (18) + : : : : : +- * Filter (17) + : : : : : +- * ColumnarToRow (16) + : : : : : +- Scan parquet default.item (15) + : : : : +- * Sort (33) + : : : : +- Exchange (32) + : : : : +- * Project (31) + : : : : +- * BroadcastHashJoin Inner BuildRight (30) + : : : : :- * Project (25) + : : : : : +- * BroadcastHashJoin Inner BuildRight (24) + : : : : : :- * Filter (22) + : : : : : : +- * ColumnarToRow (21) + : : : : : : +- Scan parquet default.catalog_sales (20) + : : : : : +- ReusedExchange (23) + : : : : +- BroadcastExchange (29) + : : : : +- * Filter (28) + : : : : +- * ColumnarToRow (27) + : : : : +- Scan parquet default.item (26) + : : : +- * Sort (53) + : : : +- Exchange (52) + : : : +- * Project (51) + : : : +- * BroadcastHashJoin Inner BuildRight (50) + : : : :- * Project (48) + : : : : +- * BroadcastHashJoin Inner BuildRight (47) + : : : : :- * Filter (45) + : : : : : +- * ColumnarToRow (44) + : : : : : +- Scan parquet default.web_sales (43) + : : : : +- ReusedExchange (46) + : : : +- ReusedExchange (49) + : : +- ReusedExchange (61) + : +- BroadcastExchange (72) + : +- * SortMergeJoin LeftSemi (71) + : :- * Sort (68) + : : +- Exchange (67) + : : +- * Filter (66) + : : +- * ColumnarToRow (65) + : : +- Scan parquet default.item (64) + : +- * Sort (70) + : +- ReusedExchange (69) + +- BroadcastExchange (97) + +- * Filter (96) + +- * HashAggregate (95) + +- Exchange (94) + +- * HashAggregate (93) + +- * Project (92) + +- * BroadcastHashJoin Inner BuildRight (91) + :- * Project (89) + : +- * BroadcastHashJoin Inner BuildRight (88) + : :- * SortMergeJoin LeftSemi (86) + : : :- * Sort (83) + : : : +- Exchange (82) + : : : +- * Filter (81) + : : : +- * ColumnarToRow (80) + : : : +- Scan parquet default.store_sales (79) + : : +- * Sort (85) + : : +- ReusedExchange (84) + : +- ReusedExchange (87) + +- ReusedExchange (90) (1) Scan parquet default.store_sales @@ -133,10 +130,10 @@ Location [not included in comparison]/{warehouse_dir}/item] PushedFilters: [IsNotNull(i_brand_id), IsNotNull(i_class_id), IsNotNull(i_category_id)] ReadSchema: struct -(7) ColumnarToRow [codegen id : 20] +(7) ColumnarToRow [codegen id : 19] Input [4]: [i_item_sk#7, i_brand_id#8, i_class_id#9, i_category_id#10] -(8) Filter [codegen id : 20] +(8) Filter [codegen id : 19] Input [4]: [i_item_sk#7, i_brand_id#8, i_class_id#9, i_category_id#10] Condition : ((isnotnull(i_brand_id#8) AND isnotnull(i_class_id#9)) AND isnotnull(i_category_id#10)) @@ -155,7 +152,7 @@ Input [2]: [ss_item_sk#11, ss_sold_date_sk#12] Input [2]: [ss_item_sk#11, ss_sold_date_sk#12] Condition : isnotnull(ss_item_sk#11) -(12) ReusedExchange [Reuses operator id: 135] +(12) ReusedExchange [Reuses operator id: 132] Output [1]: [d_date_sk#14] (13) BroadcastHashJoin [codegen id : 11] @@ -204,7 +201,7 @@ Input [2]: [cs_item_sk#20, cs_sold_date_sk#21] Input [2]: [cs_item_sk#20, cs_sold_date_sk#21] Condition : isnotnull(cs_item_sk#20) -(23) ReusedExchange [Reuses operator id: 135] +(23) ReusedExchange [Reuses operator id: 132] Output [1]: [d_date_sk#22] (24) BroadcastHashJoin [codegen id : 8] @@ -310,7 +307,7 @@ Input [2]: [ws_item_sk#35, ws_sold_date_sk#36] Input [2]: [ws_item_sk#35, ws_sold_date_sk#36] Condition : isnotnull(ws_item_sk#35) -(46) ReusedExchange [Reuses operator id: 135] +(46) ReusedExchange [Reuses operator id: 132] Output [1]: [d_date_sk#37] (47) BroadcastHashJoin [codegen id : 16] @@ -347,485 +344,467 @@ Left keys [6]: [coalesce(brand_id#30, 0), isnull(brand_id#30), coalesce(class_id Right keys [6]: [coalesce(i_brand_id#39, 0), isnull(i_brand_id#39), coalesce(i_class_id#40, 0), isnull(i_class_id#40), coalesce(i_category_id#41, 0), isnull(i_category_id#41)] Join condition: None -(55) HashAggregate [codegen id : 18] +(55) BroadcastExchange Input [3]: [brand_id#30, class_id#31, category_id#32] -Keys [3]: [brand_id#30, class_id#31, category_id#32] -Functions: [] -Aggregate Attributes: [] -Results [3]: [brand_id#30, class_id#31, category_id#32] - -(56) Exchange -Input [3]: [brand_id#30, class_id#31, category_id#32] -Arguments: hashpartitioning(brand_id#30, class_id#31, category_id#32, 5), ENSURE_REQUIREMENTS, [id=#43] - -(57) HashAggregate [codegen id : 19] -Input [3]: [brand_id#30, class_id#31, category_id#32] -Keys [3]: [brand_id#30, class_id#31, category_id#32] -Functions: [] -Aggregate Attributes: [] -Results [3]: [brand_id#30, class_id#31, category_id#32] - -(58) BroadcastExchange -Input [3]: [brand_id#30, class_id#31, category_id#32] -Arguments: HashedRelationBroadcastMode(List(input[0, int, true], input[1, int, true], input[2, int, true]),false), [id=#44] +Arguments: HashedRelationBroadcastMode(List(input[0, int, true], input[1, int, true], input[2, int, true]),false), [id=#43] -(59) BroadcastHashJoin [codegen id : 20] +(56) BroadcastHashJoin [codegen id : 19] Left keys [3]: [i_brand_id#8, i_class_id#9, i_category_id#10] Right keys [3]: [brand_id#30, class_id#31, category_id#32] Join condition: None -(60) Project [codegen id : 20] -Output [1]: [i_item_sk#7 AS ss_item_sk#45] +(57) Project [codegen id : 19] +Output [1]: [i_item_sk#7 AS ss_item_sk#44] Input [7]: [i_item_sk#7, i_brand_id#8, i_class_id#9, i_category_id#10, brand_id#30, class_id#31, category_id#32] -(61) Exchange -Input [1]: [ss_item_sk#45] -Arguments: hashpartitioning(ss_item_sk#45, 5), ENSURE_REQUIREMENTS, [id=#46] +(58) Exchange +Input [1]: [ss_item_sk#44] +Arguments: hashpartitioning(ss_item_sk#44, 5), ENSURE_REQUIREMENTS, [id=#45] -(62) Sort [codegen id : 21] -Input [1]: [ss_item_sk#45] -Arguments: [ss_item_sk#45 ASC NULLS FIRST], false, 0 +(59) Sort [codegen id : 20] +Input [1]: [ss_item_sk#44] +Arguments: [ss_item_sk#44 ASC NULLS FIRST], false, 0 -(63) SortMergeJoin [codegen id : 45] +(60) SortMergeJoin [codegen id : 43] Left keys [1]: [ss_item_sk#1] -Right keys [1]: [ss_item_sk#45] +Right keys [1]: [ss_item_sk#44] Join condition: None -(64) ReusedExchange [Reuses operator id: 126] -Output [1]: [d_date_sk#47] +(61) ReusedExchange [Reuses operator id: 123] +Output [1]: [d_date_sk#46] -(65) BroadcastHashJoin [codegen id : 45] +(62) BroadcastHashJoin [codegen id : 43] Left keys [1]: [ss_sold_date_sk#4] -Right keys [1]: [d_date_sk#47] +Right keys [1]: [d_date_sk#46] Join condition: None -(66) Project [codegen id : 45] +(63) Project [codegen id : 43] Output [3]: [ss_item_sk#1, ss_quantity#2, ss_list_price#3] -Input [5]: [ss_item_sk#1, ss_quantity#2, ss_list_price#3, ss_sold_date_sk#4, d_date_sk#47] +Input [5]: [ss_item_sk#1, ss_quantity#2, ss_list_price#3, ss_sold_date_sk#4, d_date_sk#46] -(67) Scan parquet default.item -Output [4]: [i_item_sk#48, i_brand_id#49, i_class_id#50, i_category_id#51] +(64) Scan parquet default.item +Output [4]: [i_item_sk#47, i_brand_id#48, i_class_id#49, i_category_id#50] Batched: true Location [not included in comparison]/{warehouse_dir}/item] PushedFilters: [IsNotNull(i_item_sk), IsNotNull(i_brand_id), IsNotNull(i_class_id), IsNotNull(i_category_id)] ReadSchema: struct -(68) ColumnarToRow [codegen id : 23] -Input [4]: [i_item_sk#48, i_brand_id#49, i_class_id#50, i_category_id#51] +(65) ColumnarToRow [codegen id : 22] +Input [4]: [i_item_sk#47, i_brand_id#48, i_class_id#49, i_category_id#50] -(69) Filter [codegen id : 23] -Input [4]: [i_item_sk#48, i_brand_id#49, i_class_id#50, i_category_id#51] -Condition : (((isnotnull(i_item_sk#48) AND isnotnull(i_brand_id#49)) AND isnotnull(i_class_id#50)) AND isnotnull(i_category_id#51)) +(66) Filter [codegen id : 22] +Input [4]: [i_item_sk#47, i_brand_id#48, i_class_id#49, i_category_id#50] +Condition : (((isnotnull(i_item_sk#47) AND isnotnull(i_brand_id#48)) AND isnotnull(i_class_id#49)) AND isnotnull(i_category_id#50)) -(70) Exchange -Input [4]: [i_item_sk#48, i_brand_id#49, i_class_id#50, i_category_id#51] -Arguments: hashpartitioning(i_item_sk#48, 5), ENSURE_REQUIREMENTS, [id=#52] +(67) Exchange +Input [4]: [i_item_sk#47, i_brand_id#48, i_class_id#49, i_category_id#50] +Arguments: hashpartitioning(i_item_sk#47, 5), ENSURE_REQUIREMENTS, [id=#51] -(71) Sort [codegen id : 24] -Input [4]: [i_item_sk#48, i_brand_id#49, i_class_id#50, i_category_id#51] -Arguments: [i_item_sk#48 ASC NULLS FIRST], false, 0 +(68) Sort [codegen id : 23] +Input [4]: [i_item_sk#47, i_brand_id#48, i_class_id#49, i_category_id#50] +Arguments: [i_item_sk#47 ASC NULLS FIRST], false, 0 -(72) ReusedExchange [Reuses operator id: 61] -Output [1]: [ss_item_sk#45] +(69) ReusedExchange [Reuses operator id: 58] +Output [1]: [ss_item_sk#44] -(73) Sort [codegen id : 43] -Input [1]: [ss_item_sk#45] -Arguments: [ss_item_sk#45 ASC NULLS FIRST], false, 0 +(70) Sort [codegen id : 41] +Input [1]: [ss_item_sk#44] +Arguments: [ss_item_sk#44 ASC NULLS FIRST], false, 0 -(74) SortMergeJoin [codegen id : 44] -Left keys [1]: [i_item_sk#48] -Right keys [1]: [ss_item_sk#45] +(71) SortMergeJoin [codegen id : 42] +Left keys [1]: [i_item_sk#47] +Right keys [1]: [ss_item_sk#44] Join condition: None -(75) BroadcastExchange -Input [4]: [i_item_sk#48, i_brand_id#49, i_class_id#50, i_category_id#51] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#53] +(72) BroadcastExchange +Input [4]: [i_item_sk#47, i_brand_id#48, i_class_id#49, i_category_id#50] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#52] -(76) BroadcastHashJoin [codegen id : 45] +(73) BroadcastHashJoin [codegen id : 43] Left keys [1]: [ss_item_sk#1] -Right keys [1]: [i_item_sk#48] +Right keys [1]: [i_item_sk#47] Join condition: None -(77) Project [codegen id : 45] -Output [5]: [ss_quantity#2, ss_list_price#3, i_brand_id#49, i_class_id#50, i_category_id#51] -Input [7]: [ss_item_sk#1, ss_quantity#2, ss_list_price#3, i_item_sk#48, i_brand_id#49, i_class_id#50, i_category_id#51] - -(78) HashAggregate [codegen id : 45] -Input [5]: [ss_quantity#2, ss_list_price#3, i_brand_id#49, i_class_id#50, i_category_id#51] -Keys [3]: [i_brand_id#49, i_class_id#50, i_category_id#51] -Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true)), partial_count(1)] -Aggregate Attributes [3]: [sum#54, isEmpty#55, count#56] -Results [6]: [i_brand_id#49, i_class_id#50, i_category_id#51, sum#57, isEmpty#58, count#59] - -(79) Exchange -Input [6]: [i_brand_id#49, i_class_id#50, i_category_id#51, sum#57, isEmpty#58, count#59] -Arguments: hashpartitioning(i_brand_id#49, i_class_id#50, i_category_id#51, 5), ENSURE_REQUIREMENTS, [id=#60] - -(80) HashAggregate [codegen id : 92] -Input [6]: [i_brand_id#49, i_class_id#50, i_category_id#51, sum#57, isEmpty#58, count#59] -Keys [3]: [i_brand_id#49, i_class_id#50, i_category_id#51] -Functions [2]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true)), count(1)] -Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true))#61, count(1)#62] -Results [6]: [store AS channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true))#61 AS sales#64, count(1)#62 AS number_sales#65] - -(81) Filter [codegen id : 92] -Input [6]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sales#64, number_sales#65] -Condition : (isnotnull(sales#64) AND (cast(sales#64 as decimal(32,6)) > cast(Subquery scalar-subquery#66, [id=#67] as decimal(32,6)))) - -(82) Scan parquet default.store_sales -Output [4]: [ss_item_sk#68, ss_quantity#69, ss_list_price#70, ss_sold_date_sk#71] +(74) Project [codegen id : 43] +Output [5]: [ss_quantity#2, ss_list_price#3, i_brand_id#48, i_class_id#49, i_category_id#50] +Input [7]: [ss_item_sk#1, ss_quantity#2, ss_list_price#3, i_item_sk#47, i_brand_id#48, i_class_id#49, i_category_id#50] + +(75) HashAggregate [codegen id : 43] +Input [5]: [ss_quantity#2, ss_list_price#3, i_brand_id#48, i_class_id#49, i_category_id#50] +Keys [3]: [i_brand_id#48, i_class_id#49, i_category_id#50] +Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2))), partial_count(1)] +Aggregate Attributes [3]: [sum#53, isEmpty#54, count#55] +Results [6]: [i_brand_id#48, i_class_id#49, i_category_id#50, sum#56, isEmpty#57, count#58] + +(76) Exchange +Input [6]: [i_brand_id#48, i_class_id#49, i_category_id#50, sum#56, isEmpty#57, count#58] +Arguments: hashpartitioning(i_brand_id#48, i_class_id#49, i_category_id#50, 5), ENSURE_REQUIREMENTS, [id=#59] + +(77) HashAggregate [codegen id : 88] +Input [6]: [i_brand_id#48, i_class_id#49, i_category_id#50, sum#56, isEmpty#57, count#58] +Keys [3]: [i_brand_id#48, i_class_id#49, i_category_id#50] +Functions [2]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2))), count(1)] +Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2)))#60, count(1)#61] +Results [6]: [store AS channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2)))#60 AS sales#63, count(1)#61 AS number_sales#64] + +(78) Filter [codegen id : 88] +Input [6]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sales#63, number_sales#64] +Condition : (isnotnull(sales#63) AND (cast(sales#63 as decimal(32,6)) > cast(Subquery scalar-subquery#65, [id=#66] as decimal(32,6)))) + +(79) Scan parquet default.store_sales +Output [4]: [ss_item_sk#67, ss_quantity#68, ss_list_price#69, ss_sold_date_sk#70] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(ss_sold_date_sk#71), dynamicpruningexpression(ss_sold_date_sk#71 IN dynamicpruning#72)] +PartitionFilters: [isnotnull(ss_sold_date_sk#70), dynamicpruningexpression(ss_sold_date_sk#70 IN dynamicpruning#71)] PushedFilters: [IsNotNull(ss_item_sk)] ReadSchema: struct -(83) ColumnarToRow [codegen id : 46] -Input [4]: [ss_item_sk#68, ss_quantity#69, ss_list_price#70, ss_sold_date_sk#71] +(80) ColumnarToRow [codegen id : 44] +Input [4]: [ss_item_sk#67, ss_quantity#68, ss_list_price#69, ss_sold_date_sk#70] -(84) Filter [codegen id : 46] -Input [4]: [ss_item_sk#68, ss_quantity#69, ss_list_price#70, ss_sold_date_sk#71] -Condition : isnotnull(ss_item_sk#68) +(81) Filter [codegen id : 44] +Input [4]: [ss_item_sk#67, ss_quantity#68, ss_list_price#69, ss_sold_date_sk#70] +Condition : isnotnull(ss_item_sk#67) -(85) Exchange -Input [4]: [ss_item_sk#68, ss_quantity#69, ss_list_price#70, ss_sold_date_sk#71] -Arguments: hashpartitioning(ss_item_sk#68, 5), ENSURE_REQUIREMENTS, [id=#73] +(82) Exchange +Input [4]: [ss_item_sk#67, ss_quantity#68, ss_list_price#69, ss_sold_date_sk#70] +Arguments: hashpartitioning(ss_item_sk#67, 5), ENSURE_REQUIREMENTS, [id=#72] -(86) Sort [codegen id : 47] -Input [4]: [ss_item_sk#68, ss_quantity#69, ss_list_price#70, ss_sold_date_sk#71] -Arguments: [ss_item_sk#68 ASC NULLS FIRST], false, 0 +(83) Sort [codegen id : 45] +Input [4]: [ss_item_sk#67, ss_quantity#68, ss_list_price#69, ss_sold_date_sk#70] +Arguments: [ss_item_sk#67 ASC NULLS FIRST], false, 0 -(87) ReusedExchange [Reuses operator id: 61] -Output [1]: [ss_item_sk#45] +(84) ReusedExchange [Reuses operator id: 58] +Output [1]: [ss_item_sk#44] -(88) Sort [codegen id : 66] -Input [1]: [ss_item_sk#45] -Arguments: [ss_item_sk#45 ASC NULLS FIRST], false, 0 +(85) Sort [codegen id : 63] +Input [1]: [ss_item_sk#44] +Arguments: [ss_item_sk#44 ASC NULLS FIRST], false, 0 -(89) SortMergeJoin [codegen id : 90] -Left keys [1]: [ss_item_sk#68] -Right keys [1]: [ss_item_sk#45] +(86) SortMergeJoin [codegen id : 86] +Left keys [1]: [ss_item_sk#67] +Right keys [1]: [ss_item_sk#44] Join condition: None -(90) ReusedExchange [Reuses operator id: 140] -Output [1]: [d_date_sk#74] +(87) ReusedExchange [Reuses operator id: 137] +Output [1]: [d_date_sk#73] -(91) BroadcastHashJoin [codegen id : 90] -Left keys [1]: [ss_sold_date_sk#71] -Right keys [1]: [d_date_sk#74] +(88) BroadcastHashJoin [codegen id : 86] +Left keys [1]: [ss_sold_date_sk#70] +Right keys [1]: [d_date_sk#73] Join condition: None -(92) Project [codegen id : 90] -Output [3]: [ss_item_sk#68, ss_quantity#69, ss_list_price#70] -Input [5]: [ss_item_sk#68, ss_quantity#69, ss_list_price#70, ss_sold_date_sk#71, d_date_sk#74] +(89) Project [codegen id : 86] +Output [3]: [ss_item_sk#67, ss_quantity#68, ss_list_price#69] +Input [5]: [ss_item_sk#67, ss_quantity#68, ss_list_price#69, ss_sold_date_sk#70, d_date_sk#73] -(93) ReusedExchange [Reuses operator id: 75] -Output [4]: [i_item_sk#75, i_brand_id#76, i_class_id#77, i_category_id#78] +(90) ReusedExchange [Reuses operator id: 72] +Output [4]: [i_item_sk#74, i_brand_id#75, i_class_id#76, i_category_id#77] -(94) BroadcastHashJoin [codegen id : 90] -Left keys [1]: [ss_item_sk#68] -Right keys [1]: [i_item_sk#75] +(91) BroadcastHashJoin [codegen id : 86] +Left keys [1]: [ss_item_sk#67] +Right keys [1]: [i_item_sk#74] Join condition: None -(95) Project [codegen id : 90] -Output [5]: [ss_quantity#69, ss_list_price#70, i_brand_id#76, i_class_id#77, i_category_id#78] -Input [7]: [ss_item_sk#68, ss_quantity#69, ss_list_price#70, i_item_sk#75, i_brand_id#76, i_class_id#77, i_category_id#78] - -(96) HashAggregate [codegen id : 90] -Input [5]: [ss_quantity#69, ss_list_price#70, i_brand_id#76, i_class_id#77, i_category_id#78] -Keys [3]: [i_brand_id#76, i_class_id#77, i_category_id#78] -Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#69 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#70 as decimal(12,2)))), DecimalType(18,2), true)), partial_count(1)] -Aggregate Attributes [3]: [sum#79, isEmpty#80, count#81] -Results [6]: [i_brand_id#76, i_class_id#77, i_category_id#78, sum#82, isEmpty#83, count#84] - -(97) Exchange -Input [6]: [i_brand_id#76, i_class_id#77, i_category_id#78, sum#82, isEmpty#83, count#84] -Arguments: hashpartitioning(i_brand_id#76, i_class_id#77, i_category_id#78, 5), ENSURE_REQUIREMENTS, [id=#85] - -(98) HashAggregate [codegen id : 91] -Input [6]: [i_brand_id#76, i_class_id#77, i_category_id#78, sum#82, isEmpty#83, count#84] -Keys [3]: [i_brand_id#76, i_class_id#77, i_category_id#78] -Functions [2]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#69 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#70 as decimal(12,2)))), DecimalType(18,2), true)), count(1)] -Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#69 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#70 as decimal(12,2)))), DecimalType(18,2), true))#86, count(1)#87] -Results [6]: [store AS channel#88, i_brand_id#76, i_class_id#77, i_category_id#78, sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#69 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#70 as decimal(12,2)))), DecimalType(18,2), true))#86 AS sales#89, count(1)#87 AS number_sales#90] - -(99) Filter [codegen id : 91] -Input [6]: [channel#88, i_brand_id#76, i_class_id#77, i_category_id#78, sales#89, number_sales#90] -Condition : (isnotnull(sales#89) AND (cast(sales#89 as decimal(32,6)) > cast(ReusedSubquery Subquery scalar-subquery#66, [id=#67] as decimal(32,6)))) - -(100) BroadcastExchange -Input [6]: [channel#88, i_brand_id#76, i_class_id#77, i_category_id#78, sales#89, number_sales#90] -Arguments: HashedRelationBroadcastMode(List(input[1, int, true], input[2, int, true], input[3, int, true]),false), [id=#91] - -(101) BroadcastHashJoin [codegen id : 92] -Left keys [3]: [i_brand_id#49, i_class_id#50, i_category_id#51] -Right keys [3]: [i_brand_id#76, i_class_id#77, i_category_id#78] +(92) Project [codegen id : 86] +Output [5]: [ss_quantity#68, ss_list_price#69, i_brand_id#75, i_class_id#76, i_category_id#77] +Input [7]: [ss_item_sk#67, ss_quantity#68, ss_list_price#69, i_item_sk#74, i_brand_id#75, i_class_id#76, i_category_id#77] + +(93) HashAggregate [codegen id : 86] +Input [5]: [ss_quantity#68, ss_list_price#69, i_brand_id#75, i_class_id#76, i_category_id#77] +Keys [3]: [i_brand_id#75, i_class_id#76, i_category_id#77] +Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#68 as decimal(12,2))) * promote_precision(cast(ss_list_price#69 as decimal(12,2)))), DecimalType(18,2))), partial_count(1)] +Aggregate Attributes [3]: [sum#78, isEmpty#79, count#80] +Results [6]: [i_brand_id#75, i_class_id#76, i_category_id#77, sum#81, isEmpty#82, count#83] + +(94) Exchange +Input [6]: [i_brand_id#75, i_class_id#76, i_category_id#77, sum#81, isEmpty#82, count#83] +Arguments: hashpartitioning(i_brand_id#75, i_class_id#76, i_category_id#77, 5), ENSURE_REQUIREMENTS, [id=#84] + +(95) HashAggregate [codegen id : 87] +Input [6]: [i_brand_id#75, i_class_id#76, i_category_id#77, sum#81, isEmpty#82, count#83] +Keys [3]: [i_brand_id#75, i_class_id#76, i_category_id#77] +Functions [2]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#68 as decimal(12,2))) * promote_precision(cast(ss_list_price#69 as decimal(12,2)))), DecimalType(18,2))), count(1)] +Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#68 as decimal(12,2))) * promote_precision(cast(ss_list_price#69 as decimal(12,2)))), DecimalType(18,2)))#85, count(1)#86] +Results [6]: [store AS channel#87, i_brand_id#75, i_class_id#76, i_category_id#77, sum(CheckOverflow((promote_precision(cast(ss_quantity#68 as decimal(12,2))) * promote_precision(cast(ss_list_price#69 as decimal(12,2)))), DecimalType(18,2)))#85 AS sales#88, count(1)#86 AS number_sales#89] + +(96) Filter [codegen id : 87] +Input [6]: [channel#87, i_brand_id#75, i_class_id#76, i_category_id#77, sales#88, number_sales#89] +Condition : (isnotnull(sales#88) AND (cast(sales#88 as decimal(32,6)) > cast(ReusedSubquery Subquery scalar-subquery#65, [id=#66] as decimal(32,6)))) + +(97) BroadcastExchange +Input [6]: [channel#87, i_brand_id#75, i_class_id#76, i_category_id#77, sales#88, number_sales#89] +Arguments: HashedRelationBroadcastMode(List(input[1, int, true], input[2, int, true], input[3, int, true]),false), [id=#90] + +(98) BroadcastHashJoin [codegen id : 88] +Left keys [3]: [i_brand_id#48, i_class_id#49, i_category_id#50] +Right keys [3]: [i_brand_id#75, i_class_id#76, i_category_id#77] Join condition: None -(102) TakeOrderedAndProject -Input [12]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sales#64, number_sales#65, channel#88, i_brand_id#76, i_class_id#77, i_category_id#78, sales#89, number_sales#90] -Arguments: 100, [i_brand_id#49 ASC NULLS FIRST, i_class_id#50 ASC NULLS FIRST, i_category_id#51 ASC NULLS FIRST], [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sales#64, number_sales#65, channel#88, i_brand_id#76, i_class_id#77, i_category_id#78, sales#89, number_sales#90] +(99) TakeOrderedAndProject +Input [12]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sales#63, number_sales#64, channel#87, i_brand_id#75, i_class_id#76, i_category_id#77, sales#88, number_sales#89] +Arguments: 100, [i_brand_id#48 ASC NULLS FIRST, i_class_id#49 ASC NULLS FIRST, i_category_id#50 ASC NULLS FIRST], [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sales#63, number_sales#64, channel#87, i_brand_id#75, i_class_id#76, i_category_id#77, sales#88, number_sales#89] ===== Subqueries ===== -Subquery:1 Hosting operator id = 81 Hosting Expression = Subquery scalar-subquery#66, [id=#67] -* HashAggregate (121) -+- Exchange (120) - +- * HashAggregate (119) - +- Union (118) - :- * Project (107) - : +- * BroadcastHashJoin Inner BuildRight (106) - : :- * ColumnarToRow (104) - : : +- Scan parquet default.store_sales (103) - : +- ReusedExchange (105) - :- * Project (112) - : +- * BroadcastHashJoin Inner BuildRight (111) - : :- * ColumnarToRow (109) - : : +- Scan parquet default.catalog_sales (108) - : +- ReusedExchange (110) - +- * Project (117) - +- * BroadcastHashJoin Inner BuildRight (116) - :- * ColumnarToRow (114) - : +- Scan parquet default.web_sales (113) - +- ReusedExchange (115) - - -(103) Scan parquet default.store_sales -Output [3]: [ss_quantity#92, ss_list_price#93, ss_sold_date_sk#94] +Subquery:1 Hosting operator id = 78 Hosting Expression = Subquery scalar-subquery#65, [id=#66] +* HashAggregate (118) ++- Exchange (117) + +- * HashAggregate (116) + +- Union (115) + :- * Project (104) + : +- * BroadcastHashJoin Inner BuildRight (103) + : :- * ColumnarToRow (101) + : : +- Scan parquet default.store_sales (100) + : +- ReusedExchange (102) + :- * Project (109) + : +- * BroadcastHashJoin Inner BuildRight (108) + : :- * ColumnarToRow (106) + : : +- Scan parquet default.catalog_sales (105) + : +- ReusedExchange (107) + +- * Project (114) + +- * BroadcastHashJoin Inner BuildRight (113) + :- * ColumnarToRow (111) + : +- Scan parquet default.web_sales (110) + +- ReusedExchange (112) + + +(100) Scan parquet default.store_sales +Output [3]: [ss_quantity#91, ss_list_price#92, ss_sold_date_sk#93] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(ss_sold_date_sk#94), dynamicpruningexpression(ss_sold_date_sk#94 IN dynamicpruning#13)] +PartitionFilters: [isnotnull(ss_sold_date_sk#93), dynamicpruningexpression(ss_sold_date_sk#93 IN dynamicpruning#13)] ReadSchema: struct -(104) ColumnarToRow [codegen id : 2] -Input [3]: [ss_quantity#92, ss_list_price#93, ss_sold_date_sk#94] +(101) ColumnarToRow [codegen id : 2] +Input [3]: [ss_quantity#91, ss_list_price#92, ss_sold_date_sk#93] -(105) ReusedExchange [Reuses operator id: 135] -Output [1]: [d_date_sk#95] +(102) ReusedExchange [Reuses operator id: 132] +Output [1]: [d_date_sk#94] -(106) BroadcastHashJoin [codegen id : 2] -Left keys [1]: [ss_sold_date_sk#94] -Right keys [1]: [d_date_sk#95] +(103) BroadcastHashJoin [codegen id : 2] +Left keys [1]: [ss_sold_date_sk#93] +Right keys [1]: [d_date_sk#94] Join condition: None -(107) Project [codegen id : 2] -Output [2]: [ss_quantity#92 AS quantity#96, ss_list_price#93 AS list_price#97] -Input [4]: [ss_quantity#92, ss_list_price#93, ss_sold_date_sk#94, d_date_sk#95] +(104) Project [codegen id : 2] +Output [2]: [ss_quantity#91 AS quantity#95, ss_list_price#92 AS list_price#96] +Input [4]: [ss_quantity#91, ss_list_price#92, ss_sold_date_sk#93, d_date_sk#94] -(108) Scan parquet default.catalog_sales -Output [3]: [cs_quantity#98, cs_list_price#99, cs_sold_date_sk#100] +(105) Scan parquet default.catalog_sales +Output [3]: [cs_quantity#97, cs_list_price#98, cs_sold_date_sk#99] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(cs_sold_date_sk#100), dynamicpruningexpression(cs_sold_date_sk#100 IN dynamicpruning#13)] +PartitionFilters: [isnotnull(cs_sold_date_sk#99), dynamicpruningexpression(cs_sold_date_sk#99 IN dynamicpruning#13)] ReadSchema: struct -(109) ColumnarToRow [codegen id : 4] -Input [3]: [cs_quantity#98, cs_list_price#99, cs_sold_date_sk#100] +(106) ColumnarToRow [codegen id : 4] +Input [3]: [cs_quantity#97, cs_list_price#98, cs_sold_date_sk#99] -(110) ReusedExchange [Reuses operator id: 135] -Output [1]: [d_date_sk#101] +(107) ReusedExchange [Reuses operator id: 132] +Output [1]: [d_date_sk#100] -(111) BroadcastHashJoin [codegen id : 4] -Left keys [1]: [cs_sold_date_sk#100] -Right keys [1]: [d_date_sk#101] +(108) BroadcastHashJoin [codegen id : 4] +Left keys [1]: [cs_sold_date_sk#99] +Right keys [1]: [d_date_sk#100] Join condition: None -(112) Project [codegen id : 4] -Output [2]: [cs_quantity#98 AS quantity#102, cs_list_price#99 AS list_price#103] -Input [4]: [cs_quantity#98, cs_list_price#99, cs_sold_date_sk#100, d_date_sk#101] +(109) Project [codegen id : 4] +Output [2]: [cs_quantity#97 AS quantity#101, cs_list_price#98 AS list_price#102] +Input [4]: [cs_quantity#97, cs_list_price#98, cs_sold_date_sk#99, d_date_sk#100] -(113) Scan parquet default.web_sales -Output [3]: [ws_quantity#104, ws_list_price#105, ws_sold_date_sk#106] +(110) Scan parquet default.web_sales +Output [3]: [ws_quantity#103, ws_list_price#104, ws_sold_date_sk#105] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(ws_sold_date_sk#106), dynamicpruningexpression(ws_sold_date_sk#106 IN dynamicpruning#13)] +PartitionFilters: [isnotnull(ws_sold_date_sk#105), dynamicpruningexpression(ws_sold_date_sk#105 IN dynamicpruning#13)] ReadSchema: struct -(114) ColumnarToRow [codegen id : 6] -Input [3]: [ws_quantity#104, ws_list_price#105, ws_sold_date_sk#106] +(111) ColumnarToRow [codegen id : 6] +Input [3]: [ws_quantity#103, ws_list_price#104, ws_sold_date_sk#105] -(115) ReusedExchange [Reuses operator id: 135] -Output [1]: [d_date_sk#107] +(112) ReusedExchange [Reuses operator id: 132] +Output [1]: [d_date_sk#106] -(116) BroadcastHashJoin [codegen id : 6] -Left keys [1]: [ws_sold_date_sk#106] -Right keys [1]: [d_date_sk#107] +(113) BroadcastHashJoin [codegen id : 6] +Left keys [1]: [ws_sold_date_sk#105] +Right keys [1]: [d_date_sk#106] Join condition: None -(117) Project [codegen id : 6] -Output [2]: [ws_quantity#104 AS quantity#108, ws_list_price#105 AS list_price#109] -Input [4]: [ws_quantity#104, ws_list_price#105, ws_sold_date_sk#106, d_date_sk#107] +(114) Project [codegen id : 6] +Output [2]: [ws_quantity#103 AS quantity#107, ws_list_price#104 AS list_price#108] +Input [4]: [ws_quantity#103, ws_list_price#104, ws_sold_date_sk#105, d_date_sk#106] -(118) Union +(115) Union -(119) HashAggregate [codegen id : 7] -Input [2]: [quantity#96, list_price#97] +(116) HashAggregate [codegen id : 7] +Input [2]: [quantity#95, list_price#96] Keys: [] -Functions [1]: [partial_avg(CheckOverflow((promote_precision(cast(cast(quantity#96 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#97 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [2]: [sum#110, count#111] -Results [2]: [sum#112, count#113] +Functions [1]: [partial_avg(CheckOverflow((promote_precision(cast(quantity#95 as decimal(12,2))) * promote_precision(cast(list_price#96 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [2]: [sum#109, count#110] +Results [2]: [sum#111, count#112] -(120) Exchange -Input [2]: [sum#112, count#113] -Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#114] +(117) Exchange +Input [2]: [sum#111, count#112] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#113] -(121) HashAggregate [codegen id : 8] -Input [2]: [sum#112, count#113] +(118) HashAggregate [codegen id : 8] +Input [2]: [sum#111, count#112] Keys: [] -Functions [1]: [avg(CheckOverflow((promote_precision(cast(cast(quantity#96 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#97 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [1]: [avg(CheckOverflow((promote_precision(cast(cast(quantity#96 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#97 as decimal(12,2)))), DecimalType(18,2), true))#115] -Results [1]: [avg(CheckOverflow((promote_precision(cast(cast(quantity#96 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#97 as decimal(12,2)))), DecimalType(18,2), true))#115 AS average_sales#116] +Functions [1]: [avg(CheckOverflow((promote_precision(cast(quantity#95 as decimal(12,2))) * promote_precision(cast(list_price#96 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [1]: [avg(CheckOverflow((promote_precision(cast(quantity#95 as decimal(12,2))) * promote_precision(cast(list_price#96 as decimal(12,2)))), DecimalType(18,2)))#114] +Results [1]: [avg(CheckOverflow((promote_precision(cast(quantity#95 as decimal(12,2))) * promote_precision(cast(list_price#96 as decimal(12,2)))), DecimalType(18,2)))#114 AS average_sales#115] -Subquery:2 Hosting operator id = 103 Hosting Expression = ss_sold_date_sk#94 IN dynamicpruning#13 +Subquery:2 Hosting operator id = 100 Hosting Expression = ss_sold_date_sk#93 IN dynamicpruning#13 -Subquery:3 Hosting operator id = 108 Hosting Expression = cs_sold_date_sk#100 IN dynamicpruning#13 +Subquery:3 Hosting operator id = 105 Hosting Expression = cs_sold_date_sk#99 IN dynamicpruning#13 -Subquery:4 Hosting operator id = 113 Hosting Expression = ws_sold_date_sk#106 IN dynamicpruning#13 +Subquery:4 Hosting operator id = 110 Hosting Expression = ws_sold_date_sk#105 IN dynamicpruning#13 Subquery:5 Hosting operator id = 1 Hosting Expression = ss_sold_date_sk#4 IN dynamicpruning#5 -BroadcastExchange (126) -+- * Project (125) - +- * Filter (124) - +- * ColumnarToRow (123) - +- Scan parquet default.date_dim (122) +BroadcastExchange (123) ++- * Project (122) + +- * Filter (121) + +- * ColumnarToRow (120) + +- Scan parquet default.date_dim (119) -(122) Scan parquet default.date_dim -Output [2]: [d_date_sk#47, d_week_seq#117] +(119) Scan parquet default.date_dim +Output [2]: [d_date_sk#46, d_week_seq#116] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_week_seq), IsNotNull(d_date_sk)] ReadSchema: struct -(123) ColumnarToRow [codegen id : 1] -Input [2]: [d_date_sk#47, d_week_seq#117] +(120) ColumnarToRow [codegen id : 1] +Input [2]: [d_date_sk#46, d_week_seq#116] -(124) Filter [codegen id : 1] -Input [2]: [d_date_sk#47, d_week_seq#117] -Condition : ((isnotnull(d_week_seq#117) AND (d_week_seq#117 = Subquery scalar-subquery#118, [id=#119])) AND isnotnull(d_date_sk#47)) +(121) Filter [codegen id : 1] +Input [2]: [d_date_sk#46, d_week_seq#116] +Condition : ((isnotnull(d_week_seq#116) AND (d_week_seq#116 = Subquery scalar-subquery#117, [id=#118])) AND isnotnull(d_date_sk#46)) -(125) Project [codegen id : 1] -Output [1]: [d_date_sk#47] -Input [2]: [d_date_sk#47, d_week_seq#117] +(122) Project [codegen id : 1] +Output [1]: [d_date_sk#46] +Input [2]: [d_date_sk#46, d_week_seq#116] -(126) BroadcastExchange -Input [1]: [d_date_sk#47] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#120] +(123) BroadcastExchange +Input [1]: [d_date_sk#46] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#119] -Subquery:6 Hosting operator id = 124 Hosting Expression = Subquery scalar-subquery#118, [id=#119] -* Project (130) -+- * Filter (129) - +- * ColumnarToRow (128) - +- Scan parquet default.date_dim (127) +Subquery:6 Hosting operator id = 121 Hosting Expression = Subquery scalar-subquery#117, [id=#118] +* Project (127) ++- * Filter (126) + +- * ColumnarToRow (125) + +- Scan parquet default.date_dim (124) -(127) Scan parquet default.date_dim -Output [4]: [d_week_seq#121, d_year#122, d_moy#123, d_dom#124] +(124) Scan parquet default.date_dim +Output [4]: [d_week_seq#120, d_year#121, d_moy#122, d_dom#123] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), IsNotNull(d_moy), IsNotNull(d_dom), EqualTo(d_year,2000), EqualTo(d_moy,12), EqualTo(d_dom,11)] ReadSchema: struct -(128) ColumnarToRow [codegen id : 1] -Input [4]: [d_week_seq#121, d_year#122, d_moy#123, d_dom#124] +(125) ColumnarToRow [codegen id : 1] +Input [4]: [d_week_seq#120, d_year#121, d_moy#122, d_dom#123] -(129) Filter [codegen id : 1] -Input [4]: [d_week_seq#121, d_year#122, d_moy#123, d_dom#124] -Condition : (((((isnotnull(d_year#122) AND isnotnull(d_moy#123)) AND isnotnull(d_dom#124)) AND (d_year#122 = 2000)) AND (d_moy#123 = 12)) AND (d_dom#124 = 11)) +(126) Filter [codegen id : 1] +Input [4]: [d_week_seq#120, d_year#121, d_moy#122, d_dom#123] +Condition : (((((isnotnull(d_year#121) AND isnotnull(d_moy#122)) AND isnotnull(d_dom#123)) AND (d_year#121 = 2000)) AND (d_moy#122 = 12)) AND (d_dom#123 = 11)) -(130) Project [codegen id : 1] -Output [1]: [d_week_seq#121] -Input [4]: [d_week_seq#121, d_year#122, d_moy#123, d_dom#124] +(127) Project [codegen id : 1] +Output [1]: [d_week_seq#120] +Input [4]: [d_week_seq#120, d_year#121, d_moy#122, d_dom#123] Subquery:7 Hosting operator id = 9 Hosting Expression = ss_sold_date_sk#12 IN dynamicpruning#13 -BroadcastExchange (135) -+- * Project (134) - +- * Filter (133) - +- * ColumnarToRow (132) - +- Scan parquet default.date_dim (131) +BroadcastExchange (132) ++- * Project (131) + +- * Filter (130) + +- * ColumnarToRow (129) + +- Scan parquet default.date_dim (128) -(131) Scan parquet default.date_dim -Output [2]: [d_date_sk#14, d_year#125] +(128) Scan parquet default.date_dim +Output [2]: [d_date_sk#14, d_year#124] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), GreaterThanOrEqual(d_year,1999), LessThanOrEqual(d_year,2001), IsNotNull(d_date_sk)] ReadSchema: struct -(132) ColumnarToRow [codegen id : 1] -Input [2]: [d_date_sk#14, d_year#125] +(129) ColumnarToRow [codegen id : 1] +Input [2]: [d_date_sk#14, d_year#124] -(133) Filter [codegen id : 1] -Input [2]: [d_date_sk#14, d_year#125] -Condition : (((isnotnull(d_year#125) AND (d_year#125 >= 1999)) AND (d_year#125 <= 2001)) AND isnotnull(d_date_sk#14)) +(130) Filter [codegen id : 1] +Input [2]: [d_date_sk#14, d_year#124] +Condition : (((isnotnull(d_year#124) AND (d_year#124 >= 1999)) AND (d_year#124 <= 2001)) AND isnotnull(d_date_sk#14)) -(134) Project [codegen id : 1] +(131) Project [codegen id : 1] Output [1]: [d_date_sk#14] -Input [2]: [d_date_sk#14, d_year#125] +Input [2]: [d_date_sk#14, d_year#124] -(135) BroadcastExchange +(132) BroadcastExchange Input [1]: [d_date_sk#14] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#126] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#125] Subquery:8 Hosting operator id = 20 Hosting Expression = cs_sold_date_sk#21 IN dynamicpruning#13 Subquery:9 Hosting operator id = 43 Hosting Expression = ws_sold_date_sk#36 IN dynamicpruning#13 -Subquery:10 Hosting operator id = 99 Hosting Expression = ReusedSubquery Subquery scalar-subquery#66, [id=#67] +Subquery:10 Hosting operator id = 96 Hosting Expression = ReusedSubquery Subquery scalar-subquery#65, [id=#66] -Subquery:11 Hosting operator id = 82 Hosting Expression = ss_sold_date_sk#71 IN dynamicpruning#72 -BroadcastExchange (140) -+- * Project (139) - +- * Filter (138) - +- * ColumnarToRow (137) - +- Scan parquet default.date_dim (136) +Subquery:11 Hosting operator id = 79 Hosting Expression = ss_sold_date_sk#70 IN dynamicpruning#71 +BroadcastExchange (137) ++- * Project (136) + +- * Filter (135) + +- * ColumnarToRow (134) + +- Scan parquet default.date_dim (133) -(136) Scan parquet default.date_dim -Output [2]: [d_date_sk#74, d_week_seq#127] +(133) Scan parquet default.date_dim +Output [2]: [d_date_sk#73, d_week_seq#126] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_week_seq), IsNotNull(d_date_sk)] ReadSchema: struct -(137) ColumnarToRow [codegen id : 1] -Input [2]: [d_date_sk#74, d_week_seq#127] +(134) ColumnarToRow [codegen id : 1] +Input [2]: [d_date_sk#73, d_week_seq#126] -(138) Filter [codegen id : 1] -Input [2]: [d_date_sk#74, d_week_seq#127] -Condition : ((isnotnull(d_week_seq#127) AND (d_week_seq#127 = Subquery scalar-subquery#128, [id=#129])) AND isnotnull(d_date_sk#74)) +(135) Filter [codegen id : 1] +Input [2]: [d_date_sk#73, d_week_seq#126] +Condition : ((isnotnull(d_week_seq#126) AND (d_week_seq#126 = Subquery scalar-subquery#127, [id=#128])) AND isnotnull(d_date_sk#73)) -(139) Project [codegen id : 1] -Output [1]: [d_date_sk#74] -Input [2]: [d_date_sk#74, d_week_seq#127] +(136) Project [codegen id : 1] +Output [1]: [d_date_sk#73] +Input [2]: [d_date_sk#73, d_week_seq#126] -(140) BroadcastExchange -Input [1]: [d_date_sk#74] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#130] +(137) BroadcastExchange +Input [1]: [d_date_sk#73] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#129] -Subquery:12 Hosting operator id = 138 Hosting Expression = Subquery scalar-subquery#128, [id=#129] -* Project (144) -+- * Filter (143) - +- * ColumnarToRow (142) - +- Scan parquet default.date_dim (141) +Subquery:12 Hosting operator id = 135 Hosting Expression = Subquery scalar-subquery#127, [id=#128] +* Project (141) ++- * Filter (140) + +- * ColumnarToRow (139) + +- Scan parquet default.date_dim (138) -(141) Scan parquet default.date_dim -Output [4]: [d_week_seq#131, d_year#132, d_moy#133, d_dom#134] +(138) Scan parquet default.date_dim +Output [4]: [d_week_seq#130, d_year#131, d_moy#132, d_dom#133] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), IsNotNull(d_moy), IsNotNull(d_dom), EqualTo(d_year,1999), EqualTo(d_moy,12), EqualTo(d_dom,11)] ReadSchema: struct -(142) ColumnarToRow [codegen id : 1] -Input [4]: [d_week_seq#131, d_year#132, d_moy#133, d_dom#134] +(139) ColumnarToRow [codegen id : 1] +Input [4]: [d_week_seq#130, d_year#131, d_moy#132, d_dom#133] -(143) Filter [codegen id : 1] -Input [4]: [d_week_seq#131, d_year#132, d_moy#133, d_dom#134] -Condition : (((((isnotnull(d_year#132) AND isnotnull(d_moy#133)) AND isnotnull(d_dom#134)) AND (d_year#132 = 1999)) AND (d_moy#133 = 12)) AND (d_dom#134 = 11)) +(140) Filter [codegen id : 1] +Input [4]: [d_week_seq#130, d_year#131, d_moy#132, d_dom#133] +Condition : (((((isnotnull(d_year#131) AND isnotnull(d_moy#132)) AND isnotnull(d_dom#133)) AND (d_year#131 = 1999)) AND (d_moy#132 = 12)) AND (d_dom#133 = 11)) -(144) Project [codegen id : 1] -Output [1]: [d_week_seq#131] -Input [4]: [d_week_seq#131, d_year#132, d_moy#133, d_dom#134] +(141) Project [codegen id : 1] +Output [1]: [d_week_seq#130] +Input [4]: [d_week_seq#130, d_year#131, d_moy#132, d_dom#133] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b.sf100/simplified.txt index 695a7c13381d8..82e338515f431 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b.sf100/simplified.txt @@ -1,12 +1,12 @@ TakeOrderedAndProject [i_brand_id,i_class_id,i_category_id,channel,sales,number_sales,channel,i_brand_id,i_class_id,i_category_id,sales,number_sales] - WholeStageCodegen (92) + WholeStageCodegen (88) BroadcastHashJoin [i_brand_id,i_class_id,i_category_id,i_brand_id,i_class_id,i_category_id] Filter [sales] Subquery #4 WholeStageCodegen (8) - HashAggregate [sum,count] [avg(CheckOverflow((promote_precision(cast(cast(quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price as decimal(12,2)))), DecimalType(18,2), true)),average_sales,sum,count] + HashAggregate [sum,count] [avg(CheckOverflow((promote_precision(cast(quantity as decimal(12,2))) * promote_precision(cast(list_price as decimal(12,2)))), DecimalType(18,2))),average_sales,sum,count] InputAdapter - Exchange #17 + Exchange #16 WholeStageCodegen (7) HashAggregate [quantity,list_price] [sum,count,sum,count] InputAdapter @@ -19,7 +19,7 @@ TakeOrderedAndProject [i_brand_id,i_class_id,i_category_id,channel,sales,number_ Scan parquet default.store_sales [ss_quantity,ss_list_price,ss_sold_date_sk] ReusedSubquery [d_date_sk] #3 InputAdapter - ReusedExchange [d_date_sk] #9 + ReusedExchange [d_date_sk] #8 WholeStageCodegen (4) Project [cs_quantity,cs_list_price] BroadcastHashJoin [cs_sold_date_sk,d_date_sk] @@ -28,7 +28,7 @@ TakeOrderedAndProject [i_brand_id,i_class_id,i_category_id,channel,sales,number_ Scan parquet default.catalog_sales [cs_quantity,cs_list_price,cs_sold_date_sk] ReusedSubquery [d_date_sk] #3 InputAdapter - ReusedExchange [d_date_sk] #9 + ReusedExchange [d_date_sk] #8 WholeStageCodegen (6) Project [ws_quantity,ws_list_price] BroadcastHashJoin [ws_sold_date_sk,d_date_sk] @@ -37,11 +37,11 @@ TakeOrderedAndProject [i_brand_id,i_class_id,i_category_id,channel,sales,number_ Scan parquet default.web_sales [ws_quantity,ws_list_price,ws_sold_date_sk] ReusedSubquery [d_date_sk] #3 InputAdapter - ReusedExchange [d_date_sk] #9 - HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price as decimal(12,2)))), DecimalType(18,2), true)),count(1),channel,sales,number_sales,sum,isEmpty,count] + ReusedExchange [d_date_sk] #8 + HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(ss_quantity as decimal(12,2))) * promote_precision(cast(ss_list_price as decimal(12,2)))), DecimalType(18,2))),count(1),channel,sales,number_sales,sum,isEmpty,count] InputAdapter Exchange [i_brand_id,i_class_id,i_category_id] #1 - WholeStageCodegen (45) + WholeStageCodegen (43) HashAggregate [i_brand_id,i_class_id,i_category_id,ss_quantity,ss_list_price] [sum,isEmpty,count,sum,isEmpty,count] Project [ss_quantity,ss_list_price,i_brand_id,i_class_id,i_category_id] BroadcastHashJoin [ss_item_sk,i_item_sk] @@ -74,11 +74,11 @@ TakeOrderedAndProject [i_brand_id,i_class_id,i_category_id,channel,sales,number_ InputAdapter Scan parquet default.date_dim [d_date_sk,d_week_seq] InputAdapter - WholeStageCodegen (21) + WholeStageCodegen (20) Sort [ss_item_sk] InputAdapter Exchange [ss_item_sk] #4 - WholeStageCodegen (20) + WholeStageCodegen (19) Project [i_item_sk] BroadcastHashJoin [i_brand_id,i_class_id,i_category_id,brand_id,class_id,category_id] Filter [i_brand_id,i_class_id,i_category_id] @@ -87,129 +87,124 @@ TakeOrderedAndProject [i_brand_id,i_class_id,i_category_id,channel,sales,number_ Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] InputAdapter BroadcastExchange #5 - WholeStageCodegen (19) - HashAggregate [brand_id,class_id,category_id] + WholeStageCodegen (18) + SortMergeJoin [brand_id,class_id,category_id,i_brand_id,i_class_id,i_category_id] InputAdapter - Exchange [brand_id,class_id,category_id] #6 - WholeStageCodegen (18) - HashAggregate [brand_id,class_id,category_id] - SortMergeJoin [brand_id,class_id,category_id,i_brand_id,i_class_id,i_category_id] - InputAdapter - WholeStageCodegen (13) - Sort [brand_id,class_id,category_id] - InputAdapter - Exchange [brand_id,class_id,category_id] #7 - WholeStageCodegen (12) - HashAggregate [brand_id,class_id,category_id] - InputAdapter - Exchange [brand_id,class_id,category_id] #8 - WholeStageCodegen (11) - HashAggregate [brand_id,class_id,category_id] - Project [i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [ss_item_sk,i_item_sk] - Project [ss_item_sk] - BroadcastHashJoin [ss_sold_date_sk,d_date_sk] - Filter [ss_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.store_sales [ss_item_sk,ss_sold_date_sk] - SubqueryBroadcast [d_date_sk] #3 - BroadcastExchange #9 - WholeStageCodegen (1) - Project [d_date_sk] - Filter [d_year,d_date_sk] - ColumnarToRow - InputAdapter - Scan parquet default.date_dim [d_date_sk,d_year] + WholeStageCodegen (13) + Sort [brand_id,class_id,category_id] + InputAdapter + Exchange [brand_id,class_id,category_id] #6 + WholeStageCodegen (12) + HashAggregate [brand_id,class_id,category_id] + InputAdapter + Exchange [brand_id,class_id,category_id] #7 + WholeStageCodegen (11) + HashAggregate [brand_id,class_id,category_id] + Project [i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [ss_item_sk,i_item_sk] + Project [ss_item_sk] + BroadcastHashJoin [ss_sold_date_sk,d_date_sk] + Filter [ss_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.store_sales [ss_item_sk,ss_sold_date_sk] + SubqueryBroadcast [d_date_sk] #3 + BroadcastExchange #8 + WholeStageCodegen (1) + Project [d_date_sk] + Filter [d_year,d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_year] + InputAdapter + ReusedExchange [d_date_sk] #8 + InputAdapter + BroadcastExchange #9 + WholeStageCodegen (10) + SortMergeJoin [i_brand_id,i_class_id,i_category_id,i_brand_id,i_class_id,i_category_id] + InputAdapter + WholeStageCodegen (5) + Sort [i_brand_id,i_class_id,i_category_id] InputAdapter - ReusedExchange [d_date_sk] #9 - InputAdapter - BroadcastExchange #10 - WholeStageCodegen (10) - SortMergeJoin [i_brand_id,i_class_id,i_category_id,i_brand_id,i_class_id,i_category_id] - InputAdapter - WholeStageCodegen (5) - Sort [i_brand_id,i_class_id,i_category_id] + Exchange [i_brand_id,i_class_id,i_category_id] #10 + WholeStageCodegen (4) + Filter [i_item_sk,i_brand_id,i_class_id,i_category_id] + ColumnarToRow InputAdapter - Exchange [i_brand_id,i_class_id,i_category_id] #11 - WholeStageCodegen (4) - Filter [i_item_sk,i_brand_id,i_class_id,i_category_id] + Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] + InputAdapter + WholeStageCodegen (9) + Sort [i_brand_id,i_class_id,i_category_id] + InputAdapter + Exchange [i_brand_id,i_class_id,i_category_id] #11 + WholeStageCodegen (8) + Project [i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [cs_item_sk,i_item_sk] + Project [cs_item_sk] + BroadcastHashJoin [cs_sold_date_sk,d_date_sk] + Filter [cs_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.catalog_sales [cs_item_sk,cs_sold_date_sk] + ReusedSubquery [d_date_sk] #3 + InputAdapter + ReusedExchange [d_date_sk] #8 + InputAdapter + BroadcastExchange #12 + WholeStageCodegen (7) + Filter [i_item_sk] ColumnarToRow InputAdapter Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] - InputAdapter - WholeStageCodegen (9) - Sort [i_brand_id,i_class_id,i_category_id] - InputAdapter - Exchange [i_brand_id,i_class_id,i_category_id] #12 - WholeStageCodegen (8) - Project [i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [cs_item_sk,i_item_sk] - Project [cs_item_sk] - BroadcastHashJoin [cs_sold_date_sk,d_date_sk] - Filter [cs_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.catalog_sales [cs_item_sk,cs_sold_date_sk] - ReusedSubquery [d_date_sk] #3 - InputAdapter - ReusedExchange [d_date_sk] #9 - InputAdapter - BroadcastExchange #13 - WholeStageCodegen (7) - Filter [i_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] - InputAdapter - WholeStageCodegen (17) - Sort [i_brand_id,i_class_id,i_category_id] + InputAdapter + WholeStageCodegen (17) + Sort [i_brand_id,i_class_id,i_category_id] + InputAdapter + Exchange [i_brand_id,i_class_id,i_category_id] #13 + WholeStageCodegen (16) + Project [i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [ws_item_sk,i_item_sk] + Project [ws_item_sk] + BroadcastHashJoin [ws_sold_date_sk,d_date_sk] + Filter [ws_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.web_sales [ws_item_sk,ws_sold_date_sk] + ReusedSubquery [d_date_sk] #3 + InputAdapter + ReusedExchange [d_date_sk] #8 InputAdapter - Exchange [i_brand_id,i_class_id,i_category_id] #14 - WholeStageCodegen (16) - Project [i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [ws_item_sk,i_item_sk] - Project [ws_item_sk] - BroadcastHashJoin [ws_sold_date_sk,d_date_sk] - Filter [ws_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.web_sales [ws_item_sk,ws_sold_date_sk] - ReusedSubquery [d_date_sk] #3 - InputAdapter - ReusedExchange [d_date_sk] #9 - InputAdapter - ReusedExchange [i_item_sk,i_brand_id,i_class_id,i_category_id] #13 + ReusedExchange [i_item_sk,i_brand_id,i_class_id,i_category_id] #12 InputAdapter ReusedExchange [d_date_sk] #3 InputAdapter - BroadcastExchange #15 - WholeStageCodegen (44) + BroadcastExchange #14 + WholeStageCodegen (42) SortMergeJoin [i_item_sk,ss_item_sk] InputAdapter - WholeStageCodegen (24) + WholeStageCodegen (23) Sort [i_item_sk] InputAdapter - Exchange [i_item_sk] #16 - WholeStageCodegen (23) + Exchange [i_item_sk] #15 + WholeStageCodegen (22) Filter [i_item_sk,i_brand_id,i_class_id,i_category_id] ColumnarToRow InputAdapter Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] InputAdapter - WholeStageCodegen (43) + WholeStageCodegen (41) Sort [ss_item_sk] InputAdapter ReusedExchange [ss_item_sk] #4 InputAdapter - BroadcastExchange #18 - WholeStageCodegen (91) + BroadcastExchange #17 + WholeStageCodegen (87) Filter [sales] ReusedSubquery [average_sales] #4 - HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price as decimal(12,2)))), DecimalType(18,2), true)),count(1),channel,sales,number_sales,sum,isEmpty,count] + HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(ss_quantity as decimal(12,2))) * promote_precision(cast(ss_list_price as decimal(12,2)))), DecimalType(18,2))),count(1),channel,sales,number_sales,sum,isEmpty,count] InputAdapter - Exchange [i_brand_id,i_class_id,i_category_id] #19 - WholeStageCodegen (90) + Exchange [i_brand_id,i_class_id,i_category_id] #18 + WholeStageCodegen (86) HashAggregate [i_brand_id,i_class_id,i_category_id,ss_quantity,ss_list_price] [sum,isEmpty,count,sum,isEmpty,count] Project [ss_quantity,ss_list_price,i_brand_id,i_class_id,i_category_id] BroadcastHashJoin [ss_item_sk,i_item_sk] @@ -217,17 +212,17 @@ TakeOrderedAndProject [i_brand_id,i_class_id,i_category_id,channel,sales,number_ BroadcastHashJoin [ss_sold_date_sk,d_date_sk] SortMergeJoin [ss_item_sk,ss_item_sk] InputAdapter - WholeStageCodegen (47) + WholeStageCodegen (45) Sort [ss_item_sk] InputAdapter - Exchange [ss_item_sk] #20 - WholeStageCodegen (46) + Exchange [ss_item_sk] #19 + WholeStageCodegen (44) Filter [ss_item_sk] ColumnarToRow InputAdapter Scan parquet default.store_sales [ss_item_sk,ss_quantity,ss_list_price,ss_sold_date_sk] SubqueryBroadcast [d_date_sk] #5 - BroadcastExchange #21 + BroadcastExchange #20 WholeStageCodegen (1) Project [d_date_sk] Filter [d_week_seq,d_date_sk] @@ -242,11 +237,11 @@ TakeOrderedAndProject [i_brand_id,i_class_id,i_category_id,channel,sales,number_ InputAdapter Scan parquet default.date_dim [d_date_sk,d_week_seq] InputAdapter - WholeStageCodegen (66) + WholeStageCodegen (63) Sort [ss_item_sk] InputAdapter ReusedExchange [ss_item_sk] #4 InputAdapter - ReusedExchange [d_date_sk] #21 + ReusedExchange [d_date_sk] #20 InputAdapter - ReusedExchange [i_item_sk,i_brand_id,i_class_id,i_category_id] #15 + ReusedExchange [i_item_sk,i_brand_id,i_class_id,i_category_id] #14 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b/explain.txt index ae5cf49cbb21b..69be776d2ac28 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b/explain.txt @@ -1,90 +1,88 @@ == Physical Plan == -TakeOrderedAndProject (86) -+- * BroadcastHashJoin Inner BuildRight (85) - :- * Filter (68) - : +- * HashAggregate (67) - : +- Exchange (66) - : +- * HashAggregate (65) - : +- * Project (64) - : +- * BroadcastHashJoin Inner BuildRight (63) - : :- * Project (61) - : : +- * BroadcastHashJoin Inner BuildRight (60) - : : :- * BroadcastHashJoin LeftSemi BuildRight (53) +TakeOrderedAndProject (84) ++- * BroadcastHashJoin Inner BuildRight (83) + :- * Filter (66) + : +- * HashAggregate (65) + : +- Exchange (64) + : +- * HashAggregate (63) + : +- * Project (62) + : +- * BroadcastHashJoin Inner BuildRight (61) + : :- * Project (59) + : : +- * BroadcastHashJoin Inner BuildRight (58) + : : :- * BroadcastHashJoin LeftSemi BuildRight (51) : : : :- * Filter (3) : : : : +- * ColumnarToRow (2) : : : : +- Scan parquet default.store_sales (1) - : : : +- BroadcastExchange (52) - : : : +- * Project (51) - : : : +- * BroadcastHashJoin Inner BuildRight (50) + : : : +- BroadcastExchange (50) + : : : +- * Project (49) + : : : +- * BroadcastHashJoin Inner BuildRight (48) : : : :- * Filter (6) : : : : +- * ColumnarToRow (5) : : : : +- Scan parquet default.item (4) - : : : +- BroadcastExchange (49) - : : : +- * HashAggregate (48) - : : : +- * HashAggregate (47) - : : : +- * BroadcastHashJoin LeftSemi BuildRight (46) - : : : :- * HashAggregate (35) - : : : : +- Exchange (34) - : : : : +- * HashAggregate (33) - : : : : +- * Project (32) - : : : : +- * BroadcastHashJoin Inner BuildRight (31) - : : : : :- * Project (29) - : : : : : +- * BroadcastHashJoin Inner BuildRight (28) - : : : : : :- * Filter (9) - : : : : : : +- * ColumnarToRow (8) - : : : : : : +- Scan parquet default.store_sales (7) - : : : : : +- BroadcastExchange (27) - : : : : : +- * BroadcastHashJoin LeftSemi BuildRight (26) - : : : : : :- * Filter (12) - : : : : : : +- * ColumnarToRow (11) - : : : : : : +- Scan parquet default.item (10) - : : : : : +- BroadcastExchange (25) - : : : : : +- * Project (24) - : : : : : +- * BroadcastHashJoin Inner BuildRight (23) - : : : : : :- * Project (21) - : : : : : : +- * BroadcastHashJoin Inner BuildRight (20) - : : : : : : :- * Filter (15) - : : : : : : : +- * ColumnarToRow (14) - : : : : : : : +- Scan parquet default.catalog_sales (13) - : : : : : : +- BroadcastExchange (19) - : : : : : : +- * Filter (18) - : : : : : : +- * ColumnarToRow (17) - : : : : : : +- Scan parquet default.item (16) - : : : : : +- ReusedExchange (22) - : : : : +- ReusedExchange (30) - : : : +- BroadcastExchange (45) - : : : +- * Project (44) - : : : +- * BroadcastHashJoin Inner BuildRight (43) - : : : :- * Project (41) - : : : : +- * BroadcastHashJoin Inner BuildRight (40) - : : : : :- * Filter (38) - : : : : : +- * ColumnarToRow (37) - : : : : : +- Scan parquet default.web_sales (36) - : : : : +- ReusedExchange (39) - : : : +- ReusedExchange (42) - : : +- BroadcastExchange (59) - : : +- * BroadcastHashJoin LeftSemi BuildRight (58) - : : :- * Filter (56) - : : : +- * ColumnarToRow (55) - : : : +- Scan parquet default.item (54) - : : +- ReusedExchange (57) - : +- ReusedExchange (62) - +- BroadcastExchange (84) - +- * Filter (83) - +- * HashAggregate (82) - +- Exchange (81) - +- * HashAggregate (80) - +- * Project (79) - +- * BroadcastHashJoin Inner BuildRight (78) - :- * Project (76) - : +- * BroadcastHashJoin Inner BuildRight (75) - : :- * BroadcastHashJoin LeftSemi BuildRight (73) - : : :- * Filter (71) - : : : +- * ColumnarToRow (70) - : : : +- Scan parquet default.store_sales (69) - : : +- ReusedExchange (72) - : +- ReusedExchange (74) - +- ReusedExchange (77) + : : : +- BroadcastExchange (47) + : : : +- * BroadcastHashJoin LeftSemi BuildRight (46) + : : : :- * HashAggregate (35) + : : : : +- Exchange (34) + : : : : +- * HashAggregate (33) + : : : : +- * Project (32) + : : : : +- * BroadcastHashJoin Inner BuildRight (31) + : : : : :- * Project (29) + : : : : : +- * BroadcastHashJoin Inner BuildRight (28) + : : : : : :- * Filter (9) + : : : : : : +- * ColumnarToRow (8) + : : : : : : +- Scan parquet default.store_sales (7) + : : : : : +- BroadcastExchange (27) + : : : : : +- * BroadcastHashJoin LeftSemi BuildRight (26) + : : : : : :- * Filter (12) + : : : : : : +- * ColumnarToRow (11) + : : : : : : +- Scan parquet default.item (10) + : : : : : +- BroadcastExchange (25) + : : : : : +- * Project (24) + : : : : : +- * BroadcastHashJoin Inner BuildRight (23) + : : : : : :- * Project (21) + : : : : : : +- * BroadcastHashJoin Inner BuildRight (20) + : : : : : : :- * Filter (15) + : : : : : : : +- * ColumnarToRow (14) + : : : : : : : +- Scan parquet default.catalog_sales (13) + : : : : : : +- BroadcastExchange (19) + : : : : : : +- * Filter (18) + : : : : : : +- * ColumnarToRow (17) + : : : : : : +- Scan parquet default.item (16) + : : : : : +- ReusedExchange (22) + : : : : +- ReusedExchange (30) + : : : +- BroadcastExchange (45) + : : : +- * Project (44) + : : : +- * BroadcastHashJoin Inner BuildRight (43) + : : : :- * Project (41) + : : : : +- * BroadcastHashJoin Inner BuildRight (40) + : : : : :- * Filter (38) + : : : : : +- * ColumnarToRow (37) + : : : : : +- Scan parquet default.web_sales (36) + : : : : +- ReusedExchange (39) + : : : +- ReusedExchange (42) + : : +- BroadcastExchange (57) + : : +- * BroadcastHashJoin LeftSemi BuildRight (56) + : : :- * Filter (54) + : : : +- * ColumnarToRow (53) + : : : +- Scan parquet default.item (52) + : : +- ReusedExchange (55) + : +- ReusedExchange (60) + +- BroadcastExchange (82) + +- * Filter (81) + +- * HashAggregate (80) + +- Exchange (79) + +- * HashAggregate (78) + +- * Project (77) + +- * BroadcastHashJoin Inner BuildRight (76) + :- * Project (74) + : +- * BroadcastHashJoin Inner BuildRight (73) + : :- * BroadcastHashJoin LeftSemi BuildRight (71) + : : :- * Filter (69) + : : : +- * ColumnarToRow (68) + : : : +- Scan parquet default.store_sales (67) + : : +- ReusedExchange (70) + : +- ReusedExchange (72) + +- ReusedExchange (75) (1) Scan parquet default.store_sales @@ -187,7 +185,7 @@ Join condition: None Output [4]: [cs_sold_date_sk#18, i_brand_id#20, i_class_id#21, i_category_id#22] Input [6]: [cs_item_sk#17, cs_sold_date_sk#18, i_item_sk#19, i_brand_id#20, i_class_id#21, i_category_id#22] -(22) ReusedExchange [Reuses operator id: 119] +(22) ReusedExchange [Reuses operator id: 117] Output [1]: [d_date_sk#24] (23) BroadcastHashJoin [codegen id : 3] @@ -221,7 +219,7 @@ Join condition: None Output [4]: [ss_sold_date_sk#11, i_brand_id#14, i_class_id#15, i_category_id#16] Input [6]: [ss_item_sk#10, ss_sold_date_sk#11, i_item_sk#13, i_brand_id#14, i_class_id#15, i_category_id#16] -(30) ReusedExchange [Reuses operator id: 119] +(30) ReusedExchange [Reuses operator id: 117] Output [1]: [d_date_sk#27] (31) BroadcastHashJoin [codegen id : 6] @@ -278,7 +276,7 @@ Join condition: None Output [4]: [ws_sold_date_sk#33, i_brand_id#35, i_class_id#36, i_category_id#37] Input [6]: [ws_item_sk#32, ws_sold_date_sk#33, i_item_sk#34, i_brand_id#35, i_class_id#36, i_category_id#37] -(42) ReusedExchange [Reuses operator id: 119] +(42) ReusedExchange [Reuses operator id: 117] Output [1]: [d_date_sk#38] (43) BroadcastHashJoin [codegen id : 9] @@ -299,112 +297,98 @@ Left keys [6]: [coalesce(brand_id#28, 0), isnull(brand_id#28), coalesce(class_id Right keys [6]: [coalesce(i_brand_id#35, 0), isnull(i_brand_id#35), coalesce(i_class_id#36, 0), isnull(i_class_id#36), coalesce(i_category_id#37, 0), isnull(i_category_id#37)] Join condition: None -(47) HashAggregate [codegen id : 10] -Input [3]: [brand_id#28, class_id#29, category_id#30] -Keys [3]: [brand_id#28, class_id#29, category_id#30] -Functions: [] -Aggregate Attributes: [] -Results [3]: [brand_id#28, class_id#29, category_id#30] - -(48) HashAggregate [codegen id : 10] -Input [3]: [brand_id#28, class_id#29, category_id#30] -Keys [3]: [brand_id#28, class_id#29, category_id#30] -Functions: [] -Aggregate Attributes: [] -Results [3]: [brand_id#28, class_id#29, category_id#30] - -(49) BroadcastExchange +(47) BroadcastExchange Input [3]: [brand_id#28, class_id#29, category_id#30] Arguments: HashedRelationBroadcastMode(List(input[0, int, true], input[1, int, true], input[2, int, true]),false), [id=#40] -(50) BroadcastHashJoin [codegen id : 11] +(48) BroadcastHashJoin [codegen id : 11] Left keys [3]: [i_brand_id#7, i_class_id#8, i_category_id#9] Right keys [3]: [brand_id#28, class_id#29, category_id#30] Join condition: None -(51) Project [codegen id : 11] +(49) Project [codegen id : 11] Output [1]: [i_item_sk#6 AS ss_item_sk#41] Input [7]: [i_item_sk#6, i_brand_id#7, i_class_id#8, i_category_id#9, brand_id#28, class_id#29, category_id#30] -(52) BroadcastExchange +(50) BroadcastExchange Input [1]: [ss_item_sk#41] Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#42] -(53) BroadcastHashJoin [codegen id : 25] +(51) BroadcastHashJoin [codegen id : 25] Left keys [1]: [ss_item_sk#1] Right keys [1]: [ss_item_sk#41] Join condition: None -(54) Scan parquet default.item +(52) Scan parquet default.item Output [4]: [i_item_sk#43, i_brand_id#44, i_class_id#45, i_category_id#46] Batched: true Location [not included in comparison]/{warehouse_dir}/item] PushedFilters: [IsNotNull(i_item_sk), IsNotNull(i_brand_id), IsNotNull(i_class_id), IsNotNull(i_category_id)] ReadSchema: struct -(55) ColumnarToRow [codegen id : 23] +(53) ColumnarToRow [codegen id : 23] Input [4]: [i_item_sk#43, i_brand_id#44, i_class_id#45, i_category_id#46] -(56) Filter [codegen id : 23] +(54) Filter [codegen id : 23] Input [4]: [i_item_sk#43, i_brand_id#44, i_class_id#45, i_category_id#46] Condition : (((isnotnull(i_item_sk#43) AND isnotnull(i_brand_id#44)) AND isnotnull(i_class_id#45)) AND isnotnull(i_category_id#46)) -(57) ReusedExchange [Reuses operator id: 52] +(55) ReusedExchange [Reuses operator id: 50] Output [1]: [ss_item_sk#41] -(58) BroadcastHashJoin [codegen id : 23] +(56) BroadcastHashJoin [codegen id : 23] Left keys [1]: [i_item_sk#43] Right keys [1]: [ss_item_sk#41] Join condition: None -(59) BroadcastExchange +(57) BroadcastExchange Input [4]: [i_item_sk#43, i_brand_id#44, i_class_id#45, i_category_id#46] Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#47] -(60) BroadcastHashJoin [codegen id : 25] +(58) BroadcastHashJoin [codegen id : 25] Left keys [1]: [ss_item_sk#1] Right keys [1]: [i_item_sk#43] Join condition: None -(61) Project [codegen id : 25] +(59) Project [codegen id : 25] Output [6]: [ss_quantity#2, ss_list_price#3, ss_sold_date_sk#4, i_brand_id#44, i_class_id#45, i_category_id#46] Input [8]: [ss_item_sk#1, ss_quantity#2, ss_list_price#3, ss_sold_date_sk#4, i_item_sk#43, i_brand_id#44, i_class_id#45, i_category_id#46] -(62) ReusedExchange [Reuses operator id: 110] +(60) ReusedExchange [Reuses operator id: 108] Output [1]: [d_date_sk#48] -(63) BroadcastHashJoin [codegen id : 25] +(61) BroadcastHashJoin [codegen id : 25] Left keys [1]: [ss_sold_date_sk#4] Right keys [1]: [d_date_sk#48] Join condition: None -(64) Project [codegen id : 25] +(62) Project [codegen id : 25] Output [5]: [ss_quantity#2, ss_list_price#3, i_brand_id#44, i_class_id#45, i_category_id#46] Input [7]: [ss_quantity#2, ss_list_price#3, ss_sold_date_sk#4, i_brand_id#44, i_class_id#45, i_category_id#46, d_date_sk#48] -(65) HashAggregate [codegen id : 25] +(63) HashAggregate [codegen id : 25] Input [5]: [ss_quantity#2, ss_list_price#3, i_brand_id#44, i_class_id#45, i_category_id#46] Keys [3]: [i_brand_id#44, i_class_id#45, i_category_id#46] -Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true)), partial_count(1)] +Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2))), partial_count(1)] Aggregate Attributes [3]: [sum#49, isEmpty#50, count#51] Results [6]: [i_brand_id#44, i_class_id#45, i_category_id#46, sum#52, isEmpty#53, count#54] -(66) Exchange +(64) Exchange Input [6]: [i_brand_id#44, i_class_id#45, i_category_id#46, sum#52, isEmpty#53, count#54] Arguments: hashpartitioning(i_brand_id#44, i_class_id#45, i_category_id#46, 5), ENSURE_REQUIREMENTS, [id=#55] -(67) HashAggregate [codegen id : 52] +(65) HashAggregate [codegen id : 52] Input [6]: [i_brand_id#44, i_class_id#45, i_category_id#46, sum#52, isEmpty#53, count#54] Keys [3]: [i_brand_id#44, i_class_id#45, i_category_id#46] -Functions [2]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true)), count(1)] -Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true))#56, count(1)#57] -Results [6]: [store AS channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true))#56 AS sales#59, count(1)#57 AS number_sales#60] +Functions [2]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2))), count(1)] +Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2)))#56, count(1)#57] +Results [6]: [store AS channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2)))#56 AS sales#59, count(1)#57 AS number_sales#60] -(68) Filter [codegen id : 52] +(66) Filter [codegen id : 52] Input [6]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sales#59, number_sales#60] Condition : (isnotnull(sales#59) AND (cast(sales#59 as decimal(32,6)) > cast(Subquery scalar-subquery#61, [id=#62] as decimal(32,6)))) -(69) Scan parquet default.store_sales +(67) Scan parquet default.store_sales Output [4]: [ss_item_sk#63, ss_quantity#64, ss_list_price#65, ss_sold_date_sk#66] Batched: true Location: InMemoryFileIndex [] @@ -412,278 +396,278 @@ PartitionFilters: [isnotnull(ss_sold_date_sk#66), dynamicpruningexpression(ss_so PushedFilters: [IsNotNull(ss_item_sk)] ReadSchema: struct -(70) ColumnarToRow [codegen id : 50] +(68) ColumnarToRow [codegen id : 50] Input [4]: [ss_item_sk#63, ss_quantity#64, ss_list_price#65, ss_sold_date_sk#66] -(71) Filter [codegen id : 50] +(69) Filter [codegen id : 50] Input [4]: [ss_item_sk#63, ss_quantity#64, ss_list_price#65, ss_sold_date_sk#66] Condition : isnotnull(ss_item_sk#63) -(72) ReusedExchange [Reuses operator id: 52] +(70) ReusedExchange [Reuses operator id: 50] Output [1]: [ss_item_sk#41] -(73) BroadcastHashJoin [codegen id : 50] +(71) BroadcastHashJoin [codegen id : 50] Left keys [1]: [ss_item_sk#63] Right keys [1]: [ss_item_sk#41] Join condition: None -(74) ReusedExchange [Reuses operator id: 59] +(72) ReusedExchange [Reuses operator id: 57] Output [4]: [i_item_sk#68, i_brand_id#69, i_class_id#70, i_category_id#71] -(75) BroadcastHashJoin [codegen id : 50] +(73) BroadcastHashJoin [codegen id : 50] Left keys [1]: [ss_item_sk#63] Right keys [1]: [i_item_sk#68] Join condition: None -(76) Project [codegen id : 50] +(74) Project [codegen id : 50] Output [6]: [ss_quantity#64, ss_list_price#65, ss_sold_date_sk#66, i_brand_id#69, i_class_id#70, i_category_id#71] Input [8]: [ss_item_sk#63, ss_quantity#64, ss_list_price#65, ss_sold_date_sk#66, i_item_sk#68, i_brand_id#69, i_class_id#70, i_category_id#71] -(77) ReusedExchange [Reuses operator id: 124] +(75) ReusedExchange [Reuses operator id: 122] Output [1]: [d_date_sk#72] -(78) BroadcastHashJoin [codegen id : 50] +(76) BroadcastHashJoin [codegen id : 50] Left keys [1]: [ss_sold_date_sk#66] Right keys [1]: [d_date_sk#72] Join condition: None -(79) Project [codegen id : 50] +(77) Project [codegen id : 50] Output [5]: [ss_quantity#64, ss_list_price#65, i_brand_id#69, i_class_id#70, i_category_id#71] Input [7]: [ss_quantity#64, ss_list_price#65, ss_sold_date_sk#66, i_brand_id#69, i_class_id#70, i_category_id#71, d_date_sk#72] -(80) HashAggregate [codegen id : 50] +(78) HashAggregate [codegen id : 50] Input [5]: [ss_quantity#64, ss_list_price#65, i_brand_id#69, i_class_id#70, i_category_id#71] Keys [3]: [i_brand_id#69, i_class_id#70, i_category_id#71] -Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#64 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#65 as decimal(12,2)))), DecimalType(18,2), true)), partial_count(1)] +Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#64 as decimal(12,2))) * promote_precision(cast(ss_list_price#65 as decimal(12,2)))), DecimalType(18,2))), partial_count(1)] Aggregate Attributes [3]: [sum#73, isEmpty#74, count#75] Results [6]: [i_brand_id#69, i_class_id#70, i_category_id#71, sum#76, isEmpty#77, count#78] -(81) Exchange +(79) Exchange Input [6]: [i_brand_id#69, i_class_id#70, i_category_id#71, sum#76, isEmpty#77, count#78] Arguments: hashpartitioning(i_brand_id#69, i_class_id#70, i_category_id#71, 5), ENSURE_REQUIREMENTS, [id=#79] -(82) HashAggregate [codegen id : 51] +(80) HashAggregate [codegen id : 51] Input [6]: [i_brand_id#69, i_class_id#70, i_category_id#71, sum#76, isEmpty#77, count#78] Keys [3]: [i_brand_id#69, i_class_id#70, i_category_id#71] -Functions [2]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#64 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#65 as decimal(12,2)))), DecimalType(18,2), true)), count(1)] -Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#64 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#65 as decimal(12,2)))), DecimalType(18,2), true))#80, count(1)#81] -Results [6]: [store AS channel#82, i_brand_id#69, i_class_id#70, i_category_id#71, sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#64 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#65 as decimal(12,2)))), DecimalType(18,2), true))#80 AS sales#83, count(1)#81 AS number_sales#84] +Functions [2]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#64 as decimal(12,2))) * promote_precision(cast(ss_list_price#65 as decimal(12,2)))), DecimalType(18,2))), count(1)] +Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#64 as decimal(12,2))) * promote_precision(cast(ss_list_price#65 as decimal(12,2)))), DecimalType(18,2)))#80, count(1)#81] +Results [6]: [store AS channel#82, i_brand_id#69, i_class_id#70, i_category_id#71, sum(CheckOverflow((promote_precision(cast(ss_quantity#64 as decimal(12,2))) * promote_precision(cast(ss_list_price#65 as decimal(12,2)))), DecimalType(18,2)))#80 AS sales#83, count(1)#81 AS number_sales#84] -(83) Filter [codegen id : 51] +(81) Filter [codegen id : 51] Input [6]: [channel#82, i_brand_id#69, i_class_id#70, i_category_id#71, sales#83, number_sales#84] Condition : (isnotnull(sales#83) AND (cast(sales#83 as decimal(32,6)) > cast(ReusedSubquery Subquery scalar-subquery#61, [id=#62] as decimal(32,6)))) -(84) BroadcastExchange +(82) BroadcastExchange Input [6]: [channel#82, i_brand_id#69, i_class_id#70, i_category_id#71, sales#83, number_sales#84] Arguments: HashedRelationBroadcastMode(List(input[1, int, true], input[2, int, true], input[3, int, true]),false), [id=#85] -(85) BroadcastHashJoin [codegen id : 52] +(83) BroadcastHashJoin [codegen id : 52] Left keys [3]: [i_brand_id#44, i_class_id#45, i_category_id#46] Right keys [3]: [i_brand_id#69, i_class_id#70, i_category_id#71] Join condition: None -(86) TakeOrderedAndProject +(84) TakeOrderedAndProject Input [12]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sales#59, number_sales#60, channel#82, i_brand_id#69, i_class_id#70, i_category_id#71, sales#83, number_sales#84] Arguments: 100, [i_brand_id#44 ASC NULLS FIRST, i_class_id#45 ASC NULLS FIRST, i_category_id#46 ASC NULLS FIRST], [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sales#59, number_sales#60, channel#82, i_brand_id#69, i_class_id#70, i_category_id#71, sales#83, number_sales#84] ===== Subqueries ===== -Subquery:1 Hosting operator id = 68 Hosting Expression = Subquery scalar-subquery#61, [id=#62] -* HashAggregate (105) -+- Exchange (104) - +- * HashAggregate (103) - +- Union (102) - :- * Project (91) - : +- * BroadcastHashJoin Inner BuildRight (90) - : :- * ColumnarToRow (88) - : : +- Scan parquet default.store_sales (87) - : +- ReusedExchange (89) - :- * Project (96) - : +- * BroadcastHashJoin Inner BuildRight (95) - : :- * ColumnarToRow (93) - : : +- Scan parquet default.catalog_sales (92) - : +- ReusedExchange (94) - +- * Project (101) - +- * BroadcastHashJoin Inner BuildRight (100) - :- * ColumnarToRow (98) - : +- Scan parquet default.web_sales (97) - +- ReusedExchange (99) - - -(87) Scan parquet default.store_sales +Subquery:1 Hosting operator id = 66 Hosting Expression = Subquery scalar-subquery#61, [id=#62] +* HashAggregate (103) ++- Exchange (102) + +- * HashAggregate (101) + +- Union (100) + :- * Project (89) + : +- * BroadcastHashJoin Inner BuildRight (88) + : :- * ColumnarToRow (86) + : : +- Scan parquet default.store_sales (85) + : +- ReusedExchange (87) + :- * Project (94) + : +- * BroadcastHashJoin Inner BuildRight (93) + : :- * ColumnarToRow (91) + : : +- Scan parquet default.catalog_sales (90) + : +- ReusedExchange (92) + +- * Project (99) + +- * BroadcastHashJoin Inner BuildRight (98) + :- * ColumnarToRow (96) + : +- Scan parquet default.web_sales (95) + +- ReusedExchange (97) + + +(85) Scan parquet default.store_sales Output [3]: [ss_quantity#86, ss_list_price#87, ss_sold_date_sk#88] Batched: true Location: InMemoryFileIndex [] PartitionFilters: [isnotnull(ss_sold_date_sk#88), dynamicpruningexpression(ss_sold_date_sk#88 IN dynamicpruning#12)] ReadSchema: struct -(88) ColumnarToRow [codegen id : 2] +(86) ColumnarToRow [codegen id : 2] Input [3]: [ss_quantity#86, ss_list_price#87, ss_sold_date_sk#88] -(89) ReusedExchange [Reuses operator id: 119] +(87) ReusedExchange [Reuses operator id: 117] Output [1]: [d_date_sk#89] -(90) BroadcastHashJoin [codegen id : 2] +(88) BroadcastHashJoin [codegen id : 2] Left keys [1]: [ss_sold_date_sk#88] Right keys [1]: [d_date_sk#89] Join condition: None -(91) Project [codegen id : 2] +(89) Project [codegen id : 2] Output [2]: [ss_quantity#86 AS quantity#90, ss_list_price#87 AS list_price#91] Input [4]: [ss_quantity#86, ss_list_price#87, ss_sold_date_sk#88, d_date_sk#89] -(92) Scan parquet default.catalog_sales +(90) Scan parquet default.catalog_sales Output [3]: [cs_quantity#92, cs_list_price#93, cs_sold_date_sk#94] Batched: true Location: InMemoryFileIndex [] PartitionFilters: [isnotnull(cs_sold_date_sk#94), dynamicpruningexpression(cs_sold_date_sk#94 IN dynamicpruning#12)] ReadSchema: struct -(93) ColumnarToRow [codegen id : 4] +(91) ColumnarToRow [codegen id : 4] Input [3]: [cs_quantity#92, cs_list_price#93, cs_sold_date_sk#94] -(94) ReusedExchange [Reuses operator id: 119] +(92) ReusedExchange [Reuses operator id: 117] Output [1]: [d_date_sk#95] -(95) BroadcastHashJoin [codegen id : 4] +(93) BroadcastHashJoin [codegen id : 4] Left keys [1]: [cs_sold_date_sk#94] Right keys [1]: [d_date_sk#95] Join condition: None -(96) Project [codegen id : 4] +(94) Project [codegen id : 4] Output [2]: [cs_quantity#92 AS quantity#96, cs_list_price#93 AS list_price#97] Input [4]: [cs_quantity#92, cs_list_price#93, cs_sold_date_sk#94, d_date_sk#95] -(97) Scan parquet default.web_sales +(95) Scan parquet default.web_sales Output [3]: [ws_quantity#98, ws_list_price#99, ws_sold_date_sk#100] Batched: true Location: InMemoryFileIndex [] PartitionFilters: [isnotnull(ws_sold_date_sk#100), dynamicpruningexpression(ws_sold_date_sk#100 IN dynamicpruning#12)] ReadSchema: struct -(98) ColumnarToRow [codegen id : 6] +(96) ColumnarToRow [codegen id : 6] Input [3]: [ws_quantity#98, ws_list_price#99, ws_sold_date_sk#100] -(99) ReusedExchange [Reuses operator id: 119] +(97) ReusedExchange [Reuses operator id: 117] Output [1]: [d_date_sk#101] -(100) BroadcastHashJoin [codegen id : 6] +(98) BroadcastHashJoin [codegen id : 6] Left keys [1]: [ws_sold_date_sk#100] Right keys [1]: [d_date_sk#101] Join condition: None -(101) Project [codegen id : 6] +(99) Project [codegen id : 6] Output [2]: [ws_quantity#98 AS quantity#102, ws_list_price#99 AS list_price#103] Input [4]: [ws_quantity#98, ws_list_price#99, ws_sold_date_sk#100, d_date_sk#101] -(102) Union +(100) Union -(103) HashAggregate [codegen id : 7] +(101) HashAggregate [codegen id : 7] Input [2]: [quantity#90, list_price#91] Keys: [] -Functions [1]: [partial_avg(CheckOverflow((promote_precision(cast(cast(quantity#90 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#91 as decimal(12,2)))), DecimalType(18,2), true))] +Functions [1]: [partial_avg(CheckOverflow((promote_precision(cast(quantity#90 as decimal(12,2))) * promote_precision(cast(list_price#91 as decimal(12,2)))), DecimalType(18,2)))] Aggregate Attributes [2]: [sum#104, count#105] Results [2]: [sum#106, count#107] -(104) Exchange +(102) Exchange Input [2]: [sum#106, count#107] Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#108] -(105) HashAggregate [codegen id : 8] +(103) HashAggregate [codegen id : 8] Input [2]: [sum#106, count#107] Keys: [] -Functions [1]: [avg(CheckOverflow((promote_precision(cast(cast(quantity#90 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#91 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [1]: [avg(CheckOverflow((promote_precision(cast(cast(quantity#90 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#91 as decimal(12,2)))), DecimalType(18,2), true))#109] -Results [1]: [avg(CheckOverflow((promote_precision(cast(cast(quantity#90 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#91 as decimal(12,2)))), DecimalType(18,2), true))#109 AS average_sales#110] +Functions [1]: [avg(CheckOverflow((promote_precision(cast(quantity#90 as decimal(12,2))) * promote_precision(cast(list_price#91 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [1]: [avg(CheckOverflow((promote_precision(cast(quantity#90 as decimal(12,2))) * promote_precision(cast(list_price#91 as decimal(12,2)))), DecimalType(18,2)))#109] +Results [1]: [avg(CheckOverflow((promote_precision(cast(quantity#90 as decimal(12,2))) * promote_precision(cast(list_price#91 as decimal(12,2)))), DecimalType(18,2)))#109 AS average_sales#110] -Subquery:2 Hosting operator id = 87 Hosting Expression = ss_sold_date_sk#88 IN dynamicpruning#12 +Subquery:2 Hosting operator id = 85 Hosting Expression = ss_sold_date_sk#88 IN dynamicpruning#12 -Subquery:3 Hosting operator id = 92 Hosting Expression = cs_sold_date_sk#94 IN dynamicpruning#12 +Subquery:3 Hosting operator id = 90 Hosting Expression = cs_sold_date_sk#94 IN dynamicpruning#12 -Subquery:4 Hosting operator id = 97 Hosting Expression = ws_sold_date_sk#100 IN dynamicpruning#12 +Subquery:4 Hosting operator id = 95 Hosting Expression = ws_sold_date_sk#100 IN dynamicpruning#12 Subquery:5 Hosting operator id = 1 Hosting Expression = ss_sold_date_sk#4 IN dynamicpruning#5 -BroadcastExchange (110) -+- * Project (109) - +- * Filter (108) - +- * ColumnarToRow (107) - +- Scan parquet default.date_dim (106) +BroadcastExchange (108) ++- * Project (107) + +- * Filter (106) + +- * ColumnarToRow (105) + +- Scan parquet default.date_dim (104) -(106) Scan parquet default.date_dim +(104) Scan parquet default.date_dim Output [2]: [d_date_sk#48, d_week_seq#111] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_week_seq), IsNotNull(d_date_sk)] ReadSchema: struct -(107) ColumnarToRow [codegen id : 1] +(105) ColumnarToRow [codegen id : 1] Input [2]: [d_date_sk#48, d_week_seq#111] -(108) Filter [codegen id : 1] +(106) Filter [codegen id : 1] Input [2]: [d_date_sk#48, d_week_seq#111] Condition : ((isnotnull(d_week_seq#111) AND (d_week_seq#111 = Subquery scalar-subquery#112, [id=#113])) AND isnotnull(d_date_sk#48)) -(109) Project [codegen id : 1] +(107) Project [codegen id : 1] Output [1]: [d_date_sk#48] Input [2]: [d_date_sk#48, d_week_seq#111] -(110) BroadcastExchange +(108) BroadcastExchange Input [1]: [d_date_sk#48] Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#114] -Subquery:6 Hosting operator id = 108 Hosting Expression = Subquery scalar-subquery#112, [id=#113] -* Project (114) -+- * Filter (113) - +- * ColumnarToRow (112) - +- Scan parquet default.date_dim (111) +Subquery:6 Hosting operator id = 106 Hosting Expression = Subquery scalar-subquery#112, [id=#113] +* Project (112) ++- * Filter (111) + +- * ColumnarToRow (110) + +- Scan parquet default.date_dim (109) -(111) Scan parquet default.date_dim +(109) Scan parquet default.date_dim Output [4]: [d_week_seq#115, d_year#116, d_moy#117, d_dom#118] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), IsNotNull(d_moy), IsNotNull(d_dom), EqualTo(d_year,2000), EqualTo(d_moy,12), EqualTo(d_dom,11)] ReadSchema: struct -(112) ColumnarToRow [codegen id : 1] +(110) ColumnarToRow [codegen id : 1] Input [4]: [d_week_seq#115, d_year#116, d_moy#117, d_dom#118] -(113) Filter [codegen id : 1] +(111) Filter [codegen id : 1] Input [4]: [d_week_seq#115, d_year#116, d_moy#117, d_dom#118] Condition : (((((isnotnull(d_year#116) AND isnotnull(d_moy#117)) AND isnotnull(d_dom#118)) AND (d_year#116 = 2000)) AND (d_moy#117 = 12)) AND (d_dom#118 = 11)) -(114) Project [codegen id : 1] +(112) Project [codegen id : 1] Output [1]: [d_week_seq#115] Input [4]: [d_week_seq#115, d_year#116, d_moy#117, d_dom#118] Subquery:7 Hosting operator id = 7 Hosting Expression = ss_sold_date_sk#11 IN dynamicpruning#12 -BroadcastExchange (119) -+- * Project (118) - +- * Filter (117) - +- * ColumnarToRow (116) - +- Scan parquet default.date_dim (115) +BroadcastExchange (117) ++- * Project (116) + +- * Filter (115) + +- * ColumnarToRow (114) + +- Scan parquet default.date_dim (113) -(115) Scan parquet default.date_dim +(113) Scan parquet default.date_dim Output [2]: [d_date_sk#27, d_year#119] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), GreaterThanOrEqual(d_year,1999), LessThanOrEqual(d_year,2001), IsNotNull(d_date_sk)] ReadSchema: struct -(116) ColumnarToRow [codegen id : 1] +(114) ColumnarToRow [codegen id : 1] Input [2]: [d_date_sk#27, d_year#119] -(117) Filter [codegen id : 1] +(115) Filter [codegen id : 1] Input [2]: [d_date_sk#27, d_year#119] Condition : (((isnotnull(d_year#119) AND (d_year#119 >= 1999)) AND (d_year#119 <= 2001)) AND isnotnull(d_date_sk#27)) -(118) Project [codegen id : 1] +(116) Project [codegen id : 1] Output [1]: [d_date_sk#27] Input [2]: [d_date_sk#27, d_year#119] -(119) BroadcastExchange +(117) BroadcastExchange Input [1]: [d_date_sk#27] Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#120] @@ -691,60 +675,60 @@ Subquery:8 Hosting operator id = 13 Hosting Expression = cs_sold_date_sk#18 IN d Subquery:9 Hosting operator id = 36 Hosting Expression = ws_sold_date_sk#33 IN dynamicpruning#12 -Subquery:10 Hosting operator id = 83 Hosting Expression = ReusedSubquery Subquery scalar-subquery#61, [id=#62] +Subquery:10 Hosting operator id = 81 Hosting Expression = ReusedSubquery Subquery scalar-subquery#61, [id=#62] -Subquery:11 Hosting operator id = 69 Hosting Expression = ss_sold_date_sk#66 IN dynamicpruning#67 -BroadcastExchange (124) -+- * Project (123) - +- * Filter (122) - +- * ColumnarToRow (121) - +- Scan parquet default.date_dim (120) +Subquery:11 Hosting operator id = 67 Hosting Expression = ss_sold_date_sk#66 IN dynamicpruning#67 +BroadcastExchange (122) ++- * Project (121) + +- * Filter (120) + +- * ColumnarToRow (119) + +- Scan parquet default.date_dim (118) -(120) Scan parquet default.date_dim +(118) Scan parquet default.date_dim Output [2]: [d_date_sk#72, d_week_seq#121] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_week_seq), IsNotNull(d_date_sk)] ReadSchema: struct -(121) ColumnarToRow [codegen id : 1] +(119) ColumnarToRow [codegen id : 1] Input [2]: [d_date_sk#72, d_week_seq#121] -(122) Filter [codegen id : 1] +(120) Filter [codegen id : 1] Input [2]: [d_date_sk#72, d_week_seq#121] Condition : ((isnotnull(d_week_seq#121) AND (d_week_seq#121 = Subquery scalar-subquery#122, [id=#123])) AND isnotnull(d_date_sk#72)) -(123) Project [codegen id : 1] +(121) Project [codegen id : 1] Output [1]: [d_date_sk#72] Input [2]: [d_date_sk#72, d_week_seq#121] -(124) BroadcastExchange +(122) BroadcastExchange Input [1]: [d_date_sk#72] Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#124] -Subquery:12 Hosting operator id = 122 Hosting Expression = Subquery scalar-subquery#122, [id=#123] -* Project (128) -+- * Filter (127) - +- * ColumnarToRow (126) - +- Scan parquet default.date_dim (125) +Subquery:12 Hosting operator id = 120 Hosting Expression = Subquery scalar-subquery#122, [id=#123] +* Project (126) ++- * Filter (125) + +- * ColumnarToRow (124) + +- Scan parquet default.date_dim (123) -(125) Scan parquet default.date_dim +(123) Scan parquet default.date_dim Output [4]: [d_week_seq#125, d_year#126, d_moy#127, d_dom#128] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), IsNotNull(d_moy), IsNotNull(d_dom), EqualTo(d_year,1999), EqualTo(d_moy,12), EqualTo(d_dom,11)] ReadSchema: struct -(126) ColumnarToRow [codegen id : 1] +(124) ColumnarToRow [codegen id : 1] Input [4]: [d_week_seq#125, d_year#126, d_moy#127, d_dom#128] -(127) Filter [codegen id : 1] +(125) Filter [codegen id : 1] Input [4]: [d_week_seq#125, d_year#126, d_moy#127, d_dom#128] Condition : (((((isnotnull(d_year#126) AND isnotnull(d_moy#127)) AND isnotnull(d_dom#128)) AND (d_year#126 = 1999)) AND (d_moy#127 = 12)) AND (d_dom#128 = 11)) -(128) Project [codegen id : 1] +(126) Project [codegen id : 1] Output [1]: [d_week_seq#125] Input [4]: [d_week_seq#125, d_year#126, d_moy#127, d_dom#128] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b/simplified.txt index 2df0810ddba28..259178d0e432f 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b/simplified.txt @@ -4,7 +4,7 @@ TakeOrderedAndProject [i_brand_id,i_class_id,i_category_id,channel,sales,number_ Filter [sales] Subquery #4 WholeStageCodegen (8) - HashAggregate [sum,count] [avg(CheckOverflow((promote_precision(cast(cast(quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price as decimal(12,2)))), DecimalType(18,2), true)),average_sales,sum,count] + HashAggregate [sum,count] [avg(CheckOverflow((promote_precision(cast(quantity as decimal(12,2))) * promote_precision(cast(list_price as decimal(12,2)))), DecimalType(18,2))),average_sales,sum,count] InputAdapter Exchange #12 WholeStageCodegen (7) @@ -38,7 +38,7 @@ TakeOrderedAndProject [i_brand_id,i_class_id,i_category_id,channel,sales,number_ ReusedSubquery [d_date_sk] #3 InputAdapter ReusedExchange [d_date_sk] #6 - HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price as decimal(12,2)))), DecimalType(18,2), true)),count(1),channel,sales,number_sales,sum,isEmpty,count] + HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(ss_quantity as decimal(12,2))) * promote_precision(cast(ss_list_price as decimal(12,2)))), DecimalType(18,2))),count(1),channel,sales,number_sales,sum,isEmpty,count] InputAdapter Exchange [i_brand_id,i_class_id,i_category_id] #1 WholeStageCodegen (25) @@ -79,77 +79,75 @@ TakeOrderedAndProject [i_brand_id,i_class_id,i_category_id,channel,sales,number_ InputAdapter BroadcastExchange #4 WholeStageCodegen (10) - HashAggregate [brand_id,class_id,category_id] + BroadcastHashJoin [brand_id,class_id,category_id,i_brand_id,i_class_id,i_category_id] HashAggregate [brand_id,class_id,category_id] - BroadcastHashJoin [brand_id,class_id,category_id,i_brand_id,i_class_id,i_category_id] - HashAggregate [brand_id,class_id,category_id] - InputAdapter - Exchange [brand_id,class_id,category_id] #5 - WholeStageCodegen (6) - HashAggregate [brand_id,class_id,category_id] - Project [i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [ss_sold_date_sk,d_date_sk] - Project [ss_sold_date_sk,i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [ss_item_sk,i_item_sk] - Filter [ss_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.store_sales [ss_item_sk,ss_sold_date_sk] - SubqueryBroadcast [d_date_sk] #3 - BroadcastExchange #6 - WholeStageCodegen (1) - Project [d_date_sk] - Filter [d_year,d_date_sk] - ColumnarToRow - InputAdapter - Scan parquet default.date_dim [d_date_sk,d_year] - InputAdapter - BroadcastExchange #7 - WholeStageCodegen (4) - BroadcastHashJoin [i_brand_id,i_class_id,i_category_id,i_brand_id,i_class_id,i_category_id] - Filter [i_item_sk,i_brand_id,i_class_id,i_category_id] - ColumnarToRow - InputAdapter - Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] - InputAdapter - BroadcastExchange #8 - WholeStageCodegen (3) - Project [i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [cs_sold_date_sk,d_date_sk] - Project [cs_sold_date_sk,i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [cs_item_sk,i_item_sk] - Filter [cs_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.catalog_sales [cs_item_sk,cs_sold_date_sk] - ReusedSubquery [d_date_sk] #3 - InputAdapter - BroadcastExchange #9 - WholeStageCodegen (1) - Filter [i_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] - InputAdapter - ReusedExchange [d_date_sk] #6 - InputAdapter - ReusedExchange [d_date_sk] #6 - InputAdapter - BroadcastExchange #10 - WholeStageCodegen (9) + InputAdapter + Exchange [brand_id,class_id,category_id] #5 + WholeStageCodegen (6) + HashAggregate [brand_id,class_id,category_id] Project [i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [ws_sold_date_sk,d_date_sk] - Project [ws_sold_date_sk,i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [ws_item_sk,i_item_sk] - Filter [ws_item_sk] + BroadcastHashJoin [ss_sold_date_sk,d_date_sk] + Project [ss_sold_date_sk,i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [ss_item_sk,i_item_sk] + Filter [ss_item_sk] ColumnarToRow InputAdapter - Scan parquet default.web_sales [ws_item_sk,ws_sold_date_sk] - ReusedSubquery [d_date_sk] #3 + Scan parquet default.store_sales [ss_item_sk,ss_sold_date_sk] + SubqueryBroadcast [d_date_sk] #3 + BroadcastExchange #6 + WholeStageCodegen (1) + Project [d_date_sk] + Filter [d_year,d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_year] InputAdapter - ReusedExchange [i_item_sk,i_brand_id,i_class_id,i_category_id] #9 + BroadcastExchange #7 + WholeStageCodegen (4) + BroadcastHashJoin [i_brand_id,i_class_id,i_category_id,i_brand_id,i_class_id,i_category_id] + Filter [i_item_sk,i_brand_id,i_class_id,i_category_id] + ColumnarToRow + InputAdapter + Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] + InputAdapter + BroadcastExchange #8 + WholeStageCodegen (3) + Project [i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [cs_sold_date_sk,d_date_sk] + Project [cs_sold_date_sk,i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [cs_item_sk,i_item_sk] + Filter [cs_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.catalog_sales [cs_item_sk,cs_sold_date_sk] + ReusedSubquery [d_date_sk] #3 + InputAdapter + BroadcastExchange #9 + WholeStageCodegen (1) + Filter [i_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] + InputAdapter + ReusedExchange [d_date_sk] #6 InputAdapter ReusedExchange [d_date_sk] #6 + InputAdapter + BroadcastExchange #10 + WholeStageCodegen (9) + Project [i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [ws_sold_date_sk,d_date_sk] + Project [ws_sold_date_sk,i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [ws_item_sk,i_item_sk] + Filter [ws_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.web_sales [ws_item_sk,ws_sold_date_sk] + ReusedSubquery [d_date_sk] #3 + InputAdapter + ReusedExchange [i_item_sk,i_brand_id,i_class_id,i_category_id] #9 + InputAdapter + ReusedExchange [d_date_sk] #6 InputAdapter BroadcastExchange #11 WholeStageCodegen (23) @@ -167,7 +165,7 @@ TakeOrderedAndProject [i_brand_id,i_class_id,i_category_id,channel,sales,number_ WholeStageCodegen (51) Filter [sales] ReusedSubquery [average_sales] #4 - HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price as decimal(12,2)))), DecimalType(18,2), true)),count(1),channel,sales,number_sales,sum,isEmpty,count] + HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(ss_quantity as decimal(12,2))) * promote_precision(cast(ss_list_price as decimal(12,2)))), DecimalType(18,2))),count(1),channel,sales,number_sales,sum,isEmpty,count] InputAdapter Exchange [i_brand_id,i_class_id,i_category_id] #14 WholeStageCodegen (50) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q17.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q17.sf100/explain.txt index 16afa38901107..d61798f6ad06e 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q17.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q17.sf100/explain.txt @@ -1,50 +1,53 @@ == Physical Plan == -TakeOrderedAndProject (46) -+- * HashAggregate (45) - +- Exchange (44) - +- * HashAggregate (43) - +- * Project (42) - +- * SortMergeJoin Inner (41) - :- * Project (32) - : +- * SortMergeJoin Inner (31) - : :- * Sort (22) - : : +- * Project (21) - : : +- * SortMergeJoin Inner (20) - : : :- * Sort (14) - : : : +- Exchange (13) - : : : +- * Project (12) - : : : +- * BroadcastHashJoin Inner BuildRight (11) - : : : :- * Project (6) - : : : : +- * BroadcastHashJoin Inner BuildRight (5) - : : : : :- * Filter (3) - : : : : : +- * ColumnarToRow (2) - : : : : : +- Scan parquet default.store_sales (1) - : : : : +- ReusedExchange (4) - : : : +- BroadcastExchange (10) - : : : +- * Filter (9) - : : : +- * ColumnarToRow (8) - : : : +- Scan parquet default.store (7) - : : +- * Sort (19) - : : +- Exchange (18) - : : +- * Filter (17) - : : +- * ColumnarToRow (16) - : : +- Scan parquet default.item (15) - : +- * Sort (30) - : +- Exchange (29) - : +- * Project (28) - : +- * BroadcastHashJoin Inner BuildRight (27) - : :- * Filter (25) - : : +- * ColumnarToRow (24) - : : +- Scan parquet default.store_returns (23) - : +- ReusedExchange (26) - +- * Sort (40) - +- Exchange (39) - +- * Project (38) - +- * BroadcastHashJoin Inner BuildRight (37) - :- * Filter (35) - : +- * ColumnarToRow (34) - : +- Scan parquet default.catalog_sales (33) - +- ReusedExchange (36) +TakeOrderedAndProject (49) ++- * HashAggregate (48) + +- Exchange (47) + +- * HashAggregate (46) + +- * Project (45) + +- * SortMergeJoin Inner (44) + :- * Sort (35) + : +- Exchange (34) + : +- * Project (33) + : +- * SortMergeJoin Inner (32) + : :- * Sort (23) + : : +- Exchange (22) + : : +- * Project (21) + : : +- * SortMergeJoin Inner (20) + : : :- * Sort (14) + : : : +- Exchange (13) + : : : +- * Project (12) + : : : +- * BroadcastHashJoin Inner BuildRight (11) + : : : :- * Project (6) + : : : : +- * BroadcastHashJoin Inner BuildRight (5) + : : : : :- * Filter (3) + : : : : : +- * ColumnarToRow (2) + : : : : : +- Scan parquet default.store_sales (1) + : : : : +- ReusedExchange (4) + : : : +- BroadcastExchange (10) + : : : +- * Filter (9) + : : : +- * ColumnarToRow (8) + : : : +- Scan parquet default.store (7) + : : +- * Sort (19) + : : +- Exchange (18) + : : +- * Filter (17) + : : +- * ColumnarToRow (16) + : : +- Scan parquet default.item (15) + : +- * Sort (31) + : +- Exchange (30) + : +- * Project (29) + : +- * BroadcastHashJoin Inner BuildRight (28) + : :- * Filter (26) + : : +- * ColumnarToRow (25) + : : +- Scan parquet default.store_returns (24) + : +- ReusedExchange (27) + +- * Sort (43) + +- Exchange (42) + +- * Project (41) + +- * BroadcastHashJoin Inner BuildRight (40) + :- * Filter (38) + : +- * ColumnarToRow (37) + : +- Scan parquet default.catalog_sales (36) + +- ReusedExchange (39) (1) Scan parquet default.store_sales @@ -62,7 +65,7 @@ Input [6]: [ss_item_sk#1, ss_customer_sk#2, ss_store_sk#3, ss_ticket_number#4, s Input [6]: [ss_item_sk#1, ss_customer_sk#2, ss_store_sk#3, ss_ticket_number#4, ss_quantity#5, ss_sold_date_sk#6] Condition : (((isnotnull(ss_customer_sk#2) AND isnotnull(ss_item_sk#1)) AND isnotnull(ss_ticket_number#4)) AND isnotnull(ss_store_sk#3)) -(4) ReusedExchange [Reuses operator id: 51] +(4) ReusedExchange [Reuses operator id: 54] Output [1]: [d_date_sk#8] (5) BroadcastHashJoin [codegen id : 3] @@ -140,182 +143,194 @@ Join condition: None Output [7]: [ss_item_sk#1, ss_customer_sk#2, ss_ticket_number#4, ss_quantity#5, s_state#10, i_item_id#14, i_item_desc#15] Input [8]: [ss_item_sk#1, ss_customer_sk#2, ss_ticket_number#4, ss_quantity#5, s_state#10, i_item_sk#13, i_item_id#14, i_item_desc#15] -(22) Sort [codegen id : 7] +(22) Exchange +Input [7]: [ss_item_sk#1, ss_customer_sk#2, ss_ticket_number#4, ss_quantity#5, s_state#10, i_item_id#14, i_item_desc#15] +Arguments: hashpartitioning(ss_customer_sk#2, ss_item_sk#1, ss_ticket_number#4, 5), ENSURE_REQUIREMENTS, [id=#17] + +(23) Sort [codegen id : 8] Input [7]: [ss_item_sk#1, ss_customer_sk#2, ss_ticket_number#4, ss_quantity#5, s_state#10, i_item_id#14, i_item_desc#15] Arguments: [ss_customer_sk#2 ASC NULLS FIRST, ss_item_sk#1 ASC NULLS FIRST, ss_ticket_number#4 ASC NULLS FIRST], false, 0 -(23) Scan parquet default.store_returns -Output [5]: [sr_item_sk#17, sr_customer_sk#18, sr_ticket_number#19, sr_return_quantity#20, sr_returned_date_sk#21] +(24) Scan parquet default.store_returns +Output [5]: [sr_item_sk#18, sr_customer_sk#19, sr_ticket_number#20, sr_return_quantity#21, sr_returned_date_sk#22] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(sr_returned_date_sk#21), dynamicpruningexpression(sr_returned_date_sk#21 IN dynamicpruning#22)] +PartitionFilters: [isnotnull(sr_returned_date_sk#22), dynamicpruningexpression(sr_returned_date_sk#22 IN dynamicpruning#23)] PushedFilters: [IsNotNull(sr_customer_sk), IsNotNull(sr_item_sk), IsNotNull(sr_ticket_number)] ReadSchema: struct -(24) ColumnarToRow [codegen id : 9] -Input [5]: [sr_item_sk#17, sr_customer_sk#18, sr_ticket_number#19, sr_return_quantity#20, sr_returned_date_sk#21] +(25) ColumnarToRow [codegen id : 10] +Input [5]: [sr_item_sk#18, sr_customer_sk#19, sr_ticket_number#20, sr_return_quantity#21, sr_returned_date_sk#22] -(25) Filter [codegen id : 9] -Input [5]: [sr_item_sk#17, sr_customer_sk#18, sr_ticket_number#19, sr_return_quantity#20, sr_returned_date_sk#21] -Condition : ((isnotnull(sr_customer_sk#18) AND isnotnull(sr_item_sk#17)) AND isnotnull(sr_ticket_number#19)) +(26) Filter [codegen id : 10] +Input [5]: [sr_item_sk#18, sr_customer_sk#19, sr_ticket_number#20, sr_return_quantity#21, sr_returned_date_sk#22] +Condition : ((isnotnull(sr_customer_sk#19) AND isnotnull(sr_item_sk#18)) AND isnotnull(sr_ticket_number#20)) -(26) ReusedExchange [Reuses operator id: 56] -Output [1]: [d_date_sk#23] +(27) ReusedExchange [Reuses operator id: 59] +Output [1]: [d_date_sk#24] -(27) BroadcastHashJoin [codegen id : 9] -Left keys [1]: [sr_returned_date_sk#21] -Right keys [1]: [d_date_sk#23] +(28) BroadcastHashJoin [codegen id : 10] +Left keys [1]: [sr_returned_date_sk#22] +Right keys [1]: [d_date_sk#24] Join condition: None -(28) Project [codegen id : 9] -Output [4]: [sr_item_sk#17, sr_customer_sk#18, sr_ticket_number#19, sr_return_quantity#20] -Input [6]: [sr_item_sk#17, sr_customer_sk#18, sr_ticket_number#19, sr_return_quantity#20, sr_returned_date_sk#21, d_date_sk#23] +(29) Project [codegen id : 10] +Output [4]: [sr_item_sk#18, sr_customer_sk#19, sr_ticket_number#20, sr_return_quantity#21] +Input [6]: [sr_item_sk#18, sr_customer_sk#19, sr_ticket_number#20, sr_return_quantity#21, sr_returned_date_sk#22, d_date_sk#24] -(29) Exchange -Input [4]: [sr_item_sk#17, sr_customer_sk#18, sr_ticket_number#19, sr_return_quantity#20] -Arguments: hashpartitioning(sr_item_sk#17, 5), ENSURE_REQUIREMENTS, [id=#24] +(30) Exchange +Input [4]: [sr_item_sk#18, sr_customer_sk#19, sr_ticket_number#20, sr_return_quantity#21] +Arguments: hashpartitioning(sr_customer_sk#19, sr_item_sk#18, sr_ticket_number#20, 5), ENSURE_REQUIREMENTS, [id=#25] -(30) Sort [codegen id : 10] -Input [4]: [sr_item_sk#17, sr_customer_sk#18, sr_ticket_number#19, sr_return_quantity#20] -Arguments: [sr_customer_sk#18 ASC NULLS FIRST, sr_item_sk#17 ASC NULLS FIRST, sr_ticket_number#19 ASC NULLS FIRST], false, 0 +(31) Sort [codegen id : 11] +Input [4]: [sr_item_sk#18, sr_customer_sk#19, sr_ticket_number#20, sr_return_quantity#21] +Arguments: [sr_customer_sk#19 ASC NULLS FIRST, sr_item_sk#18 ASC NULLS FIRST, sr_ticket_number#20 ASC NULLS FIRST], false, 0 -(31) SortMergeJoin [codegen id : 11] +(32) SortMergeJoin [codegen id : 12] Left keys [3]: [ss_customer_sk#2, ss_item_sk#1, ss_ticket_number#4] -Right keys [3]: [sr_customer_sk#18, sr_item_sk#17, sr_ticket_number#19] +Right keys [3]: [sr_customer_sk#19, sr_item_sk#18, sr_ticket_number#20] Join condition: None -(32) Project [codegen id : 11] -Output [7]: [ss_quantity#5, s_state#10, i_item_id#14, i_item_desc#15, sr_item_sk#17, sr_customer_sk#18, sr_return_quantity#20] -Input [11]: [ss_item_sk#1, ss_customer_sk#2, ss_ticket_number#4, ss_quantity#5, s_state#10, i_item_id#14, i_item_desc#15, sr_item_sk#17, sr_customer_sk#18, sr_ticket_number#19, sr_return_quantity#20] +(33) Project [codegen id : 12] +Output [7]: [ss_quantity#5, s_state#10, i_item_id#14, i_item_desc#15, sr_item_sk#18, sr_customer_sk#19, sr_return_quantity#21] +Input [11]: [ss_item_sk#1, ss_customer_sk#2, ss_ticket_number#4, ss_quantity#5, s_state#10, i_item_id#14, i_item_desc#15, sr_item_sk#18, sr_customer_sk#19, sr_ticket_number#20, sr_return_quantity#21] + +(34) Exchange +Input [7]: [ss_quantity#5, s_state#10, i_item_id#14, i_item_desc#15, sr_item_sk#18, sr_customer_sk#19, sr_return_quantity#21] +Arguments: hashpartitioning(sr_customer_sk#19, sr_item_sk#18, 5), ENSURE_REQUIREMENTS, [id=#26] -(33) Scan parquet default.catalog_sales -Output [4]: [cs_bill_customer_sk#25, cs_item_sk#26, cs_quantity#27, cs_sold_date_sk#28] +(35) Sort [codegen id : 13] +Input [7]: [ss_quantity#5, s_state#10, i_item_id#14, i_item_desc#15, sr_item_sk#18, sr_customer_sk#19, sr_return_quantity#21] +Arguments: [sr_customer_sk#19 ASC NULLS FIRST, sr_item_sk#18 ASC NULLS FIRST], false, 0 + +(36) Scan parquet default.catalog_sales +Output [4]: [cs_bill_customer_sk#27, cs_item_sk#28, cs_quantity#29, cs_sold_date_sk#30] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(cs_sold_date_sk#28), dynamicpruningexpression(cs_sold_date_sk#28 IN dynamicpruning#22)] +PartitionFilters: [isnotnull(cs_sold_date_sk#30), dynamicpruningexpression(cs_sold_date_sk#30 IN dynamicpruning#23)] PushedFilters: [IsNotNull(cs_bill_customer_sk), IsNotNull(cs_item_sk)] ReadSchema: struct -(34) ColumnarToRow [codegen id : 13] -Input [4]: [cs_bill_customer_sk#25, cs_item_sk#26, cs_quantity#27, cs_sold_date_sk#28] +(37) ColumnarToRow [codegen id : 15] +Input [4]: [cs_bill_customer_sk#27, cs_item_sk#28, cs_quantity#29, cs_sold_date_sk#30] -(35) Filter [codegen id : 13] -Input [4]: [cs_bill_customer_sk#25, cs_item_sk#26, cs_quantity#27, cs_sold_date_sk#28] -Condition : (isnotnull(cs_bill_customer_sk#25) AND isnotnull(cs_item_sk#26)) +(38) Filter [codegen id : 15] +Input [4]: [cs_bill_customer_sk#27, cs_item_sk#28, cs_quantity#29, cs_sold_date_sk#30] +Condition : (isnotnull(cs_bill_customer_sk#27) AND isnotnull(cs_item_sk#28)) -(36) ReusedExchange [Reuses operator id: 56] -Output [1]: [d_date_sk#29] +(39) ReusedExchange [Reuses operator id: 59] +Output [1]: [d_date_sk#31] -(37) BroadcastHashJoin [codegen id : 13] -Left keys [1]: [cs_sold_date_sk#28] -Right keys [1]: [d_date_sk#29] +(40) BroadcastHashJoin [codegen id : 15] +Left keys [1]: [cs_sold_date_sk#30] +Right keys [1]: [d_date_sk#31] Join condition: None -(38) Project [codegen id : 13] -Output [3]: [cs_bill_customer_sk#25, cs_item_sk#26, cs_quantity#27] -Input [5]: [cs_bill_customer_sk#25, cs_item_sk#26, cs_quantity#27, cs_sold_date_sk#28, d_date_sk#29] +(41) Project [codegen id : 15] +Output [3]: [cs_bill_customer_sk#27, cs_item_sk#28, cs_quantity#29] +Input [5]: [cs_bill_customer_sk#27, cs_item_sk#28, cs_quantity#29, cs_sold_date_sk#30, d_date_sk#31] -(39) Exchange -Input [3]: [cs_bill_customer_sk#25, cs_item_sk#26, cs_quantity#27] -Arguments: hashpartitioning(cs_item_sk#26, 5), ENSURE_REQUIREMENTS, [id=#30] +(42) Exchange +Input [3]: [cs_bill_customer_sk#27, cs_item_sk#28, cs_quantity#29] +Arguments: hashpartitioning(cs_bill_customer_sk#27, cs_item_sk#28, 5), ENSURE_REQUIREMENTS, [id=#32] -(40) Sort [codegen id : 14] -Input [3]: [cs_bill_customer_sk#25, cs_item_sk#26, cs_quantity#27] -Arguments: [cs_bill_customer_sk#25 ASC NULLS FIRST, cs_item_sk#26 ASC NULLS FIRST], false, 0 +(43) Sort [codegen id : 16] +Input [3]: [cs_bill_customer_sk#27, cs_item_sk#28, cs_quantity#29] +Arguments: [cs_bill_customer_sk#27 ASC NULLS FIRST, cs_item_sk#28 ASC NULLS FIRST], false, 0 -(41) SortMergeJoin [codegen id : 15] -Left keys [2]: [sr_customer_sk#18, sr_item_sk#17] -Right keys [2]: [cs_bill_customer_sk#25, cs_item_sk#26] +(44) SortMergeJoin [codegen id : 17] +Left keys [2]: [sr_customer_sk#19, sr_item_sk#18] +Right keys [2]: [cs_bill_customer_sk#27, cs_item_sk#28] Join condition: None -(42) Project [codegen id : 15] -Output [6]: [ss_quantity#5, sr_return_quantity#20, cs_quantity#27, s_state#10, i_item_id#14, i_item_desc#15] -Input [10]: [ss_quantity#5, s_state#10, i_item_id#14, i_item_desc#15, sr_item_sk#17, sr_customer_sk#18, sr_return_quantity#20, cs_bill_customer_sk#25, cs_item_sk#26, cs_quantity#27] +(45) Project [codegen id : 17] +Output [6]: [ss_quantity#5, sr_return_quantity#21, cs_quantity#29, s_state#10, i_item_id#14, i_item_desc#15] +Input [10]: [ss_quantity#5, s_state#10, i_item_id#14, i_item_desc#15, sr_item_sk#18, sr_customer_sk#19, sr_return_quantity#21, cs_bill_customer_sk#27, cs_item_sk#28, cs_quantity#29] -(43) HashAggregate [codegen id : 15] -Input [6]: [ss_quantity#5, sr_return_quantity#20, cs_quantity#27, s_state#10, i_item_id#14, i_item_desc#15] +(46) HashAggregate [codegen id : 17] +Input [6]: [ss_quantity#5, sr_return_quantity#21, cs_quantity#29, s_state#10, i_item_id#14, i_item_desc#15] Keys [3]: [i_item_id#14, i_item_desc#15, s_state#10] -Functions [9]: [partial_count(ss_quantity#5), partial_avg(ss_quantity#5), partial_stddev_samp(cast(ss_quantity#5 as double)), partial_count(sr_return_quantity#20), partial_avg(sr_return_quantity#20), partial_stddev_samp(cast(sr_return_quantity#20 as double)), partial_count(cs_quantity#27), partial_avg(cs_quantity#27), partial_stddev_samp(cast(cs_quantity#27 as double))] -Aggregate Attributes [18]: [count#31, sum#32, count#33, n#34, avg#35, m2#36, count#37, sum#38, count#39, n#40, avg#41, m2#42, count#43, sum#44, count#45, n#46, avg#47, m2#48] -Results [21]: [i_item_id#14, i_item_desc#15, s_state#10, count#49, sum#50, count#51, n#52, avg#53, m2#54, count#55, sum#56, count#57, n#58, avg#59, m2#60, count#61, sum#62, count#63, n#64, avg#65, m2#66] +Functions [9]: [partial_count(ss_quantity#5), partial_avg(ss_quantity#5), partial_stddev_samp(cast(ss_quantity#5 as double)), partial_count(sr_return_quantity#21), partial_avg(sr_return_quantity#21), partial_stddev_samp(cast(sr_return_quantity#21 as double)), partial_count(cs_quantity#29), partial_avg(cs_quantity#29), partial_stddev_samp(cast(cs_quantity#29 as double))] +Aggregate Attributes [18]: [count#33, sum#34, count#35, n#36, avg#37, m2#38, count#39, sum#40, count#41, n#42, avg#43, m2#44, count#45, sum#46, count#47, n#48, avg#49, m2#50] +Results [21]: [i_item_id#14, i_item_desc#15, s_state#10, count#51, sum#52, count#53, n#54, avg#55, m2#56, count#57, sum#58, count#59, n#60, avg#61, m2#62, count#63, sum#64, count#65, n#66, avg#67, m2#68] -(44) Exchange -Input [21]: [i_item_id#14, i_item_desc#15, s_state#10, count#49, sum#50, count#51, n#52, avg#53, m2#54, count#55, sum#56, count#57, n#58, avg#59, m2#60, count#61, sum#62, count#63, n#64, avg#65, m2#66] -Arguments: hashpartitioning(i_item_id#14, i_item_desc#15, s_state#10, 5), ENSURE_REQUIREMENTS, [id=#67] +(47) Exchange +Input [21]: [i_item_id#14, i_item_desc#15, s_state#10, count#51, sum#52, count#53, n#54, avg#55, m2#56, count#57, sum#58, count#59, n#60, avg#61, m2#62, count#63, sum#64, count#65, n#66, avg#67, m2#68] +Arguments: hashpartitioning(i_item_id#14, i_item_desc#15, s_state#10, 5), ENSURE_REQUIREMENTS, [id=#69] -(45) HashAggregate [codegen id : 16] -Input [21]: [i_item_id#14, i_item_desc#15, s_state#10, count#49, sum#50, count#51, n#52, avg#53, m2#54, count#55, sum#56, count#57, n#58, avg#59, m2#60, count#61, sum#62, count#63, n#64, avg#65, m2#66] +(48) HashAggregate [codegen id : 18] +Input [21]: [i_item_id#14, i_item_desc#15, s_state#10, count#51, sum#52, count#53, n#54, avg#55, m2#56, count#57, sum#58, count#59, n#60, avg#61, m2#62, count#63, sum#64, count#65, n#66, avg#67, m2#68] Keys [3]: [i_item_id#14, i_item_desc#15, s_state#10] -Functions [9]: [count(ss_quantity#5), avg(ss_quantity#5), stddev_samp(cast(ss_quantity#5 as double)), count(sr_return_quantity#20), avg(sr_return_quantity#20), stddev_samp(cast(sr_return_quantity#20 as double)), count(cs_quantity#27), avg(cs_quantity#27), stddev_samp(cast(cs_quantity#27 as double))] -Aggregate Attributes [9]: [count(ss_quantity#5)#68, avg(ss_quantity#5)#69, stddev_samp(cast(ss_quantity#5 as double))#70, count(sr_return_quantity#20)#71, avg(sr_return_quantity#20)#72, stddev_samp(cast(sr_return_quantity#20 as double))#73, count(cs_quantity#27)#74, avg(cs_quantity#27)#75, stddev_samp(cast(cs_quantity#27 as double))#76] -Results [15]: [i_item_id#14, i_item_desc#15, s_state#10, count(ss_quantity#5)#68 AS store_sales_quantitycount#77, avg(ss_quantity#5)#69 AS store_sales_quantityave#78, stddev_samp(cast(ss_quantity#5 as double))#70 AS store_sales_quantitystdev#79, (stddev_samp(cast(ss_quantity#5 as double))#70 / avg(ss_quantity#5)#69) AS store_sales_quantitycov#80, count(sr_return_quantity#20)#71 AS as_store_returns_quantitycount#81, avg(sr_return_quantity#20)#72 AS as_store_returns_quantityave#82, stddev_samp(cast(sr_return_quantity#20 as double))#73 AS as_store_returns_quantitystdev#83, (stddev_samp(cast(sr_return_quantity#20 as double))#73 / avg(sr_return_quantity#20)#72) AS store_returns_quantitycov#84, count(cs_quantity#27)#74 AS catalog_sales_quantitycount#85, avg(cs_quantity#27)#75 AS catalog_sales_quantityave#86, (stddev_samp(cast(cs_quantity#27 as double))#76 / avg(cs_quantity#27)#75) AS catalog_sales_quantitystdev#87, (stddev_samp(cast(cs_quantity#27 as double))#76 / avg(cs_quantity#27)#75) AS catalog_sales_quantitycov#88] +Functions [9]: [count(ss_quantity#5), avg(ss_quantity#5), stddev_samp(cast(ss_quantity#5 as double)), count(sr_return_quantity#21), avg(sr_return_quantity#21), stddev_samp(cast(sr_return_quantity#21 as double)), count(cs_quantity#29), avg(cs_quantity#29), stddev_samp(cast(cs_quantity#29 as double))] +Aggregate Attributes [9]: [count(ss_quantity#5)#70, avg(ss_quantity#5)#71, stddev_samp(cast(ss_quantity#5 as double))#72, count(sr_return_quantity#21)#73, avg(sr_return_quantity#21)#74, stddev_samp(cast(sr_return_quantity#21 as double))#75, count(cs_quantity#29)#76, avg(cs_quantity#29)#77, stddev_samp(cast(cs_quantity#29 as double))#78] +Results [15]: [i_item_id#14, i_item_desc#15, s_state#10, count(ss_quantity#5)#70 AS store_sales_quantitycount#79, avg(ss_quantity#5)#71 AS store_sales_quantityave#80, stddev_samp(cast(ss_quantity#5 as double))#72 AS store_sales_quantitystdev#81, (stddev_samp(cast(ss_quantity#5 as double))#72 / avg(ss_quantity#5)#71) AS store_sales_quantitycov#82, count(sr_return_quantity#21)#73 AS as_store_returns_quantitycount#83, avg(sr_return_quantity#21)#74 AS as_store_returns_quantityave#84, stddev_samp(cast(sr_return_quantity#21 as double))#75 AS as_store_returns_quantitystdev#85, (stddev_samp(cast(sr_return_quantity#21 as double))#75 / avg(sr_return_quantity#21)#74) AS store_returns_quantitycov#86, count(cs_quantity#29)#76 AS catalog_sales_quantitycount#87, avg(cs_quantity#29)#77 AS catalog_sales_quantityave#88, (stddev_samp(cast(cs_quantity#29 as double))#78 / avg(cs_quantity#29)#77) AS catalog_sales_quantitystdev#89, (stddev_samp(cast(cs_quantity#29 as double))#78 / avg(cs_quantity#29)#77) AS catalog_sales_quantitycov#90] -(46) TakeOrderedAndProject -Input [15]: [i_item_id#14, i_item_desc#15, s_state#10, store_sales_quantitycount#77, store_sales_quantityave#78, store_sales_quantitystdev#79, store_sales_quantitycov#80, as_store_returns_quantitycount#81, as_store_returns_quantityave#82, as_store_returns_quantitystdev#83, store_returns_quantitycov#84, catalog_sales_quantitycount#85, catalog_sales_quantityave#86, catalog_sales_quantitystdev#87, catalog_sales_quantitycov#88] -Arguments: 100, [i_item_id#14 ASC NULLS FIRST, i_item_desc#15 ASC NULLS FIRST, s_state#10 ASC NULLS FIRST], [i_item_id#14, i_item_desc#15, s_state#10, store_sales_quantitycount#77, store_sales_quantityave#78, store_sales_quantitystdev#79, store_sales_quantitycov#80, as_store_returns_quantitycount#81, as_store_returns_quantityave#82, as_store_returns_quantitystdev#83, store_returns_quantitycov#84, catalog_sales_quantitycount#85, catalog_sales_quantityave#86, catalog_sales_quantitystdev#87, catalog_sales_quantitycov#88] +(49) TakeOrderedAndProject +Input [15]: [i_item_id#14, i_item_desc#15, s_state#10, store_sales_quantitycount#79, store_sales_quantityave#80, store_sales_quantitystdev#81, store_sales_quantitycov#82, as_store_returns_quantitycount#83, as_store_returns_quantityave#84, as_store_returns_quantitystdev#85, store_returns_quantitycov#86, catalog_sales_quantitycount#87, catalog_sales_quantityave#88, catalog_sales_quantitystdev#89, catalog_sales_quantitycov#90] +Arguments: 100, [i_item_id#14 ASC NULLS FIRST, i_item_desc#15 ASC NULLS FIRST, s_state#10 ASC NULLS FIRST], [i_item_id#14, i_item_desc#15, s_state#10, store_sales_quantitycount#79, store_sales_quantityave#80, store_sales_quantitystdev#81, store_sales_quantitycov#82, as_store_returns_quantitycount#83, as_store_returns_quantityave#84, as_store_returns_quantitystdev#85, store_returns_quantitycov#86, catalog_sales_quantitycount#87, catalog_sales_quantityave#88, catalog_sales_quantitystdev#89, catalog_sales_quantitycov#90] ===== Subqueries ===== Subquery:1 Hosting operator id = 1 Hosting Expression = ss_sold_date_sk#6 IN dynamicpruning#7 -BroadcastExchange (51) -+- * Project (50) - +- * Filter (49) - +- * ColumnarToRow (48) - +- Scan parquet default.date_dim (47) +BroadcastExchange (54) ++- * Project (53) + +- * Filter (52) + +- * ColumnarToRow (51) + +- Scan parquet default.date_dim (50) -(47) Scan parquet default.date_dim -Output [2]: [d_date_sk#8, d_quarter_name#89] +(50) Scan parquet default.date_dim +Output [2]: [d_date_sk#8, d_quarter_name#91] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_quarter_name), EqualTo(d_quarter_name,2001Q1), IsNotNull(d_date_sk)] ReadSchema: struct -(48) ColumnarToRow [codegen id : 1] -Input [2]: [d_date_sk#8, d_quarter_name#89] +(51) ColumnarToRow [codegen id : 1] +Input [2]: [d_date_sk#8, d_quarter_name#91] -(49) Filter [codegen id : 1] -Input [2]: [d_date_sk#8, d_quarter_name#89] -Condition : ((isnotnull(d_quarter_name#89) AND (d_quarter_name#89 = 2001Q1)) AND isnotnull(d_date_sk#8)) +(52) Filter [codegen id : 1] +Input [2]: [d_date_sk#8, d_quarter_name#91] +Condition : ((isnotnull(d_quarter_name#91) AND (d_quarter_name#91 = 2001Q1)) AND isnotnull(d_date_sk#8)) -(50) Project [codegen id : 1] +(53) Project [codegen id : 1] Output [1]: [d_date_sk#8] -Input [2]: [d_date_sk#8, d_quarter_name#89] +Input [2]: [d_date_sk#8, d_quarter_name#91] -(51) BroadcastExchange +(54) BroadcastExchange Input [1]: [d_date_sk#8] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#90] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#92] -Subquery:2 Hosting operator id = 23 Hosting Expression = sr_returned_date_sk#21 IN dynamicpruning#22 -BroadcastExchange (56) -+- * Project (55) - +- * Filter (54) - +- * ColumnarToRow (53) - +- Scan parquet default.date_dim (52) +Subquery:2 Hosting operator id = 24 Hosting Expression = sr_returned_date_sk#22 IN dynamicpruning#23 +BroadcastExchange (59) ++- * Project (58) + +- * Filter (57) + +- * ColumnarToRow (56) + +- Scan parquet default.date_dim (55) -(52) Scan parquet default.date_dim -Output [2]: [d_date_sk#23, d_quarter_name#91] +(55) Scan parquet default.date_dim +Output [2]: [d_date_sk#24, d_quarter_name#93] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [In(d_quarter_name, [2001Q1,2001Q2,2001Q3]), IsNotNull(d_date_sk)] ReadSchema: struct -(53) ColumnarToRow [codegen id : 1] -Input [2]: [d_date_sk#23, d_quarter_name#91] +(56) ColumnarToRow [codegen id : 1] +Input [2]: [d_date_sk#24, d_quarter_name#93] -(54) Filter [codegen id : 1] -Input [2]: [d_date_sk#23, d_quarter_name#91] -Condition : (d_quarter_name#91 IN (2001Q1,2001Q2,2001Q3) AND isnotnull(d_date_sk#23)) +(57) Filter [codegen id : 1] +Input [2]: [d_date_sk#24, d_quarter_name#93] +Condition : (d_quarter_name#93 IN (2001Q1,2001Q2,2001Q3) AND isnotnull(d_date_sk#24)) -(55) Project [codegen id : 1] -Output [1]: [d_date_sk#23] -Input [2]: [d_date_sk#23, d_quarter_name#91] +(58) Project [codegen id : 1] +Output [1]: [d_date_sk#24] +Input [2]: [d_date_sk#24, d_quarter_name#93] -(56) BroadcastExchange -Input [1]: [d_date_sk#23] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#92] +(59) BroadcastExchange +Input [1]: [d_date_sk#24] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#94] -Subquery:3 Hosting operator id = 33 Hosting Expression = cs_sold_date_sk#28 IN dynamicpruning#22 +Subquery:3 Hosting operator id = 36 Hosting Expression = cs_sold_date_sk#30 IN dynamicpruning#23 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q17.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q17.sf100/simplified.txt index b00c5da2ef7d0..06c8f7b3912e5 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q17.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q17.sf100/simplified.txt @@ -1,90 +1,97 @@ TakeOrderedAndProject [i_item_id,i_item_desc,s_state,store_sales_quantitycount,store_sales_quantityave,store_sales_quantitystdev,store_sales_quantitycov,as_store_returns_quantitycount,as_store_returns_quantityave,as_store_returns_quantitystdev,store_returns_quantitycov,catalog_sales_quantitycount,catalog_sales_quantityave,catalog_sales_quantitystdev,catalog_sales_quantitycov] - WholeStageCodegen (16) + WholeStageCodegen (18) HashAggregate [i_item_id,i_item_desc,s_state,count,sum,count,n,avg,m2,count,sum,count,n,avg,m2,count,sum,count,n,avg,m2] [count(ss_quantity),avg(ss_quantity),stddev_samp(cast(ss_quantity as double)),count(sr_return_quantity),avg(sr_return_quantity),stddev_samp(cast(sr_return_quantity as double)),count(cs_quantity),avg(cs_quantity),stddev_samp(cast(cs_quantity as double)),store_sales_quantitycount,store_sales_quantityave,store_sales_quantitystdev,store_sales_quantitycov,as_store_returns_quantitycount,as_store_returns_quantityave,as_store_returns_quantitystdev,store_returns_quantitycov,catalog_sales_quantitycount,catalog_sales_quantityave,catalog_sales_quantitystdev,catalog_sales_quantitycov,count,sum,count,n,avg,m2,count,sum,count,n,avg,m2,count,sum,count,n,avg,m2] InputAdapter Exchange [i_item_id,i_item_desc,s_state] #1 - WholeStageCodegen (15) + WholeStageCodegen (17) HashAggregate [i_item_id,i_item_desc,s_state,ss_quantity,sr_return_quantity,cs_quantity] [count,sum,count,n,avg,m2,count,sum,count,n,avg,m2,count,sum,count,n,avg,m2,count,sum,count,n,avg,m2,count,sum,count,n,avg,m2,count,sum,count,n,avg,m2] Project [ss_quantity,sr_return_quantity,cs_quantity,s_state,i_item_id,i_item_desc] SortMergeJoin [sr_customer_sk,sr_item_sk,cs_bill_customer_sk,cs_item_sk] InputAdapter - WholeStageCodegen (11) - Project [ss_quantity,s_state,i_item_id,i_item_desc,sr_item_sk,sr_customer_sk,sr_return_quantity] - SortMergeJoin [ss_customer_sk,ss_item_sk,ss_ticket_number,sr_customer_sk,sr_item_sk,sr_ticket_number] - InputAdapter - WholeStageCodegen (7) - Sort [ss_customer_sk,ss_item_sk,ss_ticket_number] - Project [ss_item_sk,ss_customer_sk,ss_ticket_number,ss_quantity,s_state,i_item_id,i_item_desc] - SortMergeJoin [ss_item_sk,i_item_sk] - InputAdapter - WholeStageCodegen (4) - Sort [ss_item_sk] - InputAdapter - Exchange [ss_item_sk] #2 - WholeStageCodegen (3) - Project [ss_item_sk,ss_customer_sk,ss_ticket_number,ss_quantity,s_state] - BroadcastHashJoin [ss_store_sk,s_store_sk] - Project [ss_item_sk,ss_customer_sk,ss_store_sk,ss_ticket_number,ss_quantity] - BroadcastHashJoin [ss_sold_date_sk,d_date_sk] - Filter [ss_customer_sk,ss_item_sk,ss_ticket_number,ss_store_sk] - ColumnarToRow - InputAdapter - Scan parquet default.store_sales [ss_item_sk,ss_customer_sk,ss_store_sk,ss_ticket_number,ss_quantity,ss_sold_date_sk] - SubqueryBroadcast [d_date_sk] #1 - BroadcastExchange #3 - WholeStageCodegen (1) - Project [d_date_sk] - Filter [d_quarter_name,d_date_sk] + WholeStageCodegen (13) + Sort [sr_customer_sk,sr_item_sk] + InputAdapter + Exchange [sr_customer_sk,sr_item_sk] #2 + WholeStageCodegen (12) + Project [ss_quantity,s_state,i_item_id,i_item_desc,sr_item_sk,sr_customer_sk,sr_return_quantity] + SortMergeJoin [ss_customer_sk,ss_item_sk,ss_ticket_number,sr_customer_sk,sr_item_sk,sr_ticket_number] + InputAdapter + WholeStageCodegen (8) + Sort [ss_customer_sk,ss_item_sk,ss_ticket_number] + InputAdapter + Exchange [ss_customer_sk,ss_item_sk,ss_ticket_number] #3 + WholeStageCodegen (7) + Project [ss_item_sk,ss_customer_sk,ss_ticket_number,ss_quantity,s_state,i_item_id,i_item_desc] + SortMergeJoin [ss_item_sk,i_item_sk] + InputAdapter + WholeStageCodegen (4) + Sort [ss_item_sk] + InputAdapter + Exchange [ss_item_sk] #4 + WholeStageCodegen (3) + Project [ss_item_sk,ss_customer_sk,ss_ticket_number,ss_quantity,s_state] + BroadcastHashJoin [ss_store_sk,s_store_sk] + Project [ss_item_sk,ss_customer_sk,ss_store_sk,ss_ticket_number,ss_quantity] + BroadcastHashJoin [ss_sold_date_sk,d_date_sk] + Filter [ss_customer_sk,ss_item_sk,ss_ticket_number,ss_store_sk] + ColumnarToRow + InputAdapter + Scan parquet default.store_sales [ss_item_sk,ss_customer_sk,ss_store_sk,ss_ticket_number,ss_quantity,ss_sold_date_sk] + SubqueryBroadcast [d_date_sk] #1 + BroadcastExchange #5 + WholeStageCodegen (1) + Project [d_date_sk] + Filter [d_quarter_name,d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_quarter_name] + InputAdapter + ReusedExchange [d_date_sk] #5 + InputAdapter + BroadcastExchange #6 + WholeStageCodegen (2) + Filter [s_store_sk] ColumnarToRow InputAdapter - Scan parquet default.date_dim [d_date_sk,d_quarter_name] + Scan parquet default.store [s_store_sk,s_state] + InputAdapter + WholeStageCodegen (6) + Sort [i_item_sk] InputAdapter - ReusedExchange [d_date_sk] #3 - InputAdapter - BroadcastExchange #4 - WholeStageCodegen (2) - Filter [s_store_sk] - ColumnarToRow - InputAdapter - Scan parquet default.store [s_store_sk,s_state] - InputAdapter - WholeStageCodegen (6) - Sort [i_item_sk] - InputAdapter - Exchange [i_item_sk] #5 - WholeStageCodegen (5) - Filter [i_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.item [i_item_sk,i_item_id,i_item_desc] - InputAdapter - WholeStageCodegen (10) - Sort [sr_customer_sk,sr_item_sk,sr_ticket_number] - InputAdapter - Exchange [sr_item_sk] #6 - WholeStageCodegen (9) - Project [sr_item_sk,sr_customer_sk,sr_ticket_number,sr_return_quantity] - BroadcastHashJoin [sr_returned_date_sk,d_date_sk] - Filter [sr_customer_sk,sr_item_sk,sr_ticket_number] - ColumnarToRow - InputAdapter - Scan parquet default.store_returns [sr_item_sk,sr_customer_sk,sr_ticket_number,sr_return_quantity,sr_returned_date_sk] - SubqueryBroadcast [d_date_sk] #2 - BroadcastExchange #7 - WholeStageCodegen (1) - Project [d_date_sk] - Filter [d_quarter_name,d_date_sk] - ColumnarToRow - InputAdapter - Scan parquet default.date_dim [d_date_sk,d_quarter_name] - InputAdapter - ReusedExchange [d_date_sk] #7 + Exchange [i_item_sk] #7 + WholeStageCodegen (5) + Filter [i_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.item [i_item_sk,i_item_id,i_item_desc] + InputAdapter + WholeStageCodegen (11) + Sort [sr_customer_sk,sr_item_sk,sr_ticket_number] + InputAdapter + Exchange [sr_customer_sk,sr_item_sk,sr_ticket_number] #8 + WholeStageCodegen (10) + Project [sr_item_sk,sr_customer_sk,sr_ticket_number,sr_return_quantity] + BroadcastHashJoin [sr_returned_date_sk,d_date_sk] + Filter [sr_customer_sk,sr_item_sk,sr_ticket_number] + ColumnarToRow + InputAdapter + Scan parquet default.store_returns [sr_item_sk,sr_customer_sk,sr_ticket_number,sr_return_quantity,sr_returned_date_sk] + SubqueryBroadcast [d_date_sk] #2 + BroadcastExchange #9 + WholeStageCodegen (1) + Project [d_date_sk] + Filter [d_quarter_name,d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_quarter_name] + InputAdapter + ReusedExchange [d_date_sk] #9 InputAdapter - WholeStageCodegen (14) + WholeStageCodegen (16) Sort [cs_bill_customer_sk,cs_item_sk] InputAdapter - Exchange [cs_item_sk] #8 - WholeStageCodegen (13) + Exchange [cs_bill_customer_sk,cs_item_sk] #10 + WholeStageCodegen (15) Project [cs_bill_customer_sk,cs_item_sk,cs_quantity] BroadcastHashJoin [cs_sold_date_sk,d_date_sk] Filter [cs_bill_customer_sk,cs_item_sk] @@ -93,4 +100,4 @@ TakeOrderedAndProject [i_item_id,i_item_desc,s_state,store_sales_quantitycount,s Scan parquet default.catalog_sales [cs_bill_customer_sk,cs_item_sk,cs_quantity,cs_sold_date_sk] ReusedSubquery [d_date_sk] #2 InputAdapter - ReusedExchange [d_date_sk] #7 + ReusedExchange [d_date_sk] #9 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q2.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q2.sf100/explain.txt index 33f6c01b4b69b..8f188db553004 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q2.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q2.sf100/explain.txt @@ -195,7 +195,7 @@ Right keys [1]: [(d_week_seq2#63 - 53)] Join condition: None (35) Project [codegen id : 12] -Output [8]: [d_week_seq1#45, round(CheckOverflow((promote_precision(sun_sales1#46) / promote_precision(sun_sales2#64)), DecimalType(37,20), true), 2) AS round((sun_sales1 / sun_sales2), 2)#72, round(CheckOverflow((promote_precision(mon_sales1#47) / promote_precision(mon_sales2#65)), DecimalType(37,20), true), 2) AS round((mon_sales1 / mon_sales2), 2)#73, round(CheckOverflow((promote_precision(tue_sales1#48) / promote_precision(tue_sales2#66)), DecimalType(37,20), true), 2) AS round((tue_sales1 / tue_sales2), 2)#74, round(CheckOverflow((promote_precision(wed_sales1#49) / promote_precision(wed_sales2#67)), DecimalType(37,20), true), 2) AS round((wed_sales1 / wed_sales2), 2)#75, round(CheckOverflow((promote_precision(thu_sales1#50) / promote_precision(thu_sales2#68)), DecimalType(37,20), true), 2) AS round((thu_sales1 / thu_sales2), 2)#76, round(CheckOverflow((promote_precision(fri_sales1#51) / promote_precision(fri_sales2#69)), DecimalType(37,20), true), 2) AS round((fri_sales1 / fri_sales2), 2)#77, round(CheckOverflow((promote_precision(sat_sales1#52) / promote_precision(sat_sales2#70)), DecimalType(37,20), true), 2) AS round((sat_sales1 / sat_sales2), 2)#78] +Output [8]: [d_week_seq1#45, round(CheckOverflow((promote_precision(sun_sales1#46) / promote_precision(sun_sales2#64)), DecimalType(37,20)), 2) AS round((sun_sales1 / sun_sales2), 2)#72, round(CheckOverflow((promote_precision(mon_sales1#47) / promote_precision(mon_sales2#65)), DecimalType(37,20)), 2) AS round((mon_sales1 / mon_sales2), 2)#73, round(CheckOverflow((promote_precision(tue_sales1#48) / promote_precision(tue_sales2#66)), DecimalType(37,20)), 2) AS round((tue_sales1 / tue_sales2), 2)#74, round(CheckOverflow((promote_precision(wed_sales1#49) / promote_precision(wed_sales2#67)), DecimalType(37,20)), 2) AS round((wed_sales1 / wed_sales2), 2)#75, round(CheckOverflow((promote_precision(thu_sales1#50) / promote_precision(thu_sales2#68)), DecimalType(37,20)), 2) AS round((thu_sales1 / thu_sales2), 2)#76, round(CheckOverflow((promote_precision(fri_sales1#51) / promote_precision(fri_sales2#69)), DecimalType(37,20)), 2) AS round((fri_sales1 / fri_sales2), 2)#77, round(CheckOverflow((promote_precision(sat_sales1#52) / promote_precision(sat_sales2#70)), DecimalType(37,20)), 2) AS round((sat_sales1 / sat_sales2), 2)#78] Input [16]: [d_week_seq1#45, sun_sales1#46, mon_sales1#47, tue_sales1#48, wed_sales1#49, thu_sales1#50, fri_sales1#51, sat_sales1#52, d_week_seq2#63, sun_sales2#64, mon_sales2#65, tue_sales2#66, wed_sales2#67, thu_sales2#68, fri_sales2#69, sat_sales2#70] (36) Exchange diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q2/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q2/explain.txt index 33f6c01b4b69b..8f188db553004 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q2/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q2/explain.txt @@ -195,7 +195,7 @@ Right keys [1]: [(d_week_seq2#63 - 53)] Join condition: None (35) Project [codegen id : 12] -Output [8]: [d_week_seq1#45, round(CheckOverflow((promote_precision(sun_sales1#46) / promote_precision(sun_sales2#64)), DecimalType(37,20), true), 2) AS round((sun_sales1 / sun_sales2), 2)#72, round(CheckOverflow((promote_precision(mon_sales1#47) / promote_precision(mon_sales2#65)), DecimalType(37,20), true), 2) AS round((mon_sales1 / mon_sales2), 2)#73, round(CheckOverflow((promote_precision(tue_sales1#48) / promote_precision(tue_sales2#66)), DecimalType(37,20), true), 2) AS round((tue_sales1 / tue_sales2), 2)#74, round(CheckOverflow((promote_precision(wed_sales1#49) / promote_precision(wed_sales2#67)), DecimalType(37,20), true), 2) AS round((wed_sales1 / wed_sales2), 2)#75, round(CheckOverflow((promote_precision(thu_sales1#50) / promote_precision(thu_sales2#68)), DecimalType(37,20), true), 2) AS round((thu_sales1 / thu_sales2), 2)#76, round(CheckOverflow((promote_precision(fri_sales1#51) / promote_precision(fri_sales2#69)), DecimalType(37,20), true), 2) AS round((fri_sales1 / fri_sales2), 2)#77, round(CheckOverflow((promote_precision(sat_sales1#52) / promote_precision(sat_sales2#70)), DecimalType(37,20), true), 2) AS round((sat_sales1 / sat_sales2), 2)#78] +Output [8]: [d_week_seq1#45, round(CheckOverflow((promote_precision(sun_sales1#46) / promote_precision(sun_sales2#64)), DecimalType(37,20)), 2) AS round((sun_sales1 / sun_sales2), 2)#72, round(CheckOverflow((promote_precision(mon_sales1#47) / promote_precision(mon_sales2#65)), DecimalType(37,20)), 2) AS round((mon_sales1 / mon_sales2), 2)#73, round(CheckOverflow((promote_precision(tue_sales1#48) / promote_precision(tue_sales2#66)), DecimalType(37,20)), 2) AS round((tue_sales1 / tue_sales2), 2)#74, round(CheckOverflow((promote_precision(wed_sales1#49) / promote_precision(wed_sales2#67)), DecimalType(37,20)), 2) AS round((wed_sales1 / wed_sales2), 2)#75, round(CheckOverflow((promote_precision(thu_sales1#50) / promote_precision(thu_sales2#68)), DecimalType(37,20)), 2) AS round((thu_sales1 / thu_sales2), 2)#76, round(CheckOverflow((promote_precision(fri_sales1#51) / promote_precision(fri_sales2#69)), DecimalType(37,20)), 2) AS round((fri_sales1 / fri_sales2), 2)#77, round(CheckOverflow((promote_precision(sat_sales1#52) / promote_precision(sat_sales2#70)), DecimalType(37,20)), 2) AS round((sat_sales1 / sat_sales2), 2)#78] Input [16]: [d_week_seq1#45, sun_sales1#46, mon_sales1#47, tue_sales1#48, wed_sales1#49, thu_sales1#50, fri_sales1#51, sat_sales1#52, d_week_seq2#63, sun_sales2#64, mon_sales2#65, tue_sales2#66, wed_sales2#67, thu_sales2#68, fri_sales2#69, sat_sales2#70] (36) Exchange diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q20.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q20.sf100/explain.txt index d50622c2464ea..09e4cd2a57054 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q20.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q20.sf100/explain.txt @@ -121,7 +121,7 @@ Input [8]: [i_item_desc#8, i_category#11, i_class#10, i_current_price#9, itemrev Arguments: [sum(_w1#20) windowspecdefinition(i_class#10, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS _we0#22], [i_class#10] (22) Project [codegen id : 9] -Output [7]: [i_item_desc#8, i_category#11, i_class#10, i_current_price#9, itemrevenue#18, CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(_w0#19) * 100.00), DecimalType(21,2), true) as decimal(27,2))) / promote_precision(_we0#22)), DecimalType(38,17), true) AS revenueratio#23, i_item_id#7] +Output [7]: [i_item_desc#8, i_category#11, i_class#10, i_current_price#9, itemrevenue#18, CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(_w0#19) * 100.00), DecimalType(21,2)) as decimal(27,2))) / promote_precision(_we0#22)), DecimalType(38,17)) AS revenueratio#23, i_item_id#7] Input [9]: [i_item_desc#8, i_category#11, i_class#10, i_current_price#9, itemrevenue#18, _w0#19, _w1#20, i_item_id#7, _we0#22] (23) TakeOrderedAndProject diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q20/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q20/explain.txt index b54c704b66c3f..8b9d47316f293 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q20/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q20/explain.txt @@ -106,7 +106,7 @@ Input [8]: [i_item_desc#7, i_category#10, i_class#9, i_current_price#8, itemreve Arguments: [sum(_w1#19) windowspecdefinition(i_class#9, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS _we0#21], [i_class#9] (19) Project [codegen id : 6] -Output [7]: [i_item_desc#7, i_category#10, i_class#9, i_current_price#8, itemrevenue#17, CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(_w0#18) * 100.00), DecimalType(21,2), true) as decimal(27,2))) / promote_precision(_we0#21)), DecimalType(38,17), true) AS revenueratio#22, i_item_id#6] +Output [7]: [i_item_desc#7, i_category#10, i_class#9, i_current_price#8, itemrevenue#17, CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(_w0#18) * 100.00), DecimalType(21,2)) as decimal(27,2))) / promote_precision(_we0#21)), DecimalType(38,17)) AS revenueratio#22, i_item_id#6] Input [9]: [i_item_desc#7, i_category#10, i_class#9, i_current_price#8, itemrevenue#17, _w0#18, _w1#19, i_item_id#6, _we0#21] (20) TakeOrderedAndProject diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23a.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23a.sf100/explain.txt index be706fee66776..5bf5193487b07 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23a.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23a.sf100/explain.txt @@ -278,20 +278,20 @@ Input [4]: [ss_customer_sk#24, ss_quantity#25, ss_sales_price#26, c_customer_sk# (42) HashAggregate [codegen id : 15] Input [3]: [ss_quantity#25, ss_sales_price#26, c_customer_sk#29] Keys [1]: [c_customer_sk#29] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))] Aggregate Attributes [2]: [sum#31, isEmpty#32] Results [3]: [c_customer_sk#29, sum#33, isEmpty#34] (43) HashAggregate [codegen id : 15] Input [3]: [c_customer_sk#29, sum#33, isEmpty#34] Keys [1]: [c_customer_sk#29] -Functions [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))#35] -Results [2]: [c_customer_sk#29, sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))#35 AS ssales#36] +Functions [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))#35] +Results [2]: [c_customer_sk#29, sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))#35 AS ssales#36] (44) Filter [codegen id : 15] Input [2]: [c_customer_sk#29, ssales#36] -Condition : (isnotnull(ssales#36) AND (cast(ssales#36 as decimal(38,8)) > CheckOverflow((0.500000 * promote_precision(cast(Subquery scalar-subquery#37, [id=#38] as decimal(32,6)))), DecimalType(38,8), true))) +Condition : (isnotnull(ssales#36) AND (cast(ssales#36 as decimal(38,8)) > CheckOverflow((0.500000 * promote_precision(cast(Subquery scalar-subquery#37, [id=#38] as decimal(32,6)))), DecimalType(38,8)))) (45) Project [codegen id : 15] Output [1]: [c_customer_sk#29] @@ -319,7 +319,7 @@ Right keys [1]: [d_date_sk#39] Join condition: None (51) Project [codegen id : 17] -Output [1]: [CheckOverflow((promote_precision(cast(cast(cs_quantity#3 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price#4 as decimal(12,2)))), DecimalType(18,2), true) AS sales#40] +Output [1]: [CheckOverflow((promote_precision(cast(cs_quantity#3 as decimal(12,2))) * promote_precision(cast(cs_list_price#4 as decimal(12,2)))), DecimalType(18,2)) AS sales#40] Input [4]: [cs_quantity#3, cs_list_price#4, cs_sold_date_sk#5, d_date_sk#39] (52) Scan parquet default.web_sales @@ -432,20 +432,20 @@ Input [4]: [ss_customer_sk#24, ss_quantity#25, ss_sales_price#26, c_customer_sk# (77) HashAggregate [codegen id : 32] Input [3]: [ss_quantity#25, ss_sales_price#26, c_customer_sk#29] Keys [1]: [c_customer_sk#29] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [2]: [sum#31, isEmpty#32] -Results [3]: [c_customer_sk#29, sum#33, isEmpty#34] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [2]: [sum#49, isEmpty#50] +Results [3]: [c_customer_sk#29, sum#51, isEmpty#52] (78) HashAggregate [codegen id : 32] -Input [3]: [c_customer_sk#29, sum#33, isEmpty#34] +Input [3]: [c_customer_sk#29, sum#51, isEmpty#52] Keys [1]: [c_customer_sk#29] -Functions [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))#35] -Results [2]: [c_customer_sk#29, sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))#35 AS ssales#36] +Functions [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))#35] +Results [2]: [c_customer_sk#29, sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))#35 AS ssales#36] (79) Filter [codegen id : 32] Input [2]: [c_customer_sk#29, ssales#36] -Condition : (isnotnull(ssales#36) AND (cast(ssales#36 as decimal(38,8)) > CheckOverflow((0.500000 * promote_precision(cast(ReusedSubquery Subquery scalar-subquery#37, [id=#38] as decimal(32,6)))), DecimalType(38,8), true))) +Condition : (isnotnull(ssales#36) AND (cast(ssales#36 as decimal(38,8)) > CheckOverflow((0.500000 * promote_precision(cast(ReusedSubquery Subquery scalar-subquery#37, [id=#38] as decimal(32,6)))), DecimalType(38,8)))) (80) Project [codegen id : 32] Output [1]: [c_customer_sk#29] @@ -465,16 +465,16 @@ Output [3]: [ws_quantity#43, ws_list_price#44, ws_sold_date_sk#45] Input [4]: [ws_bill_customer_sk#42, ws_quantity#43, ws_list_price#44, ws_sold_date_sk#45] (84) ReusedExchange [Reuses operator id: 95] -Output [1]: [d_date_sk#49] +Output [1]: [d_date_sk#53] (85) BroadcastHashJoin [codegen id : 34] Left keys [1]: [ws_sold_date_sk#45] -Right keys [1]: [d_date_sk#49] +Right keys [1]: [d_date_sk#53] Join condition: None (86) Project [codegen id : 34] -Output [1]: [CheckOverflow((promote_precision(cast(cast(ws_quantity#43 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price#44 as decimal(12,2)))), DecimalType(18,2), true) AS sales#50] -Input [4]: [ws_quantity#43, ws_list_price#44, ws_sold_date_sk#45, d_date_sk#49] +Output [1]: [CheckOverflow((promote_precision(cast(ws_quantity#43 as decimal(12,2))) * promote_precision(cast(ws_list_price#44 as decimal(12,2)))), DecimalType(18,2)) AS sales#54] +Input [4]: [ws_quantity#43, ws_list_price#44, ws_sold_date_sk#45, d_date_sk#53] (87) Union @@ -482,19 +482,19 @@ Input [4]: [ws_quantity#43, ws_list_price#44, ws_sold_date_sk#45, d_date_sk#49] Input [1]: [sales#40] Keys: [] Functions [1]: [partial_sum(sales#40)] -Aggregate Attributes [2]: [sum#51, isEmpty#52] -Results [2]: [sum#53, isEmpty#54] +Aggregate Attributes [2]: [sum#55, isEmpty#56] +Results [2]: [sum#57, isEmpty#58] (89) Exchange -Input [2]: [sum#53, isEmpty#54] -Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#55] +Input [2]: [sum#57, isEmpty#58] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#59] (90) HashAggregate [codegen id : 36] -Input [2]: [sum#53, isEmpty#54] +Input [2]: [sum#57, isEmpty#58] Keys: [] Functions [1]: [sum(sales#40)] -Aggregate Attributes [1]: [sum(sales#40)#56] -Results [1]: [sum(sales#40)#56 AS sum(sales)#57] +Aggregate Attributes [1]: [sum(sales#40)#60] +Results [1]: [sum(sales#40)#60 AS sum(sales)#61] ===== Subqueries ===== @@ -507,26 +507,26 @@ BroadcastExchange (95) (91) Scan parquet default.date_dim -Output [3]: [d_date_sk#39, d_year#58, d_moy#59] +Output [3]: [d_date_sk#39, d_year#62, d_moy#63] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), IsNotNull(d_moy), EqualTo(d_year,2000), EqualTo(d_moy,2), IsNotNull(d_date_sk)] ReadSchema: struct (92) ColumnarToRow [codegen id : 1] -Input [3]: [d_date_sk#39, d_year#58, d_moy#59] +Input [3]: [d_date_sk#39, d_year#62, d_moy#63] (93) Filter [codegen id : 1] -Input [3]: [d_date_sk#39, d_year#58, d_moy#59] -Condition : ((((isnotnull(d_year#58) AND isnotnull(d_moy#59)) AND (d_year#58 = 2000)) AND (d_moy#59 = 2)) AND isnotnull(d_date_sk#39)) +Input [3]: [d_date_sk#39, d_year#62, d_moy#63] +Condition : ((((isnotnull(d_year#62) AND isnotnull(d_moy#63)) AND (d_year#62 = 2000)) AND (d_moy#63 = 2)) AND isnotnull(d_date_sk#39)) (94) Project [codegen id : 1] Output [1]: [d_date_sk#39] -Input [3]: [d_date_sk#39, d_year#58, d_moy#59] +Input [3]: [d_date_sk#39, d_year#62, d_moy#63] (95) BroadcastExchange Input [1]: [d_date_sk#39] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#60] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#64] Subquery:2 Hosting operator id = 5 Hosting Expression = ss_sold_date_sk#9 IN dynamicpruning#10 BroadcastExchange (100) @@ -537,26 +537,26 @@ BroadcastExchange (100) (96) Scan parquet default.date_dim -Output [3]: [d_date_sk#11, d_date#12, d_year#61] +Output [3]: [d_date_sk#11, d_date#12, d_year#65] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [In(d_year, [2000,2001,2002,2003]), IsNotNull(d_date_sk)] ReadSchema: struct (97) ColumnarToRow [codegen id : 1] -Input [3]: [d_date_sk#11, d_date#12, d_year#61] +Input [3]: [d_date_sk#11, d_date#12, d_year#65] (98) Filter [codegen id : 1] -Input [3]: [d_date_sk#11, d_date#12, d_year#61] -Condition : (d_year#61 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#11)) +Input [3]: [d_date_sk#11, d_date#12, d_year#65] +Condition : (d_year#65 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#11)) (99) Project [codegen id : 1] Output [2]: [d_date_sk#11, d_date#12] -Input [3]: [d_date_sk#11, d_date#12, d_year#61] +Input [3]: [d_date_sk#11, d_date#12, d_year#65] (100) BroadcastExchange Input [2]: [d_date_sk#11, d_date#12] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#62] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#66] Subquery:3 Hosting operator id = 44 Hosting Expression = Subquery scalar-subquery#37, [id=#38] * HashAggregate (117) @@ -579,89 +579,89 @@ Subquery:3 Hosting operator id = 44 Hosting Expression = Subquery scalar-subquer (101) Scan parquet default.store_sales -Output [4]: [ss_customer_sk#63, ss_quantity#64, ss_sales_price#65, ss_sold_date_sk#66] +Output [4]: [ss_customer_sk#67, ss_quantity#68, ss_sales_price#69, ss_sold_date_sk#70] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(ss_sold_date_sk#66), dynamicpruningexpression(ss_sold_date_sk#66 IN dynamicpruning#67)] +PartitionFilters: [isnotnull(ss_sold_date_sk#70), dynamicpruningexpression(ss_sold_date_sk#70 IN dynamicpruning#71)] PushedFilters: [IsNotNull(ss_customer_sk)] ReadSchema: struct (102) ColumnarToRow [codegen id : 2] -Input [4]: [ss_customer_sk#63, ss_quantity#64, ss_sales_price#65, ss_sold_date_sk#66] +Input [4]: [ss_customer_sk#67, ss_quantity#68, ss_sales_price#69, ss_sold_date_sk#70] (103) Filter [codegen id : 2] -Input [4]: [ss_customer_sk#63, ss_quantity#64, ss_sales_price#65, ss_sold_date_sk#66] -Condition : isnotnull(ss_customer_sk#63) +Input [4]: [ss_customer_sk#67, ss_quantity#68, ss_sales_price#69, ss_sold_date_sk#70] +Condition : isnotnull(ss_customer_sk#67) (104) ReusedExchange [Reuses operator id: 122] -Output [1]: [d_date_sk#68] +Output [1]: [d_date_sk#72] (105) BroadcastHashJoin [codegen id : 2] -Left keys [1]: [ss_sold_date_sk#66] -Right keys [1]: [d_date_sk#68] +Left keys [1]: [ss_sold_date_sk#70] +Right keys [1]: [d_date_sk#72] Join condition: None (106) Project [codegen id : 2] -Output [3]: [ss_customer_sk#63, ss_quantity#64, ss_sales_price#65] -Input [5]: [ss_customer_sk#63, ss_quantity#64, ss_sales_price#65, ss_sold_date_sk#66, d_date_sk#68] +Output [3]: [ss_customer_sk#67, ss_quantity#68, ss_sales_price#69] +Input [5]: [ss_customer_sk#67, ss_quantity#68, ss_sales_price#69, ss_sold_date_sk#70, d_date_sk#72] (107) Exchange -Input [3]: [ss_customer_sk#63, ss_quantity#64, ss_sales_price#65] -Arguments: hashpartitioning(ss_customer_sk#63, 5), ENSURE_REQUIREMENTS, [id=#69] +Input [3]: [ss_customer_sk#67, ss_quantity#68, ss_sales_price#69] +Arguments: hashpartitioning(ss_customer_sk#67, 5), ENSURE_REQUIREMENTS, [id=#73] (108) Sort [codegen id : 3] -Input [3]: [ss_customer_sk#63, ss_quantity#64, ss_sales_price#65] -Arguments: [ss_customer_sk#63 ASC NULLS FIRST], false, 0 +Input [3]: [ss_customer_sk#67, ss_quantity#68, ss_sales_price#69] +Arguments: [ss_customer_sk#67 ASC NULLS FIRST], false, 0 (109) ReusedExchange [Reuses operator id: 38] -Output [1]: [c_customer_sk#70] +Output [1]: [c_customer_sk#74] (110) Sort [codegen id : 5] -Input [1]: [c_customer_sk#70] -Arguments: [c_customer_sk#70 ASC NULLS FIRST], false, 0 +Input [1]: [c_customer_sk#74] +Arguments: [c_customer_sk#74 ASC NULLS FIRST], false, 0 (111) SortMergeJoin [codegen id : 6] -Left keys [1]: [ss_customer_sk#63] -Right keys [1]: [c_customer_sk#70] +Left keys [1]: [ss_customer_sk#67] +Right keys [1]: [c_customer_sk#74] Join condition: None (112) Project [codegen id : 6] -Output [3]: [ss_quantity#64, ss_sales_price#65, c_customer_sk#70] -Input [4]: [ss_customer_sk#63, ss_quantity#64, ss_sales_price#65, c_customer_sk#70] +Output [3]: [ss_quantity#68, ss_sales_price#69, c_customer_sk#74] +Input [4]: [ss_customer_sk#67, ss_quantity#68, ss_sales_price#69, c_customer_sk#74] (113) HashAggregate [codegen id : 6] -Input [3]: [ss_quantity#64, ss_sales_price#65, c_customer_sk#70] -Keys [1]: [c_customer_sk#70] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#64 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#65 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [2]: [sum#71, isEmpty#72] -Results [3]: [c_customer_sk#70, sum#73, isEmpty#74] +Input [3]: [ss_quantity#68, ss_sales_price#69, c_customer_sk#74] +Keys [1]: [c_customer_sk#74] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#68 as decimal(12,2))) * promote_precision(cast(ss_sales_price#69 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [2]: [sum#75, isEmpty#76] +Results [3]: [c_customer_sk#74, sum#77, isEmpty#78] (114) HashAggregate [codegen id : 6] -Input [3]: [c_customer_sk#70, sum#73, isEmpty#74] -Keys [1]: [c_customer_sk#70] -Functions [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#64 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#65 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#64 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#65 as decimal(12,2)))), DecimalType(18,2), true))#75] -Results [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#64 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#65 as decimal(12,2)))), DecimalType(18,2), true))#75 AS csales#76] +Input [3]: [c_customer_sk#74, sum#77, isEmpty#78] +Keys [1]: [c_customer_sk#74] +Functions [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#68 as decimal(12,2))) * promote_precision(cast(ss_sales_price#69 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#68 as decimal(12,2))) * promote_precision(cast(ss_sales_price#69 as decimal(12,2)))), DecimalType(18,2)))#79] +Results [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#68 as decimal(12,2))) * promote_precision(cast(ss_sales_price#69 as decimal(12,2)))), DecimalType(18,2)))#79 AS csales#80] (115) HashAggregate [codegen id : 6] -Input [1]: [csales#76] +Input [1]: [csales#80] Keys: [] -Functions [1]: [partial_max(csales#76)] -Aggregate Attributes [1]: [max#77] -Results [1]: [max#78] +Functions [1]: [partial_max(csales#80)] +Aggregate Attributes [1]: [max#81] +Results [1]: [max#82] (116) Exchange -Input [1]: [max#78] -Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#79] +Input [1]: [max#82] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#83] (117) HashAggregate [codegen id : 7] -Input [1]: [max#78] +Input [1]: [max#82] Keys: [] -Functions [1]: [max(csales#76)] -Aggregate Attributes [1]: [max(csales#76)#80] -Results [1]: [max(csales#76)#80 AS tpcds_cmax#81] +Functions [1]: [max(csales#80)] +Aggregate Attributes [1]: [max(csales#80)#84] +Results [1]: [max(csales#80)#84 AS tpcds_cmax#85] -Subquery:4 Hosting operator id = 101 Hosting Expression = ss_sold_date_sk#66 IN dynamicpruning#67 +Subquery:4 Hosting operator id = 101 Hosting Expression = ss_sold_date_sk#70 IN dynamicpruning#71 BroadcastExchange (122) +- * Project (121) +- * Filter (120) @@ -670,26 +670,26 @@ BroadcastExchange (122) (118) Scan parquet default.date_dim -Output [2]: [d_date_sk#68, d_year#82] +Output [2]: [d_date_sk#72, d_year#86] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [In(d_year, [2000,2001,2002,2003]), IsNotNull(d_date_sk)] ReadSchema: struct (119) ColumnarToRow [codegen id : 1] -Input [2]: [d_date_sk#68, d_year#82] +Input [2]: [d_date_sk#72, d_year#86] (120) Filter [codegen id : 1] -Input [2]: [d_date_sk#68, d_year#82] -Condition : (d_year#82 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#68)) +Input [2]: [d_date_sk#72, d_year#86] +Condition : (d_year#86 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#72)) (121) Project [codegen id : 1] -Output [1]: [d_date_sk#68] -Input [2]: [d_date_sk#68, d_year#82] +Output [1]: [d_date_sk#72] +Input [2]: [d_date_sk#72, d_year#86] (122) BroadcastExchange -Input [1]: [d_date_sk#68] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#83] +Input [1]: [d_date_sk#72] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#87] Subquery:5 Hosting operator id = 52 Hosting Expression = ws_sold_date_sk#45 IN dynamicpruning#6 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23a.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23a.sf100/simplified.txt index 17377b91326fd..0683b263ea290 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23a.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23a.sf100/simplified.txt @@ -89,7 +89,7 @@ WholeStageCodegen (36) Exchange #10 WholeStageCodegen (6) HashAggregate [csales] [max,max] - HashAggregate [c_customer_sk,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price as decimal(12,2)))), DecimalType(18,2), true)),csales,sum,isEmpty] + HashAggregate [c_customer_sk,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(ss_quantity as decimal(12,2))) * promote_precision(cast(ss_sales_price as decimal(12,2)))), DecimalType(18,2))),csales,sum,isEmpty] HashAggregate [c_customer_sk,ss_quantity,ss_sales_price] [sum,isEmpty,sum,isEmpty] Project [ss_quantity,ss_sales_price,c_customer_sk] SortMergeJoin [ss_customer_sk,c_customer_sk] @@ -120,7 +120,7 @@ WholeStageCodegen (36) Sort [c_customer_sk] InputAdapter ReusedExchange [c_customer_sk] #9 - HashAggregate [c_customer_sk,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price as decimal(12,2)))), DecimalType(18,2), true)),ssales,sum,isEmpty] + HashAggregate [c_customer_sk,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(ss_quantity as decimal(12,2))) * promote_precision(cast(ss_sales_price as decimal(12,2)))), DecimalType(18,2))),ssales,sum,isEmpty] HashAggregate [c_customer_sk,ss_quantity,ss_sales_price] [sum,isEmpty,sum,isEmpty] Project [ss_quantity,ss_sales_price,c_customer_sk] SortMergeJoin [ss_customer_sk,c_customer_sk] @@ -195,7 +195,7 @@ WholeStageCodegen (36) Project [c_customer_sk] Filter [ssales] ReusedSubquery [tpcds_cmax] #3 - HashAggregate [c_customer_sk,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price as decimal(12,2)))), DecimalType(18,2), true)),ssales,sum,isEmpty] + HashAggregate [c_customer_sk,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(ss_quantity as decimal(12,2))) * promote_precision(cast(ss_sales_price as decimal(12,2)))), DecimalType(18,2))),ssales,sum,isEmpty] HashAggregate [c_customer_sk,ss_quantity,ss_sales_price] [sum,isEmpty,sum,isEmpty] Project [ss_quantity,ss_sales_price,c_customer_sk] SortMergeJoin [ss_customer_sk,c_customer_sk] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23a/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23a/explain.txt index 1de23e1f4d2ab..58d6c22f3fd05 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23a/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23a/explain.txt @@ -226,7 +226,7 @@ Input [4]: [ss_customer_sk#24, ss_quantity#25, ss_sales_price#26, c_customer_sk# (35) HashAggregate [codegen id : 8] Input [3]: [ss_quantity#25, ss_sales_price#26, c_customer_sk#28] Keys [1]: [c_customer_sk#28] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))] Aggregate Attributes [2]: [sum#30, isEmpty#31] Results [3]: [c_customer_sk#28, sum#32, isEmpty#33] @@ -237,13 +237,13 @@ Arguments: hashpartitioning(c_customer_sk#28, 5), ENSURE_REQUIREMENTS, [id=#34] (37) HashAggregate [codegen id : 9] Input [3]: [c_customer_sk#28, sum#32, isEmpty#33] Keys [1]: [c_customer_sk#28] -Functions [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))#35] -Results [2]: [c_customer_sk#28, sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))#35 AS ssales#36] +Functions [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))#35] +Results [2]: [c_customer_sk#28, sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))#35 AS ssales#36] (38) Filter [codegen id : 9] Input [2]: [c_customer_sk#28, ssales#36] -Condition : (isnotnull(ssales#36) AND (cast(ssales#36 as decimal(38,8)) > CheckOverflow((0.500000 * promote_precision(cast(Subquery scalar-subquery#37, [id=#38] as decimal(32,6)))), DecimalType(38,8), true))) +Condition : (isnotnull(ssales#36) AND (cast(ssales#36 as decimal(38,8)) > CheckOverflow((0.500000 * promote_precision(cast(Subquery scalar-subquery#37, [id=#38] as decimal(32,6)))), DecimalType(38,8)))) (39) Project [codegen id : 9] Output [1]: [c_customer_sk#28] @@ -271,7 +271,7 @@ Right keys [1]: [d_date_sk#39] Join condition: None (45) Project [codegen id : 11] -Output [1]: [CheckOverflow((promote_precision(cast(cast(cs_quantity#3 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price#4 as decimal(12,2)))), DecimalType(18,2), true) AS sales#40] +Output [1]: [CheckOverflow((promote_precision(cast(cs_quantity#3 as decimal(12,2))) * promote_precision(cast(cs_list_price#4 as decimal(12,2)))), DecimalType(18,2)) AS sales#40] Input [4]: [cs_quantity#3, cs_list_price#4, cs_sold_date_sk#5, d_date_sk#39] (46) Scan parquet default.web_sales @@ -305,18 +305,18 @@ Input [4]: [ws_bill_customer_sk#42, ws_quantity#43, ws_list_price#44, ws_sold_da Arguments: [ws_bill_customer_sk#42 ASC NULLS FIRST], false, 0 (53) ReusedExchange [Reuses operator id: 36] -Output [3]: [c_customer_sk#28, sum#32, isEmpty#33] +Output [3]: [c_customer_sk#28, sum#47, isEmpty#48] (54) HashAggregate [codegen id : 20] -Input [3]: [c_customer_sk#28, sum#32, isEmpty#33] +Input [3]: [c_customer_sk#28, sum#47, isEmpty#48] Keys [1]: [c_customer_sk#28] -Functions [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))#35] -Results [2]: [c_customer_sk#28, sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))#35 AS ssales#36] +Functions [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))#35] +Results [2]: [c_customer_sk#28, sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))#35 AS ssales#36] (55) Filter [codegen id : 20] Input [2]: [c_customer_sk#28, ssales#36] -Condition : (isnotnull(ssales#36) AND (cast(ssales#36 as decimal(38,8)) > CheckOverflow((0.500000 * promote_precision(cast(ReusedSubquery Subquery scalar-subquery#37, [id=#38] as decimal(32,6)))), DecimalType(38,8), true))) +Condition : (isnotnull(ssales#36) AND (cast(ssales#36 as decimal(38,8)) > CheckOverflow((0.500000 * promote_precision(cast(ReusedSubquery Subquery scalar-subquery#37, [id=#38] as decimal(32,6)))), DecimalType(38,8)))) (56) Project [codegen id : 20] Output [1]: [c_customer_sk#28] @@ -336,16 +336,16 @@ Output [3]: [ws_quantity#43, ws_list_price#44, ws_sold_date_sk#45] Input [4]: [ws_bill_customer_sk#42, ws_quantity#43, ws_list_price#44, ws_sold_date_sk#45] (60) ReusedExchange [Reuses operator id: 71] -Output [1]: [d_date_sk#47] +Output [1]: [d_date_sk#49] (61) BroadcastHashJoin [codegen id : 22] Left keys [1]: [ws_sold_date_sk#45] -Right keys [1]: [d_date_sk#47] +Right keys [1]: [d_date_sk#49] Join condition: None (62) Project [codegen id : 22] -Output [1]: [CheckOverflow((promote_precision(cast(cast(ws_quantity#43 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price#44 as decimal(12,2)))), DecimalType(18,2), true) AS sales#48] -Input [4]: [ws_quantity#43, ws_list_price#44, ws_sold_date_sk#45, d_date_sk#47] +Output [1]: [CheckOverflow((promote_precision(cast(ws_quantity#43 as decimal(12,2))) * promote_precision(cast(ws_list_price#44 as decimal(12,2)))), DecimalType(18,2)) AS sales#50] +Input [4]: [ws_quantity#43, ws_list_price#44, ws_sold_date_sk#45, d_date_sk#49] (63) Union @@ -353,19 +353,19 @@ Input [4]: [ws_quantity#43, ws_list_price#44, ws_sold_date_sk#45, d_date_sk#47] Input [1]: [sales#40] Keys: [] Functions [1]: [partial_sum(sales#40)] -Aggregate Attributes [2]: [sum#49, isEmpty#50] -Results [2]: [sum#51, isEmpty#52] +Aggregate Attributes [2]: [sum#51, isEmpty#52] +Results [2]: [sum#53, isEmpty#54] (65) Exchange -Input [2]: [sum#51, isEmpty#52] -Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#53] +Input [2]: [sum#53, isEmpty#54] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#55] (66) HashAggregate [codegen id : 24] -Input [2]: [sum#51, isEmpty#52] +Input [2]: [sum#53, isEmpty#54] Keys: [] Functions [1]: [sum(sales#40)] -Aggregate Attributes [1]: [sum(sales#40)#54] -Results [1]: [sum(sales#40)#54 AS sum(sales)#55] +Aggregate Attributes [1]: [sum(sales#40)#56] +Results [1]: [sum(sales#40)#56 AS sum(sales)#57] ===== Subqueries ===== @@ -378,26 +378,26 @@ BroadcastExchange (71) (67) Scan parquet default.date_dim -Output [3]: [d_date_sk#39, d_year#56, d_moy#57] +Output [3]: [d_date_sk#39, d_year#58, d_moy#59] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), IsNotNull(d_moy), EqualTo(d_year,2000), EqualTo(d_moy,2), IsNotNull(d_date_sk)] ReadSchema: struct (68) ColumnarToRow [codegen id : 1] -Input [3]: [d_date_sk#39, d_year#56, d_moy#57] +Input [3]: [d_date_sk#39, d_year#58, d_moy#59] (69) Filter [codegen id : 1] -Input [3]: [d_date_sk#39, d_year#56, d_moy#57] -Condition : ((((isnotnull(d_year#56) AND isnotnull(d_moy#57)) AND (d_year#56 = 2000)) AND (d_moy#57 = 2)) AND isnotnull(d_date_sk#39)) +Input [3]: [d_date_sk#39, d_year#58, d_moy#59] +Condition : ((((isnotnull(d_year#58) AND isnotnull(d_moy#59)) AND (d_year#58 = 2000)) AND (d_moy#59 = 2)) AND isnotnull(d_date_sk#39)) (70) Project [codegen id : 1] Output [1]: [d_date_sk#39] -Input [3]: [d_date_sk#39, d_year#56, d_moy#57] +Input [3]: [d_date_sk#39, d_year#58, d_moy#59] (71) BroadcastExchange Input [1]: [d_date_sk#39] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#58] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#60] Subquery:2 Hosting operator id = 3 Hosting Expression = ss_sold_date_sk#8 IN dynamicpruning#9 BroadcastExchange (76) @@ -408,26 +408,26 @@ BroadcastExchange (76) (72) Scan parquet default.date_dim -Output [3]: [d_date_sk#10, d_date#11, d_year#59] +Output [3]: [d_date_sk#10, d_date#11, d_year#61] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [In(d_year, [2000,2001,2002,2003]), IsNotNull(d_date_sk)] ReadSchema: struct (73) ColumnarToRow [codegen id : 1] -Input [3]: [d_date_sk#10, d_date#11, d_year#59] +Input [3]: [d_date_sk#10, d_date#11, d_year#61] (74) Filter [codegen id : 1] -Input [3]: [d_date_sk#10, d_date#11, d_year#59] -Condition : (d_year#59 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#10)) +Input [3]: [d_date_sk#10, d_date#11, d_year#61] +Condition : (d_year#61 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#10)) (75) Project [codegen id : 1] Output [2]: [d_date_sk#10, d_date#11] -Input [3]: [d_date_sk#10, d_date#11, d_year#59] +Input [3]: [d_date_sk#10, d_date#11, d_year#61] (76) BroadcastExchange Input [2]: [d_date_sk#10, d_date#11] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#60] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#62] Subquery:3 Hosting operator id = 38 Hosting Expression = Subquery scalar-subquery#37, [id=#38] * HashAggregate (91) @@ -448,81 +448,81 @@ Subquery:3 Hosting operator id = 38 Hosting Expression = Subquery scalar-subquer (77) Scan parquet default.store_sales -Output [4]: [ss_customer_sk#61, ss_quantity#62, ss_sales_price#63, ss_sold_date_sk#64] +Output [4]: [ss_customer_sk#63, ss_quantity#64, ss_sales_price#65, ss_sold_date_sk#66] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(ss_sold_date_sk#64), dynamicpruningexpression(ss_sold_date_sk#64 IN dynamicpruning#65)] +PartitionFilters: [isnotnull(ss_sold_date_sk#66), dynamicpruningexpression(ss_sold_date_sk#66 IN dynamicpruning#67)] PushedFilters: [IsNotNull(ss_customer_sk)] ReadSchema: struct (78) ColumnarToRow [codegen id : 3] -Input [4]: [ss_customer_sk#61, ss_quantity#62, ss_sales_price#63, ss_sold_date_sk#64] +Input [4]: [ss_customer_sk#63, ss_quantity#64, ss_sales_price#65, ss_sold_date_sk#66] (79) Filter [codegen id : 3] -Input [4]: [ss_customer_sk#61, ss_quantity#62, ss_sales_price#63, ss_sold_date_sk#64] -Condition : isnotnull(ss_customer_sk#61) +Input [4]: [ss_customer_sk#63, ss_quantity#64, ss_sales_price#65, ss_sold_date_sk#66] +Condition : isnotnull(ss_customer_sk#63) (80) ReusedExchange [Reuses operator id: 32] -Output [1]: [c_customer_sk#66] +Output [1]: [c_customer_sk#68] (81) BroadcastHashJoin [codegen id : 3] -Left keys [1]: [ss_customer_sk#61] -Right keys [1]: [c_customer_sk#66] +Left keys [1]: [ss_customer_sk#63] +Right keys [1]: [c_customer_sk#68] Join condition: None (82) Project [codegen id : 3] -Output [4]: [ss_quantity#62, ss_sales_price#63, ss_sold_date_sk#64, c_customer_sk#66] -Input [5]: [ss_customer_sk#61, ss_quantity#62, ss_sales_price#63, ss_sold_date_sk#64, c_customer_sk#66] +Output [4]: [ss_quantity#64, ss_sales_price#65, ss_sold_date_sk#66, c_customer_sk#68] +Input [5]: [ss_customer_sk#63, ss_quantity#64, ss_sales_price#65, ss_sold_date_sk#66, c_customer_sk#68] (83) ReusedExchange [Reuses operator id: 96] -Output [1]: [d_date_sk#67] +Output [1]: [d_date_sk#69] (84) BroadcastHashJoin [codegen id : 3] -Left keys [1]: [ss_sold_date_sk#64] -Right keys [1]: [d_date_sk#67] +Left keys [1]: [ss_sold_date_sk#66] +Right keys [1]: [d_date_sk#69] Join condition: None (85) Project [codegen id : 3] -Output [3]: [ss_quantity#62, ss_sales_price#63, c_customer_sk#66] -Input [5]: [ss_quantity#62, ss_sales_price#63, ss_sold_date_sk#64, c_customer_sk#66, d_date_sk#67] +Output [3]: [ss_quantity#64, ss_sales_price#65, c_customer_sk#68] +Input [5]: [ss_quantity#64, ss_sales_price#65, ss_sold_date_sk#66, c_customer_sk#68, d_date_sk#69] (86) HashAggregate [codegen id : 3] -Input [3]: [ss_quantity#62, ss_sales_price#63, c_customer_sk#66] -Keys [1]: [c_customer_sk#66] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#62 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#63 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [2]: [sum#68, isEmpty#69] -Results [3]: [c_customer_sk#66, sum#70, isEmpty#71] +Input [3]: [ss_quantity#64, ss_sales_price#65, c_customer_sk#68] +Keys [1]: [c_customer_sk#68] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#64 as decimal(12,2))) * promote_precision(cast(ss_sales_price#65 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [2]: [sum#70, isEmpty#71] +Results [3]: [c_customer_sk#68, sum#72, isEmpty#73] (87) Exchange -Input [3]: [c_customer_sk#66, sum#70, isEmpty#71] -Arguments: hashpartitioning(c_customer_sk#66, 5), ENSURE_REQUIREMENTS, [id=#72] +Input [3]: [c_customer_sk#68, sum#72, isEmpty#73] +Arguments: hashpartitioning(c_customer_sk#68, 5), ENSURE_REQUIREMENTS, [id=#74] (88) HashAggregate [codegen id : 4] -Input [3]: [c_customer_sk#66, sum#70, isEmpty#71] -Keys [1]: [c_customer_sk#66] -Functions [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#62 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#63 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#62 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#63 as decimal(12,2)))), DecimalType(18,2), true))#73] -Results [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#62 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#63 as decimal(12,2)))), DecimalType(18,2), true))#73 AS csales#74] +Input [3]: [c_customer_sk#68, sum#72, isEmpty#73] +Keys [1]: [c_customer_sk#68] +Functions [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#64 as decimal(12,2))) * promote_precision(cast(ss_sales_price#65 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#64 as decimal(12,2))) * promote_precision(cast(ss_sales_price#65 as decimal(12,2)))), DecimalType(18,2)))#75] +Results [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#64 as decimal(12,2))) * promote_precision(cast(ss_sales_price#65 as decimal(12,2)))), DecimalType(18,2)))#75 AS csales#76] (89) HashAggregate [codegen id : 4] -Input [1]: [csales#74] +Input [1]: [csales#76] Keys: [] -Functions [1]: [partial_max(csales#74)] -Aggregate Attributes [1]: [max#75] -Results [1]: [max#76] +Functions [1]: [partial_max(csales#76)] +Aggregate Attributes [1]: [max#77] +Results [1]: [max#78] (90) Exchange -Input [1]: [max#76] -Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#77] +Input [1]: [max#78] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#79] (91) HashAggregate [codegen id : 5] -Input [1]: [max#76] +Input [1]: [max#78] Keys: [] -Functions [1]: [max(csales#74)] -Aggregate Attributes [1]: [max(csales#74)#78] -Results [1]: [max(csales#74)#78 AS tpcds_cmax#79] +Functions [1]: [max(csales#76)] +Aggregate Attributes [1]: [max(csales#76)#80] +Results [1]: [max(csales#76)#80 AS tpcds_cmax#81] -Subquery:4 Hosting operator id = 77 Hosting Expression = ss_sold_date_sk#64 IN dynamicpruning#65 +Subquery:4 Hosting operator id = 77 Hosting Expression = ss_sold_date_sk#66 IN dynamicpruning#67 BroadcastExchange (96) +- * Project (95) +- * Filter (94) @@ -531,26 +531,26 @@ BroadcastExchange (96) (92) Scan parquet default.date_dim -Output [2]: [d_date_sk#67, d_year#80] +Output [2]: [d_date_sk#69, d_year#82] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [In(d_year, [2000,2001,2002,2003]), IsNotNull(d_date_sk)] ReadSchema: struct (93) ColumnarToRow [codegen id : 1] -Input [2]: [d_date_sk#67, d_year#80] +Input [2]: [d_date_sk#69, d_year#82] (94) Filter [codegen id : 1] -Input [2]: [d_date_sk#67, d_year#80] -Condition : (d_year#80 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#67)) +Input [2]: [d_date_sk#69, d_year#82] +Condition : (d_year#82 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#69)) (95) Project [codegen id : 1] -Output [1]: [d_date_sk#67] -Input [2]: [d_date_sk#67, d_year#80] +Output [1]: [d_date_sk#69] +Input [2]: [d_date_sk#69, d_year#82] (96) BroadcastExchange -Input [1]: [d_date_sk#67] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#81] +Input [1]: [d_date_sk#69] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#83] Subquery:5 Hosting operator id = 46 Hosting Expression = ws_sold_date_sk#45 IN dynamicpruning#6 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23a/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23a/simplified.txt index 5c5a8a7fe425f..d38e147d305c7 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23a/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23a/simplified.txt @@ -77,7 +77,7 @@ WholeStageCodegen (24) Exchange #10 WholeStageCodegen (4) HashAggregate [csales] [max,max] - HashAggregate [c_customer_sk,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price as decimal(12,2)))), DecimalType(18,2), true)),csales,sum,isEmpty] + HashAggregate [c_customer_sk,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(ss_quantity as decimal(12,2))) * promote_precision(cast(ss_sales_price as decimal(12,2)))), DecimalType(18,2))),csales,sum,isEmpty] InputAdapter Exchange [c_customer_sk] #11 WholeStageCodegen (3) @@ -102,7 +102,7 @@ WholeStageCodegen (24) ReusedExchange [c_customer_sk] #9 InputAdapter ReusedExchange [d_date_sk] #12 - HashAggregate [c_customer_sk,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price as decimal(12,2)))), DecimalType(18,2), true)),ssales,sum,isEmpty] + HashAggregate [c_customer_sk,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(ss_quantity as decimal(12,2))) * promote_precision(cast(ss_sales_price as decimal(12,2)))), DecimalType(18,2))),ssales,sum,isEmpty] InputAdapter Exchange [c_customer_sk] #8 WholeStageCodegen (8) @@ -148,7 +148,7 @@ WholeStageCodegen (24) Project [c_customer_sk] Filter [ssales] ReusedSubquery [tpcds_cmax] #3 - HashAggregate [c_customer_sk,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price as decimal(12,2)))), DecimalType(18,2), true)),ssales,sum,isEmpty] + HashAggregate [c_customer_sk,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(ss_quantity as decimal(12,2))) * promote_precision(cast(ss_sales_price as decimal(12,2)))), DecimalType(18,2))),ssales,sum,isEmpty] InputAdapter ReusedExchange [c_customer_sk,sum,isEmpty] #8 InputAdapter diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23b.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23b.sf100/explain.txt index 638f5ec3ded62..3de1f24613451 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23b.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23b.sf100/explain.txt @@ -322,20 +322,20 @@ Input [4]: [ss_customer_sk#24, ss_quantity#25, ss_sales_price#26, c_customer_sk# (43) HashAggregate [codegen id : 15] Input [3]: [ss_quantity#25, ss_sales_price#26, c_customer_sk#29] Keys [1]: [c_customer_sk#29] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))] Aggregate Attributes [2]: [sum#31, isEmpty#32] Results [3]: [c_customer_sk#29, sum#33, isEmpty#34] (44) HashAggregate [codegen id : 15] Input [3]: [c_customer_sk#29, sum#33, isEmpty#34] Keys [1]: [c_customer_sk#29] -Functions [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))#35] -Results [2]: [c_customer_sk#29, sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))#35 AS ssales#36] +Functions [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))#35] +Results [2]: [c_customer_sk#29, sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))#35 AS ssales#36] (45) Filter [codegen id : 15] Input [2]: [c_customer_sk#29, ssales#36] -Condition : (isnotnull(ssales#36) AND (cast(ssales#36 as decimal(38,8)) > CheckOverflow((0.500000 * promote_precision(cast(Subquery scalar-subquery#37, [id=#38] as decimal(32,6)))), DecimalType(38,8), true))) +Condition : (isnotnull(ssales#36) AND (cast(ssales#36 as decimal(38,8)) > CheckOverflow((0.500000 * promote_precision(cast(Subquery scalar-subquery#37, [id=#38] as decimal(32,6)))), DecimalType(38,8)))) (46) Project [codegen id : 15] Output [1]: [c_customer_sk#29] @@ -410,20 +410,20 @@ Input [4]: [ss_customer_sk#24, ss_quantity#25, ss_sales_price#26, c_customer_sk# (63) HashAggregate [codegen id : 24] Input [3]: [ss_quantity#25, ss_sales_price#26, c_customer_sk#29] Keys [1]: [c_customer_sk#29] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))] Aggregate Attributes [2]: [sum#31, isEmpty#32] Results [3]: [c_customer_sk#29, sum#33, isEmpty#34] (64) HashAggregate [codegen id : 24] Input [3]: [c_customer_sk#29, sum#33, isEmpty#34] Keys [1]: [c_customer_sk#29] -Functions [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))#35] -Results [2]: [c_customer_sk#29, sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))#35 AS ssales#36] +Functions [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))#35] +Results [2]: [c_customer_sk#29, sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))#35 AS ssales#36] (65) Filter [codegen id : 24] Input [2]: [c_customer_sk#29, ssales#36] -Condition : (isnotnull(ssales#36) AND (cast(ssales#36 as decimal(38,8)) > CheckOverflow((0.500000 * promote_precision(cast(ReusedSubquery Subquery scalar-subquery#37, [id=#38] as decimal(32,6)))), DecimalType(38,8), true))) +Condition : (isnotnull(ssales#36) AND (cast(ssales#36 as decimal(38,8)) > CheckOverflow((0.500000 * promote_precision(cast(ReusedSubquery Subquery scalar-subquery#37, [id=#38] as decimal(32,6)))), DecimalType(38,8)))) (66) Project [codegen id : 24] Output [1]: [c_customer_sk#29] @@ -450,7 +450,7 @@ Input [6]: [cs_bill_customer_sk#1, cs_quantity#3, cs_list_price#4, c_customer_sk (71) HashAggregate [codegen id : 26] Input [4]: [cs_quantity#3, cs_list_price#4, c_first_name#41, c_last_name#42] Keys [2]: [c_last_name#42, c_first_name#41] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(cast(cs_quantity#3 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price#4 as decimal(12,2)))), DecimalType(18,2), true))] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(cs_quantity#3 as decimal(12,2))) * promote_precision(cast(cs_list_price#4 as decimal(12,2)))), DecimalType(18,2)))] Aggregate Attributes [2]: [sum#44, isEmpty#45] Results [4]: [c_last_name#42, c_first_name#41, sum#46, isEmpty#47] @@ -461,9 +461,9 @@ Arguments: hashpartitioning(c_last_name#42, c_first_name#41, 5), ENSURE_REQUIREM (73) HashAggregate [codegen id : 27] Input [4]: [c_last_name#42, c_first_name#41, sum#46, isEmpty#47] Keys [2]: [c_last_name#42, c_first_name#41] -Functions [1]: [sum(CheckOverflow((promote_precision(cast(cast(cs_quantity#3 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price#4 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(cast(cs_quantity#3 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price#4 as decimal(12,2)))), DecimalType(18,2), true))#49] -Results [3]: [c_last_name#42, c_first_name#41, sum(CheckOverflow((promote_precision(cast(cast(cs_quantity#3 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price#4 as decimal(12,2)))), DecimalType(18,2), true))#49 AS sales#50] +Functions [1]: [sum(CheckOverflow((promote_precision(cast(cs_quantity#3 as decimal(12,2))) * promote_precision(cast(cs_list_price#4 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(cs_quantity#3 as decimal(12,2))) * promote_precision(cast(cs_list_price#4 as decimal(12,2)))), DecimalType(18,2)))#49] +Results [3]: [c_last_name#42, c_first_name#41, sum(CheckOverflow((promote_precision(cast(cs_quantity#3 as decimal(12,2))) * promote_precision(cast(cs_list_price#4 as decimal(12,2)))), DecimalType(18,2)))#49 AS sales#50] (74) Scan parquet default.web_sales Output [5]: [ws_item_sk#51, ws_bill_customer_sk#52, ws_quantity#53, ws_list_price#54, ws_sold_date_sk#55] @@ -580,20 +580,20 @@ Input [4]: [ss_customer_sk#24, ss_quantity#25, ss_sales_price#26, c_customer_sk# (100) HashAggregate [codegen id : 42] Input [3]: [ss_quantity#25, ss_sales_price#26, c_customer_sk#29] Keys [1]: [c_customer_sk#29] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [2]: [sum#31, isEmpty#32] -Results [3]: [c_customer_sk#29, sum#33, isEmpty#34] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [2]: [sum#59, isEmpty#60] +Results [3]: [c_customer_sk#29, sum#61, isEmpty#62] (101) HashAggregate [codegen id : 42] -Input [3]: [c_customer_sk#29, sum#33, isEmpty#34] +Input [3]: [c_customer_sk#29, sum#61, isEmpty#62] Keys [1]: [c_customer_sk#29] -Functions [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))#35] -Results [2]: [c_customer_sk#29, sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))#35 AS ssales#36] +Functions [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))#35] +Results [2]: [c_customer_sk#29, sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))#35 AS ssales#36] (102) Filter [codegen id : 42] Input [2]: [c_customer_sk#29, ssales#36] -Condition : (isnotnull(ssales#36) AND (cast(ssales#36 as decimal(38,8)) > CheckOverflow((0.500000 * promote_precision(cast(ReusedSubquery Subquery scalar-subquery#37, [id=#38] as decimal(32,6)))), DecimalType(38,8), true))) +Condition : (isnotnull(ssales#36) AND (cast(ssales#36 as decimal(38,8)) > CheckOverflow((0.500000 * promote_precision(cast(ReusedSubquery Subquery scalar-subquery#37, [id=#38] as decimal(32,6)))), DecimalType(38,8)))) (103) Project [codegen id : 42] Output [1]: [c_customer_sk#29] @@ -609,23 +609,23 @@ Right keys [1]: [c_customer_sk#29] Join condition: None (106) ReusedExchange [Reuses operator id: 134] -Output [1]: [d_date_sk#59] +Output [1]: [d_date_sk#63] (107) BroadcastHashJoin [codegen id : 44] Left keys [1]: [ws_sold_date_sk#55] -Right keys [1]: [d_date_sk#59] +Right keys [1]: [d_date_sk#63] Join condition: None (108) Project [codegen id : 44] Output [3]: [ws_bill_customer_sk#52, ws_quantity#53, ws_list_price#54] -Input [5]: [ws_bill_customer_sk#52, ws_quantity#53, ws_list_price#54, ws_sold_date_sk#55, d_date_sk#59] +Input [5]: [ws_bill_customer_sk#52, ws_quantity#53, ws_list_price#54, ws_sold_date_sk#55, d_date_sk#63] (109) ReusedExchange [Reuses operator id: 55] -Output [3]: [c_customer_sk#60, c_first_name#61, c_last_name#62] +Output [3]: [c_customer_sk#64, c_first_name#65, c_last_name#66] (110) Sort [codegen id : 46] -Input [3]: [c_customer_sk#60, c_first_name#61, c_last_name#62] -Arguments: [c_customer_sk#60 ASC NULLS FIRST], false, 0 +Input [3]: [c_customer_sk#64, c_first_name#65, c_last_name#66] +Arguments: [c_customer_sk#64 ASC NULLS FIRST], false, 0 (111) ReusedExchange [Reuses operator id: 34] Output [3]: [ss_customer_sk#24, ss_quantity#25, ss_sales_price#26] @@ -653,20 +653,20 @@ Input [4]: [ss_customer_sk#24, ss_quantity#25, ss_sales_price#26, c_customer_sk# (117) HashAggregate [codegen id : 51] Input [3]: [ss_quantity#25, ss_sales_price#26, c_customer_sk#29] Keys [1]: [c_customer_sk#29] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [2]: [sum#31, isEmpty#32] -Results [3]: [c_customer_sk#29, sum#33, isEmpty#34] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [2]: [sum#59, isEmpty#60] +Results [3]: [c_customer_sk#29, sum#61, isEmpty#62] (118) HashAggregate [codegen id : 51] -Input [3]: [c_customer_sk#29, sum#33, isEmpty#34] +Input [3]: [c_customer_sk#29, sum#61, isEmpty#62] Keys [1]: [c_customer_sk#29] -Functions [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))#35] -Results [2]: [c_customer_sk#29, sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))#35 AS ssales#36] +Functions [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))#35] +Results [2]: [c_customer_sk#29, sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))#35 AS ssales#36] (119) Filter [codegen id : 51] Input [2]: [c_customer_sk#29, ssales#36] -Condition : (isnotnull(ssales#36) AND (cast(ssales#36 as decimal(38,8)) > CheckOverflow((0.500000 * promote_precision(cast(ReusedSubquery Subquery scalar-subquery#37, [id=#38] as decimal(32,6)))), DecimalType(38,8), true))) +Condition : (isnotnull(ssales#36) AND (cast(ssales#36 as decimal(38,8)) > CheckOverflow((0.500000 * promote_precision(cast(ReusedSubquery Subquery scalar-subquery#37, [id=#38] as decimal(32,6)))), DecimalType(38,8)))) (120) Project [codegen id : 51] Output [1]: [c_customer_sk#29] @@ -677,36 +677,36 @@ Input [1]: [c_customer_sk#29] Arguments: [c_customer_sk#29 ASC NULLS FIRST], false, 0 (122) SortMergeJoin [codegen id : 52] -Left keys [1]: [c_customer_sk#60] +Left keys [1]: [c_customer_sk#64] Right keys [1]: [c_customer_sk#29] Join condition: None (123) SortMergeJoin [codegen id : 53] Left keys [1]: [ws_bill_customer_sk#52] -Right keys [1]: [c_customer_sk#60] +Right keys [1]: [c_customer_sk#64] Join condition: None (124) Project [codegen id : 53] -Output [4]: [ws_quantity#53, ws_list_price#54, c_first_name#61, c_last_name#62] -Input [6]: [ws_bill_customer_sk#52, ws_quantity#53, ws_list_price#54, c_customer_sk#60, c_first_name#61, c_last_name#62] +Output [4]: [ws_quantity#53, ws_list_price#54, c_first_name#65, c_last_name#66] +Input [6]: [ws_bill_customer_sk#52, ws_quantity#53, ws_list_price#54, c_customer_sk#64, c_first_name#65, c_last_name#66] (125) HashAggregate [codegen id : 53] -Input [4]: [ws_quantity#53, ws_list_price#54, c_first_name#61, c_last_name#62] -Keys [2]: [c_last_name#62, c_first_name#61] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(cast(ws_quantity#53 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price#54 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [2]: [sum#63, isEmpty#64] -Results [4]: [c_last_name#62, c_first_name#61, sum#65, isEmpty#66] +Input [4]: [ws_quantity#53, ws_list_price#54, c_first_name#65, c_last_name#66] +Keys [2]: [c_last_name#66, c_first_name#65] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(ws_quantity#53 as decimal(12,2))) * promote_precision(cast(ws_list_price#54 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [2]: [sum#67, isEmpty#68] +Results [4]: [c_last_name#66, c_first_name#65, sum#69, isEmpty#70] (126) Exchange -Input [4]: [c_last_name#62, c_first_name#61, sum#65, isEmpty#66] -Arguments: hashpartitioning(c_last_name#62, c_first_name#61, 5), ENSURE_REQUIREMENTS, [id=#67] +Input [4]: [c_last_name#66, c_first_name#65, sum#69, isEmpty#70] +Arguments: hashpartitioning(c_last_name#66, c_first_name#65, 5), ENSURE_REQUIREMENTS, [id=#71] (127) HashAggregate [codegen id : 54] -Input [4]: [c_last_name#62, c_first_name#61, sum#65, isEmpty#66] -Keys [2]: [c_last_name#62, c_first_name#61] -Functions [1]: [sum(CheckOverflow((promote_precision(cast(cast(ws_quantity#53 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price#54 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(cast(ws_quantity#53 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price#54 as decimal(12,2)))), DecimalType(18,2), true))#68] -Results [3]: [c_last_name#62, c_first_name#61, sum(CheckOverflow((promote_precision(cast(cast(ws_quantity#53 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price#54 as decimal(12,2)))), DecimalType(18,2), true))#68 AS sales#69] +Input [4]: [c_last_name#66, c_first_name#65, sum#69, isEmpty#70] +Keys [2]: [c_last_name#66, c_first_name#65] +Functions [1]: [sum(CheckOverflow((promote_precision(cast(ws_quantity#53 as decimal(12,2))) * promote_precision(cast(ws_list_price#54 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(ws_quantity#53 as decimal(12,2))) * promote_precision(cast(ws_list_price#54 as decimal(12,2)))), DecimalType(18,2)))#72] +Results [3]: [c_last_name#66, c_first_name#65, sum(CheckOverflow((promote_precision(cast(ws_quantity#53 as decimal(12,2))) * promote_precision(cast(ws_list_price#54 as decimal(12,2)))), DecimalType(18,2)))#72 AS sales#73] (128) Union @@ -725,26 +725,26 @@ BroadcastExchange (134) (130) Scan parquet default.date_dim -Output [3]: [d_date_sk#39, d_year#70, d_moy#71] +Output [3]: [d_date_sk#39, d_year#74, d_moy#75] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), IsNotNull(d_moy), EqualTo(d_year,2000), EqualTo(d_moy,2), IsNotNull(d_date_sk)] ReadSchema: struct (131) ColumnarToRow [codegen id : 1] -Input [3]: [d_date_sk#39, d_year#70, d_moy#71] +Input [3]: [d_date_sk#39, d_year#74, d_moy#75] (132) Filter [codegen id : 1] -Input [3]: [d_date_sk#39, d_year#70, d_moy#71] -Condition : ((((isnotnull(d_year#70) AND isnotnull(d_moy#71)) AND (d_year#70 = 2000)) AND (d_moy#71 = 2)) AND isnotnull(d_date_sk#39)) +Input [3]: [d_date_sk#39, d_year#74, d_moy#75] +Condition : ((((isnotnull(d_year#74) AND isnotnull(d_moy#75)) AND (d_year#74 = 2000)) AND (d_moy#75 = 2)) AND isnotnull(d_date_sk#39)) (133) Project [codegen id : 1] Output [1]: [d_date_sk#39] -Input [3]: [d_date_sk#39, d_year#70, d_moy#71] +Input [3]: [d_date_sk#39, d_year#74, d_moy#75] (134) BroadcastExchange Input [1]: [d_date_sk#39] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#72] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#76] Subquery:2 Hosting operator id = 6 Hosting Expression = ss_sold_date_sk#9 IN dynamicpruning#10 BroadcastExchange (139) @@ -755,26 +755,26 @@ BroadcastExchange (139) (135) Scan parquet default.date_dim -Output [3]: [d_date_sk#11, d_date#12, d_year#73] +Output [3]: [d_date_sk#11, d_date#12, d_year#77] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [In(d_year, [2000,2001,2002,2003]), IsNotNull(d_date_sk)] ReadSchema: struct (136) ColumnarToRow [codegen id : 1] -Input [3]: [d_date_sk#11, d_date#12, d_year#73] +Input [3]: [d_date_sk#11, d_date#12, d_year#77] (137) Filter [codegen id : 1] -Input [3]: [d_date_sk#11, d_date#12, d_year#73] -Condition : (d_year#73 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#11)) +Input [3]: [d_date_sk#11, d_date#12, d_year#77] +Condition : (d_year#77 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#11)) (138) Project [codegen id : 1] Output [2]: [d_date_sk#11, d_date#12] -Input [3]: [d_date_sk#11, d_date#12, d_year#73] +Input [3]: [d_date_sk#11, d_date#12, d_year#77] (139) BroadcastExchange Input [2]: [d_date_sk#11, d_date#12] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#74] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#78] Subquery:3 Hosting operator id = 45 Hosting Expression = Subquery scalar-subquery#37, [id=#38] * HashAggregate (156) @@ -797,89 +797,89 @@ Subquery:3 Hosting operator id = 45 Hosting Expression = Subquery scalar-subquer (140) Scan parquet default.store_sales -Output [4]: [ss_customer_sk#75, ss_quantity#76, ss_sales_price#77, ss_sold_date_sk#78] +Output [4]: [ss_customer_sk#79, ss_quantity#80, ss_sales_price#81, ss_sold_date_sk#82] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(ss_sold_date_sk#78), dynamicpruningexpression(ss_sold_date_sk#78 IN dynamicpruning#79)] +PartitionFilters: [isnotnull(ss_sold_date_sk#82), dynamicpruningexpression(ss_sold_date_sk#82 IN dynamicpruning#83)] PushedFilters: [IsNotNull(ss_customer_sk)] ReadSchema: struct (141) ColumnarToRow [codegen id : 2] -Input [4]: [ss_customer_sk#75, ss_quantity#76, ss_sales_price#77, ss_sold_date_sk#78] +Input [4]: [ss_customer_sk#79, ss_quantity#80, ss_sales_price#81, ss_sold_date_sk#82] (142) Filter [codegen id : 2] -Input [4]: [ss_customer_sk#75, ss_quantity#76, ss_sales_price#77, ss_sold_date_sk#78] -Condition : isnotnull(ss_customer_sk#75) +Input [4]: [ss_customer_sk#79, ss_quantity#80, ss_sales_price#81, ss_sold_date_sk#82] +Condition : isnotnull(ss_customer_sk#79) (143) ReusedExchange [Reuses operator id: 161] -Output [1]: [d_date_sk#80] +Output [1]: [d_date_sk#84] (144) BroadcastHashJoin [codegen id : 2] -Left keys [1]: [ss_sold_date_sk#78] -Right keys [1]: [d_date_sk#80] +Left keys [1]: [ss_sold_date_sk#82] +Right keys [1]: [d_date_sk#84] Join condition: None (145) Project [codegen id : 2] -Output [3]: [ss_customer_sk#75, ss_quantity#76, ss_sales_price#77] -Input [5]: [ss_customer_sk#75, ss_quantity#76, ss_sales_price#77, ss_sold_date_sk#78, d_date_sk#80] +Output [3]: [ss_customer_sk#79, ss_quantity#80, ss_sales_price#81] +Input [5]: [ss_customer_sk#79, ss_quantity#80, ss_sales_price#81, ss_sold_date_sk#82, d_date_sk#84] (146) Exchange -Input [3]: [ss_customer_sk#75, ss_quantity#76, ss_sales_price#77] -Arguments: hashpartitioning(ss_customer_sk#75, 5), ENSURE_REQUIREMENTS, [id=#81] +Input [3]: [ss_customer_sk#79, ss_quantity#80, ss_sales_price#81] +Arguments: hashpartitioning(ss_customer_sk#79, 5), ENSURE_REQUIREMENTS, [id=#85] (147) Sort [codegen id : 3] -Input [3]: [ss_customer_sk#75, ss_quantity#76, ss_sales_price#77] -Arguments: [ss_customer_sk#75 ASC NULLS FIRST], false, 0 +Input [3]: [ss_customer_sk#79, ss_quantity#80, ss_sales_price#81] +Arguments: [ss_customer_sk#79 ASC NULLS FIRST], false, 0 (148) ReusedExchange [Reuses operator id: 39] -Output [1]: [c_customer_sk#82] +Output [1]: [c_customer_sk#86] (149) Sort [codegen id : 5] -Input [1]: [c_customer_sk#82] -Arguments: [c_customer_sk#82 ASC NULLS FIRST], false, 0 +Input [1]: [c_customer_sk#86] +Arguments: [c_customer_sk#86 ASC NULLS FIRST], false, 0 (150) SortMergeJoin [codegen id : 6] -Left keys [1]: [ss_customer_sk#75] -Right keys [1]: [c_customer_sk#82] +Left keys [1]: [ss_customer_sk#79] +Right keys [1]: [c_customer_sk#86] Join condition: None (151) Project [codegen id : 6] -Output [3]: [ss_quantity#76, ss_sales_price#77, c_customer_sk#82] -Input [4]: [ss_customer_sk#75, ss_quantity#76, ss_sales_price#77, c_customer_sk#82] +Output [3]: [ss_quantity#80, ss_sales_price#81, c_customer_sk#86] +Input [4]: [ss_customer_sk#79, ss_quantity#80, ss_sales_price#81, c_customer_sk#86] (152) HashAggregate [codegen id : 6] -Input [3]: [ss_quantity#76, ss_sales_price#77, c_customer_sk#82] -Keys [1]: [c_customer_sk#82] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#76 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#77 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [2]: [sum#83, isEmpty#84] -Results [3]: [c_customer_sk#82, sum#85, isEmpty#86] +Input [3]: [ss_quantity#80, ss_sales_price#81, c_customer_sk#86] +Keys [1]: [c_customer_sk#86] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#80 as decimal(12,2))) * promote_precision(cast(ss_sales_price#81 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [2]: [sum#87, isEmpty#88] +Results [3]: [c_customer_sk#86, sum#89, isEmpty#90] (153) HashAggregate [codegen id : 6] -Input [3]: [c_customer_sk#82, sum#85, isEmpty#86] -Keys [1]: [c_customer_sk#82] -Functions [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#76 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#77 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#76 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#77 as decimal(12,2)))), DecimalType(18,2), true))#87] -Results [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#76 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#77 as decimal(12,2)))), DecimalType(18,2), true))#87 AS csales#88] +Input [3]: [c_customer_sk#86, sum#89, isEmpty#90] +Keys [1]: [c_customer_sk#86] +Functions [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#80 as decimal(12,2))) * promote_precision(cast(ss_sales_price#81 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#80 as decimal(12,2))) * promote_precision(cast(ss_sales_price#81 as decimal(12,2)))), DecimalType(18,2)))#91] +Results [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#80 as decimal(12,2))) * promote_precision(cast(ss_sales_price#81 as decimal(12,2)))), DecimalType(18,2)))#91 AS csales#92] (154) HashAggregate [codegen id : 6] -Input [1]: [csales#88] +Input [1]: [csales#92] Keys: [] -Functions [1]: [partial_max(csales#88)] -Aggregate Attributes [1]: [max#89] -Results [1]: [max#90] +Functions [1]: [partial_max(csales#92)] +Aggregate Attributes [1]: [max#93] +Results [1]: [max#94] (155) Exchange -Input [1]: [max#90] -Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#91] +Input [1]: [max#94] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#95] (156) HashAggregate [codegen id : 7] -Input [1]: [max#90] +Input [1]: [max#94] Keys: [] -Functions [1]: [max(csales#88)] -Aggregate Attributes [1]: [max(csales#88)#92] -Results [1]: [max(csales#88)#92 AS tpcds_cmax#93] +Functions [1]: [max(csales#92)] +Aggregate Attributes [1]: [max(csales#92)#96] +Results [1]: [max(csales#92)#96 AS tpcds_cmax#97] -Subquery:4 Hosting operator id = 140 Hosting Expression = ss_sold_date_sk#78 IN dynamicpruning#79 +Subquery:4 Hosting operator id = 140 Hosting Expression = ss_sold_date_sk#82 IN dynamicpruning#83 BroadcastExchange (161) +- * Project (160) +- * Filter (159) @@ -888,26 +888,26 @@ BroadcastExchange (161) (157) Scan parquet default.date_dim -Output [2]: [d_date_sk#80, d_year#94] +Output [2]: [d_date_sk#84, d_year#98] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [In(d_year, [2000,2001,2002,2003]), IsNotNull(d_date_sk)] ReadSchema: struct (158) ColumnarToRow [codegen id : 1] -Input [2]: [d_date_sk#80, d_year#94] +Input [2]: [d_date_sk#84, d_year#98] (159) Filter [codegen id : 1] -Input [2]: [d_date_sk#80, d_year#94] -Condition : (d_year#94 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#80)) +Input [2]: [d_date_sk#84, d_year#98] +Condition : (d_year#98 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#84)) (160) Project [codegen id : 1] -Output [1]: [d_date_sk#80] -Input [2]: [d_date_sk#80, d_year#94] +Output [1]: [d_date_sk#84] +Input [2]: [d_date_sk#84, d_year#98] (161) BroadcastExchange -Input [1]: [d_date_sk#80] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#95] +Input [1]: [d_date_sk#84] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#99] Subquery:5 Hosting operator id = 65 Hosting Expression = ReusedSubquery Subquery scalar-subquery#37, [id=#38] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23b.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23b.sf100/simplified.txt index 1cdf12e0cc261..6561fbeddef1d 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23b.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23b.sf100/simplified.txt @@ -1,7 +1,7 @@ TakeOrderedAndProject [c_last_name,c_first_name,sales] Union WholeStageCodegen (27) - HashAggregate [c_last_name,c_first_name,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price as decimal(12,2)))), DecimalType(18,2), true)),sales,sum,isEmpty] + HashAggregate [c_last_name,c_first_name,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(cs_quantity as decimal(12,2))) * promote_precision(cast(cs_list_price as decimal(12,2)))), DecimalType(18,2))),sales,sum,isEmpty] InputAdapter Exchange [c_last_name,c_first_name] #1 WholeStageCodegen (26) @@ -92,7 +92,7 @@ TakeOrderedAndProject [c_last_name,c_first_name,sales] Exchange #10 WholeStageCodegen (6) HashAggregate [csales] [max,max] - HashAggregate [c_customer_sk,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price as decimal(12,2)))), DecimalType(18,2), true)),csales,sum,isEmpty] + HashAggregate [c_customer_sk,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(ss_quantity as decimal(12,2))) * promote_precision(cast(ss_sales_price as decimal(12,2)))), DecimalType(18,2))),csales,sum,isEmpty] HashAggregate [c_customer_sk,ss_quantity,ss_sales_price] [sum,isEmpty,sum,isEmpty] Project [ss_quantity,ss_sales_price,c_customer_sk] SortMergeJoin [ss_customer_sk,c_customer_sk] @@ -123,7 +123,7 @@ TakeOrderedAndProject [c_last_name,c_first_name,sales] Sort [c_customer_sk] InputAdapter ReusedExchange [c_customer_sk] #9 - HashAggregate [c_customer_sk,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price as decimal(12,2)))), DecimalType(18,2), true)),ssales,sum,isEmpty] + HashAggregate [c_customer_sk,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(ss_quantity as decimal(12,2))) * promote_precision(cast(ss_sales_price as decimal(12,2)))), DecimalType(18,2))),ssales,sum,isEmpty] HashAggregate [c_customer_sk,ss_quantity,ss_sales_price] [sum,isEmpty,sum,isEmpty] Project [ss_quantity,ss_sales_price,c_customer_sk] SortMergeJoin [ss_customer_sk,c_customer_sk] @@ -169,7 +169,7 @@ TakeOrderedAndProject [c_last_name,c_first_name,sales] Project [c_customer_sk] Filter [ssales] ReusedSubquery [tpcds_cmax] #3 - HashAggregate [c_customer_sk,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price as decimal(12,2)))), DecimalType(18,2), true)),ssales,sum,isEmpty] + HashAggregate [c_customer_sk,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(ss_quantity as decimal(12,2))) * promote_precision(cast(ss_sales_price as decimal(12,2)))), DecimalType(18,2))),ssales,sum,isEmpty] HashAggregate [c_customer_sk,ss_quantity,ss_sales_price] [sum,isEmpty,sum,isEmpty] Project [ss_quantity,ss_sales_price,c_customer_sk] SortMergeJoin [ss_customer_sk,c_customer_sk] @@ -184,7 +184,7 @@ TakeOrderedAndProject [c_last_name,c_first_name,sales] InputAdapter ReusedExchange [c_customer_sk] #9 WholeStageCodegen (54) - HashAggregate [c_last_name,c_first_name,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price as decimal(12,2)))), DecimalType(18,2), true)),sales,sum,isEmpty] + HashAggregate [c_last_name,c_first_name,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(ws_quantity as decimal(12,2))) * promote_precision(cast(ws_list_price as decimal(12,2)))), DecimalType(18,2))),sales,sum,isEmpty] InputAdapter Exchange [c_last_name,c_first_name] #14 WholeStageCodegen (53) @@ -240,7 +240,7 @@ TakeOrderedAndProject [c_last_name,c_first_name,sales] Project [c_customer_sk] Filter [ssales] ReusedSubquery [tpcds_cmax] #3 - HashAggregate [c_customer_sk,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price as decimal(12,2)))), DecimalType(18,2), true)),ssales,sum,isEmpty] + HashAggregate [c_customer_sk,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(ss_quantity as decimal(12,2))) * promote_precision(cast(ss_sales_price as decimal(12,2)))), DecimalType(18,2))),ssales,sum,isEmpty] HashAggregate [c_customer_sk,ss_quantity,ss_sales_price] [sum,isEmpty,sum,isEmpty] Project [ss_quantity,ss_sales_price,c_customer_sk] SortMergeJoin [ss_customer_sk,c_customer_sk] @@ -270,7 +270,7 @@ TakeOrderedAndProject [c_last_name,c_first_name,sales] Project [c_customer_sk] Filter [ssales] ReusedSubquery [tpcds_cmax] #3 - HashAggregate [c_customer_sk,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price as decimal(12,2)))), DecimalType(18,2), true)),ssales,sum,isEmpty] + HashAggregate [c_customer_sk,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(ss_quantity as decimal(12,2))) * promote_precision(cast(ss_sales_price as decimal(12,2)))), DecimalType(18,2))),ssales,sum,isEmpty] HashAggregate [c_customer_sk,ss_quantity,ss_sales_price] [sum,isEmpty,sum,isEmpty] Project [ss_quantity,ss_sales_price,c_customer_sk] SortMergeJoin [ss_customer_sk,c_customer_sk] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23b/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23b/explain.txt index 371f34bc14b4b..bea457e24dca9 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23b/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23b/explain.txt @@ -252,7 +252,7 @@ Input [4]: [ss_customer_sk#24, ss_quantity#25, ss_sales_price#26, c_customer_sk# (36) HashAggregate [codegen id : 8] Input [3]: [ss_quantity#25, ss_sales_price#26, c_customer_sk#28] Keys [1]: [c_customer_sk#28] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))] Aggregate Attributes [2]: [sum#30, isEmpty#31] Results [3]: [c_customer_sk#28, sum#32, isEmpty#33] @@ -263,13 +263,13 @@ Arguments: hashpartitioning(c_customer_sk#28, 5), ENSURE_REQUIREMENTS, [id=#34] (38) HashAggregate [codegen id : 9] Input [3]: [c_customer_sk#28, sum#32, isEmpty#33] Keys [1]: [c_customer_sk#28] -Functions [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))#35] -Results [2]: [c_customer_sk#28, sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))#35 AS ssales#36] +Functions [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))#35] +Results [2]: [c_customer_sk#28, sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))#35 AS ssales#36] (39) Filter [codegen id : 9] Input [2]: [c_customer_sk#28, ssales#36] -Condition : (isnotnull(ssales#36) AND (cast(ssales#36 as decimal(38,8)) > CheckOverflow((0.500000 * promote_precision(cast(Subquery scalar-subquery#37, [id=#38] as decimal(32,6)))), DecimalType(38,8), true))) +Condition : (isnotnull(ssales#36) AND (cast(ssales#36 as decimal(38,8)) > CheckOverflow((0.500000 * promote_precision(cast(Subquery scalar-subquery#37, [id=#38] as decimal(32,6)))), DecimalType(38,8)))) (40) Project [codegen id : 9] Output [1]: [c_customer_sk#28] @@ -312,13 +312,13 @@ Output [3]: [c_customer_sk#28, sum#32, isEmpty#33] (49) HashAggregate [codegen id : 14] Input [3]: [c_customer_sk#28, sum#32, isEmpty#33] Keys [1]: [c_customer_sk#28] -Functions [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))#35] -Results [2]: [c_customer_sk#28, sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))#35 AS ssales#36] +Functions [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))#35] +Results [2]: [c_customer_sk#28, sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))#35 AS ssales#36] (50) Filter [codegen id : 14] Input [2]: [c_customer_sk#28, ssales#36] -Condition : (isnotnull(ssales#36) AND (cast(ssales#36 as decimal(38,8)) > CheckOverflow((0.500000 * promote_precision(cast(ReusedSubquery Subquery scalar-subquery#37, [id=#38] as decimal(32,6)))), DecimalType(38,8), true))) +Condition : (isnotnull(ssales#36) AND (cast(ssales#36 as decimal(38,8)) > CheckOverflow((0.500000 * promote_precision(cast(ReusedSubquery Subquery scalar-subquery#37, [id=#38] as decimal(32,6)))), DecimalType(38,8)))) (51) Project [codegen id : 14] Output [1]: [c_customer_sk#28] @@ -361,7 +361,7 @@ Input [6]: [cs_quantity#3, cs_list_price#4, cs_sold_date_sk#5, c_first_name#40, (60) HashAggregate [codegen id : 17] Input [4]: [cs_quantity#3, cs_list_price#4, c_first_name#40, c_last_name#41] Keys [2]: [c_last_name#41, c_first_name#40] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(cast(cs_quantity#3 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price#4 as decimal(12,2)))), DecimalType(18,2), true))] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(cs_quantity#3 as decimal(12,2))) * promote_precision(cast(cs_list_price#4 as decimal(12,2)))), DecimalType(18,2)))] Aggregate Attributes [2]: [sum#45, isEmpty#46] Results [4]: [c_last_name#41, c_first_name#40, sum#47, isEmpty#48] @@ -372,9 +372,9 @@ Arguments: hashpartitioning(c_last_name#41, c_first_name#40, 5), ENSURE_REQUIREM (62) HashAggregate [codegen id : 18] Input [4]: [c_last_name#41, c_first_name#40, sum#47, isEmpty#48] Keys [2]: [c_last_name#41, c_first_name#40] -Functions [1]: [sum(CheckOverflow((promote_precision(cast(cast(cs_quantity#3 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price#4 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(cast(cs_quantity#3 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price#4 as decimal(12,2)))), DecimalType(18,2), true))#50] -Results [3]: [c_last_name#41, c_first_name#40, sum(CheckOverflow((promote_precision(cast(cast(cs_quantity#3 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price#4 as decimal(12,2)))), DecimalType(18,2), true))#50 AS sales#51] +Functions [1]: [sum(CheckOverflow((promote_precision(cast(cs_quantity#3 as decimal(12,2))) * promote_precision(cast(cs_list_price#4 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(cs_quantity#3 as decimal(12,2))) * promote_precision(cast(cs_list_price#4 as decimal(12,2)))), DecimalType(18,2)))#50] +Results [3]: [c_last_name#41, c_first_name#40, sum(CheckOverflow((promote_precision(cast(cs_quantity#3 as decimal(12,2))) * promote_precision(cast(cs_list_price#4 as decimal(12,2)))), DecimalType(18,2)))#50 AS sales#51] (63) Scan parquet default.web_sales Output [5]: [ws_item_sk#52, ws_bill_customer_sk#53, ws_quantity#54, ws_list_price#55, ws_sold_date_sk#56] @@ -412,18 +412,18 @@ Input [4]: [ws_bill_customer_sk#53, ws_quantity#54, ws_list_price#55, ws_sold_da Arguments: [ws_bill_customer_sk#53 ASC NULLS FIRST], false, 0 (71) ReusedExchange [Reuses operator id: 37] -Output [3]: [c_customer_sk#28, sum#32, isEmpty#33] +Output [3]: [c_customer_sk#28, sum#58, isEmpty#59] (72) HashAggregate [codegen id : 27] -Input [3]: [c_customer_sk#28, sum#32, isEmpty#33] +Input [3]: [c_customer_sk#28, sum#58, isEmpty#59] Keys [1]: [c_customer_sk#28] -Functions [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))#35] -Results [2]: [c_customer_sk#28, sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#25 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2), true))#35 AS ssales#36] +Functions [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))#35] +Results [2]: [c_customer_sk#28, sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))#35 AS ssales#36] (73) Filter [codegen id : 27] Input [2]: [c_customer_sk#28, ssales#36] -Condition : (isnotnull(ssales#36) AND (cast(ssales#36 as decimal(38,8)) > CheckOverflow((0.500000 * promote_precision(cast(ReusedSubquery Subquery scalar-subquery#37, [id=#38] as decimal(32,6)))), DecimalType(38,8), true))) +Condition : (isnotnull(ssales#36) AND (cast(ssales#36 as decimal(38,8)) > CheckOverflow((0.500000 * promote_precision(cast(ReusedSubquery Subquery scalar-subquery#37, [id=#38] as decimal(32,6)))), DecimalType(38,8)))) (74) Project [codegen id : 27] Output [1]: [c_customer_sk#28] @@ -439,46 +439,46 @@ Right keys [1]: [c_customer_sk#28] Join condition: None (77) ReusedExchange [Reuses operator id: 54] -Output [3]: [c_customer_sk#58, c_first_name#59, c_last_name#60] +Output [3]: [c_customer_sk#60, c_first_name#61, c_last_name#62] (78) BroadcastHashJoin [codegen id : 35] Left keys [1]: [ws_bill_customer_sk#53] -Right keys [1]: [c_customer_sk#58] +Right keys [1]: [c_customer_sk#60] Join condition: None (79) Project [codegen id : 35] -Output [5]: [ws_quantity#54, ws_list_price#55, ws_sold_date_sk#56, c_first_name#59, c_last_name#60] -Input [7]: [ws_bill_customer_sk#53, ws_quantity#54, ws_list_price#55, ws_sold_date_sk#56, c_customer_sk#58, c_first_name#59, c_last_name#60] +Output [5]: [ws_quantity#54, ws_list_price#55, ws_sold_date_sk#56, c_first_name#61, c_last_name#62] +Input [7]: [ws_bill_customer_sk#53, ws_quantity#54, ws_list_price#55, ws_sold_date_sk#56, c_customer_sk#60, c_first_name#61, c_last_name#62] (80) ReusedExchange [Reuses operator id: 92] -Output [1]: [d_date_sk#61] +Output [1]: [d_date_sk#63] (81) BroadcastHashJoin [codegen id : 35] Left keys [1]: [ws_sold_date_sk#56] -Right keys [1]: [d_date_sk#61] +Right keys [1]: [d_date_sk#63] Join condition: None (82) Project [codegen id : 35] -Output [4]: [ws_quantity#54, ws_list_price#55, c_first_name#59, c_last_name#60] -Input [6]: [ws_quantity#54, ws_list_price#55, ws_sold_date_sk#56, c_first_name#59, c_last_name#60, d_date_sk#61] +Output [4]: [ws_quantity#54, ws_list_price#55, c_first_name#61, c_last_name#62] +Input [6]: [ws_quantity#54, ws_list_price#55, ws_sold_date_sk#56, c_first_name#61, c_last_name#62, d_date_sk#63] (83) HashAggregate [codegen id : 35] -Input [4]: [ws_quantity#54, ws_list_price#55, c_first_name#59, c_last_name#60] -Keys [2]: [c_last_name#60, c_first_name#59] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(cast(ws_quantity#54 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price#55 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [2]: [sum#62, isEmpty#63] -Results [4]: [c_last_name#60, c_first_name#59, sum#64, isEmpty#65] +Input [4]: [ws_quantity#54, ws_list_price#55, c_first_name#61, c_last_name#62] +Keys [2]: [c_last_name#62, c_first_name#61] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(ws_quantity#54 as decimal(12,2))) * promote_precision(cast(ws_list_price#55 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [2]: [sum#64, isEmpty#65] +Results [4]: [c_last_name#62, c_first_name#61, sum#66, isEmpty#67] (84) Exchange -Input [4]: [c_last_name#60, c_first_name#59, sum#64, isEmpty#65] -Arguments: hashpartitioning(c_last_name#60, c_first_name#59, 5), ENSURE_REQUIREMENTS, [id=#66] +Input [4]: [c_last_name#62, c_first_name#61, sum#66, isEmpty#67] +Arguments: hashpartitioning(c_last_name#62, c_first_name#61, 5), ENSURE_REQUIREMENTS, [id=#68] (85) HashAggregate [codegen id : 36] -Input [4]: [c_last_name#60, c_first_name#59, sum#64, isEmpty#65] -Keys [2]: [c_last_name#60, c_first_name#59] -Functions [1]: [sum(CheckOverflow((promote_precision(cast(cast(ws_quantity#54 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price#55 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(cast(ws_quantity#54 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price#55 as decimal(12,2)))), DecimalType(18,2), true))#67] -Results [3]: [c_last_name#60, c_first_name#59, sum(CheckOverflow((promote_precision(cast(cast(ws_quantity#54 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price#55 as decimal(12,2)))), DecimalType(18,2), true))#67 AS sales#68] +Input [4]: [c_last_name#62, c_first_name#61, sum#66, isEmpty#67] +Keys [2]: [c_last_name#62, c_first_name#61] +Functions [1]: [sum(CheckOverflow((promote_precision(cast(ws_quantity#54 as decimal(12,2))) * promote_precision(cast(ws_list_price#55 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(ws_quantity#54 as decimal(12,2))) * promote_precision(cast(ws_list_price#55 as decimal(12,2)))), DecimalType(18,2)))#69] +Results [3]: [c_last_name#62, c_first_name#61, sum(CheckOverflow((promote_precision(cast(ws_quantity#54 as decimal(12,2))) * promote_precision(cast(ws_list_price#55 as decimal(12,2)))), DecimalType(18,2)))#69 AS sales#70] (86) Union @@ -497,26 +497,26 @@ BroadcastExchange (92) (88) Scan parquet default.date_dim -Output [3]: [d_date_sk#44, d_year#69, d_moy#70] +Output [3]: [d_date_sk#44, d_year#71, d_moy#72] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), IsNotNull(d_moy), EqualTo(d_year,2000), EqualTo(d_moy,2), IsNotNull(d_date_sk)] ReadSchema: struct (89) ColumnarToRow [codegen id : 1] -Input [3]: [d_date_sk#44, d_year#69, d_moy#70] +Input [3]: [d_date_sk#44, d_year#71, d_moy#72] (90) Filter [codegen id : 1] -Input [3]: [d_date_sk#44, d_year#69, d_moy#70] -Condition : ((((isnotnull(d_year#69) AND isnotnull(d_moy#70)) AND (d_year#69 = 2000)) AND (d_moy#70 = 2)) AND isnotnull(d_date_sk#44)) +Input [3]: [d_date_sk#44, d_year#71, d_moy#72] +Condition : ((((isnotnull(d_year#71) AND isnotnull(d_moy#72)) AND (d_year#71 = 2000)) AND (d_moy#72 = 2)) AND isnotnull(d_date_sk#44)) (91) Project [codegen id : 1] Output [1]: [d_date_sk#44] -Input [3]: [d_date_sk#44, d_year#69, d_moy#70] +Input [3]: [d_date_sk#44, d_year#71, d_moy#72] (92) BroadcastExchange Input [1]: [d_date_sk#44] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#71] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#73] Subquery:2 Hosting operator id = 4 Hosting Expression = ss_sold_date_sk#8 IN dynamicpruning#9 BroadcastExchange (97) @@ -527,26 +527,26 @@ BroadcastExchange (97) (93) Scan parquet default.date_dim -Output [3]: [d_date_sk#10, d_date#11, d_year#72] +Output [3]: [d_date_sk#10, d_date#11, d_year#74] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [In(d_year, [2000,2001,2002,2003]), IsNotNull(d_date_sk)] ReadSchema: struct (94) ColumnarToRow [codegen id : 1] -Input [3]: [d_date_sk#10, d_date#11, d_year#72] +Input [3]: [d_date_sk#10, d_date#11, d_year#74] (95) Filter [codegen id : 1] -Input [3]: [d_date_sk#10, d_date#11, d_year#72] -Condition : (d_year#72 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#10)) +Input [3]: [d_date_sk#10, d_date#11, d_year#74] +Condition : (d_year#74 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#10)) (96) Project [codegen id : 1] Output [2]: [d_date_sk#10, d_date#11] -Input [3]: [d_date_sk#10, d_date#11, d_year#72] +Input [3]: [d_date_sk#10, d_date#11, d_year#74] (97) BroadcastExchange Input [2]: [d_date_sk#10, d_date#11] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#73] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#75] Subquery:3 Hosting operator id = 39 Hosting Expression = Subquery scalar-subquery#37, [id=#38] * HashAggregate (112) @@ -567,81 +567,81 @@ Subquery:3 Hosting operator id = 39 Hosting Expression = Subquery scalar-subquer (98) Scan parquet default.store_sales -Output [4]: [ss_customer_sk#74, ss_quantity#75, ss_sales_price#76, ss_sold_date_sk#77] +Output [4]: [ss_customer_sk#76, ss_quantity#77, ss_sales_price#78, ss_sold_date_sk#79] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(ss_sold_date_sk#77), dynamicpruningexpression(ss_sold_date_sk#77 IN dynamicpruning#78)] +PartitionFilters: [isnotnull(ss_sold_date_sk#79), dynamicpruningexpression(ss_sold_date_sk#79 IN dynamicpruning#80)] PushedFilters: [IsNotNull(ss_customer_sk)] ReadSchema: struct (99) ColumnarToRow [codegen id : 3] -Input [4]: [ss_customer_sk#74, ss_quantity#75, ss_sales_price#76, ss_sold_date_sk#77] +Input [4]: [ss_customer_sk#76, ss_quantity#77, ss_sales_price#78, ss_sold_date_sk#79] (100) Filter [codegen id : 3] -Input [4]: [ss_customer_sk#74, ss_quantity#75, ss_sales_price#76, ss_sold_date_sk#77] -Condition : isnotnull(ss_customer_sk#74) +Input [4]: [ss_customer_sk#76, ss_quantity#77, ss_sales_price#78, ss_sold_date_sk#79] +Condition : isnotnull(ss_customer_sk#76) (101) ReusedExchange [Reuses operator id: 33] -Output [1]: [c_customer_sk#79] +Output [1]: [c_customer_sk#81] (102) BroadcastHashJoin [codegen id : 3] -Left keys [1]: [ss_customer_sk#74] -Right keys [1]: [c_customer_sk#79] +Left keys [1]: [ss_customer_sk#76] +Right keys [1]: [c_customer_sk#81] Join condition: None (103) Project [codegen id : 3] -Output [4]: [ss_quantity#75, ss_sales_price#76, ss_sold_date_sk#77, c_customer_sk#79] -Input [5]: [ss_customer_sk#74, ss_quantity#75, ss_sales_price#76, ss_sold_date_sk#77, c_customer_sk#79] +Output [4]: [ss_quantity#77, ss_sales_price#78, ss_sold_date_sk#79, c_customer_sk#81] +Input [5]: [ss_customer_sk#76, ss_quantity#77, ss_sales_price#78, ss_sold_date_sk#79, c_customer_sk#81] (104) ReusedExchange [Reuses operator id: 117] -Output [1]: [d_date_sk#80] +Output [1]: [d_date_sk#82] (105) BroadcastHashJoin [codegen id : 3] -Left keys [1]: [ss_sold_date_sk#77] -Right keys [1]: [d_date_sk#80] +Left keys [1]: [ss_sold_date_sk#79] +Right keys [1]: [d_date_sk#82] Join condition: None (106) Project [codegen id : 3] -Output [3]: [ss_quantity#75, ss_sales_price#76, c_customer_sk#79] -Input [5]: [ss_quantity#75, ss_sales_price#76, ss_sold_date_sk#77, c_customer_sk#79, d_date_sk#80] +Output [3]: [ss_quantity#77, ss_sales_price#78, c_customer_sk#81] +Input [5]: [ss_quantity#77, ss_sales_price#78, ss_sold_date_sk#79, c_customer_sk#81, d_date_sk#82] (107) HashAggregate [codegen id : 3] -Input [3]: [ss_quantity#75, ss_sales_price#76, c_customer_sk#79] -Keys [1]: [c_customer_sk#79] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#75 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#76 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [2]: [sum#81, isEmpty#82] -Results [3]: [c_customer_sk#79, sum#83, isEmpty#84] +Input [3]: [ss_quantity#77, ss_sales_price#78, c_customer_sk#81] +Keys [1]: [c_customer_sk#81] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#77 as decimal(12,2))) * promote_precision(cast(ss_sales_price#78 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [2]: [sum#83, isEmpty#84] +Results [3]: [c_customer_sk#81, sum#85, isEmpty#86] (108) Exchange -Input [3]: [c_customer_sk#79, sum#83, isEmpty#84] -Arguments: hashpartitioning(c_customer_sk#79, 5), ENSURE_REQUIREMENTS, [id=#85] +Input [3]: [c_customer_sk#81, sum#85, isEmpty#86] +Arguments: hashpartitioning(c_customer_sk#81, 5), ENSURE_REQUIREMENTS, [id=#87] (109) HashAggregate [codegen id : 4] -Input [3]: [c_customer_sk#79, sum#83, isEmpty#84] -Keys [1]: [c_customer_sk#79] -Functions [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#75 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#76 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#75 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#76 as decimal(12,2)))), DecimalType(18,2), true))#86] -Results [1]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#75 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#76 as decimal(12,2)))), DecimalType(18,2), true))#86 AS csales#87] +Input [3]: [c_customer_sk#81, sum#85, isEmpty#86] +Keys [1]: [c_customer_sk#81] +Functions [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#77 as decimal(12,2))) * promote_precision(cast(ss_sales_price#78 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#77 as decimal(12,2))) * promote_precision(cast(ss_sales_price#78 as decimal(12,2)))), DecimalType(18,2)))#88] +Results [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#77 as decimal(12,2))) * promote_precision(cast(ss_sales_price#78 as decimal(12,2)))), DecimalType(18,2)))#88 AS csales#89] (110) HashAggregate [codegen id : 4] -Input [1]: [csales#87] +Input [1]: [csales#89] Keys: [] -Functions [1]: [partial_max(csales#87)] -Aggregate Attributes [1]: [max#88] -Results [1]: [max#89] +Functions [1]: [partial_max(csales#89)] +Aggregate Attributes [1]: [max#90] +Results [1]: [max#91] (111) Exchange -Input [1]: [max#89] -Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#90] +Input [1]: [max#91] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#92] (112) HashAggregate [codegen id : 5] -Input [1]: [max#89] +Input [1]: [max#91] Keys: [] -Functions [1]: [max(csales#87)] -Aggregate Attributes [1]: [max(csales#87)#91] -Results [1]: [max(csales#87)#91 AS tpcds_cmax#92] +Functions [1]: [max(csales#89)] +Aggregate Attributes [1]: [max(csales#89)#93] +Results [1]: [max(csales#89)#93 AS tpcds_cmax#94] -Subquery:4 Hosting operator id = 98 Hosting Expression = ss_sold_date_sk#77 IN dynamicpruning#78 +Subquery:4 Hosting operator id = 98 Hosting Expression = ss_sold_date_sk#79 IN dynamicpruning#80 BroadcastExchange (117) +- * Project (116) +- * Filter (115) @@ -650,26 +650,26 @@ BroadcastExchange (117) (113) Scan parquet default.date_dim -Output [2]: [d_date_sk#80, d_year#93] +Output [2]: [d_date_sk#82, d_year#95] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [In(d_year, [2000,2001,2002,2003]), IsNotNull(d_date_sk)] ReadSchema: struct (114) ColumnarToRow [codegen id : 1] -Input [2]: [d_date_sk#80, d_year#93] +Input [2]: [d_date_sk#82, d_year#95] (115) Filter [codegen id : 1] -Input [2]: [d_date_sk#80, d_year#93] -Condition : (d_year#93 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#80)) +Input [2]: [d_date_sk#82, d_year#95] +Condition : (d_year#95 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#82)) (116) Project [codegen id : 1] -Output [1]: [d_date_sk#80] -Input [2]: [d_date_sk#80, d_year#93] +Output [1]: [d_date_sk#82] +Input [2]: [d_date_sk#82, d_year#95] (117) BroadcastExchange -Input [1]: [d_date_sk#80] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#94] +Input [1]: [d_date_sk#82] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#96] Subquery:5 Hosting operator id = 50 Hosting Expression = ReusedSubquery Subquery scalar-subquery#37, [id=#38] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23b/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23b/simplified.txt index 8a43f5cdae750..19f5b95dce994 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23b/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23b/simplified.txt @@ -1,7 +1,7 @@ TakeOrderedAndProject [c_last_name,c_first_name,sales] Union WholeStageCodegen (18) - HashAggregate [c_last_name,c_first_name,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price as decimal(12,2)))), DecimalType(18,2), true)),sales,sum,isEmpty] + HashAggregate [c_last_name,c_first_name,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(cs_quantity as decimal(12,2))) * promote_precision(cast(cs_list_price as decimal(12,2)))), DecimalType(18,2))),sales,sum,isEmpty] InputAdapter Exchange [c_last_name,c_first_name] #1 WholeStageCodegen (17) @@ -78,7 +78,7 @@ TakeOrderedAndProject [c_last_name,c_first_name,sales] Exchange #10 WholeStageCodegen (4) HashAggregate [csales] [max,max] - HashAggregate [c_customer_sk,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price as decimal(12,2)))), DecimalType(18,2), true)),csales,sum,isEmpty] + HashAggregate [c_customer_sk,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(ss_quantity as decimal(12,2))) * promote_precision(cast(ss_sales_price as decimal(12,2)))), DecimalType(18,2))),csales,sum,isEmpty] InputAdapter Exchange [c_customer_sk] #11 WholeStageCodegen (3) @@ -103,7 +103,7 @@ TakeOrderedAndProject [c_last_name,c_first_name,sales] ReusedExchange [c_customer_sk] #9 InputAdapter ReusedExchange [d_date_sk] #12 - HashAggregate [c_customer_sk,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price as decimal(12,2)))), DecimalType(18,2), true)),ssales,sum,isEmpty] + HashAggregate [c_customer_sk,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(ss_quantity as decimal(12,2))) * promote_precision(cast(ss_sales_price as decimal(12,2)))), DecimalType(18,2))),ssales,sum,isEmpty] InputAdapter Exchange [c_customer_sk] #8 WholeStageCodegen (8) @@ -142,13 +142,13 @@ TakeOrderedAndProject [c_last_name,c_first_name,sales] Project [c_customer_sk] Filter [ssales] ReusedSubquery [tpcds_cmax] #3 - HashAggregate [c_customer_sk,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price as decimal(12,2)))), DecimalType(18,2), true)),ssales,sum,isEmpty] + HashAggregate [c_customer_sk,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(ss_quantity as decimal(12,2))) * promote_precision(cast(ss_sales_price as decimal(12,2)))), DecimalType(18,2))),ssales,sum,isEmpty] InputAdapter ReusedExchange [c_customer_sk,sum,isEmpty] #8 InputAdapter ReusedExchange [d_date_sk] #3 WholeStageCodegen (36) - HashAggregate [c_last_name,c_first_name,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price as decimal(12,2)))), DecimalType(18,2), true)),sales,sum,isEmpty] + HashAggregate [c_last_name,c_first_name,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(ws_quantity as decimal(12,2))) * promote_precision(cast(ws_list_price as decimal(12,2)))), DecimalType(18,2))),sales,sum,isEmpty] InputAdapter Exchange [c_last_name,c_first_name] #15 WholeStageCodegen (35) @@ -179,7 +179,7 @@ TakeOrderedAndProject [c_last_name,c_first_name,sales] Project [c_customer_sk] Filter [ssales] ReusedSubquery [tpcds_cmax] #3 - HashAggregate [c_customer_sk,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price as decimal(12,2)))), DecimalType(18,2), true)),ssales,sum,isEmpty] + HashAggregate [c_customer_sk,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(ss_quantity as decimal(12,2))) * promote_precision(cast(ss_sales_price as decimal(12,2)))), DecimalType(18,2))),ssales,sum,isEmpty] InputAdapter ReusedExchange [c_customer_sk,sum,isEmpty] #8 InputAdapter diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q24a.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q24a.sf100/explain.txt index 2ecb115faf87d..7b82aed515f39 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q24a.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q24a.sf100/explain.txt @@ -536,6 +536,6 @@ Input [2]: [sum#61, count#62] Keys: [] Functions [1]: [avg(netpaid#39)] Aggregate Attributes [1]: [avg(netpaid#39)#64] -Results [1]: [CheckOverflow((0.050000 * promote_precision(avg(netpaid#39)#64)), DecimalType(24,8), true) AS (0.05 * avg(netpaid))#65] +Results [1]: [CheckOverflow((0.050000 * promote_precision(avg(netpaid#39)#64)), DecimalType(24,8)) AS (0.05 * avg(netpaid))#65] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q24a/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q24a/explain.txt index 0ad7d96f8f777..d1fa0bd182199 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q24a/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q24a/explain.txt @@ -412,6 +412,6 @@ Input [2]: [sum#54, count#55] Keys: [] Functions [1]: [avg(netpaid#38)] Aggregate Attributes [1]: [avg(netpaid#38)#57] -Results [1]: [CheckOverflow((0.050000 * promote_precision(avg(netpaid#38)#57)), DecimalType(24,8), true) AS (0.05 * avg(netpaid))#58] +Results [1]: [CheckOverflow((0.050000 * promote_precision(avg(netpaid#38)#57)), DecimalType(24,8)) AS (0.05 * avg(netpaid))#58] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q24b.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q24b.sf100/explain.txt index 9e4e27f2c6726..fa921b7f2b622 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q24b.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q24b.sf100/explain.txt @@ -536,6 +536,6 @@ Input [2]: [sum#61, count#62] Keys: [] Functions [1]: [avg(netpaid#39)] Aggregate Attributes [1]: [avg(netpaid#39)#64] -Results [1]: [CheckOverflow((0.050000 * promote_precision(avg(netpaid#39)#64)), DecimalType(24,8), true) AS (0.05 * avg(netpaid))#65] +Results [1]: [CheckOverflow((0.050000 * promote_precision(avg(netpaid#39)#64)), DecimalType(24,8)) AS (0.05 * avg(netpaid))#65] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q24b/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q24b/explain.txt index 78371d380114e..e1a6c33699efd 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q24b/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q24b/explain.txt @@ -412,6 +412,6 @@ Input [2]: [sum#54, count#55] Keys: [] Functions [1]: [avg(netpaid#38)] Aggregate Attributes [1]: [avg(netpaid#38)#57] -Results [1]: [CheckOverflow((0.050000 * promote_precision(avg(netpaid#38)#57)), DecimalType(24,8), true) AS (0.05 * avg(netpaid))#58] +Results [1]: [CheckOverflow((0.050000 * promote_precision(avg(netpaid#38)#57)), DecimalType(24,8)) AS (0.05 * avg(netpaid))#58] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q25.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q25.sf100/explain.txt index cbbf3da55739d..fc55789fab16a 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q25.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q25.sf100/explain.txt @@ -1,50 +1,53 @@ == Physical Plan == -TakeOrderedAndProject (46) -+- * HashAggregate (45) - +- Exchange (44) - +- * HashAggregate (43) - +- * Project (42) - +- * SortMergeJoin Inner (41) - :- * Project (32) - : +- * SortMergeJoin Inner (31) - : :- * Sort (22) - : : +- * Project (21) - : : +- * SortMergeJoin Inner (20) - : : :- * Sort (14) - : : : +- Exchange (13) - : : : +- * Project (12) - : : : +- * BroadcastHashJoin Inner BuildRight (11) - : : : :- * Project (6) - : : : : +- * BroadcastHashJoin Inner BuildRight (5) - : : : : :- * Filter (3) - : : : : : +- * ColumnarToRow (2) - : : : : : +- Scan parquet default.store_sales (1) - : : : : +- ReusedExchange (4) - : : : +- BroadcastExchange (10) - : : : +- * Filter (9) - : : : +- * ColumnarToRow (8) - : : : +- Scan parquet default.store (7) - : : +- * Sort (19) - : : +- Exchange (18) - : : +- * Filter (17) - : : +- * ColumnarToRow (16) - : : +- Scan parquet default.item (15) - : +- * Sort (30) - : +- Exchange (29) - : +- * Project (28) - : +- * BroadcastHashJoin Inner BuildRight (27) - : :- * Filter (25) - : : +- * ColumnarToRow (24) - : : +- Scan parquet default.store_returns (23) - : +- ReusedExchange (26) - +- * Sort (40) - +- Exchange (39) - +- * Project (38) - +- * BroadcastHashJoin Inner BuildRight (37) - :- * Filter (35) - : +- * ColumnarToRow (34) - : +- Scan parquet default.catalog_sales (33) - +- ReusedExchange (36) +TakeOrderedAndProject (49) ++- * HashAggregate (48) + +- Exchange (47) + +- * HashAggregate (46) + +- * Project (45) + +- * SortMergeJoin Inner (44) + :- * Sort (35) + : +- Exchange (34) + : +- * Project (33) + : +- * SortMergeJoin Inner (32) + : :- * Sort (23) + : : +- Exchange (22) + : : +- * Project (21) + : : +- * SortMergeJoin Inner (20) + : : :- * Sort (14) + : : : +- Exchange (13) + : : : +- * Project (12) + : : : +- * BroadcastHashJoin Inner BuildRight (11) + : : : :- * Project (6) + : : : : +- * BroadcastHashJoin Inner BuildRight (5) + : : : : :- * Filter (3) + : : : : : +- * ColumnarToRow (2) + : : : : : +- Scan parquet default.store_sales (1) + : : : : +- ReusedExchange (4) + : : : +- BroadcastExchange (10) + : : : +- * Filter (9) + : : : +- * ColumnarToRow (8) + : : : +- Scan parquet default.store (7) + : : +- * Sort (19) + : : +- Exchange (18) + : : +- * Filter (17) + : : +- * ColumnarToRow (16) + : : +- Scan parquet default.item (15) + : +- * Sort (31) + : +- Exchange (30) + : +- * Project (29) + : +- * BroadcastHashJoin Inner BuildRight (28) + : :- * Filter (26) + : : +- * ColumnarToRow (25) + : : +- Scan parquet default.store_returns (24) + : +- ReusedExchange (27) + +- * Sort (43) + +- Exchange (42) + +- * Project (41) + +- * BroadcastHashJoin Inner BuildRight (40) + :- * Filter (38) + : +- * ColumnarToRow (37) + : +- Scan parquet default.catalog_sales (36) + +- ReusedExchange (39) (1) Scan parquet default.store_sales @@ -62,7 +65,7 @@ Input [6]: [ss_item_sk#1, ss_customer_sk#2, ss_store_sk#3, ss_ticket_number#4, s Input [6]: [ss_item_sk#1, ss_customer_sk#2, ss_store_sk#3, ss_ticket_number#4, ss_net_profit#5, ss_sold_date_sk#6] Condition : (((isnotnull(ss_customer_sk#2) AND isnotnull(ss_item_sk#1)) AND isnotnull(ss_ticket_number#4)) AND isnotnull(ss_store_sk#3)) -(4) ReusedExchange [Reuses operator id: 51] +(4) ReusedExchange [Reuses operator id: 54] Output [1]: [d_date_sk#8] (5) BroadcastHashJoin [codegen id : 3] @@ -140,182 +143,194 @@ Join condition: None Output [8]: [ss_item_sk#1, ss_customer_sk#2, ss_ticket_number#4, ss_net_profit#5, s_store_id#10, s_store_name#11, i_item_id#15, i_item_desc#16] Input [9]: [ss_item_sk#1, ss_customer_sk#2, ss_ticket_number#4, ss_net_profit#5, s_store_id#10, s_store_name#11, i_item_sk#14, i_item_id#15, i_item_desc#16] -(22) Sort [codegen id : 7] +(22) Exchange +Input [8]: [ss_item_sk#1, ss_customer_sk#2, ss_ticket_number#4, ss_net_profit#5, s_store_id#10, s_store_name#11, i_item_id#15, i_item_desc#16] +Arguments: hashpartitioning(ss_customer_sk#2, ss_item_sk#1, ss_ticket_number#4, 5), ENSURE_REQUIREMENTS, [id=#18] + +(23) Sort [codegen id : 8] Input [8]: [ss_item_sk#1, ss_customer_sk#2, ss_ticket_number#4, ss_net_profit#5, s_store_id#10, s_store_name#11, i_item_id#15, i_item_desc#16] Arguments: [ss_customer_sk#2 ASC NULLS FIRST, ss_item_sk#1 ASC NULLS FIRST, ss_ticket_number#4 ASC NULLS FIRST], false, 0 -(23) Scan parquet default.store_returns -Output [5]: [sr_item_sk#18, sr_customer_sk#19, sr_ticket_number#20, sr_net_loss#21, sr_returned_date_sk#22] +(24) Scan parquet default.store_returns +Output [5]: [sr_item_sk#19, sr_customer_sk#20, sr_ticket_number#21, sr_net_loss#22, sr_returned_date_sk#23] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(sr_returned_date_sk#22), dynamicpruningexpression(sr_returned_date_sk#22 IN dynamicpruning#23)] +PartitionFilters: [isnotnull(sr_returned_date_sk#23), dynamicpruningexpression(sr_returned_date_sk#23 IN dynamicpruning#24)] PushedFilters: [IsNotNull(sr_customer_sk), IsNotNull(sr_item_sk), IsNotNull(sr_ticket_number)] ReadSchema: struct -(24) ColumnarToRow [codegen id : 9] -Input [5]: [sr_item_sk#18, sr_customer_sk#19, sr_ticket_number#20, sr_net_loss#21, sr_returned_date_sk#22] +(25) ColumnarToRow [codegen id : 10] +Input [5]: [sr_item_sk#19, sr_customer_sk#20, sr_ticket_number#21, sr_net_loss#22, sr_returned_date_sk#23] -(25) Filter [codegen id : 9] -Input [5]: [sr_item_sk#18, sr_customer_sk#19, sr_ticket_number#20, sr_net_loss#21, sr_returned_date_sk#22] -Condition : ((isnotnull(sr_customer_sk#19) AND isnotnull(sr_item_sk#18)) AND isnotnull(sr_ticket_number#20)) +(26) Filter [codegen id : 10] +Input [5]: [sr_item_sk#19, sr_customer_sk#20, sr_ticket_number#21, sr_net_loss#22, sr_returned_date_sk#23] +Condition : ((isnotnull(sr_customer_sk#20) AND isnotnull(sr_item_sk#19)) AND isnotnull(sr_ticket_number#21)) -(26) ReusedExchange [Reuses operator id: 56] -Output [1]: [d_date_sk#24] +(27) ReusedExchange [Reuses operator id: 59] +Output [1]: [d_date_sk#25] -(27) BroadcastHashJoin [codegen id : 9] -Left keys [1]: [sr_returned_date_sk#22] -Right keys [1]: [d_date_sk#24] +(28) BroadcastHashJoin [codegen id : 10] +Left keys [1]: [sr_returned_date_sk#23] +Right keys [1]: [d_date_sk#25] Join condition: None -(28) Project [codegen id : 9] -Output [4]: [sr_item_sk#18, sr_customer_sk#19, sr_ticket_number#20, sr_net_loss#21] -Input [6]: [sr_item_sk#18, sr_customer_sk#19, sr_ticket_number#20, sr_net_loss#21, sr_returned_date_sk#22, d_date_sk#24] +(29) Project [codegen id : 10] +Output [4]: [sr_item_sk#19, sr_customer_sk#20, sr_ticket_number#21, sr_net_loss#22] +Input [6]: [sr_item_sk#19, sr_customer_sk#20, sr_ticket_number#21, sr_net_loss#22, sr_returned_date_sk#23, d_date_sk#25] -(29) Exchange -Input [4]: [sr_item_sk#18, sr_customer_sk#19, sr_ticket_number#20, sr_net_loss#21] -Arguments: hashpartitioning(sr_item_sk#18, 5), ENSURE_REQUIREMENTS, [id=#25] +(30) Exchange +Input [4]: [sr_item_sk#19, sr_customer_sk#20, sr_ticket_number#21, sr_net_loss#22] +Arguments: hashpartitioning(sr_customer_sk#20, sr_item_sk#19, sr_ticket_number#21, 5), ENSURE_REQUIREMENTS, [id=#26] -(30) Sort [codegen id : 10] -Input [4]: [sr_item_sk#18, sr_customer_sk#19, sr_ticket_number#20, sr_net_loss#21] -Arguments: [sr_customer_sk#19 ASC NULLS FIRST, sr_item_sk#18 ASC NULLS FIRST, sr_ticket_number#20 ASC NULLS FIRST], false, 0 +(31) Sort [codegen id : 11] +Input [4]: [sr_item_sk#19, sr_customer_sk#20, sr_ticket_number#21, sr_net_loss#22] +Arguments: [sr_customer_sk#20 ASC NULLS FIRST, sr_item_sk#19 ASC NULLS FIRST, sr_ticket_number#21 ASC NULLS FIRST], false, 0 -(31) SortMergeJoin [codegen id : 11] +(32) SortMergeJoin [codegen id : 12] Left keys [3]: [ss_customer_sk#2, ss_item_sk#1, ss_ticket_number#4] -Right keys [3]: [sr_customer_sk#19, sr_item_sk#18, sr_ticket_number#20] +Right keys [3]: [sr_customer_sk#20, sr_item_sk#19, sr_ticket_number#21] Join condition: None -(32) Project [codegen id : 11] -Output [8]: [ss_net_profit#5, s_store_id#10, s_store_name#11, i_item_id#15, i_item_desc#16, sr_item_sk#18, sr_customer_sk#19, sr_net_loss#21] -Input [12]: [ss_item_sk#1, ss_customer_sk#2, ss_ticket_number#4, ss_net_profit#5, s_store_id#10, s_store_name#11, i_item_id#15, i_item_desc#16, sr_item_sk#18, sr_customer_sk#19, sr_ticket_number#20, sr_net_loss#21] +(33) Project [codegen id : 12] +Output [8]: [ss_net_profit#5, s_store_id#10, s_store_name#11, i_item_id#15, i_item_desc#16, sr_item_sk#19, sr_customer_sk#20, sr_net_loss#22] +Input [12]: [ss_item_sk#1, ss_customer_sk#2, ss_ticket_number#4, ss_net_profit#5, s_store_id#10, s_store_name#11, i_item_id#15, i_item_desc#16, sr_item_sk#19, sr_customer_sk#20, sr_ticket_number#21, sr_net_loss#22] + +(34) Exchange +Input [8]: [ss_net_profit#5, s_store_id#10, s_store_name#11, i_item_id#15, i_item_desc#16, sr_item_sk#19, sr_customer_sk#20, sr_net_loss#22] +Arguments: hashpartitioning(sr_customer_sk#20, sr_item_sk#19, 5), ENSURE_REQUIREMENTS, [id=#27] + +(35) Sort [codegen id : 13] +Input [8]: [ss_net_profit#5, s_store_id#10, s_store_name#11, i_item_id#15, i_item_desc#16, sr_item_sk#19, sr_customer_sk#20, sr_net_loss#22] +Arguments: [sr_customer_sk#20 ASC NULLS FIRST, sr_item_sk#19 ASC NULLS FIRST], false, 0 -(33) Scan parquet default.catalog_sales -Output [4]: [cs_bill_customer_sk#26, cs_item_sk#27, cs_net_profit#28, cs_sold_date_sk#29] +(36) Scan parquet default.catalog_sales +Output [4]: [cs_bill_customer_sk#28, cs_item_sk#29, cs_net_profit#30, cs_sold_date_sk#31] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(cs_sold_date_sk#29), dynamicpruningexpression(cs_sold_date_sk#29 IN dynamicpruning#23)] +PartitionFilters: [isnotnull(cs_sold_date_sk#31), dynamicpruningexpression(cs_sold_date_sk#31 IN dynamicpruning#24)] PushedFilters: [IsNotNull(cs_bill_customer_sk), IsNotNull(cs_item_sk)] ReadSchema: struct -(34) ColumnarToRow [codegen id : 13] -Input [4]: [cs_bill_customer_sk#26, cs_item_sk#27, cs_net_profit#28, cs_sold_date_sk#29] +(37) ColumnarToRow [codegen id : 15] +Input [4]: [cs_bill_customer_sk#28, cs_item_sk#29, cs_net_profit#30, cs_sold_date_sk#31] -(35) Filter [codegen id : 13] -Input [4]: [cs_bill_customer_sk#26, cs_item_sk#27, cs_net_profit#28, cs_sold_date_sk#29] -Condition : (isnotnull(cs_bill_customer_sk#26) AND isnotnull(cs_item_sk#27)) +(38) Filter [codegen id : 15] +Input [4]: [cs_bill_customer_sk#28, cs_item_sk#29, cs_net_profit#30, cs_sold_date_sk#31] +Condition : (isnotnull(cs_bill_customer_sk#28) AND isnotnull(cs_item_sk#29)) -(36) ReusedExchange [Reuses operator id: 56] -Output [1]: [d_date_sk#30] +(39) ReusedExchange [Reuses operator id: 59] +Output [1]: [d_date_sk#32] -(37) BroadcastHashJoin [codegen id : 13] -Left keys [1]: [cs_sold_date_sk#29] -Right keys [1]: [d_date_sk#30] +(40) BroadcastHashJoin [codegen id : 15] +Left keys [1]: [cs_sold_date_sk#31] +Right keys [1]: [d_date_sk#32] Join condition: None -(38) Project [codegen id : 13] -Output [3]: [cs_bill_customer_sk#26, cs_item_sk#27, cs_net_profit#28] -Input [5]: [cs_bill_customer_sk#26, cs_item_sk#27, cs_net_profit#28, cs_sold_date_sk#29, d_date_sk#30] +(41) Project [codegen id : 15] +Output [3]: [cs_bill_customer_sk#28, cs_item_sk#29, cs_net_profit#30] +Input [5]: [cs_bill_customer_sk#28, cs_item_sk#29, cs_net_profit#30, cs_sold_date_sk#31, d_date_sk#32] -(39) Exchange -Input [3]: [cs_bill_customer_sk#26, cs_item_sk#27, cs_net_profit#28] -Arguments: hashpartitioning(cs_item_sk#27, 5), ENSURE_REQUIREMENTS, [id=#31] +(42) Exchange +Input [3]: [cs_bill_customer_sk#28, cs_item_sk#29, cs_net_profit#30] +Arguments: hashpartitioning(cs_bill_customer_sk#28, cs_item_sk#29, 5), ENSURE_REQUIREMENTS, [id=#33] -(40) Sort [codegen id : 14] -Input [3]: [cs_bill_customer_sk#26, cs_item_sk#27, cs_net_profit#28] -Arguments: [cs_bill_customer_sk#26 ASC NULLS FIRST, cs_item_sk#27 ASC NULLS FIRST], false, 0 +(43) Sort [codegen id : 16] +Input [3]: [cs_bill_customer_sk#28, cs_item_sk#29, cs_net_profit#30] +Arguments: [cs_bill_customer_sk#28 ASC NULLS FIRST, cs_item_sk#29 ASC NULLS FIRST], false, 0 -(41) SortMergeJoin [codegen id : 15] -Left keys [2]: [sr_customer_sk#19, sr_item_sk#18] -Right keys [2]: [cs_bill_customer_sk#26, cs_item_sk#27] +(44) SortMergeJoin [codegen id : 17] +Left keys [2]: [sr_customer_sk#20, sr_item_sk#19] +Right keys [2]: [cs_bill_customer_sk#28, cs_item_sk#29] Join condition: None -(42) Project [codegen id : 15] -Output [7]: [ss_net_profit#5, sr_net_loss#21, cs_net_profit#28, s_store_id#10, s_store_name#11, i_item_id#15, i_item_desc#16] -Input [11]: [ss_net_profit#5, s_store_id#10, s_store_name#11, i_item_id#15, i_item_desc#16, sr_item_sk#18, sr_customer_sk#19, sr_net_loss#21, cs_bill_customer_sk#26, cs_item_sk#27, cs_net_profit#28] +(45) Project [codegen id : 17] +Output [7]: [ss_net_profit#5, sr_net_loss#22, cs_net_profit#30, s_store_id#10, s_store_name#11, i_item_id#15, i_item_desc#16] +Input [11]: [ss_net_profit#5, s_store_id#10, s_store_name#11, i_item_id#15, i_item_desc#16, sr_item_sk#19, sr_customer_sk#20, sr_net_loss#22, cs_bill_customer_sk#28, cs_item_sk#29, cs_net_profit#30] -(43) HashAggregate [codegen id : 15] -Input [7]: [ss_net_profit#5, sr_net_loss#21, cs_net_profit#28, s_store_id#10, s_store_name#11, i_item_id#15, i_item_desc#16] +(46) HashAggregate [codegen id : 17] +Input [7]: [ss_net_profit#5, sr_net_loss#22, cs_net_profit#30, s_store_id#10, s_store_name#11, i_item_id#15, i_item_desc#16] Keys [4]: [i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11] -Functions [3]: [partial_sum(UnscaledValue(ss_net_profit#5)), partial_sum(UnscaledValue(sr_net_loss#21)), partial_sum(UnscaledValue(cs_net_profit#28))] -Aggregate Attributes [3]: [sum#32, sum#33, sum#34] -Results [7]: [i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11, sum#35, sum#36, sum#37] +Functions [3]: [partial_sum(UnscaledValue(ss_net_profit#5)), partial_sum(UnscaledValue(sr_net_loss#22)), partial_sum(UnscaledValue(cs_net_profit#30))] +Aggregate Attributes [3]: [sum#34, sum#35, sum#36] +Results [7]: [i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11, sum#37, sum#38, sum#39] -(44) Exchange -Input [7]: [i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11, sum#35, sum#36, sum#37] -Arguments: hashpartitioning(i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11, 5), ENSURE_REQUIREMENTS, [id=#38] +(47) Exchange +Input [7]: [i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11, sum#37, sum#38, sum#39] +Arguments: hashpartitioning(i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11, 5), ENSURE_REQUIREMENTS, [id=#40] -(45) HashAggregate [codegen id : 16] -Input [7]: [i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11, sum#35, sum#36, sum#37] +(48) HashAggregate [codegen id : 18] +Input [7]: [i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11, sum#37, sum#38, sum#39] Keys [4]: [i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11] -Functions [3]: [sum(UnscaledValue(ss_net_profit#5)), sum(UnscaledValue(sr_net_loss#21)), sum(UnscaledValue(cs_net_profit#28))] -Aggregate Attributes [3]: [sum(UnscaledValue(ss_net_profit#5))#39, sum(UnscaledValue(sr_net_loss#21))#40, sum(UnscaledValue(cs_net_profit#28))#41] -Results [7]: [i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11, MakeDecimal(sum(UnscaledValue(ss_net_profit#5))#39,17,2) AS store_sales_profit#42, MakeDecimal(sum(UnscaledValue(sr_net_loss#21))#40,17,2) AS store_returns_loss#43, MakeDecimal(sum(UnscaledValue(cs_net_profit#28))#41,17,2) AS catalog_sales_profit#44] +Functions [3]: [sum(UnscaledValue(ss_net_profit#5)), sum(UnscaledValue(sr_net_loss#22)), sum(UnscaledValue(cs_net_profit#30))] +Aggregate Attributes [3]: [sum(UnscaledValue(ss_net_profit#5))#41, sum(UnscaledValue(sr_net_loss#22))#42, sum(UnscaledValue(cs_net_profit#30))#43] +Results [7]: [i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11, MakeDecimal(sum(UnscaledValue(ss_net_profit#5))#41,17,2) AS store_sales_profit#44, MakeDecimal(sum(UnscaledValue(sr_net_loss#22))#42,17,2) AS store_returns_loss#45, MakeDecimal(sum(UnscaledValue(cs_net_profit#30))#43,17,2) AS catalog_sales_profit#46] -(46) TakeOrderedAndProject -Input [7]: [i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11, store_sales_profit#42, store_returns_loss#43, catalog_sales_profit#44] -Arguments: 100, [i_item_id#15 ASC NULLS FIRST, i_item_desc#16 ASC NULLS FIRST, s_store_id#10 ASC NULLS FIRST, s_store_name#11 ASC NULLS FIRST], [i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11, store_sales_profit#42, store_returns_loss#43, catalog_sales_profit#44] +(49) TakeOrderedAndProject +Input [7]: [i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11, store_sales_profit#44, store_returns_loss#45, catalog_sales_profit#46] +Arguments: 100, [i_item_id#15 ASC NULLS FIRST, i_item_desc#16 ASC NULLS FIRST, s_store_id#10 ASC NULLS FIRST, s_store_name#11 ASC NULLS FIRST], [i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11, store_sales_profit#44, store_returns_loss#45, catalog_sales_profit#46] ===== Subqueries ===== Subquery:1 Hosting operator id = 1 Hosting Expression = ss_sold_date_sk#6 IN dynamicpruning#7 -BroadcastExchange (51) -+- * Project (50) - +- * Filter (49) - +- * ColumnarToRow (48) - +- Scan parquet default.date_dim (47) +BroadcastExchange (54) ++- * Project (53) + +- * Filter (52) + +- * ColumnarToRow (51) + +- Scan parquet default.date_dim (50) -(47) Scan parquet default.date_dim -Output [3]: [d_date_sk#8, d_year#45, d_moy#46] +(50) Scan parquet default.date_dim +Output [3]: [d_date_sk#8, d_year#47, d_moy#48] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_moy), IsNotNull(d_year), EqualTo(d_moy,4), EqualTo(d_year,2001), IsNotNull(d_date_sk)] ReadSchema: struct -(48) ColumnarToRow [codegen id : 1] -Input [3]: [d_date_sk#8, d_year#45, d_moy#46] +(51) ColumnarToRow [codegen id : 1] +Input [3]: [d_date_sk#8, d_year#47, d_moy#48] -(49) Filter [codegen id : 1] -Input [3]: [d_date_sk#8, d_year#45, d_moy#46] -Condition : ((((isnotnull(d_moy#46) AND isnotnull(d_year#45)) AND (d_moy#46 = 4)) AND (d_year#45 = 2001)) AND isnotnull(d_date_sk#8)) +(52) Filter [codegen id : 1] +Input [3]: [d_date_sk#8, d_year#47, d_moy#48] +Condition : ((((isnotnull(d_moy#48) AND isnotnull(d_year#47)) AND (d_moy#48 = 4)) AND (d_year#47 = 2001)) AND isnotnull(d_date_sk#8)) -(50) Project [codegen id : 1] +(53) Project [codegen id : 1] Output [1]: [d_date_sk#8] -Input [3]: [d_date_sk#8, d_year#45, d_moy#46] +Input [3]: [d_date_sk#8, d_year#47, d_moy#48] -(51) BroadcastExchange +(54) BroadcastExchange Input [1]: [d_date_sk#8] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#47] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#49] -Subquery:2 Hosting operator id = 23 Hosting Expression = sr_returned_date_sk#22 IN dynamicpruning#23 -BroadcastExchange (56) -+- * Project (55) - +- * Filter (54) - +- * ColumnarToRow (53) - +- Scan parquet default.date_dim (52) +Subquery:2 Hosting operator id = 24 Hosting Expression = sr_returned_date_sk#23 IN dynamicpruning#24 +BroadcastExchange (59) ++- * Project (58) + +- * Filter (57) + +- * ColumnarToRow (56) + +- Scan parquet default.date_dim (55) -(52) Scan parquet default.date_dim -Output [3]: [d_date_sk#24, d_year#48, d_moy#49] +(55) Scan parquet default.date_dim +Output [3]: [d_date_sk#25, d_year#50, d_moy#51] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_moy), IsNotNull(d_year), GreaterThanOrEqual(d_moy,4), LessThanOrEqual(d_moy,10), EqualTo(d_year,2001), IsNotNull(d_date_sk)] ReadSchema: struct -(53) ColumnarToRow [codegen id : 1] -Input [3]: [d_date_sk#24, d_year#48, d_moy#49] +(56) ColumnarToRow [codegen id : 1] +Input [3]: [d_date_sk#25, d_year#50, d_moy#51] -(54) Filter [codegen id : 1] -Input [3]: [d_date_sk#24, d_year#48, d_moy#49] -Condition : (((((isnotnull(d_moy#49) AND isnotnull(d_year#48)) AND (d_moy#49 >= 4)) AND (d_moy#49 <= 10)) AND (d_year#48 = 2001)) AND isnotnull(d_date_sk#24)) +(57) Filter [codegen id : 1] +Input [3]: [d_date_sk#25, d_year#50, d_moy#51] +Condition : (((((isnotnull(d_moy#51) AND isnotnull(d_year#50)) AND (d_moy#51 >= 4)) AND (d_moy#51 <= 10)) AND (d_year#50 = 2001)) AND isnotnull(d_date_sk#25)) -(55) Project [codegen id : 1] -Output [1]: [d_date_sk#24] -Input [3]: [d_date_sk#24, d_year#48, d_moy#49] +(58) Project [codegen id : 1] +Output [1]: [d_date_sk#25] +Input [3]: [d_date_sk#25, d_year#50, d_moy#51] -(56) BroadcastExchange -Input [1]: [d_date_sk#24] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#50] +(59) BroadcastExchange +Input [1]: [d_date_sk#25] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#52] -Subquery:3 Hosting operator id = 33 Hosting Expression = cs_sold_date_sk#29 IN dynamicpruning#23 +Subquery:3 Hosting operator id = 36 Hosting Expression = cs_sold_date_sk#31 IN dynamicpruning#24 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q25.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q25.sf100/simplified.txt index 0b106ced5504d..23d7e84027b2e 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q25.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q25.sf100/simplified.txt @@ -1,90 +1,97 @@ TakeOrderedAndProject [i_item_id,i_item_desc,s_store_id,s_store_name,store_sales_profit,store_returns_loss,catalog_sales_profit] - WholeStageCodegen (16) + WholeStageCodegen (18) HashAggregate [i_item_id,i_item_desc,s_store_id,s_store_name,sum,sum,sum] [sum(UnscaledValue(ss_net_profit)),sum(UnscaledValue(sr_net_loss)),sum(UnscaledValue(cs_net_profit)),store_sales_profit,store_returns_loss,catalog_sales_profit,sum,sum,sum] InputAdapter Exchange [i_item_id,i_item_desc,s_store_id,s_store_name] #1 - WholeStageCodegen (15) + WholeStageCodegen (17) HashAggregate [i_item_id,i_item_desc,s_store_id,s_store_name,ss_net_profit,sr_net_loss,cs_net_profit] [sum,sum,sum,sum,sum,sum] Project [ss_net_profit,sr_net_loss,cs_net_profit,s_store_id,s_store_name,i_item_id,i_item_desc] SortMergeJoin [sr_customer_sk,sr_item_sk,cs_bill_customer_sk,cs_item_sk] InputAdapter - WholeStageCodegen (11) - Project [ss_net_profit,s_store_id,s_store_name,i_item_id,i_item_desc,sr_item_sk,sr_customer_sk,sr_net_loss] - SortMergeJoin [ss_customer_sk,ss_item_sk,ss_ticket_number,sr_customer_sk,sr_item_sk,sr_ticket_number] - InputAdapter - WholeStageCodegen (7) - Sort [ss_customer_sk,ss_item_sk,ss_ticket_number] - Project [ss_item_sk,ss_customer_sk,ss_ticket_number,ss_net_profit,s_store_id,s_store_name,i_item_id,i_item_desc] - SortMergeJoin [ss_item_sk,i_item_sk] - InputAdapter - WholeStageCodegen (4) - Sort [ss_item_sk] - InputAdapter - Exchange [ss_item_sk] #2 - WholeStageCodegen (3) - Project [ss_item_sk,ss_customer_sk,ss_ticket_number,ss_net_profit,s_store_id,s_store_name] - BroadcastHashJoin [ss_store_sk,s_store_sk] - Project [ss_item_sk,ss_customer_sk,ss_store_sk,ss_ticket_number,ss_net_profit] - BroadcastHashJoin [ss_sold_date_sk,d_date_sk] - Filter [ss_customer_sk,ss_item_sk,ss_ticket_number,ss_store_sk] - ColumnarToRow - InputAdapter - Scan parquet default.store_sales [ss_item_sk,ss_customer_sk,ss_store_sk,ss_ticket_number,ss_net_profit,ss_sold_date_sk] - SubqueryBroadcast [d_date_sk] #1 - BroadcastExchange #3 - WholeStageCodegen (1) - Project [d_date_sk] - Filter [d_moy,d_year,d_date_sk] + WholeStageCodegen (13) + Sort [sr_customer_sk,sr_item_sk] + InputAdapter + Exchange [sr_customer_sk,sr_item_sk] #2 + WholeStageCodegen (12) + Project [ss_net_profit,s_store_id,s_store_name,i_item_id,i_item_desc,sr_item_sk,sr_customer_sk,sr_net_loss] + SortMergeJoin [ss_customer_sk,ss_item_sk,ss_ticket_number,sr_customer_sk,sr_item_sk,sr_ticket_number] + InputAdapter + WholeStageCodegen (8) + Sort [ss_customer_sk,ss_item_sk,ss_ticket_number] + InputAdapter + Exchange [ss_customer_sk,ss_item_sk,ss_ticket_number] #3 + WholeStageCodegen (7) + Project [ss_item_sk,ss_customer_sk,ss_ticket_number,ss_net_profit,s_store_id,s_store_name,i_item_id,i_item_desc] + SortMergeJoin [ss_item_sk,i_item_sk] + InputAdapter + WholeStageCodegen (4) + Sort [ss_item_sk] + InputAdapter + Exchange [ss_item_sk] #4 + WholeStageCodegen (3) + Project [ss_item_sk,ss_customer_sk,ss_ticket_number,ss_net_profit,s_store_id,s_store_name] + BroadcastHashJoin [ss_store_sk,s_store_sk] + Project [ss_item_sk,ss_customer_sk,ss_store_sk,ss_ticket_number,ss_net_profit] + BroadcastHashJoin [ss_sold_date_sk,d_date_sk] + Filter [ss_customer_sk,ss_item_sk,ss_ticket_number,ss_store_sk] + ColumnarToRow + InputAdapter + Scan parquet default.store_sales [ss_item_sk,ss_customer_sk,ss_store_sk,ss_ticket_number,ss_net_profit,ss_sold_date_sk] + SubqueryBroadcast [d_date_sk] #1 + BroadcastExchange #5 + WholeStageCodegen (1) + Project [d_date_sk] + Filter [d_moy,d_year,d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_year,d_moy] + InputAdapter + ReusedExchange [d_date_sk] #5 + InputAdapter + BroadcastExchange #6 + WholeStageCodegen (2) + Filter [s_store_sk] ColumnarToRow InputAdapter - Scan parquet default.date_dim [d_date_sk,d_year,d_moy] + Scan parquet default.store [s_store_sk,s_store_id,s_store_name] + InputAdapter + WholeStageCodegen (6) + Sort [i_item_sk] InputAdapter - ReusedExchange [d_date_sk] #3 - InputAdapter - BroadcastExchange #4 - WholeStageCodegen (2) - Filter [s_store_sk] - ColumnarToRow - InputAdapter - Scan parquet default.store [s_store_sk,s_store_id,s_store_name] - InputAdapter - WholeStageCodegen (6) - Sort [i_item_sk] - InputAdapter - Exchange [i_item_sk] #5 - WholeStageCodegen (5) - Filter [i_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.item [i_item_sk,i_item_id,i_item_desc] - InputAdapter - WholeStageCodegen (10) - Sort [sr_customer_sk,sr_item_sk,sr_ticket_number] - InputAdapter - Exchange [sr_item_sk] #6 - WholeStageCodegen (9) - Project [sr_item_sk,sr_customer_sk,sr_ticket_number,sr_net_loss] - BroadcastHashJoin [sr_returned_date_sk,d_date_sk] - Filter [sr_customer_sk,sr_item_sk,sr_ticket_number] - ColumnarToRow - InputAdapter - Scan parquet default.store_returns [sr_item_sk,sr_customer_sk,sr_ticket_number,sr_net_loss,sr_returned_date_sk] - SubqueryBroadcast [d_date_sk] #2 - BroadcastExchange #7 - WholeStageCodegen (1) - Project [d_date_sk] - Filter [d_moy,d_year,d_date_sk] - ColumnarToRow - InputAdapter - Scan parquet default.date_dim [d_date_sk,d_year,d_moy] - InputAdapter - ReusedExchange [d_date_sk] #7 + Exchange [i_item_sk] #7 + WholeStageCodegen (5) + Filter [i_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.item [i_item_sk,i_item_id,i_item_desc] + InputAdapter + WholeStageCodegen (11) + Sort [sr_customer_sk,sr_item_sk,sr_ticket_number] + InputAdapter + Exchange [sr_customer_sk,sr_item_sk,sr_ticket_number] #8 + WholeStageCodegen (10) + Project [sr_item_sk,sr_customer_sk,sr_ticket_number,sr_net_loss] + BroadcastHashJoin [sr_returned_date_sk,d_date_sk] + Filter [sr_customer_sk,sr_item_sk,sr_ticket_number] + ColumnarToRow + InputAdapter + Scan parquet default.store_returns [sr_item_sk,sr_customer_sk,sr_ticket_number,sr_net_loss,sr_returned_date_sk] + SubqueryBroadcast [d_date_sk] #2 + BroadcastExchange #9 + WholeStageCodegen (1) + Project [d_date_sk] + Filter [d_moy,d_year,d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_year,d_moy] + InputAdapter + ReusedExchange [d_date_sk] #9 InputAdapter - WholeStageCodegen (14) + WholeStageCodegen (16) Sort [cs_bill_customer_sk,cs_item_sk] InputAdapter - Exchange [cs_item_sk] #8 - WholeStageCodegen (13) + Exchange [cs_bill_customer_sk,cs_item_sk] #10 + WholeStageCodegen (15) Project [cs_bill_customer_sk,cs_item_sk,cs_net_profit] BroadcastHashJoin [cs_sold_date_sk,d_date_sk] Filter [cs_bill_customer_sk,cs_item_sk] @@ -93,4 +100,4 @@ TakeOrderedAndProject [i_item_id,i_item_desc,s_store_id,s_store_name,store_sales Scan parquet default.catalog_sales [cs_bill_customer_sk,cs_item_sk,cs_net_profit,cs_sold_date_sk] ReusedSubquery [d_date_sk] #2 InputAdapter - ReusedExchange [d_date_sk] #7 + ReusedExchange [d_date_sk] #9 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q29.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q29.sf100/explain.txt index e9857b76bc9e8..221439075d24d 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q29.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q29.sf100/explain.txt @@ -1,50 +1,53 @@ == Physical Plan == -TakeOrderedAndProject (46) -+- * HashAggregate (45) - +- Exchange (44) - +- * HashAggregate (43) - +- * Project (42) - +- * SortMergeJoin Inner (41) - :- * Project (32) - : +- * SortMergeJoin Inner (31) - : :- * Sort (22) - : : +- * Project (21) - : : +- * SortMergeJoin Inner (20) - : : :- * Sort (14) - : : : +- Exchange (13) - : : : +- * Project (12) - : : : +- * BroadcastHashJoin Inner BuildRight (11) - : : : :- * Project (6) - : : : : +- * BroadcastHashJoin Inner BuildRight (5) - : : : : :- * Filter (3) - : : : : : +- * ColumnarToRow (2) - : : : : : +- Scan parquet default.store_sales (1) - : : : : +- ReusedExchange (4) - : : : +- BroadcastExchange (10) - : : : +- * Filter (9) - : : : +- * ColumnarToRow (8) - : : : +- Scan parquet default.store (7) - : : +- * Sort (19) - : : +- Exchange (18) - : : +- * Filter (17) - : : +- * ColumnarToRow (16) - : : +- Scan parquet default.item (15) - : +- * Sort (30) - : +- Exchange (29) - : +- * Project (28) - : +- * BroadcastHashJoin Inner BuildRight (27) - : :- * Filter (25) - : : +- * ColumnarToRow (24) - : : +- Scan parquet default.store_returns (23) - : +- ReusedExchange (26) - +- * Sort (40) - +- Exchange (39) - +- * Project (38) - +- * BroadcastHashJoin Inner BuildRight (37) - :- * Filter (35) - : +- * ColumnarToRow (34) - : +- Scan parquet default.catalog_sales (33) - +- ReusedExchange (36) +TakeOrderedAndProject (49) ++- * HashAggregate (48) + +- Exchange (47) + +- * HashAggregate (46) + +- * Project (45) + +- * SortMergeJoin Inner (44) + :- * Sort (35) + : +- Exchange (34) + : +- * Project (33) + : +- * SortMergeJoin Inner (32) + : :- * Sort (23) + : : +- Exchange (22) + : : +- * Project (21) + : : +- * SortMergeJoin Inner (20) + : : :- * Sort (14) + : : : +- Exchange (13) + : : : +- * Project (12) + : : : +- * BroadcastHashJoin Inner BuildRight (11) + : : : :- * Project (6) + : : : : +- * BroadcastHashJoin Inner BuildRight (5) + : : : : :- * Filter (3) + : : : : : +- * ColumnarToRow (2) + : : : : : +- Scan parquet default.store_sales (1) + : : : : +- ReusedExchange (4) + : : : +- BroadcastExchange (10) + : : : +- * Filter (9) + : : : +- * ColumnarToRow (8) + : : : +- Scan parquet default.store (7) + : : +- * Sort (19) + : : +- Exchange (18) + : : +- * Filter (17) + : : +- * ColumnarToRow (16) + : : +- Scan parquet default.item (15) + : +- * Sort (31) + : +- Exchange (30) + : +- * Project (29) + : +- * BroadcastHashJoin Inner BuildRight (28) + : :- * Filter (26) + : : +- * ColumnarToRow (25) + : : +- Scan parquet default.store_returns (24) + : +- ReusedExchange (27) + +- * Sort (43) + +- Exchange (42) + +- * Project (41) + +- * BroadcastHashJoin Inner BuildRight (40) + :- * Filter (38) + : +- * ColumnarToRow (37) + : +- Scan parquet default.catalog_sales (36) + +- ReusedExchange (39) (1) Scan parquet default.store_sales @@ -62,7 +65,7 @@ Input [6]: [ss_item_sk#1, ss_customer_sk#2, ss_store_sk#3, ss_ticket_number#4, s Input [6]: [ss_item_sk#1, ss_customer_sk#2, ss_store_sk#3, ss_ticket_number#4, ss_quantity#5, ss_sold_date_sk#6] Condition : (((isnotnull(ss_customer_sk#2) AND isnotnull(ss_item_sk#1)) AND isnotnull(ss_ticket_number#4)) AND isnotnull(ss_store_sk#3)) -(4) ReusedExchange [Reuses operator id: 51] +(4) ReusedExchange [Reuses operator id: 54] Output [1]: [d_date_sk#8] (5) BroadcastHashJoin [codegen id : 3] @@ -140,210 +143,222 @@ Join condition: None Output [8]: [ss_item_sk#1, ss_customer_sk#2, ss_ticket_number#4, ss_quantity#5, s_store_id#10, s_store_name#11, i_item_id#15, i_item_desc#16] Input [9]: [ss_item_sk#1, ss_customer_sk#2, ss_ticket_number#4, ss_quantity#5, s_store_id#10, s_store_name#11, i_item_sk#14, i_item_id#15, i_item_desc#16] -(22) Sort [codegen id : 7] +(22) Exchange +Input [8]: [ss_item_sk#1, ss_customer_sk#2, ss_ticket_number#4, ss_quantity#5, s_store_id#10, s_store_name#11, i_item_id#15, i_item_desc#16] +Arguments: hashpartitioning(ss_customer_sk#2, ss_item_sk#1, ss_ticket_number#4, 5), ENSURE_REQUIREMENTS, [id=#18] + +(23) Sort [codegen id : 8] Input [8]: [ss_item_sk#1, ss_customer_sk#2, ss_ticket_number#4, ss_quantity#5, s_store_id#10, s_store_name#11, i_item_id#15, i_item_desc#16] Arguments: [ss_customer_sk#2 ASC NULLS FIRST, ss_item_sk#1 ASC NULLS FIRST, ss_ticket_number#4 ASC NULLS FIRST], false, 0 -(23) Scan parquet default.store_returns -Output [5]: [sr_item_sk#18, sr_customer_sk#19, sr_ticket_number#20, sr_return_quantity#21, sr_returned_date_sk#22] +(24) Scan parquet default.store_returns +Output [5]: [sr_item_sk#19, sr_customer_sk#20, sr_ticket_number#21, sr_return_quantity#22, sr_returned_date_sk#23] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(sr_returned_date_sk#22), dynamicpruningexpression(sr_returned_date_sk#22 IN dynamicpruning#23)] +PartitionFilters: [isnotnull(sr_returned_date_sk#23), dynamicpruningexpression(sr_returned_date_sk#23 IN dynamicpruning#24)] PushedFilters: [IsNotNull(sr_customer_sk), IsNotNull(sr_item_sk), IsNotNull(sr_ticket_number)] ReadSchema: struct -(24) ColumnarToRow [codegen id : 9] -Input [5]: [sr_item_sk#18, sr_customer_sk#19, sr_ticket_number#20, sr_return_quantity#21, sr_returned_date_sk#22] +(25) ColumnarToRow [codegen id : 10] +Input [5]: [sr_item_sk#19, sr_customer_sk#20, sr_ticket_number#21, sr_return_quantity#22, sr_returned_date_sk#23] -(25) Filter [codegen id : 9] -Input [5]: [sr_item_sk#18, sr_customer_sk#19, sr_ticket_number#20, sr_return_quantity#21, sr_returned_date_sk#22] -Condition : ((isnotnull(sr_customer_sk#19) AND isnotnull(sr_item_sk#18)) AND isnotnull(sr_ticket_number#20)) +(26) Filter [codegen id : 10] +Input [5]: [sr_item_sk#19, sr_customer_sk#20, sr_ticket_number#21, sr_return_quantity#22, sr_returned_date_sk#23] +Condition : ((isnotnull(sr_customer_sk#20) AND isnotnull(sr_item_sk#19)) AND isnotnull(sr_ticket_number#21)) -(26) ReusedExchange [Reuses operator id: 56] -Output [1]: [d_date_sk#24] +(27) ReusedExchange [Reuses operator id: 59] +Output [1]: [d_date_sk#25] -(27) BroadcastHashJoin [codegen id : 9] -Left keys [1]: [sr_returned_date_sk#22] -Right keys [1]: [d_date_sk#24] +(28) BroadcastHashJoin [codegen id : 10] +Left keys [1]: [sr_returned_date_sk#23] +Right keys [1]: [d_date_sk#25] Join condition: None -(28) Project [codegen id : 9] -Output [4]: [sr_item_sk#18, sr_customer_sk#19, sr_ticket_number#20, sr_return_quantity#21] -Input [6]: [sr_item_sk#18, sr_customer_sk#19, sr_ticket_number#20, sr_return_quantity#21, sr_returned_date_sk#22, d_date_sk#24] +(29) Project [codegen id : 10] +Output [4]: [sr_item_sk#19, sr_customer_sk#20, sr_ticket_number#21, sr_return_quantity#22] +Input [6]: [sr_item_sk#19, sr_customer_sk#20, sr_ticket_number#21, sr_return_quantity#22, sr_returned_date_sk#23, d_date_sk#25] -(29) Exchange -Input [4]: [sr_item_sk#18, sr_customer_sk#19, sr_ticket_number#20, sr_return_quantity#21] -Arguments: hashpartitioning(sr_item_sk#18, 5), ENSURE_REQUIREMENTS, [id=#25] +(30) Exchange +Input [4]: [sr_item_sk#19, sr_customer_sk#20, sr_ticket_number#21, sr_return_quantity#22] +Arguments: hashpartitioning(sr_customer_sk#20, sr_item_sk#19, sr_ticket_number#21, 5), ENSURE_REQUIREMENTS, [id=#26] -(30) Sort [codegen id : 10] -Input [4]: [sr_item_sk#18, sr_customer_sk#19, sr_ticket_number#20, sr_return_quantity#21] -Arguments: [sr_customer_sk#19 ASC NULLS FIRST, sr_item_sk#18 ASC NULLS FIRST, sr_ticket_number#20 ASC NULLS FIRST], false, 0 +(31) Sort [codegen id : 11] +Input [4]: [sr_item_sk#19, sr_customer_sk#20, sr_ticket_number#21, sr_return_quantity#22] +Arguments: [sr_customer_sk#20 ASC NULLS FIRST, sr_item_sk#19 ASC NULLS FIRST, sr_ticket_number#21 ASC NULLS FIRST], false, 0 -(31) SortMergeJoin [codegen id : 11] +(32) SortMergeJoin [codegen id : 12] Left keys [3]: [ss_customer_sk#2, ss_item_sk#1, ss_ticket_number#4] -Right keys [3]: [sr_customer_sk#19, sr_item_sk#18, sr_ticket_number#20] +Right keys [3]: [sr_customer_sk#20, sr_item_sk#19, sr_ticket_number#21] Join condition: None -(32) Project [codegen id : 11] -Output [8]: [ss_quantity#5, s_store_id#10, s_store_name#11, i_item_id#15, i_item_desc#16, sr_item_sk#18, sr_customer_sk#19, sr_return_quantity#21] -Input [12]: [ss_item_sk#1, ss_customer_sk#2, ss_ticket_number#4, ss_quantity#5, s_store_id#10, s_store_name#11, i_item_id#15, i_item_desc#16, sr_item_sk#18, sr_customer_sk#19, sr_ticket_number#20, sr_return_quantity#21] +(33) Project [codegen id : 12] +Output [8]: [ss_quantity#5, s_store_id#10, s_store_name#11, i_item_id#15, i_item_desc#16, sr_item_sk#19, sr_customer_sk#20, sr_return_quantity#22] +Input [12]: [ss_item_sk#1, ss_customer_sk#2, ss_ticket_number#4, ss_quantity#5, s_store_id#10, s_store_name#11, i_item_id#15, i_item_desc#16, sr_item_sk#19, sr_customer_sk#20, sr_ticket_number#21, sr_return_quantity#22] + +(34) Exchange +Input [8]: [ss_quantity#5, s_store_id#10, s_store_name#11, i_item_id#15, i_item_desc#16, sr_item_sk#19, sr_customer_sk#20, sr_return_quantity#22] +Arguments: hashpartitioning(sr_customer_sk#20, sr_item_sk#19, 5), ENSURE_REQUIREMENTS, [id=#27] -(33) Scan parquet default.catalog_sales -Output [4]: [cs_bill_customer_sk#26, cs_item_sk#27, cs_quantity#28, cs_sold_date_sk#29] +(35) Sort [codegen id : 13] +Input [8]: [ss_quantity#5, s_store_id#10, s_store_name#11, i_item_id#15, i_item_desc#16, sr_item_sk#19, sr_customer_sk#20, sr_return_quantity#22] +Arguments: [sr_customer_sk#20 ASC NULLS FIRST, sr_item_sk#19 ASC NULLS FIRST], false, 0 + +(36) Scan parquet default.catalog_sales +Output [4]: [cs_bill_customer_sk#28, cs_item_sk#29, cs_quantity#30, cs_sold_date_sk#31] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(cs_sold_date_sk#29), dynamicpruningexpression(cs_sold_date_sk#29 IN dynamicpruning#30)] +PartitionFilters: [isnotnull(cs_sold_date_sk#31), dynamicpruningexpression(cs_sold_date_sk#31 IN dynamicpruning#32)] PushedFilters: [IsNotNull(cs_bill_customer_sk), IsNotNull(cs_item_sk)] ReadSchema: struct -(34) ColumnarToRow [codegen id : 13] -Input [4]: [cs_bill_customer_sk#26, cs_item_sk#27, cs_quantity#28, cs_sold_date_sk#29] +(37) ColumnarToRow [codegen id : 15] +Input [4]: [cs_bill_customer_sk#28, cs_item_sk#29, cs_quantity#30, cs_sold_date_sk#31] -(35) Filter [codegen id : 13] -Input [4]: [cs_bill_customer_sk#26, cs_item_sk#27, cs_quantity#28, cs_sold_date_sk#29] -Condition : (isnotnull(cs_bill_customer_sk#26) AND isnotnull(cs_item_sk#27)) +(38) Filter [codegen id : 15] +Input [4]: [cs_bill_customer_sk#28, cs_item_sk#29, cs_quantity#30, cs_sold_date_sk#31] +Condition : (isnotnull(cs_bill_customer_sk#28) AND isnotnull(cs_item_sk#29)) -(36) ReusedExchange [Reuses operator id: 61] -Output [1]: [d_date_sk#31] +(39) ReusedExchange [Reuses operator id: 64] +Output [1]: [d_date_sk#33] -(37) BroadcastHashJoin [codegen id : 13] -Left keys [1]: [cs_sold_date_sk#29] -Right keys [1]: [d_date_sk#31] +(40) BroadcastHashJoin [codegen id : 15] +Left keys [1]: [cs_sold_date_sk#31] +Right keys [1]: [d_date_sk#33] Join condition: None -(38) Project [codegen id : 13] -Output [3]: [cs_bill_customer_sk#26, cs_item_sk#27, cs_quantity#28] -Input [5]: [cs_bill_customer_sk#26, cs_item_sk#27, cs_quantity#28, cs_sold_date_sk#29, d_date_sk#31] +(41) Project [codegen id : 15] +Output [3]: [cs_bill_customer_sk#28, cs_item_sk#29, cs_quantity#30] +Input [5]: [cs_bill_customer_sk#28, cs_item_sk#29, cs_quantity#30, cs_sold_date_sk#31, d_date_sk#33] -(39) Exchange -Input [3]: [cs_bill_customer_sk#26, cs_item_sk#27, cs_quantity#28] -Arguments: hashpartitioning(cs_item_sk#27, 5), ENSURE_REQUIREMENTS, [id=#32] +(42) Exchange +Input [3]: [cs_bill_customer_sk#28, cs_item_sk#29, cs_quantity#30] +Arguments: hashpartitioning(cs_bill_customer_sk#28, cs_item_sk#29, 5), ENSURE_REQUIREMENTS, [id=#34] -(40) Sort [codegen id : 14] -Input [3]: [cs_bill_customer_sk#26, cs_item_sk#27, cs_quantity#28] -Arguments: [cs_bill_customer_sk#26 ASC NULLS FIRST, cs_item_sk#27 ASC NULLS FIRST], false, 0 +(43) Sort [codegen id : 16] +Input [3]: [cs_bill_customer_sk#28, cs_item_sk#29, cs_quantity#30] +Arguments: [cs_bill_customer_sk#28 ASC NULLS FIRST, cs_item_sk#29 ASC NULLS FIRST], false, 0 -(41) SortMergeJoin [codegen id : 15] -Left keys [2]: [sr_customer_sk#19, sr_item_sk#18] -Right keys [2]: [cs_bill_customer_sk#26, cs_item_sk#27] +(44) SortMergeJoin [codegen id : 17] +Left keys [2]: [sr_customer_sk#20, sr_item_sk#19] +Right keys [2]: [cs_bill_customer_sk#28, cs_item_sk#29] Join condition: None -(42) Project [codegen id : 15] -Output [7]: [ss_quantity#5, sr_return_quantity#21, cs_quantity#28, s_store_id#10, s_store_name#11, i_item_id#15, i_item_desc#16] -Input [11]: [ss_quantity#5, s_store_id#10, s_store_name#11, i_item_id#15, i_item_desc#16, sr_item_sk#18, sr_customer_sk#19, sr_return_quantity#21, cs_bill_customer_sk#26, cs_item_sk#27, cs_quantity#28] +(45) Project [codegen id : 17] +Output [7]: [ss_quantity#5, sr_return_quantity#22, cs_quantity#30, s_store_id#10, s_store_name#11, i_item_id#15, i_item_desc#16] +Input [11]: [ss_quantity#5, s_store_id#10, s_store_name#11, i_item_id#15, i_item_desc#16, sr_item_sk#19, sr_customer_sk#20, sr_return_quantity#22, cs_bill_customer_sk#28, cs_item_sk#29, cs_quantity#30] -(43) HashAggregate [codegen id : 15] -Input [7]: [ss_quantity#5, sr_return_quantity#21, cs_quantity#28, s_store_id#10, s_store_name#11, i_item_id#15, i_item_desc#16] +(46) HashAggregate [codegen id : 17] +Input [7]: [ss_quantity#5, sr_return_quantity#22, cs_quantity#30, s_store_id#10, s_store_name#11, i_item_id#15, i_item_desc#16] Keys [4]: [i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11] -Functions [3]: [partial_sum(ss_quantity#5), partial_sum(sr_return_quantity#21), partial_sum(cs_quantity#28)] -Aggregate Attributes [3]: [sum#33, sum#34, sum#35] -Results [7]: [i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11, sum#36, sum#37, sum#38] +Functions [3]: [partial_sum(ss_quantity#5), partial_sum(sr_return_quantity#22), partial_sum(cs_quantity#30)] +Aggregate Attributes [3]: [sum#35, sum#36, sum#37] +Results [7]: [i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11, sum#38, sum#39, sum#40] -(44) Exchange -Input [7]: [i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11, sum#36, sum#37, sum#38] -Arguments: hashpartitioning(i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11, 5), ENSURE_REQUIREMENTS, [id=#39] +(47) Exchange +Input [7]: [i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11, sum#38, sum#39, sum#40] +Arguments: hashpartitioning(i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11, 5), ENSURE_REQUIREMENTS, [id=#41] -(45) HashAggregate [codegen id : 16] -Input [7]: [i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11, sum#36, sum#37, sum#38] +(48) HashAggregate [codegen id : 18] +Input [7]: [i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11, sum#38, sum#39, sum#40] Keys [4]: [i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11] -Functions [3]: [sum(ss_quantity#5), sum(sr_return_quantity#21), sum(cs_quantity#28)] -Aggregate Attributes [3]: [sum(ss_quantity#5)#40, sum(sr_return_quantity#21)#41, sum(cs_quantity#28)#42] -Results [7]: [i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11, sum(ss_quantity#5)#40 AS store_sales_quantity#43, sum(sr_return_quantity#21)#41 AS store_returns_quantity#44, sum(cs_quantity#28)#42 AS catalog_sales_quantity#45] +Functions [3]: [sum(ss_quantity#5), sum(sr_return_quantity#22), sum(cs_quantity#30)] +Aggregate Attributes [3]: [sum(ss_quantity#5)#42, sum(sr_return_quantity#22)#43, sum(cs_quantity#30)#44] +Results [7]: [i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11, sum(ss_quantity#5)#42 AS store_sales_quantity#45, sum(sr_return_quantity#22)#43 AS store_returns_quantity#46, sum(cs_quantity#30)#44 AS catalog_sales_quantity#47] -(46) TakeOrderedAndProject -Input [7]: [i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11, store_sales_quantity#43, store_returns_quantity#44, catalog_sales_quantity#45] -Arguments: 100, [i_item_id#15 ASC NULLS FIRST, i_item_desc#16 ASC NULLS FIRST, s_store_id#10 ASC NULLS FIRST, s_store_name#11 ASC NULLS FIRST], [i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11, store_sales_quantity#43, store_returns_quantity#44, catalog_sales_quantity#45] +(49) TakeOrderedAndProject +Input [7]: [i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11, store_sales_quantity#45, store_returns_quantity#46, catalog_sales_quantity#47] +Arguments: 100, [i_item_id#15 ASC NULLS FIRST, i_item_desc#16 ASC NULLS FIRST, s_store_id#10 ASC NULLS FIRST, s_store_name#11 ASC NULLS FIRST], [i_item_id#15, i_item_desc#16, s_store_id#10, s_store_name#11, store_sales_quantity#45, store_returns_quantity#46, catalog_sales_quantity#47] ===== Subqueries ===== Subquery:1 Hosting operator id = 1 Hosting Expression = ss_sold_date_sk#6 IN dynamicpruning#7 -BroadcastExchange (51) -+- * Project (50) - +- * Filter (49) - +- * ColumnarToRow (48) - +- Scan parquet default.date_dim (47) +BroadcastExchange (54) ++- * Project (53) + +- * Filter (52) + +- * ColumnarToRow (51) + +- Scan parquet default.date_dim (50) -(47) Scan parquet default.date_dim -Output [3]: [d_date_sk#8, d_year#46, d_moy#47] +(50) Scan parquet default.date_dim +Output [3]: [d_date_sk#8, d_year#48, d_moy#49] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_moy), IsNotNull(d_year), EqualTo(d_moy,9), EqualTo(d_year,1999), IsNotNull(d_date_sk)] ReadSchema: struct -(48) ColumnarToRow [codegen id : 1] -Input [3]: [d_date_sk#8, d_year#46, d_moy#47] +(51) ColumnarToRow [codegen id : 1] +Input [3]: [d_date_sk#8, d_year#48, d_moy#49] -(49) Filter [codegen id : 1] -Input [3]: [d_date_sk#8, d_year#46, d_moy#47] -Condition : ((((isnotnull(d_moy#47) AND isnotnull(d_year#46)) AND (d_moy#47 = 9)) AND (d_year#46 = 1999)) AND isnotnull(d_date_sk#8)) +(52) Filter [codegen id : 1] +Input [3]: [d_date_sk#8, d_year#48, d_moy#49] +Condition : ((((isnotnull(d_moy#49) AND isnotnull(d_year#48)) AND (d_moy#49 = 9)) AND (d_year#48 = 1999)) AND isnotnull(d_date_sk#8)) -(50) Project [codegen id : 1] +(53) Project [codegen id : 1] Output [1]: [d_date_sk#8] -Input [3]: [d_date_sk#8, d_year#46, d_moy#47] +Input [3]: [d_date_sk#8, d_year#48, d_moy#49] -(51) BroadcastExchange +(54) BroadcastExchange Input [1]: [d_date_sk#8] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#48] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#50] -Subquery:2 Hosting operator id = 23 Hosting Expression = sr_returned_date_sk#22 IN dynamicpruning#23 -BroadcastExchange (56) -+- * Project (55) - +- * Filter (54) - +- * ColumnarToRow (53) - +- Scan parquet default.date_dim (52) +Subquery:2 Hosting operator id = 24 Hosting Expression = sr_returned_date_sk#23 IN dynamicpruning#24 +BroadcastExchange (59) ++- * Project (58) + +- * Filter (57) + +- * ColumnarToRow (56) + +- Scan parquet default.date_dim (55) -(52) Scan parquet default.date_dim -Output [3]: [d_date_sk#24, d_year#49, d_moy#50] +(55) Scan parquet default.date_dim +Output [3]: [d_date_sk#25, d_year#51, d_moy#52] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_moy), IsNotNull(d_year), GreaterThanOrEqual(d_moy,9), LessThanOrEqual(d_moy,12), EqualTo(d_year,1999), IsNotNull(d_date_sk)] ReadSchema: struct -(53) ColumnarToRow [codegen id : 1] -Input [3]: [d_date_sk#24, d_year#49, d_moy#50] +(56) ColumnarToRow [codegen id : 1] +Input [3]: [d_date_sk#25, d_year#51, d_moy#52] -(54) Filter [codegen id : 1] -Input [3]: [d_date_sk#24, d_year#49, d_moy#50] -Condition : (((((isnotnull(d_moy#50) AND isnotnull(d_year#49)) AND (d_moy#50 >= 9)) AND (d_moy#50 <= 12)) AND (d_year#49 = 1999)) AND isnotnull(d_date_sk#24)) +(57) Filter [codegen id : 1] +Input [3]: [d_date_sk#25, d_year#51, d_moy#52] +Condition : (((((isnotnull(d_moy#52) AND isnotnull(d_year#51)) AND (d_moy#52 >= 9)) AND (d_moy#52 <= 12)) AND (d_year#51 = 1999)) AND isnotnull(d_date_sk#25)) -(55) Project [codegen id : 1] -Output [1]: [d_date_sk#24] -Input [3]: [d_date_sk#24, d_year#49, d_moy#50] +(58) Project [codegen id : 1] +Output [1]: [d_date_sk#25] +Input [3]: [d_date_sk#25, d_year#51, d_moy#52] -(56) BroadcastExchange -Input [1]: [d_date_sk#24] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#51] +(59) BroadcastExchange +Input [1]: [d_date_sk#25] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#53] -Subquery:3 Hosting operator id = 33 Hosting Expression = cs_sold_date_sk#29 IN dynamicpruning#30 -BroadcastExchange (61) -+- * Project (60) - +- * Filter (59) - +- * ColumnarToRow (58) - +- Scan parquet default.date_dim (57) +Subquery:3 Hosting operator id = 36 Hosting Expression = cs_sold_date_sk#31 IN dynamicpruning#32 +BroadcastExchange (64) ++- * Project (63) + +- * Filter (62) + +- * ColumnarToRow (61) + +- Scan parquet default.date_dim (60) -(57) Scan parquet default.date_dim -Output [2]: [d_date_sk#31, d_year#52] +(60) Scan parquet default.date_dim +Output [2]: [d_date_sk#33, d_year#54] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [In(d_year, [1999,2000,2001]), IsNotNull(d_date_sk)] ReadSchema: struct -(58) ColumnarToRow [codegen id : 1] -Input [2]: [d_date_sk#31, d_year#52] +(61) ColumnarToRow [codegen id : 1] +Input [2]: [d_date_sk#33, d_year#54] -(59) Filter [codegen id : 1] -Input [2]: [d_date_sk#31, d_year#52] -Condition : (d_year#52 IN (1999,2000,2001) AND isnotnull(d_date_sk#31)) +(62) Filter [codegen id : 1] +Input [2]: [d_date_sk#33, d_year#54] +Condition : (d_year#54 IN (1999,2000,2001) AND isnotnull(d_date_sk#33)) -(60) Project [codegen id : 1] -Output [1]: [d_date_sk#31] -Input [2]: [d_date_sk#31, d_year#52] +(63) Project [codegen id : 1] +Output [1]: [d_date_sk#33] +Input [2]: [d_date_sk#33, d_year#54] -(61) BroadcastExchange -Input [1]: [d_date_sk#31] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#53] +(64) BroadcastExchange +Input [1]: [d_date_sk#33] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#55] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q29.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q29.sf100/simplified.txt index 0db54fe759962..5463f3f0a8fd4 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q29.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q29.sf100/simplified.txt @@ -1,90 +1,97 @@ TakeOrderedAndProject [i_item_id,i_item_desc,s_store_id,s_store_name,store_sales_quantity,store_returns_quantity,catalog_sales_quantity] - WholeStageCodegen (16) + WholeStageCodegen (18) HashAggregate [i_item_id,i_item_desc,s_store_id,s_store_name,sum,sum,sum] [sum(ss_quantity),sum(sr_return_quantity),sum(cs_quantity),store_sales_quantity,store_returns_quantity,catalog_sales_quantity,sum,sum,sum] InputAdapter Exchange [i_item_id,i_item_desc,s_store_id,s_store_name] #1 - WholeStageCodegen (15) + WholeStageCodegen (17) HashAggregate [i_item_id,i_item_desc,s_store_id,s_store_name,ss_quantity,sr_return_quantity,cs_quantity] [sum,sum,sum,sum,sum,sum] Project [ss_quantity,sr_return_quantity,cs_quantity,s_store_id,s_store_name,i_item_id,i_item_desc] SortMergeJoin [sr_customer_sk,sr_item_sk,cs_bill_customer_sk,cs_item_sk] InputAdapter - WholeStageCodegen (11) - Project [ss_quantity,s_store_id,s_store_name,i_item_id,i_item_desc,sr_item_sk,sr_customer_sk,sr_return_quantity] - SortMergeJoin [ss_customer_sk,ss_item_sk,ss_ticket_number,sr_customer_sk,sr_item_sk,sr_ticket_number] - InputAdapter - WholeStageCodegen (7) - Sort [ss_customer_sk,ss_item_sk,ss_ticket_number] - Project [ss_item_sk,ss_customer_sk,ss_ticket_number,ss_quantity,s_store_id,s_store_name,i_item_id,i_item_desc] - SortMergeJoin [ss_item_sk,i_item_sk] - InputAdapter - WholeStageCodegen (4) - Sort [ss_item_sk] - InputAdapter - Exchange [ss_item_sk] #2 - WholeStageCodegen (3) - Project [ss_item_sk,ss_customer_sk,ss_ticket_number,ss_quantity,s_store_id,s_store_name] - BroadcastHashJoin [ss_store_sk,s_store_sk] - Project [ss_item_sk,ss_customer_sk,ss_store_sk,ss_ticket_number,ss_quantity] - BroadcastHashJoin [ss_sold_date_sk,d_date_sk] - Filter [ss_customer_sk,ss_item_sk,ss_ticket_number,ss_store_sk] - ColumnarToRow - InputAdapter - Scan parquet default.store_sales [ss_item_sk,ss_customer_sk,ss_store_sk,ss_ticket_number,ss_quantity,ss_sold_date_sk] - SubqueryBroadcast [d_date_sk] #1 - BroadcastExchange #3 - WholeStageCodegen (1) - Project [d_date_sk] - Filter [d_moy,d_year,d_date_sk] + WholeStageCodegen (13) + Sort [sr_customer_sk,sr_item_sk] + InputAdapter + Exchange [sr_customer_sk,sr_item_sk] #2 + WholeStageCodegen (12) + Project [ss_quantity,s_store_id,s_store_name,i_item_id,i_item_desc,sr_item_sk,sr_customer_sk,sr_return_quantity] + SortMergeJoin [ss_customer_sk,ss_item_sk,ss_ticket_number,sr_customer_sk,sr_item_sk,sr_ticket_number] + InputAdapter + WholeStageCodegen (8) + Sort [ss_customer_sk,ss_item_sk,ss_ticket_number] + InputAdapter + Exchange [ss_customer_sk,ss_item_sk,ss_ticket_number] #3 + WholeStageCodegen (7) + Project [ss_item_sk,ss_customer_sk,ss_ticket_number,ss_quantity,s_store_id,s_store_name,i_item_id,i_item_desc] + SortMergeJoin [ss_item_sk,i_item_sk] + InputAdapter + WholeStageCodegen (4) + Sort [ss_item_sk] + InputAdapter + Exchange [ss_item_sk] #4 + WholeStageCodegen (3) + Project [ss_item_sk,ss_customer_sk,ss_ticket_number,ss_quantity,s_store_id,s_store_name] + BroadcastHashJoin [ss_store_sk,s_store_sk] + Project [ss_item_sk,ss_customer_sk,ss_store_sk,ss_ticket_number,ss_quantity] + BroadcastHashJoin [ss_sold_date_sk,d_date_sk] + Filter [ss_customer_sk,ss_item_sk,ss_ticket_number,ss_store_sk] + ColumnarToRow + InputAdapter + Scan parquet default.store_sales [ss_item_sk,ss_customer_sk,ss_store_sk,ss_ticket_number,ss_quantity,ss_sold_date_sk] + SubqueryBroadcast [d_date_sk] #1 + BroadcastExchange #5 + WholeStageCodegen (1) + Project [d_date_sk] + Filter [d_moy,d_year,d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_year,d_moy] + InputAdapter + ReusedExchange [d_date_sk] #5 + InputAdapter + BroadcastExchange #6 + WholeStageCodegen (2) + Filter [s_store_sk] ColumnarToRow InputAdapter - Scan parquet default.date_dim [d_date_sk,d_year,d_moy] + Scan parquet default.store [s_store_sk,s_store_id,s_store_name] + InputAdapter + WholeStageCodegen (6) + Sort [i_item_sk] InputAdapter - ReusedExchange [d_date_sk] #3 - InputAdapter - BroadcastExchange #4 - WholeStageCodegen (2) - Filter [s_store_sk] - ColumnarToRow - InputAdapter - Scan parquet default.store [s_store_sk,s_store_id,s_store_name] - InputAdapter - WholeStageCodegen (6) - Sort [i_item_sk] - InputAdapter - Exchange [i_item_sk] #5 - WholeStageCodegen (5) - Filter [i_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.item [i_item_sk,i_item_id,i_item_desc] - InputAdapter - WholeStageCodegen (10) - Sort [sr_customer_sk,sr_item_sk,sr_ticket_number] - InputAdapter - Exchange [sr_item_sk] #6 - WholeStageCodegen (9) - Project [sr_item_sk,sr_customer_sk,sr_ticket_number,sr_return_quantity] - BroadcastHashJoin [sr_returned_date_sk,d_date_sk] - Filter [sr_customer_sk,sr_item_sk,sr_ticket_number] - ColumnarToRow - InputAdapter - Scan parquet default.store_returns [sr_item_sk,sr_customer_sk,sr_ticket_number,sr_return_quantity,sr_returned_date_sk] - SubqueryBroadcast [d_date_sk] #2 - BroadcastExchange #7 - WholeStageCodegen (1) - Project [d_date_sk] - Filter [d_moy,d_year,d_date_sk] - ColumnarToRow - InputAdapter - Scan parquet default.date_dim [d_date_sk,d_year,d_moy] - InputAdapter - ReusedExchange [d_date_sk] #7 + Exchange [i_item_sk] #7 + WholeStageCodegen (5) + Filter [i_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.item [i_item_sk,i_item_id,i_item_desc] + InputAdapter + WholeStageCodegen (11) + Sort [sr_customer_sk,sr_item_sk,sr_ticket_number] + InputAdapter + Exchange [sr_customer_sk,sr_item_sk,sr_ticket_number] #8 + WholeStageCodegen (10) + Project [sr_item_sk,sr_customer_sk,sr_ticket_number,sr_return_quantity] + BroadcastHashJoin [sr_returned_date_sk,d_date_sk] + Filter [sr_customer_sk,sr_item_sk,sr_ticket_number] + ColumnarToRow + InputAdapter + Scan parquet default.store_returns [sr_item_sk,sr_customer_sk,sr_ticket_number,sr_return_quantity,sr_returned_date_sk] + SubqueryBroadcast [d_date_sk] #2 + BroadcastExchange #9 + WholeStageCodegen (1) + Project [d_date_sk] + Filter [d_moy,d_year,d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_year,d_moy] + InputAdapter + ReusedExchange [d_date_sk] #9 InputAdapter - WholeStageCodegen (14) + WholeStageCodegen (16) Sort [cs_bill_customer_sk,cs_item_sk] InputAdapter - Exchange [cs_item_sk] #8 - WholeStageCodegen (13) + Exchange [cs_bill_customer_sk,cs_item_sk] #10 + WholeStageCodegen (15) Project [cs_bill_customer_sk,cs_item_sk,cs_quantity] BroadcastHashJoin [cs_sold_date_sk,d_date_sk] Filter [cs_bill_customer_sk,cs_item_sk] @@ -92,7 +99,7 @@ TakeOrderedAndProject [i_item_id,i_item_desc,s_store_id,s_store_name,store_sales InputAdapter Scan parquet default.catalog_sales [cs_bill_customer_sk,cs_item_sk,cs_quantity,cs_sold_date_sk] SubqueryBroadcast [d_date_sk] #3 - BroadcastExchange #9 + BroadcastExchange #11 WholeStageCodegen (1) Project [d_date_sk] Filter [d_year,d_date_sk] @@ -100,4 +107,4 @@ TakeOrderedAndProject [i_item_id,i_item_desc,s_store_id,s_store_name,store_sales InputAdapter Scan parquet default.date_dim [d_date_sk,d_year] InputAdapter - ReusedExchange [d_date_sk] #9 + ReusedExchange [d_date_sk] #11 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q30.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q30.sf100/explain.txt index 35b9877c4fd09..b2d52de3cae98 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q30.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q30.sf100/explain.txt @@ -287,7 +287,7 @@ Input [3]: [ctr_state#34, sum#42, count#43] Keys [1]: [ctr_state#34] Functions [1]: [avg(ctr_total_return#35)] Aggregate Attributes [1]: [avg(ctr_total_return#35)#45] -Results [2]: [CheckOverflow((promote_precision(avg(ctr_total_return#35)#45) * 1.200000), DecimalType(24,7), true) AS (avg(ctr_total_return) * 1.2)#46, ctr_state#34 AS ctr_state#34#47] +Results [2]: [CheckOverflow((promote_precision(avg(ctr_total_return#35)#45) * 1.200000), DecimalType(24,7)) AS (avg(ctr_total_return) * 1.2)#46, ctr_state#34 AS ctr_state#34#47] (51) Filter [codegen id : 16] Input [2]: [(avg(ctr_total_return) * 1.2)#46, ctr_state#34#47] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q30/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q30/explain.txt index fdf276c01e19a..333930275bbd1 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q30/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q30/explain.txt @@ -199,7 +199,7 @@ Input [3]: [ctr_state#15, sum#22, count#23] Keys [1]: [ctr_state#15] Functions [1]: [avg(ctr_total_return#16)] Aggregate Attributes [1]: [avg(ctr_total_return#16)#25] -Results [2]: [CheckOverflow((promote_precision(avg(ctr_total_return#16)#25) * 1.200000), DecimalType(24,7), true) AS (avg(ctr_total_return) * 1.2)#26, ctr_state#15 AS ctr_state#15#27] +Results [2]: [CheckOverflow((promote_precision(avg(ctr_total_return#16)#25) * 1.200000), DecimalType(24,7)) AS (avg(ctr_total_return) * 1.2)#26, ctr_state#15 AS ctr_state#15#27] (32) Filter [codegen id : 8] Input [2]: [(avg(ctr_total_return) * 1.2)#26, ctr_state#15#27] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q31.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q31.sf100/explain.txt index 807506df80411..c1bff1a691dc7 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q31.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q31.sf100/explain.txt @@ -599,10 +599,10 @@ Arguments: HashedRelationBroadcastMode(List(input[0, string, true]),false), [id= (107) BroadcastHashJoin [codegen id : 42] Left keys [1]: [ca_county#41] Right keys [1]: [ca_county#55] -Join condition: ((CASE WHEN (web_sales#60 > 0.00) THEN CheckOverflow((promote_precision(web_sales#73) / promote_precision(web_sales#60)), DecimalType(37,20), true) END > CASE WHEN (store_sales#45 > 0.00) THEN CheckOverflow((promote_precision(store_sales#16) / promote_precision(store_sales#45)), DecimalType(37,20), true) END) AND (CASE WHEN (web_sales#73 > 0.00) THEN CheckOverflow((promote_precision(web_sales#87) / promote_precision(web_sales#73)), DecimalType(37,20), true) END > CASE WHEN (store_sales#16 > 0.00) THEN CheckOverflow((promote_precision(store_sales#30) / promote_precision(store_sales#16)), DecimalType(37,20), true) END)) +Join condition: ((CASE WHEN (web_sales#60 > 0.00) THEN CheckOverflow((promote_precision(web_sales#73) / promote_precision(web_sales#60)), DecimalType(37,20)) END > CASE WHEN (store_sales#45 > 0.00) THEN CheckOverflow((promote_precision(store_sales#16) / promote_precision(store_sales#45)), DecimalType(37,20)) END) AND (CASE WHEN (web_sales#73 > 0.00) THEN CheckOverflow((promote_precision(web_sales#87) / promote_precision(web_sales#73)), DecimalType(37,20)) END > CASE WHEN (store_sales#16 > 0.00) THEN CheckOverflow((promote_precision(store_sales#30) / promote_precision(store_sales#16)), DecimalType(37,20)) END)) (108) Project [codegen id : 42] -Output [6]: [ca_county#41, d_year#37, CheckOverflow((promote_precision(web_sales#73) / promote_precision(web_sales#60)), DecimalType(37,20), true) AS web_q1_q2_increase#90, CheckOverflow((promote_precision(store_sales#16) / promote_precision(store_sales#45)), DecimalType(37,20), true) AS store_q1_q2_increase#91, CheckOverflow((promote_precision(web_sales#87) / promote_precision(web_sales#73)), DecimalType(37,20), true) AS web_q2_q3_increase#92, CheckOverflow((promote_precision(store_sales#30) / promote_precision(store_sales#16)), DecimalType(37,20), true) AS store_q2_q3_increase#93] +Output [6]: [ca_county#41, d_year#37, CheckOverflow((promote_precision(web_sales#73) / promote_precision(web_sales#60)), DecimalType(37,20)) AS web_q1_q2_increase#90, CheckOverflow((promote_precision(store_sales#16) / promote_precision(store_sales#45)), DecimalType(37,20)) AS store_q1_q2_increase#91, CheckOverflow((promote_precision(web_sales#87) / promote_precision(web_sales#73)), DecimalType(37,20)) AS web_q2_q3_increase#92, CheckOverflow((promote_precision(store_sales#30) / promote_precision(store_sales#16)), DecimalType(37,20)) AS store_q2_q3_increase#93] Input [9]: [store_sales#16, store_sales#30, ca_county#41, d_year#37, store_sales#45, ca_county#55, web_sales#60, web_sales#73, web_sales#87] (109) Exchange diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q31/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q31/explain.txt index 124f6d2dacf17..d5c2cc3377a7e 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q31/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q31/explain.txt @@ -429,7 +429,7 @@ Arguments: HashedRelationBroadcastMode(List(input[0, string, true]),false), [id= (72) BroadcastHashJoin [codegen id : 24] Left keys [1]: [ca_county#51] Right keys [1]: [ca_county#65] -Join condition: (CASE WHEN (web_sales#56 > 0.00) THEN CheckOverflow((promote_precision(web_sales#69) / promote_precision(web_sales#56)), DecimalType(37,20), true) END > CASE WHEN (store_sales#15 > 0.00) THEN CheckOverflow((promote_precision(store_sales#28) / promote_precision(store_sales#15)), DecimalType(37,20), true) END) +Join condition: (CASE WHEN (web_sales#56 > 0.00) THEN CheckOverflow((promote_precision(web_sales#69) / promote_precision(web_sales#56)), DecimalType(37,20)) END > CASE WHEN (store_sales#15 > 0.00) THEN CheckOverflow((promote_precision(store_sales#28) / promote_precision(store_sales#15)), DecimalType(37,20)) END) (73) Project [codegen id : 24] Output [8]: [ca_county#9, d_year#6, store_sales#15, store_sales#28, store_sales#42, ca_county#51, web_sales#56, web_sales#69] @@ -499,10 +499,10 @@ Arguments: HashedRelationBroadcastMode(List(input[0, string, true]),false), [id= (87) BroadcastHashJoin [codegen id : 24] Left keys [1]: [ca_county#51] Right keys [1]: [ca_county#78] -Join condition: (CASE WHEN (web_sales#69 > 0.00) THEN CheckOverflow((promote_precision(web_sales#82) / promote_precision(web_sales#69)), DecimalType(37,20), true) END > CASE WHEN (store_sales#28 > 0.00) THEN CheckOverflow((promote_precision(store_sales#42) / promote_precision(store_sales#28)), DecimalType(37,20), true) END) +Join condition: (CASE WHEN (web_sales#69 > 0.00) THEN CheckOverflow((promote_precision(web_sales#82) / promote_precision(web_sales#69)), DecimalType(37,20)) END > CASE WHEN (store_sales#28 > 0.00) THEN CheckOverflow((promote_precision(store_sales#42) / promote_precision(store_sales#28)), DecimalType(37,20)) END) (88) Project [codegen id : 24] -Output [6]: [ca_county#9, d_year#6, CheckOverflow((promote_precision(web_sales#69) / promote_precision(web_sales#56)), DecimalType(37,20), true) AS web_q1_q2_increase#84, CheckOverflow((promote_precision(store_sales#28) / promote_precision(store_sales#15)), DecimalType(37,20), true) AS store_q1_q2_increase#85, CheckOverflow((promote_precision(web_sales#82) / promote_precision(web_sales#69)), DecimalType(37,20), true) AS web_q2_q3_increase#86, CheckOverflow((promote_precision(store_sales#42) / promote_precision(store_sales#28)), DecimalType(37,20), true) AS store_q2_q3_increase#87] +Output [6]: [ca_county#9, d_year#6, CheckOverflow((promote_precision(web_sales#69) / promote_precision(web_sales#56)), DecimalType(37,20)) AS web_q1_q2_increase#84, CheckOverflow((promote_precision(store_sales#28) / promote_precision(store_sales#15)), DecimalType(37,20)) AS store_q1_q2_increase#85, CheckOverflow((promote_precision(web_sales#82) / promote_precision(web_sales#69)), DecimalType(37,20)) AS web_q2_q3_increase#86, CheckOverflow((promote_precision(store_sales#42) / promote_precision(store_sales#28)), DecimalType(37,20)) AS store_q2_q3_increase#87] Input [10]: [ca_county#9, d_year#6, store_sales#15, store_sales#28, store_sales#42, ca_county#51, web_sales#56, web_sales#69, ca_county#78, web_sales#82] (89) Exchange diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q32.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q32.sf100/explain.txt index 1ace9e7f294aa..92ba279df59fe 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q32.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q32.sf100/explain.txt @@ -93,7 +93,7 @@ Input [3]: [cs_item_sk#4, sum#11, count#12] Keys [1]: [cs_item_sk#4] Functions [1]: [avg(UnscaledValue(cs_ext_discount_amt#5))] Aggregate Attributes [1]: [avg(UnscaledValue(cs_ext_discount_amt#5))#14] -Results [2]: [CheckOverflow((1.300000 * promote_precision(cast((avg(UnscaledValue(cs_ext_discount_amt#5))#14 / 100.0) as decimal(11,6)))), DecimalType(14,7), true) AS (1.3 * avg(cs_ext_discount_amt))#15, cs_item_sk#4] +Results [2]: [CheckOverflow((1.300000 * promote_precision(cast((avg(UnscaledValue(cs_ext_discount_amt#5))#14 / 100.0) as decimal(11,6)))), DecimalType(14,7)) AS (1.3 * avg(cs_ext_discount_amt))#15, cs_item_sk#4] (15) Filter Input [2]: [(1.3 * avg(cs_ext_discount_amt))#15, cs_item_sk#4] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q32/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q32/explain.txt index f6c9b9ed7dcef..e221defe867c1 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q32/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q32/explain.txt @@ -117,7 +117,7 @@ Input [3]: [cs_item_sk#8, sum#14, count#15] Keys [1]: [cs_item_sk#8] Functions [1]: [avg(UnscaledValue(cs_ext_discount_amt#9))] Aggregate Attributes [1]: [avg(UnscaledValue(cs_ext_discount_amt#9))#17] -Results [2]: [CheckOverflow((1.300000 * promote_precision(cast((avg(UnscaledValue(cs_ext_discount_amt#9))#17 / 100.0) as decimal(11,6)))), DecimalType(14,7), true) AS (1.3 * avg(cs_ext_discount_amt))#18, cs_item_sk#8] +Results [2]: [CheckOverflow((1.300000 * promote_precision(cast((avg(UnscaledValue(cs_ext_discount_amt#9))#17 / 100.0) as decimal(11,6)))), DecimalType(14,7)) AS (1.3 * avg(cs_ext_discount_amt))#18, cs_item_sk#8] (20) Filter [codegen id : 4] Input [2]: [(1.3 * avg(cs_ext_discount_amt))#18, cs_item_sk#8] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q36.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q36.sf100/explain.txt index 6924f13d615bf..81050cfbb4475 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q36.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q36.sf100/explain.txt @@ -134,7 +134,7 @@ Input [5]: [i_category#15, i_class#16, spark_grouping_id#17, sum#20, sum#21] Keys [3]: [i_category#15, i_class#16, spark_grouping_id#17] Functions [2]: [sum(UnscaledValue(ss_net_profit#4)), sum(UnscaledValue(ss_ext_sales_price#3))] Aggregate Attributes [2]: [sum(UnscaledValue(ss_net_profit#4))#23, sum(UnscaledValue(ss_ext_sales_price#3))#24] -Results [7]: [CheckOverflow((promote_precision(MakeDecimal(sum(UnscaledValue(ss_net_profit#4))#23,17,2)) / promote_precision(MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#3))#24,17,2))), DecimalType(37,20), true) AS gross_margin#25, i_category#15, i_class#16, (cast((shiftright(spark_grouping_id#17, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#17, 0) & 1) as tinyint)) AS lochierarchy#26, (cast((shiftright(spark_grouping_id#17, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#17, 0) & 1) as tinyint)) AS _w1#27, CASE WHEN (cast((shiftright(spark_grouping_id#17, 0) & 1) as tinyint) = 0) THEN i_category#15 END AS _w2#28, CheckOverflow((promote_precision(MakeDecimal(sum(UnscaledValue(ss_net_profit#4))#23,17,2)) / promote_precision(MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#3))#24,17,2))), DecimalType(37,20), true) AS _w3#29] +Results [7]: [CheckOverflow((promote_precision(MakeDecimal(sum(UnscaledValue(ss_net_profit#4))#23,17,2)) / promote_precision(MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#3))#24,17,2))), DecimalType(37,20)) AS gross_margin#25, i_category#15, i_class#16, (cast((shiftright(spark_grouping_id#17, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#17, 0) & 1) as tinyint)) AS lochierarchy#26, (cast((shiftright(spark_grouping_id#17, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#17, 0) & 1) as tinyint)) AS _w1#27, CASE WHEN (cast((shiftright(spark_grouping_id#17, 0) & 1) as tinyint) = 0) THEN i_category#15 END AS _w2#28, CheckOverflow((promote_precision(MakeDecimal(sum(UnscaledValue(ss_net_profit#4))#23,17,2)) / promote_precision(MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#3))#24,17,2))), DecimalType(37,20)) AS _w3#29] (24) Exchange Input [7]: [gross_margin#25, i_category#15, i_class#16, lochierarchy#26, _w1#27, _w2#28, _w3#29] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q36/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q36/explain.txt index a9cad5df37b9b..7ef898a59a2c1 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q36/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q36/explain.txt @@ -134,7 +134,7 @@ Input [5]: [i_category#15, i_class#16, spark_grouping_id#17, sum#20, sum#21] Keys [3]: [i_category#15, i_class#16, spark_grouping_id#17] Functions [2]: [sum(UnscaledValue(ss_net_profit#4)), sum(UnscaledValue(ss_ext_sales_price#3))] Aggregate Attributes [2]: [sum(UnscaledValue(ss_net_profit#4))#23, sum(UnscaledValue(ss_ext_sales_price#3))#24] -Results [7]: [CheckOverflow((promote_precision(MakeDecimal(sum(UnscaledValue(ss_net_profit#4))#23,17,2)) / promote_precision(MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#3))#24,17,2))), DecimalType(37,20), true) AS gross_margin#25, i_category#15, i_class#16, (cast((shiftright(spark_grouping_id#17, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#17, 0) & 1) as tinyint)) AS lochierarchy#26, (cast((shiftright(spark_grouping_id#17, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#17, 0) & 1) as tinyint)) AS _w1#27, CASE WHEN (cast((shiftright(spark_grouping_id#17, 0) & 1) as tinyint) = 0) THEN i_category#15 END AS _w2#28, CheckOverflow((promote_precision(MakeDecimal(sum(UnscaledValue(ss_net_profit#4))#23,17,2)) / promote_precision(MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#3))#24,17,2))), DecimalType(37,20), true) AS _w3#29] +Results [7]: [CheckOverflow((promote_precision(MakeDecimal(sum(UnscaledValue(ss_net_profit#4))#23,17,2)) / promote_precision(MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#3))#24,17,2))), DecimalType(37,20)) AS gross_margin#25, i_category#15, i_class#16, (cast((shiftright(spark_grouping_id#17, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#17, 0) & 1) as tinyint)) AS lochierarchy#26, (cast((shiftright(spark_grouping_id#17, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#17, 0) & 1) as tinyint)) AS _w1#27, CASE WHEN (cast((shiftright(spark_grouping_id#17, 0) & 1) as tinyint) = 0) THEN i_category#15 END AS _w2#28, CheckOverflow((promote_precision(MakeDecimal(sum(UnscaledValue(ss_net_profit#4))#23,17,2)) / promote_precision(MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#3))#24,17,2))), DecimalType(37,20)) AS _w3#29] (24) Exchange Input [7]: [gross_margin#25, i_category#15, i_class#16, lochierarchy#26, _w1#27, _w2#28, _w3#29] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q38.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q38.sf100/explain.txt index 6011410caced0..3d266ee2c01c7 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q38.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q38.sf100/explain.txt @@ -1,71 +1,64 @@ == Physical Plan == -* HashAggregate (67) -+- Exchange (66) - +- * HashAggregate (65) - +- * HashAggregate (64) - +- Exchange (63) - +- * HashAggregate (62) - +- * SortMergeJoin LeftSemi (61) - :- * Sort (43) - : +- Exchange (42) - : +- * HashAggregate (41) - : +- Exchange (40) - : +- * HashAggregate (39) - : +- * SortMergeJoin LeftSemi (38) - : :- * Sort (20) - : : +- Exchange (19) - : : +- * HashAggregate (18) - : : +- Exchange (17) - : : +- * HashAggregate (16) - : : +- * Project (15) - : : +- * SortMergeJoin Inner (14) - : : :- * Sort (8) - : : : +- Exchange (7) - : : : +- * Project (6) - : : : +- * BroadcastHashJoin Inner BuildRight (5) - : : : :- * Filter (3) - : : : : +- * ColumnarToRow (2) - : : : : +- Scan parquet default.store_sales (1) - : : : +- ReusedExchange (4) - : : +- * Sort (13) - : : +- Exchange (12) - : : +- * Filter (11) - : : +- * ColumnarToRow (10) - : : +- Scan parquet default.customer (9) - : +- * Sort (37) - : +- Exchange (36) - : +- * HashAggregate (35) - : +- Exchange (34) - : +- * HashAggregate (33) - : +- * Project (32) - : +- * SortMergeJoin Inner (31) - : :- * Sort (28) - : : +- Exchange (27) - : : +- * Project (26) - : : +- * BroadcastHashJoin Inner BuildRight (25) - : : :- * Filter (23) - : : : +- * ColumnarToRow (22) - : : : +- Scan parquet default.catalog_sales (21) - : : +- ReusedExchange (24) - : +- * Sort (30) - : +- ReusedExchange (29) - +- * Sort (60) - +- Exchange (59) - +- * HashAggregate (58) - +- Exchange (57) - +- * HashAggregate (56) - +- * Project (55) - +- * SortMergeJoin Inner (54) - :- * Sort (51) - : +- Exchange (50) - : +- * Project (49) - : +- * BroadcastHashJoin Inner BuildRight (48) - : :- * Filter (46) - : : +- * ColumnarToRow (45) - : : +- Scan parquet default.web_sales (44) - : +- ReusedExchange (47) - +- * Sort (53) - +- ReusedExchange (52) +* HashAggregate (60) ++- Exchange (59) + +- * HashAggregate (58) + +- * Project (57) + +- * SortMergeJoin LeftSemi (56) + :- * SortMergeJoin LeftSemi (38) + : :- * Sort (20) + : : +- Exchange (19) + : : +- * HashAggregate (18) + : : +- Exchange (17) + : : +- * HashAggregate (16) + : : +- * Project (15) + : : +- * SortMergeJoin Inner (14) + : : :- * Sort (8) + : : : +- Exchange (7) + : : : +- * Project (6) + : : : +- * BroadcastHashJoin Inner BuildRight (5) + : : : :- * Filter (3) + : : : : +- * ColumnarToRow (2) + : : : : +- Scan parquet default.store_sales (1) + : : : +- ReusedExchange (4) + : : +- * Sort (13) + : : +- Exchange (12) + : : +- * Filter (11) + : : +- * ColumnarToRow (10) + : : +- Scan parquet default.customer (9) + : +- * Sort (37) + : +- Exchange (36) + : +- * HashAggregate (35) + : +- Exchange (34) + : +- * HashAggregate (33) + : +- * Project (32) + : +- * SortMergeJoin Inner (31) + : :- * Sort (28) + : : +- Exchange (27) + : : +- * Project (26) + : : +- * BroadcastHashJoin Inner BuildRight (25) + : : :- * Filter (23) + : : : +- * ColumnarToRow (22) + : : : +- Scan parquet default.catalog_sales (21) + : : +- ReusedExchange (24) + : +- * Sort (30) + : +- ReusedExchange (29) + +- * Sort (55) + +- Exchange (54) + +- * HashAggregate (53) + +- Exchange (52) + +- * HashAggregate (51) + +- * Project (50) + +- * SortMergeJoin Inner (49) + :- * Sort (46) + : +- Exchange (45) + : +- * Project (44) + : +- * BroadcastHashJoin Inner BuildRight (43) + : :- * Filter (41) + : : +- * ColumnarToRow (40) + : : +- Scan parquet default.web_sales (39) + : +- ReusedExchange (42) + +- * Sort (48) + +- ReusedExchange (47) (1) Scan parquet default.store_sales @@ -83,7 +76,7 @@ Input [2]: [ss_customer_sk#1, ss_sold_date_sk#2] Input [2]: [ss_customer_sk#1, ss_sold_date_sk#2] Condition : isnotnull(ss_customer_sk#1) -(4) ReusedExchange [Reuses operator id: 72] +(4) ReusedExchange [Reuses operator id: 65] Output [2]: [d_date_sk#4, d_date#5] (5) BroadcastHashJoin [codegen id : 2] @@ -175,7 +168,7 @@ Input [2]: [cs_bill_customer_sk#13, cs_sold_date_sk#14] Input [2]: [cs_bill_customer_sk#13, cs_sold_date_sk#14] Condition : isnotnull(cs_bill_customer_sk#13) -(24) ReusedExchange [Reuses operator id: 72] +(24) ReusedExchange [Reuses operator id: 65] Output [2]: [d_date_sk#15, d_date#16] (25) BroadcastHashJoin [codegen id : 10] @@ -242,184 +235,144 @@ Left keys [6]: [coalesce(c_last_name#9, ), isnull(c_last_name#9), coalesce(c_fir Right keys [6]: [coalesce(c_last_name#20, ), isnull(c_last_name#20), coalesce(c_first_name#19, ), isnull(c_first_name#19), coalesce(d_date#16, 1970-01-01), isnull(d_date#16)] Join condition: None -(39) HashAggregate [codegen id : 17] -Input [3]: [c_last_name#9, c_first_name#8, d_date#5] -Keys [3]: [c_last_name#9, c_first_name#8, d_date#5] -Functions: [] -Aggregate Attributes: [] -Results [3]: [c_last_name#9, c_first_name#8, d_date#5] - -(40) Exchange -Input [3]: [c_last_name#9, c_first_name#8, d_date#5] -Arguments: hashpartitioning(c_last_name#9, c_first_name#8, d_date#5, 5), ENSURE_REQUIREMENTS, [id=#23] - -(41) HashAggregate [codegen id : 18] -Input [3]: [c_last_name#9, c_first_name#8, d_date#5] -Keys [3]: [c_last_name#9, c_first_name#8, d_date#5] -Functions: [] -Aggregate Attributes: [] -Results [3]: [c_last_name#9, c_first_name#8, d_date#5] - -(42) Exchange -Input [3]: [c_last_name#9, c_first_name#8, d_date#5] -Arguments: hashpartitioning(coalesce(c_last_name#9, ), isnull(c_last_name#9), coalesce(c_first_name#8, ), isnull(c_first_name#8), coalesce(d_date#5, 1970-01-01), isnull(d_date#5), 5), ENSURE_REQUIREMENTS, [id=#24] - -(43) Sort [codegen id : 19] -Input [3]: [c_last_name#9, c_first_name#8, d_date#5] -Arguments: [coalesce(c_last_name#9, ) ASC NULLS FIRST, isnull(c_last_name#9) ASC NULLS FIRST, coalesce(c_first_name#8, ) ASC NULLS FIRST, isnull(c_first_name#8) ASC NULLS FIRST, coalesce(d_date#5, 1970-01-01) ASC NULLS FIRST, isnull(d_date#5) ASC NULLS FIRST], false, 0 - -(44) Scan parquet default.web_sales -Output [2]: [ws_bill_customer_sk#25, ws_sold_date_sk#26] +(39) Scan parquet default.web_sales +Output [2]: [ws_bill_customer_sk#23, ws_sold_date_sk#24] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(ws_sold_date_sk#26), dynamicpruningexpression(ws_sold_date_sk#26 IN dynamicpruning#3)] +PartitionFilters: [isnotnull(ws_sold_date_sk#24), dynamicpruningexpression(ws_sold_date_sk#24 IN dynamicpruning#3)] PushedFilters: [IsNotNull(ws_bill_customer_sk)] ReadSchema: struct -(45) ColumnarToRow [codegen id : 21] -Input [2]: [ws_bill_customer_sk#25, ws_sold_date_sk#26] +(40) ColumnarToRow [codegen id : 19] +Input [2]: [ws_bill_customer_sk#23, ws_sold_date_sk#24] -(46) Filter [codegen id : 21] -Input [2]: [ws_bill_customer_sk#25, ws_sold_date_sk#26] -Condition : isnotnull(ws_bill_customer_sk#25) +(41) Filter [codegen id : 19] +Input [2]: [ws_bill_customer_sk#23, ws_sold_date_sk#24] +Condition : isnotnull(ws_bill_customer_sk#23) -(47) ReusedExchange [Reuses operator id: 72] -Output [2]: [d_date_sk#27, d_date#28] +(42) ReusedExchange [Reuses operator id: 65] +Output [2]: [d_date_sk#25, d_date#26] -(48) BroadcastHashJoin [codegen id : 21] -Left keys [1]: [ws_sold_date_sk#26] -Right keys [1]: [d_date_sk#27] +(43) BroadcastHashJoin [codegen id : 19] +Left keys [1]: [ws_sold_date_sk#24] +Right keys [1]: [d_date_sk#25] Join condition: None -(49) Project [codegen id : 21] -Output [2]: [ws_bill_customer_sk#25, d_date#28] -Input [4]: [ws_bill_customer_sk#25, ws_sold_date_sk#26, d_date_sk#27, d_date#28] +(44) Project [codegen id : 19] +Output [2]: [ws_bill_customer_sk#23, d_date#26] +Input [4]: [ws_bill_customer_sk#23, ws_sold_date_sk#24, d_date_sk#25, d_date#26] -(50) Exchange -Input [2]: [ws_bill_customer_sk#25, d_date#28] -Arguments: hashpartitioning(ws_bill_customer_sk#25, 5), ENSURE_REQUIREMENTS, [id=#29] +(45) Exchange +Input [2]: [ws_bill_customer_sk#23, d_date#26] +Arguments: hashpartitioning(ws_bill_customer_sk#23, 5), ENSURE_REQUIREMENTS, [id=#27] -(51) Sort [codegen id : 22] -Input [2]: [ws_bill_customer_sk#25, d_date#28] -Arguments: [ws_bill_customer_sk#25 ASC NULLS FIRST], false, 0 +(46) Sort [codegen id : 20] +Input [2]: [ws_bill_customer_sk#23, d_date#26] +Arguments: [ws_bill_customer_sk#23 ASC NULLS FIRST], false, 0 -(52) ReusedExchange [Reuses operator id: 12] -Output [3]: [c_customer_sk#30, c_first_name#31, c_last_name#32] +(47) ReusedExchange [Reuses operator id: 12] +Output [3]: [c_customer_sk#28, c_first_name#29, c_last_name#30] -(53) Sort [codegen id : 24] -Input [3]: [c_customer_sk#30, c_first_name#31, c_last_name#32] -Arguments: [c_customer_sk#30 ASC NULLS FIRST], false, 0 +(48) Sort [codegen id : 22] +Input [3]: [c_customer_sk#28, c_first_name#29, c_last_name#30] +Arguments: [c_customer_sk#28 ASC NULLS FIRST], false, 0 -(54) SortMergeJoin [codegen id : 25] -Left keys [1]: [ws_bill_customer_sk#25] -Right keys [1]: [c_customer_sk#30] +(49) SortMergeJoin [codegen id : 23] +Left keys [1]: [ws_bill_customer_sk#23] +Right keys [1]: [c_customer_sk#28] Join condition: None -(55) Project [codegen id : 25] -Output [3]: [c_last_name#32, c_first_name#31, d_date#28] -Input [5]: [ws_bill_customer_sk#25, d_date#28, c_customer_sk#30, c_first_name#31, c_last_name#32] +(50) Project [codegen id : 23] +Output [3]: [c_last_name#30, c_first_name#29, d_date#26] +Input [5]: [ws_bill_customer_sk#23, d_date#26, c_customer_sk#28, c_first_name#29, c_last_name#30] -(56) HashAggregate [codegen id : 25] -Input [3]: [c_last_name#32, c_first_name#31, d_date#28] -Keys [3]: [c_last_name#32, c_first_name#31, d_date#28] +(51) HashAggregate [codegen id : 23] +Input [3]: [c_last_name#30, c_first_name#29, d_date#26] +Keys [3]: [c_last_name#30, c_first_name#29, d_date#26] Functions: [] Aggregate Attributes: [] -Results [3]: [c_last_name#32, c_first_name#31, d_date#28] +Results [3]: [c_last_name#30, c_first_name#29, d_date#26] -(57) Exchange -Input [3]: [c_last_name#32, c_first_name#31, d_date#28] -Arguments: hashpartitioning(c_last_name#32, c_first_name#31, d_date#28, 5), ENSURE_REQUIREMENTS, [id=#33] +(52) Exchange +Input [3]: [c_last_name#30, c_first_name#29, d_date#26] +Arguments: hashpartitioning(c_last_name#30, c_first_name#29, d_date#26, 5), ENSURE_REQUIREMENTS, [id=#31] -(58) HashAggregate [codegen id : 26] -Input [3]: [c_last_name#32, c_first_name#31, d_date#28] -Keys [3]: [c_last_name#32, c_first_name#31, d_date#28] +(53) HashAggregate [codegen id : 24] +Input [3]: [c_last_name#30, c_first_name#29, d_date#26] +Keys [3]: [c_last_name#30, c_first_name#29, d_date#26] Functions: [] Aggregate Attributes: [] -Results [3]: [c_last_name#32, c_first_name#31, d_date#28] +Results [3]: [c_last_name#30, c_first_name#29, d_date#26] -(59) Exchange -Input [3]: [c_last_name#32, c_first_name#31, d_date#28] -Arguments: hashpartitioning(coalesce(c_last_name#32, ), isnull(c_last_name#32), coalesce(c_first_name#31, ), isnull(c_first_name#31), coalesce(d_date#28, 1970-01-01), isnull(d_date#28), 5), ENSURE_REQUIREMENTS, [id=#34] +(54) Exchange +Input [3]: [c_last_name#30, c_first_name#29, d_date#26] +Arguments: hashpartitioning(coalesce(c_last_name#30, ), isnull(c_last_name#30), coalesce(c_first_name#29, ), isnull(c_first_name#29), coalesce(d_date#26, 1970-01-01), isnull(d_date#26), 5), ENSURE_REQUIREMENTS, [id=#32] -(60) Sort [codegen id : 27] -Input [3]: [c_last_name#32, c_first_name#31, d_date#28] -Arguments: [coalesce(c_last_name#32, ) ASC NULLS FIRST, isnull(c_last_name#32) ASC NULLS FIRST, coalesce(c_first_name#31, ) ASC NULLS FIRST, isnull(c_first_name#31) ASC NULLS FIRST, coalesce(d_date#28, 1970-01-01) ASC NULLS FIRST, isnull(d_date#28) ASC NULLS FIRST], false, 0 +(55) Sort [codegen id : 25] +Input [3]: [c_last_name#30, c_first_name#29, d_date#26] +Arguments: [coalesce(c_last_name#30, ) ASC NULLS FIRST, isnull(c_last_name#30) ASC NULLS FIRST, coalesce(c_first_name#29, ) ASC NULLS FIRST, isnull(c_first_name#29) ASC NULLS FIRST, coalesce(d_date#26, 1970-01-01) ASC NULLS FIRST, isnull(d_date#26) ASC NULLS FIRST], false, 0 -(61) SortMergeJoin [codegen id : 28] +(56) SortMergeJoin [codegen id : 26] Left keys [6]: [coalesce(c_last_name#9, ), isnull(c_last_name#9), coalesce(c_first_name#8, ), isnull(c_first_name#8), coalesce(d_date#5, 1970-01-01), isnull(d_date#5)] -Right keys [6]: [coalesce(c_last_name#32, ), isnull(c_last_name#32), coalesce(c_first_name#31, ), isnull(c_first_name#31), coalesce(d_date#28, 1970-01-01), isnull(d_date#28)] +Right keys [6]: [coalesce(c_last_name#30, ), isnull(c_last_name#30), coalesce(c_first_name#29, ), isnull(c_first_name#29), coalesce(d_date#26, 1970-01-01), isnull(d_date#26)] Join condition: None -(62) HashAggregate [codegen id : 28] +(57) Project [codegen id : 26] +Output: [] Input [3]: [c_last_name#9, c_first_name#8, d_date#5] -Keys [3]: [c_last_name#9, c_first_name#8, d_date#5] -Functions: [] -Aggregate Attributes: [] -Results [3]: [c_last_name#9, c_first_name#8, d_date#5] -(63) Exchange -Input [3]: [c_last_name#9, c_first_name#8, d_date#5] -Arguments: hashpartitioning(c_last_name#9, c_first_name#8, d_date#5, 5), ENSURE_REQUIREMENTS, [id=#35] - -(64) HashAggregate [codegen id : 29] -Input [3]: [c_last_name#9, c_first_name#8, d_date#5] -Keys [3]: [c_last_name#9, c_first_name#8, d_date#5] -Functions: [] -Aggregate Attributes: [] -Results: [] - -(65) HashAggregate [codegen id : 29] +(58) HashAggregate [codegen id : 26] Input: [] Keys: [] Functions [1]: [partial_count(1)] -Aggregate Attributes [1]: [count#36] -Results [1]: [count#37] +Aggregate Attributes [1]: [count#33] +Results [1]: [count#34] -(66) Exchange -Input [1]: [count#37] -Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#38] +(59) Exchange +Input [1]: [count#34] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#35] -(67) HashAggregate [codegen id : 30] -Input [1]: [count#37] +(60) HashAggregate [codegen id : 27] +Input [1]: [count#34] Keys: [] Functions [1]: [count(1)] -Aggregate Attributes [1]: [count(1)#39] -Results [1]: [count(1)#39 AS count(1)#40] +Aggregate Attributes [1]: [count(1)#36] +Results [1]: [count(1)#36 AS count(1)#37] ===== Subqueries ===== Subquery:1 Hosting operator id = 1 Hosting Expression = ss_sold_date_sk#2 IN dynamicpruning#3 -BroadcastExchange (72) -+- * Project (71) - +- * Filter (70) - +- * ColumnarToRow (69) - +- Scan parquet default.date_dim (68) +BroadcastExchange (65) ++- * Project (64) + +- * Filter (63) + +- * ColumnarToRow (62) + +- Scan parquet default.date_dim (61) -(68) Scan parquet default.date_dim -Output [3]: [d_date_sk#4, d_date#5, d_month_seq#41] +(61) Scan parquet default.date_dim +Output [3]: [d_date_sk#4, d_date#5, d_month_seq#38] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_month_seq), GreaterThanOrEqual(d_month_seq,1200), LessThanOrEqual(d_month_seq,1211), IsNotNull(d_date_sk)] ReadSchema: struct -(69) ColumnarToRow [codegen id : 1] -Input [3]: [d_date_sk#4, d_date#5, d_month_seq#41] +(62) ColumnarToRow [codegen id : 1] +Input [3]: [d_date_sk#4, d_date#5, d_month_seq#38] -(70) Filter [codegen id : 1] -Input [3]: [d_date_sk#4, d_date#5, d_month_seq#41] -Condition : (((isnotnull(d_month_seq#41) AND (d_month_seq#41 >= 1200)) AND (d_month_seq#41 <= 1211)) AND isnotnull(d_date_sk#4)) +(63) Filter [codegen id : 1] +Input [3]: [d_date_sk#4, d_date#5, d_month_seq#38] +Condition : (((isnotnull(d_month_seq#38) AND (d_month_seq#38 >= 1200)) AND (d_month_seq#38 <= 1211)) AND isnotnull(d_date_sk#4)) -(71) Project [codegen id : 1] +(64) Project [codegen id : 1] Output [2]: [d_date_sk#4, d_date#5] -Input [3]: [d_date_sk#4, d_date#5, d_month_seq#41] +Input [3]: [d_date_sk#4, d_date#5, d_month_seq#38] -(72) BroadcastExchange +(65) BroadcastExchange Input [2]: [d_date_sk#4, d_date#5] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#42] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#39] Subquery:2 Hosting operator id = 21 Hosting Expression = cs_sold_date_sk#14 IN dynamicpruning#3 -Subquery:3 Hosting operator id = 44 Hosting Expression = ws_sold_date_sk#26 IN dynamicpruning#3 +Subquery:3 Hosting operator id = 39 Hosting Expression = ws_sold_date_sk#24 IN dynamicpruning#3 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q38.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q38.sf100/simplified.txt index eda0d4b03f483..cc66a0040ef9a 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q38.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q38.sf100/simplified.txt @@ -1,135 +1,122 @@ -WholeStageCodegen (30) +WholeStageCodegen (27) HashAggregate [count] [count(1),count(1),count] InputAdapter Exchange #1 - WholeStageCodegen (29) + WholeStageCodegen (26) HashAggregate [count,count] - HashAggregate [c_last_name,c_first_name,d_date] - InputAdapter - Exchange [c_last_name,c_first_name,d_date] #2 - WholeStageCodegen (28) - HashAggregate [c_last_name,c_first_name,d_date] - SortMergeJoin [c_last_name,c_first_name,d_date,c_last_name,c_first_name,d_date] - InputAdapter - WholeStageCodegen (19) - Sort [c_last_name,c_first_name,d_date] - InputAdapter - Exchange [c_last_name,c_first_name,d_date] #3 - WholeStageCodegen (18) - HashAggregate [c_last_name,c_first_name,d_date] - InputAdapter - Exchange [c_last_name,c_first_name,d_date] #4 - WholeStageCodegen (17) - HashAggregate [c_last_name,c_first_name,d_date] - SortMergeJoin [c_last_name,c_first_name,d_date,c_last_name,c_first_name,d_date] + Project + SortMergeJoin [c_last_name,c_first_name,d_date,c_last_name,c_first_name,d_date] + InputAdapter + WholeStageCodegen (17) + SortMergeJoin [c_last_name,c_first_name,d_date,c_last_name,c_first_name,d_date] + InputAdapter + WholeStageCodegen (8) + Sort [c_last_name,c_first_name,d_date] + InputAdapter + Exchange [c_last_name,c_first_name,d_date] #2 + WholeStageCodegen (7) + HashAggregate [c_last_name,c_first_name,d_date] + InputAdapter + Exchange [c_last_name,c_first_name,d_date] #3 + WholeStageCodegen (6) + HashAggregate [c_last_name,c_first_name,d_date] + Project [c_last_name,c_first_name,d_date] + SortMergeJoin [ss_customer_sk,c_customer_sk] InputAdapter - WholeStageCodegen (8) - Sort [c_last_name,c_first_name,d_date] + WholeStageCodegen (3) + Sort [ss_customer_sk] InputAdapter - Exchange [c_last_name,c_first_name,d_date] #5 - WholeStageCodegen (7) - HashAggregate [c_last_name,c_first_name,d_date] - InputAdapter - Exchange [c_last_name,c_first_name,d_date] #6 - WholeStageCodegen (6) - HashAggregate [c_last_name,c_first_name,d_date] - Project [c_last_name,c_first_name,d_date] - SortMergeJoin [ss_customer_sk,c_customer_sk] - InputAdapter - WholeStageCodegen (3) - Sort [ss_customer_sk] - InputAdapter - Exchange [ss_customer_sk] #7 - WholeStageCodegen (2) - Project [ss_customer_sk,d_date] - BroadcastHashJoin [ss_sold_date_sk,d_date_sk] - Filter [ss_customer_sk] - ColumnarToRow - InputAdapter - Scan parquet default.store_sales [ss_customer_sk,ss_sold_date_sk] - SubqueryBroadcast [d_date_sk] #1 - BroadcastExchange #8 - WholeStageCodegen (1) - Project [d_date_sk,d_date] - Filter [d_month_seq,d_date_sk] - ColumnarToRow - InputAdapter - Scan parquet default.date_dim [d_date_sk,d_date,d_month_seq] - InputAdapter - ReusedExchange [d_date_sk,d_date] #8 - InputAdapter - WholeStageCodegen (5) - Sort [c_customer_sk] - InputAdapter - Exchange [c_customer_sk] #9 - WholeStageCodegen (4) - Filter [c_customer_sk] - ColumnarToRow - InputAdapter - Scan parquet default.customer [c_customer_sk,c_first_name,c_last_name] + Exchange [ss_customer_sk] #4 + WholeStageCodegen (2) + Project [ss_customer_sk,d_date] + BroadcastHashJoin [ss_sold_date_sk,d_date_sk] + Filter [ss_customer_sk] + ColumnarToRow + InputAdapter + Scan parquet default.store_sales [ss_customer_sk,ss_sold_date_sk] + SubqueryBroadcast [d_date_sk] #1 + BroadcastExchange #5 + WholeStageCodegen (1) + Project [d_date_sk,d_date] + Filter [d_month_seq,d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_date,d_month_seq] + InputAdapter + ReusedExchange [d_date_sk,d_date] #5 InputAdapter - WholeStageCodegen (16) - Sort [c_last_name,c_first_name,d_date] + WholeStageCodegen (5) + Sort [c_customer_sk] InputAdapter - Exchange [c_last_name,c_first_name,d_date] #10 - WholeStageCodegen (15) - HashAggregate [c_last_name,c_first_name,d_date] - InputAdapter - Exchange [c_last_name,c_first_name,d_date] #11 - WholeStageCodegen (14) - HashAggregate [c_last_name,c_first_name,d_date] - Project [c_last_name,c_first_name,d_date] - SortMergeJoin [cs_bill_customer_sk,c_customer_sk] - InputAdapter - WholeStageCodegen (11) - Sort [cs_bill_customer_sk] - InputAdapter - Exchange [cs_bill_customer_sk] #12 - WholeStageCodegen (10) - Project [cs_bill_customer_sk,d_date] - BroadcastHashJoin [cs_sold_date_sk,d_date_sk] - Filter [cs_bill_customer_sk] - ColumnarToRow - InputAdapter - Scan parquet default.catalog_sales [cs_bill_customer_sk,cs_sold_date_sk] - ReusedSubquery [d_date_sk] #1 - InputAdapter - ReusedExchange [d_date_sk,d_date] #8 - InputAdapter - WholeStageCodegen (13) - Sort [c_customer_sk] - InputAdapter - ReusedExchange [c_customer_sk,c_first_name,c_last_name] #9 - InputAdapter - WholeStageCodegen (27) - Sort [c_last_name,c_first_name,d_date] + Exchange [c_customer_sk] #6 + WholeStageCodegen (4) + Filter [c_customer_sk] + ColumnarToRow + InputAdapter + Scan parquet default.customer [c_customer_sk,c_first_name,c_last_name] + InputAdapter + WholeStageCodegen (16) + Sort [c_last_name,c_first_name,d_date] + InputAdapter + Exchange [c_last_name,c_first_name,d_date] #7 + WholeStageCodegen (15) + HashAggregate [c_last_name,c_first_name,d_date] + InputAdapter + Exchange [c_last_name,c_first_name,d_date] #8 + WholeStageCodegen (14) + HashAggregate [c_last_name,c_first_name,d_date] + Project [c_last_name,c_first_name,d_date] + SortMergeJoin [cs_bill_customer_sk,c_customer_sk] + InputAdapter + WholeStageCodegen (11) + Sort [cs_bill_customer_sk] + InputAdapter + Exchange [cs_bill_customer_sk] #9 + WholeStageCodegen (10) + Project [cs_bill_customer_sk,d_date] + BroadcastHashJoin [cs_sold_date_sk,d_date_sk] + Filter [cs_bill_customer_sk] + ColumnarToRow + InputAdapter + Scan parquet default.catalog_sales [cs_bill_customer_sk,cs_sold_date_sk] + ReusedSubquery [d_date_sk] #1 + InputAdapter + ReusedExchange [d_date_sk,d_date] #5 + InputAdapter + WholeStageCodegen (13) + Sort [c_customer_sk] + InputAdapter + ReusedExchange [c_customer_sk,c_first_name,c_last_name] #6 + InputAdapter + WholeStageCodegen (25) + Sort [c_last_name,c_first_name,d_date] + InputAdapter + Exchange [c_last_name,c_first_name,d_date] #10 + WholeStageCodegen (24) + HashAggregate [c_last_name,c_first_name,d_date] InputAdapter - Exchange [c_last_name,c_first_name,d_date] #13 - WholeStageCodegen (26) + Exchange [c_last_name,c_first_name,d_date] #11 + WholeStageCodegen (23) HashAggregate [c_last_name,c_first_name,d_date] - InputAdapter - Exchange [c_last_name,c_first_name,d_date] #14 - WholeStageCodegen (25) - HashAggregate [c_last_name,c_first_name,d_date] - Project [c_last_name,c_first_name,d_date] - SortMergeJoin [ws_bill_customer_sk,c_customer_sk] - InputAdapter - WholeStageCodegen (22) - Sort [ws_bill_customer_sk] - InputAdapter - Exchange [ws_bill_customer_sk] #15 - WholeStageCodegen (21) - Project [ws_bill_customer_sk,d_date] - BroadcastHashJoin [ws_sold_date_sk,d_date_sk] - Filter [ws_bill_customer_sk] - ColumnarToRow - InputAdapter - Scan parquet default.web_sales [ws_bill_customer_sk,ws_sold_date_sk] - ReusedSubquery [d_date_sk] #1 - InputAdapter - ReusedExchange [d_date_sk,d_date] #8 - InputAdapter - WholeStageCodegen (24) - Sort [c_customer_sk] - InputAdapter - ReusedExchange [c_customer_sk,c_first_name,c_last_name] #9 + Project [c_last_name,c_first_name,d_date] + SortMergeJoin [ws_bill_customer_sk,c_customer_sk] + InputAdapter + WholeStageCodegen (20) + Sort [ws_bill_customer_sk] + InputAdapter + Exchange [ws_bill_customer_sk] #12 + WholeStageCodegen (19) + Project [ws_bill_customer_sk,d_date] + BroadcastHashJoin [ws_sold_date_sk,d_date_sk] + Filter [ws_bill_customer_sk] + ColumnarToRow + InputAdapter + Scan parquet default.web_sales [ws_bill_customer_sk,ws_sold_date_sk] + ReusedSubquery [d_date_sk] #1 + InputAdapter + ReusedExchange [d_date_sk,d_date] #5 + InputAdapter + WholeStageCodegen (22) + Sort [c_customer_sk] + InputAdapter + ReusedExchange [c_customer_sk,c_first_name,c_last_name] #6 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q38/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q38/explain.txt index ca4a34d7b6087..60190c9f39e43 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q38/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q38/explain.txt @@ -1,54 +1,51 @@ == Physical Plan == -* HashAggregate (50) -+- Exchange (49) - +- * HashAggregate (48) - +- * HashAggregate (47) - +- * HashAggregate (46) - +- * BroadcastHashJoin LeftSemi BuildRight (45) - :- * HashAggregate (31) - : +- * HashAggregate (30) - : +- * BroadcastHashJoin LeftSemi BuildRight (29) - : :- * HashAggregate (15) - : : +- Exchange (14) - : : +- * HashAggregate (13) - : : +- * Project (12) - : : +- * BroadcastHashJoin Inner BuildRight (11) - : : :- * Project (6) - : : : +- * BroadcastHashJoin Inner BuildRight (5) - : : : :- * Filter (3) - : : : : +- * ColumnarToRow (2) - : : : : +- Scan parquet default.store_sales (1) - : : : +- ReusedExchange (4) - : : +- BroadcastExchange (10) - : : +- * Filter (9) - : : +- * ColumnarToRow (8) - : : +- Scan parquet default.customer (7) - : +- BroadcastExchange (28) - : +- * HashAggregate (27) - : +- Exchange (26) - : +- * HashAggregate (25) - : +- * Project (24) - : +- * BroadcastHashJoin Inner BuildRight (23) - : :- * Project (21) - : : +- * BroadcastHashJoin Inner BuildRight (20) - : : :- * Filter (18) - : : : +- * ColumnarToRow (17) - : : : +- Scan parquet default.catalog_sales (16) - : : +- ReusedExchange (19) - : +- ReusedExchange (22) - +- BroadcastExchange (44) - +- * HashAggregate (43) - +- Exchange (42) - +- * HashAggregate (41) - +- * Project (40) - +- * BroadcastHashJoin Inner BuildRight (39) - :- * Project (37) - : +- * BroadcastHashJoin Inner BuildRight (36) - : :- * Filter (34) - : : +- * ColumnarToRow (33) - : : +- Scan parquet default.web_sales (32) - : +- ReusedExchange (35) - +- ReusedExchange (38) +* HashAggregate (47) ++- Exchange (46) + +- * HashAggregate (45) + +- * Project (44) + +- * BroadcastHashJoin LeftSemi BuildRight (43) + :- * BroadcastHashJoin LeftSemi BuildRight (29) + : :- * HashAggregate (15) + : : +- Exchange (14) + : : +- * HashAggregate (13) + : : +- * Project (12) + : : +- * BroadcastHashJoin Inner BuildRight (11) + : : :- * Project (6) + : : : +- * BroadcastHashJoin Inner BuildRight (5) + : : : :- * Filter (3) + : : : : +- * ColumnarToRow (2) + : : : : +- Scan parquet default.store_sales (1) + : : : +- ReusedExchange (4) + : : +- BroadcastExchange (10) + : : +- * Filter (9) + : : +- * ColumnarToRow (8) + : : +- Scan parquet default.customer (7) + : +- BroadcastExchange (28) + : +- * HashAggregate (27) + : +- Exchange (26) + : +- * HashAggregate (25) + : +- * Project (24) + : +- * BroadcastHashJoin Inner BuildRight (23) + : :- * Project (21) + : : +- * BroadcastHashJoin Inner BuildRight (20) + : : :- * Filter (18) + : : : +- * ColumnarToRow (17) + : : : +- Scan parquet default.catalog_sales (16) + : : +- ReusedExchange (19) + : +- ReusedExchange (22) + +- BroadcastExchange (42) + +- * HashAggregate (41) + +- Exchange (40) + +- * HashAggregate (39) + +- * Project (38) + +- * BroadcastHashJoin Inner BuildRight (37) + :- * Project (35) + : +- * BroadcastHashJoin Inner BuildRight (34) + : :- * Filter (32) + : : +- * ColumnarToRow (31) + : : +- Scan parquet default.web_sales (30) + : +- ReusedExchange (33) + +- ReusedExchange (36) (1) Scan parquet default.store_sales @@ -66,7 +63,7 @@ Input [2]: [ss_customer_sk#1, ss_sold_date_sk#2] Input [2]: [ss_customer_sk#1, ss_sold_date_sk#2] Condition : isnotnull(ss_customer_sk#1) -(4) ReusedExchange [Reuses operator id: 55] +(4) ReusedExchange [Reuses operator id: 52] Output [2]: [d_date_sk#4, d_date#5] (5) BroadcastHashJoin [codegen id : 3] @@ -138,7 +135,7 @@ Input [2]: [cs_bill_customer_sk#11, cs_sold_date_sk#12] Input [2]: [cs_bill_customer_sk#11, cs_sold_date_sk#12] Condition : isnotnull(cs_bill_customer_sk#11) -(19) ReusedExchange [Reuses operator id: 55] +(19) ReusedExchange [Reuses operator id: 52] Output [2]: [d_date_sk#13, d_date#14] (20) BroadcastHashJoin [codegen id : 6] @@ -189,21 +186,7 @@ Left keys [6]: [coalesce(c_last_name#8, ), isnull(c_last_name#8), coalesce(c_fir Right keys [6]: [coalesce(c_last_name#17, ), isnull(c_last_name#17), coalesce(c_first_name#16, ), isnull(c_first_name#16), coalesce(d_date#14, 1970-01-01), isnull(d_date#14)] Join condition: None -(30) HashAggregate [codegen id : 12] -Input [3]: [c_last_name#8, c_first_name#7, d_date#5] -Keys [3]: [c_last_name#8, c_first_name#7, d_date#5] -Functions: [] -Aggregate Attributes: [] -Results [3]: [c_last_name#8, c_first_name#7, d_date#5] - -(31) HashAggregate [codegen id : 12] -Input [3]: [c_last_name#8, c_first_name#7, d_date#5] -Keys [3]: [c_last_name#8, c_first_name#7, d_date#5] -Functions: [] -Aggregate Attributes: [] -Results [3]: [c_last_name#8, c_first_name#7, d_date#5] - -(32) Scan parquet default.web_sales +(30) Scan parquet default.web_sales Output [2]: [ws_bill_customer_sk#20, ws_sold_date_sk#21] Batched: true Location: InMemoryFileIndex [] @@ -211,90 +194,80 @@ PartitionFilters: [isnotnull(ws_sold_date_sk#21), dynamicpruningexpression(ws_so PushedFilters: [IsNotNull(ws_bill_customer_sk)] ReadSchema: struct -(33) ColumnarToRow [codegen id : 10] +(31) ColumnarToRow [codegen id : 10] Input [2]: [ws_bill_customer_sk#20, ws_sold_date_sk#21] -(34) Filter [codegen id : 10] +(32) Filter [codegen id : 10] Input [2]: [ws_bill_customer_sk#20, ws_sold_date_sk#21] Condition : isnotnull(ws_bill_customer_sk#20) -(35) ReusedExchange [Reuses operator id: 55] +(33) ReusedExchange [Reuses operator id: 52] Output [2]: [d_date_sk#22, d_date#23] -(36) BroadcastHashJoin [codegen id : 10] +(34) BroadcastHashJoin [codegen id : 10] Left keys [1]: [ws_sold_date_sk#21] Right keys [1]: [d_date_sk#22] Join condition: None -(37) Project [codegen id : 10] +(35) Project [codegen id : 10] Output [2]: [ws_bill_customer_sk#20, d_date#23] Input [4]: [ws_bill_customer_sk#20, ws_sold_date_sk#21, d_date_sk#22, d_date#23] -(38) ReusedExchange [Reuses operator id: 10] +(36) ReusedExchange [Reuses operator id: 10] Output [3]: [c_customer_sk#24, c_first_name#25, c_last_name#26] -(39) BroadcastHashJoin [codegen id : 10] +(37) BroadcastHashJoin [codegen id : 10] Left keys [1]: [ws_bill_customer_sk#20] Right keys [1]: [c_customer_sk#24] Join condition: None -(40) Project [codegen id : 10] +(38) Project [codegen id : 10] Output [3]: [c_last_name#26, c_first_name#25, d_date#23] Input [5]: [ws_bill_customer_sk#20, d_date#23, c_customer_sk#24, c_first_name#25, c_last_name#26] -(41) HashAggregate [codegen id : 10] +(39) HashAggregate [codegen id : 10] Input [3]: [c_last_name#26, c_first_name#25, d_date#23] Keys [3]: [c_last_name#26, c_first_name#25, d_date#23] Functions: [] Aggregate Attributes: [] Results [3]: [c_last_name#26, c_first_name#25, d_date#23] -(42) Exchange +(40) Exchange Input [3]: [c_last_name#26, c_first_name#25, d_date#23] Arguments: hashpartitioning(c_last_name#26, c_first_name#25, d_date#23, 5), ENSURE_REQUIREMENTS, [id=#27] -(43) HashAggregate [codegen id : 11] +(41) HashAggregate [codegen id : 11] Input [3]: [c_last_name#26, c_first_name#25, d_date#23] Keys [3]: [c_last_name#26, c_first_name#25, d_date#23] Functions: [] Aggregate Attributes: [] Results [3]: [c_last_name#26, c_first_name#25, d_date#23] -(44) BroadcastExchange +(42) BroadcastExchange Input [3]: [c_last_name#26, c_first_name#25, d_date#23] Arguments: HashedRelationBroadcastMode(List(coalesce(input[0, string, true], ), isnull(input[0, string, true]), coalesce(input[1, string, true], ), isnull(input[1, string, true]), coalesce(input[2, date, true], 1970-01-01), isnull(input[2, date, true])),false), [id=#28] -(45) BroadcastHashJoin [codegen id : 12] +(43) BroadcastHashJoin [codegen id : 12] Left keys [6]: [coalesce(c_last_name#8, ), isnull(c_last_name#8), coalesce(c_first_name#7, ), isnull(c_first_name#7), coalesce(d_date#5, 1970-01-01), isnull(d_date#5)] Right keys [6]: [coalesce(c_last_name#26, ), isnull(c_last_name#26), coalesce(c_first_name#25, ), isnull(c_first_name#25), coalesce(d_date#23, 1970-01-01), isnull(d_date#23)] Join condition: None -(46) HashAggregate [codegen id : 12] -Input [3]: [c_last_name#8, c_first_name#7, d_date#5] -Keys [3]: [c_last_name#8, c_first_name#7, d_date#5] -Functions: [] -Aggregate Attributes: [] -Results [3]: [c_last_name#8, c_first_name#7, d_date#5] - -(47) HashAggregate [codegen id : 12] +(44) Project [codegen id : 12] +Output: [] Input [3]: [c_last_name#8, c_first_name#7, d_date#5] -Keys [3]: [c_last_name#8, c_first_name#7, d_date#5] -Functions: [] -Aggregate Attributes: [] -Results: [] -(48) HashAggregate [codegen id : 12] +(45) HashAggregate [codegen id : 12] Input: [] Keys: [] Functions [1]: [partial_count(1)] Aggregate Attributes [1]: [count#29] Results [1]: [count#30] -(49) Exchange +(46) Exchange Input [1]: [count#30] Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#31] -(50) HashAggregate [codegen id : 13] +(47) HashAggregate [codegen id : 13] Input [1]: [count#30] Keys: [] Functions [1]: [count(1)] @@ -304,37 +277,37 @@ Results [1]: [count(1)#32 AS count(1)#33] ===== Subqueries ===== Subquery:1 Hosting operator id = 1 Hosting Expression = ss_sold_date_sk#2 IN dynamicpruning#3 -BroadcastExchange (55) -+- * Project (54) - +- * Filter (53) - +- * ColumnarToRow (52) - +- Scan parquet default.date_dim (51) +BroadcastExchange (52) ++- * Project (51) + +- * Filter (50) + +- * ColumnarToRow (49) + +- Scan parquet default.date_dim (48) -(51) Scan parquet default.date_dim +(48) Scan parquet default.date_dim Output [3]: [d_date_sk#4, d_date#5, d_month_seq#34] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_month_seq), GreaterThanOrEqual(d_month_seq,1200), LessThanOrEqual(d_month_seq,1211), IsNotNull(d_date_sk)] ReadSchema: struct -(52) ColumnarToRow [codegen id : 1] +(49) ColumnarToRow [codegen id : 1] Input [3]: [d_date_sk#4, d_date#5, d_month_seq#34] -(53) Filter [codegen id : 1] +(50) Filter [codegen id : 1] Input [3]: [d_date_sk#4, d_date#5, d_month_seq#34] Condition : (((isnotnull(d_month_seq#34) AND (d_month_seq#34 >= 1200)) AND (d_month_seq#34 <= 1211)) AND isnotnull(d_date_sk#4)) -(54) Project [codegen id : 1] +(51) Project [codegen id : 1] Output [2]: [d_date_sk#4, d_date#5] Input [3]: [d_date_sk#4, d_date#5, d_month_seq#34] -(55) BroadcastExchange +(52) BroadcastExchange Input [2]: [d_date_sk#4, d_date#5] Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#35] Subquery:2 Hosting operator id = 16 Hosting Expression = cs_sold_date_sk#12 IN dynamicpruning#3 -Subquery:3 Hosting operator id = 32 Hosting Expression = ws_sold_date_sk#21 IN dynamicpruning#3 +Subquery:3 Hosting operator id = 30 Hosting Expression = ws_sold_date_sk#21 IN dynamicpruning#3 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q38/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q38/simplified.txt index 7f96f5657836a..34d46c5671774 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q38/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q38/simplified.txt @@ -4,81 +4,78 @@ WholeStageCodegen (13) Exchange #1 WholeStageCodegen (12) HashAggregate [count,count] - HashAggregate [c_last_name,c_first_name,d_date] - HashAggregate [c_last_name,c_first_name,d_date] + Project + BroadcastHashJoin [c_last_name,c_first_name,d_date,c_last_name,c_first_name,d_date] BroadcastHashJoin [c_last_name,c_first_name,d_date,c_last_name,c_first_name,d_date] HashAggregate [c_last_name,c_first_name,d_date] - HashAggregate [c_last_name,c_first_name,d_date] - BroadcastHashJoin [c_last_name,c_first_name,d_date,c_last_name,c_first_name,d_date] - HashAggregate [c_last_name,c_first_name,d_date] - InputAdapter - Exchange [c_last_name,c_first_name,d_date] #2 - WholeStageCodegen (3) - HashAggregate [c_last_name,c_first_name,d_date] - Project [c_last_name,c_first_name,d_date] - BroadcastHashJoin [ss_customer_sk,c_customer_sk] - Project [ss_customer_sk,d_date] - BroadcastHashJoin [ss_sold_date_sk,d_date_sk] - Filter [ss_customer_sk] - ColumnarToRow - InputAdapter - Scan parquet default.store_sales [ss_customer_sk,ss_sold_date_sk] - SubqueryBroadcast [d_date_sk] #1 - BroadcastExchange #3 - WholeStageCodegen (1) - Project [d_date_sk,d_date] - Filter [d_month_seq,d_date_sk] - ColumnarToRow - InputAdapter - Scan parquet default.date_dim [d_date_sk,d_date,d_month_seq] - InputAdapter - ReusedExchange [d_date_sk,d_date] #3 - InputAdapter - BroadcastExchange #4 - WholeStageCodegen (2) - Filter [c_customer_sk] - ColumnarToRow - InputAdapter - Scan parquet default.customer [c_customer_sk,c_first_name,c_last_name] - InputAdapter - BroadcastExchange #5 - WholeStageCodegen (7) - HashAggregate [c_last_name,c_first_name,d_date] + InputAdapter + Exchange [c_last_name,c_first_name,d_date] #2 + WholeStageCodegen (3) + HashAggregate [c_last_name,c_first_name,d_date] + Project [c_last_name,c_first_name,d_date] + BroadcastHashJoin [ss_customer_sk,c_customer_sk] + Project [ss_customer_sk,d_date] + BroadcastHashJoin [ss_sold_date_sk,d_date_sk] + Filter [ss_customer_sk] + ColumnarToRow + InputAdapter + Scan parquet default.store_sales [ss_customer_sk,ss_sold_date_sk] + SubqueryBroadcast [d_date_sk] #1 + BroadcastExchange #3 + WholeStageCodegen (1) + Project [d_date_sk,d_date] + Filter [d_month_seq,d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_date,d_month_seq] + InputAdapter + ReusedExchange [d_date_sk,d_date] #3 InputAdapter - Exchange [c_last_name,c_first_name,d_date] #6 - WholeStageCodegen (6) - HashAggregate [c_last_name,c_first_name,d_date] - Project [c_last_name,c_first_name,d_date] - BroadcastHashJoin [cs_bill_customer_sk,c_customer_sk] - Project [cs_bill_customer_sk,d_date] - BroadcastHashJoin [cs_sold_date_sk,d_date_sk] - Filter [cs_bill_customer_sk] - ColumnarToRow - InputAdapter - Scan parquet default.catalog_sales [cs_bill_customer_sk,cs_sold_date_sk] - ReusedSubquery [d_date_sk] #1 - InputAdapter - ReusedExchange [d_date_sk,d_date] #3 - InputAdapter - ReusedExchange [c_customer_sk,c_first_name,c_last_name] #4 + BroadcastExchange #4 + WholeStageCodegen (2) + Filter [c_customer_sk] + ColumnarToRow + InputAdapter + Scan parquet default.customer [c_customer_sk,c_first_name,c_last_name] InputAdapter - BroadcastExchange #7 - WholeStageCodegen (11) + BroadcastExchange #5 + WholeStageCodegen (7) HashAggregate [c_last_name,c_first_name,d_date] InputAdapter - Exchange [c_last_name,c_first_name,d_date] #8 - WholeStageCodegen (10) + Exchange [c_last_name,c_first_name,d_date] #6 + WholeStageCodegen (6) HashAggregate [c_last_name,c_first_name,d_date] Project [c_last_name,c_first_name,d_date] - BroadcastHashJoin [ws_bill_customer_sk,c_customer_sk] - Project [ws_bill_customer_sk,d_date] - BroadcastHashJoin [ws_sold_date_sk,d_date_sk] - Filter [ws_bill_customer_sk] + BroadcastHashJoin [cs_bill_customer_sk,c_customer_sk] + Project [cs_bill_customer_sk,d_date] + BroadcastHashJoin [cs_sold_date_sk,d_date_sk] + Filter [cs_bill_customer_sk] ColumnarToRow InputAdapter - Scan parquet default.web_sales [ws_bill_customer_sk,ws_sold_date_sk] + Scan parquet default.catalog_sales [cs_bill_customer_sk,cs_sold_date_sk] ReusedSubquery [d_date_sk] #1 InputAdapter ReusedExchange [d_date_sk,d_date] #3 InputAdapter ReusedExchange [c_customer_sk,c_first_name,c_last_name] #4 + InputAdapter + BroadcastExchange #7 + WholeStageCodegen (11) + HashAggregate [c_last_name,c_first_name,d_date] + InputAdapter + Exchange [c_last_name,c_first_name,d_date] #8 + WholeStageCodegen (10) + HashAggregate [c_last_name,c_first_name,d_date] + Project [c_last_name,c_first_name,d_date] + BroadcastHashJoin [ws_bill_customer_sk,c_customer_sk] + Project [ws_bill_customer_sk,d_date] + BroadcastHashJoin [ws_sold_date_sk,d_date_sk] + Filter [ws_bill_customer_sk] + ColumnarToRow + InputAdapter + Scan parquet default.web_sales [ws_bill_customer_sk,ws_sold_date_sk] + ReusedSubquery [d_date_sk] #1 + InputAdapter + ReusedExchange [d_date_sk,d_date] #3 + InputAdapter + ReusedExchange [c_customer_sk,c_first_name,c_last_name] #4 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q4.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q4.sf100/explain.txt index 40deb5feb0b4b..7ebe44763c25a 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q4.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q4.sf100/explain.txt @@ -188,7 +188,7 @@ Input [14]: [ss_customer_sk#1, ss_ext_discount_amt#2, ss_ext_sales_price#3, ss_e (16) HashAggregate [codegen id : 6] Input [12]: [c_customer_id#12, c_first_name#13, c_last_name#14, c_preferred_cust_flag#15, c_birth_country#16, c_login#17, c_email_address#18, ss_ext_discount_amt#2, ss_ext_sales_price#3, ss_ext_wholesale_cost#4, ss_ext_list_price#5, d_year#9] Keys [8]: [c_customer_id#12, c_first_name#13, c_last_name#14, c_preferred_cust_flag#15, c_birth_country#16, c_login#17, c_email_address#18, d_year#9] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#5 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#4 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#2 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#3 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#5 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#4 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#2 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#3 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))] Aggregate Attributes [2]: [sum#20, isEmpty#21] Results [10]: [c_customer_id#12, c_first_name#13, c_last_name#14, c_preferred_cust_flag#15, c_birth_country#16, c_login#17, c_email_address#18, d_year#9, sum#22, isEmpty#23] @@ -199,9 +199,9 @@ Arguments: hashpartitioning(c_customer_id#12, c_first_name#13, c_last_name#14, c (18) HashAggregate [codegen id : 7] Input [10]: [c_customer_id#12, c_first_name#13, c_last_name#14, c_preferred_cust_flag#15, c_birth_country#16, c_login#17, c_email_address#18, d_year#9, sum#22, isEmpty#23] Keys [8]: [c_customer_id#12, c_first_name#13, c_last_name#14, c_preferred_cust_flag#15, c_birth_country#16, c_login#17, c_email_address#18, d_year#9] -Functions [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#5 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#4 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#2 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#3 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#5 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#4 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#2 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#3 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))#25] -Results [2]: [c_customer_id#12 AS customer_id#26, sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#5 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#4 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#2 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#3 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))#25 AS year_total#27] +Functions [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#5 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#4 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#2 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#3 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#5 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#4 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#2 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#3 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))#25] +Results [2]: [c_customer_id#12 AS customer_id#26, sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#5 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#4 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#2 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#3 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))#25 AS year_total#27] (19) Filter [codegen id : 7] Input [2]: [customer_id#26, year_total#27] @@ -269,7 +269,7 @@ Input [14]: [ss_customer_sk#29, ss_ext_discount_amt#30, ss_ext_sales_price#31, s (34) HashAggregate [codegen id : 14] Input [12]: [c_customer_id#40, c_first_name#41, c_last_name#42, c_preferred_cust_flag#43, c_birth_country#44, c_login#45, c_email_address#46, ss_ext_discount_amt#30, ss_ext_sales_price#31, ss_ext_wholesale_cost#32, ss_ext_list_price#33, d_year#37] Keys [8]: [c_customer_id#40, c_first_name#41, c_last_name#42, c_preferred_cust_flag#43, c_birth_country#44, c_login#45, c_email_address#46, d_year#37] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#33 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#32 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#30 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#31 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#33 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#32 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#30 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#31 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))] Aggregate Attributes [2]: [sum#47, isEmpty#48] Results [10]: [c_customer_id#40, c_first_name#41, c_last_name#42, c_preferred_cust_flag#43, c_birth_country#44, c_login#45, c_email_address#46, d_year#37, sum#49, isEmpty#50] @@ -280,9 +280,9 @@ Arguments: hashpartitioning(c_customer_id#40, c_first_name#41, c_last_name#42, c (36) HashAggregate [codegen id : 15] Input [10]: [c_customer_id#40, c_first_name#41, c_last_name#42, c_preferred_cust_flag#43, c_birth_country#44, c_login#45, c_email_address#46, d_year#37, sum#49, isEmpty#50] Keys [8]: [c_customer_id#40, c_first_name#41, c_last_name#42, c_preferred_cust_flag#43, c_birth_country#44, c_login#45, c_email_address#46, d_year#37] -Functions [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#33 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#32 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#30 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#31 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#33 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#32 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#30 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#31 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))#25] -Results [8]: [c_customer_id#40 AS customer_id#52, c_first_name#41 AS customer_first_name#53, c_last_name#42 AS customer_last_name#54, c_preferred_cust_flag#43 AS customer_preferred_cust_flag#55, c_birth_country#44 AS customer_birth_country#56, c_login#45 AS customer_login#57, c_email_address#46 AS customer_email_address#58, sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#33 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#32 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#30 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#31 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))#25 AS year_total#59] +Functions [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#33 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#32 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#30 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#31 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#33 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#32 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#30 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#31 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))#25] +Results [8]: [c_customer_id#40 AS customer_id#52, c_first_name#41 AS customer_first_name#53, c_last_name#42 AS customer_last_name#54, c_preferred_cust_flag#43 AS customer_preferred_cust_flag#55, c_birth_country#44 AS customer_birth_country#56, c_login#45 AS customer_login#57, c_email_address#46 AS customer_email_address#58, sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#33 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#32 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#30 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#31 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))#25 AS year_total#59] (37) Exchange Input [8]: [customer_id#52, customer_first_name#53, customer_last_name#54, customer_preferred_cust_flag#55, customer_birth_country#56, customer_login#57, customer_email_address#58, year_total#59] @@ -351,7 +351,7 @@ Input [14]: [cs_bill_customer_sk#61, cs_ext_discount_amt#62, cs_ext_sales_price# (52) HashAggregate [codegen id : 23] Input [12]: [c_customer_id#71, c_first_name#72, c_last_name#73, c_preferred_cust_flag#74, c_birth_country#75, c_login#76, c_email_address#77, cs_ext_discount_amt#62, cs_ext_sales_price#63, cs_ext_wholesale_cost#64, cs_ext_list_price#65, d_year#68] Keys [8]: [c_customer_id#71, c_first_name#72, c_last_name#73, c_preferred_cust_flag#74, c_birth_country#75, c_login#76, c_email_address#77, d_year#68] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#65 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#64 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#62 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#63 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#65 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#64 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#62 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#63 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))] Aggregate Attributes [2]: [sum#78, isEmpty#79] Results [10]: [c_customer_id#71, c_first_name#72, c_last_name#73, c_preferred_cust_flag#74, c_birth_country#75, c_login#76, c_email_address#77, d_year#68, sum#80, isEmpty#81] @@ -362,9 +362,9 @@ Arguments: hashpartitioning(c_customer_id#71, c_first_name#72, c_last_name#73, c (54) HashAggregate [codegen id : 24] Input [10]: [c_customer_id#71, c_first_name#72, c_last_name#73, c_preferred_cust_flag#74, c_birth_country#75, c_login#76, c_email_address#77, d_year#68, sum#80, isEmpty#81] Keys [8]: [c_customer_id#71, c_first_name#72, c_last_name#73, c_preferred_cust_flag#74, c_birth_country#75, c_login#76, c_email_address#77, d_year#68] -Functions [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#65 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#64 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#62 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#63 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#65 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#64 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#62 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#63 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))#83] -Results [2]: [c_customer_id#71 AS customer_id#84, sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#65 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#64 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#62 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#63 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))#83 AS year_total#85] +Functions [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#65 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#64 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#62 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#63 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#65 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#64 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#62 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#63 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))#83] +Results [2]: [c_customer_id#71 AS customer_id#84, sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#65 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#64 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#62 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#63 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))#83 AS year_total#85] (55) Filter [codegen id : 24] Input [2]: [customer_id#84, year_total#85] @@ -441,7 +441,7 @@ Input [14]: [cs_bill_customer_sk#87, cs_ext_discount_amt#88, cs_ext_sales_price# (72) HashAggregate [codegen id : 32] Input [12]: [c_customer_id#97, c_first_name#98, c_last_name#99, c_preferred_cust_flag#100, c_birth_country#101, c_login#102, c_email_address#103, cs_ext_discount_amt#88, cs_ext_sales_price#89, cs_ext_wholesale_cost#90, cs_ext_list_price#91, d_year#94] Keys [8]: [c_customer_id#97, c_first_name#98, c_last_name#99, c_preferred_cust_flag#100, c_birth_country#101, c_login#102, c_email_address#103, d_year#94] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#91 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#90 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#88 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#89 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#91 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#90 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#88 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#89 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))] Aggregate Attributes [2]: [sum#104, isEmpty#105] Results [10]: [c_customer_id#97, c_first_name#98, c_last_name#99, c_preferred_cust_flag#100, c_birth_country#101, c_login#102, c_email_address#103, d_year#94, sum#106, isEmpty#107] @@ -452,9 +452,9 @@ Arguments: hashpartitioning(c_customer_id#97, c_first_name#98, c_last_name#99, c (74) HashAggregate [codegen id : 33] Input [10]: [c_customer_id#97, c_first_name#98, c_last_name#99, c_preferred_cust_flag#100, c_birth_country#101, c_login#102, c_email_address#103, d_year#94, sum#106, isEmpty#107] Keys [8]: [c_customer_id#97, c_first_name#98, c_last_name#99, c_preferred_cust_flag#100, c_birth_country#101, c_login#102, c_email_address#103, d_year#94] -Functions [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#91 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#90 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#88 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#89 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#91 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#90 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#88 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#89 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))#83] -Results [2]: [c_customer_id#97 AS customer_id#109, sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#91 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#90 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#88 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#89 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))#83 AS year_total#110] +Functions [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#91 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#90 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#88 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#89 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#91 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#90 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#88 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#89 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))#83] +Results [2]: [c_customer_id#97 AS customer_id#109, sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#91 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#90 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#88 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#89 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))#83 AS year_total#110] (75) Exchange Input [2]: [customer_id#109, year_total#110] @@ -467,7 +467,7 @@ Arguments: [customer_id#109 ASC NULLS FIRST], false, 0 (77) SortMergeJoin [codegen id : 35] Left keys [1]: [customer_id#26] Right keys [1]: [customer_id#109] -Join condition: (CASE WHEN (year_total#85 > 0.000000) THEN CheckOverflow((promote_precision(year_total#110) / promote_precision(year_total#85)), DecimalType(38,14), true) END > CASE WHEN (year_total#27 > 0.000000) THEN CheckOverflow((promote_precision(year_total#59) / promote_precision(year_total#27)), DecimalType(38,14), true) END) +Join condition: (CASE WHEN (year_total#85 > 0.000000) THEN CheckOverflow((promote_precision(year_total#110) / promote_precision(year_total#85)), DecimalType(38,14)) END > CASE WHEN (year_total#27 > 0.000000) THEN CheckOverflow((promote_precision(year_total#59) / promote_precision(year_total#27)), DecimalType(38,14)) END) (78) Project [codegen id : 35] Output [10]: [customer_id#26, customer_id#52, customer_first_name#53, customer_last_name#54, customer_preferred_cust_flag#55, customer_birth_country#56, customer_login#57, customer_email_address#58, year_total#85, year_total#110] @@ -527,7 +527,7 @@ Input [14]: [ws_bill_customer_sk#112, ws_ext_discount_amt#113, ws_ext_sales_pric (91) HashAggregate [codegen id : 41] Input [12]: [c_customer_id#122, c_first_name#123, c_last_name#124, c_preferred_cust_flag#125, c_birth_country#126, c_login#127, c_email_address#128, ws_ext_discount_amt#113, ws_ext_sales_price#114, ws_ext_wholesale_cost#115, ws_ext_list_price#116, d_year#119] Keys [8]: [c_customer_id#122, c_first_name#123, c_last_name#124, c_preferred_cust_flag#125, c_birth_country#126, c_login#127, c_email_address#128, d_year#119] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#116 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#115 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#113 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#114 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#116 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#115 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#113 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#114 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))] Aggregate Attributes [2]: [sum#129, isEmpty#130] Results [10]: [c_customer_id#122, c_first_name#123, c_last_name#124, c_preferred_cust_flag#125, c_birth_country#126, c_login#127, c_email_address#128, d_year#119, sum#131, isEmpty#132] @@ -538,9 +538,9 @@ Arguments: hashpartitioning(c_customer_id#122, c_first_name#123, c_last_name#124 (93) HashAggregate [codegen id : 42] Input [10]: [c_customer_id#122, c_first_name#123, c_last_name#124, c_preferred_cust_flag#125, c_birth_country#126, c_login#127, c_email_address#128, d_year#119, sum#131, isEmpty#132] Keys [8]: [c_customer_id#122, c_first_name#123, c_last_name#124, c_preferred_cust_flag#125, c_birth_country#126, c_login#127, c_email_address#128, d_year#119] -Functions [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#116 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#115 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#113 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#114 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#116 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#115 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#113 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#114 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))#134] -Results [2]: [c_customer_id#122 AS customer_id#135, sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#116 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#115 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#113 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#114 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))#134 AS year_total#136] +Functions [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#116 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#115 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#113 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#114 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#116 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#115 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#113 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#114 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))#134] +Results [2]: [c_customer_id#122 AS customer_id#135, sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#116 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#115 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#113 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#114 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))#134 AS year_total#136] (94) Filter [codegen id : 42] Input [2]: [customer_id#135, year_total#136] @@ -617,7 +617,7 @@ Input [14]: [ws_bill_customer_sk#138, ws_ext_discount_amt#139, ws_ext_sales_pric (111) HashAggregate [codegen id : 50] Input [12]: [c_customer_id#148, c_first_name#149, c_last_name#150, c_preferred_cust_flag#151, c_birth_country#152, c_login#153, c_email_address#154, ws_ext_discount_amt#139, ws_ext_sales_price#140, ws_ext_wholesale_cost#141, ws_ext_list_price#142, d_year#145] Keys [8]: [c_customer_id#148, c_first_name#149, c_last_name#150, c_preferred_cust_flag#151, c_birth_country#152, c_login#153, c_email_address#154, d_year#145] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#142 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#141 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#139 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#140 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#142 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#141 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#139 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#140 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))] Aggregate Attributes [2]: [sum#155, isEmpty#156] Results [10]: [c_customer_id#148, c_first_name#149, c_last_name#150, c_preferred_cust_flag#151, c_birth_country#152, c_login#153, c_email_address#154, d_year#145, sum#157, isEmpty#158] @@ -628,9 +628,9 @@ Arguments: hashpartitioning(c_customer_id#148, c_first_name#149, c_last_name#150 (113) HashAggregate [codegen id : 51] Input [10]: [c_customer_id#148, c_first_name#149, c_last_name#150, c_preferred_cust_flag#151, c_birth_country#152, c_login#153, c_email_address#154, d_year#145, sum#157, isEmpty#158] Keys [8]: [c_customer_id#148, c_first_name#149, c_last_name#150, c_preferred_cust_flag#151, c_birth_country#152, c_login#153, c_email_address#154, d_year#145] -Functions [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#142 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#141 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#139 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#140 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#142 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#141 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#139 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#140 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))#134] -Results [2]: [c_customer_id#148 AS customer_id#160, sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#142 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#141 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#139 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#140 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))#134 AS year_total#161] +Functions [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#142 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#141 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#139 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#140 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#142 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#141 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#139 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#140 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))#134] +Results [2]: [c_customer_id#148 AS customer_id#160, sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#142 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#141 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#139 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#140 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))#134 AS year_total#161] (114) Exchange Input [2]: [customer_id#160, year_total#161] @@ -643,7 +643,7 @@ Arguments: [customer_id#160 ASC NULLS FIRST], false, 0 (116) SortMergeJoin [codegen id : 53] Left keys [1]: [customer_id#26] Right keys [1]: [customer_id#160] -Join condition: (CASE WHEN (year_total#85 > 0.000000) THEN CheckOverflow((promote_precision(year_total#110) / promote_precision(year_total#85)), DecimalType(38,14), true) END > CASE WHEN (year_total#136 > 0.000000) THEN CheckOverflow((promote_precision(year_total#161) / promote_precision(year_total#136)), DecimalType(38,14), true) END) +Join condition: (CASE WHEN (year_total#85 > 0.000000) THEN CheckOverflow((promote_precision(year_total#110) / promote_precision(year_total#85)), DecimalType(38,14)) END > CASE WHEN (year_total#136 > 0.000000) THEN CheckOverflow((promote_precision(year_total#161) / promote_precision(year_total#136)), DecimalType(38,14)) END) (117) Project [codegen id : 53] Output [7]: [customer_id#52, customer_first_name#53, customer_last_name#54, customer_preferred_cust_flag#55, customer_birth_country#56, customer_login#57, customer_email_address#58] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q4.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q4.sf100/simplified.txt index cb2e3432e4ab2..e8e55fe575720 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q4.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q4.sf100/simplified.txt @@ -24,7 +24,7 @@ TakeOrderedAndProject [customer_id,customer_first_name,customer_last_name,custom Exchange [customer_id] #1 WholeStageCodegen (7) Filter [year_total] - HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum,isEmpty] [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true)),customer_id,year_total,sum,isEmpty] + HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum,isEmpty] [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6))),customer_id,year_total,sum,isEmpty] InputAdapter Exchange [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year] #2 WholeStageCodegen (6) @@ -68,7 +68,7 @@ TakeOrderedAndProject [customer_id,customer_first_name,customer_last_name,custom InputAdapter Exchange [customer_id] #6 WholeStageCodegen (15) - HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum,isEmpty] [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true)),customer_id,customer_first_name,customer_last_name,customer_preferred_cust_flag,customer_birth_country,customer_login,customer_email_address,year_total,sum,isEmpty] + HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum,isEmpty] [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6))),customer_id,customer_first_name,customer_last_name,customer_preferred_cust_flag,customer_birth_country,customer_login,customer_email_address,year_total,sum,isEmpty] InputAdapter Exchange [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year] #7 WholeStageCodegen (14) @@ -108,7 +108,7 @@ TakeOrderedAndProject [customer_id,customer_first_name,customer_last_name,custom Exchange [customer_id] #10 WholeStageCodegen (24) Filter [year_total] - HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum,isEmpty] [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true)),customer_id,year_total,sum,isEmpty] + HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum,isEmpty] [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6))),customer_id,year_total,sum,isEmpty] InputAdapter Exchange [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year] #11 WholeStageCodegen (23) @@ -141,7 +141,7 @@ TakeOrderedAndProject [customer_id,customer_first_name,customer_last_name,custom InputAdapter Exchange [customer_id] #13 WholeStageCodegen (33) - HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum,isEmpty] [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true)),customer_id,year_total,sum,isEmpty] + HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum,isEmpty] [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6))),customer_id,year_total,sum,isEmpty] InputAdapter Exchange [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year] #14 WholeStageCodegen (32) @@ -175,7 +175,7 @@ TakeOrderedAndProject [customer_id,customer_first_name,customer_last_name,custom Exchange [customer_id] #16 WholeStageCodegen (42) Filter [year_total] - HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum,isEmpty] [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true)),customer_id,year_total,sum,isEmpty] + HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum,isEmpty] [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6))),customer_id,year_total,sum,isEmpty] InputAdapter Exchange [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year] #17 WholeStageCodegen (41) @@ -208,7 +208,7 @@ TakeOrderedAndProject [customer_id,customer_first_name,customer_last_name,custom InputAdapter Exchange [customer_id] #19 WholeStageCodegen (51) - HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum,isEmpty] [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true)),customer_id,year_total,sum,isEmpty] + HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum,isEmpty] [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6))),customer_id,year_total,sum,isEmpty] InputAdapter Exchange [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year] #20 WholeStageCodegen (50) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q4/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q4/explain.txt index 9dbbacae2047e..b0af6fb5e1627 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q4/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q4/explain.txt @@ -166,7 +166,7 @@ Input [14]: [c_customer_id#2, c_first_name#3, c_last_name#4, c_preferred_cust_fl (13) HashAggregate [codegen id : 3] Input [12]: [c_customer_id#2, c_first_name#3, c_last_name#4, c_preferred_cust_flag#5, c_birth_country#6, c_login#7, c_email_address#8, ss_ext_discount_amt#10, ss_ext_sales_price#11, ss_ext_wholesale_cost#12, ss_ext_list_price#13, d_year#18] Keys [8]: [c_customer_id#2, c_first_name#3, c_last_name#4, c_preferred_cust_flag#5, c_birth_country#6, c_login#7, c_email_address#8, d_year#18] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#13 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#12 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#10 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#11 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#13 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#12 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#10 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#11 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))] Aggregate Attributes [2]: [sum#19, isEmpty#20] Results [10]: [c_customer_id#2, c_first_name#3, c_last_name#4, c_preferred_cust_flag#5, c_birth_country#6, c_login#7, c_email_address#8, d_year#18, sum#21, isEmpty#22] @@ -177,9 +177,9 @@ Arguments: hashpartitioning(c_customer_id#2, c_first_name#3, c_last_name#4, c_pr (15) HashAggregate [codegen id : 24] Input [10]: [c_customer_id#2, c_first_name#3, c_last_name#4, c_preferred_cust_flag#5, c_birth_country#6, c_login#7, c_email_address#8, d_year#18, sum#21, isEmpty#22] Keys [8]: [c_customer_id#2, c_first_name#3, c_last_name#4, c_preferred_cust_flag#5, c_birth_country#6, c_login#7, c_email_address#8, d_year#18] -Functions [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#13 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#12 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#10 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#11 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#13 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#12 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#10 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#11 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))#24] -Results [2]: [c_customer_id#2 AS customer_id#25, sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#13 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#12 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#10 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#11 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))#24 AS year_total#26] +Functions [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#13 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#12 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#10 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#11 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#13 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#12 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#10 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#11 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))#24] +Results [2]: [c_customer_id#2 AS customer_id#25, sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#13 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#12 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#10 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#11 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))#24 AS year_total#26] (16) Filter [codegen id : 24] Input [2]: [customer_id#25, year_total#26] @@ -242,7 +242,7 @@ Input [14]: [c_customer_id#28, c_first_name#29, c_last_name#30, c_preferred_cust (29) HashAggregate [codegen id : 6] Input [12]: [c_customer_id#28, c_first_name#29, c_last_name#30, c_preferred_cust_flag#31, c_birth_country#32, c_login#33, c_email_address#34, ss_ext_discount_amt#36, ss_ext_sales_price#37, ss_ext_wholesale_cost#38, ss_ext_list_price#39, d_year#44] Keys [8]: [c_customer_id#28, c_first_name#29, c_last_name#30, c_preferred_cust_flag#31, c_birth_country#32, c_login#33, c_email_address#34, d_year#44] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#39 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#38 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#36 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#37 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#39 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#38 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#36 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#37 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))] Aggregate Attributes [2]: [sum#45, isEmpty#46] Results [10]: [c_customer_id#28, c_first_name#29, c_last_name#30, c_preferred_cust_flag#31, c_birth_country#32, c_login#33, c_email_address#34, d_year#44, sum#47, isEmpty#48] @@ -253,9 +253,9 @@ Arguments: hashpartitioning(c_customer_id#28, c_first_name#29, c_last_name#30, c (31) HashAggregate [codegen id : 7] Input [10]: [c_customer_id#28, c_first_name#29, c_last_name#30, c_preferred_cust_flag#31, c_birth_country#32, c_login#33, c_email_address#34, d_year#44, sum#47, isEmpty#48] Keys [8]: [c_customer_id#28, c_first_name#29, c_last_name#30, c_preferred_cust_flag#31, c_birth_country#32, c_login#33, c_email_address#34, d_year#44] -Functions [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#39 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#38 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#36 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#37 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#39 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#38 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#36 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#37 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))#24] -Results [8]: [c_customer_id#28 AS customer_id#50, c_first_name#29 AS customer_first_name#51, c_last_name#30 AS customer_last_name#52, c_preferred_cust_flag#31 AS customer_preferred_cust_flag#53, c_birth_country#32 AS customer_birth_country#54, c_login#33 AS customer_login#55, c_email_address#34 AS customer_email_address#56, sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#39 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#38 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#36 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#37 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))#24 AS year_total#57] +Functions [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#39 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#38 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#36 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#37 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#39 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#38 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#36 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#37 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))#24] +Results [8]: [c_customer_id#28 AS customer_id#50, c_first_name#29 AS customer_first_name#51, c_last_name#30 AS customer_last_name#52, c_preferred_cust_flag#31 AS customer_preferred_cust_flag#53, c_birth_country#32 AS customer_birth_country#54, c_login#33 AS customer_login#55, c_email_address#34 AS customer_email_address#56, sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price#39 as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost#38 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt#36 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price#37 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))#24 AS year_total#57] (32) BroadcastExchange Input [8]: [customer_id#50, customer_first_name#51, customer_last_name#52, customer_preferred_cust_flag#53, customer_birth_country#54, customer_login#55, customer_email_address#56, year_total#57] @@ -323,7 +323,7 @@ Input [14]: [c_customer_id#60, c_first_name#61, c_last_name#62, c_preferred_cust (46) HashAggregate [codegen id : 10] Input [12]: [c_customer_id#60, c_first_name#61, c_last_name#62, c_preferred_cust_flag#63, c_birth_country#64, c_login#65, c_email_address#66, cs_ext_discount_amt#68, cs_ext_sales_price#69, cs_ext_wholesale_cost#70, cs_ext_list_price#71, d_year#75] Keys [8]: [c_customer_id#60, c_first_name#61, c_last_name#62, c_preferred_cust_flag#63, c_birth_country#64, c_login#65, c_email_address#66, d_year#75] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#71 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#70 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#68 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#69 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#71 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#70 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#68 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#69 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))] Aggregate Attributes [2]: [sum#76, isEmpty#77] Results [10]: [c_customer_id#60, c_first_name#61, c_last_name#62, c_preferred_cust_flag#63, c_birth_country#64, c_login#65, c_email_address#66, d_year#75, sum#78, isEmpty#79] @@ -334,9 +334,9 @@ Arguments: hashpartitioning(c_customer_id#60, c_first_name#61, c_last_name#62, c (48) HashAggregate [codegen id : 11] Input [10]: [c_customer_id#60, c_first_name#61, c_last_name#62, c_preferred_cust_flag#63, c_birth_country#64, c_login#65, c_email_address#66, d_year#75, sum#78, isEmpty#79] Keys [8]: [c_customer_id#60, c_first_name#61, c_last_name#62, c_preferred_cust_flag#63, c_birth_country#64, c_login#65, c_email_address#66, d_year#75] -Functions [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#71 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#70 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#68 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#69 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#71 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#70 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#68 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#69 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))#81] -Results [2]: [c_customer_id#60 AS customer_id#82, sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#71 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#70 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#68 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#69 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))#81 AS year_total#83] +Functions [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#71 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#70 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#68 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#69 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#71 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#70 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#68 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#69 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))#81] +Results [2]: [c_customer_id#60 AS customer_id#82, sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#71 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#70 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#68 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#69 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))#81 AS year_total#83] (49) Filter [codegen id : 11] Input [2]: [customer_id#82, year_total#83] @@ -412,7 +412,7 @@ Input [14]: [c_customer_id#86, c_first_name#87, c_last_name#88, c_preferred_cust (65) HashAggregate [codegen id : 14] Input [12]: [c_customer_id#86, c_first_name#87, c_last_name#88, c_preferred_cust_flag#89, c_birth_country#90, c_login#91, c_email_address#92, cs_ext_discount_amt#94, cs_ext_sales_price#95, cs_ext_wholesale_cost#96, cs_ext_list_price#97, d_year#101] Keys [8]: [c_customer_id#86, c_first_name#87, c_last_name#88, c_preferred_cust_flag#89, c_birth_country#90, c_login#91, c_email_address#92, d_year#101] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#97 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#96 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#94 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#95 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#97 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#96 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#94 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#95 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))] Aggregate Attributes [2]: [sum#102, isEmpty#103] Results [10]: [c_customer_id#86, c_first_name#87, c_last_name#88, c_preferred_cust_flag#89, c_birth_country#90, c_login#91, c_email_address#92, d_year#101, sum#104, isEmpty#105] @@ -423,9 +423,9 @@ Arguments: hashpartitioning(c_customer_id#86, c_first_name#87, c_last_name#88, c (67) HashAggregate [codegen id : 15] Input [10]: [c_customer_id#86, c_first_name#87, c_last_name#88, c_preferred_cust_flag#89, c_birth_country#90, c_login#91, c_email_address#92, d_year#101, sum#104, isEmpty#105] Keys [8]: [c_customer_id#86, c_first_name#87, c_last_name#88, c_preferred_cust_flag#89, c_birth_country#90, c_login#91, c_email_address#92, d_year#101] -Functions [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#97 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#96 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#94 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#95 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#97 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#96 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#94 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#95 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))#81] -Results [2]: [c_customer_id#86 AS customer_id#107, sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#97 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#96 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#94 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#95 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))#81 AS year_total#108] +Functions [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#97 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#96 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#94 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#95 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#97 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#96 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#94 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#95 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))#81] +Results [2]: [c_customer_id#86 AS customer_id#107, sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price#97 as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost#96 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt#94 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price#95 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))#81 AS year_total#108] (68) BroadcastExchange Input [2]: [customer_id#107, year_total#108] @@ -434,7 +434,7 @@ Arguments: HashedRelationBroadcastMode(List(input[0, string, true]),false), [id= (69) BroadcastHashJoin [codegen id : 24] Left keys [1]: [customer_id#25] Right keys [1]: [customer_id#107] -Join condition: (CASE WHEN (year_total#83 > 0.000000) THEN CheckOverflow((promote_precision(year_total#108) / promote_precision(year_total#83)), DecimalType(38,14), true) END > CASE WHEN (year_total#26 > 0.000000) THEN CheckOverflow((promote_precision(year_total#57) / promote_precision(year_total#26)), DecimalType(38,14), true) END) +Join condition: (CASE WHEN (year_total#83 > 0.000000) THEN CheckOverflow((promote_precision(year_total#108) / promote_precision(year_total#83)), DecimalType(38,14)) END > CASE WHEN (year_total#26 > 0.000000) THEN CheckOverflow((promote_precision(year_total#57) / promote_precision(year_total#26)), DecimalType(38,14)) END) (70) Project [codegen id : 24] Output [10]: [customer_id#25, customer_id#50, customer_first_name#51, customer_last_name#52, customer_preferred_cust_flag#53, customer_birth_country#54, customer_login#55, customer_email_address#56, year_total#83, year_total#108] @@ -497,7 +497,7 @@ Input [14]: [c_customer_id#111, c_first_name#112, c_last_name#113, c_preferred_c (83) HashAggregate [codegen id : 18] Input [12]: [c_customer_id#111, c_first_name#112, c_last_name#113, c_preferred_cust_flag#114, c_birth_country#115, c_login#116, c_email_address#117, ws_ext_discount_amt#119, ws_ext_sales_price#120, ws_ext_wholesale_cost#121, ws_ext_list_price#122, d_year#126] Keys [8]: [c_customer_id#111, c_first_name#112, c_last_name#113, c_preferred_cust_flag#114, c_birth_country#115, c_login#116, c_email_address#117, d_year#126] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#122 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#121 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#119 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#120 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#122 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#121 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#119 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#120 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))] Aggregate Attributes [2]: [sum#127, isEmpty#128] Results [10]: [c_customer_id#111, c_first_name#112, c_last_name#113, c_preferred_cust_flag#114, c_birth_country#115, c_login#116, c_email_address#117, d_year#126, sum#129, isEmpty#130] @@ -508,9 +508,9 @@ Arguments: hashpartitioning(c_customer_id#111, c_first_name#112, c_last_name#113 (85) HashAggregate [codegen id : 19] Input [10]: [c_customer_id#111, c_first_name#112, c_last_name#113, c_preferred_cust_flag#114, c_birth_country#115, c_login#116, c_email_address#117, d_year#126, sum#129, isEmpty#130] Keys [8]: [c_customer_id#111, c_first_name#112, c_last_name#113, c_preferred_cust_flag#114, c_birth_country#115, c_login#116, c_email_address#117, d_year#126] -Functions [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#122 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#121 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#119 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#120 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#122 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#121 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#119 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#120 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))#132] -Results [2]: [c_customer_id#111 AS customer_id#133, sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#122 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#121 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#119 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#120 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))#132 AS year_total#134] +Functions [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#122 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#121 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#119 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#120 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#122 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#121 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#119 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#120 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))#132] +Results [2]: [c_customer_id#111 AS customer_id#133, sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#122 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#121 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#119 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#120 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))#132 AS year_total#134] (86) Filter [codegen id : 19] Input [2]: [customer_id#133, year_total#134] @@ -586,7 +586,7 @@ Input [14]: [c_customer_id#137, c_first_name#138, c_last_name#139, c_preferred_c (102) HashAggregate [codegen id : 22] Input [12]: [c_customer_id#137, c_first_name#138, c_last_name#139, c_preferred_cust_flag#140, c_birth_country#141, c_login#142, c_email_address#143, ws_ext_discount_amt#145, ws_ext_sales_price#146, ws_ext_wholesale_cost#147, ws_ext_list_price#148, d_year#152] Keys [8]: [c_customer_id#137, c_first_name#138, c_last_name#139, c_preferred_cust_flag#140, c_birth_country#141, c_login#142, c_email_address#143, d_year#152] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#148 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#147 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#145 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#146 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#148 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#147 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#145 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#146 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))] Aggregate Attributes [2]: [sum#153, isEmpty#154] Results [10]: [c_customer_id#137, c_first_name#138, c_last_name#139, c_preferred_cust_flag#140, c_birth_country#141, c_login#142, c_email_address#143, d_year#152, sum#155, isEmpty#156] @@ -597,9 +597,9 @@ Arguments: hashpartitioning(c_customer_id#137, c_first_name#138, c_last_name#139 (104) HashAggregate [codegen id : 23] Input [10]: [c_customer_id#137, c_first_name#138, c_last_name#139, c_preferred_cust_flag#140, c_birth_country#141, c_login#142, c_email_address#143, d_year#152, sum#155, isEmpty#156] Keys [8]: [c_customer_id#137, c_first_name#138, c_last_name#139, c_preferred_cust_flag#140, c_birth_country#141, c_login#142, c_email_address#143, d_year#152] -Functions [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#148 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#147 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#145 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#146 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#148 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#147 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#145 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#146 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))#132] -Results [2]: [c_customer_id#137 AS customer_id#158, sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#148 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#147 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#145 as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#146 as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true))#132 AS year_total#159] +Functions [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#148 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#147 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#145 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#146 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#148 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#147 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#145 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#146 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))#132] +Results [2]: [c_customer_id#137 AS customer_id#158, sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price#148 as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost#147 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt#145 as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price#146 as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6)))#132 AS year_total#159] (105) BroadcastExchange Input [2]: [customer_id#158, year_total#159] @@ -608,7 +608,7 @@ Arguments: HashedRelationBroadcastMode(List(input[0, string, true]),false), [id= (106) BroadcastHashJoin [codegen id : 24] Left keys [1]: [customer_id#25] Right keys [1]: [customer_id#158] -Join condition: (CASE WHEN (year_total#83 > 0.000000) THEN CheckOverflow((promote_precision(year_total#108) / promote_precision(year_total#83)), DecimalType(38,14), true) END > CASE WHEN (year_total#134 > 0.000000) THEN CheckOverflow((promote_precision(year_total#159) / promote_precision(year_total#134)), DecimalType(38,14), true) END) +Join condition: (CASE WHEN (year_total#83 > 0.000000) THEN CheckOverflow((promote_precision(year_total#108) / promote_precision(year_total#83)), DecimalType(38,14)) END > CASE WHEN (year_total#134 > 0.000000) THEN CheckOverflow((promote_precision(year_total#159) / promote_precision(year_total#134)), DecimalType(38,14)) END) (107) Project [codegen id : 24] Output [7]: [customer_id#50, customer_first_name#51, customer_last_name#52, customer_preferred_cust_flag#53, customer_birth_country#54, customer_login#55, customer_email_address#56] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q4/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q4/simplified.txt index 68d4f3219238a..67afe29952d88 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q4/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q4/simplified.txt @@ -10,7 +10,7 @@ TakeOrderedAndProject [customer_id,customer_first_name,customer_last_name,custom BroadcastHashJoin [customer_id,customer_id] BroadcastHashJoin [customer_id,customer_id] Filter [year_total] - HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum,isEmpty] [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true)),customer_id,year_total,sum,isEmpty] + HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum,isEmpty] [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6))),customer_id,year_total,sum,isEmpty] InputAdapter Exchange [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year] #1 WholeStageCodegen (3) @@ -42,7 +42,7 @@ TakeOrderedAndProject [customer_id,customer_first_name,customer_last_name,custom InputAdapter BroadcastExchange #4 WholeStageCodegen (7) - HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum,isEmpty] [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true)),customer_id,customer_first_name,customer_last_name,customer_preferred_cust_flag,customer_birth_country,customer_login,customer_email_address,year_total,sum,isEmpty] + HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum,isEmpty] [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_ext_list_price as decimal(8,2))) - promote_precision(cast(ss_ext_wholesale_cost as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ss_ext_discount_amt as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ss_ext_sales_price as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6))),customer_id,customer_first_name,customer_last_name,customer_preferred_cust_flag,customer_birth_country,customer_login,customer_email_address,year_total,sum,isEmpty] InputAdapter Exchange [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year] #5 WholeStageCodegen (6) @@ -75,7 +75,7 @@ TakeOrderedAndProject [customer_id,customer_first_name,customer_last_name,custom BroadcastExchange #8 WholeStageCodegen (11) Filter [year_total] - HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum,isEmpty] [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true)),customer_id,year_total,sum,isEmpty] + HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum,isEmpty] [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6))),customer_id,year_total,sum,isEmpty] InputAdapter Exchange [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year] #9 WholeStageCodegen (10) @@ -101,7 +101,7 @@ TakeOrderedAndProject [customer_id,customer_first_name,customer_last_name,custom InputAdapter BroadcastExchange #11 WholeStageCodegen (15) - HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum,isEmpty] [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true)),customer_id,year_total,sum,isEmpty] + HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum,isEmpty] [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cs_ext_list_price as decimal(8,2))) - promote_precision(cast(cs_ext_wholesale_cost as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(cs_ext_discount_amt as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(cs_ext_sales_price as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6))),customer_id,year_total,sum,isEmpty] InputAdapter Exchange [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year] #12 WholeStageCodegen (14) @@ -128,7 +128,7 @@ TakeOrderedAndProject [customer_id,customer_first_name,customer_last_name,custom BroadcastExchange #14 WholeStageCodegen (19) Filter [year_total] - HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum,isEmpty] [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true)),customer_id,year_total,sum,isEmpty] + HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum,isEmpty] [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6))),customer_id,year_total,sum,isEmpty] InputAdapter Exchange [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year] #15 WholeStageCodegen (18) @@ -154,7 +154,7 @@ TakeOrderedAndProject [customer_id,customer_first_name,customer_last_name,custom InputAdapter BroadcastExchange #17 WholeStageCodegen (23) - HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum,isEmpty] [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt as decimal(9,2)))), DecimalType(9,2), true) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price as decimal(10,2)))), DecimalType(10,2), true)) / 2.00), DecimalType(14,6), true)),customer_id,year_total,sum,isEmpty] + HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum,isEmpty] [sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ws_ext_list_price as decimal(8,2))) - promote_precision(cast(ws_ext_wholesale_cost as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) - promote_precision(cast(ws_ext_discount_amt as decimal(9,2)))), DecimalType(9,2)) as decimal(10,2))) + promote_precision(cast(ws_ext_sales_price as decimal(10,2)))), DecimalType(10,2))) / 2.00), DecimalType(14,6))),customer_id,year_total,sum,isEmpty] InputAdapter Exchange [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year] #18 WholeStageCodegen (22) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q40.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q40.sf100/explain.txt index 0da152eaf66a8..32d76db8cdf3a 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q40.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q40.sf100/explain.txt @@ -165,7 +165,7 @@ Input [7]: [cs_warehouse_sk#1, cs_sales_price#4, cr_refunded_cash#10, i_item_id# (30) HashAggregate [codegen id : 8] Input [5]: [cs_sales_price#4, cr_refunded_cash#10, w_state#20, i_item_id#14, d_date#18] Keys [2]: [w_state#20, i_item_id#14] -Functions [2]: [partial_sum(CASE WHEN (d_date#18 < 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_date#18 >= 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true) ELSE 0.00 END)] +Functions [2]: [partial_sum(CASE WHEN (d_date#18 < 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_date#18 >= 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)) ELSE 0.00 END)] Aggregate Attributes [4]: [sum#22, isEmpty#23, sum#24, isEmpty#25] Results [6]: [w_state#20, i_item_id#14, sum#26, isEmpty#27, sum#28, isEmpty#29] @@ -176,9 +176,9 @@ Arguments: hashpartitioning(w_state#20, i_item_id#14, 5), ENSURE_REQUIREMENTS, [ (32) HashAggregate [codegen id : 9] Input [6]: [w_state#20, i_item_id#14, sum#26, isEmpty#27, sum#28, isEmpty#29] Keys [2]: [w_state#20, i_item_id#14] -Functions [2]: [sum(CASE WHEN (d_date#18 < 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true) ELSE 0.00 END), sum(CASE WHEN (d_date#18 >= 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true) ELSE 0.00 END)] -Aggregate Attributes [2]: [sum(CASE WHEN (d_date#18 < 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true) ELSE 0.00 END)#31, sum(CASE WHEN (d_date#18 >= 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true) ELSE 0.00 END)#32] -Results [4]: [w_state#20, i_item_id#14, sum(CASE WHEN (d_date#18 < 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true) ELSE 0.00 END)#31 AS sales_before#33, sum(CASE WHEN (d_date#18 >= 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true) ELSE 0.00 END)#32 AS sales_after#34] +Functions [2]: [sum(CASE WHEN (d_date#18 < 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)) ELSE 0.00 END), sum(CASE WHEN (d_date#18 >= 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)) ELSE 0.00 END)] +Aggregate Attributes [2]: [sum(CASE WHEN (d_date#18 < 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)) ELSE 0.00 END)#31, sum(CASE WHEN (d_date#18 >= 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)) ELSE 0.00 END)#32] +Results [4]: [w_state#20, i_item_id#14, sum(CASE WHEN (d_date#18 < 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)) ELSE 0.00 END)#31 AS sales_before#33, sum(CASE WHEN (d_date#18 >= 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)) ELSE 0.00 END)#32 AS sales_after#34] (33) TakeOrderedAndProject Input [4]: [w_state#20, i_item_id#14, sales_before#33, sales_after#34] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q40.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q40.sf100/simplified.txt index 296e9186f9fd9..5854dc101f305 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q40.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q40.sf100/simplified.txt @@ -1,6 +1,6 @@ TakeOrderedAndProject [w_state,i_item_id,sales_before,sales_after] WholeStageCodegen (9) - HashAggregate [w_state,i_item_id,sum,isEmpty,sum,isEmpty] [sum(CASE WHEN (d_date < 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true) ELSE 0.00 END),sum(CASE WHEN (d_date >= 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true) ELSE 0.00 END),sales_before,sales_after,sum,isEmpty,sum,isEmpty] + HashAggregate [w_state,i_item_id,sum,isEmpty,sum,isEmpty] [sum(CASE WHEN (d_date < 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)) ELSE 0.00 END),sum(CASE WHEN (d_date >= 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)) ELSE 0.00 END),sales_before,sales_after,sum,isEmpty,sum,isEmpty] InputAdapter Exchange [w_state,i_item_id] #1 WholeStageCodegen (8) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q40/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q40/explain.txt index 7678a91036fd6..f1a79d04f36bc 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q40/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q40/explain.txt @@ -165,7 +165,7 @@ Input [7]: [cs_sales_price#4, cs_sold_date_sk#5, cr_refunded_cash#10, w_state#14 (30) HashAggregate [codegen id : 8] Input [5]: [cs_sales_price#4, cr_refunded_cash#10, w_state#14, i_item_id#17, d_date#21] Keys [2]: [w_state#14, i_item_id#17] -Functions [2]: [partial_sum(CASE WHEN (d_date#21 < 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_date#21 >= 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true) ELSE 0.00 END)] +Functions [2]: [partial_sum(CASE WHEN (d_date#21 < 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_date#21 >= 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)) ELSE 0.00 END)] Aggregate Attributes [4]: [sum#22, isEmpty#23, sum#24, isEmpty#25] Results [6]: [w_state#14, i_item_id#17, sum#26, isEmpty#27, sum#28, isEmpty#29] @@ -176,9 +176,9 @@ Arguments: hashpartitioning(w_state#14, i_item_id#17, 5), ENSURE_REQUIREMENTS, [ (32) HashAggregate [codegen id : 9] Input [6]: [w_state#14, i_item_id#17, sum#26, isEmpty#27, sum#28, isEmpty#29] Keys [2]: [w_state#14, i_item_id#17] -Functions [2]: [sum(CASE WHEN (d_date#21 < 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true) ELSE 0.00 END), sum(CASE WHEN (d_date#21 >= 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true) ELSE 0.00 END)] -Aggregate Attributes [2]: [sum(CASE WHEN (d_date#21 < 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true) ELSE 0.00 END)#31, sum(CASE WHEN (d_date#21 >= 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true) ELSE 0.00 END)#32] -Results [4]: [w_state#14, i_item_id#17, sum(CASE WHEN (d_date#21 < 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true) ELSE 0.00 END)#31 AS sales_before#33, sum(CASE WHEN (d_date#21 >= 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true) ELSE 0.00 END)#32 AS sales_after#34] +Functions [2]: [sum(CASE WHEN (d_date#21 < 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)) ELSE 0.00 END), sum(CASE WHEN (d_date#21 >= 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)) ELSE 0.00 END)] +Aggregate Attributes [2]: [sum(CASE WHEN (d_date#21 < 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)) ELSE 0.00 END)#31, sum(CASE WHEN (d_date#21 >= 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)) ELSE 0.00 END)#32] +Results [4]: [w_state#14, i_item_id#17, sum(CASE WHEN (d_date#21 < 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)) ELSE 0.00 END)#31 AS sales_before#33, sum(CASE WHEN (d_date#21 >= 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#4 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash#10 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)) ELSE 0.00 END)#32 AS sales_after#34] (33) TakeOrderedAndProject Input [4]: [w_state#14, i_item_id#17, sales_before#33, sales_after#34] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q40/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q40/simplified.txt index c691a23f64bf9..206317e8a5210 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q40/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q40/simplified.txt @@ -1,6 +1,6 @@ TakeOrderedAndProject [w_state,i_item_id,sales_before,sales_after] WholeStageCodegen (9) - HashAggregate [w_state,i_item_id,sum,isEmpty,sum,isEmpty] [sum(CASE WHEN (d_date < 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true) ELSE 0.00 END),sum(CASE WHEN (d_date >= 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true) ELSE 0.00 END),sales_before,sales_after,sum,isEmpty,sum,isEmpty] + HashAggregate [w_state,i_item_id,sum,isEmpty,sum,isEmpty] [sum(CASE WHEN (d_date < 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)) ELSE 0.00 END),sum(CASE WHEN (d_date >= 2000-03-11) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_refunded_cash as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)) ELSE 0.00 END),sales_before,sales_after,sum,isEmpty,sum,isEmpty] InputAdapter Exchange [w_state,i_item_id] #1 WholeStageCodegen (8) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q44.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q44.sf100/explain.txt index 8fa5abffaa52f..0d7aa6dbdfbb8 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q44.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q44.sf100/explain.txt @@ -70,7 +70,7 @@ Results [2]: [ss_item_sk#1 AS item_sk#11, cast((avg(UnscaledValue(ss_net_profit# (8) Filter [codegen id : 2] Input [2]: [item_sk#11, rank_col#12] -Condition : (isnotnull(rank_col#12) AND (cast(rank_col#12 as decimal(13,7)) > CheckOverflow((0.900000 * promote_precision(Subquery scalar-subquery#13, [id=#14])), DecimalType(13,7), true))) +Condition : (isnotnull(rank_col#12) AND (cast(rank_col#12 as decimal(13,7)) > CheckOverflow((0.900000 * promote_precision(Subquery scalar-subquery#13, [id=#14])), DecimalType(13,7)))) (9) Exchange Input [2]: [item_sk#11, rank_col#12] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q44/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q44/explain.txt index b3d0081f5d22e..5783d8b49b6a0 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q44/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q44/explain.txt @@ -71,7 +71,7 @@ Results [2]: [ss_item_sk#1 AS item_sk#11, cast((avg(UnscaledValue(ss_net_profit# (8) Filter [codegen id : 2] Input [2]: [item_sk#11, rank_col#12] -Condition : (isnotnull(rank_col#12) AND (cast(rank_col#12 as decimal(13,7)) > CheckOverflow((0.900000 * promote_precision(Subquery scalar-subquery#13, [id=#14])), DecimalType(13,7), true))) +Condition : (isnotnull(rank_col#12) AND (cast(rank_col#12 as decimal(13,7)) > CheckOverflow((0.900000 * promote_precision(Subquery scalar-subquery#13, [id=#14])), DecimalType(13,7)))) (9) Exchange Input [2]: [item_sk#11, rank_col#12] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q47.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q47.sf100/explain.txt index 44a956471b61e..23dfbecdbca9d 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q47.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q47.sf100/explain.txt @@ -1,53 +1,56 @@ == Physical Plan == -TakeOrderedAndProject (49) -+- * Project (48) - +- * SortMergeJoin Inner (47) - :- * Project (41) - : +- * SortMergeJoin Inner (40) - : :- * Sort (32) - : : +- * Project (31) - : : +- * Filter (30) - : : +- Window (29) - : : +- * Filter (28) - : : +- Window (27) - : : +- * Sort (26) - : : +- Exchange (25) - : : +- * HashAggregate (24) - : : +- Exchange (23) - : : +- * HashAggregate (22) - : : +- * Project (21) - : : +- * SortMergeJoin Inner (20) - : : :- * Sort (14) - : : : +- Exchange (13) - : : : +- * Project (12) - : : : +- * BroadcastHashJoin Inner BuildRight (11) - : : : :- * Project (6) - : : : : +- * BroadcastHashJoin Inner BuildRight (5) - : : : : :- * Filter (3) - : : : : : +- * ColumnarToRow (2) - : : : : : +- Scan parquet default.store_sales (1) - : : : : +- ReusedExchange (4) - : : : +- BroadcastExchange (10) - : : : +- * Filter (9) - : : : +- * ColumnarToRow (8) - : : : +- Scan parquet default.store (7) - : : +- * Sort (19) - : : +- Exchange (18) - : : +- * Filter (17) - : : +- * ColumnarToRow (16) - : : +- Scan parquet default.item (15) - : +- * Sort (39) - : +- * Project (38) - : +- Window (37) - : +- * Sort (36) - : +- Exchange (35) - : +- * HashAggregate (34) - : +- ReusedExchange (33) - +- * Sort (46) - +- * Project (45) - +- Window (44) - +- * Sort (43) - +- ReusedExchange (42) +TakeOrderedAndProject (52) ++- * Project (51) + +- * SortMergeJoin Inner (50) + :- * Project (43) + : +- * SortMergeJoin Inner (42) + : :- * Sort (33) + : : +- Exchange (32) + : : +- * Project (31) + : : +- * Filter (30) + : : +- Window (29) + : : +- * Filter (28) + : : +- Window (27) + : : +- * Sort (26) + : : +- Exchange (25) + : : +- * HashAggregate (24) + : : +- Exchange (23) + : : +- * HashAggregate (22) + : : +- * Project (21) + : : +- * SortMergeJoin Inner (20) + : : :- * Sort (14) + : : : +- Exchange (13) + : : : +- * Project (12) + : : : +- * BroadcastHashJoin Inner BuildRight (11) + : : : :- * Project (6) + : : : : +- * BroadcastHashJoin Inner BuildRight (5) + : : : : :- * Filter (3) + : : : : : +- * ColumnarToRow (2) + : : : : : +- Scan parquet default.store_sales (1) + : : : : +- ReusedExchange (4) + : : : +- BroadcastExchange (10) + : : : +- * Filter (9) + : : : +- * ColumnarToRow (8) + : : : +- Scan parquet default.store (7) + : : +- * Sort (19) + : : +- Exchange (18) + : : +- * Filter (17) + : : +- * ColumnarToRow (16) + : : +- Scan parquet default.item (15) + : +- * Sort (41) + : +- Exchange (40) + : +- * Project (39) + : +- Window (38) + : +- * Sort (37) + : +- Exchange (36) + : +- * HashAggregate (35) + : +- ReusedExchange (34) + +- * Sort (49) + +- Exchange (48) + +- * Project (47) + +- Window (46) + +- * Sort (45) + +- ReusedExchange (44) (1) Scan parquet default.store_sales @@ -65,7 +68,7 @@ Input [4]: [ss_item_sk#1, ss_store_sk#2, ss_sales_price#3, ss_sold_date_sk#4] Input [4]: [ss_item_sk#1, ss_store_sk#2, ss_sales_price#3, ss_sold_date_sk#4] Condition : (isnotnull(ss_item_sk#1) AND isnotnull(ss_store_sk#2)) -(4) ReusedExchange [Reuses operator id: 53] +(4) ReusedExchange [Reuses operator id: 56] Output [3]: [d_date_sk#6, d_year#7, d_moy#8] (5) BroadcastHashJoin [codegen id : 3] @@ -183,112 +186,124 @@ Arguments: [avg(_w0#23) windowspecdefinition(i_category#16, i_brand#15, s_store_ (30) Filter [codegen id : 11] Input [10]: [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, d_year#7, d_moy#8, sum_sales#22, _w0#23, rn#25, avg_monthly_sales#26] -Condition : ((isnotnull(avg_monthly_sales#26) AND (avg_monthly_sales#26 > 0.000000)) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#22 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#26 as decimal(22,6)))), DecimalType(22,6), true), false)) / promote_precision(cast(avg_monthly_sales#26 as decimal(22,6)))), DecimalType(38,16), true) > 0.1000000000000000)) +Condition : ((isnotnull(avg_monthly_sales#26) AND (avg_monthly_sales#26 > 0.000000)) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#22 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#26 as decimal(22,6)))), DecimalType(22,6)))) / promote_precision(cast(avg_monthly_sales#26 as decimal(22,6)))), DecimalType(38,16)) > 0.1000000000000000)) (31) Project [codegen id : 11] Output [9]: [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, d_year#7, d_moy#8, sum_sales#22, avg_monthly_sales#26, rn#25] Input [10]: [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, d_year#7, d_moy#8, sum_sales#22, _w0#23, rn#25, avg_monthly_sales#26] -(32) Sort [codegen id : 11] +(32) Exchange +Input [9]: [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, d_year#7, d_moy#8, sum_sales#22, avg_monthly_sales#26, rn#25] +Arguments: hashpartitioning(i_category#16, i_brand#15, s_store_name#10, s_company_name#11, rn#25, 5), ENSURE_REQUIREMENTS, [id=#27] + +(33) Sort [codegen id : 12] Input [9]: [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, d_year#7, d_moy#8, sum_sales#22, avg_monthly_sales#26, rn#25] Arguments: [i_category#16 ASC NULLS FIRST, i_brand#15 ASC NULLS FIRST, s_store_name#10 ASC NULLS FIRST, s_company_name#11 ASC NULLS FIRST, rn#25 ASC NULLS FIRST], false, 0 -(33) ReusedExchange [Reuses operator id: 23] -Output [7]: [i_category#27, i_brand#28, s_store_name#29, s_company_name#30, d_year#31, d_moy#32, sum#33] +(34) ReusedExchange [Reuses operator id: 23] +Output [7]: [i_category#28, i_brand#29, s_store_name#30, s_company_name#31, d_year#32, d_moy#33, sum#34] -(34) HashAggregate [codegen id : 19] -Input [7]: [i_category#27, i_brand#28, s_store_name#29, s_company_name#30, d_year#31, d_moy#32, sum#33] -Keys [6]: [i_category#27, i_brand#28, s_store_name#29, s_company_name#30, d_year#31, d_moy#32] -Functions [1]: [sum(UnscaledValue(ss_sales_price#34))] -Aggregate Attributes [1]: [sum(UnscaledValue(ss_sales_price#34))#21] -Results [7]: [i_category#27, i_brand#28, s_store_name#29, s_company_name#30, d_year#31, d_moy#32, MakeDecimal(sum(UnscaledValue(ss_sales_price#34))#21,17,2) AS sum_sales#22] +(35) HashAggregate [codegen id : 20] +Input [7]: [i_category#28, i_brand#29, s_store_name#30, s_company_name#31, d_year#32, d_moy#33, sum#34] +Keys [6]: [i_category#28, i_brand#29, s_store_name#30, s_company_name#31, d_year#32, d_moy#33] +Functions [1]: [sum(UnscaledValue(ss_sales_price#35))] +Aggregate Attributes [1]: [sum(UnscaledValue(ss_sales_price#35))#21] +Results [7]: [i_category#28, i_brand#29, s_store_name#30, s_company_name#31, d_year#32, d_moy#33, MakeDecimal(sum(UnscaledValue(ss_sales_price#35))#21,17,2) AS sum_sales#22] -(35) Exchange -Input [7]: [i_category#27, i_brand#28, s_store_name#29, s_company_name#30, d_year#31, d_moy#32, sum_sales#22] -Arguments: hashpartitioning(i_category#27, i_brand#28, s_store_name#29, s_company_name#30, 5), ENSURE_REQUIREMENTS, [id=#35] +(36) Exchange +Input [7]: [i_category#28, i_brand#29, s_store_name#30, s_company_name#31, d_year#32, d_moy#33, sum_sales#22] +Arguments: hashpartitioning(i_category#28, i_brand#29, s_store_name#30, s_company_name#31, 5), ENSURE_REQUIREMENTS, [id=#36] -(36) Sort [codegen id : 20] -Input [7]: [i_category#27, i_brand#28, s_store_name#29, s_company_name#30, d_year#31, d_moy#32, sum_sales#22] -Arguments: [i_category#27 ASC NULLS FIRST, i_brand#28 ASC NULLS FIRST, s_store_name#29 ASC NULLS FIRST, s_company_name#30 ASC NULLS FIRST, d_year#31 ASC NULLS FIRST, d_moy#32 ASC NULLS FIRST], false, 0 +(37) Sort [codegen id : 21] +Input [7]: [i_category#28, i_brand#29, s_store_name#30, s_company_name#31, d_year#32, d_moy#33, sum_sales#22] +Arguments: [i_category#28 ASC NULLS FIRST, i_brand#29 ASC NULLS FIRST, s_store_name#30 ASC NULLS FIRST, s_company_name#31 ASC NULLS FIRST, d_year#32 ASC NULLS FIRST, d_moy#33 ASC NULLS FIRST], false, 0 -(37) Window -Input [7]: [i_category#27, i_brand#28, s_store_name#29, s_company_name#30, d_year#31, d_moy#32, sum_sales#22] -Arguments: [rank(d_year#31, d_moy#32) windowspecdefinition(i_category#27, i_brand#28, s_store_name#29, s_company_name#30, d_year#31 ASC NULLS FIRST, d_moy#32 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rn#36], [i_category#27, i_brand#28, s_store_name#29, s_company_name#30], [d_year#31 ASC NULLS FIRST, d_moy#32 ASC NULLS FIRST] +(38) Window +Input [7]: [i_category#28, i_brand#29, s_store_name#30, s_company_name#31, d_year#32, d_moy#33, sum_sales#22] +Arguments: [rank(d_year#32, d_moy#33) windowspecdefinition(i_category#28, i_brand#29, s_store_name#30, s_company_name#31, d_year#32 ASC NULLS FIRST, d_moy#33 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rn#37], [i_category#28, i_brand#29, s_store_name#30, s_company_name#31], [d_year#32 ASC NULLS FIRST, d_moy#33 ASC NULLS FIRST] -(38) Project [codegen id : 21] -Output [6]: [i_category#27, i_brand#28, s_store_name#29, s_company_name#30, sum_sales#22 AS sum_sales#37, rn#36] -Input [8]: [i_category#27, i_brand#28, s_store_name#29, s_company_name#30, d_year#31, d_moy#32, sum_sales#22, rn#36] +(39) Project [codegen id : 22] +Output [6]: [i_category#28, i_brand#29, s_store_name#30, s_company_name#31, sum_sales#22 AS sum_sales#38, rn#37] +Input [8]: [i_category#28, i_brand#29, s_store_name#30, s_company_name#31, d_year#32, d_moy#33, sum_sales#22, rn#37] -(39) Sort [codegen id : 21] -Input [6]: [i_category#27, i_brand#28, s_store_name#29, s_company_name#30, sum_sales#37, rn#36] -Arguments: [i_category#27 ASC NULLS FIRST, i_brand#28 ASC NULLS FIRST, s_store_name#29 ASC NULLS FIRST, s_company_name#30 ASC NULLS FIRST, (rn#36 + 1) ASC NULLS FIRST], false, 0 +(40) Exchange +Input [6]: [i_category#28, i_brand#29, s_store_name#30, s_company_name#31, sum_sales#38, rn#37] +Arguments: hashpartitioning(i_category#28, i_brand#29, s_store_name#30, s_company_name#31, (rn#37 + 1), 5), ENSURE_REQUIREMENTS, [id=#39] -(40) SortMergeJoin [codegen id : 22] +(41) Sort [codegen id : 23] +Input [6]: [i_category#28, i_brand#29, s_store_name#30, s_company_name#31, sum_sales#38, rn#37] +Arguments: [i_category#28 ASC NULLS FIRST, i_brand#29 ASC NULLS FIRST, s_store_name#30 ASC NULLS FIRST, s_company_name#31 ASC NULLS FIRST, (rn#37 + 1) ASC NULLS FIRST], false, 0 + +(42) SortMergeJoin [codegen id : 24] Left keys [5]: [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, rn#25] -Right keys [5]: [i_category#27, i_brand#28, s_store_name#29, s_company_name#30, (rn#36 + 1)] +Right keys [5]: [i_category#28, i_brand#29, s_store_name#30, s_company_name#31, (rn#37 + 1)] Join condition: None -(41) Project [codegen id : 22] -Output [10]: [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, d_year#7, d_moy#8, sum_sales#22, avg_monthly_sales#26, rn#25, sum_sales#37] -Input [15]: [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, d_year#7, d_moy#8, sum_sales#22, avg_monthly_sales#26, rn#25, i_category#27, i_brand#28, s_store_name#29, s_company_name#30, sum_sales#37, rn#36] +(43) Project [codegen id : 24] +Output [10]: [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, d_year#7, d_moy#8, sum_sales#22, avg_monthly_sales#26, rn#25, sum_sales#38] +Input [15]: [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, d_year#7, d_moy#8, sum_sales#22, avg_monthly_sales#26, rn#25, i_category#28, i_brand#29, s_store_name#30, s_company_name#31, sum_sales#38, rn#37] + +(44) ReusedExchange [Reuses operator id: 36] +Output [7]: [i_category#40, i_brand#41, s_store_name#42, s_company_name#43, d_year#44, d_moy#45, sum_sales#22] -(42) ReusedExchange [Reuses operator id: 35] -Output [7]: [i_category#38, i_brand#39, s_store_name#40, s_company_name#41, d_year#42, d_moy#43, sum_sales#22] +(45) Sort [codegen id : 33] +Input [7]: [i_category#40, i_brand#41, s_store_name#42, s_company_name#43, d_year#44, d_moy#45, sum_sales#22] +Arguments: [i_category#40 ASC NULLS FIRST, i_brand#41 ASC NULLS FIRST, s_store_name#42 ASC NULLS FIRST, s_company_name#43 ASC NULLS FIRST, d_year#44 ASC NULLS FIRST, d_moy#45 ASC NULLS FIRST], false, 0 -(43) Sort [codegen id : 31] -Input [7]: [i_category#38, i_brand#39, s_store_name#40, s_company_name#41, d_year#42, d_moy#43, sum_sales#22] -Arguments: [i_category#38 ASC NULLS FIRST, i_brand#39 ASC NULLS FIRST, s_store_name#40 ASC NULLS FIRST, s_company_name#41 ASC NULLS FIRST, d_year#42 ASC NULLS FIRST, d_moy#43 ASC NULLS FIRST], false, 0 +(46) Window +Input [7]: [i_category#40, i_brand#41, s_store_name#42, s_company_name#43, d_year#44, d_moy#45, sum_sales#22] +Arguments: [rank(d_year#44, d_moy#45) windowspecdefinition(i_category#40, i_brand#41, s_store_name#42, s_company_name#43, d_year#44 ASC NULLS FIRST, d_moy#45 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rn#46], [i_category#40, i_brand#41, s_store_name#42, s_company_name#43], [d_year#44 ASC NULLS FIRST, d_moy#45 ASC NULLS FIRST] -(44) Window -Input [7]: [i_category#38, i_brand#39, s_store_name#40, s_company_name#41, d_year#42, d_moy#43, sum_sales#22] -Arguments: [rank(d_year#42, d_moy#43) windowspecdefinition(i_category#38, i_brand#39, s_store_name#40, s_company_name#41, d_year#42 ASC NULLS FIRST, d_moy#43 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rn#44], [i_category#38, i_brand#39, s_store_name#40, s_company_name#41], [d_year#42 ASC NULLS FIRST, d_moy#43 ASC NULLS FIRST] +(47) Project [codegen id : 34] +Output [6]: [i_category#40, i_brand#41, s_store_name#42, s_company_name#43, sum_sales#22 AS sum_sales#47, rn#46] +Input [8]: [i_category#40, i_brand#41, s_store_name#42, s_company_name#43, d_year#44, d_moy#45, sum_sales#22, rn#46] -(45) Project [codegen id : 32] -Output [6]: [i_category#38, i_brand#39, s_store_name#40, s_company_name#41, sum_sales#22 AS sum_sales#45, rn#44] -Input [8]: [i_category#38, i_brand#39, s_store_name#40, s_company_name#41, d_year#42, d_moy#43, sum_sales#22, rn#44] +(48) Exchange +Input [6]: [i_category#40, i_brand#41, s_store_name#42, s_company_name#43, sum_sales#47, rn#46] +Arguments: hashpartitioning(i_category#40, i_brand#41, s_store_name#42, s_company_name#43, (rn#46 - 1), 5), ENSURE_REQUIREMENTS, [id=#48] -(46) Sort [codegen id : 32] -Input [6]: [i_category#38, i_brand#39, s_store_name#40, s_company_name#41, sum_sales#45, rn#44] -Arguments: [i_category#38 ASC NULLS FIRST, i_brand#39 ASC NULLS FIRST, s_store_name#40 ASC NULLS FIRST, s_company_name#41 ASC NULLS FIRST, (rn#44 - 1) ASC NULLS FIRST], false, 0 +(49) Sort [codegen id : 35] +Input [6]: [i_category#40, i_brand#41, s_store_name#42, s_company_name#43, sum_sales#47, rn#46] +Arguments: [i_category#40 ASC NULLS FIRST, i_brand#41 ASC NULLS FIRST, s_store_name#42 ASC NULLS FIRST, s_company_name#43 ASC NULLS FIRST, (rn#46 - 1) ASC NULLS FIRST], false, 0 -(47) SortMergeJoin [codegen id : 33] +(50) SortMergeJoin [codegen id : 36] Left keys [5]: [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, rn#25] -Right keys [5]: [i_category#38, i_brand#39, s_store_name#40, s_company_name#41, (rn#44 - 1)] +Right keys [5]: [i_category#40, i_brand#41, s_store_name#42, s_company_name#43, (rn#46 - 1)] Join condition: None -(48) Project [codegen id : 33] -Output [10]: [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, d_year#7, d_moy#8, avg_monthly_sales#26, sum_sales#22, sum_sales#37 AS psum#46, sum_sales#45 AS nsum#47] -Input [16]: [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, d_year#7, d_moy#8, sum_sales#22, avg_monthly_sales#26, rn#25, sum_sales#37, i_category#38, i_brand#39, s_store_name#40, s_company_name#41, sum_sales#45, rn#44] +(51) Project [codegen id : 36] +Output [10]: [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, d_year#7, d_moy#8, avg_monthly_sales#26, sum_sales#22, sum_sales#38 AS psum#49, sum_sales#47 AS nsum#50] +Input [16]: [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, d_year#7, d_moy#8, sum_sales#22, avg_monthly_sales#26, rn#25, sum_sales#38, i_category#40, i_brand#41, s_store_name#42, s_company_name#43, sum_sales#47, rn#46] -(49) TakeOrderedAndProject -Input [10]: [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, d_year#7, d_moy#8, avg_monthly_sales#26, sum_sales#22, psum#46, nsum#47] -Arguments: 100, [CheckOverflow((promote_precision(cast(sum_sales#22 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#26 as decimal(22,6)))), DecimalType(22,6), true) ASC NULLS FIRST, s_store_name#10 ASC NULLS FIRST], [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, d_year#7, d_moy#8, avg_monthly_sales#26, sum_sales#22, psum#46, nsum#47] +(52) TakeOrderedAndProject +Input [10]: [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, d_year#7, d_moy#8, avg_monthly_sales#26, sum_sales#22, psum#49, nsum#50] +Arguments: 100, [CheckOverflow((promote_precision(cast(sum_sales#22 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#26 as decimal(22,6)))), DecimalType(22,6)) ASC NULLS FIRST, s_store_name#10 ASC NULLS FIRST], [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, d_year#7, d_moy#8, avg_monthly_sales#26, sum_sales#22, psum#49, nsum#50] ===== Subqueries ===== Subquery:1 Hosting operator id = 1 Hosting Expression = ss_sold_date_sk#4 IN dynamicpruning#5 -BroadcastExchange (53) -+- * Filter (52) - +- * ColumnarToRow (51) - +- Scan parquet default.date_dim (50) +BroadcastExchange (56) ++- * Filter (55) + +- * ColumnarToRow (54) + +- Scan parquet default.date_dim (53) -(50) Scan parquet default.date_dim +(53) Scan parquet default.date_dim Output [3]: [d_date_sk#6, d_year#7, d_moy#8] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [Or(Or(EqualTo(d_year,1999),And(EqualTo(d_year,1998),EqualTo(d_moy,12))),And(EqualTo(d_year,2000),EqualTo(d_moy,1))), IsNotNull(d_date_sk)] ReadSchema: struct -(51) ColumnarToRow [codegen id : 1] +(54) ColumnarToRow [codegen id : 1] Input [3]: [d_date_sk#6, d_year#7, d_moy#8] -(52) Filter [codegen id : 1] +(55) Filter [codegen id : 1] Input [3]: [d_date_sk#6, d_year#7, d_moy#8] Condition : ((((d_year#7 = 1999) OR ((d_year#7 = 1998) AND (d_moy#8 = 12))) OR ((d_year#7 = 2000) AND (d_moy#8 = 1))) AND isnotnull(d_date_sk#6)) -(53) BroadcastExchange +(56) BroadcastExchange Input [3]: [d_date_sk#6, d_year#7, d_moy#8] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#48] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#51] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q47.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q47.sf100/simplified.txt index aa2346cacaf2d..07c75d91ca3cf 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q47.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q47.sf100/simplified.txt @@ -1,95 +1,104 @@ TakeOrderedAndProject [sum_sales,avg_monthly_sales,s_store_name,i_category,i_brand,s_company_name,d_year,d_moy,psum,nsum] - WholeStageCodegen (33) + WholeStageCodegen (36) Project [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy,avg_monthly_sales,sum_sales,sum_sales,sum_sales] SortMergeJoin [i_category,i_brand,s_store_name,s_company_name,rn,i_category,i_brand,s_store_name,s_company_name,rn] InputAdapter - WholeStageCodegen (22) + WholeStageCodegen (24) Project [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy,sum_sales,avg_monthly_sales,rn,sum_sales] SortMergeJoin [i_category,i_brand,s_store_name,s_company_name,rn,i_category,i_brand,s_store_name,s_company_name,rn] InputAdapter - WholeStageCodegen (11) + WholeStageCodegen (12) Sort [i_category,i_brand,s_store_name,s_company_name,rn] - Project [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy,sum_sales,avg_monthly_sales,rn] - Filter [avg_monthly_sales,sum_sales] - InputAdapter - Window [_w0,i_category,i_brand,s_store_name,s_company_name,d_year] - WholeStageCodegen (10) - Filter [d_year] - InputAdapter - Window [d_year,d_moy,i_category,i_brand,s_store_name,s_company_name] - WholeStageCodegen (9) - Sort [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy] - InputAdapter - Exchange [i_category,i_brand,s_store_name,s_company_name] #1 - WholeStageCodegen (8) - HashAggregate [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy,sum] [sum(UnscaledValue(ss_sales_price)),sum_sales,_w0,sum] - InputAdapter - Exchange [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy] #2 - WholeStageCodegen (7) - HashAggregate [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy,ss_sales_price] [sum,sum] - Project [i_brand,i_category,ss_sales_price,d_year,d_moy,s_store_name,s_company_name] - SortMergeJoin [ss_item_sk,i_item_sk] - InputAdapter - WholeStageCodegen (4) - Sort [ss_item_sk] + InputAdapter + Exchange [i_category,i_brand,s_store_name,s_company_name,rn] #1 + WholeStageCodegen (11) + Project [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy,sum_sales,avg_monthly_sales,rn] + Filter [avg_monthly_sales,sum_sales] + InputAdapter + Window [_w0,i_category,i_brand,s_store_name,s_company_name,d_year] + WholeStageCodegen (10) + Filter [d_year] + InputAdapter + Window [d_year,d_moy,i_category,i_brand,s_store_name,s_company_name] + WholeStageCodegen (9) + Sort [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy] + InputAdapter + Exchange [i_category,i_brand,s_store_name,s_company_name] #2 + WholeStageCodegen (8) + HashAggregate [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy,sum] [sum(UnscaledValue(ss_sales_price)),sum_sales,_w0,sum] + InputAdapter + Exchange [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy] #3 + WholeStageCodegen (7) + HashAggregate [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy,ss_sales_price] [sum,sum] + Project [i_brand,i_category,ss_sales_price,d_year,d_moy,s_store_name,s_company_name] + SortMergeJoin [ss_item_sk,i_item_sk] InputAdapter - Exchange [ss_item_sk] #3 - WholeStageCodegen (3) - Project [ss_item_sk,ss_sales_price,d_year,d_moy,s_store_name,s_company_name] - BroadcastHashJoin [ss_store_sk,s_store_sk] - Project [ss_item_sk,ss_store_sk,ss_sales_price,d_year,d_moy] - BroadcastHashJoin [ss_sold_date_sk,d_date_sk] - Filter [ss_item_sk,ss_store_sk] - ColumnarToRow - InputAdapter - Scan parquet default.store_sales [ss_item_sk,ss_store_sk,ss_sales_price,ss_sold_date_sk] - SubqueryBroadcast [d_date_sk] #1 - BroadcastExchange #4 - WholeStageCodegen (1) - Filter [d_year,d_moy,d_date_sk] - ColumnarToRow - InputAdapter - Scan parquet default.date_dim [d_date_sk,d_year,d_moy] - InputAdapter - ReusedExchange [d_date_sk,d_year,d_moy] #4 - InputAdapter - BroadcastExchange #5 - WholeStageCodegen (2) - Filter [s_store_sk,s_store_name,s_company_name] - ColumnarToRow + WholeStageCodegen (4) + Sort [ss_item_sk] + InputAdapter + Exchange [ss_item_sk] #4 + WholeStageCodegen (3) + Project [ss_item_sk,ss_sales_price,d_year,d_moy,s_store_name,s_company_name] + BroadcastHashJoin [ss_store_sk,s_store_sk] + Project [ss_item_sk,ss_store_sk,ss_sales_price,d_year,d_moy] + BroadcastHashJoin [ss_sold_date_sk,d_date_sk] + Filter [ss_item_sk,ss_store_sk] + ColumnarToRow + InputAdapter + Scan parquet default.store_sales [ss_item_sk,ss_store_sk,ss_sales_price,ss_sold_date_sk] + SubqueryBroadcast [d_date_sk] #1 + BroadcastExchange #5 + WholeStageCodegen (1) + Filter [d_year,d_moy,d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_year,d_moy] InputAdapter - Scan parquet default.store [s_store_sk,s_store_name,s_company_name] - InputAdapter - WholeStageCodegen (6) - Sort [i_item_sk] + ReusedExchange [d_date_sk,d_year,d_moy] #5 + InputAdapter + BroadcastExchange #6 + WholeStageCodegen (2) + Filter [s_store_sk,s_store_name,s_company_name] + ColumnarToRow + InputAdapter + Scan parquet default.store [s_store_sk,s_store_name,s_company_name] InputAdapter - Exchange [i_item_sk] #6 - WholeStageCodegen (5) - Filter [i_item_sk,i_category,i_brand] - ColumnarToRow - InputAdapter - Scan parquet default.item [i_item_sk,i_brand,i_category] + WholeStageCodegen (6) + Sort [i_item_sk] + InputAdapter + Exchange [i_item_sk] #7 + WholeStageCodegen (5) + Filter [i_item_sk,i_category,i_brand] + ColumnarToRow + InputAdapter + Scan parquet default.item [i_item_sk,i_brand,i_category] InputAdapter - WholeStageCodegen (21) + WholeStageCodegen (23) Sort [i_category,i_brand,s_store_name,s_company_name,rn] - Project [i_category,i_brand,s_store_name,s_company_name,sum_sales,rn] - InputAdapter - Window [d_year,d_moy,i_category,i_brand,s_store_name,s_company_name] - WholeStageCodegen (20) - Sort [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy] - InputAdapter - Exchange [i_category,i_brand,s_store_name,s_company_name] #7 - WholeStageCodegen (19) - HashAggregate [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy,sum] [sum(UnscaledValue(ss_sales_price)),sum_sales,sum] - InputAdapter - ReusedExchange [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy,sum] #2 + InputAdapter + Exchange [i_category,i_brand,s_store_name,s_company_name,rn] #8 + WholeStageCodegen (22) + Project [i_category,i_brand,s_store_name,s_company_name,sum_sales,rn] + InputAdapter + Window [d_year,d_moy,i_category,i_brand,s_store_name,s_company_name] + WholeStageCodegen (21) + Sort [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy] + InputAdapter + Exchange [i_category,i_brand,s_store_name,s_company_name] #9 + WholeStageCodegen (20) + HashAggregate [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy,sum] [sum(UnscaledValue(ss_sales_price)),sum_sales,sum] + InputAdapter + ReusedExchange [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy,sum] #3 InputAdapter - WholeStageCodegen (32) + WholeStageCodegen (35) Sort [i_category,i_brand,s_store_name,s_company_name,rn] - Project [i_category,i_brand,s_store_name,s_company_name,sum_sales,rn] - InputAdapter - Window [d_year,d_moy,i_category,i_brand,s_store_name,s_company_name] - WholeStageCodegen (31) - Sort [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy] - InputAdapter - ReusedExchange [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy,sum_sales] #7 + InputAdapter + Exchange [i_category,i_brand,s_store_name,s_company_name,rn] #10 + WholeStageCodegen (34) + Project [i_category,i_brand,s_store_name,s_company_name,sum_sales,rn] + InputAdapter + Window [d_year,d_moy,i_category,i_brand,s_store_name,s_company_name] + WholeStageCodegen (33) + Sort [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy] + InputAdapter + ReusedExchange [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy,sum_sales] #9 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q47/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q47/explain.txt index 4f69eb1367b8b..e7faf392ad879 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q47/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q47/explain.txt @@ -167,7 +167,7 @@ Arguments: [avg(_w0#22) windowspecdefinition(i_category#3, i_brand#2, s_store_na (27) Filter [codegen id : 22] Input [10]: [i_category#3, i_brand#2, s_store_name#14, s_company_name#15, d_year#11, d_moy#12, sum_sales#21, _w0#22, rn#24, avg_monthly_sales#25] -Condition : ((isnotnull(avg_monthly_sales#25) AND (avg_monthly_sales#25 > 0.000000)) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#25 as decimal(22,6)))), DecimalType(22,6), true), false)) / promote_precision(cast(avg_monthly_sales#25 as decimal(22,6)))), DecimalType(38,16), true) > 0.1000000000000000)) +Condition : ((isnotnull(avg_monthly_sales#25) AND (avg_monthly_sales#25 > 0.000000)) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#25 as decimal(22,6)))), DecimalType(22,6)))) / promote_precision(cast(avg_monthly_sales#25 as decimal(22,6)))), DecimalType(38,16)) > 0.1000000000000000)) (28) Project [codegen id : 22] Output [9]: [i_category#3, i_brand#2, s_store_name#14, s_company_name#15, d_year#11, d_moy#12, sum_sales#21, avg_monthly_sales#25, rn#24] @@ -242,7 +242,7 @@ Input [16]: [i_category#3, i_brand#2, s_store_name#14, s_company_name#15, d_year (45) TakeOrderedAndProject Input [10]: [i_category#3, i_brand#2, s_store_name#14, s_company_name#15, d_year#11, d_moy#12, avg_monthly_sales#25, sum_sales#21, psum#47, nsum#48] -Arguments: 100, [CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#25 as decimal(22,6)))), DecimalType(22,6), true) ASC NULLS FIRST, s_store_name#14 ASC NULLS FIRST], [i_category#3, i_brand#2, s_store_name#14, s_company_name#15, d_year#11, d_moy#12, avg_monthly_sales#25, sum_sales#21, psum#47, nsum#48] +Arguments: 100, [CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#25 as decimal(22,6)))), DecimalType(22,6)) ASC NULLS FIRST, s_store_name#14 ASC NULLS FIRST], [i_category#3, i_brand#2, s_store_name#14, s_company_name#15, d_year#11, d_moy#12, avg_monthly_sales#25, sum_sales#21, psum#47, nsum#48] ===== Subqueries ===== diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q49.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q49.sf100/explain.txt index 889ada3f2bd24..65606c025adc4 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q49.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q49.sf100/explain.txt @@ -177,7 +177,7 @@ Input [7]: [ws_item_sk#1, sum#22, sum#23, sum#24, isEmpty#25, sum#26, isEmpty#27 Keys [1]: [ws_item_sk#1] Functions [4]: [sum(coalesce(wr_return_quantity#12, 0)), sum(coalesce(ws_quantity#3, 0)), sum(coalesce(cast(wr_return_amt#13 as decimal(12,2)), 0.00)), sum(coalesce(cast(ws_net_paid#4 as decimal(12,2)), 0.00))] Aggregate Attributes [4]: [sum(coalesce(wr_return_quantity#12, 0))#29, sum(coalesce(ws_quantity#3, 0))#30, sum(coalesce(cast(wr_return_amt#13 as decimal(12,2)), 0.00))#31, sum(coalesce(cast(ws_net_paid#4 as decimal(12,2)), 0.00))#32] -Results [3]: [ws_item_sk#1 AS item#33, CheckOverflow((promote_precision(cast(sum(coalesce(wr_return_quantity#12, 0))#29 as decimal(15,4))) / promote_precision(cast(sum(coalesce(ws_quantity#3, 0))#30 as decimal(15,4)))), DecimalType(35,20), true) AS return_ratio#34, CheckOverflow((promote_precision(cast(sum(coalesce(cast(wr_return_amt#13 as decimal(12,2)), 0.00))#31 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cast(ws_net_paid#4 as decimal(12,2)), 0.00))#32 as decimal(15,4)))), DecimalType(35,20), true) AS currency_ratio#35] +Results [3]: [ws_item_sk#1 AS item#33, CheckOverflow((promote_precision(cast(sum(coalesce(wr_return_quantity#12, 0))#29 as decimal(15,4))) / promote_precision(cast(sum(coalesce(ws_quantity#3, 0))#30 as decimal(15,4)))), DecimalType(35,20)) AS return_ratio#34, CheckOverflow((promote_precision(cast(sum(coalesce(cast(wr_return_amt#13 as decimal(12,2)), 0.00))#31 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cast(ws_net_paid#4 as decimal(12,2)), 0.00))#32 as decimal(15,4)))), DecimalType(35,20)) AS currency_ratio#35] (21) Exchange Input [3]: [item#33, return_ratio#34, currency_ratio#35] @@ -297,7 +297,7 @@ Input [7]: [cs_item_sk#40, sum#60, sum#61, sum#62, isEmpty#63, sum#64, isEmpty#6 Keys [1]: [cs_item_sk#40] Functions [4]: [sum(coalesce(cr_return_quantity#50, 0)), sum(coalesce(cs_quantity#42, 0)), sum(coalesce(cast(cr_return_amount#51 as decimal(12,2)), 0.00)), sum(coalesce(cast(cs_net_paid#43 as decimal(12,2)), 0.00))] Aggregate Attributes [4]: [sum(coalesce(cr_return_quantity#50, 0))#67, sum(coalesce(cs_quantity#42, 0))#68, sum(coalesce(cast(cr_return_amount#51 as decimal(12,2)), 0.00))#69, sum(coalesce(cast(cs_net_paid#43 as decimal(12,2)), 0.00))#70] -Results [3]: [cs_item_sk#40 AS item#71, CheckOverflow((promote_precision(cast(sum(coalesce(cr_return_quantity#50, 0))#67 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cs_quantity#42, 0))#68 as decimal(15,4)))), DecimalType(35,20), true) AS return_ratio#72, CheckOverflow((promote_precision(cast(sum(coalesce(cast(cr_return_amount#51 as decimal(12,2)), 0.00))#69 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cast(cs_net_paid#43 as decimal(12,2)), 0.00))#70 as decimal(15,4)))), DecimalType(35,20), true) AS currency_ratio#73] +Results [3]: [cs_item_sk#40 AS item#71, CheckOverflow((promote_precision(cast(sum(coalesce(cr_return_quantity#50, 0))#67 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cs_quantity#42, 0))#68 as decimal(15,4)))), DecimalType(35,20)) AS return_ratio#72, CheckOverflow((promote_precision(cast(sum(coalesce(cast(cr_return_amount#51 as decimal(12,2)), 0.00))#69 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cast(cs_net_paid#43 as decimal(12,2)), 0.00))#70 as decimal(15,4)))), DecimalType(35,20)) AS currency_ratio#73] (48) Exchange Input [3]: [item#71, return_ratio#72, currency_ratio#73] @@ -417,7 +417,7 @@ Input [7]: [ss_item_sk#78, sum#98, sum#99, sum#100, isEmpty#101, sum#102, isEmpt Keys [1]: [ss_item_sk#78] Functions [4]: [sum(coalesce(sr_return_quantity#88, 0)), sum(coalesce(ss_quantity#80, 0)), sum(coalesce(cast(sr_return_amt#89 as decimal(12,2)), 0.00)), sum(coalesce(cast(ss_net_paid#81 as decimal(12,2)), 0.00))] Aggregate Attributes [4]: [sum(coalesce(sr_return_quantity#88, 0))#105, sum(coalesce(ss_quantity#80, 0))#106, sum(coalesce(cast(sr_return_amt#89 as decimal(12,2)), 0.00))#107, sum(coalesce(cast(ss_net_paid#81 as decimal(12,2)), 0.00))#108] -Results [3]: [ss_item_sk#78 AS item#109, CheckOverflow((promote_precision(cast(sum(coalesce(sr_return_quantity#88, 0))#105 as decimal(15,4))) / promote_precision(cast(sum(coalesce(ss_quantity#80, 0))#106 as decimal(15,4)))), DecimalType(35,20), true) AS return_ratio#110, CheckOverflow((promote_precision(cast(sum(coalesce(cast(sr_return_amt#89 as decimal(12,2)), 0.00))#107 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cast(ss_net_paid#81 as decimal(12,2)), 0.00))#108 as decimal(15,4)))), DecimalType(35,20), true) AS currency_ratio#111] +Results [3]: [ss_item_sk#78 AS item#109, CheckOverflow((promote_precision(cast(sum(coalesce(sr_return_quantity#88, 0))#105 as decimal(15,4))) / promote_precision(cast(sum(coalesce(ss_quantity#80, 0))#106 as decimal(15,4)))), DecimalType(35,20)) AS return_ratio#110, CheckOverflow((promote_precision(cast(sum(coalesce(cast(sr_return_amt#89 as decimal(12,2)), 0.00))#107 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cast(ss_net_paid#81 as decimal(12,2)), 0.00))#108 as decimal(15,4)))), DecimalType(35,20)) AS currency_ratio#111] (75) Exchange Input [3]: [item#109, return_ratio#110, currency_ratio#111] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q49/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q49/explain.txt index 399ab59cd7a71..ac64de5188462 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q49/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q49/explain.txt @@ -156,7 +156,7 @@ Input [7]: [ws_item_sk#1, sum#21, sum#22, sum#23, isEmpty#24, sum#25, isEmpty#26 Keys [1]: [ws_item_sk#1] Functions [4]: [sum(coalesce(wr_return_quantity#11, 0)), sum(coalesce(ws_quantity#3, 0)), sum(coalesce(cast(wr_return_amt#12 as decimal(12,2)), 0.00)), sum(coalesce(cast(ws_net_paid#4 as decimal(12,2)), 0.00))] Aggregate Attributes [4]: [sum(coalesce(wr_return_quantity#11, 0))#28, sum(coalesce(ws_quantity#3, 0))#29, sum(coalesce(cast(wr_return_amt#12 as decimal(12,2)), 0.00))#30, sum(coalesce(cast(ws_net_paid#4 as decimal(12,2)), 0.00))#31] -Results [3]: [ws_item_sk#1 AS item#32, CheckOverflow((promote_precision(cast(sum(coalesce(wr_return_quantity#11, 0))#28 as decimal(15,4))) / promote_precision(cast(sum(coalesce(ws_quantity#3, 0))#29 as decimal(15,4)))), DecimalType(35,20), true) AS return_ratio#33, CheckOverflow((promote_precision(cast(sum(coalesce(cast(wr_return_amt#12 as decimal(12,2)), 0.00))#30 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cast(ws_net_paid#4 as decimal(12,2)), 0.00))#31 as decimal(15,4)))), DecimalType(35,20), true) AS currency_ratio#34] +Results [3]: [ws_item_sk#1 AS item#32, CheckOverflow((promote_precision(cast(sum(coalesce(wr_return_quantity#11, 0))#28 as decimal(15,4))) / promote_precision(cast(sum(coalesce(ws_quantity#3, 0))#29 as decimal(15,4)))), DecimalType(35,20)) AS return_ratio#33, CheckOverflow((promote_precision(cast(sum(coalesce(cast(wr_return_amt#12 as decimal(12,2)), 0.00))#30 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cast(ws_net_paid#4 as decimal(12,2)), 0.00))#31 as decimal(15,4)))), DecimalType(35,20)) AS currency_ratio#34] (18) Exchange Input [3]: [item#32, return_ratio#33, currency_ratio#34] @@ -264,7 +264,7 @@ Input [7]: [cs_item_sk#39, sum#58, sum#59, sum#60, isEmpty#61, sum#62, isEmpty#6 Keys [1]: [cs_item_sk#39] Functions [4]: [sum(coalesce(cr_return_quantity#48, 0)), sum(coalesce(cs_quantity#41, 0)), sum(coalesce(cast(cr_return_amount#49 as decimal(12,2)), 0.00)), sum(coalesce(cast(cs_net_paid#42 as decimal(12,2)), 0.00))] Aggregate Attributes [4]: [sum(coalesce(cr_return_quantity#48, 0))#65, sum(coalesce(cs_quantity#41, 0))#66, sum(coalesce(cast(cr_return_amount#49 as decimal(12,2)), 0.00))#67, sum(coalesce(cast(cs_net_paid#42 as decimal(12,2)), 0.00))#68] -Results [3]: [cs_item_sk#39 AS item#69, CheckOverflow((promote_precision(cast(sum(coalesce(cr_return_quantity#48, 0))#65 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cs_quantity#41, 0))#66 as decimal(15,4)))), DecimalType(35,20), true) AS return_ratio#70, CheckOverflow((promote_precision(cast(sum(coalesce(cast(cr_return_amount#49 as decimal(12,2)), 0.00))#67 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cast(cs_net_paid#42 as decimal(12,2)), 0.00))#68 as decimal(15,4)))), DecimalType(35,20), true) AS currency_ratio#71] +Results [3]: [cs_item_sk#39 AS item#69, CheckOverflow((promote_precision(cast(sum(coalesce(cr_return_quantity#48, 0))#65 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cs_quantity#41, 0))#66 as decimal(15,4)))), DecimalType(35,20)) AS return_ratio#70, CheckOverflow((promote_precision(cast(sum(coalesce(cast(cr_return_amount#49 as decimal(12,2)), 0.00))#67 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cast(cs_net_paid#42 as decimal(12,2)), 0.00))#68 as decimal(15,4)))), DecimalType(35,20)) AS currency_ratio#71] (42) Exchange Input [3]: [item#69, return_ratio#70, currency_ratio#71] @@ -372,7 +372,7 @@ Input [7]: [ss_item_sk#76, sum#95, sum#96, sum#97, isEmpty#98, sum#99, isEmpty#1 Keys [1]: [ss_item_sk#76] Functions [4]: [sum(coalesce(sr_return_quantity#85, 0)), sum(coalesce(ss_quantity#78, 0)), sum(coalesce(cast(sr_return_amt#86 as decimal(12,2)), 0.00)), sum(coalesce(cast(ss_net_paid#79 as decimal(12,2)), 0.00))] Aggregate Attributes [4]: [sum(coalesce(sr_return_quantity#85, 0))#102, sum(coalesce(ss_quantity#78, 0))#103, sum(coalesce(cast(sr_return_amt#86 as decimal(12,2)), 0.00))#104, sum(coalesce(cast(ss_net_paid#79 as decimal(12,2)), 0.00))#105] -Results [3]: [ss_item_sk#76 AS item#106, CheckOverflow((promote_precision(cast(sum(coalesce(sr_return_quantity#85, 0))#102 as decimal(15,4))) / promote_precision(cast(sum(coalesce(ss_quantity#78, 0))#103 as decimal(15,4)))), DecimalType(35,20), true) AS return_ratio#107, CheckOverflow((promote_precision(cast(sum(coalesce(cast(sr_return_amt#86 as decimal(12,2)), 0.00))#104 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cast(ss_net_paid#79 as decimal(12,2)), 0.00))#105 as decimal(15,4)))), DecimalType(35,20), true) AS currency_ratio#108] +Results [3]: [ss_item_sk#76 AS item#106, CheckOverflow((promote_precision(cast(sum(coalesce(sr_return_quantity#85, 0))#102 as decimal(15,4))) / promote_precision(cast(sum(coalesce(ss_quantity#78, 0))#103 as decimal(15,4)))), DecimalType(35,20)) AS return_ratio#107, CheckOverflow((promote_precision(cast(sum(coalesce(cast(sr_return_amt#86 as decimal(12,2)), 0.00))#104 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cast(ss_net_paid#79 as decimal(12,2)), 0.00))#105 as decimal(15,4)))), DecimalType(35,20)) AS currency_ratio#108] (66) Exchange Input [3]: [item#106, return_ratio#107, currency_ratio#108] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5.sf100/explain.txt index 0690c363a98e7..29a88fbab1b3c 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5.sf100/explain.txt @@ -173,7 +173,7 @@ Input [5]: [s_store_id#23, sum#30, sum#31, sum#32, sum#33] Keys [1]: [s_store_id#23] Functions [4]: [sum(UnscaledValue(sales_price#8)), sum(UnscaledValue(return_amt#10)), sum(UnscaledValue(profit#9)), sum(UnscaledValue(net_loss#11))] Aggregate Attributes [4]: [sum(UnscaledValue(sales_price#8))#35, sum(UnscaledValue(return_amt#10))#36, sum(UnscaledValue(profit#9))#37, sum(UnscaledValue(net_loss#11))#38] -Results [5]: [MakeDecimal(sum(UnscaledValue(sales_price#8))#35,17,2) AS sales#39, MakeDecimal(sum(UnscaledValue(return_amt#10))#36,17,2) AS returns#40, CheckOverflow((promote_precision(cast(MakeDecimal(sum(UnscaledValue(profit#9))#37,17,2) as decimal(18,2))) - promote_precision(cast(MakeDecimal(sum(UnscaledValue(net_loss#11))#38,17,2) as decimal(18,2)))), DecimalType(18,2), true) AS profit#41, store channel AS channel#42, concat(store, s_store_id#23) AS id#43] +Results [5]: [MakeDecimal(sum(UnscaledValue(sales_price#8))#35,17,2) AS sales#39, MakeDecimal(sum(UnscaledValue(return_amt#10))#36,17,2) AS returns#40, CheckOverflow((promote_precision(cast(MakeDecimal(sum(UnscaledValue(profit#9))#37,17,2) as decimal(18,2))) - promote_precision(cast(MakeDecimal(sum(UnscaledValue(net_loss#11))#38,17,2) as decimal(18,2)))), DecimalType(18,2)) AS profit#41, store channel AS channel#42, concat(store, s_store_id#23) AS id#43] (22) Scan parquet default.catalog_sales Output [4]: [cs_catalog_page_sk#44, cs_ext_sales_price#45, cs_net_profit#46, cs_sold_date_sk#47] @@ -270,7 +270,7 @@ Input [5]: [cp_catalog_page_id#65, sum#72, sum#73, sum#74, sum#75] Keys [1]: [cp_catalog_page_id#65] Functions [4]: [sum(UnscaledValue(sales_price#50)), sum(UnscaledValue(return_amt#52)), sum(UnscaledValue(profit#51)), sum(UnscaledValue(net_loss#53))] Aggregate Attributes [4]: [sum(UnscaledValue(sales_price#50))#77, sum(UnscaledValue(return_amt#52))#78, sum(UnscaledValue(profit#51))#79, sum(UnscaledValue(net_loss#53))#80] -Results [5]: [MakeDecimal(sum(UnscaledValue(sales_price#50))#77,17,2) AS sales#81, MakeDecimal(sum(UnscaledValue(return_amt#52))#78,17,2) AS returns#82, CheckOverflow((promote_precision(cast(MakeDecimal(sum(UnscaledValue(profit#51))#79,17,2) as decimal(18,2))) - promote_precision(cast(MakeDecimal(sum(UnscaledValue(net_loss#53))#80,17,2) as decimal(18,2)))), DecimalType(18,2), true) AS profit#83, catalog channel AS channel#84, concat(catalog_page, cp_catalog_page_id#65) AS id#85] +Results [5]: [MakeDecimal(sum(UnscaledValue(sales_price#50))#77,17,2) AS sales#81, MakeDecimal(sum(UnscaledValue(return_amt#52))#78,17,2) AS returns#82, CheckOverflow((promote_precision(cast(MakeDecimal(sum(UnscaledValue(profit#51))#79,17,2) as decimal(18,2))) - promote_precision(cast(MakeDecimal(sum(UnscaledValue(net_loss#53))#80,17,2) as decimal(18,2)))), DecimalType(18,2)) AS profit#83, catalog channel AS channel#84, concat(catalog_page, cp_catalog_page_id#65) AS id#85] (43) Scan parquet default.web_sales Output [4]: [ws_web_site_sk#86, ws_ext_sales_price#87, ws_net_profit#88, ws_sold_date_sk#89] @@ -401,7 +401,7 @@ Input [5]: [web_site_id#114, sum#121, sum#122, sum#123, sum#124] Keys [1]: [web_site_id#114] Functions [4]: [sum(UnscaledValue(sales_price#92)), sum(UnscaledValue(return_amt#94)), sum(UnscaledValue(profit#93)), sum(UnscaledValue(net_loss#95))] Aggregate Attributes [4]: [sum(UnscaledValue(sales_price#92))#126, sum(UnscaledValue(return_amt#94))#127, sum(UnscaledValue(profit#93))#128, sum(UnscaledValue(net_loss#95))#129] -Results [5]: [MakeDecimal(sum(UnscaledValue(sales_price#92))#126,17,2) AS sales#130, MakeDecimal(sum(UnscaledValue(return_amt#94))#127,17,2) AS returns#131, CheckOverflow((promote_precision(cast(MakeDecimal(sum(UnscaledValue(profit#93))#128,17,2) as decimal(18,2))) - promote_precision(cast(MakeDecimal(sum(UnscaledValue(net_loss#95))#129,17,2) as decimal(18,2)))), DecimalType(18,2), true) AS profit#132, web channel AS channel#133, concat(web_site, web_site_id#114) AS id#134] +Results [5]: [MakeDecimal(sum(UnscaledValue(sales_price#92))#126,17,2) AS sales#130, MakeDecimal(sum(UnscaledValue(return_amt#94))#127,17,2) AS returns#131, CheckOverflow((promote_precision(cast(MakeDecimal(sum(UnscaledValue(profit#93))#128,17,2) as decimal(18,2))) - promote_precision(cast(MakeDecimal(sum(UnscaledValue(net_loss#95))#129,17,2) as decimal(18,2)))), DecimalType(18,2)) AS profit#132, web channel AS channel#133, concat(web_site, web_site_id#114) AS id#134] (72) Union diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5/explain.txt index 693a853440d32..a9e5929f70b54 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5/explain.txt @@ -170,7 +170,7 @@ Input [5]: [s_store_id#24, sum#30, sum#31, sum#32, sum#33] Keys [1]: [s_store_id#24] Functions [4]: [sum(UnscaledValue(sales_price#8)), sum(UnscaledValue(return_amt#10)), sum(UnscaledValue(profit#9)), sum(UnscaledValue(net_loss#11))] Aggregate Attributes [4]: [sum(UnscaledValue(sales_price#8))#35, sum(UnscaledValue(return_amt#10))#36, sum(UnscaledValue(profit#9))#37, sum(UnscaledValue(net_loss#11))#38] -Results [5]: [MakeDecimal(sum(UnscaledValue(sales_price#8))#35,17,2) AS sales#39, MakeDecimal(sum(UnscaledValue(return_amt#10))#36,17,2) AS returns#40, CheckOverflow((promote_precision(cast(MakeDecimal(sum(UnscaledValue(profit#9))#37,17,2) as decimal(18,2))) - promote_precision(cast(MakeDecimal(sum(UnscaledValue(net_loss#11))#38,17,2) as decimal(18,2)))), DecimalType(18,2), true) AS profit#41, store channel AS channel#42, concat(store, s_store_id#24) AS id#43] +Results [5]: [MakeDecimal(sum(UnscaledValue(sales_price#8))#35,17,2) AS sales#39, MakeDecimal(sum(UnscaledValue(return_amt#10))#36,17,2) AS returns#40, CheckOverflow((promote_precision(cast(MakeDecimal(sum(UnscaledValue(profit#9))#37,17,2) as decimal(18,2))) - promote_precision(cast(MakeDecimal(sum(UnscaledValue(net_loss#11))#38,17,2) as decimal(18,2)))), DecimalType(18,2)) AS profit#41, store channel AS channel#42, concat(store, s_store_id#24) AS id#43] (22) Scan parquet default.catalog_sales Output [4]: [cs_catalog_page_sk#44, cs_ext_sales_price#45, cs_net_profit#46, cs_sold_date_sk#47] @@ -267,7 +267,7 @@ Input [5]: [cp_catalog_page_id#66, sum#72, sum#73, sum#74, sum#75] Keys [1]: [cp_catalog_page_id#66] Functions [4]: [sum(UnscaledValue(sales_price#50)), sum(UnscaledValue(return_amt#52)), sum(UnscaledValue(profit#51)), sum(UnscaledValue(net_loss#53))] Aggregate Attributes [4]: [sum(UnscaledValue(sales_price#50))#77, sum(UnscaledValue(return_amt#52))#78, sum(UnscaledValue(profit#51))#79, sum(UnscaledValue(net_loss#53))#80] -Results [5]: [MakeDecimal(sum(UnscaledValue(sales_price#50))#77,17,2) AS sales#81, MakeDecimal(sum(UnscaledValue(return_amt#52))#78,17,2) AS returns#82, CheckOverflow((promote_precision(cast(MakeDecimal(sum(UnscaledValue(profit#51))#79,17,2) as decimal(18,2))) - promote_precision(cast(MakeDecimal(sum(UnscaledValue(net_loss#53))#80,17,2) as decimal(18,2)))), DecimalType(18,2), true) AS profit#83, catalog channel AS channel#84, concat(catalog_page, cp_catalog_page_id#66) AS id#85] +Results [5]: [MakeDecimal(sum(UnscaledValue(sales_price#50))#77,17,2) AS sales#81, MakeDecimal(sum(UnscaledValue(return_amt#52))#78,17,2) AS returns#82, CheckOverflow((promote_precision(cast(MakeDecimal(sum(UnscaledValue(profit#51))#79,17,2) as decimal(18,2))) - promote_precision(cast(MakeDecimal(sum(UnscaledValue(net_loss#53))#80,17,2) as decimal(18,2)))), DecimalType(18,2)) AS profit#83, catalog channel AS channel#84, concat(catalog_page, cp_catalog_page_id#66) AS id#85] (43) Scan parquet default.web_sales Output [4]: [ws_web_site_sk#86, ws_ext_sales_price#87, ws_net_profit#88, ws_sold_date_sk#89] @@ -386,7 +386,7 @@ Input [5]: [web_site_id#114, sum#120, sum#121, sum#122, sum#123] Keys [1]: [web_site_id#114] Functions [4]: [sum(UnscaledValue(sales_price#92)), sum(UnscaledValue(return_amt#94)), sum(UnscaledValue(profit#93)), sum(UnscaledValue(net_loss#95))] Aggregate Attributes [4]: [sum(UnscaledValue(sales_price#92))#125, sum(UnscaledValue(return_amt#94))#126, sum(UnscaledValue(profit#93))#127, sum(UnscaledValue(net_loss#95))#128] -Results [5]: [MakeDecimal(sum(UnscaledValue(sales_price#92))#125,17,2) AS sales#129, MakeDecimal(sum(UnscaledValue(return_amt#94))#126,17,2) AS returns#130, CheckOverflow((promote_precision(cast(MakeDecimal(sum(UnscaledValue(profit#93))#127,17,2) as decimal(18,2))) - promote_precision(cast(MakeDecimal(sum(UnscaledValue(net_loss#95))#128,17,2) as decimal(18,2)))), DecimalType(18,2), true) AS profit#131, web channel AS channel#132, concat(web_site, web_site_id#114) AS id#133] +Results [5]: [MakeDecimal(sum(UnscaledValue(sales_price#92))#125,17,2) AS sales#129, MakeDecimal(sum(UnscaledValue(return_amt#94))#126,17,2) AS returns#130, CheckOverflow((promote_precision(cast(MakeDecimal(sum(UnscaledValue(profit#93))#127,17,2) as decimal(18,2))) - promote_precision(cast(MakeDecimal(sum(UnscaledValue(net_loss#95))#128,17,2) as decimal(18,2)))), DecimalType(18,2)) AS profit#131, web channel AS channel#132, concat(web_site, web_site_id#114) AS id#133] (69) Union diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q53.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q53.sf100/explain.txt index ea800b099f46a..694852c3ed6b0 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q53.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q53.sf100/explain.txt @@ -146,7 +146,7 @@ Arguments: [avg(_w0#25) windowspecdefinition(i_manufact_id#5, specifiedwindowfra (26) Filter [codegen id : 7] Input [4]: [i_manufact_id#5, sum_sales#24, _w0#25, avg_quarterly_sales#27] -Condition : (isnotnull(avg_quarterly_sales#27) AND ((avg_quarterly_sales#27 > 0.000000) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#24 as decimal(22,6))) - promote_precision(cast(avg_quarterly_sales#27 as decimal(22,6)))), DecimalType(22,6), true), false)) / promote_precision(cast(avg_quarterly_sales#27 as decimal(22,6)))), DecimalType(38,16), true) > 0.1000000000000000))) +Condition : (isnotnull(avg_quarterly_sales#27) AND ((avg_quarterly_sales#27 > 0.000000) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#24 as decimal(22,6))) - promote_precision(cast(avg_quarterly_sales#27 as decimal(22,6)))), DecimalType(22,6)))) / promote_precision(cast(avg_quarterly_sales#27 as decimal(22,6)))), DecimalType(38,16)) > 0.1000000000000000))) (27) Project [codegen id : 7] Output [3]: [i_manufact_id#5, sum_sales#24, avg_quarterly_sales#27] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q53/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q53/explain.txt index a2c5cba8b3548..91364dcce16e4 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q53/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q53/explain.txt @@ -146,7 +146,7 @@ Arguments: [avg(_w0#25) windowspecdefinition(i_manufact_id#5, specifiedwindowfra (26) Filter [codegen id : 7] Input [4]: [i_manufact_id#5, sum_sales#24, _w0#25, avg_quarterly_sales#27] -Condition : (isnotnull(avg_quarterly_sales#27) AND ((avg_quarterly_sales#27 > 0.000000) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#24 as decimal(22,6))) - promote_precision(cast(avg_quarterly_sales#27 as decimal(22,6)))), DecimalType(22,6), true), false)) / promote_precision(cast(avg_quarterly_sales#27 as decimal(22,6)))), DecimalType(38,16), true) > 0.1000000000000000))) +Condition : (isnotnull(avg_quarterly_sales#27) AND ((avg_quarterly_sales#27 > 0.000000) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#24 as decimal(22,6))) - promote_precision(cast(avg_quarterly_sales#27 as decimal(22,6)))), DecimalType(22,6)))) / promote_precision(cast(avg_quarterly_sales#27 as decimal(22,6)))), DecimalType(38,16)) > 0.1000000000000000))) (27) Project [codegen id : 7] Output [3]: [i_manufact_id#5, sum_sales#24, avg_quarterly_sales#27] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q54.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q54.sf100/explain.txt index b15ae61d824d4..543281ef9100e 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q54.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q54.sf100/explain.txt @@ -310,7 +310,7 @@ Input [2]: [c_customer_sk#27, sum#37] Keys [1]: [c_customer_sk#27] Functions [1]: [sum(UnscaledValue(ss_ext_sales_price#31))] Aggregate Attributes [1]: [sum(UnscaledValue(ss_ext_sales_price#31))#38] -Results [1]: [cast(CheckOverflow((promote_precision(MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#31))#38,17,2)) / 50.00), DecimalType(21,6), true) as int) AS segment#39] +Results [1]: [cast(CheckOverflow((promote_precision(MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#31))#38,17,2)) / 50.00), DecimalType(21,6)) as int) AS segment#39] (56) HashAggregate [codegen id : 15] Input [1]: [segment#39] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q54/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q54/explain.txt index ed5cd21140cad..4c65587bee530 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q54/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q54/explain.txt @@ -295,7 +295,7 @@ Input [2]: [c_customer_sk#19, sum#37] Keys [1]: [c_customer_sk#19] Functions [1]: [sum(UnscaledValue(ss_ext_sales_price#24))] Aggregate Attributes [1]: [sum(UnscaledValue(ss_ext_sales_price#24))#39] -Results [1]: [cast(CheckOverflow((promote_precision(MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#24))#39,17,2)) / 50.00), DecimalType(21,6), true) as int) AS segment#40] +Results [1]: [cast(CheckOverflow((promote_precision(MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#24))#39,17,2)) / 50.00), DecimalType(21,6)) as int) AS segment#40] (53) HashAggregate [codegen id : 12] Input [1]: [segment#40] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q57.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q57.sf100/explain.txt index ad356d44af668..0b933a733f888 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q57.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q57.sf100/explain.txt @@ -1,53 +1,56 @@ == Physical Plan == -TakeOrderedAndProject (49) -+- * Project (48) - +- * SortMergeJoin Inner (47) - :- * Project (41) - : +- * SortMergeJoin Inner (40) - : :- * Sort (32) - : : +- * Project (31) - : : +- * Filter (30) - : : +- Window (29) - : : +- * Filter (28) - : : +- Window (27) - : : +- * Sort (26) - : : +- Exchange (25) - : : +- * HashAggregate (24) - : : +- Exchange (23) - : : +- * HashAggregate (22) - : : +- * Project (21) - : : +- * SortMergeJoin Inner (20) - : : :- * Sort (14) - : : : +- Exchange (13) - : : : +- * Project (12) - : : : +- * BroadcastHashJoin Inner BuildRight (11) - : : : :- * Project (6) - : : : : +- * BroadcastHashJoin Inner BuildRight (5) - : : : : :- * Filter (3) - : : : : : +- * ColumnarToRow (2) - : : : : : +- Scan parquet default.catalog_sales (1) - : : : : +- ReusedExchange (4) - : : : +- BroadcastExchange (10) - : : : +- * Filter (9) - : : : +- * ColumnarToRow (8) - : : : +- Scan parquet default.call_center (7) - : : +- * Sort (19) - : : +- Exchange (18) - : : +- * Filter (17) - : : +- * ColumnarToRow (16) - : : +- Scan parquet default.item (15) - : +- * Sort (39) - : +- * Project (38) - : +- Window (37) - : +- * Sort (36) - : +- Exchange (35) - : +- * HashAggregate (34) - : +- ReusedExchange (33) - +- * Sort (46) - +- * Project (45) - +- Window (44) - +- * Sort (43) - +- ReusedExchange (42) +TakeOrderedAndProject (52) ++- * Project (51) + +- * SortMergeJoin Inner (50) + :- * Project (43) + : +- * SortMergeJoin Inner (42) + : :- * Sort (33) + : : +- Exchange (32) + : : +- * Project (31) + : : +- * Filter (30) + : : +- Window (29) + : : +- * Filter (28) + : : +- Window (27) + : : +- * Sort (26) + : : +- Exchange (25) + : : +- * HashAggregate (24) + : : +- Exchange (23) + : : +- * HashAggregate (22) + : : +- * Project (21) + : : +- * SortMergeJoin Inner (20) + : : :- * Sort (14) + : : : +- Exchange (13) + : : : +- * Project (12) + : : : +- * BroadcastHashJoin Inner BuildRight (11) + : : : :- * Project (6) + : : : : +- * BroadcastHashJoin Inner BuildRight (5) + : : : : :- * Filter (3) + : : : : : +- * ColumnarToRow (2) + : : : : : +- Scan parquet default.catalog_sales (1) + : : : : +- ReusedExchange (4) + : : : +- BroadcastExchange (10) + : : : +- * Filter (9) + : : : +- * ColumnarToRow (8) + : : : +- Scan parquet default.call_center (7) + : : +- * Sort (19) + : : +- Exchange (18) + : : +- * Filter (17) + : : +- * ColumnarToRow (16) + : : +- Scan parquet default.item (15) + : +- * Sort (41) + : +- Exchange (40) + : +- * Project (39) + : +- Window (38) + : +- * Sort (37) + : +- Exchange (36) + : +- * HashAggregate (35) + : +- ReusedExchange (34) + +- * Sort (49) + +- Exchange (48) + +- * Project (47) + +- Window (46) + +- * Sort (45) + +- ReusedExchange (44) (1) Scan parquet default.catalog_sales @@ -65,7 +68,7 @@ Input [4]: [cs_call_center_sk#1, cs_item_sk#2, cs_sales_price#3, cs_sold_date_sk Input [4]: [cs_call_center_sk#1, cs_item_sk#2, cs_sales_price#3, cs_sold_date_sk#4] Condition : (isnotnull(cs_item_sk#2) AND isnotnull(cs_call_center_sk#1)) -(4) ReusedExchange [Reuses operator id: 53] +(4) ReusedExchange [Reuses operator id: 56] Output [3]: [d_date_sk#6, d_year#7, d_moy#8] (5) BroadcastHashJoin [codegen id : 3] @@ -183,112 +186,124 @@ Arguments: [avg(_w0#22) windowspecdefinition(i_category#15, i_brand#14, cc_name# (30) Filter [codegen id : 11] Input [9]: [i_category#15, i_brand#14, cc_name#10, d_year#7, d_moy#8, sum_sales#21, _w0#22, rn#24, avg_monthly_sales#25] -Condition : ((isnotnull(avg_monthly_sales#25) AND (avg_monthly_sales#25 > 0.000000)) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#25 as decimal(22,6)))), DecimalType(22,6), true), false)) / promote_precision(cast(avg_monthly_sales#25 as decimal(22,6)))), DecimalType(38,16), true) > 0.1000000000000000)) +Condition : ((isnotnull(avg_monthly_sales#25) AND (avg_monthly_sales#25 > 0.000000)) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#25 as decimal(22,6)))), DecimalType(22,6)))) / promote_precision(cast(avg_monthly_sales#25 as decimal(22,6)))), DecimalType(38,16)) > 0.1000000000000000)) (31) Project [codegen id : 11] Output [8]: [i_category#15, i_brand#14, cc_name#10, d_year#7, d_moy#8, sum_sales#21, avg_monthly_sales#25, rn#24] Input [9]: [i_category#15, i_brand#14, cc_name#10, d_year#7, d_moy#8, sum_sales#21, _w0#22, rn#24, avg_monthly_sales#25] -(32) Sort [codegen id : 11] +(32) Exchange +Input [8]: [i_category#15, i_brand#14, cc_name#10, d_year#7, d_moy#8, sum_sales#21, avg_monthly_sales#25, rn#24] +Arguments: hashpartitioning(i_category#15, i_brand#14, cc_name#10, rn#24, 5), ENSURE_REQUIREMENTS, [id=#26] + +(33) Sort [codegen id : 12] Input [8]: [i_category#15, i_brand#14, cc_name#10, d_year#7, d_moy#8, sum_sales#21, avg_monthly_sales#25, rn#24] Arguments: [i_category#15 ASC NULLS FIRST, i_brand#14 ASC NULLS FIRST, cc_name#10 ASC NULLS FIRST, rn#24 ASC NULLS FIRST], false, 0 -(33) ReusedExchange [Reuses operator id: 23] -Output [6]: [i_category#26, i_brand#27, cc_name#28, d_year#29, d_moy#30, sum#31] +(34) ReusedExchange [Reuses operator id: 23] +Output [6]: [i_category#27, i_brand#28, cc_name#29, d_year#30, d_moy#31, sum#32] -(34) HashAggregate [codegen id : 19] -Input [6]: [i_category#26, i_brand#27, cc_name#28, d_year#29, d_moy#30, sum#31] -Keys [5]: [i_category#26, i_brand#27, cc_name#28, d_year#29, d_moy#30] -Functions [1]: [sum(UnscaledValue(cs_sales_price#32))] -Aggregate Attributes [1]: [sum(UnscaledValue(cs_sales_price#32))#20] -Results [6]: [i_category#26, i_brand#27, cc_name#28, d_year#29, d_moy#30, MakeDecimal(sum(UnscaledValue(cs_sales_price#32))#20,17,2) AS sum_sales#21] +(35) HashAggregate [codegen id : 20] +Input [6]: [i_category#27, i_brand#28, cc_name#29, d_year#30, d_moy#31, sum#32] +Keys [5]: [i_category#27, i_brand#28, cc_name#29, d_year#30, d_moy#31] +Functions [1]: [sum(UnscaledValue(cs_sales_price#33))] +Aggregate Attributes [1]: [sum(UnscaledValue(cs_sales_price#33))#20] +Results [6]: [i_category#27, i_brand#28, cc_name#29, d_year#30, d_moy#31, MakeDecimal(sum(UnscaledValue(cs_sales_price#33))#20,17,2) AS sum_sales#21] -(35) Exchange -Input [6]: [i_category#26, i_brand#27, cc_name#28, d_year#29, d_moy#30, sum_sales#21] -Arguments: hashpartitioning(i_category#26, i_brand#27, cc_name#28, 5), ENSURE_REQUIREMENTS, [id=#33] +(36) Exchange +Input [6]: [i_category#27, i_brand#28, cc_name#29, d_year#30, d_moy#31, sum_sales#21] +Arguments: hashpartitioning(i_category#27, i_brand#28, cc_name#29, 5), ENSURE_REQUIREMENTS, [id=#34] -(36) Sort [codegen id : 20] -Input [6]: [i_category#26, i_brand#27, cc_name#28, d_year#29, d_moy#30, sum_sales#21] -Arguments: [i_category#26 ASC NULLS FIRST, i_brand#27 ASC NULLS FIRST, cc_name#28 ASC NULLS FIRST, d_year#29 ASC NULLS FIRST, d_moy#30 ASC NULLS FIRST], false, 0 +(37) Sort [codegen id : 21] +Input [6]: [i_category#27, i_brand#28, cc_name#29, d_year#30, d_moy#31, sum_sales#21] +Arguments: [i_category#27 ASC NULLS FIRST, i_brand#28 ASC NULLS FIRST, cc_name#29 ASC NULLS FIRST, d_year#30 ASC NULLS FIRST, d_moy#31 ASC NULLS FIRST], false, 0 -(37) Window -Input [6]: [i_category#26, i_brand#27, cc_name#28, d_year#29, d_moy#30, sum_sales#21] -Arguments: [rank(d_year#29, d_moy#30) windowspecdefinition(i_category#26, i_brand#27, cc_name#28, d_year#29 ASC NULLS FIRST, d_moy#30 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rn#34], [i_category#26, i_brand#27, cc_name#28], [d_year#29 ASC NULLS FIRST, d_moy#30 ASC NULLS FIRST] +(38) Window +Input [6]: [i_category#27, i_brand#28, cc_name#29, d_year#30, d_moy#31, sum_sales#21] +Arguments: [rank(d_year#30, d_moy#31) windowspecdefinition(i_category#27, i_brand#28, cc_name#29, d_year#30 ASC NULLS FIRST, d_moy#31 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rn#35], [i_category#27, i_brand#28, cc_name#29], [d_year#30 ASC NULLS FIRST, d_moy#31 ASC NULLS FIRST] -(38) Project [codegen id : 21] -Output [5]: [i_category#26, i_brand#27, cc_name#28, sum_sales#21 AS sum_sales#35, rn#34] -Input [7]: [i_category#26, i_brand#27, cc_name#28, d_year#29, d_moy#30, sum_sales#21, rn#34] +(39) Project [codegen id : 22] +Output [5]: [i_category#27, i_brand#28, cc_name#29, sum_sales#21 AS sum_sales#36, rn#35] +Input [7]: [i_category#27, i_brand#28, cc_name#29, d_year#30, d_moy#31, sum_sales#21, rn#35] -(39) Sort [codegen id : 21] -Input [5]: [i_category#26, i_brand#27, cc_name#28, sum_sales#35, rn#34] -Arguments: [i_category#26 ASC NULLS FIRST, i_brand#27 ASC NULLS FIRST, cc_name#28 ASC NULLS FIRST, (rn#34 + 1) ASC NULLS FIRST], false, 0 +(40) Exchange +Input [5]: [i_category#27, i_brand#28, cc_name#29, sum_sales#36, rn#35] +Arguments: hashpartitioning(i_category#27, i_brand#28, cc_name#29, (rn#35 + 1), 5), ENSURE_REQUIREMENTS, [id=#37] -(40) SortMergeJoin [codegen id : 22] +(41) Sort [codegen id : 23] +Input [5]: [i_category#27, i_brand#28, cc_name#29, sum_sales#36, rn#35] +Arguments: [i_category#27 ASC NULLS FIRST, i_brand#28 ASC NULLS FIRST, cc_name#29 ASC NULLS FIRST, (rn#35 + 1) ASC NULLS FIRST], false, 0 + +(42) SortMergeJoin [codegen id : 24] Left keys [4]: [i_category#15, i_brand#14, cc_name#10, rn#24] -Right keys [4]: [i_category#26, i_brand#27, cc_name#28, (rn#34 + 1)] +Right keys [4]: [i_category#27, i_brand#28, cc_name#29, (rn#35 + 1)] Join condition: None -(41) Project [codegen id : 22] -Output [9]: [i_category#15, i_brand#14, cc_name#10, d_year#7, d_moy#8, sum_sales#21, avg_monthly_sales#25, rn#24, sum_sales#35] -Input [13]: [i_category#15, i_brand#14, cc_name#10, d_year#7, d_moy#8, sum_sales#21, avg_monthly_sales#25, rn#24, i_category#26, i_brand#27, cc_name#28, sum_sales#35, rn#34] +(43) Project [codegen id : 24] +Output [9]: [i_category#15, i_brand#14, cc_name#10, d_year#7, d_moy#8, sum_sales#21, avg_monthly_sales#25, rn#24, sum_sales#36] +Input [13]: [i_category#15, i_brand#14, cc_name#10, d_year#7, d_moy#8, sum_sales#21, avg_monthly_sales#25, rn#24, i_category#27, i_brand#28, cc_name#29, sum_sales#36, rn#35] + +(44) ReusedExchange [Reuses operator id: 36] +Output [6]: [i_category#38, i_brand#39, cc_name#40, d_year#41, d_moy#42, sum_sales#21] -(42) ReusedExchange [Reuses operator id: 35] -Output [6]: [i_category#36, i_brand#37, cc_name#38, d_year#39, d_moy#40, sum_sales#21] +(45) Sort [codegen id : 33] +Input [6]: [i_category#38, i_brand#39, cc_name#40, d_year#41, d_moy#42, sum_sales#21] +Arguments: [i_category#38 ASC NULLS FIRST, i_brand#39 ASC NULLS FIRST, cc_name#40 ASC NULLS FIRST, d_year#41 ASC NULLS FIRST, d_moy#42 ASC NULLS FIRST], false, 0 -(43) Sort [codegen id : 31] -Input [6]: [i_category#36, i_brand#37, cc_name#38, d_year#39, d_moy#40, sum_sales#21] -Arguments: [i_category#36 ASC NULLS FIRST, i_brand#37 ASC NULLS FIRST, cc_name#38 ASC NULLS FIRST, d_year#39 ASC NULLS FIRST, d_moy#40 ASC NULLS FIRST], false, 0 +(46) Window +Input [6]: [i_category#38, i_brand#39, cc_name#40, d_year#41, d_moy#42, sum_sales#21] +Arguments: [rank(d_year#41, d_moy#42) windowspecdefinition(i_category#38, i_brand#39, cc_name#40, d_year#41 ASC NULLS FIRST, d_moy#42 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rn#43], [i_category#38, i_brand#39, cc_name#40], [d_year#41 ASC NULLS FIRST, d_moy#42 ASC NULLS FIRST] -(44) Window -Input [6]: [i_category#36, i_brand#37, cc_name#38, d_year#39, d_moy#40, sum_sales#21] -Arguments: [rank(d_year#39, d_moy#40) windowspecdefinition(i_category#36, i_brand#37, cc_name#38, d_year#39 ASC NULLS FIRST, d_moy#40 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rn#41], [i_category#36, i_brand#37, cc_name#38], [d_year#39 ASC NULLS FIRST, d_moy#40 ASC NULLS FIRST] +(47) Project [codegen id : 34] +Output [5]: [i_category#38, i_brand#39, cc_name#40, sum_sales#21 AS sum_sales#44, rn#43] +Input [7]: [i_category#38, i_brand#39, cc_name#40, d_year#41, d_moy#42, sum_sales#21, rn#43] -(45) Project [codegen id : 32] -Output [5]: [i_category#36, i_brand#37, cc_name#38, sum_sales#21 AS sum_sales#42, rn#41] -Input [7]: [i_category#36, i_brand#37, cc_name#38, d_year#39, d_moy#40, sum_sales#21, rn#41] +(48) Exchange +Input [5]: [i_category#38, i_brand#39, cc_name#40, sum_sales#44, rn#43] +Arguments: hashpartitioning(i_category#38, i_brand#39, cc_name#40, (rn#43 - 1), 5), ENSURE_REQUIREMENTS, [id=#45] -(46) Sort [codegen id : 32] -Input [5]: [i_category#36, i_brand#37, cc_name#38, sum_sales#42, rn#41] -Arguments: [i_category#36 ASC NULLS FIRST, i_brand#37 ASC NULLS FIRST, cc_name#38 ASC NULLS FIRST, (rn#41 - 1) ASC NULLS FIRST], false, 0 +(49) Sort [codegen id : 35] +Input [5]: [i_category#38, i_brand#39, cc_name#40, sum_sales#44, rn#43] +Arguments: [i_category#38 ASC NULLS FIRST, i_brand#39 ASC NULLS FIRST, cc_name#40 ASC NULLS FIRST, (rn#43 - 1) ASC NULLS FIRST], false, 0 -(47) SortMergeJoin [codegen id : 33] +(50) SortMergeJoin [codegen id : 36] Left keys [4]: [i_category#15, i_brand#14, cc_name#10, rn#24] -Right keys [4]: [i_category#36, i_brand#37, cc_name#38, (rn#41 - 1)] +Right keys [4]: [i_category#38, i_brand#39, cc_name#40, (rn#43 - 1)] Join condition: None -(48) Project [codegen id : 33] -Output [9]: [i_category#15, i_brand#14, cc_name#10, d_year#7, d_moy#8, avg_monthly_sales#25, sum_sales#21, sum_sales#35 AS psum#43, sum_sales#42 AS nsum#44] -Input [14]: [i_category#15, i_brand#14, cc_name#10, d_year#7, d_moy#8, sum_sales#21, avg_monthly_sales#25, rn#24, sum_sales#35, i_category#36, i_brand#37, cc_name#38, sum_sales#42, rn#41] +(51) Project [codegen id : 36] +Output [9]: [i_category#15, i_brand#14, cc_name#10, d_year#7, d_moy#8, avg_monthly_sales#25, sum_sales#21, sum_sales#36 AS psum#46, sum_sales#44 AS nsum#47] +Input [14]: [i_category#15, i_brand#14, cc_name#10, d_year#7, d_moy#8, sum_sales#21, avg_monthly_sales#25, rn#24, sum_sales#36, i_category#38, i_brand#39, cc_name#40, sum_sales#44, rn#43] -(49) TakeOrderedAndProject -Input [9]: [i_category#15, i_brand#14, cc_name#10, d_year#7, d_moy#8, avg_monthly_sales#25, sum_sales#21, psum#43, nsum#44] -Arguments: 100, [CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#25 as decimal(22,6)))), DecimalType(22,6), true) ASC NULLS FIRST, cc_name#10 ASC NULLS FIRST], [i_category#15, i_brand#14, cc_name#10, d_year#7, d_moy#8, avg_monthly_sales#25, sum_sales#21, psum#43, nsum#44] +(52) TakeOrderedAndProject +Input [9]: [i_category#15, i_brand#14, cc_name#10, d_year#7, d_moy#8, avg_monthly_sales#25, sum_sales#21, psum#46, nsum#47] +Arguments: 100, [CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#25 as decimal(22,6)))), DecimalType(22,6)) ASC NULLS FIRST, cc_name#10 ASC NULLS FIRST], [i_category#15, i_brand#14, cc_name#10, d_year#7, d_moy#8, avg_monthly_sales#25, sum_sales#21, psum#46, nsum#47] ===== Subqueries ===== Subquery:1 Hosting operator id = 1 Hosting Expression = cs_sold_date_sk#4 IN dynamicpruning#5 -BroadcastExchange (53) -+- * Filter (52) - +- * ColumnarToRow (51) - +- Scan parquet default.date_dim (50) +BroadcastExchange (56) ++- * Filter (55) + +- * ColumnarToRow (54) + +- Scan parquet default.date_dim (53) -(50) Scan parquet default.date_dim +(53) Scan parquet default.date_dim Output [3]: [d_date_sk#6, d_year#7, d_moy#8] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [Or(Or(EqualTo(d_year,1999),And(EqualTo(d_year,1998),EqualTo(d_moy,12))),And(EqualTo(d_year,2000),EqualTo(d_moy,1))), IsNotNull(d_date_sk)] ReadSchema: struct -(51) ColumnarToRow [codegen id : 1] +(54) ColumnarToRow [codegen id : 1] Input [3]: [d_date_sk#6, d_year#7, d_moy#8] -(52) Filter [codegen id : 1] +(55) Filter [codegen id : 1] Input [3]: [d_date_sk#6, d_year#7, d_moy#8] Condition : ((((d_year#7 = 1999) OR ((d_year#7 = 1998) AND (d_moy#8 = 12))) OR ((d_year#7 = 2000) AND (d_moy#8 = 1))) AND isnotnull(d_date_sk#6)) -(53) BroadcastExchange +(56) BroadcastExchange Input [3]: [d_date_sk#6, d_year#7, d_moy#8] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#45] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#48] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q57.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q57.sf100/simplified.txt index b488806fe9a07..3bf10f82e6a88 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q57.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q57.sf100/simplified.txt @@ -1,95 +1,104 @@ TakeOrderedAndProject [sum_sales,avg_monthly_sales,cc_name,i_category,i_brand,d_year,d_moy,psum,nsum] - WholeStageCodegen (33) + WholeStageCodegen (36) Project [i_category,i_brand,cc_name,d_year,d_moy,avg_monthly_sales,sum_sales,sum_sales,sum_sales] SortMergeJoin [i_category,i_brand,cc_name,rn,i_category,i_brand,cc_name,rn] InputAdapter - WholeStageCodegen (22) + WholeStageCodegen (24) Project [i_category,i_brand,cc_name,d_year,d_moy,sum_sales,avg_monthly_sales,rn,sum_sales] SortMergeJoin [i_category,i_brand,cc_name,rn,i_category,i_brand,cc_name,rn] InputAdapter - WholeStageCodegen (11) + WholeStageCodegen (12) Sort [i_category,i_brand,cc_name,rn] - Project [i_category,i_brand,cc_name,d_year,d_moy,sum_sales,avg_monthly_sales,rn] - Filter [avg_monthly_sales,sum_sales] - InputAdapter - Window [_w0,i_category,i_brand,cc_name,d_year] - WholeStageCodegen (10) - Filter [d_year] - InputAdapter - Window [d_year,d_moy,i_category,i_brand,cc_name] - WholeStageCodegen (9) - Sort [i_category,i_brand,cc_name,d_year,d_moy] - InputAdapter - Exchange [i_category,i_brand,cc_name] #1 - WholeStageCodegen (8) - HashAggregate [i_category,i_brand,cc_name,d_year,d_moy,sum] [sum(UnscaledValue(cs_sales_price)),sum_sales,_w0,sum] - InputAdapter - Exchange [i_category,i_brand,cc_name,d_year,d_moy] #2 - WholeStageCodegen (7) - HashAggregate [i_category,i_brand,cc_name,d_year,d_moy,cs_sales_price] [sum,sum] - Project [i_brand,i_category,cs_sales_price,d_year,d_moy,cc_name] - SortMergeJoin [cs_item_sk,i_item_sk] - InputAdapter - WholeStageCodegen (4) - Sort [cs_item_sk] + InputAdapter + Exchange [i_category,i_brand,cc_name,rn] #1 + WholeStageCodegen (11) + Project [i_category,i_brand,cc_name,d_year,d_moy,sum_sales,avg_monthly_sales,rn] + Filter [avg_monthly_sales,sum_sales] + InputAdapter + Window [_w0,i_category,i_brand,cc_name,d_year] + WholeStageCodegen (10) + Filter [d_year] + InputAdapter + Window [d_year,d_moy,i_category,i_brand,cc_name] + WholeStageCodegen (9) + Sort [i_category,i_brand,cc_name,d_year,d_moy] + InputAdapter + Exchange [i_category,i_brand,cc_name] #2 + WholeStageCodegen (8) + HashAggregate [i_category,i_brand,cc_name,d_year,d_moy,sum] [sum(UnscaledValue(cs_sales_price)),sum_sales,_w0,sum] + InputAdapter + Exchange [i_category,i_brand,cc_name,d_year,d_moy] #3 + WholeStageCodegen (7) + HashAggregate [i_category,i_brand,cc_name,d_year,d_moy,cs_sales_price] [sum,sum] + Project [i_brand,i_category,cs_sales_price,d_year,d_moy,cc_name] + SortMergeJoin [cs_item_sk,i_item_sk] InputAdapter - Exchange [cs_item_sk] #3 - WholeStageCodegen (3) - Project [cs_item_sk,cs_sales_price,d_year,d_moy,cc_name] - BroadcastHashJoin [cs_call_center_sk,cc_call_center_sk] - Project [cs_call_center_sk,cs_item_sk,cs_sales_price,d_year,d_moy] - BroadcastHashJoin [cs_sold_date_sk,d_date_sk] - Filter [cs_item_sk,cs_call_center_sk] - ColumnarToRow - InputAdapter - Scan parquet default.catalog_sales [cs_call_center_sk,cs_item_sk,cs_sales_price,cs_sold_date_sk] - SubqueryBroadcast [d_date_sk] #1 - BroadcastExchange #4 - WholeStageCodegen (1) - Filter [d_year,d_moy,d_date_sk] - ColumnarToRow - InputAdapter - Scan parquet default.date_dim [d_date_sk,d_year,d_moy] - InputAdapter - ReusedExchange [d_date_sk,d_year,d_moy] #4 - InputAdapter - BroadcastExchange #5 - WholeStageCodegen (2) - Filter [cc_call_center_sk,cc_name] - ColumnarToRow + WholeStageCodegen (4) + Sort [cs_item_sk] + InputAdapter + Exchange [cs_item_sk] #4 + WholeStageCodegen (3) + Project [cs_item_sk,cs_sales_price,d_year,d_moy,cc_name] + BroadcastHashJoin [cs_call_center_sk,cc_call_center_sk] + Project [cs_call_center_sk,cs_item_sk,cs_sales_price,d_year,d_moy] + BroadcastHashJoin [cs_sold_date_sk,d_date_sk] + Filter [cs_item_sk,cs_call_center_sk] + ColumnarToRow + InputAdapter + Scan parquet default.catalog_sales [cs_call_center_sk,cs_item_sk,cs_sales_price,cs_sold_date_sk] + SubqueryBroadcast [d_date_sk] #1 + BroadcastExchange #5 + WholeStageCodegen (1) + Filter [d_year,d_moy,d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_year,d_moy] InputAdapter - Scan parquet default.call_center [cc_call_center_sk,cc_name] - InputAdapter - WholeStageCodegen (6) - Sort [i_item_sk] + ReusedExchange [d_date_sk,d_year,d_moy] #5 + InputAdapter + BroadcastExchange #6 + WholeStageCodegen (2) + Filter [cc_call_center_sk,cc_name] + ColumnarToRow + InputAdapter + Scan parquet default.call_center [cc_call_center_sk,cc_name] InputAdapter - Exchange [i_item_sk] #6 - WholeStageCodegen (5) - Filter [i_item_sk,i_category,i_brand] - ColumnarToRow - InputAdapter - Scan parquet default.item [i_item_sk,i_brand,i_category] + WholeStageCodegen (6) + Sort [i_item_sk] + InputAdapter + Exchange [i_item_sk] #7 + WholeStageCodegen (5) + Filter [i_item_sk,i_category,i_brand] + ColumnarToRow + InputAdapter + Scan parquet default.item [i_item_sk,i_brand,i_category] InputAdapter - WholeStageCodegen (21) + WholeStageCodegen (23) Sort [i_category,i_brand,cc_name,rn] - Project [i_category,i_brand,cc_name,sum_sales,rn] - InputAdapter - Window [d_year,d_moy,i_category,i_brand,cc_name] - WholeStageCodegen (20) - Sort [i_category,i_brand,cc_name,d_year,d_moy] - InputAdapter - Exchange [i_category,i_brand,cc_name] #7 - WholeStageCodegen (19) - HashAggregate [i_category,i_brand,cc_name,d_year,d_moy,sum] [sum(UnscaledValue(cs_sales_price)),sum_sales,sum] - InputAdapter - ReusedExchange [i_category,i_brand,cc_name,d_year,d_moy,sum] #2 + InputAdapter + Exchange [i_category,i_brand,cc_name,rn] #8 + WholeStageCodegen (22) + Project [i_category,i_brand,cc_name,sum_sales,rn] + InputAdapter + Window [d_year,d_moy,i_category,i_brand,cc_name] + WholeStageCodegen (21) + Sort [i_category,i_brand,cc_name,d_year,d_moy] + InputAdapter + Exchange [i_category,i_brand,cc_name] #9 + WholeStageCodegen (20) + HashAggregate [i_category,i_brand,cc_name,d_year,d_moy,sum] [sum(UnscaledValue(cs_sales_price)),sum_sales,sum] + InputAdapter + ReusedExchange [i_category,i_brand,cc_name,d_year,d_moy,sum] #3 InputAdapter - WholeStageCodegen (32) + WholeStageCodegen (35) Sort [i_category,i_brand,cc_name,rn] - Project [i_category,i_brand,cc_name,sum_sales,rn] - InputAdapter - Window [d_year,d_moy,i_category,i_brand,cc_name] - WholeStageCodegen (31) - Sort [i_category,i_brand,cc_name,d_year,d_moy] - InputAdapter - ReusedExchange [i_category,i_brand,cc_name,d_year,d_moy,sum_sales] #7 + InputAdapter + Exchange [i_category,i_brand,cc_name,rn] #10 + WholeStageCodegen (34) + Project [i_category,i_brand,cc_name,sum_sales,rn] + InputAdapter + Window [d_year,d_moy,i_category,i_brand,cc_name] + WholeStageCodegen (33) + Sort [i_category,i_brand,cc_name,d_year,d_moy] + InputAdapter + ReusedExchange [i_category,i_brand,cc_name,d_year,d_moy,sum_sales] #9 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q57/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q57/explain.txt index a3b9279528ba9..6b2736ef4008f 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q57/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q57/explain.txt @@ -167,7 +167,7 @@ Arguments: [avg(_w0#21) windowspecdefinition(i_category#3, i_brand#2, cc_name#14 (27) Filter [codegen id : 22] Input [9]: [i_category#3, i_brand#2, cc_name#14, d_year#11, d_moy#12, sum_sales#20, _w0#21, rn#23, avg_monthly_sales#24] -Condition : ((isnotnull(avg_monthly_sales#24) AND (avg_monthly_sales#24 > 0.000000)) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#20 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(22,6), true), false)) / promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(38,16), true) > 0.1000000000000000)) +Condition : ((isnotnull(avg_monthly_sales#24) AND (avg_monthly_sales#24 > 0.000000)) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#20 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(22,6)))) / promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(38,16)) > 0.1000000000000000)) (28) Project [codegen id : 22] Output [8]: [i_category#3, i_brand#2, cc_name#14, d_year#11, d_moy#12, sum_sales#20, avg_monthly_sales#24, rn#23] @@ -242,7 +242,7 @@ Input [14]: [i_category#3, i_brand#2, cc_name#14, d_year#11, d_moy#12, sum_sales (45) TakeOrderedAndProject Input [9]: [i_category#3, i_brand#2, cc_name#14, d_year#11, d_moy#12, avg_monthly_sales#24, sum_sales#20, psum#44, nsum#45] -Arguments: 100, [CheckOverflow((promote_precision(cast(sum_sales#20 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(22,6), true) ASC NULLS FIRST, cc_name#14 ASC NULLS FIRST], [i_category#3, i_brand#2, cc_name#14, d_year#11, d_moy#12, avg_monthly_sales#24, sum_sales#20, psum#44, nsum#45] +Arguments: 100, [CheckOverflow((promote_precision(cast(sum_sales#20 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(22,6)) ASC NULLS FIRST, cc_name#14 ASC NULLS FIRST], [i_category#3, i_brand#2, cc_name#14, d_year#11, d_moy#12, avg_monthly_sales#24, sum_sales#20, psum#44, nsum#45] ===== Subqueries ===== diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q58.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q58.sf100/explain.txt index 8e969096c5239..abbd29292b260 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q58.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q58.sf100/explain.txt @@ -194,7 +194,7 @@ Arguments: HashedRelationBroadcastMode(List(input[0, string, true]),false), [id= (31) BroadcastHashJoin [codegen id : 15] Left keys [1]: [item_id#13] Right keys [1]: [item_id#25] -Join condition: ((((cast(ss_item_rev#14 as decimal(19,3)) >= CheckOverflow((0.90 * promote_precision(cs_item_rev#26)), DecimalType(19,3), true)) AND (cast(ss_item_rev#14 as decimal(20,3)) <= CheckOverflow((1.10 * promote_precision(cs_item_rev#26)), DecimalType(20,3), true))) AND (cast(cs_item_rev#26 as decimal(19,3)) >= CheckOverflow((0.90 * promote_precision(ss_item_rev#14)), DecimalType(19,3), true))) AND (cast(cs_item_rev#26 as decimal(20,3)) <= CheckOverflow((1.10 * promote_precision(ss_item_rev#14)), DecimalType(20,3), true))) +Join condition: ((((cast(ss_item_rev#14 as decimal(19,3)) >= CheckOverflow((0.90 * promote_precision(cs_item_rev#26)), DecimalType(19,3))) AND (cast(ss_item_rev#14 as decimal(20,3)) <= CheckOverflow((1.10 * promote_precision(cs_item_rev#26)), DecimalType(20,3)))) AND (cast(cs_item_rev#26 as decimal(19,3)) >= CheckOverflow((0.90 * promote_precision(ss_item_rev#14)), DecimalType(19,3)))) AND (cast(cs_item_rev#26 as decimal(20,3)) <= CheckOverflow((1.10 * promote_precision(ss_item_rev#14)), DecimalType(20,3)))) (32) Project [codegen id : 15] Output [3]: [item_id#13, ss_item_rev#14, cs_item_rev#26] @@ -268,10 +268,10 @@ Arguments: HashedRelationBroadcastMode(List(input[0, string, true]),false), [id= (47) BroadcastHashJoin [codegen id : 15] Left keys [1]: [item_id#13] Right keys [1]: [item_id#38] -Join condition: ((((((((cast(ss_item_rev#14 as decimal(19,3)) >= CheckOverflow((0.90 * promote_precision(ws_item_rev#39)), DecimalType(19,3), true)) AND (cast(ss_item_rev#14 as decimal(20,3)) <= CheckOverflow((1.10 * promote_precision(ws_item_rev#39)), DecimalType(20,3), true))) AND (cast(cs_item_rev#26 as decimal(19,3)) >= CheckOverflow((0.90 * promote_precision(ws_item_rev#39)), DecimalType(19,3), true))) AND (cast(cs_item_rev#26 as decimal(20,3)) <= CheckOverflow((1.10 * promote_precision(ws_item_rev#39)), DecimalType(20,3), true))) AND (cast(ws_item_rev#39 as decimal(19,3)) >= CheckOverflow((0.90 * promote_precision(ss_item_rev#14)), DecimalType(19,3), true))) AND (cast(ws_item_rev#39 as decimal(20,3)) <= CheckOverflow((1.10 * promote_precision(ss_item_rev#14)), DecimalType(20,3), true))) AND (cast(ws_item_rev#39 as decimal(19,3)) >= CheckOverflow((0.90 * promote_precision(cs_item_rev#26)), DecimalType(19,3), true))) AND (cast(ws_item_rev#39 as decimal(20,3)) <= CheckOverflow((1.10 * promote_precision(cs_item_rev#26)), DecimalType(20,3), true))) +Join condition: ((((((((cast(ss_item_rev#14 as decimal(19,3)) >= CheckOverflow((0.90 * promote_precision(ws_item_rev#39)), DecimalType(19,3))) AND (cast(ss_item_rev#14 as decimal(20,3)) <= CheckOverflow((1.10 * promote_precision(ws_item_rev#39)), DecimalType(20,3)))) AND (cast(cs_item_rev#26 as decimal(19,3)) >= CheckOverflow((0.90 * promote_precision(ws_item_rev#39)), DecimalType(19,3)))) AND (cast(cs_item_rev#26 as decimal(20,3)) <= CheckOverflow((1.10 * promote_precision(ws_item_rev#39)), DecimalType(20,3)))) AND (cast(ws_item_rev#39 as decimal(19,3)) >= CheckOverflow((0.90 * promote_precision(ss_item_rev#14)), DecimalType(19,3)))) AND (cast(ws_item_rev#39 as decimal(20,3)) <= CheckOverflow((1.10 * promote_precision(ss_item_rev#14)), DecimalType(20,3)))) AND (cast(ws_item_rev#39 as decimal(19,3)) >= CheckOverflow((0.90 * promote_precision(cs_item_rev#26)), DecimalType(19,3)))) AND (cast(ws_item_rev#39 as decimal(20,3)) <= CheckOverflow((1.10 * promote_precision(cs_item_rev#26)), DecimalType(20,3)))) (48) Project [codegen id : 15] -Output [8]: [item_id#13, ss_item_rev#14, CheckOverflow((promote_precision(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(ss_item_rev#14 as decimal(19,2))) / promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_item_rev#14 as decimal(18,2))) + promote_precision(cast(cs_item_rev#26 as decimal(18,2)))), DecimalType(18,2), true) as decimal(19,2))) + promote_precision(cast(ws_item_rev#39 as decimal(19,2)))), DecimalType(19,2), true))), DecimalType(38,21), true)) / 3.000000000000000000000), DecimalType(38,21), true)) * 100.000000000000000000000), DecimalType(38,17), true) AS ss_dev#41, cs_item_rev#26, CheckOverflow((promote_precision(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(cs_item_rev#26 as decimal(19,2))) / promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_item_rev#14 as decimal(18,2))) + promote_precision(cast(cs_item_rev#26 as decimal(18,2)))), DecimalType(18,2), true) as decimal(19,2))) + promote_precision(cast(ws_item_rev#39 as decimal(19,2)))), DecimalType(19,2), true))), DecimalType(38,21), true)) / 3.000000000000000000000), DecimalType(38,21), true)) * 100.000000000000000000000), DecimalType(38,17), true) AS cs_dev#42, ws_item_rev#39, CheckOverflow((promote_precision(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(ws_item_rev#39 as decimal(19,2))) / promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_item_rev#14 as decimal(18,2))) + promote_precision(cast(cs_item_rev#26 as decimal(18,2)))), DecimalType(18,2), true) as decimal(19,2))) + promote_precision(cast(ws_item_rev#39 as decimal(19,2)))), DecimalType(19,2), true))), DecimalType(38,21), true)) / 3.000000000000000000000), DecimalType(38,21), true)) * 100.000000000000000000000), DecimalType(38,17), true) AS ws_dev#43, CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_item_rev#14 as decimal(18,2))) + promote_precision(cast(cs_item_rev#26 as decimal(18,2)))), DecimalType(18,2), true) as decimal(19,2))) + promote_precision(cast(ws_item_rev#39 as decimal(19,2)))), DecimalType(19,2), true)) / 3.00), DecimalType(23,6), true) AS average#44] +Output [8]: [item_id#13, ss_item_rev#14, CheckOverflow((promote_precision(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(ss_item_rev#14 as decimal(19,2))) / promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_item_rev#14 as decimal(18,2))) + promote_precision(cast(cs_item_rev#26 as decimal(18,2)))), DecimalType(18,2)) as decimal(19,2))) + promote_precision(cast(ws_item_rev#39 as decimal(19,2)))), DecimalType(19,2)))), DecimalType(38,21))) / 3.000000000000000000000), DecimalType(38,21))) * 100.000000000000000000000), DecimalType(38,17)) AS ss_dev#41, cs_item_rev#26, CheckOverflow((promote_precision(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(cs_item_rev#26 as decimal(19,2))) / promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_item_rev#14 as decimal(18,2))) + promote_precision(cast(cs_item_rev#26 as decimal(18,2)))), DecimalType(18,2)) as decimal(19,2))) + promote_precision(cast(ws_item_rev#39 as decimal(19,2)))), DecimalType(19,2)))), DecimalType(38,21))) / 3.000000000000000000000), DecimalType(38,21))) * 100.000000000000000000000), DecimalType(38,17)) AS cs_dev#42, ws_item_rev#39, CheckOverflow((promote_precision(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(ws_item_rev#39 as decimal(19,2))) / promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_item_rev#14 as decimal(18,2))) + promote_precision(cast(cs_item_rev#26 as decimal(18,2)))), DecimalType(18,2)) as decimal(19,2))) + promote_precision(cast(ws_item_rev#39 as decimal(19,2)))), DecimalType(19,2)))), DecimalType(38,21))) / 3.000000000000000000000), DecimalType(38,21))) * 100.000000000000000000000), DecimalType(38,17)) AS ws_dev#43, CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_item_rev#14 as decimal(18,2))) + promote_precision(cast(cs_item_rev#26 as decimal(18,2)))), DecimalType(18,2)) as decimal(19,2))) + promote_precision(cast(ws_item_rev#39 as decimal(19,2)))), DecimalType(19,2))) / 3.00), DecimalType(23,6)) AS average#44] Input [5]: [item_id#13, ss_item_rev#14, cs_item_rev#26, item_id#38, ws_item_rev#39] (49) TakeOrderedAndProject diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q58/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q58/explain.txt index 67f19d31e3946..47651c0f92dca 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q58/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q58/explain.txt @@ -194,7 +194,7 @@ Arguments: HashedRelationBroadcastMode(List(input[0, string, true]),false), [id= (31) BroadcastHashJoin [codegen id : 15] Left keys [1]: [item_id#13] Right keys [1]: [item_id#25] -Join condition: ((((cast(ss_item_rev#14 as decimal(19,3)) >= CheckOverflow((0.90 * promote_precision(cs_item_rev#26)), DecimalType(19,3), true)) AND (cast(ss_item_rev#14 as decimal(20,3)) <= CheckOverflow((1.10 * promote_precision(cs_item_rev#26)), DecimalType(20,3), true))) AND (cast(cs_item_rev#26 as decimal(19,3)) >= CheckOverflow((0.90 * promote_precision(ss_item_rev#14)), DecimalType(19,3), true))) AND (cast(cs_item_rev#26 as decimal(20,3)) <= CheckOverflow((1.10 * promote_precision(ss_item_rev#14)), DecimalType(20,3), true))) +Join condition: ((((cast(ss_item_rev#14 as decimal(19,3)) >= CheckOverflow((0.90 * promote_precision(cs_item_rev#26)), DecimalType(19,3))) AND (cast(ss_item_rev#14 as decimal(20,3)) <= CheckOverflow((1.10 * promote_precision(cs_item_rev#26)), DecimalType(20,3)))) AND (cast(cs_item_rev#26 as decimal(19,3)) >= CheckOverflow((0.90 * promote_precision(ss_item_rev#14)), DecimalType(19,3)))) AND (cast(cs_item_rev#26 as decimal(20,3)) <= CheckOverflow((1.10 * promote_precision(ss_item_rev#14)), DecimalType(20,3)))) (32) Project [codegen id : 15] Output [3]: [item_id#13, ss_item_rev#14, cs_item_rev#26] @@ -268,10 +268,10 @@ Arguments: HashedRelationBroadcastMode(List(input[0, string, true]),false), [id= (47) BroadcastHashJoin [codegen id : 15] Left keys [1]: [item_id#13] Right keys [1]: [item_id#38] -Join condition: ((((((((cast(ss_item_rev#14 as decimal(19,3)) >= CheckOverflow((0.90 * promote_precision(ws_item_rev#39)), DecimalType(19,3), true)) AND (cast(ss_item_rev#14 as decimal(20,3)) <= CheckOverflow((1.10 * promote_precision(ws_item_rev#39)), DecimalType(20,3), true))) AND (cast(cs_item_rev#26 as decimal(19,3)) >= CheckOverflow((0.90 * promote_precision(ws_item_rev#39)), DecimalType(19,3), true))) AND (cast(cs_item_rev#26 as decimal(20,3)) <= CheckOverflow((1.10 * promote_precision(ws_item_rev#39)), DecimalType(20,3), true))) AND (cast(ws_item_rev#39 as decimal(19,3)) >= CheckOverflow((0.90 * promote_precision(ss_item_rev#14)), DecimalType(19,3), true))) AND (cast(ws_item_rev#39 as decimal(20,3)) <= CheckOverflow((1.10 * promote_precision(ss_item_rev#14)), DecimalType(20,3), true))) AND (cast(ws_item_rev#39 as decimal(19,3)) >= CheckOverflow((0.90 * promote_precision(cs_item_rev#26)), DecimalType(19,3), true))) AND (cast(ws_item_rev#39 as decimal(20,3)) <= CheckOverflow((1.10 * promote_precision(cs_item_rev#26)), DecimalType(20,3), true))) +Join condition: ((((((((cast(ss_item_rev#14 as decimal(19,3)) >= CheckOverflow((0.90 * promote_precision(ws_item_rev#39)), DecimalType(19,3))) AND (cast(ss_item_rev#14 as decimal(20,3)) <= CheckOverflow((1.10 * promote_precision(ws_item_rev#39)), DecimalType(20,3)))) AND (cast(cs_item_rev#26 as decimal(19,3)) >= CheckOverflow((0.90 * promote_precision(ws_item_rev#39)), DecimalType(19,3)))) AND (cast(cs_item_rev#26 as decimal(20,3)) <= CheckOverflow((1.10 * promote_precision(ws_item_rev#39)), DecimalType(20,3)))) AND (cast(ws_item_rev#39 as decimal(19,3)) >= CheckOverflow((0.90 * promote_precision(ss_item_rev#14)), DecimalType(19,3)))) AND (cast(ws_item_rev#39 as decimal(20,3)) <= CheckOverflow((1.10 * promote_precision(ss_item_rev#14)), DecimalType(20,3)))) AND (cast(ws_item_rev#39 as decimal(19,3)) >= CheckOverflow((0.90 * promote_precision(cs_item_rev#26)), DecimalType(19,3)))) AND (cast(ws_item_rev#39 as decimal(20,3)) <= CheckOverflow((1.10 * promote_precision(cs_item_rev#26)), DecimalType(20,3)))) (48) Project [codegen id : 15] -Output [8]: [item_id#13, ss_item_rev#14, CheckOverflow((promote_precision(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(ss_item_rev#14 as decimal(19,2))) / promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_item_rev#14 as decimal(18,2))) + promote_precision(cast(cs_item_rev#26 as decimal(18,2)))), DecimalType(18,2), true) as decimal(19,2))) + promote_precision(cast(ws_item_rev#39 as decimal(19,2)))), DecimalType(19,2), true))), DecimalType(38,21), true)) / 3.000000000000000000000), DecimalType(38,21), true)) * 100.000000000000000000000), DecimalType(38,17), true) AS ss_dev#41, cs_item_rev#26, CheckOverflow((promote_precision(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(cs_item_rev#26 as decimal(19,2))) / promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_item_rev#14 as decimal(18,2))) + promote_precision(cast(cs_item_rev#26 as decimal(18,2)))), DecimalType(18,2), true) as decimal(19,2))) + promote_precision(cast(ws_item_rev#39 as decimal(19,2)))), DecimalType(19,2), true))), DecimalType(38,21), true)) / 3.000000000000000000000), DecimalType(38,21), true)) * 100.000000000000000000000), DecimalType(38,17), true) AS cs_dev#42, ws_item_rev#39, CheckOverflow((promote_precision(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(ws_item_rev#39 as decimal(19,2))) / promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_item_rev#14 as decimal(18,2))) + promote_precision(cast(cs_item_rev#26 as decimal(18,2)))), DecimalType(18,2), true) as decimal(19,2))) + promote_precision(cast(ws_item_rev#39 as decimal(19,2)))), DecimalType(19,2), true))), DecimalType(38,21), true)) / 3.000000000000000000000), DecimalType(38,21), true)) * 100.000000000000000000000), DecimalType(38,17), true) AS ws_dev#43, CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_item_rev#14 as decimal(18,2))) + promote_precision(cast(cs_item_rev#26 as decimal(18,2)))), DecimalType(18,2), true) as decimal(19,2))) + promote_precision(cast(ws_item_rev#39 as decimal(19,2)))), DecimalType(19,2), true)) / 3.00), DecimalType(23,6), true) AS average#44] +Output [8]: [item_id#13, ss_item_rev#14, CheckOverflow((promote_precision(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(ss_item_rev#14 as decimal(19,2))) / promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_item_rev#14 as decimal(18,2))) + promote_precision(cast(cs_item_rev#26 as decimal(18,2)))), DecimalType(18,2)) as decimal(19,2))) + promote_precision(cast(ws_item_rev#39 as decimal(19,2)))), DecimalType(19,2)))), DecimalType(38,21))) / 3.000000000000000000000), DecimalType(38,21))) * 100.000000000000000000000), DecimalType(38,17)) AS ss_dev#41, cs_item_rev#26, CheckOverflow((promote_precision(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(cs_item_rev#26 as decimal(19,2))) / promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_item_rev#14 as decimal(18,2))) + promote_precision(cast(cs_item_rev#26 as decimal(18,2)))), DecimalType(18,2)) as decimal(19,2))) + promote_precision(cast(ws_item_rev#39 as decimal(19,2)))), DecimalType(19,2)))), DecimalType(38,21))) / 3.000000000000000000000), DecimalType(38,21))) * 100.000000000000000000000), DecimalType(38,17)) AS cs_dev#42, ws_item_rev#39, CheckOverflow((promote_precision(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(ws_item_rev#39 as decimal(19,2))) / promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_item_rev#14 as decimal(18,2))) + promote_precision(cast(cs_item_rev#26 as decimal(18,2)))), DecimalType(18,2)) as decimal(19,2))) + promote_precision(cast(ws_item_rev#39 as decimal(19,2)))), DecimalType(19,2)))), DecimalType(38,21))) / 3.000000000000000000000), DecimalType(38,21))) * 100.000000000000000000000), DecimalType(38,17)) AS ws_dev#43, CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(ss_item_rev#14 as decimal(18,2))) + promote_precision(cast(cs_item_rev#26 as decimal(18,2)))), DecimalType(18,2)) as decimal(19,2))) + promote_precision(cast(ws_item_rev#39 as decimal(19,2)))), DecimalType(19,2))) / 3.00), DecimalType(23,6)) AS average#44] Input [5]: [item_id#13, ss_item_rev#14, cs_item_rev#26, item_id#38, ws_item_rev#39] (49) TakeOrderedAndProject diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q59.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q59.sf100/explain.txt index 201ba377a7f79..1e9c240705bd8 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q59.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q59.sf100/explain.txt @@ -241,7 +241,7 @@ Right keys [2]: [s_store_id2#68, (d_week_seq2#67 - 52)] Join condition: None (43) Project [codegen id : 10] -Output [10]: [s_store_name1#44, s_store_id1#46, d_week_seq1#45, CheckOverflow((promote_precision(sun_sales1#47) / promote_precision(sun_sales2#69)), DecimalType(37,20), true) AS (sun_sales1 / sun_sales2)#77, CheckOverflow((promote_precision(mon_sales1#48) / promote_precision(mon_sales2#70)), DecimalType(37,20), true) AS (mon_sales1 / mon_sales2)#78, CheckOverflow((promote_precision(tue_sales1#49) / promote_precision(tue_sales2#71)), DecimalType(37,20), true) AS (tue_sales1 / tue_sales2)#79, CheckOverflow((promote_precision(wed_sales1#50) / promote_precision(wed_sales2#72)), DecimalType(37,20), true) AS (wed_sales1 / wed_sales2)#80, CheckOverflow((promote_precision(thu_sales1#51) / promote_precision(thu_sales2#73)), DecimalType(37,20), true) AS (thu_sales1 / thu_sales2)#81, CheckOverflow((promote_precision(fri_sales1#52) / promote_precision(fri_sales2#74)), DecimalType(37,20), true) AS (fri_sales1 / fri_sales2)#82, CheckOverflow((promote_precision(sat_sales1#53) / promote_precision(sat_sales2#75)), DecimalType(37,20), true) AS (sat_sales1 / sat_sales2)#83] +Output [10]: [s_store_name1#44, s_store_id1#46, d_week_seq1#45, CheckOverflow((promote_precision(sun_sales1#47) / promote_precision(sun_sales2#69)), DecimalType(37,20)) AS (sun_sales1 / sun_sales2)#77, CheckOverflow((promote_precision(mon_sales1#48) / promote_precision(mon_sales2#70)), DecimalType(37,20)) AS (mon_sales1 / mon_sales2)#78, CheckOverflow((promote_precision(tue_sales1#49) / promote_precision(tue_sales2#71)), DecimalType(37,20)) AS (tue_sales1 / tue_sales2)#79, CheckOverflow((promote_precision(wed_sales1#50) / promote_precision(wed_sales2#72)), DecimalType(37,20)) AS (wed_sales1 / wed_sales2)#80, CheckOverflow((promote_precision(thu_sales1#51) / promote_precision(thu_sales2#73)), DecimalType(37,20)) AS (thu_sales1 / thu_sales2)#81, CheckOverflow((promote_precision(fri_sales1#52) / promote_precision(fri_sales2#74)), DecimalType(37,20)) AS (fri_sales1 / fri_sales2)#82, CheckOverflow((promote_precision(sat_sales1#53) / promote_precision(sat_sales2#75)), DecimalType(37,20)) AS (sat_sales1 / sat_sales2)#83] Input [19]: [s_store_name1#44, d_week_seq1#45, s_store_id1#46, sun_sales1#47, mon_sales1#48, tue_sales1#49, wed_sales1#50, thu_sales1#51, fri_sales1#52, sat_sales1#53, d_week_seq2#67, s_store_id2#68, sun_sales2#69, mon_sales2#70, tue_sales2#71, wed_sales2#72, thu_sales2#73, fri_sales2#74, sat_sales2#75] (44) TakeOrderedAndProject diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q59/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q59/explain.txt index 201ba377a7f79..1e9c240705bd8 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q59/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q59/explain.txt @@ -241,7 +241,7 @@ Right keys [2]: [s_store_id2#68, (d_week_seq2#67 - 52)] Join condition: None (43) Project [codegen id : 10] -Output [10]: [s_store_name1#44, s_store_id1#46, d_week_seq1#45, CheckOverflow((promote_precision(sun_sales1#47) / promote_precision(sun_sales2#69)), DecimalType(37,20), true) AS (sun_sales1 / sun_sales2)#77, CheckOverflow((promote_precision(mon_sales1#48) / promote_precision(mon_sales2#70)), DecimalType(37,20), true) AS (mon_sales1 / mon_sales2)#78, CheckOverflow((promote_precision(tue_sales1#49) / promote_precision(tue_sales2#71)), DecimalType(37,20), true) AS (tue_sales1 / tue_sales2)#79, CheckOverflow((promote_precision(wed_sales1#50) / promote_precision(wed_sales2#72)), DecimalType(37,20), true) AS (wed_sales1 / wed_sales2)#80, CheckOverflow((promote_precision(thu_sales1#51) / promote_precision(thu_sales2#73)), DecimalType(37,20), true) AS (thu_sales1 / thu_sales2)#81, CheckOverflow((promote_precision(fri_sales1#52) / promote_precision(fri_sales2#74)), DecimalType(37,20), true) AS (fri_sales1 / fri_sales2)#82, CheckOverflow((promote_precision(sat_sales1#53) / promote_precision(sat_sales2#75)), DecimalType(37,20), true) AS (sat_sales1 / sat_sales2)#83] +Output [10]: [s_store_name1#44, s_store_id1#46, d_week_seq1#45, CheckOverflow((promote_precision(sun_sales1#47) / promote_precision(sun_sales2#69)), DecimalType(37,20)) AS (sun_sales1 / sun_sales2)#77, CheckOverflow((promote_precision(mon_sales1#48) / promote_precision(mon_sales2#70)), DecimalType(37,20)) AS (mon_sales1 / mon_sales2)#78, CheckOverflow((promote_precision(tue_sales1#49) / promote_precision(tue_sales2#71)), DecimalType(37,20)) AS (tue_sales1 / tue_sales2)#79, CheckOverflow((promote_precision(wed_sales1#50) / promote_precision(wed_sales2#72)), DecimalType(37,20)) AS (wed_sales1 / wed_sales2)#80, CheckOverflow((promote_precision(thu_sales1#51) / promote_precision(thu_sales2#73)), DecimalType(37,20)) AS (thu_sales1 / thu_sales2)#81, CheckOverflow((promote_precision(fri_sales1#52) / promote_precision(fri_sales2#74)), DecimalType(37,20)) AS (fri_sales1 / fri_sales2)#82, CheckOverflow((promote_precision(sat_sales1#53) / promote_precision(sat_sales2#75)), DecimalType(37,20)) AS (sat_sales1 / sat_sales2)#83] Input [19]: [s_store_name1#44, d_week_seq1#45, s_store_id1#46, sun_sales1#47, mon_sales1#48, tue_sales1#49, wed_sales1#50, thu_sales1#51, fri_sales1#52, sat_sales1#53, d_week_seq2#67, s_store_id2#68, sun_sales2#69, mon_sales2#70, tue_sales2#71, wed_sales2#72, thu_sales2#73, fri_sales2#74, sat_sales2#75] (44) TakeOrderedAndProject diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q61.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q61.sf100/explain.txt index 70ea372f1eb5b..e83c4be6f7e5a 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q61.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q61.sf100/explain.txt @@ -350,7 +350,7 @@ Arguments: IdentityBroadcastMode, [id=#45] Join condition: None (64) Project [codegen id : 15] -Output [3]: [promotions#30, total#44, CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(promotions#30 as decimal(15,4))) / promote_precision(cast(total#44 as decimal(15,4)))), DecimalType(35,20), true)) * 100.00000000000000000000), DecimalType(38,19), true) AS ((CAST(promotions AS DECIMAL(15,4)) / CAST(total AS DECIMAL(15,4))) * 100)#46] +Output [3]: [promotions#30, total#44, CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(promotions#30 as decimal(15,4))) / promote_precision(cast(total#44 as decimal(15,4)))), DecimalType(35,20))) * 100.00000000000000000000), DecimalType(38,19)) AS ((CAST(promotions AS DECIMAL(15,4)) / CAST(total AS DECIMAL(15,4))) * 100)#46] Input [2]: [promotions#30, total#44] ===== Subqueries ===== diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q61/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q61/explain.txt index 7e1ce65ee7236..ebf1161c7a1f0 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q61/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q61/explain.txt @@ -365,7 +365,7 @@ Arguments: IdentityBroadcastMode, [id=#47] Join condition: None (67) Project [codegen id : 15] -Output [3]: [promotions#30, total#46, CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(promotions#30 as decimal(15,4))) / promote_precision(cast(total#46 as decimal(15,4)))), DecimalType(35,20), true)) * 100.00000000000000000000), DecimalType(38,19), true) AS ((CAST(promotions AS DECIMAL(15,4)) / CAST(total AS DECIMAL(15,4))) * 100)#48] +Output [3]: [promotions#30, total#46, CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(promotions#30 as decimal(15,4))) / promote_precision(cast(total#46 as decimal(15,4)))), DecimalType(35,20))) * 100.00000000000000000000), DecimalType(38,19)) AS ((CAST(promotions AS DECIMAL(15,4)) / CAST(total AS DECIMAL(15,4))) * 100)#48] Input [2]: [promotions#30, total#46] ===== Subqueries ===== diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q63.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q63.sf100/explain.txt index 9dd05765ecd2d..fe91e93a55aba 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q63.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q63.sf100/explain.txt @@ -146,7 +146,7 @@ Arguments: [avg(_w0#25) windowspecdefinition(i_manager_id#5, specifiedwindowfram (26) Filter [codegen id : 7] Input [4]: [i_manager_id#5, sum_sales#24, _w0#25, avg_monthly_sales#27] -Condition : (isnotnull(avg_monthly_sales#27) AND ((avg_monthly_sales#27 > 0.000000) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#24 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#27 as decimal(22,6)))), DecimalType(22,6), true), false)) / promote_precision(cast(avg_monthly_sales#27 as decimal(22,6)))), DecimalType(38,16), true) > 0.1000000000000000))) +Condition : (isnotnull(avg_monthly_sales#27) AND ((avg_monthly_sales#27 > 0.000000) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#24 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#27 as decimal(22,6)))), DecimalType(22,6)))) / promote_precision(cast(avg_monthly_sales#27 as decimal(22,6)))), DecimalType(38,16)) > 0.1000000000000000))) (27) Project [codegen id : 7] Output [3]: [i_manager_id#5, sum_sales#24, avg_monthly_sales#27] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q63/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q63/explain.txt index b49e25109080e..ad0ca3ea63d42 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q63/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q63/explain.txt @@ -146,7 +146,7 @@ Arguments: [avg(_w0#25) windowspecdefinition(i_manager_id#5, specifiedwindowfram (26) Filter [codegen id : 7] Input [4]: [i_manager_id#5, sum_sales#24, _w0#25, avg_monthly_sales#27] -Condition : (isnotnull(avg_monthly_sales#27) AND ((avg_monthly_sales#27 > 0.000000) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#24 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#27 as decimal(22,6)))), DecimalType(22,6), true), false)) / promote_precision(cast(avg_monthly_sales#27 as decimal(22,6)))), DecimalType(38,16), true) > 0.1000000000000000))) +Condition : (isnotnull(avg_monthly_sales#27) AND ((avg_monthly_sales#27 > 0.000000) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#24 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#27 as decimal(22,6)))), DecimalType(22,6)))) / promote_precision(cast(avg_monthly_sales#27 as decimal(22,6)))), DecimalType(38,16)) > 0.1000000000000000))) (27) Project [codegen id : 7] Output [3]: [i_manager_id#5, sum_sales#24, avg_monthly_sales#27] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q65.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q65.sf100/explain.txt index e4baf3b296016..474967b54286a 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q65.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q65.sf100/explain.txt @@ -161,7 +161,7 @@ Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)) (24) BroadcastHashJoin [codegen id : 8] Left keys [1]: [ss_store_sk#2] Right keys [1]: [ss_store_sk#13] -Join condition: (cast(revenue#11 as decimal(23,7)) <= CheckOverflow((0.100000 * promote_precision(ave#28)), DecimalType(23,7), true)) +Join condition: (cast(revenue#11 as decimal(23,7)) <= CheckOverflow((0.100000 * promote_precision(ave#28)), DecimalType(23,7))) (25) Project [codegen id : 8] Output [3]: [ss_store_sk#2, ss_item_sk#1, revenue#11] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q65/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q65/explain.txt index 49cc9f75956a2..c7967bfa915b8 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q65/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q65/explain.txt @@ -212,7 +212,7 @@ Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)) (36) BroadcastHashJoin [codegen id : 9] Left keys [1]: [ss_store_sk#4] Right keys [1]: [ss_store_sk#22] -Join condition: (cast(revenue#13 as decimal(23,7)) <= CheckOverflow((0.100000 * promote_precision(ave#37)), DecimalType(23,7), true)) +Join condition: (cast(revenue#13 as decimal(23,7)) <= CheckOverflow((0.100000 * promote_precision(ave#37)), DecimalType(23,7))) (37) Project [codegen id : 9] Output [6]: [s_store_name#2, i_item_desc#16, revenue#13, i_current_price#17, i_wholesale_cost#18, i_brand#19] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q66.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q66.sf100/explain.txt index b59df1b7d5777..85aa68cbedd88 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q66.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q66.sf100/explain.txt @@ -172,7 +172,7 @@ Input [13]: [ws_warehouse_sk#3, ws_quantity#4, ws_ext_sales_price#5, ws_net_paid (27) HashAggregate [codegen id : 5] Input [11]: [ws_quantity#4, ws_ext_sales_price#5, ws_net_paid#6, w_warehouse_name#19, w_warehouse_sq_ft#20, w_city#21, w_county#22, w_state#23, w_country#24, d_year#16, d_moy#17] Keys [7]: [w_warehouse_name#19, w_warehouse_sq_ft#20, w_city#21, w_county#22, w_state#23, w_country#24, d_year#16] -Functions [24]: [partial_sum(CASE WHEN (d_moy#17 = 1) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 2) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 3) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 4) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 5) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 6) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 7) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 8) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 9) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 10) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 11) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 12) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 1) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 2) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 3) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 4) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 5) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 6) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 7) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 8) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 9) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 10) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 11) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 12) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)] +Functions [24]: [partial_sum(CASE WHEN (d_moy#17 = 1) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 2) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 3) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 4) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 5) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 6) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 7) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 8) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 9) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 10) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 11) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 12) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 1) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 2) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 3) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 4) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 5) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 6) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 7) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 8) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 9) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 10) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 11) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#17 = 12) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)] Aggregate Attributes [48]: [sum#26, isEmpty#27, sum#28, isEmpty#29, sum#30, isEmpty#31, sum#32, isEmpty#33, sum#34, isEmpty#35, sum#36, isEmpty#37, sum#38, isEmpty#39, sum#40, isEmpty#41, sum#42, isEmpty#43, sum#44, isEmpty#45, sum#46, isEmpty#47, sum#48, isEmpty#49, sum#50, isEmpty#51, sum#52, isEmpty#53, sum#54, isEmpty#55, sum#56, isEmpty#57, sum#58, isEmpty#59, sum#60, isEmpty#61, sum#62, isEmpty#63, sum#64, isEmpty#65, sum#66, isEmpty#67, sum#68, isEmpty#69, sum#70, isEmpty#71, sum#72, isEmpty#73] Results [55]: [w_warehouse_name#19, w_warehouse_sq_ft#20, w_city#21, w_county#22, w_state#23, w_country#24, d_year#16, sum#74, isEmpty#75, sum#76, isEmpty#77, sum#78, isEmpty#79, sum#80, isEmpty#81, sum#82, isEmpty#83, sum#84, isEmpty#85, sum#86, isEmpty#87, sum#88, isEmpty#89, sum#90, isEmpty#91, sum#92, isEmpty#93, sum#94, isEmpty#95, sum#96, isEmpty#97, sum#98, isEmpty#99, sum#100, isEmpty#101, sum#102, isEmpty#103, sum#104, isEmpty#105, sum#106, isEmpty#107, sum#108, isEmpty#109, sum#110, isEmpty#111, sum#112, isEmpty#113, sum#114, isEmpty#115, sum#116, isEmpty#117, sum#118, isEmpty#119, sum#120, isEmpty#121] @@ -183,9 +183,9 @@ Arguments: hashpartitioning(w_warehouse_name#19, w_warehouse_sq_ft#20, w_city#21 (29) HashAggregate [codegen id : 6] Input [55]: [w_warehouse_name#19, w_warehouse_sq_ft#20, w_city#21, w_county#22, w_state#23, w_country#24, d_year#16, sum#74, isEmpty#75, sum#76, isEmpty#77, sum#78, isEmpty#79, sum#80, isEmpty#81, sum#82, isEmpty#83, sum#84, isEmpty#85, sum#86, isEmpty#87, sum#88, isEmpty#89, sum#90, isEmpty#91, sum#92, isEmpty#93, sum#94, isEmpty#95, sum#96, isEmpty#97, sum#98, isEmpty#99, sum#100, isEmpty#101, sum#102, isEmpty#103, sum#104, isEmpty#105, sum#106, isEmpty#107, sum#108, isEmpty#109, sum#110, isEmpty#111, sum#112, isEmpty#113, sum#114, isEmpty#115, sum#116, isEmpty#117, sum#118, isEmpty#119, sum#120, isEmpty#121] Keys [7]: [w_warehouse_name#19, w_warehouse_sq_ft#20, w_city#21, w_county#22, w_state#23, w_country#24, d_year#16] -Functions [24]: [sum(CASE WHEN (d_moy#17 = 1) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 2) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 3) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 4) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 5) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 6) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 7) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 8) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 9) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 10) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 11) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 12) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 1) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 2) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 3) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 4) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 5) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 6) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 7) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 8) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 9) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 10) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 11) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 12) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)] -Aggregate Attributes [24]: [sum(CASE WHEN (d_moy#17 = 1) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#123, sum(CASE WHEN (d_moy#17 = 2) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#124, sum(CASE WHEN (d_moy#17 = 3) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#125, sum(CASE WHEN (d_moy#17 = 4) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#126, sum(CASE WHEN (d_moy#17 = 5) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#127, sum(CASE WHEN (d_moy#17 = 6) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#128, sum(CASE WHEN (d_moy#17 = 7) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#129, sum(CASE WHEN (d_moy#17 = 8) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#130, sum(CASE WHEN (d_moy#17 = 9) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#131, sum(CASE WHEN (d_moy#17 = 10) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#132, sum(CASE WHEN (d_moy#17 = 11) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#133, sum(CASE WHEN (d_moy#17 = 12) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#134, sum(CASE WHEN (d_moy#17 = 1) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#135, sum(CASE WHEN (d_moy#17 = 2) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#136, sum(CASE WHEN (d_moy#17 = 3) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#137, sum(CASE WHEN (d_moy#17 = 4) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#138, sum(CASE WHEN (d_moy#17 = 5) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#139, sum(CASE WHEN (d_moy#17 = 6) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#140, sum(CASE WHEN (d_moy#17 = 7) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#141, sum(CASE WHEN (d_moy#17 = 8) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#142, sum(CASE WHEN (d_moy#17 = 9) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#143, sum(CASE WHEN (d_moy#17 = 10) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#144, sum(CASE WHEN (d_moy#17 = 11) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#145, sum(CASE WHEN (d_moy#17 = 12) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#146] -Results [32]: [w_warehouse_name#19, w_warehouse_sq_ft#20, w_city#21, w_county#22, w_state#23, w_country#24, DHL,BARIAN AS ship_carriers#147, d_year#16 AS year#148, sum(CASE WHEN (d_moy#17 = 1) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#123 AS jan_sales#149, sum(CASE WHEN (d_moy#17 = 2) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#124 AS feb_sales#150, sum(CASE WHEN (d_moy#17 = 3) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#125 AS mar_sales#151, sum(CASE WHEN (d_moy#17 = 4) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#126 AS apr_sales#152, sum(CASE WHEN (d_moy#17 = 5) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#127 AS may_sales#153, sum(CASE WHEN (d_moy#17 = 6) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#128 AS jun_sales#154, sum(CASE WHEN (d_moy#17 = 7) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#129 AS jul_sales#155, sum(CASE WHEN (d_moy#17 = 8) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#130 AS aug_sales#156, sum(CASE WHEN (d_moy#17 = 9) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#131 AS sep_sales#157, sum(CASE WHEN (d_moy#17 = 10) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#132 AS oct_sales#158, sum(CASE WHEN (d_moy#17 = 11) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#133 AS nov_sales#159, sum(CASE WHEN (d_moy#17 = 12) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#134 AS dec_sales#160, sum(CASE WHEN (d_moy#17 = 1) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#135 AS jan_net#161, sum(CASE WHEN (d_moy#17 = 2) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#136 AS feb_net#162, sum(CASE WHEN (d_moy#17 = 3) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#137 AS mar_net#163, sum(CASE WHEN (d_moy#17 = 4) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#138 AS apr_net#164, sum(CASE WHEN (d_moy#17 = 5) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#139 AS may_net#165, sum(CASE WHEN (d_moy#17 = 6) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#140 AS jun_net#166, sum(CASE WHEN (d_moy#17 = 7) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#141 AS jul_net#167, sum(CASE WHEN (d_moy#17 = 8) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#142 AS aug_net#168, sum(CASE WHEN (d_moy#17 = 9) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#143 AS sep_net#169, sum(CASE WHEN (d_moy#17 = 10) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#144 AS oct_net#170, sum(CASE WHEN (d_moy#17 = 11) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#145 AS nov_net#171, sum(CASE WHEN (d_moy#17 = 12) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#146 AS dec_net#172] +Functions [24]: [sum(CASE WHEN (d_moy#17 = 1) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 2) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 3) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 4) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 5) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 6) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 7) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 8) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 9) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 10) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 11) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 12) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 1) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 2) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 3) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 4) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 5) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 6) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 7) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 8) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 9) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 10) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 11) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#17 = 12) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)] +Aggregate Attributes [24]: [sum(CASE WHEN (d_moy#17 = 1) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#123, sum(CASE WHEN (d_moy#17 = 2) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#124, sum(CASE WHEN (d_moy#17 = 3) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#125, sum(CASE WHEN (d_moy#17 = 4) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#126, sum(CASE WHEN (d_moy#17 = 5) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#127, sum(CASE WHEN (d_moy#17 = 6) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#128, sum(CASE WHEN (d_moy#17 = 7) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#129, sum(CASE WHEN (d_moy#17 = 8) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#130, sum(CASE WHEN (d_moy#17 = 9) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#131, sum(CASE WHEN (d_moy#17 = 10) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#132, sum(CASE WHEN (d_moy#17 = 11) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#133, sum(CASE WHEN (d_moy#17 = 12) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#134, sum(CASE WHEN (d_moy#17 = 1) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#135, sum(CASE WHEN (d_moy#17 = 2) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#136, sum(CASE WHEN (d_moy#17 = 3) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#137, sum(CASE WHEN (d_moy#17 = 4) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#138, sum(CASE WHEN (d_moy#17 = 5) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#139, sum(CASE WHEN (d_moy#17 = 6) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#140, sum(CASE WHEN (d_moy#17 = 7) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#141, sum(CASE WHEN (d_moy#17 = 8) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#142, sum(CASE WHEN (d_moy#17 = 9) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#143, sum(CASE WHEN (d_moy#17 = 10) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#144, sum(CASE WHEN (d_moy#17 = 11) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#145, sum(CASE WHEN (d_moy#17 = 12) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#146] +Results [32]: [w_warehouse_name#19, w_warehouse_sq_ft#20, w_city#21, w_county#22, w_state#23, w_country#24, DHL,BARIAN AS ship_carriers#147, d_year#16 AS year#148, sum(CASE WHEN (d_moy#17 = 1) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#123 AS jan_sales#149, sum(CASE WHEN (d_moy#17 = 2) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#124 AS feb_sales#150, sum(CASE WHEN (d_moy#17 = 3) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#125 AS mar_sales#151, sum(CASE WHEN (d_moy#17 = 4) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#126 AS apr_sales#152, sum(CASE WHEN (d_moy#17 = 5) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#127 AS may_sales#153, sum(CASE WHEN (d_moy#17 = 6) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#128 AS jun_sales#154, sum(CASE WHEN (d_moy#17 = 7) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#129 AS jul_sales#155, sum(CASE WHEN (d_moy#17 = 8) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#130 AS aug_sales#156, sum(CASE WHEN (d_moy#17 = 9) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#131 AS sep_sales#157, sum(CASE WHEN (d_moy#17 = 10) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#132 AS oct_sales#158, sum(CASE WHEN (d_moy#17 = 11) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#133 AS nov_sales#159, sum(CASE WHEN (d_moy#17 = 12) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#134 AS dec_sales#160, sum(CASE WHEN (d_moy#17 = 1) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#135 AS jan_net#161, sum(CASE WHEN (d_moy#17 = 2) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#136 AS feb_net#162, sum(CASE WHEN (d_moy#17 = 3) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#137 AS mar_net#163, sum(CASE WHEN (d_moy#17 = 4) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#138 AS apr_net#164, sum(CASE WHEN (d_moy#17 = 5) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#139 AS may_net#165, sum(CASE WHEN (d_moy#17 = 6) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#140 AS jun_net#166, sum(CASE WHEN (d_moy#17 = 7) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#141 AS jul_net#167, sum(CASE WHEN (d_moy#17 = 8) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#142 AS aug_net#168, sum(CASE WHEN (d_moy#17 = 9) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#143 AS sep_net#169, sum(CASE WHEN (d_moy#17 = 10) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#144 AS oct_net#170, sum(CASE WHEN (d_moy#17 = 11) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#145 AS nov_net#171, sum(CASE WHEN (d_moy#17 = 12) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#146 AS dec_net#172] (30) Scan parquet default.catalog_sales Output [7]: [cs_sold_time_sk#173, cs_ship_mode_sk#174, cs_warehouse_sk#175, cs_quantity#176, cs_sales_price#177, cs_net_paid_inc_tax#178, cs_sold_date_sk#179] @@ -253,7 +253,7 @@ Input [13]: [cs_warehouse_sk#175, cs_quantity#176, cs_sales_price#177, cs_net_pa (45) HashAggregate [codegen id : 11] Input [11]: [cs_quantity#176, cs_sales_price#177, cs_net_paid_inc_tax#178, w_warehouse_name#186, w_warehouse_sq_ft#187, w_city#188, w_county#189, w_state#190, w_country#191, d_year#183, d_moy#184] Keys [7]: [w_warehouse_name#186, w_warehouse_sq_ft#187, w_city#188, w_county#189, w_state#190, w_country#191, d_year#183] -Functions [24]: [partial_sum(CASE WHEN (d_moy#184 = 1) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 2) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 3) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 4) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 5) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 6) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 7) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 8) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 9) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 10) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 12) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 1) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 2) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 3) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 4) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 5) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 6) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 7) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 8) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 9) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 10) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 11) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 12) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)] +Functions [24]: [partial_sum(CASE WHEN (d_moy#184 = 1) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 2) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 3) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 4) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 5) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 6) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 7) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 8) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 9) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 10) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 12) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 1) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 2) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 3) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 4) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 5) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 6) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 7) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 8) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 9) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 10) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 11) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#184 = 12) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)] Aggregate Attributes [48]: [sum#192, isEmpty#193, sum#194, isEmpty#195, sum#196, isEmpty#197, sum#198, isEmpty#199, sum#200, isEmpty#201, sum#202, isEmpty#203, sum#204, isEmpty#205, sum#206, isEmpty#207, sum#208, isEmpty#209, sum#210, isEmpty#211, sum#212, isEmpty#213, sum#214, isEmpty#215, sum#216, isEmpty#217, sum#218, isEmpty#219, sum#220, isEmpty#221, sum#222, isEmpty#223, sum#224, isEmpty#225, sum#226, isEmpty#227, sum#228, isEmpty#229, sum#230, isEmpty#231, sum#232, isEmpty#233, sum#234, isEmpty#235, sum#236, isEmpty#237, sum#238, isEmpty#239] Results [55]: [w_warehouse_name#186, w_warehouse_sq_ft#187, w_city#188, w_county#189, w_state#190, w_country#191, d_year#183, sum#240, isEmpty#241, sum#242, isEmpty#243, sum#244, isEmpty#245, sum#246, isEmpty#247, sum#248, isEmpty#249, sum#250, isEmpty#251, sum#252, isEmpty#253, sum#254, isEmpty#255, sum#256, isEmpty#257, sum#258, isEmpty#259, sum#260, isEmpty#261, sum#262, isEmpty#263, sum#264, isEmpty#265, sum#266, isEmpty#267, sum#268, isEmpty#269, sum#270, isEmpty#271, sum#272, isEmpty#273, sum#274, isEmpty#275, sum#276, isEmpty#277, sum#278, isEmpty#279, sum#280, isEmpty#281, sum#282, isEmpty#283, sum#284, isEmpty#285, sum#286, isEmpty#287] @@ -264,16 +264,16 @@ Arguments: hashpartitioning(w_warehouse_name#186, w_warehouse_sq_ft#187, w_city# (47) HashAggregate [codegen id : 12] Input [55]: [w_warehouse_name#186, w_warehouse_sq_ft#187, w_city#188, w_county#189, w_state#190, w_country#191, d_year#183, sum#240, isEmpty#241, sum#242, isEmpty#243, sum#244, isEmpty#245, sum#246, isEmpty#247, sum#248, isEmpty#249, sum#250, isEmpty#251, sum#252, isEmpty#253, sum#254, isEmpty#255, sum#256, isEmpty#257, sum#258, isEmpty#259, sum#260, isEmpty#261, sum#262, isEmpty#263, sum#264, isEmpty#265, sum#266, isEmpty#267, sum#268, isEmpty#269, sum#270, isEmpty#271, sum#272, isEmpty#273, sum#274, isEmpty#275, sum#276, isEmpty#277, sum#278, isEmpty#279, sum#280, isEmpty#281, sum#282, isEmpty#283, sum#284, isEmpty#285, sum#286, isEmpty#287] Keys [7]: [w_warehouse_name#186, w_warehouse_sq_ft#187, w_city#188, w_county#189, w_state#190, w_country#191, d_year#183] -Functions [24]: [sum(CASE WHEN (d_moy#184 = 1) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 2) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 3) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 4) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 5) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 6) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 7) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 8) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 9) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 10) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 12) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 1) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 2) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 3) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 4) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 5) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 6) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 7) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 8) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 9) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 10) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 11) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 12) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)] -Aggregate Attributes [24]: [sum(CASE WHEN (d_moy#184 = 1) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#289, sum(CASE WHEN (d_moy#184 = 2) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#290, sum(CASE WHEN (d_moy#184 = 3) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#291, sum(CASE WHEN (d_moy#184 = 4) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#292, sum(CASE WHEN (d_moy#184 = 5) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#293, sum(CASE WHEN (d_moy#184 = 6) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#294, sum(CASE WHEN (d_moy#184 = 7) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#295, sum(CASE WHEN (d_moy#184 = 8) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#296, sum(CASE WHEN (d_moy#184 = 9) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#297, sum(CASE WHEN (d_moy#184 = 10) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#298, sum(CASE WHEN (d_moy#184 = 11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#299, sum(CASE WHEN (d_moy#184 = 12) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#300, sum(CASE WHEN (d_moy#184 = 1) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#301, sum(CASE WHEN (d_moy#184 = 2) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#302, sum(CASE WHEN (d_moy#184 = 3) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#303, sum(CASE WHEN (d_moy#184 = 4) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#304, sum(CASE WHEN (d_moy#184 = 5) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#305, sum(CASE WHEN (d_moy#184 = 6) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#306, sum(CASE WHEN (d_moy#184 = 7) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#307, sum(CASE WHEN (d_moy#184 = 8) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#308, sum(CASE WHEN (d_moy#184 = 9) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#309, sum(CASE WHEN (d_moy#184 = 10) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#310, sum(CASE WHEN (d_moy#184 = 11) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#311, sum(CASE WHEN (d_moy#184 = 12) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#312] -Results [32]: [w_warehouse_name#186, w_warehouse_sq_ft#187, w_city#188, w_county#189, w_state#190, w_country#191, DHL,BARIAN AS ship_carriers#313, d_year#183 AS year#314, sum(CASE WHEN (d_moy#184 = 1) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#289 AS jan_sales#315, sum(CASE WHEN (d_moy#184 = 2) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#290 AS feb_sales#316, sum(CASE WHEN (d_moy#184 = 3) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#291 AS mar_sales#317, sum(CASE WHEN (d_moy#184 = 4) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#292 AS apr_sales#318, sum(CASE WHEN (d_moy#184 = 5) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#293 AS may_sales#319, sum(CASE WHEN (d_moy#184 = 6) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#294 AS jun_sales#320, sum(CASE WHEN (d_moy#184 = 7) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#295 AS jul_sales#321, sum(CASE WHEN (d_moy#184 = 8) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#296 AS aug_sales#322, sum(CASE WHEN (d_moy#184 = 9) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#297 AS sep_sales#323, sum(CASE WHEN (d_moy#184 = 10) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#298 AS oct_sales#324, sum(CASE WHEN (d_moy#184 = 11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#299 AS nov_sales#325, sum(CASE WHEN (d_moy#184 = 12) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#300 AS dec_sales#326, sum(CASE WHEN (d_moy#184 = 1) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#301 AS jan_net#327, sum(CASE WHEN (d_moy#184 = 2) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#302 AS feb_net#328, sum(CASE WHEN (d_moy#184 = 3) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#303 AS mar_net#329, sum(CASE WHEN (d_moy#184 = 4) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#304 AS apr_net#330, sum(CASE WHEN (d_moy#184 = 5) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#305 AS may_net#331, sum(CASE WHEN (d_moy#184 = 6) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#306 AS jun_net#332, sum(CASE WHEN (d_moy#184 = 7) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#307 AS jul_net#333, sum(CASE WHEN (d_moy#184 = 8) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#308 AS aug_net#334, sum(CASE WHEN (d_moy#184 = 9) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#309 AS sep_net#335, sum(CASE WHEN (d_moy#184 = 10) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#310 AS oct_net#336, sum(CASE WHEN (d_moy#184 = 11) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#311 AS nov_net#337, sum(CASE WHEN (d_moy#184 = 12) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#312 AS dec_net#338] +Functions [24]: [sum(CASE WHEN (d_moy#184 = 1) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 2) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 3) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 4) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 5) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 6) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 7) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 8) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 9) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 10) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 12) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 1) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 2) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 3) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 4) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 5) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 6) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 7) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 8) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 9) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 10) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 11) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#184 = 12) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)] +Aggregate Attributes [24]: [sum(CASE WHEN (d_moy#184 = 1) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#289, sum(CASE WHEN (d_moy#184 = 2) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#290, sum(CASE WHEN (d_moy#184 = 3) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#291, sum(CASE WHEN (d_moy#184 = 4) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#292, sum(CASE WHEN (d_moy#184 = 5) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#293, sum(CASE WHEN (d_moy#184 = 6) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#294, sum(CASE WHEN (d_moy#184 = 7) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#295, sum(CASE WHEN (d_moy#184 = 8) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#296, sum(CASE WHEN (d_moy#184 = 9) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#297, sum(CASE WHEN (d_moy#184 = 10) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#298, sum(CASE WHEN (d_moy#184 = 11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#299, sum(CASE WHEN (d_moy#184 = 12) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#300, sum(CASE WHEN (d_moy#184 = 1) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#301, sum(CASE WHEN (d_moy#184 = 2) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#302, sum(CASE WHEN (d_moy#184 = 3) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#303, sum(CASE WHEN (d_moy#184 = 4) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#304, sum(CASE WHEN (d_moy#184 = 5) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#305, sum(CASE WHEN (d_moy#184 = 6) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#306, sum(CASE WHEN (d_moy#184 = 7) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#307, sum(CASE WHEN (d_moy#184 = 8) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#308, sum(CASE WHEN (d_moy#184 = 9) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#309, sum(CASE WHEN (d_moy#184 = 10) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#310, sum(CASE WHEN (d_moy#184 = 11) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#311, sum(CASE WHEN (d_moy#184 = 12) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#312] +Results [32]: [w_warehouse_name#186, w_warehouse_sq_ft#187, w_city#188, w_county#189, w_state#190, w_country#191, DHL,BARIAN AS ship_carriers#313, d_year#183 AS year#314, sum(CASE WHEN (d_moy#184 = 1) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#289 AS jan_sales#315, sum(CASE WHEN (d_moy#184 = 2) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#290 AS feb_sales#316, sum(CASE WHEN (d_moy#184 = 3) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#291 AS mar_sales#317, sum(CASE WHEN (d_moy#184 = 4) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#292 AS apr_sales#318, sum(CASE WHEN (d_moy#184 = 5) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#293 AS may_sales#319, sum(CASE WHEN (d_moy#184 = 6) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#294 AS jun_sales#320, sum(CASE WHEN (d_moy#184 = 7) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#295 AS jul_sales#321, sum(CASE WHEN (d_moy#184 = 8) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#296 AS aug_sales#322, sum(CASE WHEN (d_moy#184 = 9) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#297 AS sep_sales#323, sum(CASE WHEN (d_moy#184 = 10) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#298 AS oct_sales#324, sum(CASE WHEN (d_moy#184 = 11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#299 AS nov_sales#325, sum(CASE WHEN (d_moy#184 = 12) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#300 AS dec_sales#326, sum(CASE WHEN (d_moy#184 = 1) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#301 AS jan_net#327, sum(CASE WHEN (d_moy#184 = 2) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#302 AS feb_net#328, sum(CASE WHEN (d_moy#184 = 3) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#303 AS mar_net#329, sum(CASE WHEN (d_moy#184 = 4) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#304 AS apr_net#330, sum(CASE WHEN (d_moy#184 = 5) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#305 AS may_net#331, sum(CASE WHEN (d_moy#184 = 6) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#306 AS jun_net#332, sum(CASE WHEN (d_moy#184 = 7) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#307 AS jul_net#333, sum(CASE WHEN (d_moy#184 = 8) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#308 AS aug_net#334, sum(CASE WHEN (d_moy#184 = 9) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#309 AS sep_net#335, sum(CASE WHEN (d_moy#184 = 10) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#310 AS oct_net#336, sum(CASE WHEN (d_moy#184 = 11) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#311 AS nov_net#337, sum(CASE WHEN (d_moy#184 = 12) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#312 AS dec_net#338] (48) Union (49) HashAggregate [codegen id : 13] Input [32]: [w_warehouse_name#19, w_warehouse_sq_ft#20, w_city#21, w_county#22, w_state#23, w_country#24, ship_carriers#147, year#148, jan_sales#149, feb_sales#150, mar_sales#151, apr_sales#152, may_sales#153, jun_sales#154, jul_sales#155, aug_sales#156, sep_sales#157, oct_sales#158, nov_sales#159, dec_sales#160, jan_net#161, feb_net#162, mar_net#163, apr_net#164, may_net#165, jun_net#166, jul_net#167, aug_net#168, sep_net#169, oct_net#170, nov_net#171, dec_net#172] Keys [8]: [w_warehouse_name#19, w_warehouse_sq_ft#20, w_city#21, w_county#22, w_state#23, w_country#24, ship_carriers#147, year#148] -Functions [36]: [partial_sum(jan_sales#149), partial_sum(feb_sales#150), partial_sum(mar_sales#151), partial_sum(apr_sales#152), partial_sum(may_sales#153), partial_sum(jun_sales#154), partial_sum(jul_sales#155), partial_sum(aug_sales#156), partial_sum(sep_sales#157), partial_sum(oct_sales#158), partial_sum(nov_sales#159), partial_sum(dec_sales#160), partial_sum(CheckOverflow((promote_precision(jan_sales#149) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), partial_sum(CheckOverflow((promote_precision(feb_sales#150) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), partial_sum(CheckOverflow((promote_precision(mar_sales#151) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), partial_sum(CheckOverflow((promote_precision(apr_sales#152) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), partial_sum(CheckOverflow((promote_precision(may_sales#153) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), partial_sum(CheckOverflow((promote_precision(jun_sales#154) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), partial_sum(CheckOverflow((promote_precision(jul_sales#155) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), partial_sum(CheckOverflow((promote_precision(aug_sales#156) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), partial_sum(CheckOverflow((promote_precision(sep_sales#157) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), partial_sum(CheckOverflow((promote_precision(oct_sales#158) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), partial_sum(CheckOverflow((promote_precision(nov_sales#159) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), partial_sum(CheckOverflow((promote_precision(dec_sales#160) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), partial_sum(jan_net#161), partial_sum(feb_net#162), partial_sum(mar_net#163), partial_sum(apr_net#164), partial_sum(may_net#165), partial_sum(jun_net#166), partial_sum(jul_net#167), partial_sum(aug_net#168), partial_sum(sep_net#169), partial_sum(oct_net#170), partial_sum(nov_net#171), partial_sum(dec_net#172)] +Functions [36]: [partial_sum(jan_sales#149), partial_sum(feb_sales#150), partial_sum(mar_sales#151), partial_sum(apr_sales#152), partial_sum(may_sales#153), partial_sum(jun_sales#154), partial_sum(jul_sales#155), partial_sum(aug_sales#156), partial_sum(sep_sales#157), partial_sum(oct_sales#158), partial_sum(nov_sales#159), partial_sum(dec_sales#160), partial_sum(CheckOverflow((promote_precision(jan_sales#149) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12))), partial_sum(CheckOverflow((promote_precision(feb_sales#150) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12))), partial_sum(CheckOverflow((promote_precision(mar_sales#151) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12))), partial_sum(CheckOverflow((promote_precision(apr_sales#152) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12))), partial_sum(CheckOverflow((promote_precision(may_sales#153) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12))), partial_sum(CheckOverflow((promote_precision(jun_sales#154) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12))), partial_sum(CheckOverflow((promote_precision(jul_sales#155) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12))), partial_sum(CheckOverflow((promote_precision(aug_sales#156) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12))), partial_sum(CheckOverflow((promote_precision(sep_sales#157) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12))), partial_sum(CheckOverflow((promote_precision(oct_sales#158) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12))), partial_sum(CheckOverflow((promote_precision(nov_sales#159) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12))), partial_sum(CheckOverflow((promote_precision(dec_sales#160) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12))), partial_sum(jan_net#161), partial_sum(feb_net#162), partial_sum(mar_net#163), partial_sum(apr_net#164), partial_sum(may_net#165), partial_sum(jun_net#166), partial_sum(jul_net#167), partial_sum(aug_net#168), partial_sum(sep_net#169), partial_sum(oct_net#170), partial_sum(nov_net#171), partial_sum(dec_net#172)] Aggregate Attributes [72]: [sum#339, isEmpty#340, sum#341, isEmpty#342, sum#343, isEmpty#344, sum#345, isEmpty#346, sum#347, isEmpty#348, sum#349, isEmpty#350, sum#351, isEmpty#352, sum#353, isEmpty#354, sum#355, isEmpty#356, sum#357, isEmpty#358, sum#359, isEmpty#360, sum#361, isEmpty#362, sum#363, isEmpty#364, sum#365, isEmpty#366, sum#367, isEmpty#368, sum#369, isEmpty#370, sum#371, isEmpty#372, sum#373, isEmpty#374, sum#375, isEmpty#376, sum#377, isEmpty#378, sum#379, isEmpty#380, sum#381, isEmpty#382, sum#383, isEmpty#384, sum#385, isEmpty#386, sum#387, isEmpty#388, sum#389, isEmpty#390, sum#391, isEmpty#392, sum#393, isEmpty#394, sum#395, isEmpty#396, sum#397, isEmpty#398, sum#399, isEmpty#400, sum#401, isEmpty#402, sum#403, isEmpty#404, sum#405, isEmpty#406, sum#407, isEmpty#408, sum#409, isEmpty#410] Results [80]: [w_warehouse_name#19, w_warehouse_sq_ft#20, w_city#21, w_county#22, w_state#23, w_country#24, ship_carriers#147, year#148, sum#411, isEmpty#412, sum#413, isEmpty#414, sum#415, isEmpty#416, sum#417, isEmpty#418, sum#419, isEmpty#420, sum#421, isEmpty#422, sum#423, isEmpty#424, sum#425, isEmpty#426, sum#427, isEmpty#428, sum#429, isEmpty#430, sum#431, isEmpty#432, sum#433, isEmpty#434, sum#435, isEmpty#436, sum#437, isEmpty#438, sum#439, isEmpty#440, sum#441, isEmpty#442, sum#443, isEmpty#444, sum#445, isEmpty#446, sum#447, isEmpty#448, sum#449, isEmpty#450, sum#451, isEmpty#452, sum#453, isEmpty#454, sum#455, isEmpty#456, sum#457, isEmpty#458, sum#459, isEmpty#460, sum#461, isEmpty#462, sum#463, isEmpty#464, sum#465, isEmpty#466, sum#467, isEmpty#468, sum#469, isEmpty#470, sum#471, isEmpty#472, sum#473, isEmpty#474, sum#475, isEmpty#476, sum#477, isEmpty#478, sum#479, isEmpty#480, sum#481, isEmpty#482] @@ -284,9 +284,9 @@ Arguments: hashpartitioning(w_warehouse_name#19, w_warehouse_sq_ft#20, w_city#21 (51) HashAggregate [codegen id : 14] Input [80]: [w_warehouse_name#19, w_warehouse_sq_ft#20, w_city#21, w_county#22, w_state#23, w_country#24, ship_carriers#147, year#148, sum#411, isEmpty#412, sum#413, isEmpty#414, sum#415, isEmpty#416, sum#417, isEmpty#418, sum#419, isEmpty#420, sum#421, isEmpty#422, sum#423, isEmpty#424, sum#425, isEmpty#426, sum#427, isEmpty#428, sum#429, isEmpty#430, sum#431, isEmpty#432, sum#433, isEmpty#434, sum#435, isEmpty#436, sum#437, isEmpty#438, sum#439, isEmpty#440, sum#441, isEmpty#442, sum#443, isEmpty#444, sum#445, isEmpty#446, sum#447, isEmpty#448, sum#449, isEmpty#450, sum#451, isEmpty#452, sum#453, isEmpty#454, sum#455, isEmpty#456, sum#457, isEmpty#458, sum#459, isEmpty#460, sum#461, isEmpty#462, sum#463, isEmpty#464, sum#465, isEmpty#466, sum#467, isEmpty#468, sum#469, isEmpty#470, sum#471, isEmpty#472, sum#473, isEmpty#474, sum#475, isEmpty#476, sum#477, isEmpty#478, sum#479, isEmpty#480, sum#481, isEmpty#482] Keys [8]: [w_warehouse_name#19, w_warehouse_sq_ft#20, w_city#21, w_county#22, w_state#23, w_country#24, ship_carriers#147, year#148] -Functions [36]: [sum(jan_sales#149), sum(feb_sales#150), sum(mar_sales#151), sum(apr_sales#152), sum(may_sales#153), sum(jun_sales#154), sum(jul_sales#155), sum(aug_sales#156), sum(sep_sales#157), sum(oct_sales#158), sum(nov_sales#159), sum(dec_sales#160), sum(CheckOverflow((promote_precision(jan_sales#149) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), sum(CheckOverflow((promote_precision(feb_sales#150) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), sum(CheckOverflow((promote_precision(mar_sales#151) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), sum(CheckOverflow((promote_precision(apr_sales#152) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), sum(CheckOverflow((promote_precision(may_sales#153) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), sum(CheckOverflow((promote_precision(jun_sales#154) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), sum(CheckOverflow((promote_precision(jul_sales#155) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), sum(CheckOverflow((promote_precision(aug_sales#156) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), sum(CheckOverflow((promote_precision(sep_sales#157) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), sum(CheckOverflow((promote_precision(oct_sales#158) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), sum(CheckOverflow((promote_precision(nov_sales#159) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), sum(CheckOverflow((promote_precision(dec_sales#160) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), sum(jan_net#161), sum(feb_net#162), sum(mar_net#163), sum(apr_net#164), sum(may_net#165), sum(jun_net#166), sum(jul_net#167), sum(aug_net#168), sum(sep_net#169), sum(oct_net#170), sum(nov_net#171), sum(dec_net#172)] -Aggregate Attributes [36]: [sum(jan_sales#149)#484, sum(feb_sales#150)#485, sum(mar_sales#151)#486, sum(apr_sales#152)#487, sum(may_sales#153)#488, sum(jun_sales#154)#489, sum(jul_sales#155)#490, sum(aug_sales#156)#491, sum(sep_sales#157)#492, sum(oct_sales#158)#493, sum(nov_sales#159)#494, sum(dec_sales#160)#495, sum(CheckOverflow((promote_precision(jan_sales#149) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#496, sum(CheckOverflow((promote_precision(feb_sales#150) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#497, sum(CheckOverflow((promote_precision(mar_sales#151) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#498, sum(CheckOverflow((promote_precision(apr_sales#152) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#499, sum(CheckOverflow((promote_precision(may_sales#153) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#500, sum(CheckOverflow((promote_precision(jun_sales#154) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#501, sum(CheckOverflow((promote_precision(jul_sales#155) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#502, sum(CheckOverflow((promote_precision(aug_sales#156) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#503, sum(CheckOverflow((promote_precision(sep_sales#157) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#504, sum(CheckOverflow((promote_precision(oct_sales#158) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#505, sum(CheckOverflow((promote_precision(nov_sales#159) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#506, sum(CheckOverflow((promote_precision(dec_sales#160) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#507, sum(jan_net#161)#508, sum(feb_net#162)#509, sum(mar_net#163)#510, sum(apr_net#164)#511, sum(may_net#165)#512, sum(jun_net#166)#513, sum(jul_net#167)#514, sum(aug_net#168)#515, sum(sep_net#169)#516, sum(oct_net#170)#517, sum(nov_net#171)#518, sum(dec_net#172)#519] -Results [44]: [w_warehouse_name#19, w_warehouse_sq_ft#20, w_city#21, w_county#22, w_state#23, w_country#24, ship_carriers#147, year#148, sum(jan_sales#149)#484 AS jan_sales#520, sum(feb_sales#150)#485 AS feb_sales#521, sum(mar_sales#151)#486 AS mar_sales#522, sum(apr_sales#152)#487 AS apr_sales#523, sum(may_sales#153)#488 AS may_sales#524, sum(jun_sales#154)#489 AS jun_sales#525, sum(jul_sales#155)#490 AS jul_sales#526, sum(aug_sales#156)#491 AS aug_sales#527, sum(sep_sales#157)#492 AS sep_sales#528, sum(oct_sales#158)#493 AS oct_sales#529, sum(nov_sales#159)#494 AS nov_sales#530, sum(dec_sales#160)#495 AS dec_sales#531, sum(CheckOverflow((promote_precision(jan_sales#149) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#496 AS jan_sales_per_sq_foot#532, sum(CheckOverflow((promote_precision(feb_sales#150) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#497 AS feb_sales_per_sq_foot#533, sum(CheckOverflow((promote_precision(mar_sales#151) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#498 AS mar_sales_per_sq_foot#534, sum(CheckOverflow((promote_precision(apr_sales#152) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#499 AS apr_sales_per_sq_foot#535, sum(CheckOverflow((promote_precision(may_sales#153) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#500 AS may_sales_per_sq_foot#536, sum(CheckOverflow((promote_precision(jun_sales#154) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#501 AS jun_sales_per_sq_foot#537, sum(CheckOverflow((promote_precision(jul_sales#155) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#502 AS jul_sales_per_sq_foot#538, sum(CheckOverflow((promote_precision(aug_sales#156) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#503 AS aug_sales_per_sq_foot#539, sum(CheckOverflow((promote_precision(sep_sales#157) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#504 AS sep_sales_per_sq_foot#540, sum(CheckOverflow((promote_precision(oct_sales#158) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#505 AS oct_sales_per_sq_foot#541, sum(CheckOverflow((promote_precision(nov_sales#159) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#506 AS nov_sales_per_sq_foot#542, sum(CheckOverflow((promote_precision(dec_sales#160) / promote_precision(cast(cast(w_warehouse_sq_ft#20 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#507 AS dec_sales_per_sq_foot#543, sum(jan_net#161)#508 AS jan_net#544, sum(feb_net#162)#509 AS feb_net#545, sum(mar_net#163)#510 AS mar_net#546, sum(apr_net#164)#511 AS apr_net#547, sum(may_net#165)#512 AS may_net#548, sum(jun_net#166)#513 AS jun_net#549, sum(jul_net#167)#514 AS jul_net#550, sum(aug_net#168)#515 AS aug_net#551, sum(sep_net#169)#516 AS sep_net#552, sum(oct_net#170)#517 AS oct_net#553, sum(nov_net#171)#518 AS nov_net#554, sum(dec_net#172)#519 AS dec_net#555] +Functions [36]: [sum(jan_sales#149), sum(feb_sales#150), sum(mar_sales#151), sum(apr_sales#152), sum(may_sales#153), sum(jun_sales#154), sum(jul_sales#155), sum(aug_sales#156), sum(sep_sales#157), sum(oct_sales#158), sum(nov_sales#159), sum(dec_sales#160), sum(CheckOverflow((promote_precision(jan_sales#149) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12))), sum(CheckOverflow((promote_precision(feb_sales#150) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12))), sum(CheckOverflow((promote_precision(mar_sales#151) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12))), sum(CheckOverflow((promote_precision(apr_sales#152) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12))), sum(CheckOverflow((promote_precision(may_sales#153) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12))), sum(CheckOverflow((promote_precision(jun_sales#154) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12))), sum(CheckOverflow((promote_precision(jul_sales#155) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12))), sum(CheckOverflow((promote_precision(aug_sales#156) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12))), sum(CheckOverflow((promote_precision(sep_sales#157) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12))), sum(CheckOverflow((promote_precision(oct_sales#158) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12))), sum(CheckOverflow((promote_precision(nov_sales#159) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12))), sum(CheckOverflow((promote_precision(dec_sales#160) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12))), sum(jan_net#161), sum(feb_net#162), sum(mar_net#163), sum(apr_net#164), sum(may_net#165), sum(jun_net#166), sum(jul_net#167), sum(aug_net#168), sum(sep_net#169), sum(oct_net#170), sum(nov_net#171), sum(dec_net#172)] +Aggregate Attributes [36]: [sum(jan_sales#149)#484, sum(feb_sales#150)#485, sum(mar_sales#151)#486, sum(apr_sales#152)#487, sum(may_sales#153)#488, sum(jun_sales#154)#489, sum(jul_sales#155)#490, sum(aug_sales#156)#491, sum(sep_sales#157)#492, sum(oct_sales#158)#493, sum(nov_sales#159)#494, sum(dec_sales#160)#495, sum(CheckOverflow((promote_precision(jan_sales#149) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12)))#496, sum(CheckOverflow((promote_precision(feb_sales#150) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12)))#497, sum(CheckOverflow((promote_precision(mar_sales#151) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12)))#498, sum(CheckOverflow((promote_precision(apr_sales#152) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12)))#499, sum(CheckOverflow((promote_precision(may_sales#153) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12)))#500, sum(CheckOverflow((promote_precision(jun_sales#154) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12)))#501, sum(CheckOverflow((promote_precision(jul_sales#155) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12)))#502, sum(CheckOverflow((promote_precision(aug_sales#156) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12)))#503, sum(CheckOverflow((promote_precision(sep_sales#157) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12)))#504, sum(CheckOverflow((promote_precision(oct_sales#158) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12)))#505, sum(CheckOverflow((promote_precision(nov_sales#159) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12)))#506, sum(CheckOverflow((promote_precision(dec_sales#160) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12)))#507, sum(jan_net#161)#508, sum(feb_net#162)#509, sum(mar_net#163)#510, sum(apr_net#164)#511, sum(may_net#165)#512, sum(jun_net#166)#513, sum(jul_net#167)#514, sum(aug_net#168)#515, sum(sep_net#169)#516, sum(oct_net#170)#517, sum(nov_net#171)#518, sum(dec_net#172)#519] +Results [44]: [w_warehouse_name#19, w_warehouse_sq_ft#20, w_city#21, w_county#22, w_state#23, w_country#24, ship_carriers#147, year#148, sum(jan_sales#149)#484 AS jan_sales#520, sum(feb_sales#150)#485 AS feb_sales#521, sum(mar_sales#151)#486 AS mar_sales#522, sum(apr_sales#152)#487 AS apr_sales#523, sum(may_sales#153)#488 AS may_sales#524, sum(jun_sales#154)#489 AS jun_sales#525, sum(jul_sales#155)#490 AS jul_sales#526, sum(aug_sales#156)#491 AS aug_sales#527, sum(sep_sales#157)#492 AS sep_sales#528, sum(oct_sales#158)#493 AS oct_sales#529, sum(nov_sales#159)#494 AS nov_sales#530, sum(dec_sales#160)#495 AS dec_sales#531, sum(CheckOverflow((promote_precision(jan_sales#149) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12)))#496 AS jan_sales_per_sq_foot#532, sum(CheckOverflow((promote_precision(feb_sales#150) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12)))#497 AS feb_sales_per_sq_foot#533, sum(CheckOverflow((promote_precision(mar_sales#151) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12)))#498 AS mar_sales_per_sq_foot#534, sum(CheckOverflow((promote_precision(apr_sales#152) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12)))#499 AS apr_sales_per_sq_foot#535, sum(CheckOverflow((promote_precision(may_sales#153) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12)))#500 AS may_sales_per_sq_foot#536, sum(CheckOverflow((promote_precision(jun_sales#154) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12)))#501 AS jun_sales_per_sq_foot#537, sum(CheckOverflow((promote_precision(jul_sales#155) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12)))#502 AS jul_sales_per_sq_foot#538, sum(CheckOverflow((promote_precision(aug_sales#156) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12)))#503 AS aug_sales_per_sq_foot#539, sum(CheckOverflow((promote_precision(sep_sales#157) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12)))#504 AS sep_sales_per_sq_foot#540, sum(CheckOverflow((promote_precision(oct_sales#158) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12)))#505 AS oct_sales_per_sq_foot#541, sum(CheckOverflow((promote_precision(nov_sales#159) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12)))#506 AS nov_sales_per_sq_foot#542, sum(CheckOverflow((promote_precision(dec_sales#160) / promote_precision(cast(w_warehouse_sq_ft#20 as decimal(28,2)))), DecimalType(38,12)))#507 AS dec_sales_per_sq_foot#543, sum(jan_net#161)#508 AS jan_net#544, sum(feb_net#162)#509 AS feb_net#545, sum(mar_net#163)#510 AS mar_net#546, sum(apr_net#164)#511 AS apr_net#547, sum(may_net#165)#512 AS may_net#548, sum(jun_net#166)#513 AS jun_net#549, sum(jul_net#167)#514 AS jul_net#550, sum(aug_net#168)#515 AS aug_net#551, sum(sep_net#169)#516 AS sep_net#552, sum(oct_net#170)#517 AS oct_net#553, sum(nov_net#171)#518 AS nov_net#554, sum(dec_net#172)#519 AS dec_net#555] (52) TakeOrderedAndProject Input [44]: [w_warehouse_name#19, w_warehouse_sq_ft#20, w_city#21, w_county#22, w_state#23, w_country#24, ship_carriers#147, year#148, jan_sales#520, feb_sales#521, mar_sales#522, apr_sales#523, may_sales#524, jun_sales#525, jul_sales#526, aug_sales#527, sep_sales#528, oct_sales#529, nov_sales#530, dec_sales#531, jan_sales_per_sq_foot#532, feb_sales_per_sq_foot#533, mar_sales_per_sq_foot#534, apr_sales_per_sq_foot#535, may_sales_per_sq_foot#536, jun_sales_per_sq_foot#537, jul_sales_per_sq_foot#538, aug_sales_per_sq_foot#539, sep_sales_per_sq_foot#540, oct_sales_per_sq_foot#541, nov_sales_per_sq_foot#542, dec_sales_per_sq_foot#543, jan_net#544, feb_net#545, mar_net#546, apr_net#547, may_net#548, jun_net#549, jul_net#550, aug_net#551, sep_net#552, oct_net#553, nov_net#554, dec_net#555] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q66.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q66.sf100/simplified.txt index 86c73b1f44bfe..d9ac8f54234f7 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q66.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q66.sf100/simplified.txt @@ -1,6 +1,6 @@ TakeOrderedAndProject [w_warehouse_name,w_warehouse_sq_ft,w_city,w_county,w_state,w_country,ship_carriers,year,jan_sales,feb_sales,mar_sales,apr_sales,may_sales,jun_sales,jul_sales,aug_sales,sep_sales,oct_sales,nov_sales,dec_sales,jan_sales_per_sq_foot,feb_sales_per_sq_foot,mar_sales_per_sq_foot,apr_sales_per_sq_foot,may_sales_per_sq_foot,jun_sales_per_sq_foot,jul_sales_per_sq_foot,aug_sales_per_sq_foot,sep_sales_per_sq_foot,oct_sales_per_sq_foot,nov_sales_per_sq_foot,dec_sales_per_sq_foot,jan_net,feb_net,mar_net,apr_net,may_net,jun_net,jul_net,aug_net,sep_net,oct_net,nov_net,dec_net] WholeStageCodegen (14) - HashAggregate [w_warehouse_name,w_warehouse_sq_ft,w_city,w_county,w_state,w_country,ship_carriers,year,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty] [sum(jan_sales),sum(feb_sales),sum(mar_sales),sum(apr_sales),sum(may_sales),sum(jun_sales),sum(jul_sales),sum(aug_sales),sum(sep_sales),sum(oct_sales),sum(nov_sales),sum(dec_sales),sum(CheckOverflow((promote_precision(jan_sales) / promote_precision(cast(cast(w_warehouse_sq_ft as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)),sum(CheckOverflow((promote_precision(feb_sales) / promote_precision(cast(cast(w_warehouse_sq_ft as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)),sum(CheckOverflow((promote_precision(mar_sales) / promote_precision(cast(cast(w_warehouse_sq_ft as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)),sum(CheckOverflow((promote_precision(apr_sales) / promote_precision(cast(cast(w_warehouse_sq_ft as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)),sum(CheckOverflow((promote_precision(may_sales) / promote_precision(cast(cast(w_warehouse_sq_ft as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)),sum(CheckOverflow((promote_precision(jun_sales) / promote_precision(cast(cast(w_warehouse_sq_ft as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)),sum(CheckOverflow((promote_precision(jul_sales) / promote_precision(cast(cast(w_warehouse_sq_ft as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)),sum(CheckOverflow((promote_precision(aug_sales) / promote_precision(cast(cast(w_warehouse_sq_ft as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)),sum(CheckOverflow((promote_precision(sep_sales) / promote_precision(cast(cast(w_warehouse_sq_ft as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)),sum(CheckOverflow((promote_precision(oct_sales) / promote_precision(cast(cast(w_warehouse_sq_ft as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)),sum(CheckOverflow((promote_precision(nov_sales) / promote_precision(cast(cast(w_warehouse_sq_ft as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)),sum(CheckOverflow((promote_precision(dec_sales) / promote_precision(cast(cast(w_warehouse_sq_ft as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)),sum(jan_net),sum(feb_net),sum(mar_net),sum(apr_net),sum(may_net),sum(jun_net),sum(jul_net),sum(aug_net),sum(sep_net),sum(oct_net),sum(nov_net),sum(dec_net),jan_sales,feb_sales,mar_sales,apr_sales,may_sales,jun_sales,jul_sales,aug_sales,sep_sales,oct_sales,nov_sales,dec_sales,jan_sales_per_sq_foot,feb_sales_per_sq_foot,mar_sales_per_sq_foot,apr_sales_per_sq_foot,may_sales_per_sq_foot,jun_sales_per_sq_foot,jul_sales_per_sq_foot,aug_sales_per_sq_foot,sep_sales_per_sq_foot,oct_sales_per_sq_foot,nov_sales_per_sq_foot,dec_sales_per_sq_foot,jan_net,feb_net,mar_net,apr_net,may_net,jun_net,jul_net,aug_net,sep_net,oct_net,nov_net,dec_net,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty] + HashAggregate [w_warehouse_name,w_warehouse_sq_ft,w_city,w_county,w_state,w_country,ship_carriers,year,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty] [sum(jan_sales),sum(feb_sales),sum(mar_sales),sum(apr_sales),sum(may_sales),sum(jun_sales),sum(jul_sales),sum(aug_sales),sum(sep_sales),sum(oct_sales),sum(nov_sales),sum(dec_sales),sum(CheckOverflow((promote_precision(jan_sales) / promote_precision(cast(w_warehouse_sq_ft as decimal(28,2)))), DecimalType(38,12))),sum(CheckOverflow((promote_precision(feb_sales) / promote_precision(cast(w_warehouse_sq_ft as decimal(28,2)))), DecimalType(38,12))),sum(CheckOverflow((promote_precision(mar_sales) / promote_precision(cast(w_warehouse_sq_ft as decimal(28,2)))), DecimalType(38,12))),sum(CheckOverflow((promote_precision(apr_sales) / promote_precision(cast(w_warehouse_sq_ft as decimal(28,2)))), DecimalType(38,12))),sum(CheckOverflow((promote_precision(may_sales) / promote_precision(cast(w_warehouse_sq_ft as decimal(28,2)))), DecimalType(38,12))),sum(CheckOverflow((promote_precision(jun_sales) / promote_precision(cast(w_warehouse_sq_ft as decimal(28,2)))), DecimalType(38,12))),sum(CheckOverflow((promote_precision(jul_sales) / promote_precision(cast(w_warehouse_sq_ft as decimal(28,2)))), DecimalType(38,12))),sum(CheckOverflow((promote_precision(aug_sales) / promote_precision(cast(w_warehouse_sq_ft as decimal(28,2)))), DecimalType(38,12))),sum(CheckOverflow((promote_precision(sep_sales) / promote_precision(cast(w_warehouse_sq_ft as decimal(28,2)))), DecimalType(38,12))),sum(CheckOverflow((promote_precision(oct_sales) / promote_precision(cast(w_warehouse_sq_ft as decimal(28,2)))), DecimalType(38,12))),sum(CheckOverflow((promote_precision(nov_sales) / promote_precision(cast(w_warehouse_sq_ft as decimal(28,2)))), DecimalType(38,12))),sum(CheckOverflow((promote_precision(dec_sales) / promote_precision(cast(w_warehouse_sq_ft as decimal(28,2)))), DecimalType(38,12))),sum(jan_net),sum(feb_net),sum(mar_net),sum(apr_net),sum(may_net),sum(jun_net),sum(jul_net),sum(aug_net),sum(sep_net),sum(oct_net),sum(nov_net),sum(dec_net),jan_sales,feb_sales,mar_sales,apr_sales,may_sales,jun_sales,jul_sales,aug_sales,sep_sales,oct_sales,nov_sales,dec_sales,jan_sales_per_sq_foot,feb_sales_per_sq_foot,mar_sales_per_sq_foot,apr_sales_per_sq_foot,may_sales_per_sq_foot,jun_sales_per_sq_foot,jul_sales_per_sq_foot,aug_sales_per_sq_foot,sep_sales_per_sq_foot,oct_sales_per_sq_foot,nov_sales_per_sq_foot,dec_sales_per_sq_foot,jan_net,feb_net,mar_net,apr_net,may_net,jun_net,jul_net,aug_net,sep_net,oct_net,nov_net,dec_net,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty] InputAdapter Exchange [w_warehouse_name,w_warehouse_sq_ft,w_city,w_county,w_state,w_country,ship_carriers,year] #1 WholeStageCodegen (13) @@ -8,7 +8,7 @@ TakeOrderedAndProject [w_warehouse_name,w_warehouse_sq_ft,w_city,w_county,w_stat InputAdapter Union WholeStageCodegen (6) - HashAggregate [w_warehouse_name,w_warehouse_sq_ft,w_city,w_county,w_state,w_country,d_year,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty] [sum(CASE WHEN (d_moy = 1) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 2) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 3) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 4) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 5) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 6) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 7) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 8) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 9) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 10) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 11) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 12) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 1) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 2) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 3) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 4) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 5) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 6) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 7) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 8) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 9) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 10) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 11) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 12) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),ship_carriers,year,jan_sales,feb_sales,mar_sales,apr_sales,may_sales,jun_sales,jul_sales,aug_sales,sep_sales,oct_sales,nov_sales,dec_sales,jan_net,feb_net,mar_net,apr_net,may_net,jun_net,jul_net,aug_net,sep_net,oct_net,nov_net,dec_net,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty] + HashAggregate [w_warehouse_name,w_warehouse_sq_ft,w_city,w_county,w_state,w_country,d_year,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty] [sum(CASE WHEN (d_moy = 1) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 2) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 3) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 4) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 5) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 6) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 7) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 8) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 9) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 10) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 11) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 12) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 1) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 2) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 3) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 4) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 5) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 6) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 7) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 8) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 9) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 10) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 11) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 12) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),ship_carriers,year,jan_sales,feb_sales,mar_sales,apr_sales,may_sales,jun_sales,jul_sales,aug_sales,sep_sales,oct_sales,nov_sales,dec_sales,jan_net,feb_net,mar_net,apr_net,may_net,jun_net,jul_net,aug_net,sep_net,oct_net,nov_net,dec_net,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty] InputAdapter Exchange [w_warehouse_name,w_warehouse_sq_ft,w_city,w_county,w_state,w_country,d_year] #2 WholeStageCodegen (5) @@ -58,7 +58,7 @@ TakeOrderedAndProject [w_warehouse_name,w_warehouse_sq_ft,w_city,w_county,w_stat InputAdapter Scan parquet default.warehouse [w_warehouse_sk,w_warehouse_name,w_warehouse_sq_ft,w_city,w_county,w_state,w_country] WholeStageCodegen (12) - HashAggregate [w_warehouse_name,w_warehouse_sq_ft,w_city,w_county,w_state,w_country,d_year,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty] [sum(CASE WHEN (d_moy = 1) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 2) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 3) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 4) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 5) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 6) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 7) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 8) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 9) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 10) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 11) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 12) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 1) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 2) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 3) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 4) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 5) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 6) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 7) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 8) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 9) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 10) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 11) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 12) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),ship_carriers,year,jan_sales,feb_sales,mar_sales,apr_sales,may_sales,jun_sales,jul_sales,aug_sales,sep_sales,oct_sales,nov_sales,dec_sales,jan_net,feb_net,mar_net,apr_net,may_net,jun_net,jul_net,aug_net,sep_net,oct_net,nov_net,dec_net,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty] + HashAggregate [w_warehouse_name,w_warehouse_sq_ft,w_city,w_county,w_state,w_country,d_year,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty] [sum(CASE WHEN (d_moy = 1) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 2) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 3) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 4) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 5) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 6) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 7) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 8) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 9) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 10) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 11) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 12) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 1) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 2) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 3) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 4) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 5) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 6) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 7) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 8) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 9) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 10) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 11) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 12) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),ship_carriers,year,jan_sales,feb_sales,mar_sales,apr_sales,may_sales,jun_sales,jul_sales,aug_sales,sep_sales,oct_sales,nov_sales,dec_sales,jan_net,feb_net,mar_net,apr_net,may_net,jun_net,jul_net,aug_net,sep_net,oct_net,nov_net,dec_net,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty] InputAdapter Exchange [w_warehouse_name,w_warehouse_sq_ft,w_city,w_county,w_state,w_country,d_year] #7 WholeStageCodegen (11) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q66/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q66/explain.txt index defc9caffa7c2..f0b239a262c26 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q66/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q66/explain.txt @@ -172,7 +172,7 @@ Input [13]: [ws_ship_mode_sk#2, ws_quantity#4, ws_ext_sales_price#5, ws_net_paid (27) HashAggregate [codegen id : 5] Input [11]: [ws_quantity#4, ws_ext_sales_price#5, ws_net_paid#6, w_warehouse_name#10, w_warehouse_sq_ft#11, w_city#12, w_county#13, w_state#14, w_country#15, d_year#18, d_moy#19] Keys [7]: [w_warehouse_name#10, w_warehouse_sq_ft#11, w_city#12, w_county#13, w_state#14, w_country#15, d_year#18] -Functions [24]: [partial_sum(CASE WHEN (d_moy#19 = 1) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 2) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 3) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 4) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 5) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 6) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 7) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 8) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 9) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 10) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 11) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 12) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 1) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 2) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 3) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 4) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 5) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 6) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 7) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 8) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 9) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 10) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 11) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 12) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)] +Functions [24]: [partial_sum(CASE WHEN (d_moy#19 = 1) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 2) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 3) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 4) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 5) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 6) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 7) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 8) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 9) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 10) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 11) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 12) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 1) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 2) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 3) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 4) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 5) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 6) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 7) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 8) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 9) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 10) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 11) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#19 = 12) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)] Aggregate Attributes [48]: [sum#26, isEmpty#27, sum#28, isEmpty#29, sum#30, isEmpty#31, sum#32, isEmpty#33, sum#34, isEmpty#35, sum#36, isEmpty#37, sum#38, isEmpty#39, sum#40, isEmpty#41, sum#42, isEmpty#43, sum#44, isEmpty#45, sum#46, isEmpty#47, sum#48, isEmpty#49, sum#50, isEmpty#51, sum#52, isEmpty#53, sum#54, isEmpty#55, sum#56, isEmpty#57, sum#58, isEmpty#59, sum#60, isEmpty#61, sum#62, isEmpty#63, sum#64, isEmpty#65, sum#66, isEmpty#67, sum#68, isEmpty#69, sum#70, isEmpty#71, sum#72, isEmpty#73] Results [55]: [w_warehouse_name#10, w_warehouse_sq_ft#11, w_city#12, w_county#13, w_state#14, w_country#15, d_year#18, sum#74, isEmpty#75, sum#76, isEmpty#77, sum#78, isEmpty#79, sum#80, isEmpty#81, sum#82, isEmpty#83, sum#84, isEmpty#85, sum#86, isEmpty#87, sum#88, isEmpty#89, sum#90, isEmpty#91, sum#92, isEmpty#93, sum#94, isEmpty#95, sum#96, isEmpty#97, sum#98, isEmpty#99, sum#100, isEmpty#101, sum#102, isEmpty#103, sum#104, isEmpty#105, sum#106, isEmpty#107, sum#108, isEmpty#109, sum#110, isEmpty#111, sum#112, isEmpty#113, sum#114, isEmpty#115, sum#116, isEmpty#117, sum#118, isEmpty#119, sum#120, isEmpty#121] @@ -183,9 +183,9 @@ Arguments: hashpartitioning(w_warehouse_name#10, w_warehouse_sq_ft#11, w_city#12 (29) HashAggregate [codegen id : 6] Input [55]: [w_warehouse_name#10, w_warehouse_sq_ft#11, w_city#12, w_county#13, w_state#14, w_country#15, d_year#18, sum#74, isEmpty#75, sum#76, isEmpty#77, sum#78, isEmpty#79, sum#80, isEmpty#81, sum#82, isEmpty#83, sum#84, isEmpty#85, sum#86, isEmpty#87, sum#88, isEmpty#89, sum#90, isEmpty#91, sum#92, isEmpty#93, sum#94, isEmpty#95, sum#96, isEmpty#97, sum#98, isEmpty#99, sum#100, isEmpty#101, sum#102, isEmpty#103, sum#104, isEmpty#105, sum#106, isEmpty#107, sum#108, isEmpty#109, sum#110, isEmpty#111, sum#112, isEmpty#113, sum#114, isEmpty#115, sum#116, isEmpty#117, sum#118, isEmpty#119, sum#120, isEmpty#121] Keys [7]: [w_warehouse_name#10, w_warehouse_sq_ft#11, w_city#12, w_county#13, w_state#14, w_country#15, d_year#18] -Functions [24]: [sum(CASE WHEN (d_moy#19 = 1) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 2) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 3) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 4) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 5) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 6) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 7) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 8) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 9) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 10) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 11) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 12) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 1) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 2) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 3) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 4) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 5) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 6) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 7) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 8) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 9) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 10) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 11) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 12) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)] -Aggregate Attributes [24]: [sum(CASE WHEN (d_moy#19 = 1) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#123, sum(CASE WHEN (d_moy#19 = 2) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#124, sum(CASE WHEN (d_moy#19 = 3) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#125, sum(CASE WHEN (d_moy#19 = 4) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#126, sum(CASE WHEN (d_moy#19 = 5) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#127, sum(CASE WHEN (d_moy#19 = 6) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#128, sum(CASE WHEN (d_moy#19 = 7) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#129, sum(CASE WHEN (d_moy#19 = 8) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#130, sum(CASE WHEN (d_moy#19 = 9) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#131, sum(CASE WHEN (d_moy#19 = 10) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#132, sum(CASE WHEN (d_moy#19 = 11) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#133, sum(CASE WHEN (d_moy#19 = 12) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#134, sum(CASE WHEN (d_moy#19 = 1) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#135, sum(CASE WHEN (d_moy#19 = 2) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#136, sum(CASE WHEN (d_moy#19 = 3) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#137, sum(CASE WHEN (d_moy#19 = 4) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#138, sum(CASE WHEN (d_moy#19 = 5) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#139, sum(CASE WHEN (d_moy#19 = 6) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#140, sum(CASE WHEN (d_moy#19 = 7) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#141, sum(CASE WHEN (d_moy#19 = 8) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#142, sum(CASE WHEN (d_moy#19 = 9) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#143, sum(CASE WHEN (d_moy#19 = 10) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#144, sum(CASE WHEN (d_moy#19 = 11) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#145, sum(CASE WHEN (d_moy#19 = 12) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#146] -Results [32]: [w_warehouse_name#10, w_warehouse_sq_ft#11, w_city#12, w_county#13, w_state#14, w_country#15, DHL,BARIAN AS ship_carriers#147, d_year#18 AS year#148, sum(CASE WHEN (d_moy#19 = 1) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#123 AS jan_sales#149, sum(CASE WHEN (d_moy#19 = 2) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#124 AS feb_sales#150, sum(CASE WHEN (d_moy#19 = 3) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#125 AS mar_sales#151, sum(CASE WHEN (d_moy#19 = 4) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#126 AS apr_sales#152, sum(CASE WHEN (d_moy#19 = 5) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#127 AS may_sales#153, sum(CASE WHEN (d_moy#19 = 6) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#128 AS jun_sales#154, sum(CASE WHEN (d_moy#19 = 7) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#129 AS jul_sales#155, sum(CASE WHEN (d_moy#19 = 8) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#130 AS aug_sales#156, sum(CASE WHEN (d_moy#19 = 9) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#131 AS sep_sales#157, sum(CASE WHEN (d_moy#19 = 10) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#132 AS oct_sales#158, sum(CASE WHEN (d_moy#19 = 11) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#133 AS nov_sales#159, sum(CASE WHEN (d_moy#19 = 12) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#134 AS dec_sales#160, sum(CASE WHEN (d_moy#19 = 1) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#135 AS jan_net#161, sum(CASE WHEN (d_moy#19 = 2) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#136 AS feb_net#162, sum(CASE WHEN (d_moy#19 = 3) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#137 AS mar_net#163, sum(CASE WHEN (d_moy#19 = 4) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#138 AS apr_net#164, sum(CASE WHEN (d_moy#19 = 5) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#139 AS may_net#165, sum(CASE WHEN (d_moy#19 = 6) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#140 AS jun_net#166, sum(CASE WHEN (d_moy#19 = 7) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#141 AS jul_net#167, sum(CASE WHEN (d_moy#19 = 8) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#142 AS aug_net#168, sum(CASE WHEN (d_moy#19 = 9) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#143 AS sep_net#169, sum(CASE WHEN (d_moy#19 = 10) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#144 AS oct_net#170, sum(CASE WHEN (d_moy#19 = 11) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#145 AS nov_net#171, sum(CASE WHEN (d_moy#19 = 12) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(cast(ws_quantity#4 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#146 AS dec_net#172] +Functions [24]: [sum(CASE WHEN (d_moy#19 = 1) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 2) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 3) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 4) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 5) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 6) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 7) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 8) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 9) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 10) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 11) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 12) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 1) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 2) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 3) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 4) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 5) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 6) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 7) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 8) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 9) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 10) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 11) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#19 = 12) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)] +Aggregate Attributes [24]: [sum(CASE WHEN (d_moy#19 = 1) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#123, sum(CASE WHEN (d_moy#19 = 2) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#124, sum(CASE WHEN (d_moy#19 = 3) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#125, sum(CASE WHEN (d_moy#19 = 4) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#126, sum(CASE WHEN (d_moy#19 = 5) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#127, sum(CASE WHEN (d_moy#19 = 6) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#128, sum(CASE WHEN (d_moy#19 = 7) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#129, sum(CASE WHEN (d_moy#19 = 8) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#130, sum(CASE WHEN (d_moy#19 = 9) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#131, sum(CASE WHEN (d_moy#19 = 10) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#132, sum(CASE WHEN (d_moy#19 = 11) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#133, sum(CASE WHEN (d_moy#19 = 12) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#134, sum(CASE WHEN (d_moy#19 = 1) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#135, sum(CASE WHEN (d_moy#19 = 2) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#136, sum(CASE WHEN (d_moy#19 = 3) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#137, sum(CASE WHEN (d_moy#19 = 4) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#138, sum(CASE WHEN (d_moy#19 = 5) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#139, sum(CASE WHEN (d_moy#19 = 6) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#140, sum(CASE WHEN (d_moy#19 = 7) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#141, sum(CASE WHEN (d_moy#19 = 8) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#142, sum(CASE WHEN (d_moy#19 = 9) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#143, sum(CASE WHEN (d_moy#19 = 10) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#144, sum(CASE WHEN (d_moy#19 = 11) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#145, sum(CASE WHEN (d_moy#19 = 12) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#146] +Results [32]: [w_warehouse_name#10, w_warehouse_sq_ft#11, w_city#12, w_county#13, w_state#14, w_country#15, DHL,BARIAN AS ship_carriers#147, d_year#18 AS year#148, sum(CASE WHEN (d_moy#19 = 1) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#123 AS jan_sales#149, sum(CASE WHEN (d_moy#19 = 2) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#124 AS feb_sales#150, sum(CASE WHEN (d_moy#19 = 3) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#125 AS mar_sales#151, sum(CASE WHEN (d_moy#19 = 4) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#126 AS apr_sales#152, sum(CASE WHEN (d_moy#19 = 5) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#127 AS may_sales#153, sum(CASE WHEN (d_moy#19 = 6) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#128 AS jun_sales#154, sum(CASE WHEN (d_moy#19 = 7) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#129 AS jul_sales#155, sum(CASE WHEN (d_moy#19 = 8) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#130 AS aug_sales#156, sum(CASE WHEN (d_moy#19 = 9) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#131 AS sep_sales#157, sum(CASE WHEN (d_moy#19 = 10) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#132 AS oct_sales#158, sum(CASE WHEN (d_moy#19 = 11) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#133 AS nov_sales#159, sum(CASE WHEN (d_moy#19 = 12) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price#5 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#134 AS dec_sales#160, sum(CASE WHEN (d_moy#19 = 1) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#135 AS jan_net#161, sum(CASE WHEN (d_moy#19 = 2) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#136 AS feb_net#162, sum(CASE WHEN (d_moy#19 = 3) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#137 AS mar_net#163, sum(CASE WHEN (d_moy#19 = 4) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#138 AS apr_net#164, sum(CASE WHEN (d_moy#19 = 5) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#139 AS may_net#165, sum(CASE WHEN (d_moy#19 = 6) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#140 AS jun_net#166, sum(CASE WHEN (d_moy#19 = 7) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#141 AS jul_net#167, sum(CASE WHEN (d_moy#19 = 8) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#142 AS aug_net#168, sum(CASE WHEN (d_moy#19 = 9) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#143 AS sep_net#169, sum(CASE WHEN (d_moy#19 = 10) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#144 AS oct_net#170, sum(CASE WHEN (d_moy#19 = 11) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#145 AS nov_net#171, sum(CASE WHEN (d_moy#19 = 12) THEN CheckOverflow((promote_precision(cast(ws_net_paid#6 as decimal(12,2))) * promote_precision(cast(ws_quantity#4 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#146 AS dec_net#172] (30) Scan parquet default.catalog_sales Output [7]: [cs_sold_time_sk#173, cs_ship_mode_sk#174, cs_warehouse_sk#175, cs_quantity#176, cs_sales_price#177, cs_net_paid_inc_tax#178, cs_sold_date_sk#179] @@ -253,7 +253,7 @@ Input [13]: [cs_ship_mode_sk#174, cs_quantity#176, cs_sales_price#177, cs_net_pa (45) HashAggregate [codegen id : 11] Input [11]: [cs_quantity#176, cs_sales_price#177, cs_net_paid_inc_tax#178, w_warehouse_name#181, w_warehouse_sq_ft#182, w_city#183, w_county#184, w_state#185, w_country#186, d_year#188, d_moy#189] Keys [7]: [w_warehouse_name#181, w_warehouse_sq_ft#182, w_city#183, w_county#184, w_state#185, w_country#186, d_year#188] -Functions [24]: [partial_sum(CASE WHEN (d_moy#189 = 1) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 2) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 3) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 4) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 5) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 6) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 7) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 8) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 9) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 10) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 12) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 1) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 2) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 3) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 4) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 5) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 6) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 7) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 8) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 9) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 10) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 11) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 12) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)] +Functions [24]: [partial_sum(CASE WHEN (d_moy#189 = 1) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 2) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 3) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 4) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 5) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 6) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 7) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 8) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 9) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 10) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 12) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 1) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 2) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 3) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 4) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 5) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 6) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 7) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 8) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 9) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 10) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 11) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), partial_sum(CASE WHEN (d_moy#189 = 12) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)] Aggregate Attributes [48]: [sum#192, isEmpty#193, sum#194, isEmpty#195, sum#196, isEmpty#197, sum#198, isEmpty#199, sum#200, isEmpty#201, sum#202, isEmpty#203, sum#204, isEmpty#205, sum#206, isEmpty#207, sum#208, isEmpty#209, sum#210, isEmpty#211, sum#212, isEmpty#213, sum#214, isEmpty#215, sum#216, isEmpty#217, sum#218, isEmpty#219, sum#220, isEmpty#221, sum#222, isEmpty#223, sum#224, isEmpty#225, sum#226, isEmpty#227, sum#228, isEmpty#229, sum#230, isEmpty#231, sum#232, isEmpty#233, sum#234, isEmpty#235, sum#236, isEmpty#237, sum#238, isEmpty#239] Results [55]: [w_warehouse_name#181, w_warehouse_sq_ft#182, w_city#183, w_county#184, w_state#185, w_country#186, d_year#188, sum#240, isEmpty#241, sum#242, isEmpty#243, sum#244, isEmpty#245, sum#246, isEmpty#247, sum#248, isEmpty#249, sum#250, isEmpty#251, sum#252, isEmpty#253, sum#254, isEmpty#255, sum#256, isEmpty#257, sum#258, isEmpty#259, sum#260, isEmpty#261, sum#262, isEmpty#263, sum#264, isEmpty#265, sum#266, isEmpty#267, sum#268, isEmpty#269, sum#270, isEmpty#271, sum#272, isEmpty#273, sum#274, isEmpty#275, sum#276, isEmpty#277, sum#278, isEmpty#279, sum#280, isEmpty#281, sum#282, isEmpty#283, sum#284, isEmpty#285, sum#286, isEmpty#287] @@ -264,16 +264,16 @@ Arguments: hashpartitioning(w_warehouse_name#181, w_warehouse_sq_ft#182, w_city# (47) HashAggregate [codegen id : 12] Input [55]: [w_warehouse_name#181, w_warehouse_sq_ft#182, w_city#183, w_county#184, w_state#185, w_country#186, d_year#188, sum#240, isEmpty#241, sum#242, isEmpty#243, sum#244, isEmpty#245, sum#246, isEmpty#247, sum#248, isEmpty#249, sum#250, isEmpty#251, sum#252, isEmpty#253, sum#254, isEmpty#255, sum#256, isEmpty#257, sum#258, isEmpty#259, sum#260, isEmpty#261, sum#262, isEmpty#263, sum#264, isEmpty#265, sum#266, isEmpty#267, sum#268, isEmpty#269, sum#270, isEmpty#271, sum#272, isEmpty#273, sum#274, isEmpty#275, sum#276, isEmpty#277, sum#278, isEmpty#279, sum#280, isEmpty#281, sum#282, isEmpty#283, sum#284, isEmpty#285, sum#286, isEmpty#287] Keys [7]: [w_warehouse_name#181, w_warehouse_sq_ft#182, w_city#183, w_county#184, w_state#185, w_country#186, d_year#188] -Functions [24]: [sum(CASE WHEN (d_moy#189 = 1) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 2) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 3) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 4) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 5) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 6) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 7) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 8) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 9) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 10) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 12) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 1) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 2) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 3) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 4) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 5) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 6) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 7) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 8) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 9) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 10) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 11) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 12) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)] -Aggregate Attributes [24]: [sum(CASE WHEN (d_moy#189 = 1) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#289, sum(CASE WHEN (d_moy#189 = 2) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#290, sum(CASE WHEN (d_moy#189 = 3) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#291, sum(CASE WHEN (d_moy#189 = 4) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#292, sum(CASE WHEN (d_moy#189 = 5) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#293, sum(CASE WHEN (d_moy#189 = 6) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#294, sum(CASE WHEN (d_moy#189 = 7) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#295, sum(CASE WHEN (d_moy#189 = 8) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#296, sum(CASE WHEN (d_moy#189 = 9) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#297, sum(CASE WHEN (d_moy#189 = 10) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#298, sum(CASE WHEN (d_moy#189 = 11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#299, sum(CASE WHEN (d_moy#189 = 12) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#300, sum(CASE WHEN (d_moy#189 = 1) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#301, sum(CASE WHEN (d_moy#189 = 2) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#302, sum(CASE WHEN (d_moy#189 = 3) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#303, sum(CASE WHEN (d_moy#189 = 4) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#304, sum(CASE WHEN (d_moy#189 = 5) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#305, sum(CASE WHEN (d_moy#189 = 6) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#306, sum(CASE WHEN (d_moy#189 = 7) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#307, sum(CASE WHEN (d_moy#189 = 8) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#308, sum(CASE WHEN (d_moy#189 = 9) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#309, sum(CASE WHEN (d_moy#189 = 10) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#310, sum(CASE WHEN (d_moy#189 = 11) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#311, sum(CASE WHEN (d_moy#189 = 12) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#312] -Results [32]: [w_warehouse_name#181, w_warehouse_sq_ft#182, w_city#183, w_county#184, w_state#185, w_country#186, DHL,BARIAN AS ship_carriers#313, d_year#188 AS year#314, sum(CASE WHEN (d_moy#189 = 1) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#289 AS jan_sales#315, sum(CASE WHEN (d_moy#189 = 2) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#290 AS feb_sales#316, sum(CASE WHEN (d_moy#189 = 3) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#291 AS mar_sales#317, sum(CASE WHEN (d_moy#189 = 4) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#292 AS apr_sales#318, sum(CASE WHEN (d_moy#189 = 5) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#293 AS may_sales#319, sum(CASE WHEN (d_moy#189 = 6) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#294 AS jun_sales#320, sum(CASE WHEN (d_moy#189 = 7) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#295 AS jul_sales#321, sum(CASE WHEN (d_moy#189 = 8) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#296 AS aug_sales#322, sum(CASE WHEN (d_moy#189 = 9) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#297 AS sep_sales#323, sum(CASE WHEN (d_moy#189 = 10) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#298 AS oct_sales#324, sum(CASE WHEN (d_moy#189 = 11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#299 AS nov_sales#325, sum(CASE WHEN (d_moy#189 = 12) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#300 AS dec_sales#326, sum(CASE WHEN (d_moy#189 = 1) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#301 AS jan_net#327, sum(CASE WHEN (d_moy#189 = 2) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#302 AS feb_net#328, sum(CASE WHEN (d_moy#189 = 3) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#303 AS mar_net#329, sum(CASE WHEN (d_moy#189 = 4) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#304 AS apr_net#330, sum(CASE WHEN (d_moy#189 = 5) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#305 AS may_net#331, sum(CASE WHEN (d_moy#189 = 6) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#306 AS jun_net#332, sum(CASE WHEN (d_moy#189 = 7) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#307 AS jul_net#333, sum(CASE WHEN (d_moy#189 = 8) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#308 AS aug_net#334, sum(CASE WHEN (d_moy#189 = 9) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#309 AS sep_net#335, sum(CASE WHEN (d_moy#189 = 10) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#310 AS oct_net#336, sum(CASE WHEN (d_moy#189 = 11) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#311 AS nov_net#337, sum(CASE WHEN (d_moy#189 = 12) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cast(cs_quantity#176 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END)#312 AS dec_net#338] +Functions [24]: [sum(CASE WHEN (d_moy#189 = 1) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 2) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 3) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 4) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 5) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 6) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 7) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 8) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 9) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 10) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 12) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 1) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 2) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 3) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 4) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 5) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 6) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 7) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 8) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 9) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 10) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 11) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END), sum(CASE WHEN (d_moy#189 = 12) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)] +Aggregate Attributes [24]: [sum(CASE WHEN (d_moy#189 = 1) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#289, sum(CASE WHEN (d_moy#189 = 2) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#290, sum(CASE WHEN (d_moy#189 = 3) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#291, sum(CASE WHEN (d_moy#189 = 4) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#292, sum(CASE WHEN (d_moy#189 = 5) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#293, sum(CASE WHEN (d_moy#189 = 6) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#294, sum(CASE WHEN (d_moy#189 = 7) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#295, sum(CASE WHEN (d_moy#189 = 8) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#296, sum(CASE WHEN (d_moy#189 = 9) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#297, sum(CASE WHEN (d_moy#189 = 10) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#298, sum(CASE WHEN (d_moy#189 = 11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#299, sum(CASE WHEN (d_moy#189 = 12) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#300, sum(CASE WHEN (d_moy#189 = 1) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#301, sum(CASE WHEN (d_moy#189 = 2) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#302, sum(CASE WHEN (d_moy#189 = 3) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#303, sum(CASE WHEN (d_moy#189 = 4) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#304, sum(CASE WHEN (d_moy#189 = 5) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#305, sum(CASE WHEN (d_moy#189 = 6) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#306, sum(CASE WHEN (d_moy#189 = 7) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#307, sum(CASE WHEN (d_moy#189 = 8) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#308, sum(CASE WHEN (d_moy#189 = 9) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#309, sum(CASE WHEN (d_moy#189 = 10) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#310, sum(CASE WHEN (d_moy#189 = 11) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#311, sum(CASE WHEN (d_moy#189 = 12) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#312] +Results [32]: [w_warehouse_name#181, w_warehouse_sq_ft#182, w_city#183, w_county#184, w_state#185, w_country#186, DHL,BARIAN AS ship_carriers#313, d_year#188 AS year#314, sum(CASE WHEN (d_moy#189 = 1) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#289 AS jan_sales#315, sum(CASE WHEN (d_moy#189 = 2) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#290 AS feb_sales#316, sum(CASE WHEN (d_moy#189 = 3) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#291 AS mar_sales#317, sum(CASE WHEN (d_moy#189 = 4) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#292 AS apr_sales#318, sum(CASE WHEN (d_moy#189 = 5) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#293 AS may_sales#319, sum(CASE WHEN (d_moy#189 = 6) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#294 AS jun_sales#320, sum(CASE WHEN (d_moy#189 = 7) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#295 AS jul_sales#321, sum(CASE WHEN (d_moy#189 = 8) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#296 AS aug_sales#322, sum(CASE WHEN (d_moy#189 = 9) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#297 AS sep_sales#323, sum(CASE WHEN (d_moy#189 = 10) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#298 AS oct_sales#324, sum(CASE WHEN (d_moy#189 = 11) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#299 AS nov_sales#325, sum(CASE WHEN (d_moy#189 = 12) THEN CheckOverflow((promote_precision(cast(cs_sales_price#177 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#300 AS dec_sales#326, sum(CASE WHEN (d_moy#189 = 1) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#301 AS jan_net#327, sum(CASE WHEN (d_moy#189 = 2) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#302 AS feb_net#328, sum(CASE WHEN (d_moy#189 = 3) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#303 AS mar_net#329, sum(CASE WHEN (d_moy#189 = 4) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#304 AS apr_net#330, sum(CASE WHEN (d_moy#189 = 5) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#305 AS may_net#331, sum(CASE WHEN (d_moy#189 = 6) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#306 AS jun_net#332, sum(CASE WHEN (d_moy#189 = 7) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#307 AS jul_net#333, sum(CASE WHEN (d_moy#189 = 8) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#308 AS aug_net#334, sum(CASE WHEN (d_moy#189 = 9) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#309 AS sep_net#335, sum(CASE WHEN (d_moy#189 = 10) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#310 AS oct_net#336, sum(CASE WHEN (d_moy#189 = 11) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#311 AS nov_net#337, sum(CASE WHEN (d_moy#189 = 12) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax#178 as decimal(12,2))) * promote_precision(cast(cs_quantity#176 as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END)#312 AS dec_net#338] (48) Union (49) HashAggregate [codegen id : 13] Input [32]: [w_warehouse_name#10, w_warehouse_sq_ft#11, w_city#12, w_county#13, w_state#14, w_country#15, ship_carriers#147, year#148, jan_sales#149, feb_sales#150, mar_sales#151, apr_sales#152, may_sales#153, jun_sales#154, jul_sales#155, aug_sales#156, sep_sales#157, oct_sales#158, nov_sales#159, dec_sales#160, jan_net#161, feb_net#162, mar_net#163, apr_net#164, may_net#165, jun_net#166, jul_net#167, aug_net#168, sep_net#169, oct_net#170, nov_net#171, dec_net#172] Keys [8]: [w_warehouse_name#10, w_warehouse_sq_ft#11, w_city#12, w_county#13, w_state#14, w_country#15, ship_carriers#147, year#148] -Functions [36]: [partial_sum(jan_sales#149), partial_sum(feb_sales#150), partial_sum(mar_sales#151), partial_sum(apr_sales#152), partial_sum(may_sales#153), partial_sum(jun_sales#154), partial_sum(jul_sales#155), partial_sum(aug_sales#156), partial_sum(sep_sales#157), partial_sum(oct_sales#158), partial_sum(nov_sales#159), partial_sum(dec_sales#160), partial_sum(CheckOverflow((promote_precision(jan_sales#149) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), partial_sum(CheckOverflow((promote_precision(feb_sales#150) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), partial_sum(CheckOverflow((promote_precision(mar_sales#151) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), partial_sum(CheckOverflow((promote_precision(apr_sales#152) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), partial_sum(CheckOverflow((promote_precision(may_sales#153) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), partial_sum(CheckOverflow((promote_precision(jun_sales#154) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), partial_sum(CheckOverflow((promote_precision(jul_sales#155) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), partial_sum(CheckOverflow((promote_precision(aug_sales#156) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), partial_sum(CheckOverflow((promote_precision(sep_sales#157) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), partial_sum(CheckOverflow((promote_precision(oct_sales#158) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), partial_sum(CheckOverflow((promote_precision(nov_sales#159) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), partial_sum(CheckOverflow((promote_precision(dec_sales#160) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), partial_sum(jan_net#161), partial_sum(feb_net#162), partial_sum(mar_net#163), partial_sum(apr_net#164), partial_sum(may_net#165), partial_sum(jun_net#166), partial_sum(jul_net#167), partial_sum(aug_net#168), partial_sum(sep_net#169), partial_sum(oct_net#170), partial_sum(nov_net#171), partial_sum(dec_net#172)] +Functions [36]: [partial_sum(jan_sales#149), partial_sum(feb_sales#150), partial_sum(mar_sales#151), partial_sum(apr_sales#152), partial_sum(may_sales#153), partial_sum(jun_sales#154), partial_sum(jul_sales#155), partial_sum(aug_sales#156), partial_sum(sep_sales#157), partial_sum(oct_sales#158), partial_sum(nov_sales#159), partial_sum(dec_sales#160), partial_sum(CheckOverflow((promote_precision(jan_sales#149) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12))), partial_sum(CheckOverflow((promote_precision(feb_sales#150) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12))), partial_sum(CheckOverflow((promote_precision(mar_sales#151) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12))), partial_sum(CheckOverflow((promote_precision(apr_sales#152) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12))), partial_sum(CheckOverflow((promote_precision(may_sales#153) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12))), partial_sum(CheckOverflow((promote_precision(jun_sales#154) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12))), partial_sum(CheckOverflow((promote_precision(jul_sales#155) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12))), partial_sum(CheckOverflow((promote_precision(aug_sales#156) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12))), partial_sum(CheckOverflow((promote_precision(sep_sales#157) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12))), partial_sum(CheckOverflow((promote_precision(oct_sales#158) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12))), partial_sum(CheckOverflow((promote_precision(nov_sales#159) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12))), partial_sum(CheckOverflow((promote_precision(dec_sales#160) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12))), partial_sum(jan_net#161), partial_sum(feb_net#162), partial_sum(mar_net#163), partial_sum(apr_net#164), partial_sum(may_net#165), partial_sum(jun_net#166), partial_sum(jul_net#167), partial_sum(aug_net#168), partial_sum(sep_net#169), partial_sum(oct_net#170), partial_sum(nov_net#171), partial_sum(dec_net#172)] Aggregate Attributes [72]: [sum#339, isEmpty#340, sum#341, isEmpty#342, sum#343, isEmpty#344, sum#345, isEmpty#346, sum#347, isEmpty#348, sum#349, isEmpty#350, sum#351, isEmpty#352, sum#353, isEmpty#354, sum#355, isEmpty#356, sum#357, isEmpty#358, sum#359, isEmpty#360, sum#361, isEmpty#362, sum#363, isEmpty#364, sum#365, isEmpty#366, sum#367, isEmpty#368, sum#369, isEmpty#370, sum#371, isEmpty#372, sum#373, isEmpty#374, sum#375, isEmpty#376, sum#377, isEmpty#378, sum#379, isEmpty#380, sum#381, isEmpty#382, sum#383, isEmpty#384, sum#385, isEmpty#386, sum#387, isEmpty#388, sum#389, isEmpty#390, sum#391, isEmpty#392, sum#393, isEmpty#394, sum#395, isEmpty#396, sum#397, isEmpty#398, sum#399, isEmpty#400, sum#401, isEmpty#402, sum#403, isEmpty#404, sum#405, isEmpty#406, sum#407, isEmpty#408, sum#409, isEmpty#410] Results [80]: [w_warehouse_name#10, w_warehouse_sq_ft#11, w_city#12, w_county#13, w_state#14, w_country#15, ship_carriers#147, year#148, sum#411, isEmpty#412, sum#413, isEmpty#414, sum#415, isEmpty#416, sum#417, isEmpty#418, sum#419, isEmpty#420, sum#421, isEmpty#422, sum#423, isEmpty#424, sum#425, isEmpty#426, sum#427, isEmpty#428, sum#429, isEmpty#430, sum#431, isEmpty#432, sum#433, isEmpty#434, sum#435, isEmpty#436, sum#437, isEmpty#438, sum#439, isEmpty#440, sum#441, isEmpty#442, sum#443, isEmpty#444, sum#445, isEmpty#446, sum#447, isEmpty#448, sum#449, isEmpty#450, sum#451, isEmpty#452, sum#453, isEmpty#454, sum#455, isEmpty#456, sum#457, isEmpty#458, sum#459, isEmpty#460, sum#461, isEmpty#462, sum#463, isEmpty#464, sum#465, isEmpty#466, sum#467, isEmpty#468, sum#469, isEmpty#470, sum#471, isEmpty#472, sum#473, isEmpty#474, sum#475, isEmpty#476, sum#477, isEmpty#478, sum#479, isEmpty#480, sum#481, isEmpty#482] @@ -284,9 +284,9 @@ Arguments: hashpartitioning(w_warehouse_name#10, w_warehouse_sq_ft#11, w_city#12 (51) HashAggregate [codegen id : 14] Input [80]: [w_warehouse_name#10, w_warehouse_sq_ft#11, w_city#12, w_county#13, w_state#14, w_country#15, ship_carriers#147, year#148, sum#411, isEmpty#412, sum#413, isEmpty#414, sum#415, isEmpty#416, sum#417, isEmpty#418, sum#419, isEmpty#420, sum#421, isEmpty#422, sum#423, isEmpty#424, sum#425, isEmpty#426, sum#427, isEmpty#428, sum#429, isEmpty#430, sum#431, isEmpty#432, sum#433, isEmpty#434, sum#435, isEmpty#436, sum#437, isEmpty#438, sum#439, isEmpty#440, sum#441, isEmpty#442, sum#443, isEmpty#444, sum#445, isEmpty#446, sum#447, isEmpty#448, sum#449, isEmpty#450, sum#451, isEmpty#452, sum#453, isEmpty#454, sum#455, isEmpty#456, sum#457, isEmpty#458, sum#459, isEmpty#460, sum#461, isEmpty#462, sum#463, isEmpty#464, sum#465, isEmpty#466, sum#467, isEmpty#468, sum#469, isEmpty#470, sum#471, isEmpty#472, sum#473, isEmpty#474, sum#475, isEmpty#476, sum#477, isEmpty#478, sum#479, isEmpty#480, sum#481, isEmpty#482] Keys [8]: [w_warehouse_name#10, w_warehouse_sq_ft#11, w_city#12, w_county#13, w_state#14, w_country#15, ship_carriers#147, year#148] -Functions [36]: [sum(jan_sales#149), sum(feb_sales#150), sum(mar_sales#151), sum(apr_sales#152), sum(may_sales#153), sum(jun_sales#154), sum(jul_sales#155), sum(aug_sales#156), sum(sep_sales#157), sum(oct_sales#158), sum(nov_sales#159), sum(dec_sales#160), sum(CheckOverflow((promote_precision(jan_sales#149) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), sum(CheckOverflow((promote_precision(feb_sales#150) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), sum(CheckOverflow((promote_precision(mar_sales#151) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), sum(CheckOverflow((promote_precision(apr_sales#152) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), sum(CheckOverflow((promote_precision(may_sales#153) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), sum(CheckOverflow((promote_precision(jun_sales#154) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), sum(CheckOverflow((promote_precision(jul_sales#155) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), sum(CheckOverflow((promote_precision(aug_sales#156) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), sum(CheckOverflow((promote_precision(sep_sales#157) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), sum(CheckOverflow((promote_precision(oct_sales#158) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), sum(CheckOverflow((promote_precision(nov_sales#159) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), sum(CheckOverflow((promote_precision(dec_sales#160) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)), sum(jan_net#161), sum(feb_net#162), sum(mar_net#163), sum(apr_net#164), sum(may_net#165), sum(jun_net#166), sum(jul_net#167), sum(aug_net#168), sum(sep_net#169), sum(oct_net#170), sum(nov_net#171), sum(dec_net#172)] -Aggregate Attributes [36]: [sum(jan_sales#149)#484, sum(feb_sales#150)#485, sum(mar_sales#151)#486, sum(apr_sales#152)#487, sum(may_sales#153)#488, sum(jun_sales#154)#489, sum(jul_sales#155)#490, sum(aug_sales#156)#491, sum(sep_sales#157)#492, sum(oct_sales#158)#493, sum(nov_sales#159)#494, sum(dec_sales#160)#495, sum(CheckOverflow((promote_precision(jan_sales#149) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#496, sum(CheckOverflow((promote_precision(feb_sales#150) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#497, sum(CheckOverflow((promote_precision(mar_sales#151) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#498, sum(CheckOverflow((promote_precision(apr_sales#152) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#499, sum(CheckOverflow((promote_precision(may_sales#153) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#500, sum(CheckOverflow((promote_precision(jun_sales#154) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#501, sum(CheckOverflow((promote_precision(jul_sales#155) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#502, sum(CheckOverflow((promote_precision(aug_sales#156) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#503, sum(CheckOverflow((promote_precision(sep_sales#157) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#504, sum(CheckOverflow((promote_precision(oct_sales#158) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#505, sum(CheckOverflow((promote_precision(nov_sales#159) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#506, sum(CheckOverflow((promote_precision(dec_sales#160) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#507, sum(jan_net#161)#508, sum(feb_net#162)#509, sum(mar_net#163)#510, sum(apr_net#164)#511, sum(may_net#165)#512, sum(jun_net#166)#513, sum(jul_net#167)#514, sum(aug_net#168)#515, sum(sep_net#169)#516, sum(oct_net#170)#517, sum(nov_net#171)#518, sum(dec_net#172)#519] -Results [44]: [w_warehouse_name#10, w_warehouse_sq_ft#11, w_city#12, w_county#13, w_state#14, w_country#15, ship_carriers#147, year#148, sum(jan_sales#149)#484 AS jan_sales#520, sum(feb_sales#150)#485 AS feb_sales#521, sum(mar_sales#151)#486 AS mar_sales#522, sum(apr_sales#152)#487 AS apr_sales#523, sum(may_sales#153)#488 AS may_sales#524, sum(jun_sales#154)#489 AS jun_sales#525, sum(jul_sales#155)#490 AS jul_sales#526, sum(aug_sales#156)#491 AS aug_sales#527, sum(sep_sales#157)#492 AS sep_sales#528, sum(oct_sales#158)#493 AS oct_sales#529, sum(nov_sales#159)#494 AS nov_sales#530, sum(dec_sales#160)#495 AS dec_sales#531, sum(CheckOverflow((promote_precision(jan_sales#149) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#496 AS jan_sales_per_sq_foot#532, sum(CheckOverflow((promote_precision(feb_sales#150) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#497 AS feb_sales_per_sq_foot#533, sum(CheckOverflow((promote_precision(mar_sales#151) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#498 AS mar_sales_per_sq_foot#534, sum(CheckOverflow((promote_precision(apr_sales#152) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#499 AS apr_sales_per_sq_foot#535, sum(CheckOverflow((promote_precision(may_sales#153) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#500 AS may_sales_per_sq_foot#536, sum(CheckOverflow((promote_precision(jun_sales#154) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#501 AS jun_sales_per_sq_foot#537, sum(CheckOverflow((promote_precision(jul_sales#155) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#502 AS jul_sales_per_sq_foot#538, sum(CheckOverflow((promote_precision(aug_sales#156) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#503 AS aug_sales_per_sq_foot#539, sum(CheckOverflow((promote_precision(sep_sales#157) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#504 AS sep_sales_per_sq_foot#540, sum(CheckOverflow((promote_precision(oct_sales#158) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#505 AS oct_sales_per_sq_foot#541, sum(CheckOverflow((promote_precision(nov_sales#159) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#506 AS nov_sales_per_sq_foot#542, sum(CheckOverflow((promote_precision(dec_sales#160) / promote_precision(cast(cast(w_warehouse_sq_ft#11 as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true))#507 AS dec_sales_per_sq_foot#543, sum(jan_net#161)#508 AS jan_net#544, sum(feb_net#162)#509 AS feb_net#545, sum(mar_net#163)#510 AS mar_net#546, sum(apr_net#164)#511 AS apr_net#547, sum(may_net#165)#512 AS may_net#548, sum(jun_net#166)#513 AS jun_net#549, sum(jul_net#167)#514 AS jul_net#550, sum(aug_net#168)#515 AS aug_net#551, sum(sep_net#169)#516 AS sep_net#552, sum(oct_net#170)#517 AS oct_net#553, sum(nov_net#171)#518 AS nov_net#554, sum(dec_net#172)#519 AS dec_net#555] +Functions [36]: [sum(jan_sales#149), sum(feb_sales#150), sum(mar_sales#151), sum(apr_sales#152), sum(may_sales#153), sum(jun_sales#154), sum(jul_sales#155), sum(aug_sales#156), sum(sep_sales#157), sum(oct_sales#158), sum(nov_sales#159), sum(dec_sales#160), sum(CheckOverflow((promote_precision(jan_sales#149) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12))), sum(CheckOverflow((promote_precision(feb_sales#150) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12))), sum(CheckOverflow((promote_precision(mar_sales#151) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12))), sum(CheckOverflow((promote_precision(apr_sales#152) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12))), sum(CheckOverflow((promote_precision(may_sales#153) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12))), sum(CheckOverflow((promote_precision(jun_sales#154) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12))), sum(CheckOverflow((promote_precision(jul_sales#155) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12))), sum(CheckOverflow((promote_precision(aug_sales#156) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12))), sum(CheckOverflow((promote_precision(sep_sales#157) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12))), sum(CheckOverflow((promote_precision(oct_sales#158) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12))), sum(CheckOverflow((promote_precision(nov_sales#159) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12))), sum(CheckOverflow((promote_precision(dec_sales#160) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12))), sum(jan_net#161), sum(feb_net#162), sum(mar_net#163), sum(apr_net#164), sum(may_net#165), sum(jun_net#166), sum(jul_net#167), sum(aug_net#168), sum(sep_net#169), sum(oct_net#170), sum(nov_net#171), sum(dec_net#172)] +Aggregate Attributes [36]: [sum(jan_sales#149)#484, sum(feb_sales#150)#485, sum(mar_sales#151)#486, sum(apr_sales#152)#487, sum(may_sales#153)#488, sum(jun_sales#154)#489, sum(jul_sales#155)#490, sum(aug_sales#156)#491, sum(sep_sales#157)#492, sum(oct_sales#158)#493, sum(nov_sales#159)#494, sum(dec_sales#160)#495, sum(CheckOverflow((promote_precision(jan_sales#149) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12)))#496, sum(CheckOverflow((promote_precision(feb_sales#150) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12)))#497, sum(CheckOverflow((promote_precision(mar_sales#151) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12)))#498, sum(CheckOverflow((promote_precision(apr_sales#152) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12)))#499, sum(CheckOverflow((promote_precision(may_sales#153) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12)))#500, sum(CheckOverflow((promote_precision(jun_sales#154) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12)))#501, sum(CheckOverflow((promote_precision(jul_sales#155) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12)))#502, sum(CheckOverflow((promote_precision(aug_sales#156) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12)))#503, sum(CheckOverflow((promote_precision(sep_sales#157) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12)))#504, sum(CheckOverflow((promote_precision(oct_sales#158) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12)))#505, sum(CheckOverflow((promote_precision(nov_sales#159) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12)))#506, sum(CheckOverflow((promote_precision(dec_sales#160) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12)))#507, sum(jan_net#161)#508, sum(feb_net#162)#509, sum(mar_net#163)#510, sum(apr_net#164)#511, sum(may_net#165)#512, sum(jun_net#166)#513, sum(jul_net#167)#514, sum(aug_net#168)#515, sum(sep_net#169)#516, sum(oct_net#170)#517, sum(nov_net#171)#518, sum(dec_net#172)#519] +Results [44]: [w_warehouse_name#10, w_warehouse_sq_ft#11, w_city#12, w_county#13, w_state#14, w_country#15, ship_carriers#147, year#148, sum(jan_sales#149)#484 AS jan_sales#520, sum(feb_sales#150)#485 AS feb_sales#521, sum(mar_sales#151)#486 AS mar_sales#522, sum(apr_sales#152)#487 AS apr_sales#523, sum(may_sales#153)#488 AS may_sales#524, sum(jun_sales#154)#489 AS jun_sales#525, sum(jul_sales#155)#490 AS jul_sales#526, sum(aug_sales#156)#491 AS aug_sales#527, sum(sep_sales#157)#492 AS sep_sales#528, sum(oct_sales#158)#493 AS oct_sales#529, sum(nov_sales#159)#494 AS nov_sales#530, sum(dec_sales#160)#495 AS dec_sales#531, sum(CheckOverflow((promote_precision(jan_sales#149) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12)))#496 AS jan_sales_per_sq_foot#532, sum(CheckOverflow((promote_precision(feb_sales#150) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12)))#497 AS feb_sales_per_sq_foot#533, sum(CheckOverflow((promote_precision(mar_sales#151) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12)))#498 AS mar_sales_per_sq_foot#534, sum(CheckOverflow((promote_precision(apr_sales#152) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12)))#499 AS apr_sales_per_sq_foot#535, sum(CheckOverflow((promote_precision(may_sales#153) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12)))#500 AS may_sales_per_sq_foot#536, sum(CheckOverflow((promote_precision(jun_sales#154) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12)))#501 AS jun_sales_per_sq_foot#537, sum(CheckOverflow((promote_precision(jul_sales#155) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12)))#502 AS jul_sales_per_sq_foot#538, sum(CheckOverflow((promote_precision(aug_sales#156) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12)))#503 AS aug_sales_per_sq_foot#539, sum(CheckOverflow((promote_precision(sep_sales#157) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12)))#504 AS sep_sales_per_sq_foot#540, sum(CheckOverflow((promote_precision(oct_sales#158) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12)))#505 AS oct_sales_per_sq_foot#541, sum(CheckOverflow((promote_precision(nov_sales#159) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12)))#506 AS nov_sales_per_sq_foot#542, sum(CheckOverflow((promote_precision(dec_sales#160) / promote_precision(cast(w_warehouse_sq_ft#11 as decimal(28,2)))), DecimalType(38,12)))#507 AS dec_sales_per_sq_foot#543, sum(jan_net#161)#508 AS jan_net#544, sum(feb_net#162)#509 AS feb_net#545, sum(mar_net#163)#510 AS mar_net#546, sum(apr_net#164)#511 AS apr_net#547, sum(may_net#165)#512 AS may_net#548, sum(jun_net#166)#513 AS jun_net#549, sum(jul_net#167)#514 AS jul_net#550, sum(aug_net#168)#515 AS aug_net#551, sum(sep_net#169)#516 AS sep_net#552, sum(oct_net#170)#517 AS oct_net#553, sum(nov_net#171)#518 AS nov_net#554, sum(dec_net#172)#519 AS dec_net#555] (52) TakeOrderedAndProject Input [44]: [w_warehouse_name#10, w_warehouse_sq_ft#11, w_city#12, w_county#13, w_state#14, w_country#15, ship_carriers#147, year#148, jan_sales#520, feb_sales#521, mar_sales#522, apr_sales#523, may_sales#524, jun_sales#525, jul_sales#526, aug_sales#527, sep_sales#528, oct_sales#529, nov_sales#530, dec_sales#531, jan_sales_per_sq_foot#532, feb_sales_per_sq_foot#533, mar_sales_per_sq_foot#534, apr_sales_per_sq_foot#535, may_sales_per_sq_foot#536, jun_sales_per_sq_foot#537, jul_sales_per_sq_foot#538, aug_sales_per_sq_foot#539, sep_sales_per_sq_foot#540, oct_sales_per_sq_foot#541, nov_sales_per_sq_foot#542, dec_sales_per_sq_foot#543, jan_net#544, feb_net#545, mar_net#546, apr_net#547, may_net#548, jun_net#549, jul_net#550, aug_net#551, sep_net#552, oct_net#553, nov_net#554, dec_net#555] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q66/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q66/simplified.txt index 46e0418b4fabe..17037cfe02c2a 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q66/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q66/simplified.txt @@ -1,6 +1,6 @@ TakeOrderedAndProject [w_warehouse_name,w_warehouse_sq_ft,w_city,w_county,w_state,w_country,ship_carriers,year,jan_sales,feb_sales,mar_sales,apr_sales,may_sales,jun_sales,jul_sales,aug_sales,sep_sales,oct_sales,nov_sales,dec_sales,jan_sales_per_sq_foot,feb_sales_per_sq_foot,mar_sales_per_sq_foot,apr_sales_per_sq_foot,may_sales_per_sq_foot,jun_sales_per_sq_foot,jul_sales_per_sq_foot,aug_sales_per_sq_foot,sep_sales_per_sq_foot,oct_sales_per_sq_foot,nov_sales_per_sq_foot,dec_sales_per_sq_foot,jan_net,feb_net,mar_net,apr_net,may_net,jun_net,jul_net,aug_net,sep_net,oct_net,nov_net,dec_net] WholeStageCodegen (14) - HashAggregate [w_warehouse_name,w_warehouse_sq_ft,w_city,w_county,w_state,w_country,ship_carriers,year,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty] [sum(jan_sales),sum(feb_sales),sum(mar_sales),sum(apr_sales),sum(may_sales),sum(jun_sales),sum(jul_sales),sum(aug_sales),sum(sep_sales),sum(oct_sales),sum(nov_sales),sum(dec_sales),sum(CheckOverflow((promote_precision(jan_sales) / promote_precision(cast(cast(w_warehouse_sq_ft as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)),sum(CheckOverflow((promote_precision(feb_sales) / promote_precision(cast(cast(w_warehouse_sq_ft as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)),sum(CheckOverflow((promote_precision(mar_sales) / promote_precision(cast(cast(w_warehouse_sq_ft as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)),sum(CheckOverflow((promote_precision(apr_sales) / promote_precision(cast(cast(w_warehouse_sq_ft as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)),sum(CheckOverflow((promote_precision(may_sales) / promote_precision(cast(cast(w_warehouse_sq_ft as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)),sum(CheckOverflow((promote_precision(jun_sales) / promote_precision(cast(cast(w_warehouse_sq_ft as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)),sum(CheckOverflow((promote_precision(jul_sales) / promote_precision(cast(cast(w_warehouse_sq_ft as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)),sum(CheckOverflow((promote_precision(aug_sales) / promote_precision(cast(cast(w_warehouse_sq_ft as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)),sum(CheckOverflow((promote_precision(sep_sales) / promote_precision(cast(cast(w_warehouse_sq_ft as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)),sum(CheckOverflow((promote_precision(oct_sales) / promote_precision(cast(cast(w_warehouse_sq_ft as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)),sum(CheckOverflow((promote_precision(nov_sales) / promote_precision(cast(cast(w_warehouse_sq_ft as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)),sum(CheckOverflow((promote_precision(dec_sales) / promote_precision(cast(cast(w_warehouse_sq_ft as decimal(10,0)) as decimal(28,2)))), DecimalType(38,12), true)),sum(jan_net),sum(feb_net),sum(mar_net),sum(apr_net),sum(may_net),sum(jun_net),sum(jul_net),sum(aug_net),sum(sep_net),sum(oct_net),sum(nov_net),sum(dec_net),jan_sales,feb_sales,mar_sales,apr_sales,may_sales,jun_sales,jul_sales,aug_sales,sep_sales,oct_sales,nov_sales,dec_sales,jan_sales_per_sq_foot,feb_sales_per_sq_foot,mar_sales_per_sq_foot,apr_sales_per_sq_foot,may_sales_per_sq_foot,jun_sales_per_sq_foot,jul_sales_per_sq_foot,aug_sales_per_sq_foot,sep_sales_per_sq_foot,oct_sales_per_sq_foot,nov_sales_per_sq_foot,dec_sales_per_sq_foot,jan_net,feb_net,mar_net,apr_net,may_net,jun_net,jul_net,aug_net,sep_net,oct_net,nov_net,dec_net,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty] + HashAggregate [w_warehouse_name,w_warehouse_sq_ft,w_city,w_county,w_state,w_country,ship_carriers,year,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty] [sum(jan_sales),sum(feb_sales),sum(mar_sales),sum(apr_sales),sum(may_sales),sum(jun_sales),sum(jul_sales),sum(aug_sales),sum(sep_sales),sum(oct_sales),sum(nov_sales),sum(dec_sales),sum(CheckOverflow((promote_precision(jan_sales) / promote_precision(cast(w_warehouse_sq_ft as decimal(28,2)))), DecimalType(38,12))),sum(CheckOverflow((promote_precision(feb_sales) / promote_precision(cast(w_warehouse_sq_ft as decimal(28,2)))), DecimalType(38,12))),sum(CheckOverflow((promote_precision(mar_sales) / promote_precision(cast(w_warehouse_sq_ft as decimal(28,2)))), DecimalType(38,12))),sum(CheckOverflow((promote_precision(apr_sales) / promote_precision(cast(w_warehouse_sq_ft as decimal(28,2)))), DecimalType(38,12))),sum(CheckOverflow((promote_precision(may_sales) / promote_precision(cast(w_warehouse_sq_ft as decimal(28,2)))), DecimalType(38,12))),sum(CheckOverflow((promote_precision(jun_sales) / promote_precision(cast(w_warehouse_sq_ft as decimal(28,2)))), DecimalType(38,12))),sum(CheckOverflow((promote_precision(jul_sales) / promote_precision(cast(w_warehouse_sq_ft as decimal(28,2)))), DecimalType(38,12))),sum(CheckOverflow((promote_precision(aug_sales) / promote_precision(cast(w_warehouse_sq_ft as decimal(28,2)))), DecimalType(38,12))),sum(CheckOverflow((promote_precision(sep_sales) / promote_precision(cast(w_warehouse_sq_ft as decimal(28,2)))), DecimalType(38,12))),sum(CheckOverflow((promote_precision(oct_sales) / promote_precision(cast(w_warehouse_sq_ft as decimal(28,2)))), DecimalType(38,12))),sum(CheckOverflow((promote_precision(nov_sales) / promote_precision(cast(w_warehouse_sq_ft as decimal(28,2)))), DecimalType(38,12))),sum(CheckOverflow((promote_precision(dec_sales) / promote_precision(cast(w_warehouse_sq_ft as decimal(28,2)))), DecimalType(38,12))),sum(jan_net),sum(feb_net),sum(mar_net),sum(apr_net),sum(may_net),sum(jun_net),sum(jul_net),sum(aug_net),sum(sep_net),sum(oct_net),sum(nov_net),sum(dec_net),jan_sales,feb_sales,mar_sales,apr_sales,may_sales,jun_sales,jul_sales,aug_sales,sep_sales,oct_sales,nov_sales,dec_sales,jan_sales_per_sq_foot,feb_sales_per_sq_foot,mar_sales_per_sq_foot,apr_sales_per_sq_foot,may_sales_per_sq_foot,jun_sales_per_sq_foot,jul_sales_per_sq_foot,aug_sales_per_sq_foot,sep_sales_per_sq_foot,oct_sales_per_sq_foot,nov_sales_per_sq_foot,dec_sales_per_sq_foot,jan_net,feb_net,mar_net,apr_net,may_net,jun_net,jul_net,aug_net,sep_net,oct_net,nov_net,dec_net,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty] InputAdapter Exchange [w_warehouse_name,w_warehouse_sq_ft,w_city,w_county,w_state,w_country,ship_carriers,year] #1 WholeStageCodegen (13) @@ -8,7 +8,7 @@ TakeOrderedAndProject [w_warehouse_name,w_warehouse_sq_ft,w_city,w_county,w_stat InputAdapter Union WholeStageCodegen (6) - HashAggregate [w_warehouse_name,w_warehouse_sq_ft,w_city,w_county,w_state,w_country,d_year,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty] [sum(CASE WHEN (d_moy = 1) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 2) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 3) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 4) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 5) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 6) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 7) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 8) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 9) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 10) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 11) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 12) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 1) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 2) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 3) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 4) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 5) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 6) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 7) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 8) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 9) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 10) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 11) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 12) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),ship_carriers,year,jan_sales,feb_sales,mar_sales,apr_sales,may_sales,jun_sales,jul_sales,aug_sales,sep_sales,oct_sales,nov_sales,dec_sales,jan_net,feb_net,mar_net,apr_net,may_net,jun_net,jul_net,aug_net,sep_net,oct_net,nov_net,dec_net,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty] + HashAggregate [w_warehouse_name,w_warehouse_sq_ft,w_city,w_county,w_state,w_country,d_year,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty] [sum(CASE WHEN (d_moy = 1) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 2) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 3) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 4) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 5) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 6) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 7) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 8) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 9) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 10) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 11) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 12) THEN CheckOverflow((promote_precision(cast(ws_ext_sales_price as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 1) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 2) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 3) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 4) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 5) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 6) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 7) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 8) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 9) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 10) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 11) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 12) THEN CheckOverflow((promote_precision(cast(ws_net_paid as decimal(12,2))) * promote_precision(cast(ws_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),ship_carriers,year,jan_sales,feb_sales,mar_sales,apr_sales,may_sales,jun_sales,jul_sales,aug_sales,sep_sales,oct_sales,nov_sales,dec_sales,jan_net,feb_net,mar_net,apr_net,may_net,jun_net,jul_net,aug_net,sep_net,oct_net,nov_net,dec_net,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty] InputAdapter Exchange [w_warehouse_name,w_warehouse_sq_ft,w_city,w_county,w_state,w_country,d_year] #2 WholeStageCodegen (5) @@ -58,7 +58,7 @@ TakeOrderedAndProject [w_warehouse_name,w_warehouse_sq_ft,w_city,w_county,w_stat InputAdapter Scan parquet default.ship_mode [sm_ship_mode_sk,sm_carrier] WholeStageCodegen (12) - HashAggregate [w_warehouse_name,w_warehouse_sq_ft,w_city,w_county,w_state,w_country,d_year,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty] [sum(CASE WHEN (d_moy = 1) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 2) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 3) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 4) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 5) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 6) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 7) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 8) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 9) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 10) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 11) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 12) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 1) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 2) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 3) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 4) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 5) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 6) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 7) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 8) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 9) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 10) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 11) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),sum(CASE WHEN (d_moy = 12) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true) ELSE 0.00 END),ship_carriers,year,jan_sales,feb_sales,mar_sales,apr_sales,may_sales,jun_sales,jul_sales,aug_sales,sep_sales,oct_sales,nov_sales,dec_sales,jan_net,feb_net,mar_net,apr_net,may_net,jun_net,jul_net,aug_net,sep_net,oct_net,nov_net,dec_net,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty] + HashAggregate [w_warehouse_name,w_warehouse_sq_ft,w_city,w_county,w_state,w_country,d_year,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty] [sum(CASE WHEN (d_moy = 1) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 2) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 3) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 4) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 5) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 6) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 7) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 8) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 9) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 10) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 11) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 12) THEN CheckOverflow((promote_precision(cast(cs_sales_price as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 1) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 2) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 3) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 4) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 5) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 6) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 7) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 8) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 9) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 10) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 11) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),sum(CASE WHEN (d_moy = 12) THEN CheckOverflow((promote_precision(cast(cs_net_paid_inc_tax as decimal(12,2))) * promote_precision(cast(cs_quantity as decimal(12,2)))), DecimalType(18,2)) ELSE 0.00 END),ship_carriers,year,jan_sales,feb_sales,mar_sales,apr_sales,may_sales,jun_sales,jul_sales,aug_sales,sep_sales,oct_sales,nov_sales,dec_sales,jan_net,feb_net,mar_net,apr_net,may_net,jun_net,jul_net,aug_net,sep_net,oct_net,nov_net,dec_net,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty] InputAdapter Exchange [w_warehouse_name,w_warehouse_sq_ft,w_city,w_county,w_state,w_country,d_year] #7 WholeStageCodegen (11) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q67.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q67.sf100/explain.txt index d74fb5b4bfb61..5a6c73dbe6a98 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q67.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q67.sf100/explain.txt @@ -131,7 +131,7 @@ Arguments: [[ss_quantity#3, ss_sales_price#4, i_category#18, i_class#17, i_brand (23) HashAggregate [codegen id : 7] Input [11]: [ss_quantity#3, ss_sales_price#4, i_category#21, i_class#22, i_brand#23, i_product_name#24, d_year#25, d_qoy#26, d_moy#27, s_store_id#28, spark_grouping_id#29] Keys [9]: [i_category#21, i_class#22, i_brand#23, i_product_name#24, d_year#25, d_qoy#26, d_moy#27, s_store_id#28, spark_grouping_id#29] -Functions [1]: [partial_sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))] +Functions [1]: [partial_sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))] Aggregate Attributes [2]: [sum#30, isEmpty#31] Results [11]: [i_category#21, i_class#22, i_brand#23, i_product_name#24, d_year#25, d_qoy#26, d_moy#27, s_store_id#28, spark_grouping_id#29, sum#32, isEmpty#33] @@ -142,9 +142,9 @@ Arguments: hashpartitioning(i_category#21, i_class#22, i_brand#23, i_product_nam (25) HashAggregate [codegen id : 8] Input [11]: [i_category#21, i_class#22, i_brand#23, i_product_name#24, d_year#25, d_qoy#26, d_moy#27, s_store_id#28, spark_grouping_id#29, sum#32, isEmpty#33] Keys [9]: [i_category#21, i_class#22, i_brand#23, i_product_name#24, d_year#25, d_qoy#26, d_moy#27, s_store_id#28, spark_grouping_id#29] -Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))] -Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#35] -Results [9]: [i_category#21, i_class#22, i_brand#23, i_product_name#24, d_year#25, d_qoy#26, d_moy#27, s_store_id#28, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#35 AS sumsales#36] +Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))] +Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#35] +Results [9]: [i_category#21, i_class#22, i_brand#23, i_product_name#24, d_year#25, d_qoy#26, d_moy#27, s_store_id#28, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#35 AS sumsales#36] (26) Exchange Input [9]: [i_category#21, i_class#22, i_brand#23, i_product_name#24, d_year#25, d_qoy#26, d_moy#27, s_store_id#28, sumsales#36] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q67.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q67.sf100/simplified.txt index e6c26f61d2832..55953a73ff11d 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q67.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q67.sf100/simplified.txt @@ -8,7 +8,7 @@ TakeOrderedAndProject [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_ InputAdapter Exchange [i_category] #1 WholeStageCodegen (8) - HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,spark_grouping_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00)),sumsales,sum,isEmpty] + HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,spark_grouping_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(ss_quantity as decimal(12,2)))), DecimalType(18,2)), 0.00)),sumsales,sum,isEmpty] InputAdapter Exchange [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,spark_grouping_id] #2 WholeStageCodegen (7) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q67/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q67/explain.txt index a9efff6eba561..53f71a188fcb5 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q67/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q67/explain.txt @@ -116,7 +116,7 @@ Arguments: [[ss_quantity#3, ss_sales_price#4, i_category#17, i_class#16, i_brand (20) HashAggregate [codegen id : 4] Input [11]: [ss_quantity#3, ss_sales_price#4, i_category#20, i_class#21, i_brand#22, i_product_name#23, d_year#24, d_qoy#25, d_moy#26, s_store_id#27, spark_grouping_id#28] Keys [9]: [i_category#20, i_class#21, i_brand#22, i_product_name#23, d_year#24, d_qoy#25, d_moy#26, s_store_id#27, spark_grouping_id#28] -Functions [1]: [partial_sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))] +Functions [1]: [partial_sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))] Aggregate Attributes [2]: [sum#29, isEmpty#30] Results [11]: [i_category#20, i_class#21, i_brand#22, i_product_name#23, d_year#24, d_qoy#25, d_moy#26, s_store_id#27, spark_grouping_id#28, sum#31, isEmpty#32] @@ -127,9 +127,9 @@ Arguments: hashpartitioning(i_category#20, i_class#21, i_brand#22, i_product_nam (22) HashAggregate [codegen id : 5] Input [11]: [i_category#20, i_class#21, i_brand#22, i_product_name#23, d_year#24, d_qoy#25, d_moy#26, s_store_id#27, spark_grouping_id#28, sum#31, isEmpty#32] Keys [9]: [i_category#20, i_class#21, i_brand#22, i_product_name#23, d_year#24, d_qoy#25, d_moy#26, s_store_id#27, spark_grouping_id#28] -Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))] -Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#34] -Results [9]: [i_category#20, i_class#21, i_brand#22, i_product_name#23, d_year#24, d_qoy#25, d_moy#26, s_store_id#27, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#34 AS sumsales#35] +Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))] +Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#34] +Results [9]: [i_category#20, i_class#21, i_brand#22, i_product_name#23, d_year#24, d_qoy#25, d_moy#26, s_store_id#27, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#34 AS sumsales#35] (23) Exchange Input [9]: [i_category#20, i_class#21, i_brand#22, i_product_name#23, d_year#24, d_qoy#25, d_moy#26, s_store_id#27, sumsales#35] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q67/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q67/simplified.txt index 5b7d1595c0398..3cb879f7019b5 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q67/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q67/simplified.txt @@ -8,7 +8,7 @@ TakeOrderedAndProject [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_ InputAdapter Exchange [i_category] #1 WholeStageCodegen (5) - HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,spark_grouping_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00)),sumsales,sum,isEmpty] + HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,spark_grouping_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(ss_quantity as decimal(12,2)))), DecimalType(18,2)), 0.00)),sumsales,sum,isEmpty] InputAdapter Exchange [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,spark_grouping_id] #2 WholeStageCodegen (4) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q72.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q72.sf100/explain.txt index d4ecd7a94c66a..c6971f3ea904b 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q72.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q72.sf100/explain.txt @@ -1,72 +1,74 @@ == Physical Plan == -TakeOrderedAndProject (68) -+- * HashAggregate (67) - +- Exchange (66) - +- * HashAggregate (65) - +- * Project (64) - +- * SortMergeJoin LeftOuter (63) - :- * Sort (56) - : +- * Project (55) - : +- * BroadcastHashJoin LeftOuter BuildRight (54) - : :- * Project (49) - : : +- * SortMergeJoin Inner (48) - : : :- * Sort (36) - : : : +- * Project (35) - : : : +- * BroadcastHashJoin Inner BuildRight (34) - : : : :- * Project (32) - : : : : +- * SortMergeJoin Inner (31) - : : : : :- * Sort (25) - : : : : : +- Exchange (24) - : : : : : +- * Project (23) - : : : : : +- * BroadcastHashJoin Inner BuildRight (22) - : : : : : :- * Project (17) - : : : : : : +- * BroadcastHashJoin Inner BuildRight (16) - : : : : : : :- * Project (10) - : : : : : : : +- * BroadcastHashJoin Inner BuildRight (9) - : : : : : : : :- * Filter (3) - : : : : : : : : +- * ColumnarToRow (2) - : : : : : : : : +- Scan parquet default.catalog_sales (1) - : : : : : : : +- BroadcastExchange (8) - : : : : : : : +- * Project (7) - : : : : : : : +- * Filter (6) - : : : : : : : +- * ColumnarToRow (5) - : : : : : : : +- Scan parquet default.household_demographics (4) - : : : : : : +- BroadcastExchange (15) - : : : : : : +- * Project (14) - : : : : : : +- * Filter (13) - : : : : : : +- * ColumnarToRow (12) - : : : : : : +- Scan parquet default.customer_demographics (11) - : : : : : +- BroadcastExchange (21) - : : : : : +- * Filter (20) - : : : : : +- * ColumnarToRow (19) - : : : : : +- Scan parquet default.date_dim (18) - : : : : +- * Sort (30) - : : : : +- Exchange (29) - : : : : +- * Filter (28) - : : : : +- * ColumnarToRow (27) - : : : : +- Scan parquet default.item (26) - : : : +- ReusedExchange (33) - : : +- * Sort (47) - : : +- Exchange (46) - : : +- * Project (45) - : : +- * BroadcastHashJoin Inner BuildRight (44) - : : :- * Filter (39) - : : : +- * ColumnarToRow (38) - : : : +- Scan parquet default.inventory (37) - : : +- BroadcastExchange (43) - : : +- * Filter (42) - : : +- * ColumnarToRow (41) - : : +- Scan parquet default.warehouse (40) - : +- BroadcastExchange (53) - : +- * Filter (52) - : +- * ColumnarToRow (51) - : +- Scan parquet default.promotion (50) - +- * Sort (62) - +- Exchange (61) - +- * Project (60) - +- * Filter (59) - +- * ColumnarToRow (58) - +- Scan parquet default.catalog_returns (57) +TakeOrderedAndProject (70) ++- * HashAggregate (69) + +- Exchange (68) + +- * HashAggregate (67) + +- * Project (66) + +- * SortMergeJoin LeftOuter (65) + :- * Sort (58) + : +- Exchange (57) + : +- * Project (56) + : +- * BroadcastHashJoin LeftOuter BuildRight (55) + : :- * Project (50) + : : +- * SortMergeJoin Inner (49) + : : :- * Sort (37) + : : : +- Exchange (36) + : : : +- * Project (35) + : : : +- * BroadcastHashJoin Inner BuildRight (34) + : : : :- * Project (32) + : : : : +- * SortMergeJoin Inner (31) + : : : : :- * Sort (25) + : : : : : +- Exchange (24) + : : : : : +- * Project (23) + : : : : : +- * BroadcastHashJoin Inner BuildRight (22) + : : : : : :- * Project (17) + : : : : : : +- * BroadcastHashJoin Inner BuildRight (16) + : : : : : : :- * Project (10) + : : : : : : : +- * BroadcastHashJoin Inner BuildRight (9) + : : : : : : : :- * Filter (3) + : : : : : : : : +- * ColumnarToRow (2) + : : : : : : : : +- Scan parquet default.catalog_sales (1) + : : : : : : : +- BroadcastExchange (8) + : : : : : : : +- * Project (7) + : : : : : : : +- * Filter (6) + : : : : : : : +- * ColumnarToRow (5) + : : : : : : : +- Scan parquet default.household_demographics (4) + : : : : : : +- BroadcastExchange (15) + : : : : : : +- * Project (14) + : : : : : : +- * Filter (13) + : : : : : : +- * ColumnarToRow (12) + : : : : : : +- Scan parquet default.customer_demographics (11) + : : : : : +- BroadcastExchange (21) + : : : : : +- * Filter (20) + : : : : : +- * ColumnarToRow (19) + : : : : : +- Scan parquet default.date_dim (18) + : : : : +- * Sort (30) + : : : : +- Exchange (29) + : : : : +- * Filter (28) + : : : : +- * ColumnarToRow (27) + : : : : +- Scan parquet default.item (26) + : : : +- ReusedExchange (33) + : : +- * Sort (48) + : : +- Exchange (47) + : : +- * Project (46) + : : +- * BroadcastHashJoin Inner BuildRight (45) + : : :- * Filter (40) + : : : +- * ColumnarToRow (39) + : : : +- Scan parquet default.inventory (38) + : : +- BroadcastExchange (44) + : : +- * Filter (43) + : : +- * ColumnarToRow (42) + : : +- Scan parquet default.warehouse (41) + : +- BroadcastExchange (54) + : +- * Filter (53) + : +- * ColumnarToRow (52) + : +- Scan parquet default.promotion (51) + +- * Sort (64) + +- Exchange (63) + +- * Project (62) + +- * Filter (61) + +- * ColumnarToRow (60) + +- Scan parquet default.catalog_returns (59) (1) Scan parquet default.catalog_sales @@ -212,7 +214,7 @@ Join condition: None Output [7]: [cs_item_sk#4, cs_promo_sk#5, cs_order_number#6, cs_quantity#7, cs_sold_date_sk#8, d_date#17, i_item_desc#21] Input [8]: [cs_item_sk#4, cs_promo_sk#5, cs_order_number#6, cs_quantity#7, cs_sold_date_sk#8, d_date#17, i_item_sk#20, i_item_desc#21] -(33) ReusedExchange [Reuses operator id: 79] +(33) ReusedExchange [Reuses operator id: 81] Output [4]: [d_date_sk#23, d_date#24, d_week_seq#25, d_date_sk#26] (34) BroadcastHashJoin [codegen id : 10] @@ -224,220 +226,228 @@ Join condition: (d_date#17 > date_add(d_date#24, 5)) Output [7]: [cs_item_sk#4, cs_promo_sk#5, cs_order_number#6, cs_quantity#7, i_item_desc#21, d_week_seq#25, d_date_sk#26] Input [11]: [cs_item_sk#4, cs_promo_sk#5, cs_order_number#6, cs_quantity#7, cs_sold_date_sk#8, d_date#17, i_item_desc#21, d_date_sk#23, d_date#24, d_week_seq#25, d_date_sk#26] -(36) Sort [codegen id : 10] +(36) Exchange +Input [7]: [cs_item_sk#4, cs_promo_sk#5, cs_order_number#6, cs_quantity#7, i_item_desc#21, d_week_seq#25, d_date_sk#26] +Arguments: hashpartitioning(cs_item_sk#4, d_date_sk#26, 5), ENSURE_REQUIREMENTS, [id=#27] + +(37) Sort [codegen id : 11] Input [7]: [cs_item_sk#4, cs_promo_sk#5, cs_order_number#6, cs_quantity#7, i_item_desc#21, d_week_seq#25, d_date_sk#26] Arguments: [cs_item_sk#4 ASC NULLS FIRST, d_date_sk#26 ASC NULLS FIRST], false, 0 -(37) Scan parquet default.inventory -Output [4]: [inv_item_sk#27, inv_warehouse_sk#28, inv_quantity_on_hand#29, inv_date_sk#30] +(38) Scan parquet default.inventory +Output [4]: [inv_item_sk#28, inv_warehouse_sk#29, inv_quantity_on_hand#30, inv_date_sk#31] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(inv_date_sk#30), dynamicpruningexpression(true)] +PartitionFilters: [isnotnull(inv_date_sk#31), dynamicpruningexpression(true)] PushedFilters: [IsNotNull(inv_quantity_on_hand), IsNotNull(inv_item_sk), IsNotNull(inv_warehouse_sk)] ReadSchema: struct -(38) ColumnarToRow [codegen id : 12] -Input [4]: [inv_item_sk#27, inv_warehouse_sk#28, inv_quantity_on_hand#29, inv_date_sk#30] +(39) ColumnarToRow [codegen id : 13] +Input [4]: [inv_item_sk#28, inv_warehouse_sk#29, inv_quantity_on_hand#30, inv_date_sk#31] -(39) Filter [codegen id : 12] -Input [4]: [inv_item_sk#27, inv_warehouse_sk#28, inv_quantity_on_hand#29, inv_date_sk#30] -Condition : ((isnotnull(inv_quantity_on_hand#29) AND isnotnull(inv_item_sk#27)) AND isnotnull(inv_warehouse_sk#28)) +(40) Filter [codegen id : 13] +Input [4]: [inv_item_sk#28, inv_warehouse_sk#29, inv_quantity_on_hand#30, inv_date_sk#31] +Condition : ((isnotnull(inv_quantity_on_hand#30) AND isnotnull(inv_item_sk#28)) AND isnotnull(inv_warehouse_sk#29)) -(40) Scan parquet default.warehouse -Output [2]: [w_warehouse_sk#31, w_warehouse_name#32] +(41) Scan parquet default.warehouse +Output [2]: [w_warehouse_sk#32, w_warehouse_name#33] Batched: true Location [not included in comparison]/{warehouse_dir}/warehouse] PushedFilters: [IsNotNull(w_warehouse_sk)] ReadSchema: struct -(41) ColumnarToRow [codegen id : 11] -Input [2]: [w_warehouse_sk#31, w_warehouse_name#32] +(42) ColumnarToRow [codegen id : 12] +Input [2]: [w_warehouse_sk#32, w_warehouse_name#33] -(42) Filter [codegen id : 11] -Input [2]: [w_warehouse_sk#31, w_warehouse_name#32] -Condition : isnotnull(w_warehouse_sk#31) +(43) Filter [codegen id : 12] +Input [2]: [w_warehouse_sk#32, w_warehouse_name#33] +Condition : isnotnull(w_warehouse_sk#32) -(43) BroadcastExchange -Input [2]: [w_warehouse_sk#31, w_warehouse_name#32] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#33] +(44) BroadcastExchange +Input [2]: [w_warehouse_sk#32, w_warehouse_name#33] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#34] -(44) BroadcastHashJoin [codegen id : 12] -Left keys [1]: [inv_warehouse_sk#28] -Right keys [1]: [w_warehouse_sk#31] +(45) BroadcastHashJoin [codegen id : 13] +Left keys [1]: [inv_warehouse_sk#29] +Right keys [1]: [w_warehouse_sk#32] Join condition: None -(45) Project [codegen id : 12] -Output [4]: [inv_item_sk#27, inv_quantity_on_hand#29, inv_date_sk#30, w_warehouse_name#32] -Input [6]: [inv_item_sk#27, inv_warehouse_sk#28, inv_quantity_on_hand#29, inv_date_sk#30, w_warehouse_sk#31, w_warehouse_name#32] +(46) Project [codegen id : 13] +Output [4]: [inv_item_sk#28, inv_quantity_on_hand#30, inv_date_sk#31, w_warehouse_name#33] +Input [6]: [inv_item_sk#28, inv_warehouse_sk#29, inv_quantity_on_hand#30, inv_date_sk#31, w_warehouse_sk#32, w_warehouse_name#33] -(46) Exchange -Input [4]: [inv_item_sk#27, inv_quantity_on_hand#29, inv_date_sk#30, w_warehouse_name#32] -Arguments: hashpartitioning(inv_item_sk#27, 5), ENSURE_REQUIREMENTS, [id=#34] +(47) Exchange +Input [4]: [inv_item_sk#28, inv_quantity_on_hand#30, inv_date_sk#31, w_warehouse_name#33] +Arguments: hashpartitioning(inv_item_sk#28, inv_date_sk#31, 5), ENSURE_REQUIREMENTS, [id=#35] -(47) Sort [codegen id : 13] -Input [4]: [inv_item_sk#27, inv_quantity_on_hand#29, inv_date_sk#30, w_warehouse_name#32] -Arguments: [inv_item_sk#27 ASC NULLS FIRST, inv_date_sk#30 ASC NULLS FIRST], false, 0 +(48) Sort [codegen id : 14] +Input [4]: [inv_item_sk#28, inv_quantity_on_hand#30, inv_date_sk#31, w_warehouse_name#33] +Arguments: [inv_item_sk#28 ASC NULLS FIRST, inv_date_sk#31 ASC NULLS FIRST], false, 0 -(48) SortMergeJoin [codegen id : 15] +(49) SortMergeJoin [codegen id : 16] Left keys [2]: [cs_item_sk#4, d_date_sk#26] -Right keys [2]: [inv_item_sk#27, inv_date_sk#30] -Join condition: (inv_quantity_on_hand#29 < cs_quantity#7) +Right keys [2]: [inv_item_sk#28, inv_date_sk#31] +Join condition: (inv_quantity_on_hand#30 < cs_quantity#7) -(49) Project [codegen id : 15] -Output [6]: [cs_item_sk#4, cs_promo_sk#5, cs_order_number#6, w_warehouse_name#32, i_item_desc#21, d_week_seq#25] -Input [11]: [cs_item_sk#4, cs_promo_sk#5, cs_order_number#6, cs_quantity#7, i_item_desc#21, d_week_seq#25, d_date_sk#26, inv_item_sk#27, inv_quantity_on_hand#29, inv_date_sk#30, w_warehouse_name#32] +(50) Project [codegen id : 16] +Output [6]: [cs_item_sk#4, cs_promo_sk#5, cs_order_number#6, w_warehouse_name#33, i_item_desc#21, d_week_seq#25] +Input [11]: [cs_item_sk#4, cs_promo_sk#5, cs_order_number#6, cs_quantity#7, i_item_desc#21, d_week_seq#25, d_date_sk#26, inv_item_sk#28, inv_quantity_on_hand#30, inv_date_sk#31, w_warehouse_name#33] -(50) Scan parquet default.promotion -Output [1]: [p_promo_sk#35] +(51) Scan parquet default.promotion +Output [1]: [p_promo_sk#36] Batched: true Location [not included in comparison]/{warehouse_dir}/promotion] PushedFilters: [IsNotNull(p_promo_sk)] ReadSchema: struct -(51) ColumnarToRow [codegen id : 14] -Input [1]: [p_promo_sk#35] +(52) ColumnarToRow [codegen id : 15] +Input [1]: [p_promo_sk#36] -(52) Filter [codegen id : 14] -Input [1]: [p_promo_sk#35] -Condition : isnotnull(p_promo_sk#35) +(53) Filter [codegen id : 15] +Input [1]: [p_promo_sk#36] +Condition : isnotnull(p_promo_sk#36) -(53) BroadcastExchange -Input [1]: [p_promo_sk#35] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#36] +(54) BroadcastExchange +Input [1]: [p_promo_sk#36] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#37] -(54) BroadcastHashJoin [codegen id : 15] +(55) BroadcastHashJoin [codegen id : 16] Left keys [1]: [cs_promo_sk#5] -Right keys [1]: [p_promo_sk#35] +Right keys [1]: [p_promo_sk#36] Join condition: None -(55) Project [codegen id : 15] -Output [5]: [cs_item_sk#4, cs_order_number#6, w_warehouse_name#32, i_item_desc#21, d_week_seq#25] -Input [7]: [cs_item_sk#4, cs_promo_sk#5, cs_order_number#6, w_warehouse_name#32, i_item_desc#21, d_week_seq#25, p_promo_sk#35] +(56) Project [codegen id : 16] +Output [5]: [cs_item_sk#4, cs_order_number#6, w_warehouse_name#33, i_item_desc#21, d_week_seq#25] +Input [7]: [cs_item_sk#4, cs_promo_sk#5, cs_order_number#6, w_warehouse_name#33, i_item_desc#21, d_week_seq#25, p_promo_sk#36] + +(57) Exchange +Input [5]: [cs_item_sk#4, cs_order_number#6, w_warehouse_name#33, i_item_desc#21, d_week_seq#25] +Arguments: hashpartitioning(cs_item_sk#4, cs_order_number#6, 5), ENSURE_REQUIREMENTS, [id=#38] -(56) Sort [codegen id : 15] -Input [5]: [cs_item_sk#4, cs_order_number#6, w_warehouse_name#32, i_item_desc#21, d_week_seq#25] +(58) Sort [codegen id : 17] +Input [5]: [cs_item_sk#4, cs_order_number#6, w_warehouse_name#33, i_item_desc#21, d_week_seq#25] Arguments: [cs_item_sk#4 ASC NULLS FIRST, cs_order_number#6 ASC NULLS FIRST], false, 0 -(57) Scan parquet default.catalog_returns -Output [3]: [cr_item_sk#37, cr_order_number#38, cr_returned_date_sk#39] +(59) Scan parquet default.catalog_returns +Output [3]: [cr_item_sk#39, cr_order_number#40, cr_returned_date_sk#41] Batched: true Location [not included in comparison]/{warehouse_dir}/catalog_returns] PushedFilters: [IsNotNull(cr_item_sk), IsNotNull(cr_order_number)] ReadSchema: struct -(58) ColumnarToRow [codegen id : 16] -Input [3]: [cr_item_sk#37, cr_order_number#38, cr_returned_date_sk#39] +(60) ColumnarToRow [codegen id : 18] +Input [3]: [cr_item_sk#39, cr_order_number#40, cr_returned_date_sk#41] -(59) Filter [codegen id : 16] -Input [3]: [cr_item_sk#37, cr_order_number#38, cr_returned_date_sk#39] -Condition : (isnotnull(cr_item_sk#37) AND isnotnull(cr_order_number#38)) +(61) Filter [codegen id : 18] +Input [3]: [cr_item_sk#39, cr_order_number#40, cr_returned_date_sk#41] +Condition : (isnotnull(cr_item_sk#39) AND isnotnull(cr_order_number#40)) -(60) Project [codegen id : 16] -Output [2]: [cr_item_sk#37, cr_order_number#38] -Input [3]: [cr_item_sk#37, cr_order_number#38, cr_returned_date_sk#39] +(62) Project [codegen id : 18] +Output [2]: [cr_item_sk#39, cr_order_number#40] +Input [3]: [cr_item_sk#39, cr_order_number#40, cr_returned_date_sk#41] -(61) Exchange -Input [2]: [cr_item_sk#37, cr_order_number#38] -Arguments: hashpartitioning(cr_item_sk#37, 5), ENSURE_REQUIREMENTS, [id=#40] +(63) Exchange +Input [2]: [cr_item_sk#39, cr_order_number#40] +Arguments: hashpartitioning(cr_item_sk#39, cr_order_number#40, 5), ENSURE_REQUIREMENTS, [id=#42] -(62) Sort [codegen id : 17] -Input [2]: [cr_item_sk#37, cr_order_number#38] -Arguments: [cr_item_sk#37 ASC NULLS FIRST, cr_order_number#38 ASC NULLS FIRST], false, 0 +(64) Sort [codegen id : 19] +Input [2]: [cr_item_sk#39, cr_order_number#40] +Arguments: [cr_item_sk#39 ASC NULLS FIRST, cr_order_number#40 ASC NULLS FIRST], false, 0 -(63) SortMergeJoin [codegen id : 18] +(65) SortMergeJoin [codegen id : 20] Left keys [2]: [cs_item_sk#4, cs_order_number#6] -Right keys [2]: [cr_item_sk#37, cr_order_number#38] +Right keys [2]: [cr_item_sk#39, cr_order_number#40] Join condition: None -(64) Project [codegen id : 18] -Output [3]: [w_warehouse_name#32, i_item_desc#21, d_week_seq#25] -Input [7]: [cs_item_sk#4, cs_order_number#6, w_warehouse_name#32, i_item_desc#21, d_week_seq#25, cr_item_sk#37, cr_order_number#38] +(66) Project [codegen id : 20] +Output [3]: [w_warehouse_name#33, i_item_desc#21, d_week_seq#25] +Input [7]: [cs_item_sk#4, cs_order_number#6, w_warehouse_name#33, i_item_desc#21, d_week_seq#25, cr_item_sk#39, cr_order_number#40] -(65) HashAggregate [codegen id : 18] -Input [3]: [w_warehouse_name#32, i_item_desc#21, d_week_seq#25] -Keys [3]: [i_item_desc#21, w_warehouse_name#32, d_week_seq#25] +(67) HashAggregate [codegen id : 20] +Input [3]: [w_warehouse_name#33, i_item_desc#21, d_week_seq#25] +Keys [3]: [i_item_desc#21, w_warehouse_name#33, d_week_seq#25] Functions [1]: [partial_count(1)] -Aggregate Attributes [1]: [count#41] -Results [4]: [i_item_desc#21, w_warehouse_name#32, d_week_seq#25, count#42] +Aggregate Attributes [1]: [count#43] +Results [4]: [i_item_desc#21, w_warehouse_name#33, d_week_seq#25, count#44] -(66) Exchange -Input [4]: [i_item_desc#21, w_warehouse_name#32, d_week_seq#25, count#42] -Arguments: hashpartitioning(i_item_desc#21, w_warehouse_name#32, d_week_seq#25, 5), ENSURE_REQUIREMENTS, [id=#43] +(68) Exchange +Input [4]: [i_item_desc#21, w_warehouse_name#33, d_week_seq#25, count#44] +Arguments: hashpartitioning(i_item_desc#21, w_warehouse_name#33, d_week_seq#25, 5), ENSURE_REQUIREMENTS, [id=#45] -(67) HashAggregate [codegen id : 19] -Input [4]: [i_item_desc#21, w_warehouse_name#32, d_week_seq#25, count#42] -Keys [3]: [i_item_desc#21, w_warehouse_name#32, d_week_seq#25] +(69) HashAggregate [codegen id : 21] +Input [4]: [i_item_desc#21, w_warehouse_name#33, d_week_seq#25, count#44] +Keys [3]: [i_item_desc#21, w_warehouse_name#33, d_week_seq#25] Functions [1]: [count(1)] -Aggregate Attributes [1]: [count(1)#44] -Results [6]: [i_item_desc#21, w_warehouse_name#32, d_week_seq#25, count(1)#44 AS no_promo#45, count(1)#44 AS promo#46, count(1)#44 AS total_cnt#47] +Aggregate Attributes [1]: [count(1)#46] +Results [6]: [i_item_desc#21, w_warehouse_name#33, d_week_seq#25, count(1)#46 AS no_promo#47, count(1)#46 AS promo#48, count(1)#46 AS total_cnt#49] -(68) TakeOrderedAndProject -Input [6]: [i_item_desc#21, w_warehouse_name#32, d_week_seq#25, no_promo#45, promo#46, total_cnt#47] -Arguments: 100, [total_cnt#47 DESC NULLS LAST, i_item_desc#21 ASC NULLS FIRST, w_warehouse_name#32 ASC NULLS FIRST, d_week_seq#25 ASC NULLS FIRST], [i_item_desc#21, w_warehouse_name#32, d_week_seq#25, no_promo#45, promo#46, total_cnt#47] +(70) TakeOrderedAndProject +Input [6]: [i_item_desc#21, w_warehouse_name#33, d_week_seq#25, no_promo#47, promo#48, total_cnt#49] +Arguments: 100, [total_cnt#49 DESC NULLS LAST, i_item_desc#21 ASC NULLS FIRST, w_warehouse_name#33 ASC NULLS FIRST, d_week_seq#25 ASC NULLS FIRST], [i_item_desc#21, w_warehouse_name#33, d_week_seq#25, no_promo#47, promo#48, total_cnt#49] ===== Subqueries ===== Subquery:1 Hosting operator id = 1 Hosting Expression = cs_sold_date_sk#8 IN dynamicpruning#9 -BroadcastExchange (79) -+- * Project (78) - +- * BroadcastHashJoin Inner BuildLeft (77) - :- BroadcastExchange (73) - : +- * Project (72) - : +- * Filter (71) - : +- * ColumnarToRow (70) - : +- Scan parquet default.date_dim (69) - +- * Filter (76) - +- * ColumnarToRow (75) - +- Scan parquet default.date_dim (74) - - -(69) Scan parquet default.date_dim -Output [4]: [d_date_sk#23, d_date#24, d_week_seq#25, d_year#48] +BroadcastExchange (81) ++- * Project (80) + +- * BroadcastHashJoin Inner BuildLeft (79) + :- BroadcastExchange (75) + : +- * Project (74) + : +- * Filter (73) + : +- * ColumnarToRow (72) + : +- Scan parquet default.date_dim (71) + +- * Filter (78) + +- * ColumnarToRow (77) + +- Scan parquet default.date_dim (76) + + +(71) Scan parquet default.date_dim +Output [4]: [d_date_sk#23, d_date#24, d_week_seq#25, d_year#50] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), EqualTo(d_year,1999), IsNotNull(d_date_sk), IsNotNull(d_week_seq), IsNotNull(d_date)] ReadSchema: struct -(70) ColumnarToRow [codegen id : 1] -Input [4]: [d_date_sk#23, d_date#24, d_week_seq#25, d_year#48] +(72) ColumnarToRow [codegen id : 1] +Input [4]: [d_date_sk#23, d_date#24, d_week_seq#25, d_year#50] -(71) Filter [codegen id : 1] -Input [4]: [d_date_sk#23, d_date#24, d_week_seq#25, d_year#48] -Condition : ((((isnotnull(d_year#48) AND (d_year#48 = 1999)) AND isnotnull(d_date_sk#23)) AND isnotnull(d_week_seq#25)) AND isnotnull(d_date#24)) +(73) Filter [codegen id : 1] +Input [4]: [d_date_sk#23, d_date#24, d_week_seq#25, d_year#50] +Condition : ((((isnotnull(d_year#50) AND (d_year#50 = 1999)) AND isnotnull(d_date_sk#23)) AND isnotnull(d_week_seq#25)) AND isnotnull(d_date#24)) -(72) Project [codegen id : 1] +(74) Project [codegen id : 1] Output [3]: [d_date_sk#23, d_date#24, d_week_seq#25] -Input [4]: [d_date_sk#23, d_date#24, d_week_seq#25, d_year#48] +Input [4]: [d_date_sk#23, d_date#24, d_week_seq#25, d_year#50] -(73) BroadcastExchange +(75) BroadcastExchange Input [3]: [d_date_sk#23, d_date#24, d_week_seq#25] -Arguments: HashedRelationBroadcastMode(List(cast(input[2, int, true] as bigint)),false), [id=#49] +Arguments: HashedRelationBroadcastMode(List(cast(input[2, int, true] as bigint)),false), [id=#51] -(74) Scan parquet default.date_dim -Output [2]: [d_date_sk#26, d_week_seq#50] +(76) Scan parquet default.date_dim +Output [2]: [d_date_sk#26, d_week_seq#52] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_week_seq), IsNotNull(d_date_sk)] ReadSchema: struct -(75) ColumnarToRow -Input [2]: [d_date_sk#26, d_week_seq#50] +(77) ColumnarToRow +Input [2]: [d_date_sk#26, d_week_seq#52] -(76) Filter -Input [2]: [d_date_sk#26, d_week_seq#50] -Condition : (isnotnull(d_week_seq#50) AND isnotnull(d_date_sk#26)) +(78) Filter +Input [2]: [d_date_sk#26, d_week_seq#52] +Condition : (isnotnull(d_week_seq#52) AND isnotnull(d_date_sk#26)) -(77) BroadcastHashJoin [codegen id : 2] +(79) BroadcastHashJoin [codegen id : 2] Left keys [1]: [d_week_seq#25] -Right keys [1]: [d_week_seq#50] +Right keys [1]: [d_week_seq#52] Join condition: None -(78) Project [codegen id : 2] +(80) Project [codegen id : 2] Output [4]: [d_date_sk#23, d_date#24, d_week_seq#25, d_date_sk#26] -Input [5]: [d_date_sk#23, d_date#24, d_week_seq#25, d_date_sk#26, d_week_seq#50] +Input [5]: [d_date_sk#23, d_date#24, d_week_seq#25, d_date_sk#26, d_week_seq#52] -(79) BroadcastExchange +(81) BroadcastExchange Input [4]: [d_date_sk#23, d_date#24, d_week_seq#25, d_date_sk#26] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#51] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#53] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q72.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q72.sf100/simplified.txt index d84393b2ff106..e838025a71db8 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q72.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q72.sf100/simplified.txt @@ -1,126 +1,132 @@ TakeOrderedAndProject [total_cnt,i_item_desc,w_warehouse_name,d_week_seq,no_promo,promo] - WholeStageCodegen (19) + WholeStageCodegen (21) HashAggregate [i_item_desc,w_warehouse_name,d_week_seq,count] [count(1),no_promo,promo,total_cnt,count] InputAdapter Exchange [i_item_desc,w_warehouse_name,d_week_seq] #1 - WholeStageCodegen (18) + WholeStageCodegen (20) HashAggregate [i_item_desc,w_warehouse_name,d_week_seq] [count,count] Project [w_warehouse_name,i_item_desc,d_week_seq] SortMergeJoin [cs_item_sk,cs_order_number,cr_item_sk,cr_order_number] InputAdapter - WholeStageCodegen (15) + WholeStageCodegen (17) Sort [cs_item_sk,cs_order_number] - Project [cs_item_sk,cs_order_number,w_warehouse_name,i_item_desc,d_week_seq] - BroadcastHashJoin [cs_promo_sk,p_promo_sk] - Project [cs_item_sk,cs_promo_sk,cs_order_number,w_warehouse_name,i_item_desc,d_week_seq] - SortMergeJoin [cs_item_sk,d_date_sk,inv_item_sk,inv_date_sk,inv_quantity_on_hand,cs_quantity] - InputAdapter - WholeStageCodegen (10) - Sort [cs_item_sk,d_date_sk] - Project [cs_item_sk,cs_promo_sk,cs_order_number,cs_quantity,i_item_desc,d_week_seq,d_date_sk] - BroadcastHashJoin [cs_sold_date_sk,d_date_sk,d_date,d_date] - Project [cs_item_sk,cs_promo_sk,cs_order_number,cs_quantity,cs_sold_date_sk,d_date,i_item_desc] - SortMergeJoin [cs_item_sk,i_item_sk] - InputAdapter - WholeStageCodegen (5) - Sort [cs_item_sk] - InputAdapter - Exchange [cs_item_sk] #2 - WholeStageCodegen (4) - Project [cs_item_sk,cs_promo_sk,cs_order_number,cs_quantity,cs_sold_date_sk,d_date] - BroadcastHashJoin [cs_ship_date_sk,d_date_sk] - Project [cs_ship_date_sk,cs_item_sk,cs_promo_sk,cs_order_number,cs_quantity,cs_sold_date_sk] - BroadcastHashJoin [cs_bill_cdemo_sk,cd_demo_sk] - Project [cs_ship_date_sk,cs_bill_cdemo_sk,cs_item_sk,cs_promo_sk,cs_order_number,cs_quantity,cs_sold_date_sk] - BroadcastHashJoin [cs_bill_hdemo_sk,hd_demo_sk] - Filter [cs_quantity,cs_item_sk,cs_bill_cdemo_sk,cs_bill_hdemo_sk,cs_ship_date_sk] - ColumnarToRow - InputAdapter - Scan parquet default.catalog_sales [cs_ship_date_sk,cs_bill_cdemo_sk,cs_bill_hdemo_sk,cs_item_sk,cs_promo_sk,cs_order_number,cs_quantity,cs_sold_date_sk] - SubqueryBroadcast [d_date_sk] #1 - BroadcastExchange #3 - WholeStageCodegen (2) - Project [d_date_sk,d_date,d_week_seq,d_date_sk] - BroadcastHashJoin [d_week_seq,d_week_seq] - InputAdapter - BroadcastExchange #4 - WholeStageCodegen (1) - Project [d_date_sk,d_date,d_week_seq] - Filter [d_year,d_date_sk,d_week_seq,d_date] - ColumnarToRow + InputAdapter + Exchange [cs_item_sk,cs_order_number] #2 + WholeStageCodegen (16) + Project [cs_item_sk,cs_order_number,w_warehouse_name,i_item_desc,d_week_seq] + BroadcastHashJoin [cs_promo_sk,p_promo_sk] + Project [cs_item_sk,cs_promo_sk,cs_order_number,w_warehouse_name,i_item_desc,d_week_seq] + SortMergeJoin [cs_item_sk,d_date_sk,inv_item_sk,inv_date_sk,inv_quantity_on_hand,cs_quantity] + InputAdapter + WholeStageCodegen (11) + Sort [cs_item_sk,d_date_sk] + InputAdapter + Exchange [cs_item_sk,d_date_sk] #3 + WholeStageCodegen (10) + Project [cs_item_sk,cs_promo_sk,cs_order_number,cs_quantity,i_item_desc,d_week_seq,d_date_sk] + BroadcastHashJoin [cs_sold_date_sk,d_date_sk,d_date,d_date] + Project [cs_item_sk,cs_promo_sk,cs_order_number,cs_quantity,cs_sold_date_sk,d_date,i_item_desc] + SortMergeJoin [cs_item_sk,i_item_sk] + InputAdapter + WholeStageCodegen (5) + Sort [cs_item_sk] + InputAdapter + Exchange [cs_item_sk] #4 + WholeStageCodegen (4) + Project [cs_item_sk,cs_promo_sk,cs_order_number,cs_quantity,cs_sold_date_sk,d_date] + BroadcastHashJoin [cs_ship_date_sk,d_date_sk] + Project [cs_ship_date_sk,cs_item_sk,cs_promo_sk,cs_order_number,cs_quantity,cs_sold_date_sk] + BroadcastHashJoin [cs_bill_cdemo_sk,cd_demo_sk] + Project [cs_ship_date_sk,cs_bill_cdemo_sk,cs_item_sk,cs_promo_sk,cs_order_number,cs_quantity,cs_sold_date_sk] + BroadcastHashJoin [cs_bill_hdemo_sk,hd_demo_sk] + Filter [cs_quantity,cs_item_sk,cs_bill_cdemo_sk,cs_bill_hdemo_sk,cs_ship_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.catalog_sales [cs_ship_date_sk,cs_bill_cdemo_sk,cs_bill_hdemo_sk,cs_item_sk,cs_promo_sk,cs_order_number,cs_quantity,cs_sold_date_sk] + SubqueryBroadcast [d_date_sk] #1 + BroadcastExchange #5 + WholeStageCodegen (2) + Project [d_date_sk,d_date,d_week_seq,d_date_sk] + BroadcastHashJoin [d_week_seq,d_week_seq] InputAdapter - Scan parquet default.date_dim [d_date_sk,d_date,d_week_seq,d_year] - Filter [d_week_seq,d_date_sk] - ColumnarToRow - InputAdapter - Scan parquet default.date_dim [d_date_sk,d_week_seq] - InputAdapter - BroadcastExchange #5 - WholeStageCodegen (1) - Project [hd_demo_sk] - Filter [hd_buy_potential,hd_demo_sk] - ColumnarToRow + BroadcastExchange #6 + WholeStageCodegen (1) + Project [d_date_sk,d_date,d_week_seq] + Filter [d_year,d_date_sk,d_week_seq,d_date] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_date,d_week_seq,d_year] + Filter [d_week_seq,d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_week_seq] InputAdapter - Scan parquet default.household_demographics [hd_demo_sk,hd_buy_potential] - InputAdapter - BroadcastExchange #6 - WholeStageCodegen (2) - Project [cd_demo_sk] - Filter [cd_marital_status,cd_demo_sk] - ColumnarToRow + BroadcastExchange #7 + WholeStageCodegen (1) + Project [hd_demo_sk] + Filter [hd_buy_potential,hd_demo_sk] + ColumnarToRow + InputAdapter + Scan parquet default.household_demographics [hd_demo_sk,hd_buy_potential] InputAdapter - Scan parquet default.customer_demographics [cd_demo_sk,cd_marital_status] - InputAdapter - BroadcastExchange #7 - WholeStageCodegen (3) - Filter [d_date,d_date_sk] - ColumnarToRow - InputAdapter - Scan parquet default.date_dim [d_date_sk,d_date] - InputAdapter - WholeStageCodegen (7) - Sort [i_item_sk] - InputAdapter - Exchange [i_item_sk] #8 - WholeStageCodegen (6) - Filter [i_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.item [i_item_sk,i_item_desc] - InputAdapter - ReusedExchange [d_date_sk,d_date,d_week_seq,d_date_sk] #3 - InputAdapter - WholeStageCodegen (13) - Sort [inv_item_sk,inv_date_sk] + BroadcastExchange #8 + WholeStageCodegen (2) + Project [cd_demo_sk] + Filter [cd_marital_status,cd_demo_sk] + ColumnarToRow + InputAdapter + Scan parquet default.customer_demographics [cd_demo_sk,cd_marital_status] + InputAdapter + BroadcastExchange #9 + WholeStageCodegen (3) + Filter [d_date,d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_date] + InputAdapter + WholeStageCodegen (7) + Sort [i_item_sk] + InputAdapter + Exchange [i_item_sk] #10 + WholeStageCodegen (6) + Filter [i_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.item [i_item_sk,i_item_desc] + InputAdapter + ReusedExchange [d_date_sk,d_date,d_week_seq,d_date_sk] #5 InputAdapter - Exchange [inv_item_sk] #9 - WholeStageCodegen (12) - Project [inv_item_sk,inv_quantity_on_hand,inv_date_sk,w_warehouse_name] - BroadcastHashJoin [inv_warehouse_sk,w_warehouse_sk] - Filter [inv_quantity_on_hand,inv_item_sk,inv_warehouse_sk] - ColumnarToRow - InputAdapter - Scan parquet default.inventory [inv_item_sk,inv_warehouse_sk,inv_quantity_on_hand,inv_date_sk] - InputAdapter - BroadcastExchange #10 - WholeStageCodegen (11) - Filter [w_warehouse_sk] + WholeStageCodegen (14) + Sort [inv_item_sk,inv_date_sk] + InputAdapter + Exchange [inv_item_sk,inv_date_sk] #11 + WholeStageCodegen (13) + Project [inv_item_sk,inv_quantity_on_hand,inv_date_sk,w_warehouse_name] + BroadcastHashJoin [inv_warehouse_sk,w_warehouse_sk] + Filter [inv_quantity_on_hand,inv_item_sk,inv_warehouse_sk] ColumnarToRow InputAdapter - Scan parquet default.warehouse [w_warehouse_sk,w_warehouse_name] - InputAdapter - BroadcastExchange #11 - WholeStageCodegen (14) - Filter [p_promo_sk] - ColumnarToRow - InputAdapter - Scan parquet default.promotion [p_promo_sk] + Scan parquet default.inventory [inv_item_sk,inv_warehouse_sk,inv_quantity_on_hand,inv_date_sk] + InputAdapter + BroadcastExchange #12 + WholeStageCodegen (12) + Filter [w_warehouse_sk] + ColumnarToRow + InputAdapter + Scan parquet default.warehouse [w_warehouse_sk,w_warehouse_name] + InputAdapter + BroadcastExchange #13 + WholeStageCodegen (15) + Filter [p_promo_sk] + ColumnarToRow + InputAdapter + Scan parquet default.promotion [p_promo_sk] InputAdapter - WholeStageCodegen (17) + WholeStageCodegen (19) Sort [cr_item_sk,cr_order_number] InputAdapter - Exchange [cr_item_sk] #12 - WholeStageCodegen (16) + Exchange [cr_item_sk,cr_order_number] #14 + WholeStageCodegen (18) Project [cr_item_sk,cr_order_number] Filter [cr_item_sk,cr_order_number] ColumnarToRow diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77.sf100/explain.txt index a00880bad3116..04a0ca4cd3027 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77.sf100/explain.txt @@ -225,7 +225,7 @@ Right keys [1]: [s_store_sk#23] Join condition: None (30) Project [codegen id : 8] -Output [5]: [sales#16, coalesce(returns#31, 0.00) AS returns#34, CheckOverflow((promote_precision(cast(profit#17 as decimal(18,2))) - promote_precision(cast(coalesce(profit_loss#32, 0.00) as decimal(18,2)))), DecimalType(18,2), true) AS profit#35, store channel AS channel#36, s_store_sk#7 AS id#37] +Output [5]: [sales#16, coalesce(returns#31, 0.00) AS returns#34, CheckOverflow((promote_precision(cast(profit#17 as decimal(18,2))) - promote_precision(cast(coalesce(profit_loss#32, 0.00) as decimal(18,2)))), DecimalType(18,2)) AS profit#35, store channel AS channel#36, s_store_sk#7 AS id#37] Input [6]: [s_store_sk#7, sales#16, profit#17, s_store_sk#23, returns#31, profit_loss#32] (31) Scan parquet default.catalog_sales @@ -316,7 +316,7 @@ Arguments: IdentityBroadcastMode, [id=#65] Join condition: None (49) Project [codegen id : 14] -Output [5]: [sales#50, returns#63, CheckOverflow((promote_precision(cast(profit#51 as decimal(18,2))) - promote_precision(cast(profit_loss#64 as decimal(18,2)))), DecimalType(18,2), true) AS profit#66, catalog channel AS channel#67, cs_call_center_sk#38 AS id#68] +Output [5]: [sales#50, returns#63, CheckOverflow((promote_precision(cast(profit#51 as decimal(18,2))) - promote_precision(cast(profit_loss#64 as decimal(18,2)))), DecimalType(18,2)) AS profit#66, catalog channel AS channel#67, cs_call_center_sk#38 AS id#68] Input [5]: [cs_call_center_sk#38, sales#50, profit#51, returns#63, profit_loss#64] (50) Scan parquet default.web_sales @@ -458,7 +458,7 @@ Right keys [1]: [wp_web_page_sk#90] Join condition: None (79) Project [codegen id : 22] -Output [5]: [sales#83, coalesce(returns#98, 0.00) AS returns#101, CheckOverflow((promote_precision(cast(profit#84 as decimal(18,2))) - promote_precision(cast(coalesce(profit_loss#99, 0.00) as decimal(18,2)))), DecimalType(18,2), true) AS profit#102, web channel AS channel#103, wp_web_page_sk#74 AS id#104] +Output [5]: [sales#83, coalesce(returns#98, 0.00) AS returns#101, CheckOverflow((promote_precision(cast(profit#84 as decimal(18,2))) - promote_precision(cast(coalesce(profit_loss#99, 0.00) as decimal(18,2)))), DecimalType(18,2)) AS profit#102, web channel AS channel#103, wp_web_page_sk#74 AS id#104] Input [6]: [wp_web_page_sk#74, sales#83, profit#84, wp_web_page_sk#90, returns#98, profit_loss#99] (80) Union diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77/explain.txt index 0d7bfa462ef4c..c3cd748f43775 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77/explain.txt @@ -225,7 +225,7 @@ Right keys [1]: [s_store_sk#23] Join condition: None (30) Project [codegen id : 8] -Output [5]: [sales#16, coalesce(returns#31, 0.00) AS returns#34, CheckOverflow((promote_precision(cast(profit#17 as decimal(18,2))) - promote_precision(cast(coalesce(profit_loss#32, 0.00) as decimal(18,2)))), DecimalType(18,2), true) AS profit#35, store channel AS channel#36, s_store_sk#7 AS id#37] +Output [5]: [sales#16, coalesce(returns#31, 0.00) AS returns#34, CheckOverflow((promote_precision(cast(profit#17 as decimal(18,2))) - promote_precision(cast(coalesce(profit_loss#32, 0.00) as decimal(18,2)))), DecimalType(18,2)) AS profit#35, store channel AS channel#36, s_store_sk#7 AS id#37] Input [6]: [s_store_sk#7, sales#16, profit#17, s_store_sk#23, returns#31, profit_loss#32] (31) Scan parquet default.catalog_sales @@ -316,7 +316,7 @@ Results [2]: [MakeDecimal(sum(UnscaledValue(cr_return_amount#53))#62,17,2) AS re Join condition: None (49) Project [codegen id : 14] -Output [5]: [sales#50, returns#64, CheckOverflow((promote_precision(cast(profit#51 as decimal(18,2))) - promote_precision(cast(profit_loss#65 as decimal(18,2)))), DecimalType(18,2), true) AS profit#66, catalog channel AS channel#67, cs_call_center_sk#38 AS id#68] +Output [5]: [sales#50, returns#64, CheckOverflow((promote_precision(cast(profit#51 as decimal(18,2))) - promote_precision(cast(profit_loss#65 as decimal(18,2)))), DecimalType(18,2)) AS profit#66, catalog channel AS channel#67, cs_call_center_sk#38 AS id#68] Input [5]: [cs_call_center_sk#38, sales#50, profit#51, returns#64, profit_loss#65] (50) Scan parquet default.web_sales @@ -458,7 +458,7 @@ Right keys [1]: [wp_web_page_sk#90] Join condition: None (79) Project [codegen id : 22] -Output [5]: [sales#83, coalesce(returns#98, 0.00) AS returns#101, CheckOverflow((promote_precision(cast(profit#84 as decimal(18,2))) - promote_precision(cast(coalesce(profit_loss#99, 0.00) as decimal(18,2)))), DecimalType(18,2), true) AS profit#102, web channel AS channel#103, wp_web_page_sk#74 AS id#104] +Output [5]: [sales#83, coalesce(returns#98, 0.00) AS returns#101, CheckOverflow((promote_precision(cast(profit#84 as decimal(18,2))) - promote_precision(cast(coalesce(profit_loss#99, 0.00) as decimal(18,2)))), DecimalType(18,2)) AS profit#102, web channel AS channel#103, wp_web_page_sk#74 AS id#104] Input [6]: [wp_web_page_sk#74, sales#83, profit#84, wp_web_page_sk#90, returns#98, profit_loss#99] (80) Union diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80.sf100/explain.txt index cfbaa2e8b48d2..9cc78e12028ff 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80.sf100/explain.txt @@ -270,7 +270,7 @@ Input [7]: [ss_store_sk#2, ss_ext_sales_price#5, ss_net_profit#6, sr_return_amt# (37) HashAggregate [codegen id : 9] Input [5]: [ss_ext_sales_price#5, ss_net_profit#6, sr_return_amt#12, sr_net_loss#13, s_store_id#24] Keys [1]: [s_store_id#24] -Functions [3]: [partial_sum(UnscaledValue(ss_ext_sales_price#5)), partial_sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00)), partial_sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))] +Functions [3]: [partial_sum(UnscaledValue(ss_ext_sales_price#5)), partial_sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00)), partial_sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))] Aggregate Attributes [5]: [sum#26, sum#27, isEmpty#28, sum#29, isEmpty#30] Results [6]: [s_store_id#24, sum#31, sum#32, isEmpty#33, sum#34, isEmpty#35] @@ -281,9 +281,9 @@ Arguments: hashpartitioning(s_store_id#24, 5), ENSURE_REQUIREMENTS, [id=#36] (39) HashAggregate [codegen id : 10] Input [6]: [s_store_id#24, sum#31, sum#32, isEmpty#33, sum#34, isEmpty#35] Keys [1]: [s_store_id#24] -Functions [3]: [sum(UnscaledValue(ss_ext_sales_price#5)), sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00)), sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))] -Aggregate Attributes [3]: [sum(UnscaledValue(ss_ext_sales_price#5))#37, sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00))#38, sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))#39] -Results [5]: [MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#5))#37,17,2) AS sales#40, sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00))#38 AS returns#41, sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))#39 AS profit#42, store channel AS channel#43, concat(store, s_store_id#24) AS id#44] +Functions [3]: [sum(UnscaledValue(ss_ext_sales_price#5)), sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00)), sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))] +Aggregate Attributes [3]: [sum(UnscaledValue(ss_ext_sales_price#5))#37, sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00))#38, sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))#39] +Results [5]: [MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#5))#37,17,2) AS sales#40, sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00))#38 AS returns#41, sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))#39 AS profit#42, store channel AS channel#43, concat(store, s_store_id#24) AS id#44] (40) Scan parquet default.catalog_sales Output [7]: [cs_catalog_page_sk#45, cs_item_sk#46, cs_promo_sk#47, cs_order_number#48, cs_ext_sales_price#49, cs_net_profit#50, cs_sold_date_sk#51] @@ -409,7 +409,7 @@ Input [7]: [cs_catalog_page_sk#45, cs_ext_sales_price#49, cs_net_profit#50, cr_r (68) HashAggregate [codegen id : 19] Input [5]: [cs_ext_sales_price#49, cs_net_profit#50, cr_return_amount#55, cr_net_loss#56, cp_catalog_page_id#63] Keys [1]: [cp_catalog_page_id#63] -Functions [3]: [partial_sum(UnscaledValue(cs_ext_sales_price#49)), partial_sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00)), partial_sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))] +Functions [3]: [partial_sum(UnscaledValue(cs_ext_sales_price#49)), partial_sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00)), partial_sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))] Aggregate Attributes [5]: [sum#65, sum#66, isEmpty#67, sum#68, isEmpty#69] Results [6]: [cp_catalog_page_id#63, sum#70, sum#71, isEmpty#72, sum#73, isEmpty#74] @@ -420,9 +420,9 @@ Arguments: hashpartitioning(cp_catalog_page_id#63, 5), ENSURE_REQUIREMENTS, [id= (70) HashAggregate [codegen id : 20] Input [6]: [cp_catalog_page_id#63, sum#70, sum#71, isEmpty#72, sum#73, isEmpty#74] Keys [1]: [cp_catalog_page_id#63] -Functions [3]: [sum(UnscaledValue(cs_ext_sales_price#49)), sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00)), sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))] -Aggregate Attributes [3]: [sum(UnscaledValue(cs_ext_sales_price#49))#76, sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00))#77, sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))#78] -Results [5]: [MakeDecimal(sum(UnscaledValue(cs_ext_sales_price#49))#76,17,2) AS sales#79, sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00))#77 AS returns#80, sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))#78 AS profit#81, catalog channel AS channel#82, concat(catalog_page, cp_catalog_page_id#63) AS id#83] +Functions [3]: [sum(UnscaledValue(cs_ext_sales_price#49)), sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00)), sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))] +Aggregate Attributes [3]: [sum(UnscaledValue(cs_ext_sales_price#49))#76, sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00))#77, sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))#78] +Results [5]: [MakeDecimal(sum(UnscaledValue(cs_ext_sales_price#49))#76,17,2) AS sales#79, sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00))#77 AS returns#80, sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))#78 AS profit#81, catalog channel AS channel#82, concat(catalog_page, cp_catalog_page_id#63) AS id#83] (71) Scan parquet default.web_sales Output [7]: [ws_item_sk#84, ws_web_site_sk#85, ws_promo_sk#86, ws_order_number#87, ws_ext_sales_price#88, ws_net_profit#89, ws_sold_date_sk#90] @@ -548,7 +548,7 @@ Input [7]: [ws_web_site_sk#85, ws_ext_sales_price#88, ws_net_profit#89, wr_retur (99) HashAggregate [codegen id : 29] Input [5]: [ws_ext_sales_price#88, ws_net_profit#89, wr_return_amt#94, wr_net_loss#95, web_site_id#102] Keys [1]: [web_site_id#102] -Functions [3]: [partial_sum(UnscaledValue(ws_ext_sales_price#88)), partial_sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00)), partial_sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))] +Functions [3]: [partial_sum(UnscaledValue(ws_ext_sales_price#88)), partial_sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00)), partial_sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))] Aggregate Attributes [5]: [sum#104, sum#105, isEmpty#106, sum#107, isEmpty#108] Results [6]: [web_site_id#102, sum#109, sum#110, isEmpty#111, sum#112, isEmpty#113] @@ -559,9 +559,9 @@ Arguments: hashpartitioning(web_site_id#102, 5), ENSURE_REQUIREMENTS, [id=#114] (101) HashAggregate [codegen id : 30] Input [6]: [web_site_id#102, sum#109, sum#110, isEmpty#111, sum#112, isEmpty#113] Keys [1]: [web_site_id#102] -Functions [3]: [sum(UnscaledValue(ws_ext_sales_price#88)), sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00)), sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))] -Aggregate Attributes [3]: [sum(UnscaledValue(ws_ext_sales_price#88))#115, sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00))#116, sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))#117] -Results [5]: [MakeDecimal(sum(UnscaledValue(ws_ext_sales_price#88))#115,17,2) AS sales#118, sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00))#116 AS returns#119, sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))#117 AS profit#120, web channel AS channel#121, concat(web_site, web_site_id#102) AS id#122] +Functions [3]: [sum(UnscaledValue(ws_ext_sales_price#88)), sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00)), sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))] +Aggregate Attributes [3]: [sum(UnscaledValue(ws_ext_sales_price#88))#115, sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00))#116, sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))#117] +Results [5]: [MakeDecimal(sum(UnscaledValue(ws_ext_sales_price#88))#115,17,2) AS sales#118, sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00))#116 AS returns#119, sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))#117 AS profit#120, web channel AS channel#121, concat(web_site, web_site_id#102) AS id#122] (102) Union diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80.sf100/simplified.txt index b742daa007454..7de3dd817429d 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80.sf100/simplified.txt @@ -9,7 +9,7 @@ TakeOrderedAndProject [channel,id,sales,returns,profit] InputAdapter Union WholeStageCodegen (10) - HashAggregate [s_store_id,sum,sum,isEmpty,sum,isEmpty] [sum(UnscaledValue(ss_ext_sales_price)),sum(coalesce(cast(sr_return_amt as decimal(12,2)), 0.00)),sum(CheckOverflow((promote_precision(cast(ss_net_profit as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true)),sales,returns,profit,channel,id,sum,sum,isEmpty,sum,isEmpty] + HashAggregate [s_store_id,sum,sum,isEmpty,sum,isEmpty] [sum(UnscaledValue(ss_ext_sales_price)),sum(coalesce(cast(sr_return_amt as decimal(12,2)), 0.00)),sum(CheckOverflow((promote_precision(cast(ss_net_profit as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2))),sales,returns,profit,channel,id,sum,sum,isEmpty,sum,isEmpty] InputAdapter Exchange [s_store_id] #2 WholeStageCodegen (9) @@ -79,7 +79,7 @@ TakeOrderedAndProject [channel,id,sales,returns,profit] InputAdapter Scan parquet default.store [s_store_sk,s_store_id] WholeStageCodegen (20) - HashAggregate [cp_catalog_page_id,sum,sum,isEmpty,sum,isEmpty] [sum(UnscaledValue(cs_ext_sales_price)),sum(coalesce(cast(cr_return_amount as decimal(12,2)), 0.00)),sum(CheckOverflow((promote_precision(cast(cs_net_profit as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true)),sales,returns,profit,channel,id,sum,sum,isEmpty,sum,isEmpty] + HashAggregate [cp_catalog_page_id,sum,sum,isEmpty,sum,isEmpty] [sum(UnscaledValue(cs_ext_sales_price)),sum(coalesce(cast(cr_return_amount as decimal(12,2)), 0.00)),sum(CheckOverflow((promote_precision(cast(cs_net_profit as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2))),sales,returns,profit,channel,id,sum,sum,isEmpty,sum,isEmpty] InputAdapter Exchange [cp_catalog_page_id] #9 WholeStageCodegen (19) @@ -130,7 +130,7 @@ TakeOrderedAndProject [channel,id,sales,returns,profit] InputAdapter Scan parquet default.catalog_page [cp_catalog_page_sk,cp_catalog_page_id] WholeStageCodegen (30) - HashAggregate [web_site_id,sum,sum,isEmpty,sum,isEmpty] [sum(UnscaledValue(ws_ext_sales_price)),sum(coalesce(cast(wr_return_amt as decimal(12,2)), 0.00)),sum(CheckOverflow((promote_precision(cast(ws_net_profit as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true)),sales,returns,profit,channel,id,sum,sum,isEmpty,sum,isEmpty] + HashAggregate [web_site_id,sum,sum,isEmpty,sum,isEmpty] [sum(UnscaledValue(ws_ext_sales_price)),sum(coalesce(cast(wr_return_amt as decimal(12,2)), 0.00)),sum(CheckOverflow((promote_precision(cast(ws_net_profit as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2))),sales,returns,profit,channel,id,sum,sum,isEmpty,sum,isEmpty] InputAdapter Exchange [web_site_id] #13 WholeStageCodegen (29) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80/explain.txt index c18e9a125335e..20cf55dba4482 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80/explain.txt @@ -270,7 +270,7 @@ Input [7]: [ss_promo_sk#3, ss_ext_sales_price#5, ss_net_profit#6, sr_return_amt# (37) HashAggregate [codegen id : 9] Input [5]: [ss_ext_sales_price#5, ss_net_profit#6, sr_return_amt#12, sr_net_loss#13, s_store_id#18] Keys [1]: [s_store_id#18] -Functions [3]: [partial_sum(UnscaledValue(ss_ext_sales_price#5)), partial_sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00)), partial_sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))] +Functions [3]: [partial_sum(UnscaledValue(ss_ext_sales_price#5)), partial_sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00)), partial_sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))] Aggregate Attributes [5]: [sum#26, sum#27, isEmpty#28, sum#29, isEmpty#30] Results [6]: [s_store_id#18, sum#31, sum#32, isEmpty#33, sum#34, isEmpty#35] @@ -281,9 +281,9 @@ Arguments: hashpartitioning(s_store_id#18, 5), ENSURE_REQUIREMENTS, [id=#36] (39) HashAggregate [codegen id : 10] Input [6]: [s_store_id#18, sum#31, sum#32, isEmpty#33, sum#34, isEmpty#35] Keys [1]: [s_store_id#18] -Functions [3]: [sum(UnscaledValue(ss_ext_sales_price#5)), sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00)), sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))] -Aggregate Attributes [3]: [sum(UnscaledValue(ss_ext_sales_price#5))#37, sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00))#38, sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))#39] -Results [5]: [MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#5))#37,17,2) AS sales#40, sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00))#38 AS returns#41, sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))#39 AS profit#42, store channel AS channel#43, concat(store, s_store_id#18) AS id#44] +Functions [3]: [sum(UnscaledValue(ss_ext_sales_price#5)), sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00)), sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))] +Aggregate Attributes [3]: [sum(UnscaledValue(ss_ext_sales_price#5))#37, sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00))#38, sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))#39] +Results [5]: [MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#5))#37,17,2) AS sales#40, sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00))#38 AS returns#41, sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))#39 AS profit#42, store channel AS channel#43, concat(store, s_store_id#18) AS id#44] (40) Scan parquet default.catalog_sales Output [7]: [cs_catalog_page_sk#45, cs_item_sk#46, cs_promo_sk#47, cs_order_number#48, cs_ext_sales_price#49, cs_net_profit#50, cs_sold_date_sk#51] @@ -409,7 +409,7 @@ Input [7]: [cs_promo_sk#47, cs_ext_sales_price#49, cs_net_profit#50, cr_return_a (68) HashAggregate [codegen id : 19] Input [5]: [cs_ext_sales_price#49, cs_net_profit#50, cr_return_amount#55, cr_net_loss#56, cp_catalog_page_id#61] Keys [1]: [cp_catalog_page_id#61] -Functions [3]: [partial_sum(UnscaledValue(cs_ext_sales_price#49)), partial_sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00)), partial_sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))] +Functions [3]: [partial_sum(UnscaledValue(cs_ext_sales_price#49)), partial_sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00)), partial_sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))] Aggregate Attributes [5]: [sum#65, sum#66, isEmpty#67, sum#68, isEmpty#69] Results [6]: [cp_catalog_page_id#61, sum#70, sum#71, isEmpty#72, sum#73, isEmpty#74] @@ -420,9 +420,9 @@ Arguments: hashpartitioning(cp_catalog_page_id#61, 5), ENSURE_REQUIREMENTS, [id= (70) HashAggregate [codegen id : 20] Input [6]: [cp_catalog_page_id#61, sum#70, sum#71, isEmpty#72, sum#73, isEmpty#74] Keys [1]: [cp_catalog_page_id#61] -Functions [3]: [sum(UnscaledValue(cs_ext_sales_price#49)), sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00)), sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))] -Aggregate Attributes [3]: [sum(UnscaledValue(cs_ext_sales_price#49))#76, sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00))#77, sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))#78] -Results [5]: [MakeDecimal(sum(UnscaledValue(cs_ext_sales_price#49))#76,17,2) AS sales#79, sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00))#77 AS returns#80, sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))#78 AS profit#81, catalog channel AS channel#82, concat(catalog_page, cp_catalog_page_id#61) AS id#83] +Functions [3]: [sum(UnscaledValue(cs_ext_sales_price#49)), sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00)), sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))] +Aggregate Attributes [3]: [sum(UnscaledValue(cs_ext_sales_price#49))#76, sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00))#77, sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))#78] +Results [5]: [MakeDecimal(sum(UnscaledValue(cs_ext_sales_price#49))#76,17,2) AS sales#79, sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00))#77 AS returns#80, sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))#78 AS profit#81, catalog channel AS channel#82, concat(catalog_page, cp_catalog_page_id#61) AS id#83] (71) Scan parquet default.web_sales Output [7]: [ws_item_sk#84, ws_web_site_sk#85, ws_promo_sk#86, ws_order_number#87, ws_ext_sales_price#88, ws_net_profit#89, ws_sold_date_sk#90] @@ -548,7 +548,7 @@ Input [7]: [ws_promo_sk#86, ws_ext_sales_price#88, ws_net_profit#89, wr_return_a (99) HashAggregate [codegen id : 29] Input [5]: [ws_ext_sales_price#88, ws_net_profit#89, wr_return_amt#94, wr_net_loss#95, web_site_id#100] Keys [1]: [web_site_id#100] -Functions [3]: [partial_sum(UnscaledValue(ws_ext_sales_price#88)), partial_sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00)), partial_sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))] +Functions [3]: [partial_sum(UnscaledValue(ws_ext_sales_price#88)), partial_sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00)), partial_sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))] Aggregate Attributes [5]: [sum#104, sum#105, isEmpty#106, sum#107, isEmpty#108] Results [6]: [web_site_id#100, sum#109, sum#110, isEmpty#111, sum#112, isEmpty#113] @@ -559,9 +559,9 @@ Arguments: hashpartitioning(web_site_id#100, 5), ENSURE_REQUIREMENTS, [id=#114] (101) HashAggregate [codegen id : 30] Input [6]: [web_site_id#100, sum#109, sum#110, isEmpty#111, sum#112, isEmpty#113] Keys [1]: [web_site_id#100] -Functions [3]: [sum(UnscaledValue(ws_ext_sales_price#88)), sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00)), sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))] -Aggregate Attributes [3]: [sum(UnscaledValue(ws_ext_sales_price#88))#115, sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00))#116, sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))#117] -Results [5]: [MakeDecimal(sum(UnscaledValue(ws_ext_sales_price#88))#115,17,2) AS sales#118, sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00))#116 AS returns#119, sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))#117 AS profit#120, web channel AS channel#121, concat(web_site, web_site_id#100) AS id#122] +Functions [3]: [sum(UnscaledValue(ws_ext_sales_price#88)), sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00)), sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))] +Aggregate Attributes [3]: [sum(UnscaledValue(ws_ext_sales_price#88))#115, sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00))#116, sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))#117] +Results [5]: [MakeDecimal(sum(UnscaledValue(ws_ext_sales_price#88))#115,17,2) AS sales#118, sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00))#116 AS returns#119, sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))#117 AS profit#120, web channel AS channel#121, concat(web_site, web_site_id#100) AS id#122] (102) Union diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80/simplified.txt index b8122c8270984..a6fd641bc2434 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80/simplified.txt @@ -9,7 +9,7 @@ TakeOrderedAndProject [channel,id,sales,returns,profit] InputAdapter Union WholeStageCodegen (10) - HashAggregate [s_store_id,sum,sum,isEmpty,sum,isEmpty] [sum(UnscaledValue(ss_ext_sales_price)),sum(coalesce(cast(sr_return_amt as decimal(12,2)), 0.00)),sum(CheckOverflow((promote_precision(cast(ss_net_profit as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true)),sales,returns,profit,channel,id,sum,sum,isEmpty,sum,isEmpty] + HashAggregate [s_store_id,sum,sum,isEmpty,sum,isEmpty] [sum(UnscaledValue(ss_ext_sales_price)),sum(coalesce(cast(sr_return_amt as decimal(12,2)), 0.00)),sum(CheckOverflow((promote_precision(cast(ss_net_profit as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2))),sales,returns,profit,channel,id,sum,sum,isEmpty,sum,isEmpty] InputAdapter Exchange [s_store_id] #2 WholeStageCodegen (9) @@ -79,7 +79,7 @@ TakeOrderedAndProject [channel,id,sales,returns,profit] InputAdapter Scan parquet default.promotion [p_promo_sk,p_channel_tv] WholeStageCodegen (20) - HashAggregate [cp_catalog_page_id,sum,sum,isEmpty,sum,isEmpty] [sum(UnscaledValue(cs_ext_sales_price)),sum(coalesce(cast(cr_return_amount as decimal(12,2)), 0.00)),sum(CheckOverflow((promote_precision(cast(cs_net_profit as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true)),sales,returns,profit,channel,id,sum,sum,isEmpty,sum,isEmpty] + HashAggregate [cp_catalog_page_id,sum,sum,isEmpty,sum,isEmpty] [sum(UnscaledValue(cs_ext_sales_price)),sum(coalesce(cast(cr_return_amount as decimal(12,2)), 0.00)),sum(CheckOverflow((promote_precision(cast(cs_net_profit as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2))),sales,returns,profit,channel,id,sum,sum,isEmpty,sum,isEmpty] InputAdapter Exchange [cp_catalog_page_id] #9 WholeStageCodegen (19) @@ -130,7 +130,7 @@ TakeOrderedAndProject [channel,id,sales,returns,profit] InputAdapter ReusedExchange [p_promo_sk] #8 WholeStageCodegen (30) - HashAggregate [web_site_id,sum,sum,isEmpty,sum,isEmpty] [sum(UnscaledValue(ws_ext_sales_price)),sum(coalesce(cast(wr_return_amt as decimal(12,2)), 0.00)),sum(CheckOverflow((promote_precision(cast(ws_net_profit as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true)),sales,returns,profit,channel,id,sum,sum,isEmpty,sum,isEmpty] + HashAggregate [web_site_id,sum,sum,isEmpty,sum,isEmpty] [sum(UnscaledValue(ws_ext_sales_price)),sum(coalesce(cast(wr_return_amt as decimal(12,2)), 0.00)),sum(CheckOverflow((promote_precision(cast(ws_net_profit as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2))),sales,returns,profit,channel,id,sum,sum,isEmpty,sum,isEmpty] InputAdapter Exchange [web_site_id] #13 WholeStageCodegen (29) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q81.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q81.sf100/explain.txt index 83d227688cf61..288df2457edf2 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q81.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q81.sf100/explain.txt @@ -297,7 +297,7 @@ Input [3]: [ctr_state#36, sum#45, count#46] Keys [1]: [ctr_state#36] Functions [1]: [avg(ctr_total_return#37)] Aggregate Attributes [1]: [avg(ctr_total_return#37)#48] -Results [2]: [CheckOverflow((promote_precision(avg(ctr_total_return#37)#48) * 1.200000), DecimalType(24,7), true) AS (avg(ctr_total_return) * 1.2)#49, ctr_state#36 AS ctr_state#36#50] +Results [2]: [CheckOverflow((promote_precision(avg(ctr_total_return#37)#48) * 1.200000), DecimalType(24,7)) AS (avg(ctr_total_return) * 1.2)#49, ctr_state#36 AS ctr_state#36#50] (53) Filter [codegen id : 19] Input [2]: [(avg(ctr_total_return) * 1.2)#49, ctr_state#36#50] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q81/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q81/explain.txt index 260224e41b7f7..91bd90224827a 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q81/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q81/explain.txt @@ -198,7 +198,7 @@ Input [3]: [ctr_state#15, sum#22, count#23] Keys [1]: [ctr_state#15] Functions [1]: [avg(ctr_total_return#16)] Aggregate Attributes [1]: [avg(ctr_total_return#16)#25] -Results [2]: [CheckOverflow((promote_precision(avg(ctr_total_return#16)#25) * 1.200000), DecimalType(24,7), true) AS (avg(ctr_total_return) * 1.2)#26, ctr_state#15 AS ctr_state#15#27] +Results [2]: [CheckOverflow((promote_precision(avg(ctr_total_return#16)#25) * 1.200000), DecimalType(24,7)) AS (avg(ctr_total_return) * 1.2)#26, ctr_state#15 AS ctr_state#15#27] (32) Filter [codegen id : 8] Input [2]: [(avg(ctr_total_return) * 1.2)#26, ctr_state#15#27] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q83.ansi/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q83.ansi/explain.txt new file mode 100644 index 0000000000000..c46fce21c25a2 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q83.ansi/explain.txt @@ -0,0 +1,362 @@ +== Physical Plan == +TakeOrderedAndProject (46) ++- * Project (45) + +- * BroadcastHashJoin Inner BuildRight (44) + :- * Project (30) + : +- * BroadcastHashJoin Inner BuildRight (29) + : :- * HashAggregate (15) + : : +- Exchange (14) + : : +- * HashAggregate (13) + : : +- * Project (12) + : : +- * BroadcastHashJoin Inner BuildRight (11) + : : :- * Project (9) + : : : +- * BroadcastHashJoin Inner BuildRight (8) + : : : :- * Filter (3) + : : : : +- * ColumnarToRow (2) + : : : : +- Scan parquet default.store_returns (1) + : : : +- BroadcastExchange (7) + : : : +- * Filter (6) + : : : +- * ColumnarToRow (5) + : : : +- Scan parquet default.item (4) + : : +- ReusedExchange (10) + : +- BroadcastExchange (28) + : +- * HashAggregate (27) + : +- Exchange (26) + : +- * HashAggregate (25) + : +- * Project (24) + : +- * BroadcastHashJoin Inner BuildRight (23) + : :- * Project (21) + : : +- * BroadcastHashJoin Inner BuildRight (20) + : : :- * Filter (18) + : : : +- * ColumnarToRow (17) + : : : +- Scan parquet default.catalog_returns (16) + : : +- ReusedExchange (19) + : +- ReusedExchange (22) + +- BroadcastExchange (43) + +- * HashAggregate (42) + +- Exchange (41) + +- * HashAggregate (40) + +- * Project (39) + +- * BroadcastHashJoin Inner BuildRight (38) + :- * Project (36) + : +- * BroadcastHashJoin Inner BuildRight (35) + : :- * Filter (33) + : : +- * ColumnarToRow (32) + : : +- Scan parquet default.web_returns (31) + : +- ReusedExchange (34) + +- ReusedExchange (37) + + +(1) Scan parquet default.store_returns +Output [3]: [sr_item_sk#1, sr_return_quantity#2, sr_returned_date_sk#3] +Batched: true +Location: InMemoryFileIndex [] +PartitionFilters: [isnotnull(sr_returned_date_sk#3), dynamicpruningexpression(sr_returned_date_sk#3 IN dynamicpruning#4)] +PushedFilters: [IsNotNull(sr_item_sk)] +ReadSchema: struct + +(2) ColumnarToRow [codegen id : 5] +Input [3]: [sr_item_sk#1, sr_return_quantity#2, sr_returned_date_sk#3] + +(3) Filter [codegen id : 5] +Input [3]: [sr_item_sk#1, sr_return_quantity#2, sr_returned_date_sk#3] +Condition : isnotnull(sr_item_sk#1) + +(4) Scan parquet default.item +Output [2]: [i_item_sk#5, i_item_id#6] +Batched: true +Location [not included in comparison]/{warehouse_dir}/item] +PushedFilters: [IsNotNull(i_item_sk), IsNotNull(i_item_id)] +ReadSchema: struct + +(5) ColumnarToRow [codegen id : 1] +Input [2]: [i_item_sk#5, i_item_id#6] + +(6) Filter [codegen id : 1] +Input [2]: [i_item_sk#5, i_item_id#6] +Condition : (isnotnull(i_item_sk#5) AND isnotnull(i_item_id#6)) + +(7) BroadcastExchange +Input [2]: [i_item_sk#5, i_item_id#6] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#7] + +(8) BroadcastHashJoin [codegen id : 5] +Left keys [1]: [sr_item_sk#1] +Right keys [1]: [i_item_sk#5] +Join condition: None + +(9) Project [codegen id : 5] +Output [3]: [sr_return_quantity#2, sr_returned_date_sk#3, i_item_id#6] +Input [5]: [sr_item_sk#1, sr_return_quantity#2, sr_returned_date_sk#3, i_item_sk#5, i_item_id#6] + +(10) ReusedExchange [Reuses operator id: 62] +Output [1]: [d_date_sk#8] + +(11) BroadcastHashJoin [codegen id : 5] +Left keys [1]: [sr_returned_date_sk#3] +Right keys [1]: [d_date_sk#8] +Join condition: None + +(12) Project [codegen id : 5] +Output [2]: [sr_return_quantity#2, i_item_id#6] +Input [4]: [sr_return_quantity#2, sr_returned_date_sk#3, i_item_id#6, d_date_sk#8] + +(13) HashAggregate [codegen id : 5] +Input [2]: [sr_return_quantity#2, i_item_id#6] +Keys [1]: [i_item_id#6] +Functions [1]: [partial_sum(sr_return_quantity#2)] +Aggregate Attributes [1]: [sum#9] +Results [2]: [i_item_id#6, sum#10] + +(14) Exchange +Input [2]: [i_item_id#6, sum#10] +Arguments: hashpartitioning(i_item_id#6, 5), ENSURE_REQUIREMENTS, [id=#11] + +(15) HashAggregate [codegen id : 18] +Input [2]: [i_item_id#6, sum#10] +Keys [1]: [i_item_id#6] +Functions [1]: [sum(sr_return_quantity#2)] +Aggregate Attributes [1]: [sum(sr_return_quantity#2)#12] +Results [2]: [i_item_id#6 AS item_id#13, sum(sr_return_quantity#2)#12 AS sr_item_qty#14] + +(16) Scan parquet default.catalog_returns +Output [3]: [cr_item_sk#15, cr_return_quantity#16, cr_returned_date_sk#17] +Batched: true +Location: InMemoryFileIndex [] +PartitionFilters: [isnotnull(cr_returned_date_sk#17), dynamicpruningexpression(cr_returned_date_sk#17 IN dynamicpruning#4)] +PushedFilters: [IsNotNull(cr_item_sk)] +ReadSchema: struct + +(17) ColumnarToRow [codegen id : 10] +Input [3]: [cr_item_sk#15, cr_return_quantity#16, cr_returned_date_sk#17] + +(18) Filter [codegen id : 10] +Input [3]: [cr_item_sk#15, cr_return_quantity#16, cr_returned_date_sk#17] +Condition : isnotnull(cr_item_sk#15) + +(19) ReusedExchange [Reuses operator id: 7] +Output [2]: [i_item_sk#18, i_item_id#19] + +(20) BroadcastHashJoin [codegen id : 10] +Left keys [1]: [cr_item_sk#15] +Right keys [1]: [i_item_sk#18] +Join condition: None + +(21) Project [codegen id : 10] +Output [3]: [cr_return_quantity#16, cr_returned_date_sk#17, i_item_id#19] +Input [5]: [cr_item_sk#15, cr_return_quantity#16, cr_returned_date_sk#17, i_item_sk#18, i_item_id#19] + +(22) ReusedExchange [Reuses operator id: 62] +Output [1]: [d_date_sk#20] + +(23) BroadcastHashJoin [codegen id : 10] +Left keys [1]: [cr_returned_date_sk#17] +Right keys [1]: [d_date_sk#20] +Join condition: None + +(24) Project [codegen id : 10] +Output [2]: [cr_return_quantity#16, i_item_id#19] +Input [4]: [cr_return_quantity#16, cr_returned_date_sk#17, i_item_id#19, d_date_sk#20] + +(25) HashAggregate [codegen id : 10] +Input [2]: [cr_return_quantity#16, i_item_id#19] +Keys [1]: [i_item_id#19] +Functions [1]: [partial_sum(cr_return_quantity#16)] +Aggregate Attributes [1]: [sum#21] +Results [2]: [i_item_id#19, sum#22] + +(26) Exchange +Input [2]: [i_item_id#19, sum#22] +Arguments: hashpartitioning(i_item_id#19, 5), ENSURE_REQUIREMENTS, [id=#23] + +(27) HashAggregate [codegen id : 11] +Input [2]: [i_item_id#19, sum#22] +Keys [1]: [i_item_id#19] +Functions [1]: [sum(cr_return_quantity#16)] +Aggregate Attributes [1]: [sum(cr_return_quantity#16)#24] +Results [2]: [i_item_id#19 AS item_id#25, sum(cr_return_quantity#16)#24 AS cr_item_qty#26] + +(28) BroadcastExchange +Input [2]: [item_id#25, cr_item_qty#26] +Arguments: HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#27] + +(29) BroadcastHashJoin [codegen id : 18] +Left keys [1]: [item_id#13] +Right keys [1]: [item_id#25] +Join condition: None + +(30) Project [codegen id : 18] +Output [3]: [item_id#13, sr_item_qty#14, cr_item_qty#26] +Input [4]: [item_id#13, sr_item_qty#14, item_id#25, cr_item_qty#26] + +(31) Scan parquet default.web_returns +Output [3]: [wr_item_sk#28, wr_return_quantity#29, wr_returned_date_sk#30] +Batched: true +Location: InMemoryFileIndex [] +PartitionFilters: [isnotnull(wr_returned_date_sk#30), dynamicpruningexpression(wr_returned_date_sk#30 IN dynamicpruning#4)] +PushedFilters: [IsNotNull(wr_item_sk)] +ReadSchema: struct + +(32) ColumnarToRow [codegen id : 16] +Input [3]: [wr_item_sk#28, wr_return_quantity#29, wr_returned_date_sk#30] + +(33) Filter [codegen id : 16] +Input [3]: [wr_item_sk#28, wr_return_quantity#29, wr_returned_date_sk#30] +Condition : isnotnull(wr_item_sk#28) + +(34) ReusedExchange [Reuses operator id: 7] +Output [2]: [i_item_sk#31, i_item_id#32] + +(35) BroadcastHashJoin [codegen id : 16] +Left keys [1]: [wr_item_sk#28] +Right keys [1]: [i_item_sk#31] +Join condition: None + +(36) Project [codegen id : 16] +Output [3]: [wr_return_quantity#29, wr_returned_date_sk#30, i_item_id#32] +Input [5]: [wr_item_sk#28, wr_return_quantity#29, wr_returned_date_sk#30, i_item_sk#31, i_item_id#32] + +(37) ReusedExchange [Reuses operator id: 62] +Output [1]: [d_date_sk#33] + +(38) BroadcastHashJoin [codegen id : 16] +Left keys [1]: [wr_returned_date_sk#30] +Right keys [1]: [d_date_sk#33] +Join condition: None + +(39) Project [codegen id : 16] +Output [2]: [wr_return_quantity#29, i_item_id#32] +Input [4]: [wr_return_quantity#29, wr_returned_date_sk#30, i_item_id#32, d_date_sk#33] + +(40) HashAggregate [codegen id : 16] +Input [2]: [wr_return_quantity#29, i_item_id#32] +Keys [1]: [i_item_id#32] +Functions [1]: [partial_sum(wr_return_quantity#29)] +Aggregate Attributes [1]: [sum#34] +Results [2]: [i_item_id#32, sum#35] + +(41) Exchange +Input [2]: [i_item_id#32, sum#35] +Arguments: hashpartitioning(i_item_id#32, 5), ENSURE_REQUIREMENTS, [id=#36] + +(42) HashAggregate [codegen id : 17] +Input [2]: [i_item_id#32, sum#35] +Keys [1]: [i_item_id#32] +Functions [1]: [sum(wr_return_quantity#29)] +Aggregate Attributes [1]: [sum(wr_return_quantity#29)#37] +Results [2]: [i_item_id#32 AS item_id#38, sum(wr_return_quantity#29)#37 AS wr_item_qty#39] + +(43) BroadcastExchange +Input [2]: [item_id#38, wr_item_qty#39] +Arguments: HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#40] + +(44) BroadcastHashJoin [codegen id : 18] +Left keys [1]: [item_id#13] +Right keys [1]: [item_id#38] +Join condition: None + +(45) Project [codegen id : 18] +Output [8]: [item_id#13, sr_item_qty#14, (((cast(sr_item_qty#14 as double) / cast(((sr_item_qty#14 + cr_item_qty#26) + wr_item_qty#39) as double)) / 3.0) * 100.0) AS sr_dev#41, cr_item_qty#26, (((cast(cr_item_qty#26 as double) / cast(((sr_item_qty#14 + cr_item_qty#26) + wr_item_qty#39) as double)) / 3.0) * 100.0) AS cr_dev#42, wr_item_qty#39, (((cast(wr_item_qty#39 as double) / cast(((sr_item_qty#14 + cr_item_qty#26) + wr_item_qty#39) as double)) / 3.0) * 100.0) AS wr_dev#43, CheckOverflow((promote_precision(cast(((sr_item_qty#14 + cr_item_qty#26) + wr_item_qty#39) as decimal(21,1))) / 3.0), DecimalType(27,6)) AS average#44] +Input [5]: [item_id#13, sr_item_qty#14, cr_item_qty#26, item_id#38, wr_item_qty#39] + +(46) TakeOrderedAndProject +Input [8]: [item_id#13, sr_item_qty#14, sr_dev#41, cr_item_qty#26, cr_dev#42, wr_item_qty#39, wr_dev#43, average#44] +Arguments: 100, [item_id#13 ASC NULLS FIRST, sr_item_qty#14 ASC NULLS FIRST], [item_id#13, sr_item_qty#14, sr_dev#41, cr_item_qty#26, cr_dev#42, wr_item_qty#39, wr_dev#43, average#44] + +===== Subqueries ===== + +Subquery:1 Hosting operator id = 1 Hosting Expression = sr_returned_date_sk#3 IN dynamicpruning#4 +BroadcastExchange (62) ++- * Project (61) + +- * BroadcastHashJoin LeftSemi BuildRight (60) + :- * Filter (49) + : +- * ColumnarToRow (48) + : +- Scan parquet default.date_dim (47) + +- BroadcastExchange (59) + +- * Project (58) + +- * BroadcastHashJoin LeftSemi BuildRight (57) + :- * ColumnarToRow (51) + : +- Scan parquet default.date_dim (50) + +- BroadcastExchange (56) + +- * Project (55) + +- * Filter (54) + +- * ColumnarToRow (53) + +- Scan parquet default.date_dim (52) + + +(47) Scan parquet default.date_dim +Output [2]: [d_date_sk#8, d_date#45] +Batched: true +Location [not included in comparison]/{warehouse_dir}/date_dim] +PushedFilters: [IsNotNull(d_date_sk)] +ReadSchema: struct + +(48) ColumnarToRow [codegen id : 3] +Input [2]: [d_date_sk#8, d_date#45] + +(49) Filter [codegen id : 3] +Input [2]: [d_date_sk#8, d_date#45] +Condition : isnotnull(d_date_sk#8) + +(50) Scan parquet default.date_dim +Output [2]: [d_date#46, d_week_seq#47] +Batched: true +Location [not included in comparison]/{warehouse_dir}/date_dim] +ReadSchema: struct + +(51) ColumnarToRow [codegen id : 2] +Input [2]: [d_date#46, d_week_seq#47] + +(52) Scan parquet default.date_dim +Output [2]: [d_date#48, d_week_seq#49] +Batched: true +Location [not included in comparison]/{warehouse_dir}/date_dim] +PushedFilters: [In(d_date, [2000-06-30,2000-09-27,2000-11-17])] +ReadSchema: struct + +(53) ColumnarToRow [codegen id : 1] +Input [2]: [d_date#48, d_week_seq#49] + +(54) Filter [codegen id : 1] +Input [2]: [d_date#48, d_week_seq#49] +Condition : d_date#48 IN (2000-06-30,2000-09-27,2000-11-17) + +(55) Project [codegen id : 1] +Output [1]: [d_week_seq#49] +Input [2]: [d_date#48, d_week_seq#49] + +(56) BroadcastExchange +Input [1]: [d_week_seq#49] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#50] + +(57) BroadcastHashJoin [codegen id : 2] +Left keys [1]: [d_week_seq#47] +Right keys [1]: [d_week_seq#49] +Join condition: None + +(58) Project [codegen id : 2] +Output [1]: [d_date#46] +Input [2]: [d_date#46, d_week_seq#47] + +(59) BroadcastExchange +Input [1]: [d_date#46] +Arguments: HashedRelationBroadcastMode(List(input[0, date, true]),false), [id=#51] + +(60) BroadcastHashJoin [codegen id : 3] +Left keys [1]: [d_date#45] +Right keys [1]: [d_date#46] +Join condition: None + +(61) Project [codegen id : 3] +Output [1]: [d_date_sk#8] +Input [2]: [d_date_sk#8, d_date#45] + +(62) BroadcastExchange +Input [1]: [d_date_sk#8] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#52] + +Subquery:2 Hosting operator id = 16 Hosting Expression = cr_returned_date_sk#17 IN dynamicpruning#4 + +Subquery:3 Hosting operator id = 31 Hosting Expression = wr_returned_date_sk#30 IN dynamicpruning#4 + + diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q83.ansi/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q83.ansi/simplified.txt new file mode 100644 index 0000000000000..29ff19d7450c8 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q83.ansi/simplified.txt @@ -0,0 +1,95 @@ +TakeOrderedAndProject [item_id,sr_item_qty,sr_dev,cr_item_qty,cr_dev,wr_item_qty,wr_dev,average] + WholeStageCodegen (18) + Project [item_id,sr_item_qty,cr_item_qty,wr_item_qty] + BroadcastHashJoin [item_id,item_id] + Project [item_id,sr_item_qty,cr_item_qty] + BroadcastHashJoin [item_id,item_id] + HashAggregate [i_item_id,sum] [sum(sr_return_quantity),item_id,sr_item_qty,sum] + InputAdapter + Exchange [i_item_id] #1 + WholeStageCodegen (5) + HashAggregate [i_item_id,sr_return_quantity] [sum,sum] + Project [sr_return_quantity,i_item_id] + BroadcastHashJoin [sr_returned_date_sk,d_date_sk] + Project [sr_return_quantity,sr_returned_date_sk,i_item_id] + BroadcastHashJoin [sr_item_sk,i_item_sk] + Filter [sr_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.store_returns [sr_item_sk,sr_return_quantity,sr_returned_date_sk] + SubqueryBroadcast [d_date_sk] #1 + BroadcastExchange #2 + WholeStageCodegen (3) + Project [d_date_sk] + BroadcastHashJoin [d_date,d_date] + Filter [d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_date] + InputAdapter + BroadcastExchange #3 + WholeStageCodegen (2) + Project [d_date] + BroadcastHashJoin [d_week_seq,d_week_seq] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date,d_week_seq] + InputAdapter + BroadcastExchange #4 + WholeStageCodegen (1) + Project [d_week_seq] + Filter [d_date] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date,d_week_seq] + InputAdapter + BroadcastExchange #5 + WholeStageCodegen (1) + Filter [i_item_sk,i_item_id] + ColumnarToRow + InputAdapter + Scan parquet default.item [i_item_sk,i_item_id] + InputAdapter + ReusedExchange [d_date_sk] #2 + InputAdapter + BroadcastExchange #6 + WholeStageCodegen (11) + HashAggregate [i_item_id,sum] [sum(cr_return_quantity),item_id,cr_item_qty,sum] + InputAdapter + Exchange [i_item_id] #7 + WholeStageCodegen (10) + HashAggregate [i_item_id,cr_return_quantity] [sum,sum] + Project [cr_return_quantity,i_item_id] + BroadcastHashJoin [cr_returned_date_sk,d_date_sk] + Project [cr_return_quantity,cr_returned_date_sk,i_item_id] + BroadcastHashJoin [cr_item_sk,i_item_sk] + Filter [cr_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.catalog_returns [cr_item_sk,cr_return_quantity,cr_returned_date_sk] + ReusedSubquery [d_date_sk] #1 + InputAdapter + ReusedExchange [i_item_sk,i_item_id] #5 + InputAdapter + ReusedExchange [d_date_sk] #2 + InputAdapter + BroadcastExchange #8 + WholeStageCodegen (17) + HashAggregate [i_item_id,sum] [sum(wr_return_quantity),item_id,wr_item_qty,sum] + InputAdapter + Exchange [i_item_id] #9 + WholeStageCodegen (16) + HashAggregate [i_item_id,wr_return_quantity] [sum,sum] + Project [wr_return_quantity,i_item_id] + BroadcastHashJoin [wr_returned_date_sk,d_date_sk] + Project [wr_return_quantity,wr_returned_date_sk,i_item_id] + BroadcastHashJoin [wr_item_sk,i_item_sk] + Filter [wr_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.web_returns [wr_item_sk,wr_return_quantity,wr_returned_date_sk] + ReusedSubquery [d_date_sk] #1 + InputAdapter + ReusedExchange [i_item_sk,i_item_id] #5 + InputAdapter + ReusedExchange [d_date_sk] #2 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q83.sf100.ansi/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q83.sf100.ansi/explain.txt new file mode 100644 index 0000000000000..bda63681ef500 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q83.sf100.ansi/explain.txt @@ -0,0 +1,362 @@ +== Physical Plan == +TakeOrderedAndProject (46) ++- * Project (45) + +- * BroadcastHashJoin Inner BuildRight (44) + :- * Project (30) + : +- * BroadcastHashJoin Inner BuildRight (29) + : :- * HashAggregate (15) + : : +- Exchange (14) + : : +- * HashAggregate (13) + : : +- * Project (12) + : : +- * BroadcastHashJoin Inner BuildRight (11) + : : :- * Project (6) + : : : +- * BroadcastHashJoin Inner BuildRight (5) + : : : :- * Filter (3) + : : : : +- * ColumnarToRow (2) + : : : : +- Scan parquet default.store_returns (1) + : : : +- ReusedExchange (4) + : : +- BroadcastExchange (10) + : : +- * Filter (9) + : : +- * ColumnarToRow (8) + : : +- Scan parquet default.item (7) + : +- BroadcastExchange (28) + : +- * HashAggregate (27) + : +- Exchange (26) + : +- * HashAggregate (25) + : +- * Project (24) + : +- * BroadcastHashJoin Inner BuildRight (23) + : :- * Project (21) + : : +- * BroadcastHashJoin Inner BuildRight (20) + : : :- * Filter (18) + : : : +- * ColumnarToRow (17) + : : : +- Scan parquet default.catalog_returns (16) + : : +- ReusedExchange (19) + : +- ReusedExchange (22) + +- BroadcastExchange (43) + +- * HashAggregate (42) + +- Exchange (41) + +- * HashAggregate (40) + +- * Project (39) + +- * BroadcastHashJoin Inner BuildRight (38) + :- * Project (36) + : +- * BroadcastHashJoin Inner BuildRight (35) + : :- * Filter (33) + : : +- * ColumnarToRow (32) + : : +- Scan parquet default.web_returns (31) + : +- ReusedExchange (34) + +- ReusedExchange (37) + + +(1) Scan parquet default.store_returns +Output [3]: [sr_item_sk#1, sr_return_quantity#2, sr_returned_date_sk#3] +Batched: true +Location: InMemoryFileIndex [] +PartitionFilters: [isnotnull(sr_returned_date_sk#3), dynamicpruningexpression(sr_returned_date_sk#3 IN dynamicpruning#4)] +PushedFilters: [IsNotNull(sr_item_sk)] +ReadSchema: struct + +(2) ColumnarToRow [codegen id : 5] +Input [3]: [sr_item_sk#1, sr_return_quantity#2, sr_returned_date_sk#3] + +(3) Filter [codegen id : 5] +Input [3]: [sr_item_sk#1, sr_return_quantity#2, sr_returned_date_sk#3] +Condition : isnotnull(sr_item_sk#1) + +(4) ReusedExchange [Reuses operator id: 62] +Output [1]: [d_date_sk#5] + +(5) BroadcastHashJoin [codegen id : 5] +Left keys [1]: [sr_returned_date_sk#3] +Right keys [1]: [d_date_sk#5] +Join condition: None + +(6) Project [codegen id : 5] +Output [2]: [sr_item_sk#1, sr_return_quantity#2] +Input [4]: [sr_item_sk#1, sr_return_quantity#2, sr_returned_date_sk#3, d_date_sk#5] + +(7) Scan parquet default.item +Output [2]: [i_item_sk#6, i_item_id#7] +Batched: true +Location [not included in comparison]/{warehouse_dir}/item] +PushedFilters: [IsNotNull(i_item_sk), IsNotNull(i_item_id)] +ReadSchema: struct + +(8) ColumnarToRow [codegen id : 4] +Input [2]: [i_item_sk#6, i_item_id#7] + +(9) Filter [codegen id : 4] +Input [2]: [i_item_sk#6, i_item_id#7] +Condition : (isnotnull(i_item_sk#6) AND isnotnull(i_item_id#7)) + +(10) BroadcastExchange +Input [2]: [i_item_sk#6, i_item_id#7] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#8] + +(11) BroadcastHashJoin [codegen id : 5] +Left keys [1]: [sr_item_sk#1] +Right keys [1]: [i_item_sk#6] +Join condition: None + +(12) Project [codegen id : 5] +Output [2]: [sr_return_quantity#2, i_item_id#7] +Input [4]: [sr_item_sk#1, sr_return_quantity#2, i_item_sk#6, i_item_id#7] + +(13) HashAggregate [codegen id : 5] +Input [2]: [sr_return_quantity#2, i_item_id#7] +Keys [1]: [i_item_id#7] +Functions [1]: [partial_sum(sr_return_quantity#2)] +Aggregate Attributes [1]: [sum#9] +Results [2]: [i_item_id#7, sum#10] + +(14) Exchange +Input [2]: [i_item_id#7, sum#10] +Arguments: hashpartitioning(i_item_id#7, 5), ENSURE_REQUIREMENTS, [id=#11] + +(15) HashAggregate [codegen id : 18] +Input [2]: [i_item_id#7, sum#10] +Keys [1]: [i_item_id#7] +Functions [1]: [sum(sr_return_quantity#2)] +Aggregate Attributes [1]: [sum(sr_return_quantity#2)#12] +Results [2]: [i_item_id#7 AS item_id#13, sum(sr_return_quantity#2)#12 AS sr_item_qty#14] + +(16) Scan parquet default.catalog_returns +Output [3]: [cr_item_sk#15, cr_return_quantity#16, cr_returned_date_sk#17] +Batched: true +Location: InMemoryFileIndex [] +PartitionFilters: [isnotnull(cr_returned_date_sk#17), dynamicpruningexpression(cr_returned_date_sk#17 IN dynamicpruning#4)] +PushedFilters: [IsNotNull(cr_item_sk)] +ReadSchema: struct + +(17) ColumnarToRow [codegen id : 10] +Input [3]: [cr_item_sk#15, cr_return_quantity#16, cr_returned_date_sk#17] + +(18) Filter [codegen id : 10] +Input [3]: [cr_item_sk#15, cr_return_quantity#16, cr_returned_date_sk#17] +Condition : isnotnull(cr_item_sk#15) + +(19) ReusedExchange [Reuses operator id: 62] +Output [1]: [d_date_sk#18] + +(20) BroadcastHashJoin [codegen id : 10] +Left keys [1]: [cr_returned_date_sk#17] +Right keys [1]: [d_date_sk#18] +Join condition: None + +(21) Project [codegen id : 10] +Output [2]: [cr_item_sk#15, cr_return_quantity#16] +Input [4]: [cr_item_sk#15, cr_return_quantity#16, cr_returned_date_sk#17, d_date_sk#18] + +(22) ReusedExchange [Reuses operator id: 10] +Output [2]: [i_item_sk#19, i_item_id#20] + +(23) BroadcastHashJoin [codegen id : 10] +Left keys [1]: [cr_item_sk#15] +Right keys [1]: [i_item_sk#19] +Join condition: None + +(24) Project [codegen id : 10] +Output [2]: [cr_return_quantity#16, i_item_id#20] +Input [4]: [cr_item_sk#15, cr_return_quantity#16, i_item_sk#19, i_item_id#20] + +(25) HashAggregate [codegen id : 10] +Input [2]: [cr_return_quantity#16, i_item_id#20] +Keys [1]: [i_item_id#20] +Functions [1]: [partial_sum(cr_return_quantity#16)] +Aggregate Attributes [1]: [sum#21] +Results [2]: [i_item_id#20, sum#22] + +(26) Exchange +Input [2]: [i_item_id#20, sum#22] +Arguments: hashpartitioning(i_item_id#20, 5), ENSURE_REQUIREMENTS, [id=#23] + +(27) HashAggregate [codegen id : 11] +Input [2]: [i_item_id#20, sum#22] +Keys [1]: [i_item_id#20] +Functions [1]: [sum(cr_return_quantity#16)] +Aggregate Attributes [1]: [sum(cr_return_quantity#16)#24] +Results [2]: [i_item_id#20 AS item_id#25, sum(cr_return_quantity#16)#24 AS cr_item_qty#26] + +(28) BroadcastExchange +Input [2]: [item_id#25, cr_item_qty#26] +Arguments: HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#27] + +(29) BroadcastHashJoin [codegen id : 18] +Left keys [1]: [item_id#13] +Right keys [1]: [item_id#25] +Join condition: None + +(30) Project [codegen id : 18] +Output [3]: [item_id#13, sr_item_qty#14, cr_item_qty#26] +Input [4]: [item_id#13, sr_item_qty#14, item_id#25, cr_item_qty#26] + +(31) Scan parquet default.web_returns +Output [3]: [wr_item_sk#28, wr_return_quantity#29, wr_returned_date_sk#30] +Batched: true +Location: InMemoryFileIndex [] +PartitionFilters: [isnotnull(wr_returned_date_sk#30), dynamicpruningexpression(wr_returned_date_sk#30 IN dynamicpruning#4)] +PushedFilters: [IsNotNull(wr_item_sk)] +ReadSchema: struct + +(32) ColumnarToRow [codegen id : 16] +Input [3]: [wr_item_sk#28, wr_return_quantity#29, wr_returned_date_sk#30] + +(33) Filter [codegen id : 16] +Input [3]: [wr_item_sk#28, wr_return_quantity#29, wr_returned_date_sk#30] +Condition : isnotnull(wr_item_sk#28) + +(34) ReusedExchange [Reuses operator id: 62] +Output [1]: [d_date_sk#31] + +(35) BroadcastHashJoin [codegen id : 16] +Left keys [1]: [wr_returned_date_sk#30] +Right keys [1]: [d_date_sk#31] +Join condition: None + +(36) Project [codegen id : 16] +Output [2]: [wr_item_sk#28, wr_return_quantity#29] +Input [4]: [wr_item_sk#28, wr_return_quantity#29, wr_returned_date_sk#30, d_date_sk#31] + +(37) ReusedExchange [Reuses operator id: 10] +Output [2]: [i_item_sk#32, i_item_id#33] + +(38) BroadcastHashJoin [codegen id : 16] +Left keys [1]: [wr_item_sk#28] +Right keys [1]: [i_item_sk#32] +Join condition: None + +(39) Project [codegen id : 16] +Output [2]: [wr_return_quantity#29, i_item_id#33] +Input [4]: [wr_item_sk#28, wr_return_quantity#29, i_item_sk#32, i_item_id#33] + +(40) HashAggregate [codegen id : 16] +Input [2]: [wr_return_quantity#29, i_item_id#33] +Keys [1]: [i_item_id#33] +Functions [1]: [partial_sum(wr_return_quantity#29)] +Aggregate Attributes [1]: [sum#34] +Results [2]: [i_item_id#33, sum#35] + +(41) Exchange +Input [2]: [i_item_id#33, sum#35] +Arguments: hashpartitioning(i_item_id#33, 5), ENSURE_REQUIREMENTS, [id=#36] + +(42) HashAggregate [codegen id : 17] +Input [2]: [i_item_id#33, sum#35] +Keys [1]: [i_item_id#33] +Functions [1]: [sum(wr_return_quantity#29)] +Aggregate Attributes [1]: [sum(wr_return_quantity#29)#37] +Results [2]: [i_item_id#33 AS item_id#38, sum(wr_return_quantity#29)#37 AS wr_item_qty#39] + +(43) BroadcastExchange +Input [2]: [item_id#38, wr_item_qty#39] +Arguments: HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#40] + +(44) BroadcastHashJoin [codegen id : 18] +Left keys [1]: [item_id#13] +Right keys [1]: [item_id#38] +Join condition: None + +(45) Project [codegen id : 18] +Output [8]: [item_id#13, sr_item_qty#14, (((cast(sr_item_qty#14 as double) / cast(((sr_item_qty#14 + cr_item_qty#26) + wr_item_qty#39) as double)) / 3.0) * 100.0) AS sr_dev#41, cr_item_qty#26, (((cast(cr_item_qty#26 as double) / cast(((sr_item_qty#14 + cr_item_qty#26) + wr_item_qty#39) as double)) / 3.0) * 100.0) AS cr_dev#42, wr_item_qty#39, (((cast(wr_item_qty#39 as double) / cast(((sr_item_qty#14 + cr_item_qty#26) + wr_item_qty#39) as double)) / 3.0) * 100.0) AS wr_dev#43, CheckOverflow((promote_precision(cast(((sr_item_qty#14 + cr_item_qty#26) + wr_item_qty#39) as decimal(21,1))) / 3.0), DecimalType(27,6)) AS average#44] +Input [5]: [item_id#13, sr_item_qty#14, cr_item_qty#26, item_id#38, wr_item_qty#39] + +(46) TakeOrderedAndProject +Input [8]: [item_id#13, sr_item_qty#14, sr_dev#41, cr_item_qty#26, cr_dev#42, wr_item_qty#39, wr_dev#43, average#44] +Arguments: 100, [item_id#13 ASC NULLS FIRST, sr_item_qty#14 ASC NULLS FIRST], [item_id#13, sr_item_qty#14, sr_dev#41, cr_item_qty#26, cr_dev#42, wr_item_qty#39, wr_dev#43, average#44] + +===== Subqueries ===== + +Subquery:1 Hosting operator id = 1 Hosting Expression = sr_returned_date_sk#3 IN dynamicpruning#4 +BroadcastExchange (62) ++- * Project (61) + +- * BroadcastHashJoin LeftSemi BuildRight (60) + :- * Filter (49) + : +- * ColumnarToRow (48) + : +- Scan parquet default.date_dim (47) + +- BroadcastExchange (59) + +- * Project (58) + +- * BroadcastHashJoin LeftSemi BuildRight (57) + :- * ColumnarToRow (51) + : +- Scan parquet default.date_dim (50) + +- BroadcastExchange (56) + +- * Project (55) + +- * Filter (54) + +- * ColumnarToRow (53) + +- Scan parquet default.date_dim (52) + + +(47) Scan parquet default.date_dim +Output [2]: [d_date_sk#5, d_date#45] +Batched: true +Location [not included in comparison]/{warehouse_dir}/date_dim] +PushedFilters: [IsNotNull(d_date_sk)] +ReadSchema: struct + +(48) ColumnarToRow [codegen id : 3] +Input [2]: [d_date_sk#5, d_date#45] + +(49) Filter [codegen id : 3] +Input [2]: [d_date_sk#5, d_date#45] +Condition : isnotnull(d_date_sk#5) + +(50) Scan parquet default.date_dim +Output [2]: [d_date#46, d_week_seq#47] +Batched: true +Location [not included in comparison]/{warehouse_dir}/date_dim] +ReadSchema: struct + +(51) ColumnarToRow [codegen id : 2] +Input [2]: [d_date#46, d_week_seq#47] + +(52) Scan parquet default.date_dim +Output [2]: [d_date#48, d_week_seq#49] +Batched: true +Location [not included in comparison]/{warehouse_dir}/date_dim] +PushedFilters: [In(d_date, [2000-06-30,2000-09-27,2000-11-17])] +ReadSchema: struct + +(53) ColumnarToRow [codegen id : 1] +Input [2]: [d_date#48, d_week_seq#49] + +(54) Filter [codegen id : 1] +Input [2]: [d_date#48, d_week_seq#49] +Condition : d_date#48 IN (2000-06-30,2000-09-27,2000-11-17) + +(55) Project [codegen id : 1] +Output [1]: [d_week_seq#49] +Input [2]: [d_date#48, d_week_seq#49] + +(56) BroadcastExchange +Input [1]: [d_week_seq#49] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#50] + +(57) BroadcastHashJoin [codegen id : 2] +Left keys [1]: [d_week_seq#47] +Right keys [1]: [d_week_seq#49] +Join condition: None + +(58) Project [codegen id : 2] +Output [1]: [d_date#46] +Input [2]: [d_date#46, d_week_seq#47] + +(59) BroadcastExchange +Input [1]: [d_date#46] +Arguments: HashedRelationBroadcastMode(List(input[0, date, true]),false), [id=#51] + +(60) BroadcastHashJoin [codegen id : 3] +Left keys [1]: [d_date#45] +Right keys [1]: [d_date#46] +Join condition: None + +(61) Project [codegen id : 3] +Output [1]: [d_date_sk#5] +Input [2]: [d_date_sk#5, d_date#45] + +(62) BroadcastExchange +Input [1]: [d_date_sk#5] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#52] + +Subquery:2 Hosting operator id = 16 Hosting Expression = cr_returned_date_sk#17 IN dynamicpruning#4 + +Subquery:3 Hosting operator id = 31 Hosting Expression = wr_returned_date_sk#30 IN dynamicpruning#4 + + diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q83.sf100.ansi/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q83.sf100.ansi/simplified.txt new file mode 100644 index 0000000000000..7f38503363767 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q83.sf100.ansi/simplified.txt @@ -0,0 +1,95 @@ +TakeOrderedAndProject [item_id,sr_item_qty,sr_dev,cr_item_qty,cr_dev,wr_item_qty,wr_dev,average] + WholeStageCodegen (18) + Project [item_id,sr_item_qty,cr_item_qty,wr_item_qty] + BroadcastHashJoin [item_id,item_id] + Project [item_id,sr_item_qty,cr_item_qty] + BroadcastHashJoin [item_id,item_id] + HashAggregate [i_item_id,sum] [sum(sr_return_quantity),item_id,sr_item_qty,sum] + InputAdapter + Exchange [i_item_id] #1 + WholeStageCodegen (5) + HashAggregate [i_item_id,sr_return_quantity] [sum,sum] + Project [sr_return_quantity,i_item_id] + BroadcastHashJoin [sr_item_sk,i_item_sk] + Project [sr_item_sk,sr_return_quantity] + BroadcastHashJoin [sr_returned_date_sk,d_date_sk] + Filter [sr_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.store_returns [sr_item_sk,sr_return_quantity,sr_returned_date_sk] + SubqueryBroadcast [d_date_sk] #1 + BroadcastExchange #2 + WholeStageCodegen (3) + Project [d_date_sk] + BroadcastHashJoin [d_date,d_date] + Filter [d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_date] + InputAdapter + BroadcastExchange #3 + WholeStageCodegen (2) + Project [d_date] + BroadcastHashJoin [d_week_seq,d_week_seq] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date,d_week_seq] + InputAdapter + BroadcastExchange #4 + WholeStageCodegen (1) + Project [d_week_seq] + Filter [d_date] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date,d_week_seq] + InputAdapter + ReusedExchange [d_date_sk] #2 + InputAdapter + BroadcastExchange #5 + WholeStageCodegen (4) + Filter [i_item_sk,i_item_id] + ColumnarToRow + InputAdapter + Scan parquet default.item [i_item_sk,i_item_id] + InputAdapter + BroadcastExchange #6 + WholeStageCodegen (11) + HashAggregate [i_item_id,sum] [sum(cr_return_quantity),item_id,cr_item_qty,sum] + InputAdapter + Exchange [i_item_id] #7 + WholeStageCodegen (10) + HashAggregate [i_item_id,cr_return_quantity] [sum,sum] + Project [cr_return_quantity,i_item_id] + BroadcastHashJoin [cr_item_sk,i_item_sk] + Project [cr_item_sk,cr_return_quantity] + BroadcastHashJoin [cr_returned_date_sk,d_date_sk] + Filter [cr_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.catalog_returns [cr_item_sk,cr_return_quantity,cr_returned_date_sk] + ReusedSubquery [d_date_sk] #1 + InputAdapter + ReusedExchange [d_date_sk] #2 + InputAdapter + ReusedExchange [i_item_sk,i_item_id] #5 + InputAdapter + BroadcastExchange #8 + WholeStageCodegen (17) + HashAggregate [i_item_id,sum] [sum(wr_return_quantity),item_id,wr_item_qty,sum] + InputAdapter + Exchange [i_item_id] #9 + WholeStageCodegen (16) + HashAggregate [i_item_id,wr_return_quantity] [sum,sum] + Project [wr_return_quantity,i_item_id] + BroadcastHashJoin [wr_item_sk,i_item_sk] + Project [wr_item_sk,wr_return_quantity] + BroadcastHashJoin [wr_returned_date_sk,d_date_sk] + Filter [wr_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.web_returns [wr_item_sk,wr_return_quantity,wr_returned_date_sk] + ReusedSubquery [d_date_sk] #1 + InputAdapter + ReusedExchange [d_date_sk] #2 + InputAdapter + ReusedExchange [i_item_sk,i_item_id] #5 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q83.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q83.sf100/explain.txt index 175a1c675675f..3374a3dc3daae 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q83.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q83.sf100/explain.txt @@ -256,7 +256,7 @@ Right keys [1]: [item_id#38] Join condition: None (45) Project [codegen id : 18] -Output [8]: [item_id#13, sr_item_qty#14, (((cast(sr_item_qty#14 as double) / cast(((sr_item_qty#14 + cr_item_qty#26) + wr_item_qty#39) as double)) / 3.0) * 100.0) AS sr_dev#41, cr_item_qty#26, (((cast(cr_item_qty#26 as double) / cast(((sr_item_qty#14 + cr_item_qty#26) + wr_item_qty#39) as double)) / 3.0) * 100.0) AS cr_dev#42, wr_item_qty#39, (((cast(wr_item_qty#39 as double) / cast(((sr_item_qty#14 + cr_item_qty#26) + wr_item_qty#39) as double)) / 3.0) * 100.0) AS wr_dev#43, CheckOverflow((promote_precision(cast(cast(((sr_item_qty#14 + cr_item_qty#26) + wr_item_qty#39) as decimal(20,0)) as decimal(21,1))) / 3.0), DecimalType(27,6), true) AS average#44] +Output [8]: [item_id#13, sr_item_qty#14, (((cast(sr_item_qty#14 as double) / cast(((sr_item_qty#14 + cr_item_qty#26) + wr_item_qty#39) as double)) / 3.0) * 100.0) AS sr_dev#41, cr_item_qty#26, (((cast(cr_item_qty#26 as double) / cast(((sr_item_qty#14 + cr_item_qty#26) + wr_item_qty#39) as double)) / 3.0) * 100.0) AS cr_dev#42, wr_item_qty#39, (((cast(wr_item_qty#39 as double) / cast(((sr_item_qty#14 + cr_item_qty#26) + wr_item_qty#39) as double)) / 3.0) * 100.0) AS wr_dev#43, CheckOverflow((promote_precision(cast(((sr_item_qty#14 + cr_item_qty#26) + wr_item_qty#39) as decimal(21,1))) / 3.0), DecimalType(27,6)) AS average#44] Input [5]: [item_id#13, sr_item_qty#14, cr_item_qty#26, item_id#38, wr_item_qty#39] (46) TakeOrderedAndProject diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q83/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q83/explain.txt index 8332d48905e48..106d5dd3090e3 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q83/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q83/explain.txt @@ -256,7 +256,7 @@ Right keys [1]: [item_id#38] Join condition: None (45) Project [codegen id : 18] -Output [8]: [item_id#13, sr_item_qty#14, (((cast(sr_item_qty#14 as double) / cast(((sr_item_qty#14 + cr_item_qty#26) + wr_item_qty#39) as double)) / 3.0) * 100.0) AS sr_dev#41, cr_item_qty#26, (((cast(cr_item_qty#26 as double) / cast(((sr_item_qty#14 + cr_item_qty#26) + wr_item_qty#39) as double)) / 3.0) * 100.0) AS cr_dev#42, wr_item_qty#39, (((cast(wr_item_qty#39 as double) / cast(((sr_item_qty#14 + cr_item_qty#26) + wr_item_qty#39) as double)) / 3.0) * 100.0) AS wr_dev#43, CheckOverflow((promote_precision(cast(cast(((sr_item_qty#14 + cr_item_qty#26) + wr_item_qty#39) as decimal(20,0)) as decimal(21,1))) / 3.0), DecimalType(27,6), true) AS average#44] +Output [8]: [item_id#13, sr_item_qty#14, (((cast(sr_item_qty#14 as double) / cast(((sr_item_qty#14 + cr_item_qty#26) + wr_item_qty#39) as double)) / 3.0) * 100.0) AS sr_dev#41, cr_item_qty#26, (((cast(cr_item_qty#26 as double) / cast(((sr_item_qty#14 + cr_item_qty#26) + wr_item_qty#39) as double)) / 3.0) * 100.0) AS cr_dev#42, wr_item_qty#39, (((cast(wr_item_qty#39 as double) / cast(((sr_item_qty#14 + cr_item_qty#26) + wr_item_qty#39) as double)) / 3.0) * 100.0) AS wr_dev#43, CheckOverflow((promote_precision(cast(((sr_item_qty#14 + cr_item_qty#26) + wr_item_qty#39) as decimal(21,1))) / 3.0), DecimalType(27,6)) AS average#44] Input [5]: [item_id#13, sr_item_qty#14, cr_item_qty#26, item_id#38, wr_item_qty#39] (46) TakeOrderedAndProject diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q87.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q87.sf100/explain.txt index 408b0defda53c..38ecc6f3ed822 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q87.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q87.sf100/explain.txt @@ -1,71 +1,64 @@ == Physical Plan == -* HashAggregate (67) -+- Exchange (66) - +- * HashAggregate (65) - +- * HashAggregate (64) - +- Exchange (63) - +- * HashAggregate (62) - +- * SortMergeJoin LeftAnti (61) - :- * Sort (43) - : +- Exchange (42) - : +- * HashAggregate (41) - : +- Exchange (40) - : +- * HashAggregate (39) - : +- * SortMergeJoin LeftAnti (38) - : :- * Sort (20) - : : +- Exchange (19) - : : +- * HashAggregate (18) - : : +- Exchange (17) - : : +- * HashAggregate (16) - : : +- * Project (15) - : : +- * SortMergeJoin Inner (14) - : : :- * Sort (8) - : : : +- Exchange (7) - : : : +- * Project (6) - : : : +- * BroadcastHashJoin Inner BuildRight (5) - : : : :- * Filter (3) - : : : : +- * ColumnarToRow (2) - : : : : +- Scan parquet default.store_sales (1) - : : : +- ReusedExchange (4) - : : +- * Sort (13) - : : +- Exchange (12) - : : +- * Filter (11) - : : +- * ColumnarToRow (10) - : : +- Scan parquet default.customer (9) - : +- * Sort (37) - : +- Exchange (36) - : +- * HashAggregate (35) - : +- Exchange (34) - : +- * HashAggregate (33) - : +- * Project (32) - : +- * SortMergeJoin Inner (31) - : :- * Sort (28) - : : +- Exchange (27) - : : +- * Project (26) - : : +- * BroadcastHashJoin Inner BuildRight (25) - : : :- * Filter (23) - : : : +- * ColumnarToRow (22) - : : : +- Scan parquet default.catalog_sales (21) - : : +- ReusedExchange (24) - : +- * Sort (30) - : +- ReusedExchange (29) - +- * Sort (60) - +- Exchange (59) - +- * HashAggregate (58) - +- Exchange (57) - +- * HashAggregate (56) - +- * Project (55) - +- * SortMergeJoin Inner (54) - :- * Sort (51) - : +- Exchange (50) - : +- * Project (49) - : +- * BroadcastHashJoin Inner BuildRight (48) - : :- * Filter (46) - : : +- * ColumnarToRow (45) - : : +- Scan parquet default.web_sales (44) - : +- ReusedExchange (47) - +- * Sort (53) - +- ReusedExchange (52) +* HashAggregate (60) ++- Exchange (59) + +- * HashAggregate (58) + +- * Project (57) + +- * SortMergeJoin LeftAnti (56) + :- * SortMergeJoin LeftAnti (38) + : :- * Sort (20) + : : +- Exchange (19) + : : +- * HashAggregate (18) + : : +- Exchange (17) + : : +- * HashAggregate (16) + : : +- * Project (15) + : : +- * SortMergeJoin Inner (14) + : : :- * Sort (8) + : : : +- Exchange (7) + : : : +- * Project (6) + : : : +- * BroadcastHashJoin Inner BuildRight (5) + : : : :- * Filter (3) + : : : : +- * ColumnarToRow (2) + : : : : +- Scan parquet default.store_sales (1) + : : : +- ReusedExchange (4) + : : +- * Sort (13) + : : +- Exchange (12) + : : +- * Filter (11) + : : +- * ColumnarToRow (10) + : : +- Scan parquet default.customer (9) + : +- * Sort (37) + : +- Exchange (36) + : +- * HashAggregate (35) + : +- Exchange (34) + : +- * HashAggregate (33) + : +- * Project (32) + : +- * SortMergeJoin Inner (31) + : :- * Sort (28) + : : +- Exchange (27) + : : +- * Project (26) + : : +- * BroadcastHashJoin Inner BuildRight (25) + : : :- * Filter (23) + : : : +- * ColumnarToRow (22) + : : : +- Scan parquet default.catalog_sales (21) + : : +- ReusedExchange (24) + : +- * Sort (30) + : +- ReusedExchange (29) + +- * Sort (55) + +- Exchange (54) + +- * HashAggregate (53) + +- Exchange (52) + +- * HashAggregate (51) + +- * Project (50) + +- * SortMergeJoin Inner (49) + :- * Sort (46) + : +- Exchange (45) + : +- * Project (44) + : +- * BroadcastHashJoin Inner BuildRight (43) + : :- * Filter (41) + : : +- * ColumnarToRow (40) + : : +- Scan parquet default.web_sales (39) + : +- ReusedExchange (42) + +- * Sort (48) + +- ReusedExchange (47) (1) Scan parquet default.store_sales @@ -83,7 +76,7 @@ Input [2]: [ss_customer_sk#1, ss_sold_date_sk#2] Input [2]: [ss_customer_sk#1, ss_sold_date_sk#2] Condition : isnotnull(ss_customer_sk#1) -(4) ReusedExchange [Reuses operator id: 72] +(4) ReusedExchange [Reuses operator id: 65] Output [2]: [d_date_sk#4, d_date#5] (5) BroadcastHashJoin [codegen id : 2] @@ -175,7 +168,7 @@ Input [2]: [cs_bill_customer_sk#13, cs_sold_date_sk#14] Input [2]: [cs_bill_customer_sk#13, cs_sold_date_sk#14] Condition : isnotnull(cs_bill_customer_sk#13) -(24) ReusedExchange [Reuses operator id: 72] +(24) ReusedExchange [Reuses operator id: 65] Output [2]: [d_date_sk#15, d_date#16] (25) BroadcastHashJoin [codegen id : 10] @@ -242,184 +235,144 @@ Left keys [6]: [coalesce(c_last_name#9, ), isnull(c_last_name#9), coalesce(c_fir Right keys [6]: [coalesce(c_last_name#20, ), isnull(c_last_name#20), coalesce(c_first_name#19, ), isnull(c_first_name#19), coalesce(d_date#16, 1970-01-01), isnull(d_date#16)] Join condition: None -(39) HashAggregate [codegen id : 17] -Input [3]: [c_last_name#9, c_first_name#8, d_date#5] -Keys [3]: [c_last_name#9, c_first_name#8, d_date#5] -Functions: [] -Aggregate Attributes: [] -Results [3]: [c_last_name#9, c_first_name#8, d_date#5] - -(40) Exchange -Input [3]: [c_last_name#9, c_first_name#8, d_date#5] -Arguments: hashpartitioning(c_last_name#9, c_first_name#8, d_date#5, 5), ENSURE_REQUIREMENTS, [id=#23] - -(41) HashAggregate [codegen id : 18] -Input [3]: [c_last_name#9, c_first_name#8, d_date#5] -Keys [3]: [c_last_name#9, c_first_name#8, d_date#5] -Functions: [] -Aggregate Attributes: [] -Results [3]: [c_last_name#9, c_first_name#8, d_date#5] - -(42) Exchange -Input [3]: [c_last_name#9, c_first_name#8, d_date#5] -Arguments: hashpartitioning(coalesce(c_last_name#9, ), isnull(c_last_name#9), coalesce(c_first_name#8, ), isnull(c_first_name#8), coalesce(d_date#5, 1970-01-01), isnull(d_date#5), 5), ENSURE_REQUIREMENTS, [id=#24] - -(43) Sort [codegen id : 19] -Input [3]: [c_last_name#9, c_first_name#8, d_date#5] -Arguments: [coalesce(c_last_name#9, ) ASC NULLS FIRST, isnull(c_last_name#9) ASC NULLS FIRST, coalesce(c_first_name#8, ) ASC NULLS FIRST, isnull(c_first_name#8) ASC NULLS FIRST, coalesce(d_date#5, 1970-01-01) ASC NULLS FIRST, isnull(d_date#5) ASC NULLS FIRST], false, 0 - -(44) Scan parquet default.web_sales -Output [2]: [ws_bill_customer_sk#25, ws_sold_date_sk#26] +(39) Scan parquet default.web_sales +Output [2]: [ws_bill_customer_sk#23, ws_sold_date_sk#24] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(ws_sold_date_sk#26), dynamicpruningexpression(ws_sold_date_sk#26 IN dynamicpruning#3)] +PartitionFilters: [isnotnull(ws_sold_date_sk#24), dynamicpruningexpression(ws_sold_date_sk#24 IN dynamicpruning#3)] PushedFilters: [IsNotNull(ws_bill_customer_sk)] ReadSchema: struct -(45) ColumnarToRow [codegen id : 21] -Input [2]: [ws_bill_customer_sk#25, ws_sold_date_sk#26] +(40) ColumnarToRow [codegen id : 19] +Input [2]: [ws_bill_customer_sk#23, ws_sold_date_sk#24] -(46) Filter [codegen id : 21] -Input [2]: [ws_bill_customer_sk#25, ws_sold_date_sk#26] -Condition : isnotnull(ws_bill_customer_sk#25) +(41) Filter [codegen id : 19] +Input [2]: [ws_bill_customer_sk#23, ws_sold_date_sk#24] +Condition : isnotnull(ws_bill_customer_sk#23) -(47) ReusedExchange [Reuses operator id: 72] -Output [2]: [d_date_sk#27, d_date#28] +(42) ReusedExchange [Reuses operator id: 65] +Output [2]: [d_date_sk#25, d_date#26] -(48) BroadcastHashJoin [codegen id : 21] -Left keys [1]: [ws_sold_date_sk#26] -Right keys [1]: [d_date_sk#27] +(43) BroadcastHashJoin [codegen id : 19] +Left keys [1]: [ws_sold_date_sk#24] +Right keys [1]: [d_date_sk#25] Join condition: None -(49) Project [codegen id : 21] -Output [2]: [ws_bill_customer_sk#25, d_date#28] -Input [4]: [ws_bill_customer_sk#25, ws_sold_date_sk#26, d_date_sk#27, d_date#28] +(44) Project [codegen id : 19] +Output [2]: [ws_bill_customer_sk#23, d_date#26] +Input [4]: [ws_bill_customer_sk#23, ws_sold_date_sk#24, d_date_sk#25, d_date#26] -(50) Exchange -Input [2]: [ws_bill_customer_sk#25, d_date#28] -Arguments: hashpartitioning(ws_bill_customer_sk#25, 5), ENSURE_REQUIREMENTS, [id=#29] +(45) Exchange +Input [2]: [ws_bill_customer_sk#23, d_date#26] +Arguments: hashpartitioning(ws_bill_customer_sk#23, 5), ENSURE_REQUIREMENTS, [id=#27] -(51) Sort [codegen id : 22] -Input [2]: [ws_bill_customer_sk#25, d_date#28] -Arguments: [ws_bill_customer_sk#25 ASC NULLS FIRST], false, 0 +(46) Sort [codegen id : 20] +Input [2]: [ws_bill_customer_sk#23, d_date#26] +Arguments: [ws_bill_customer_sk#23 ASC NULLS FIRST], false, 0 -(52) ReusedExchange [Reuses operator id: 12] -Output [3]: [c_customer_sk#30, c_first_name#31, c_last_name#32] +(47) ReusedExchange [Reuses operator id: 12] +Output [3]: [c_customer_sk#28, c_first_name#29, c_last_name#30] -(53) Sort [codegen id : 24] -Input [3]: [c_customer_sk#30, c_first_name#31, c_last_name#32] -Arguments: [c_customer_sk#30 ASC NULLS FIRST], false, 0 +(48) Sort [codegen id : 22] +Input [3]: [c_customer_sk#28, c_first_name#29, c_last_name#30] +Arguments: [c_customer_sk#28 ASC NULLS FIRST], false, 0 -(54) SortMergeJoin [codegen id : 25] -Left keys [1]: [ws_bill_customer_sk#25] -Right keys [1]: [c_customer_sk#30] +(49) SortMergeJoin [codegen id : 23] +Left keys [1]: [ws_bill_customer_sk#23] +Right keys [1]: [c_customer_sk#28] Join condition: None -(55) Project [codegen id : 25] -Output [3]: [c_last_name#32, c_first_name#31, d_date#28] -Input [5]: [ws_bill_customer_sk#25, d_date#28, c_customer_sk#30, c_first_name#31, c_last_name#32] +(50) Project [codegen id : 23] +Output [3]: [c_last_name#30, c_first_name#29, d_date#26] +Input [5]: [ws_bill_customer_sk#23, d_date#26, c_customer_sk#28, c_first_name#29, c_last_name#30] -(56) HashAggregate [codegen id : 25] -Input [3]: [c_last_name#32, c_first_name#31, d_date#28] -Keys [3]: [c_last_name#32, c_first_name#31, d_date#28] +(51) HashAggregate [codegen id : 23] +Input [3]: [c_last_name#30, c_first_name#29, d_date#26] +Keys [3]: [c_last_name#30, c_first_name#29, d_date#26] Functions: [] Aggregate Attributes: [] -Results [3]: [c_last_name#32, c_first_name#31, d_date#28] +Results [3]: [c_last_name#30, c_first_name#29, d_date#26] -(57) Exchange -Input [3]: [c_last_name#32, c_first_name#31, d_date#28] -Arguments: hashpartitioning(c_last_name#32, c_first_name#31, d_date#28, 5), ENSURE_REQUIREMENTS, [id=#33] +(52) Exchange +Input [3]: [c_last_name#30, c_first_name#29, d_date#26] +Arguments: hashpartitioning(c_last_name#30, c_first_name#29, d_date#26, 5), ENSURE_REQUIREMENTS, [id=#31] -(58) HashAggregate [codegen id : 26] -Input [3]: [c_last_name#32, c_first_name#31, d_date#28] -Keys [3]: [c_last_name#32, c_first_name#31, d_date#28] +(53) HashAggregate [codegen id : 24] +Input [3]: [c_last_name#30, c_first_name#29, d_date#26] +Keys [3]: [c_last_name#30, c_first_name#29, d_date#26] Functions: [] Aggregate Attributes: [] -Results [3]: [c_last_name#32, c_first_name#31, d_date#28] +Results [3]: [c_last_name#30, c_first_name#29, d_date#26] -(59) Exchange -Input [3]: [c_last_name#32, c_first_name#31, d_date#28] -Arguments: hashpartitioning(coalesce(c_last_name#32, ), isnull(c_last_name#32), coalesce(c_first_name#31, ), isnull(c_first_name#31), coalesce(d_date#28, 1970-01-01), isnull(d_date#28), 5), ENSURE_REQUIREMENTS, [id=#34] +(54) Exchange +Input [3]: [c_last_name#30, c_first_name#29, d_date#26] +Arguments: hashpartitioning(coalesce(c_last_name#30, ), isnull(c_last_name#30), coalesce(c_first_name#29, ), isnull(c_first_name#29), coalesce(d_date#26, 1970-01-01), isnull(d_date#26), 5), ENSURE_REQUIREMENTS, [id=#32] -(60) Sort [codegen id : 27] -Input [3]: [c_last_name#32, c_first_name#31, d_date#28] -Arguments: [coalesce(c_last_name#32, ) ASC NULLS FIRST, isnull(c_last_name#32) ASC NULLS FIRST, coalesce(c_first_name#31, ) ASC NULLS FIRST, isnull(c_first_name#31) ASC NULLS FIRST, coalesce(d_date#28, 1970-01-01) ASC NULLS FIRST, isnull(d_date#28) ASC NULLS FIRST], false, 0 +(55) Sort [codegen id : 25] +Input [3]: [c_last_name#30, c_first_name#29, d_date#26] +Arguments: [coalesce(c_last_name#30, ) ASC NULLS FIRST, isnull(c_last_name#30) ASC NULLS FIRST, coalesce(c_first_name#29, ) ASC NULLS FIRST, isnull(c_first_name#29) ASC NULLS FIRST, coalesce(d_date#26, 1970-01-01) ASC NULLS FIRST, isnull(d_date#26) ASC NULLS FIRST], false, 0 -(61) SortMergeJoin [codegen id : 28] +(56) SortMergeJoin [codegen id : 26] Left keys [6]: [coalesce(c_last_name#9, ), isnull(c_last_name#9), coalesce(c_first_name#8, ), isnull(c_first_name#8), coalesce(d_date#5, 1970-01-01), isnull(d_date#5)] -Right keys [6]: [coalesce(c_last_name#32, ), isnull(c_last_name#32), coalesce(c_first_name#31, ), isnull(c_first_name#31), coalesce(d_date#28, 1970-01-01), isnull(d_date#28)] +Right keys [6]: [coalesce(c_last_name#30, ), isnull(c_last_name#30), coalesce(c_first_name#29, ), isnull(c_first_name#29), coalesce(d_date#26, 1970-01-01), isnull(d_date#26)] Join condition: None -(62) HashAggregate [codegen id : 28] +(57) Project [codegen id : 26] +Output: [] Input [3]: [c_last_name#9, c_first_name#8, d_date#5] -Keys [3]: [c_last_name#9, c_first_name#8, d_date#5] -Functions: [] -Aggregate Attributes: [] -Results [3]: [c_last_name#9, c_first_name#8, d_date#5] -(63) Exchange -Input [3]: [c_last_name#9, c_first_name#8, d_date#5] -Arguments: hashpartitioning(c_last_name#9, c_first_name#8, d_date#5, 5), ENSURE_REQUIREMENTS, [id=#35] - -(64) HashAggregate [codegen id : 29] -Input [3]: [c_last_name#9, c_first_name#8, d_date#5] -Keys [3]: [c_last_name#9, c_first_name#8, d_date#5] -Functions: [] -Aggregate Attributes: [] -Results: [] - -(65) HashAggregate [codegen id : 29] +(58) HashAggregate [codegen id : 26] Input: [] Keys: [] Functions [1]: [partial_count(1)] -Aggregate Attributes [1]: [count#36] -Results [1]: [count#37] +Aggregate Attributes [1]: [count#33] +Results [1]: [count#34] -(66) Exchange -Input [1]: [count#37] -Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#38] +(59) Exchange +Input [1]: [count#34] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#35] -(67) HashAggregate [codegen id : 30] -Input [1]: [count#37] +(60) HashAggregate [codegen id : 27] +Input [1]: [count#34] Keys: [] Functions [1]: [count(1)] -Aggregate Attributes [1]: [count(1)#39] -Results [1]: [count(1)#39 AS count(1)#40] +Aggregate Attributes [1]: [count(1)#36] +Results [1]: [count(1)#36 AS count(1)#37] ===== Subqueries ===== Subquery:1 Hosting operator id = 1 Hosting Expression = ss_sold_date_sk#2 IN dynamicpruning#3 -BroadcastExchange (72) -+- * Project (71) - +- * Filter (70) - +- * ColumnarToRow (69) - +- Scan parquet default.date_dim (68) +BroadcastExchange (65) ++- * Project (64) + +- * Filter (63) + +- * ColumnarToRow (62) + +- Scan parquet default.date_dim (61) -(68) Scan parquet default.date_dim -Output [3]: [d_date_sk#4, d_date#5, d_month_seq#41] +(61) Scan parquet default.date_dim +Output [3]: [d_date_sk#4, d_date#5, d_month_seq#38] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_month_seq), GreaterThanOrEqual(d_month_seq,1200), LessThanOrEqual(d_month_seq,1211), IsNotNull(d_date_sk)] ReadSchema: struct -(69) ColumnarToRow [codegen id : 1] -Input [3]: [d_date_sk#4, d_date#5, d_month_seq#41] +(62) ColumnarToRow [codegen id : 1] +Input [3]: [d_date_sk#4, d_date#5, d_month_seq#38] -(70) Filter [codegen id : 1] -Input [3]: [d_date_sk#4, d_date#5, d_month_seq#41] -Condition : (((isnotnull(d_month_seq#41) AND (d_month_seq#41 >= 1200)) AND (d_month_seq#41 <= 1211)) AND isnotnull(d_date_sk#4)) +(63) Filter [codegen id : 1] +Input [3]: [d_date_sk#4, d_date#5, d_month_seq#38] +Condition : (((isnotnull(d_month_seq#38) AND (d_month_seq#38 >= 1200)) AND (d_month_seq#38 <= 1211)) AND isnotnull(d_date_sk#4)) -(71) Project [codegen id : 1] +(64) Project [codegen id : 1] Output [2]: [d_date_sk#4, d_date#5] -Input [3]: [d_date_sk#4, d_date#5, d_month_seq#41] +Input [3]: [d_date_sk#4, d_date#5, d_month_seq#38] -(72) BroadcastExchange +(65) BroadcastExchange Input [2]: [d_date_sk#4, d_date#5] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#42] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#39] Subquery:2 Hosting operator id = 21 Hosting Expression = cs_sold_date_sk#14 IN dynamicpruning#3 -Subquery:3 Hosting operator id = 44 Hosting Expression = ws_sold_date_sk#26 IN dynamicpruning#3 +Subquery:3 Hosting operator id = 39 Hosting Expression = ws_sold_date_sk#24 IN dynamicpruning#3 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q87.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q87.sf100/simplified.txt index eda0d4b03f483..cc66a0040ef9a 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q87.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q87.sf100/simplified.txt @@ -1,135 +1,122 @@ -WholeStageCodegen (30) +WholeStageCodegen (27) HashAggregate [count] [count(1),count(1),count] InputAdapter Exchange #1 - WholeStageCodegen (29) + WholeStageCodegen (26) HashAggregate [count,count] - HashAggregate [c_last_name,c_first_name,d_date] - InputAdapter - Exchange [c_last_name,c_first_name,d_date] #2 - WholeStageCodegen (28) - HashAggregate [c_last_name,c_first_name,d_date] - SortMergeJoin [c_last_name,c_first_name,d_date,c_last_name,c_first_name,d_date] - InputAdapter - WholeStageCodegen (19) - Sort [c_last_name,c_first_name,d_date] - InputAdapter - Exchange [c_last_name,c_first_name,d_date] #3 - WholeStageCodegen (18) - HashAggregate [c_last_name,c_first_name,d_date] - InputAdapter - Exchange [c_last_name,c_first_name,d_date] #4 - WholeStageCodegen (17) - HashAggregate [c_last_name,c_first_name,d_date] - SortMergeJoin [c_last_name,c_first_name,d_date,c_last_name,c_first_name,d_date] + Project + SortMergeJoin [c_last_name,c_first_name,d_date,c_last_name,c_first_name,d_date] + InputAdapter + WholeStageCodegen (17) + SortMergeJoin [c_last_name,c_first_name,d_date,c_last_name,c_first_name,d_date] + InputAdapter + WholeStageCodegen (8) + Sort [c_last_name,c_first_name,d_date] + InputAdapter + Exchange [c_last_name,c_first_name,d_date] #2 + WholeStageCodegen (7) + HashAggregate [c_last_name,c_first_name,d_date] + InputAdapter + Exchange [c_last_name,c_first_name,d_date] #3 + WholeStageCodegen (6) + HashAggregate [c_last_name,c_first_name,d_date] + Project [c_last_name,c_first_name,d_date] + SortMergeJoin [ss_customer_sk,c_customer_sk] InputAdapter - WholeStageCodegen (8) - Sort [c_last_name,c_first_name,d_date] + WholeStageCodegen (3) + Sort [ss_customer_sk] InputAdapter - Exchange [c_last_name,c_first_name,d_date] #5 - WholeStageCodegen (7) - HashAggregate [c_last_name,c_first_name,d_date] - InputAdapter - Exchange [c_last_name,c_first_name,d_date] #6 - WholeStageCodegen (6) - HashAggregate [c_last_name,c_first_name,d_date] - Project [c_last_name,c_first_name,d_date] - SortMergeJoin [ss_customer_sk,c_customer_sk] - InputAdapter - WholeStageCodegen (3) - Sort [ss_customer_sk] - InputAdapter - Exchange [ss_customer_sk] #7 - WholeStageCodegen (2) - Project [ss_customer_sk,d_date] - BroadcastHashJoin [ss_sold_date_sk,d_date_sk] - Filter [ss_customer_sk] - ColumnarToRow - InputAdapter - Scan parquet default.store_sales [ss_customer_sk,ss_sold_date_sk] - SubqueryBroadcast [d_date_sk] #1 - BroadcastExchange #8 - WholeStageCodegen (1) - Project [d_date_sk,d_date] - Filter [d_month_seq,d_date_sk] - ColumnarToRow - InputAdapter - Scan parquet default.date_dim [d_date_sk,d_date,d_month_seq] - InputAdapter - ReusedExchange [d_date_sk,d_date] #8 - InputAdapter - WholeStageCodegen (5) - Sort [c_customer_sk] - InputAdapter - Exchange [c_customer_sk] #9 - WholeStageCodegen (4) - Filter [c_customer_sk] - ColumnarToRow - InputAdapter - Scan parquet default.customer [c_customer_sk,c_first_name,c_last_name] + Exchange [ss_customer_sk] #4 + WholeStageCodegen (2) + Project [ss_customer_sk,d_date] + BroadcastHashJoin [ss_sold_date_sk,d_date_sk] + Filter [ss_customer_sk] + ColumnarToRow + InputAdapter + Scan parquet default.store_sales [ss_customer_sk,ss_sold_date_sk] + SubqueryBroadcast [d_date_sk] #1 + BroadcastExchange #5 + WholeStageCodegen (1) + Project [d_date_sk,d_date] + Filter [d_month_seq,d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_date,d_month_seq] + InputAdapter + ReusedExchange [d_date_sk,d_date] #5 InputAdapter - WholeStageCodegen (16) - Sort [c_last_name,c_first_name,d_date] + WholeStageCodegen (5) + Sort [c_customer_sk] InputAdapter - Exchange [c_last_name,c_first_name,d_date] #10 - WholeStageCodegen (15) - HashAggregate [c_last_name,c_first_name,d_date] - InputAdapter - Exchange [c_last_name,c_first_name,d_date] #11 - WholeStageCodegen (14) - HashAggregate [c_last_name,c_first_name,d_date] - Project [c_last_name,c_first_name,d_date] - SortMergeJoin [cs_bill_customer_sk,c_customer_sk] - InputAdapter - WholeStageCodegen (11) - Sort [cs_bill_customer_sk] - InputAdapter - Exchange [cs_bill_customer_sk] #12 - WholeStageCodegen (10) - Project [cs_bill_customer_sk,d_date] - BroadcastHashJoin [cs_sold_date_sk,d_date_sk] - Filter [cs_bill_customer_sk] - ColumnarToRow - InputAdapter - Scan parquet default.catalog_sales [cs_bill_customer_sk,cs_sold_date_sk] - ReusedSubquery [d_date_sk] #1 - InputAdapter - ReusedExchange [d_date_sk,d_date] #8 - InputAdapter - WholeStageCodegen (13) - Sort [c_customer_sk] - InputAdapter - ReusedExchange [c_customer_sk,c_first_name,c_last_name] #9 - InputAdapter - WholeStageCodegen (27) - Sort [c_last_name,c_first_name,d_date] + Exchange [c_customer_sk] #6 + WholeStageCodegen (4) + Filter [c_customer_sk] + ColumnarToRow + InputAdapter + Scan parquet default.customer [c_customer_sk,c_first_name,c_last_name] + InputAdapter + WholeStageCodegen (16) + Sort [c_last_name,c_first_name,d_date] + InputAdapter + Exchange [c_last_name,c_first_name,d_date] #7 + WholeStageCodegen (15) + HashAggregate [c_last_name,c_first_name,d_date] + InputAdapter + Exchange [c_last_name,c_first_name,d_date] #8 + WholeStageCodegen (14) + HashAggregate [c_last_name,c_first_name,d_date] + Project [c_last_name,c_first_name,d_date] + SortMergeJoin [cs_bill_customer_sk,c_customer_sk] + InputAdapter + WholeStageCodegen (11) + Sort [cs_bill_customer_sk] + InputAdapter + Exchange [cs_bill_customer_sk] #9 + WholeStageCodegen (10) + Project [cs_bill_customer_sk,d_date] + BroadcastHashJoin [cs_sold_date_sk,d_date_sk] + Filter [cs_bill_customer_sk] + ColumnarToRow + InputAdapter + Scan parquet default.catalog_sales [cs_bill_customer_sk,cs_sold_date_sk] + ReusedSubquery [d_date_sk] #1 + InputAdapter + ReusedExchange [d_date_sk,d_date] #5 + InputAdapter + WholeStageCodegen (13) + Sort [c_customer_sk] + InputAdapter + ReusedExchange [c_customer_sk,c_first_name,c_last_name] #6 + InputAdapter + WholeStageCodegen (25) + Sort [c_last_name,c_first_name,d_date] + InputAdapter + Exchange [c_last_name,c_first_name,d_date] #10 + WholeStageCodegen (24) + HashAggregate [c_last_name,c_first_name,d_date] InputAdapter - Exchange [c_last_name,c_first_name,d_date] #13 - WholeStageCodegen (26) + Exchange [c_last_name,c_first_name,d_date] #11 + WholeStageCodegen (23) HashAggregate [c_last_name,c_first_name,d_date] - InputAdapter - Exchange [c_last_name,c_first_name,d_date] #14 - WholeStageCodegen (25) - HashAggregate [c_last_name,c_first_name,d_date] - Project [c_last_name,c_first_name,d_date] - SortMergeJoin [ws_bill_customer_sk,c_customer_sk] - InputAdapter - WholeStageCodegen (22) - Sort [ws_bill_customer_sk] - InputAdapter - Exchange [ws_bill_customer_sk] #15 - WholeStageCodegen (21) - Project [ws_bill_customer_sk,d_date] - BroadcastHashJoin [ws_sold_date_sk,d_date_sk] - Filter [ws_bill_customer_sk] - ColumnarToRow - InputAdapter - Scan parquet default.web_sales [ws_bill_customer_sk,ws_sold_date_sk] - ReusedSubquery [d_date_sk] #1 - InputAdapter - ReusedExchange [d_date_sk,d_date] #8 - InputAdapter - WholeStageCodegen (24) - Sort [c_customer_sk] - InputAdapter - ReusedExchange [c_customer_sk,c_first_name,c_last_name] #9 + Project [c_last_name,c_first_name,d_date] + SortMergeJoin [ws_bill_customer_sk,c_customer_sk] + InputAdapter + WholeStageCodegen (20) + Sort [ws_bill_customer_sk] + InputAdapter + Exchange [ws_bill_customer_sk] #12 + WholeStageCodegen (19) + Project [ws_bill_customer_sk,d_date] + BroadcastHashJoin [ws_sold_date_sk,d_date_sk] + Filter [ws_bill_customer_sk] + ColumnarToRow + InputAdapter + Scan parquet default.web_sales [ws_bill_customer_sk,ws_sold_date_sk] + ReusedSubquery [d_date_sk] #1 + InputAdapter + ReusedExchange [d_date_sk,d_date] #5 + InputAdapter + WholeStageCodegen (22) + Sort [c_customer_sk] + InputAdapter + ReusedExchange [c_customer_sk,c_first_name,c_last_name] #6 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q87/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q87/explain.txt index 7193c4f8c57ef..ed2a97704b2f7 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q87/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q87/explain.txt @@ -1,54 +1,51 @@ == Physical Plan == -* HashAggregate (50) -+- Exchange (49) - +- * HashAggregate (48) - +- * HashAggregate (47) - +- * HashAggregate (46) - +- * BroadcastHashJoin LeftAnti BuildRight (45) - :- * HashAggregate (31) - : +- * HashAggregate (30) - : +- * BroadcastHashJoin LeftAnti BuildRight (29) - : :- * HashAggregate (15) - : : +- Exchange (14) - : : +- * HashAggregate (13) - : : +- * Project (12) - : : +- * BroadcastHashJoin Inner BuildRight (11) - : : :- * Project (6) - : : : +- * BroadcastHashJoin Inner BuildRight (5) - : : : :- * Filter (3) - : : : : +- * ColumnarToRow (2) - : : : : +- Scan parquet default.store_sales (1) - : : : +- ReusedExchange (4) - : : +- BroadcastExchange (10) - : : +- * Filter (9) - : : +- * ColumnarToRow (8) - : : +- Scan parquet default.customer (7) - : +- BroadcastExchange (28) - : +- * HashAggregate (27) - : +- Exchange (26) - : +- * HashAggregate (25) - : +- * Project (24) - : +- * BroadcastHashJoin Inner BuildRight (23) - : :- * Project (21) - : : +- * BroadcastHashJoin Inner BuildRight (20) - : : :- * Filter (18) - : : : +- * ColumnarToRow (17) - : : : +- Scan parquet default.catalog_sales (16) - : : +- ReusedExchange (19) - : +- ReusedExchange (22) - +- BroadcastExchange (44) - +- * HashAggregate (43) - +- Exchange (42) - +- * HashAggregate (41) - +- * Project (40) - +- * BroadcastHashJoin Inner BuildRight (39) - :- * Project (37) - : +- * BroadcastHashJoin Inner BuildRight (36) - : :- * Filter (34) - : : +- * ColumnarToRow (33) - : : +- Scan parquet default.web_sales (32) - : +- ReusedExchange (35) - +- ReusedExchange (38) +* HashAggregate (47) ++- Exchange (46) + +- * HashAggregate (45) + +- * Project (44) + +- * BroadcastHashJoin LeftAnti BuildRight (43) + :- * BroadcastHashJoin LeftAnti BuildRight (29) + : :- * HashAggregate (15) + : : +- Exchange (14) + : : +- * HashAggregate (13) + : : +- * Project (12) + : : +- * BroadcastHashJoin Inner BuildRight (11) + : : :- * Project (6) + : : : +- * BroadcastHashJoin Inner BuildRight (5) + : : : :- * Filter (3) + : : : : +- * ColumnarToRow (2) + : : : : +- Scan parquet default.store_sales (1) + : : : +- ReusedExchange (4) + : : +- BroadcastExchange (10) + : : +- * Filter (9) + : : +- * ColumnarToRow (8) + : : +- Scan parquet default.customer (7) + : +- BroadcastExchange (28) + : +- * HashAggregate (27) + : +- Exchange (26) + : +- * HashAggregate (25) + : +- * Project (24) + : +- * BroadcastHashJoin Inner BuildRight (23) + : :- * Project (21) + : : +- * BroadcastHashJoin Inner BuildRight (20) + : : :- * Filter (18) + : : : +- * ColumnarToRow (17) + : : : +- Scan parquet default.catalog_sales (16) + : : +- ReusedExchange (19) + : +- ReusedExchange (22) + +- BroadcastExchange (42) + +- * HashAggregate (41) + +- Exchange (40) + +- * HashAggregate (39) + +- * Project (38) + +- * BroadcastHashJoin Inner BuildRight (37) + :- * Project (35) + : +- * BroadcastHashJoin Inner BuildRight (34) + : :- * Filter (32) + : : +- * ColumnarToRow (31) + : : +- Scan parquet default.web_sales (30) + : +- ReusedExchange (33) + +- ReusedExchange (36) (1) Scan parquet default.store_sales @@ -66,7 +63,7 @@ Input [2]: [ss_customer_sk#1, ss_sold_date_sk#2] Input [2]: [ss_customer_sk#1, ss_sold_date_sk#2] Condition : isnotnull(ss_customer_sk#1) -(4) ReusedExchange [Reuses operator id: 55] +(4) ReusedExchange [Reuses operator id: 52] Output [2]: [d_date_sk#4, d_date#5] (5) BroadcastHashJoin [codegen id : 3] @@ -138,7 +135,7 @@ Input [2]: [cs_bill_customer_sk#11, cs_sold_date_sk#12] Input [2]: [cs_bill_customer_sk#11, cs_sold_date_sk#12] Condition : isnotnull(cs_bill_customer_sk#11) -(19) ReusedExchange [Reuses operator id: 55] +(19) ReusedExchange [Reuses operator id: 52] Output [2]: [d_date_sk#13, d_date#14] (20) BroadcastHashJoin [codegen id : 6] @@ -189,21 +186,7 @@ Left keys [6]: [coalesce(c_last_name#8, ), isnull(c_last_name#8), coalesce(c_fir Right keys [6]: [coalesce(c_last_name#17, ), isnull(c_last_name#17), coalesce(c_first_name#16, ), isnull(c_first_name#16), coalesce(d_date#14, 1970-01-01), isnull(d_date#14)] Join condition: None -(30) HashAggregate [codegen id : 12] -Input [3]: [c_last_name#8, c_first_name#7, d_date#5] -Keys [3]: [c_last_name#8, c_first_name#7, d_date#5] -Functions: [] -Aggregate Attributes: [] -Results [3]: [c_last_name#8, c_first_name#7, d_date#5] - -(31) HashAggregate [codegen id : 12] -Input [3]: [c_last_name#8, c_first_name#7, d_date#5] -Keys [3]: [c_last_name#8, c_first_name#7, d_date#5] -Functions: [] -Aggregate Attributes: [] -Results [3]: [c_last_name#8, c_first_name#7, d_date#5] - -(32) Scan parquet default.web_sales +(30) Scan parquet default.web_sales Output [2]: [ws_bill_customer_sk#20, ws_sold_date_sk#21] Batched: true Location: InMemoryFileIndex [] @@ -211,90 +194,80 @@ PartitionFilters: [isnotnull(ws_sold_date_sk#21), dynamicpruningexpression(ws_so PushedFilters: [IsNotNull(ws_bill_customer_sk)] ReadSchema: struct -(33) ColumnarToRow [codegen id : 10] +(31) ColumnarToRow [codegen id : 10] Input [2]: [ws_bill_customer_sk#20, ws_sold_date_sk#21] -(34) Filter [codegen id : 10] +(32) Filter [codegen id : 10] Input [2]: [ws_bill_customer_sk#20, ws_sold_date_sk#21] Condition : isnotnull(ws_bill_customer_sk#20) -(35) ReusedExchange [Reuses operator id: 55] +(33) ReusedExchange [Reuses operator id: 52] Output [2]: [d_date_sk#22, d_date#23] -(36) BroadcastHashJoin [codegen id : 10] +(34) BroadcastHashJoin [codegen id : 10] Left keys [1]: [ws_sold_date_sk#21] Right keys [1]: [d_date_sk#22] Join condition: None -(37) Project [codegen id : 10] +(35) Project [codegen id : 10] Output [2]: [ws_bill_customer_sk#20, d_date#23] Input [4]: [ws_bill_customer_sk#20, ws_sold_date_sk#21, d_date_sk#22, d_date#23] -(38) ReusedExchange [Reuses operator id: 10] +(36) ReusedExchange [Reuses operator id: 10] Output [3]: [c_customer_sk#24, c_first_name#25, c_last_name#26] -(39) BroadcastHashJoin [codegen id : 10] +(37) BroadcastHashJoin [codegen id : 10] Left keys [1]: [ws_bill_customer_sk#20] Right keys [1]: [c_customer_sk#24] Join condition: None -(40) Project [codegen id : 10] +(38) Project [codegen id : 10] Output [3]: [c_last_name#26, c_first_name#25, d_date#23] Input [5]: [ws_bill_customer_sk#20, d_date#23, c_customer_sk#24, c_first_name#25, c_last_name#26] -(41) HashAggregate [codegen id : 10] +(39) HashAggregate [codegen id : 10] Input [3]: [c_last_name#26, c_first_name#25, d_date#23] Keys [3]: [c_last_name#26, c_first_name#25, d_date#23] Functions: [] Aggregate Attributes: [] Results [3]: [c_last_name#26, c_first_name#25, d_date#23] -(42) Exchange +(40) Exchange Input [3]: [c_last_name#26, c_first_name#25, d_date#23] Arguments: hashpartitioning(c_last_name#26, c_first_name#25, d_date#23, 5), ENSURE_REQUIREMENTS, [id=#27] -(43) HashAggregate [codegen id : 11] +(41) HashAggregate [codegen id : 11] Input [3]: [c_last_name#26, c_first_name#25, d_date#23] Keys [3]: [c_last_name#26, c_first_name#25, d_date#23] Functions: [] Aggregate Attributes: [] Results [3]: [c_last_name#26, c_first_name#25, d_date#23] -(44) BroadcastExchange +(42) BroadcastExchange Input [3]: [c_last_name#26, c_first_name#25, d_date#23] Arguments: HashedRelationBroadcastMode(List(coalesce(input[0, string, true], ), isnull(input[0, string, true]), coalesce(input[1, string, true], ), isnull(input[1, string, true]), coalesce(input[2, date, true], 1970-01-01), isnull(input[2, date, true])),false), [id=#28] -(45) BroadcastHashJoin [codegen id : 12] +(43) BroadcastHashJoin [codegen id : 12] Left keys [6]: [coalesce(c_last_name#8, ), isnull(c_last_name#8), coalesce(c_first_name#7, ), isnull(c_first_name#7), coalesce(d_date#5, 1970-01-01), isnull(d_date#5)] Right keys [6]: [coalesce(c_last_name#26, ), isnull(c_last_name#26), coalesce(c_first_name#25, ), isnull(c_first_name#25), coalesce(d_date#23, 1970-01-01), isnull(d_date#23)] Join condition: None -(46) HashAggregate [codegen id : 12] -Input [3]: [c_last_name#8, c_first_name#7, d_date#5] -Keys [3]: [c_last_name#8, c_first_name#7, d_date#5] -Functions: [] -Aggregate Attributes: [] -Results [3]: [c_last_name#8, c_first_name#7, d_date#5] - -(47) HashAggregate [codegen id : 12] +(44) Project [codegen id : 12] +Output: [] Input [3]: [c_last_name#8, c_first_name#7, d_date#5] -Keys [3]: [c_last_name#8, c_first_name#7, d_date#5] -Functions: [] -Aggregate Attributes: [] -Results: [] -(48) HashAggregate [codegen id : 12] +(45) HashAggregate [codegen id : 12] Input: [] Keys: [] Functions [1]: [partial_count(1)] Aggregate Attributes [1]: [count#29] Results [1]: [count#30] -(49) Exchange +(46) Exchange Input [1]: [count#30] Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#31] -(50) HashAggregate [codegen id : 13] +(47) HashAggregate [codegen id : 13] Input [1]: [count#30] Keys: [] Functions [1]: [count(1)] @@ -304,37 +277,37 @@ Results [1]: [count(1)#32 AS count(1)#33] ===== Subqueries ===== Subquery:1 Hosting operator id = 1 Hosting Expression = ss_sold_date_sk#2 IN dynamicpruning#3 -BroadcastExchange (55) -+- * Project (54) - +- * Filter (53) - +- * ColumnarToRow (52) - +- Scan parquet default.date_dim (51) +BroadcastExchange (52) ++- * Project (51) + +- * Filter (50) + +- * ColumnarToRow (49) + +- Scan parquet default.date_dim (48) -(51) Scan parquet default.date_dim +(48) Scan parquet default.date_dim Output [3]: [d_date_sk#4, d_date#5, d_month_seq#34] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_month_seq), GreaterThanOrEqual(d_month_seq,1200), LessThanOrEqual(d_month_seq,1211), IsNotNull(d_date_sk)] ReadSchema: struct -(52) ColumnarToRow [codegen id : 1] +(49) ColumnarToRow [codegen id : 1] Input [3]: [d_date_sk#4, d_date#5, d_month_seq#34] -(53) Filter [codegen id : 1] +(50) Filter [codegen id : 1] Input [3]: [d_date_sk#4, d_date#5, d_month_seq#34] Condition : (((isnotnull(d_month_seq#34) AND (d_month_seq#34 >= 1200)) AND (d_month_seq#34 <= 1211)) AND isnotnull(d_date_sk#4)) -(54) Project [codegen id : 1] +(51) Project [codegen id : 1] Output [2]: [d_date_sk#4, d_date#5] Input [3]: [d_date_sk#4, d_date#5, d_month_seq#34] -(55) BroadcastExchange +(52) BroadcastExchange Input [2]: [d_date_sk#4, d_date#5] Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#35] Subquery:2 Hosting operator id = 16 Hosting Expression = cs_sold_date_sk#12 IN dynamicpruning#3 -Subquery:3 Hosting operator id = 32 Hosting Expression = ws_sold_date_sk#21 IN dynamicpruning#3 +Subquery:3 Hosting operator id = 30 Hosting Expression = ws_sold_date_sk#21 IN dynamicpruning#3 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q87/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q87/simplified.txt index 7f96f5657836a..34d46c5671774 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q87/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q87/simplified.txt @@ -4,81 +4,78 @@ WholeStageCodegen (13) Exchange #1 WholeStageCodegen (12) HashAggregate [count,count] - HashAggregate [c_last_name,c_first_name,d_date] - HashAggregate [c_last_name,c_first_name,d_date] + Project + BroadcastHashJoin [c_last_name,c_first_name,d_date,c_last_name,c_first_name,d_date] BroadcastHashJoin [c_last_name,c_first_name,d_date,c_last_name,c_first_name,d_date] HashAggregate [c_last_name,c_first_name,d_date] - HashAggregate [c_last_name,c_first_name,d_date] - BroadcastHashJoin [c_last_name,c_first_name,d_date,c_last_name,c_first_name,d_date] - HashAggregate [c_last_name,c_first_name,d_date] - InputAdapter - Exchange [c_last_name,c_first_name,d_date] #2 - WholeStageCodegen (3) - HashAggregate [c_last_name,c_first_name,d_date] - Project [c_last_name,c_first_name,d_date] - BroadcastHashJoin [ss_customer_sk,c_customer_sk] - Project [ss_customer_sk,d_date] - BroadcastHashJoin [ss_sold_date_sk,d_date_sk] - Filter [ss_customer_sk] - ColumnarToRow - InputAdapter - Scan parquet default.store_sales [ss_customer_sk,ss_sold_date_sk] - SubqueryBroadcast [d_date_sk] #1 - BroadcastExchange #3 - WholeStageCodegen (1) - Project [d_date_sk,d_date] - Filter [d_month_seq,d_date_sk] - ColumnarToRow - InputAdapter - Scan parquet default.date_dim [d_date_sk,d_date,d_month_seq] - InputAdapter - ReusedExchange [d_date_sk,d_date] #3 - InputAdapter - BroadcastExchange #4 - WholeStageCodegen (2) - Filter [c_customer_sk] - ColumnarToRow - InputAdapter - Scan parquet default.customer [c_customer_sk,c_first_name,c_last_name] - InputAdapter - BroadcastExchange #5 - WholeStageCodegen (7) - HashAggregate [c_last_name,c_first_name,d_date] + InputAdapter + Exchange [c_last_name,c_first_name,d_date] #2 + WholeStageCodegen (3) + HashAggregate [c_last_name,c_first_name,d_date] + Project [c_last_name,c_first_name,d_date] + BroadcastHashJoin [ss_customer_sk,c_customer_sk] + Project [ss_customer_sk,d_date] + BroadcastHashJoin [ss_sold_date_sk,d_date_sk] + Filter [ss_customer_sk] + ColumnarToRow + InputAdapter + Scan parquet default.store_sales [ss_customer_sk,ss_sold_date_sk] + SubqueryBroadcast [d_date_sk] #1 + BroadcastExchange #3 + WholeStageCodegen (1) + Project [d_date_sk,d_date] + Filter [d_month_seq,d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_date,d_month_seq] + InputAdapter + ReusedExchange [d_date_sk,d_date] #3 InputAdapter - Exchange [c_last_name,c_first_name,d_date] #6 - WholeStageCodegen (6) - HashAggregate [c_last_name,c_first_name,d_date] - Project [c_last_name,c_first_name,d_date] - BroadcastHashJoin [cs_bill_customer_sk,c_customer_sk] - Project [cs_bill_customer_sk,d_date] - BroadcastHashJoin [cs_sold_date_sk,d_date_sk] - Filter [cs_bill_customer_sk] - ColumnarToRow - InputAdapter - Scan parquet default.catalog_sales [cs_bill_customer_sk,cs_sold_date_sk] - ReusedSubquery [d_date_sk] #1 - InputAdapter - ReusedExchange [d_date_sk,d_date] #3 - InputAdapter - ReusedExchange [c_customer_sk,c_first_name,c_last_name] #4 + BroadcastExchange #4 + WholeStageCodegen (2) + Filter [c_customer_sk] + ColumnarToRow + InputAdapter + Scan parquet default.customer [c_customer_sk,c_first_name,c_last_name] InputAdapter - BroadcastExchange #7 - WholeStageCodegen (11) + BroadcastExchange #5 + WholeStageCodegen (7) HashAggregate [c_last_name,c_first_name,d_date] InputAdapter - Exchange [c_last_name,c_first_name,d_date] #8 - WholeStageCodegen (10) + Exchange [c_last_name,c_first_name,d_date] #6 + WholeStageCodegen (6) HashAggregate [c_last_name,c_first_name,d_date] Project [c_last_name,c_first_name,d_date] - BroadcastHashJoin [ws_bill_customer_sk,c_customer_sk] - Project [ws_bill_customer_sk,d_date] - BroadcastHashJoin [ws_sold_date_sk,d_date_sk] - Filter [ws_bill_customer_sk] + BroadcastHashJoin [cs_bill_customer_sk,c_customer_sk] + Project [cs_bill_customer_sk,d_date] + BroadcastHashJoin [cs_sold_date_sk,d_date_sk] + Filter [cs_bill_customer_sk] ColumnarToRow InputAdapter - Scan parquet default.web_sales [ws_bill_customer_sk,ws_sold_date_sk] + Scan parquet default.catalog_sales [cs_bill_customer_sk,cs_sold_date_sk] ReusedSubquery [d_date_sk] #1 InputAdapter ReusedExchange [d_date_sk,d_date] #3 InputAdapter ReusedExchange [c_customer_sk,c_first_name,c_last_name] #4 + InputAdapter + BroadcastExchange #7 + WholeStageCodegen (11) + HashAggregate [c_last_name,c_first_name,d_date] + InputAdapter + Exchange [c_last_name,c_first_name,d_date] #8 + WholeStageCodegen (10) + HashAggregate [c_last_name,c_first_name,d_date] + Project [c_last_name,c_first_name,d_date] + BroadcastHashJoin [ws_bill_customer_sk,c_customer_sk] + Project [ws_bill_customer_sk,d_date] + BroadcastHashJoin [ws_sold_date_sk,d_date_sk] + Filter [ws_bill_customer_sk] + ColumnarToRow + InputAdapter + Scan parquet default.web_sales [ws_bill_customer_sk,ws_sold_date_sk] + ReusedSubquery [d_date_sk] #1 + InputAdapter + ReusedExchange [d_date_sk,d_date] #3 + InputAdapter + ReusedExchange [c_customer_sk,c_first_name,c_last_name] #4 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q89.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q89.sf100/explain.txt index 9c798856baa66..6325bd574530a 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q89.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q89.sf100/explain.txt @@ -141,7 +141,7 @@ Arguments: [avg(_w0#22) windowspecdefinition(i_category#4, i_brand#2, s_store_na (25) Filter [codegen id : 7] Input [9]: [i_category#4, i_class#3, i_brand#2, s_store_name#14, s_company_name#15, d_moy#12, sum_sales#21, _w0#22, avg_monthly_sales#24] -Condition : (isnotnull(avg_monthly_sales#24) AND (NOT (avg_monthly_sales#24 = 0.000000) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(22,6), true), false)) / promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(38,16), true) > 0.1000000000000000))) +Condition : (isnotnull(avg_monthly_sales#24) AND (NOT (avg_monthly_sales#24 = 0.000000) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(22,6)))) / promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(38,16)) > 0.1000000000000000))) (26) Project [codegen id : 7] Output [8]: [i_category#4, i_class#3, i_brand#2, s_store_name#14, s_company_name#15, d_moy#12, sum_sales#21, avg_monthly_sales#24] @@ -149,7 +149,7 @@ Input [9]: [i_category#4, i_class#3, i_brand#2, s_store_name#14, s_company_name# (27) TakeOrderedAndProject Input [8]: [i_category#4, i_class#3, i_brand#2, s_store_name#14, s_company_name#15, d_moy#12, sum_sales#21, avg_monthly_sales#24] -Arguments: 100, [CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(22,6), true) ASC NULLS FIRST, s_store_name#14 ASC NULLS FIRST], [i_category#4, i_class#3, i_brand#2, s_store_name#14, s_company_name#15, d_moy#12, sum_sales#21, avg_monthly_sales#24] +Arguments: 100, [CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(22,6)) ASC NULLS FIRST, s_store_name#14 ASC NULLS FIRST], [i_category#4, i_class#3, i_brand#2, s_store_name#14, s_company_name#15, d_moy#12, sum_sales#21, avg_monthly_sales#24] ===== Subqueries ===== diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q89/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q89/explain.txt index 4c6124960bb0d..770ab84503645 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q89/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q89/explain.txt @@ -141,7 +141,7 @@ Arguments: [avg(_w0#22) windowspecdefinition(i_category#4, i_brand#2, s_store_na (25) Filter [codegen id : 7] Input [9]: [i_category#4, i_class#3, i_brand#2, s_store_name#14, s_company_name#15, d_moy#12, sum_sales#21, _w0#22, avg_monthly_sales#24] -Condition : (isnotnull(avg_monthly_sales#24) AND (NOT (avg_monthly_sales#24 = 0.000000) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(22,6), true), false)) / promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(38,16), true) > 0.1000000000000000))) +Condition : (isnotnull(avg_monthly_sales#24) AND (NOT (avg_monthly_sales#24 = 0.000000) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(22,6)))) / promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(38,16)) > 0.1000000000000000))) (26) Project [codegen id : 7] Output [8]: [i_category#4, i_class#3, i_brand#2, s_store_name#14, s_company_name#15, d_moy#12, sum_sales#21, avg_monthly_sales#24] @@ -149,7 +149,7 @@ Input [9]: [i_category#4, i_class#3, i_brand#2, s_store_name#14, s_company_name# (27) TakeOrderedAndProject Input [8]: [i_category#4, i_class#3, i_brand#2, s_store_name#14, s_company_name#15, d_moy#12, sum_sales#21, avg_monthly_sales#24] -Arguments: 100, [CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(22,6), true) ASC NULLS FIRST, s_store_name#14 ASC NULLS FIRST], [i_category#4, i_class#3, i_brand#2, s_store_name#14, s_company_name#15, d_moy#12, sum_sales#21, avg_monthly_sales#24] +Arguments: 100, [CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(22,6)) ASC NULLS FIRST, s_store_name#14 ASC NULLS FIRST], [i_category#4, i_class#3, i_brand#2, s_store_name#14, s_company_name#15, d_moy#12, sum_sales#21, avg_monthly_sales#24] ===== Subqueries ===== diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q90.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q90.sf100/explain.txt index 39b6534100574..095c3d531a509 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q90.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q90.sf100/explain.txt @@ -280,6 +280,6 @@ Arguments: IdentityBroadcastMode, [id=#33] Join condition: None (51) Project [codegen id : 10] -Output [1]: [CheckOverflow((promote_precision(cast(amc#18 as decimal(15,4))) / promote_precision(cast(pmc#32 as decimal(15,4)))), DecimalType(35,20), true) AS am_pm_ratio#34] +Output [1]: [CheckOverflow((promote_precision(cast(amc#18 as decimal(15,4))) / promote_precision(cast(pmc#32 as decimal(15,4)))), DecimalType(35,20)) AS am_pm_ratio#34] Input [2]: [amc#18, pmc#32] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q90/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q90/explain.txt index 80ab6fd9d8a3f..e9884d694852d 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q90/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q90/explain.txt @@ -280,6 +280,6 @@ Arguments: IdentityBroadcastMode, [id=#33] Join condition: None (51) Project [codegen id : 10] -Output [1]: [CheckOverflow((promote_precision(cast(amc#18 as decimal(15,4))) / promote_precision(cast(pmc#32 as decimal(15,4)))), DecimalType(35,20), true) AS am_pm_ratio#34] +Output [1]: [CheckOverflow((promote_precision(cast(amc#18 as decimal(15,4))) / promote_precision(cast(pmc#32 as decimal(15,4)))), DecimalType(35,20)) AS am_pm_ratio#34] Input [2]: [amc#18, pmc#32] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q92.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q92.sf100/explain.txt index d13b0f1c9bb91..71aa2bb603946 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q92.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q92.sf100/explain.txt @@ -95,7 +95,7 @@ Input [3]: [ws_item_sk#4, sum#11, count#12] Keys [1]: [ws_item_sk#4] Functions [1]: [avg(UnscaledValue(ws_ext_discount_amt#5))] Aggregate Attributes [1]: [avg(UnscaledValue(ws_ext_discount_amt#5))#14] -Results [2]: [CheckOverflow((1.300000 * promote_precision(cast((avg(UnscaledValue(ws_ext_discount_amt#5))#14 / 100.0) as decimal(11,6)))), DecimalType(14,7), true) AS (1.3 * avg(ws_ext_discount_amt))#15, ws_item_sk#4] +Results [2]: [CheckOverflow((1.300000 * promote_precision(cast((avg(UnscaledValue(ws_ext_discount_amt#5))#14 / 100.0) as decimal(11,6)))), DecimalType(14,7)) AS (1.3 * avg(ws_ext_discount_amt))#15, ws_item_sk#4] (15) Filter Input [2]: [(1.3 * avg(ws_ext_discount_amt))#15, ws_item_sk#4] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q92/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q92/explain.txt index 72c206a372644..bec857eb2489a 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q92/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q92/explain.txt @@ -119,7 +119,7 @@ Input [3]: [ws_item_sk#8, sum#14, count#15] Keys [1]: [ws_item_sk#8] Functions [1]: [avg(UnscaledValue(ws_ext_discount_amt#9))] Aggregate Attributes [1]: [avg(UnscaledValue(ws_ext_discount_amt#9))#17] -Results [2]: [CheckOverflow((1.300000 * promote_precision(cast((avg(UnscaledValue(ws_ext_discount_amt#9))#17 / 100.0) as decimal(11,6)))), DecimalType(14,7), true) AS (1.3 * avg(ws_ext_discount_amt))#18, ws_item_sk#8] +Results [2]: [CheckOverflow((1.300000 * promote_precision(cast((avg(UnscaledValue(ws_ext_discount_amt#9))#17 / 100.0) as decimal(11,6)))), DecimalType(14,7)) AS (1.3 * avg(ws_ext_discount_amt))#18, ws_item_sk#8] (20) Filter [codegen id : 4] Input [2]: [(1.3 * avg(ws_ext_discount_amt))#18, ws_item_sk#8] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q93.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q93.sf100/explain.txt index 01b7b7f5e20c8..3f6b5ffb48a67 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q93.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q93.sf100/explain.txt @@ -109,7 +109,7 @@ Right keys [2]: [ss_item_sk#10, ss_ticket_number#12] Join condition: None (20) Project [codegen id : 6] -Output [2]: [ss_customer_sk#11, CASE WHEN isnotnull(sr_return_quantity#4) THEN CheckOverflow((promote_precision(cast(cast((ss_quantity#13 - sr_return_quantity#4) as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#14 as decimal(12,2)))), DecimalType(18,2), true) ELSE CheckOverflow((promote_precision(cast(cast(ss_quantity#13 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#14 as decimal(12,2)))), DecimalType(18,2), true) END AS act_sales#17] +Output [2]: [ss_customer_sk#11, CASE WHEN isnotnull(sr_return_quantity#4) THEN CheckOverflow((promote_precision(cast((ss_quantity#13 - sr_return_quantity#4) as decimal(12,2))) * promote_precision(cast(ss_sales_price#14 as decimal(12,2)))), DecimalType(18,2)) ELSE CheckOverflow((promote_precision(cast(ss_quantity#13 as decimal(12,2))) * promote_precision(cast(ss_sales_price#14 as decimal(12,2)))), DecimalType(18,2)) END AS act_sales#17] Input [8]: [sr_item_sk#1, sr_ticket_number#3, sr_return_quantity#4, ss_item_sk#10, ss_customer_sk#11, ss_ticket_number#12, ss_quantity#13, ss_sales_price#14] (21) HashAggregate [codegen id : 6] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q93/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q93/explain.txt index 54b9ae752c7a0..11f69606ece91 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q93/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q93/explain.txt @@ -109,7 +109,7 @@ Right keys [1]: [r_reason_sk#14] Join condition: None (20) Project [codegen id : 6] -Output [2]: [ss_customer_sk#2, CASE WHEN isnotnull(sr_return_quantity#11) THEN CheckOverflow((promote_precision(cast(cast((ss_quantity#4 - sr_return_quantity#11) as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#5 as decimal(12,2)))), DecimalType(18,2), true) ELSE CheckOverflow((promote_precision(cast(cast(ss_quantity#4 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_sales_price#5 as decimal(12,2)))), DecimalType(18,2), true) END AS act_sales#17] +Output [2]: [ss_customer_sk#2, CASE WHEN isnotnull(sr_return_quantity#11) THEN CheckOverflow((promote_precision(cast((ss_quantity#4 - sr_return_quantity#11) as decimal(12,2))) * promote_precision(cast(ss_sales_price#5 as decimal(12,2)))), DecimalType(18,2)) ELSE CheckOverflow((promote_precision(cast(ss_quantity#4 as decimal(12,2))) * promote_precision(cast(ss_sales_price#5 as decimal(12,2)))), DecimalType(18,2)) END AS act_sales#17] Input [6]: [ss_customer_sk#2, ss_quantity#4, ss_sales_price#5, sr_reason_sk#9, sr_return_quantity#11, r_reason_sk#14] (21) HashAggregate [codegen id : 6] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q98.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q98.sf100/explain.txt index 310321f5cf372..b3528e4b6881b 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q98.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q98.sf100/explain.txt @@ -123,7 +123,7 @@ Input [8]: [i_item_desc#8, i_category#11, i_class#10, i_current_price#9, itemrev Arguments: [sum(_w1#20) windowspecdefinition(i_class#10, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS _we0#22], [i_class#10] (22) Project [codegen id : 9] -Output [7]: [i_item_desc#8, i_category#11, i_class#10, i_current_price#9, itemrevenue#18, CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(_w0#19) * 100.00), DecimalType(21,2), true) as decimal(27,2))) / promote_precision(_we0#22)), DecimalType(38,17), true) AS revenueratio#23, i_item_id#7] +Output [7]: [i_item_desc#8, i_category#11, i_class#10, i_current_price#9, itemrevenue#18, CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(_w0#19) * 100.00), DecimalType(21,2)) as decimal(27,2))) / promote_precision(_we0#22)), DecimalType(38,17)) AS revenueratio#23, i_item_id#7] Input [9]: [i_item_desc#8, i_category#11, i_class#10, i_current_price#9, itemrevenue#18, _w0#19, _w1#20, i_item_id#7, _we0#22] (23) Exchange diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q98/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q98/explain.txt index 95f856b398707..ec1192af4d398 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q98/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q98/explain.txt @@ -108,7 +108,7 @@ Input [8]: [i_item_desc#7, i_category#10, i_class#9, i_current_price#8, itemreve Arguments: [sum(_w1#19) windowspecdefinition(i_class#9, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS _we0#21], [i_class#9] (19) Project [codegen id : 6] -Output [7]: [i_item_desc#7, i_category#10, i_class#9, i_current_price#8, itemrevenue#17, CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(_w0#18) * 100.00), DecimalType(21,2), true) as decimal(27,2))) / promote_precision(_we0#21)), DecimalType(38,17), true) AS revenueratio#22, i_item_id#6] +Output [7]: [i_item_desc#7, i_category#10, i_class#9, i_current_price#8, itemrevenue#17, CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(_w0#18) * 100.00), DecimalType(21,2)) as decimal(27,2))) / promote_precision(_we0#21)), DecimalType(38,17)) AS revenueratio#22, i_item_id#6] Input [9]: [i_item_desc#7, i_category#10, i_class#9, i_current_price#8, itemrevenue#17, _w0#18, _w1#19, i_item_id#6, _we0#21] (20) Exchange diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q11.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q11.sf100/explain.txt index 732f510b80d1b..7591e3bdb30c7 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q11.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q11.sf100/explain.txt @@ -149,7 +149,7 @@ Input [12]: [ss_customer_sk#1, ss_ext_discount_amt#2, ss_ext_list_price#3, d_yea (16) HashAggregate [codegen id : 6] Input [10]: [c_customer_id#10, c_first_name#11, c_last_name#12, c_preferred_cust_flag#13, c_birth_country#14, c_login#15, c_email_address#16, ss_ext_discount_amt#2, ss_ext_list_price#3, d_year#7] Keys [8]: [c_customer_id#10, c_first_name#11, c_last_name#12, d_year#7, c_preferred_cust_flag#13, c_birth_country#14, c_login#15, c_email_address#16] -Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#3 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#2 as decimal(8,2)))), DecimalType(8,2), true)))] +Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#3 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#2 as decimal(8,2)))), DecimalType(8,2))))] Aggregate Attributes [1]: [sum#18] Results [9]: [c_customer_id#10, c_first_name#11, c_last_name#12, d_year#7, c_preferred_cust_flag#13, c_birth_country#14, c_login#15, c_email_address#16, sum#19] @@ -160,9 +160,9 @@ Arguments: hashpartitioning(c_customer_id#10, c_first_name#11, c_last_name#12, d (18) HashAggregate [codegen id : 7] Input [9]: [c_customer_id#10, c_first_name#11, c_last_name#12, d_year#7, c_preferred_cust_flag#13, c_birth_country#14, c_login#15, c_email_address#16, sum#19] Keys [8]: [c_customer_id#10, c_first_name#11, c_last_name#12, d_year#7, c_preferred_cust_flag#13, c_birth_country#14, c_login#15, c_email_address#16] -Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#3 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#2 as decimal(8,2)))), DecimalType(8,2), true)))] -Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#3 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#2 as decimal(8,2)))), DecimalType(8,2), true)))#21] -Results [2]: [c_customer_id#10 AS customer_id#22, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#3 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#2 as decimal(8,2)))), DecimalType(8,2), true)))#21,18,2) AS year_total#23] +Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#3 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#2 as decimal(8,2)))), DecimalType(8,2))))] +Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#3 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#2 as decimal(8,2)))), DecimalType(8,2))))#21] +Results [2]: [c_customer_id#10 AS customer_id#22, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#3 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#2 as decimal(8,2)))), DecimalType(8,2))))#21,18,2) AS year_total#23] (19) Filter [codegen id : 7] Input [2]: [customer_id#22, year_total#23] @@ -230,7 +230,7 @@ Input [12]: [ss_customer_sk#25, ss_ext_discount_amt#26, ss_ext_list_price#27, d_ (34) HashAggregate [codegen id : 14] Input [10]: [c_customer_id#34, c_first_name#35, c_last_name#36, c_preferred_cust_flag#37, c_birth_country#38, c_login#39, c_email_address#40, ss_ext_discount_amt#26, ss_ext_list_price#27, d_year#31] Keys [8]: [c_customer_id#34, c_first_name#35, c_last_name#36, d_year#31, c_preferred_cust_flag#37, c_birth_country#38, c_login#39, c_email_address#40] -Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#27 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#26 as decimal(8,2)))), DecimalType(8,2), true)))] +Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#27 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#26 as decimal(8,2)))), DecimalType(8,2))))] Aggregate Attributes [1]: [sum#41] Results [9]: [c_customer_id#34, c_first_name#35, c_last_name#36, d_year#31, c_preferred_cust_flag#37, c_birth_country#38, c_login#39, c_email_address#40, sum#42] @@ -241,9 +241,9 @@ Arguments: hashpartitioning(c_customer_id#34, c_first_name#35, c_last_name#36, d (36) HashAggregate [codegen id : 15] Input [9]: [c_customer_id#34, c_first_name#35, c_last_name#36, d_year#31, c_preferred_cust_flag#37, c_birth_country#38, c_login#39, c_email_address#40, sum#42] Keys [8]: [c_customer_id#34, c_first_name#35, c_last_name#36, d_year#31, c_preferred_cust_flag#37, c_birth_country#38, c_login#39, c_email_address#40] -Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#27 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#26 as decimal(8,2)))), DecimalType(8,2), true)))] -Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#27 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#26 as decimal(8,2)))), DecimalType(8,2), true)))#21] -Results [5]: [c_customer_id#34 AS customer_id#44, c_first_name#35 AS customer_first_name#45, c_last_name#36 AS customer_last_name#46, c_email_address#40 AS customer_email_address#47, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#27 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#26 as decimal(8,2)))), DecimalType(8,2), true)))#21,18,2) AS year_total#48] +Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#27 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#26 as decimal(8,2)))), DecimalType(8,2))))] +Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#27 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#26 as decimal(8,2)))), DecimalType(8,2))))#21] +Results [5]: [c_customer_id#34 AS customer_id#44, c_first_name#35 AS customer_first_name#45, c_last_name#36 AS customer_last_name#46, c_email_address#40 AS customer_email_address#47, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#27 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#26 as decimal(8,2)))), DecimalType(8,2))))#21,18,2) AS year_total#48] (37) Exchange Input [5]: [customer_id#44, customer_first_name#45, customer_last_name#46, customer_email_address#47, year_total#48] @@ -312,7 +312,7 @@ Input [12]: [ws_bill_customer_sk#50, ws_ext_discount_amt#51, ws_ext_list_price#5 (52) HashAggregate [codegen id : 23] Input [10]: [c_customer_id#58, c_first_name#59, c_last_name#60, c_preferred_cust_flag#61, c_birth_country#62, c_login#63, c_email_address#64, ws_ext_discount_amt#51, ws_ext_list_price#52, d_year#55] Keys [8]: [c_customer_id#58, c_first_name#59, c_last_name#60, c_preferred_cust_flag#61, c_birth_country#62, c_login#63, c_email_address#64, d_year#55] -Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#52 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#51 as decimal(8,2)))), DecimalType(8,2), true)))] +Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#52 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#51 as decimal(8,2)))), DecimalType(8,2))))] Aggregate Attributes [1]: [sum#65] Results [9]: [c_customer_id#58, c_first_name#59, c_last_name#60, c_preferred_cust_flag#61, c_birth_country#62, c_login#63, c_email_address#64, d_year#55, sum#66] @@ -323,9 +323,9 @@ Arguments: hashpartitioning(c_customer_id#58, c_first_name#59, c_last_name#60, c (54) HashAggregate [codegen id : 24] Input [9]: [c_customer_id#58, c_first_name#59, c_last_name#60, c_preferred_cust_flag#61, c_birth_country#62, c_login#63, c_email_address#64, d_year#55, sum#66] Keys [8]: [c_customer_id#58, c_first_name#59, c_last_name#60, c_preferred_cust_flag#61, c_birth_country#62, c_login#63, c_email_address#64, d_year#55] -Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#52 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#51 as decimal(8,2)))), DecimalType(8,2), true)))] -Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#52 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#51 as decimal(8,2)))), DecimalType(8,2), true)))#68] -Results [2]: [c_customer_id#58 AS customer_id#69, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#52 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#51 as decimal(8,2)))), DecimalType(8,2), true)))#68,18,2) AS year_total#70] +Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#52 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#51 as decimal(8,2)))), DecimalType(8,2))))] +Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#52 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#51 as decimal(8,2)))), DecimalType(8,2))))#68] +Results [2]: [c_customer_id#58 AS customer_id#69, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#52 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#51 as decimal(8,2)))), DecimalType(8,2))))#68,18,2) AS year_total#70] (55) Filter [codegen id : 24] Input [2]: [customer_id#69, year_total#70] @@ -402,7 +402,7 @@ Input [12]: [ws_bill_customer_sk#72, ws_ext_discount_amt#73, ws_ext_list_price#7 (72) HashAggregate [codegen id : 32] Input [10]: [c_customer_id#80, c_first_name#81, c_last_name#82, c_preferred_cust_flag#83, c_birth_country#84, c_login#85, c_email_address#86, ws_ext_discount_amt#73, ws_ext_list_price#74, d_year#77] Keys [8]: [c_customer_id#80, c_first_name#81, c_last_name#82, c_preferred_cust_flag#83, c_birth_country#84, c_login#85, c_email_address#86, d_year#77] -Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#74 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#73 as decimal(8,2)))), DecimalType(8,2), true)))] +Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#74 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#73 as decimal(8,2)))), DecimalType(8,2))))] Aggregate Attributes [1]: [sum#87] Results [9]: [c_customer_id#80, c_first_name#81, c_last_name#82, c_preferred_cust_flag#83, c_birth_country#84, c_login#85, c_email_address#86, d_year#77, sum#88] @@ -413,9 +413,9 @@ Arguments: hashpartitioning(c_customer_id#80, c_first_name#81, c_last_name#82, c (74) HashAggregate [codegen id : 33] Input [9]: [c_customer_id#80, c_first_name#81, c_last_name#82, c_preferred_cust_flag#83, c_birth_country#84, c_login#85, c_email_address#86, d_year#77, sum#88] Keys [8]: [c_customer_id#80, c_first_name#81, c_last_name#82, c_preferred_cust_flag#83, c_birth_country#84, c_login#85, c_email_address#86, d_year#77] -Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#74 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#73 as decimal(8,2)))), DecimalType(8,2), true)))] -Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#74 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#73 as decimal(8,2)))), DecimalType(8,2), true)))#68] -Results [2]: [c_customer_id#80 AS customer_id#90, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#74 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#73 as decimal(8,2)))), DecimalType(8,2), true)))#68,18,2) AS year_total#91] +Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#74 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#73 as decimal(8,2)))), DecimalType(8,2))))] +Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#74 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#73 as decimal(8,2)))), DecimalType(8,2))))#68] +Results [2]: [c_customer_id#80 AS customer_id#90, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#74 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#73 as decimal(8,2)))), DecimalType(8,2))))#68,18,2) AS year_total#91] (75) Exchange Input [2]: [customer_id#90, year_total#91] @@ -428,7 +428,7 @@ Arguments: [customer_id#90 ASC NULLS FIRST], false, 0 (77) SortMergeJoin [codegen id : 35] Left keys [1]: [customer_id#22] Right keys [1]: [customer_id#90] -Join condition: (CASE WHEN (year_total#70 > 0.00) THEN CheckOverflow((promote_precision(year_total#91) / promote_precision(year_total#70)), DecimalType(38,20), true) ELSE 0E-20 END > CASE WHEN (year_total#23 > 0.00) THEN CheckOverflow((promote_precision(year_total#48) / promote_precision(year_total#23)), DecimalType(38,20), true) ELSE 0E-20 END) +Join condition: (CASE WHEN (year_total#70 > 0.00) THEN CheckOverflow((promote_precision(year_total#91) / promote_precision(year_total#70)), DecimalType(38,20)) ELSE 0E-20 END > CASE WHEN (year_total#23 > 0.00) THEN CheckOverflow((promote_precision(year_total#48) / promote_precision(year_total#23)), DecimalType(38,20)) ELSE 0E-20 END) (78) Project [codegen id : 35] Output [4]: [customer_id#44, customer_first_name#45, customer_last_name#46, customer_email_address#47] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q11.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q11.sf100/simplified.txt index cc47c3516b497..a97e1ed828a9c 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q11.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q11.sf100/simplified.txt @@ -16,7 +16,7 @@ TakeOrderedAndProject [customer_id,customer_first_name,customer_last_name,custom Exchange [customer_id] #1 WholeStageCodegen (7) Filter [year_total] - HashAggregate [c_customer_id,c_first_name,c_last_name,d_year,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt as decimal(8,2)))), DecimalType(8,2), true))),customer_id,year_total,sum] + HashAggregate [c_customer_id,c_first_name,c_last_name,d_year,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt as decimal(8,2)))), DecimalType(8,2)))),customer_id,year_total,sum] InputAdapter Exchange [c_customer_id,c_first_name,c_last_name,d_year,c_preferred_cust_flag,c_birth_country,c_login,c_email_address] #2 WholeStageCodegen (6) @@ -60,7 +60,7 @@ TakeOrderedAndProject [customer_id,customer_first_name,customer_last_name,custom InputAdapter Exchange [customer_id] #6 WholeStageCodegen (15) - HashAggregate [c_customer_id,c_first_name,c_last_name,d_year,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt as decimal(8,2)))), DecimalType(8,2), true))),customer_id,customer_first_name,customer_last_name,customer_email_address,year_total,sum] + HashAggregate [c_customer_id,c_first_name,c_last_name,d_year,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt as decimal(8,2)))), DecimalType(8,2)))),customer_id,customer_first_name,customer_last_name,customer_email_address,year_total,sum] InputAdapter Exchange [c_customer_id,c_first_name,c_last_name,d_year,c_preferred_cust_flag,c_birth_country,c_login,c_email_address] #7 WholeStageCodegen (14) @@ -100,7 +100,7 @@ TakeOrderedAndProject [customer_id,customer_first_name,customer_last_name,custom Exchange [customer_id] #10 WholeStageCodegen (24) Filter [year_total] - HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt as decimal(8,2)))), DecimalType(8,2), true))),customer_id,year_total,sum] + HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt as decimal(8,2)))), DecimalType(8,2)))),customer_id,year_total,sum] InputAdapter Exchange [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year] #11 WholeStageCodegen (23) @@ -133,7 +133,7 @@ TakeOrderedAndProject [customer_id,customer_first_name,customer_last_name,custom InputAdapter Exchange [customer_id] #13 WholeStageCodegen (33) - HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt as decimal(8,2)))), DecimalType(8,2), true))),customer_id,year_total,sum] + HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt as decimal(8,2)))), DecimalType(8,2)))),customer_id,year_total,sum] InputAdapter Exchange [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year] #14 WholeStageCodegen (32) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q11/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q11/explain.txt index cb7fe2568123f..69d3f4ac97247 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q11/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q11/explain.txt @@ -129,7 +129,7 @@ Input [12]: [c_customer_id#2, c_first_name#3, c_last_name#4, c_preferred_cust_fl (13) HashAggregate [codegen id : 3] Input [10]: [c_customer_id#2, c_first_name#3, c_last_name#4, c_preferred_cust_flag#5, c_birth_country#6, c_login#7, c_email_address#8, ss_ext_discount_amt#10, ss_ext_list_price#11, d_year#16] Keys [8]: [c_customer_id#2, c_first_name#3, c_last_name#4, d_year#16, c_preferred_cust_flag#5, c_birth_country#6, c_login#7, c_email_address#8] -Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#11 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#10 as decimal(8,2)))), DecimalType(8,2), true)))] +Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#11 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#10 as decimal(8,2)))), DecimalType(8,2))))] Aggregate Attributes [1]: [sum#17] Results [9]: [c_customer_id#2, c_first_name#3, c_last_name#4, d_year#16, c_preferred_cust_flag#5, c_birth_country#6, c_login#7, c_email_address#8, sum#18] @@ -140,9 +140,9 @@ Arguments: hashpartitioning(c_customer_id#2, c_first_name#3, c_last_name#4, d_ye (15) HashAggregate [codegen id : 16] Input [9]: [c_customer_id#2, c_first_name#3, c_last_name#4, d_year#16, c_preferred_cust_flag#5, c_birth_country#6, c_login#7, c_email_address#8, sum#18] Keys [8]: [c_customer_id#2, c_first_name#3, c_last_name#4, d_year#16, c_preferred_cust_flag#5, c_birth_country#6, c_login#7, c_email_address#8] -Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#11 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#10 as decimal(8,2)))), DecimalType(8,2), true)))] -Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#11 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#10 as decimal(8,2)))), DecimalType(8,2), true)))#20] -Results [2]: [c_customer_id#2 AS customer_id#21, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#11 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#10 as decimal(8,2)))), DecimalType(8,2), true)))#20,18,2) AS year_total#22] +Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#11 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#10 as decimal(8,2)))), DecimalType(8,2))))] +Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#11 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#10 as decimal(8,2)))), DecimalType(8,2))))#20] +Results [2]: [c_customer_id#2 AS customer_id#21, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#11 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#10 as decimal(8,2)))), DecimalType(8,2))))#20,18,2) AS year_total#22] (16) Filter [codegen id : 16] Input [2]: [customer_id#21, year_total#22] @@ -205,7 +205,7 @@ Input [12]: [c_customer_id#24, c_first_name#25, c_last_name#26, c_preferred_cust (29) HashAggregate [codegen id : 6] Input [10]: [c_customer_id#24, c_first_name#25, c_last_name#26, c_preferred_cust_flag#27, c_birth_country#28, c_login#29, c_email_address#30, ss_ext_discount_amt#32, ss_ext_list_price#33, d_year#38] Keys [8]: [c_customer_id#24, c_first_name#25, c_last_name#26, d_year#38, c_preferred_cust_flag#27, c_birth_country#28, c_login#29, c_email_address#30] -Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#33 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#32 as decimal(8,2)))), DecimalType(8,2), true)))] +Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#33 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#32 as decimal(8,2)))), DecimalType(8,2))))] Aggregate Attributes [1]: [sum#39] Results [9]: [c_customer_id#24, c_first_name#25, c_last_name#26, d_year#38, c_preferred_cust_flag#27, c_birth_country#28, c_login#29, c_email_address#30, sum#40] @@ -216,9 +216,9 @@ Arguments: hashpartitioning(c_customer_id#24, c_first_name#25, c_last_name#26, d (31) HashAggregate [codegen id : 7] Input [9]: [c_customer_id#24, c_first_name#25, c_last_name#26, d_year#38, c_preferred_cust_flag#27, c_birth_country#28, c_login#29, c_email_address#30, sum#40] Keys [8]: [c_customer_id#24, c_first_name#25, c_last_name#26, d_year#38, c_preferred_cust_flag#27, c_birth_country#28, c_login#29, c_email_address#30] -Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#33 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#32 as decimal(8,2)))), DecimalType(8,2), true)))] -Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#33 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#32 as decimal(8,2)))), DecimalType(8,2), true)))#20] -Results [5]: [c_customer_id#24 AS customer_id#42, c_first_name#25 AS customer_first_name#43, c_last_name#26 AS customer_last_name#44, c_email_address#30 AS customer_email_address#45, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#33 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#32 as decimal(8,2)))), DecimalType(8,2), true)))#20,18,2) AS year_total#46] +Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#33 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#32 as decimal(8,2)))), DecimalType(8,2))))] +Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#33 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#32 as decimal(8,2)))), DecimalType(8,2))))#20] +Results [5]: [c_customer_id#24 AS customer_id#42, c_first_name#25 AS customer_first_name#43, c_last_name#26 AS customer_last_name#44, c_email_address#30 AS customer_email_address#45, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price#33 as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt#32 as decimal(8,2)))), DecimalType(8,2))))#20,18,2) AS year_total#46] (32) BroadcastExchange Input [5]: [customer_id#42, customer_first_name#43, customer_last_name#44, customer_email_address#45, year_total#46] @@ -286,7 +286,7 @@ Input [12]: [c_customer_id#49, c_first_name#50, c_last_name#51, c_preferred_cust (46) HashAggregate [codegen id : 10] Input [10]: [c_customer_id#49, c_first_name#50, c_last_name#51, c_preferred_cust_flag#52, c_birth_country#53, c_login#54, c_email_address#55, ws_ext_discount_amt#57, ws_ext_list_price#58, d_year#62] Keys [8]: [c_customer_id#49, c_first_name#50, c_last_name#51, c_preferred_cust_flag#52, c_birth_country#53, c_login#54, c_email_address#55, d_year#62] -Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#58 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#57 as decimal(8,2)))), DecimalType(8,2), true)))] +Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#58 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#57 as decimal(8,2)))), DecimalType(8,2))))] Aggregate Attributes [1]: [sum#63] Results [9]: [c_customer_id#49, c_first_name#50, c_last_name#51, c_preferred_cust_flag#52, c_birth_country#53, c_login#54, c_email_address#55, d_year#62, sum#64] @@ -297,9 +297,9 @@ Arguments: hashpartitioning(c_customer_id#49, c_first_name#50, c_last_name#51, c (48) HashAggregate [codegen id : 11] Input [9]: [c_customer_id#49, c_first_name#50, c_last_name#51, c_preferred_cust_flag#52, c_birth_country#53, c_login#54, c_email_address#55, d_year#62, sum#64] Keys [8]: [c_customer_id#49, c_first_name#50, c_last_name#51, c_preferred_cust_flag#52, c_birth_country#53, c_login#54, c_email_address#55, d_year#62] -Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#58 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#57 as decimal(8,2)))), DecimalType(8,2), true)))] -Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#58 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#57 as decimal(8,2)))), DecimalType(8,2), true)))#66] -Results [2]: [c_customer_id#49 AS customer_id#67, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#58 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#57 as decimal(8,2)))), DecimalType(8,2), true)))#66,18,2) AS year_total#68] +Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#58 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#57 as decimal(8,2)))), DecimalType(8,2))))] +Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#58 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#57 as decimal(8,2)))), DecimalType(8,2))))#66] +Results [2]: [c_customer_id#49 AS customer_id#67, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#58 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#57 as decimal(8,2)))), DecimalType(8,2))))#66,18,2) AS year_total#68] (49) Filter [codegen id : 11] Input [2]: [customer_id#67, year_total#68] @@ -375,7 +375,7 @@ Input [12]: [c_customer_id#71, c_first_name#72, c_last_name#73, c_preferred_cust (65) HashAggregate [codegen id : 14] Input [10]: [c_customer_id#71, c_first_name#72, c_last_name#73, c_preferred_cust_flag#74, c_birth_country#75, c_login#76, c_email_address#77, ws_ext_discount_amt#79, ws_ext_list_price#80, d_year#84] Keys [8]: [c_customer_id#71, c_first_name#72, c_last_name#73, c_preferred_cust_flag#74, c_birth_country#75, c_login#76, c_email_address#77, d_year#84] -Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#80 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#79 as decimal(8,2)))), DecimalType(8,2), true)))] +Functions [1]: [partial_sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#80 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#79 as decimal(8,2)))), DecimalType(8,2))))] Aggregate Attributes [1]: [sum#85] Results [9]: [c_customer_id#71, c_first_name#72, c_last_name#73, c_preferred_cust_flag#74, c_birth_country#75, c_login#76, c_email_address#77, d_year#84, sum#86] @@ -386,9 +386,9 @@ Arguments: hashpartitioning(c_customer_id#71, c_first_name#72, c_last_name#73, c (67) HashAggregate [codegen id : 15] Input [9]: [c_customer_id#71, c_first_name#72, c_last_name#73, c_preferred_cust_flag#74, c_birth_country#75, c_login#76, c_email_address#77, d_year#84, sum#86] Keys [8]: [c_customer_id#71, c_first_name#72, c_last_name#73, c_preferred_cust_flag#74, c_birth_country#75, c_login#76, c_email_address#77, d_year#84] -Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#80 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#79 as decimal(8,2)))), DecimalType(8,2), true)))] -Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#80 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#79 as decimal(8,2)))), DecimalType(8,2), true)))#66] -Results [2]: [c_customer_id#71 AS customer_id#88, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#80 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#79 as decimal(8,2)))), DecimalType(8,2), true)))#66,18,2) AS year_total#89] +Functions [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#80 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#79 as decimal(8,2)))), DecimalType(8,2))))] +Aggregate Attributes [1]: [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#80 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#79 as decimal(8,2)))), DecimalType(8,2))))#66] +Results [2]: [c_customer_id#71 AS customer_id#88, MakeDecimal(sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price#80 as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt#79 as decimal(8,2)))), DecimalType(8,2))))#66,18,2) AS year_total#89] (68) BroadcastExchange Input [2]: [customer_id#88, year_total#89] @@ -397,7 +397,7 @@ Arguments: HashedRelationBroadcastMode(List(input[0, string, true]),false), [id= (69) BroadcastHashJoin [codegen id : 16] Left keys [1]: [customer_id#21] Right keys [1]: [customer_id#88] -Join condition: (CASE WHEN (year_total#68 > 0.00) THEN CheckOverflow((promote_precision(year_total#89) / promote_precision(year_total#68)), DecimalType(38,20), true) ELSE 0E-20 END > CASE WHEN (year_total#22 > 0.00) THEN CheckOverflow((promote_precision(year_total#46) / promote_precision(year_total#22)), DecimalType(38,20), true) ELSE 0E-20 END) +Join condition: (CASE WHEN (year_total#68 > 0.00) THEN CheckOverflow((promote_precision(year_total#89) / promote_precision(year_total#68)), DecimalType(38,20)) ELSE 0E-20 END > CASE WHEN (year_total#22 > 0.00) THEN CheckOverflow((promote_precision(year_total#46) / promote_precision(year_total#22)), DecimalType(38,20)) ELSE 0E-20 END) (70) Project [codegen id : 16] Output [4]: [customer_id#42, customer_first_name#43, customer_last_name#44, customer_email_address#45] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q11/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q11/simplified.txt index 5fc4dacd55273..91974a295b774 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q11/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q11/simplified.txt @@ -6,7 +6,7 @@ TakeOrderedAndProject [customer_id,customer_first_name,customer_last_name,custom BroadcastHashJoin [customer_id,customer_id] BroadcastHashJoin [customer_id,customer_id] Filter [year_total] - HashAggregate [c_customer_id,c_first_name,c_last_name,d_year,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt as decimal(8,2)))), DecimalType(8,2), true))),customer_id,year_total,sum] + HashAggregate [c_customer_id,c_first_name,c_last_name,d_year,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt as decimal(8,2)))), DecimalType(8,2)))),customer_id,year_total,sum] InputAdapter Exchange [c_customer_id,c_first_name,c_last_name,d_year,c_preferred_cust_flag,c_birth_country,c_login,c_email_address] #1 WholeStageCodegen (3) @@ -38,7 +38,7 @@ TakeOrderedAndProject [customer_id,customer_first_name,customer_last_name,custom InputAdapter BroadcastExchange #4 WholeStageCodegen (7) - HashAggregate [c_customer_id,c_first_name,c_last_name,d_year,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt as decimal(8,2)))), DecimalType(8,2), true))),customer_id,customer_first_name,customer_last_name,customer_email_address,year_total,sum] + HashAggregate [c_customer_id,c_first_name,c_last_name,d_year,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ss_ext_list_price as decimal(8,2))) - promote_precision(cast(ss_ext_discount_amt as decimal(8,2)))), DecimalType(8,2)))),customer_id,customer_first_name,customer_last_name,customer_email_address,year_total,sum] InputAdapter Exchange [c_customer_id,c_first_name,c_last_name,d_year,c_preferred_cust_flag,c_birth_country,c_login,c_email_address] #5 WholeStageCodegen (6) @@ -71,7 +71,7 @@ TakeOrderedAndProject [customer_id,customer_first_name,customer_last_name,custom BroadcastExchange #8 WholeStageCodegen (11) Filter [year_total] - HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt as decimal(8,2)))), DecimalType(8,2), true))),customer_id,year_total,sum] + HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt as decimal(8,2)))), DecimalType(8,2)))),customer_id,year_total,sum] InputAdapter Exchange [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year] #9 WholeStageCodegen (10) @@ -97,7 +97,7 @@ TakeOrderedAndProject [customer_id,customer_first_name,customer_last_name,custom InputAdapter BroadcastExchange #11 WholeStageCodegen (15) - HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt as decimal(8,2)))), DecimalType(8,2), true))),customer_id,year_total,sum] + HashAggregate [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year,sum] [sum(UnscaledValue(CheckOverflow((promote_precision(cast(ws_ext_list_price as decimal(8,2))) - promote_precision(cast(ws_ext_discount_amt as decimal(8,2)))), DecimalType(8,2)))),customer_id,year_total,sum] InputAdapter Exchange [c_customer_id,c_first_name,c_last_name,c_preferred_cust_flag,c_birth_country,c_login,c_email_address,d_year] #12 WholeStageCodegen (14) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q12.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q12.sf100/explain.txt index 40a9cea61aecc..40793508f4786 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q12.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q12.sf100/explain.txt @@ -121,7 +121,7 @@ Input [8]: [i_item_id#7, i_item_desc#8, i_category#11, i_class#10, i_current_pri Arguments: [sum(_w1#20) windowspecdefinition(i_class#10, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS _we0#22], [i_class#10] (22) Project [codegen id : 9] -Output [7]: [i_item_id#7, i_item_desc#8, i_category#11, i_class#10, i_current_price#9, itemrevenue#18, CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(_w0#19) * 100.00), DecimalType(21,2), true) as decimal(27,2))) / promote_precision(_we0#22)), DecimalType(38,17), true) AS revenueratio#23] +Output [7]: [i_item_id#7, i_item_desc#8, i_category#11, i_class#10, i_current_price#9, itemrevenue#18, CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(_w0#19) * 100.00), DecimalType(21,2)) as decimal(27,2))) / promote_precision(_we0#22)), DecimalType(38,17)) AS revenueratio#23] Input [9]: [i_item_id#7, i_item_desc#8, i_category#11, i_class#10, i_current_price#9, itemrevenue#18, _w0#19, _w1#20, _we0#22] (23) TakeOrderedAndProject diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q12/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q12/explain.txt index 479a27f8fee47..02f8baa5a0b81 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q12/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q12/explain.txt @@ -106,7 +106,7 @@ Input [8]: [i_item_id#6, i_item_desc#7, i_category#10, i_class#9, i_current_pric Arguments: [sum(_w1#19) windowspecdefinition(i_class#9, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS _we0#21], [i_class#9] (19) Project [codegen id : 6] -Output [7]: [i_item_id#6, i_item_desc#7, i_category#10, i_class#9, i_current_price#8, itemrevenue#17, CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(_w0#18) * 100.00), DecimalType(21,2), true) as decimal(27,2))) / promote_precision(_we0#21)), DecimalType(38,17), true) AS revenueratio#22] +Output [7]: [i_item_id#6, i_item_desc#7, i_category#10, i_class#9, i_current_price#8, itemrevenue#17, CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(_w0#18) * 100.00), DecimalType(21,2)) as decimal(27,2))) / promote_precision(_we0#21)), DecimalType(38,17)) AS revenueratio#22] Input [9]: [i_item_id#6, i_item_desc#7, i_category#10, i_class#9, i_current_price#8, itemrevenue#17, _w0#18, _w1#19, _we0#21] (20) TakeOrderedAndProject diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14.sf100/explain.txt index 5c3fbb7946f1f..92b80b4085c67 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14.sf100/explain.txt @@ -1,106 +1,103 @@ == Physical Plan == -TakeOrderedAndProject (102) -+- * BroadcastHashJoin Inner BuildRight (101) - :- * Filter (81) - : +- * HashAggregate (80) - : +- Exchange (79) - : +- * HashAggregate (78) - : +- * Project (77) - : +- * BroadcastHashJoin Inner BuildRight (76) - : :- * Project (66) - : : +- * BroadcastHashJoin Inner BuildRight (65) - : : :- * SortMergeJoin LeftSemi (63) +TakeOrderedAndProject (99) ++- * BroadcastHashJoin Inner BuildRight (98) + :- * Filter (78) + : +- * HashAggregate (77) + : +- Exchange (76) + : +- * HashAggregate (75) + : +- * Project (74) + : +- * BroadcastHashJoin Inner BuildRight (73) + : :- * Project (63) + : : +- * BroadcastHashJoin Inner BuildRight (62) + : : :- * SortMergeJoin LeftSemi (60) : : : :- * Sort (5) : : : : +- Exchange (4) : : : : +- * Filter (3) : : : : +- * ColumnarToRow (2) : : : : +- Scan parquet default.store_sales (1) - : : : +- * Sort (62) - : : : +- Exchange (61) - : : : +- * Project (60) - : : : +- * BroadcastHashJoin Inner BuildRight (59) + : : : +- * Sort (59) + : : : +- Exchange (58) + : : : +- * Project (57) + : : : +- * BroadcastHashJoin Inner BuildRight (56) : : : :- * Filter (8) : : : : +- * ColumnarToRow (7) : : : : +- Scan parquet default.item (6) - : : : +- BroadcastExchange (58) - : : : +- * HashAggregate (57) - : : : +- Exchange (56) - : : : +- * HashAggregate (55) - : : : +- * SortMergeJoin LeftSemi (54) - : : : :- * Sort (42) - : : : : +- Exchange (41) - : : : : +- * HashAggregate (40) - : : : : +- Exchange (39) - : : : : +- * HashAggregate (38) - : : : : +- * Project (37) - : : : : +- * BroadcastHashJoin Inner BuildRight (36) - : : : : :- * Project (14) - : : : : : +- * BroadcastHashJoin Inner BuildRight (13) - : : : : : :- * Filter (11) - : : : : : : +- * ColumnarToRow (10) - : : : : : : +- Scan parquet default.store_sales (9) - : : : : : +- ReusedExchange (12) - : : : : +- BroadcastExchange (35) - : : : : +- * SortMergeJoin LeftSemi (34) - : : : : :- * Sort (19) - : : : : : +- Exchange (18) - : : : : : +- * Filter (17) - : : : : : +- * ColumnarToRow (16) - : : : : : +- Scan parquet default.item (15) - : : : : +- * Sort (33) - : : : : +- Exchange (32) - : : : : +- * Project (31) - : : : : +- * BroadcastHashJoin Inner BuildRight (30) - : : : : :- * Project (25) - : : : : : +- * BroadcastHashJoin Inner BuildRight (24) - : : : : : :- * Filter (22) - : : : : : : +- * ColumnarToRow (21) - : : : : : : +- Scan parquet default.catalog_sales (20) - : : : : : +- ReusedExchange (23) - : : : : +- BroadcastExchange (29) - : : : : +- * Filter (28) - : : : : +- * ColumnarToRow (27) - : : : : +- Scan parquet default.item (26) - : : : +- * Sort (53) - : : : +- Exchange (52) - : : : +- * Project (51) - : : : +- * BroadcastHashJoin Inner BuildRight (50) - : : : :- * Project (48) - : : : : +- * BroadcastHashJoin Inner BuildRight (47) - : : : : :- * Filter (45) - : : : : : +- * ColumnarToRow (44) - : : : : : +- Scan parquet default.web_sales (43) - : : : : +- ReusedExchange (46) - : : : +- ReusedExchange (49) - : : +- ReusedExchange (64) - : +- BroadcastExchange (75) - : +- * SortMergeJoin LeftSemi (74) - : :- * Sort (71) - : : +- Exchange (70) - : : +- * Filter (69) - : : +- * ColumnarToRow (68) - : : +- Scan parquet default.item (67) - : +- * Sort (73) - : +- ReusedExchange (72) - +- BroadcastExchange (100) - +- * Filter (99) - +- * HashAggregate (98) - +- Exchange (97) - +- * HashAggregate (96) - +- * Project (95) - +- * BroadcastHashJoin Inner BuildRight (94) - :- * Project (92) - : +- * BroadcastHashJoin Inner BuildRight (91) - : :- * SortMergeJoin LeftSemi (89) - : : :- * Sort (86) - : : : +- Exchange (85) - : : : +- * Filter (84) - : : : +- * ColumnarToRow (83) - : : : +- Scan parquet default.store_sales (82) - : : +- * Sort (88) - : : +- ReusedExchange (87) - : +- ReusedExchange (90) - +- ReusedExchange (93) + : : : +- BroadcastExchange (55) + : : : +- * SortMergeJoin LeftSemi (54) + : : : :- * Sort (42) + : : : : +- Exchange (41) + : : : : +- * HashAggregate (40) + : : : : +- Exchange (39) + : : : : +- * HashAggregate (38) + : : : : +- * Project (37) + : : : : +- * BroadcastHashJoin Inner BuildRight (36) + : : : : :- * Project (14) + : : : : : +- * BroadcastHashJoin Inner BuildRight (13) + : : : : : :- * Filter (11) + : : : : : : +- * ColumnarToRow (10) + : : : : : : +- Scan parquet default.store_sales (9) + : : : : : +- ReusedExchange (12) + : : : : +- BroadcastExchange (35) + : : : : +- * SortMergeJoin LeftSemi (34) + : : : : :- * Sort (19) + : : : : : +- Exchange (18) + : : : : : +- * Filter (17) + : : : : : +- * ColumnarToRow (16) + : : : : : +- Scan parquet default.item (15) + : : : : +- * Sort (33) + : : : : +- Exchange (32) + : : : : +- * Project (31) + : : : : +- * BroadcastHashJoin Inner BuildRight (30) + : : : : :- * Project (25) + : : : : : +- * BroadcastHashJoin Inner BuildRight (24) + : : : : : :- * Filter (22) + : : : : : : +- * ColumnarToRow (21) + : : : : : : +- Scan parquet default.catalog_sales (20) + : : : : : +- ReusedExchange (23) + : : : : +- BroadcastExchange (29) + : : : : +- * Filter (28) + : : : : +- * ColumnarToRow (27) + : : : : +- Scan parquet default.item (26) + : : : +- * Sort (53) + : : : +- Exchange (52) + : : : +- * Project (51) + : : : +- * BroadcastHashJoin Inner BuildRight (50) + : : : :- * Project (48) + : : : : +- * BroadcastHashJoin Inner BuildRight (47) + : : : : :- * Filter (45) + : : : : : +- * ColumnarToRow (44) + : : : : : +- Scan parquet default.web_sales (43) + : : : : +- ReusedExchange (46) + : : : +- ReusedExchange (49) + : : +- ReusedExchange (61) + : +- BroadcastExchange (72) + : +- * SortMergeJoin LeftSemi (71) + : :- * Sort (68) + : : +- Exchange (67) + : : +- * Filter (66) + : : +- * ColumnarToRow (65) + : : +- Scan parquet default.item (64) + : +- * Sort (70) + : +- ReusedExchange (69) + +- BroadcastExchange (97) + +- * Filter (96) + +- * HashAggregate (95) + +- Exchange (94) + +- * HashAggregate (93) + +- * Project (92) + +- * BroadcastHashJoin Inner BuildRight (91) + :- * Project (89) + : +- * BroadcastHashJoin Inner BuildRight (88) + : :- * SortMergeJoin LeftSemi (86) + : : :- * Sort (83) + : : : +- Exchange (82) + : : : +- * Filter (81) + : : : +- * ColumnarToRow (80) + : : : +- Scan parquet default.store_sales (79) + : : +- * Sort (85) + : : +- ReusedExchange (84) + : +- ReusedExchange (87) + +- ReusedExchange (90) (1) Scan parquet default.store_sales @@ -133,10 +130,10 @@ Location [not included in comparison]/{warehouse_dir}/item] PushedFilters: [IsNotNull(i_brand_id), IsNotNull(i_class_id), IsNotNull(i_category_id)] ReadSchema: struct -(7) ColumnarToRow [codegen id : 20] +(7) ColumnarToRow [codegen id : 19] Input [4]: [i_item_sk#7, i_brand_id#8, i_class_id#9, i_category_id#10] -(8) Filter [codegen id : 20] +(8) Filter [codegen id : 19] Input [4]: [i_item_sk#7, i_brand_id#8, i_class_id#9, i_category_id#10] Condition : ((isnotnull(i_brand_id#8) AND isnotnull(i_class_id#9)) AND isnotnull(i_category_id#10)) @@ -155,7 +152,7 @@ Input [2]: [ss_item_sk#11, ss_sold_date_sk#12] Input [2]: [ss_item_sk#11, ss_sold_date_sk#12] Condition : isnotnull(ss_item_sk#11) -(12) ReusedExchange [Reuses operator id: 135] +(12) ReusedExchange [Reuses operator id: 132] Output [1]: [d_date_sk#14] (13) BroadcastHashJoin [codegen id : 11] @@ -204,7 +201,7 @@ Input [2]: [cs_item_sk#20, cs_sold_date_sk#21] Input [2]: [cs_item_sk#20, cs_sold_date_sk#21] Condition : isnotnull(cs_item_sk#20) -(23) ReusedExchange [Reuses operator id: 135] +(23) ReusedExchange [Reuses operator id: 132] Output [1]: [d_date_sk#22] (24) BroadcastHashJoin [codegen id : 8] @@ -310,7 +307,7 @@ Input [2]: [ws_item_sk#35, ws_sold_date_sk#36] Input [2]: [ws_item_sk#35, ws_sold_date_sk#36] Condition : isnotnull(ws_item_sk#35) -(46) ReusedExchange [Reuses operator id: 135] +(46) ReusedExchange [Reuses operator id: 132] Output [1]: [d_date_sk#37] (47) BroadcastHashJoin [codegen id : 16] @@ -347,485 +344,467 @@ Left keys [6]: [coalesce(brand_id#30, 0), isnull(brand_id#30), coalesce(class_id Right keys [6]: [coalesce(i_brand_id#39, 0), isnull(i_brand_id#39), coalesce(i_class_id#40, 0), isnull(i_class_id#40), coalesce(i_category_id#41, 0), isnull(i_category_id#41)] Join condition: None -(55) HashAggregate [codegen id : 18] +(55) BroadcastExchange Input [3]: [brand_id#30, class_id#31, category_id#32] -Keys [3]: [brand_id#30, class_id#31, category_id#32] -Functions: [] -Aggregate Attributes: [] -Results [3]: [brand_id#30, class_id#31, category_id#32] - -(56) Exchange -Input [3]: [brand_id#30, class_id#31, category_id#32] -Arguments: hashpartitioning(brand_id#30, class_id#31, category_id#32, 5), ENSURE_REQUIREMENTS, [id=#43] - -(57) HashAggregate [codegen id : 19] -Input [3]: [brand_id#30, class_id#31, category_id#32] -Keys [3]: [brand_id#30, class_id#31, category_id#32] -Functions: [] -Aggregate Attributes: [] -Results [3]: [brand_id#30, class_id#31, category_id#32] - -(58) BroadcastExchange -Input [3]: [brand_id#30, class_id#31, category_id#32] -Arguments: HashedRelationBroadcastMode(List(input[0, int, true], input[1, int, true], input[2, int, true]),false), [id=#44] +Arguments: HashedRelationBroadcastMode(List(input[0, int, true], input[1, int, true], input[2, int, true]),false), [id=#43] -(59) BroadcastHashJoin [codegen id : 20] +(56) BroadcastHashJoin [codegen id : 19] Left keys [3]: [i_brand_id#8, i_class_id#9, i_category_id#10] Right keys [3]: [brand_id#30, class_id#31, category_id#32] Join condition: None -(60) Project [codegen id : 20] -Output [1]: [i_item_sk#7 AS ss_item_sk#45] +(57) Project [codegen id : 19] +Output [1]: [i_item_sk#7 AS ss_item_sk#44] Input [7]: [i_item_sk#7, i_brand_id#8, i_class_id#9, i_category_id#10, brand_id#30, class_id#31, category_id#32] -(61) Exchange -Input [1]: [ss_item_sk#45] -Arguments: hashpartitioning(ss_item_sk#45, 5), ENSURE_REQUIREMENTS, [id=#46] +(58) Exchange +Input [1]: [ss_item_sk#44] +Arguments: hashpartitioning(ss_item_sk#44, 5), ENSURE_REQUIREMENTS, [id=#45] -(62) Sort [codegen id : 21] -Input [1]: [ss_item_sk#45] -Arguments: [ss_item_sk#45 ASC NULLS FIRST], false, 0 +(59) Sort [codegen id : 20] +Input [1]: [ss_item_sk#44] +Arguments: [ss_item_sk#44 ASC NULLS FIRST], false, 0 -(63) SortMergeJoin [codegen id : 45] +(60) SortMergeJoin [codegen id : 43] Left keys [1]: [ss_item_sk#1] -Right keys [1]: [ss_item_sk#45] +Right keys [1]: [ss_item_sk#44] Join condition: None -(64) ReusedExchange [Reuses operator id: 126] -Output [1]: [d_date_sk#47] +(61) ReusedExchange [Reuses operator id: 123] +Output [1]: [d_date_sk#46] -(65) BroadcastHashJoin [codegen id : 45] +(62) BroadcastHashJoin [codegen id : 43] Left keys [1]: [ss_sold_date_sk#4] -Right keys [1]: [d_date_sk#47] +Right keys [1]: [d_date_sk#46] Join condition: None -(66) Project [codegen id : 45] +(63) Project [codegen id : 43] Output [3]: [ss_item_sk#1, ss_quantity#2, ss_list_price#3] -Input [5]: [ss_item_sk#1, ss_quantity#2, ss_list_price#3, ss_sold_date_sk#4, d_date_sk#47] +Input [5]: [ss_item_sk#1, ss_quantity#2, ss_list_price#3, ss_sold_date_sk#4, d_date_sk#46] -(67) Scan parquet default.item -Output [4]: [i_item_sk#48, i_brand_id#49, i_class_id#50, i_category_id#51] +(64) Scan parquet default.item +Output [4]: [i_item_sk#47, i_brand_id#48, i_class_id#49, i_category_id#50] Batched: true Location [not included in comparison]/{warehouse_dir}/item] PushedFilters: [IsNotNull(i_item_sk), IsNotNull(i_brand_id), IsNotNull(i_class_id), IsNotNull(i_category_id)] ReadSchema: struct -(68) ColumnarToRow [codegen id : 23] -Input [4]: [i_item_sk#48, i_brand_id#49, i_class_id#50, i_category_id#51] +(65) ColumnarToRow [codegen id : 22] +Input [4]: [i_item_sk#47, i_brand_id#48, i_class_id#49, i_category_id#50] -(69) Filter [codegen id : 23] -Input [4]: [i_item_sk#48, i_brand_id#49, i_class_id#50, i_category_id#51] -Condition : (((isnotnull(i_item_sk#48) AND isnotnull(i_brand_id#49)) AND isnotnull(i_class_id#50)) AND isnotnull(i_category_id#51)) +(66) Filter [codegen id : 22] +Input [4]: [i_item_sk#47, i_brand_id#48, i_class_id#49, i_category_id#50] +Condition : (((isnotnull(i_item_sk#47) AND isnotnull(i_brand_id#48)) AND isnotnull(i_class_id#49)) AND isnotnull(i_category_id#50)) -(70) Exchange -Input [4]: [i_item_sk#48, i_brand_id#49, i_class_id#50, i_category_id#51] -Arguments: hashpartitioning(i_item_sk#48, 5), ENSURE_REQUIREMENTS, [id=#52] +(67) Exchange +Input [4]: [i_item_sk#47, i_brand_id#48, i_class_id#49, i_category_id#50] +Arguments: hashpartitioning(i_item_sk#47, 5), ENSURE_REQUIREMENTS, [id=#51] -(71) Sort [codegen id : 24] -Input [4]: [i_item_sk#48, i_brand_id#49, i_class_id#50, i_category_id#51] -Arguments: [i_item_sk#48 ASC NULLS FIRST], false, 0 +(68) Sort [codegen id : 23] +Input [4]: [i_item_sk#47, i_brand_id#48, i_class_id#49, i_category_id#50] +Arguments: [i_item_sk#47 ASC NULLS FIRST], false, 0 -(72) ReusedExchange [Reuses operator id: 61] -Output [1]: [ss_item_sk#45] +(69) ReusedExchange [Reuses operator id: 58] +Output [1]: [ss_item_sk#44] -(73) Sort [codegen id : 43] -Input [1]: [ss_item_sk#45] -Arguments: [ss_item_sk#45 ASC NULLS FIRST], false, 0 +(70) Sort [codegen id : 41] +Input [1]: [ss_item_sk#44] +Arguments: [ss_item_sk#44 ASC NULLS FIRST], false, 0 -(74) SortMergeJoin [codegen id : 44] -Left keys [1]: [i_item_sk#48] -Right keys [1]: [ss_item_sk#45] +(71) SortMergeJoin [codegen id : 42] +Left keys [1]: [i_item_sk#47] +Right keys [1]: [ss_item_sk#44] Join condition: None -(75) BroadcastExchange -Input [4]: [i_item_sk#48, i_brand_id#49, i_class_id#50, i_category_id#51] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#53] +(72) BroadcastExchange +Input [4]: [i_item_sk#47, i_brand_id#48, i_class_id#49, i_category_id#50] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#52] -(76) BroadcastHashJoin [codegen id : 45] +(73) BroadcastHashJoin [codegen id : 43] Left keys [1]: [ss_item_sk#1] -Right keys [1]: [i_item_sk#48] +Right keys [1]: [i_item_sk#47] Join condition: None -(77) Project [codegen id : 45] -Output [5]: [ss_quantity#2, ss_list_price#3, i_brand_id#49, i_class_id#50, i_category_id#51] -Input [7]: [ss_item_sk#1, ss_quantity#2, ss_list_price#3, i_item_sk#48, i_brand_id#49, i_class_id#50, i_category_id#51] - -(78) HashAggregate [codegen id : 45] -Input [5]: [ss_quantity#2, ss_list_price#3, i_brand_id#49, i_class_id#50, i_category_id#51] -Keys [3]: [i_brand_id#49, i_class_id#50, i_category_id#51] -Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true)), partial_count(1)] -Aggregate Attributes [3]: [sum#54, isEmpty#55, count#56] -Results [6]: [i_brand_id#49, i_class_id#50, i_category_id#51, sum#57, isEmpty#58, count#59] - -(79) Exchange -Input [6]: [i_brand_id#49, i_class_id#50, i_category_id#51, sum#57, isEmpty#58, count#59] -Arguments: hashpartitioning(i_brand_id#49, i_class_id#50, i_category_id#51, 5), ENSURE_REQUIREMENTS, [id=#60] - -(80) HashAggregate [codegen id : 92] -Input [6]: [i_brand_id#49, i_class_id#50, i_category_id#51, sum#57, isEmpty#58, count#59] -Keys [3]: [i_brand_id#49, i_class_id#50, i_category_id#51] -Functions [2]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true)), count(1)] -Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true))#61, count(1)#62] -Results [6]: [store AS channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true))#61 AS sales#64, count(1)#62 AS number_sales#65] - -(81) Filter [codegen id : 92] -Input [6]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sales#64, number_sales#65] -Condition : (isnotnull(sales#64) AND (cast(sales#64 as decimal(32,6)) > cast(Subquery scalar-subquery#66, [id=#67] as decimal(32,6)))) - -(82) Scan parquet default.store_sales -Output [4]: [ss_item_sk#68, ss_quantity#69, ss_list_price#70, ss_sold_date_sk#71] +(74) Project [codegen id : 43] +Output [5]: [ss_quantity#2, ss_list_price#3, i_brand_id#48, i_class_id#49, i_category_id#50] +Input [7]: [ss_item_sk#1, ss_quantity#2, ss_list_price#3, i_item_sk#47, i_brand_id#48, i_class_id#49, i_category_id#50] + +(75) HashAggregate [codegen id : 43] +Input [5]: [ss_quantity#2, ss_list_price#3, i_brand_id#48, i_class_id#49, i_category_id#50] +Keys [3]: [i_brand_id#48, i_class_id#49, i_category_id#50] +Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2))), partial_count(1)] +Aggregate Attributes [3]: [sum#53, isEmpty#54, count#55] +Results [6]: [i_brand_id#48, i_class_id#49, i_category_id#50, sum#56, isEmpty#57, count#58] + +(76) Exchange +Input [6]: [i_brand_id#48, i_class_id#49, i_category_id#50, sum#56, isEmpty#57, count#58] +Arguments: hashpartitioning(i_brand_id#48, i_class_id#49, i_category_id#50, 5), ENSURE_REQUIREMENTS, [id=#59] + +(77) HashAggregate [codegen id : 88] +Input [6]: [i_brand_id#48, i_class_id#49, i_category_id#50, sum#56, isEmpty#57, count#58] +Keys [3]: [i_brand_id#48, i_class_id#49, i_category_id#50] +Functions [2]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2))), count(1)] +Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2)))#60, count(1)#61] +Results [6]: [store AS channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2)))#60 AS sales#63, count(1)#61 AS number_sales#64] + +(78) Filter [codegen id : 88] +Input [6]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sales#63, number_sales#64] +Condition : (isnotnull(sales#63) AND (cast(sales#63 as decimal(32,6)) > cast(Subquery scalar-subquery#65, [id=#66] as decimal(32,6)))) + +(79) Scan parquet default.store_sales +Output [4]: [ss_item_sk#67, ss_quantity#68, ss_list_price#69, ss_sold_date_sk#70] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(ss_sold_date_sk#71), dynamicpruningexpression(ss_sold_date_sk#71 IN dynamicpruning#72)] +PartitionFilters: [isnotnull(ss_sold_date_sk#70), dynamicpruningexpression(ss_sold_date_sk#70 IN dynamicpruning#71)] PushedFilters: [IsNotNull(ss_item_sk)] ReadSchema: struct -(83) ColumnarToRow [codegen id : 46] -Input [4]: [ss_item_sk#68, ss_quantity#69, ss_list_price#70, ss_sold_date_sk#71] +(80) ColumnarToRow [codegen id : 44] +Input [4]: [ss_item_sk#67, ss_quantity#68, ss_list_price#69, ss_sold_date_sk#70] -(84) Filter [codegen id : 46] -Input [4]: [ss_item_sk#68, ss_quantity#69, ss_list_price#70, ss_sold_date_sk#71] -Condition : isnotnull(ss_item_sk#68) +(81) Filter [codegen id : 44] +Input [4]: [ss_item_sk#67, ss_quantity#68, ss_list_price#69, ss_sold_date_sk#70] +Condition : isnotnull(ss_item_sk#67) -(85) Exchange -Input [4]: [ss_item_sk#68, ss_quantity#69, ss_list_price#70, ss_sold_date_sk#71] -Arguments: hashpartitioning(ss_item_sk#68, 5), ENSURE_REQUIREMENTS, [id=#73] +(82) Exchange +Input [4]: [ss_item_sk#67, ss_quantity#68, ss_list_price#69, ss_sold_date_sk#70] +Arguments: hashpartitioning(ss_item_sk#67, 5), ENSURE_REQUIREMENTS, [id=#72] -(86) Sort [codegen id : 47] -Input [4]: [ss_item_sk#68, ss_quantity#69, ss_list_price#70, ss_sold_date_sk#71] -Arguments: [ss_item_sk#68 ASC NULLS FIRST], false, 0 +(83) Sort [codegen id : 45] +Input [4]: [ss_item_sk#67, ss_quantity#68, ss_list_price#69, ss_sold_date_sk#70] +Arguments: [ss_item_sk#67 ASC NULLS FIRST], false, 0 -(87) ReusedExchange [Reuses operator id: 61] -Output [1]: [ss_item_sk#45] +(84) ReusedExchange [Reuses operator id: 58] +Output [1]: [ss_item_sk#44] -(88) Sort [codegen id : 66] -Input [1]: [ss_item_sk#45] -Arguments: [ss_item_sk#45 ASC NULLS FIRST], false, 0 +(85) Sort [codegen id : 63] +Input [1]: [ss_item_sk#44] +Arguments: [ss_item_sk#44 ASC NULLS FIRST], false, 0 -(89) SortMergeJoin [codegen id : 90] -Left keys [1]: [ss_item_sk#68] -Right keys [1]: [ss_item_sk#45] +(86) SortMergeJoin [codegen id : 86] +Left keys [1]: [ss_item_sk#67] +Right keys [1]: [ss_item_sk#44] Join condition: None -(90) ReusedExchange [Reuses operator id: 140] -Output [1]: [d_date_sk#74] +(87) ReusedExchange [Reuses operator id: 137] +Output [1]: [d_date_sk#73] -(91) BroadcastHashJoin [codegen id : 90] -Left keys [1]: [ss_sold_date_sk#71] -Right keys [1]: [d_date_sk#74] +(88) BroadcastHashJoin [codegen id : 86] +Left keys [1]: [ss_sold_date_sk#70] +Right keys [1]: [d_date_sk#73] Join condition: None -(92) Project [codegen id : 90] -Output [3]: [ss_item_sk#68, ss_quantity#69, ss_list_price#70] -Input [5]: [ss_item_sk#68, ss_quantity#69, ss_list_price#70, ss_sold_date_sk#71, d_date_sk#74] +(89) Project [codegen id : 86] +Output [3]: [ss_item_sk#67, ss_quantity#68, ss_list_price#69] +Input [5]: [ss_item_sk#67, ss_quantity#68, ss_list_price#69, ss_sold_date_sk#70, d_date_sk#73] -(93) ReusedExchange [Reuses operator id: 75] -Output [4]: [i_item_sk#75, i_brand_id#76, i_class_id#77, i_category_id#78] +(90) ReusedExchange [Reuses operator id: 72] +Output [4]: [i_item_sk#74, i_brand_id#75, i_class_id#76, i_category_id#77] -(94) BroadcastHashJoin [codegen id : 90] -Left keys [1]: [ss_item_sk#68] -Right keys [1]: [i_item_sk#75] +(91) BroadcastHashJoin [codegen id : 86] +Left keys [1]: [ss_item_sk#67] +Right keys [1]: [i_item_sk#74] Join condition: None -(95) Project [codegen id : 90] -Output [5]: [ss_quantity#69, ss_list_price#70, i_brand_id#76, i_class_id#77, i_category_id#78] -Input [7]: [ss_item_sk#68, ss_quantity#69, ss_list_price#70, i_item_sk#75, i_brand_id#76, i_class_id#77, i_category_id#78] - -(96) HashAggregate [codegen id : 90] -Input [5]: [ss_quantity#69, ss_list_price#70, i_brand_id#76, i_class_id#77, i_category_id#78] -Keys [3]: [i_brand_id#76, i_class_id#77, i_category_id#78] -Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#69 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#70 as decimal(12,2)))), DecimalType(18,2), true)), partial_count(1)] -Aggregate Attributes [3]: [sum#79, isEmpty#80, count#81] -Results [6]: [i_brand_id#76, i_class_id#77, i_category_id#78, sum#82, isEmpty#83, count#84] - -(97) Exchange -Input [6]: [i_brand_id#76, i_class_id#77, i_category_id#78, sum#82, isEmpty#83, count#84] -Arguments: hashpartitioning(i_brand_id#76, i_class_id#77, i_category_id#78, 5), ENSURE_REQUIREMENTS, [id=#85] - -(98) HashAggregate [codegen id : 91] -Input [6]: [i_brand_id#76, i_class_id#77, i_category_id#78, sum#82, isEmpty#83, count#84] -Keys [3]: [i_brand_id#76, i_class_id#77, i_category_id#78] -Functions [2]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#69 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#70 as decimal(12,2)))), DecimalType(18,2), true)), count(1)] -Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#69 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#70 as decimal(12,2)))), DecimalType(18,2), true))#86, count(1)#87] -Results [6]: [store AS channel#88, i_brand_id#76, i_class_id#77, i_category_id#78, sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#69 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#70 as decimal(12,2)))), DecimalType(18,2), true))#86 AS sales#89, count(1)#87 AS number_sales#90] - -(99) Filter [codegen id : 91] -Input [6]: [channel#88, i_brand_id#76, i_class_id#77, i_category_id#78, sales#89, number_sales#90] -Condition : (isnotnull(sales#89) AND (cast(sales#89 as decimal(32,6)) > cast(ReusedSubquery Subquery scalar-subquery#66, [id=#67] as decimal(32,6)))) - -(100) BroadcastExchange -Input [6]: [channel#88, i_brand_id#76, i_class_id#77, i_category_id#78, sales#89, number_sales#90] -Arguments: HashedRelationBroadcastMode(List(input[1, int, true], input[2, int, true], input[3, int, true]),false), [id=#91] - -(101) BroadcastHashJoin [codegen id : 92] -Left keys [3]: [i_brand_id#49, i_class_id#50, i_category_id#51] -Right keys [3]: [i_brand_id#76, i_class_id#77, i_category_id#78] +(92) Project [codegen id : 86] +Output [5]: [ss_quantity#68, ss_list_price#69, i_brand_id#75, i_class_id#76, i_category_id#77] +Input [7]: [ss_item_sk#67, ss_quantity#68, ss_list_price#69, i_item_sk#74, i_brand_id#75, i_class_id#76, i_category_id#77] + +(93) HashAggregate [codegen id : 86] +Input [5]: [ss_quantity#68, ss_list_price#69, i_brand_id#75, i_class_id#76, i_category_id#77] +Keys [3]: [i_brand_id#75, i_class_id#76, i_category_id#77] +Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#68 as decimal(12,2))) * promote_precision(cast(ss_list_price#69 as decimal(12,2)))), DecimalType(18,2))), partial_count(1)] +Aggregate Attributes [3]: [sum#78, isEmpty#79, count#80] +Results [6]: [i_brand_id#75, i_class_id#76, i_category_id#77, sum#81, isEmpty#82, count#83] + +(94) Exchange +Input [6]: [i_brand_id#75, i_class_id#76, i_category_id#77, sum#81, isEmpty#82, count#83] +Arguments: hashpartitioning(i_brand_id#75, i_class_id#76, i_category_id#77, 5), ENSURE_REQUIREMENTS, [id=#84] + +(95) HashAggregate [codegen id : 87] +Input [6]: [i_brand_id#75, i_class_id#76, i_category_id#77, sum#81, isEmpty#82, count#83] +Keys [3]: [i_brand_id#75, i_class_id#76, i_category_id#77] +Functions [2]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#68 as decimal(12,2))) * promote_precision(cast(ss_list_price#69 as decimal(12,2)))), DecimalType(18,2))), count(1)] +Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#68 as decimal(12,2))) * promote_precision(cast(ss_list_price#69 as decimal(12,2)))), DecimalType(18,2)))#85, count(1)#86] +Results [6]: [store AS channel#87, i_brand_id#75, i_class_id#76, i_category_id#77, sum(CheckOverflow((promote_precision(cast(ss_quantity#68 as decimal(12,2))) * promote_precision(cast(ss_list_price#69 as decimal(12,2)))), DecimalType(18,2)))#85 AS sales#88, count(1)#86 AS number_sales#89] + +(96) Filter [codegen id : 87] +Input [6]: [channel#87, i_brand_id#75, i_class_id#76, i_category_id#77, sales#88, number_sales#89] +Condition : (isnotnull(sales#88) AND (cast(sales#88 as decimal(32,6)) > cast(ReusedSubquery Subquery scalar-subquery#65, [id=#66] as decimal(32,6)))) + +(97) BroadcastExchange +Input [6]: [channel#87, i_brand_id#75, i_class_id#76, i_category_id#77, sales#88, number_sales#89] +Arguments: HashedRelationBroadcastMode(List(input[1, int, true], input[2, int, true], input[3, int, true]),false), [id=#90] + +(98) BroadcastHashJoin [codegen id : 88] +Left keys [3]: [i_brand_id#48, i_class_id#49, i_category_id#50] +Right keys [3]: [i_brand_id#75, i_class_id#76, i_category_id#77] Join condition: None -(102) TakeOrderedAndProject -Input [12]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sales#64, number_sales#65, channel#88, i_brand_id#76, i_class_id#77, i_category_id#78, sales#89, number_sales#90] -Arguments: 100, [i_brand_id#49 ASC NULLS FIRST, i_class_id#50 ASC NULLS FIRST, i_category_id#51 ASC NULLS FIRST], [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sales#64, number_sales#65, channel#88, i_brand_id#76, i_class_id#77, i_category_id#78, sales#89, number_sales#90] +(99) TakeOrderedAndProject +Input [12]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sales#63, number_sales#64, channel#87, i_brand_id#75, i_class_id#76, i_category_id#77, sales#88, number_sales#89] +Arguments: 100, [i_brand_id#48 ASC NULLS FIRST, i_class_id#49 ASC NULLS FIRST, i_category_id#50 ASC NULLS FIRST], [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sales#63, number_sales#64, channel#87, i_brand_id#75, i_class_id#76, i_category_id#77, sales#88, number_sales#89] ===== Subqueries ===== -Subquery:1 Hosting operator id = 81 Hosting Expression = Subquery scalar-subquery#66, [id=#67] -* HashAggregate (121) -+- Exchange (120) - +- * HashAggregate (119) - +- Union (118) - :- * Project (107) - : +- * BroadcastHashJoin Inner BuildRight (106) - : :- * ColumnarToRow (104) - : : +- Scan parquet default.store_sales (103) - : +- ReusedExchange (105) - :- * Project (112) - : +- * BroadcastHashJoin Inner BuildRight (111) - : :- * ColumnarToRow (109) - : : +- Scan parquet default.catalog_sales (108) - : +- ReusedExchange (110) - +- * Project (117) - +- * BroadcastHashJoin Inner BuildRight (116) - :- * ColumnarToRow (114) - : +- Scan parquet default.web_sales (113) - +- ReusedExchange (115) - - -(103) Scan parquet default.store_sales -Output [3]: [ss_quantity#92, ss_list_price#93, ss_sold_date_sk#94] +Subquery:1 Hosting operator id = 78 Hosting Expression = Subquery scalar-subquery#65, [id=#66] +* HashAggregate (118) ++- Exchange (117) + +- * HashAggregate (116) + +- Union (115) + :- * Project (104) + : +- * BroadcastHashJoin Inner BuildRight (103) + : :- * ColumnarToRow (101) + : : +- Scan parquet default.store_sales (100) + : +- ReusedExchange (102) + :- * Project (109) + : +- * BroadcastHashJoin Inner BuildRight (108) + : :- * ColumnarToRow (106) + : : +- Scan parquet default.catalog_sales (105) + : +- ReusedExchange (107) + +- * Project (114) + +- * BroadcastHashJoin Inner BuildRight (113) + :- * ColumnarToRow (111) + : +- Scan parquet default.web_sales (110) + +- ReusedExchange (112) + + +(100) Scan parquet default.store_sales +Output [3]: [ss_quantity#91, ss_list_price#92, ss_sold_date_sk#93] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(ss_sold_date_sk#94), dynamicpruningexpression(ss_sold_date_sk#94 IN dynamicpruning#13)] +PartitionFilters: [isnotnull(ss_sold_date_sk#93), dynamicpruningexpression(ss_sold_date_sk#93 IN dynamicpruning#13)] ReadSchema: struct -(104) ColumnarToRow [codegen id : 2] -Input [3]: [ss_quantity#92, ss_list_price#93, ss_sold_date_sk#94] +(101) ColumnarToRow [codegen id : 2] +Input [3]: [ss_quantity#91, ss_list_price#92, ss_sold_date_sk#93] -(105) ReusedExchange [Reuses operator id: 135] -Output [1]: [d_date_sk#95] +(102) ReusedExchange [Reuses operator id: 132] +Output [1]: [d_date_sk#94] -(106) BroadcastHashJoin [codegen id : 2] -Left keys [1]: [ss_sold_date_sk#94] -Right keys [1]: [d_date_sk#95] +(103) BroadcastHashJoin [codegen id : 2] +Left keys [1]: [ss_sold_date_sk#93] +Right keys [1]: [d_date_sk#94] Join condition: None -(107) Project [codegen id : 2] -Output [2]: [ss_quantity#92 AS quantity#96, ss_list_price#93 AS list_price#97] -Input [4]: [ss_quantity#92, ss_list_price#93, ss_sold_date_sk#94, d_date_sk#95] +(104) Project [codegen id : 2] +Output [2]: [ss_quantity#91 AS quantity#95, ss_list_price#92 AS list_price#96] +Input [4]: [ss_quantity#91, ss_list_price#92, ss_sold_date_sk#93, d_date_sk#94] -(108) Scan parquet default.catalog_sales -Output [3]: [cs_quantity#98, cs_list_price#99, cs_sold_date_sk#100] +(105) Scan parquet default.catalog_sales +Output [3]: [cs_quantity#97, cs_list_price#98, cs_sold_date_sk#99] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(cs_sold_date_sk#100), dynamicpruningexpression(cs_sold_date_sk#100 IN dynamicpruning#13)] +PartitionFilters: [isnotnull(cs_sold_date_sk#99), dynamicpruningexpression(cs_sold_date_sk#99 IN dynamicpruning#13)] ReadSchema: struct -(109) ColumnarToRow [codegen id : 4] -Input [3]: [cs_quantity#98, cs_list_price#99, cs_sold_date_sk#100] +(106) ColumnarToRow [codegen id : 4] +Input [3]: [cs_quantity#97, cs_list_price#98, cs_sold_date_sk#99] -(110) ReusedExchange [Reuses operator id: 135] -Output [1]: [d_date_sk#101] +(107) ReusedExchange [Reuses operator id: 132] +Output [1]: [d_date_sk#100] -(111) BroadcastHashJoin [codegen id : 4] -Left keys [1]: [cs_sold_date_sk#100] -Right keys [1]: [d_date_sk#101] +(108) BroadcastHashJoin [codegen id : 4] +Left keys [1]: [cs_sold_date_sk#99] +Right keys [1]: [d_date_sk#100] Join condition: None -(112) Project [codegen id : 4] -Output [2]: [cs_quantity#98 AS quantity#102, cs_list_price#99 AS list_price#103] -Input [4]: [cs_quantity#98, cs_list_price#99, cs_sold_date_sk#100, d_date_sk#101] +(109) Project [codegen id : 4] +Output [2]: [cs_quantity#97 AS quantity#101, cs_list_price#98 AS list_price#102] +Input [4]: [cs_quantity#97, cs_list_price#98, cs_sold_date_sk#99, d_date_sk#100] -(113) Scan parquet default.web_sales -Output [3]: [ws_quantity#104, ws_list_price#105, ws_sold_date_sk#106] +(110) Scan parquet default.web_sales +Output [3]: [ws_quantity#103, ws_list_price#104, ws_sold_date_sk#105] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(ws_sold_date_sk#106), dynamicpruningexpression(ws_sold_date_sk#106 IN dynamicpruning#13)] +PartitionFilters: [isnotnull(ws_sold_date_sk#105), dynamicpruningexpression(ws_sold_date_sk#105 IN dynamicpruning#13)] ReadSchema: struct -(114) ColumnarToRow [codegen id : 6] -Input [3]: [ws_quantity#104, ws_list_price#105, ws_sold_date_sk#106] +(111) ColumnarToRow [codegen id : 6] +Input [3]: [ws_quantity#103, ws_list_price#104, ws_sold_date_sk#105] -(115) ReusedExchange [Reuses operator id: 135] -Output [1]: [d_date_sk#107] +(112) ReusedExchange [Reuses operator id: 132] +Output [1]: [d_date_sk#106] -(116) BroadcastHashJoin [codegen id : 6] -Left keys [1]: [ws_sold_date_sk#106] -Right keys [1]: [d_date_sk#107] +(113) BroadcastHashJoin [codegen id : 6] +Left keys [1]: [ws_sold_date_sk#105] +Right keys [1]: [d_date_sk#106] Join condition: None -(117) Project [codegen id : 6] -Output [2]: [ws_quantity#104 AS quantity#108, ws_list_price#105 AS list_price#109] -Input [4]: [ws_quantity#104, ws_list_price#105, ws_sold_date_sk#106, d_date_sk#107] +(114) Project [codegen id : 6] +Output [2]: [ws_quantity#103 AS quantity#107, ws_list_price#104 AS list_price#108] +Input [4]: [ws_quantity#103, ws_list_price#104, ws_sold_date_sk#105, d_date_sk#106] -(118) Union +(115) Union -(119) HashAggregate [codegen id : 7] -Input [2]: [quantity#96, list_price#97] +(116) HashAggregate [codegen id : 7] +Input [2]: [quantity#95, list_price#96] Keys: [] -Functions [1]: [partial_avg(CheckOverflow((promote_precision(cast(cast(quantity#96 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#97 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [2]: [sum#110, count#111] -Results [2]: [sum#112, count#113] +Functions [1]: [partial_avg(CheckOverflow((promote_precision(cast(quantity#95 as decimal(12,2))) * promote_precision(cast(list_price#96 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [2]: [sum#109, count#110] +Results [2]: [sum#111, count#112] -(120) Exchange -Input [2]: [sum#112, count#113] -Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#114] +(117) Exchange +Input [2]: [sum#111, count#112] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#113] -(121) HashAggregate [codegen id : 8] -Input [2]: [sum#112, count#113] +(118) HashAggregate [codegen id : 8] +Input [2]: [sum#111, count#112] Keys: [] -Functions [1]: [avg(CheckOverflow((promote_precision(cast(cast(quantity#96 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#97 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [1]: [avg(CheckOverflow((promote_precision(cast(cast(quantity#96 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#97 as decimal(12,2)))), DecimalType(18,2), true))#115] -Results [1]: [avg(CheckOverflow((promote_precision(cast(cast(quantity#96 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#97 as decimal(12,2)))), DecimalType(18,2), true))#115 AS average_sales#116] +Functions [1]: [avg(CheckOverflow((promote_precision(cast(quantity#95 as decimal(12,2))) * promote_precision(cast(list_price#96 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [1]: [avg(CheckOverflow((promote_precision(cast(quantity#95 as decimal(12,2))) * promote_precision(cast(list_price#96 as decimal(12,2)))), DecimalType(18,2)))#114] +Results [1]: [avg(CheckOverflow((promote_precision(cast(quantity#95 as decimal(12,2))) * promote_precision(cast(list_price#96 as decimal(12,2)))), DecimalType(18,2)))#114 AS average_sales#115] -Subquery:2 Hosting operator id = 103 Hosting Expression = ss_sold_date_sk#94 IN dynamicpruning#13 +Subquery:2 Hosting operator id = 100 Hosting Expression = ss_sold_date_sk#93 IN dynamicpruning#13 -Subquery:3 Hosting operator id = 108 Hosting Expression = cs_sold_date_sk#100 IN dynamicpruning#13 +Subquery:3 Hosting operator id = 105 Hosting Expression = cs_sold_date_sk#99 IN dynamicpruning#13 -Subquery:4 Hosting operator id = 113 Hosting Expression = ws_sold_date_sk#106 IN dynamicpruning#13 +Subquery:4 Hosting operator id = 110 Hosting Expression = ws_sold_date_sk#105 IN dynamicpruning#13 Subquery:5 Hosting operator id = 1 Hosting Expression = ss_sold_date_sk#4 IN dynamicpruning#5 -BroadcastExchange (126) -+- * Project (125) - +- * Filter (124) - +- * ColumnarToRow (123) - +- Scan parquet default.date_dim (122) +BroadcastExchange (123) ++- * Project (122) + +- * Filter (121) + +- * ColumnarToRow (120) + +- Scan parquet default.date_dim (119) -(122) Scan parquet default.date_dim -Output [2]: [d_date_sk#47, d_week_seq#117] +(119) Scan parquet default.date_dim +Output [2]: [d_date_sk#46, d_week_seq#116] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_week_seq), IsNotNull(d_date_sk)] ReadSchema: struct -(123) ColumnarToRow [codegen id : 1] -Input [2]: [d_date_sk#47, d_week_seq#117] +(120) ColumnarToRow [codegen id : 1] +Input [2]: [d_date_sk#46, d_week_seq#116] -(124) Filter [codegen id : 1] -Input [2]: [d_date_sk#47, d_week_seq#117] -Condition : ((isnotnull(d_week_seq#117) AND (d_week_seq#117 = Subquery scalar-subquery#118, [id=#119])) AND isnotnull(d_date_sk#47)) +(121) Filter [codegen id : 1] +Input [2]: [d_date_sk#46, d_week_seq#116] +Condition : ((isnotnull(d_week_seq#116) AND (d_week_seq#116 = Subquery scalar-subquery#117, [id=#118])) AND isnotnull(d_date_sk#46)) -(125) Project [codegen id : 1] -Output [1]: [d_date_sk#47] -Input [2]: [d_date_sk#47, d_week_seq#117] +(122) Project [codegen id : 1] +Output [1]: [d_date_sk#46] +Input [2]: [d_date_sk#46, d_week_seq#116] -(126) BroadcastExchange -Input [1]: [d_date_sk#47] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#120] +(123) BroadcastExchange +Input [1]: [d_date_sk#46] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#119] -Subquery:6 Hosting operator id = 124 Hosting Expression = Subquery scalar-subquery#118, [id=#119] -* Project (130) -+- * Filter (129) - +- * ColumnarToRow (128) - +- Scan parquet default.date_dim (127) +Subquery:6 Hosting operator id = 121 Hosting Expression = Subquery scalar-subquery#117, [id=#118] +* Project (127) ++- * Filter (126) + +- * ColumnarToRow (125) + +- Scan parquet default.date_dim (124) -(127) Scan parquet default.date_dim -Output [4]: [d_week_seq#121, d_year#122, d_moy#123, d_dom#124] +(124) Scan parquet default.date_dim +Output [4]: [d_week_seq#120, d_year#121, d_moy#122, d_dom#123] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), IsNotNull(d_moy), IsNotNull(d_dom), EqualTo(d_year,1999), EqualTo(d_moy,12), EqualTo(d_dom,16)] ReadSchema: struct -(128) ColumnarToRow [codegen id : 1] -Input [4]: [d_week_seq#121, d_year#122, d_moy#123, d_dom#124] +(125) ColumnarToRow [codegen id : 1] +Input [4]: [d_week_seq#120, d_year#121, d_moy#122, d_dom#123] -(129) Filter [codegen id : 1] -Input [4]: [d_week_seq#121, d_year#122, d_moy#123, d_dom#124] -Condition : (((((isnotnull(d_year#122) AND isnotnull(d_moy#123)) AND isnotnull(d_dom#124)) AND (d_year#122 = 1999)) AND (d_moy#123 = 12)) AND (d_dom#124 = 16)) +(126) Filter [codegen id : 1] +Input [4]: [d_week_seq#120, d_year#121, d_moy#122, d_dom#123] +Condition : (((((isnotnull(d_year#121) AND isnotnull(d_moy#122)) AND isnotnull(d_dom#123)) AND (d_year#121 = 1999)) AND (d_moy#122 = 12)) AND (d_dom#123 = 16)) -(130) Project [codegen id : 1] -Output [1]: [d_week_seq#121] -Input [4]: [d_week_seq#121, d_year#122, d_moy#123, d_dom#124] +(127) Project [codegen id : 1] +Output [1]: [d_week_seq#120] +Input [4]: [d_week_seq#120, d_year#121, d_moy#122, d_dom#123] Subquery:7 Hosting operator id = 9 Hosting Expression = ss_sold_date_sk#12 IN dynamicpruning#13 -BroadcastExchange (135) -+- * Project (134) - +- * Filter (133) - +- * ColumnarToRow (132) - +- Scan parquet default.date_dim (131) +BroadcastExchange (132) ++- * Project (131) + +- * Filter (130) + +- * ColumnarToRow (129) + +- Scan parquet default.date_dim (128) -(131) Scan parquet default.date_dim -Output [2]: [d_date_sk#14, d_year#125] +(128) Scan parquet default.date_dim +Output [2]: [d_date_sk#14, d_year#124] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), GreaterThanOrEqual(d_year,1998), LessThanOrEqual(d_year,2000), IsNotNull(d_date_sk)] ReadSchema: struct -(132) ColumnarToRow [codegen id : 1] -Input [2]: [d_date_sk#14, d_year#125] +(129) ColumnarToRow [codegen id : 1] +Input [2]: [d_date_sk#14, d_year#124] -(133) Filter [codegen id : 1] -Input [2]: [d_date_sk#14, d_year#125] -Condition : (((isnotnull(d_year#125) AND (d_year#125 >= 1998)) AND (d_year#125 <= 2000)) AND isnotnull(d_date_sk#14)) +(130) Filter [codegen id : 1] +Input [2]: [d_date_sk#14, d_year#124] +Condition : (((isnotnull(d_year#124) AND (d_year#124 >= 1998)) AND (d_year#124 <= 2000)) AND isnotnull(d_date_sk#14)) -(134) Project [codegen id : 1] +(131) Project [codegen id : 1] Output [1]: [d_date_sk#14] -Input [2]: [d_date_sk#14, d_year#125] +Input [2]: [d_date_sk#14, d_year#124] -(135) BroadcastExchange +(132) BroadcastExchange Input [1]: [d_date_sk#14] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#126] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#125] Subquery:8 Hosting operator id = 20 Hosting Expression = cs_sold_date_sk#21 IN dynamicpruning#13 Subquery:9 Hosting operator id = 43 Hosting Expression = ws_sold_date_sk#36 IN dynamicpruning#13 -Subquery:10 Hosting operator id = 99 Hosting Expression = ReusedSubquery Subquery scalar-subquery#66, [id=#67] +Subquery:10 Hosting operator id = 96 Hosting Expression = ReusedSubquery Subquery scalar-subquery#65, [id=#66] -Subquery:11 Hosting operator id = 82 Hosting Expression = ss_sold_date_sk#71 IN dynamicpruning#72 -BroadcastExchange (140) -+- * Project (139) - +- * Filter (138) - +- * ColumnarToRow (137) - +- Scan parquet default.date_dim (136) +Subquery:11 Hosting operator id = 79 Hosting Expression = ss_sold_date_sk#70 IN dynamicpruning#71 +BroadcastExchange (137) ++- * Project (136) + +- * Filter (135) + +- * ColumnarToRow (134) + +- Scan parquet default.date_dim (133) -(136) Scan parquet default.date_dim -Output [2]: [d_date_sk#74, d_week_seq#127] +(133) Scan parquet default.date_dim +Output [2]: [d_date_sk#73, d_week_seq#126] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_week_seq), IsNotNull(d_date_sk)] ReadSchema: struct -(137) ColumnarToRow [codegen id : 1] -Input [2]: [d_date_sk#74, d_week_seq#127] +(134) ColumnarToRow [codegen id : 1] +Input [2]: [d_date_sk#73, d_week_seq#126] -(138) Filter [codegen id : 1] -Input [2]: [d_date_sk#74, d_week_seq#127] -Condition : ((isnotnull(d_week_seq#127) AND (d_week_seq#127 = Subquery scalar-subquery#128, [id=#129])) AND isnotnull(d_date_sk#74)) +(135) Filter [codegen id : 1] +Input [2]: [d_date_sk#73, d_week_seq#126] +Condition : ((isnotnull(d_week_seq#126) AND (d_week_seq#126 = Subquery scalar-subquery#127, [id=#128])) AND isnotnull(d_date_sk#73)) -(139) Project [codegen id : 1] -Output [1]: [d_date_sk#74] -Input [2]: [d_date_sk#74, d_week_seq#127] +(136) Project [codegen id : 1] +Output [1]: [d_date_sk#73] +Input [2]: [d_date_sk#73, d_week_seq#126] -(140) BroadcastExchange -Input [1]: [d_date_sk#74] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#130] +(137) BroadcastExchange +Input [1]: [d_date_sk#73] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#129] -Subquery:12 Hosting operator id = 138 Hosting Expression = Subquery scalar-subquery#128, [id=#129] -* Project (144) -+- * Filter (143) - +- * ColumnarToRow (142) - +- Scan parquet default.date_dim (141) +Subquery:12 Hosting operator id = 135 Hosting Expression = Subquery scalar-subquery#127, [id=#128] +* Project (141) ++- * Filter (140) + +- * ColumnarToRow (139) + +- Scan parquet default.date_dim (138) -(141) Scan parquet default.date_dim -Output [4]: [d_week_seq#131, d_year#132, d_moy#133, d_dom#134] +(138) Scan parquet default.date_dim +Output [4]: [d_week_seq#130, d_year#131, d_moy#132, d_dom#133] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), IsNotNull(d_moy), IsNotNull(d_dom), EqualTo(d_year,1998), EqualTo(d_moy,12), EqualTo(d_dom,16)] ReadSchema: struct -(142) ColumnarToRow [codegen id : 1] -Input [4]: [d_week_seq#131, d_year#132, d_moy#133, d_dom#134] +(139) ColumnarToRow [codegen id : 1] +Input [4]: [d_week_seq#130, d_year#131, d_moy#132, d_dom#133] -(143) Filter [codegen id : 1] -Input [4]: [d_week_seq#131, d_year#132, d_moy#133, d_dom#134] -Condition : (((((isnotnull(d_year#132) AND isnotnull(d_moy#133)) AND isnotnull(d_dom#134)) AND (d_year#132 = 1998)) AND (d_moy#133 = 12)) AND (d_dom#134 = 16)) +(140) Filter [codegen id : 1] +Input [4]: [d_week_seq#130, d_year#131, d_moy#132, d_dom#133] +Condition : (((((isnotnull(d_year#131) AND isnotnull(d_moy#132)) AND isnotnull(d_dom#133)) AND (d_year#131 = 1998)) AND (d_moy#132 = 12)) AND (d_dom#133 = 16)) -(144) Project [codegen id : 1] -Output [1]: [d_week_seq#131] -Input [4]: [d_week_seq#131, d_year#132, d_moy#133, d_dom#134] +(141) Project [codegen id : 1] +Output [1]: [d_week_seq#130] +Input [4]: [d_week_seq#130, d_year#131, d_moy#132, d_dom#133] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14.sf100/simplified.txt index 695a7c13381d8..82e338515f431 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14.sf100/simplified.txt @@ -1,12 +1,12 @@ TakeOrderedAndProject [i_brand_id,i_class_id,i_category_id,channel,sales,number_sales,channel,i_brand_id,i_class_id,i_category_id,sales,number_sales] - WholeStageCodegen (92) + WholeStageCodegen (88) BroadcastHashJoin [i_brand_id,i_class_id,i_category_id,i_brand_id,i_class_id,i_category_id] Filter [sales] Subquery #4 WholeStageCodegen (8) - HashAggregate [sum,count] [avg(CheckOverflow((promote_precision(cast(cast(quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price as decimal(12,2)))), DecimalType(18,2), true)),average_sales,sum,count] + HashAggregate [sum,count] [avg(CheckOverflow((promote_precision(cast(quantity as decimal(12,2))) * promote_precision(cast(list_price as decimal(12,2)))), DecimalType(18,2))),average_sales,sum,count] InputAdapter - Exchange #17 + Exchange #16 WholeStageCodegen (7) HashAggregate [quantity,list_price] [sum,count,sum,count] InputAdapter @@ -19,7 +19,7 @@ TakeOrderedAndProject [i_brand_id,i_class_id,i_category_id,channel,sales,number_ Scan parquet default.store_sales [ss_quantity,ss_list_price,ss_sold_date_sk] ReusedSubquery [d_date_sk] #3 InputAdapter - ReusedExchange [d_date_sk] #9 + ReusedExchange [d_date_sk] #8 WholeStageCodegen (4) Project [cs_quantity,cs_list_price] BroadcastHashJoin [cs_sold_date_sk,d_date_sk] @@ -28,7 +28,7 @@ TakeOrderedAndProject [i_brand_id,i_class_id,i_category_id,channel,sales,number_ Scan parquet default.catalog_sales [cs_quantity,cs_list_price,cs_sold_date_sk] ReusedSubquery [d_date_sk] #3 InputAdapter - ReusedExchange [d_date_sk] #9 + ReusedExchange [d_date_sk] #8 WholeStageCodegen (6) Project [ws_quantity,ws_list_price] BroadcastHashJoin [ws_sold_date_sk,d_date_sk] @@ -37,11 +37,11 @@ TakeOrderedAndProject [i_brand_id,i_class_id,i_category_id,channel,sales,number_ Scan parquet default.web_sales [ws_quantity,ws_list_price,ws_sold_date_sk] ReusedSubquery [d_date_sk] #3 InputAdapter - ReusedExchange [d_date_sk] #9 - HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price as decimal(12,2)))), DecimalType(18,2), true)),count(1),channel,sales,number_sales,sum,isEmpty,count] + ReusedExchange [d_date_sk] #8 + HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(ss_quantity as decimal(12,2))) * promote_precision(cast(ss_list_price as decimal(12,2)))), DecimalType(18,2))),count(1),channel,sales,number_sales,sum,isEmpty,count] InputAdapter Exchange [i_brand_id,i_class_id,i_category_id] #1 - WholeStageCodegen (45) + WholeStageCodegen (43) HashAggregate [i_brand_id,i_class_id,i_category_id,ss_quantity,ss_list_price] [sum,isEmpty,count,sum,isEmpty,count] Project [ss_quantity,ss_list_price,i_brand_id,i_class_id,i_category_id] BroadcastHashJoin [ss_item_sk,i_item_sk] @@ -74,11 +74,11 @@ TakeOrderedAndProject [i_brand_id,i_class_id,i_category_id,channel,sales,number_ InputAdapter Scan parquet default.date_dim [d_date_sk,d_week_seq] InputAdapter - WholeStageCodegen (21) + WholeStageCodegen (20) Sort [ss_item_sk] InputAdapter Exchange [ss_item_sk] #4 - WholeStageCodegen (20) + WholeStageCodegen (19) Project [i_item_sk] BroadcastHashJoin [i_brand_id,i_class_id,i_category_id,brand_id,class_id,category_id] Filter [i_brand_id,i_class_id,i_category_id] @@ -87,129 +87,124 @@ TakeOrderedAndProject [i_brand_id,i_class_id,i_category_id,channel,sales,number_ Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] InputAdapter BroadcastExchange #5 - WholeStageCodegen (19) - HashAggregate [brand_id,class_id,category_id] + WholeStageCodegen (18) + SortMergeJoin [brand_id,class_id,category_id,i_brand_id,i_class_id,i_category_id] InputAdapter - Exchange [brand_id,class_id,category_id] #6 - WholeStageCodegen (18) - HashAggregate [brand_id,class_id,category_id] - SortMergeJoin [brand_id,class_id,category_id,i_brand_id,i_class_id,i_category_id] - InputAdapter - WholeStageCodegen (13) - Sort [brand_id,class_id,category_id] - InputAdapter - Exchange [brand_id,class_id,category_id] #7 - WholeStageCodegen (12) - HashAggregate [brand_id,class_id,category_id] - InputAdapter - Exchange [brand_id,class_id,category_id] #8 - WholeStageCodegen (11) - HashAggregate [brand_id,class_id,category_id] - Project [i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [ss_item_sk,i_item_sk] - Project [ss_item_sk] - BroadcastHashJoin [ss_sold_date_sk,d_date_sk] - Filter [ss_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.store_sales [ss_item_sk,ss_sold_date_sk] - SubqueryBroadcast [d_date_sk] #3 - BroadcastExchange #9 - WholeStageCodegen (1) - Project [d_date_sk] - Filter [d_year,d_date_sk] - ColumnarToRow - InputAdapter - Scan parquet default.date_dim [d_date_sk,d_year] + WholeStageCodegen (13) + Sort [brand_id,class_id,category_id] + InputAdapter + Exchange [brand_id,class_id,category_id] #6 + WholeStageCodegen (12) + HashAggregate [brand_id,class_id,category_id] + InputAdapter + Exchange [brand_id,class_id,category_id] #7 + WholeStageCodegen (11) + HashAggregate [brand_id,class_id,category_id] + Project [i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [ss_item_sk,i_item_sk] + Project [ss_item_sk] + BroadcastHashJoin [ss_sold_date_sk,d_date_sk] + Filter [ss_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.store_sales [ss_item_sk,ss_sold_date_sk] + SubqueryBroadcast [d_date_sk] #3 + BroadcastExchange #8 + WholeStageCodegen (1) + Project [d_date_sk] + Filter [d_year,d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_year] + InputAdapter + ReusedExchange [d_date_sk] #8 + InputAdapter + BroadcastExchange #9 + WholeStageCodegen (10) + SortMergeJoin [i_brand_id,i_class_id,i_category_id,i_brand_id,i_class_id,i_category_id] + InputAdapter + WholeStageCodegen (5) + Sort [i_brand_id,i_class_id,i_category_id] InputAdapter - ReusedExchange [d_date_sk] #9 - InputAdapter - BroadcastExchange #10 - WholeStageCodegen (10) - SortMergeJoin [i_brand_id,i_class_id,i_category_id,i_brand_id,i_class_id,i_category_id] - InputAdapter - WholeStageCodegen (5) - Sort [i_brand_id,i_class_id,i_category_id] + Exchange [i_brand_id,i_class_id,i_category_id] #10 + WholeStageCodegen (4) + Filter [i_item_sk,i_brand_id,i_class_id,i_category_id] + ColumnarToRow InputAdapter - Exchange [i_brand_id,i_class_id,i_category_id] #11 - WholeStageCodegen (4) - Filter [i_item_sk,i_brand_id,i_class_id,i_category_id] + Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] + InputAdapter + WholeStageCodegen (9) + Sort [i_brand_id,i_class_id,i_category_id] + InputAdapter + Exchange [i_brand_id,i_class_id,i_category_id] #11 + WholeStageCodegen (8) + Project [i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [cs_item_sk,i_item_sk] + Project [cs_item_sk] + BroadcastHashJoin [cs_sold_date_sk,d_date_sk] + Filter [cs_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.catalog_sales [cs_item_sk,cs_sold_date_sk] + ReusedSubquery [d_date_sk] #3 + InputAdapter + ReusedExchange [d_date_sk] #8 + InputAdapter + BroadcastExchange #12 + WholeStageCodegen (7) + Filter [i_item_sk] ColumnarToRow InputAdapter Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] - InputAdapter - WholeStageCodegen (9) - Sort [i_brand_id,i_class_id,i_category_id] - InputAdapter - Exchange [i_brand_id,i_class_id,i_category_id] #12 - WholeStageCodegen (8) - Project [i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [cs_item_sk,i_item_sk] - Project [cs_item_sk] - BroadcastHashJoin [cs_sold_date_sk,d_date_sk] - Filter [cs_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.catalog_sales [cs_item_sk,cs_sold_date_sk] - ReusedSubquery [d_date_sk] #3 - InputAdapter - ReusedExchange [d_date_sk] #9 - InputAdapter - BroadcastExchange #13 - WholeStageCodegen (7) - Filter [i_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] - InputAdapter - WholeStageCodegen (17) - Sort [i_brand_id,i_class_id,i_category_id] + InputAdapter + WholeStageCodegen (17) + Sort [i_brand_id,i_class_id,i_category_id] + InputAdapter + Exchange [i_brand_id,i_class_id,i_category_id] #13 + WholeStageCodegen (16) + Project [i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [ws_item_sk,i_item_sk] + Project [ws_item_sk] + BroadcastHashJoin [ws_sold_date_sk,d_date_sk] + Filter [ws_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.web_sales [ws_item_sk,ws_sold_date_sk] + ReusedSubquery [d_date_sk] #3 + InputAdapter + ReusedExchange [d_date_sk] #8 InputAdapter - Exchange [i_brand_id,i_class_id,i_category_id] #14 - WholeStageCodegen (16) - Project [i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [ws_item_sk,i_item_sk] - Project [ws_item_sk] - BroadcastHashJoin [ws_sold_date_sk,d_date_sk] - Filter [ws_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.web_sales [ws_item_sk,ws_sold_date_sk] - ReusedSubquery [d_date_sk] #3 - InputAdapter - ReusedExchange [d_date_sk] #9 - InputAdapter - ReusedExchange [i_item_sk,i_brand_id,i_class_id,i_category_id] #13 + ReusedExchange [i_item_sk,i_brand_id,i_class_id,i_category_id] #12 InputAdapter ReusedExchange [d_date_sk] #3 InputAdapter - BroadcastExchange #15 - WholeStageCodegen (44) + BroadcastExchange #14 + WholeStageCodegen (42) SortMergeJoin [i_item_sk,ss_item_sk] InputAdapter - WholeStageCodegen (24) + WholeStageCodegen (23) Sort [i_item_sk] InputAdapter - Exchange [i_item_sk] #16 - WholeStageCodegen (23) + Exchange [i_item_sk] #15 + WholeStageCodegen (22) Filter [i_item_sk,i_brand_id,i_class_id,i_category_id] ColumnarToRow InputAdapter Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] InputAdapter - WholeStageCodegen (43) + WholeStageCodegen (41) Sort [ss_item_sk] InputAdapter ReusedExchange [ss_item_sk] #4 InputAdapter - BroadcastExchange #18 - WholeStageCodegen (91) + BroadcastExchange #17 + WholeStageCodegen (87) Filter [sales] ReusedSubquery [average_sales] #4 - HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price as decimal(12,2)))), DecimalType(18,2), true)),count(1),channel,sales,number_sales,sum,isEmpty,count] + HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(ss_quantity as decimal(12,2))) * promote_precision(cast(ss_list_price as decimal(12,2)))), DecimalType(18,2))),count(1),channel,sales,number_sales,sum,isEmpty,count] InputAdapter - Exchange [i_brand_id,i_class_id,i_category_id] #19 - WholeStageCodegen (90) + Exchange [i_brand_id,i_class_id,i_category_id] #18 + WholeStageCodegen (86) HashAggregate [i_brand_id,i_class_id,i_category_id,ss_quantity,ss_list_price] [sum,isEmpty,count,sum,isEmpty,count] Project [ss_quantity,ss_list_price,i_brand_id,i_class_id,i_category_id] BroadcastHashJoin [ss_item_sk,i_item_sk] @@ -217,17 +212,17 @@ TakeOrderedAndProject [i_brand_id,i_class_id,i_category_id,channel,sales,number_ BroadcastHashJoin [ss_sold_date_sk,d_date_sk] SortMergeJoin [ss_item_sk,ss_item_sk] InputAdapter - WholeStageCodegen (47) + WholeStageCodegen (45) Sort [ss_item_sk] InputAdapter - Exchange [ss_item_sk] #20 - WholeStageCodegen (46) + Exchange [ss_item_sk] #19 + WholeStageCodegen (44) Filter [ss_item_sk] ColumnarToRow InputAdapter Scan parquet default.store_sales [ss_item_sk,ss_quantity,ss_list_price,ss_sold_date_sk] SubqueryBroadcast [d_date_sk] #5 - BroadcastExchange #21 + BroadcastExchange #20 WholeStageCodegen (1) Project [d_date_sk] Filter [d_week_seq,d_date_sk] @@ -242,11 +237,11 @@ TakeOrderedAndProject [i_brand_id,i_class_id,i_category_id,channel,sales,number_ InputAdapter Scan parquet default.date_dim [d_date_sk,d_week_seq] InputAdapter - WholeStageCodegen (66) + WholeStageCodegen (63) Sort [ss_item_sk] InputAdapter ReusedExchange [ss_item_sk] #4 InputAdapter - ReusedExchange [d_date_sk] #21 + ReusedExchange [d_date_sk] #20 InputAdapter - ReusedExchange [i_item_sk,i_brand_id,i_class_id,i_category_id] #15 + ReusedExchange [i_item_sk,i_brand_id,i_class_id,i_category_id] #14 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14/explain.txt index 212cb97de2873..86bbc553e8c31 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14/explain.txt @@ -1,90 +1,88 @@ == Physical Plan == -TakeOrderedAndProject (86) -+- * BroadcastHashJoin Inner BuildRight (85) - :- * Filter (68) - : +- * HashAggregate (67) - : +- Exchange (66) - : +- * HashAggregate (65) - : +- * Project (64) - : +- * BroadcastHashJoin Inner BuildRight (63) - : :- * Project (61) - : : +- * BroadcastHashJoin Inner BuildRight (60) - : : :- * BroadcastHashJoin LeftSemi BuildRight (53) +TakeOrderedAndProject (84) ++- * BroadcastHashJoin Inner BuildRight (83) + :- * Filter (66) + : +- * HashAggregate (65) + : +- Exchange (64) + : +- * HashAggregate (63) + : +- * Project (62) + : +- * BroadcastHashJoin Inner BuildRight (61) + : :- * Project (59) + : : +- * BroadcastHashJoin Inner BuildRight (58) + : : :- * BroadcastHashJoin LeftSemi BuildRight (51) : : : :- * Filter (3) : : : : +- * ColumnarToRow (2) : : : : +- Scan parquet default.store_sales (1) - : : : +- BroadcastExchange (52) - : : : +- * Project (51) - : : : +- * BroadcastHashJoin Inner BuildRight (50) + : : : +- BroadcastExchange (50) + : : : +- * Project (49) + : : : +- * BroadcastHashJoin Inner BuildRight (48) : : : :- * Filter (6) : : : : +- * ColumnarToRow (5) : : : : +- Scan parquet default.item (4) - : : : +- BroadcastExchange (49) - : : : +- * HashAggregate (48) - : : : +- * HashAggregate (47) - : : : +- * BroadcastHashJoin LeftSemi BuildRight (46) - : : : :- * HashAggregate (35) - : : : : +- Exchange (34) - : : : : +- * HashAggregate (33) - : : : : +- * Project (32) - : : : : +- * BroadcastHashJoin Inner BuildRight (31) - : : : : :- * Project (29) - : : : : : +- * BroadcastHashJoin Inner BuildRight (28) - : : : : : :- * Filter (9) - : : : : : : +- * ColumnarToRow (8) - : : : : : : +- Scan parquet default.store_sales (7) - : : : : : +- BroadcastExchange (27) - : : : : : +- * BroadcastHashJoin LeftSemi BuildRight (26) - : : : : : :- * Filter (12) - : : : : : : +- * ColumnarToRow (11) - : : : : : : +- Scan parquet default.item (10) - : : : : : +- BroadcastExchange (25) - : : : : : +- * Project (24) - : : : : : +- * BroadcastHashJoin Inner BuildRight (23) - : : : : : :- * Project (21) - : : : : : : +- * BroadcastHashJoin Inner BuildRight (20) - : : : : : : :- * Filter (15) - : : : : : : : +- * ColumnarToRow (14) - : : : : : : : +- Scan parquet default.catalog_sales (13) - : : : : : : +- BroadcastExchange (19) - : : : : : : +- * Filter (18) - : : : : : : +- * ColumnarToRow (17) - : : : : : : +- Scan parquet default.item (16) - : : : : : +- ReusedExchange (22) - : : : : +- ReusedExchange (30) - : : : +- BroadcastExchange (45) - : : : +- * Project (44) - : : : +- * BroadcastHashJoin Inner BuildRight (43) - : : : :- * Project (41) - : : : : +- * BroadcastHashJoin Inner BuildRight (40) - : : : : :- * Filter (38) - : : : : : +- * ColumnarToRow (37) - : : : : : +- Scan parquet default.web_sales (36) - : : : : +- ReusedExchange (39) - : : : +- ReusedExchange (42) - : : +- BroadcastExchange (59) - : : +- * BroadcastHashJoin LeftSemi BuildRight (58) - : : :- * Filter (56) - : : : +- * ColumnarToRow (55) - : : : +- Scan parquet default.item (54) - : : +- ReusedExchange (57) - : +- ReusedExchange (62) - +- BroadcastExchange (84) - +- * Filter (83) - +- * HashAggregate (82) - +- Exchange (81) - +- * HashAggregate (80) - +- * Project (79) - +- * BroadcastHashJoin Inner BuildRight (78) - :- * Project (76) - : +- * BroadcastHashJoin Inner BuildRight (75) - : :- * BroadcastHashJoin LeftSemi BuildRight (73) - : : :- * Filter (71) - : : : +- * ColumnarToRow (70) - : : : +- Scan parquet default.store_sales (69) - : : +- ReusedExchange (72) - : +- ReusedExchange (74) - +- ReusedExchange (77) + : : : +- BroadcastExchange (47) + : : : +- * BroadcastHashJoin LeftSemi BuildRight (46) + : : : :- * HashAggregate (35) + : : : : +- Exchange (34) + : : : : +- * HashAggregate (33) + : : : : +- * Project (32) + : : : : +- * BroadcastHashJoin Inner BuildRight (31) + : : : : :- * Project (29) + : : : : : +- * BroadcastHashJoin Inner BuildRight (28) + : : : : : :- * Filter (9) + : : : : : : +- * ColumnarToRow (8) + : : : : : : +- Scan parquet default.store_sales (7) + : : : : : +- BroadcastExchange (27) + : : : : : +- * BroadcastHashJoin LeftSemi BuildRight (26) + : : : : : :- * Filter (12) + : : : : : : +- * ColumnarToRow (11) + : : : : : : +- Scan parquet default.item (10) + : : : : : +- BroadcastExchange (25) + : : : : : +- * Project (24) + : : : : : +- * BroadcastHashJoin Inner BuildRight (23) + : : : : : :- * Project (21) + : : : : : : +- * BroadcastHashJoin Inner BuildRight (20) + : : : : : : :- * Filter (15) + : : : : : : : +- * ColumnarToRow (14) + : : : : : : : +- Scan parquet default.catalog_sales (13) + : : : : : : +- BroadcastExchange (19) + : : : : : : +- * Filter (18) + : : : : : : +- * ColumnarToRow (17) + : : : : : : +- Scan parquet default.item (16) + : : : : : +- ReusedExchange (22) + : : : : +- ReusedExchange (30) + : : : +- BroadcastExchange (45) + : : : +- * Project (44) + : : : +- * BroadcastHashJoin Inner BuildRight (43) + : : : :- * Project (41) + : : : : +- * BroadcastHashJoin Inner BuildRight (40) + : : : : :- * Filter (38) + : : : : : +- * ColumnarToRow (37) + : : : : : +- Scan parquet default.web_sales (36) + : : : : +- ReusedExchange (39) + : : : +- ReusedExchange (42) + : : +- BroadcastExchange (57) + : : +- * BroadcastHashJoin LeftSemi BuildRight (56) + : : :- * Filter (54) + : : : +- * ColumnarToRow (53) + : : : +- Scan parquet default.item (52) + : : +- ReusedExchange (55) + : +- ReusedExchange (60) + +- BroadcastExchange (82) + +- * Filter (81) + +- * HashAggregate (80) + +- Exchange (79) + +- * HashAggregate (78) + +- * Project (77) + +- * BroadcastHashJoin Inner BuildRight (76) + :- * Project (74) + : +- * BroadcastHashJoin Inner BuildRight (73) + : :- * BroadcastHashJoin LeftSemi BuildRight (71) + : : :- * Filter (69) + : : : +- * ColumnarToRow (68) + : : : +- Scan parquet default.store_sales (67) + : : +- ReusedExchange (70) + : +- ReusedExchange (72) + +- ReusedExchange (75) (1) Scan parquet default.store_sales @@ -187,7 +185,7 @@ Join condition: None Output [4]: [cs_sold_date_sk#18, i_brand_id#20, i_class_id#21, i_category_id#22] Input [6]: [cs_item_sk#17, cs_sold_date_sk#18, i_item_sk#19, i_brand_id#20, i_class_id#21, i_category_id#22] -(22) ReusedExchange [Reuses operator id: 119] +(22) ReusedExchange [Reuses operator id: 117] Output [1]: [d_date_sk#24] (23) BroadcastHashJoin [codegen id : 3] @@ -221,7 +219,7 @@ Join condition: None Output [4]: [ss_sold_date_sk#11, i_brand_id#14, i_class_id#15, i_category_id#16] Input [6]: [ss_item_sk#10, ss_sold_date_sk#11, i_item_sk#13, i_brand_id#14, i_class_id#15, i_category_id#16] -(30) ReusedExchange [Reuses operator id: 119] +(30) ReusedExchange [Reuses operator id: 117] Output [1]: [d_date_sk#27] (31) BroadcastHashJoin [codegen id : 6] @@ -278,7 +276,7 @@ Join condition: None Output [4]: [ws_sold_date_sk#33, i_brand_id#35, i_class_id#36, i_category_id#37] Input [6]: [ws_item_sk#32, ws_sold_date_sk#33, i_item_sk#34, i_brand_id#35, i_class_id#36, i_category_id#37] -(42) ReusedExchange [Reuses operator id: 119] +(42) ReusedExchange [Reuses operator id: 117] Output [1]: [d_date_sk#38] (43) BroadcastHashJoin [codegen id : 9] @@ -299,112 +297,98 @@ Left keys [6]: [coalesce(brand_id#28, 0), isnull(brand_id#28), coalesce(class_id Right keys [6]: [coalesce(i_brand_id#35, 0), isnull(i_brand_id#35), coalesce(i_class_id#36, 0), isnull(i_class_id#36), coalesce(i_category_id#37, 0), isnull(i_category_id#37)] Join condition: None -(47) HashAggregate [codegen id : 10] -Input [3]: [brand_id#28, class_id#29, category_id#30] -Keys [3]: [brand_id#28, class_id#29, category_id#30] -Functions: [] -Aggregate Attributes: [] -Results [3]: [brand_id#28, class_id#29, category_id#30] - -(48) HashAggregate [codegen id : 10] -Input [3]: [brand_id#28, class_id#29, category_id#30] -Keys [3]: [brand_id#28, class_id#29, category_id#30] -Functions: [] -Aggregate Attributes: [] -Results [3]: [brand_id#28, class_id#29, category_id#30] - -(49) BroadcastExchange +(47) BroadcastExchange Input [3]: [brand_id#28, class_id#29, category_id#30] Arguments: HashedRelationBroadcastMode(List(input[0, int, true], input[1, int, true], input[2, int, true]),false), [id=#40] -(50) BroadcastHashJoin [codegen id : 11] +(48) BroadcastHashJoin [codegen id : 11] Left keys [3]: [i_brand_id#7, i_class_id#8, i_category_id#9] Right keys [3]: [brand_id#28, class_id#29, category_id#30] Join condition: None -(51) Project [codegen id : 11] +(49) Project [codegen id : 11] Output [1]: [i_item_sk#6 AS ss_item_sk#41] Input [7]: [i_item_sk#6, i_brand_id#7, i_class_id#8, i_category_id#9, brand_id#28, class_id#29, category_id#30] -(52) BroadcastExchange +(50) BroadcastExchange Input [1]: [ss_item_sk#41] Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#42] -(53) BroadcastHashJoin [codegen id : 25] +(51) BroadcastHashJoin [codegen id : 25] Left keys [1]: [ss_item_sk#1] Right keys [1]: [ss_item_sk#41] Join condition: None -(54) Scan parquet default.item +(52) Scan parquet default.item Output [4]: [i_item_sk#43, i_brand_id#44, i_class_id#45, i_category_id#46] Batched: true Location [not included in comparison]/{warehouse_dir}/item] PushedFilters: [IsNotNull(i_item_sk), IsNotNull(i_brand_id), IsNotNull(i_class_id), IsNotNull(i_category_id)] ReadSchema: struct -(55) ColumnarToRow [codegen id : 23] +(53) ColumnarToRow [codegen id : 23] Input [4]: [i_item_sk#43, i_brand_id#44, i_class_id#45, i_category_id#46] -(56) Filter [codegen id : 23] +(54) Filter [codegen id : 23] Input [4]: [i_item_sk#43, i_brand_id#44, i_class_id#45, i_category_id#46] Condition : (((isnotnull(i_item_sk#43) AND isnotnull(i_brand_id#44)) AND isnotnull(i_class_id#45)) AND isnotnull(i_category_id#46)) -(57) ReusedExchange [Reuses operator id: 52] +(55) ReusedExchange [Reuses operator id: 50] Output [1]: [ss_item_sk#41] -(58) BroadcastHashJoin [codegen id : 23] +(56) BroadcastHashJoin [codegen id : 23] Left keys [1]: [i_item_sk#43] Right keys [1]: [ss_item_sk#41] Join condition: None -(59) BroadcastExchange +(57) BroadcastExchange Input [4]: [i_item_sk#43, i_brand_id#44, i_class_id#45, i_category_id#46] Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#47] -(60) BroadcastHashJoin [codegen id : 25] +(58) BroadcastHashJoin [codegen id : 25] Left keys [1]: [ss_item_sk#1] Right keys [1]: [i_item_sk#43] Join condition: None -(61) Project [codegen id : 25] +(59) Project [codegen id : 25] Output [6]: [ss_quantity#2, ss_list_price#3, ss_sold_date_sk#4, i_brand_id#44, i_class_id#45, i_category_id#46] Input [8]: [ss_item_sk#1, ss_quantity#2, ss_list_price#3, ss_sold_date_sk#4, i_item_sk#43, i_brand_id#44, i_class_id#45, i_category_id#46] -(62) ReusedExchange [Reuses operator id: 110] +(60) ReusedExchange [Reuses operator id: 108] Output [1]: [d_date_sk#48] -(63) BroadcastHashJoin [codegen id : 25] +(61) BroadcastHashJoin [codegen id : 25] Left keys [1]: [ss_sold_date_sk#4] Right keys [1]: [d_date_sk#48] Join condition: None -(64) Project [codegen id : 25] +(62) Project [codegen id : 25] Output [5]: [ss_quantity#2, ss_list_price#3, i_brand_id#44, i_class_id#45, i_category_id#46] Input [7]: [ss_quantity#2, ss_list_price#3, ss_sold_date_sk#4, i_brand_id#44, i_class_id#45, i_category_id#46, d_date_sk#48] -(65) HashAggregate [codegen id : 25] +(63) HashAggregate [codegen id : 25] Input [5]: [ss_quantity#2, ss_list_price#3, i_brand_id#44, i_class_id#45, i_category_id#46] Keys [3]: [i_brand_id#44, i_class_id#45, i_category_id#46] -Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true)), partial_count(1)] +Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2))), partial_count(1)] Aggregate Attributes [3]: [sum#49, isEmpty#50, count#51] Results [6]: [i_brand_id#44, i_class_id#45, i_category_id#46, sum#52, isEmpty#53, count#54] -(66) Exchange +(64) Exchange Input [6]: [i_brand_id#44, i_class_id#45, i_category_id#46, sum#52, isEmpty#53, count#54] Arguments: hashpartitioning(i_brand_id#44, i_class_id#45, i_category_id#46, 5), ENSURE_REQUIREMENTS, [id=#55] -(67) HashAggregate [codegen id : 52] +(65) HashAggregate [codegen id : 52] Input [6]: [i_brand_id#44, i_class_id#45, i_category_id#46, sum#52, isEmpty#53, count#54] Keys [3]: [i_brand_id#44, i_class_id#45, i_category_id#46] -Functions [2]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true)), count(1)] -Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true))#56, count(1)#57] -Results [6]: [store AS channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true))#56 AS sales#59, count(1)#57 AS number_sales#60] +Functions [2]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2))), count(1)] +Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2)))#56, count(1)#57] +Results [6]: [store AS channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2)))#56 AS sales#59, count(1)#57 AS number_sales#60] -(68) Filter [codegen id : 52] +(66) Filter [codegen id : 52] Input [6]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sales#59, number_sales#60] Condition : (isnotnull(sales#59) AND (cast(sales#59 as decimal(32,6)) > cast(Subquery scalar-subquery#61, [id=#62] as decimal(32,6)))) -(69) Scan parquet default.store_sales +(67) Scan parquet default.store_sales Output [4]: [ss_item_sk#63, ss_quantity#64, ss_list_price#65, ss_sold_date_sk#66] Batched: true Location: InMemoryFileIndex [] @@ -412,278 +396,278 @@ PartitionFilters: [isnotnull(ss_sold_date_sk#66), dynamicpruningexpression(ss_so PushedFilters: [IsNotNull(ss_item_sk)] ReadSchema: struct -(70) ColumnarToRow [codegen id : 50] +(68) ColumnarToRow [codegen id : 50] Input [4]: [ss_item_sk#63, ss_quantity#64, ss_list_price#65, ss_sold_date_sk#66] -(71) Filter [codegen id : 50] +(69) Filter [codegen id : 50] Input [4]: [ss_item_sk#63, ss_quantity#64, ss_list_price#65, ss_sold_date_sk#66] Condition : isnotnull(ss_item_sk#63) -(72) ReusedExchange [Reuses operator id: 52] +(70) ReusedExchange [Reuses operator id: 50] Output [1]: [ss_item_sk#41] -(73) BroadcastHashJoin [codegen id : 50] +(71) BroadcastHashJoin [codegen id : 50] Left keys [1]: [ss_item_sk#63] Right keys [1]: [ss_item_sk#41] Join condition: None -(74) ReusedExchange [Reuses operator id: 59] +(72) ReusedExchange [Reuses operator id: 57] Output [4]: [i_item_sk#68, i_brand_id#69, i_class_id#70, i_category_id#71] -(75) BroadcastHashJoin [codegen id : 50] +(73) BroadcastHashJoin [codegen id : 50] Left keys [1]: [ss_item_sk#63] Right keys [1]: [i_item_sk#68] Join condition: None -(76) Project [codegen id : 50] +(74) Project [codegen id : 50] Output [6]: [ss_quantity#64, ss_list_price#65, ss_sold_date_sk#66, i_brand_id#69, i_class_id#70, i_category_id#71] Input [8]: [ss_item_sk#63, ss_quantity#64, ss_list_price#65, ss_sold_date_sk#66, i_item_sk#68, i_brand_id#69, i_class_id#70, i_category_id#71] -(77) ReusedExchange [Reuses operator id: 124] +(75) ReusedExchange [Reuses operator id: 122] Output [1]: [d_date_sk#72] -(78) BroadcastHashJoin [codegen id : 50] +(76) BroadcastHashJoin [codegen id : 50] Left keys [1]: [ss_sold_date_sk#66] Right keys [1]: [d_date_sk#72] Join condition: None -(79) Project [codegen id : 50] +(77) Project [codegen id : 50] Output [5]: [ss_quantity#64, ss_list_price#65, i_brand_id#69, i_class_id#70, i_category_id#71] Input [7]: [ss_quantity#64, ss_list_price#65, ss_sold_date_sk#66, i_brand_id#69, i_class_id#70, i_category_id#71, d_date_sk#72] -(80) HashAggregate [codegen id : 50] +(78) HashAggregate [codegen id : 50] Input [5]: [ss_quantity#64, ss_list_price#65, i_brand_id#69, i_class_id#70, i_category_id#71] Keys [3]: [i_brand_id#69, i_class_id#70, i_category_id#71] -Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#64 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#65 as decimal(12,2)))), DecimalType(18,2), true)), partial_count(1)] +Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#64 as decimal(12,2))) * promote_precision(cast(ss_list_price#65 as decimal(12,2)))), DecimalType(18,2))), partial_count(1)] Aggregate Attributes [3]: [sum#73, isEmpty#74, count#75] Results [6]: [i_brand_id#69, i_class_id#70, i_category_id#71, sum#76, isEmpty#77, count#78] -(81) Exchange +(79) Exchange Input [6]: [i_brand_id#69, i_class_id#70, i_category_id#71, sum#76, isEmpty#77, count#78] Arguments: hashpartitioning(i_brand_id#69, i_class_id#70, i_category_id#71, 5), ENSURE_REQUIREMENTS, [id=#79] -(82) HashAggregate [codegen id : 51] +(80) HashAggregate [codegen id : 51] Input [6]: [i_brand_id#69, i_class_id#70, i_category_id#71, sum#76, isEmpty#77, count#78] Keys [3]: [i_brand_id#69, i_class_id#70, i_category_id#71] -Functions [2]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#64 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#65 as decimal(12,2)))), DecimalType(18,2), true)), count(1)] -Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#64 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#65 as decimal(12,2)))), DecimalType(18,2), true))#80, count(1)#81] -Results [6]: [store AS channel#82, i_brand_id#69, i_class_id#70, i_category_id#71, sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#64 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#65 as decimal(12,2)))), DecimalType(18,2), true))#80 AS sales#83, count(1)#81 AS number_sales#84] +Functions [2]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#64 as decimal(12,2))) * promote_precision(cast(ss_list_price#65 as decimal(12,2)))), DecimalType(18,2))), count(1)] +Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#64 as decimal(12,2))) * promote_precision(cast(ss_list_price#65 as decimal(12,2)))), DecimalType(18,2)))#80, count(1)#81] +Results [6]: [store AS channel#82, i_brand_id#69, i_class_id#70, i_category_id#71, sum(CheckOverflow((promote_precision(cast(ss_quantity#64 as decimal(12,2))) * promote_precision(cast(ss_list_price#65 as decimal(12,2)))), DecimalType(18,2)))#80 AS sales#83, count(1)#81 AS number_sales#84] -(83) Filter [codegen id : 51] +(81) Filter [codegen id : 51] Input [6]: [channel#82, i_brand_id#69, i_class_id#70, i_category_id#71, sales#83, number_sales#84] Condition : (isnotnull(sales#83) AND (cast(sales#83 as decimal(32,6)) > cast(ReusedSubquery Subquery scalar-subquery#61, [id=#62] as decimal(32,6)))) -(84) BroadcastExchange +(82) BroadcastExchange Input [6]: [channel#82, i_brand_id#69, i_class_id#70, i_category_id#71, sales#83, number_sales#84] Arguments: HashedRelationBroadcastMode(List(input[1, int, true], input[2, int, true], input[3, int, true]),false), [id=#85] -(85) BroadcastHashJoin [codegen id : 52] +(83) BroadcastHashJoin [codegen id : 52] Left keys [3]: [i_brand_id#44, i_class_id#45, i_category_id#46] Right keys [3]: [i_brand_id#69, i_class_id#70, i_category_id#71] Join condition: None -(86) TakeOrderedAndProject +(84) TakeOrderedAndProject Input [12]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sales#59, number_sales#60, channel#82, i_brand_id#69, i_class_id#70, i_category_id#71, sales#83, number_sales#84] Arguments: 100, [i_brand_id#44 ASC NULLS FIRST, i_class_id#45 ASC NULLS FIRST, i_category_id#46 ASC NULLS FIRST], [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sales#59, number_sales#60, channel#82, i_brand_id#69, i_class_id#70, i_category_id#71, sales#83, number_sales#84] ===== Subqueries ===== -Subquery:1 Hosting operator id = 68 Hosting Expression = Subquery scalar-subquery#61, [id=#62] -* HashAggregate (105) -+- Exchange (104) - +- * HashAggregate (103) - +- Union (102) - :- * Project (91) - : +- * BroadcastHashJoin Inner BuildRight (90) - : :- * ColumnarToRow (88) - : : +- Scan parquet default.store_sales (87) - : +- ReusedExchange (89) - :- * Project (96) - : +- * BroadcastHashJoin Inner BuildRight (95) - : :- * ColumnarToRow (93) - : : +- Scan parquet default.catalog_sales (92) - : +- ReusedExchange (94) - +- * Project (101) - +- * BroadcastHashJoin Inner BuildRight (100) - :- * ColumnarToRow (98) - : +- Scan parquet default.web_sales (97) - +- ReusedExchange (99) - - -(87) Scan parquet default.store_sales +Subquery:1 Hosting operator id = 66 Hosting Expression = Subquery scalar-subquery#61, [id=#62] +* HashAggregate (103) ++- Exchange (102) + +- * HashAggregate (101) + +- Union (100) + :- * Project (89) + : +- * BroadcastHashJoin Inner BuildRight (88) + : :- * ColumnarToRow (86) + : : +- Scan parquet default.store_sales (85) + : +- ReusedExchange (87) + :- * Project (94) + : +- * BroadcastHashJoin Inner BuildRight (93) + : :- * ColumnarToRow (91) + : : +- Scan parquet default.catalog_sales (90) + : +- ReusedExchange (92) + +- * Project (99) + +- * BroadcastHashJoin Inner BuildRight (98) + :- * ColumnarToRow (96) + : +- Scan parquet default.web_sales (95) + +- ReusedExchange (97) + + +(85) Scan parquet default.store_sales Output [3]: [ss_quantity#86, ss_list_price#87, ss_sold_date_sk#88] Batched: true Location: InMemoryFileIndex [] PartitionFilters: [isnotnull(ss_sold_date_sk#88), dynamicpruningexpression(ss_sold_date_sk#88 IN dynamicpruning#12)] ReadSchema: struct -(88) ColumnarToRow [codegen id : 2] +(86) ColumnarToRow [codegen id : 2] Input [3]: [ss_quantity#86, ss_list_price#87, ss_sold_date_sk#88] -(89) ReusedExchange [Reuses operator id: 119] +(87) ReusedExchange [Reuses operator id: 117] Output [1]: [d_date_sk#89] -(90) BroadcastHashJoin [codegen id : 2] +(88) BroadcastHashJoin [codegen id : 2] Left keys [1]: [ss_sold_date_sk#88] Right keys [1]: [d_date_sk#89] Join condition: None -(91) Project [codegen id : 2] +(89) Project [codegen id : 2] Output [2]: [ss_quantity#86 AS quantity#90, ss_list_price#87 AS list_price#91] Input [4]: [ss_quantity#86, ss_list_price#87, ss_sold_date_sk#88, d_date_sk#89] -(92) Scan parquet default.catalog_sales +(90) Scan parquet default.catalog_sales Output [3]: [cs_quantity#92, cs_list_price#93, cs_sold_date_sk#94] Batched: true Location: InMemoryFileIndex [] PartitionFilters: [isnotnull(cs_sold_date_sk#94), dynamicpruningexpression(cs_sold_date_sk#94 IN dynamicpruning#12)] ReadSchema: struct -(93) ColumnarToRow [codegen id : 4] +(91) ColumnarToRow [codegen id : 4] Input [3]: [cs_quantity#92, cs_list_price#93, cs_sold_date_sk#94] -(94) ReusedExchange [Reuses operator id: 119] +(92) ReusedExchange [Reuses operator id: 117] Output [1]: [d_date_sk#95] -(95) BroadcastHashJoin [codegen id : 4] +(93) BroadcastHashJoin [codegen id : 4] Left keys [1]: [cs_sold_date_sk#94] Right keys [1]: [d_date_sk#95] Join condition: None -(96) Project [codegen id : 4] +(94) Project [codegen id : 4] Output [2]: [cs_quantity#92 AS quantity#96, cs_list_price#93 AS list_price#97] Input [4]: [cs_quantity#92, cs_list_price#93, cs_sold_date_sk#94, d_date_sk#95] -(97) Scan parquet default.web_sales +(95) Scan parquet default.web_sales Output [3]: [ws_quantity#98, ws_list_price#99, ws_sold_date_sk#100] Batched: true Location: InMemoryFileIndex [] PartitionFilters: [isnotnull(ws_sold_date_sk#100), dynamicpruningexpression(ws_sold_date_sk#100 IN dynamicpruning#12)] ReadSchema: struct -(98) ColumnarToRow [codegen id : 6] +(96) ColumnarToRow [codegen id : 6] Input [3]: [ws_quantity#98, ws_list_price#99, ws_sold_date_sk#100] -(99) ReusedExchange [Reuses operator id: 119] +(97) ReusedExchange [Reuses operator id: 117] Output [1]: [d_date_sk#101] -(100) BroadcastHashJoin [codegen id : 6] +(98) BroadcastHashJoin [codegen id : 6] Left keys [1]: [ws_sold_date_sk#100] Right keys [1]: [d_date_sk#101] Join condition: None -(101) Project [codegen id : 6] +(99) Project [codegen id : 6] Output [2]: [ws_quantity#98 AS quantity#102, ws_list_price#99 AS list_price#103] Input [4]: [ws_quantity#98, ws_list_price#99, ws_sold_date_sk#100, d_date_sk#101] -(102) Union +(100) Union -(103) HashAggregate [codegen id : 7] +(101) HashAggregate [codegen id : 7] Input [2]: [quantity#90, list_price#91] Keys: [] -Functions [1]: [partial_avg(CheckOverflow((promote_precision(cast(cast(quantity#90 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#91 as decimal(12,2)))), DecimalType(18,2), true))] +Functions [1]: [partial_avg(CheckOverflow((promote_precision(cast(quantity#90 as decimal(12,2))) * promote_precision(cast(list_price#91 as decimal(12,2)))), DecimalType(18,2)))] Aggregate Attributes [2]: [sum#104, count#105] Results [2]: [sum#106, count#107] -(104) Exchange +(102) Exchange Input [2]: [sum#106, count#107] Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#108] -(105) HashAggregate [codegen id : 8] +(103) HashAggregate [codegen id : 8] Input [2]: [sum#106, count#107] Keys: [] -Functions [1]: [avg(CheckOverflow((promote_precision(cast(cast(quantity#90 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#91 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [1]: [avg(CheckOverflow((promote_precision(cast(cast(quantity#90 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#91 as decimal(12,2)))), DecimalType(18,2), true))#109] -Results [1]: [avg(CheckOverflow((promote_precision(cast(cast(quantity#90 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#91 as decimal(12,2)))), DecimalType(18,2), true))#109 AS average_sales#110] +Functions [1]: [avg(CheckOverflow((promote_precision(cast(quantity#90 as decimal(12,2))) * promote_precision(cast(list_price#91 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [1]: [avg(CheckOverflow((promote_precision(cast(quantity#90 as decimal(12,2))) * promote_precision(cast(list_price#91 as decimal(12,2)))), DecimalType(18,2)))#109] +Results [1]: [avg(CheckOverflow((promote_precision(cast(quantity#90 as decimal(12,2))) * promote_precision(cast(list_price#91 as decimal(12,2)))), DecimalType(18,2)))#109 AS average_sales#110] -Subquery:2 Hosting operator id = 87 Hosting Expression = ss_sold_date_sk#88 IN dynamicpruning#12 +Subquery:2 Hosting operator id = 85 Hosting Expression = ss_sold_date_sk#88 IN dynamicpruning#12 -Subquery:3 Hosting operator id = 92 Hosting Expression = cs_sold_date_sk#94 IN dynamicpruning#12 +Subquery:3 Hosting operator id = 90 Hosting Expression = cs_sold_date_sk#94 IN dynamicpruning#12 -Subquery:4 Hosting operator id = 97 Hosting Expression = ws_sold_date_sk#100 IN dynamicpruning#12 +Subquery:4 Hosting operator id = 95 Hosting Expression = ws_sold_date_sk#100 IN dynamicpruning#12 Subquery:5 Hosting operator id = 1 Hosting Expression = ss_sold_date_sk#4 IN dynamicpruning#5 -BroadcastExchange (110) -+- * Project (109) - +- * Filter (108) - +- * ColumnarToRow (107) - +- Scan parquet default.date_dim (106) +BroadcastExchange (108) ++- * Project (107) + +- * Filter (106) + +- * ColumnarToRow (105) + +- Scan parquet default.date_dim (104) -(106) Scan parquet default.date_dim +(104) Scan parquet default.date_dim Output [2]: [d_date_sk#48, d_week_seq#111] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_week_seq), IsNotNull(d_date_sk)] ReadSchema: struct -(107) ColumnarToRow [codegen id : 1] +(105) ColumnarToRow [codegen id : 1] Input [2]: [d_date_sk#48, d_week_seq#111] -(108) Filter [codegen id : 1] +(106) Filter [codegen id : 1] Input [2]: [d_date_sk#48, d_week_seq#111] Condition : ((isnotnull(d_week_seq#111) AND (d_week_seq#111 = Subquery scalar-subquery#112, [id=#113])) AND isnotnull(d_date_sk#48)) -(109) Project [codegen id : 1] +(107) Project [codegen id : 1] Output [1]: [d_date_sk#48] Input [2]: [d_date_sk#48, d_week_seq#111] -(110) BroadcastExchange +(108) BroadcastExchange Input [1]: [d_date_sk#48] Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#114] -Subquery:6 Hosting operator id = 108 Hosting Expression = Subquery scalar-subquery#112, [id=#113] -* Project (114) -+- * Filter (113) - +- * ColumnarToRow (112) - +- Scan parquet default.date_dim (111) +Subquery:6 Hosting operator id = 106 Hosting Expression = Subquery scalar-subquery#112, [id=#113] +* Project (112) ++- * Filter (111) + +- * ColumnarToRow (110) + +- Scan parquet default.date_dim (109) -(111) Scan parquet default.date_dim +(109) Scan parquet default.date_dim Output [4]: [d_week_seq#115, d_year#116, d_moy#117, d_dom#118] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), IsNotNull(d_moy), IsNotNull(d_dom), EqualTo(d_year,1999), EqualTo(d_moy,12), EqualTo(d_dom,16)] ReadSchema: struct -(112) ColumnarToRow [codegen id : 1] +(110) ColumnarToRow [codegen id : 1] Input [4]: [d_week_seq#115, d_year#116, d_moy#117, d_dom#118] -(113) Filter [codegen id : 1] +(111) Filter [codegen id : 1] Input [4]: [d_week_seq#115, d_year#116, d_moy#117, d_dom#118] Condition : (((((isnotnull(d_year#116) AND isnotnull(d_moy#117)) AND isnotnull(d_dom#118)) AND (d_year#116 = 1999)) AND (d_moy#117 = 12)) AND (d_dom#118 = 16)) -(114) Project [codegen id : 1] +(112) Project [codegen id : 1] Output [1]: [d_week_seq#115] Input [4]: [d_week_seq#115, d_year#116, d_moy#117, d_dom#118] Subquery:7 Hosting operator id = 7 Hosting Expression = ss_sold_date_sk#11 IN dynamicpruning#12 -BroadcastExchange (119) -+- * Project (118) - +- * Filter (117) - +- * ColumnarToRow (116) - +- Scan parquet default.date_dim (115) +BroadcastExchange (117) ++- * Project (116) + +- * Filter (115) + +- * ColumnarToRow (114) + +- Scan parquet default.date_dim (113) -(115) Scan parquet default.date_dim +(113) Scan parquet default.date_dim Output [2]: [d_date_sk#27, d_year#119] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), GreaterThanOrEqual(d_year,1998), LessThanOrEqual(d_year,2000), IsNotNull(d_date_sk)] ReadSchema: struct -(116) ColumnarToRow [codegen id : 1] +(114) ColumnarToRow [codegen id : 1] Input [2]: [d_date_sk#27, d_year#119] -(117) Filter [codegen id : 1] +(115) Filter [codegen id : 1] Input [2]: [d_date_sk#27, d_year#119] Condition : (((isnotnull(d_year#119) AND (d_year#119 >= 1998)) AND (d_year#119 <= 2000)) AND isnotnull(d_date_sk#27)) -(118) Project [codegen id : 1] +(116) Project [codegen id : 1] Output [1]: [d_date_sk#27] Input [2]: [d_date_sk#27, d_year#119] -(119) BroadcastExchange +(117) BroadcastExchange Input [1]: [d_date_sk#27] Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#120] @@ -691,60 +675,60 @@ Subquery:8 Hosting operator id = 13 Hosting Expression = cs_sold_date_sk#18 IN d Subquery:9 Hosting operator id = 36 Hosting Expression = ws_sold_date_sk#33 IN dynamicpruning#12 -Subquery:10 Hosting operator id = 83 Hosting Expression = ReusedSubquery Subquery scalar-subquery#61, [id=#62] +Subquery:10 Hosting operator id = 81 Hosting Expression = ReusedSubquery Subquery scalar-subquery#61, [id=#62] -Subquery:11 Hosting operator id = 69 Hosting Expression = ss_sold_date_sk#66 IN dynamicpruning#67 -BroadcastExchange (124) -+- * Project (123) - +- * Filter (122) - +- * ColumnarToRow (121) - +- Scan parquet default.date_dim (120) +Subquery:11 Hosting operator id = 67 Hosting Expression = ss_sold_date_sk#66 IN dynamicpruning#67 +BroadcastExchange (122) ++- * Project (121) + +- * Filter (120) + +- * ColumnarToRow (119) + +- Scan parquet default.date_dim (118) -(120) Scan parquet default.date_dim +(118) Scan parquet default.date_dim Output [2]: [d_date_sk#72, d_week_seq#121] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_week_seq), IsNotNull(d_date_sk)] ReadSchema: struct -(121) ColumnarToRow [codegen id : 1] +(119) ColumnarToRow [codegen id : 1] Input [2]: [d_date_sk#72, d_week_seq#121] -(122) Filter [codegen id : 1] +(120) Filter [codegen id : 1] Input [2]: [d_date_sk#72, d_week_seq#121] Condition : ((isnotnull(d_week_seq#121) AND (d_week_seq#121 = Subquery scalar-subquery#122, [id=#123])) AND isnotnull(d_date_sk#72)) -(123) Project [codegen id : 1] +(121) Project [codegen id : 1] Output [1]: [d_date_sk#72] Input [2]: [d_date_sk#72, d_week_seq#121] -(124) BroadcastExchange +(122) BroadcastExchange Input [1]: [d_date_sk#72] Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#124] -Subquery:12 Hosting operator id = 122 Hosting Expression = Subquery scalar-subquery#122, [id=#123] -* Project (128) -+- * Filter (127) - +- * ColumnarToRow (126) - +- Scan parquet default.date_dim (125) +Subquery:12 Hosting operator id = 120 Hosting Expression = Subquery scalar-subquery#122, [id=#123] +* Project (126) ++- * Filter (125) + +- * ColumnarToRow (124) + +- Scan parquet default.date_dim (123) -(125) Scan parquet default.date_dim +(123) Scan parquet default.date_dim Output [4]: [d_week_seq#125, d_year#126, d_moy#127, d_dom#128] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), IsNotNull(d_moy), IsNotNull(d_dom), EqualTo(d_year,1998), EqualTo(d_moy,12), EqualTo(d_dom,16)] ReadSchema: struct -(126) ColumnarToRow [codegen id : 1] +(124) ColumnarToRow [codegen id : 1] Input [4]: [d_week_seq#125, d_year#126, d_moy#127, d_dom#128] -(127) Filter [codegen id : 1] +(125) Filter [codegen id : 1] Input [4]: [d_week_seq#125, d_year#126, d_moy#127, d_dom#128] Condition : (((((isnotnull(d_year#126) AND isnotnull(d_moy#127)) AND isnotnull(d_dom#128)) AND (d_year#126 = 1998)) AND (d_moy#127 = 12)) AND (d_dom#128 = 16)) -(128) Project [codegen id : 1] +(126) Project [codegen id : 1] Output [1]: [d_week_seq#125] Input [4]: [d_week_seq#125, d_year#126, d_moy#127, d_dom#128] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14/simplified.txt index 2df0810ddba28..259178d0e432f 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14/simplified.txt @@ -4,7 +4,7 @@ TakeOrderedAndProject [i_brand_id,i_class_id,i_category_id,channel,sales,number_ Filter [sales] Subquery #4 WholeStageCodegen (8) - HashAggregate [sum,count] [avg(CheckOverflow((promote_precision(cast(cast(quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price as decimal(12,2)))), DecimalType(18,2), true)),average_sales,sum,count] + HashAggregate [sum,count] [avg(CheckOverflow((promote_precision(cast(quantity as decimal(12,2))) * promote_precision(cast(list_price as decimal(12,2)))), DecimalType(18,2))),average_sales,sum,count] InputAdapter Exchange #12 WholeStageCodegen (7) @@ -38,7 +38,7 @@ TakeOrderedAndProject [i_brand_id,i_class_id,i_category_id,channel,sales,number_ ReusedSubquery [d_date_sk] #3 InputAdapter ReusedExchange [d_date_sk] #6 - HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price as decimal(12,2)))), DecimalType(18,2), true)),count(1),channel,sales,number_sales,sum,isEmpty,count] + HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(ss_quantity as decimal(12,2))) * promote_precision(cast(ss_list_price as decimal(12,2)))), DecimalType(18,2))),count(1),channel,sales,number_sales,sum,isEmpty,count] InputAdapter Exchange [i_brand_id,i_class_id,i_category_id] #1 WholeStageCodegen (25) @@ -79,77 +79,75 @@ TakeOrderedAndProject [i_brand_id,i_class_id,i_category_id,channel,sales,number_ InputAdapter BroadcastExchange #4 WholeStageCodegen (10) - HashAggregate [brand_id,class_id,category_id] + BroadcastHashJoin [brand_id,class_id,category_id,i_brand_id,i_class_id,i_category_id] HashAggregate [brand_id,class_id,category_id] - BroadcastHashJoin [brand_id,class_id,category_id,i_brand_id,i_class_id,i_category_id] - HashAggregate [brand_id,class_id,category_id] - InputAdapter - Exchange [brand_id,class_id,category_id] #5 - WholeStageCodegen (6) - HashAggregate [brand_id,class_id,category_id] - Project [i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [ss_sold_date_sk,d_date_sk] - Project [ss_sold_date_sk,i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [ss_item_sk,i_item_sk] - Filter [ss_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.store_sales [ss_item_sk,ss_sold_date_sk] - SubqueryBroadcast [d_date_sk] #3 - BroadcastExchange #6 - WholeStageCodegen (1) - Project [d_date_sk] - Filter [d_year,d_date_sk] - ColumnarToRow - InputAdapter - Scan parquet default.date_dim [d_date_sk,d_year] - InputAdapter - BroadcastExchange #7 - WholeStageCodegen (4) - BroadcastHashJoin [i_brand_id,i_class_id,i_category_id,i_brand_id,i_class_id,i_category_id] - Filter [i_item_sk,i_brand_id,i_class_id,i_category_id] - ColumnarToRow - InputAdapter - Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] - InputAdapter - BroadcastExchange #8 - WholeStageCodegen (3) - Project [i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [cs_sold_date_sk,d_date_sk] - Project [cs_sold_date_sk,i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [cs_item_sk,i_item_sk] - Filter [cs_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.catalog_sales [cs_item_sk,cs_sold_date_sk] - ReusedSubquery [d_date_sk] #3 - InputAdapter - BroadcastExchange #9 - WholeStageCodegen (1) - Filter [i_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] - InputAdapter - ReusedExchange [d_date_sk] #6 - InputAdapter - ReusedExchange [d_date_sk] #6 - InputAdapter - BroadcastExchange #10 - WholeStageCodegen (9) + InputAdapter + Exchange [brand_id,class_id,category_id] #5 + WholeStageCodegen (6) + HashAggregate [brand_id,class_id,category_id] Project [i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [ws_sold_date_sk,d_date_sk] - Project [ws_sold_date_sk,i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [ws_item_sk,i_item_sk] - Filter [ws_item_sk] + BroadcastHashJoin [ss_sold_date_sk,d_date_sk] + Project [ss_sold_date_sk,i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [ss_item_sk,i_item_sk] + Filter [ss_item_sk] ColumnarToRow InputAdapter - Scan parquet default.web_sales [ws_item_sk,ws_sold_date_sk] - ReusedSubquery [d_date_sk] #3 + Scan parquet default.store_sales [ss_item_sk,ss_sold_date_sk] + SubqueryBroadcast [d_date_sk] #3 + BroadcastExchange #6 + WholeStageCodegen (1) + Project [d_date_sk] + Filter [d_year,d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_year] InputAdapter - ReusedExchange [i_item_sk,i_brand_id,i_class_id,i_category_id] #9 + BroadcastExchange #7 + WholeStageCodegen (4) + BroadcastHashJoin [i_brand_id,i_class_id,i_category_id,i_brand_id,i_class_id,i_category_id] + Filter [i_item_sk,i_brand_id,i_class_id,i_category_id] + ColumnarToRow + InputAdapter + Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] + InputAdapter + BroadcastExchange #8 + WholeStageCodegen (3) + Project [i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [cs_sold_date_sk,d_date_sk] + Project [cs_sold_date_sk,i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [cs_item_sk,i_item_sk] + Filter [cs_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.catalog_sales [cs_item_sk,cs_sold_date_sk] + ReusedSubquery [d_date_sk] #3 + InputAdapter + BroadcastExchange #9 + WholeStageCodegen (1) + Filter [i_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] + InputAdapter + ReusedExchange [d_date_sk] #6 InputAdapter ReusedExchange [d_date_sk] #6 + InputAdapter + BroadcastExchange #10 + WholeStageCodegen (9) + Project [i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [ws_sold_date_sk,d_date_sk] + Project [ws_sold_date_sk,i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [ws_item_sk,i_item_sk] + Filter [ws_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.web_sales [ws_item_sk,ws_sold_date_sk] + ReusedSubquery [d_date_sk] #3 + InputAdapter + ReusedExchange [i_item_sk,i_brand_id,i_class_id,i_category_id] #9 + InputAdapter + ReusedExchange [d_date_sk] #6 InputAdapter BroadcastExchange #11 WholeStageCodegen (23) @@ -167,7 +165,7 @@ TakeOrderedAndProject [i_brand_id,i_class_id,i_category_id,channel,sales,number_ WholeStageCodegen (51) Filter [sales] ReusedSubquery [average_sales] #4 - HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price as decimal(12,2)))), DecimalType(18,2), true)),count(1),channel,sales,number_sales,sum,isEmpty,count] + HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(ss_quantity as decimal(12,2))) * promote_precision(cast(ss_list_price as decimal(12,2)))), DecimalType(18,2))),count(1),channel,sales,number_sales,sum,isEmpty,count] InputAdapter Exchange [i_brand_id,i_class_id,i_category_id] #14 WholeStageCodegen (50) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14a.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14a.sf100/explain.txt index 5595e1a12b3fc..88d71316966c6 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14a.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14a.sf100/explain.txt @@ -1,150 +1,147 @@ == Physical Plan == -TakeOrderedAndProject (146) -+- * HashAggregate (145) - +- Exchange (144) - +- * HashAggregate (143) - +- Union (142) - :- * HashAggregate (121) - : +- Exchange (120) - : +- * HashAggregate (119) - : +- Union (118) - : :- * Filter (81) - : : +- * HashAggregate (80) - : : +- Exchange (79) - : : +- * HashAggregate (78) - : : +- * Project (77) - : : +- * BroadcastHashJoin Inner BuildRight (76) - : : :- * Project (66) - : : : +- * BroadcastHashJoin Inner BuildRight (65) - : : : :- * SortMergeJoin LeftSemi (63) +TakeOrderedAndProject (143) ++- * HashAggregate (142) + +- Exchange (141) + +- * HashAggregate (140) + +- Union (139) + :- * HashAggregate (118) + : +- Exchange (117) + : +- * HashAggregate (116) + : +- Union (115) + : :- * Filter (78) + : : +- * HashAggregate (77) + : : +- Exchange (76) + : : +- * HashAggregate (75) + : : +- * Project (74) + : : +- * BroadcastHashJoin Inner BuildRight (73) + : : :- * Project (63) + : : : +- * BroadcastHashJoin Inner BuildRight (62) + : : : :- * SortMergeJoin LeftSemi (60) : : : : :- * Sort (5) : : : : : +- Exchange (4) : : : : : +- * Filter (3) : : : : : +- * ColumnarToRow (2) : : : : : +- Scan parquet default.store_sales (1) - : : : : +- * Sort (62) - : : : : +- Exchange (61) - : : : : +- * Project (60) - : : : : +- * BroadcastHashJoin Inner BuildRight (59) + : : : : +- * Sort (59) + : : : : +- Exchange (58) + : : : : +- * Project (57) + : : : : +- * BroadcastHashJoin Inner BuildRight (56) : : : : :- * Filter (8) : : : : : +- * ColumnarToRow (7) : : : : : +- Scan parquet default.item (6) - : : : : +- BroadcastExchange (58) - : : : : +- * HashAggregate (57) - : : : : +- Exchange (56) - : : : : +- * HashAggregate (55) - : : : : +- * SortMergeJoin LeftSemi (54) - : : : : :- * Sort (42) - : : : : : +- Exchange (41) - : : : : : +- * HashAggregate (40) - : : : : : +- Exchange (39) - : : : : : +- * HashAggregate (38) - : : : : : +- * Project (37) - : : : : : +- * BroadcastHashJoin Inner BuildRight (36) - : : : : : :- * Project (14) - : : : : : : +- * BroadcastHashJoin Inner BuildRight (13) - : : : : : : :- * Filter (11) - : : : : : : : +- * ColumnarToRow (10) - : : : : : : : +- Scan parquet default.store_sales (9) - : : : : : : +- ReusedExchange (12) - : : : : : +- BroadcastExchange (35) - : : : : : +- * SortMergeJoin LeftSemi (34) - : : : : : :- * Sort (19) - : : : : : : +- Exchange (18) - : : : : : : +- * Filter (17) - : : : : : : +- * ColumnarToRow (16) - : : : : : : +- Scan parquet default.item (15) - : : : : : +- * Sort (33) - : : : : : +- Exchange (32) - : : : : : +- * Project (31) - : : : : : +- * BroadcastHashJoin Inner BuildRight (30) - : : : : : :- * Project (25) - : : : : : : +- * BroadcastHashJoin Inner BuildRight (24) - : : : : : : :- * Filter (22) - : : : : : : : +- * ColumnarToRow (21) - : : : : : : : +- Scan parquet default.catalog_sales (20) - : : : : : : +- ReusedExchange (23) - : : : : : +- BroadcastExchange (29) - : : : : : +- * Filter (28) - : : : : : +- * ColumnarToRow (27) - : : : : : +- Scan parquet default.item (26) - : : : : +- * Sort (53) - : : : : +- Exchange (52) - : : : : +- * Project (51) - : : : : +- * BroadcastHashJoin Inner BuildRight (50) - : : : : :- * Project (48) - : : : : : +- * BroadcastHashJoin Inner BuildRight (47) - : : : : : :- * Filter (45) - : : : : : : +- * ColumnarToRow (44) - : : : : : : +- Scan parquet default.web_sales (43) - : : : : : +- ReusedExchange (46) - : : : : +- ReusedExchange (49) - : : : +- ReusedExchange (64) - : : +- BroadcastExchange (75) - : : +- * SortMergeJoin LeftSemi (74) - : : :- * Sort (71) - : : : +- Exchange (70) - : : : +- * Filter (69) - : : : +- * ColumnarToRow (68) - : : : +- Scan parquet default.item (67) - : : +- * Sort (73) - : : +- ReusedExchange (72) - : :- * Filter (99) - : : +- * HashAggregate (98) - : : +- Exchange (97) - : : +- * HashAggregate (96) - : : +- * Project (95) - : : +- * BroadcastHashJoin Inner BuildRight (94) - : : :- * Project (92) - : : : +- * BroadcastHashJoin Inner BuildRight (91) - : : : :- * SortMergeJoin LeftSemi (89) - : : : : :- * Sort (86) - : : : : : +- Exchange (85) - : : : : : +- * Filter (84) - : : : : : +- * ColumnarToRow (83) - : : : : : +- Scan parquet default.catalog_sales (82) - : : : : +- * Sort (88) - : : : : +- ReusedExchange (87) - : : : +- ReusedExchange (90) - : : +- ReusedExchange (93) - : +- * Filter (117) - : +- * HashAggregate (116) - : +- Exchange (115) - : +- * HashAggregate (114) - : +- * Project (113) - : +- * BroadcastHashJoin Inner BuildRight (112) - : :- * Project (110) - : : +- * BroadcastHashJoin Inner BuildRight (109) - : : :- * SortMergeJoin LeftSemi (107) - : : : :- * Sort (104) - : : : : +- Exchange (103) - : : : : +- * Filter (102) - : : : : +- * ColumnarToRow (101) - : : : : +- Scan parquet default.web_sales (100) - : : : +- * Sort (106) - : : : +- ReusedExchange (105) - : : +- ReusedExchange (108) - : +- ReusedExchange (111) - :- * HashAggregate (126) - : +- Exchange (125) - : +- * HashAggregate (124) - : +- * HashAggregate (123) - : +- ReusedExchange (122) - :- * HashAggregate (131) - : +- Exchange (130) - : +- * HashAggregate (129) - : +- * HashAggregate (128) - : +- ReusedExchange (127) - :- * HashAggregate (136) - : +- Exchange (135) - : +- * HashAggregate (134) - : +- * HashAggregate (133) - : +- ReusedExchange (132) - +- * HashAggregate (141) - +- Exchange (140) - +- * HashAggregate (139) - +- * HashAggregate (138) - +- ReusedExchange (137) + : : : : +- BroadcastExchange (55) + : : : : +- * SortMergeJoin LeftSemi (54) + : : : : :- * Sort (42) + : : : : : +- Exchange (41) + : : : : : +- * HashAggregate (40) + : : : : : +- Exchange (39) + : : : : : +- * HashAggregate (38) + : : : : : +- * Project (37) + : : : : : +- * BroadcastHashJoin Inner BuildRight (36) + : : : : : :- * Project (14) + : : : : : : +- * BroadcastHashJoin Inner BuildRight (13) + : : : : : : :- * Filter (11) + : : : : : : : +- * ColumnarToRow (10) + : : : : : : : +- Scan parquet default.store_sales (9) + : : : : : : +- ReusedExchange (12) + : : : : : +- BroadcastExchange (35) + : : : : : +- * SortMergeJoin LeftSemi (34) + : : : : : :- * Sort (19) + : : : : : : +- Exchange (18) + : : : : : : +- * Filter (17) + : : : : : : +- * ColumnarToRow (16) + : : : : : : +- Scan parquet default.item (15) + : : : : : +- * Sort (33) + : : : : : +- Exchange (32) + : : : : : +- * Project (31) + : : : : : +- * BroadcastHashJoin Inner BuildRight (30) + : : : : : :- * Project (25) + : : : : : : +- * BroadcastHashJoin Inner BuildRight (24) + : : : : : : :- * Filter (22) + : : : : : : : +- * ColumnarToRow (21) + : : : : : : : +- Scan parquet default.catalog_sales (20) + : : : : : : +- ReusedExchange (23) + : : : : : +- BroadcastExchange (29) + : : : : : +- * Filter (28) + : : : : : +- * ColumnarToRow (27) + : : : : : +- Scan parquet default.item (26) + : : : : +- * Sort (53) + : : : : +- Exchange (52) + : : : : +- * Project (51) + : : : : +- * BroadcastHashJoin Inner BuildRight (50) + : : : : :- * Project (48) + : : : : : +- * BroadcastHashJoin Inner BuildRight (47) + : : : : : :- * Filter (45) + : : : : : : +- * ColumnarToRow (44) + : : : : : : +- Scan parquet default.web_sales (43) + : : : : : +- ReusedExchange (46) + : : : : +- ReusedExchange (49) + : : : +- ReusedExchange (61) + : : +- BroadcastExchange (72) + : : +- * SortMergeJoin LeftSemi (71) + : : :- * Sort (68) + : : : +- Exchange (67) + : : : +- * Filter (66) + : : : +- * ColumnarToRow (65) + : : : +- Scan parquet default.item (64) + : : +- * Sort (70) + : : +- ReusedExchange (69) + : :- * Filter (96) + : : +- * HashAggregate (95) + : : +- Exchange (94) + : : +- * HashAggregate (93) + : : +- * Project (92) + : : +- * BroadcastHashJoin Inner BuildRight (91) + : : :- * Project (89) + : : : +- * BroadcastHashJoin Inner BuildRight (88) + : : : :- * SortMergeJoin LeftSemi (86) + : : : : :- * Sort (83) + : : : : : +- Exchange (82) + : : : : : +- * Filter (81) + : : : : : +- * ColumnarToRow (80) + : : : : : +- Scan parquet default.catalog_sales (79) + : : : : +- * Sort (85) + : : : : +- ReusedExchange (84) + : : : +- ReusedExchange (87) + : : +- ReusedExchange (90) + : +- * Filter (114) + : +- * HashAggregate (113) + : +- Exchange (112) + : +- * HashAggregate (111) + : +- * Project (110) + : +- * BroadcastHashJoin Inner BuildRight (109) + : :- * Project (107) + : : +- * BroadcastHashJoin Inner BuildRight (106) + : : :- * SortMergeJoin LeftSemi (104) + : : : :- * Sort (101) + : : : : +- Exchange (100) + : : : : +- * Filter (99) + : : : : +- * ColumnarToRow (98) + : : : : +- Scan parquet default.web_sales (97) + : : : +- * Sort (103) + : : : +- ReusedExchange (102) + : : +- ReusedExchange (105) + : +- ReusedExchange (108) + :- * HashAggregate (123) + : +- Exchange (122) + : +- * HashAggregate (121) + : +- * HashAggregate (120) + : +- ReusedExchange (119) + :- * HashAggregate (128) + : +- Exchange (127) + : +- * HashAggregate (126) + : +- * HashAggregate (125) + : +- ReusedExchange (124) + :- * HashAggregate (133) + : +- Exchange (132) + : +- * HashAggregate (131) + : +- * HashAggregate (130) + : +- ReusedExchange (129) + +- * HashAggregate (138) + +- Exchange (137) + +- * HashAggregate (136) + +- * HashAggregate (135) + +- ReusedExchange (134) (1) Scan parquet default.store_sales @@ -177,10 +174,10 @@ Location [not included in comparison]/{warehouse_dir}/item] PushedFilters: [IsNotNull(i_brand_id), IsNotNull(i_class_id), IsNotNull(i_category_id)] ReadSchema: struct -(7) ColumnarToRow [codegen id : 20] +(7) ColumnarToRow [codegen id : 19] Input [4]: [i_item_sk#7, i_brand_id#8, i_class_id#9, i_category_id#10] -(8) Filter [codegen id : 20] +(8) Filter [codegen id : 19] Input [4]: [i_item_sk#7, i_brand_id#8, i_class_id#9, i_category_id#10] Condition : ((isnotnull(i_brand_id#8) AND isnotnull(i_class_id#9)) AND isnotnull(i_category_id#10)) @@ -199,7 +196,7 @@ Input [2]: [ss_item_sk#11, ss_sold_date_sk#12] Input [2]: [ss_item_sk#11, ss_sold_date_sk#12] Condition : isnotnull(ss_item_sk#11) -(12) ReusedExchange [Reuses operator id: 180] +(12) ReusedExchange [Reuses operator id: 177] Output [1]: [d_date_sk#14] (13) BroadcastHashJoin [codegen id : 11] @@ -248,7 +245,7 @@ Input [2]: [cs_item_sk#20, cs_sold_date_sk#21] Input [2]: [cs_item_sk#20, cs_sold_date_sk#21] Condition : isnotnull(cs_item_sk#20) -(23) ReusedExchange [Reuses operator id: 180] +(23) ReusedExchange [Reuses operator id: 177] Output [1]: [d_date_sk#22] (24) BroadcastHashJoin [codegen id : 8] @@ -354,7 +351,7 @@ Input [2]: [ws_item_sk#35, ws_sold_date_sk#36] Input [2]: [ws_item_sk#35, ws_sold_date_sk#36] Condition : isnotnull(ws_item_sk#35) -(46) ReusedExchange [Reuses operator id: 180] +(46) ReusedExchange [Reuses operator id: 177] Output [1]: [d_date_sk#37] (47) BroadcastHashJoin [codegen id : 16] @@ -391,663 +388,645 @@ Left keys [6]: [coalesce(brand_id#30, 0), isnull(brand_id#30), coalesce(class_id Right keys [6]: [coalesce(i_brand_id#39, 0), isnull(i_brand_id#39), coalesce(i_class_id#40, 0), isnull(i_class_id#40), coalesce(i_category_id#41, 0), isnull(i_category_id#41)] Join condition: None -(55) HashAggregate [codegen id : 18] +(55) BroadcastExchange Input [3]: [brand_id#30, class_id#31, category_id#32] -Keys [3]: [brand_id#30, class_id#31, category_id#32] -Functions: [] -Aggregate Attributes: [] -Results [3]: [brand_id#30, class_id#31, category_id#32] - -(56) Exchange -Input [3]: [brand_id#30, class_id#31, category_id#32] -Arguments: hashpartitioning(brand_id#30, class_id#31, category_id#32, 5), ENSURE_REQUIREMENTS, [id=#43] - -(57) HashAggregate [codegen id : 19] -Input [3]: [brand_id#30, class_id#31, category_id#32] -Keys [3]: [brand_id#30, class_id#31, category_id#32] -Functions: [] -Aggregate Attributes: [] -Results [3]: [brand_id#30, class_id#31, category_id#32] - -(58) BroadcastExchange -Input [3]: [brand_id#30, class_id#31, category_id#32] -Arguments: HashedRelationBroadcastMode(List(input[0, int, true], input[1, int, true], input[2, int, true]),false), [id=#44] +Arguments: HashedRelationBroadcastMode(List(input[0, int, true], input[1, int, true], input[2, int, true]),false), [id=#43] -(59) BroadcastHashJoin [codegen id : 20] +(56) BroadcastHashJoin [codegen id : 19] Left keys [3]: [i_brand_id#8, i_class_id#9, i_category_id#10] Right keys [3]: [brand_id#30, class_id#31, category_id#32] Join condition: None -(60) Project [codegen id : 20] -Output [1]: [i_item_sk#7 AS ss_item_sk#45] +(57) Project [codegen id : 19] +Output [1]: [i_item_sk#7 AS ss_item_sk#44] Input [7]: [i_item_sk#7, i_brand_id#8, i_class_id#9, i_category_id#10, brand_id#30, class_id#31, category_id#32] -(61) Exchange -Input [1]: [ss_item_sk#45] -Arguments: hashpartitioning(ss_item_sk#45, 5), ENSURE_REQUIREMENTS, [id=#46] +(58) Exchange +Input [1]: [ss_item_sk#44] +Arguments: hashpartitioning(ss_item_sk#44, 5), ENSURE_REQUIREMENTS, [id=#45] -(62) Sort [codegen id : 21] -Input [1]: [ss_item_sk#45] -Arguments: [ss_item_sk#45 ASC NULLS FIRST], false, 0 +(59) Sort [codegen id : 20] +Input [1]: [ss_item_sk#44] +Arguments: [ss_item_sk#44 ASC NULLS FIRST], false, 0 -(63) SortMergeJoin [codegen id : 45] +(60) SortMergeJoin [codegen id : 43] Left keys [1]: [ss_item_sk#1] -Right keys [1]: [ss_item_sk#45] +Right keys [1]: [ss_item_sk#44] Join condition: None -(64) ReusedExchange [Reuses operator id: 175] -Output [1]: [d_date_sk#47] +(61) ReusedExchange [Reuses operator id: 172] +Output [1]: [d_date_sk#46] -(65) BroadcastHashJoin [codegen id : 45] +(62) BroadcastHashJoin [codegen id : 43] Left keys [1]: [ss_sold_date_sk#4] -Right keys [1]: [d_date_sk#47] +Right keys [1]: [d_date_sk#46] Join condition: None -(66) Project [codegen id : 45] +(63) Project [codegen id : 43] Output [3]: [ss_item_sk#1, ss_quantity#2, ss_list_price#3] -Input [5]: [ss_item_sk#1, ss_quantity#2, ss_list_price#3, ss_sold_date_sk#4, d_date_sk#47] +Input [5]: [ss_item_sk#1, ss_quantity#2, ss_list_price#3, ss_sold_date_sk#4, d_date_sk#46] -(67) Scan parquet default.item -Output [4]: [i_item_sk#48, i_brand_id#49, i_class_id#50, i_category_id#51] +(64) Scan parquet default.item +Output [4]: [i_item_sk#47, i_brand_id#48, i_class_id#49, i_category_id#50] Batched: true Location [not included in comparison]/{warehouse_dir}/item] PushedFilters: [IsNotNull(i_item_sk)] ReadSchema: struct -(68) ColumnarToRow [codegen id : 23] -Input [4]: [i_item_sk#48, i_brand_id#49, i_class_id#50, i_category_id#51] +(65) ColumnarToRow [codegen id : 22] +Input [4]: [i_item_sk#47, i_brand_id#48, i_class_id#49, i_category_id#50] -(69) Filter [codegen id : 23] -Input [4]: [i_item_sk#48, i_brand_id#49, i_class_id#50, i_category_id#51] -Condition : isnotnull(i_item_sk#48) +(66) Filter [codegen id : 22] +Input [4]: [i_item_sk#47, i_brand_id#48, i_class_id#49, i_category_id#50] +Condition : isnotnull(i_item_sk#47) -(70) Exchange -Input [4]: [i_item_sk#48, i_brand_id#49, i_class_id#50, i_category_id#51] -Arguments: hashpartitioning(i_item_sk#48, 5), ENSURE_REQUIREMENTS, [id=#52] +(67) Exchange +Input [4]: [i_item_sk#47, i_brand_id#48, i_class_id#49, i_category_id#50] +Arguments: hashpartitioning(i_item_sk#47, 5), ENSURE_REQUIREMENTS, [id=#51] -(71) Sort [codegen id : 24] -Input [4]: [i_item_sk#48, i_brand_id#49, i_class_id#50, i_category_id#51] -Arguments: [i_item_sk#48 ASC NULLS FIRST], false, 0 +(68) Sort [codegen id : 23] +Input [4]: [i_item_sk#47, i_brand_id#48, i_class_id#49, i_category_id#50] +Arguments: [i_item_sk#47 ASC NULLS FIRST], false, 0 -(72) ReusedExchange [Reuses operator id: 61] -Output [1]: [ss_item_sk#45] +(69) ReusedExchange [Reuses operator id: 58] +Output [1]: [ss_item_sk#44] -(73) Sort [codegen id : 43] -Input [1]: [ss_item_sk#45] -Arguments: [ss_item_sk#45 ASC NULLS FIRST], false, 0 +(70) Sort [codegen id : 41] +Input [1]: [ss_item_sk#44] +Arguments: [ss_item_sk#44 ASC NULLS FIRST], false, 0 -(74) SortMergeJoin [codegen id : 44] -Left keys [1]: [i_item_sk#48] -Right keys [1]: [ss_item_sk#45] +(71) SortMergeJoin [codegen id : 42] +Left keys [1]: [i_item_sk#47] +Right keys [1]: [ss_item_sk#44] Join condition: None -(75) BroadcastExchange -Input [4]: [i_item_sk#48, i_brand_id#49, i_class_id#50, i_category_id#51] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#53] +(72) BroadcastExchange +Input [4]: [i_item_sk#47, i_brand_id#48, i_class_id#49, i_category_id#50] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#52] -(76) BroadcastHashJoin [codegen id : 45] +(73) BroadcastHashJoin [codegen id : 43] Left keys [1]: [ss_item_sk#1] -Right keys [1]: [i_item_sk#48] +Right keys [1]: [i_item_sk#47] Join condition: None -(77) Project [codegen id : 45] -Output [5]: [ss_quantity#2, ss_list_price#3, i_brand_id#49, i_class_id#50, i_category_id#51] -Input [7]: [ss_item_sk#1, ss_quantity#2, ss_list_price#3, i_item_sk#48, i_brand_id#49, i_class_id#50, i_category_id#51] - -(78) HashAggregate [codegen id : 45] -Input [5]: [ss_quantity#2, ss_list_price#3, i_brand_id#49, i_class_id#50, i_category_id#51] -Keys [3]: [i_brand_id#49, i_class_id#50, i_category_id#51] -Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true)), partial_count(1)] -Aggregate Attributes [3]: [sum#54, isEmpty#55, count#56] -Results [6]: [i_brand_id#49, i_class_id#50, i_category_id#51, sum#57, isEmpty#58, count#59] - -(79) Exchange -Input [6]: [i_brand_id#49, i_class_id#50, i_category_id#51, sum#57, isEmpty#58, count#59] -Arguments: hashpartitioning(i_brand_id#49, i_class_id#50, i_category_id#51, 5), ENSURE_REQUIREMENTS, [id=#60] - -(80) HashAggregate [codegen id : 46] -Input [6]: [i_brand_id#49, i_class_id#50, i_category_id#51, sum#57, isEmpty#58, count#59] -Keys [3]: [i_brand_id#49, i_class_id#50, i_category_id#51] -Functions [2]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true)), count(1)] -Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true))#61, count(1)#62] -Results [6]: [store AS channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true))#61 AS sales#64, count(1)#62 AS number_sales#65] - -(81) Filter [codegen id : 46] -Input [6]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sales#64, number_sales#65] -Condition : (isnotnull(sales#64) AND (cast(sales#64 as decimal(32,6)) > cast(Subquery scalar-subquery#66, [id=#67] as decimal(32,6)))) - -(82) Scan parquet default.catalog_sales -Output [4]: [cs_item_sk#68, cs_quantity#69, cs_list_price#70, cs_sold_date_sk#71] +(74) Project [codegen id : 43] +Output [5]: [ss_quantity#2, ss_list_price#3, i_brand_id#48, i_class_id#49, i_category_id#50] +Input [7]: [ss_item_sk#1, ss_quantity#2, ss_list_price#3, i_item_sk#47, i_brand_id#48, i_class_id#49, i_category_id#50] + +(75) HashAggregate [codegen id : 43] +Input [5]: [ss_quantity#2, ss_list_price#3, i_brand_id#48, i_class_id#49, i_category_id#50] +Keys [3]: [i_brand_id#48, i_class_id#49, i_category_id#50] +Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2))), partial_count(1)] +Aggregate Attributes [3]: [sum#53, isEmpty#54, count#55] +Results [6]: [i_brand_id#48, i_class_id#49, i_category_id#50, sum#56, isEmpty#57, count#58] + +(76) Exchange +Input [6]: [i_brand_id#48, i_class_id#49, i_category_id#50, sum#56, isEmpty#57, count#58] +Arguments: hashpartitioning(i_brand_id#48, i_class_id#49, i_category_id#50, 5), ENSURE_REQUIREMENTS, [id=#59] + +(77) HashAggregate [codegen id : 44] +Input [6]: [i_brand_id#48, i_class_id#49, i_category_id#50, sum#56, isEmpty#57, count#58] +Keys [3]: [i_brand_id#48, i_class_id#49, i_category_id#50] +Functions [2]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2))), count(1)] +Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2)))#60, count(1)#61] +Results [6]: [store AS channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2)))#60 AS sales#63, count(1)#61 AS number_sales#64] + +(78) Filter [codegen id : 44] +Input [6]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sales#63, number_sales#64] +Condition : (isnotnull(sales#63) AND (cast(sales#63 as decimal(32,6)) > cast(Subquery scalar-subquery#65, [id=#66] as decimal(32,6)))) + +(79) Scan parquet default.catalog_sales +Output [4]: [cs_item_sk#67, cs_quantity#68, cs_list_price#69, cs_sold_date_sk#70] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(cs_sold_date_sk#71), dynamicpruningexpression(cs_sold_date_sk#71 IN dynamicpruning#5)] +PartitionFilters: [isnotnull(cs_sold_date_sk#70), dynamicpruningexpression(cs_sold_date_sk#70 IN dynamicpruning#5)] PushedFilters: [IsNotNull(cs_item_sk)] ReadSchema: struct -(83) ColumnarToRow [codegen id : 47] -Input [4]: [cs_item_sk#68, cs_quantity#69, cs_list_price#70, cs_sold_date_sk#71] +(80) ColumnarToRow [codegen id : 45] +Input [4]: [cs_item_sk#67, cs_quantity#68, cs_list_price#69, cs_sold_date_sk#70] -(84) Filter [codegen id : 47] -Input [4]: [cs_item_sk#68, cs_quantity#69, cs_list_price#70, cs_sold_date_sk#71] -Condition : isnotnull(cs_item_sk#68) +(81) Filter [codegen id : 45] +Input [4]: [cs_item_sk#67, cs_quantity#68, cs_list_price#69, cs_sold_date_sk#70] +Condition : isnotnull(cs_item_sk#67) -(85) Exchange -Input [4]: [cs_item_sk#68, cs_quantity#69, cs_list_price#70, cs_sold_date_sk#71] -Arguments: hashpartitioning(cs_item_sk#68, 5), ENSURE_REQUIREMENTS, [id=#72] +(82) Exchange +Input [4]: [cs_item_sk#67, cs_quantity#68, cs_list_price#69, cs_sold_date_sk#70] +Arguments: hashpartitioning(cs_item_sk#67, 5), ENSURE_REQUIREMENTS, [id=#71] -(86) Sort [codegen id : 48] -Input [4]: [cs_item_sk#68, cs_quantity#69, cs_list_price#70, cs_sold_date_sk#71] -Arguments: [cs_item_sk#68 ASC NULLS FIRST], false, 0 +(83) Sort [codegen id : 46] +Input [4]: [cs_item_sk#67, cs_quantity#68, cs_list_price#69, cs_sold_date_sk#70] +Arguments: [cs_item_sk#67 ASC NULLS FIRST], false, 0 -(87) ReusedExchange [Reuses operator id: 61] -Output [1]: [ss_item_sk#45] +(84) ReusedExchange [Reuses operator id: 58] +Output [1]: [ss_item_sk#44] -(88) Sort [codegen id : 67] -Input [1]: [ss_item_sk#45] -Arguments: [ss_item_sk#45 ASC NULLS FIRST], false, 0 +(85) Sort [codegen id : 64] +Input [1]: [ss_item_sk#44] +Arguments: [ss_item_sk#44 ASC NULLS FIRST], false, 0 -(89) SortMergeJoin [codegen id : 91] -Left keys [1]: [cs_item_sk#68] -Right keys [1]: [ss_item_sk#45] +(86) SortMergeJoin [codegen id : 87] +Left keys [1]: [cs_item_sk#67] +Right keys [1]: [ss_item_sk#44] Join condition: None -(90) ReusedExchange [Reuses operator id: 175] -Output [1]: [d_date_sk#73] +(87) ReusedExchange [Reuses operator id: 172] +Output [1]: [d_date_sk#72] -(91) BroadcastHashJoin [codegen id : 91] -Left keys [1]: [cs_sold_date_sk#71] -Right keys [1]: [d_date_sk#73] +(88) BroadcastHashJoin [codegen id : 87] +Left keys [1]: [cs_sold_date_sk#70] +Right keys [1]: [d_date_sk#72] Join condition: None -(92) Project [codegen id : 91] -Output [3]: [cs_item_sk#68, cs_quantity#69, cs_list_price#70] -Input [5]: [cs_item_sk#68, cs_quantity#69, cs_list_price#70, cs_sold_date_sk#71, d_date_sk#73] +(89) Project [codegen id : 87] +Output [3]: [cs_item_sk#67, cs_quantity#68, cs_list_price#69] +Input [5]: [cs_item_sk#67, cs_quantity#68, cs_list_price#69, cs_sold_date_sk#70, d_date_sk#72] -(93) ReusedExchange [Reuses operator id: 75] -Output [4]: [i_item_sk#74, i_brand_id#75, i_class_id#76, i_category_id#77] +(90) ReusedExchange [Reuses operator id: 72] +Output [4]: [i_item_sk#73, i_brand_id#74, i_class_id#75, i_category_id#76] -(94) BroadcastHashJoin [codegen id : 91] -Left keys [1]: [cs_item_sk#68] -Right keys [1]: [i_item_sk#74] +(91) BroadcastHashJoin [codegen id : 87] +Left keys [1]: [cs_item_sk#67] +Right keys [1]: [i_item_sk#73] Join condition: None -(95) Project [codegen id : 91] -Output [5]: [cs_quantity#69, cs_list_price#70, i_brand_id#75, i_class_id#76, i_category_id#77] -Input [7]: [cs_item_sk#68, cs_quantity#69, cs_list_price#70, i_item_sk#74, i_brand_id#75, i_class_id#76, i_category_id#77] - -(96) HashAggregate [codegen id : 91] -Input [5]: [cs_quantity#69, cs_list_price#70, i_brand_id#75, i_class_id#76, i_category_id#77] -Keys [3]: [i_brand_id#75, i_class_id#76, i_category_id#77] -Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(cast(cs_quantity#69 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price#70 as decimal(12,2)))), DecimalType(18,2), true)), partial_count(1)] -Aggregate Attributes [3]: [sum#78, isEmpty#79, count#80] -Results [6]: [i_brand_id#75, i_class_id#76, i_category_id#77, sum#81, isEmpty#82, count#83] - -(97) Exchange -Input [6]: [i_brand_id#75, i_class_id#76, i_category_id#77, sum#81, isEmpty#82, count#83] -Arguments: hashpartitioning(i_brand_id#75, i_class_id#76, i_category_id#77, 5), ENSURE_REQUIREMENTS, [id=#84] - -(98) HashAggregate [codegen id : 92] -Input [6]: [i_brand_id#75, i_class_id#76, i_category_id#77, sum#81, isEmpty#82, count#83] -Keys [3]: [i_brand_id#75, i_class_id#76, i_category_id#77] -Functions [2]: [sum(CheckOverflow((promote_precision(cast(cast(cs_quantity#69 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price#70 as decimal(12,2)))), DecimalType(18,2), true)), count(1)] -Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(cast(cs_quantity#69 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price#70 as decimal(12,2)))), DecimalType(18,2), true))#85, count(1)#86] -Results [6]: [catalog AS channel#87, i_brand_id#75, i_class_id#76, i_category_id#77, sum(CheckOverflow((promote_precision(cast(cast(cs_quantity#69 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price#70 as decimal(12,2)))), DecimalType(18,2), true))#85 AS sales#88, count(1)#86 AS number_sales#89] - -(99) Filter [codegen id : 92] -Input [6]: [channel#87, i_brand_id#75, i_class_id#76, i_category_id#77, sales#88, number_sales#89] -Condition : (isnotnull(sales#88) AND (cast(sales#88 as decimal(32,6)) > cast(ReusedSubquery Subquery scalar-subquery#66, [id=#67] as decimal(32,6)))) - -(100) Scan parquet default.web_sales -Output [4]: [ws_item_sk#90, ws_quantity#91, ws_list_price#92, ws_sold_date_sk#93] +(92) Project [codegen id : 87] +Output [5]: [cs_quantity#68, cs_list_price#69, i_brand_id#74, i_class_id#75, i_category_id#76] +Input [7]: [cs_item_sk#67, cs_quantity#68, cs_list_price#69, i_item_sk#73, i_brand_id#74, i_class_id#75, i_category_id#76] + +(93) HashAggregate [codegen id : 87] +Input [5]: [cs_quantity#68, cs_list_price#69, i_brand_id#74, i_class_id#75, i_category_id#76] +Keys [3]: [i_brand_id#74, i_class_id#75, i_category_id#76] +Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(cs_quantity#68 as decimal(12,2))) * promote_precision(cast(cs_list_price#69 as decimal(12,2)))), DecimalType(18,2))), partial_count(1)] +Aggregate Attributes [3]: [sum#77, isEmpty#78, count#79] +Results [6]: [i_brand_id#74, i_class_id#75, i_category_id#76, sum#80, isEmpty#81, count#82] + +(94) Exchange +Input [6]: [i_brand_id#74, i_class_id#75, i_category_id#76, sum#80, isEmpty#81, count#82] +Arguments: hashpartitioning(i_brand_id#74, i_class_id#75, i_category_id#76, 5), ENSURE_REQUIREMENTS, [id=#83] + +(95) HashAggregate [codegen id : 88] +Input [6]: [i_brand_id#74, i_class_id#75, i_category_id#76, sum#80, isEmpty#81, count#82] +Keys [3]: [i_brand_id#74, i_class_id#75, i_category_id#76] +Functions [2]: [sum(CheckOverflow((promote_precision(cast(cs_quantity#68 as decimal(12,2))) * promote_precision(cast(cs_list_price#69 as decimal(12,2)))), DecimalType(18,2))), count(1)] +Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(cs_quantity#68 as decimal(12,2))) * promote_precision(cast(cs_list_price#69 as decimal(12,2)))), DecimalType(18,2)))#84, count(1)#85] +Results [6]: [catalog AS channel#86, i_brand_id#74, i_class_id#75, i_category_id#76, sum(CheckOverflow((promote_precision(cast(cs_quantity#68 as decimal(12,2))) * promote_precision(cast(cs_list_price#69 as decimal(12,2)))), DecimalType(18,2)))#84 AS sales#87, count(1)#85 AS number_sales#88] + +(96) Filter [codegen id : 88] +Input [6]: [channel#86, i_brand_id#74, i_class_id#75, i_category_id#76, sales#87, number_sales#88] +Condition : (isnotnull(sales#87) AND (cast(sales#87 as decimal(32,6)) > cast(ReusedSubquery Subquery scalar-subquery#65, [id=#66] as decimal(32,6)))) + +(97) Scan parquet default.web_sales +Output [4]: [ws_item_sk#89, ws_quantity#90, ws_list_price#91, ws_sold_date_sk#92] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(ws_sold_date_sk#93), dynamicpruningexpression(ws_sold_date_sk#93 IN dynamicpruning#5)] +PartitionFilters: [isnotnull(ws_sold_date_sk#92), dynamicpruningexpression(ws_sold_date_sk#92 IN dynamicpruning#5)] PushedFilters: [IsNotNull(ws_item_sk)] ReadSchema: struct -(101) ColumnarToRow [codegen id : 93] -Input [4]: [ws_item_sk#90, ws_quantity#91, ws_list_price#92, ws_sold_date_sk#93] +(98) ColumnarToRow [codegen id : 89] +Input [4]: [ws_item_sk#89, ws_quantity#90, ws_list_price#91, ws_sold_date_sk#92] -(102) Filter [codegen id : 93] -Input [4]: [ws_item_sk#90, ws_quantity#91, ws_list_price#92, ws_sold_date_sk#93] -Condition : isnotnull(ws_item_sk#90) +(99) Filter [codegen id : 89] +Input [4]: [ws_item_sk#89, ws_quantity#90, ws_list_price#91, ws_sold_date_sk#92] +Condition : isnotnull(ws_item_sk#89) -(103) Exchange -Input [4]: [ws_item_sk#90, ws_quantity#91, ws_list_price#92, ws_sold_date_sk#93] -Arguments: hashpartitioning(ws_item_sk#90, 5), ENSURE_REQUIREMENTS, [id=#94] +(100) Exchange +Input [4]: [ws_item_sk#89, ws_quantity#90, ws_list_price#91, ws_sold_date_sk#92] +Arguments: hashpartitioning(ws_item_sk#89, 5), ENSURE_REQUIREMENTS, [id=#93] -(104) Sort [codegen id : 94] -Input [4]: [ws_item_sk#90, ws_quantity#91, ws_list_price#92, ws_sold_date_sk#93] -Arguments: [ws_item_sk#90 ASC NULLS FIRST], false, 0 +(101) Sort [codegen id : 90] +Input [4]: [ws_item_sk#89, ws_quantity#90, ws_list_price#91, ws_sold_date_sk#92] +Arguments: [ws_item_sk#89 ASC NULLS FIRST], false, 0 -(105) ReusedExchange [Reuses operator id: 61] -Output [1]: [ss_item_sk#45] +(102) ReusedExchange [Reuses operator id: 58] +Output [1]: [ss_item_sk#44] -(106) Sort [codegen id : 113] -Input [1]: [ss_item_sk#45] -Arguments: [ss_item_sk#45 ASC NULLS FIRST], false, 0 +(103) Sort [codegen id : 108] +Input [1]: [ss_item_sk#44] +Arguments: [ss_item_sk#44 ASC NULLS FIRST], false, 0 -(107) SortMergeJoin [codegen id : 137] -Left keys [1]: [ws_item_sk#90] -Right keys [1]: [ss_item_sk#45] +(104) SortMergeJoin [codegen id : 131] +Left keys [1]: [ws_item_sk#89] +Right keys [1]: [ss_item_sk#44] Join condition: None -(108) ReusedExchange [Reuses operator id: 175] -Output [1]: [d_date_sk#95] +(105) ReusedExchange [Reuses operator id: 172] +Output [1]: [d_date_sk#94] -(109) BroadcastHashJoin [codegen id : 137] -Left keys [1]: [ws_sold_date_sk#93] -Right keys [1]: [d_date_sk#95] +(106) BroadcastHashJoin [codegen id : 131] +Left keys [1]: [ws_sold_date_sk#92] +Right keys [1]: [d_date_sk#94] Join condition: None -(110) Project [codegen id : 137] -Output [3]: [ws_item_sk#90, ws_quantity#91, ws_list_price#92] -Input [5]: [ws_item_sk#90, ws_quantity#91, ws_list_price#92, ws_sold_date_sk#93, d_date_sk#95] +(107) Project [codegen id : 131] +Output [3]: [ws_item_sk#89, ws_quantity#90, ws_list_price#91] +Input [5]: [ws_item_sk#89, ws_quantity#90, ws_list_price#91, ws_sold_date_sk#92, d_date_sk#94] -(111) ReusedExchange [Reuses operator id: 75] -Output [4]: [i_item_sk#96, i_brand_id#97, i_class_id#98, i_category_id#99] +(108) ReusedExchange [Reuses operator id: 72] +Output [4]: [i_item_sk#95, i_brand_id#96, i_class_id#97, i_category_id#98] -(112) BroadcastHashJoin [codegen id : 137] -Left keys [1]: [ws_item_sk#90] -Right keys [1]: [i_item_sk#96] +(109) BroadcastHashJoin [codegen id : 131] +Left keys [1]: [ws_item_sk#89] +Right keys [1]: [i_item_sk#95] Join condition: None -(113) Project [codegen id : 137] -Output [5]: [ws_quantity#91, ws_list_price#92, i_brand_id#97, i_class_id#98, i_category_id#99] -Input [7]: [ws_item_sk#90, ws_quantity#91, ws_list_price#92, i_item_sk#96, i_brand_id#97, i_class_id#98, i_category_id#99] - -(114) HashAggregate [codegen id : 137] -Input [5]: [ws_quantity#91, ws_list_price#92, i_brand_id#97, i_class_id#98, i_category_id#99] -Keys [3]: [i_brand_id#97, i_class_id#98, i_category_id#99] -Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(cast(ws_quantity#91 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price#92 as decimal(12,2)))), DecimalType(18,2), true)), partial_count(1)] -Aggregate Attributes [3]: [sum#100, isEmpty#101, count#102] -Results [6]: [i_brand_id#97, i_class_id#98, i_category_id#99, sum#103, isEmpty#104, count#105] - -(115) Exchange -Input [6]: [i_brand_id#97, i_class_id#98, i_category_id#99, sum#103, isEmpty#104, count#105] -Arguments: hashpartitioning(i_brand_id#97, i_class_id#98, i_category_id#99, 5), ENSURE_REQUIREMENTS, [id=#106] - -(116) HashAggregate [codegen id : 138] -Input [6]: [i_brand_id#97, i_class_id#98, i_category_id#99, sum#103, isEmpty#104, count#105] -Keys [3]: [i_brand_id#97, i_class_id#98, i_category_id#99] -Functions [2]: [sum(CheckOverflow((promote_precision(cast(cast(ws_quantity#91 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price#92 as decimal(12,2)))), DecimalType(18,2), true)), count(1)] -Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(cast(ws_quantity#91 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price#92 as decimal(12,2)))), DecimalType(18,2), true))#107, count(1)#108] -Results [6]: [web AS channel#109, i_brand_id#97, i_class_id#98, i_category_id#99, sum(CheckOverflow((promote_precision(cast(cast(ws_quantity#91 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price#92 as decimal(12,2)))), DecimalType(18,2), true))#107 AS sales#110, count(1)#108 AS number_sales#111] - -(117) Filter [codegen id : 138] -Input [6]: [channel#109, i_brand_id#97, i_class_id#98, i_category_id#99, sales#110, number_sales#111] -Condition : (isnotnull(sales#110) AND (cast(sales#110 as decimal(32,6)) > cast(ReusedSubquery Subquery scalar-subquery#66, [id=#67] as decimal(32,6)))) - -(118) Union - -(119) HashAggregate [codegen id : 139] -Input [6]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sales#64, number_sales#65] -Keys [4]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51] -Functions [2]: [partial_sum(sales#64), partial_sum(number_sales#65)] -Aggregate Attributes [3]: [sum#112, isEmpty#113, sum#114] -Results [7]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sum#115, isEmpty#116, sum#117] - -(120) Exchange -Input [7]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sum#115, isEmpty#116, sum#117] -Arguments: hashpartitioning(channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, 5), ENSURE_REQUIREMENTS, [id=#118] - -(121) HashAggregate [codegen id : 140] -Input [7]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sum#115, isEmpty#116, sum#117] -Keys [4]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51] -Functions [2]: [sum(sales#64), sum(number_sales#65)] -Aggregate Attributes [2]: [sum(sales#64)#119, sum(number_sales#65)#120] -Results [6]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sum(sales#64)#119 AS sum_sales#121, sum(number_sales#65)#120 AS number_sales#122] - -(122) ReusedExchange [Reuses operator id: 120] -Output [7]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sum#115, isEmpty#116, sum#117] - -(123) HashAggregate [codegen id : 280] -Input [7]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sum#115, isEmpty#116, sum#117] -Keys [4]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51] -Functions [2]: [sum(sales#64), sum(number_sales#65)] -Aggregate Attributes [2]: [sum(sales#64)#119, sum(number_sales#65)#120] -Results [5]: [channel#63, i_brand_id#49, i_class_id#50, sum(sales#64)#119 AS sum_sales#121, sum(number_sales#65)#120 AS number_sales#122] - -(124) HashAggregate [codegen id : 280] -Input [5]: [channel#63, i_brand_id#49, i_class_id#50, sum_sales#121, number_sales#122] -Keys [3]: [channel#63, i_brand_id#49, i_class_id#50] -Functions [2]: [partial_sum(sum_sales#121), partial_sum(number_sales#122)] -Aggregate Attributes [3]: [sum#123, isEmpty#124, sum#125] -Results [6]: [channel#63, i_brand_id#49, i_class_id#50, sum#126, isEmpty#127, sum#128] - -(125) Exchange -Input [6]: [channel#63, i_brand_id#49, i_class_id#50, sum#126, isEmpty#127, sum#128] -Arguments: hashpartitioning(channel#63, i_brand_id#49, i_class_id#50, 5), ENSURE_REQUIREMENTS, [id=#129] - -(126) HashAggregate [codegen id : 281] -Input [6]: [channel#63, i_brand_id#49, i_class_id#50, sum#126, isEmpty#127, sum#128] -Keys [3]: [channel#63, i_brand_id#49, i_class_id#50] -Functions [2]: [sum(sum_sales#121), sum(number_sales#122)] -Aggregate Attributes [2]: [sum(sum_sales#121)#130, sum(number_sales#122)#131] -Results [6]: [channel#63, i_brand_id#49, i_class_id#50, null AS i_category_id#132, sum(sum_sales#121)#130 AS sum(sum_sales)#133, sum(number_sales#122)#131 AS sum(number_sales)#134] - -(127) ReusedExchange [Reuses operator id: 120] -Output [7]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sum#115, isEmpty#116, sum#117] - -(128) HashAggregate [codegen id : 421] -Input [7]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sum#115, isEmpty#116, sum#117] -Keys [4]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51] -Functions [2]: [sum(sales#64), sum(number_sales#65)] -Aggregate Attributes [2]: [sum(sales#64)#119, sum(number_sales#65)#120] -Results [4]: [channel#63, i_brand_id#49, sum(sales#64)#119 AS sum_sales#121, sum(number_sales#65)#120 AS number_sales#122] - -(129) HashAggregate [codegen id : 421] -Input [4]: [channel#63, i_brand_id#49, sum_sales#121, number_sales#122] -Keys [2]: [channel#63, i_brand_id#49] -Functions [2]: [partial_sum(sum_sales#121), partial_sum(number_sales#122)] -Aggregate Attributes [3]: [sum#135, isEmpty#136, sum#137] -Results [5]: [channel#63, i_brand_id#49, sum#138, isEmpty#139, sum#140] - -(130) Exchange -Input [5]: [channel#63, i_brand_id#49, sum#138, isEmpty#139, sum#140] -Arguments: hashpartitioning(channel#63, i_brand_id#49, 5), ENSURE_REQUIREMENTS, [id=#141] - -(131) HashAggregate [codegen id : 422] -Input [5]: [channel#63, i_brand_id#49, sum#138, isEmpty#139, sum#140] -Keys [2]: [channel#63, i_brand_id#49] -Functions [2]: [sum(sum_sales#121), sum(number_sales#122)] -Aggregate Attributes [2]: [sum(sum_sales#121)#142, sum(number_sales#122)#143] -Results [6]: [channel#63, i_brand_id#49, null AS i_class_id#144, null AS i_category_id#145, sum(sum_sales#121)#142 AS sum(sum_sales)#146, sum(number_sales#122)#143 AS sum(number_sales)#147] - -(132) ReusedExchange [Reuses operator id: 120] -Output [7]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sum#115, isEmpty#116, sum#117] - -(133) HashAggregate [codegen id : 562] -Input [7]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sum#115, isEmpty#116, sum#117] -Keys [4]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51] -Functions [2]: [sum(sales#64), sum(number_sales#65)] -Aggregate Attributes [2]: [sum(sales#64)#119, sum(number_sales#65)#120] -Results [3]: [channel#63, sum(sales#64)#119 AS sum_sales#121, sum(number_sales#65)#120 AS number_sales#122] - -(134) HashAggregate [codegen id : 562] -Input [3]: [channel#63, sum_sales#121, number_sales#122] -Keys [1]: [channel#63] -Functions [2]: [partial_sum(sum_sales#121), partial_sum(number_sales#122)] -Aggregate Attributes [3]: [sum#148, isEmpty#149, sum#150] -Results [4]: [channel#63, sum#151, isEmpty#152, sum#153] - -(135) Exchange -Input [4]: [channel#63, sum#151, isEmpty#152, sum#153] -Arguments: hashpartitioning(channel#63, 5), ENSURE_REQUIREMENTS, [id=#154] - -(136) HashAggregate [codegen id : 563] -Input [4]: [channel#63, sum#151, isEmpty#152, sum#153] -Keys [1]: [channel#63] -Functions [2]: [sum(sum_sales#121), sum(number_sales#122)] -Aggregate Attributes [2]: [sum(sum_sales#121)#155, sum(number_sales#122)#156] -Results [6]: [channel#63, null AS i_brand_id#157, null AS i_class_id#158, null AS i_category_id#159, sum(sum_sales#121)#155 AS sum(sum_sales)#160, sum(number_sales#122)#156 AS sum(number_sales)#161] - -(137) ReusedExchange [Reuses operator id: 120] -Output [7]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sum#115, isEmpty#116, sum#117] - -(138) HashAggregate [codegen id : 703] -Input [7]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sum#115, isEmpty#116, sum#117] -Keys [4]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51] -Functions [2]: [sum(sales#64), sum(number_sales#65)] -Aggregate Attributes [2]: [sum(sales#64)#119, sum(number_sales#65)#120] -Results [2]: [sum(sales#64)#119 AS sum_sales#121, sum(number_sales#65)#120 AS number_sales#122] - -(139) HashAggregate [codegen id : 703] -Input [2]: [sum_sales#121, number_sales#122] +(110) Project [codegen id : 131] +Output [5]: [ws_quantity#90, ws_list_price#91, i_brand_id#96, i_class_id#97, i_category_id#98] +Input [7]: [ws_item_sk#89, ws_quantity#90, ws_list_price#91, i_item_sk#95, i_brand_id#96, i_class_id#97, i_category_id#98] + +(111) HashAggregate [codegen id : 131] +Input [5]: [ws_quantity#90, ws_list_price#91, i_brand_id#96, i_class_id#97, i_category_id#98] +Keys [3]: [i_brand_id#96, i_class_id#97, i_category_id#98] +Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(ws_quantity#90 as decimal(12,2))) * promote_precision(cast(ws_list_price#91 as decimal(12,2)))), DecimalType(18,2))), partial_count(1)] +Aggregate Attributes [3]: [sum#99, isEmpty#100, count#101] +Results [6]: [i_brand_id#96, i_class_id#97, i_category_id#98, sum#102, isEmpty#103, count#104] + +(112) Exchange +Input [6]: [i_brand_id#96, i_class_id#97, i_category_id#98, sum#102, isEmpty#103, count#104] +Arguments: hashpartitioning(i_brand_id#96, i_class_id#97, i_category_id#98, 5), ENSURE_REQUIREMENTS, [id=#105] + +(113) HashAggregate [codegen id : 132] +Input [6]: [i_brand_id#96, i_class_id#97, i_category_id#98, sum#102, isEmpty#103, count#104] +Keys [3]: [i_brand_id#96, i_class_id#97, i_category_id#98] +Functions [2]: [sum(CheckOverflow((promote_precision(cast(ws_quantity#90 as decimal(12,2))) * promote_precision(cast(ws_list_price#91 as decimal(12,2)))), DecimalType(18,2))), count(1)] +Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(ws_quantity#90 as decimal(12,2))) * promote_precision(cast(ws_list_price#91 as decimal(12,2)))), DecimalType(18,2)))#106, count(1)#107] +Results [6]: [web AS channel#108, i_brand_id#96, i_class_id#97, i_category_id#98, sum(CheckOverflow((promote_precision(cast(ws_quantity#90 as decimal(12,2))) * promote_precision(cast(ws_list_price#91 as decimal(12,2)))), DecimalType(18,2)))#106 AS sales#109, count(1)#107 AS number_sales#110] + +(114) Filter [codegen id : 132] +Input [6]: [channel#108, i_brand_id#96, i_class_id#97, i_category_id#98, sales#109, number_sales#110] +Condition : (isnotnull(sales#109) AND (cast(sales#109 as decimal(32,6)) > cast(ReusedSubquery Subquery scalar-subquery#65, [id=#66] as decimal(32,6)))) + +(115) Union + +(116) HashAggregate [codegen id : 133] +Input [6]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sales#63, number_sales#64] +Keys [4]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50] +Functions [2]: [partial_sum(sales#63), partial_sum(number_sales#64)] +Aggregate Attributes [3]: [sum#111, isEmpty#112, sum#113] +Results [7]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sum#114, isEmpty#115, sum#116] + +(117) Exchange +Input [7]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sum#114, isEmpty#115, sum#116] +Arguments: hashpartitioning(channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, 5), ENSURE_REQUIREMENTS, [id=#117] + +(118) HashAggregate [codegen id : 134] +Input [7]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sum#114, isEmpty#115, sum#116] +Keys [4]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50] +Functions [2]: [sum(sales#63), sum(number_sales#64)] +Aggregate Attributes [2]: [sum(sales#63)#118, sum(number_sales#64)#119] +Results [6]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sum(sales#63)#118 AS sum_sales#120, sum(number_sales#64)#119 AS number_sales#121] + +(119) ReusedExchange [Reuses operator id: 117] +Output [7]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sum#114, isEmpty#115, sum#116] + +(120) HashAggregate [codegen id : 268] +Input [7]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sum#114, isEmpty#115, sum#116] +Keys [4]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50] +Functions [2]: [sum(sales#63), sum(number_sales#64)] +Aggregate Attributes [2]: [sum(sales#63)#118, sum(number_sales#64)#119] +Results [5]: [channel#62, i_brand_id#48, i_class_id#49, sum(sales#63)#118 AS sum_sales#120, sum(number_sales#64)#119 AS number_sales#121] + +(121) HashAggregate [codegen id : 268] +Input [5]: [channel#62, i_brand_id#48, i_class_id#49, sum_sales#120, number_sales#121] +Keys [3]: [channel#62, i_brand_id#48, i_class_id#49] +Functions [2]: [partial_sum(sum_sales#120), partial_sum(number_sales#121)] +Aggregate Attributes [3]: [sum#122, isEmpty#123, sum#124] +Results [6]: [channel#62, i_brand_id#48, i_class_id#49, sum#125, isEmpty#126, sum#127] + +(122) Exchange +Input [6]: [channel#62, i_brand_id#48, i_class_id#49, sum#125, isEmpty#126, sum#127] +Arguments: hashpartitioning(channel#62, i_brand_id#48, i_class_id#49, 5), ENSURE_REQUIREMENTS, [id=#128] + +(123) HashAggregate [codegen id : 269] +Input [6]: [channel#62, i_brand_id#48, i_class_id#49, sum#125, isEmpty#126, sum#127] +Keys [3]: [channel#62, i_brand_id#48, i_class_id#49] +Functions [2]: [sum(sum_sales#120), sum(number_sales#121)] +Aggregate Attributes [2]: [sum(sum_sales#120)#129, sum(number_sales#121)#130] +Results [6]: [channel#62, i_brand_id#48, i_class_id#49, null AS i_category_id#131, sum(sum_sales#120)#129 AS sum(sum_sales)#132, sum(number_sales#121)#130 AS sum(number_sales)#133] + +(124) ReusedExchange [Reuses operator id: 117] +Output [7]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sum#114, isEmpty#115, sum#116] + +(125) HashAggregate [codegen id : 403] +Input [7]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sum#114, isEmpty#115, sum#116] +Keys [4]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50] +Functions [2]: [sum(sales#63), sum(number_sales#64)] +Aggregate Attributes [2]: [sum(sales#63)#118, sum(number_sales#64)#119] +Results [4]: [channel#62, i_brand_id#48, sum(sales#63)#118 AS sum_sales#120, sum(number_sales#64)#119 AS number_sales#121] + +(126) HashAggregate [codegen id : 403] +Input [4]: [channel#62, i_brand_id#48, sum_sales#120, number_sales#121] +Keys [2]: [channel#62, i_brand_id#48] +Functions [2]: [partial_sum(sum_sales#120), partial_sum(number_sales#121)] +Aggregate Attributes [3]: [sum#134, isEmpty#135, sum#136] +Results [5]: [channel#62, i_brand_id#48, sum#137, isEmpty#138, sum#139] + +(127) Exchange +Input [5]: [channel#62, i_brand_id#48, sum#137, isEmpty#138, sum#139] +Arguments: hashpartitioning(channel#62, i_brand_id#48, 5), ENSURE_REQUIREMENTS, [id=#140] + +(128) HashAggregate [codegen id : 404] +Input [5]: [channel#62, i_brand_id#48, sum#137, isEmpty#138, sum#139] +Keys [2]: [channel#62, i_brand_id#48] +Functions [2]: [sum(sum_sales#120), sum(number_sales#121)] +Aggregate Attributes [2]: [sum(sum_sales#120)#141, sum(number_sales#121)#142] +Results [6]: [channel#62, i_brand_id#48, null AS i_class_id#143, null AS i_category_id#144, sum(sum_sales#120)#141 AS sum(sum_sales)#145, sum(number_sales#121)#142 AS sum(number_sales)#146] + +(129) ReusedExchange [Reuses operator id: 117] +Output [7]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sum#114, isEmpty#115, sum#116] + +(130) HashAggregate [codegen id : 538] +Input [7]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sum#114, isEmpty#115, sum#116] +Keys [4]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50] +Functions [2]: [sum(sales#63), sum(number_sales#64)] +Aggregate Attributes [2]: [sum(sales#63)#118, sum(number_sales#64)#119] +Results [3]: [channel#62, sum(sales#63)#118 AS sum_sales#120, sum(number_sales#64)#119 AS number_sales#121] + +(131) HashAggregate [codegen id : 538] +Input [3]: [channel#62, sum_sales#120, number_sales#121] +Keys [1]: [channel#62] +Functions [2]: [partial_sum(sum_sales#120), partial_sum(number_sales#121)] +Aggregate Attributes [3]: [sum#147, isEmpty#148, sum#149] +Results [4]: [channel#62, sum#150, isEmpty#151, sum#152] + +(132) Exchange +Input [4]: [channel#62, sum#150, isEmpty#151, sum#152] +Arguments: hashpartitioning(channel#62, 5), ENSURE_REQUIREMENTS, [id=#153] + +(133) HashAggregate [codegen id : 539] +Input [4]: [channel#62, sum#150, isEmpty#151, sum#152] +Keys [1]: [channel#62] +Functions [2]: [sum(sum_sales#120), sum(number_sales#121)] +Aggregate Attributes [2]: [sum(sum_sales#120)#154, sum(number_sales#121)#155] +Results [6]: [channel#62, null AS i_brand_id#156, null AS i_class_id#157, null AS i_category_id#158, sum(sum_sales#120)#154 AS sum(sum_sales)#159, sum(number_sales#121)#155 AS sum(number_sales)#160] + +(134) ReusedExchange [Reuses operator id: 117] +Output [7]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sum#114, isEmpty#115, sum#116] + +(135) HashAggregate [codegen id : 673] +Input [7]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sum#114, isEmpty#115, sum#116] +Keys [4]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50] +Functions [2]: [sum(sales#63), sum(number_sales#64)] +Aggregate Attributes [2]: [sum(sales#63)#118, sum(number_sales#64)#119] +Results [2]: [sum(sales#63)#118 AS sum_sales#120, sum(number_sales#64)#119 AS number_sales#121] + +(136) HashAggregate [codegen id : 673] +Input [2]: [sum_sales#120, number_sales#121] Keys: [] -Functions [2]: [partial_sum(sum_sales#121), partial_sum(number_sales#122)] -Aggregate Attributes [3]: [sum#162, isEmpty#163, sum#164] -Results [3]: [sum#165, isEmpty#166, sum#167] +Functions [2]: [partial_sum(sum_sales#120), partial_sum(number_sales#121)] +Aggregate Attributes [3]: [sum#161, isEmpty#162, sum#163] +Results [3]: [sum#164, isEmpty#165, sum#166] -(140) Exchange -Input [3]: [sum#165, isEmpty#166, sum#167] -Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#168] +(137) Exchange +Input [3]: [sum#164, isEmpty#165, sum#166] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#167] -(141) HashAggregate [codegen id : 704] -Input [3]: [sum#165, isEmpty#166, sum#167] +(138) HashAggregate [codegen id : 674] +Input [3]: [sum#164, isEmpty#165, sum#166] Keys: [] -Functions [2]: [sum(sum_sales#121), sum(number_sales#122)] -Aggregate Attributes [2]: [sum(sum_sales#121)#169, sum(number_sales#122)#170] -Results [6]: [null AS channel#171, null AS i_brand_id#172, null AS i_class_id#173, null AS i_category_id#174, sum(sum_sales#121)#169 AS sum(sum_sales)#175, sum(number_sales#122)#170 AS sum(number_sales)#176] +Functions [2]: [sum(sum_sales#120), sum(number_sales#121)] +Aggregate Attributes [2]: [sum(sum_sales#120)#168, sum(number_sales#121)#169] +Results [6]: [null AS channel#170, null AS i_brand_id#171, null AS i_class_id#172, null AS i_category_id#173, sum(sum_sales#120)#168 AS sum(sum_sales)#174, sum(number_sales#121)#169 AS sum(number_sales)#175] -(142) Union +(139) Union -(143) HashAggregate [codegen id : 705] -Input [6]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sum_sales#121, number_sales#122] -Keys [6]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sum_sales#121, number_sales#122] +(140) HashAggregate [codegen id : 675] +Input [6]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sum_sales#120, number_sales#121] +Keys [6]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sum_sales#120, number_sales#121] Functions: [] Aggregate Attributes: [] -Results [6]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sum_sales#121, number_sales#122] +Results [6]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sum_sales#120, number_sales#121] -(144) Exchange -Input [6]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sum_sales#121, number_sales#122] -Arguments: hashpartitioning(channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sum_sales#121, number_sales#122, 5), ENSURE_REQUIREMENTS, [id=#177] +(141) Exchange +Input [6]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sum_sales#120, number_sales#121] +Arguments: hashpartitioning(channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sum_sales#120, number_sales#121, 5), ENSURE_REQUIREMENTS, [id=#176] -(145) HashAggregate [codegen id : 706] -Input [6]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sum_sales#121, number_sales#122] -Keys [6]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sum_sales#121, number_sales#122] +(142) HashAggregate [codegen id : 676] +Input [6]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sum_sales#120, number_sales#121] +Keys [6]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sum_sales#120, number_sales#121] Functions: [] Aggregate Attributes: [] -Results [6]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sum_sales#121, number_sales#122] +Results [6]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sum_sales#120, number_sales#121] -(146) TakeOrderedAndProject -Input [6]: [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sum_sales#121, number_sales#122] -Arguments: 100, [channel#63 ASC NULLS FIRST, i_brand_id#49 ASC NULLS FIRST, i_class_id#50 ASC NULLS FIRST, i_category_id#51 ASC NULLS FIRST], [channel#63, i_brand_id#49, i_class_id#50, i_category_id#51, sum_sales#121, number_sales#122] +(143) TakeOrderedAndProject +Input [6]: [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sum_sales#120, number_sales#121] +Arguments: 100, [channel#62 ASC NULLS FIRST, i_brand_id#48 ASC NULLS FIRST, i_class_id#49 ASC NULLS FIRST, i_category_id#50 ASC NULLS FIRST], [channel#62, i_brand_id#48, i_class_id#49, i_category_id#50, sum_sales#120, number_sales#121] ===== Subqueries ===== -Subquery:1 Hosting operator id = 81 Hosting Expression = Subquery scalar-subquery#66, [id=#67] -* HashAggregate (165) -+- Exchange (164) - +- * HashAggregate (163) - +- Union (162) - :- * Project (151) - : +- * BroadcastHashJoin Inner BuildRight (150) - : :- * ColumnarToRow (148) - : : +- Scan parquet default.store_sales (147) - : +- ReusedExchange (149) - :- * Project (156) - : +- * BroadcastHashJoin Inner BuildRight (155) - : :- * ColumnarToRow (153) - : : +- Scan parquet default.catalog_sales (152) - : +- ReusedExchange (154) - +- * Project (161) - +- * BroadcastHashJoin Inner BuildRight (160) - :- * ColumnarToRow (158) - : +- Scan parquet default.web_sales (157) - +- ReusedExchange (159) - - -(147) Scan parquet default.store_sales -Output [3]: [ss_quantity#178, ss_list_price#179, ss_sold_date_sk#180] +Subquery:1 Hosting operator id = 78 Hosting Expression = Subquery scalar-subquery#65, [id=#66] +* HashAggregate (162) ++- Exchange (161) + +- * HashAggregate (160) + +- Union (159) + :- * Project (148) + : +- * BroadcastHashJoin Inner BuildRight (147) + : :- * ColumnarToRow (145) + : : +- Scan parquet default.store_sales (144) + : +- ReusedExchange (146) + :- * Project (153) + : +- * BroadcastHashJoin Inner BuildRight (152) + : :- * ColumnarToRow (150) + : : +- Scan parquet default.catalog_sales (149) + : +- ReusedExchange (151) + +- * Project (158) + +- * BroadcastHashJoin Inner BuildRight (157) + :- * ColumnarToRow (155) + : +- Scan parquet default.web_sales (154) + +- ReusedExchange (156) + + +(144) Scan parquet default.store_sales +Output [3]: [ss_quantity#177, ss_list_price#178, ss_sold_date_sk#179] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(ss_sold_date_sk#180), dynamicpruningexpression(ss_sold_date_sk#180 IN dynamicpruning#13)] +PartitionFilters: [isnotnull(ss_sold_date_sk#179), dynamicpruningexpression(ss_sold_date_sk#179 IN dynamicpruning#13)] ReadSchema: struct -(148) ColumnarToRow [codegen id : 2] -Input [3]: [ss_quantity#178, ss_list_price#179, ss_sold_date_sk#180] +(145) ColumnarToRow [codegen id : 2] +Input [3]: [ss_quantity#177, ss_list_price#178, ss_sold_date_sk#179] -(149) ReusedExchange [Reuses operator id: 180] -Output [1]: [d_date_sk#181] +(146) ReusedExchange [Reuses operator id: 177] +Output [1]: [d_date_sk#180] -(150) BroadcastHashJoin [codegen id : 2] -Left keys [1]: [ss_sold_date_sk#180] -Right keys [1]: [d_date_sk#181] +(147) BroadcastHashJoin [codegen id : 2] +Left keys [1]: [ss_sold_date_sk#179] +Right keys [1]: [d_date_sk#180] Join condition: None -(151) Project [codegen id : 2] -Output [2]: [ss_quantity#178 AS quantity#182, ss_list_price#179 AS list_price#183] -Input [4]: [ss_quantity#178, ss_list_price#179, ss_sold_date_sk#180, d_date_sk#181] +(148) Project [codegen id : 2] +Output [2]: [ss_quantity#177 AS quantity#181, ss_list_price#178 AS list_price#182] +Input [4]: [ss_quantity#177, ss_list_price#178, ss_sold_date_sk#179, d_date_sk#180] -(152) Scan parquet default.catalog_sales -Output [3]: [cs_quantity#184, cs_list_price#185, cs_sold_date_sk#186] +(149) Scan parquet default.catalog_sales +Output [3]: [cs_quantity#183, cs_list_price#184, cs_sold_date_sk#185] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(cs_sold_date_sk#186), dynamicpruningexpression(cs_sold_date_sk#186 IN dynamicpruning#187)] +PartitionFilters: [isnotnull(cs_sold_date_sk#185), dynamicpruningexpression(cs_sold_date_sk#185 IN dynamicpruning#186)] ReadSchema: struct -(153) ColumnarToRow [codegen id : 4] -Input [3]: [cs_quantity#184, cs_list_price#185, cs_sold_date_sk#186] +(150) ColumnarToRow [codegen id : 4] +Input [3]: [cs_quantity#183, cs_list_price#184, cs_sold_date_sk#185] -(154) ReusedExchange [Reuses operator id: 170] -Output [1]: [d_date_sk#188] +(151) ReusedExchange [Reuses operator id: 167] +Output [1]: [d_date_sk#187] -(155) BroadcastHashJoin [codegen id : 4] -Left keys [1]: [cs_sold_date_sk#186] -Right keys [1]: [d_date_sk#188] +(152) BroadcastHashJoin [codegen id : 4] +Left keys [1]: [cs_sold_date_sk#185] +Right keys [1]: [d_date_sk#187] Join condition: None -(156) Project [codegen id : 4] -Output [2]: [cs_quantity#184 AS quantity#189, cs_list_price#185 AS list_price#190] -Input [4]: [cs_quantity#184, cs_list_price#185, cs_sold_date_sk#186, d_date_sk#188] +(153) Project [codegen id : 4] +Output [2]: [cs_quantity#183 AS quantity#188, cs_list_price#184 AS list_price#189] +Input [4]: [cs_quantity#183, cs_list_price#184, cs_sold_date_sk#185, d_date_sk#187] -(157) Scan parquet default.web_sales -Output [3]: [ws_quantity#191, ws_list_price#192, ws_sold_date_sk#193] +(154) Scan parquet default.web_sales +Output [3]: [ws_quantity#190, ws_list_price#191, ws_sold_date_sk#192] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(ws_sold_date_sk#193), dynamicpruningexpression(ws_sold_date_sk#193 IN dynamicpruning#187)] +PartitionFilters: [isnotnull(ws_sold_date_sk#192), dynamicpruningexpression(ws_sold_date_sk#192 IN dynamicpruning#186)] ReadSchema: struct -(158) ColumnarToRow [codegen id : 6] -Input [3]: [ws_quantity#191, ws_list_price#192, ws_sold_date_sk#193] +(155) ColumnarToRow [codegen id : 6] +Input [3]: [ws_quantity#190, ws_list_price#191, ws_sold_date_sk#192] -(159) ReusedExchange [Reuses operator id: 170] -Output [1]: [d_date_sk#194] +(156) ReusedExchange [Reuses operator id: 167] +Output [1]: [d_date_sk#193] -(160) BroadcastHashJoin [codegen id : 6] -Left keys [1]: [ws_sold_date_sk#193] -Right keys [1]: [d_date_sk#194] +(157) BroadcastHashJoin [codegen id : 6] +Left keys [1]: [ws_sold_date_sk#192] +Right keys [1]: [d_date_sk#193] Join condition: None -(161) Project [codegen id : 6] -Output [2]: [ws_quantity#191 AS quantity#195, ws_list_price#192 AS list_price#196] -Input [4]: [ws_quantity#191, ws_list_price#192, ws_sold_date_sk#193, d_date_sk#194] +(158) Project [codegen id : 6] +Output [2]: [ws_quantity#190 AS quantity#194, ws_list_price#191 AS list_price#195] +Input [4]: [ws_quantity#190, ws_list_price#191, ws_sold_date_sk#192, d_date_sk#193] -(162) Union +(159) Union -(163) HashAggregate [codegen id : 7] -Input [2]: [quantity#182, list_price#183] +(160) HashAggregate [codegen id : 7] +Input [2]: [quantity#181, list_price#182] Keys: [] -Functions [1]: [partial_avg(CheckOverflow((promote_precision(cast(cast(quantity#182 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#183 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [2]: [sum#197, count#198] -Results [2]: [sum#199, count#200] +Functions [1]: [partial_avg(CheckOverflow((promote_precision(cast(quantity#181 as decimal(12,2))) * promote_precision(cast(list_price#182 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [2]: [sum#196, count#197] +Results [2]: [sum#198, count#199] -(164) Exchange -Input [2]: [sum#199, count#200] -Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#201] +(161) Exchange +Input [2]: [sum#198, count#199] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#200] -(165) HashAggregate [codegen id : 8] -Input [2]: [sum#199, count#200] +(162) HashAggregate [codegen id : 8] +Input [2]: [sum#198, count#199] Keys: [] -Functions [1]: [avg(CheckOverflow((promote_precision(cast(cast(quantity#182 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#183 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [1]: [avg(CheckOverflow((promote_precision(cast(cast(quantity#182 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#183 as decimal(12,2)))), DecimalType(18,2), true))#202] -Results [1]: [avg(CheckOverflow((promote_precision(cast(cast(quantity#182 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#183 as decimal(12,2)))), DecimalType(18,2), true))#202 AS average_sales#203] +Functions [1]: [avg(CheckOverflow((promote_precision(cast(quantity#181 as decimal(12,2))) * promote_precision(cast(list_price#182 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [1]: [avg(CheckOverflow((promote_precision(cast(quantity#181 as decimal(12,2))) * promote_precision(cast(list_price#182 as decimal(12,2)))), DecimalType(18,2)))#201] +Results [1]: [avg(CheckOverflow((promote_precision(cast(quantity#181 as decimal(12,2))) * promote_precision(cast(list_price#182 as decimal(12,2)))), DecimalType(18,2)))#201 AS average_sales#202] -Subquery:2 Hosting operator id = 147 Hosting Expression = ss_sold_date_sk#180 IN dynamicpruning#13 +Subquery:2 Hosting operator id = 144 Hosting Expression = ss_sold_date_sk#179 IN dynamicpruning#13 -Subquery:3 Hosting operator id = 152 Hosting Expression = cs_sold_date_sk#186 IN dynamicpruning#187 -BroadcastExchange (170) -+- * Project (169) - +- * Filter (168) - +- * ColumnarToRow (167) - +- Scan parquet default.date_dim (166) +Subquery:3 Hosting operator id = 149 Hosting Expression = cs_sold_date_sk#185 IN dynamicpruning#186 +BroadcastExchange (167) ++- * Project (166) + +- * Filter (165) + +- * ColumnarToRow (164) + +- Scan parquet default.date_dim (163) -(166) Scan parquet default.date_dim -Output [2]: [d_date_sk#188, d_year#204] +(163) Scan parquet default.date_dim +Output [2]: [d_date_sk#187, d_year#203] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), GreaterThanOrEqual(d_year,1998), LessThanOrEqual(d_year,2000), IsNotNull(d_date_sk)] ReadSchema: struct -(167) ColumnarToRow [codegen id : 1] -Input [2]: [d_date_sk#188, d_year#204] +(164) ColumnarToRow [codegen id : 1] +Input [2]: [d_date_sk#187, d_year#203] -(168) Filter [codegen id : 1] -Input [2]: [d_date_sk#188, d_year#204] -Condition : (((isnotnull(d_year#204) AND (d_year#204 >= 1998)) AND (d_year#204 <= 2000)) AND isnotnull(d_date_sk#188)) +(165) Filter [codegen id : 1] +Input [2]: [d_date_sk#187, d_year#203] +Condition : (((isnotnull(d_year#203) AND (d_year#203 >= 1998)) AND (d_year#203 <= 2000)) AND isnotnull(d_date_sk#187)) -(169) Project [codegen id : 1] -Output [1]: [d_date_sk#188] -Input [2]: [d_date_sk#188, d_year#204] +(166) Project [codegen id : 1] +Output [1]: [d_date_sk#187] +Input [2]: [d_date_sk#187, d_year#203] -(170) BroadcastExchange -Input [1]: [d_date_sk#188] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#205] +(167) BroadcastExchange +Input [1]: [d_date_sk#187] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#204] -Subquery:4 Hosting operator id = 157 Hosting Expression = ws_sold_date_sk#193 IN dynamicpruning#187 +Subquery:4 Hosting operator id = 154 Hosting Expression = ws_sold_date_sk#192 IN dynamicpruning#186 Subquery:5 Hosting operator id = 1 Hosting Expression = ss_sold_date_sk#4 IN dynamicpruning#5 -BroadcastExchange (175) -+- * Project (174) - +- * Filter (173) - +- * ColumnarToRow (172) - +- Scan parquet default.date_dim (171) +BroadcastExchange (172) ++- * Project (171) + +- * Filter (170) + +- * ColumnarToRow (169) + +- Scan parquet default.date_dim (168) -(171) Scan parquet default.date_dim -Output [3]: [d_date_sk#47, d_year#206, d_moy#207] +(168) Scan parquet default.date_dim +Output [3]: [d_date_sk#46, d_year#205, d_moy#206] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), IsNotNull(d_moy), EqualTo(d_year,2000), EqualTo(d_moy,11), IsNotNull(d_date_sk)] ReadSchema: struct -(172) ColumnarToRow [codegen id : 1] -Input [3]: [d_date_sk#47, d_year#206, d_moy#207] +(169) ColumnarToRow [codegen id : 1] +Input [3]: [d_date_sk#46, d_year#205, d_moy#206] -(173) Filter [codegen id : 1] -Input [3]: [d_date_sk#47, d_year#206, d_moy#207] -Condition : ((((isnotnull(d_year#206) AND isnotnull(d_moy#207)) AND (d_year#206 = 2000)) AND (d_moy#207 = 11)) AND isnotnull(d_date_sk#47)) +(170) Filter [codegen id : 1] +Input [3]: [d_date_sk#46, d_year#205, d_moy#206] +Condition : ((((isnotnull(d_year#205) AND isnotnull(d_moy#206)) AND (d_year#205 = 2000)) AND (d_moy#206 = 11)) AND isnotnull(d_date_sk#46)) -(174) Project [codegen id : 1] -Output [1]: [d_date_sk#47] -Input [3]: [d_date_sk#47, d_year#206, d_moy#207] +(171) Project [codegen id : 1] +Output [1]: [d_date_sk#46] +Input [3]: [d_date_sk#46, d_year#205, d_moy#206] -(175) BroadcastExchange -Input [1]: [d_date_sk#47] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#208] +(172) BroadcastExchange +Input [1]: [d_date_sk#46] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#207] Subquery:6 Hosting operator id = 9 Hosting Expression = ss_sold_date_sk#12 IN dynamicpruning#13 -BroadcastExchange (180) -+- * Project (179) - +- * Filter (178) - +- * ColumnarToRow (177) - +- Scan parquet default.date_dim (176) +BroadcastExchange (177) ++- * Project (176) + +- * Filter (175) + +- * ColumnarToRow (174) + +- Scan parquet default.date_dim (173) -(176) Scan parquet default.date_dim -Output [2]: [d_date_sk#14, d_year#209] +(173) Scan parquet default.date_dim +Output [2]: [d_date_sk#14, d_year#208] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), GreaterThanOrEqual(d_year,1999), LessThanOrEqual(d_year,2001), IsNotNull(d_date_sk)] ReadSchema: struct -(177) ColumnarToRow [codegen id : 1] -Input [2]: [d_date_sk#14, d_year#209] +(174) ColumnarToRow [codegen id : 1] +Input [2]: [d_date_sk#14, d_year#208] -(178) Filter [codegen id : 1] -Input [2]: [d_date_sk#14, d_year#209] -Condition : (((isnotnull(d_year#209) AND (d_year#209 >= 1999)) AND (d_year#209 <= 2001)) AND isnotnull(d_date_sk#14)) +(175) Filter [codegen id : 1] +Input [2]: [d_date_sk#14, d_year#208] +Condition : (((isnotnull(d_year#208) AND (d_year#208 >= 1999)) AND (d_year#208 <= 2001)) AND isnotnull(d_date_sk#14)) -(179) Project [codegen id : 1] +(176) Project [codegen id : 1] Output [1]: [d_date_sk#14] -Input [2]: [d_date_sk#14, d_year#209] +Input [2]: [d_date_sk#14, d_year#208] -(180) BroadcastExchange +(177) BroadcastExchange Input [1]: [d_date_sk#14] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#210] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#209] Subquery:7 Hosting operator id = 20 Hosting Expression = cs_sold_date_sk#21 IN dynamicpruning#13 Subquery:8 Hosting operator id = 43 Hosting Expression = ws_sold_date_sk#36 IN dynamicpruning#13 -Subquery:9 Hosting operator id = 99 Hosting Expression = ReusedSubquery Subquery scalar-subquery#66, [id=#67] +Subquery:9 Hosting operator id = 96 Hosting Expression = ReusedSubquery Subquery scalar-subquery#65, [id=#66] -Subquery:10 Hosting operator id = 82 Hosting Expression = cs_sold_date_sk#71 IN dynamicpruning#5 +Subquery:10 Hosting operator id = 79 Hosting Expression = cs_sold_date_sk#70 IN dynamicpruning#5 -Subquery:11 Hosting operator id = 117 Hosting Expression = ReusedSubquery Subquery scalar-subquery#66, [id=#67] +Subquery:11 Hosting operator id = 114 Hosting Expression = ReusedSubquery Subquery scalar-subquery#65, [id=#66] -Subquery:12 Hosting operator id = 100 Hosting Expression = ws_sold_date_sk#93 IN dynamicpruning#5 +Subquery:12 Hosting operator id = 97 Hosting Expression = ws_sold_date_sk#92 IN dynamicpruning#5 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14a.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14a.sf100/simplified.txt index d494944f8e4d5..856de20a40ca8 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14a.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14a.sf100/simplified.txt @@ -1,27 +1,27 @@ TakeOrderedAndProject [channel,i_brand_id,i_class_id,i_category_id,sum_sales,number_sales] - WholeStageCodegen (706) + WholeStageCodegen (676) HashAggregate [channel,i_brand_id,i_class_id,i_category_id,sum_sales,number_sales] InputAdapter Exchange [channel,i_brand_id,i_class_id,i_category_id,sum_sales,number_sales] #1 - WholeStageCodegen (705) + WholeStageCodegen (675) HashAggregate [channel,i_brand_id,i_class_id,i_category_id,sum_sales,number_sales] InputAdapter Union - WholeStageCodegen (140) + WholeStageCodegen (134) HashAggregate [channel,i_brand_id,i_class_id,i_category_id,sum,isEmpty,sum] [sum(sales),sum(number_salesL),sum_sales,number_sales,sum,isEmpty,sum] InputAdapter Exchange [channel,i_brand_id,i_class_id,i_category_id] #2 - WholeStageCodegen (139) + WholeStageCodegen (133) HashAggregate [channel,i_brand_id,i_class_id,i_category_id,sales,number_sales] [sum,isEmpty,sum,sum,isEmpty,sum] InputAdapter Union - WholeStageCodegen (46) + WholeStageCodegen (44) Filter [sales] Subquery #3 WholeStageCodegen (8) - HashAggregate [sum,count] [avg(CheckOverflow((promote_precision(cast(cast(quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price as decimal(12,2)))), DecimalType(18,2), true)),average_sales,sum,count] + HashAggregate [sum,count] [avg(CheckOverflow((promote_precision(cast(quantity as decimal(12,2))) * promote_precision(cast(list_price as decimal(12,2)))), DecimalType(18,2))),average_sales,sum,count] InputAdapter - Exchange #19 + Exchange #18 WholeStageCodegen (7) HashAggregate [quantity,list_price] [sum,count,sum,count] InputAdapter @@ -34,7 +34,7 @@ TakeOrderedAndProject [channel,i_brand_id,i_class_id,i_category_id,sum_sales,num Scan parquet default.store_sales [ss_quantity,ss_list_price,ss_sold_date_sk] ReusedSubquery [d_date_sk] #2 InputAdapter - ReusedExchange [d_date_sk] #11 + ReusedExchange [d_date_sk] #10 WholeStageCodegen (4) Project [cs_quantity,cs_list_price] BroadcastHashJoin [cs_sold_date_sk,d_date_sk] @@ -42,7 +42,7 @@ TakeOrderedAndProject [channel,i_brand_id,i_class_id,i_category_id,sum_sales,num InputAdapter Scan parquet default.catalog_sales [cs_quantity,cs_list_price,cs_sold_date_sk] SubqueryBroadcast [d_date_sk] #4 - BroadcastExchange #20 + BroadcastExchange #19 WholeStageCodegen (1) Project [d_date_sk] Filter [d_year,d_date_sk] @@ -50,7 +50,7 @@ TakeOrderedAndProject [channel,i_brand_id,i_class_id,i_category_id,sum_sales,num InputAdapter Scan parquet default.date_dim [d_date_sk,d_year] InputAdapter - ReusedExchange [d_date_sk] #20 + ReusedExchange [d_date_sk] #19 WholeStageCodegen (6) Project [ws_quantity,ws_list_price] BroadcastHashJoin [ws_sold_date_sk,d_date_sk] @@ -59,11 +59,11 @@ TakeOrderedAndProject [channel,i_brand_id,i_class_id,i_category_id,sum_sales,num Scan parquet default.web_sales [ws_quantity,ws_list_price,ws_sold_date_sk] ReusedSubquery [d_date_sk] #4 InputAdapter - ReusedExchange [d_date_sk] #20 - HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price as decimal(12,2)))), DecimalType(18,2), true)),count(1),channel,sales,number_sales,sum,isEmpty,count] + ReusedExchange [d_date_sk] #19 + HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(ss_quantity as decimal(12,2))) * promote_precision(cast(ss_list_price as decimal(12,2)))), DecimalType(18,2))),count(1),channel,sales,number_sales,sum,isEmpty,count] InputAdapter Exchange [i_brand_id,i_class_id,i_category_id] #3 - WholeStageCodegen (45) + WholeStageCodegen (43) HashAggregate [i_brand_id,i_class_id,i_category_id,ss_quantity,ss_list_price] [sum,isEmpty,count,sum,isEmpty,count] Project [ss_quantity,ss_list_price,i_brand_id,i_class_id,i_category_id] BroadcastHashJoin [ss_item_sk,i_item_sk] @@ -89,11 +89,11 @@ TakeOrderedAndProject [channel,i_brand_id,i_class_id,i_category_id,sum_sales,num InputAdapter Scan parquet default.date_dim [d_date_sk,d_year,d_moy] InputAdapter - WholeStageCodegen (21) + WholeStageCodegen (20) Sort [ss_item_sk] InputAdapter Exchange [ss_item_sk] #6 - WholeStageCodegen (20) + WholeStageCodegen (19) Project [i_item_sk] BroadcastHashJoin [i_brand_id,i_class_id,i_category_id,brand_id,class_id,category_id] Filter [i_brand_id,i_class_id,i_category_id] @@ -102,127 +102,122 @@ TakeOrderedAndProject [channel,i_brand_id,i_class_id,i_category_id,sum_sales,num Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] InputAdapter BroadcastExchange #7 - WholeStageCodegen (19) - HashAggregate [brand_id,class_id,category_id] + WholeStageCodegen (18) + SortMergeJoin [brand_id,class_id,category_id,i_brand_id,i_class_id,i_category_id] InputAdapter - Exchange [brand_id,class_id,category_id] #8 - WholeStageCodegen (18) - HashAggregate [brand_id,class_id,category_id] - SortMergeJoin [brand_id,class_id,category_id,i_brand_id,i_class_id,i_category_id] - InputAdapter - WholeStageCodegen (13) - Sort [brand_id,class_id,category_id] - InputAdapter - Exchange [brand_id,class_id,category_id] #9 - WholeStageCodegen (12) - HashAggregate [brand_id,class_id,category_id] - InputAdapter - Exchange [brand_id,class_id,category_id] #10 - WholeStageCodegen (11) - HashAggregate [brand_id,class_id,category_id] - Project [i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [ss_item_sk,i_item_sk] - Project [ss_item_sk] - BroadcastHashJoin [ss_sold_date_sk,d_date_sk] - Filter [ss_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.store_sales [ss_item_sk,ss_sold_date_sk] - SubqueryBroadcast [d_date_sk] #2 - BroadcastExchange #11 - WholeStageCodegen (1) - Project [d_date_sk] - Filter [d_year,d_date_sk] - ColumnarToRow - InputAdapter - Scan parquet default.date_dim [d_date_sk,d_year] + WholeStageCodegen (13) + Sort [brand_id,class_id,category_id] + InputAdapter + Exchange [brand_id,class_id,category_id] #8 + WholeStageCodegen (12) + HashAggregate [brand_id,class_id,category_id] + InputAdapter + Exchange [brand_id,class_id,category_id] #9 + WholeStageCodegen (11) + HashAggregate [brand_id,class_id,category_id] + Project [i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [ss_item_sk,i_item_sk] + Project [ss_item_sk] + BroadcastHashJoin [ss_sold_date_sk,d_date_sk] + Filter [ss_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.store_sales [ss_item_sk,ss_sold_date_sk] + SubqueryBroadcast [d_date_sk] #2 + BroadcastExchange #10 + WholeStageCodegen (1) + Project [d_date_sk] + Filter [d_year,d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_year] + InputAdapter + ReusedExchange [d_date_sk] #10 + InputAdapter + BroadcastExchange #11 + WholeStageCodegen (10) + SortMergeJoin [i_brand_id,i_class_id,i_category_id,i_brand_id,i_class_id,i_category_id] + InputAdapter + WholeStageCodegen (5) + Sort [i_brand_id,i_class_id,i_category_id] InputAdapter - ReusedExchange [d_date_sk] #11 - InputAdapter - BroadcastExchange #12 - WholeStageCodegen (10) - SortMergeJoin [i_brand_id,i_class_id,i_category_id,i_brand_id,i_class_id,i_category_id] - InputAdapter - WholeStageCodegen (5) - Sort [i_brand_id,i_class_id,i_category_id] + Exchange [i_brand_id,i_class_id,i_category_id] #12 + WholeStageCodegen (4) + Filter [i_item_sk,i_brand_id,i_class_id,i_category_id] + ColumnarToRow InputAdapter - Exchange [i_brand_id,i_class_id,i_category_id] #13 - WholeStageCodegen (4) - Filter [i_item_sk,i_brand_id,i_class_id,i_category_id] + Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] + InputAdapter + WholeStageCodegen (9) + Sort [i_brand_id,i_class_id,i_category_id] + InputAdapter + Exchange [i_brand_id,i_class_id,i_category_id] #13 + WholeStageCodegen (8) + Project [i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [cs_item_sk,i_item_sk] + Project [cs_item_sk] + BroadcastHashJoin [cs_sold_date_sk,d_date_sk] + Filter [cs_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.catalog_sales [cs_item_sk,cs_sold_date_sk] + ReusedSubquery [d_date_sk] #2 + InputAdapter + ReusedExchange [d_date_sk] #10 + InputAdapter + BroadcastExchange #14 + WholeStageCodegen (7) + Filter [i_item_sk] ColumnarToRow InputAdapter Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] - InputAdapter - WholeStageCodegen (9) - Sort [i_brand_id,i_class_id,i_category_id] - InputAdapter - Exchange [i_brand_id,i_class_id,i_category_id] #14 - WholeStageCodegen (8) - Project [i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [cs_item_sk,i_item_sk] - Project [cs_item_sk] - BroadcastHashJoin [cs_sold_date_sk,d_date_sk] - Filter [cs_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.catalog_sales [cs_item_sk,cs_sold_date_sk] - ReusedSubquery [d_date_sk] #2 - InputAdapter - ReusedExchange [d_date_sk] #11 - InputAdapter - BroadcastExchange #15 - WholeStageCodegen (7) - Filter [i_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] - InputAdapter - WholeStageCodegen (17) - Sort [i_brand_id,i_class_id,i_category_id] + InputAdapter + WholeStageCodegen (17) + Sort [i_brand_id,i_class_id,i_category_id] + InputAdapter + Exchange [i_brand_id,i_class_id,i_category_id] #15 + WholeStageCodegen (16) + Project [i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [ws_item_sk,i_item_sk] + Project [ws_item_sk] + BroadcastHashJoin [ws_sold_date_sk,d_date_sk] + Filter [ws_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.web_sales [ws_item_sk,ws_sold_date_sk] + ReusedSubquery [d_date_sk] #2 + InputAdapter + ReusedExchange [d_date_sk] #10 InputAdapter - Exchange [i_brand_id,i_class_id,i_category_id] #16 - WholeStageCodegen (16) - Project [i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [ws_item_sk,i_item_sk] - Project [ws_item_sk] - BroadcastHashJoin [ws_sold_date_sk,d_date_sk] - Filter [ws_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.web_sales [ws_item_sk,ws_sold_date_sk] - ReusedSubquery [d_date_sk] #2 - InputAdapter - ReusedExchange [d_date_sk] #11 - InputAdapter - ReusedExchange [i_item_sk,i_brand_id,i_class_id,i_category_id] #15 + ReusedExchange [i_item_sk,i_brand_id,i_class_id,i_category_id] #14 InputAdapter ReusedExchange [d_date_sk] #5 InputAdapter - BroadcastExchange #17 - WholeStageCodegen (44) + BroadcastExchange #16 + WholeStageCodegen (42) SortMergeJoin [i_item_sk,ss_item_sk] InputAdapter - WholeStageCodegen (24) + WholeStageCodegen (23) Sort [i_item_sk] InputAdapter - Exchange [i_item_sk] #18 - WholeStageCodegen (23) + Exchange [i_item_sk] #17 + WholeStageCodegen (22) Filter [i_item_sk] ColumnarToRow InputAdapter Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] InputAdapter - WholeStageCodegen (43) + WholeStageCodegen (41) Sort [ss_item_sk] InputAdapter ReusedExchange [ss_item_sk] #6 - WholeStageCodegen (92) + WholeStageCodegen (88) Filter [sales] ReusedSubquery [average_sales] #3 - HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price as decimal(12,2)))), DecimalType(18,2), true)),count(1),channel,sales,number_sales,sum,isEmpty,count] + HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(cs_quantity as decimal(12,2))) * promote_precision(cast(cs_list_price as decimal(12,2)))), DecimalType(18,2))),count(1),channel,sales,number_sales,sum,isEmpty,count] InputAdapter - Exchange [i_brand_id,i_class_id,i_category_id] #21 - WholeStageCodegen (91) + Exchange [i_brand_id,i_class_id,i_category_id] #20 + WholeStageCodegen (87) HashAggregate [i_brand_id,i_class_id,i_category_id,cs_quantity,cs_list_price] [sum,isEmpty,count,sum,isEmpty,count] Project [cs_quantity,cs_list_price,i_brand_id,i_class_id,i_category_id] BroadcastHashJoin [cs_item_sk,i_item_sk] @@ -230,32 +225,32 @@ TakeOrderedAndProject [channel,i_brand_id,i_class_id,i_category_id,sum_sales,num BroadcastHashJoin [cs_sold_date_sk,d_date_sk] SortMergeJoin [cs_item_sk,ss_item_sk] InputAdapter - WholeStageCodegen (48) + WholeStageCodegen (46) Sort [cs_item_sk] InputAdapter - Exchange [cs_item_sk] #22 - WholeStageCodegen (47) + Exchange [cs_item_sk] #21 + WholeStageCodegen (45) Filter [cs_item_sk] ColumnarToRow InputAdapter Scan parquet default.catalog_sales [cs_item_sk,cs_quantity,cs_list_price,cs_sold_date_sk] ReusedSubquery [d_date_sk] #1 InputAdapter - WholeStageCodegen (67) + WholeStageCodegen (64) Sort [ss_item_sk] InputAdapter ReusedExchange [ss_item_sk] #6 InputAdapter ReusedExchange [d_date_sk] #5 InputAdapter - ReusedExchange [i_item_sk,i_brand_id,i_class_id,i_category_id] #17 - WholeStageCodegen (138) + ReusedExchange [i_item_sk,i_brand_id,i_class_id,i_category_id] #16 + WholeStageCodegen (132) Filter [sales] ReusedSubquery [average_sales] #3 - HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price as decimal(12,2)))), DecimalType(18,2), true)),count(1),channel,sales,number_sales,sum,isEmpty,count] + HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(ws_quantity as decimal(12,2))) * promote_precision(cast(ws_list_price as decimal(12,2)))), DecimalType(18,2))),count(1),channel,sales,number_sales,sum,isEmpty,count] InputAdapter - Exchange [i_brand_id,i_class_id,i_category_id] #23 - WholeStageCodegen (137) + Exchange [i_brand_id,i_class_id,i_category_id] #22 + WholeStageCodegen (131) HashAggregate [i_brand_id,i_class_id,i_category_id,ws_quantity,ws_list_price] [sum,isEmpty,count,sum,isEmpty,count] Project [ws_quantity,ws_list_price,i_brand_id,i_class_id,i_category_id] BroadcastHashJoin [ws_item_sk,i_item_sk] @@ -263,57 +258,57 @@ TakeOrderedAndProject [channel,i_brand_id,i_class_id,i_category_id,sum_sales,num BroadcastHashJoin [ws_sold_date_sk,d_date_sk] SortMergeJoin [ws_item_sk,ss_item_sk] InputAdapter - WholeStageCodegen (94) + WholeStageCodegen (90) Sort [ws_item_sk] InputAdapter - Exchange [ws_item_sk] #24 - WholeStageCodegen (93) + Exchange [ws_item_sk] #23 + WholeStageCodegen (89) Filter [ws_item_sk] ColumnarToRow InputAdapter Scan parquet default.web_sales [ws_item_sk,ws_quantity,ws_list_price,ws_sold_date_sk] ReusedSubquery [d_date_sk] #1 InputAdapter - WholeStageCodegen (113) + WholeStageCodegen (108) Sort [ss_item_sk] InputAdapter ReusedExchange [ss_item_sk] #6 InputAdapter ReusedExchange [d_date_sk] #5 InputAdapter - ReusedExchange [i_item_sk,i_brand_id,i_class_id,i_category_id] #17 - WholeStageCodegen (281) + ReusedExchange [i_item_sk,i_brand_id,i_class_id,i_category_id] #16 + WholeStageCodegen (269) HashAggregate [channel,i_brand_id,i_class_id,sum,isEmpty,sum] [sum(sum_sales),sum(number_salesL),i_category_id,sum(sum_sales),sum(number_sales),sum,isEmpty,sum] InputAdapter - Exchange [channel,i_brand_id,i_class_id] #25 - WholeStageCodegen (280) + Exchange [channel,i_brand_id,i_class_id] #24 + WholeStageCodegen (268) HashAggregate [channel,i_brand_id,i_class_id,sum_sales,number_sales] [sum,isEmpty,sum,sum,isEmpty,sum] HashAggregate [channel,i_brand_id,i_class_id,i_category_id,sum,isEmpty,sum] [sum(sales),sum(number_salesL),sum_sales,number_sales,sum,isEmpty,sum] InputAdapter ReusedExchange [channel,i_brand_id,i_class_id,i_category_id,sum,isEmpty,sum] #2 - WholeStageCodegen (422) + WholeStageCodegen (404) HashAggregate [channel,i_brand_id,sum,isEmpty,sum] [sum(sum_sales),sum(number_salesL),i_class_id,i_category_id,sum(sum_sales),sum(number_sales),sum,isEmpty,sum] InputAdapter - Exchange [channel,i_brand_id] #26 - WholeStageCodegen (421) + Exchange [channel,i_brand_id] #25 + WholeStageCodegen (403) HashAggregate [channel,i_brand_id,sum_sales,number_sales] [sum,isEmpty,sum,sum,isEmpty,sum] HashAggregate [channel,i_brand_id,i_class_id,i_category_id,sum,isEmpty,sum] [sum(sales),sum(number_salesL),sum_sales,number_sales,sum,isEmpty,sum] InputAdapter ReusedExchange [channel,i_brand_id,i_class_id,i_category_id,sum,isEmpty,sum] #2 - WholeStageCodegen (563) + WholeStageCodegen (539) HashAggregate [channel,sum,isEmpty,sum] [sum(sum_sales),sum(number_salesL),i_brand_id,i_class_id,i_category_id,sum(sum_sales),sum(number_sales),sum,isEmpty,sum] InputAdapter - Exchange [channel] #27 - WholeStageCodegen (562) + Exchange [channel] #26 + WholeStageCodegen (538) HashAggregate [channel,sum_sales,number_sales] [sum,isEmpty,sum,sum,isEmpty,sum] HashAggregate [channel,i_brand_id,i_class_id,i_category_id,sum,isEmpty,sum] [sum(sales),sum(number_salesL),sum_sales,number_sales,sum,isEmpty,sum] InputAdapter ReusedExchange [channel,i_brand_id,i_class_id,i_category_id,sum,isEmpty,sum] #2 - WholeStageCodegen (704) + WholeStageCodegen (674) HashAggregate [sum,isEmpty,sum] [sum(sum_sales),sum(number_salesL),channel,i_brand_id,i_class_id,i_category_id,sum(sum_sales),sum(number_sales),sum,isEmpty,sum] InputAdapter - Exchange #28 - WholeStageCodegen (703) + Exchange #27 + WholeStageCodegen (673) HashAggregate [sum_sales,number_sales] [sum,isEmpty,sum,sum,isEmpty,sum] HashAggregate [channel,i_brand_id,i_class_id,i_category_id,sum,isEmpty,sum] [sum(sales),sum(number_salesL),sum_sales,number_sales,sum,isEmpty,sum] InputAdapter diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14a/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14a/explain.txt index bd3290f8c55b4..2438fa9d7eb57 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14a/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14a/explain.txt @@ -1,131 +1,129 @@ == Physical Plan == -TakeOrderedAndProject (127) -+- * HashAggregate (126) - +- Exchange (125) - +- * HashAggregate (124) - +- Union (123) - :- * HashAggregate (102) - : +- Exchange (101) - : +- * HashAggregate (100) - : +- Union (99) - : :- * Filter (68) - : : +- * HashAggregate (67) - : : +- Exchange (66) - : : +- * HashAggregate (65) - : : +- * Project (64) - : : +- * BroadcastHashJoin Inner BuildRight (63) - : : :- * Project (61) - : : : +- * BroadcastHashJoin Inner BuildRight (60) - : : : :- * BroadcastHashJoin LeftSemi BuildRight (53) +TakeOrderedAndProject (125) ++- * HashAggregate (124) + +- Exchange (123) + +- * HashAggregate (122) + +- Union (121) + :- * HashAggregate (100) + : +- Exchange (99) + : +- * HashAggregate (98) + : +- Union (97) + : :- * Filter (66) + : : +- * HashAggregate (65) + : : +- Exchange (64) + : : +- * HashAggregate (63) + : : +- * Project (62) + : : +- * BroadcastHashJoin Inner BuildRight (61) + : : :- * Project (59) + : : : +- * BroadcastHashJoin Inner BuildRight (58) + : : : :- * BroadcastHashJoin LeftSemi BuildRight (51) : : : : :- * Filter (3) : : : : : +- * ColumnarToRow (2) : : : : : +- Scan parquet default.store_sales (1) - : : : : +- BroadcastExchange (52) - : : : : +- * Project (51) - : : : : +- * BroadcastHashJoin Inner BuildRight (50) + : : : : +- BroadcastExchange (50) + : : : : +- * Project (49) + : : : : +- * BroadcastHashJoin Inner BuildRight (48) : : : : :- * Filter (6) : : : : : +- * ColumnarToRow (5) : : : : : +- Scan parquet default.item (4) - : : : : +- BroadcastExchange (49) - : : : : +- * HashAggregate (48) - : : : : +- * HashAggregate (47) - : : : : +- * BroadcastHashJoin LeftSemi BuildRight (46) - : : : : :- * HashAggregate (35) - : : : : : +- Exchange (34) - : : : : : +- * HashAggregate (33) - : : : : : +- * Project (32) - : : : : : +- * BroadcastHashJoin Inner BuildRight (31) - : : : : : :- * Project (29) - : : : : : : +- * BroadcastHashJoin Inner BuildRight (28) - : : : : : : :- * Filter (9) - : : : : : : : +- * ColumnarToRow (8) - : : : : : : : +- Scan parquet default.store_sales (7) - : : : : : : +- BroadcastExchange (27) - : : : : : : +- * BroadcastHashJoin LeftSemi BuildRight (26) - : : : : : : :- * Filter (12) - : : : : : : : +- * ColumnarToRow (11) - : : : : : : : +- Scan parquet default.item (10) - : : : : : : +- BroadcastExchange (25) - : : : : : : +- * Project (24) - : : : : : : +- * BroadcastHashJoin Inner BuildRight (23) - : : : : : : :- * Project (21) - : : : : : : : +- * BroadcastHashJoin Inner BuildRight (20) - : : : : : : : :- * Filter (15) - : : : : : : : : +- * ColumnarToRow (14) - : : : : : : : : +- Scan parquet default.catalog_sales (13) - : : : : : : : +- BroadcastExchange (19) - : : : : : : : +- * Filter (18) - : : : : : : : +- * ColumnarToRow (17) - : : : : : : : +- Scan parquet default.item (16) - : : : : : : +- ReusedExchange (22) - : : : : : +- ReusedExchange (30) - : : : : +- BroadcastExchange (45) - : : : : +- * Project (44) - : : : : +- * BroadcastHashJoin Inner BuildRight (43) - : : : : :- * Project (41) - : : : : : +- * BroadcastHashJoin Inner BuildRight (40) - : : : : : :- * Filter (38) - : : : : : : +- * ColumnarToRow (37) - : : : : : : +- Scan parquet default.web_sales (36) - : : : : : +- ReusedExchange (39) - : : : : +- ReusedExchange (42) - : : : +- BroadcastExchange (59) - : : : +- * BroadcastHashJoin LeftSemi BuildRight (58) - : : : :- * Filter (56) - : : : : +- * ColumnarToRow (55) - : : : : +- Scan parquet default.item (54) - : : : +- ReusedExchange (57) - : : +- ReusedExchange (62) - : :- * Filter (83) - : : +- * HashAggregate (82) - : : +- Exchange (81) - : : +- * HashAggregate (80) - : : +- * Project (79) - : : +- * BroadcastHashJoin Inner BuildRight (78) - : : :- * Project (76) - : : : +- * BroadcastHashJoin Inner BuildRight (75) - : : : :- * BroadcastHashJoin LeftSemi BuildRight (73) - : : : : :- * Filter (71) - : : : : : +- * ColumnarToRow (70) - : : : : : +- Scan parquet default.catalog_sales (69) - : : : : +- ReusedExchange (72) - : : : +- ReusedExchange (74) - : : +- ReusedExchange (77) - : +- * Filter (98) - : +- * HashAggregate (97) - : +- Exchange (96) - : +- * HashAggregate (95) - : +- * Project (94) - : +- * BroadcastHashJoin Inner BuildRight (93) - : :- * Project (91) - : : +- * BroadcastHashJoin Inner BuildRight (90) - : : :- * BroadcastHashJoin LeftSemi BuildRight (88) - : : : :- * Filter (86) - : : : : +- * ColumnarToRow (85) - : : : : +- Scan parquet default.web_sales (84) - : : : +- ReusedExchange (87) - : : +- ReusedExchange (89) - : +- ReusedExchange (92) - :- * HashAggregate (107) - : +- Exchange (106) - : +- * HashAggregate (105) - : +- * HashAggregate (104) - : +- ReusedExchange (103) - :- * HashAggregate (112) - : +- Exchange (111) - : +- * HashAggregate (110) - : +- * HashAggregate (109) - : +- ReusedExchange (108) - :- * HashAggregate (117) - : +- Exchange (116) - : +- * HashAggregate (115) - : +- * HashAggregate (114) - : +- ReusedExchange (113) - +- * HashAggregate (122) - +- Exchange (121) - +- * HashAggregate (120) - +- * HashAggregate (119) - +- ReusedExchange (118) + : : : : +- BroadcastExchange (47) + : : : : +- * BroadcastHashJoin LeftSemi BuildRight (46) + : : : : :- * HashAggregate (35) + : : : : : +- Exchange (34) + : : : : : +- * HashAggregate (33) + : : : : : +- * Project (32) + : : : : : +- * BroadcastHashJoin Inner BuildRight (31) + : : : : : :- * Project (29) + : : : : : : +- * BroadcastHashJoin Inner BuildRight (28) + : : : : : : :- * Filter (9) + : : : : : : : +- * ColumnarToRow (8) + : : : : : : : +- Scan parquet default.store_sales (7) + : : : : : : +- BroadcastExchange (27) + : : : : : : +- * BroadcastHashJoin LeftSemi BuildRight (26) + : : : : : : :- * Filter (12) + : : : : : : : +- * ColumnarToRow (11) + : : : : : : : +- Scan parquet default.item (10) + : : : : : : +- BroadcastExchange (25) + : : : : : : +- * Project (24) + : : : : : : +- * BroadcastHashJoin Inner BuildRight (23) + : : : : : : :- * Project (21) + : : : : : : : +- * BroadcastHashJoin Inner BuildRight (20) + : : : : : : : :- * Filter (15) + : : : : : : : : +- * ColumnarToRow (14) + : : : : : : : : +- Scan parquet default.catalog_sales (13) + : : : : : : : +- BroadcastExchange (19) + : : : : : : : +- * Filter (18) + : : : : : : : +- * ColumnarToRow (17) + : : : : : : : +- Scan parquet default.item (16) + : : : : : : +- ReusedExchange (22) + : : : : : +- ReusedExchange (30) + : : : : +- BroadcastExchange (45) + : : : : +- * Project (44) + : : : : +- * BroadcastHashJoin Inner BuildRight (43) + : : : : :- * Project (41) + : : : : : +- * BroadcastHashJoin Inner BuildRight (40) + : : : : : :- * Filter (38) + : : : : : : +- * ColumnarToRow (37) + : : : : : : +- Scan parquet default.web_sales (36) + : : : : : +- ReusedExchange (39) + : : : : +- ReusedExchange (42) + : : : +- BroadcastExchange (57) + : : : +- * BroadcastHashJoin LeftSemi BuildRight (56) + : : : :- * Filter (54) + : : : : +- * ColumnarToRow (53) + : : : : +- Scan parquet default.item (52) + : : : +- ReusedExchange (55) + : : +- ReusedExchange (60) + : :- * Filter (81) + : : +- * HashAggregate (80) + : : +- Exchange (79) + : : +- * HashAggregate (78) + : : +- * Project (77) + : : +- * BroadcastHashJoin Inner BuildRight (76) + : : :- * Project (74) + : : : +- * BroadcastHashJoin Inner BuildRight (73) + : : : :- * BroadcastHashJoin LeftSemi BuildRight (71) + : : : : :- * Filter (69) + : : : : : +- * ColumnarToRow (68) + : : : : : +- Scan parquet default.catalog_sales (67) + : : : : +- ReusedExchange (70) + : : : +- ReusedExchange (72) + : : +- ReusedExchange (75) + : +- * Filter (96) + : +- * HashAggregate (95) + : +- Exchange (94) + : +- * HashAggregate (93) + : +- * Project (92) + : +- * BroadcastHashJoin Inner BuildRight (91) + : :- * Project (89) + : : +- * BroadcastHashJoin Inner BuildRight (88) + : : :- * BroadcastHashJoin LeftSemi BuildRight (86) + : : : :- * Filter (84) + : : : : +- * ColumnarToRow (83) + : : : : +- Scan parquet default.web_sales (82) + : : : +- ReusedExchange (85) + : : +- ReusedExchange (87) + : +- ReusedExchange (90) + :- * HashAggregate (105) + : +- Exchange (104) + : +- * HashAggregate (103) + : +- * HashAggregate (102) + : +- ReusedExchange (101) + :- * HashAggregate (110) + : +- Exchange (109) + : +- * HashAggregate (108) + : +- * HashAggregate (107) + : +- ReusedExchange (106) + :- * HashAggregate (115) + : +- Exchange (114) + : +- * HashAggregate (113) + : +- * HashAggregate (112) + : +- ReusedExchange (111) + +- * HashAggregate (120) + +- Exchange (119) + +- * HashAggregate (118) + +- * HashAggregate (117) + +- ReusedExchange (116) (1) Scan parquet default.store_sales @@ -228,7 +226,7 @@ Join condition: None Output [4]: [cs_sold_date_sk#18, i_brand_id#20, i_class_id#21, i_category_id#22] Input [6]: [cs_item_sk#17, cs_sold_date_sk#18, i_item_sk#19, i_brand_id#20, i_class_id#21, i_category_id#22] -(22) ReusedExchange [Reuses operator id: 161] +(22) ReusedExchange [Reuses operator id: 159] Output [1]: [d_date_sk#24] (23) BroadcastHashJoin [codegen id : 3] @@ -262,7 +260,7 @@ Join condition: None Output [4]: [ss_sold_date_sk#11, i_brand_id#14, i_class_id#15, i_category_id#16] Input [6]: [ss_item_sk#10, ss_sold_date_sk#11, i_item_sk#13, i_brand_id#14, i_class_id#15, i_category_id#16] -(30) ReusedExchange [Reuses operator id: 161] +(30) ReusedExchange [Reuses operator id: 159] Output [1]: [d_date_sk#27] (31) BroadcastHashJoin [codegen id : 6] @@ -319,7 +317,7 @@ Join condition: None Output [4]: [ws_sold_date_sk#33, i_brand_id#35, i_class_id#36, i_category_id#37] Input [6]: [ws_item_sk#32, ws_sold_date_sk#33, i_item_sk#34, i_brand_id#35, i_class_id#36, i_category_id#37] -(42) ReusedExchange [Reuses operator id: 161] +(42) ReusedExchange [Reuses operator id: 159] Output [1]: [d_date_sk#38] (43) BroadcastHashJoin [codegen id : 9] @@ -340,112 +338,98 @@ Left keys [6]: [coalesce(brand_id#28, 0), isnull(brand_id#28), coalesce(class_id Right keys [6]: [coalesce(i_brand_id#35, 0), isnull(i_brand_id#35), coalesce(i_class_id#36, 0), isnull(i_class_id#36), coalesce(i_category_id#37, 0), isnull(i_category_id#37)] Join condition: None -(47) HashAggregate [codegen id : 10] -Input [3]: [brand_id#28, class_id#29, category_id#30] -Keys [3]: [brand_id#28, class_id#29, category_id#30] -Functions: [] -Aggregate Attributes: [] -Results [3]: [brand_id#28, class_id#29, category_id#30] - -(48) HashAggregate [codegen id : 10] -Input [3]: [brand_id#28, class_id#29, category_id#30] -Keys [3]: [brand_id#28, class_id#29, category_id#30] -Functions: [] -Aggregate Attributes: [] -Results [3]: [brand_id#28, class_id#29, category_id#30] - -(49) BroadcastExchange +(47) BroadcastExchange Input [3]: [brand_id#28, class_id#29, category_id#30] Arguments: HashedRelationBroadcastMode(List(input[0, int, true], input[1, int, true], input[2, int, true]),false), [id=#40] -(50) BroadcastHashJoin [codegen id : 11] +(48) BroadcastHashJoin [codegen id : 11] Left keys [3]: [i_brand_id#7, i_class_id#8, i_category_id#9] Right keys [3]: [brand_id#28, class_id#29, category_id#30] Join condition: None -(51) Project [codegen id : 11] +(49) Project [codegen id : 11] Output [1]: [i_item_sk#6 AS ss_item_sk#41] Input [7]: [i_item_sk#6, i_brand_id#7, i_class_id#8, i_category_id#9, brand_id#28, class_id#29, category_id#30] -(52) BroadcastExchange +(50) BroadcastExchange Input [1]: [ss_item_sk#41] Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#42] -(53) BroadcastHashJoin [codegen id : 25] +(51) BroadcastHashJoin [codegen id : 25] Left keys [1]: [ss_item_sk#1] Right keys [1]: [ss_item_sk#41] Join condition: None -(54) Scan parquet default.item +(52) Scan parquet default.item Output [4]: [i_item_sk#43, i_brand_id#44, i_class_id#45, i_category_id#46] Batched: true Location [not included in comparison]/{warehouse_dir}/item] PushedFilters: [IsNotNull(i_item_sk)] ReadSchema: struct -(55) ColumnarToRow [codegen id : 23] +(53) ColumnarToRow [codegen id : 23] Input [4]: [i_item_sk#43, i_brand_id#44, i_class_id#45, i_category_id#46] -(56) Filter [codegen id : 23] +(54) Filter [codegen id : 23] Input [4]: [i_item_sk#43, i_brand_id#44, i_class_id#45, i_category_id#46] Condition : isnotnull(i_item_sk#43) -(57) ReusedExchange [Reuses operator id: 52] +(55) ReusedExchange [Reuses operator id: 50] Output [1]: [ss_item_sk#41] -(58) BroadcastHashJoin [codegen id : 23] +(56) BroadcastHashJoin [codegen id : 23] Left keys [1]: [i_item_sk#43] Right keys [1]: [ss_item_sk#41] Join condition: None -(59) BroadcastExchange +(57) BroadcastExchange Input [4]: [i_item_sk#43, i_brand_id#44, i_class_id#45, i_category_id#46] Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#47] -(60) BroadcastHashJoin [codegen id : 25] +(58) BroadcastHashJoin [codegen id : 25] Left keys [1]: [ss_item_sk#1] Right keys [1]: [i_item_sk#43] Join condition: None -(61) Project [codegen id : 25] +(59) Project [codegen id : 25] Output [6]: [ss_quantity#2, ss_list_price#3, ss_sold_date_sk#4, i_brand_id#44, i_class_id#45, i_category_id#46] Input [8]: [ss_item_sk#1, ss_quantity#2, ss_list_price#3, ss_sold_date_sk#4, i_item_sk#43, i_brand_id#44, i_class_id#45, i_category_id#46] -(62) ReusedExchange [Reuses operator id: 156] +(60) ReusedExchange [Reuses operator id: 154] Output [1]: [d_date_sk#48] -(63) BroadcastHashJoin [codegen id : 25] +(61) BroadcastHashJoin [codegen id : 25] Left keys [1]: [ss_sold_date_sk#4] Right keys [1]: [d_date_sk#48] Join condition: None -(64) Project [codegen id : 25] +(62) Project [codegen id : 25] Output [5]: [ss_quantity#2, ss_list_price#3, i_brand_id#44, i_class_id#45, i_category_id#46] Input [7]: [ss_quantity#2, ss_list_price#3, ss_sold_date_sk#4, i_brand_id#44, i_class_id#45, i_category_id#46, d_date_sk#48] -(65) HashAggregate [codegen id : 25] +(63) HashAggregate [codegen id : 25] Input [5]: [ss_quantity#2, ss_list_price#3, i_brand_id#44, i_class_id#45, i_category_id#46] Keys [3]: [i_brand_id#44, i_class_id#45, i_category_id#46] -Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true)), partial_count(1)] +Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2))), partial_count(1)] Aggregate Attributes [3]: [sum#49, isEmpty#50, count#51] Results [6]: [i_brand_id#44, i_class_id#45, i_category_id#46, sum#52, isEmpty#53, count#54] -(66) Exchange +(64) Exchange Input [6]: [i_brand_id#44, i_class_id#45, i_category_id#46, sum#52, isEmpty#53, count#54] Arguments: hashpartitioning(i_brand_id#44, i_class_id#45, i_category_id#46, 5), ENSURE_REQUIREMENTS, [id=#55] -(67) HashAggregate [codegen id : 26] +(65) HashAggregate [codegen id : 26] Input [6]: [i_brand_id#44, i_class_id#45, i_category_id#46, sum#52, isEmpty#53, count#54] Keys [3]: [i_brand_id#44, i_class_id#45, i_category_id#46] -Functions [2]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true)), count(1)] -Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true))#56, count(1)#57] -Results [6]: [store AS channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sum(CheckOverflow((promote_precision(cast(cast(ss_quantity#2 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2), true))#56 AS sales#59, count(1)#57 AS number_sales#60] +Functions [2]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2))), count(1)] +Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2)))#56, count(1)#57] +Results [6]: [store AS channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sum(CheckOverflow((promote_precision(cast(ss_quantity#2 as decimal(12,2))) * promote_precision(cast(ss_list_price#3 as decimal(12,2)))), DecimalType(18,2)))#56 AS sales#59, count(1)#57 AS number_sales#60] -(68) Filter [codegen id : 26] +(66) Filter [codegen id : 26] Input [6]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sales#59, number_sales#60] Condition : (isnotnull(sales#59) AND (cast(sales#59 as decimal(32,6)) > cast(Subquery scalar-subquery#61, [id=#62] as decimal(32,6)))) -(69) Scan parquet default.catalog_sales +(67) Scan parquet default.catalog_sales Output [4]: [cs_item_sk#63, cs_quantity#64, cs_list_price#65, cs_sold_date_sk#66] Batched: true Location: InMemoryFileIndex [] @@ -453,68 +437,68 @@ PartitionFilters: [isnotnull(cs_sold_date_sk#66), dynamicpruningexpression(cs_so PushedFilters: [IsNotNull(cs_item_sk)] ReadSchema: struct -(70) ColumnarToRow [codegen id : 51] +(68) ColumnarToRow [codegen id : 51] Input [4]: [cs_item_sk#63, cs_quantity#64, cs_list_price#65, cs_sold_date_sk#66] -(71) Filter [codegen id : 51] +(69) Filter [codegen id : 51] Input [4]: [cs_item_sk#63, cs_quantity#64, cs_list_price#65, cs_sold_date_sk#66] Condition : isnotnull(cs_item_sk#63) -(72) ReusedExchange [Reuses operator id: 52] +(70) ReusedExchange [Reuses operator id: 50] Output [1]: [ss_item_sk#41] -(73) BroadcastHashJoin [codegen id : 51] +(71) BroadcastHashJoin [codegen id : 51] Left keys [1]: [cs_item_sk#63] Right keys [1]: [ss_item_sk#41] Join condition: None -(74) ReusedExchange [Reuses operator id: 59] +(72) ReusedExchange [Reuses operator id: 57] Output [4]: [i_item_sk#67, i_brand_id#68, i_class_id#69, i_category_id#70] -(75) BroadcastHashJoin [codegen id : 51] +(73) BroadcastHashJoin [codegen id : 51] Left keys [1]: [cs_item_sk#63] Right keys [1]: [i_item_sk#67] Join condition: None -(76) Project [codegen id : 51] +(74) Project [codegen id : 51] Output [6]: [cs_quantity#64, cs_list_price#65, cs_sold_date_sk#66, i_brand_id#68, i_class_id#69, i_category_id#70] Input [8]: [cs_item_sk#63, cs_quantity#64, cs_list_price#65, cs_sold_date_sk#66, i_item_sk#67, i_brand_id#68, i_class_id#69, i_category_id#70] -(77) ReusedExchange [Reuses operator id: 156] +(75) ReusedExchange [Reuses operator id: 154] Output [1]: [d_date_sk#71] -(78) BroadcastHashJoin [codegen id : 51] +(76) BroadcastHashJoin [codegen id : 51] Left keys [1]: [cs_sold_date_sk#66] Right keys [1]: [d_date_sk#71] Join condition: None -(79) Project [codegen id : 51] +(77) Project [codegen id : 51] Output [5]: [cs_quantity#64, cs_list_price#65, i_brand_id#68, i_class_id#69, i_category_id#70] Input [7]: [cs_quantity#64, cs_list_price#65, cs_sold_date_sk#66, i_brand_id#68, i_class_id#69, i_category_id#70, d_date_sk#71] -(80) HashAggregate [codegen id : 51] +(78) HashAggregate [codegen id : 51] Input [5]: [cs_quantity#64, cs_list_price#65, i_brand_id#68, i_class_id#69, i_category_id#70] Keys [3]: [i_brand_id#68, i_class_id#69, i_category_id#70] -Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(cast(cs_quantity#64 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price#65 as decimal(12,2)))), DecimalType(18,2), true)), partial_count(1)] +Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(cs_quantity#64 as decimal(12,2))) * promote_precision(cast(cs_list_price#65 as decimal(12,2)))), DecimalType(18,2))), partial_count(1)] Aggregate Attributes [3]: [sum#72, isEmpty#73, count#74] Results [6]: [i_brand_id#68, i_class_id#69, i_category_id#70, sum#75, isEmpty#76, count#77] -(81) Exchange +(79) Exchange Input [6]: [i_brand_id#68, i_class_id#69, i_category_id#70, sum#75, isEmpty#76, count#77] Arguments: hashpartitioning(i_brand_id#68, i_class_id#69, i_category_id#70, 5), ENSURE_REQUIREMENTS, [id=#78] -(82) HashAggregate [codegen id : 52] +(80) HashAggregate [codegen id : 52] Input [6]: [i_brand_id#68, i_class_id#69, i_category_id#70, sum#75, isEmpty#76, count#77] Keys [3]: [i_brand_id#68, i_class_id#69, i_category_id#70] -Functions [2]: [sum(CheckOverflow((promote_precision(cast(cast(cs_quantity#64 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price#65 as decimal(12,2)))), DecimalType(18,2), true)), count(1)] -Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(cast(cs_quantity#64 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price#65 as decimal(12,2)))), DecimalType(18,2), true))#79, count(1)#80] -Results [6]: [catalog AS channel#81, i_brand_id#68, i_class_id#69, i_category_id#70, sum(CheckOverflow((promote_precision(cast(cast(cs_quantity#64 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price#65 as decimal(12,2)))), DecimalType(18,2), true))#79 AS sales#82, count(1)#80 AS number_sales#83] +Functions [2]: [sum(CheckOverflow((promote_precision(cast(cs_quantity#64 as decimal(12,2))) * promote_precision(cast(cs_list_price#65 as decimal(12,2)))), DecimalType(18,2))), count(1)] +Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(cs_quantity#64 as decimal(12,2))) * promote_precision(cast(cs_list_price#65 as decimal(12,2)))), DecimalType(18,2)))#79, count(1)#80] +Results [6]: [catalog AS channel#81, i_brand_id#68, i_class_id#69, i_category_id#70, sum(CheckOverflow((promote_precision(cast(cs_quantity#64 as decimal(12,2))) * promote_precision(cast(cs_list_price#65 as decimal(12,2)))), DecimalType(18,2)))#79 AS sales#82, count(1)#80 AS number_sales#83] -(83) Filter [codegen id : 52] +(81) Filter [codegen id : 52] Input [6]: [channel#81, i_brand_id#68, i_class_id#69, i_category_id#70, sales#82, number_sales#83] Condition : (isnotnull(sales#82) AND (cast(sales#82 as decimal(32,6)) > cast(ReusedSubquery Subquery scalar-subquery#61, [id=#62] as decimal(32,6)))) -(84) Scan parquet default.web_sales +(82) Scan parquet default.web_sales Output [4]: [ws_item_sk#84, ws_quantity#85, ws_list_price#86, ws_sold_date_sk#87] Batched: true Location: InMemoryFileIndex [] @@ -522,424 +506,424 @@ PartitionFilters: [isnotnull(ws_sold_date_sk#87), dynamicpruningexpression(ws_so PushedFilters: [IsNotNull(ws_item_sk)] ReadSchema: struct -(85) ColumnarToRow [codegen id : 77] +(83) ColumnarToRow [codegen id : 77] Input [4]: [ws_item_sk#84, ws_quantity#85, ws_list_price#86, ws_sold_date_sk#87] -(86) Filter [codegen id : 77] +(84) Filter [codegen id : 77] Input [4]: [ws_item_sk#84, ws_quantity#85, ws_list_price#86, ws_sold_date_sk#87] Condition : isnotnull(ws_item_sk#84) -(87) ReusedExchange [Reuses operator id: 52] +(85) ReusedExchange [Reuses operator id: 50] Output [1]: [ss_item_sk#41] -(88) BroadcastHashJoin [codegen id : 77] +(86) BroadcastHashJoin [codegen id : 77] Left keys [1]: [ws_item_sk#84] Right keys [1]: [ss_item_sk#41] Join condition: None -(89) ReusedExchange [Reuses operator id: 59] +(87) ReusedExchange [Reuses operator id: 57] Output [4]: [i_item_sk#88, i_brand_id#89, i_class_id#90, i_category_id#91] -(90) BroadcastHashJoin [codegen id : 77] +(88) BroadcastHashJoin [codegen id : 77] Left keys [1]: [ws_item_sk#84] Right keys [1]: [i_item_sk#88] Join condition: None -(91) Project [codegen id : 77] +(89) Project [codegen id : 77] Output [6]: [ws_quantity#85, ws_list_price#86, ws_sold_date_sk#87, i_brand_id#89, i_class_id#90, i_category_id#91] Input [8]: [ws_item_sk#84, ws_quantity#85, ws_list_price#86, ws_sold_date_sk#87, i_item_sk#88, i_brand_id#89, i_class_id#90, i_category_id#91] -(92) ReusedExchange [Reuses operator id: 156] +(90) ReusedExchange [Reuses operator id: 154] Output [1]: [d_date_sk#92] -(93) BroadcastHashJoin [codegen id : 77] +(91) BroadcastHashJoin [codegen id : 77] Left keys [1]: [ws_sold_date_sk#87] Right keys [1]: [d_date_sk#92] Join condition: None -(94) Project [codegen id : 77] +(92) Project [codegen id : 77] Output [5]: [ws_quantity#85, ws_list_price#86, i_brand_id#89, i_class_id#90, i_category_id#91] Input [7]: [ws_quantity#85, ws_list_price#86, ws_sold_date_sk#87, i_brand_id#89, i_class_id#90, i_category_id#91, d_date_sk#92] -(95) HashAggregate [codegen id : 77] +(93) HashAggregate [codegen id : 77] Input [5]: [ws_quantity#85, ws_list_price#86, i_brand_id#89, i_class_id#90, i_category_id#91] Keys [3]: [i_brand_id#89, i_class_id#90, i_category_id#91] -Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(cast(ws_quantity#85 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price#86 as decimal(12,2)))), DecimalType(18,2), true)), partial_count(1)] +Functions [2]: [partial_sum(CheckOverflow((promote_precision(cast(ws_quantity#85 as decimal(12,2))) * promote_precision(cast(ws_list_price#86 as decimal(12,2)))), DecimalType(18,2))), partial_count(1)] Aggregate Attributes [3]: [sum#93, isEmpty#94, count#95] Results [6]: [i_brand_id#89, i_class_id#90, i_category_id#91, sum#96, isEmpty#97, count#98] -(96) Exchange +(94) Exchange Input [6]: [i_brand_id#89, i_class_id#90, i_category_id#91, sum#96, isEmpty#97, count#98] Arguments: hashpartitioning(i_brand_id#89, i_class_id#90, i_category_id#91, 5), ENSURE_REQUIREMENTS, [id=#99] -(97) HashAggregate [codegen id : 78] +(95) HashAggregate [codegen id : 78] Input [6]: [i_brand_id#89, i_class_id#90, i_category_id#91, sum#96, isEmpty#97, count#98] Keys [3]: [i_brand_id#89, i_class_id#90, i_category_id#91] -Functions [2]: [sum(CheckOverflow((promote_precision(cast(cast(ws_quantity#85 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price#86 as decimal(12,2)))), DecimalType(18,2), true)), count(1)] -Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(cast(ws_quantity#85 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price#86 as decimal(12,2)))), DecimalType(18,2), true))#100, count(1)#101] -Results [6]: [web AS channel#102, i_brand_id#89, i_class_id#90, i_category_id#91, sum(CheckOverflow((promote_precision(cast(cast(ws_quantity#85 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price#86 as decimal(12,2)))), DecimalType(18,2), true))#100 AS sales#103, count(1)#101 AS number_sales#104] +Functions [2]: [sum(CheckOverflow((promote_precision(cast(ws_quantity#85 as decimal(12,2))) * promote_precision(cast(ws_list_price#86 as decimal(12,2)))), DecimalType(18,2))), count(1)] +Aggregate Attributes [2]: [sum(CheckOverflow((promote_precision(cast(ws_quantity#85 as decimal(12,2))) * promote_precision(cast(ws_list_price#86 as decimal(12,2)))), DecimalType(18,2)))#100, count(1)#101] +Results [6]: [web AS channel#102, i_brand_id#89, i_class_id#90, i_category_id#91, sum(CheckOverflow((promote_precision(cast(ws_quantity#85 as decimal(12,2))) * promote_precision(cast(ws_list_price#86 as decimal(12,2)))), DecimalType(18,2)))#100 AS sales#103, count(1)#101 AS number_sales#104] -(98) Filter [codegen id : 78] +(96) Filter [codegen id : 78] Input [6]: [channel#102, i_brand_id#89, i_class_id#90, i_category_id#91, sales#103, number_sales#104] Condition : (isnotnull(sales#103) AND (cast(sales#103 as decimal(32,6)) > cast(ReusedSubquery Subquery scalar-subquery#61, [id=#62] as decimal(32,6)))) -(99) Union +(97) Union -(100) HashAggregate [codegen id : 79] +(98) HashAggregate [codegen id : 79] Input [6]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sales#59, number_sales#60] Keys [4]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46] Functions [2]: [partial_sum(sales#59), partial_sum(number_sales#60)] Aggregate Attributes [3]: [sum#105, isEmpty#106, sum#107] Results [7]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sum#108, isEmpty#109, sum#110] -(101) Exchange +(99) Exchange Input [7]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sum#108, isEmpty#109, sum#110] Arguments: hashpartitioning(channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, 5), ENSURE_REQUIREMENTS, [id=#111] -(102) HashAggregate [codegen id : 80] +(100) HashAggregate [codegen id : 80] Input [7]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sum#108, isEmpty#109, sum#110] Keys [4]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46] Functions [2]: [sum(sales#59), sum(number_sales#60)] Aggregate Attributes [2]: [sum(sales#59)#112, sum(number_sales#60)#113] Results [6]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sum(sales#59)#112 AS sum_sales#114, sum(number_sales#60)#113 AS number_sales#115] -(103) ReusedExchange [Reuses operator id: 101] +(101) ReusedExchange [Reuses operator id: 99] Output [7]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sum#108, isEmpty#109, sum#110] -(104) HashAggregate [codegen id : 160] +(102) HashAggregate [codegen id : 160] Input [7]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sum#108, isEmpty#109, sum#110] Keys [4]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46] Functions [2]: [sum(sales#59), sum(number_sales#60)] Aggregate Attributes [2]: [sum(sales#59)#112, sum(number_sales#60)#113] Results [5]: [channel#58, i_brand_id#44, i_class_id#45, sum(sales#59)#112 AS sum_sales#114, sum(number_sales#60)#113 AS number_sales#115] -(105) HashAggregate [codegen id : 160] +(103) HashAggregate [codegen id : 160] Input [5]: [channel#58, i_brand_id#44, i_class_id#45, sum_sales#114, number_sales#115] Keys [3]: [channel#58, i_brand_id#44, i_class_id#45] Functions [2]: [partial_sum(sum_sales#114), partial_sum(number_sales#115)] Aggregate Attributes [3]: [sum#116, isEmpty#117, sum#118] Results [6]: [channel#58, i_brand_id#44, i_class_id#45, sum#119, isEmpty#120, sum#121] -(106) Exchange +(104) Exchange Input [6]: [channel#58, i_brand_id#44, i_class_id#45, sum#119, isEmpty#120, sum#121] Arguments: hashpartitioning(channel#58, i_brand_id#44, i_class_id#45, 5), ENSURE_REQUIREMENTS, [id=#122] -(107) HashAggregate [codegen id : 161] +(105) HashAggregate [codegen id : 161] Input [6]: [channel#58, i_brand_id#44, i_class_id#45, sum#119, isEmpty#120, sum#121] Keys [3]: [channel#58, i_brand_id#44, i_class_id#45] Functions [2]: [sum(sum_sales#114), sum(number_sales#115)] Aggregate Attributes [2]: [sum(sum_sales#114)#123, sum(number_sales#115)#124] Results [6]: [channel#58, i_brand_id#44, i_class_id#45, null AS i_category_id#125, sum(sum_sales#114)#123 AS sum(sum_sales)#126, sum(number_sales#115)#124 AS sum(number_sales)#127] -(108) ReusedExchange [Reuses operator id: 101] +(106) ReusedExchange [Reuses operator id: 99] Output [7]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sum#108, isEmpty#109, sum#110] -(109) HashAggregate [codegen id : 241] +(107) HashAggregate [codegen id : 241] Input [7]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sum#108, isEmpty#109, sum#110] Keys [4]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46] Functions [2]: [sum(sales#59), sum(number_sales#60)] Aggregate Attributes [2]: [sum(sales#59)#112, sum(number_sales#60)#113] Results [4]: [channel#58, i_brand_id#44, sum(sales#59)#112 AS sum_sales#114, sum(number_sales#60)#113 AS number_sales#115] -(110) HashAggregate [codegen id : 241] +(108) HashAggregate [codegen id : 241] Input [4]: [channel#58, i_brand_id#44, sum_sales#114, number_sales#115] Keys [2]: [channel#58, i_brand_id#44] Functions [2]: [partial_sum(sum_sales#114), partial_sum(number_sales#115)] Aggregate Attributes [3]: [sum#128, isEmpty#129, sum#130] Results [5]: [channel#58, i_brand_id#44, sum#131, isEmpty#132, sum#133] -(111) Exchange +(109) Exchange Input [5]: [channel#58, i_brand_id#44, sum#131, isEmpty#132, sum#133] Arguments: hashpartitioning(channel#58, i_brand_id#44, 5), ENSURE_REQUIREMENTS, [id=#134] -(112) HashAggregate [codegen id : 242] +(110) HashAggregate [codegen id : 242] Input [5]: [channel#58, i_brand_id#44, sum#131, isEmpty#132, sum#133] Keys [2]: [channel#58, i_brand_id#44] Functions [2]: [sum(sum_sales#114), sum(number_sales#115)] Aggregate Attributes [2]: [sum(sum_sales#114)#135, sum(number_sales#115)#136] Results [6]: [channel#58, i_brand_id#44, null AS i_class_id#137, null AS i_category_id#138, sum(sum_sales#114)#135 AS sum(sum_sales)#139, sum(number_sales#115)#136 AS sum(number_sales)#140] -(113) ReusedExchange [Reuses operator id: 101] +(111) ReusedExchange [Reuses operator id: 99] Output [7]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sum#108, isEmpty#109, sum#110] -(114) HashAggregate [codegen id : 322] +(112) HashAggregate [codegen id : 322] Input [7]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sum#108, isEmpty#109, sum#110] Keys [4]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46] Functions [2]: [sum(sales#59), sum(number_sales#60)] Aggregate Attributes [2]: [sum(sales#59)#112, sum(number_sales#60)#113] Results [3]: [channel#58, sum(sales#59)#112 AS sum_sales#114, sum(number_sales#60)#113 AS number_sales#115] -(115) HashAggregate [codegen id : 322] +(113) HashAggregate [codegen id : 322] Input [3]: [channel#58, sum_sales#114, number_sales#115] Keys [1]: [channel#58] Functions [2]: [partial_sum(sum_sales#114), partial_sum(number_sales#115)] Aggregate Attributes [3]: [sum#141, isEmpty#142, sum#143] Results [4]: [channel#58, sum#144, isEmpty#145, sum#146] -(116) Exchange +(114) Exchange Input [4]: [channel#58, sum#144, isEmpty#145, sum#146] Arguments: hashpartitioning(channel#58, 5), ENSURE_REQUIREMENTS, [id=#147] -(117) HashAggregate [codegen id : 323] +(115) HashAggregate [codegen id : 323] Input [4]: [channel#58, sum#144, isEmpty#145, sum#146] Keys [1]: [channel#58] Functions [2]: [sum(sum_sales#114), sum(number_sales#115)] Aggregate Attributes [2]: [sum(sum_sales#114)#148, sum(number_sales#115)#149] Results [6]: [channel#58, null AS i_brand_id#150, null AS i_class_id#151, null AS i_category_id#152, sum(sum_sales#114)#148 AS sum(sum_sales)#153, sum(number_sales#115)#149 AS sum(number_sales)#154] -(118) ReusedExchange [Reuses operator id: 101] +(116) ReusedExchange [Reuses operator id: 99] Output [7]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sum#108, isEmpty#109, sum#110] -(119) HashAggregate [codegen id : 403] +(117) HashAggregate [codegen id : 403] Input [7]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sum#108, isEmpty#109, sum#110] Keys [4]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46] Functions [2]: [sum(sales#59), sum(number_sales#60)] Aggregate Attributes [2]: [sum(sales#59)#112, sum(number_sales#60)#113] Results [2]: [sum(sales#59)#112 AS sum_sales#114, sum(number_sales#60)#113 AS number_sales#115] -(120) HashAggregate [codegen id : 403] +(118) HashAggregate [codegen id : 403] Input [2]: [sum_sales#114, number_sales#115] Keys: [] Functions [2]: [partial_sum(sum_sales#114), partial_sum(number_sales#115)] Aggregate Attributes [3]: [sum#155, isEmpty#156, sum#157] Results [3]: [sum#158, isEmpty#159, sum#160] -(121) Exchange +(119) Exchange Input [3]: [sum#158, isEmpty#159, sum#160] Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#161] -(122) HashAggregate [codegen id : 404] +(120) HashAggregate [codegen id : 404] Input [3]: [sum#158, isEmpty#159, sum#160] Keys: [] Functions [2]: [sum(sum_sales#114), sum(number_sales#115)] Aggregate Attributes [2]: [sum(sum_sales#114)#162, sum(number_sales#115)#163] Results [6]: [null AS channel#164, null AS i_brand_id#165, null AS i_class_id#166, null AS i_category_id#167, sum(sum_sales#114)#162 AS sum(sum_sales)#168, sum(number_sales#115)#163 AS sum(number_sales)#169] -(123) Union +(121) Union -(124) HashAggregate [codegen id : 405] +(122) HashAggregate [codegen id : 405] Input [6]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sum_sales#114, number_sales#115] Keys [6]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sum_sales#114, number_sales#115] Functions: [] Aggregate Attributes: [] Results [6]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sum_sales#114, number_sales#115] -(125) Exchange +(123) Exchange Input [6]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sum_sales#114, number_sales#115] Arguments: hashpartitioning(channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sum_sales#114, number_sales#115, 5), ENSURE_REQUIREMENTS, [id=#170] -(126) HashAggregate [codegen id : 406] +(124) HashAggregate [codegen id : 406] Input [6]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sum_sales#114, number_sales#115] Keys [6]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sum_sales#114, number_sales#115] Functions: [] Aggregate Attributes: [] Results [6]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sum_sales#114, number_sales#115] -(127) TakeOrderedAndProject +(125) TakeOrderedAndProject Input [6]: [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sum_sales#114, number_sales#115] Arguments: 100, [channel#58 ASC NULLS FIRST, i_brand_id#44 ASC NULLS FIRST, i_class_id#45 ASC NULLS FIRST, i_category_id#46 ASC NULLS FIRST], [channel#58, i_brand_id#44, i_class_id#45, i_category_id#46, sum_sales#114, number_sales#115] ===== Subqueries ===== -Subquery:1 Hosting operator id = 68 Hosting Expression = Subquery scalar-subquery#61, [id=#62] -* HashAggregate (146) -+- Exchange (145) - +- * HashAggregate (144) - +- Union (143) - :- * Project (132) - : +- * BroadcastHashJoin Inner BuildRight (131) - : :- * ColumnarToRow (129) - : : +- Scan parquet default.store_sales (128) - : +- ReusedExchange (130) - :- * Project (137) - : +- * BroadcastHashJoin Inner BuildRight (136) - : :- * ColumnarToRow (134) - : : +- Scan parquet default.catalog_sales (133) - : +- ReusedExchange (135) - +- * Project (142) - +- * BroadcastHashJoin Inner BuildRight (141) - :- * ColumnarToRow (139) - : +- Scan parquet default.web_sales (138) - +- ReusedExchange (140) - - -(128) Scan parquet default.store_sales +Subquery:1 Hosting operator id = 66 Hosting Expression = Subquery scalar-subquery#61, [id=#62] +* HashAggregate (144) ++- Exchange (143) + +- * HashAggregate (142) + +- Union (141) + :- * Project (130) + : +- * BroadcastHashJoin Inner BuildRight (129) + : :- * ColumnarToRow (127) + : : +- Scan parquet default.store_sales (126) + : +- ReusedExchange (128) + :- * Project (135) + : +- * BroadcastHashJoin Inner BuildRight (134) + : :- * ColumnarToRow (132) + : : +- Scan parquet default.catalog_sales (131) + : +- ReusedExchange (133) + +- * Project (140) + +- * BroadcastHashJoin Inner BuildRight (139) + :- * ColumnarToRow (137) + : +- Scan parquet default.web_sales (136) + +- ReusedExchange (138) + + +(126) Scan parquet default.store_sales Output [3]: [ss_quantity#171, ss_list_price#172, ss_sold_date_sk#173] Batched: true Location: InMemoryFileIndex [] PartitionFilters: [isnotnull(ss_sold_date_sk#173), dynamicpruningexpression(ss_sold_date_sk#173 IN dynamicpruning#12)] ReadSchema: struct -(129) ColumnarToRow [codegen id : 2] +(127) ColumnarToRow [codegen id : 2] Input [3]: [ss_quantity#171, ss_list_price#172, ss_sold_date_sk#173] -(130) ReusedExchange [Reuses operator id: 161] +(128) ReusedExchange [Reuses operator id: 159] Output [1]: [d_date_sk#174] -(131) BroadcastHashJoin [codegen id : 2] +(129) BroadcastHashJoin [codegen id : 2] Left keys [1]: [ss_sold_date_sk#173] Right keys [1]: [d_date_sk#174] Join condition: None -(132) Project [codegen id : 2] +(130) Project [codegen id : 2] Output [2]: [ss_quantity#171 AS quantity#175, ss_list_price#172 AS list_price#176] Input [4]: [ss_quantity#171, ss_list_price#172, ss_sold_date_sk#173, d_date_sk#174] -(133) Scan parquet default.catalog_sales +(131) Scan parquet default.catalog_sales Output [3]: [cs_quantity#177, cs_list_price#178, cs_sold_date_sk#179] Batched: true Location: InMemoryFileIndex [] PartitionFilters: [isnotnull(cs_sold_date_sk#179), dynamicpruningexpression(cs_sold_date_sk#179 IN dynamicpruning#180)] ReadSchema: struct -(134) ColumnarToRow [codegen id : 4] +(132) ColumnarToRow [codegen id : 4] Input [3]: [cs_quantity#177, cs_list_price#178, cs_sold_date_sk#179] -(135) ReusedExchange [Reuses operator id: 151] +(133) ReusedExchange [Reuses operator id: 149] Output [1]: [d_date_sk#181] -(136) BroadcastHashJoin [codegen id : 4] +(134) BroadcastHashJoin [codegen id : 4] Left keys [1]: [cs_sold_date_sk#179] Right keys [1]: [d_date_sk#181] Join condition: None -(137) Project [codegen id : 4] +(135) Project [codegen id : 4] Output [2]: [cs_quantity#177 AS quantity#182, cs_list_price#178 AS list_price#183] Input [4]: [cs_quantity#177, cs_list_price#178, cs_sold_date_sk#179, d_date_sk#181] -(138) Scan parquet default.web_sales +(136) Scan parquet default.web_sales Output [3]: [ws_quantity#184, ws_list_price#185, ws_sold_date_sk#186] Batched: true Location: InMemoryFileIndex [] PartitionFilters: [isnotnull(ws_sold_date_sk#186), dynamicpruningexpression(ws_sold_date_sk#186 IN dynamicpruning#180)] ReadSchema: struct -(139) ColumnarToRow [codegen id : 6] +(137) ColumnarToRow [codegen id : 6] Input [3]: [ws_quantity#184, ws_list_price#185, ws_sold_date_sk#186] -(140) ReusedExchange [Reuses operator id: 151] +(138) ReusedExchange [Reuses operator id: 149] Output [1]: [d_date_sk#187] -(141) BroadcastHashJoin [codegen id : 6] +(139) BroadcastHashJoin [codegen id : 6] Left keys [1]: [ws_sold_date_sk#186] Right keys [1]: [d_date_sk#187] Join condition: None -(142) Project [codegen id : 6] +(140) Project [codegen id : 6] Output [2]: [ws_quantity#184 AS quantity#188, ws_list_price#185 AS list_price#189] Input [4]: [ws_quantity#184, ws_list_price#185, ws_sold_date_sk#186, d_date_sk#187] -(143) Union +(141) Union -(144) HashAggregate [codegen id : 7] +(142) HashAggregate [codegen id : 7] Input [2]: [quantity#175, list_price#176] Keys: [] -Functions [1]: [partial_avg(CheckOverflow((promote_precision(cast(cast(quantity#175 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#176 as decimal(12,2)))), DecimalType(18,2), true))] +Functions [1]: [partial_avg(CheckOverflow((promote_precision(cast(quantity#175 as decimal(12,2))) * promote_precision(cast(list_price#176 as decimal(12,2)))), DecimalType(18,2)))] Aggregate Attributes [2]: [sum#190, count#191] Results [2]: [sum#192, count#193] -(145) Exchange +(143) Exchange Input [2]: [sum#192, count#193] Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#194] -(146) HashAggregate [codegen id : 8] +(144) HashAggregate [codegen id : 8] Input [2]: [sum#192, count#193] Keys: [] -Functions [1]: [avg(CheckOverflow((promote_precision(cast(cast(quantity#175 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#176 as decimal(12,2)))), DecimalType(18,2), true))] -Aggregate Attributes [1]: [avg(CheckOverflow((promote_precision(cast(cast(quantity#175 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#176 as decimal(12,2)))), DecimalType(18,2), true))#195] -Results [1]: [avg(CheckOverflow((promote_precision(cast(cast(quantity#175 as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price#176 as decimal(12,2)))), DecimalType(18,2), true))#195 AS average_sales#196] +Functions [1]: [avg(CheckOverflow((promote_precision(cast(quantity#175 as decimal(12,2))) * promote_precision(cast(list_price#176 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [1]: [avg(CheckOverflow((promote_precision(cast(quantity#175 as decimal(12,2))) * promote_precision(cast(list_price#176 as decimal(12,2)))), DecimalType(18,2)))#195] +Results [1]: [avg(CheckOverflow((promote_precision(cast(quantity#175 as decimal(12,2))) * promote_precision(cast(list_price#176 as decimal(12,2)))), DecimalType(18,2)))#195 AS average_sales#196] -Subquery:2 Hosting operator id = 128 Hosting Expression = ss_sold_date_sk#173 IN dynamicpruning#12 +Subquery:2 Hosting operator id = 126 Hosting Expression = ss_sold_date_sk#173 IN dynamicpruning#12 -Subquery:3 Hosting operator id = 133 Hosting Expression = cs_sold_date_sk#179 IN dynamicpruning#180 -BroadcastExchange (151) -+- * Project (150) - +- * Filter (149) - +- * ColumnarToRow (148) - +- Scan parquet default.date_dim (147) +Subquery:3 Hosting operator id = 131 Hosting Expression = cs_sold_date_sk#179 IN dynamicpruning#180 +BroadcastExchange (149) ++- * Project (148) + +- * Filter (147) + +- * ColumnarToRow (146) + +- Scan parquet default.date_dim (145) -(147) Scan parquet default.date_dim +(145) Scan parquet default.date_dim Output [2]: [d_date_sk#181, d_year#197] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), GreaterThanOrEqual(d_year,1998), LessThanOrEqual(d_year,2000), IsNotNull(d_date_sk)] ReadSchema: struct -(148) ColumnarToRow [codegen id : 1] +(146) ColumnarToRow [codegen id : 1] Input [2]: [d_date_sk#181, d_year#197] -(149) Filter [codegen id : 1] +(147) Filter [codegen id : 1] Input [2]: [d_date_sk#181, d_year#197] Condition : (((isnotnull(d_year#197) AND (d_year#197 >= 1998)) AND (d_year#197 <= 2000)) AND isnotnull(d_date_sk#181)) -(150) Project [codegen id : 1] +(148) Project [codegen id : 1] Output [1]: [d_date_sk#181] Input [2]: [d_date_sk#181, d_year#197] -(151) BroadcastExchange +(149) BroadcastExchange Input [1]: [d_date_sk#181] Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#198] -Subquery:4 Hosting operator id = 138 Hosting Expression = ws_sold_date_sk#186 IN dynamicpruning#180 +Subquery:4 Hosting operator id = 136 Hosting Expression = ws_sold_date_sk#186 IN dynamicpruning#180 Subquery:5 Hosting operator id = 1 Hosting Expression = ss_sold_date_sk#4 IN dynamicpruning#5 -BroadcastExchange (156) -+- * Project (155) - +- * Filter (154) - +- * ColumnarToRow (153) - +- Scan parquet default.date_dim (152) +BroadcastExchange (154) ++- * Project (153) + +- * Filter (152) + +- * ColumnarToRow (151) + +- Scan parquet default.date_dim (150) -(152) Scan parquet default.date_dim +(150) Scan parquet default.date_dim Output [3]: [d_date_sk#48, d_year#199, d_moy#200] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), IsNotNull(d_moy), EqualTo(d_year,2000), EqualTo(d_moy,11), IsNotNull(d_date_sk)] ReadSchema: struct -(153) ColumnarToRow [codegen id : 1] +(151) ColumnarToRow [codegen id : 1] Input [3]: [d_date_sk#48, d_year#199, d_moy#200] -(154) Filter [codegen id : 1] +(152) Filter [codegen id : 1] Input [3]: [d_date_sk#48, d_year#199, d_moy#200] Condition : ((((isnotnull(d_year#199) AND isnotnull(d_moy#200)) AND (d_year#199 = 2000)) AND (d_moy#200 = 11)) AND isnotnull(d_date_sk#48)) -(155) Project [codegen id : 1] +(153) Project [codegen id : 1] Output [1]: [d_date_sk#48] Input [3]: [d_date_sk#48, d_year#199, d_moy#200] -(156) BroadcastExchange +(154) BroadcastExchange Input [1]: [d_date_sk#48] Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#201] Subquery:6 Hosting operator id = 7 Hosting Expression = ss_sold_date_sk#11 IN dynamicpruning#12 -BroadcastExchange (161) -+- * Project (160) - +- * Filter (159) - +- * ColumnarToRow (158) - +- Scan parquet default.date_dim (157) +BroadcastExchange (159) ++- * Project (158) + +- * Filter (157) + +- * ColumnarToRow (156) + +- Scan parquet default.date_dim (155) -(157) Scan parquet default.date_dim +(155) Scan parquet default.date_dim Output [2]: [d_date_sk#27, d_year#202] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), GreaterThanOrEqual(d_year,1999), LessThanOrEqual(d_year,2001), IsNotNull(d_date_sk)] ReadSchema: struct -(158) ColumnarToRow [codegen id : 1] +(156) ColumnarToRow [codegen id : 1] Input [2]: [d_date_sk#27, d_year#202] -(159) Filter [codegen id : 1] +(157) Filter [codegen id : 1] Input [2]: [d_date_sk#27, d_year#202] Condition : (((isnotnull(d_year#202) AND (d_year#202 >= 1999)) AND (d_year#202 <= 2001)) AND isnotnull(d_date_sk#27)) -(160) Project [codegen id : 1] +(158) Project [codegen id : 1] Output [1]: [d_date_sk#27] Input [2]: [d_date_sk#27, d_year#202] -(161) BroadcastExchange +(159) BroadcastExchange Input [1]: [d_date_sk#27] Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#203] @@ -947,12 +931,12 @@ Subquery:7 Hosting operator id = 13 Hosting Expression = cs_sold_date_sk#18 IN d Subquery:8 Hosting operator id = 36 Hosting Expression = ws_sold_date_sk#33 IN dynamicpruning#12 -Subquery:9 Hosting operator id = 83 Hosting Expression = ReusedSubquery Subquery scalar-subquery#61, [id=#62] +Subquery:9 Hosting operator id = 81 Hosting Expression = ReusedSubquery Subquery scalar-subquery#61, [id=#62] -Subquery:10 Hosting operator id = 69 Hosting Expression = cs_sold_date_sk#66 IN dynamicpruning#5 +Subquery:10 Hosting operator id = 67 Hosting Expression = cs_sold_date_sk#66 IN dynamicpruning#5 -Subquery:11 Hosting operator id = 98 Hosting Expression = ReusedSubquery Subquery scalar-subquery#61, [id=#62] +Subquery:11 Hosting operator id = 96 Hosting Expression = ReusedSubquery Subquery scalar-subquery#61, [id=#62] -Subquery:12 Hosting operator id = 84 Hosting Expression = ws_sold_date_sk#87 IN dynamicpruning#5 +Subquery:12 Hosting operator id = 82 Hosting Expression = ws_sold_date_sk#87 IN dynamicpruning#5 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14a/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14a/simplified.txt index 3a56d26b3b2d3..086c36864ebdb 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14a/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14a/simplified.txt @@ -19,7 +19,7 @@ TakeOrderedAndProject [channel,i_brand_id,i_class_id,i_category_id,sum_sales,num Filter [sales] Subquery #3 WholeStageCodegen (8) - HashAggregate [sum,count] [avg(CheckOverflow((promote_precision(cast(cast(quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(list_price as decimal(12,2)))), DecimalType(18,2), true)),average_sales,sum,count] + HashAggregate [sum,count] [avg(CheckOverflow((promote_precision(cast(quantity as decimal(12,2))) * promote_precision(cast(list_price as decimal(12,2)))), DecimalType(18,2))),average_sales,sum,count] InputAdapter Exchange #14 WholeStageCodegen (7) @@ -60,7 +60,7 @@ TakeOrderedAndProject [channel,i_brand_id,i_class_id,i_category_id,sum_sales,num ReusedSubquery [d_date_sk] #4 InputAdapter ReusedExchange [d_date_sk] #15 - HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ss_list_price as decimal(12,2)))), DecimalType(18,2), true)),count(1),channel,sales,number_sales,sum,isEmpty,count] + HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(ss_quantity as decimal(12,2))) * promote_precision(cast(ss_list_price as decimal(12,2)))), DecimalType(18,2))),count(1),channel,sales,number_sales,sum,isEmpty,count] InputAdapter Exchange [i_brand_id,i_class_id,i_category_id] #3 WholeStageCodegen (25) @@ -94,77 +94,75 @@ TakeOrderedAndProject [channel,i_brand_id,i_class_id,i_category_id,sum_sales,num InputAdapter BroadcastExchange #6 WholeStageCodegen (10) - HashAggregate [brand_id,class_id,category_id] + BroadcastHashJoin [brand_id,class_id,category_id,i_brand_id,i_class_id,i_category_id] HashAggregate [brand_id,class_id,category_id] - BroadcastHashJoin [brand_id,class_id,category_id,i_brand_id,i_class_id,i_category_id] - HashAggregate [brand_id,class_id,category_id] - InputAdapter - Exchange [brand_id,class_id,category_id] #7 - WholeStageCodegen (6) - HashAggregate [brand_id,class_id,category_id] - Project [i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [ss_sold_date_sk,d_date_sk] - Project [ss_sold_date_sk,i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [ss_item_sk,i_item_sk] - Filter [ss_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.store_sales [ss_item_sk,ss_sold_date_sk] - SubqueryBroadcast [d_date_sk] #2 - BroadcastExchange #8 - WholeStageCodegen (1) - Project [d_date_sk] - Filter [d_year,d_date_sk] - ColumnarToRow - InputAdapter - Scan parquet default.date_dim [d_date_sk,d_year] - InputAdapter - BroadcastExchange #9 - WholeStageCodegen (4) - BroadcastHashJoin [i_brand_id,i_class_id,i_category_id,i_brand_id,i_class_id,i_category_id] - Filter [i_item_sk,i_brand_id,i_class_id,i_category_id] - ColumnarToRow - InputAdapter - Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] - InputAdapter - BroadcastExchange #10 - WholeStageCodegen (3) - Project [i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [cs_sold_date_sk,d_date_sk] - Project [cs_sold_date_sk,i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [cs_item_sk,i_item_sk] - Filter [cs_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.catalog_sales [cs_item_sk,cs_sold_date_sk] - ReusedSubquery [d_date_sk] #2 - InputAdapter - BroadcastExchange #11 - WholeStageCodegen (1) - Filter [i_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] - InputAdapter - ReusedExchange [d_date_sk] #8 - InputAdapter - ReusedExchange [d_date_sk] #8 - InputAdapter - BroadcastExchange #12 - WholeStageCodegen (9) + InputAdapter + Exchange [brand_id,class_id,category_id] #7 + WholeStageCodegen (6) + HashAggregate [brand_id,class_id,category_id] Project [i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [ws_sold_date_sk,d_date_sk] - Project [ws_sold_date_sk,i_brand_id,i_class_id,i_category_id] - BroadcastHashJoin [ws_item_sk,i_item_sk] - Filter [ws_item_sk] + BroadcastHashJoin [ss_sold_date_sk,d_date_sk] + Project [ss_sold_date_sk,i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [ss_item_sk,i_item_sk] + Filter [ss_item_sk] ColumnarToRow InputAdapter - Scan parquet default.web_sales [ws_item_sk,ws_sold_date_sk] - ReusedSubquery [d_date_sk] #2 + Scan parquet default.store_sales [ss_item_sk,ss_sold_date_sk] + SubqueryBroadcast [d_date_sk] #2 + BroadcastExchange #8 + WholeStageCodegen (1) + Project [d_date_sk] + Filter [d_year,d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_year] InputAdapter - ReusedExchange [i_item_sk,i_brand_id,i_class_id,i_category_id] #11 + BroadcastExchange #9 + WholeStageCodegen (4) + BroadcastHashJoin [i_brand_id,i_class_id,i_category_id,i_brand_id,i_class_id,i_category_id] + Filter [i_item_sk,i_brand_id,i_class_id,i_category_id] + ColumnarToRow + InputAdapter + Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] + InputAdapter + BroadcastExchange #10 + WholeStageCodegen (3) + Project [i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [cs_sold_date_sk,d_date_sk] + Project [cs_sold_date_sk,i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [cs_item_sk,i_item_sk] + Filter [cs_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.catalog_sales [cs_item_sk,cs_sold_date_sk] + ReusedSubquery [d_date_sk] #2 + InputAdapter + BroadcastExchange #11 + WholeStageCodegen (1) + Filter [i_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.item [i_item_sk,i_brand_id,i_class_id,i_category_id] + InputAdapter + ReusedExchange [d_date_sk] #8 InputAdapter ReusedExchange [d_date_sk] #8 + InputAdapter + BroadcastExchange #12 + WholeStageCodegen (9) + Project [i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [ws_sold_date_sk,d_date_sk] + Project [ws_sold_date_sk,i_brand_id,i_class_id,i_category_id] + BroadcastHashJoin [ws_item_sk,i_item_sk] + Filter [ws_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.web_sales [ws_item_sk,ws_sold_date_sk] + ReusedSubquery [d_date_sk] #2 + InputAdapter + ReusedExchange [i_item_sk,i_brand_id,i_class_id,i_category_id] #11 + InputAdapter + ReusedExchange [d_date_sk] #8 InputAdapter BroadcastExchange #13 WholeStageCodegen (23) @@ -180,7 +178,7 @@ TakeOrderedAndProject [channel,i_brand_id,i_class_id,i_category_id,sum_sales,num WholeStageCodegen (52) Filter [sales] ReusedSubquery [average_sales] #3 - HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(cast(cs_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(cs_list_price as decimal(12,2)))), DecimalType(18,2), true)),count(1),channel,sales,number_sales,sum,isEmpty,count] + HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(cs_quantity as decimal(12,2))) * promote_precision(cast(cs_list_price as decimal(12,2)))), DecimalType(18,2))),count(1),channel,sales,number_sales,sum,isEmpty,count] InputAdapter Exchange [i_brand_id,i_class_id,i_category_id] #16 WholeStageCodegen (51) @@ -204,7 +202,7 @@ TakeOrderedAndProject [channel,i_brand_id,i_class_id,i_category_id,sum_sales,num WholeStageCodegen (78) Filter [sales] ReusedSubquery [average_sales] #3 - HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(cast(ws_quantity as decimal(10,0)) as decimal(12,2))) * promote_precision(cast(ws_list_price as decimal(12,2)))), DecimalType(18,2), true)),count(1),channel,sales,number_sales,sum,isEmpty,count] + HashAggregate [i_brand_id,i_class_id,i_category_id,sum,isEmpty,count] [sum(CheckOverflow((promote_precision(cast(ws_quantity as decimal(12,2))) * promote_precision(cast(ws_list_price as decimal(12,2)))), DecimalType(18,2))),count(1),channel,sales,number_sales,sum,isEmpty,count] InputAdapter Exchange [i_brand_id,i_class_id,i_category_id] #17 WholeStageCodegen (77) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q20.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q20.sf100/explain.txt index 64a92b9e727bc..c925197336e95 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q20.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q20.sf100/explain.txt @@ -121,7 +121,7 @@ Input [8]: [i_item_id#7, i_item_desc#8, i_category#11, i_class#10, i_current_pri Arguments: [sum(_w1#20) windowspecdefinition(i_class#10, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS _we0#22], [i_class#10] (22) Project [codegen id : 9] -Output [7]: [i_item_id#7, i_item_desc#8, i_category#11, i_class#10, i_current_price#9, itemrevenue#18, CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(_w0#19) * 100.00), DecimalType(21,2), true) as decimal(27,2))) / promote_precision(_we0#22)), DecimalType(38,17), true) AS revenueratio#23] +Output [7]: [i_item_id#7, i_item_desc#8, i_category#11, i_class#10, i_current_price#9, itemrevenue#18, CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(_w0#19) * 100.00), DecimalType(21,2)) as decimal(27,2))) / promote_precision(_we0#22)), DecimalType(38,17)) AS revenueratio#23] Input [9]: [i_item_id#7, i_item_desc#8, i_category#11, i_class#10, i_current_price#9, itemrevenue#18, _w0#19, _w1#20, _we0#22] (23) TakeOrderedAndProject diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q20/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q20/explain.txt index 5ea1cda2f68d5..ff461dafc09c0 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q20/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q20/explain.txt @@ -106,7 +106,7 @@ Input [8]: [i_item_id#6, i_item_desc#7, i_category#10, i_class#9, i_current_pric Arguments: [sum(_w1#19) windowspecdefinition(i_class#9, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS _we0#21], [i_class#9] (19) Project [codegen id : 6] -Output [7]: [i_item_id#6, i_item_desc#7, i_category#10, i_class#9, i_current_price#8, itemrevenue#17, CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(_w0#18) * 100.00), DecimalType(21,2), true) as decimal(27,2))) / promote_precision(_we0#21)), DecimalType(38,17), true) AS revenueratio#22] +Output [7]: [i_item_id#6, i_item_desc#7, i_category#10, i_class#9, i_current_price#8, itemrevenue#17, CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(_w0#18) * 100.00), DecimalType(21,2)) as decimal(27,2))) / promote_precision(_we0#21)), DecimalType(38,17)) AS revenueratio#22] Input [9]: [i_item_id#6, i_item_desc#7, i_category#10, i_class#9, i_current_price#8, itemrevenue#17, _w0#18, _w1#19, _we0#21] (20) TakeOrderedAndProject diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q24.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q24.sf100/explain.txt index 332a0b9220538..db2116117c81e 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q24.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q24.sf100/explain.txt @@ -273,37 +273,34 @@ Arguments: [c_last_name#15 ASC NULLS FIRST, c_first_name#14 ASC NULLS FIRST, s_s ===== Subqueries ===== Subquery:1 Hosting operator id = 46 Hosting Expression = Subquery scalar-subquery#48, [id=#49] -* HashAggregate (79) -+- Exchange (78) - +- * HashAggregate (77) - +- * HashAggregate (76) - +- Exchange (75) - +- * HashAggregate (74) - +- * Project (73) - +- * SortMergeJoin Inner (72) - :- * Sort (65) - : +- * Project (64) - : +- * SortMergeJoin Inner (63) - : :- * Sort (57) - : : +- Exchange (56) - : : +- * Project (55) - : : +- * BroadcastHashJoin Inner BuildLeft (54) - : : :- ReusedExchange (49) - : : +- * Project (53) - : : +- * Filter (52) - : : +- * ColumnarToRow (51) - : : +- Scan parquet default.store_sales (50) - : +- * Sort (62) - : +- Exchange (61) - : +- * Filter (60) - : +- * ColumnarToRow (59) - : +- Scan parquet default.item (58) - +- * Sort (71) - +- Exchange (70) - +- * Project (69) - +- * Filter (68) - +- * ColumnarToRow (67) - +- Scan parquet default.store_returns (66) +* HashAggregate (76) ++- Exchange (75) + +- * HashAggregate (74) + +- * HashAggregate (73) + +- Exchange (72) + +- * HashAggregate (71) + +- * Project (70) + +- * SortMergeJoin Inner (69) + :- * Sort (66) + : +- Exchange (65) + : +- * Project (64) + : +- * SortMergeJoin Inner (63) + : :- * Sort (57) + : : +- Exchange (56) + : : +- * Project (55) + : : +- * BroadcastHashJoin Inner BuildLeft (54) + : : :- ReusedExchange (49) + : : +- * Project (53) + : : +- * Filter (52) + : : +- * ColumnarToRow (51) + : : +- Scan parquet default.store_sales (50) + : +- * Sort (62) + : +- Exchange (61) + : +- * Filter (60) + : +- * ColumnarToRow (59) + : +- Scan parquet default.item (58) + +- * Sort (68) + +- ReusedExchange (67) (49) ReusedExchange [Reuses operator id: 17] @@ -375,79 +372,64 @@ Join condition: None Output [13]: [s_store_name#2, s_state#4, ca_state#8, c_first_name#14, c_last_name#15, ss_item_sk#18, ss_ticket_number#21, ss_net_paid#22, i_current_price#25, i_size#26, i_color#27, i_units#28, i_manager_id#29] Input [14]: [s_store_name#2, s_state#4, ca_state#8, c_first_name#14, c_last_name#15, ss_item_sk#18, ss_ticket_number#21, ss_net_paid#22, i_item_sk#24, i_current_price#25, i_size#26, i_color#27, i_units#28, i_manager_id#29] -(65) Sort [codegen id : 8] +(65) Exchange Input [13]: [s_store_name#2, s_state#4, ca_state#8, c_first_name#14, c_last_name#15, ss_item_sk#18, ss_ticket_number#21, ss_net_paid#22, i_current_price#25, i_size#26, i_color#27, i_units#28, i_manager_id#29] -Arguments: [ss_ticket_number#21 ASC NULLS FIRST, ss_item_sk#18 ASC NULLS FIRST], false, 0 - -(66) Scan parquet default.store_returns -Output [3]: [sr_item_sk#32, sr_ticket_number#33, sr_returned_date_sk#34] -Batched: true -Location [not included in comparison]/{warehouse_dir}/store_returns] -PushedFilters: [IsNotNull(sr_ticket_number), IsNotNull(sr_item_sk)] -ReadSchema: struct - -(67) ColumnarToRow [codegen id : 9] -Input [3]: [sr_item_sk#32, sr_ticket_number#33, sr_returned_date_sk#34] +Arguments: hashpartitioning(ss_ticket_number#21, ss_item_sk#18, 5), ENSURE_REQUIREMENTS, [id=#53] -(68) Filter [codegen id : 9] -Input [3]: [sr_item_sk#32, sr_ticket_number#33, sr_returned_date_sk#34] -Condition : (isnotnull(sr_ticket_number#33) AND isnotnull(sr_item_sk#32)) +(66) Sort [codegen id : 9] +Input [13]: [s_store_name#2, s_state#4, ca_state#8, c_first_name#14, c_last_name#15, ss_item_sk#18, ss_ticket_number#21, ss_net_paid#22, i_current_price#25, i_size#26, i_color#27, i_units#28, i_manager_id#29] +Arguments: [ss_ticket_number#21 ASC NULLS FIRST, ss_item_sk#18 ASC NULLS FIRST], false, 0 -(69) Project [codegen id : 9] +(67) ReusedExchange [Reuses operator id: 36] Output [2]: [sr_item_sk#32, sr_ticket_number#33] -Input [3]: [sr_item_sk#32, sr_ticket_number#33, sr_returned_date_sk#34] -(70) Exchange -Input [2]: [sr_item_sk#32, sr_ticket_number#33] -Arguments: hashpartitioning(sr_item_sk#32, 5), ENSURE_REQUIREMENTS, [id=#53] - -(71) Sort [codegen id : 10] +(68) Sort [codegen id : 11] Input [2]: [sr_item_sk#32, sr_ticket_number#33] Arguments: [sr_ticket_number#33 ASC NULLS FIRST, sr_item_sk#32 ASC NULLS FIRST], false, 0 -(72) SortMergeJoin [codegen id : 11] +(69) SortMergeJoin [codegen id : 12] Left keys [2]: [ss_ticket_number#21, ss_item_sk#18] Right keys [2]: [sr_ticket_number#33, sr_item_sk#32] Join condition: None -(73) Project [codegen id : 11] +(70) Project [codegen id : 12] Output [11]: [ss_net_paid#22, s_store_name#2, s_state#4, i_current_price#25, i_size#26, i_color#27, i_units#28, i_manager_id#29, c_first_name#14, c_last_name#15, ca_state#8] Input [15]: [s_store_name#2, s_state#4, ca_state#8, c_first_name#14, c_last_name#15, ss_item_sk#18, ss_ticket_number#21, ss_net_paid#22, i_current_price#25, i_size#26, i_color#27, i_units#28, i_manager_id#29, sr_item_sk#32, sr_ticket_number#33] -(74) HashAggregate [codegen id : 11] +(71) HashAggregate [codegen id : 12] Input [11]: [ss_net_paid#22, s_store_name#2, s_state#4, i_current_price#25, i_size#26, i_color#27, i_units#28, i_manager_id#29, c_first_name#14, c_last_name#15, ca_state#8] Keys [10]: [c_last_name#15, c_first_name#14, s_store_name#2, ca_state#8, s_state#4, i_color#27, i_current_price#25, i_manager_id#29, i_units#28, i_size#26] Functions [1]: [partial_sum(UnscaledValue(ss_net_paid#22))] Aggregate Attributes [1]: [sum#54] Results [11]: [c_last_name#15, c_first_name#14, s_store_name#2, ca_state#8, s_state#4, i_color#27, i_current_price#25, i_manager_id#29, i_units#28, i_size#26, sum#55] -(75) Exchange +(72) Exchange Input [11]: [c_last_name#15, c_first_name#14, s_store_name#2, ca_state#8, s_state#4, i_color#27, i_current_price#25, i_manager_id#29, i_units#28, i_size#26, sum#55] Arguments: hashpartitioning(c_last_name#15, c_first_name#14, s_store_name#2, ca_state#8, s_state#4, i_color#27, i_current_price#25, i_manager_id#29, i_units#28, i_size#26, 5), ENSURE_REQUIREMENTS, [id=#56] -(76) HashAggregate [codegen id : 12] +(73) HashAggregate [codegen id : 13] Input [11]: [c_last_name#15, c_first_name#14, s_store_name#2, ca_state#8, s_state#4, i_color#27, i_current_price#25, i_manager_id#29, i_units#28, i_size#26, sum#55] Keys [10]: [c_last_name#15, c_first_name#14, s_store_name#2, ca_state#8, s_state#4, i_color#27, i_current_price#25, i_manager_id#29, i_units#28, i_size#26] Functions [1]: [sum(UnscaledValue(ss_net_paid#22))] Aggregate Attributes [1]: [sum(UnscaledValue(ss_net_paid#22))#39] Results [1]: [MakeDecimal(sum(UnscaledValue(ss_net_paid#22))#39,17,2) AS netpaid#40] -(77) HashAggregate [codegen id : 12] +(74) HashAggregate [codegen id : 13] Input [1]: [netpaid#40] Keys: [] Functions [1]: [partial_avg(netpaid#40)] Aggregate Attributes [2]: [sum#57, count#58] Results [2]: [sum#59, count#60] -(78) Exchange +(75) Exchange Input [2]: [sum#59, count#60] Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#61] -(79) HashAggregate [codegen id : 13] +(76) HashAggregate [codegen id : 14] Input [2]: [sum#59, count#60] Keys: [] Functions [1]: [avg(netpaid#40)] Aggregate Attributes [1]: [avg(netpaid#40)#62] -Results [1]: [CheckOverflow((0.050000 * promote_precision(avg(netpaid#40)#62)), DecimalType(24,8), true) AS (0.05 * avg(netpaid))#63] +Results [1]: [CheckOverflow((0.050000 * promote_precision(avg(netpaid#40)#62)), DecimalType(24,8)) AS (0.05 * avg(netpaid))#63] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q24.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q24.sf100/simplified.txt index d12b734269651..4beebcbbe52ef 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q24.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q24.sf100/simplified.txt @@ -5,60 +5,57 @@ WholeStageCodegen (12) WholeStageCodegen (11) Filter [paid] Subquery #1 - WholeStageCodegen (13) + WholeStageCodegen (14) HashAggregate [sum,count] [avg(netpaid),(0.05 * avg(netpaid)),sum,count] InputAdapter Exchange #10 - WholeStageCodegen (12) + WholeStageCodegen (13) HashAggregate [netpaid] [sum,count,sum,count] HashAggregate [c_last_name,c_first_name,s_store_name,ca_state,s_state,i_color,i_current_price,i_manager_id,i_units,i_size,sum] [sum(UnscaledValue(ss_net_paid)),netpaid,sum] InputAdapter Exchange [c_last_name,c_first_name,s_store_name,ca_state,s_state,i_color,i_current_price,i_manager_id,i_units,i_size] #11 - WholeStageCodegen (11) + WholeStageCodegen (12) HashAggregate [c_last_name,c_first_name,s_store_name,ca_state,s_state,i_color,i_current_price,i_manager_id,i_units,i_size,ss_net_paid] [sum,sum] Project [ss_net_paid,s_store_name,s_state,i_current_price,i_size,i_color,i_units,i_manager_id,c_first_name,c_last_name,ca_state] SortMergeJoin [ss_ticket_number,ss_item_sk,sr_ticket_number,sr_item_sk] InputAdapter - WholeStageCodegen (8) + WholeStageCodegen (9) Sort [ss_ticket_number,ss_item_sk] - Project [s_store_name,s_state,ca_state,c_first_name,c_last_name,ss_item_sk,ss_ticket_number,ss_net_paid,i_current_price,i_size,i_color,i_units,i_manager_id] - SortMergeJoin [ss_item_sk,i_item_sk] - InputAdapter - WholeStageCodegen (5) - Sort [ss_item_sk] + InputAdapter + Exchange [ss_ticket_number,ss_item_sk] #12 + WholeStageCodegen (8) + Project [s_store_name,s_state,ca_state,c_first_name,c_last_name,ss_item_sk,ss_ticket_number,ss_net_paid,i_current_price,i_size,i_color,i_units,i_manager_id] + SortMergeJoin [ss_item_sk,i_item_sk] InputAdapter - Exchange [ss_item_sk] #12 - WholeStageCodegen (4) - Project [s_store_name,s_state,ca_state,c_first_name,c_last_name,ss_item_sk,ss_ticket_number,ss_net_paid] - BroadcastHashJoin [s_store_sk,c_customer_sk,ss_store_sk,ss_customer_sk] - InputAdapter - ReusedExchange [s_store_sk,s_store_name,s_state,ca_state,c_customer_sk,c_first_name,c_last_name] #5 - Project [ss_item_sk,ss_customer_sk,ss_store_sk,ss_ticket_number,ss_net_paid] - Filter [ss_ticket_number,ss_item_sk,ss_store_sk,ss_customer_sk] - ColumnarToRow + WholeStageCodegen (5) + Sort [ss_item_sk] + InputAdapter + Exchange [ss_item_sk] #13 + WholeStageCodegen (4) + Project [s_store_name,s_state,ca_state,c_first_name,c_last_name,ss_item_sk,ss_ticket_number,ss_net_paid] + BroadcastHashJoin [s_store_sk,c_customer_sk,ss_store_sk,ss_customer_sk] InputAdapter - Scan parquet default.store_sales [ss_item_sk,ss_customer_sk,ss_store_sk,ss_ticket_number,ss_net_paid,ss_sold_date_sk] - InputAdapter - WholeStageCodegen (7) - Sort [i_item_sk] + ReusedExchange [s_store_sk,s_store_name,s_state,ca_state,c_customer_sk,c_first_name,c_last_name] #5 + Project [ss_item_sk,ss_customer_sk,ss_store_sk,ss_ticket_number,ss_net_paid] + Filter [ss_ticket_number,ss_item_sk,ss_store_sk,ss_customer_sk] + ColumnarToRow + InputAdapter + Scan parquet default.store_sales [ss_item_sk,ss_customer_sk,ss_store_sk,ss_ticket_number,ss_net_paid,ss_sold_date_sk] InputAdapter - Exchange [i_item_sk] #13 - WholeStageCodegen (6) - Filter [i_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.item [i_item_sk,i_current_price,i_size,i_color,i_units,i_manager_id] + WholeStageCodegen (7) + Sort [i_item_sk] + InputAdapter + Exchange [i_item_sk] #14 + WholeStageCodegen (6) + Filter [i_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.item [i_item_sk,i_current_price,i_size,i_color,i_units,i_manager_id] InputAdapter - WholeStageCodegen (10) + WholeStageCodegen (11) Sort [sr_ticket_number,sr_item_sk] InputAdapter - Exchange [sr_item_sk] #14 - WholeStageCodegen (9) - Project [sr_item_sk,sr_ticket_number] - Filter [sr_ticket_number,sr_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.store_returns [sr_item_sk,sr_ticket_number,sr_returned_date_sk] + ReusedExchange [sr_item_sk,sr_ticket_number] #9 HashAggregate [c_last_name,c_first_name,s_store_name,sum,isEmpty] [sum(netpaid),paid,sum,isEmpty] InputAdapter Exchange [c_last_name,c_first_name,s_store_name] #2 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q24/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q24/explain.txt index d27e5af04e2dc..ea90187cb53ad 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q24/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q24/explain.txt @@ -422,6 +422,6 @@ Input [2]: [sum#57, count#58] Keys: [] Functions [1]: [avg(netpaid#40)] Aggregate Attributes [1]: [avg(netpaid#40)#60] -Results [1]: [CheckOverflow((0.050000 * promote_precision(avg(netpaid#40)#60)), DecimalType(24,8), true) AS (0.05 * avg(netpaid))#61] +Results [1]: [CheckOverflow((0.050000 * promote_precision(avg(netpaid#40)#60)), DecimalType(24,8)) AS (0.05 * avg(netpaid))#61] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q36a.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q36a.sf100/explain.txt index 0e20331e83484..9224fbda95e47 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q36a.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q36a.sf100/explain.txt @@ -143,7 +143,7 @@ Input [4]: [i_category#13, i_class#12, sum#17, sum#18] Keys [2]: [i_category#13, i_class#12] Functions [2]: [sum(UnscaledValue(ss_net_profit#4)), sum(UnscaledValue(ss_ext_sales_price#3))] Aggregate Attributes [2]: [sum(UnscaledValue(ss_net_profit#4))#20, sum(UnscaledValue(ss_ext_sales_price#3))#21] -Results [6]: [cast(CheckOverflow((promote_precision(MakeDecimal(sum(UnscaledValue(ss_net_profit#4))#20,17,2)) / promote_precision(MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#3))#21,17,2))), DecimalType(37,20), true) as decimal(38,20)) AS gross_margin#22, i_category#13, i_class#12, 0 AS t_category#23, 0 AS t_class#24, 0 AS lochierarchy#25] +Results [6]: [cast(CheckOverflow((promote_precision(MakeDecimal(sum(UnscaledValue(ss_net_profit#4))#20,17,2)) / promote_precision(MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#3))#21,17,2))), DecimalType(37,20)) as decimal(38,20)) AS gross_margin#22, i_category#13, i_class#12, 0 AS t_category#23, 0 AS t_class#24, 0 AS lochierarchy#25] (23) ReusedExchange [Reuses operator id: 21] Output [4]: [i_category#13, i_class#12, sum#26, sum#27] @@ -171,7 +171,7 @@ Input [5]: [i_category#13, sum#36, isEmpty#37, sum#38, isEmpty#39] Keys [1]: [i_category#13] Functions [2]: [sum(ss_net_profit#30), sum(ss_ext_sales_price#31)] Aggregate Attributes [2]: [sum(ss_net_profit#30)#41, sum(ss_ext_sales_price#31)#42] -Results [6]: [cast(CheckOverflow((promote_precision(sum(ss_net_profit#30)#41) / promote_precision(sum(ss_ext_sales_price#31)#42)), DecimalType(38,11), true) as decimal(38,20)) AS gross_margin#43, i_category#13, null AS i_class#44, 0 AS t_category#45, 1 AS t_class#46, 1 AS lochierarchy#47] +Results [6]: [cast(CheckOverflow((promote_precision(sum(ss_net_profit#30)#41) / promote_precision(sum(ss_ext_sales_price#31)#42)), DecimalType(38,11)) as decimal(38,20)) AS gross_margin#43, i_category#13, null AS i_class#44, 0 AS t_category#45, 1 AS t_class#46, 1 AS lochierarchy#47] (28) ReusedExchange [Reuses operator id: 21] Output [4]: [i_category#13, i_class#12, sum#48, sum#49] @@ -199,7 +199,7 @@ Input [4]: [sum#54, isEmpty#55, sum#56, isEmpty#57] Keys: [] Functions [2]: [sum(ss_net_profit#30), sum(ss_ext_sales_price#31)] Aggregate Attributes [2]: [sum(ss_net_profit#30)#59, sum(ss_ext_sales_price#31)#60] -Results [6]: [cast(CheckOverflow((promote_precision(sum(ss_net_profit#30)#59) / promote_precision(sum(ss_ext_sales_price#31)#60)), DecimalType(38,11), true) as decimal(38,20)) AS gross_margin#61, null AS i_category#62, null AS i_class#63, 1 AS t_category#64, 1 AS t_class#65, 2 AS lochierarchy#66] +Results [6]: [cast(CheckOverflow((promote_precision(sum(ss_net_profit#30)#59) / promote_precision(sum(ss_ext_sales_price#31)#60)), DecimalType(38,11)) as decimal(38,20)) AS gross_margin#61, null AS i_category#62, null AS i_class#63, 1 AS t_category#64, 1 AS t_class#65, 2 AS lochierarchy#66] (33) Union diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q36a/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q36a/explain.txt index 5470bf61ac502..f036e3e8fef42 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q36a/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q36a/explain.txt @@ -143,7 +143,7 @@ Input [4]: [i_category#10, i_class#9, sum#17, sum#18] Keys [2]: [i_category#10, i_class#9] Functions [2]: [sum(UnscaledValue(ss_net_profit#4)), sum(UnscaledValue(ss_ext_sales_price#3))] Aggregate Attributes [2]: [sum(UnscaledValue(ss_net_profit#4))#20, sum(UnscaledValue(ss_ext_sales_price#3))#21] -Results [6]: [cast(CheckOverflow((promote_precision(MakeDecimal(sum(UnscaledValue(ss_net_profit#4))#20,17,2)) / promote_precision(MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#3))#21,17,2))), DecimalType(37,20), true) as decimal(38,20)) AS gross_margin#22, i_category#10, i_class#9, 0 AS t_category#23, 0 AS t_class#24, 0 AS lochierarchy#25] +Results [6]: [cast(CheckOverflow((promote_precision(MakeDecimal(sum(UnscaledValue(ss_net_profit#4))#20,17,2)) / promote_precision(MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#3))#21,17,2))), DecimalType(37,20)) as decimal(38,20)) AS gross_margin#22, i_category#10, i_class#9, 0 AS t_category#23, 0 AS t_class#24, 0 AS lochierarchy#25] (23) ReusedExchange [Reuses operator id: 21] Output [4]: [i_category#10, i_class#9, sum#26, sum#27] @@ -171,7 +171,7 @@ Input [5]: [i_category#10, sum#36, isEmpty#37, sum#38, isEmpty#39] Keys [1]: [i_category#10] Functions [2]: [sum(ss_net_profit#30), sum(ss_ext_sales_price#31)] Aggregate Attributes [2]: [sum(ss_net_profit#30)#41, sum(ss_ext_sales_price#31)#42] -Results [6]: [cast(CheckOverflow((promote_precision(sum(ss_net_profit#30)#41) / promote_precision(sum(ss_ext_sales_price#31)#42)), DecimalType(38,11), true) as decimal(38,20)) AS gross_margin#43, i_category#10, null AS i_class#44, 0 AS t_category#45, 1 AS t_class#46, 1 AS lochierarchy#47] +Results [6]: [cast(CheckOverflow((promote_precision(sum(ss_net_profit#30)#41) / promote_precision(sum(ss_ext_sales_price#31)#42)), DecimalType(38,11)) as decimal(38,20)) AS gross_margin#43, i_category#10, null AS i_class#44, 0 AS t_category#45, 1 AS t_class#46, 1 AS lochierarchy#47] (28) ReusedExchange [Reuses operator id: 21] Output [4]: [i_category#10, i_class#9, sum#48, sum#49] @@ -199,7 +199,7 @@ Input [4]: [sum#54, isEmpty#55, sum#56, isEmpty#57] Keys: [] Functions [2]: [sum(ss_net_profit#30), sum(ss_ext_sales_price#31)] Aggregate Attributes [2]: [sum(ss_net_profit#30)#59, sum(ss_ext_sales_price#31)#60] -Results [6]: [cast(CheckOverflow((promote_precision(sum(ss_net_profit#30)#59) / promote_precision(sum(ss_ext_sales_price#31)#60)), DecimalType(38,11), true) as decimal(38,20)) AS gross_margin#61, null AS i_category#62, null AS i_class#63, 1 AS t_category#64, 1 AS t_class#65, 2 AS lochierarchy#66] +Results [6]: [cast(CheckOverflow((promote_precision(sum(ss_net_profit#30)#59) / promote_precision(sum(ss_ext_sales_price#31)#60)), DecimalType(38,11)) as decimal(38,20)) AS gross_margin#61, null AS i_category#62, null AS i_class#63, 1 AS t_category#64, 1 AS t_class#65, 2 AS lochierarchy#66] (33) Union diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q47.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q47.sf100/explain.txt index 51b2f051403e6..d2a5ecef9c900 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q47.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q47.sf100/explain.txt @@ -1,53 +1,56 @@ == Physical Plan == -TakeOrderedAndProject (49) -+- * Project (48) - +- * SortMergeJoin Inner (47) - :- * Project (41) - : +- * SortMergeJoin Inner (40) - : :- * Sort (32) - : : +- * Project (31) - : : +- * Filter (30) - : : +- Window (29) - : : +- * Filter (28) - : : +- Window (27) - : : +- * Sort (26) - : : +- Exchange (25) - : : +- * HashAggregate (24) - : : +- Exchange (23) - : : +- * HashAggregate (22) - : : +- * Project (21) - : : +- * SortMergeJoin Inner (20) - : : :- * Sort (14) - : : : +- Exchange (13) - : : : +- * Project (12) - : : : +- * BroadcastHashJoin Inner BuildRight (11) - : : : :- * Project (6) - : : : : +- * BroadcastHashJoin Inner BuildRight (5) - : : : : :- * Filter (3) - : : : : : +- * ColumnarToRow (2) - : : : : : +- Scan parquet default.store_sales (1) - : : : : +- ReusedExchange (4) - : : : +- BroadcastExchange (10) - : : : +- * Filter (9) - : : : +- * ColumnarToRow (8) - : : : +- Scan parquet default.store (7) - : : +- * Sort (19) - : : +- Exchange (18) - : : +- * Filter (17) - : : +- * ColumnarToRow (16) - : : +- Scan parquet default.item (15) - : +- * Sort (39) - : +- * Project (38) - : +- Window (37) - : +- * Sort (36) - : +- Exchange (35) - : +- * HashAggregate (34) - : +- ReusedExchange (33) - +- * Sort (46) - +- * Project (45) - +- Window (44) - +- * Sort (43) - +- ReusedExchange (42) +TakeOrderedAndProject (52) ++- * Project (51) + +- * SortMergeJoin Inner (50) + :- * Project (43) + : +- * SortMergeJoin Inner (42) + : :- * Sort (33) + : : +- Exchange (32) + : : +- * Project (31) + : : +- * Filter (30) + : : +- Window (29) + : : +- * Filter (28) + : : +- Window (27) + : : +- * Sort (26) + : : +- Exchange (25) + : : +- * HashAggregate (24) + : : +- Exchange (23) + : : +- * HashAggregate (22) + : : +- * Project (21) + : : +- * SortMergeJoin Inner (20) + : : :- * Sort (14) + : : : +- Exchange (13) + : : : +- * Project (12) + : : : +- * BroadcastHashJoin Inner BuildRight (11) + : : : :- * Project (6) + : : : : +- * BroadcastHashJoin Inner BuildRight (5) + : : : : :- * Filter (3) + : : : : : +- * ColumnarToRow (2) + : : : : : +- Scan parquet default.store_sales (1) + : : : : +- ReusedExchange (4) + : : : +- BroadcastExchange (10) + : : : +- * Filter (9) + : : : +- * ColumnarToRow (8) + : : : +- Scan parquet default.store (7) + : : +- * Sort (19) + : : +- Exchange (18) + : : +- * Filter (17) + : : +- * ColumnarToRow (16) + : : +- Scan parquet default.item (15) + : +- * Sort (41) + : +- Exchange (40) + : +- * Project (39) + : +- Window (38) + : +- * Sort (37) + : +- Exchange (36) + : +- * HashAggregate (35) + : +- ReusedExchange (34) + +- * Sort (49) + +- Exchange (48) + +- * Project (47) + +- Window (46) + +- * Sort (45) + +- ReusedExchange (44) (1) Scan parquet default.store_sales @@ -65,7 +68,7 @@ Input [4]: [ss_item_sk#1, ss_store_sk#2, ss_sales_price#3, ss_sold_date_sk#4] Input [4]: [ss_item_sk#1, ss_store_sk#2, ss_sales_price#3, ss_sold_date_sk#4] Condition : (isnotnull(ss_item_sk#1) AND isnotnull(ss_store_sk#2)) -(4) ReusedExchange [Reuses operator id: 53] +(4) ReusedExchange [Reuses operator id: 56] Output [3]: [d_date_sk#6, d_year#7, d_moy#8] (5) BroadcastHashJoin [codegen id : 3] @@ -183,112 +186,124 @@ Arguments: [avg(_w0#23) windowspecdefinition(i_category#16, i_brand#15, s_store_ (30) Filter [codegen id : 11] Input [10]: [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, d_year#7, d_moy#8, sum_sales#22, _w0#23, rn#25, avg_monthly_sales#26] -Condition : ((isnotnull(avg_monthly_sales#26) AND (avg_monthly_sales#26 > 0.000000)) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#22 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#26 as decimal(22,6)))), DecimalType(22,6), true), false)) / promote_precision(cast(avg_monthly_sales#26 as decimal(22,6)))), DecimalType(38,16), true) > 0.1000000000000000)) +Condition : ((isnotnull(avg_monthly_sales#26) AND (avg_monthly_sales#26 > 0.000000)) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#22 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#26 as decimal(22,6)))), DecimalType(22,6)))) / promote_precision(cast(avg_monthly_sales#26 as decimal(22,6)))), DecimalType(38,16)) > 0.1000000000000000)) (31) Project [codegen id : 11] Output [9]: [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, d_year#7, d_moy#8, sum_sales#22, avg_monthly_sales#26, rn#25] Input [10]: [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, d_year#7, d_moy#8, sum_sales#22, _w0#23, rn#25, avg_monthly_sales#26] -(32) Sort [codegen id : 11] +(32) Exchange +Input [9]: [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, d_year#7, d_moy#8, sum_sales#22, avg_monthly_sales#26, rn#25] +Arguments: hashpartitioning(i_category#16, i_brand#15, s_store_name#10, s_company_name#11, rn#25, 5), ENSURE_REQUIREMENTS, [id=#27] + +(33) Sort [codegen id : 12] Input [9]: [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, d_year#7, d_moy#8, sum_sales#22, avg_monthly_sales#26, rn#25] Arguments: [i_category#16 ASC NULLS FIRST, i_brand#15 ASC NULLS FIRST, s_store_name#10 ASC NULLS FIRST, s_company_name#11 ASC NULLS FIRST, rn#25 ASC NULLS FIRST], false, 0 -(33) ReusedExchange [Reuses operator id: 23] -Output [7]: [i_category#27, i_brand#28, s_store_name#29, s_company_name#30, d_year#31, d_moy#32, sum#33] +(34) ReusedExchange [Reuses operator id: 23] +Output [7]: [i_category#28, i_brand#29, s_store_name#30, s_company_name#31, d_year#32, d_moy#33, sum#34] -(34) HashAggregate [codegen id : 19] -Input [7]: [i_category#27, i_brand#28, s_store_name#29, s_company_name#30, d_year#31, d_moy#32, sum#33] -Keys [6]: [i_category#27, i_brand#28, s_store_name#29, s_company_name#30, d_year#31, d_moy#32] -Functions [1]: [sum(UnscaledValue(ss_sales_price#34))] -Aggregate Attributes [1]: [sum(UnscaledValue(ss_sales_price#34))#21] -Results [7]: [i_category#27, i_brand#28, s_store_name#29, s_company_name#30, d_year#31, d_moy#32, MakeDecimal(sum(UnscaledValue(ss_sales_price#34))#21,17,2) AS sum_sales#22] +(35) HashAggregate [codegen id : 20] +Input [7]: [i_category#28, i_brand#29, s_store_name#30, s_company_name#31, d_year#32, d_moy#33, sum#34] +Keys [6]: [i_category#28, i_brand#29, s_store_name#30, s_company_name#31, d_year#32, d_moy#33] +Functions [1]: [sum(UnscaledValue(ss_sales_price#35))] +Aggregate Attributes [1]: [sum(UnscaledValue(ss_sales_price#35))#21] +Results [7]: [i_category#28, i_brand#29, s_store_name#30, s_company_name#31, d_year#32, d_moy#33, MakeDecimal(sum(UnscaledValue(ss_sales_price#35))#21,17,2) AS sum_sales#22] -(35) Exchange -Input [7]: [i_category#27, i_brand#28, s_store_name#29, s_company_name#30, d_year#31, d_moy#32, sum_sales#22] -Arguments: hashpartitioning(i_category#27, i_brand#28, s_store_name#29, s_company_name#30, 5), ENSURE_REQUIREMENTS, [id=#35] +(36) Exchange +Input [7]: [i_category#28, i_brand#29, s_store_name#30, s_company_name#31, d_year#32, d_moy#33, sum_sales#22] +Arguments: hashpartitioning(i_category#28, i_brand#29, s_store_name#30, s_company_name#31, 5), ENSURE_REQUIREMENTS, [id=#36] -(36) Sort [codegen id : 20] -Input [7]: [i_category#27, i_brand#28, s_store_name#29, s_company_name#30, d_year#31, d_moy#32, sum_sales#22] -Arguments: [i_category#27 ASC NULLS FIRST, i_brand#28 ASC NULLS FIRST, s_store_name#29 ASC NULLS FIRST, s_company_name#30 ASC NULLS FIRST, d_year#31 ASC NULLS FIRST, d_moy#32 ASC NULLS FIRST], false, 0 +(37) Sort [codegen id : 21] +Input [7]: [i_category#28, i_brand#29, s_store_name#30, s_company_name#31, d_year#32, d_moy#33, sum_sales#22] +Arguments: [i_category#28 ASC NULLS FIRST, i_brand#29 ASC NULLS FIRST, s_store_name#30 ASC NULLS FIRST, s_company_name#31 ASC NULLS FIRST, d_year#32 ASC NULLS FIRST, d_moy#33 ASC NULLS FIRST], false, 0 -(37) Window -Input [7]: [i_category#27, i_brand#28, s_store_name#29, s_company_name#30, d_year#31, d_moy#32, sum_sales#22] -Arguments: [rank(d_year#31, d_moy#32) windowspecdefinition(i_category#27, i_brand#28, s_store_name#29, s_company_name#30, d_year#31 ASC NULLS FIRST, d_moy#32 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rn#36], [i_category#27, i_brand#28, s_store_name#29, s_company_name#30], [d_year#31 ASC NULLS FIRST, d_moy#32 ASC NULLS FIRST] +(38) Window +Input [7]: [i_category#28, i_brand#29, s_store_name#30, s_company_name#31, d_year#32, d_moy#33, sum_sales#22] +Arguments: [rank(d_year#32, d_moy#33) windowspecdefinition(i_category#28, i_brand#29, s_store_name#30, s_company_name#31, d_year#32 ASC NULLS FIRST, d_moy#33 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rn#37], [i_category#28, i_brand#29, s_store_name#30, s_company_name#31], [d_year#32 ASC NULLS FIRST, d_moy#33 ASC NULLS FIRST] -(38) Project [codegen id : 21] -Output [6]: [i_category#27, i_brand#28, s_store_name#29, s_company_name#30, sum_sales#22 AS sum_sales#37, rn#36] -Input [8]: [i_category#27, i_brand#28, s_store_name#29, s_company_name#30, d_year#31, d_moy#32, sum_sales#22, rn#36] +(39) Project [codegen id : 22] +Output [6]: [i_category#28, i_brand#29, s_store_name#30, s_company_name#31, sum_sales#22 AS sum_sales#38, rn#37] +Input [8]: [i_category#28, i_brand#29, s_store_name#30, s_company_name#31, d_year#32, d_moy#33, sum_sales#22, rn#37] -(39) Sort [codegen id : 21] -Input [6]: [i_category#27, i_brand#28, s_store_name#29, s_company_name#30, sum_sales#37, rn#36] -Arguments: [i_category#27 ASC NULLS FIRST, i_brand#28 ASC NULLS FIRST, s_store_name#29 ASC NULLS FIRST, s_company_name#30 ASC NULLS FIRST, (rn#36 + 1) ASC NULLS FIRST], false, 0 +(40) Exchange +Input [6]: [i_category#28, i_brand#29, s_store_name#30, s_company_name#31, sum_sales#38, rn#37] +Arguments: hashpartitioning(i_category#28, i_brand#29, s_store_name#30, s_company_name#31, (rn#37 + 1), 5), ENSURE_REQUIREMENTS, [id=#39] -(40) SortMergeJoin [codegen id : 22] +(41) Sort [codegen id : 23] +Input [6]: [i_category#28, i_brand#29, s_store_name#30, s_company_name#31, sum_sales#38, rn#37] +Arguments: [i_category#28 ASC NULLS FIRST, i_brand#29 ASC NULLS FIRST, s_store_name#30 ASC NULLS FIRST, s_company_name#31 ASC NULLS FIRST, (rn#37 + 1) ASC NULLS FIRST], false, 0 + +(42) SortMergeJoin [codegen id : 24] Left keys [5]: [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, rn#25] -Right keys [5]: [i_category#27, i_brand#28, s_store_name#29, s_company_name#30, (rn#36 + 1)] +Right keys [5]: [i_category#28, i_brand#29, s_store_name#30, s_company_name#31, (rn#37 + 1)] Join condition: None -(41) Project [codegen id : 22] -Output [10]: [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, d_year#7, d_moy#8, sum_sales#22, avg_monthly_sales#26, rn#25, sum_sales#37] -Input [15]: [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, d_year#7, d_moy#8, sum_sales#22, avg_monthly_sales#26, rn#25, i_category#27, i_brand#28, s_store_name#29, s_company_name#30, sum_sales#37, rn#36] +(43) Project [codegen id : 24] +Output [10]: [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, d_year#7, d_moy#8, sum_sales#22, avg_monthly_sales#26, rn#25, sum_sales#38] +Input [15]: [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, d_year#7, d_moy#8, sum_sales#22, avg_monthly_sales#26, rn#25, i_category#28, i_brand#29, s_store_name#30, s_company_name#31, sum_sales#38, rn#37] + +(44) ReusedExchange [Reuses operator id: 36] +Output [7]: [i_category#40, i_brand#41, s_store_name#42, s_company_name#43, d_year#44, d_moy#45, sum_sales#22] -(42) ReusedExchange [Reuses operator id: 35] -Output [7]: [i_category#38, i_brand#39, s_store_name#40, s_company_name#41, d_year#42, d_moy#43, sum_sales#22] +(45) Sort [codegen id : 33] +Input [7]: [i_category#40, i_brand#41, s_store_name#42, s_company_name#43, d_year#44, d_moy#45, sum_sales#22] +Arguments: [i_category#40 ASC NULLS FIRST, i_brand#41 ASC NULLS FIRST, s_store_name#42 ASC NULLS FIRST, s_company_name#43 ASC NULLS FIRST, d_year#44 ASC NULLS FIRST, d_moy#45 ASC NULLS FIRST], false, 0 -(43) Sort [codegen id : 31] -Input [7]: [i_category#38, i_brand#39, s_store_name#40, s_company_name#41, d_year#42, d_moy#43, sum_sales#22] -Arguments: [i_category#38 ASC NULLS FIRST, i_brand#39 ASC NULLS FIRST, s_store_name#40 ASC NULLS FIRST, s_company_name#41 ASC NULLS FIRST, d_year#42 ASC NULLS FIRST, d_moy#43 ASC NULLS FIRST], false, 0 +(46) Window +Input [7]: [i_category#40, i_brand#41, s_store_name#42, s_company_name#43, d_year#44, d_moy#45, sum_sales#22] +Arguments: [rank(d_year#44, d_moy#45) windowspecdefinition(i_category#40, i_brand#41, s_store_name#42, s_company_name#43, d_year#44 ASC NULLS FIRST, d_moy#45 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rn#46], [i_category#40, i_brand#41, s_store_name#42, s_company_name#43], [d_year#44 ASC NULLS FIRST, d_moy#45 ASC NULLS FIRST] -(44) Window -Input [7]: [i_category#38, i_brand#39, s_store_name#40, s_company_name#41, d_year#42, d_moy#43, sum_sales#22] -Arguments: [rank(d_year#42, d_moy#43) windowspecdefinition(i_category#38, i_brand#39, s_store_name#40, s_company_name#41, d_year#42 ASC NULLS FIRST, d_moy#43 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rn#44], [i_category#38, i_brand#39, s_store_name#40, s_company_name#41], [d_year#42 ASC NULLS FIRST, d_moy#43 ASC NULLS FIRST] +(47) Project [codegen id : 34] +Output [6]: [i_category#40, i_brand#41, s_store_name#42, s_company_name#43, sum_sales#22 AS sum_sales#47, rn#46] +Input [8]: [i_category#40, i_brand#41, s_store_name#42, s_company_name#43, d_year#44, d_moy#45, sum_sales#22, rn#46] -(45) Project [codegen id : 32] -Output [6]: [i_category#38, i_brand#39, s_store_name#40, s_company_name#41, sum_sales#22 AS sum_sales#45, rn#44] -Input [8]: [i_category#38, i_brand#39, s_store_name#40, s_company_name#41, d_year#42, d_moy#43, sum_sales#22, rn#44] +(48) Exchange +Input [6]: [i_category#40, i_brand#41, s_store_name#42, s_company_name#43, sum_sales#47, rn#46] +Arguments: hashpartitioning(i_category#40, i_brand#41, s_store_name#42, s_company_name#43, (rn#46 - 1), 5), ENSURE_REQUIREMENTS, [id=#48] -(46) Sort [codegen id : 32] -Input [6]: [i_category#38, i_brand#39, s_store_name#40, s_company_name#41, sum_sales#45, rn#44] -Arguments: [i_category#38 ASC NULLS FIRST, i_brand#39 ASC NULLS FIRST, s_store_name#40 ASC NULLS FIRST, s_company_name#41 ASC NULLS FIRST, (rn#44 - 1) ASC NULLS FIRST], false, 0 +(49) Sort [codegen id : 35] +Input [6]: [i_category#40, i_brand#41, s_store_name#42, s_company_name#43, sum_sales#47, rn#46] +Arguments: [i_category#40 ASC NULLS FIRST, i_brand#41 ASC NULLS FIRST, s_store_name#42 ASC NULLS FIRST, s_company_name#43 ASC NULLS FIRST, (rn#46 - 1) ASC NULLS FIRST], false, 0 -(47) SortMergeJoin [codegen id : 33] +(50) SortMergeJoin [codegen id : 36] Left keys [5]: [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, rn#25] -Right keys [5]: [i_category#38, i_brand#39, s_store_name#40, s_company_name#41, (rn#44 - 1)] +Right keys [5]: [i_category#40, i_brand#41, s_store_name#42, s_company_name#43, (rn#46 - 1)] Join condition: None -(48) Project [codegen id : 33] -Output [7]: [i_category#16, d_year#7, d_moy#8, avg_monthly_sales#26, sum_sales#22, sum_sales#37 AS psum#46, sum_sales#45 AS nsum#47] -Input [16]: [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, d_year#7, d_moy#8, sum_sales#22, avg_monthly_sales#26, rn#25, sum_sales#37, i_category#38, i_brand#39, s_store_name#40, s_company_name#41, sum_sales#45, rn#44] +(51) Project [codegen id : 36] +Output [7]: [i_category#16, d_year#7, d_moy#8, avg_monthly_sales#26, sum_sales#22, sum_sales#38 AS psum#49, sum_sales#47 AS nsum#50] +Input [16]: [i_category#16, i_brand#15, s_store_name#10, s_company_name#11, d_year#7, d_moy#8, sum_sales#22, avg_monthly_sales#26, rn#25, sum_sales#38, i_category#40, i_brand#41, s_store_name#42, s_company_name#43, sum_sales#47, rn#46] -(49) TakeOrderedAndProject -Input [7]: [i_category#16, d_year#7, d_moy#8, avg_monthly_sales#26, sum_sales#22, psum#46, nsum#47] -Arguments: 100, [CheckOverflow((promote_precision(cast(sum_sales#22 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#26 as decimal(22,6)))), DecimalType(22,6), true) ASC NULLS FIRST, d_moy#8 ASC NULLS FIRST], [i_category#16, d_year#7, d_moy#8, avg_monthly_sales#26, sum_sales#22, psum#46, nsum#47] +(52) TakeOrderedAndProject +Input [7]: [i_category#16, d_year#7, d_moy#8, avg_monthly_sales#26, sum_sales#22, psum#49, nsum#50] +Arguments: 100, [CheckOverflow((promote_precision(cast(sum_sales#22 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#26 as decimal(22,6)))), DecimalType(22,6)) ASC NULLS FIRST, d_moy#8 ASC NULLS FIRST], [i_category#16, d_year#7, d_moy#8, avg_monthly_sales#26, sum_sales#22, psum#49, nsum#50] ===== Subqueries ===== Subquery:1 Hosting operator id = 1 Hosting Expression = ss_sold_date_sk#4 IN dynamicpruning#5 -BroadcastExchange (53) -+- * Filter (52) - +- * ColumnarToRow (51) - +- Scan parquet default.date_dim (50) +BroadcastExchange (56) ++- * Filter (55) + +- * ColumnarToRow (54) + +- Scan parquet default.date_dim (53) -(50) Scan parquet default.date_dim +(53) Scan parquet default.date_dim Output [3]: [d_date_sk#6, d_year#7, d_moy#8] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [Or(Or(EqualTo(d_year,1999),And(EqualTo(d_year,1998),EqualTo(d_moy,12))),And(EqualTo(d_year,2000),EqualTo(d_moy,1))), IsNotNull(d_date_sk)] ReadSchema: struct -(51) ColumnarToRow [codegen id : 1] +(54) ColumnarToRow [codegen id : 1] Input [3]: [d_date_sk#6, d_year#7, d_moy#8] -(52) Filter [codegen id : 1] +(55) Filter [codegen id : 1] Input [3]: [d_date_sk#6, d_year#7, d_moy#8] Condition : ((((d_year#7 = 1999) OR ((d_year#7 = 1998) AND (d_moy#8 = 12))) OR ((d_year#7 = 2000) AND (d_moy#8 = 1))) AND isnotnull(d_date_sk#6)) -(53) BroadcastExchange +(56) BroadcastExchange Input [3]: [d_date_sk#6, d_year#7, d_moy#8] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#48] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#51] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q47.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q47.sf100/simplified.txt index 65bcf10a8518b..5f64a22717270 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q47.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q47.sf100/simplified.txt @@ -1,95 +1,104 @@ TakeOrderedAndProject [sum_sales,avg_monthly_sales,d_moy,i_category,d_year,psum,nsum] - WholeStageCodegen (33) + WholeStageCodegen (36) Project [i_category,d_year,d_moy,avg_monthly_sales,sum_sales,sum_sales,sum_sales] SortMergeJoin [i_category,i_brand,s_store_name,s_company_name,rn,i_category,i_brand,s_store_name,s_company_name,rn] InputAdapter - WholeStageCodegen (22) + WholeStageCodegen (24) Project [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy,sum_sales,avg_monthly_sales,rn,sum_sales] SortMergeJoin [i_category,i_brand,s_store_name,s_company_name,rn,i_category,i_brand,s_store_name,s_company_name,rn] InputAdapter - WholeStageCodegen (11) + WholeStageCodegen (12) Sort [i_category,i_brand,s_store_name,s_company_name,rn] - Project [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy,sum_sales,avg_monthly_sales,rn] - Filter [avg_monthly_sales,sum_sales] - InputAdapter - Window [_w0,i_category,i_brand,s_store_name,s_company_name,d_year] - WholeStageCodegen (10) - Filter [d_year] - InputAdapter - Window [d_year,d_moy,i_category,i_brand,s_store_name,s_company_name] - WholeStageCodegen (9) - Sort [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy] - InputAdapter - Exchange [i_category,i_brand,s_store_name,s_company_name] #1 - WholeStageCodegen (8) - HashAggregate [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy,sum] [sum(UnscaledValue(ss_sales_price)),sum_sales,_w0,sum] - InputAdapter - Exchange [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy] #2 - WholeStageCodegen (7) - HashAggregate [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy,ss_sales_price] [sum,sum] - Project [i_brand,i_category,ss_sales_price,d_year,d_moy,s_store_name,s_company_name] - SortMergeJoin [ss_item_sk,i_item_sk] - InputAdapter - WholeStageCodegen (4) - Sort [ss_item_sk] + InputAdapter + Exchange [i_category,i_brand,s_store_name,s_company_name,rn] #1 + WholeStageCodegen (11) + Project [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy,sum_sales,avg_monthly_sales,rn] + Filter [avg_monthly_sales,sum_sales] + InputAdapter + Window [_w0,i_category,i_brand,s_store_name,s_company_name,d_year] + WholeStageCodegen (10) + Filter [d_year] + InputAdapter + Window [d_year,d_moy,i_category,i_brand,s_store_name,s_company_name] + WholeStageCodegen (9) + Sort [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy] + InputAdapter + Exchange [i_category,i_brand,s_store_name,s_company_name] #2 + WholeStageCodegen (8) + HashAggregate [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy,sum] [sum(UnscaledValue(ss_sales_price)),sum_sales,_w0,sum] + InputAdapter + Exchange [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy] #3 + WholeStageCodegen (7) + HashAggregate [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy,ss_sales_price] [sum,sum] + Project [i_brand,i_category,ss_sales_price,d_year,d_moy,s_store_name,s_company_name] + SortMergeJoin [ss_item_sk,i_item_sk] InputAdapter - Exchange [ss_item_sk] #3 - WholeStageCodegen (3) - Project [ss_item_sk,ss_sales_price,d_year,d_moy,s_store_name,s_company_name] - BroadcastHashJoin [ss_store_sk,s_store_sk] - Project [ss_item_sk,ss_store_sk,ss_sales_price,d_year,d_moy] - BroadcastHashJoin [ss_sold_date_sk,d_date_sk] - Filter [ss_item_sk,ss_store_sk] - ColumnarToRow - InputAdapter - Scan parquet default.store_sales [ss_item_sk,ss_store_sk,ss_sales_price,ss_sold_date_sk] - SubqueryBroadcast [d_date_sk] #1 - BroadcastExchange #4 - WholeStageCodegen (1) - Filter [d_year,d_moy,d_date_sk] - ColumnarToRow - InputAdapter - Scan parquet default.date_dim [d_date_sk,d_year,d_moy] - InputAdapter - ReusedExchange [d_date_sk,d_year,d_moy] #4 - InputAdapter - BroadcastExchange #5 - WholeStageCodegen (2) - Filter [s_store_sk,s_store_name,s_company_name] - ColumnarToRow + WholeStageCodegen (4) + Sort [ss_item_sk] + InputAdapter + Exchange [ss_item_sk] #4 + WholeStageCodegen (3) + Project [ss_item_sk,ss_sales_price,d_year,d_moy,s_store_name,s_company_name] + BroadcastHashJoin [ss_store_sk,s_store_sk] + Project [ss_item_sk,ss_store_sk,ss_sales_price,d_year,d_moy] + BroadcastHashJoin [ss_sold_date_sk,d_date_sk] + Filter [ss_item_sk,ss_store_sk] + ColumnarToRow + InputAdapter + Scan parquet default.store_sales [ss_item_sk,ss_store_sk,ss_sales_price,ss_sold_date_sk] + SubqueryBroadcast [d_date_sk] #1 + BroadcastExchange #5 + WholeStageCodegen (1) + Filter [d_year,d_moy,d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_year,d_moy] InputAdapter - Scan parquet default.store [s_store_sk,s_store_name,s_company_name] - InputAdapter - WholeStageCodegen (6) - Sort [i_item_sk] + ReusedExchange [d_date_sk,d_year,d_moy] #5 + InputAdapter + BroadcastExchange #6 + WholeStageCodegen (2) + Filter [s_store_sk,s_store_name,s_company_name] + ColumnarToRow + InputAdapter + Scan parquet default.store [s_store_sk,s_store_name,s_company_name] InputAdapter - Exchange [i_item_sk] #6 - WholeStageCodegen (5) - Filter [i_item_sk,i_category,i_brand] - ColumnarToRow - InputAdapter - Scan parquet default.item [i_item_sk,i_brand,i_category] + WholeStageCodegen (6) + Sort [i_item_sk] + InputAdapter + Exchange [i_item_sk] #7 + WholeStageCodegen (5) + Filter [i_item_sk,i_category,i_brand] + ColumnarToRow + InputAdapter + Scan parquet default.item [i_item_sk,i_brand,i_category] InputAdapter - WholeStageCodegen (21) + WholeStageCodegen (23) Sort [i_category,i_brand,s_store_name,s_company_name,rn] - Project [i_category,i_brand,s_store_name,s_company_name,sum_sales,rn] - InputAdapter - Window [d_year,d_moy,i_category,i_brand,s_store_name,s_company_name] - WholeStageCodegen (20) - Sort [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy] - InputAdapter - Exchange [i_category,i_brand,s_store_name,s_company_name] #7 - WholeStageCodegen (19) - HashAggregate [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy,sum] [sum(UnscaledValue(ss_sales_price)),sum_sales,sum] - InputAdapter - ReusedExchange [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy,sum] #2 + InputAdapter + Exchange [i_category,i_brand,s_store_name,s_company_name,rn] #8 + WholeStageCodegen (22) + Project [i_category,i_brand,s_store_name,s_company_name,sum_sales,rn] + InputAdapter + Window [d_year,d_moy,i_category,i_brand,s_store_name,s_company_name] + WholeStageCodegen (21) + Sort [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy] + InputAdapter + Exchange [i_category,i_brand,s_store_name,s_company_name] #9 + WholeStageCodegen (20) + HashAggregate [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy,sum] [sum(UnscaledValue(ss_sales_price)),sum_sales,sum] + InputAdapter + ReusedExchange [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy,sum] #3 InputAdapter - WholeStageCodegen (32) + WholeStageCodegen (35) Sort [i_category,i_brand,s_store_name,s_company_name,rn] - Project [i_category,i_brand,s_store_name,s_company_name,sum_sales,rn] - InputAdapter - Window [d_year,d_moy,i_category,i_brand,s_store_name,s_company_name] - WholeStageCodegen (31) - Sort [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy] - InputAdapter - ReusedExchange [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy,sum_sales] #7 + InputAdapter + Exchange [i_category,i_brand,s_store_name,s_company_name,rn] #10 + WholeStageCodegen (34) + Project [i_category,i_brand,s_store_name,s_company_name,sum_sales,rn] + InputAdapter + Window [d_year,d_moy,i_category,i_brand,s_store_name,s_company_name] + WholeStageCodegen (33) + Sort [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy] + InputAdapter + ReusedExchange [i_category,i_brand,s_store_name,s_company_name,d_year,d_moy,sum_sales] #9 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q47/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q47/explain.txt index 21944f91237a0..8abc8fda35cef 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q47/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q47/explain.txt @@ -167,7 +167,7 @@ Arguments: [avg(_w0#22) windowspecdefinition(i_category#3, i_brand#2, s_store_na (27) Filter [codegen id : 22] Input [10]: [i_category#3, i_brand#2, s_store_name#14, s_company_name#15, d_year#11, d_moy#12, sum_sales#21, _w0#22, rn#24, avg_monthly_sales#25] -Condition : ((isnotnull(avg_monthly_sales#25) AND (avg_monthly_sales#25 > 0.000000)) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#25 as decimal(22,6)))), DecimalType(22,6), true), false)) / promote_precision(cast(avg_monthly_sales#25 as decimal(22,6)))), DecimalType(38,16), true) > 0.1000000000000000)) +Condition : ((isnotnull(avg_monthly_sales#25) AND (avg_monthly_sales#25 > 0.000000)) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#25 as decimal(22,6)))), DecimalType(22,6)))) / promote_precision(cast(avg_monthly_sales#25 as decimal(22,6)))), DecimalType(38,16)) > 0.1000000000000000)) (28) Project [codegen id : 22] Output [9]: [i_category#3, i_brand#2, s_store_name#14, s_company_name#15, d_year#11, d_moy#12, sum_sales#21, avg_monthly_sales#25, rn#24] @@ -242,7 +242,7 @@ Input [16]: [i_category#3, i_brand#2, s_store_name#14, s_company_name#15, d_year (45) TakeOrderedAndProject Input [7]: [i_category#3, d_year#11, d_moy#12, avg_monthly_sales#25, sum_sales#21, psum#47, nsum#48] -Arguments: 100, [CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#25 as decimal(22,6)))), DecimalType(22,6), true) ASC NULLS FIRST, d_moy#12 ASC NULLS FIRST], [i_category#3, d_year#11, d_moy#12, avg_monthly_sales#25, sum_sales#21, psum#47, nsum#48] +Arguments: 100, [CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#25 as decimal(22,6)))), DecimalType(22,6)) ASC NULLS FIRST, d_moy#12 ASC NULLS FIRST], [i_category#3, d_year#11, d_moy#12, avg_monthly_sales#25, sum_sales#21, psum#47, nsum#48] ===== Subqueries ===== diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q49.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q49.sf100/explain.txt index b1b28f1a20048..5efc0bfaed99e 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q49.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q49.sf100/explain.txt @@ -177,7 +177,7 @@ Input [7]: [ws_item_sk#1, sum#22, sum#23, sum#24, isEmpty#25, sum#26, isEmpty#27 Keys [1]: [ws_item_sk#1] Functions [4]: [sum(coalesce(wr_return_quantity#12, 0)), sum(coalesce(ws_quantity#3, 0)), sum(coalesce(cast(wr_return_amt#13 as decimal(12,2)), 0.00)), sum(coalesce(cast(ws_net_paid#4 as decimal(12,2)), 0.00))] Aggregate Attributes [4]: [sum(coalesce(wr_return_quantity#12, 0))#29, sum(coalesce(ws_quantity#3, 0))#30, sum(coalesce(cast(wr_return_amt#13 as decimal(12,2)), 0.00))#31, sum(coalesce(cast(ws_net_paid#4 as decimal(12,2)), 0.00))#32] -Results [3]: [ws_item_sk#1 AS item#33, CheckOverflow((promote_precision(cast(sum(coalesce(wr_return_quantity#12, 0))#29 as decimal(15,4))) / promote_precision(cast(sum(coalesce(ws_quantity#3, 0))#30 as decimal(15,4)))), DecimalType(35,20), true) AS return_ratio#34, CheckOverflow((promote_precision(cast(sum(coalesce(cast(wr_return_amt#13 as decimal(12,2)), 0.00))#31 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cast(ws_net_paid#4 as decimal(12,2)), 0.00))#32 as decimal(15,4)))), DecimalType(35,20), true) AS currency_ratio#35] +Results [3]: [ws_item_sk#1 AS item#33, CheckOverflow((promote_precision(cast(sum(coalesce(wr_return_quantity#12, 0))#29 as decimal(15,4))) / promote_precision(cast(sum(coalesce(ws_quantity#3, 0))#30 as decimal(15,4)))), DecimalType(35,20)) AS return_ratio#34, CheckOverflow((promote_precision(cast(sum(coalesce(cast(wr_return_amt#13 as decimal(12,2)), 0.00))#31 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cast(ws_net_paid#4 as decimal(12,2)), 0.00))#32 as decimal(15,4)))), DecimalType(35,20)) AS currency_ratio#35] (21) Exchange Input [3]: [item#33, return_ratio#34, currency_ratio#35] @@ -297,7 +297,7 @@ Input [7]: [cs_item_sk#40, sum#60, sum#61, sum#62, isEmpty#63, sum#64, isEmpty#6 Keys [1]: [cs_item_sk#40] Functions [4]: [sum(coalesce(cr_return_quantity#50, 0)), sum(coalesce(cs_quantity#42, 0)), sum(coalesce(cast(cr_return_amount#51 as decimal(12,2)), 0.00)), sum(coalesce(cast(cs_net_paid#43 as decimal(12,2)), 0.00))] Aggregate Attributes [4]: [sum(coalesce(cr_return_quantity#50, 0))#67, sum(coalesce(cs_quantity#42, 0))#68, sum(coalesce(cast(cr_return_amount#51 as decimal(12,2)), 0.00))#69, sum(coalesce(cast(cs_net_paid#43 as decimal(12,2)), 0.00))#70] -Results [3]: [cs_item_sk#40 AS item#71, CheckOverflow((promote_precision(cast(sum(coalesce(cr_return_quantity#50, 0))#67 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cs_quantity#42, 0))#68 as decimal(15,4)))), DecimalType(35,20), true) AS return_ratio#72, CheckOverflow((promote_precision(cast(sum(coalesce(cast(cr_return_amount#51 as decimal(12,2)), 0.00))#69 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cast(cs_net_paid#43 as decimal(12,2)), 0.00))#70 as decimal(15,4)))), DecimalType(35,20), true) AS currency_ratio#73] +Results [3]: [cs_item_sk#40 AS item#71, CheckOverflow((promote_precision(cast(sum(coalesce(cr_return_quantity#50, 0))#67 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cs_quantity#42, 0))#68 as decimal(15,4)))), DecimalType(35,20)) AS return_ratio#72, CheckOverflow((promote_precision(cast(sum(coalesce(cast(cr_return_amount#51 as decimal(12,2)), 0.00))#69 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cast(cs_net_paid#43 as decimal(12,2)), 0.00))#70 as decimal(15,4)))), DecimalType(35,20)) AS currency_ratio#73] (48) Exchange Input [3]: [item#71, return_ratio#72, currency_ratio#73] @@ -417,7 +417,7 @@ Input [7]: [ss_item_sk#78, sum#98, sum#99, sum#100, isEmpty#101, sum#102, isEmpt Keys [1]: [ss_item_sk#78] Functions [4]: [sum(coalesce(sr_return_quantity#88, 0)), sum(coalesce(ss_quantity#80, 0)), sum(coalesce(cast(sr_return_amt#89 as decimal(12,2)), 0.00)), sum(coalesce(cast(ss_net_paid#81 as decimal(12,2)), 0.00))] Aggregate Attributes [4]: [sum(coalesce(sr_return_quantity#88, 0))#105, sum(coalesce(ss_quantity#80, 0))#106, sum(coalesce(cast(sr_return_amt#89 as decimal(12,2)), 0.00))#107, sum(coalesce(cast(ss_net_paid#81 as decimal(12,2)), 0.00))#108] -Results [3]: [ss_item_sk#78 AS item#109, CheckOverflow((promote_precision(cast(sum(coalesce(sr_return_quantity#88, 0))#105 as decimal(15,4))) / promote_precision(cast(sum(coalesce(ss_quantity#80, 0))#106 as decimal(15,4)))), DecimalType(35,20), true) AS return_ratio#110, CheckOverflow((promote_precision(cast(sum(coalesce(cast(sr_return_amt#89 as decimal(12,2)), 0.00))#107 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cast(ss_net_paid#81 as decimal(12,2)), 0.00))#108 as decimal(15,4)))), DecimalType(35,20), true) AS currency_ratio#111] +Results [3]: [ss_item_sk#78 AS item#109, CheckOverflow((promote_precision(cast(sum(coalesce(sr_return_quantity#88, 0))#105 as decimal(15,4))) / promote_precision(cast(sum(coalesce(ss_quantity#80, 0))#106 as decimal(15,4)))), DecimalType(35,20)) AS return_ratio#110, CheckOverflow((promote_precision(cast(sum(coalesce(cast(sr_return_amt#89 as decimal(12,2)), 0.00))#107 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cast(ss_net_paid#81 as decimal(12,2)), 0.00))#108 as decimal(15,4)))), DecimalType(35,20)) AS currency_ratio#111] (75) Exchange Input [3]: [item#109, return_ratio#110, currency_ratio#111] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q49/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q49/explain.txt index 1e11686ade7cc..657a1a1f358c6 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q49/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q49/explain.txt @@ -156,7 +156,7 @@ Input [7]: [ws_item_sk#1, sum#21, sum#22, sum#23, isEmpty#24, sum#25, isEmpty#26 Keys [1]: [ws_item_sk#1] Functions [4]: [sum(coalesce(wr_return_quantity#11, 0)), sum(coalesce(ws_quantity#3, 0)), sum(coalesce(cast(wr_return_amt#12 as decimal(12,2)), 0.00)), sum(coalesce(cast(ws_net_paid#4 as decimal(12,2)), 0.00))] Aggregate Attributes [4]: [sum(coalesce(wr_return_quantity#11, 0))#28, sum(coalesce(ws_quantity#3, 0))#29, sum(coalesce(cast(wr_return_amt#12 as decimal(12,2)), 0.00))#30, sum(coalesce(cast(ws_net_paid#4 as decimal(12,2)), 0.00))#31] -Results [3]: [ws_item_sk#1 AS item#32, CheckOverflow((promote_precision(cast(sum(coalesce(wr_return_quantity#11, 0))#28 as decimal(15,4))) / promote_precision(cast(sum(coalesce(ws_quantity#3, 0))#29 as decimal(15,4)))), DecimalType(35,20), true) AS return_ratio#33, CheckOverflow((promote_precision(cast(sum(coalesce(cast(wr_return_amt#12 as decimal(12,2)), 0.00))#30 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cast(ws_net_paid#4 as decimal(12,2)), 0.00))#31 as decimal(15,4)))), DecimalType(35,20), true) AS currency_ratio#34] +Results [3]: [ws_item_sk#1 AS item#32, CheckOverflow((promote_precision(cast(sum(coalesce(wr_return_quantity#11, 0))#28 as decimal(15,4))) / promote_precision(cast(sum(coalesce(ws_quantity#3, 0))#29 as decimal(15,4)))), DecimalType(35,20)) AS return_ratio#33, CheckOverflow((promote_precision(cast(sum(coalesce(cast(wr_return_amt#12 as decimal(12,2)), 0.00))#30 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cast(ws_net_paid#4 as decimal(12,2)), 0.00))#31 as decimal(15,4)))), DecimalType(35,20)) AS currency_ratio#34] (18) Exchange Input [3]: [item#32, return_ratio#33, currency_ratio#34] @@ -264,7 +264,7 @@ Input [7]: [cs_item_sk#39, sum#58, sum#59, sum#60, isEmpty#61, sum#62, isEmpty#6 Keys [1]: [cs_item_sk#39] Functions [4]: [sum(coalesce(cr_return_quantity#48, 0)), sum(coalesce(cs_quantity#41, 0)), sum(coalesce(cast(cr_return_amount#49 as decimal(12,2)), 0.00)), sum(coalesce(cast(cs_net_paid#42 as decimal(12,2)), 0.00))] Aggregate Attributes [4]: [sum(coalesce(cr_return_quantity#48, 0))#65, sum(coalesce(cs_quantity#41, 0))#66, sum(coalesce(cast(cr_return_amount#49 as decimal(12,2)), 0.00))#67, sum(coalesce(cast(cs_net_paid#42 as decimal(12,2)), 0.00))#68] -Results [3]: [cs_item_sk#39 AS item#69, CheckOverflow((promote_precision(cast(sum(coalesce(cr_return_quantity#48, 0))#65 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cs_quantity#41, 0))#66 as decimal(15,4)))), DecimalType(35,20), true) AS return_ratio#70, CheckOverflow((promote_precision(cast(sum(coalesce(cast(cr_return_amount#49 as decimal(12,2)), 0.00))#67 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cast(cs_net_paid#42 as decimal(12,2)), 0.00))#68 as decimal(15,4)))), DecimalType(35,20), true) AS currency_ratio#71] +Results [3]: [cs_item_sk#39 AS item#69, CheckOverflow((promote_precision(cast(sum(coalesce(cr_return_quantity#48, 0))#65 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cs_quantity#41, 0))#66 as decimal(15,4)))), DecimalType(35,20)) AS return_ratio#70, CheckOverflow((promote_precision(cast(sum(coalesce(cast(cr_return_amount#49 as decimal(12,2)), 0.00))#67 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cast(cs_net_paid#42 as decimal(12,2)), 0.00))#68 as decimal(15,4)))), DecimalType(35,20)) AS currency_ratio#71] (42) Exchange Input [3]: [item#69, return_ratio#70, currency_ratio#71] @@ -372,7 +372,7 @@ Input [7]: [ss_item_sk#76, sum#95, sum#96, sum#97, isEmpty#98, sum#99, isEmpty#1 Keys [1]: [ss_item_sk#76] Functions [4]: [sum(coalesce(sr_return_quantity#85, 0)), sum(coalesce(ss_quantity#78, 0)), sum(coalesce(cast(sr_return_amt#86 as decimal(12,2)), 0.00)), sum(coalesce(cast(ss_net_paid#79 as decimal(12,2)), 0.00))] Aggregate Attributes [4]: [sum(coalesce(sr_return_quantity#85, 0))#102, sum(coalesce(ss_quantity#78, 0))#103, sum(coalesce(cast(sr_return_amt#86 as decimal(12,2)), 0.00))#104, sum(coalesce(cast(ss_net_paid#79 as decimal(12,2)), 0.00))#105] -Results [3]: [ss_item_sk#76 AS item#106, CheckOverflow((promote_precision(cast(sum(coalesce(sr_return_quantity#85, 0))#102 as decimal(15,4))) / promote_precision(cast(sum(coalesce(ss_quantity#78, 0))#103 as decimal(15,4)))), DecimalType(35,20), true) AS return_ratio#107, CheckOverflow((promote_precision(cast(sum(coalesce(cast(sr_return_amt#86 as decimal(12,2)), 0.00))#104 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cast(ss_net_paid#79 as decimal(12,2)), 0.00))#105 as decimal(15,4)))), DecimalType(35,20), true) AS currency_ratio#108] +Results [3]: [ss_item_sk#76 AS item#106, CheckOverflow((promote_precision(cast(sum(coalesce(sr_return_quantity#85, 0))#102 as decimal(15,4))) / promote_precision(cast(sum(coalesce(ss_quantity#78, 0))#103 as decimal(15,4)))), DecimalType(35,20)) AS return_ratio#107, CheckOverflow((promote_precision(cast(sum(coalesce(cast(sr_return_amt#86 as decimal(12,2)), 0.00))#104 as decimal(15,4))) / promote_precision(cast(sum(coalesce(cast(ss_net_paid#79 as decimal(12,2)), 0.00))#105 as decimal(15,4)))), DecimalType(35,20)) AS currency_ratio#108] (66) Exchange Input [3]: [item#106, return_ratio#107, currency_ratio#108] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/explain.txt index e3d76bfea8c2c..64111eef627d2 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/explain.txt @@ -1,72 +1,74 @@ == Physical Plan == -TakeOrderedAndProject (68) -+- * Filter (67) - +- * HashAggregate (66) - +- * HashAggregate (65) - +- * Project (64) - +- * SortMergeJoin Inner (63) - :- Window (58) - : +- * Sort (57) - : +- Exchange (56) - : +- * Project (55) - : +- * Filter (54) - : +- * SortMergeJoin FullOuter (53) - : :- * Sort (26) - : : +- * HashAggregate (25) - : : +- * HashAggregate (24) - : : +- * Project (23) - : : +- * SortMergeJoin Inner (22) - : : :- * Sort (15) - : : : +- Exchange (14) - : : : +- * Project (13) - : : : +- Window (12) - : : : +- * Sort (11) - : : : +- Exchange (10) - : : : +- * HashAggregate (9) - : : : +- Exchange (8) - : : : +- * HashAggregate (7) - : : : +- * Project (6) - : : : +- * BroadcastHashJoin Inner BuildRight (5) - : : : :- * Filter (3) - : : : : +- * ColumnarToRow (2) - : : : : +- Scan parquet default.web_sales (1) - : : : +- ReusedExchange (4) - : : +- * Sort (21) - : : +- Exchange (20) - : : +- * Project (19) - : : +- Window (18) - : : +- * Sort (17) - : : +- ReusedExchange (16) - : +- * Sort (52) - : +- * HashAggregate (51) - : +- * HashAggregate (50) - : +- * Project (49) - : +- * SortMergeJoin Inner (48) - : :- * Sort (41) - : : +- Exchange (40) - : : +- * Project (39) - : : +- Window (38) - : : +- * Sort (37) - : : +- Exchange (36) - : : +- * HashAggregate (35) - : : +- Exchange (34) - : : +- * HashAggregate (33) - : : +- * Project (32) - : : +- * BroadcastHashJoin Inner BuildRight (31) - : : :- * Filter (29) - : : : +- * ColumnarToRow (28) - : : : +- Scan parquet default.store_sales (27) - : : +- ReusedExchange (30) - : +- * Sort (47) - : +- Exchange (46) - : +- * Project (45) - : +- Window (44) - : +- * Sort (43) - : +- ReusedExchange (42) - +- * Project (62) - +- Window (61) - +- * Sort (60) - +- ReusedExchange (59) +TakeOrderedAndProject (70) ++- * Filter (69) + +- * HashAggregate (68) + +- * HashAggregate (67) + +- * Project (66) + +- * SortMergeJoin Inner (65) + :- Window (60) + : +- * Sort (59) + : +- Exchange (58) + : +- * Project (57) + : +- * Filter (56) + : +- * SortMergeJoin FullOuter (55) + : :- * Sort (27) + : : +- Exchange (26) + : : +- * HashAggregate (25) + : : +- * HashAggregate (24) + : : +- * Project (23) + : : +- * SortMergeJoin Inner (22) + : : :- * Sort (15) + : : : +- Exchange (14) + : : : +- * Project (13) + : : : +- Window (12) + : : : +- * Sort (11) + : : : +- Exchange (10) + : : : +- * HashAggregate (9) + : : : +- Exchange (8) + : : : +- * HashAggregate (7) + : : : +- * Project (6) + : : : +- * BroadcastHashJoin Inner BuildRight (5) + : : : :- * Filter (3) + : : : : +- * ColumnarToRow (2) + : : : : +- Scan parquet default.web_sales (1) + : : : +- ReusedExchange (4) + : : +- * Sort (21) + : : +- Exchange (20) + : : +- * Project (19) + : : +- Window (18) + : : +- * Sort (17) + : : +- ReusedExchange (16) + : +- * Sort (54) + : +- Exchange (53) + : +- * HashAggregate (52) + : +- * HashAggregate (51) + : +- * Project (50) + : +- * SortMergeJoin Inner (49) + : :- * Sort (42) + : : +- Exchange (41) + : : +- * Project (40) + : : +- Window (39) + : : +- * Sort (38) + : : +- Exchange (37) + : : +- * HashAggregate (36) + : : +- Exchange (35) + : : +- * HashAggregate (34) + : : +- * Project (33) + : : +- * BroadcastHashJoin Inner BuildRight (32) + : : :- * Filter (30) + : : : +- * ColumnarToRow (29) + : : : +- Scan parquet default.store_sales (28) + : : +- ReusedExchange (31) + : +- * Sort (48) + : +- Exchange (47) + : +- * Project (46) + : +- Window (45) + : +- * Sort (44) + : +- ReusedExchange (43) + +- * Project (64) + +- Window (63) + +- * Sort (62) + +- ReusedExchange (61) (1) Scan parquet default.web_sales @@ -84,7 +86,7 @@ Input [3]: [ws_item_sk#1, ws_sales_price#2, ws_sold_date_sk#3] Input [3]: [ws_item_sk#1, ws_sales_price#2, ws_sold_date_sk#3] Condition : isnotnull(ws_item_sk#1) -(4) ReusedExchange [Reuses operator id: 73] +(4) ReusedExchange [Reuses operator id: 75] Output [2]: [d_date_sk#5, d_date#6] (5) BroadcastHashJoin [codegen id : 2] @@ -184,232 +186,240 @@ Functions [1]: [sum(sumws#20)] Aggregate Attributes [1]: [sum(sumws#20)#26] Results [3]: [item_sk#11, d_date#6, sum(sumws#20)#26 AS cume_sales#27] -(26) Sort [codegen id : 13] +(26) Exchange +Input [3]: [item_sk#11, d_date#6, cume_sales#27] +Arguments: hashpartitioning(item_sk#11, d_date#6, 5), ENSURE_REQUIREMENTS, [id=#28] + +(27) Sort [codegen id : 14] Input [3]: [item_sk#11, d_date#6, cume_sales#27] Arguments: [item_sk#11 ASC NULLS FIRST, d_date#6 ASC NULLS FIRST], false, 0 -(27) Scan parquet default.store_sales -Output [3]: [ss_item_sk#28, ss_sales_price#29, ss_sold_date_sk#30] +(28) Scan parquet default.store_sales +Output [3]: [ss_item_sk#29, ss_sales_price#30, ss_sold_date_sk#31] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(ss_sold_date_sk#30), dynamicpruningexpression(ss_sold_date_sk#30 IN dynamicpruning#4)] +PartitionFilters: [isnotnull(ss_sold_date_sk#31), dynamicpruningexpression(ss_sold_date_sk#31 IN dynamicpruning#4)] PushedFilters: [IsNotNull(ss_item_sk)] ReadSchema: struct -(28) ColumnarToRow [codegen id : 15] -Input [3]: [ss_item_sk#28, ss_sales_price#29, ss_sold_date_sk#30] +(29) ColumnarToRow [codegen id : 16] +Input [3]: [ss_item_sk#29, ss_sales_price#30, ss_sold_date_sk#31] -(29) Filter [codegen id : 15] -Input [3]: [ss_item_sk#28, ss_sales_price#29, ss_sold_date_sk#30] -Condition : isnotnull(ss_item_sk#28) +(30) Filter [codegen id : 16] +Input [3]: [ss_item_sk#29, ss_sales_price#30, ss_sold_date_sk#31] +Condition : isnotnull(ss_item_sk#29) -(30) ReusedExchange [Reuses operator id: 73] -Output [2]: [d_date_sk#31, d_date#32] +(31) ReusedExchange [Reuses operator id: 75] +Output [2]: [d_date_sk#32, d_date#33] -(31) BroadcastHashJoin [codegen id : 15] -Left keys [1]: [ss_sold_date_sk#30] -Right keys [1]: [d_date_sk#31] +(32) BroadcastHashJoin [codegen id : 16] +Left keys [1]: [ss_sold_date_sk#31] +Right keys [1]: [d_date_sk#32] Join condition: None -(32) Project [codegen id : 15] -Output [3]: [ss_item_sk#28, ss_sales_price#29, d_date#32] -Input [5]: [ss_item_sk#28, ss_sales_price#29, ss_sold_date_sk#30, d_date_sk#31, d_date#32] - -(33) HashAggregate [codegen id : 15] -Input [3]: [ss_item_sk#28, ss_sales_price#29, d_date#32] -Keys [2]: [ss_item_sk#28, d_date#32] -Functions [1]: [partial_sum(UnscaledValue(ss_sales_price#29))] -Aggregate Attributes [1]: [sum#33] -Results [3]: [ss_item_sk#28, d_date#32, sum#34] - -(34) Exchange -Input [3]: [ss_item_sk#28, d_date#32, sum#34] -Arguments: hashpartitioning(ss_item_sk#28, d_date#32, 5), ENSURE_REQUIREMENTS, [id=#35] - -(35) HashAggregate [codegen id : 16] -Input [3]: [ss_item_sk#28, d_date#32, sum#34] -Keys [2]: [ss_item_sk#28, d_date#32] -Functions [1]: [sum(UnscaledValue(ss_sales_price#29))] -Aggregate Attributes [1]: [sum(UnscaledValue(ss_sales_price#29))#36] -Results [4]: [ss_item_sk#28 AS item_sk#37, d_date#32, MakeDecimal(sum(UnscaledValue(ss_sales_price#29))#36,17,2) AS sumss#38, ss_item_sk#28] - -(36) Exchange -Input [4]: [item_sk#37, d_date#32, sumss#38, ss_item_sk#28] -Arguments: hashpartitioning(ss_item_sk#28, 5), ENSURE_REQUIREMENTS, [id=#39] - -(37) Sort [codegen id : 17] -Input [4]: [item_sk#37, d_date#32, sumss#38, ss_item_sk#28] -Arguments: [ss_item_sk#28 ASC NULLS FIRST, d_date#32 ASC NULLS FIRST], false, 0 - -(38) Window -Input [4]: [item_sk#37, d_date#32, sumss#38, ss_item_sk#28] -Arguments: [row_number() windowspecdefinition(ss_item_sk#28, d_date#32 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rk#40], [ss_item_sk#28], [d_date#32 ASC NULLS FIRST] - -(39) Project [codegen id : 18] -Output [4]: [item_sk#37, d_date#32, sumss#38, rk#40] -Input [5]: [item_sk#37, d_date#32, sumss#38, ss_item_sk#28, rk#40] - -(40) Exchange -Input [4]: [item_sk#37, d_date#32, sumss#38, rk#40] -Arguments: hashpartitioning(item_sk#37, 5), ENSURE_REQUIREMENTS, [id=#41] - -(41) Sort [codegen id : 19] -Input [4]: [item_sk#37, d_date#32, sumss#38, rk#40] -Arguments: [item_sk#37 ASC NULLS FIRST], false, 0 - -(42) ReusedExchange [Reuses operator id: 36] -Output [4]: [item_sk#37, d_date#42, sumss#38, ss_item_sk#43] - -(43) Sort [codegen id : 23] -Input [4]: [item_sk#37, d_date#42, sumss#38, ss_item_sk#43] -Arguments: [ss_item_sk#43 ASC NULLS FIRST, d_date#42 ASC NULLS FIRST], false, 0 - -(44) Window -Input [4]: [item_sk#37, d_date#42, sumss#38, ss_item_sk#43] -Arguments: [row_number() windowspecdefinition(ss_item_sk#43, d_date#42 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rk#44], [ss_item_sk#43], [d_date#42 ASC NULLS FIRST] - -(45) Project [codegen id : 24] -Output [3]: [item_sk#37 AS item_sk#45, sumss#38 AS sumss#46, rk#44] -Input [5]: [item_sk#37, d_date#42, sumss#38, ss_item_sk#43, rk#44] - -(46) Exchange -Input [3]: [item_sk#45, sumss#46, rk#44] -Arguments: hashpartitioning(item_sk#45, 5), ENSURE_REQUIREMENTS, [id=#47] - -(47) Sort [codegen id : 25] -Input [3]: [item_sk#45, sumss#46, rk#44] -Arguments: [item_sk#45 ASC NULLS FIRST], false, 0 - -(48) SortMergeJoin [codegen id : 26] -Left keys [1]: [item_sk#37] -Right keys [1]: [item_sk#45] -Join condition: (rk#40 >= rk#44) - -(49) Project [codegen id : 26] -Output [4]: [item_sk#37, d_date#32, sumss#38, sumss#46] -Input [7]: [item_sk#37, d_date#32, sumss#38, rk#40, item_sk#45, sumss#46, rk#44] - -(50) HashAggregate [codegen id : 26] -Input [4]: [item_sk#37, d_date#32, sumss#38, sumss#46] -Keys [3]: [item_sk#37, d_date#32, sumss#38] -Functions [1]: [partial_sum(sumss#46)] -Aggregate Attributes [2]: [sum#48, isEmpty#49] -Results [5]: [item_sk#37, d_date#32, sumss#38, sum#50, isEmpty#51] - -(51) HashAggregate [codegen id : 26] -Input [5]: [item_sk#37, d_date#32, sumss#38, sum#50, isEmpty#51] -Keys [3]: [item_sk#37, d_date#32, sumss#38] -Functions [1]: [sum(sumss#46)] -Aggregate Attributes [1]: [sum(sumss#46)#52] -Results [3]: [item_sk#37, d_date#32, sum(sumss#46)#52 AS cume_sales#53] - -(52) Sort [codegen id : 26] -Input [3]: [item_sk#37, d_date#32, cume_sales#53] -Arguments: [item_sk#37 ASC NULLS FIRST, d_date#32 ASC NULLS FIRST], false, 0 - -(53) SortMergeJoin [codegen id : 27] +(33) Project [codegen id : 16] +Output [3]: [ss_item_sk#29, ss_sales_price#30, d_date#33] +Input [5]: [ss_item_sk#29, ss_sales_price#30, ss_sold_date_sk#31, d_date_sk#32, d_date#33] + +(34) HashAggregate [codegen id : 16] +Input [3]: [ss_item_sk#29, ss_sales_price#30, d_date#33] +Keys [2]: [ss_item_sk#29, d_date#33] +Functions [1]: [partial_sum(UnscaledValue(ss_sales_price#30))] +Aggregate Attributes [1]: [sum#34] +Results [3]: [ss_item_sk#29, d_date#33, sum#35] + +(35) Exchange +Input [3]: [ss_item_sk#29, d_date#33, sum#35] +Arguments: hashpartitioning(ss_item_sk#29, d_date#33, 5), ENSURE_REQUIREMENTS, [id=#36] + +(36) HashAggregate [codegen id : 17] +Input [3]: [ss_item_sk#29, d_date#33, sum#35] +Keys [2]: [ss_item_sk#29, d_date#33] +Functions [1]: [sum(UnscaledValue(ss_sales_price#30))] +Aggregate Attributes [1]: [sum(UnscaledValue(ss_sales_price#30))#37] +Results [4]: [ss_item_sk#29 AS item_sk#38, d_date#33, MakeDecimal(sum(UnscaledValue(ss_sales_price#30))#37,17,2) AS sumss#39, ss_item_sk#29] + +(37) Exchange +Input [4]: [item_sk#38, d_date#33, sumss#39, ss_item_sk#29] +Arguments: hashpartitioning(ss_item_sk#29, 5), ENSURE_REQUIREMENTS, [id=#40] + +(38) Sort [codegen id : 18] +Input [4]: [item_sk#38, d_date#33, sumss#39, ss_item_sk#29] +Arguments: [ss_item_sk#29 ASC NULLS FIRST, d_date#33 ASC NULLS FIRST], false, 0 + +(39) Window +Input [4]: [item_sk#38, d_date#33, sumss#39, ss_item_sk#29] +Arguments: [row_number() windowspecdefinition(ss_item_sk#29, d_date#33 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rk#41], [ss_item_sk#29], [d_date#33 ASC NULLS FIRST] + +(40) Project [codegen id : 19] +Output [4]: [item_sk#38, d_date#33, sumss#39, rk#41] +Input [5]: [item_sk#38, d_date#33, sumss#39, ss_item_sk#29, rk#41] + +(41) Exchange +Input [4]: [item_sk#38, d_date#33, sumss#39, rk#41] +Arguments: hashpartitioning(item_sk#38, 5), ENSURE_REQUIREMENTS, [id=#42] + +(42) Sort [codegen id : 20] +Input [4]: [item_sk#38, d_date#33, sumss#39, rk#41] +Arguments: [item_sk#38 ASC NULLS FIRST], false, 0 + +(43) ReusedExchange [Reuses operator id: 37] +Output [4]: [item_sk#38, d_date#43, sumss#39, ss_item_sk#44] + +(44) Sort [codegen id : 24] +Input [4]: [item_sk#38, d_date#43, sumss#39, ss_item_sk#44] +Arguments: [ss_item_sk#44 ASC NULLS FIRST, d_date#43 ASC NULLS FIRST], false, 0 + +(45) Window +Input [4]: [item_sk#38, d_date#43, sumss#39, ss_item_sk#44] +Arguments: [row_number() windowspecdefinition(ss_item_sk#44, d_date#43 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rk#45], [ss_item_sk#44], [d_date#43 ASC NULLS FIRST] + +(46) Project [codegen id : 25] +Output [3]: [item_sk#38 AS item_sk#46, sumss#39 AS sumss#47, rk#45] +Input [5]: [item_sk#38, d_date#43, sumss#39, ss_item_sk#44, rk#45] + +(47) Exchange +Input [3]: [item_sk#46, sumss#47, rk#45] +Arguments: hashpartitioning(item_sk#46, 5), ENSURE_REQUIREMENTS, [id=#48] + +(48) Sort [codegen id : 26] +Input [3]: [item_sk#46, sumss#47, rk#45] +Arguments: [item_sk#46 ASC NULLS FIRST], false, 0 + +(49) SortMergeJoin [codegen id : 27] +Left keys [1]: [item_sk#38] +Right keys [1]: [item_sk#46] +Join condition: (rk#41 >= rk#45) + +(50) Project [codegen id : 27] +Output [4]: [item_sk#38, d_date#33, sumss#39, sumss#47] +Input [7]: [item_sk#38, d_date#33, sumss#39, rk#41, item_sk#46, sumss#47, rk#45] + +(51) HashAggregate [codegen id : 27] +Input [4]: [item_sk#38, d_date#33, sumss#39, sumss#47] +Keys [3]: [item_sk#38, d_date#33, sumss#39] +Functions [1]: [partial_sum(sumss#47)] +Aggregate Attributes [2]: [sum#49, isEmpty#50] +Results [5]: [item_sk#38, d_date#33, sumss#39, sum#51, isEmpty#52] + +(52) HashAggregate [codegen id : 27] +Input [5]: [item_sk#38, d_date#33, sumss#39, sum#51, isEmpty#52] +Keys [3]: [item_sk#38, d_date#33, sumss#39] +Functions [1]: [sum(sumss#47)] +Aggregate Attributes [1]: [sum(sumss#47)#53] +Results [3]: [item_sk#38, d_date#33, sum(sumss#47)#53 AS cume_sales#54] + +(53) Exchange +Input [3]: [item_sk#38, d_date#33, cume_sales#54] +Arguments: hashpartitioning(item_sk#38, d_date#33, 5), ENSURE_REQUIREMENTS, [id=#55] + +(54) Sort [codegen id : 28] +Input [3]: [item_sk#38, d_date#33, cume_sales#54] +Arguments: [item_sk#38 ASC NULLS FIRST, d_date#33 ASC NULLS FIRST], false, 0 + +(55) SortMergeJoin [codegen id : 29] Left keys [2]: [item_sk#11, d_date#6] -Right keys [2]: [item_sk#37, d_date#32] +Right keys [2]: [item_sk#38, d_date#33] Join condition: None -(54) Filter [codegen id : 27] -Input [6]: [item_sk#11, d_date#6, cume_sales#27, item_sk#37, d_date#32, cume_sales#53] -Condition : isnotnull(CASE WHEN isnotnull(item_sk#11) THEN item_sk#11 ELSE item_sk#37 END) +(56) Filter [codegen id : 29] +Input [6]: [item_sk#11, d_date#6, cume_sales#27, item_sk#38, d_date#33, cume_sales#54] +Condition : isnotnull(CASE WHEN isnotnull(item_sk#11) THEN item_sk#11 ELSE item_sk#38 END) -(55) Project [codegen id : 27] -Output [4]: [CASE WHEN isnotnull(item_sk#11) THEN item_sk#11 ELSE item_sk#37 END AS item_sk#54, CASE WHEN isnotnull(d_date#6) THEN d_date#6 ELSE d_date#32 END AS d_date#55, cume_sales#27 AS web_sales#56, cume_sales#53 AS store_sales#57] -Input [6]: [item_sk#11, d_date#6, cume_sales#27, item_sk#37, d_date#32, cume_sales#53] +(57) Project [codegen id : 29] +Output [4]: [CASE WHEN isnotnull(item_sk#11) THEN item_sk#11 ELSE item_sk#38 END AS item_sk#56, CASE WHEN isnotnull(d_date#6) THEN d_date#6 ELSE d_date#33 END AS d_date#57, cume_sales#27 AS web_sales#58, cume_sales#54 AS store_sales#59] +Input [6]: [item_sk#11, d_date#6, cume_sales#27, item_sk#38, d_date#33, cume_sales#54] -(56) Exchange -Input [4]: [item_sk#54, d_date#55, web_sales#56, store_sales#57] -Arguments: hashpartitioning(item_sk#54, 5), ENSURE_REQUIREMENTS, [id=#58] +(58) Exchange +Input [4]: [item_sk#56, d_date#57, web_sales#58, store_sales#59] +Arguments: hashpartitioning(item_sk#56, 5), ENSURE_REQUIREMENTS, [id=#60] -(57) Sort [codegen id : 28] -Input [4]: [item_sk#54, d_date#55, web_sales#56, store_sales#57] -Arguments: [item_sk#54 ASC NULLS FIRST, d_date#55 ASC NULLS FIRST], false, 0 +(59) Sort [codegen id : 30] +Input [4]: [item_sk#56, d_date#57, web_sales#58, store_sales#59] +Arguments: [item_sk#56 ASC NULLS FIRST, d_date#57 ASC NULLS FIRST], false, 0 -(58) Window -Input [4]: [item_sk#54, d_date#55, web_sales#56, store_sales#57] -Arguments: [row_number() windowspecdefinition(item_sk#54, d_date#55 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rk#59], [item_sk#54], [d_date#55 ASC NULLS FIRST] +(60) Window +Input [4]: [item_sk#56, d_date#57, web_sales#58, store_sales#59] +Arguments: [row_number() windowspecdefinition(item_sk#56, d_date#57 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rk#61], [item_sk#56], [d_date#57 ASC NULLS FIRST] -(59) ReusedExchange [Reuses operator id: 56] -Output [4]: [item_sk#54, d_date#55, web_sales#56, store_sales#57] +(61) ReusedExchange [Reuses operator id: 58] +Output [4]: [item_sk#56, d_date#57, web_sales#58, store_sales#59] -(60) Sort [codegen id : 56] -Input [4]: [item_sk#54, d_date#55, web_sales#56, store_sales#57] -Arguments: [item_sk#54 ASC NULLS FIRST, d_date#55 ASC NULLS FIRST], false, 0 +(62) Sort [codegen id : 60] +Input [4]: [item_sk#56, d_date#57, web_sales#58, store_sales#59] +Arguments: [item_sk#56 ASC NULLS FIRST, d_date#57 ASC NULLS FIRST], false, 0 -(61) Window -Input [4]: [item_sk#54, d_date#55, web_sales#56, store_sales#57] -Arguments: [row_number() windowspecdefinition(item_sk#54, d_date#55 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rk#60], [item_sk#54], [d_date#55 ASC NULLS FIRST] +(63) Window +Input [4]: [item_sk#56, d_date#57, web_sales#58, store_sales#59] +Arguments: [row_number() windowspecdefinition(item_sk#56, d_date#57 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rk#62], [item_sk#56], [d_date#57 ASC NULLS FIRST] -(62) Project [codegen id : 57] -Output [4]: [item_sk#54 AS item_sk#61, web_sales#56 AS web_sales#62, store_sales#57 AS store_sales#63, rk#60] -Input [5]: [item_sk#54, d_date#55, web_sales#56, store_sales#57, rk#60] +(64) Project [codegen id : 61] +Output [4]: [item_sk#56 AS item_sk#63, web_sales#58 AS web_sales#64, store_sales#59 AS store_sales#65, rk#62] +Input [5]: [item_sk#56, d_date#57, web_sales#58, store_sales#59, rk#62] -(63) SortMergeJoin [codegen id : 58] -Left keys [1]: [item_sk#54] -Right keys [1]: [item_sk#61] -Join condition: (rk#59 >= rk#60) +(65) SortMergeJoin [codegen id : 62] +Left keys [1]: [item_sk#56] +Right keys [1]: [item_sk#63] +Join condition: (rk#61 >= rk#62) -(64) Project [codegen id : 58] -Output [6]: [item_sk#54, d_date#55, web_sales#56, store_sales#57, web_sales#62, store_sales#63] -Input [9]: [item_sk#54, d_date#55, web_sales#56, store_sales#57, rk#59, item_sk#61, web_sales#62, store_sales#63, rk#60] +(66) Project [codegen id : 62] +Output [6]: [item_sk#56, d_date#57, web_sales#58, store_sales#59, web_sales#64, store_sales#65] +Input [9]: [item_sk#56, d_date#57, web_sales#58, store_sales#59, rk#61, item_sk#63, web_sales#64, store_sales#65, rk#62] -(65) HashAggregate [codegen id : 58] -Input [6]: [item_sk#54, d_date#55, web_sales#56, store_sales#57, web_sales#62, store_sales#63] -Keys [4]: [item_sk#54, d_date#55, web_sales#56, store_sales#57] -Functions [2]: [partial_max(web_sales#62), partial_max(store_sales#63)] -Aggregate Attributes [2]: [max#64, max#65] -Results [6]: [item_sk#54, d_date#55, web_sales#56, store_sales#57, max#66, max#67] +(67) HashAggregate [codegen id : 62] +Input [6]: [item_sk#56, d_date#57, web_sales#58, store_sales#59, web_sales#64, store_sales#65] +Keys [4]: [item_sk#56, d_date#57, web_sales#58, store_sales#59] +Functions [2]: [partial_max(web_sales#64), partial_max(store_sales#65)] +Aggregate Attributes [2]: [max#66, max#67] +Results [6]: [item_sk#56, d_date#57, web_sales#58, store_sales#59, max#68, max#69] -(66) HashAggregate [codegen id : 58] -Input [6]: [item_sk#54, d_date#55, web_sales#56, store_sales#57, max#66, max#67] -Keys [4]: [item_sk#54, d_date#55, web_sales#56, store_sales#57] -Functions [2]: [max(web_sales#62), max(store_sales#63)] -Aggregate Attributes [2]: [max(web_sales#62)#68, max(store_sales#63)#69] -Results [6]: [item_sk#54, d_date#55, web_sales#56, store_sales#57, max(web_sales#62)#68 AS web_cumulative#70, max(store_sales#63)#69 AS store_cumulative#71] +(68) HashAggregate [codegen id : 62] +Input [6]: [item_sk#56, d_date#57, web_sales#58, store_sales#59, max#68, max#69] +Keys [4]: [item_sk#56, d_date#57, web_sales#58, store_sales#59] +Functions [2]: [max(web_sales#64), max(store_sales#65)] +Aggregate Attributes [2]: [max(web_sales#64)#70, max(store_sales#65)#71] +Results [6]: [item_sk#56, d_date#57, web_sales#58, store_sales#59, max(web_sales#64)#70 AS web_cumulative#72, max(store_sales#65)#71 AS store_cumulative#73] -(67) Filter [codegen id : 58] -Input [6]: [item_sk#54, d_date#55, web_sales#56, store_sales#57, web_cumulative#70, store_cumulative#71] -Condition : ((isnotnull(web_cumulative#70) AND isnotnull(store_cumulative#71)) AND (web_cumulative#70 > store_cumulative#71)) +(69) Filter [codegen id : 62] +Input [6]: [item_sk#56, d_date#57, web_sales#58, store_sales#59, web_cumulative#72, store_cumulative#73] +Condition : ((isnotnull(web_cumulative#72) AND isnotnull(store_cumulative#73)) AND (web_cumulative#72 > store_cumulative#73)) -(68) TakeOrderedAndProject -Input [6]: [item_sk#54, d_date#55, web_sales#56, store_sales#57, web_cumulative#70, store_cumulative#71] -Arguments: 100, [item_sk#54 ASC NULLS FIRST, d_date#55 ASC NULLS FIRST], [item_sk#54, d_date#55, web_sales#56, store_sales#57, web_cumulative#70, store_cumulative#71] +(70) TakeOrderedAndProject +Input [6]: [item_sk#56, d_date#57, web_sales#58, store_sales#59, web_cumulative#72, store_cumulative#73] +Arguments: 100, [item_sk#56 ASC NULLS FIRST, d_date#57 ASC NULLS FIRST], [item_sk#56, d_date#57, web_sales#58, store_sales#59, web_cumulative#72, store_cumulative#73] ===== Subqueries ===== Subquery:1 Hosting operator id = 1 Hosting Expression = ws_sold_date_sk#3 IN dynamicpruning#4 -BroadcastExchange (73) -+- * Project (72) - +- * Filter (71) - +- * ColumnarToRow (70) - +- Scan parquet default.date_dim (69) +BroadcastExchange (75) ++- * Project (74) + +- * Filter (73) + +- * ColumnarToRow (72) + +- Scan parquet default.date_dim (71) -(69) Scan parquet default.date_dim -Output [3]: [d_date_sk#5, d_date#6, d_month_seq#72] +(71) Scan parquet default.date_dim +Output [3]: [d_date_sk#5, d_date#6, d_month_seq#74] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_month_seq), GreaterThanOrEqual(d_month_seq,1212), LessThanOrEqual(d_month_seq,1223), IsNotNull(d_date_sk)] ReadSchema: struct -(70) ColumnarToRow [codegen id : 1] -Input [3]: [d_date_sk#5, d_date#6, d_month_seq#72] +(72) ColumnarToRow [codegen id : 1] +Input [3]: [d_date_sk#5, d_date#6, d_month_seq#74] -(71) Filter [codegen id : 1] -Input [3]: [d_date_sk#5, d_date#6, d_month_seq#72] -Condition : (((isnotnull(d_month_seq#72) AND (d_month_seq#72 >= 1212)) AND (d_month_seq#72 <= 1223)) AND isnotnull(d_date_sk#5)) +(73) Filter [codegen id : 1] +Input [3]: [d_date_sk#5, d_date#6, d_month_seq#74] +Condition : (((isnotnull(d_month_seq#74) AND (d_month_seq#74 >= 1212)) AND (d_month_seq#74 <= 1223)) AND isnotnull(d_date_sk#5)) -(72) Project [codegen id : 1] +(74) Project [codegen id : 1] Output [2]: [d_date_sk#5, d_date#6] -Input [3]: [d_date_sk#5, d_date#6, d_month_seq#72] +Input [3]: [d_date_sk#5, d_date#6, d_month_seq#74] -(73) BroadcastExchange +(75) BroadcastExchange Input [2]: [d_date_sk#5, d_date#6] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#73] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#75] -Subquery:2 Hosting operator id = 27 Hosting Expression = ss_sold_date_sk#30 IN dynamicpruning#4 +Subquery:2 Hosting operator id = 28 Hosting Expression = ss_sold_date_sk#31 IN dynamicpruning#4 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/simplified.txt index b1d245a9ffc43..1a89b7c72a169 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/simplified.txt @@ -1,5 +1,5 @@ TakeOrderedAndProject [item_sk,d_date,web_sales,store_sales,web_cumulative,store_cumulative] - WholeStageCodegen (58) + WholeStageCodegen (62) Filter [web_cumulative,store_cumulative] HashAggregate [item_sk,d_date,web_sales,store_sales,max,max] [max(web_sales),max(store_sales),web_cumulative,store_cumulative,max,max] HashAggregate [item_sk,d_date,web_sales,store_sales,web_sales,store_sales] [max,max,max,max] @@ -7,123 +7,129 @@ TakeOrderedAndProject [item_sk,d_date,web_sales,store_sales,web_cumulative,store SortMergeJoin [item_sk,item_sk,rk,rk] InputAdapter Window [item_sk,d_date] - WholeStageCodegen (28) + WholeStageCodegen (30) Sort [item_sk,d_date] InputAdapter Exchange [item_sk] #1 - WholeStageCodegen (27) + WholeStageCodegen (29) Project [item_sk,item_sk,d_date,d_date,cume_sales,cume_sales] Filter [item_sk,item_sk] SortMergeJoin [item_sk,d_date,item_sk,d_date] InputAdapter - WholeStageCodegen (13) + WholeStageCodegen (14) Sort [item_sk,d_date] - HashAggregate [item_sk,d_date,sumws,sum,isEmpty] [sum(sumws),cume_sales,sum,isEmpty] - HashAggregate [item_sk,d_date,sumws,sumws] [sum,isEmpty,sum,isEmpty] - Project [item_sk,d_date,sumws,sumws] - SortMergeJoin [item_sk,item_sk,rk,rk] - InputAdapter - WholeStageCodegen (6) - Sort [item_sk] + InputAdapter + Exchange [item_sk,d_date] #2 + WholeStageCodegen (13) + HashAggregate [item_sk,d_date,sumws,sum,isEmpty] [sum(sumws),cume_sales,sum,isEmpty] + HashAggregate [item_sk,d_date,sumws,sumws] [sum,isEmpty,sum,isEmpty] + Project [item_sk,d_date,sumws,sumws] + SortMergeJoin [item_sk,item_sk,rk,rk] InputAdapter - Exchange [item_sk] #2 - WholeStageCodegen (5) - Project [item_sk,d_date,sumws,rk] - InputAdapter - Window [ws_item_sk,d_date] - WholeStageCodegen (4) - Sort [ws_item_sk,d_date] - InputAdapter - Exchange [ws_item_sk] #3 - WholeStageCodegen (3) - HashAggregate [ws_item_sk,d_date,sum] [sum(UnscaledValue(ws_sales_price)),item_sk,sumws,sum] - InputAdapter - Exchange [ws_item_sk,d_date] #4 - WholeStageCodegen (2) - HashAggregate [ws_item_sk,d_date,ws_sales_price] [sum,sum] - Project [ws_item_sk,ws_sales_price,d_date] - BroadcastHashJoin [ws_sold_date_sk,d_date_sk] - Filter [ws_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.web_sales [ws_item_sk,ws_sales_price,ws_sold_date_sk] - SubqueryBroadcast [d_date_sk] #1 - BroadcastExchange #5 - WholeStageCodegen (1) - Project [d_date_sk,d_date] - Filter [d_month_seq,d_date_sk] - ColumnarToRow - InputAdapter - Scan parquet default.date_dim [d_date_sk,d_date,d_month_seq] - InputAdapter - ReusedExchange [d_date_sk,d_date] #5 - InputAdapter - WholeStageCodegen (12) - Sort [item_sk] + WholeStageCodegen (6) + Sort [item_sk] + InputAdapter + Exchange [item_sk] #3 + WholeStageCodegen (5) + Project [item_sk,d_date,sumws,rk] + InputAdapter + Window [ws_item_sk,d_date] + WholeStageCodegen (4) + Sort [ws_item_sk,d_date] + InputAdapter + Exchange [ws_item_sk] #4 + WholeStageCodegen (3) + HashAggregate [ws_item_sk,d_date,sum] [sum(UnscaledValue(ws_sales_price)),item_sk,sumws,sum] + InputAdapter + Exchange [ws_item_sk,d_date] #5 + WholeStageCodegen (2) + HashAggregate [ws_item_sk,d_date,ws_sales_price] [sum,sum] + Project [ws_item_sk,ws_sales_price,d_date] + BroadcastHashJoin [ws_sold_date_sk,d_date_sk] + Filter [ws_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.web_sales [ws_item_sk,ws_sales_price,ws_sold_date_sk] + SubqueryBroadcast [d_date_sk] #1 + BroadcastExchange #6 + WholeStageCodegen (1) + Project [d_date_sk,d_date] + Filter [d_month_seq,d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_date,d_month_seq] + InputAdapter + ReusedExchange [d_date_sk,d_date] #6 InputAdapter - Exchange [item_sk] #6 - WholeStageCodegen (11) - Project [item_sk,sumws,rk] - InputAdapter - Window [ws_item_sk,d_date] - WholeStageCodegen (10) - Sort [ws_item_sk,d_date] - InputAdapter - ReusedExchange [item_sk,d_date,sumws,ws_item_sk] #3 + WholeStageCodegen (12) + Sort [item_sk] + InputAdapter + Exchange [item_sk] #7 + WholeStageCodegen (11) + Project [item_sk,sumws,rk] + InputAdapter + Window [ws_item_sk,d_date] + WholeStageCodegen (10) + Sort [ws_item_sk,d_date] + InputAdapter + ReusedExchange [item_sk,d_date,sumws,ws_item_sk] #4 InputAdapter - WholeStageCodegen (26) + WholeStageCodegen (28) Sort [item_sk,d_date] - HashAggregate [item_sk,d_date,sumss,sum,isEmpty] [sum(sumss),cume_sales,sum,isEmpty] - HashAggregate [item_sk,d_date,sumss,sumss] [sum,isEmpty,sum,isEmpty] - Project [item_sk,d_date,sumss,sumss] - SortMergeJoin [item_sk,item_sk,rk,rk] - InputAdapter - WholeStageCodegen (19) - Sort [item_sk] + InputAdapter + Exchange [item_sk,d_date] #8 + WholeStageCodegen (27) + HashAggregate [item_sk,d_date,sumss,sum,isEmpty] [sum(sumss),cume_sales,sum,isEmpty] + HashAggregate [item_sk,d_date,sumss,sumss] [sum,isEmpty,sum,isEmpty] + Project [item_sk,d_date,sumss,sumss] + SortMergeJoin [item_sk,item_sk,rk,rk] InputAdapter - Exchange [item_sk] #7 - WholeStageCodegen (18) - Project [item_sk,d_date,sumss,rk] - InputAdapter - Window [ss_item_sk,d_date] - WholeStageCodegen (17) - Sort [ss_item_sk,d_date] - InputAdapter - Exchange [ss_item_sk] #8 - WholeStageCodegen (16) - HashAggregate [ss_item_sk,d_date,sum] [sum(UnscaledValue(ss_sales_price)),item_sk,sumss,sum] - InputAdapter - Exchange [ss_item_sk,d_date] #9 - WholeStageCodegen (15) - HashAggregate [ss_item_sk,d_date,ss_sales_price] [sum,sum] - Project [ss_item_sk,ss_sales_price,d_date] - BroadcastHashJoin [ss_sold_date_sk,d_date_sk] - Filter [ss_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.store_sales [ss_item_sk,ss_sales_price,ss_sold_date_sk] - ReusedSubquery [d_date_sk] #1 - InputAdapter - ReusedExchange [d_date_sk,d_date] #5 - InputAdapter - WholeStageCodegen (25) - Sort [item_sk] + WholeStageCodegen (20) + Sort [item_sk] + InputAdapter + Exchange [item_sk] #9 + WholeStageCodegen (19) + Project [item_sk,d_date,sumss,rk] + InputAdapter + Window [ss_item_sk,d_date] + WholeStageCodegen (18) + Sort [ss_item_sk,d_date] + InputAdapter + Exchange [ss_item_sk] #10 + WholeStageCodegen (17) + HashAggregate [ss_item_sk,d_date,sum] [sum(UnscaledValue(ss_sales_price)),item_sk,sumss,sum] + InputAdapter + Exchange [ss_item_sk,d_date] #11 + WholeStageCodegen (16) + HashAggregate [ss_item_sk,d_date,ss_sales_price] [sum,sum] + Project [ss_item_sk,ss_sales_price,d_date] + BroadcastHashJoin [ss_sold_date_sk,d_date_sk] + Filter [ss_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.store_sales [ss_item_sk,ss_sales_price,ss_sold_date_sk] + ReusedSubquery [d_date_sk] #1 + InputAdapter + ReusedExchange [d_date_sk,d_date] #6 InputAdapter - Exchange [item_sk] #10 - WholeStageCodegen (24) - Project [item_sk,sumss,rk] - InputAdapter - Window [ss_item_sk,d_date] - WholeStageCodegen (23) - Sort [ss_item_sk,d_date] - InputAdapter - ReusedExchange [item_sk,d_date,sumss,ss_item_sk] #8 + WholeStageCodegen (26) + Sort [item_sk] + InputAdapter + Exchange [item_sk] #12 + WholeStageCodegen (25) + Project [item_sk,sumss,rk] + InputAdapter + Window [ss_item_sk,d_date] + WholeStageCodegen (24) + Sort [ss_item_sk,d_date] + InputAdapter + ReusedExchange [item_sk,d_date,sumss,ss_item_sk] #10 InputAdapter - WholeStageCodegen (57) + WholeStageCodegen (61) Project [item_sk,web_sales,store_sales,rk] InputAdapter Window [item_sk,d_date] - WholeStageCodegen (56) + WholeStageCodegen (60) Sort [item_sk,d_date] InputAdapter ReusedExchange [item_sk,d_date,web_sales,store_sales] #1 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q57.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q57.sf100/explain.txt index aa9b899a9308c..d46c1d8c7e336 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q57.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q57.sf100/explain.txt @@ -1,53 +1,56 @@ == Physical Plan == -TakeOrderedAndProject (49) -+- * Project (48) - +- * SortMergeJoin Inner (47) - :- * Project (41) - : +- * SortMergeJoin Inner (40) - : :- * Sort (32) - : : +- * Project (31) - : : +- * Filter (30) - : : +- Window (29) - : : +- * Filter (28) - : : +- Window (27) - : : +- * Sort (26) - : : +- Exchange (25) - : : +- * HashAggregate (24) - : : +- Exchange (23) - : : +- * HashAggregate (22) - : : +- * Project (21) - : : +- * SortMergeJoin Inner (20) - : : :- * Sort (14) - : : : +- Exchange (13) - : : : +- * Project (12) - : : : +- * BroadcastHashJoin Inner BuildRight (11) - : : : :- * Project (6) - : : : : +- * BroadcastHashJoin Inner BuildRight (5) - : : : : :- * Filter (3) - : : : : : +- * ColumnarToRow (2) - : : : : : +- Scan parquet default.catalog_sales (1) - : : : : +- ReusedExchange (4) - : : : +- BroadcastExchange (10) - : : : +- * Filter (9) - : : : +- * ColumnarToRow (8) - : : : +- Scan parquet default.call_center (7) - : : +- * Sort (19) - : : +- Exchange (18) - : : +- * Filter (17) - : : +- * ColumnarToRow (16) - : : +- Scan parquet default.item (15) - : +- * Sort (39) - : +- * Project (38) - : +- Window (37) - : +- * Sort (36) - : +- Exchange (35) - : +- * HashAggregate (34) - : +- ReusedExchange (33) - +- * Sort (46) - +- * Project (45) - +- Window (44) - +- * Sort (43) - +- ReusedExchange (42) +TakeOrderedAndProject (52) ++- * Project (51) + +- * SortMergeJoin Inner (50) + :- * Project (43) + : +- * SortMergeJoin Inner (42) + : :- * Sort (33) + : : +- Exchange (32) + : : +- * Project (31) + : : +- * Filter (30) + : : +- Window (29) + : : +- * Filter (28) + : : +- Window (27) + : : +- * Sort (26) + : : +- Exchange (25) + : : +- * HashAggregate (24) + : : +- Exchange (23) + : : +- * HashAggregate (22) + : : +- * Project (21) + : : +- * SortMergeJoin Inner (20) + : : :- * Sort (14) + : : : +- Exchange (13) + : : : +- * Project (12) + : : : +- * BroadcastHashJoin Inner BuildRight (11) + : : : :- * Project (6) + : : : : +- * BroadcastHashJoin Inner BuildRight (5) + : : : : :- * Filter (3) + : : : : : +- * ColumnarToRow (2) + : : : : : +- Scan parquet default.catalog_sales (1) + : : : : +- ReusedExchange (4) + : : : +- BroadcastExchange (10) + : : : +- * Filter (9) + : : : +- * ColumnarToRow (8) + : : : +- Scan parquet default.call_center (7) + : : +- * Sort (19) + : : +- Exchange (18) + : : +- * Filter (17) + : : +- * ColumnarToRow (16) + : : +- Scan parquet default.item (15) + : +- * Sort (41) + : +- Exchange (40) + : +- * Project (39) + : +- Window (38) + : +- * Sort (37) + : +- Exchange (36) + : +- * HashAggregate (35) + : +- ReusedExchange (34) + +- * Sort (49) + +- Exchange (48) + +- * Project (47) + +- Window (46) + +- * Sort (45) + +- ReusedExchange (44) (1) Scan parquet default.catalog_sales @@ -65,7 +68,7 @@ Input [4]: [cs_call_center_sk#1, cs_item_sk#2, cs_sales_price#3, cs_sold_date_sk Input [4]: [cs_call_center_sk#1, cs_item_sk#2, cs_sales_price#3, cs_sold_date_sk#4] Condition : (isnotnull(cs_item_sk#2) AND isnotnull(cs_call_center_sk#1)) -(4) ReusedExchange [Reuses operator id: 53] +(4) ReusedExchange [Reuses operator id: 56] Output [3]: [d_date_sk#6, d_year#7, d_moy#8] (5) BroadcastHashJoin [codegen id : 3] @@ -183,112 +186,124 @@ Arguments: [avg(_w0#22) windowspecdefinition(i_category#15, i_brand#14, cc_name# (30) Filter [codegen id : 11] Input [9]: [i_category#15, i_brand#14, cc_name#10, d_year#7, d_moy#8, sum_sales#21, _w0#22, rn#24, avg_monthly_sales#25] -Condition : ((isnotnull(avg_monthly_sales#25) AND (avg_monthly_sales#25 > 0.000000)) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#25 as decimal(22,6)))), DecimalType(22,6), true), false)) / promote_precision(cast(avg_monthly_sales#25 as decimal(22,6)))), DecimalType(38,16), true) > 0.1000000000000000)) +Condition : ((isnotnull(avg_monthly_sales#25) AND (avg_monthly_sales#25 > 0.000000)) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#25 as decimal(22,6)))), DecimalType(22,6)))) / promote_precision(cast(avg_monthly_sales#25 as decimal(22,6)))), DecimalType(38,16)) > 0.1000000000000000)) (31) Project [codegen id : 11] Output [8]: [i_category#15, i_brand#14, cc_name#10, d_year#7, d_moy#8, sum_sales#21, avg_monthly_sales#25, rn#24] Input [9]: [i_category#15, i_brand#14, cc_name#10, d_year#7, d_moy#8, sum_sales#21, _w0#22, rn#24, avg_monthly_sales#25] -(32) Sort [codegen id : 11] +(32) Exchange +Input [8]: [i_category#15, i_brand#14, cc_name#10, d_year#7, d_moy#8, sum_sales#21, avg_monthly_sales#25, rn#24] +Arguments: hashpartitioning(i_category#15, i_brand#14, cc_name#10, rn#24, 5), ENSURE_REQUIREMENTS, [id=#26] + +(33) Sort [codegen id : 12] Input [8]: [i_category#15, i_brand#14, cc_name#10, d_year#7, d_moy#8, sum_sales#21, avg_monthly_sales#25, rn#24] Arguments: [i_category#15 ASC NULLS FIRST, i_brand#14 ASC NULLS FIRST, cc_name#10 ASC NULLS FIRST, rn#24 ASC NULLS FIRST], false, 0 -(33) ReusedExchange [Reuses operator id: 23] -Output [6]: [i_category#26, i_brand#27, cc_name#28, d_year#29, d_moy#30, sum#31] +(34) ReusedExchange [Reuses operator id: 23] +Output [6]: [i_category#27, i_brand#28, cc_name#29, d_year#30, d_moy#31, sum#32] -(34) HashAggregate [codegen id : 19] -Input [6]: [i_category#26, i_brand#27, cc_name#28, d_year#29, d_moy#30, sum#31] -Keys [5]: [i_category#26, i_brand#27, cc_name#28, d_year#29, d_moy#30] -Functions [1]: [sum(UnscaledValue(cs_sales_price#32))] -Aggregate Attributes [1]: [sum(UnscaledValue(cs_sales_price#32))#20] -Results [6]: [i_category#26, i_brand#27, cc_name#28, d_year#29, d_moy#30, MakeDecimal(sum(UnscaledValue(cs_sales_price#32))#20,17,2) AS sum_sales#21] +(35) HashAggregate [codegen id : 20] +Input [6]: [i_category#27, i_brand#28, cc_name#29, d_year#30, d_moy#31, sum#32] +Keys [5]: [i_category#27, i_brand#28, cc_name#29, d_year#30, d_moy#31] +Functions [1]: [sum(UnscaledValue(cs_sales_price#33))] +Aggregate Attributes [1]: [sum(UnscaledValue(cs_sales_price#33))#20] +Results [6]: [i_category#27, i_brand#28, cc_name#29, d_year#30, d_moy#31, MakeDecimal(sum(UnscaledValue(cs_sales_price#33))#20,17,2) AS sum_sales#21] -(35) Exchange -Input [6]: [i_category#26, i_brand#27, cc_name#28, d_year#29, d_moy#30, sum_sales#21] -Arguments: hashpartitioning(i_category#26, i_brand#27, cc_name#28, 5), ENSURE_REQUIREMENTS, [id=#33] +(36) Exchange +Input [6]: [i_category#27, i_brand#28, cc_name#29, d_year#30, d_moy#31, sum_sales#21] +Arguments: hashpartitioning(i_category#27, i_brand#28, cc_name#29, 5), ENSURE_REQUIREMENTS, [id=#34] -(36) Sort [codegen id : 20] -Input [6]: [i_category#26, i_brand#27, cc_name#28, d_year#29, d_moy#30, sum_sales#21] -Arguments: [i_category#26 ASC NULLS FIRST, i_brand#27 ASC NULLS FIRST, cc_name#28 ASC NULLS FIRST, d_year#29 ASC NULLS FIRST, d_moy#30 ASC NULLS FIRST], false, 0 +(37) Sort [codegen id : 21] +Input [6]: [i_category#27, i_brand#28, cc_name#29, d_year#30, d_moy#31, sum_sales#21] +Arguments: [i_category#27 ASC NULLS FIRST, i_brand#28 ASC NULLS FIRST, cc_name#29 ASC NULLS FIRST, d_year#30 ASC NULLS FIRST, d_moy#31 ASC NULLS FIRST], false, 0 -(37) Window -Input [6]: [i_category#26, i_brand#27, cc_name#28, d_year#29, d_moy#30, sum_sales#21] -Arguments: [rank(d_year#29, d_moy#30) windowspecdefinition(i_category#26, i_brand#27, cc_name#28, d_year#29 ASC NULLS FIRST, d_moy#30 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rn#34], [i_category#26, i_brand#27, cc_name#28], [d_year#29 ASC NULLS FIRST, d_moy#30 ASC NULLS FIRST] +(38) Window +Input [6]: [i_category#27, i_brand#28, cc_name#29, d_year#30, d_moy#31, sum_sales#21] +Arguments: [rank(d_year#30, d_moy#31) windowspecdefinition(i_category#27, i_brand#28, cc_name#29, d_year#30 ASC NULLS FIRST, d_moy#31 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rn#35], [i_category#27, i_brand#28, cc_name#29], [d_year#30 ASC NULLS FIRST, d_moy#31 ASC NULLS FIRST] -(38) Project [codegen id : 21] -Output [5]: [i_category#26, i_brand#27, cc_name#28, sum_sales#21 AS sum_sales#35, rn#34] -Input [7]: [i_category#26, i_brand#27, cc_name#28, d_year#29, d_moy#30, sum_sales#21, rn#34] +(39) Project [codegen id : 22] +Output [5]: [i_category#27, i_brand#28, cc_name#29, sum_sales#21 AS sum_sales#36, rn#35] +Input [7]: [i_category#27, i_brand#28, cc_name#29, d_year#30, d_moy#31, sum_sales#21, rn#35] -(39) Sort [codegen id : 21] -Input [5]: [i_category#26, i_brand#27, cc_name#28, sum_sales#35, rn#34] -Arguments: [i_category#26 ASC NULLS FIRST, i_brand#27 ASC NULLS FIRST, cc_name#28 ASC NULLS FIRST, (rn#34 + 1) ASC NULLS FIRST], false, 0 +(40) Exchange +Input [5]: [i_category#27, i_brand#28, cc_name#29, sum_sales#36, rn#35] +Arguments: hashpartitioning(i_category#27, i_brand#28, cc_name#29, (rn#35 + 1), 5), ENSURE_REQUIREMENTS, [id=#37] -(40) SortMergeJoin [codegen id : 22] +(41) Sort [codegen id : 23] +Input [5]: [i_category#27, i_brand#28, cc_name#29, sum_sales#36, rn#35] +Arguments: [i_category#27 ASC NULLS FIRST, i_brand#28 ASC NULLS FIRST, cc_name#29 ASC NULLS FIRST, (rn#35 + 1) ASC NULLS FIRST], false, 0 + +(42) SortMergeJoin [codegen id : 24] Left keys [4]: [i_category#15, i_brand#14, cc_name#10, rn#24] -Right keys [4]: [i_category#26, i_brand#27, cc_name#28, (rn#34 + 1)] +Right keys [4]: [i_category#27, i_brand#28, cc_name#29, (rn#35 + 1)] Join condition: None -(41) Project [codegen id : 22] -Output [9]: [i_category#15, i_brand#14, cc_name#10, d_year#7, d_moy#8, sum_sales#21, avg_monthly_sales#25, rn#24, sum_sales#35] -Input [13]: [i_category#15, i_brand#14, cc_name#10, d_year#7, d_moy#8, sum_sales#21, avg_monthly_sales#25, rn#24, i_category#26, i_brand#27, cc_name#28, sum_sales#35, rn#34] +(43) Project [codegen id : 24] +Output [9]: [i_category#15, i_brand#14, cc_name#10, d_year#7, d_moy#8, sum_sales#21, avg_monthly_sales#25, rn#24, sum_sales#36] +Input [13]: [i_category#15, i_brand#14, cc_name#10, d_year#7, d_moy#8, sum_sales#21, avg_monthly_sales#25, rn#24, i_category#27, i_brand#28, cc_name#29, sum_sales#36, rn#35] + +(44) ReusedExchange [Reuses operator id: 36] +Output [6]: [i_category#38, i_brand#39, cc_name#40, d_year#41, d_moy#42, sum_sales#21] -(42) ReusedExchange [Reuses operator id: 35] -Output [6]: [i_category#36, i_brand#37, cc_name#38, d_year#39, d_moy#40, sum_sales#21] +(45) Sort [codegen id : 33] +Input [6]: [i_category#38, i_brand#39, cc_name#40, d_year#41, d_moy#42, sum_sales#21] +Arguments: [i_category#38 ASC NULLS FIRST, i_brand#39 ASC NULLS FIRST, cc_name#40 ASC NULLS FIRST, d_year#41 ASC NULLS FIRST, d_moy#42 ASC NULLS FIRST], false, 0 -(43) Sort [codegen id : 31] -Input [6]: [i_category#36, i_brand#37, cc_name#38, d_year#39, d_moy#40, sum_sales#21] -Arguments: [i_category#36 ASC NULLS FIRST, i_brand#37 ASC NULLS FIRST, cc_name#38 ASC NULLS FIRST, d_year#39 ASC NULLS FIRST, d_moy#40 ASC NULLS FIRST], false, 0 +(46) Window +Input [6]: [i_category#38, i_brand#39, cc_name#40, d_year#41, d_moy#42, sum_sales#21] +Arguments: [rank(d_year#41, d_moy#42) windowspecdefinition(i_category#38, i_brand#39, cc_name#40, d_year#41 ASC NULLS FIRST, d_moy#42 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rn#43], [i_category#38, i_brand#39, cc_name#40], [d_year#41 ASC NULLS FIRST, d_moy#42 ASC NULLS FIRST] -(44) Window -Input [6]: [i_category#36, i_brand#37, cc_name#38, d_year#39, d_moy#40, sum_sales#21] -Arguments: [rank(d_year#39, d_moy#40) windowspecdefinition(i_category#36, i_brand#37, cc_name#38, d_year#39 ASC NULLS FIRST, d_moy#40 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rn#41], [i_category#36, i_brand#37, cc_name#38], [d_year#39 ASC NULLS FIRST, d_moy#40 ASC NULLS FIRST] +(47) Project [codegen id : 34] +Output [5]: [i_category#38, i_brand#39, cc_name#40, sum_sales#21 AS sum_sales#44, rn#43] +Input [7]: [i_category#38, i_brand#39, cc_name#40, d_year#41, d_moy#42, sum_sales#21, rn#43] -(45) Project [codegen id : 32] -Output [5]: [i_category#36, i_brand#37, cc_name#38, sum_sales#21 AS sum_sales#42, rn#41] -Input [7]: [i_category#36, i_brand#37, cc_name#38, d_year#39, d_moy#40, sum_sales#21, rn#41] +(48) Exchange +Input [5]: [i_category#38, i_brand#39, cc_name#40, sum_sales#44, rn#43] +Arguments: hashpartitioning(i_category#38, i_brand#39, cc_name#40, (rn#43 - 1), 5), ENSURE_REQUIREMENTS, [id=#45] -(46) Sort [codegen id : 32] -Input [5]: [i_category#36, i_brand#37, cc_name#38, sum_sales#42, rn#41] -Arguments: [i_category#36 ASC NULLS FIRST, i_brand#37 ASC NULLS FIRST, cc_name#38 ASC NULLS FIRST, (rn#41 - 1) ASC NULLS FIRST], false, 0 +(49) Sort [codegen id : 35] +Input [5]: [i_category#38, i_brand#39, cc_name#40, sum_sales#44, rn#43] +Arguments: [i_category#38 ASC NULLS FIRST, i_brand#39 ASC NULLS FIRST, cc_name#40 ASC NULLS FIRST, (rn#43 - 1) ASC NULLS FIRST], false, 0 -(47) SortMergeJoin [codegen id : 33] +(50) SortMergeJoin [codegen id : 36] Left keys [4]: [i_category#15, i_brand#14, cc_name#10, rn#24] -Right keys [4]: [i_category#36, i_brand#37, cc_name#38, (rn#41 - 1)] +Right keys [4]: [i_category#38, i_brand#39, cc_name#40, (rn#43 - 1)] Join condition: None -(48) Project [codegen id : 33] -Output [8]: [i_category#15, i_brand#14, d_year#7, d_moy#8, avg_monthly_sales#25, sum_sales#21, sum_sales#35 AS psum#43, sum_sales#42 AS nsum#44] -Input [14]: [i_category#15, i_brand#14, cc_name#10, d_year#7, d_moy#8, sum_sales#21, avg_monthly_sales#25, rn#24, sum_sales#35, i_category#36, i_brand#37, cc_name#38, sum_sales#42, rn#41] +(51) Project [codegen id : 36] +Output [8]: [i_category#15, i_brand#14, d_year#7, d_moy#8, avg_monthly_sales#25, sum_sales#21, sum_sales#36 AS psum#46, sum_sales#44 AS nsum#47] +Input [14]: [i_category#15, i_brand#14, cc_name#10, d_year#7, d_moy#8, sum_sales#21, avg_monthly_sales#25, rn#24, sum_sales#36, i_category#38, i_brand#39, cc_name#40, sum_sales#44, rn#43] -(49) TakeOrderedAndProject -Input [8]: [i_category#15, i_brand#14, d_year#7, d_moy#8, avg_monthly_sales#25, sum_sales#21, psum#43, nsum#44] -Arguments: 100, [CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#25 as decimal(22,6)))), DecimalType(22,6), true) ASC NULLS FIRST, d_year#7 ASC NULLS FIRST], [i_category#15, i_brand#14, d_year#7, d_moy#8, avg_monthly_sales#25, sum_sales#21, psum#43, nsum#44] +(52) TakeOrderedAndProject +Input [8]: [i_category#15, i_brand#14, d_year#7, d_moy#8, avg_monthly_sales#25, sum_sales#21, psum#46, nsum#47] +Arguments: 100, [CheckOverflow((promote_precision(cast(sum_sales#21 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#25 as decimal(22,6)))), DecimalType(22,6)) ASC NULLS FIRST, d_year#7 ASC NULLS FIRST], [i_category#15, i_brand#14, d_year#7, d_moy#8, avg_monthly_sales#25, sum_sales#21, psum#46, nsum#47] ===== Subqueries ===== Subquery:1 Hosting operator id = 1 Hosting Expression = cs_sold_date_sk#4 IN dynamicpruning#5 -BroadcastExchange (53) -+- * Filter (52) - +- * ColumnarToRow (51) - +- Scan parquet default.date_dim (50) +BroadcastExchange (56) ++- * Filter (55) + +- * ColumnarToRow (54) + +- Scan parquet default.date_dim (53) -(50) Scan parquet default.date_dim +(53) Scan parquet default.date_dim Output [3]: [d_date_sk#6, d_year#7, d_moy#8] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [Or(Or(EqualTo(d_year,1999),And(EqualTo(d_year,1998),EqualTo(d_moy,12))),And(EqualTo(d_year,2000),EqualTo(d_moy,1))), IsNotNull(d_date_sk)] ReadSchema: struct -(51) ColumnarToRow [codegen id : 1] +(54) ColumnarToRow [codegen id : 1] Input [3]: [d_date_sk#6, d_year#7, d_moy#8] -(52) Filter [codegen id : 1] +(55) Filter [codegen id : 1] Input [3]: [d_date_sk#6, d_year#7, d_moy#8] Condition : ((((d_year#7 = 1999) OR ((d_year#7 = 1998) AND (d_moy#8 = 12))) OR ((d_year#7 = 2000) AND (d_moy#8 = 1))) AND isnotnull(d_date_sk#6)) -(53) BroadcastExchange +(56) BroadcastExchange Input [3]: [d_date_sk#6, d_year#7, d_moy#8] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#45] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#48] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q57.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q57.sf100/simplified.txt index 4389f6035a41b..b464f558bbc1a 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q57.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q57.sf100/simplified.txt @@ -1,95 +1,104 @@ TakeOrderedAndProject [sum_sales,avg_monthly_sales,d_year,i_category,i_brand,d_moy,psum,nsum] - WholeStageCodegen (33) + WholeStageCodegen (36) Project [i_category,i_brand,d_year,d_moy,avg_monthly_sales,sum_sales,sum_sales,sum_sales] SortMergeJoin [i_category,i_brand,cc_name,rn,i_category,i_brand,cc_name,rn] InputAdapter - WholeStageCodegen (22) + WholeStageCodegen (24) Project [i_category,i_brand,cc_name,d_year,d_moy,sum_sales,avg_monthly_sales,rn,sum_sales] SortMergeJoin [i_category,i_brand,cc_name,rn,i_category,i_brand,cc_name,rn] InputAdapter - WholeStageCodegen (11) + WholeStageCodegen (12) Sort [i_category,i_brand,cc_name,rn] - Project [i_category,i_brand,cc_name,d_year,d_moy,sum_sales,avg_monthly_sales,rn] - Filter [avg_monthly_sales,sum_sales] - InputAdapter - Window [_w0,i_category,i_brand,cc_name,d_year] - WholeStageCodegen (10) - Filter [d_year] - InputAdapter - Window [d_year,d_moy,i_category,i_brand,cc_name] - WholeStageCodegen (9) - Sort [i_category,i_brand,cc_name,d_year,d_moy] - InputAdapter - Exchange [i_category,i_brand,cc_name] #1 - WholeStageCodegen (8) - HashAggregate [i_category,i_brand,cc_name,d_year,d_moy,sum] [sum(UnscaledValue(cs_sales_price)),sum_sales,_w0,sum] - InputAdapter - Exchange [i_category,i_brand,cc_name,d_year,d_moy] #2 - WholeStageCodegen (7) - HashAggregate [i_category,i_brand,cc_name,d_year,d_moy,cs_sales_price] [sum,sum] - Project [i_brand,i_category,cs_sales_price,d_year,d_moy,cc_name] - SortMergeJoin [cs_item_sk,i_item_sk] - InputAdapter - WholeStageCodegen (4) - Sort [cs_item_sk] + InputAdapter + Exchange [i_category,i_brand,cc_name,rn] #1 + WholeStageCodegen (11) + Project [i_category,i_brand,cc_name,d_year,d_moy,sum_sales,avg_monthly_sales,rn] + Filter [avg_monthly_sales,sum_sales] + InputAdapter + Window [_w0,i_category,i_brand,cc_name,d_year] + WholeStageCodegen (10) + Filter [d_year] + InputAdapter + Window [d_year,d_moy,i_category,i_brand,cc_name] + WholeStageCodegen (9) + Sort [i_category,i_brand,cc_name,d_year,d_moy] + InputAdapter + Exchange [i_category,i_brand,cc_name] #2 + WholeStageCodegen (8) + HashAggregate [i_category,i_brand,cc_name,d_year,d_moy,sum] [sum(UnscaledValue(cs_sales_price)),sum_sales,_w0,sum] + InputAdapter + Exchange [i_category,i_brand,cc_name,d_year,d_moy] #3 + WholeStageCodegen (7) + HashAggregate [i_category,i_brand,cc_name,d_year,d_moy,cs_sales_price] [sum,sum] + Project [i_brand,i_category,cs_sales_price,d_year,d_moy,cc_name] + SortMergeJoin [cs_item_sk,i_item_sk] InputAdapter - Exchange [cs_item_sk] #3 - WholeStageCodegen (3) - Project [cs_item_sk,cs_sales_price,d_year,d_moy,cc_name] - BroadcastHashJoin [cs_call_center_sk,cc_call_center_sk] - Project [cs_call_center_sk,cs_item_sk,cs_sales_price,d_year,d_moy] - BroadcastHashJoin [cs_sold_date_sk,d_date_sk] - Filter [cs_item_sk,cs_call_center_sk] - ColumnarToRow - InputAdapter - Scan parquet default.catalog_sales [cs_call_center_sk,cs_item_sk,cs_sales_price,cs_sold_date_sk] - SubqueryBroadcast [d_date_sk] #1 - BroadcastExchange #4 - WholeStageCodegen (1) - Filter [d_year,d_moy,d_date_sk] - ColumnarToRow - InputAdapter - Scan parquet default.date_dim [d_date_sk,d_year,d_moy] - InputAdapter - ReusedExchange [d_date_sk,d_year,d_moy] #4 - InputAdapter - BroadcastExchange #5 - WholeStageCodegen (2) - Filter [cc_call_center_sk,cc_name] - ColumnarToRow + WholeStageCodegen (4) + Sort [cs_item_sk] + InputAdapter + Exchange [cs_item_sk] #4 + WholeStageCodegen (3) + Project [cs_item_sk,cs_sales_price,d_year,d_moy,cc_name] + BroadcastHashJoin [cs_call_center_sk,cc_call_center_sk] + Project [cs_call_center_sk,cs_item_sk,cs_sales_price,d_year,d_moy] + BroadcastHashJoin [cs_sold_date_sk,d_date_sk] + Filter [cs_item_sk,cs_call_center_sk] + ColumnarToRow + InputAdapter + Scan parquet default.catalog_sales [cs_call_center_sk,cs_item_sk,cs_sales_price,cs_sold_date_sk] + SubqueryBroadcast [d_date_sk] #1 + BroadcastExchange #5 + WholeStageCodegen (1) + Filter [d_year,d_moy,d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_year,d_moy] InputAdapter - Scan parquet default.call_center [cc_call_center_sk,cc_name] - InputAdapter - WholeStageCodegen (6) - Sort [i_item_sk] + ReusedExchange [d_date_sk,d_year,d_moy] #5 + InputAdapter + BroadcastExchange #6 + WholeStageCodegen (2) + Filter [cc_call_center_sk,cc_name] + ColumnarToRow + InputAdapter + Scan parquet default.call_center [cc_call_center_sk,cc_name] InputAdapter - Exchange [i_item_sk] #6 - WholeStageCodegen (5) - Filter [i_item_sk,i_category,i_brand] - ColumnarToRow - InputAdapter - Scan parquet default.item [i_item_sk,i_brand,i_category] + WholeStageCodegen (6) + Sort [i_item_sk] + InputAdapter + Exchange [i_item_sk] #7 + WholeStageCodegen (5) + Filter [i_item_sk,i_category,i_brand] + ColumnarToRow + InputAdapter + Scan parquet default.item [i_item_sk,i_brand,i_category] InputAdapter - WholeStageCodegen (21) + WholeStageCodegen (23) Sort [i_category,i_brand,cc_name,rn] - Project [i_category,i_brand,cc_name,sum_sales,rn] - InputAdapter - Window [d_year,d_moy,i_category,i_brand,cc_name] - WholeStageCodegen (20) - Sort [i_category,i_brand,cc_name,d_year,d_moy] - InputAdapter - Exchange [i_category,i_brand,cc_name] #7 - WholeStageCodegen (19) - HashAggregate [i_category,i_brand,cc_name,d_year,d_moy,sum] [sum(UnscaledValue(cs_sales_price)),sum_sales,sum] - InputAdapter - ReusedExchange [i_category,i_brand,cc_name,d_year,d_moy,sum] #2 + InputAdapter + Exchange [i_category,i_brand,cc_name,rn] #8 + WholeStageCodegen (22) + Project [i_category,i_brand,cc_name,sum_sales,rn] + InputAdapter + Window [d_year,d_moy,i_category,i_brand,cc_name] + WholeStageCodegen (21) + Sort [i_category,i_brand,cc_name,d_year,d_moy] + InputAdapter + Exchange [i_category,i_brand,cc_name] #9 + WholeStageCodegen (20) + HashAggregate [i_category,i_brand,cc_name,d_year,d_moy,sum] [sum(UnscaledValue(cs_sales_price)),sum_sales,sum] + InputAdapter + ReusedExchange [i_category,i_brand,cc_name,d_year,d_moy,sum] #3 InputAdapter - WholeStageCodegen (32) + WholeStageCodegen (35) Sort [i_category,i_brand,cc_name,rn] - Project [i_category,i_brand,cc_name,sum_sales,rn] - InputAdapter - Window [d_year,d_moy,i_category,i_brand,cc_name] - WholeStageCodegen (31) - Sort [i_category,i_brand,cc_name,d_year,d_moy] - InputAdapter - ReusedExchange [i_category,i_brand,cc_name,d_year,d_moy,sum_sales] #7 + InputAdapter + Exchange [i_category,i_brand,cc_name,rn] #10 + WholeStageCodegen (34) + Project [i_category,i_brand,cc_name,sum_sales,rn] + InputAdapter + Window [d_year,d_moy,i_category,i_brand,cc_name] + WholeStageCodegen (33) + Sort [i_category,i_brand,cc_name,d_year,d_moy] + InputAdapter + ReusedExchange [i_category,i_brand,cc_name,d_year,d_moy,sum_sales] #9 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q57/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q57/explain.txt index 65a811671c32d..675acedcd9cad 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q57/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q57/explain.txt @@ -167,7 +167,7 @@ Arguments: [avg(_w0#21) windowspecdefinition(i_category#3, i_brand#2, cc_name#14 (27) Filter [codegen id : 22] Input [9]: [i_category#3, i_brand#2, cc_name#14, d_year#11, d_moy#12, sum_sales#20, _w0#21, rn#23, avg_monthly_sales#24] -Condition : ((isnotnull(avg_monthly_sales#24) AND (avg_monthly_sales#24 > 0.000000)) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#20 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(22,6), true), false)) / promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(38,16), true) > 0.1000000000000000)) +Condition : ((isnotnull(avg_monthly_sales#24) AND (avg_monthly_sales#24 > 0.000000)) AND (CheckOverflow((promote_precision(abs(CheckOverflow((promote_precision(cast(sum_sales#20 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(22,6)))) / promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(38,16)) > 0.1000000000000000)) (28) Project [codegen id : 22] Output [8]: [i_category#3, i_brand#2, cc_name#14, d_year#11, d_moy#12, sum_sales#20, avg_monthly_sales#24, rn#23] @@ -242,7 +242,7 @@ Input [14]: [i_category#3, i_brand#2, cc_name#14, d_year#11, d_moy#12, sum_sales (45) TakeOrderedAndProject Input [8]: [i_category#3, i_brand#2, d_year#11, d_moy#12, avg_monthly_sales#24, sum_sales#20, psum#44, nsum#45] -Arguments: 100, [CheckOverflow((promote_precision(cast(sum_sales#20 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(22,6), true) ASC NULLS FIRST, d_year#11 ASC NULLS FIRST], [i_category#3, i_brand#2, d_year#11, d_moy#12, avg_monthly_sales#24, sum_sales#20, psum#44, nsum#45] +Arguments: 100, [CheckOverflow((promote_precision(cast(sum_sales#20 as decimal(22,6))) - promote_precision(cast(avg_monthly_sales#24 as decimal(22,6)))), DecimalType(22,6)) ASC NULLS FIRST, d_year#11 ASC NULLS FIRST], [i_category#3, i_brand#2, d_year#11, d_moy#12, avg_monthly_sales#24, sum_sales#20, psum#44, nsum#45] ===== Subqueries ===== diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q5a.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q5a.sf100/explain.txt index b6a5a36a10c6c..88d3ec5d20f2b 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q5a.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q5a.sf100/explain.txt @@ -186,7 +186,7 @@ Input [5]: [s_store_id#23, sum#30, sum#31, sum#32, sum#33] Keys [1]: [s_store_id#23] Functions [4]: [sum(UnscaledValue(sales_price#8)), sum(UnscaledValue(return_amt#10)), sum(UnscaledValue(profit#9)), sum(UnscaledValue(net_loss#11))] Aggregate Attributes [4]: [sum(UnscaledValue(sales_price#8))#35, sum(UnscaledValue(return_amt#10))#36, sum(UnscaledValue(profit#9))#37, sum(UnscaledValue(net_loss#11))#38] -Results [5]: [store channel AS channel#39, concat(store, s_store_id#23) AS id#40, MakeDecimal(sum(UnscaledValue(sales_price#8))#35,17,2) AS sales#41, MakeDecimal(sum(UnscaledValue(return_amt#10))#36,17,2) AS returns#42, CheckOverflow((promote_precision(cast(MakeDecimal(sum(UnscaledValue(profit#9))#37,17,2) as decimal(18,2))) - promote_precision(cast(MakeDecimal(sum(UnscaledValue(net_loss#11))#38,17,2) as decimal(18,2)))), DecimalType(18,2), true) AS profit#43] +Results [5]: [store channel AS channel#39, concat(store, s_store_id#23) AS id#40, MakeDecimal(sum(UnscaledValue(sales_price#8))#35,17,2) AS sales#41, MakeDecimal(sum(UnscaledValue(return_amt#10))#36,17,2) AS returns#42, CheckOverflow((promote_precision(cast(MakeDecimal(sum(UnscaledValue(profit#9))#37,17,2) as decimal(18,2))) - promote_precision(cast(MakeDecimal(sum(UnscaledValue(net_loss#11))#38,17,2) as decimal(18,2)))), DecimalType(18,2)) AS profit#43] (22) Scan parquet default.catalog_sales Output [4]: [cs_catalog_page_sk#44, cs_ext_sales_price#45, cs_net_profit#46, cs_sold_date_sk#47] @@ -283,7 +283,7 @@ Input [5]: [cp_catalog_page_id#65, sum#72, sum#73, sum#74, sum#75] Keys [1]: [cp_catalog_page_id#65] Functions [4]: [sum(UnscaledValue(sales_price#50)), sum(UnscaledValue(return_amt#52)), sum(UnscaledValue(profit#51)), sum(UnscaledValue(net_loss#53))] Aggregate Attributes [4]: [sum(UnscaledValue(sales_price#50))#77, sum(UnscaledValue(return_amt#52))#78, sum(UnscaledValue(profit#51))#79, sum(UnscaledValue(net_loss#53))#80] -Results [5]: [catalog channel AS channel#81, concat(catalog_page, cp_catalog_page_id#65) AS id#82, MakeDecimal(sum(UnscaledValue(sales_price#50))#77,17,2) AS sales#83, MakeDecimal(sum(UnscaledValue(return_amt#52))#78,17,2) AS returns#84, CheckOverflow((promote_precision(cast(MakeDecimal(sum(UnscaledValue(profit#51))#79,17,2) as decimal(18,2))) - promote_precision(cast(MakeDecimal(sum(UnscaledValue(net_loss#53))#80,17,2) as decimal(18,2)))), DecimalType(18,2), true) AS profit#85] +Results [5]: [catalog channel AS channel#81, concat(catalog_page, cp_catalog_page_id#65) AS id#82, MakeDecimal(sum(UnscaledValue(sales_price#50))#77,17,2) AS sales#83, MakeDecimal(sum(UnscaledValue(return_amt#52))#78,17,2) AS returns#84, CheckOverflow((promote_precision(cast(MakeDecimal(sum(UnscaledValue(profit#51))#79,17,2) as decimal(18,2))) - promote_precision(cast(MakeDecimal(sum(UnscaledValue(net_loss#53))#80,17,2) as decimal(18,2)))), DecimalType(18,2)) AS profit#85] (43) Scan parquet default.web_sales Output [4]: [ws_web_site_sk#86, ws_ext_sales_price#87, ws_net_profit#88, ws_sold_date_sk#89] @@ -414,7 +414,7 @@ Input [5]: [web_site_id#114, sum#121, sum#122, sum#123, sum#124] Keys [1]: [web_site_id#114] Functions [4]: [sum(UnscaledValue(sales_price#92)), sum(UnscaledValue(return_amt#94)), sum(UnscaledValue(profit#93)), sum(UnscaledValue(net_loss#95))] Aggregate Attributes [4]: [sum(UnscaledValue(sales_price#92))#126, sum(UnscaledValue(return_amt#94))#127, sum(UnscaledValue(profit#93))#128, sum(UnscaledValue(net_loss#95))#129] -Results [5]: [web channel AS channel#130, concat(web_site, web_site_id#114) AS id#131, MakeDecimal(sum(UnscaledValue(sales_price#92))#126,17,2) AS sales#132, MakeDecimal(sum(UnscaledValue(return_amt#94))#127,17,2) AS returns#133, CheckOverflow((promote_precision(cast(MakeDecimal(sum(UnscaledValue(profit#93))#128,17,2) as decimal(18,2))) - promote_precision(cast(MakeDecimal(sum(UnscaledValue(net_loss#95))#129,17,2) as decimal(18,2)))), DecimalType(18,2), true) AS profit#134] +Results [5]: [web channel AS channel#130, concat(web_site, web_site_id#114) AS id#131, MakeDecimal(sum(UnscaledValue(sales_price#92))#126,17,2) AS sales#132, MakeDecimal(sum(UnscaledValue(return_amt#94))#127,17,2) AS returns#133, CheckOverflow((promote_precision(cast(MakeDecimal(sum(UnscaledValue(profit#93))#128,17,2) as decimal(18,2))) - promote_precision(cast(MakeDecimal(sum(UnscaledValue(net_loss#95))#129,17,2) as decimal(18,2)))), DecimalType(18,2)) AS profit#134] (72) Union diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q5a/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q5a/explain.txt index 05636f5f44067..cadbb12000ba3 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q5a/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q5a/explain.txt @@ -183,7 +183,7 @@ Input [5]: [s_store_id#24, sum#30, sum#31, sum#32, sum#33] Keys [1]: [s_store_id#24] Functions [4]: [sum(UnscaledValue(sales_price#8)), sum(UnscaledValue(return_amt#10)), sum(UnscaledValue(profit#9)), sum(UnscaledValue(net_loss#11))] Aggregate Attributes [4]: [sum(UnscaledValue(sales_price#8))#35, sum(UnscaledValue(return_amt#10))#36, sum(UnscaledValue(profit#9))#37, sum(UnscaledValue(net_loss#11))#38] -Results [5]: [store channel AS channel#39, concat(store, s_store_id#24) AS id#40, MakeDecimal(sum(UnscaledValue(sales_price#8))#35,17,2) AS sales#41, MakeDecimal(sum(UnscaledValue(return_amt#10))#36,17,2) AS returns#42, CheckOverflow((promote_precision(cast(MakeDecimal(sum(UnscaledValue(profit#9))#37,17,2) as decimal(18,2))) - promote_precision(cast(MakeDecimal(sum(UnscaledValue(net_loss#11))#38,17,2) as decimal(18,2)))), DecimalType(18,2), true) AS profit#43] +Results [5]: [store channel AS channel#39, concat(store, s_store_id#24) AS id#40, MakeDecimal(sum(UnscaledValue(sales_price#8))#35,17,2) AS sales#41, MakeDecimal(sum(UnscaledValue(return_amt#10))#36,17,2) AS returns#42, CheckOverflow((promote_precision(cast(MakeDecimal(sum(UnscaledValue(profit#9))#37,17,2) as decimal(18,2))) - promote_precision(cast(MakeDecimal(sum(UnscaledValue(net_loss#11))#38,17,2) as decimal(18,2)))), DecimalType(18,2)) AS profit#43] (22) Scan parquet default.catalog_sales Output [4]: [cs_catalog_page_sk#44, cs_ext_sales_price#45, cs_net_profit#46, cs_sold_date_sk#47] @@ -280,7 +280,7 @@ Input [5]: [cp_catalog_page_id#66, sum#72, sum#73, sum#74, sum#75] Keys [1]: [cp_catalog_page_id#66] Functions [4]: [sum(UnscaledValue(sales_price#50)), sum(UnscaledValue(return_amt#52)), sum(UnscaledValue(profit#51)), sum(UnscaledValue(net_loss#53))] Aggregate Attributes [4]: [sum(UnscaledValue(sales_price#50))#77, sum(UnscaledValue(return_amt#52))#78, sum(UnscaledValue(profit#51))#79, sum(UnscaledValue(net_loss#53))#80] -Results [5]: [catalog channel AS channel#81, concat(catalog_page, cp_catalog_page_id#66) AS id#82, MakeDecimal(sum(UnscaledValue(sales_price#50))#77,17,2) AS sales#83, MakeDecimal(sum(UnscaledValue(return_amt#52))#78,17,2) AS returns#84, CheckOverflow((promote_precision(cast(MakeDecimal(sum(UnscaledValue(profit#51))#79,17,2) as decimal(18,2))) - promote_precision(cast(MakeDecimal(sum(UnscaledValue(net_loss#53))#80,17,2) as decimal(18,2)))), DecimalType(18,2), true) AS profit#85] +Results [5]: [catalog channel AS channel#81, concat(catalog_page, cp_catalog_page_id#66) AS id#82, MakeDecimal(sum(UnscaledValue(sales_price#50))#77,17,2) AS sales#83, MakeDecimal(sum(UnscaledValue(return_amt#52))#78,17,2) AS returns#84, CheckOverflow((promote_precision(cast(MakeDecimal(sum(UnscaledValue(profit#51))#79,17,2) as decimal(18,2))) - promote_precision(cast(MakeDecimal(sum(UnscaledValue(net_loss#53))#80,17,2) as decimal(18,2)))), DecimalType(18,2)) AS profit#85] (43) Scan parquet default.web_sales Output [4]: [ws_web_site_sk#86, ws_ext_sales_price#87, ws_net_profit#88, ws_sold_date_sk#89] @@ -399,7 +399,7 @@ Input [5]: [web_site_id#114, sum#120, sum#121, sum#122, sum#123] Keys [1]: [web_site_id#114] Functions [4]: [sum(UnscaledValue(sales_price#92)), sum(UnscaledValue(return_amt#94)), sum(UnscaledValue(profit#93)), sum(UnscaledValue(net_loss#95))] Aggregate Attributes [4]: [sum(UnscaledValue(sales_price#92))#125, sum(UnscaledValue(return_amt#94))#126, sum(UnscaledValue(profit#93))#127, sum(UnscaledValue(net_loss#95))#128] -Results [5]: [web channel AS channel#129, concat(web_site, web_site_id#114) AS id#130, MakeDecimal(sum(UnscaledValue(sales_price#92))#125,17,2) AS sales#131, MakeDecimal(sum(UnscaledValue(return_amt#94))#126,17,2) AS returns#132, CheckOverflow((promote_precision(cast(MakeDecimal(sum(UnscaledValue(profit#93))#127,17,2) as decimal(18,2))) - promote_precision(cast(MakeDecimal(sum(UnscaledValue(net_loss#95))#128,17,2) as decimal(18,2)))), DecimalType(18,2), true) AS profit#133] +Results [5]: [web channel AS channel#129, concat(web_site, web_site_id#114) AS id#130, MakeDecimal(sum(UnscaledValue(sales_price#92))#125,17,2) AS sales#131, MakeDecimal(sum(UnscaledValue(return_amt#94))#126,17,2) AS returns#132, CheckOverflow((promote_precision(cast(MakeDecimal(sum(UnscaledValue(profit#93))#127,17,2) as decimal(18,2))) - promote_precision(cast(MakeDecimal(sum(UnscaledValue(net_loss#95))#128,17,2) as decimal(18,2)))), DecimalType(18,2)) AS profit#133] (69) Union diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q6.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q6.sf100/explain.txt index c7ccb242056f7..1992b08c26b23 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q6.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q6.sf100/explain.txt @@ -118,7 +118,7 @@ Join condition: None (15) Filter [codegen id : 3] Input [5]: [i_item_sk#5, i_current_price#6, i_category#7, avg(i_current_price)#16, i_category#9] -Condition : (cast(i_current_price#6 as decimal(14,7)) > CheckOverflow((1.200000 * promote_precision(avg(i_current_price)#16)), DecimalType(14,7), true)) +Condition : (cast(i_current_price#6 as decimal(14,7)) > CheckOverflow((1.200000 * promote_precision(avg(i_current_price)#16)), DecimalType(14,7))) (16) Project [codegen id : 3] Output [1]: [i_item_sk#5] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q6/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q6/explain.txt index 0e1ea31859b3d..918c6c375a9ea 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q6/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q6/explain.txt @@ -178,7 +178,7 @@ Join condition: None (30) Filter [codegen id : 6] Input [5]: [i_item_sk#12, i_current_price#13, i_category#14, avg(i_current_price)#23, i_category#16] -Condition : (cast(i_current_price#13 as decimal(14,7)) > CheckOverflow((1.200000 * promote_precision(avg(i_current_price)#23)), DecimalType(14,7), true)) +Condition : (cast(i_current_price#13 as decimal(14,7)) > CheckOverflow((1.200000 * promote_precision(avg(i_current_price)#23)), DecimalType(14,7))) (31) Project [codegen id : 6] Output [1]: [i_item_sk#12] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q64.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q64.sf100/explain.txt index 19240a79cc91c..868f1f26459aa 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q64.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q64.sf100/explain.txt @@ -332,7 +332,7 @@ Input [8]: [cs_item_sk#19, cs_order_number#20, cs_ext_list_price#21, cr_item_sk# (28) HashAggregate [codegen id : 9] Input [5]: [cs_item_sk#19, cs_ext_list_price#21, cr_refunded_cash#26, cr_reversed_charge#27, cr_store_credit#28] Keys [1]: [cs_item_sk#19] -Functions [2]: [partial_sum(UnscaledValue(cs_ext_list_price#21)), partial_sum(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cr_refunded_cash#26 as decimal(8,2))) + promote_precision(cast(cr_reversed_charge#27 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) + promote_precision(cast(cr_store_credit#28 as decimal(9,2)))), DecimalType(9,2), true))] +Functions [2]: [partial_sum(UnscaledValue(cs_ext_list_price#21)), partial_sum(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cr_refunded_cash#26 as decimal(8,2))) + promote_precision(cast(cr_reversed_charge#27 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) + promote_precision(cast(cr_store_credit#28 as decimal(9,2)))), DecimalType(9,2)))] Aggregate Attributes [3]: [sum#31, sum#32, isEmpty#33] Results [4]: [cs_item_sk#19, sum#34, sum#35, isEmpty#36] @@ -343,13 +343,13 @@ Arguments: hashpartitioning(cs_item_sk#19, 5), ENSURE_REQUIREMENTS, [id=#37] (30) HashAggregate [codegen id : 10] Input [4]: [cs_item_sk#19, sum#34, sum#35, isEmpty#36] Keys [1]: [cs_item_sk#19] -Functions [2]: [sum(UnscaledValue(cs_ext_list_price#21)), sum(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cr_refunded_cash#26 as decimal(8,2))) + promote_precision(cast(cr_reversed_charge#27 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) + promote_precision(cast(cr_store_credit#28 as decimal(9,2)))), DecimalType(9,2), true))] -Aggregate Attributes [2]: [sum(UnscaledValue(cs_ext_list_price#21))#38, sum(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cr_refunded_cash#26 as decimal(8,2))) + promote_precision(cast(cr_reversed_charge#27 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) + promote_precision(cast(cr_store_credit#28 as decimal(9,2)))), DecimalType(9,2), true))#39] -Results [3]: [cs_item_sk#19, MakeDecimal(sum(UnscaledValue(cs_ext_list_price#21))#38,17,2) AS sale#40, sum(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cr_refunded_cash#26 as decimal(8,2))) + promote_precision(cast(cr_reversed_charge#27 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) + promote_precision(cast(cr_store_credit#28 as decimal(9,2)))), DecimalType(9,2), true))#39 AS refund#41] +Functions [2]: [sum(UnscaledValue(cs_ext_list_price#21)), sum(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cr_refunded_cash#26 as decimal(8,2))) + promote_precision(cast(cr_reversed_charge#27 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) + promote_precision(cast(cr_store_credit#28 as decimal(9,2)))), DecimalType(9,2)))] +Aggregate Attributes [2]: [sum(UnscaledValue(cs_ext_list_price#21))#38, sum(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cr_refunded_cash#26 as decimal(8,2))) + promote_precision(cast(cr_reversed_charge#27 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) + promote_precision(cast(cr_store_credit#28 as decimal(9,2)))), DecimalType(9,2)))#39] +Results [3]: [cs_item_sk#19, MakeDecimal(sum(UnscaledValue(cs_ext_list_price#21))#38,17,2) AS sale#40, sum(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cr_refunded_cash#26 as decimal(8,2))) + promote_precision(cast(cr_reversed_charge#27 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) + promote_precision(cast(cr_store_credit#28 as decimal(9,2)))), DecimalType(9,2)))#39 AS refund#41] (31) Filter [codegen id : 10] Input [3]: [cs_item_sk#19, sale#40, refund#41] -Condition : (isnotnull(sale#40) AND (cast(sale#40 as decimal(21,2)) > CheckOverflow((2.00 * promote_precision(refund#41)), DecimalType(21,2), true))) +Condition : (isnotnull(sale#40) AND (cast(sale#40 as decimal(21,2)) > CheckOverflow((2.00 * promote_precision(refund#41)), DecimalType(21,2)))) (32) Project [codegen id : 10] Output [1]: [cs_item_sk#19] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q64.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q64.sf100/simplified.txt index b5ebf7af31bed..00becee05ec8c 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q64.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q64.sf100/simplified.txt @@ -113,7 +113,7 @@ WholeStageCodegen (88) WholeStageCodegen (10) Project [cs_item_sk] Filter [sale,refund] - HashAggregate [cs_item_sk,sum,sum,isEmpty] [sum(UnscaledValue(cs_ext_list_price)),sum(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cr_refunded_cash as decimal(8,2))) + promote_precision(cast(cr_reversed_charge as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) + promote_precision(cast(cr_store_credit as decimal(9,2)))), DecimalType(9,2), true)),sale,refund,sum,sum,isEmpty] + HashAggregate [cs_item_sk,sum,sum,isEmpty] [sum(UnscaledValue(cs_ext_list_price)),sum(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cr_refunded_cash as decimal(8,2))) + promote_precision(cast(cr_reversed_charge as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) + promote_precision(cast(cr_store_credit as decimal(9,2)))), DecimalType(9,2))),sale,refund,sum,sum,isEmpty] InputAdapter Exchange [cs_item_sk] #13 WholeStageCodegen (9) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q64/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q64/explain.txt index cfee2290adff9..426b408190662 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q64/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q64/explain.txt @@ -1,185 +1,187 @@ == Physical Plan == -* Sort (181) -+- Exchange (180) - +- * Project (179) - +- * SortMergeJoin Inner (178) - :- * Sort (110) - : +- * HashAggregate (109) - : +- * HashAggregate (108) - : +- * Project (107) - : +- * BroadcastHashJoin Inner BuildRight (106) - : :- * Project (100) - : : +- * BroadcastHashJoin Inner BuildRight (99) - : : :- * Project (97) - : : : +- * BroadcastHashJoin Inner BuildRight (96) - : : : :- * Project (91) - : : : : +- * BroadcastHashJoin Inner BuildRight (90) - : : : : :- * Project (88) - : : : : : +- * BroadcastHashJoin Inner BuildRight (87) - : : : : : :- * Project (82) - : : : : : : +- * BroadcastHashJoin Inner BuildRight (81) - : : : : : : :- * Project (79) - : : : : : : : +- * BroadcastHashJoin Inner BuildRight (78) - : : : : : : : :- * Project (73) - : : : : : : : : +- * BroadcastHashJoin Inner BuildRight (72) - : : : : : : : : :- * Project (67) - : : : : : : : : : +- * BroadcastHashJoin Inner BuildRight (66) - : : : : : : : : : :- * Project (64) - : : : : : : : : : : +- * BroadcastHashJoin Inner BuildRight (63) - : : : : : : : : : : :- * Project (58) - : : : : : : : : : : : +- * BroadcastHashJoin Inner BuildRight (57) - : : : : : : : : : : : :- * Project (55) - : : : : : : : : : : : : +- * BroadcastHashJoin Inner BuildRight (54) - : : : : : : : : : : : : :- * Project (49) - : : : : : : : : : : : : : +- * BroadcastHashJoin Inner BuildRight (48) - : : : : : : : : : : : : : :- * Project (43) - : : : : : : : : : : : : : : +- * BroadcastHashJoin Inner BuildRight (42) - : : : : : : : : : : : : : : :- * Project (37) - : : : : : : : : : : : : : : : +- * BroadcastHashJoin Inner BuildRight (36) - : : : : : : : : : : : : : : : :- * Project (34) - : : : : : : : : : : : : : : : : +- * SortMergeJoin Inner (33) - : : : : : : : : : : : : : : : : :- * Sort (12) - : : : : : : : : : : : : : : : : : +- Exchange (11) - : : : : : : : : : : : : : : : : : +- * Project (10) - : : : : : : : : : : : : : : : : : +- * BroadcastHashJoin Inner BuildLeft (9) - : : : : : : : : : : : : : : : : : :- BroadcastExchange (4) - : : : : : : : : : : : : : : : : : : +- * Filter (3) - : : : : : : : : : : : : : : : : : : +- * ColumnarToRow (2) - : : : : : : : : : : : : : : : : : : +- Scan parquet default.store_sales (1) - : : : : : : : : : : : : : : : : : +- * Project (8) - : : : : : : : : : : : : : : : : : +- * Filter (7) - : : : : : : : : : : : : : : : : : +- * ColumnarToRow (6) - : : : : : : : : : : : : : : : : : +- Scan parquet default.store_returns (5) - : : : : : : : : : : : : : : : : +- * Sort (32) - : : : : : : : : : : : : : : : : +- * Project (31) - : : : : : : : : : : : : : : : : +- * Filter (30) - : : : : : : : : : : : : : : : : +- * HashAggregate (29) - : : : : : : : : : : : : : : : : +- Exchange (28) - : : : : : : : : : : : : : : : : +- * HashAggregate (27) - : : : : : : : : : : : : : : : : +- * Project (26) - : : : : : : : : : : : : : : : : +- * SortMergeJoin Inner (25) - : : : : : : : : : : : : : : : : :- * Sort (18) - : : : : : : : : : : : : : : : : : +- Exchange (17) - : : : : : : : : : : : : : : : : : +- * Project (16) - : : : : : : : : : : : : : : : : : +- * Filter (15) - : : : : : : : : : : : : : : : : : +- * ColumnarToRow (14) - : : : : : : : : : : : : : : : : : +- Scan parquet default.catalog_sales (13) - : : : : : : : : : : : : : : : : +- * Sort (24) - : : : : : : : : : : : : : : : : +- Exchange (23) - : : : : : : : : : : : : : : : : +- * Project (22) - : : : : : : : : : : : : : : : : +- * Filter (21) - : : : : : : : : : : : : : : : : +- * ColumnarToRow (20) - : : : : : : : : : : : : : : : : +- Scan parquet default.catalog_returns (19) - : : : : : : : : : : : : : : : +- ReusedExchange (35) - : : : : : : : : : : : : : : +- BroadcastExchange (41) - : : : : : : : : : : : : : : +- * Filter (40) - : : : : : : : : : : : : : : +- * ColumnarToRow (39) - : : : : : : : : : : : : : : +- Scan parquet default.store (38) - : : : : : : : : : : : : : +- BroadcastExchange (47) - : : : : : : : : : : : : : +- * Filter (46) - : : : : : : : : : : : : : +- * ColumnarToRow (45) - : : : : : : : : : : : : : +- Scan parquet default.customer (44) - : : : : : : : : : : : : +- BroadcastExchange (53) - : : : : : : : : : : : : +- * Filter (52) - : : : : : : : : : : : : +- * ColumnarToRow (51) - : : : : : : : : : : : : +- Scan parquet default.date_dim (50) - : : : : : : : : : : : +- ReusedExchange (56) - : : : : : : : : : : +- BroadcastExchange (62) - : : : : : : : : : : +- * Filter (61) - : : : : : : : : : : +- * ColumnarToRow (60) - : : : : : : : : : : +- Scan parquet default.customer_demographics (59) - : : : : : : : : : +- ReusedExchange (65) - : : : : : : : : +- BroadcastExchange (71) - : : : : : : : : +- * Filter (70) - : : : : : : : : +- * ColumnarToRow (69) - : : : : : : : : +- Scan parquet default.promotion (68) - : : : : : : : +- BroadcastExchange (77) - : : : : : : : +- * Filter (76) - : : : : : : : +- * ColumnarToRow (75) - : : : : : : : +- Scan parquet default.household_demographics (74) - : : : : : : +- ReusedExchange (80) - : : : : : +- BroadcastExchange (86) - : : : : : +- * Filter (85) - : : : : : +- * ColumnarToRow (84) - : : : : : +- Scan parquet default.customer_address (83) - : : : : +- ReusedExchange (89) - : : : +- BroadcastExchange (95) - : : : +- * Filter (94) - : : : +- * ColumnarToRow (93) - : : : +- Scan parquet default.income_band (92) - : : +- ReusedExchange (98) - : +- BroadcastExchange (105) - : +- * Project (104) - : +- * Filter (103) - : +- * ColumnarToRow (102) - : +- Scan parquet default.item (101) - +- * Sort (177) - +- * HashAggregate (176) - +- * HashAggregate (175) - +- * Project (174) - +- * BroadcastHashJoin Inner BuildRight (173) - :- * Project (171) - : +- * BroadcastHashJoin Inner BuildRight (170) - : :- * Project (168) - : : +- * BroadcastHashJoin Inner BuildRight (167) - : : :- * Project (165) - : : : +- * BroadcastHashJoin Inner BuildRight (164) - : : : :- * Project (162) - : : : : +- * BroadcastHashJoin Inner BuildRight (161) - : : : : :- * Project (159) - : : : : : +- * BroadcastHashJoin Inner BuildRight (158) - : : : : : :- * Project (156) - : : : : : : +- * BroadcastHashJoin Inner BuildRight (155) - : : : : : : :- * Project (153) - : : : : : : : +- * BroadcastHashJoin Inner BuildRight (152) - : : : : : : : :- * Project (150) - : : : : : : : : +- * BroadcastHashJoin Inner BuildRight (149) - : : : : : : : : :- * Project (147) - : : : : : : : : : +- * BroadcastHashJoin Inner BuildRight (146) - : : : : : : : : : :- * Project (144) - : : : : : : : : : : +- * BroadcastHashJoin Inner BuildRight (143) - : : : : : : : : : : :- * Project (141) - : : : : : : : : : : : +- * BroadcastHashJoin Inner BuildRight (140) - : : : : : : : : : : : :- * Project (138) - : : : : : : : : : : : : +- * BroadcastHashJoin Inner BuildRight (137) - : : : : : : : : : : : : :- * Project (135) - : : : : : : : : : : : : : +- * BroadcastHashJoin Inner BuildRight (134) - : : : : : : : : : : : : : :- * Project (132) - : : : : : : : : : : : : : : +- * BroadcastHashJoin Inner BuildRight (131) - : : : : : : : : : : : : : : :- * Project (129) - : : : : : : : : : : : : : : : +- * SortMergeJoin Inner (128) - : : : : : : : : : : : : : : : :- * Sort (122) - : : : : : : : : : : : : : : : : +- Exchange (121) - : : : : : : : : : : : : : : : : +- * Project (120) - : : : : : : : : : : : : : : : : +- * BroadcastHashJoin Inner BuildLeft (119) - : : : : : : : : : : : : : : : : :- BroadcastExchange (114) - : : : : : : : : : : : : : : : : : +- * Filter (113) - : : : : : : : : : : : : : : : : : +- * ColumnarToRow (112) - : : : : : : : : : : : : : : : : : +- Scan parquet default.store_sales (111) - : : : : : : : : : : : : : : : : +- * Project (118) - : : : : : : : : : : : : : : : : +- * Filter (117) - : : : : : : : : : : : : : : : : +- * ColumnarToRow (116) - : : : : : : : : : : : : : : : : +- Scan parquet default.store_returns (115) - : : : : : : : : : : : : : : : +- * Sort (127) - : : : : : : : : : : : : : : : +- * Project (126) - : : : : : : : : : : : : : : : +- * Filter (125) - : : : : : : : : : : : : : : : +- * HashAggregate (124) - : : : : : : : : : : : : : : : +- ReusedExchange (123) - : : : : : : : : : : : : : : +- ReusedExchange (130) - : : : : : : : : : : : : : +- ReusedExchange (133) - : : : : : : : : : : : : +- ReusedExchange (136) - : : : : : : : : : : : +- ReusedExchange (139) - : : : : : : : : : : +- ReusedExchange (142) - : : : : : : : : : +- ReusedExchange (145) - : : : : : : : : +- ReusedExchange (148) - : : : : : : : +- ReusedExchange (151) - : : : : : : +- ReusedExchange (154) - : : : : : +- ReusedExchange (157) - : : : : +- ReusedExchange (160) - : : : +- ReusedExchange (163) - : : +- ReusedExchange (166) - : +- ReusedExchange (169) - +- ReusedExchange (172) +* Sort (183) ++- Exchange (182) + +- * Project (181) + +- * SortMergeJoin Inner (180) + :- * Sort (111) + : +- Exchange (110) + : +- * HashAggregate (109) + : +- * HashAggregate (108) + : +- * Project (107) + : +- * BroadcastHashJoin Inner BuildRight (106) + : :- * Project (100) + : : +- * BroadcastHashJoin Inner BuildRight (99) + : : :- * Project (97) + : : : +- * BroadcastHashJoin Inner BuildRight (96) + : : : :- * Project (91) + : : : : +- * BroadcastHashJoin Inner BuildRight (90) + : : : : :- * Project (88) + : : : : : +- * BroadcastHashJoin Inner BuildRight (87) + : : : : : :- * Project (82) + : : : : : : +- * BroadcastHashJoin Inner BuildRight (81) + : : : : : : :- * Project (79) + : : : : : : : +- * BroadcastHashJoin Inner BuildRight (78) + : : : : : : : :- * Project (73) + : : : : : : : : +- * BroadcastHashJoin Inner BuildRight (72) + : : : : : : : : :- * Project (67) + : : : : : : : : : +- * BroadcastHashJoin Inner BuildRight (66) + : : : : : : : : : :- * Project (64) + : : : : : : : : : : +- * BroadcastHashJoin Inner BuildRight (63) + : : : : : : : : : : :- * Project (58) + : : : : : : : : : : : +- * BroadcastHashJoin Inner BuildRight (57) + : : : : : : : : : : : :- * Project (55) + : : : : : : : : : : : : +- * BroadcastHashJoin Inner BuildRight (54) + : : : : : : : : : : : : :- * Project (49) + : : : : : : : : : : : : : +- * BroadcastHashJoin Inner BuildRight (48) + : : : : : : : : : : : : : :- * Project (43) + : : : : : : : : : : : : : : +- * BroadcastHashJoin Inner BuildRight (42) + : : : : : : : : : : : : : : :- * Project (37) + : : : : : : : : : : : : : : : +- * BroadcastHashJoin Inner BuildRight (36) + : : : : : : : : : : : : : : : :- * Project (34) + : : : : : : : : : : : : : : : : +- * SortMergeJoin Inner (33) + : : : : : : : : : : : : : : : : :- * Sort (12) + : : : : : : : : : : : : : : : : : +- Exchange (11) + : : : : : : : : : : : : : : : : : +- * Project (10) + : : : : : : : : : : : : : : : : : +- * BroadcastHashJoin Inner BuildLeft (9) + : : : : : : : : : : : : : : : : : :- BroadcastExchange (4) + : : : : : : : : : : : : : : : : : : +- * Filter (3) + : : : : : : : : : : : : : : : : : : +- * ColumnarToRow (2) + : : : : : : : : : : : : : : : : : : +- Scan parquet default.store_sales (1) + : : : : : : : : : : : : : : : : : +- * Project (8) + : : : : : : : : : : : : : : : : : +- * Filter (7) + : : : : : : : : : : : : : : : : : +- * ColumnarToRow (6) + : : : : : : : : : : : : : : : : : +- Scan parquet default.store_returns (5) + : : : : : : : : : : : : : : : : +- * Sort (32) + : : : : : : : : : : : : : : : : +- * Project (31) + : : : : : : : : : : : : : : : : +- * Filter (30) + : : : : : : : : : : : : : : : : +- * HashAggregate (29) + : : : : : : : : : : : : : : : : +- Exchange (28) + : : : : : : : : : : : : : : : : +- * HashAggregate (27) + : : : : : : : : : : : : : : : : +- * Project (26) + : : : : : : : : : : : : : : : : +- * SortMergeJoin Inner (25) + : : : : : : : : : : : : : : : : :- * Sort (18) + : : : : : : : : : : : : : : : : : +- Exchange (17) + : : : : : : : : : : : : : : : : : +- * Project (16) + : : : : : : : : : : : : : : : : : +- * Filter (15) + : : : : : : : : : : : : : : : : : +- * ColumnarToRow (14) + : : : : : : : : : : : : : : : : : +- Scan parquet default.catalog_sales (13) + : : : : : : : : : : : : : : : : +- * Sort (24) + : : : : : : : : : : : : : : : : +- Exchange (23) + : : : : : : : : : : : : : : : : +- * Project (22) + : : : : : : : : : : : : : : : : +- * Filter (21) + : : : : : : : : : : : : : : : : +- * ColumnarToRow (20) + : : : : : : : : : : : : : : : : +- Scan parquet default.catalog_returns (19) + : : : : : : : : : : : : : : : +- ReusedExchange (35) + : : : : : : : : : : : : : : +- BroadcastExchange (41) + : : : : : : : : : : : : : : +- * Filter (40) + : : : : : : : : : : : : : : +- * ColumnarToRow (39) + : : : : : : : : : : : : : : +- Scan parquet default.store (38) + : : : : : : : : : : : : : +- BroadcastExchange (47) + : : : : : : : : : : : : : +- * Filter (46) + : : : : : : : : : : : : : +- * ColumnarToRow (45) + : : : : : : : : : : : : : +- Scan parquet default.customer (44) + : : : : : : : : : : : : +- BroadcastExchange (53) + : : : : : : : : : : : : +- * Filter (52) + : : : : : : : : : : : : +- * ColumnarToRow (51) + : : : : : : : : : : : : +- Scan parquet default.date_dim (50) + : : : : : : : : : : : +- ReusedExchange (56) + : : : : : : : : : : +- BroadcastExchange (62) + : : : : : : : : : : +- * Filter (61) + : : : : : : : : : : +- * ColumnarToRow (60) + : : : : : : : : : : +- Scan parquet default.customer_demographics (59) + : : : : : : : : : +- ReusedExchange (65) + : : : : : : : : +- BroadcastExchange (71) + : : : : : : : : +- * Filter (70) + : : : : : : : : +- * ColumnarToRow (69) + : : : : : : : : +- Scan parquet default.promotion (68) + : : : : : : : +- BroadcastExchange (77) + : : : : : : : +- * Filter (76) + : : : : : : : +- * ColumnarToRow (75) + : : : : : : : +- Scan parquet default.household_demographics (74) + : : : : : : +- ReusedExchange (80) + : : : : : +- BroadcastExchange (86) + : : : : : +- * Filter (85) + : : : : : +- * ColumnarToRow (84) + : : : : : +- Scan parquet default.customer_address (83) + : : : : +- ReusedExchange (89) + : : : +- BroadcastExchange (95) + : : : +- * Filter (94) + : : : +- * ColumnarToRow (93) + : : : +- Scan parquet default.income_band (92) + : : +- ReusedExchange (98) + : +- BroadcastExchange (105) + : +- * Project (104) + : +- * Filter (103) + : +- * ColumnarToRow (102) + : +- Scan parquet default.item (101) + +- * Sort (179) + +- Exchange (178) + +- * HashAggregate (177) + +- * HashAggregate (176) + +- * Project (175) + +- * BroadcastHashJoin Inner BuildRight (174) + :- * Project (172) + : +- * BroadcastHashJoin Inner BuildRight (171) + : :- * Project (169) + : : +- * BroadcastHashJoin Inner BuildRight (168) + : : :- * Project (166) + : : : +- * BroadcastHashJoin Inner BuildRight (165) + : : : :- * Project (163) + : : : : +- * BroadcastHashJoin Inner BuildRight (162) + : : : : :- * Project (160) + : : : : : +- * BroadcastHashJoin Inner BuildRight (159) + : : : : : :- * Project (157) + : : : : : : +- * BroadcastHashJoin Inner BuildRight (156) + : : : : : : :- * Project (154) + : : : : : : : +- * BroadcastHashJoin Inner BuildRight (153) + : : : : : : : :- * Project (151) + : : : : : : : : +- * BroadcastHashJoin Inner BuildRight (150) + : : : : : : : : :- * Project (148) + : : : : : : : : : +- * BroadcastHashJoin Inner BuildRight (147) + : : : : : : : : : :- * Project (145) + : : : : : : : : : : +- * BroadcastHashJoin Inner BuildRight (144) + : : : : : : : : : : :- * Project (142) + : : : : : : : : : : : +- * BroadcastHashJoin Inner BuildRight (141) + : : : : : : : : : : : :- * Project (139) + : : : : : : : : : : : : +- * BroadcastHashJoin Inner BuildRight (138) + : : : : : : : : : : : : :- * Project (136) + : : : : : : : : : : : : : +- * BroadcastHashJoin Inner BuildRight (135) + : : : : : : : : : : : : : :- * Project (133) + : : : : : : : : : : : : : : +- * BroadcastHashJoin Inner BuildRight (132) + : : : : : : : : : : : : : : :- * Project (130) + : : : : : : : : : : : : : : : +- * SortMergeJoin Inner (129) + : : : : : : : : : : : : : : : :- * Sort (123) + : : : : : : : : : : : : : : : : +- Exchange (122) + : : : : : : : : : : : : : : : : +- * Project (121) + : : : : : : : : : : : : : : : : +- * BroadcastHashJoin Inner BuildLeft (120) + : : : : : : : : : : : : : : : : :- BroadcastExchange (115) + : : : : : : : : : : : : : : : : : +- * Filter (114) + : : : : : : : : : : : : : : : : : +- * ColumnarToRow (113) + : : : : : : : : : : : : : : : : : +- Scan parquet default.store_sales (112) + : : : : : : : : : : : : : : : : +- * Project (119) + : : : : : : : : : : : : : : : : +- * Filter (118) + : : : : : : : : : : : : : : : : +- * ColumnarToRow (117) + : : : : : : : : : : : : : : : : +- Scan parquet default.store_returns (116) + : : : : : : : : : : : : : : : +- * Sort (128) + : : : : : : : : : : : : : : : +- * Project (127) + : : : : : : : : : : : : : : : +- * Filter (126) + : : : : : : : : : : : : : : : +- * HashAggregate (125) + : : : : : : : : : : : : : : : +- ReusedExchange (124) + : : : : : : : : : : : : : : +- ReusedExchange (131) + : : : : : : : : : : : : : +- ReusedExchange (134) + : : : : : : : : : : : : +- ReusedExchange (137) + : : : : : : : : : : : +- ReusedExchange (140) + : : : : : : : : : : +- ReusedExchange (143) + : : : : : : : : : +- ReusedExchange (146) + : : : : : : : : +- ReusedExchange (149) + : : : : : : : +- ReusedExchange (152) + : : : : : : +- ReusedExchange (155) + : : : : : +- ReusedExchange (158) + : : : : +- ReusedExchange (161) + : : : +- ReusedExchange (164) + : : +- ReusedExchange (167) + : +- ReusedExchange (170) + +- ReusedExchange (173) (1) Scan parquet default.store_sales @@ -300,7 +302,7 @@ Input [8]: [cs_item_sk#19, cs_order_number#20, cs_ext_list_price#21, cr_item_sk# (27) HashAggregate [codegen id : 8] Input [5]: [cs_item_sk#19, cs_ext_list_price#21, cr_refunded_cash#26, cr_reversed_charge#27, cr_store_credit#28] Keys [1]: [cs_item_sk#19] -Functions [2]: [partial_sum(UnscaledValue(cs_ext_list_price#21)), partial_sum(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cr_refunded_cash#26 as decimal(8,2))) + promote_precision(cast(cr_reversed_charge#27 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) + promote_precision(cast(cr_store_credit#28 as decimal(9,2)))), DecimalType(9,2), true))] +Functions [2]: [partial_sum(UnscaledValue(cs_ext_list_price#21)), partial_sum(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cr_refunded_cash#26 as decimal(8,2))) + promote_precision(cast(cr_reversed_charge#27 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) + promote_precision(cast(cr_store_credit#28 as decimal(9,2)))), DecimalType(9,2)))] Aggregate Attributes [3]: [sum#31, sum#32, isEmpty#33] Results [4]: [cs_item_sk#19, sum#34, sum#35, isEmpty#36] @@ -311,13 +313,13 @@ Arguments: hashpartitioning(cs_item_sk#19, 5), ENSURE_REQUIREMENTS, [id=#37] (29) HashAggregate [codegen id : 9] Input [4]: [cs_item_sk#19, sum#34, sum#35, isEmpty#36] Keys [1]: [cs_item_sk#19] -Functions [2]: [sum(UnscaledValue(cs_ext_list_price#21)), sum(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cr_refunded_cash#26 as decimal(8,2))) + promote_precision(cast(cr_reversed_charge#27 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) + promote_precision(cast(cr_store_credit#28 as decimal(9,2)))), DecimalType(9,2), true))] -Aggregate Attributes [2]: [sum(UnscaledValue(cs_ext_list_price#21))#38, sum(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cr_refunded_cash#26 as decimal(8,2))) + promote_precision(cast(cr_reversed_charge#27 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) + promote_precision(cast(cr_store_credit#28 as decimal(9,2)))), DecimalType(9,2), true))#39] -Results [3]: [cs_item_sk#19, MakeDecimal(sum(UnscaledValue(cs_ext_list_price#21))#38,17,2) AS sale#40, sum(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cr_refunded_cash#26 as decimal(8,2))) + promote_precision(cast(cr_reversed_charge#27 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) + promote_precision(cast(cr_store_credit#28 as decimal(9,2)))), DecimalType(9,2), true))#39 AS refund#41] +Functions [2]: [sum(UnscaledValue(cs_ext_list_price#21)), sum(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cr_refunded_cash#26 as decimal(8,2))) + promote_precision(cast(cr_reversed_charge#27 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) + promote_precision(cast(cr_store_credit#28 as decimal(9,2)))), DecimalType(9,2)))] +Aggregate Attributes [2]: [sum(UnscaledValue(cs_ext_list_price#21))#38, sum(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cr_refunded_cash#26 as decimal(8,2))) + promote_precision(cast(cr_reversed_charge#27 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) + promote_precision(cast(cr_store_credit#28 as decimal(9,2)))), DecimalType(9,2)))#39] +Results [3]: [cs_item_sk#19, MakeDecimal(sum(UnscaledValue(cs_ext_list_price#21))#38,17,2) AS sale#40, sum(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cr_refunded_cash#26 as decimal(8,2))) + promote_precision(cast(cr_reversed_charge#27 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) + promote_precision(cast(cr_store_credit#28 as decimal(9,2)))), DecimalType(9,2)))#39 AS refund#41] (30) Filter [codegen id : 9] Input [3]: [cs_item_sk#19, sale#40, refund#41] -Condition : (isnotnull(sale#40) AND (cast(sale#40 as decimal(21,2)) > CheckOverflow((2.00 * promote_precision(refund#41)), DecimalType(21,2), true))) +Condition : (isnotnull(sale#40) AND (cast(sale#40 as decimal(21,2)) > CheckOverflow((2.00 * promote_precision(refund#41)), DecimalType(21,2)))) (31) Project [codegen id : 9] Output [1]: [cs_item_sk#19] @@ -336,7 +338,7 @@ Join condition: None Output [11]: [ss_item_sk#1, ss_customer_sk#2, ss_cdemo_sk#3, ss_hdemo_sk#4, ss_addr_sk#5, ss_store_sk#6, ss_promo_sk#7, ss_wholesale_cost#9, ss_list_price#10, ss_coupon_amt#11, ss_sold_date_sk#12] Input [12]: [ss_item_sk#1, ss_customer_sk#2, ss_cdemo_sk#3, ss_hdemo_sk#4, ss_addr_sk#5, ss_store_sk#6, ss_promo_sk#7, ss_wholesale_cost#9, ss_list_price#10, ss_coupon_amt#11, ss_sold_date_sk#12, cs_item_sk#19] -(35) ReusedExchange [Reuses operator id: 185] +(35) ReusedExchange [Reuses operator id: 187] Output [2]: [d_date_sk#42, d_year#43] (36) BroadcastHashJoin [codegen id : 25] @@ -669,360 +671,368 @@ Functions [4]: [count(1), sum(UnscaledValue(ss_wholesale_cost#9)), sum(UnscaledV Aggregate Attributes [4]: [count(1)#99, sum(UnscaledValue(ss_wholesale_cost#9))#100, sum(UnscaledValue(ss_list_price#10))#101, sum(UnscaledValue(ss_coupon_amt#11))#102] Results [17]: [i_product_name#89 AS product_name#103, i_item_sk#86 AS item_sk#104, s_store_name#45 AS store_name#105, s_zip#46 AS store_zip#106, ca_street_number#73 AS b_street_number#107, ca_street_name#74 AS b_streen_name#108, ca_city#75 AS b_city#109, ca_zip#76 AS b_zip#110, ca_street_number#79 AS c_street_number#111, ca_street_name#80 AS c_street_name#112, ca_city#81 AS c_city#113, ca_zip#82 AS c_zip#114, d_year#43 AS syear#115, count(1)#99 AS cnt#116, MakeDecimal(sum(UnscaledValue(ss_wholesale_cost#9))#100,17,2) AS s1#117, MakeDecimal(sum(UnscaledValue(ss_list_price#10))#101,17,2) AS s2#118, MakeDecimal(sum(UnscaledValue(ss_coupon_amt#11))#102,17,2) AS s3#119] -(110) Sort [codegen id : 25] +(110) Exchange +Input [17]: [product_name#103, item_sk#104, store_name#105, store_zip#106, b_street_number#107, b_streen_name#108, b_city#109, b_zip#110, c_street_number#111, c_street_name#112, c_city#113, c_zip#114, syear#115, cnt#116, s1#117, s2#118, s3#119] +Arguments: hashpartitioning(item_sk#104, store_name#105, store_zip#106, 5), ENSURE_REQUIREMENTS, [id=#120] + +(111) Sort [codegen id : 26] Input [17]: [product_name#103, item_sk#104, store_name#105, store_zip#106, b_street_number#107, b_streen_name#108, b_city#109, b_zip#110, c_street_number#111, c_street_name#112, c_city#113, c_zip#114, syear#115, cnt#116, s1#117, s2#118, s3#119] Arguments: [item_sk#104 ASC NULLS FIRST, store_name#105 ASC NULLS FIRST, store_zip#106 ASC NULLS FIRST], false, 0 -(111) Scan parquet default.store_sales -Output [12]: [ss_item_sk#120, ss_customer_sk#121, ss_cdemo_sk#122, ss_hdemo_sk#123, ss_addr_sk#124, ss_store_sk#125, ss_promo_sk#126, ss_ticket_number#127, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, ss_sold_date_sk#131] +(112) Scan parquet default.store_sales +Output [12]: [ss_item_sk#121, ss_customer_sk#122, ss_cdemo_sk#123, ss_hdemo_sk#124, ss_addr_sk#125, ss_store_sk#126, ss_promo_sk#127, ss_ticket_number#128, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, ss_sold_date_sk#132] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(ss_sold_date_sk#131), dynamicpruningexpression(ss_sold_date_sk#131 IN dynamicpruning#132)] +PartitionFilters: [isnotnull(ss_sold_date_sk#132), dynamicpruningexpression(ss_sold_date_sk#132 IN dynamicpruning#133)] PushedFilters: [IsNotNull(ss_item_sk), IsNotNull(ss_ticket_number), IsNotNull(ss_store_sk), IsNotNull(ss_customer_sk), IsNotNull(ss_cdemo_sk), IsNotNull(ss_promo_sk), IsNotNull(ss_hdemo_sk), IsNotNull(ss_addr_sk)] ReadSchema: struct -(112) ColumnarToRow [codegen id : 26] -Input [12]: [ss_item_sk#120, ss_customer_sk#121, ss_cdemo_sk#122, ss_hdemo_sk#123, ss_addr_sk#124, ss_store_sk#125, ss_promo_sk#126, ss_ticket_number#127, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, ss_sold_date_sk#131] +(113) ColumnarToRow [codegen id : 27] +Input [12]: [ss_item_sk#121, ss_customer_sk#122, ss_cdemo_sk#123, ss_hdemo_sk#124, ss_addr_sk#125, ss_store_sk#126, ss_promo_sk#127, ss_ticket_number#128, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, ss_sold_date_sk#132] -(113) Filter [codegen id : 26] -Input [12]: [ss_item_sk#120, ss_customer_sk#121, ss_cdemo_sk#122, ss_hdemo_sk#123, ss_addr_sk#124, ss_store_sk#125, ss_promo_sk#126, ss_ticket_number#127, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, ss_sold_date_sk#131] -Condition : (((((((isnotnull(ss_item_sk#120) AND isnotnull(ss_ticket_number#127)) AND isnotnull(ss_store_sk#125)) AND isnotnull(ss_customer_sk#121)) AND isnotnull(ss_cdemo_sk#122)) AND isnotnull(ss_promo_sk#126)) AND isnotnull(ss_hdemo_sk#123)) AND isnotnull(ss_addr_sk#124)) +(114) Filter [codegen id : 27] +Input [12]: [ss_item_sk#121, ss_customer_sk#122, ss_cdemo_sk#123, ss_hdemo_sk#124, ss_addr_sk#125, ss_store_sk#126, ss_promo_sk#127, ss_ticket_number#128, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, ss_sold_date_sk#132] +Condition : (((((((isnotnull(ss_item_sk#121) AND isnotnull(ss_ticket_number#128)) AND isnotnull(ss_store_sk#126)) AND isnotnull(ss_customer_sk#122)) AND isnotnull(ss_cdemo_sk#123)) AND isnotnull(ss_promo_sk#127)) AND isnotnull(ss_hdemo_sk#124)) AND isnotnull(ss_addr_sk#125)) -(114) BroadcastExchange -Input [12]: [ss_item_sk#120, ss_customer_sk#121, ss_cdemo_sk#122, ss_hdemo_sk#123, ss_addr_sk#124, ss_store_sk#125, ss_promo_sk#126, ss_ticket_number#127, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, ss_sold_date_sk#131] -Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, false] as bigint), 32) | (cast(input[7, int, false] as bigint) & 4294967295))),false), [id=#133] +(115) BroadcastExchange +Input [12]: [ss_item_sk#121, ss_customer_sk#122, ss_cdemo_sk#123, ss_hdemo_sk#124, ss_addr_sk#125, ss_store_sk#126, ss_promo_sk#127, ss_ticket_number#128, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, ss_sold_date_sk#132] +Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, false] as bigint), 32) | (cast(input[7, int, false] as bigint) & 4294967295))),false), [id=#134] -(115) Scan parquet default.store_returns -Output [3]: [sr_item_sk#134, sr_ticket_number#135, sr_returned_date_sk#136] +(116) Scan parquet default.store_returns +Output [3]: [sr_item_sk#135, sr_ticket_number#136, sr_returned_date_sk#137] Batched: true Location [not included in comparison]/{warehouse_dir}/store_returns] PushedFilters: [IsNotNull(sr_item_sk), IsNotNull(sr_ticket_number)] ReadSchema: struct -(116) ColumnarToRow -Input [3]: [sr_item_sk#134, sr_ticket_number#135, sr_returned_date_sk#136] +(117) ColumnarToRow +Input [3]: [sr_item_sk#135, sr_ticket_number#136, sr_returned_date_sk#137] -(117) Filter -Input [3]: [sr_item_sk#134, sr_ticket_number#135, sr_returned_date_sk#136] -Condition : (isnotnull(sr_item_sk#134) AND isnotnull(sr_ticket_number#135)) +(118) Filter +Input [3]: [sr_item_sk#135, sr_ticket_number#136, sr_returned_date_sk#137] +Condition : (isnotnull(sr_item_sk#135) AND isnotnull(sr_ticket_number#136)) -(118) Project -Output [2]: [sr_item_sk#134, sr_ticket_number#135] -Input [3]: [sr_item_sk#134, sr_ticket_number#135, sr_returned_date_sk#136] +(119) Project +Output [2]: [sr_item_sk#135, sr_ticket_number#136] +Input [3]: [sr_item_sk#135, sr_ticket_number#136, sr_returned_date_sk#137] -(119) BroadcastHashJoin [codegen id : 27] -Left keys [2]: [ss_item_sk#120, ss_ticket_number#127] -Right keys [2]: [sr_item_sk#134, sr_ticket_number#135] +(120) BroadcastHashJoin [codegen id : 28] +Left keys [2]: [ss_item_sk#121, ss_ticket_number#128] +Right keys [2]: [sr_item_sk#135, sr_ticket_number#136] Join condition: None -(120) Project [codegen id : 27] -Output [11]: [ss_item_sk#120, ss_customer_sk#121, ss_cdemo_sk#122, ss_hdemo_sk#123, ss_addr_sk#124, ss_store_sk#125, ss_promo_sk#126, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, ss_sold_date_sk#131] -Input [14]: [ss_item_sk#120, ss_customer_sk#121, ss_cdemo_sk#122, ss_hdemo_sk#123, ss_addr_sk#124, ss_store_sk#125, ss_promo_sk#126, ss_ticket_number#127, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, ss_sold_date_sk#131, sr_item_sk#134, sr_ticket_number#135] +(121) Project [codegen id : 28] +Output [11]: [ss_item_sk#121, ss_customer_sk#122, ss_cdemo_sk#123, ss_hdemo_sk#124, ss_addr_sk#125, ss_store_sk#126, ss_promo_sk#127, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, ss_sold_date_sk#132] +Input [14]: [ss_item_sk#121, ss_customer_sk#122, ss_cdemo_sk#123, ss_hdemo_sk#124, ss_addr_sk#125, ss_store_sk#126, ss_promo_sk#127, ss_ticket_number#128, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, ss_sold_date_sk#132, sr_item_sk#135, sr_ticket_number#136] -(121) Exchange -Input [11]: [ss_item_sk#120, ss_customer_sk#121, ss_cdemo_sk#122, ss_hdemo_sk#123, ss_addr_sk#124, ss_store_sk#125, ss_promo_sk#126, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, ss_sold_date_sk#131] -Arguments: hashpartitioning(ss_item_sk#120, 5), ENSURE_REQUIREMENTS, [id=#137] +(122) Exchange +Input [11]: [ss_item_sk#121, ss_customer_sk#122, ss_cdemo_sk#123, ss_hdemo_sk#124, ss_addr_sk#125, ss_store_sk#126, ss_promo_sk#127, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, ss_sold_date_sk#132] +Arguments: hashpartitioning(ss_item_sk#121, 5), ENSURE_REQUIREMENTS, [id=#138] -(122) Sort [codegen id : 28] -Input [11]: [ss_item_sk#120, ss_customer_sk#121, ss_cdemo_sk#122, ss_hdemo_sk#123, ss_addr_sk#124, ss_store_sk#125, ss_promo_sk#126, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, ss_sold_date_sk#131] -Arguments: [ss_item_sk#120 ASC NULLS FIRST], false, 0 +(123) Sort [codegen id : 29] +Input [11]: [ss_item_sk#121, ss_customer_sk#122, ss_cdemo_sk#123, ss_hdemo_sk#124, ss_addr_sk#125, ss_store_sk#126, ss_promo_sk#127, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, ss_sold_date_sk#132] +Arguments: [ss_item_sk#121 ASC NULLS FIRST], false, 0 -(123) ReusedExchange [Reuses operator id: 28] -Output [4]: [cs_item_sk#138, sum#139, sum#140, isEmpty#141] +(124) ReusedExchange [Reuses operator id: 28] +Output [4]: [cs_item_sk#139, sum#140, sum#141, isEmpty#142] -(124) HashAggregate [codegen id : 34] -Input [4]: [cs_item_sk#138, sum#139, sum#140, isEmpty#141] -Keys [1]: [cs_item_sk#138] -Functions [2]: [sum(UnscaledValue(cs_ext_list_price#142)), sum(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cr_refunded_cash#143 as decimal(8,2))) + promote_precision(cast(cr_reversed_charge#144 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) + promote_precision(cast(cr_store_credit#145 as decimal(9,2)))), DecimalType(9,2), true))] -Aggregate Attributes [2]: [sum(UnscaledValue(cs_ext_list_price#142))#38, sum(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cr_refunded_cash#143 as decimal(8,2))) + promote_precision(cast(cr_reversed_charge#144 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) + promote_precision(cast(cr_store_credit#145 as decimal(9,2)))), DecimalType(9,2), true))#39] -Results [3]: [cs_item_sk#138, MakeDecimal(sum(UnscaledValue(cs_ext_list_price#142))#38,17,2) AS sale#40, sum(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cr_refunded_cash#143 as decimal(8,2))) + promote_precision(cast(cr_reversed_charge#144 as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) + promote_precision(cast(cr_store_credit#145 as decimal(9,2)))), DecimalType(9,2), true))#39 AS refund#41] +(125) HashAggregate [codegen id : 35] +Input [4]: [cs_item_sk#139, sum#140, sum#141, isEmpty#142] +Keys [1]: [cs_item_sk#139] +Functions [2]: [sum(UnscaledValue(cs_ext_list_price#143)), sum(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cr_refunded_cash#144 as decimal(8,2))) + promote_precision(cast(cr_reversed_charge#145 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) + promote_precision(cast(cr_store_credit#146 as decimal(9,2)))), DecimalType(9,2)))] +Aggregate Attributes [2]: [sum(UnscaledValue(cs_ext_list_price#143))#38, sum(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cr_refunded_cash#144 as decimal(8,2))) + promote_precision(cast(cr_reversed_charge#145 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) + promote_precision(cast(cr_store_credit#146 as decimal(9,2)))), DecimalType(9,2)))#39] +Results [3]: [cs_item_sk#139, MakeDecimal(sum(UnscaledValue(cs_ext_list_price#143))#38,17,2) AS sale#40, sum(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cr_refunded_cash#144 as decimal(8,2))) + promote_precision(cast(cr_reversed_charge#145 as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) + promote_precision(cast(cr_store_credit#146 as decimal(9,2)))), DecimalType(9,2)))#39 AS refund#41] -(125) Filter [codegen id : 34] -Input [3]: [cs_item_sk#138, sale#40, refund#41] -Condition : (isnotnull(sale#40) AND (cast(sale#40 as decimal(21,2)) > CheckOverflow((2.00 * promote_precision(refund#41)), DecimalType(21,2), true))) +(126) Filter [codegen id : 35] +Input [3]: [cs_item_sk#139, sale#40, refund#41] +Condition : (isnotnull(sale#40) AND (cast(sale#40 as decimal(21,2)) > CheckOverflow((2.00 * promote_precision(refund#41)), DecimalType(21,2)))) -(126) Project [codegen id : 34] -Output [1]: [cs_item_sk#138] -Input [3]: [cs_item_sk#138, sale#40, refund#41] +(127) Project [codegen id : 35] +Output [1]: [cs_item_sk#139] +Input [3]: [cs_item_sk#139, sale#40, refund#41] -(127) Sort [codegen id : 34] -Input [1]: [cs_item_sk#138] -Arguments: [cs_item_sk#138 ASC NULLS FIRST], false, 0 +(128) Sort [codegen id : 35] +Input [1]: [cs_item_sk#139] +Arguments: [cs_item_sk#139 ASC NULLS FIRST], false, 0 -(128) SortMergeJoin [codegen id : 50] -Left keys [1]: [ss_item_sk#120] -Right keys [1]: [cs_item_sk#138] +(129) SortMergeJoin [codegen id : 51] +Left keys [1]: [ss_item_sk#121] +Right keys [1]: [cs_item_sk#139] Join condition: None -(129) Project [codegen id : 50] -Output [11]: [ss_item_sk#120, ss_customer_sk#121, ss_cdemo_sk#122, ss_hdemo_sk#123, ss_addr_sk#124, ss_store_sk#125, ss_promo_sk#126, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, ss_sold_date_sk#131] -Input [12]: [ss_item_sk#120, ss_customer_sk#121, ss_cdemo_sk#122, ss_hdemo_sk#123, ss_addr_sk#124, ss_store_sk#125, ss_promo_sk#126, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, ss_sold_date_sk#131, cs_item_sk#138] +(130) Project [codegen id : 51] +Output [11]: [ss_item_sk#121, ss_customer_sk#122, ss_cdemo_sk#123, ss_hdemo_sk#124, ss_addr_sk#125, ss_store_sk#126, ss_promo_sk#127, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, ss_sold_date_sk#132] +Input [12]: [ss_item_sk#121, ss_customer_sk#122, ss_cdemo_sk#123, ss_hdemo_sk#124, ss_addr_sk#125, ss_store_sk#126, ss_promo_sk#127, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, ss_sold_date_sk#132, cs_item_sk#139] -(130) ReusedExchange [Reuses operator id: 189] -Output [2]: [d_date_sk#146, d_year#147] +(131) ReusedExchange [Reuses operator id: 191] +Output [2]: [d_date_sk#147, d_year#148] -(131) BroadcastHashJoin [codegen id : 50] -Left keys [1]: [ss_sold_date_sk#131] -Right keys [1]: [d_date_sk#146] +(132) BroadcastHashJoin [codegen id : 51] +Left keys [1]: [ss_sold_date_sk#132] +Right keys [1]: [d_date_sk#147] Join condition: None -(132) Project [codegen id : 50] -Output [11]: [ss_item_sk#120, ss_customer_sk#121, ss_cdemo_sk#122, ss_hdemo_sk#123, ss_addr_sk#124, ss_store_sk#125, ss_promo_sk#126, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, d_year#147] -Input [13]: [ss_item_sk#120, ss_customer_sk#121, ss_cdemo_sk#122, ss_hdemo_sk#123, ss_addr_sk#124, ss_store_sk#125, ss_promo_sk#126, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, ss_sold_date_sk#131, d_date_sk#146, d_year#147] +(133) Project [codegen id : 51] +Output [11]: [ss_item_sk#121, ss_customer_sk#122, ss_cdemo_sk#123, ss_hdemo_sk#124, ss_addr_sk#125, ss_store_sk#126, ss_promo_sk#127, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, d_year#148] +Input [13]: [ss_item_sk#121, ss_customer_sk#122, ss_cdemo_sk#123, ss_hdemo_sk#124, ss_addr_sk#125, ss_store_sk#126, ss_promo_sk#127, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, ss_sold_date_sk#132, d_date_sk#147, d_year#148] -(133) ReusedExchange [Reuses operator id: 41] -Output [3]: [s_store_sk#148, s_store_name#149, s_zip#150] +(134) ReusedExchange [Reuses operator id: 41] +Output [3]: [s_store_sk#149, s_store_name#150, s_zip#151] -(134) BroadcastHashJoin [codegen id : 50] -Left keys [1]: [ss_store_sk#125] -Right keys [1]: [s_store_sk#148] +(135) BroadcastHashJoin [codegen id : 51] +Left keys [1]: [ss_store_sk#126] +Right keys [1]: [s_store_sk#149] Join condition: None -(135) Project [codegen id : 50] -Output [12]: [ss_item_sk#120, ss_customer_sk#121, ss_cdemo_sk#122, ss_hdemo_sk#123, ss_addr_sk#124, ss_promo_sk#126, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, d_year#147, s_store_name#149, s_zip#150] -Input [14]: [ss_item_sk#120, ss_customer_sk#121, ss_cdemo_sk#122, ss_hdemo_sk#123, ss_addr_sk#124, ss_store_sk#125, ss_promo_sk#126, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, d_year#147, s_store_sk#148, s_store_name#149, s_zip#150] +(136) Project [codegen id : 51] +Output [12]: [ss_item_sk#121, ss_customer_sk#122, ss_cdemo_sk#123, ss_hdemo_sk#124, ss_addr_sk#125, ss_promo_sk#127, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, d_year#148, s_store_name#150, s_zip#151] +Input [14]: [ss_item_sk#121, ss_customer_sk#122, ss_cdemo_sk#123, ss_hdemo_sk#124, ss_addr_sk#125, ss_store_sk#126, ss_promo_sk#127, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, d_year#148, s_store_sk#149, s_store_name#150, s_zip#151] -(136) ReusedExchange [Reuses operator id: 47] -Output [6]: [c_customer_sk#151, c_current_cdemo_sk#152, c_current_hdemo_sk#153, c_current_addr_sk#154, c_first_shipto_date_sk#155, c_first_sales_date_sk#156] +(137) ReusedExchange [Reuses operator id: 47] +Output [6]: [c_customer_sk#152, c_current_cdemo_sk#153, c_current_hdemo_sk#154, c_current_addr_sk#155, c_first_shipto_date_sk#156, c_first_sales_date_sk#157] -(137) BroadcastHashJoin [codegen id : 50] -Left keys [1]: [ss_customer_sk#121] -Right keys [1]: [c_customer_sk#151] +(138) BroadcastHashJoin [codegen id : 51] +Left keys [1]: [ss_customer_sk#122] +Right keys [1]: [c_customer_sk#152] Join condition: None -(138) Project [codegen id : 50] -Output [16]: [ss_item_sk#120, ss_cdemo_sk#122, ss_hdemo_sk#123, ss_addr_sk#124, ss_promo_sk#126, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, d_year#147, s_store_name#149, s_zip#150, c_current_cdemo_sk#152, c_current_hdemo_sk#153, c_current_addr_sk#154, c_first_shipto_date_sk#155, c_first_sales_date_sk#156] -Input [18]: [ss_item_sk#120, ss_customer_sk#121, ss_cdemo_sk#122, ss_hdemo_sk#123, ss_addr_sk#124, ss_promo_sk#126, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, d_year#147, s_store_name#149, s_zip#150, c_customer_sk#151, c_current_cdemo_sk#152, c_current_hdemo_sk#153, c_current_addr_sk#154, c_first_shipto_date_sk#155, c_first_sales_date_sk#156] +(139) Project [codegen id : 51] +Output [16]: [ss_item_sk#121, ss_cdemo_sk#123, ss_hdemo_sk#124, ss_addr_sk#125, ss_promo_sk#127, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, d_year#148, s_store_name#150, s_zip#151, c_current_cdemo_sk#153, c_current_hdemo_sk#154, c_current_addr_sk#155, c_first_shipto_date_sk#156, c_first_sales_date_sk#157] +Input [18]: [ss_item_sk#121, ss_customer_sk#122, ss_cdemo_sk#123, ss_hdemo_sk#124, ss_addr_sk#125, ss_promo_sk#127, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, d_year#148, s_store_name#150, s_zip#151, c_customer_sk#152, c_current_cdemo_sk#153, c_current_hdemo_sk#154, c_current_addr_sk#155, c_first_shipto_date_sk#156, c_first_sales_date_sk#157] -(139) ReusedExchange [Reuses operator id: 53] -Output [2]: [d_date_sk#157, d_year#158] +(140) ReusedExchange [Reuses operator id: 53] +Output [2]: [d_date_sk#158, d_year#159] -(140) BroadcastHashJoin [codegen id : 50] -Left keys [1]: [c_first_sales_date_sk#156] -Right keys [1]: [d_date_sk#157] +(141) BroadcastHashJoin [codegen id : 51] +Left keys [1]: [c_first_sales_date_sk#157] +Right keys [1]: [d_date_sk#158] Join condition: None -(141) Project [codegen id : 50] -Output [16]: [ss_item_sk#120, ss_cdemo_sk#122, ss_hdemo_sk#123, ss_addr_sk#124, ss_promo_sk#126, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, d_year#147, s_store_name#149, s_zip#150, c_current_cdemo_sk#152, c_current_hdemo_sk#153, c_current_addr_sk#154, c_first_shipto_date_sk#155, d_year#158] -Input [18]: [ss_item_sk#120, ss_cdemo_sk#122, ss_hdemo_sk#123, ss_addr_sk#124, ss_promo_sk#126, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, d_year#147, s_store_name#149, s_zip#150, c_current_cdemo_sk#152, c_current_hdemo_sk#153, c_current_addr_sk#154, c_first_shipto_date_sk#155, c_first_sales_date_sk#156, d_date_sk#157, d_year#158] +(142) Project [codegen id : 51] +Output [16]: [ss_item_sk#121, ss_cdemo_sk#123, ss_hdemo_sk#124, ss_addr_sk#125, ss_promo_sk#127, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, d_year#148, s_store_name#150, s_zip#151, c_current_cdemo_sk#153, c_current_hdemo_sk#154, c_current_addr_sk#155, c_first_shipto_date_sk#156, d_year#159] +Input [18]: [ss_item_sk#121, ss_cdemo_sk#123, ss_hdemo_sk#124, ss_addr_sk#125, ss_promo_sk#127, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, d_year#148, s_store_name#150, s_zip#151, c_current_cdemo_sk#153, c_current_hdemo_sk#154, c_current_addr_sk#155, c_first_shipto_date_sk#156, c_first_sales_date_sk#157, d_date_sk#158, d_year#159] -(142) ReusedExchange [Reuses operator id: 53] -Output [2]: [d_date_sk#159, d_year#160] +(143) ReusedExchange [Reuses operator id: 53] +Output [2]: [d_date_sk#160, d_year#161] -(143) BroadcastHashJoin [codegen id : 50] -Left keys [1]: [c_first_shipto_date_sk#155] -Right keys [1]: [d_date_sk#159] +(144) BroadcastHashJoin [codegen id : 51] +Left keys [1]: [c_first_shipto_date_sk#156] +Right keys [1]: [d_date_sk#160] Join condition: None -(144) Project [codegen id : 50] -Output [16]: [ss_item_sk#120, ss_cdemo_sk#122, ss_hdemo_sk#123, ss_addr_sk#124, ss_promo_sk#126, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, d_year#147, s_store_name#149, s_zip#150, c_current_cdemo_sk#152, c_current_hdemo_sk#153, c_current_addr_sk#154, d_year#158, d_year#160] -Input [18]: [ss_item_sk#120, ss_cdemo_sk#122, ss_hdemo_sk#123, ss_addr_sk#124, ss_promo_sk#126, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, d_year#147, s_store_name#149, s_zip#150, c_current_cdemo_sk#152, c_current_hdemo_sk#153, c_current_addr_sk#154, c_first_shipto_date_sk#155, d_year#158, d_date_sk#159, d_year#160] +(145) Project [codegen id : 51] +Output [16]: [ss_item_sk#121, ss_cdemo_sk#123, ss_hdemo_sk#124, ss_addr_sk#125, ss_promo_sk#127, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, d_year#148, s_store_name#150, s_zip#151, c_current_cdemo_sk#153, c_current_hdemo_sk#154, c_current_addr_sk#155, d_year#159, d_year#161] +Input [18]: [ss_item_sk#121, ss_cdemo_sk#123, ss_hdemo_sk#124, ss_addr_sk#125, ss_promo_sk#127, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, d_year#148, s_store_name#150, s_zip#151, c_current_cdemo_sk#153, c_current_hdemo_sk#154, c_current_addr_sk#155, c_first_shipto_date_sk#156, d_year#159, d_date_sk#160, d_year#161] -(145) ReusedExchange [Reuses operator id: 62] -Output [2]: [cd_demo_sk#161, cd_marital_status#162] +(146) ReusedExchange [Reuses operator id: 62] +Output [2]: [cd_demo_sk#162, cd_marital_status#163] -(146) BroadcastHashJoin [codegen id : 50] -Left keys [1]: [ss_cdemo_sk#122] -Right keys [1]: [cd_demo_sk#161] +(147) BroadcastHashJoin [codegen id : 51] +Left keys [1]: [ss_cdemo_sk#123] +Right keys [1]: [cd_demo_sk#162] Join condition: None -(147) Project [codegen id : 50] -Output [16]: [ss_item_sk#120, ss_hdemo_sk#123, ss_addr_sk#124, ss_promo_sk#126, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, d_year#147, s_store_name#149, s_zip#150, c_current_cdemo_sk#152, c_current_hdemo_sk#153, c_current_addr_sk#154, d_year#158, d_year#160, cd_marital_status#162] -Input [18]: [ss_item_sk#120, ss_cdemo_sk#122, ss_hdemo_sk#123, ss_addr_sk#124, ss_promo_sk#126, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, d_year#147, s_store_name#149, s_zip#150, c_current_cdemo_sk#152, c_current_hdemo_sk#153, c_current_addr_sk#154, d_year#158, d_year#160, cd_demo_sk#161, cd_marital_status#162] +(148) Project [codegen id : 51] +Output [16]: [ss_item_sk#121, ss_hdemo_sk#124, ss_addr_sk#125, ss_promo_sk#127, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, d_year#148, s_store_name#150, s_zip#151, c_current_cdemo_sk#153, c_current_hdemo_sk#154, c_current_addr_sk#155, d_year#159, d_year#161, cd_marital_status#163] +Input [18]: [ss_item_sk#121, ss_cdemo_sk#123, ss_hdemo_sk#124, ss_addr_sk#125, ss_promo_sk#127, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, d_year#148, s_store_name#150, s_zip#151, c_current_cdemo_sk#153, c_current_hdemo_sk#154, c_current_addr_sk#155, d_year#159, d_year#161, cd_demo_sk#162, cd_marital_status#163] -(148) ReusedExchange [Reuses operator id: 62] -Output [2]: [cd_demo_sk#163, cd_marital_status#164] +(149) ReusedExchange [Reuses operator id: 62] +Output [2]: [cd_demo_sk#164, cd_marital_status#165] -(149) BroadcastHashJoin [codegen id : 50] -Left keys [1]: [c_current_cdemo_sk#152] -Right keys [1]: [cd_demo_sk#163] -Join condition: NOT (cd_marital_status#162 = cd_marital_status#164) +(150) BroadcastHashJoin [codegen id : 51] +Left keys [1]: [c_current_cdemo_sk#153] +Right keys [1]: [cd_demo_sk#164] +Join condition: NOT (cd_marital_status#163 = cd_marital_status#165) -(150) Project [codegen id : 50] -Output [14]: [ss_item_sk#120, ss_hdemo_sk#123, ss_addr_sk#124, ss_promo_sk#126, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, d_year#147, s_store_name#149, s_zip#150, c_current_hdemo_sk#153, c_current_addr_sk#154, d_year#158, d_year#160] -Input [18]: [ss_item_sk#120, ss_hdemo_sk#123, ss_addr_sk#124, ss_promo_sk#126, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, d_year#147, s_store_name#149, s_zip#150, c_current_cdemo_sk#152, c_current_hdemo_sk#153, c_current_addr_sk#154, d_year#158, d_year#160, cd_marital_status#162, cd_demo_sk#163, cd_marital_status#164] +(151) Project [codegen id : 51] +Output [14]: [ss_item_sk#121, ss_hdemo_sk#124, ss_addr_sk#125, ss_promo_sk#127, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, d_year#148, s_store_name#150, s_zip#151, c_current_hdemo_sk#154, c_current_addr_sk#155, d_year#159, d_year#161] +Input [18]: [ss_item_sk#121, ss_hdemo_sk#124, ss_addr_sk#125, ss_promo_sk#127, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, d_year#148, s_store_name#150, s_zip#151, c_current_cdemo_sk#153, c_current_hdemo_sk#154, c_current_addr_sk#155, d_year#159, d_year#161, cd_marital_status#163, cd_demo_sk#164, cd_marital_status#165] -(151) ReusedExchange [Reuses operator id: 71] -Output [1]: [p_promo_sk#165] +(152) ReusedExchange [Reuses operator id: 71] +Output [1]: [p_promo_sk#166] -(152) BroadcastHashJoin [codegen id : 50] -Left keys [1]: [ss_promo_sk#126] -Right keys [1]: [p_promo_sk#165] +(153) BroadcastHashJoin [codegen id : 51] +Left keys [1]: [ss_promo_sk#127] +Right keys [1]: [p_promo_sk#166] Join condition: None -(153) Project [codegen id : 50] -Output [13]: [ss_item_sk#120, ss_hdemo_sk#123, ss_addr_sk#124, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, d_year#147, s_store_name#149, s_zip#150, c_current_hdemo_sk#153, c_current_addr_sk#154, d_year#158, d_year#160] -Input [15]: [ss_item_sk#120, ss_hdemo_sk#123, ss_addr_sk#124, ss_promo_sk#126, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, d_year#147, s_store_name#149, s_zip#150, c_current_hdemo_sk#153, c_current_addr_sk#154, d_year#158, d_year#160, p_promo_sk#165] +(154) Project [codegen id : 51] +Output [13]: [ss_item_sk#121, ss_hdemo_sk#124, ss_addr_sk#125, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, d_year#148, s_store_name#150, s_zip#151, c_current_hdemo_sk#154, c_current_addr_sk#155, d_year#159, d_year#161] +Input [15]: [ss_item_sk#121, ss_hdemo_sk#124, ss_addr_sk#125, ss_promo_sk#127, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, d_year#148, s_store_name#150, s_zip#151, c_current_hdemo_sk#154, c_current_addr_sk#155, d_year#159, d_year#161, p_promo_sk#166] -(154) ReusedExchange [Reuses operator id: 77] -Output [2]: [hd_demo_sk#166, hd_income_band_sk#167] +(155) ReusedExchange [Reuses operator id: 77] +Output [2]: [hd_demo_sk#167, hd_income_band_sk#168] -(155) BroadcastHashJoin [codegen id : 50] -Left keys [1]: [ss_hdemo_sk#123] -Right keys [1]: [hd_demo_sk#166] +(156) BroadcastHashJoin [codegen id : 51] +Left keys [1]: [ss_hdemo_sk#124] +Right keys [1]: [hd_demo_sk#167] Join condition: None -(156) Project [codegen id : 50] -Output [13]: [ss_item_sk#120, ss_addr_sk#124, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, d_year#147, s_store_name#149, s_zip#150, c_current_hdemo_sk#153, c_current_addr_sk#154, d_year#158, d_year#160, hd_income_band_sk#167] -Input [15]: [ss_item_sk#120, ss_hdemo_sk#123, ss_addr_sk#124, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, d_year#147, s_store_name#149, s_zip#150, c_current_hdemo_sk#153, c_current_addr_sk#154, d_year#158, d_year#160, hd_demo_sk#166, hd_income_band_sk#167] +(157) Project [codegen id : 51] +Output [13]: [ss_item_sk#121, ss_addr_sk#125, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, d_year#148, s_store_name#150, s_zip#151, c_current_hdemo_sk#154, c_current_addr_sk#155, d_year#159, d_year#161, hd_income_band_sk#168] +Input [15]: [ss_item_sk#121, ss_hdemo_sk#124, ss_addr_sk#125, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, d_year#148, s_store_name#150, s_zip#151, c_current_hdemo_sk#154, c_current_addr_sk#155, d_year#159, d_year#161, hd_demo_sk#167, hd_income_band_sk#168] -(157) ReusedExchange [Reuses operator id: 77] -Output [2]: [hd_demo_sk#168, hd_income_band_sk#169] +(158) ReusedExchange [Reuses operator id: 77] +Output [2]: [hd_demo_sk#169, hd_income_band_sk#170] -(158) BroadcastHashJoin [codegen id : 50] -Left keys [1]: [c_current_hdemo_sk#153] -Right keys [1]: [hd_demo_sk#168] +(159) BroadcastHashJoin [codegen id : 51] +Left keys [1]: [c_current_hdemo_sk#154] +Right keys [1]: [hd_demo_sk#169] Join condition: None -(159) Project [codegen id : 50] -Output [13]: [ss_item_sk#120, ss_addr_sk#124, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, d_year#147, s_store_name#149, s_zip#150, c_current_addr_sk#154, d_year#158, d_year#160, hd_income_band_sk#167, hd_income_band_sk#169] -Input [15]: [ss_item_sk#120, ss_addr_sk#124, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, d_year#147, s_store_name#149, s_zip#150, c_current_hdemo_sk#153, c_current_addr_sk#154, d_year#158, d_year#160, hd_income_band_sk#167, hd_demo_sk#168, hd_income_band_sk#169] +(160) Project [codegen id : 51] +Output [13]: [ss_item_sk#121, ss_addr_sk#125, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, d_year#148, s_store_name#150, s_zip#151, c_current_addr_sk#155, d_year#159, d_year#161, hd_income_band_sk#168, hd_income_band_sk#170] +Input [15]: [ss_item_sk#121, ss_addr_sk#125, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, d_year#148, s_store_name#150, s_zip#151, c_current_hdemo_sk#154, c_current_addr_sk#155, d_year#159, d_year#161, hd_income_band_sk#168, hd_demo_sk#169, hd_income_band_sk#170] -(160) ReusedExchange [Reuses operator id: 86] -Output [5]: [ca_address_sk#170, ca_street_number#171, ca_street_name#172, ca_city#173, ca_zip#174] +(161) ReusedExchange [Reuses operator id: 86] +Output [5]: [ca_address_sk#171, ca_street_number#172, ca_street_name#173, ca_city#174, ca_zip#175] -(161) BroadcastHashJoin [codegen id : 50] -Left keys [1]: [ss_addr_sk#124] -Right keys [1]: [ca_address_sk#170] +(162) BroadcastHashJoin [codegen id : 51] +Left keys [1]: [ss_addr_sk#125] +Right keys [1]: [ca_address_sk#171] Join condition: None -(162) Project [codegen id : 50] -Output [16]: [ss_item_sk#120, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, d_year#147, s_store_name#149, s_zip#150, c_current_addr_sk#154, d_year#158, d_year#160, hd_income_band_sk#167, hd_income_band_sk#169, ca_street_number#171, ca_street_name#172, ca_city#173, ca_zip#174] -Input [18]: [ss_item_sk#120, ss_addr_sk#124, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, d_year#147, s_store_name#149, s_zip#150, c_current_addr_sk#154, d_year#158, d_year#160, hd_income_band_sk#167, hd_income_band_sk#169, ca_address_sk#170, ca_street_number#171, ca_street_name#172, ca_city#173, ca_zip#174] +(163) Project [codegen id : 51] +Output [16]: [ss_item_sk#121, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, d_year#148, s_store_name#150, s_zip#151, c_current_addr_sk#155, d_year#159, d_year#161, hd_income_band_sk#168, hd_income_band_sk#170, ca_street_number#172, ca_street_name#173, ca_city#174, ca_zip#175] +Input [18]: [ss_item_sk#121, ss_addr_sk#125, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, d_year#148, s_store_name#150, s_zip#151, c_current_addr_sk#155, d_year#159, d_year#161, hd_income_band_sk#168, hd_income_band_sk#170, ca_address_sk#171, ca_street_number#172, ca_street_name#173, ca_city#174, ca_zip#175] -(163) ReusedExchange [Reuses operator id: 86] -Output [5]: [ca_address_sk#175, ca_street_number#176, ca_street_name#177, ca_city#178, ca_zip#179] +(164) ReusedExchange [Reuses operator id: 86] +Output [5]: [ca_address_sk#176, ca_street_number#177, ca_street_name#178, ca_city#179, ca_zip#180] -(164) BroadcastHashJoin [codegen id : 50] -Left keys [1]: [c_current_addr_sk#154] -Right keys [1]: [ca_address_sk#175] +(165) BroadcastHashJoin [codegen id : 51] +Left keys [1]: [c_current_addr_sk#155] +Right keys [1]: [ca_address_sk#176] Join condition: None -(165) Project [codegen id : 50] -Output [19]: [ss_item_sk#120, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, d_year#147, s_store_name#149, s_zip#150, d_year#158, d_year#160, hd_income_band_sk#167, hd_income_band_sk#169, ca_street_number#171, ca_street_name#172, ca_city#173, ca_zip#174, ca_street_number#176, ca_street_name#177, ca_city#178, ca_zip#179] -Input [21]: [ss_item_sk#120, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, d_year#147, s_store_name#149, s_zip#150, c_current_addr_sk#154, d_year#158, d_year#160, hd_income_band_sk#167, hd_income_band_sk#169, ca_street_number#171, ca_street_name#172, ca_city#173, ca_zip#174, ca_address_sk#175, ca_street_number#176, ca_street_name#177, ca_city#178, ca_zip#179] +(166) Project [codegen id : 51] +Output [19]: [ss_item_sk#121, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, d_year#148, s_store_name#150, s_zip#151, d_year#159, d_year#161, hd_income_band_sk#168, hd_income_band_sk#170, ca_street_number#172, ca_street_name#173, ca_city#174, ca_zip#175, ca_street_number#177, ca_street_name#178, ca_city#179, ca_zip#180] +Input [21]: [ss_item_sk#121, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, d_year#148, s_store_name#150, s_zip#151, c_current_addr_sk#155, d_year#159, d_year#161, hd_income_band_sk#168, hd_income_band_sk#170, ca_street_number#172, ca_street_name#173, ca_city#174, ca_zip#175, ca_address_sk#176, ca_street_number#177, ca_street_name#178, ca_city#179, ca_zip#180] -(166) ReusedExchange [Reuses operator id: 95] -Output [1]: [ib_income_band_sk#180] +(167) ReusedExchange [Reuses operator id: 95] +Output [1]: [ib_income_band_sk#181] -(167) BroadcastHashJoin [codegen id : 50] -Left keys [1]: [hd_income_band_sk#167] -Right keys [1]: [ib_income_band_sk#180] +(168) BroadcastHashJoin [codegen id : 51] +Left keys [1]: [hd_income_band_sk#168] +Right keys [1]: [ib_income_band_sk#181] Join condition: None -(168) Project [codegen id : 50] -Output [18]: [ss_item_sk#120, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, d_year#147, s_store_name#149, s_zip#150, d_year#158, d_year#160, hd_income_band_sk#169, ca_street_number#171, ca_street_name#172, ca_city#173, ca_zip#174, ca_street_number#176, ca_street_name#177, ca_city#178, ca_zip#179] -Input [20]: [ss_item_sk#120, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, d_year#147, s_store_name#149, s_zip#150, d_year#158, d_year#160, hd_income_band_sk#167, hd_income_band_sk#169, ca_street_number#171, ca_street_name#172, ca_city#173, ca_zip#174, ca_street_number#176, ca_street_name#177, ca_city#178, ca_zip#179, ib_income_band_sk#180] +(169) Project [codegen id : 51] +Output [18]: [ss_item_sk#121, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, d_year#148, s_store_name#150, s_zip#151, d_year#159, d_year#161, hd_income_band_sk#170, ca_street_number#172, ca_street_name#173, ca_city#174, ca_zip#175, ca_street_number#177, ca_street_name#178, ca_city#179, ca_zip#180] +Input [20]: [ss_item_sk#121, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, d_year#148, s_store_name#150, s_zip#151, d_year#159, d_year#161, hd_income_band_sk#168, hd_income_band_sk#170, ca_street_number#172, ca_street_name#173, ca_city#174, ca_zip#175, ca_street_number#177, ca_street_name#178, ca_city#179, ca_zip#180, ib_income_band_sk#181] -(169) ReusedExchange [Reuses operator id: 95] -Output [1]: [ib_income_band_sk#181] +(170) ReusedExchange [Reuses operator id: 95] +Output [1]: [ib_income_band_sk#182] -(170) BroadcastHashJoin [codegen id : 50] -Left keys [1]: [hd_income_band_sk#169] -Right keys [1]: [ib_income_band_sk#181] +(171) BroadcastHashJoin [codegen id : 51] +Left keys [1]: [hd_income_band_sk#170] +Right keys [1]: [ib_income_band_sk#182] Join condition: None -(171) Project [codegen id : 50] -Output [17]: [ss_item_sk#120, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, d_year#147, s_store_name#149, s_zip#150, d_year#158, d_year#160, ca_street_number#171, ca_street_name#172, ca_city#173, ca_zip#174, ca_street_number#176, ca_street_name#177, ca_city#178, ca_zip#179] -Input [19]: [ss_item_sk#120, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, d_year#147, s_store_name#149, s_zip#150, d_year#158, d_year#160, hd_income_band_sk#169, ca_street_number#171, ca_street_name#172, ca_city#173, ca_zip#174, ca_street_number#176, ca_street_name#177, ca_city#178, ca_zip#179, ib_income_band_sk#181] +(172) Project [codegen id : 51] +Output [17]: [ss_item_sk#121, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, d_year#148, s_store_name#150, s_zip#151, d_year#159, d_year#161, ca_street_number#172, ca_street_name#173, ca_city#174, ca_zip#175, ca_street_number#177, ca_street_name#178, ca_city#179, ca_zip#180] +Input [19]: [ss_item_sk#121, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, d_year#148, s_store_name#150, s_zip#151, d_year#159, d_year#161, hd_income_band_sk#170, ca_street_number#172, ca_street_name#173, ca_city#174, ca_zip#175, ca_street_number#177, ca_street_name#178, ca_city#179, ca_zip#180, ib_income_band_sk#182] -(172) ReusedExchange [Reuses operator id: 105] -Output [2]: [i_item_sk#182, i_product_name#183] +(173) ReusedExchange [Reuses operator id: 105] +Output [2]: [i_item_sk#183, i_product_name#184] -(173) BroadcastHashJoin [codegen id : 50] -Left keys [1]: [ss_item_sk#120] -Right keys [1]: [i_item_sk#182] +(174) BroadcastHashJoin [codegen id : 51] +Left keys [1]: [ss_item_sk#121] +Right keys [1]: [i_item_sk#183] Join condition: None -(174) Project [codegen id : 50] -Output [18]: [ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, d_year#147, d_year#158, d_year#160, s_store_name#149, s_zip#150, ca_street_number#171, ca_street_name#172, ca_city#173, ca_zip#174, ca_street_number#176, ca_street_name#177, ca_city#178, ca_zip#179, i_item_sk#182, i_product_name#183] -Input [19]: [ss_item_sk#120, ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, d_year#147, s_store_name#149, s_zip#150, d_year#158, d_year#160, ca_street_number#171, ca_street_name#172, ca_city#173, ca_zip#174, ca_street_number#176, ca_street_name#177, ca_city#178, ca_zip#179, i_item_sk#182, i_product_name#183] - -(175) HashAggregate [codegen id : 50] -Input [18]: [ss_wholesale_cost#128, ss_list_price#129, ss_coupon_amt#130, d_year#147, d_year#158, d_year#160, s_store_name#149, s_zip#150, ca_street_number#171, ca_street_name#172, ca_city#173, ca_zip#174, ca_street_number#176, ca_street_name#177, ca_city#178, ca_zip#179, i_item_sk#182, i_product_name#183] -Keys [15]: [i_product_name#183, i_item_sk#182, s_store_name#149, s_zip#150, ca_street_number#171, ca_street_name#172, ca_city#173, ca_zip#174, ca_street_number#176, ca_street_name#177, ca_city#178, ca_zip#179, d_year#147, d_year#158, d_year#160] -Functions [4]: [partial_count(1), partial_sum(UnscaledValue(ss_wholesale_cost#128)), partial_sum(UnscaledValue(ss_list_price#129)), partial_sum(UnscaledValue(ss_coupon_amt#130))] -Aggregate Attributes [4]: [count#91, sum#184, sum#185, sum#186] -Results [19]: [i_product_name#183, i_item_sk#182, s_store_name#149, s_zip#150, ca_street_number#171, ca_street_name#172, ca_city#173, ca_zip#174, ca_street_number#176, ca_street_name#177, ca_city#178, ca_zip#179, d_year#147, d_year#158, d_year#160, count#95, sum#187, sum#188, sum#189] - -(176) HashAggregate [codegen id : 50] -Input [19]: [i_product_name#183, i_item_sk#182, s_store_name#149, s_zip#150, ca_street_number#171, ca_street_name#172, ca_city#173, ca_zip#174, ca_street_number#176, ca_street_name#177, ca_city#178, ca_zip#179, d_year#147, d_year#158, d_year#160, count#95, sum#187, sum#188, sum#189] -Keys [15]: [i_product_name#183, i_item_sk#182, s_store_name#149, s_zip#150, ca_street_number#171, ca_street_name#172, ca_city#173, ca_zip#174, ca_street_number#176, ca_street_name#177, ca_city#178, ca_zip#179, d_year#147, d_year#158, d_year#160] -Functions [4]: [count(1), sum(UnscaledValue(ss_wholesale_cost#128)), sum(UnscaledValue(ss_list_price#129)), sum(UnscaledValue(ss_coupon_amt#130))] -Aggregate Attributes [4]: [count(1)#99, sum(UnscaledValue(ss_wholesale_cost#128))#100, sum(UnscaledValue(ss_list_price#129))#101, sum(UnscaledValue(ss_coupon_amt#130))#102] -Results [8]: [i_item_sk#182 AS item_sk#190, s_store_name#149 AS store_name#191, s_zip#150 AS store_zip#192, d_year#147 AS syear#193, count(1)#99 AS cnt#194, MakeDecimal(sum(UnscaledValue(ss_wholesale_cost#128))#100,17,2) AS s1#195, MakeDecimal(sum(UnscaledValue(ss_list_price#129))#101,17,2) AS s2#196, MakeDecimal(sum(UnscaledValue(ss_coupon_amt#130))#102,17,2) AS s3#197] - -(177) Sort [codegen id : 50] -Input [8]: [item_sk#190, store_name#191, store_zip#192, syear#193, cnt#194, s1#195, s2#196, s3#197] -Arguments: [item_sk#190 ASC NULLS FIRST, store_name#191 ASC NULLS FIRST, store_zip#192 ASC NULLS FIRST], false, 0 - -(178) SortMergeJoin [codegen id : 51] +(175) Project [codegen id : 51] +Output [18]: [ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, d_year#148, d_year#159, d_year#161, s_store_name#150, s_zip#151, ca_street_number#172, ca_street_name#173, ca_city#174, ca_zip#175, ca_street_number#177, ca_street_name#178, ca_city#179, ca_zip#180, i_item_sk#183, i_product_name#184] +Input [19]: [ss_item_sk#121, ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, d_year#148, s_store_name#150, s_zip#151, d_year#159, d_year#161, ca_street_number#172, ca_street_name#173, ca_city#174, ca_zip#175, ca_street_number#177, ca_street_name#178, ca_city#179, ca_zip#180, i_item_sk#183, i_product_name#184] + +(176) HashAggregate [codegen id : 51] +Input [18]: [ss_wholesale_cost#129, ss_list_price#130, ss_coupon_amt#131, d_year#148, d_year#159, d_year#161, s_store_name#150, s_zip#151, ca_street_number#172, ca_street_name#173, ca_city#174, ca_zip#175, ca_street_number#177, ca_street_name#178, ca_city#179, ca_zip#180, i_item_sk#183, i_product_name#184] +Keys [15]: [i_product_name#184, i_item_sk#183, s_store_name#150, s_zip#151, ca_street_number#172, ca_street_name#173, ca_city#174, ca_zip#175, ca_street_number#177, ca_street_name#178, ca_city#179, ca_zip#180, d_year#148, d_year#159, d_year#161] +Functions [4]: [partial_count(1), partial_sum(UnscaledValue(ss_wholesale_cost#129)), partial_sum(UnscaledValue(ss_list_price#130)), partial_sum(UnscaledValue(ss_coupon_amt#131))] +Aggregate Attributes [4]: [count#91, sum#185, sum#186, sum#187] +Results [19]: [i_product_name#184, i_item_sk#183, s_store_name#150, s_zip#151, ca_street_number#172, ca_street_name#173, ca_city#174, ca_zip#175, ca_street_number#177, ca_street_name#178, ca_city#179, ca_zip#180, d_year#148, d_year#159, d_year#161, count#95, sum#188, sum#189, sum#190] + +(177) HashAggregate [codegen id : 51] +Input [19]: [i_product_name#184, i_item_sk#183, s_store_name#150, s_zip#151, ca_street_number#172, ca_street_name#173, ca_city#174, ca_zip#175, ca_street_number#177, ca_street_name#178, ca_city#179, ca_zip#180, d_year#148, d_year#159, d_year#161, count#95, sum#188, sum#189, sum#190] +Keys [15]: [i_product_name#184, i_item_sk#183, s_store_name#150, s_zip#151, ca_street_number#172, ca_street_name#173, ca_city#174, ca_zip#175, ca_street_number#177, ca_street_name#178, ca_city#179, ca_zip#180, d_year#148, d_year#159, d_year#161] +Functions [4]: [count(1), sum(UnscaledValue(ss_wholesale_cost#129)), sum(UnscaledValue(ss_list_price#130)), sum(UnscaledValue(ss_coupon_amt#131))] +Aggregate Attributes [4]: [count(1)#99, sum(UnscaledValue(ss_wholesale_cost#129))#100, sum(UnscaledValue(ss_list_price#130))#101, sum(UnscaledValue(ss_coupon_amt#131))#102] +Results [8]: [i_item_sk#183 AS item_sk#191, s_store_name#150 AS store_name#192, s_zip#151 AS store_zip#193, d_year#148 AS syear#194, count(1)#99 AS cnt#195, MakeDecimal(sum(UnscaledValue(ss_wholesale_cost#129))#100,17,2) AS s1#196, MakeDecimal(sum(UnscaledValue(ss_list_price#130))#101,17,2) AS s2#197, MakeDecimal(sum(UnscaledValue(ss_coupon_amt#131))#102,17,2) AS s3#198] + +(178) Exchange +Input [8]: [item_sk#191, store_name#192, store_zip#193, syear#194, cnt#195, s1#196, s2#197, s3#198] +Arguments: hashpartitioning(item_sk#191, store_name#192, store_zip#193, 5), ENSURE_REQUIREMENTS, [id=#199] + +(179) Sort [codegen id : 52] +Input [8]: [item_sk#191, store_name#192, store_zip#193, syear#194, cnt#195, s1#196, s2#197, s3#198] +Arguments: [item_sk#191 ASC NULLS FIRST, store_name#192 ASC NULLS FIRST, store_zip#193 ASC NULLS FIRST], false, 0 + +(180) SortMergeJoin [codegen id : 53] Left keys [3]: [item_sk#104, store_name#105, store_zip#106] -Right keys [3]: [item_sk#190, store_name#191, store_zip#192] -Join condition: (cnt#194 <= cnt#116) +Right keys [3]: [item_sk#191, store_name#192, store_zip#193] +Join condition: (cnt#195 <= cnt#116) -(179) Project [codegen id : 51] -Output [21]: [product_name#103, store_name#105, store_zip#106, b_street_number#107, b_streen_name#108, b_city#109, b_zip#110, c_street_number#111, c_street_name#112, c_city#113, c_zip#114, syear#115, cnt#116, s1#117, s2#118, s3#119, s1#195, s2#196, s3#197, syear#193, cnt#194] -Input [25]: [product_name#103, item_sk#104, store_name#105, store_zip#106, b_street_number#107, b_streen_name#108, b_city#109, b_zip#110, c_street_number#111, c_street_name#112, c_city#113, c_zip#114, syear#115, cnt#116, s1#117, s2#118, s3#119, item_sk#190, store_name#191, store_zip#192, syear#193, cnt#194, s1#195, s2#196, s3#197] +(181) Project [codegen id : 53] +Output [21]: [product_name#103, store_name#105, store_zip#106, b_street_number#107, b_streen_name#108, b_city#109, b_zip#110, c_street_number#111, c_street_name#112, c_city#113, c_zip#114, syear#115, cnt#116, s1#117, s2#118, s3#119, s1#196, s2#197, s3#198, syear#194, cnt#195] +Input [25]: [product_name#103, item_sk#104, store_name#105, store_zip#106, b_street_number#107, b_streen_name#108, b_city#109, b_zip#110, c_street_number#111, c_street_name#112, c_city#113, c_zip#114, syear#115, cnt#116, s1#117, s2#118, s3#119, item_sk#191, store_name#192, store_zip#193, syear#194, cnt#195, s1#196, s2#197, s3#198] -(180) Exchange -Input [21]: [product_name#103, store_name#105, store_zip#106, b_street_number#107, b_streen_name#108, b_city#109, b_zip#110, c_street_number#111, c_street_name#112, c_city#113, c_zip#114, syear#115, cnt#116, s1#117, s2#118, s3#119, s1#195, s2#196, s3#197, syear#193, cnt#194] -Arguments: rangepartitioning(product_name#103 ASC NULLS FIRST, store_name#105 ASC NULLS FIRST, cnt#194 ASC NULLS FIRST, s1#117 ASC NULLS FIRST, s1#195 ASC NULLS FIRST, 5), ENSURE_REQUIREMENTS, [id=#198] +(182) Exchange +Input [21]: [product_name#103, store_name#105, store_zip#106, b_street_number#107, b_streen_name#108, b_city#109, b_zip#110, c_street_number#111, c_street_name#112, c_city#113, c_zip#114, syear#115, cnt#116, s1#117, s2#118, s3#119, s1#196, s2#197, s3#198, syear#194, cnt#195] +Arguments: rangepartitioning(product_name#103 ASC NULLS FIRST, store_name#105 ASC NULLS FIRST, cnt#195 ASC NULLS FIRST, s1#117 ASC NULLS FIRST, s1#196 ASC NULLS FIRST, 5), ENSURE_REQUIREMENTS, [id=#200] -(181) Sort [codegen id : 52] -Input [21]: [product_name#103, store_name#105, store_zip#106, b_street_number#107, b_streen_name#108, b_city#109, b_zip#110, c_street_number#111, c_street_name#112, c_city#113, c_zip#114, syear#115, cnt#116, s1#117, s2#118, s3#119, s1#195, s2#196, s3#197, syear#193, cnt#194] -Arguments: [product_name#103 ASC NULLS FIRST, store_name#105 ASC NULLS FIRST, cnt#194 ASC NULLS FIRST, s1#117 ASC NULLS FIRST, s1#195 ASC NULLS FIRST], true, 0 +(183) Sort [codegen id : 54] +Input [21]: [product_name#103, store_name#105, store_zip#106, b_street_number#107, b_streen_name#108, b_city#109, b_zip#110, c_street_number#111, c_street_name#112, c_city#113, c_zip#114, syear#115, cnt#116, s1#117, s2#118, s3#119, s1#196, s2#197, s3#198, syear#194, cnt#195] +Arguments: [product_name#103 ASC NULLS FIRST, store_name#105 ASC NULLS FIRST, cnt#195 ASC NULLS FIRST, s1#117 ASC NULLS FIRST, s1#196 ASC NULLS FIRST], true, 0 ===== Subqueries ===== Subquery:1 Hosting operator id = 1 Hosting Expression = ss_sold_date_sk#12 IN dynamicpruning#13 -BroadcastExchange (185) -+- * Filter (184) - +- * ColumnarToRow (183) - +- Scan parquet default.date_dim (182) +BroadcastExchange (187) ++- * Filter (186) + +- * ColumnarToRow (185) + +- Scan parquet default.date_dim (184) -(182) Scan parquet default.date_dim +(184) Scan parquet default.date_dim Output [2]: [d_date_sk#42, d_year#43] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), EqualTo(d_year,1999), IsNotNull(d_date_sk)] ReadSchema: struct -(183) ColumnarToRow [codegen id : 1] +(185) ColumnarToRow [codegen id : 1] Input [2]: [d_date_sk#42, d_year#43] -(184) Filter [codegen id : 1] +(186) Filter [codegen id : 1] Input [2]: [d_date_sk#42, d_year#43] Condition : ((isnotnull(d_year#43) AND (d_year#43 = 1999)) AND isnotnull(d_date_sk#42)) -(185) BroadcastExchange +(187) BroadcastExchange Input [2]: [d_date_sk#42, d_year#43] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#199] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#201] -Subquery:2 Hosting operator id = 111 Hosting Expression = ss_sold_date_sk#131 IN dynamicpruning#132 -BroadcastExchange (189) -+- * Filter (188) - +- * ColumnarToRow (187) - +- Scan parquet default.date_dim (186) +Subquery:2 Hosting operator id = 112 Hosting Expression = ss_sold_date_sk#132 IN dynamicpruning#133 +BroadcastExchange (191) ++- * Filter (190) + +- * ColumnarToRow (189) + +- Scan parquet default.date_dim (188) -(186) Scan parquet default.date_dim -Output [2]: [d_date_sk#146, d_year#147] +(188) Scan parquet default.date_dim +Output [2]: [d_date_sk#147, d_year#148] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), EqualTo(d_year,2000), IsNotNull(d_date_sk)] ReadSchema: struct -(187) ColumnarToRow [codegen id : 1] -Input [2]: [d_date_sk#146, d_year#147] +(189) ColumnarToRow [codegen id : 1] +Input [2]: [d_date_sk#147, d_year#148] -(188) Filter [codegen id : 1] -Input [2]: [d_date_sk#146, d_year#147] -Condition : ((isnotnull(d_year#147) AND (d_year#147 = 2000)) AND isnotnull(d_date_sk#146)) +(190) Filter [codegen id : 1] +Input [2]: [d_date_sk#147, d_year#148] +Condition : ((isnotnull(d_year#148) AND (d_year#148 = 2000)) AND isnotnull(d_date_sk#147)) -(189) BroadcastExchange -Input [2]: [d_date_sk#146, d_year#147] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#200] +(191) BroadcastExchange +Input [2]: [d_date_sk#147, d_year#148] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#202] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q64/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q64/simplified.txt index 716aaa2663630..859101af5baf2 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q64/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q64/simplified.txt @@ -1,283 +1,289 @@ -WholeStageCodegen (52) +WholeStageCodegen (54) Sort [product_name,store_name,cnt,s1,s1] InputAdapter Exchange [product_name,store_name,cnt,s1,s1] #1 - WholeStageCodegen (51) + WholeStageCodegen (53) Project [product_name,store_name,store_zip,b_street_number,b_streen_name,b_city,b_zip,c_street_number,c_street_name,c_city,c_zip,syear,cnt,s1,s2,s3,s1,s2,s3,syear,cnt] SortMergeJoin [item_sk,store_name,store_zip,item_sk,store_name,store_zip,cnt,cnt] InputAdapter - WholeStageCodegen (25) + WholeStageCodegen (26) Sort [item_sk,store_name,store_zip] - HashAggregate [i_product_name,i_item_sk,s_store_name,s_zip,ca_street_number,ca_street_name,ca_city,ca_zip,ca_street_number,ca_street_name,ca_city,ca_zip,d_year,d_year,d_year,count,sum,sum,sum] [count(1),sum(UnscaledValue(ss_wholesale_cost)),sum(UnscaledValue(ss_list_price)),sum(UnscaledValue(ss_coupon_amt)),product_name,item_sk,store_name,store_zip,b_street_number,b_streen_name,b_city,b_zip,c_street_number,c_street_name,c_city,c_zip,syear,cnt,s1,s2,s3,count,sum,sum,sum] - HashAggregate [i_product_name,i_item_sk,s_store_name,s_zip,ca_street_number,ca_street_name,ca_city,ca_zip,ca_street_number,ca_street_name,ca_city,ca_zip,d_year,d_year,d_year,ss_wholesale_cost,ss_list_price,ss_coupon_amt] [count,sum,sum,sum,count,sum,sum,sum] - Project [ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,d_year,d_year,s_store_name,s_zip,ca_street_number,ca_street_name,ca_city,ca_zip,ca_street_number,ca_street_name,ca_city,ca_zip,i_item_sk,i_product_name] - BroadcastHashJoin [ss_item_sk,i_item_sk] - Project [ss_item_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,d_year,d_year,ca_street_number,ca_street_name,ca_city,ca_zip,ca_street_number,ca_street_name,ca_city,ca_zip] - BroadcastHashJoin [hd_income_band_sk,ib_income_band_sk] - Project [ss_item_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,d_year,d_year,hd_income_band_sk,ca_street_number,ca_street_name,ca_city,ca_zip,ca_street_number,ca_street_name,ca_city,ca_zip] - BroadcastHashJoin [hd_income_band_sk,ib_income_band_sk] - Project [ss_item_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,d_year,d_year,hd_income_band_sk,hd_income_band_sk,ca_street_number,ca_street_name,ca_city,ca_zip,ca_street_number,ca_street_name,ca_city,ca_zip] - BroadcastHashJoin [c_current_addr_sk,ca_address_sk] - Project [ss_item_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_addr_sk,d_year,d_year,hd_income_band_sk,hd_income_band_sk,ca_street_number,ca_street_name,ca_city,ca_zip] - BroadcastHashJoin [ss_addr_sk,ca_address_sk] - Project [ss_item_sk,ss_addr_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_addr_sk,d_year,d_year,hd_income_band_sk,hd_income_band_sk] - BroadcastHashJoin [c_current_hdemo_sk,hd_demo_sk] - Project [ss_item_sk,ss_addr_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_hdemo_sk,c_current_addr_sk,d_year,d_year,hd_income_band_sk] - BroadcastHashJoin [ss_hdemo_sk,hd_demo_sk] - Project [ss_item_sk,ss_hdemo_sk,ss_addr_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_hdemo_sk,c_current_addr_sk,d_year,d_year] - BroadcastHashJoin [ss_promo_sk,p_promo_sk] - Project [ss_item_sk,ss_hdemo_sk,ss_addr_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_hdemo_sk,c_current_addr_sk,d_year,d_year] - BroadcastHashJoin [c_current_cdemo_sk,cd_demo_sk,cd_marital_status,cd_marital_status] - Project [ss_item_sk,ss_hdemo_sk,ss_addr_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_cdemo_sk,c_current_hdemo_sk,c_current_addr_sk,d_year,d_year,cd_marital_status] - BroadcastHashJoin [ss_cdemo_sk,cd_demo_sk] - Project [ss_item_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_cdemo_sk,c_current_hdemo_sk,c_current_addr_sk,d_year,d_year] - BroadcastHashJoin [c_first_shipto_date_sk,d_date_sk] - Project [ss_item_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_cdemo_sk,c_current_hdemo_sk,c_current_addr_sk,c_first_shipto_date_sk,d_year] - BroadcastHashJoin [c_first_sales_date_sk,d_date_sk] - Project [ss_item_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_cdemo_sk,c_current_hdemo_sk,c_current_addr_sk,c_first_shipto_date_sk,c_first_sales_date_sk] - BroadcastHashJoin [ss_customer_sk,c_customer_sk] - Project [ss_item_sk,ss_customer_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip] - BroadcastHashJoin [ss_store_sk,s_store_sk] - Project [ss_item_sk,ss_customer_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_store_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year] - BroadcastHashJoin [ss_sold_date_sk,d_date_sk] - Project [ss_item_sk,ss_customer_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_store_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,ss_sold_date_sk] - SortMergeJoin [ss_item_sk,cs_item_sk] - InputAdapter - WholeStageCodegen (3) - Sort [ss_item_sk] + InputAdapter + Exchange [item_sk,store_name,store_zip] #2 + WholeStageCodegen (25) + HashAggregate [i_product_name,i_item_sk,s_store_name,s_zip,ca_street_number,ca_street_name,ca_city,ca_zip,ca_street_number,ca_street_name,ca_city,ca_zip,d_year,d_year,d_year,count,sum,sum,sum] [count(1),sum(UnscaledValue(ss_wholesale_cost)),sum(UnscaledValue(ss_list_price)),sum(UnscaledValue(ss_coupon_amt)),product_name,item_sk,store_name,store_zip,b_street_number,b_streen_name,b_city,b_zip,c_street_number,c_street_name,c_city,c_zip,syear,cnt,s1,s2,s3,count,sum,sum,sum] + HashAggregate [i_product_name,i_item_sk,s_store_name,s_zip,ca_street_number,ca_street_name,ca_city,ca_zip,ca_street_number,ca_street_name,ca_city,ca_zip,d_year,d_year,d_year,ss_wholesale_cost,ss_list_price,ss_coupon_amt] [count,sum,sum,sum,count,sum,sum,sum] + Project [ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,d_year,d_year,s_store_name,s_zip,ca_street_number,ca_street_name,ca_city,ca_zip,ca_street_number,ca_street_name,ca_city,ca_zip,i_item_sk,i_product_name] + BroadcastHashJoin [ss_item_sk,i_item_sk] + Project [ss_item_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,d_year,d_year,ca_street_number,ca_street_name,ca_city,ca_zip,ca_street_number,ca_street_name,ca_city,ca_zip] + BroadcastHashJoin [hd_income_band_sk,ib_income_band_sk] + Project [ss_item_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,d_year,d_year,hd_income_band_sk,ca_street_number,ca_street_name,ca_city,ca_zip,ca_street_number,ca_street_name,ca_city,ca_zip] + BroadcastHashJoin [hd_income_band_sk,ib_income_band_sk] + Project [ss_item_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,d_year,d_year,hd_income_band_sk,hd_income_band_sk,ca_street_number,ca_street_name,ca_city,ca_zip,ca_street_number,ca_street_name,ca_city,ca_zip] + BroadcastHashJoin [c_current_addr_sk,ca_address_sk] + Project [ss_item_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_addr_sk,d_year,d_year,hd_income_band_sk,hd_income_band_sk,ca_street_number,ca_street_name,ca_city,ca_zip] + BroadcastHashJoin [ss_addr_sk,ca_address_sk] + Project [ss_item_sk,ss_addr_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_addr_sk,d_year,d_year,hd_income_band_sk,hd_income_band_sk] + BroadcastHashJoin [c_current_hdemo_sk,hd_demo_sk] + Project [ss_item_sk,ss_addr_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_hdemo_sk,c_current_addr_sk,d_year,d_year,hd_income_band_sk] + BroadcastHashJoin [ss_hdemo_sk,hd_demo_sk] + Project [ss_item_sk,ss_hdemo_sk,ss_addr_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_hdemo_sk,c_current_addr_sk,d_year,d_year] + BroadcastHashJoin [ss_promo_sk,p_promo_sk] + Project [ss_item_sk,ss_hdemo_sk,ss_addr_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_hdemo_sk,c_current_addr_sk,d_year,d_year] + BroadcastHashJoin [c_current_cdemo_sk,cd_demo_sk,cd_marital_status,cd_marital_status] + Project [ss_item_sk,ss_hdemo_sk,ss_addr_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_cdemo_sk,c_current_hdemo_sk,c_current_addr_sk,d_year,d_year,cd_marital_status] + BroadcastHashJoin [ss_cdemo_sk,cd_demo_sk] + Project [ss_item_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_cdemo_sk,c_current_hdemo_sk,c_current_addr_sk,d_year,d_year] + BroadcastHashJoin [c_first_shipto_date_sk,d_date_sk] + Project [ss_item_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_cdemo_sk,c_current_hdemo_sk,c_current_addr_sk,c_first_shipto_date_sk,d_year] + BroadcastHashJoin [c_first_sales_date_sk,d_date_sk] + Project [ss_item_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_cdemo_sk,c_current_hdemo_sk,c_current_addr_sk,c_first_shipto_date_sk,c_first_sales_date_sk] + BroadcastHashJoin [ss_customer_sk,c_customer_sk] + Project [ss_item_sk,ss_customer_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip] + BroadcastHashJoin [ss_store_sk,s_store_sk] + Project [ss_item_sk,ss_customer_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_store_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year] + BroadcastHashJoin [ss_sold_date_sk,d_date_sk] + Project [ss_item_sk,ss_customer_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_store_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,ss_sold_date_sk] + SortMergeJoin [ss_item_sk,cs_item_sk] InputAdapter - Exchange [ss_item_sk] #2 - WholeStageCodegen (2) - Project [ss_item_sk,ss_customer_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_store_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,ss_sold_date_sk] - BroadcastHashJoin [ss_item_sk,ss_ticket_number,sr_item_sk,sr_ticket_number] - InputAdapter - BroadcastExchange #3 - WholeStageCodegen (1) - Filter [ss_item_sk,ss_ticket_number,ss_store_sk,ss_customer_sk,ss_cdemo_sk,ss_promo_sk,ss_hdemo_sk,ss_addr_sk] - ColumnarToRow - InputAdapter - Scan parquet default.store_sales [ss_item_sk,ss_customer_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_store_sk,ss_promo_sk,ss_ticket_number,ss_wholesale_cost,ss_list_price,ss_coupon_amt,ss_sold_date_sk] - SubqueryBroadcast [d_date_sk] #1 - BroadcastExchange #4 - WholeStageCodegen (1) - Filter [d_year,d_date_sk] - ColumnarToRow - InputAdapter - Scan parquet default.date_dim [d_date_sk,d_year] - Project [sr_item_sk,sr_ticket_number] - Filter [sr_item_sk,sr_ticket_number] - ColumnarToRow - InputAdapter - Scan parquet default.store_returns [sr_item_sk,sr_ticket_number,sr_returned_date_sk] - InputAdapter - WholeStageCodegen (9) - Sort [cs_item_sk] - Project [cs_item_sk] - Filter [sale,refund] - HashAggregate [cs_item_sk,sum,sum,isEmpty] [sum(UnscaledValue(cs_ext_list_price)),sum(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cr_refunded_cash as decimal(8,2))) + promote_precision(cast(cr_reversed_charge as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) + promote_precision(cast(cr_store_credit as decimal(9,2)))), DecimalType(9,2), true)),sale,refund,sum,sum,isEmpty] + WholeStageCodegen (3) + Sort [ss_item_sk] InputAdapter - Exchange [cs_item_sk] #5 - WholeStageCodegen (8) - HashAggregate [cs_item_sk,cs_ext_list_price,cr_refunded_cash,cr_reversed_charge,cr_store_credit] [sum,sum,isEmpty,sum,sum,isEmpty] - Project [cs_item_sk,cs_ext_list_price,cr_refunded_cash,cr_reversed_charge,cr_store_credit] - SortMergeJoin [cs_item_sk,cs_order_number,cr_item_sk,cr_order_number] - InputAdapter - WholeStageCodegen (5) - Sort [cs_item_sk,cs_order_number] + Exchange [ss_item_sk] #3 + WholeStageCodegen (2) + Project [ss_item_sk,ss_customer_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_store_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,ss_sold_date_sk] + BroadcastHashJoin [ss_item_sk,ss_ticket_number,sr_item_sk,sr_ticket_number] + InputAdapter + BroadcastExchange #4 + WholeStageCodegen (1) + Filter [ss_item_sk,ss_ticket_number,ss_store_sk,ss_customer_sk,ss_cdemo_sk,ss_promo_sk,ss_hdemo_sk,ss_addr_sk] + ColumnarToRow + InputAdapter + Scan parquet default.store_sales [ss_item_sk,ss_customer_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_store_sk,ss_promo_sk,ss_ticket_number,ss_wholesale_cost,ss_list_price,ss_coupon_amt,ss_sold_date_sk] + SubqueryBroadcast [d_date_sk] #1 + BroadcastExchange #5 + WholeStageCodegen (1) + Filter [d_year,d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_year] + Project [sr_item_sk,sr_ticket_number] + Filter [sr_item_sk,sr_ticket_number] + ColumnarToRow + InputAdapter + Scan parquet default.store_returns [sr_item_sk,sr_ticket_number,sr_returned_date_sk] + InputAdapter + WholeStageCodegen (9) + Sort [cs_item_sk] + Project [cs_item_sk] + Filter [sale,refund] + HashAggregate [cs_item_sk,sum,sum,isEmpty] [sum(UnscaledValue(cs_ext_list_price)),sum(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cr_refunded_cash as decimal(8,2))) + promote_precision(cast(cr_reversed_charge as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) + promote_precision(cast(cr_store_credit as decimal(9,2)))), DecimalType(9,2))),sale,refund,sum,sum,isEmpty] + InputAdapter + Exchange [cs_item_sk] #6 + WholeStageCodegen (8) + HashAggregate [cs_item_sk,cs_ext_list_price,cr_refunded_cash,cr_reversed_charge,cr_store_credit] [sum,sum,isEmpty,sum,sum,isEmpty] + Project [cs_item_sk,cs_ext_list_price,cr_refunded_cash,cr_reversed_charge,cr_store_credit] + SortMergeJoin [cs_item_sk,cs_order_number,cr_item_sk,cr_order_number] InputAdapter - Exchange [cs_item_sk,cs_order_number] #6 - WholeStageCodegen (4) - Project [cs_item_sk,cs_order_number,cs_ext_list_price] - Filter [cs_item_sk,cs_order_number] - ColumnarToRow - InputAdapter - Scan parquet default.catalog_sales [cs_item_sk,cs_order_number,cs_ext_list_price,cs_sold_date_sk] - InputAdapter - WholeStageCodegen (7) - Sort [cr_item_sk,cr_order_number] + WholeStageCodegen (5) + Sort [cs_item_sk,cs_order_number] + InputAdapter + Exchange [cs_item_sk,cs_order_number] #7 + WholeStageCodegen (4) + Project [cs_item_sk,cs_order_number,cs_ext_list_price] + Filter [cs_item_sk,cs_order_number] + ColumnarToRow + InputAdapter + Scan parquet default.catalog_sales [cs_item_sk,cs_order_number,cs_ext_list_price,cs_sold_date_sk] InputAdapter - Exchange [cr_item_sk,cr_order_number] #7 - WholeStageCodegen (6) - Project [cr_item_sk,cr_order_number,cr_refunded_cash,cr_reversed_charge,cr_store_credit] - Filter [cr_item_sk,cr_order_number] - ColumnarToRow - InputAdapter - Scan parquet default.catalog_returns [cr_item_sk,cr_order_number,cr_refunded_cash,cr_reversed_charge,cr_store_credit,cr_returned_date_sk] - InputAdapter - ReusedExchange [d_date_sk,d_year] #4 - InputAdapter - BroadcastExchange #8 - WholeStageCodegen (11) - Filter [s_store_sk,s_store_name,s_zip] - ColumnarToRow + WholeStageCodegen (7) + Sort [cr_item_sk,cr_order_number] + InputAdapter + Exchange [cr_item_sk,cr_order_number] #8 + WholeStageCodegen (6) + Project [cr_item_sk,cr_order_number,cr_refunded_cash,cr_reversed_charge,cr_store_credit] + Filter [cr_item_sk,cr_order_number] + ColumnarToRow + InputAdapter + Scan parquet default.catalog_returns [cr_item_sk,cr_order_number,cr_refunded_cash,cr_reversed_charge,cr_store_credit,cr_returned_date_sk] InputAdapter - Scan parquet default.store [s_store_sk,s_store_name,s_zip] - InputAdapter - BroadcastExchange #9 - WholeStageCodegen (12) - Filter [c_customer_sk,c_first_sales_date_sk,c_first_shipto_date_sk,c_current_cdemo_sk,c_current_hdemo_sk,c_current_addr_sk] - ColumnarToRow + ReusedExchange [d_date_sk,d_year] #5 InputAdapter - Scan parquet default.customer [c_customer_sk,c_current_cdemo_sk,c_current_hdemo_sk,c_current_addr_sk,c_first_shipto_date_sk,c_first_sales_date_sk] - InputAdapter - BroadcastExchange #10 - WholeStageCodegen (13) - Filter [d_date_sk] - ColumnarToRow + BroadcastExchange #9 + WholeStageCodegen (11) + Filter [s_store_sk,s_store_name,s_zip] + ColumnarToRow + InputAdapter + Scan parquet default.store [s_store_sk,s_store_name,s_zip] InputAdapter - Scan parquet default.date_dim [d_date_sk,d_year] - InputAdapter - ReusedExchange [d_date_sk,d_year] #10 - InputAdapter - BroadcastExchange #11 - WholeStageCodegen (15) - Filter [cd_demo_sk,cd_marital_status] - ColumnarToRow + BroadcastExchange #10 + WholeStageCodegen (12) + Filter [c_customer_sk,c_first_sales_date_sk,c_first_shipto_date_sk,c_current_cdemo_sk,c_current_hdemo_sk,c_current_addr_sk] + ColumnarToRow + InputAdapter + Scan parquet default.customer [c_customer_sk,c_current_cdemo_sk,c_current_hdemo_sk,c_current_addr_sk,c_first_shipto_date_sk,c_first_sales_date_sk] + InputAdapter + BroadcastExchange #11 + WholeStageCodegen (13) + Filter [d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_year] InputAdapter - Scan parquet default.customer_demographics [cd_demo_sk,cd_marital_status] - InputAdapter - ReusedExchange [cd_demo_sk,cd_marital_status] #11 - InputAdapter - BroadcastExchange #12 - WholeStageCodegen (17) - Filter [p_promo_sk] - ColumnarToRow + ReusedExchange [d_date_sk,d_year] #11 + InputAdapter + BroadcastExchange #12 + WholeStageCodegen (15) + Filter [cd_demo_sk,cd_marital_status] + ColumnarToRow + InputAdapter + Scan parquet default.customer_demographics [cd_demo_sk,cd_marital_status] InputAdapter - Scan parquet default.promotion [p_promo_sk] - InputAdapter - BroadcastExchange #13 - WholeStageCodegen (18) - Filter [hd_demo_sk,hd_income_band_sk] - ColumnarToRow + ReusedExchange [cd_demo_sk,cd_marital_status] #12 InputAdapter - Scan parquet default.household_demographics [hd_demo_sk,hd_income_band_sk] - InputAdapter - ReusedExchange [hd_demo_sk,hd_income_band_sk] #13 - InputAdapter - BroadcastExchange #14 - WholeStageCodegen (20) - Filter [ca_address_sk] - ColumnarToRow + BroadcastExchange #13 + WholeStageCodegen (17) + Filter [p_promo_sk] + ColumnarToRow + InputAdapter + Scan parquet default.promotion [p_promo_sk] + InputAdapter + BroadcastExchange #14 + WholeStageCodegen (18) + Filter [hd_demo_sk,hd_income_band_sk] + ColumnarToRow + InputAdapter + Scan parquet default.household_demographics [hd_demo_sk,hd_income_band_sk] InputAdapter - Scan parquet default.customer_address [ca_address_sk,ca_street_number,ca_street_name,ca_city,ca_zip] - InputAdapter - ReusedExchange [ca_address_sk,ca_street_number,ca_street_name,ca_city,ca_zip] #14 - InputAdapter - BroadcastExchange #15 - WholeStageCodegen (22) - Filter [ib_income_band_sk] + ReusedExchange [hd_demo_sk,hd_income_band_sk] #14 + InputAdapter + BroadcastExchange #15 + WholeStageCodegen (20) + Filter [ca_address_sk] + ColumnarToRow + InputAdapter + Scan parquet default.customer_address [ca_address_sk,ca_street_number,ca_street_name,ca_city,ca_zip] + InputAdapter + ReusedExchange [ca_address_sk,ca_street_number,ca_street_name,ca_city,ca_zip] #15 + InputAdapter + BroadcastExchange #16 + WholeStageCodegen (22) + Filter [ib_income_band_sk] + ColumnarToRow + InputAdapter + Scan parquet default.income_band [ib_income_band_sk] + InputAdapter + ReusedExchange [ib_income_band_sk] #16 + InputAdapter + BroadcastExchange #17 + WholeStageCodegen (24) + Project [i_item_sk,i_product_name] + Filter [i_current_price,i_color,i_item_sk] ColumnarToRow InputAdapter - Scan parquet default.income_band [ib_income_band_sk] - InputAdapter - ReusedExchange [ib_income_band_sk] #15 - InputAdapter - BroadcastExchange #16 - WholeStageCodegen (24) - Project [i_item_sk,i_product_name] - Filter [i_current_price,i_color,i_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.item [i_item_sk,i_current_price,i_color,i_product_name] + Scan parquet default.item [i_item_sk,i_current_price,i_color,i_product_name] InputAdapter - WholeStageCodegen (50) + WholeStageCodegen (52) Sort [item_sk,store_name,store_zip] - HashAggregate [i_product_name,i_item_sk,s_store_name,s_zip,ca_street_number,ca_street_name,ca_city,ca_zip,ca_street_number,ca_street_name,ca_city,ca_zip,d_year,d_year,d_year,count,sum,sum,sum] [count(1),sum(UnscaledValue(ss_wholesale_cost)),sum(UnscaledValue(ss_list_price)),sum(UnscaledValue(ss_coupon_amt)),item_sk,store_name,store_zip,syear,cnt,s1,s2,s3,count,sum,sum,sum] - HashAggregate [i_product_name,i_item_sk,s_store_name,s_zip,ca_street_number,ca_street_name,ca_city,ca_zip,ca_street_number,ca_street_name,ca_city,ca_zip,d_year,d_year,d_year,ss_wholesale_cost,ss_list_price,ss_coupon_amt] [count,sum,sum,sum,count,sum,sum,sum] - Project [ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,d_year,d_year,s_store_name,s_zip,ca_street_number,ca_street_name,ca_city,ca_zip,ca_street_number,ca_street_name,ca_city,ca_zip,i_item_sk,i_product_name] - BroadcastHashJoin [ss_item_sk,i_item_sk] - Project [ss_item_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,d_year,d_year,ca_street_number,ca_street_name,ca_city,ca_zip,ca_street_number,ca_street_name,ca_city,ca_zip] - BroadcastHashJoin [hd_income_band_sk,ib_income_band_sk] - Project [ss_item_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,d_year,d_year,hd_income_band_sk,ca_street_number,ca_street_name,ca_city,ca_zip,ca_street_number,ca_street_name,ca_city,ca_zip] - BroadcastHashJoin [hd_income_band_sk,ib_income_band_sk] - Project [ss_item_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,d_year,d_year,hd_income_band_sk,hd_income_band_sk,ca_street_number,ca_street_name,ca_city,ca_zip,ca_street_number,ca_street_name,ca_city,ca_zip] - BroadcastHashJoin [c_current_addr_sk,ca_address_sk] - Project [ss_item_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_addr_sk,d_year,d_year,hd_income_band_sk,hd_income_band_sk,ca_street_number,ca_street_name,ca_city,ca_zip] - BroadcastHashJoin [ss_addr_sk,ca_address_sk] - Project [ss_item_sk,ss_addr_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_addr_sk,d_year,d_year,hd_income_band_sk,hd_income_band_sk] - BroadcastHashJoin [c_current_hdemo_sk,hd_demo_sk] - Project [ss_item_sk,ss_addr_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_hdemo_sk,c_current_addr_sk,d_year,d_year,hd_income_band_sk] - BroadcastHashJoin [ss_hdemo_sk,hd_demo_sk] - Project [ss_item_sk,ss_hdemo_sk,ss_addr_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_hdemo_sk,c_current_addr_sk,d_year,d_year] - BroadcastHashJoin [ss_promo_sk,p_promo_sk] - Project [ss_item_sk,ss_hdemo_sk,ss_addr_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_hdemo_sk,c_current_addr_sk,d_year,d_year] - BroadcastHashJoin [c_current_cdemo_sk,cd_demo_sk,cd_marital_status,cd_marital_status] - Project [ss_item_sk,ss_hdemo_sk,ss_addr_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_cdemo_sk,c_current_hdemo_sk,c_current_addr_sk,d_year,d_year,cd_marital_status] - BroadcastHashJoin [ss_cdemo_sk,cd_demo_sk] - Project [ss_item_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_cdemo_sk,c_current_hdemo_sk,c_current_addr_sk,d_year,d_year] - BroadcastHashJoin [c_first_shipto_date_sk,d_date_sk] - Project [ss_item_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_cdemo_sk,c_current_hdemo_sk,c_current_addr_sk,c_first_shipto_date_sk,d_year] - BroadcastHashJoin [c_first_sales_date_sk,d_date_sk] - Project [ss_item_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_cdemo_sk,c_current_hdemo_sk,c_current_addr_sk,c_first_shipto_date_sk,c_first_sales_date_sk] - BroadcastHashJoin [ss_customer_sk,c_customer_sk] - Project [ss_item_sk,ss_customer_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip] - BroadcastHashJoin [ss_store_sk,s_store_sk] - Project [ss_item_sk,ss_customer_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_store_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year] - BroadcastHashJoin [ss_sold_date_sk,d_date_sk] - Project [ss_item_sk,ss_customer_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_store_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,ss_sold_date_sk] - SortMergeJoin [ss_item_sk,cs_item_sk] - InputAdapter - WholeStageCodegen (28) - Sort [ss_item_sk] + InputAdapter + Exchange [item_sk,store_name,store_zip] #18 + WholeStageCodegen (51) + HashAggregate [i_product_name,i_item_sk,s_store_name,s_zip,ca_street_number,ca_street_name,ca_city,ca_zip,ca_street_number,ca_street_name,ca_city,ca_zip,d_year,d_year,d_year,count,sum,sum,sum] [count(1),sum(UnscaledValue(ss_wholesale_cost)),sum(UnscaledValue(ss_list_price)),sum(UnscaledValue(ss_coupon_amt)),item_sk,store_name,store_zip,syear,cnt,s1,s2,s3,count,sum,sum,sum] + HashAggregate [i_product_name,i_item_sk,s_store_name,s_zip,ca_street_number,ca_street_name,ca_city,ca_zip,ca_street_number,ca_street_name,ca_city,ca_zip,d_year,d_year,d_year,ss_wholesale_cost,ss_list_price,ss_coupon_amt] [count,sum,sum,sum,count,sum,sum,sum] + Project [ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,d_year,d_year,s_store_name,s_zip,ca_street_number,ca_street_name,ca_city,ca_zip,ca_street_number,ca_street_name,ca_city,ca_zip,i_item_sk,i_product_name] + BroadcastHashJoin [ss_item_sk,i_item_sk] + Project [ss_item_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,d_year,d_year,ca_street_number,ca_street_name,ca_city,ca_zip,ca_street_number,ca_street_name,ca_city,ca_zip] + BroadcastHashJoin [hd_income_band_sk,ib_income_band_sk] + Project [ss_item_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,d_year,d_year,hd_income_band_sk,ca_street_number,ca_street_name,ca_city,ca_zip,ca_street_number,ca_street_name,ca_city,ca_zip] + BroadcastHashJoin [hd_income_band_sk,ib_income_band_sk] + Project [ss_item_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,d_year,d_year,hd_income_band_sk,hd_income_band_sk,ca_street_number,ca_street_name,ca_city,ca_zip,ca_street_number,ca_street_name,ca_city,ca_zip] + BroadcastHashJoin [c_current_addr_sk,ca_address_sk] + Project [ss_item_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_addr_sk,d_year,d_year,hd_income_band_sk,hd_income_band_sk,ca_street_number,ca_street_name,ca_city,ca_zip] + BroadcastHashJoin [ss_addr_sk,ca_address_sk] + Project [ss_item_sk,ss_addr_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_addr_sk,d_year,d_year,hd_income_band_sk,hd_income_band_sk] + BroadcastHashJoin [c_current_hdemo_sk,hd_demo_sk] + Project [ss_item_sk,ss_addr_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_hdemo_sk,c_current_addr_sk,d_year,d_year,hd_income_band_sk] + BroadcastHashJoin [ss_hdemo_sk,hd_demo_sk] + Project [ss_item_sk,ss_hdemo_sk,ss_addr_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_hdemo_sk,c_current_addr_sk,d_year,d_year] + BroadcastHashJoin [ss_promo_sk,p_promo_sk] + Project [ss_item_sk,ss_hdemo_sk,ss_addr_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_hdemo_sk,c_current_addr_sk,d_year,d_year] + BroadcastHashJoin [c_current_cdemo_sk,cd_demo_sk,cd_marital_status,cd_marital_status] + Project [ss_item_sk,ss_hdemo_sk,ss_addr_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_cdemo_sk,c_current_hdemo_sk,c_current_addr_sk,d_year,d_year,cd_marital_status] + BroadcastHashJoin [ss_cdemo_sk,cd_demo_sk] + Project [ss_item_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_cdemo_sk,c_current_hdemo_sk,c_current_addr_sk,d_year,d_year] + BroadcastHashJoin [c_first_shipto_date_sk,d_date_sk] + Project [ss_item_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_cdemo_sk,c_current_hdemo_sk,c_current_addr_sk,c_first_shipto_date_sk,d_year] + BroadcastHashJoin [c_first_sales_date_sk,d_date_sk] + Project [ss_item_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip,c_current_cdemo_sk,c_current_hdemo_sk,c_current_addr_sk,c_first_shipto_date_sk,c_first_sales_date_sk] + BroadcastHashJoin [ss_customer_sk,c_customer_sk] + Project [ss_item_sk,ss_customer_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year,s_store_name,s_zip] + BroadcastHashJoin [ss_store_sk,s_store_sk] + Project [ss_item_sk,ss_customer_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_store_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,d_year] + BroadcastHashJoin [ss_sold_date_sk,d_date_sk] + Project [ss_item_sk,ss_customer_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_store_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,ss_sold_date_sk] + SortMergeJoin [ss_item_sk,cs_item_sk] InputAdapter - Exchange [ss_item_sk] #17 - WholeStageCodegen (27) - Project [ss_item_sk,ss_customer_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_store_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,ss_sold_date_sk] - BroadcastHashJoin [ss_item_sk,ss_ticket_number,sr_item_sk,sr_ticket_number] - InputAdapter - BroadcastExchange #18 - WholeStageCodegen (26) - Filter [ss_item_sk,ss_ticket_number,ss_store_sk,ss_customer_sk,ss_cdemo_sk,ss_promo_sk,ss_hdemo_sk,ss_addr_sk] - ColumnarToRow - InputAdapter - Scan parquet default.store_sales [ss_item_sk,ss_customer_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_store_sk,ss_promo_sk,ss_ticket_number,ss_wholesale_cost,ss_list_price,ss_coupon_amt,ss_sold_date_sk] - SubqueryBroadcast [d_date_sk] #2 - BroadcastExchange #19 - WholeStageCodegen (1) - Filter [d_year,d_date_sk] - ColumnarToRow - InputAdapter - Scan parquet default.date_dim [d_date_sk,d_year] - Project [sr_item_sk,sr_ticket_number] - Filter [sr_item_sk,sr_ticket_number] - ColumnarToRow - InputAdapter - Scan parquet default.store_returns [sr_item_sk,sr_ticket_number,sr_returned_date_sk] - InputAdapter - WholeStageCodegen (34) - Sort [cs_item_sk] - Project [cs_item_sk] - Filter [sale,refund] - HashAggregate [cs_item_sk,sum,sum,isEmpty] [sum(UnscaledValue(cs_ext_list_price)),sum(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cr_refunded_cash as decimal(8,2))) + promote_precision(cast(cr_reversed_charge as decimal(8,2)))), DecimalType(8,2), true) as decimal(9,2))) + promote_precision(cast(cr_store_credit as decimal(9,2)))), DecimalType(9,2), true)),sale,refund,sum,sum,isEmpty] + WholeStageCodegen (29) + Sort [ss_item_sk] InputAdapter - ReusedExchange [cs_item_sk,sum,sum,isEmpty] #5 - InputAdapter - ReusedExchange [d_date_sk,d_year] #19 - InputAdapter - ReusedExchange [s_store_sk,s_store_name,s_zip] #8 - InputAdapter - ReusedExchange [c_customer_sk,c_current_cdemo_sk,c_current_hdemo_sk,c_current_addr_sk,c_first_shipto_date_sk,c_first_sales_date_sk] #9 - InputAdapter - ReusedExchange [d_date_sk,d_year] #10 - InputAdapter - ReusedExchange [d_date_sk,d_year] #10 - InputAdapter - ReusedExchange [cd_demo_sk,cd_marital_status] #11 - InputAdapter - ReusedExchange [cd_demo_sk,cd_marital_status] #11 - InputAdapter - ReusedExchange [p_promo_sk] #12 - InputAdapter - ReusedExchange [hd_demo_sk,hd_income_band_sk] #13 - InputAdapter - ReusedExchange [hd_demo_sk,hd_income_band_sk] #13 - InputAdapter - ReusedExchange [ca_address_sk,ca_street_number,ca_street_name,ca_city,ca_zip] #14 - InputAdapter - ReusedExchange [ca_address_sk,ca_street_number,ca_street_name,ca_city,ca_zip] #14 - InputAdapter - ReusedExchange [ib_income_band_sk] #15 - InputAdapter - ReusedExchange [ib_income_band_sk] #15 - InputAdapter - ReusedExchange [i_item_sk,i_product_name] #16 + Exchange [ss_item_sk] #19 + WholeStageCodegen (28) + Project [ss_item_sk,ss_customer_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_store_sk,ss_promo_sk,ss_wholesale_cost,ss_list_price,ss_coupon_amt,ss_sold_date_sk] + BroadcastHashJoin [ss_item_sk,ss_ticket_number,sr_item_sk,sr_ticket_number] + InputAdapter + BroadcastExchange #20 + WholeStageCodegen (27) + Filter [ss_item_sk,ss_ticket_number,ss_store_sk,ss_customer_sk,ss_cdemo_sk,ss_promo_sk,ss_hdemo_sk,ss_addr_sk] + ColumnarToRow + InputAdapter + Scan parquet default.store_sales [ss_item_sk,ss_customer_sk,ss_cdemo_sk,ss_hdemo_sk,ss_addr_sk,ss_store_sk,ss_promo_sk,ss_ticket_number,ss_wholesale_cost,ss_list_price,ss_coupon_amt,ss_sold_date_sk] + SubqueryBroadcast [d_date_sk] #2 + BroadcastExchange #21 + WholeStageCodegen (1) + Filter [d_year,d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_year] + Project [sr_item_sk,sr_ticket_number] + Filter [sr_item_sk,sr_ticket_number] + ColumnarToRow + InputAdapter + Scan parquet default.store_returns [sr_item_sk,sr_ticket_number,sr_returned_date_sk] + InputAdapter + WholeStageCodegen (35) + Sort [cs_item_sk] + Project [cs_item_sk] + Filter [sale,refund] + HashAggregate [cs_item_sk,sum,sum,isEmpty] [sum(UnscaledValue(cs_ext_list_price)),sum(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(cr_refunded_cash as decimal(8,2))) + promote_precision(cast(cr_reversed_charge as decimal(8,2)))), DecimalType(8,2)) as decimal(9,2))) + promote_precision(cast(cr_store_credit as decimal(9,2)))), DecimalType(9,2))),sale,refund,sum,sum,isEmpty] + InputAdapter + ReusedExchange [cs_item_sk,sum,sum,isEmpty] #6 + InputAdapter + ReusedExchange [d_date_sk,d_year] #21 + InputAdapter + ReusedExchange [s_store_sk,s_store_name,s_zip] #9 + InputAdapter + ReusedExchange [c_customer_sk,c_current_cdemo_sk,c_current_hdemo_sk,c_current_addr_sk,c_first_shipto_date_sk,c_first_sales_date_sk] #10 + InputAdapter + ReusedExchange [d_date_sk,d_year] #11 + InputAdapter + ReusedExchange [d_date_sk,d_year] #11 + InputAdapter + ReusedExchange [cd_demo_sk,cd_marital_status] #12 + InputAdapter + ReusedExchange [cd_demo_sk,cd_marital_status] #12 + InputAdapter + ReusedExchange [p_promo_sk] #13 + InputAdapter + ReusedExchange [hd_demo_sk,hd_income_band_sk] #14 + InputAdapter + ReusedExchange [hd_demo_sk,hd_income_band_sk] #14 + InputAdapter + ReusedExchange [ca_address_sk,ca_street_number,ca_street_name,ca_city,ca_zip] #15 + InputAdapter + ReusedExchange [ca_address_sk,ca_street_number,ca_street_name,ca_city,ca_zip] #15 + InputAdapter + ReusedExchange [ib_income_band_sk] #16 + InputAdapter + ReusedExchange [ib_income_band_sk] #16 + InputAdapter + ReusedExchange [i_item_sk,i_product_name] #17 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/explain.txt index b0ecc08ff8b25..00d9676dc2ec9 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/explain.txt @@ -167,7 +167,7 @@ Input [12]: [ss_item_sk#1, ss_quantity#3, ss_sales_price#4, d_year#8, d_moy#9, d (22) HashAggregate [codegen id : 7] Input [10]: [ss_quantity#3, ss_sales_price#4, d_year#8, d_moy#9, d_qoy#10, s_store_id#12, i_brand#16, i_class#17, i_category#18, i_product_name#19] Keys [8]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, d_qoy#10, d_moy#9, s_store_id#12] -Functions [1]: [partial_sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))] +Functions [1]: [partial_sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))] Aggregate Attributes [2]: [sum#21, isEmpty#22] Results [10]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sum#23, isEmpty#24] @@ -178,9 +178,9 @@ Arguments: hashpartitioning(i_category#18, i_class#17, i_brand#16, i_product_nam (24) HashAggregate [codegen id : 8] Input [10]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sum#23, isEmpty#24] Keys [8]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, d_qoy#10, d_moy#9, s_store_id#12] -Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))] -Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#26] -Results [9]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, cast(sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#26 as decimal(38,2)) AS sumsales#27] +Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))] +Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#26] +Results [9]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, cast(sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#26 as decimal(38,2)) AS sumsales#27] (25) ReusedExchange [Reuses operator id: 23] Output [10]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sum#28, isEmpty#29] @@ -188,9 +188,9 @@ Output [10]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8 (26) HashAggregate [codegen id : 16] Input [10]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sum#28, isEmpty#29] Keys [8]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, d_qoy#10, d_moy#9, s_store_id#12] -Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))] -Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#26] -Results [8]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, d_qoy#10, d_moy#9, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#26 AS sumsales#30] +Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))] +Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#26] +Results [8]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, d_qoy#10, d_moy#9, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#26 AS sumsales#30] (27) HashAggregate [codegen id : 16] Input [8]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, d_qoy#10, d_moy#9, sumsales#30] @@ -216,9 +216,9 @@ Output [10]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8 (31) HashAggregate [codegen id : 25] Input [10]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sum#39, isEmpty#40] Keys [8]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, d_qoy#10, d_moy#9, s_store_id#12] -Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))] -Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#26] -Results [7]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, d_qoy#10, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#26 AS sumsales#30] +Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))] +Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#26] +Results [7]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, d_qoy#10, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#26 AS sumsales#30] (32) HashAggregate [codegen id : 25] Input [7]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, d_qoy#10, sumsales#30] @@ -244,9 +244,9 @@ Output [10]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8 (36) HashAggregate [codegen id : 34] Input [10]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sum#50, isEmpty#51] Keys [8]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, d_qoy#10, d_moy#9, s_store_id#12] -Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))] -Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#26] -Results [6]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#26 AS sumsales#30] +Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))] +Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#26] +Results [6]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#26 AS sumsales#30] (37) HashAggregate [codegen id : 34] Input [6]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, sumsales#30] @@ -272,9 +272,9 @@ Output [10]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8 (41) HashAggregate [codegen id : 43] Input [10]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sum#62, isEmpty#63] Keys [8]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, d_qoy#10, d_moy#9, s_store_id#12] -Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))] -Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#26] -Results [5]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#26 AS sumsales#30] +Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))] +Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#26] +Results [5]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#26 AS sumsales#30] (42) HashAggregate [codegen id : 43] Input [5]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, sumsales#30] @@ -300,9 +300,9 @@ Output [10]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8 (46) HashAggregate [codegen id : 52] Input [10]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sum#75, isEmpty#76] Keys [8]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, d_qoy#10, d_moy#9, s_store_id#12] -Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))] -Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#26] -Results [4]: [i_category#18, i_class#17, i_brand#16, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#26 AS sumsales#30] +Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))] +Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#26] +Results [4]: [i_category#18, i_class#17, i_brand#16, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#26 AS sumsales#30] (47) HashAggregate [codegen id : 52] Input [4]: [i_category#18, i_class#17, i_brand#16, sumsales#30] @@ -328,9 +328,9 @@ Output [10]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8 (51) HashAggregate [codegen id : 61] Input [10]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sum#89, isEmpty#90] Keys [8]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, d_qoy#10, d_moy#9, s_store_id#12] -Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))] -Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#26] -Results [3]: [i_category#18, i_class#17, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#26 AS sumsales#30] +Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))] +Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#26] +Results [3]: [i_category#18, i_class#17, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#26 AS sumsales#30] (52) HashAggregate [codegen id : 61] Input [3]: [i_category#18, i_class#17, sumsales#30] @@ -356,9 +356,9 @@ Output [10]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8 (56) HashAggregate [codegen id : 70] Input [10]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sum#104, isEmpty#105] Keys [8]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, d_qoy#10, d_moy#9, s_store_id#12] -Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))] -Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#26] -Results [2]: [i_category#18, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#26 AS sumsales#30] +Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))] +Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#26] +Results [2]: [i_category#18, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#26 AS sumsales#30] (57) HashAggregate [codegen id : 70] Input [2]: [i_category#18, sumsales#30] @@ -384,9 +384,9 @@ Output [10]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8 (61) HashAggregate [codegen id : 79] Input [10]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sum#120, isEmpty#121] Keys [8]: [i_category#18, i_class#17, i_brand#16, i_product_name#19, d_year#8, d_qoy#10, d_moy#9, s_store_id#12] -Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))] -Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#26] -Results [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#26 AS sumsales#30] +Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))] +Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#26] +Results [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#26 AS sumsales#30] (62) HashAggregate [codegen id : 79] Input [1]: [sumsales#30] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/simplified.txt index ef75e80bde2a5..8b39e27c4ca40 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/simplified.txt @@ -9,7 +9,7 @@ TakeOrderedAndProject [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_ Exchange [i_category] #1 Union WholeStageCodegen (8) - HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00)),sumsales,sum,isEmpty] + HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(ss_quantity as decimal(12,2)))), DecimalType(18,2)), 0.00)),sumsales,sum,isEmpty] InputAdapter Exchange [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id] #2 WholeStageCodegen (7) @@ -63,7 +63,7 @@ TakeOrderedAndProject [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_ Exchange [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy] #7 WholeStageCodegen (16) HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,sumsales] [sum,isEmpty,sum,isEmpty] - HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00)),sumsales,sum,isEmpty] + HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(ss_quantity as decimal(12,2)))), DecimalType(18,2)), 0.00)),sumsales,sum,isEmpty] InputAdapter ReusedExchange [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] #2 WholeStageCodegen (26) @@ -72,7 +72,7 @@ TakeOrderedAndProject [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_ Exchange [i_category,i_class,i_brand,i_product_name,d_year,d_qoy] #8 WholeStageCodegen (25) HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,sumsales] [sum,isEmpty,sum,isEmpty] - HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00)),sumsales,sum,isEmpty] + HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(ss_quantity as decimal(12,2)))), DecimalType(18,2)), 0.00)),sumsales,sum,isEmpty] InputAdapter ReusedExchange [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] #2 WholeStageCodegen (35) @@ -81,7 +81,7 @@ TakeOrderedAndProject [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_ Exchange [i_category,i_class,i_brand,i_product_name,d_year] #9 WholeStageCodegen (34) HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,sumsales] [sum,isEmpty,sum,isEmpty] - HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00)),sumsales,sum,isEmpty] + HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(ss_quantity as decimal(12,2)))), DecimalType(18,2)), 0.00)),sumsales,sum,isEmpty] InputAdapter ReusedExchange [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] #2 WholeStageCodegen (44) @@ -90,7 +90,7 @@ TakeOrderedAndProject [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_ Exchange [i_category,i_class,i_brand,i_product_name] #10 WholeStageCodegen (43) HashAggregate [i_category,i_class,i_brand,i_product_name,sumsales] [sum,isEmpty,sum,isEmpty] - HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00)),sumsales,sum,isEmpty] + HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(ss_quantity as decimal(12,2)))), DecimalType(18,2)), 0.00)),sumsales,sum,isEmpty] InputAdapter ReusedExchange [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] #2 WholeStageCodegen (53) @@ -99,7 +99,7 @@ TakeOrderedAndProject [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_ Exchange [i_category,i_class,i_brand] #11 WholeStageCodegen (52) HashAggregate [i_category,i_class,i_brand,sumsales] [sum,isEmpty,sum,isEmpty] - HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00)),sumsales,sum,isEmpty] + HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(ss_quantity as decimal(12,2)))), DecimalType(18,2)), 0.00)),sumsales,sum,isEmpty] InputAdapter ReusedExchange [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] #2 WholeStageCodegen (62) @@ -108,7 +108,7 @@ TakeOrderedAndProject [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_ Exchange [i_category,i_class] #12 WholeStageCodegen (61) HashAggregate [i_category,i_class,sumsales] [sum,isEmpty,sum,isEmpty] - HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00)),sumsales,sum,isEmpty] + HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(ss_quantity as decimal(12,2)))), DecimalType(18,2)), 0.00)),sumsales,sum,isEmpty] InputAdapter ReusedExchange [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] #2 WholeStageCodegen (71) @@ -117,7 +117,7 @@ TakeOrderedAndProject [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_ Exchange [i_category] #13 WholeStageCodegen (70) HashAggregate [i_category,sumsales] [sum,isEmpty,sum,isEmpty] - HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00)),sumsales,sum,isEmpty] + HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(ss_quantity as decimal(12,2)))), DecimalType(18,2)), 0.00)),sumsales,sum,isEmpty] InputAdapter ReusedExchange [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] #2 WholeStageCodegen (80) @@ -126,6 +126,6 @@ TakeOrderedAndProject [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_ Exchange #14 WholeStageCodegen (79) HashAggregate [sumsales] [sum,isEmpty,sum,isEmpty] - HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00)),sumsales,sum,isEmpty] + HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(ss_quantity as decimal(12,2)))), DecimalType(18,2)), 0.00)),sumsales,sum,isEmpty] InputAdapter ReusedExchange [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] #2 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/explain.txt index 48ab2f77ad964..d0208d6e24e2f 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/explain.txt @@ -152,7 +152,7 @@ Input [12]: [ss_item_sk#1, ss_quantity#3, ss_sales_price#4, d_year#8, d_moy#9, d (19) HashAggregate [codegen id : 4] Input [10]: [ss_quantity#3, ss_sales_price#4, d_year#8, d_moy#9, d_qoy#10, s_store_id#12, i_brand#15, i_class#16, i_category#17, i_product_name#18] Keys [8]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, d_qoy#10, d_moy#9, s_store_id#12] -Functions [1]: [partial_sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))] +Functions [1]: [partial_sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))] Aggregate Attributes [2]: [sum#20, isEmpty#21] Results [10]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sum#22, isEmpty#23] @@ -163,9 +163,9 @@ Arguments: hashpartitioning(i_category#17, i_class#16, i_brand#15, i_product_nam (21) HashAggregate [codegen id : 5] Input [10]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sum#22, isEmpty#23] Keys [8]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, d_qoy#10, d_moy#9, s_store_id#12] -Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))] -Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#25] -Results [9]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, cast(sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#25 as decimal(38,2)) AS sumsales#26] +Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))] +Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#25] +Results [9]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, cast(sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#25 as decimal(38,2)) AS sumsales#26] (22) ReusedExchange [Reuses operator id: 20] Output [10]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sum#27, isEmpty#28] @@ -173,9 +173,9 @@ Output [10]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8 (23) HashAggregate [codegen id : 10] Input [10]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sum#27, isEmpty#28] Keys [8]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, d_qoy#10, d_moy#9, s_store_id#12] -Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))] -Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#25] -Results [8]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, d_qoy#10, d_moy#9, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#25 AS sumsales#29] +Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))] +Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#25] +Results [8]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, d_qoy#10, d_moy#9, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#25 AS sumsales#29] (24) HashAggregate [codegen id : 10] Input [8]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, d_qoy#10, d_moy#9, sumsales#29] @@ -201,9 +201,9 @@ Output [10]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8 (28) HashAggregate [codegen id : 16] Input [10]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sum#38, isEmpty#39] Keys [8]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, d_qoy#10, d_moy#9, s_store_id#12] -Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))] -Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#25] -Results [7]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, d_qoy#10, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#25 AS sumsales#29] +Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))] +Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#25] +Results [7]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, d_qoy#10, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#25 AS sumsales#29] (29) HashAggregate [codegen id : 16] Input [7]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, d_qoy#10, sumsales#29] @@ -229,9 +229,9 @@ Output [10]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8 (33) HashAggregate [codegen id : 22] Input [10]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sum#49, isEmpty#50] Keys [8]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, d_qoy#10, d_moy#9, s_store_id#12] -Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))] -Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#25] -Results [6]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#25 AS sumsales#29] +Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))] +Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#25] +Results [6]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#25 AS sumsales#29] (34) HashAggregate [codegen id : 22] Input [6]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, sumsales#29] @@ -257,9 +257,9 @@ Output [10]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8 (38) HashAggregate [codegen id : 28] Input [10]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sum#61, isEmpty#62] Keys [8]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, d_qoy#10, d_moy#9, s_store_id#12] -Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))] -Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#25] -Results [5]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#25 AS sumsales#29] +Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))] +Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#25] +Results [5]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#25 AS sumsales#29] (39) HashAggregate [codegen id : 28] Input [5]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, sumsales#29] @@ -285,9 +285,9 @@ Output [10]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8 (43) HashAggregate [codegen id : 34] Input [10]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sum#74, isEmpty#75] Keys [8]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, d_qoy#10, d_moy#9, s_store_id#12] -Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))] -Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#25] -Results [4]: [i_category#17, i_class#16, i_brand#15, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#25 AS sumsales#29] +Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))] +Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#25] +Results [4]: [i_category#17, i_class#16, i_brand#15, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#25 AS sumsales#29] (44) HashAggregate [codegen id : 34] Input [4]: [i_category#17, i_class#16, i_brand#15, sumsales#29] @@ -313,9 +313,9 @@ Output [10]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8 (48) HashAggregate [codegen id : 40] Input [10]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sum#88, isEmpty#89] Keys [8]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, d_qoy#10, d_moy#9, s_store_id#12] -Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))] -Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#25] -Results [3]: [i_category#17, i_class#16, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#25 AS sumsales#29] +Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))] +Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#25] +Results [3]: [i_category#17, i_class#16, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#25 AS sumsales#29] (49) HashAggregate [codegen id : 40] Input [3]: [i_category#17, i_class#16, sumsales#29] @@ -341,9 +341,9 @@ Output [10]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8 (53) HashAggregate [codegen id : 46] Input [10]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sum#103, isEmpty#104] Keys [8]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, d_qoy#10, d_moy#9, s_store_id#12] -Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))] -Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#25] -Results [2]: [i_category#17, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#25 AS sumsales#29] +Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))] +Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#25] +Results [2]: [i_category#17, sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#25 AS sumsales#29] (54) HashAggregate [codegen id : 46] Input [2]: [i_category#17, sumsales#29] @@ -369,9 +369,9 @@ Output [10]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8 (58) HashAggregate [codegen id : 52] Input [10]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sum#119, isEmpty#120] Keys [8]: [i_category#17, i_class#16, i_brand#15, i_product_name#18, d_year#8, d_qoy#10, d_moy#9, s_store_id#12] -Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))] -Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#25] -Results [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(cast(ss_quantity#3 as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00))#25 AS sumsales#29] +Functions [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))] +Aggregate Attributes [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#25] +Results [1]: [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price#4 as decimal(12,2))) * promote_precision(cast(ss_quantity#3 as decimal(12,2)))), DecimalType(18,2)), 0.00))#25 AS sumsales#29] (59) HashAggregate [codegen id : 52] Input [1]: [sumsales#29] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/simplified.txt index a26fa77b9a6d2..35d285165618b 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/simplified.txt @@ -9,7 +9,7 @@ TakeOrderedAndProject [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_ Exchange [i_category] #1 Union WholeStageCodegen (5) - HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00)),sumsales,sum,isEmpty] + HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(ss_quantity as decimal(12,2)))), DecimalType(18,2)), 0.00)),sumsales,sum,isEmpty] InputAdapter Exchange [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id] #2 WholeStageCodegen (4) @@ -54,7 +54,7 @@ TakeOrderedAndProject [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_ Exchange [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy] #6 WholeStageCodegen (10) HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,sumsales] [sum,isEmpty,sum,isEmpty] - HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00)),sumsales,sum,isEmpty] + HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(ss_quantity as decimal(12,2)))), DecimalType(18,2)), 0.00)),sumsales,sum,isEmpty] InputAdapter ReusedExchange [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] #2 WholeStageCodegen (17) @@ -63,7 +63,7 @@ TakeOrderedAndProject [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_ Exchange [i_category,i_class,i_brand,i_product_name,d_year,d_qoy] #7 WholeStageCodegen (16) HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,sumsales] [sum,isEmpty,sum,isEmpty] - HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00)),sumsales,sum,isEmpty] + HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(ss_quantity as decimal(12,2)))), DecimalType(18,2)), 0.00)),sumsales,sum,isEmpty] InputAdapter ReusedExchange [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] #2 WholeStageCodegen (23) @@ -72,7 +72,7 @@ TakeOrderedAndProject [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_ Exchange [i_category,i_class,i_brand,i_product_name,d_year] #8 WholeStageCodegen (22) HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,sumsales] [sum,isEmpty,sum,isEmpty] - HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00)),sumsales,sum,isEmpty] + HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(ss_quantity as decimal(12,2)))), DecimalType(18,2)), 0.00)),sumsales,sum,isEmpty] InputAdapter ReusedExchange [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] #2 WholeStageCodegen (29) @@ -81,7 +81,7 @@ TakeOrderedAndProject [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_ Exchange [i_category,i_class,i_brand,i_product_name] #9 WholeStageCodegen (28) HashAggregate [i_category,i_class,i_brand,i_product_name,sumsales] [sum,isEmpty,sum,isEmpty] - HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00)),sumsales,sum,isEmpty] + HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(ss_quantity as decimal(12,2)))), DecimalType(18,2)), 0.00)),sumsales,sum,isEmpty] InputAdapter ReusedExchange [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] #2 WholeStageCodegen (35) @@ -90,7 +90,7 @@ TakeOrderedAndProject [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_ Exchange [i_category,i_class,i_brand] #10 WholeStageCodegen (34) HashAggregate [i_category,i_class,i_brand,sumsales] [sum,isEmpty,sum,isEmpty] - HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00)),sumsales,sum,isEmpty] + HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(ss_quantity as decimal(12,2)))), DecimalType(18,2)), 0.00)),sumsales,sum,isEmpty] InputAdapter ReusedExchange [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] #2 WholeStageCodegen (41) @@ -99,7 +99,7 @@ TakeOrderedAndProject [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_ Exchange [i_category,i_class] #11 WholeStageCodegen (40) HashAggregate [i_category,i_class,sumsales] [sum,isEmpty,sum,isEmpty] - HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00)),sumsales,sum,isEmpty] + HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(ss_quantity as decimal(12,2)))), DecimalType(18,2)), 0.00)),sumsales,sum,isEmpty] InputAdapter ReusedExchange [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] #2 WholeStageCodegen (47) @@ -108,7 +108,7 @@ TakeOrderedAndProject [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_ Exchange [i_category] #12 WholeStageCodegen (46) HashAggregate [i_category,sumsales] [sum,isEmpty,sum,isEmpty] - HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00)),sumsales,sum,isEmpty] + HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(ss_quantity as decimal(12,2)))), DecimalType(18,2)), 0.00)),sumsales,sum,isEmpty] InputAdapter ReusedExchange [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] #2 WholeStageCodegen (53) @@ -117,6 +117,6 @@ TakeOrderedAndProject [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_ Exchange #13 WholeStageCodegen (52) HashAggregate [sumsales] [sum,isEmpty,sum,isEmpty] - HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(cast(ss_quantity as decimal(10,0)) as decimal(12,2)))), DecimalType(18,2), true), 0.00)),sumsales,sum,isEmpty] + HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce(CheckOverflow((promote_precision(cast(ss_sales_price as decimal(12,2))) * promote_precision(cast(ss_quantity as decimal(12,2)))), DecimalType(18,2)), 0.00)),sumsales,sum,isEmpty] InputAdapter ReusedExchange [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] #2 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q72.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q72.sf100/explain.txt index 42f7488ad66d3..e5e42f2be1366 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q72.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q72.sf100/explain.txt @@ -1,72 +1,74 @@ == Physical Plan == -TakeOrderedAndProject (68) -+- * HashAggregate (67) - +- Exchange (66) - +- * HashAggregate (65) - +- * Project (64) - +- * SortMergeJoin LeftOuter (63) - :- * Sort (56) - : +- * Project (55) - : +- * BroadcastHashJoin LeftOuter BuildRight (54) - : :- * Project (49) - : : +- * SortMergeJoin Inner (48) - : : :- * Sort (36) - : : : +- * Project (35) - : : : +- * BroadcastHashJoin Inner BuildRight (34) - : : : :- * Project (32) - : : : : +- * SortMergeJoin Inner (31) - : : : : :- * Sort (25) - : : : : : +- Exchange (24) - : : : : : +- * Project (23) - : : : : : +- * BroadcastHashJoin Inner BuildRight (22) - : : : : : :- * Project (17) - : : : : : : +- * BroadcastHashJoin Inner BuildRight (16) - : : : : : : :- * Project (10) - : : : : : : : +- * BroadcastHashJoin Inner BuildRight (9) - : : : : : : : :- * Filter (3) - : : : : : : : : +- * ColumnarToRow (2) - : : : : : : : : +- Scan parquet default.catalog_sales (1) - : : : : : : : +- BroadcastExchange (8) - : : : : : : : +- * Project (7) - : : : : : : : +- * Filter (6) - : : : : : : : +- * ColumnarToRow (5) - : : : : : : : +- Scan parquet default.household_demographics (4) - : : : : : : +- BroadcastExchange (15) - : : : : : : +- * Project (14) - : : : : : : +- * Filter (13) - : : : : : : +- * ColumnarToRow (12) - : : : : : : +- Scan parquet default.customer_demographics (11) - : : : : : +- BroadcastExchange (21) - : : : : : +- * Filter (20) - : : : : : +- * ColumnarToRow (19) - : : : : : +- Scan parquet default.date_dim (18) - : : : : +- * Sort (30) - : : : : +- Exchange (29) - : : : : +- * Filter (28) - : : : : +- * ColumnarToRow (27) - : : : : +- Scan parquet default.item (26) - : : : +- ReusedExchange (33) - : : +- * Sort (47) - : : +- Exchange (46) - : : +- * Project (45) - : : +- * BroadcastHashJoin Inner BuildRight (44) - : : :- * Filter (39) - : : : +- * ColumnarToRow (38) - : : : +- Scan parquet default.inventory (37) - : : +- BroadcastExchange (43) - : : +- * Filter (42) - : : +- * ColumnarToRow (41) - : : +- Scan parquet default.warehouse (40) - : +- BroadcastExchange (53) - : +- * Filter (52) - : +- * ColumnarToRow (51) - : +- Scan parquet default.promotion (50) - +- * Sort (62) - +- Exchange (61) - +- * Project (60) - +- * Filter (59) - +- * ColumnarToRow (58) - +- Scan parquet default.catalog_returns (57) +TakeOrderedAndProject (70) ++- * HashAggregate (69) + +- Exchange (68) + +- * HashAggregate (67) + +- * Project (66) + +- * SortMergeJoin LeftOuter (65) + :- * Sort (58) + : +- Exchange (57) + : +- * Project (56) + : +- * BroadcastHashJoin LeftOuter BuildRight (55) + : :- * Project (50) + : : +- * SortMergeJoin Inner (49) + : : :- * Sort (37) + : : : +- Exchange (36) + : : : +- * Project (35) + : : : +- * BroadcastHashJoin Inner BuildRight (34) + : : : :- * Project (32) + : : : : +- * SortMergeJoin Inner (31) + : : : : :- * Sort (25) + : : : : : +- Exchange (24) + : : : : : +- * Project (23) + : : : : : +- * BroadcastHashJoin Inner BuildRight (22) + : : : : : :- * Project (17) + : : : : : : +- * BroadcastHashJoin Inner BuildRight (16) + : : : : : : :- * Project (10) + : : : : : : : +- * BroadcastHashJoin Inner BuildRight (9) + : : : : : : : :- * Filter (3) + : : : : : : : : +- * ColumnarToRow (2) + : : : : : : : : +- Scan parquet default.catalog_sales (1) + : : : : : : : +- BroadcastExchange (8) + : : : : : : : +- * Project (7) + : : : : : : : +- * Filter (6) + : : : : : : : +- * ColumnarToRow (5) + : : : : : : : +- Scan parquet default.household_demographics (4) + : : : : : : +- BroadcastExchange (15) + : : : : : : +- * Project (14) + : : : : : : +- * Filter (13) + : : : : : : +- * ColumnarToRow (12) + : : : : : : +- Scan parquet default.customer_demographics (11) + : : : : : +- BroadcastExchange (21) + : : : : : +- * Filter (20) + : : : : : +- * ColumnarToRow (19) + : : : : : +- Scan parquet default.date_dim (18) + : : : : +- * Sort (30) + : : : : +- Exchange (29) + : : : : +- * Filter (28) + : : : : +- * ColumnarToRow (27) + : : : : +- Scan parquet default.item (26) + : : : +- ReusedExchange (33) + : : +- * Sort (48) + : : +- Exchange (47) + : : +- * Project (46) + : : +- * BroadcastHashJoin Inner BuildRight (45) + : : :- * Filter (40) + : : : +- * ColumnarToRow (39) + : : : +- Scan parquet default.inventory (38) + : : +- BroadcastExchange (44) + : : +- * Filter (43) + : : +- * ColumnarToRow (42) + : : +- Scan parquet default.warehouse (41) + : +- BroadcastExchange (54) + : +- * Filter (53) + : +- * ColumnarToRow (52) + : +- Scan parquet default.promotion (51) + +- * Sort (64) + +- Exchange (63) + +- * Project (62) + +- * Filter (61) + +- * ColumnarToRow (60) + +- Scan parquet default.catalog_returns (59) (1) Scan parquet default.catalog_sales @@ -212,7 +214,7 @@ Join condition: None Output [7]: [cs_item_sk#4, cs_promo_sk#5, cs_order_number#6, cs_quantity#7, cs_sold_date_sk#8, d_date#17, i_item_desc#21] Input [8]: [cs_item_sk#4, cs_promo_sk#5, cs_order_number#6, cs_quantity#7, cs_sold_date_sk#8, d_date#17, i_item_sk#20, i_item_desc#21] -(33) ReusedExchange [Reuses operator id: 79] +(33) ReusedExchange [Reuses operator id: 81] Output [4]: [d_date_sk#23, d_date#24, d_week_seq#25, d_date_sk#26] (34) BroadcastHashJoin [codegen id : 10] @@ -224,220 +226,228 @@ Join condition: (d_date#17 > date_add(d_date#24, 5)) Output [7]: [cs_item_sk#4, cs_promo_sk#5, cs_order_number#6, cs_quantity#7, i_item_desc#21, d_week_seq#25, d_date_sk#26] Input [11]: [cs_item_sk#4, cs_promo_sk#5, cs_order_number#6, cs_quantity#7, cs_sold_date_sk#8, d_date#17, i_item_desc#21, d_date_sk#23, d_date#24, d_week_seq#25, d_date_sk#26] -(36) Sort [codegen id : 10] +(36) Exchange +Input [7]: [cs_item_sk#4, cs_promo_sk#5, cs_order_number#6, cs_quantity#7, i_item_desc#21, d_week_seq#25, d_date_sk#26] +Arguments: hashpartitioning(cs_item_sk#4, d_date_sk#26, 5), ENSURE_REQUIREMENTS, [id=#27] + +(37) Sort [codegen id : 11] Input [7]: [cs_item_sk#4, cs_promo_sk#5, cs_order_number#6, cs_quantity#7, i_item_desc#21, d_week_seq#25, d_date_sk#26] Arguments: [cs_item_sk#4 ASC NULLS FIRST, d_date_sk#26 ASC NULLS FIRST], false, 0 -(37) Scan parquet default.inventory -Output [4]: [inv_item_sk#27, inv_warehouse_sk#28, inv_quantity_on_hand#29, inv_date_sk#30] +(38) Scan parquet default.inventory +Output [4]: [inv_item_sk#28, inv_warehouse_sk#29, inv_quantity_on_hand#30, inv_date_sk#31] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(inv_date_sk#30), dynamicpruningexpression(true)] +PartitionFilters: [isnotnull(inv_date_sk#31), dynamicpruningexpression(true)] PushedFilters: [IsNotNull(inv_quantity_on_hand), IsNotNull(inv_item_sk), IsNotNull(inv_warehouse_sk)] ReadSchema: struct -(38) ColumnarToRow [codegen id : 12] -Input [4]: [inv_item_sk#27, inv_warehouse_sk#28, inv_quantity_on_hand#29, inv_date_sk#30] +(39) ColumnarToRow [codegen id : 13] +Input [4]: [inv_item_sk#28, inv_warehouse_sk#29, inv_quantity_on_hand#30, inv_date_sk#31] -(39) Filter [codegen id : 12] -Input [4]: [inv_item_sk#27, inv_warehouse_sk#28, inv_quantity_on_hand#29, inv_date_sk#30] -Condition : ((isnotnull(inv_quantity_on_hand#29) AND isnotnull(inv_item_sk#27)) AND isnotnull(inv_warehouse_sk#28)) +(40) Filter [codegen id : 13] +Input [4]: [inv_item_sk#28, inv_warehouse_sk#29, inv_quantity_on_hand#30, inv_date_sk#31] +Condition : ((isnotnull(inv_quantity_on_hand#30) AND isnotnull(inv_item_sk#28)) AND isnotnull(inv_warehouse_sk#29)) -(40) Scan parquet default.warehouse -Output [2]: [w_warehouse_sk#31, w_warehouse_name#32] +(41) Scan parquet default.warehouse +Output [2]: [w_warehouse_sk#32, w_warehouse_name#33] Batched: true Location [not included in comparison]/{warehouse_dir}/warehouse] PushedFilters: [IsNotNull(w_warehouse_sk)] ReadSchema: struct -(41) ColumnarToRow [codegen id : 11] -Input [2]: [w_warehouse_sk#31, w_warehouse_name#32] +(42) ColumnarToRow [codegen id : 12] +Input [2]: [w_warehouse_sk#32, w_warehouse_name#33] -(42) Filter [codegen id : 11] -Input [2]: [w_warehouse_sk#31, w_warehouse_name#32] -Condition : isnotnull(w_warehouse_sk#31) +(43) Filter [codegen id : 12] +Input [2]: [w_warehouse_sk#32, w_warehouse_name#33] +Condition : isnotnull(w_warehouse_sk#32) -(43) BroadcastExchange -Input [2]: [w_warehouse_sk#31, w_warehouse_name#32] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#33] +(44) BroadcastExchange +Input [2]: [w_warehouse_sk#32, w_warehouse_name#33] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#34] -(44) BroadcastHashJoin [codegen id : 12] -Left keys [1]: [inv_warehouse_sk#28] -Right keys [1]: [w_warehouse_sk#31] +(45) BroadcastHashJoin [codegen id : 13] +Left keys [1]: [inv_warehouse_sk#29] +Right keys [1]: [w_warehouse_sk#32] Join condition: None -(45) Project [codegen id : 12] -Output [4]: [inv_item_sk#27, inv_quantity_on_hand#29, inv_date_sk#30, w_warehouse_name#32] -Input [6]: [inv_item_sk#27, inv_warehouse_sk#28, inv_quantity_on_hand#29, inv_date_sk#30, w_warehouse_sk#31, w_warehouse_name#32] +(46) Project [codegen id : 13] +Output [4]: [inv_item_sk#28, inv_quantity_on_hand#30, inv_date_sk#31, w_warehouse_name#33] +Input [6]: [inv_item_sk#28, inv_warehouse_sk#29, inv_quantity_on_hand#30, inv_date_sk#31, w_warehouse_sk#32, w_warehouse_name#33] -(46) Exchange -Input [4]: [inv_item_sk#27, inv_quantity_on_hand#29, inv_date_sk#30, w_warehouse_name#32] -Arguments: hashpartitioning(inv_item_sk#27, 5), ENSURE_REQUIREMENTS, [id=#34] +(47) Exchange +Input [4]: [inv_item_sk#28, inv_quantity_on_hand#30, inv_date_sk#31, w_warehouse_name#33] +Arguments: hashpartitioning(inv_item_sk#28, inv_date_sk#31, 5), ENSURE_REQUIREMENTS, [id=#35] -(47) Sort [codegen id : 13] -Input [4]: [inv_item_sk#27, inv_quantity_on_hand#29, inv_date_sk#30, w_warehouse_name#32] -Arguments: [inv_item_sk#27 ASC NULLS FIRST, inv_date_sk#30 ASC NULLS FIRST], false, 0 +(48) Sort [codegen id : 14] +Input [4]: [inv_item_sk#28, inv_quantity_on_hand#30, inv_date_sk#31, w_warehouse_name#33] +Arguments: [inv_item_sk#28 ASC NULLS FIRST, inv_date_sk#31 ASC NULLS FIRST], false, 0 -(48) SortMergeJoin [codegen id : 15] +(49) SortMergeJoin [codegen id : 16] Left keys [2]: [cs_item_sk#4, d_date_sk#26] -Right keys [2]: [inv_item_sk#27, inv_date_sk#30] -Join condition: (inv_quantity_on_hand#29 < cs_quantity#7) +Right keys [2]: [inv_item_sk#28, inv_date_sk#31] +Join condition: (inv_quantity_on_hand#30 < cs_quantity#7) -(49) Project [codegen id : 15] -Output [6]: [cs_item_sk#4, cs_promo_sk#5, cs_order_number#6, w_warehouse_name#32, i_item_desc#21, d_week_seq#25] -Input [11]: [cs_item_sk#4, cs_promo_sk#5, cs_order_number#6, cs_quantity#7, i_item_desc#21, d_week_seq#25, d_date_sk#26, inv_item_sk#27, inv_quantity_on_hand#29, inv_date_sk#30, w_warehouse_name#32] +(50) Project [codegen id : 16] +Output [6]: [cs_item_sk#4, cs_promo_sk#5, cs_order_number#6, w_warehouse_name#33, i_item_desc#21, d_week_seq#25] +Input [11]: [cs_item_sk#4, cs_promo_sk#5, cs_order_number#6, cs_quantity#7, i_item_desc#21, d_week_seq#25, d_date_sk#26, inv_item_sk#28, inv_quantity_on_hand#30, inv_date_sk#31, w_warehouse_name#33] -(50) Scan parquet default.promotion -Output [1]: [p_promo_sk#35] +(51) Scan parquet default.promotion +Output [1]: [p_promo_sk#36] Batched: true Location [not included in comparison]/{warehouse_dir}/promotion] PushedFilters: [IsNotNull(p_promo_sk)] ReadSchema: struct -(51) ColumnarToRow [codegen id : 14] -Input [1]: [p_promo_sk#35] +(52) ColumnarToRow [codegen id : 15] +Input [1]: [p_promo_sk#36] -(52) Filter [codegen id : 14] -Input [1]: [p_promo_sk#35] -Condition : isnotnull(p_promo_sk#35) +(53) Filter [codegen id : 15] +Input [1]: [p_promo_sk#36] +Condition : isnotnull(p_promo_sk#36) -(53) BroadcastExchange -Input [1]: [p_promo_sk#35] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#36] +(54) BroadcastExchange +Input [1]: [p_promo_sk#36] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#37] -(54) BroadcastHashJoin [codegen id : 15] +(55) BroadcastHashJoin [codegen id : 16] Left keys [1]: [cs_promo_sk#5] -Right keys [1]: [p_promo_sk#35] +Right keys [1]: [p_promo_sk#36] Join condition: None -(55) Project [codegen id : 15] -Output [5]: [cs_item_sk#4, cs_order_number#6, w_warehouse_name#32, i_item_desc#21, d_week_seq#25] -Input [7]: [cs_item_sk#4, cs_promo_sk#5, cs_order_number#6, w_warehouse_name#32, i_item_desc#21, d_week_seq#25, p_promo_sk#35] +(56) Project [codegen id : 16] +Output [5]: [cs_item_sk#4, cs_order_number#6, w_warehouse_name#33, i_item_desc#21, d_week_seq#25] +Input [7]: [cs_item_sk#4, cs_promo_sk#5, cs_order_number#6, w_warehouse_name#33, i_item_desc#21, d_week_seq#25, p_promo_sk#36] + +(57) Exchange +Input [5]: [cs_item_sk#4, cs_order_number#6, w_warehouse_name#33, i_item_desc#21, d_week_seq#25] +Arguments: hashpartitioning(cs_item_sk#4, cs_order_number#6, 5), ENSURE_REQUIREMENTS, [id=#38] -(56) Sort [codegen id : 15] -Input [5]: [cs_item_sk#4, cs_order_number#6, w_warehouse_name#32, i_item_desc#21, d_week_seq#25] +(58) Sort [codegen id : 17] +Input [5]: [cs_item_sk#4, cs_order_number#6, w_warehouse_name#33, i_item_desc#21, d_week_seq#25] Arguments: [cs_item_sk#4 ASC NULLS FIRST, cs_order_number#6 ASC NULLS FIRST], false, 0 -(57) Scan parquet default.catalog_returns -Output [3]: [cr_item_sk#37, cr_order_number#38, cr_returned_date_sk#39] +(59) Scan parquet default.catalog_returns +Output [3]: [cr_item_sk#39, cr_order_number#40, cr_returned_date_sk#41] Batched: true Location [not included in comparison]/{warehouse_dir}/catalog_returns] PushedFilters: [IsNotNull(cr_item_sk), IsNotNull(cr_order_number)] ReadSchema: struct -(58) ColumnarToRow [codegen id : 16] -Input [3]: [cr_item_sk#37, cr_order_number#38, cr_returned_date_sk#39] +(60) ColumnarToRow [codegen id : 18] +Input [3]: [cr_item_sk#39, cr_order_number#40, cr_returned_date_sk#41] -(59) Filter [codegen id : 16] -Input [3]: [cr_item_sk#37, cr_order_number#38, cr_returned_date_sk#39] -Condition : (isnotnull(cr_item_sk#37) AND isnotnull(cr_order_number#38)) +(61) Filter [codegen id : 18] +Input [3]: [cr_item_sk#39, cr_order_number#40, cr_returned_date_sk#41] +Condition : (isnotnull(cr_item_sk#39) AND isnotnull(cr_order_number#40)) -(60) Project [codegen id : 16] -Output [2]: [cr_item_sk#37, cr_order_number#38] -Input [3]: [cr_item_sk#37, cr_order_number#38, cr_returned_date_sk#39] +(62) Project [codegen id : 18] +Output [2]: [cr_item_sk#39, cr_order_number#40] +Input [3]: [cr_item_sk#39, cr_order_number#40, cr_returned_date_sk#41] -(61) Exchange -Input [2]: [cr_item_sk#37, cr_order_number#38] -Arguments: hashpartitioning(cr_item_sk#37, 5), ENSURE_REQUIREMENTS, [id=#40] +(63) Exchange +Input [2]: [cr_item_sk#39, cr_order_number#40] +Arguments: hashpartitioning(cr_item_sk#39, cr_order_number#40, 5), ENSURE_REQUIREMENTS, [id=#42] -(62) Sort [codegen id : 17] -Input [2]: [cr_item_sk#37, cr_order_number#38] -Arguments: [cr_item_sk#37 ASC NULLS FIRST, cr_order_number#38 ASC NULLS FIRST], false, 0 +(64) Sort [codegen id : 19] +Input [2]: [cr_item_sk#39, cr_order_number#40] +Arguments: [cr_item_sk#39 ASC NULLS FIRST, cr_order_number#40 ASC NULLS FIRST], false, 0 -(63) SortMergeJoin [codegen id : 18] +(65) SortMergeJoin [codegen id : 20] Left keys [2]: [cs_item_sk#4, cs_order_number#6] -Right keys [2]: [cr_item_sk#37, cr_order_number#38] +Right keys [2]: [cr_item_sk#39, cr_order_number#40] Join condition: None -(64) Project [codegen id : 18] -Output [3]: [w_warehouse_name#32, i_item_desc#21, d_week_seq#25] -Input [7]: [cs_item_sk#4, cs_order_number#6, w_warehouse_name#32, i_item_desc#21, d_week_seq#25, cr_item_sk#37, cr_order_number#38] +(66) Project [codegen id : 20] +Output [3]: [w_warehouse_name#33, i_item_desc#21, d_week_seq#25] +Input [7]: [cs_item_sk#4, cs_order_number#6, w_warehouse_name#33, i_item_desc#21, d_week_seq#25, cr_item_sk#39, cr_order_number#40] -(65) HashAggregate [codegen id : 18] -Input [3]: [w_warehouse_name#32, i_item_desc#21, d_week_seq#25] -Keys [3]: [i_item_desc#21, w_warehouse_name#32, d_week_seq#25] +(67) HashAggregate [codegen id : 20] +Input [3]: [w_warehouse_name#33, i_item_desc#21, d_week_seq#25] +Keys [3]: [i_item_desc#21, w_warehouse_name#33, d_week_seq#25] Functions [1]: [partial_count(1)] -Aggregate Attributes [1]: [count#41] -Results [4]: [i_item_desc#21, w_warehouse_name#32, d_week_seq#25, count#42] +Aggregate Attributes [1]: [count#43] +Results [4]: [i_item_desc#21, w_warehouse_name#33, d_week_seq#25, count#44] -(66) Exchange -Input [4]: [i_item_desc#21, w_warehouse_name#32, d_week_seq#25, count#42] -Arguments: hashpartitioning(i_item_desc#21, w_warehouse_name#32, d_week_seq#25, 5), ENSURE_REQUIREMENTS, [id=#43] +(68) Exchange +Input [4]: [i_item_desc#21, w_warehouse_name#33, d_week_seq#25, count#44] +Arguments: hashpartitioning(i_item_desc#21, w_warehouse_name#33, d_week_seq#25, 5), ENSURE_REQUIREMENTS, [id=#45] -(67) HashAggregate [codegen id : 19] -Input [4]: [i_item_desc#21, w_warehouse_name#32, d_week_seq#25, count#42] -Keys [3]: [i_item_desc#21, w_warehouse_name#32, d_week_seq#25] +(69) HashAggregate [codegen id : 21] +Input [4]: [i_item_desc#21, w_warehouse_name#33, d_week_seq#25, count#44] +Keys [3]: [i_item_desc#21, w_warehouse_name#33, d_week_seq#25] Functions [1]: [count(1)] -Aggregate Attributes [1]: [count(1)#44] -Results [6]: [i_item_desc#21, w_warehouse_name#32, d_week_seq#25, count(1)#44 AS no_promo#45, count(1)#44 AS promo#46, count(1)#44 AS total_cnt#47] +Aggregate Attributes [1]: [count(1)#46] +Results [6]: [i_item_desc#21, w_warehouse_name#33, d_week_seq#25, count(1)#46 AS no_promo#47, count(1)#46 AS promo#48, count(1)#46 AS total_cnt#49] -(68) TakeOrderedAndProject -Input [6]: [i_item_desc#21, w_warehouse_name#32, d_week_seq#25, no_promo#45, promo#46, total_cnt#47] -Arguments: 100, [total_cnt#47 DESC NULLS LAST, i_item_desc#21 ASC NULLS FIRST, w_warehouse_name#32 ASC NULLS FIRST, d_week_seq#25 ASC NULLS FIRST], [i_item_desc#21, w_warehouse_name#32, d_week_seq#25, no_promo#45, promo#46, total_cnt#47] +(70) TakeOrderedAndProject +Input [6]: [i_item_desc#21, w_warehouse_name#33, d_week_seq#25, no_promo#47, promo#48, total_cnt#49] +Arguments: 100, [total_cnt#49 DESC NULLS LAST, i_item_desc#21 ASC NULLS FIRST, w_warehouse_name#33 ASC NULLS FIRST, d_week_seq#25 ASC NULLS FIRST], [i_item_desc#21, w_warehouse_name#33, d_week_seq#25, no_promo#47, promo#48, total_cnt#49] ===== Subqueries ===== Subquery:1 Hosting operator id = 1 Hosting Expression = cs_sold_date_sk#8 IN dynamicpruning#9 -BroadcastExchange (79) -+- * Project (78) - +- * BroadcastHashJoin Inner BuildLeft (77) - :- BroadcastExchange (73) - : +- * Project (72) - : +- * Filter (71) - : +- * ColumnarToRow (70) - : +- Scan parquet default.date_dim (69) - +- * Filter (76) - +- * ColumnarToRow (75) - +- Scan parquet default.date_dim (74) - - -(69) Scan parquet default.date_dim -Output [4]: [d_date_sk#23, d_date#24, d_week_seq#25, d_year#48] +BroadcastExchange (81) ++- * Project (80) + +- * BroadcastHashJoin Inner BuildLeft (79) + :- BroadcastExchange (75) + : +- * Project (74) + : +- * Filter (73) + : +- * ColumnarToRow (72) + : +- Scan parquet default.date_dim (71) + +- * Filter (78) + +- * ColumnarToRow (77) + +- Scan parquet default.date_dim (76) + + +(71) Scan parquet default.date_dim +Output [4]: [d_date_sk#23, d_date#24, d_week_seq#25, d_year#50] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), EqualTo(d_year,2001), IsNotNull(d_date_sk), IsNotNull(d_week_seq), IsNotNull(d_date)] ReadSchema: struct -(70) ColumnarToRow [codegen id : 1] -Input [4]: [d_date_sk#23, d_date#24, d_week_seq#25, d_year#48] +(72) ColumnarToRow [codegen id : 1] +Input [4]: [d_date_sk#23, d_date#24, d_week_seq#25, d_year#50] -(71) Filter [codegen id : 1] -Input [4]: [d_date_sk#23, d_date#24, d_week_seq#25, d_year#48] -Condition : ((((isnotnull(d_year#48) AND (d_year#48 = 2001)) AND isnotnull(d_date_sk#23)) AND isnotnull(d_week_seq#25)) AND isnotnull(d_date#24)) +(73) Filter [codegen id : 1] +Input [4]: [d_date_sk#23, d_date#24, d_week_seq#25, d_year#50] +Condition : ((((isnotnull(d_year#50) AND (d_year#50 = 2001)) AND isnotnull(d_date_sk#23)) AND isnotnull(d_week_seq#25)) AND isnotnull(d_date#24)) -(72) Project [codegen id : 1] +(74) Project [codegen id : 1] Output [3]: [d_date_sk#23, d_date#24, d_week_seq#25] -Input [4]: [d_date_sk#23, d_date#24, d_week_seq#25, d_year#48] +Input [4]: [d_date_sk#23, d_date#24, d_week_seq#25, d_year#50] -(73) BroadcastExchange +(75) BroadcastExchange Input [3]: [d_date_sk#23, d_date#24, d_week_seq#25] -Arguments: HashedRelationBroadcastMode(List(cast(input[2, int, true] as bigint)),false), [id=#49] +Arguments: HashedRelationBroadcastMode(List(cast(input[2, int, true] as bigint)),false), [id=#51] -(74) Scan parquet default.date_dim -Output [2]: [d_date_sk#26, d_week_seq#50] +(76) Scan parquet default.date_dim +Output [2]: [d_date_sk#26, d_week_seq#52] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_week_seq), IsNotNull(d_date_sk)] ReadSchema: struct -(75) ColumnarToRow -Input [2]: [d_date_sk#26, d_week_seq#50] +(77) ColumnarToRow +Input [2]: [d_date_sk#26, d_week_seq#52] -(76) Filter -Input [2]: [d_date_sk#26, d_week_seq#50] -Condition : (isnotnull(d_week_seq#50) AND isnotnull(d_date_sk#26)) +(78) Filter +Input [2]: [d_date_sk#26, d_week_seq#52] +Condition : (isnotnull(d_week_seq#52) AND isnotnull(d_date_sk#26)) -(77) BroadcastHashJoin [codegen id : 2] +(79) BroadcastHashJoin [codegen id : 2] Left keys [1]: [d_week_seq#25] -Right keys [1]: [d_week_seq#50] +Right keys [1]: [d_week_seq#52] Join condition: None -(78) Project [codegen id : 2] +(80) Project [codegen id : 2] Output [4]: [d_date_sk#23, d_date#24, d_week_seq#25, d_date_sk#26] -Input [5]: [d_date_sk#23, d_date#24, d_week_seq#25, d_date_sk#26, d_week_seq#50] +Input [5]: [d_date_sk#23, d_date#24, d_week_seq#25, d_date_sk#26, d_week_seq#52] -(79) BroadcastExchange +(81) BroadcastExchange Input [4]: [d_date_sk#23, d_date#24, d_week_seq#25, d_date_sk#26] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#51] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#53] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q72.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q72.sf100/simplified.txt index d84393b2ff106..e838025a71db8 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q72.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q72.sf100/simplified.txt @@ -1,126 +1,132 @@ TakeOrderedAndProject [total_cnt,i_item_desc,w_warehouse_name,d_week_seq,no_promo,promo] - WholeStageCodegen (19) + WholeStageCodegen (21) HashAggregate [i_item_desc,w_warehouse_name,d_week_seq,count] [count(1),no_promo,promo,total_cnt,count] InputAdapter Exchange [i_item_desc,w_warehouse_name,d_week_seq] #1 - WholeStageCodegen (18) + WholeStageCodegen (20) HashAggregate [i_item_desc,w_warehouse_name,d_week_seq] [count,count] Project [w_warehouse_name,i_item_desc,d_week_seq] SortMergeJoin [cs_item_sk,cs_order_number,cr_item_sk,cr_order_number] InputAdapter - WholeStageCodegen (15) + WholeStageCodegen (17) Sort [cs_item_sk,cs_order_number] - Project [cs_item_sk,cs_order_number,w_warehouse_name,i_item_desc,d_week_seq] - BroadcastHashJoin [cs_promo_sk,p_promo_sk] - Project [cs_item_sk,cs_promo_sk,cs_order_number,w_warehouse_name,i_item_desc,d_week_seq] - SortMergeJoin [cs_item_sk,d_date_sk,inv_item_sk,inv_date_sk,inv_quantity_on_hand,cs_quantity] - InputAdapter - WholeStageCodegen (10) - Sort [cs_item_sk,d_date_sk] - Project [cs_item_sk,cs_promo_sk,cs_order_number,cs_quantity,i_item_desc,d_week_seq,d_date_sk] - BroadcastHashJoin [cs_sold_date_sk,d_date_sk,d_date,d_date] - Project [cs_item_sk,cs_promo_sk,cs_order_number,cs_quantity,cs_sold_date_sk,d_date,i_item_desc] - SortMergeJoin [cs_item_sk,i_item_sk] - InputAdapter - WholeStageCodegen (5) - Sort [cs_item_sk] - InputAdapter - Exchange [cs_item_sk] #2 - WholeStageCodegen (4) - Project [cs_item_sk,cs_promo_sk,cs_order_number,cs_quantity,cs_sold_date_sk,d_date] - BroadcastHashJoin [cs_ship_date_sk,d_date_sk] - Project [cs_ship_date_sk,cs_item_sk,cs_promo_sk,cs_order_number,cs_quantity,cs_sold_date_sk] - BroadcastHashJoin [cs_bill_cdemo_sk,cd_demo_sk] - Project [cs_ship_date_sk,cs_bill_cdemo_sk,cs_item_sk,cs_promo_sk,cs_order_number,cs_quantity,cs_sold_date_sk] - BroadcastHashJoin [cs_bill_hdemo_sk,hd_demo_sk] - Filter [cs_quantity,cs_item_sk,cs_bill_cdemo_sk,cs_bill_hdemo_sk,cs_ship_date_sk] - ColumnarToRow - InputAdapter - Scan parquet default.catalog_sales [cs_ship_date_sk,cs_bill_cdemo_sk,cs_bill_hdemo_sk,cs_item_sk,cs_promo_sk,cs_order_number,cs_quantity,cs_sold_date_sk] - SubqueryBroadcast [d_date_sk] #1 - BroadcastExchange #3 - WholeStageCodegen (2) - Project [d_date_sk,d_date,d_week_seq,d_date_sk] - BroadcastHashJoin [d_week_seq,d_week_seq] - InputAdapter - BroadcastExchange #4 - WholeStageCodegen (1) - Project [d_date_sk,d_date,d_week_seq] - Filter [d_year,d_date_sk,d_week_seq,d_date] - ColumnarToRow + InputAdapter + Exchange [cs_item_sk,cs_order_number] #2 + WholeStageCodegen (16) + Project [cs_item_sk,cs_order_number,w_warehouse_name,i_item_desc,d_week_seq] + BroadcastHashJoin [cs_promo_sk,p_promo_sk] + Project [cs_item_sk,cs_promo_sk,cs_order_number,w_warehouse_name,i_item_desc,d_week_seq] + SortMergeJoin [cs_item_sk,d_date_sk,inv_item_sk,inv_date_sk,inv_quantity_on_hand,cs_quantity] + InputAdapter + WholeStageCodegen (11) + Sort [cs_item_sk,d_date_sk] + InputAdapter + Exchange [cs_item_sk,d_date_sk] #3 + WholeStageCodegen (10) + Project [cs_item_sk,cs_promo_sk,cs_order_number,cs_quantity,i_item_desc,d_week_seq,d_date_sk] + BroadcastHashJoin [cs_sold_date_sk,d_date_sk,d_date,d_date] + Project [cs_item_sk,cs_promo_sk,cs_order_number,cs_quantity,cs_sold_date_sk,d_date,i_item_desc] + SortMergeJoin [cs_item_sk,i_item_sk] + InputAdapter + WholeStageCodegen (5) + Sort [cs_item_sk] + InputAdapter + Exchange [cs_item_sk] #4 + WholeStageCodegen (4) + Project [cs_item_sk,cs_promo_sk,cs_order_number,cs_quantity,cs_sold_date_sk,d_date] + BroadcastHashJoin [cs_ship_date_sk,d_date_sk] + Project [cs_ship_date_sk,cs_item_sk,cs_promo_sk,cs_order_number,cs_quantity,cs_sold_date_sk] + BroadcastHashJoin [cs_bill_cdemo_sk,cd_demo_sk] + Project [cs_ship_date_sk,cs_bill_cdemo_sk,cs_item_sk,cs_promo_sk,cs_order_number,cs_quantity,cs_sold_date_sk] + BroadcastHashJoin [cs_bill_hdemo_sk,hd_demo_sk] + Filter [cs_quantity,cs_item_sk,cs_bill_cdemo_sk,cs_bill_hdemo_sk,cs_ship_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.catalog_sales [cs_ship_date_sk,cs_bill_cdemo_sk,cs_bill_hdemo_sk,cs_item_sk,cs_promo_sk,cs_order_number,cs_quantity,cs_sold_date_sk] + SubqueryBroadcast [d_date_sk] #1 + BroadcastExchange #5 + WholeStageCodegen (2) + Project [d_date_sk,d_date,d_week_seq,d_date_sk] + BroadcastHashJoin [d_week_seq,d_week_seq] InputAdapter - Scan parquet default.date_dim [d_date_sk,d_date,d_week_seq,d_year] - Filter [d_week_seq,d_date_sk] - ColumnarToRow - InputAdapter - Scan parquet default.date_dim [d_date_sk,d_week_seq] - InputAdapter - BroadcastExchange #5 - WholeStageCodegen (1) - Project [hd_demo_sk] - Filter [hd_buy_potential,hd_demo_sk] - ColumnarToRow + BroadcastExchange #6 + WholeStageCodegen (1) + Project [d_date_sk,d_date,d_week_seq] + Filter [d_year,d_date_sk,d_week_seq,d_date] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_date,d_week_seq,d_year] + Filter [d_week_seq,d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_week_seq] InputAdapter - Scan parquet default.household_demographics [hd_demo_sk,hd_buy_potential] - InputAdapter - BroadcastExchange #6 - WholeStageCodegen (2) - Project [cd_demo_sk] - Filter [cd_marital_status,cd_demo_sk] - ColumnarToRow + BroadcastExchange #7 + WholeStageCodegen (1) + Project [hd_demo_sk] + Filter [hd_buy_potential,hd_demo_sk] + ColumnarToRow + InputAdapter + Scan parquet default.household_demographics [hd_demo_sk,hd_buy_potential] InputAdapter - Scan parquet default.customer_demographics [cd_demo_sk,cd_marital_status] - InputAdapter - BroadcastExchange #7 - WholeStageCodegen (3) - Filter [d_date,d_date_sk] - ColumnarToRow - InputAdapter - Scan parquet default.date_dim [d_date_sk,d_date] - InputAdapter - WholeStageCodegen (7) - Sort [i_item_sk] - InputAdapter - Exchange [i_item_sk] #8 - WholeStageCodegen (6) - Filter [i_item_sk] - ColumnarToRow - InputAdapter - Scan parquet default.item [i_item_sk,i_item_desc] - InputAdapter - ReusedExchange [d_date_sk,d_date,d_week_seq,d_date_sk] #3 - InputAdapter - WholeStageCodegen (13) - Sort [inv_item_sk,inv_date_sk] + BroadcastExchange #8 + WholeStageCodegen (2) + Project [cd_demo_sk] + Filter [cd_marital_status,cd_demo_sk] + ColumnarToRow + InputAdapter + Scan parquet default.customer_demographics [cd_demo_sk,cd_marital_status] + InputAdapter + BroadcastExchange #9 + WholeStageCodegen (3) + Filter [d_date,d_date_sk] + ColumnarToRow + InputAdapter + Scan parquet default.date_dim [d_date_sk,d_date] + InputAdapter + WholeStageCodegen (7) + Sort [i_item_sk] + InputAdapter + Exchange [i_item_sk] #10 + WholeStageCodegen (6) + Filter [i_item_sk] + ColumnarToRow + InputAdapter + Scan parquet default.item [i_item_sk,i_item_desc] + InputAdapter + ReusedExchange [d_date_sk,d_date,d_week_seq,d_date_sk] #5 InputAdapter - Exchange [inv_item_sk] #9 - WholeStageCodegen (12) - Project [inv_item_sk,inv_quantity_on_hand,inv_date_sk,w_warehouse_name] - BroadcastHashJoin [inv_warehouse_sk,w_warehouse_sk] - Filter [inv_quantity_on_hand,inv_item_sk,inv_warehouse_sk] - ColumnarToRow - InputAdapter - Scan parquet default.inventory [inv_item_sk,inv_warehouse_sk,inv_quantity_on_hand,inv_date_sk] - InputAdapter - BroadcastExchange #10 - WholeStageCodegen (11) - Filter [w_warehouse_sk] + WholeStageCodegen (14) + Sort [inv_item_sk,inv_date_sk] + InputAdapter + Exchange [inv_item_sk,inv_date_sk] #11 + WholeStageCodegen (13) + Project [inv_item_sk,inv_quantity_on_hand,inv_date_sk,w_warehouse_name] + BroadcastHashJoin [inv_warehouse_sk,w_warehouse_sk] + Filter [inv_quantity_on_hand,inv_item_sk,inv_warehouse_sk] ColumnarToRow InputAdapter - Scan parquet default.warehouse [w_warehouse_sk,w_warehouse_name] - InputAdapter - BroadcastExchange #11 - WholeStageCodegen (14) - Filter [p_promo_sk] - ColumnarToRow - InputAdapter - Scan parquet default.promotion [p_promo_sk] + Scan parquet default.inventory [inv_item_sk,inv_warehouse_sk,inv_quantity_on_hand,inv_date_sk] + InputAdapter + BroadcastExchange #12 + WholeStageCodegen (12) + Filter [w_warehouse_sk] + ColumnarToRow + InputAdapter + Scan parquet default.warehouse [w_warehouse_sk,w_warehouse_name] + InputAdapter + BroadcastExchange #13 + WholeStageCodegen (15) + Filter [p_promo_sk] + ColumnarToRow + InputAdapter + Scan parquet default.promotion [p_promo_sk] InputAdapter - WholeStageCodegen (17) + WholeStageCodegen (19) Sort [cr_item_sk,cr_order_number] InputAdapter - Exchange [cr_item_sk] #12 - WholeStageCodegen (16) + Exchange [cr_item_sk,cr_order_number] #14 + WholeStageCodegen (18) Project [cr_item_sk,cr_order_number] Filter [cr_item_sk,cr_order_number] ColumnarToRow diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q74.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q74.sf100/explain.txt index 864593f67a1e1..7ee6ada91dfea 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q74.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q74.sf100/explain.txt @@ -428,7 +428,7 @@ Arguments: [customer_id#69 ASC NULLS FIRST], false, 0 (77) SortMergeJoin [codegen id : 35] Left keys [1]: [customer_id#17] Right keys [1]: [customer_id#69] -Join condition: (CASE WHEN (year_total#54 > 0.00) THEN CheckOverflow((promote_precision(year_total#70) / promote_precision(year_total#54)), DecimalType(37,20), true) END > CASE WHEN (year_total#18 > 0.00) THEN CheckOverflow((promote_precision(year_total#37) / promote_precision(year_total#18)), DecimalType(37,20), true) END) +Join condition: (CASE WHEN (year_total#54 > 0.00) THEN CheckOverflow((promote_precision(year_total#70) / promote_precision(year_total#54)), DecimalType(37,20)) END > CASE WHEN (year_total#18 > 0.00) THEN CheckOverflow((promote_precision(year_total#37) / promote_precision(year_total#18)), DecimalType(37,20)) END) (78) Project [codegen id : 35] Output [3]: [customer_id#34, customer_first_name#35, customer_last_name#36] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q74/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q74/explain.txt index 8e7250c4fc4d3..a2c8929c7f285 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q74/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q74/explain.txt @@ -397,7 +397,7 @@ Arguments: HashedRelationBroadcastMode(List(input[0, string, true]),false), [id= (69) BroadcastHashJoin [codegen id : 16] Left keys [1]: [customer_id#16] Right keys [1]: [customer_id#67] -Join condition: (CASE WHEN (year_total#52 > 0.00) THEN CheckOverflow((promote_precision(year_total#68) / promote_precision(year_total#52)), DecimalType(37,20), true) END > CASE WHEN (year_total#17 > 0.00) THEN CheckOverflow((promote_precision(year_total#35) / promote_precision(year_total#17)), DecimalType(37,20), true) END) +Join condition: (CASE WHEN (year_total#52 > 0.00) THEN CheckOverflow((promote_precision(year_total#68) / promote_precision(year_total#52)), DecimalType(37,20)) END > CASE WHEN (year_total#17 > 0.00) THEN CheckOverflow((promote_precision(year_total#35) / promote_precision(year_total#17)), DecimalType(37,20)) END) (70) Project [codegen id : 16] Output [3]: [customer_id#32, customer_first_name#33, customer_last_name#34] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q75.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q75.sf100/explain.txt index cd66823f10e8c..27a2b5f734281 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q75.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q75.sf100/explain.txt @@ -226,7 +226,7 @@ Right keys [2]: [cr_order_number#18, cr_item_sk#17] Join condition: None (23) Project [codegen id : 7] -Output [7]: [d_year#15, i_brand_id#8, i_class_id#9, i_category_id#10, i_manufact_id#12, (cs_quantity#3 - coalesce(cr_return_quantity#19, 0)) AS sales_cnt#23, CheckOverflow((promote_precision(cast(cs_ext_sales_price#4 as decimal(8,2))) - promote_precision(cast(coalesce(cr_return_amount#20, 0.00) as decimal(8,2)))), DecimalType(8,2), true) AS sales_amt#24] +Output [7]: [d_year#15, i_brand_id#8, i_class_id#9, i_category_id#10, i_manufact_id#12, (cs_quantity#3 - coalesce(cr_return_quantity#19, 0)) AS sales_cnt#23, CheckOverflow((promote_precision(cast(cs_ext_sales_price#4 as decimal(8,2))) - promote_precision(cast(coalesce(cr_return_amount#20, 0.00) as decimal(8,2)))), DecimalType(8,2)) AS sales_amt#24] Input [13]: [cs_item_sk#1, cs_order_number#2, cs_quantity#3, cs_ext_sales_price#4, i_brand_id#8, i_class_id#9, i_category_id#10, i_manufact_id#12, d_year#15, cr_item_sk#17, cr_order_number#18, cr_return_quantity#19, cr_return_amount#20] (24) Scan parquet default.store_sales @@ -308,7 +308,7 @@ Right keys [2]: [sr_ticket_number#39, sr_item_sk#38] Join condition: None (42) Project [codegen id : 14] -Output [7]: [d_year#36, i_brand_id#31, i_class_id#32, i_category_id#33, i_manufact_id#34, (ss_quantity#27 - coalesce(sr_return_quantity#40, 0)) AS sales_cnt#44, CheckOverflow((promote_precision(cast(ss_ext_sales_price#28 as decimal(8,2))) - promote_precision(cast(coalesce(sr_return_amt#41, 0.00) as decimal(8,2)))), DecimalType(8,2), true) AS sales_amt#45] +Output [7]: [d_year#36, i_brand_id#31, i_class_id#32, i_category_id#33, i_manufact_id#34, (ss_quantity#27 - coalesce(sr_return_quantity#40, 0)) AS sales_cnt#44, CheckOverflow((promote_precision(cast(ss_ext_sales_price#28 as decimal(8,2))) - promote_precision(cast(coalesce(sr_return_amt#41, 0.00) as decimal(8,2)))), DecimalType(8,2)) AS sales_amt#45] Input [13]: [ss_item_sk#25, ss_ticket_number#26, ss_quantity#27, ss_ext_sales_price#28, i_brand_id#31, i_class_id#32, i_category_id#33, i_manufact_id#34, d_year#36, sr_item_sk#38, sr_ticket_number#39, sr_return_quantity#40, sr_return_amt#41] (43) Scan parquet default.web_sales @@ -390,7 +390,7 @@ Right keys [2]: [wr_order_number#60, wr_item_sk#59] Join condition: None (61) Project [codegen id : 21] -Output [7]: [d_year#57, i_brand_id#52, i_class_id#53, i_category_id#54, i_manufact_id#55, (ws_quantity#48 - coalesce(wr_return_quantity#61, 0)) AS sales_cnt#65, CheckOverflow((promote_precision(cast(ws_ext_sales_price#49 as decimal(8,2))) - promote_precision(cast(coalesce(wr_return_amt#62, 0.00) as decimal(8,2)))), DecimalType(8,2), true) AS sales_amt#66] +Output [7]: [d_year#57, i_brand_id#52, i_class_id#53, i_category_id#54, i_manufact_id#55, (ws_quantity#48 - coalesce(wr_return_quantity#61, 0)) AS sales_cnt#65, CheckOverflow((promote_precision(cast(ws_ext_sales_price#49 as decimal(8,2))) - promote_precision(cast(coalesce(wr_return_amt#62, 0.00) as decimal(8,2)))), DecimalType(8,2)) AS sales_amt#66] Input [13]: [ws_item_sk#46, ws_order_number#47, ws_quantity#48, ws_ext_sales_price#49, i_brand_id#52, i_class_id#53, i_category_id#54, i_manufact_id#55, d_year#57, wr_item_sk#59, wr_order_number#60, wr_return_quantity#61, wr_return_amt#62] (62) Union @@ -499,7 +499,7 @@ Right keys [2]: [cr_order_number#93, cr_item_sk#92] Join condition: None (85) Project [codegen id : 32] -Output [7]: [d_year#90, i_brand_id#85, i_class_id#86, i_category_id#87, i_manufact_id#88, (cs_quantity#80 - coalesce(cr_return_quantity#94, 0)) AS sales_cnt#23, CheckOverflow((promote_precision(cast(cs_ext_sales_price#81 as decimal(8,2))) - promote_precision(cast(coalesce(cr_return_amount#95, 0.00) as decimal(8,2)))), DecimalType(8,2), true) AS sales_amt#24] +Output [7]: [d_year#90, i_brand_id#85, i_class_id#86, i_category_id#87, i_manufact_id#88, (cs_quantity#80 - coalesce(cr_return_quantity#94, 0)) AS sales_cnt#23, CheckOverflow((promote_precision(cast(cs_ext_sales_price#81 as decimal(8,2))) - promote_precision(cast(coalesce(cr_return_amount#95, 0.00) as decimal(8,2)))), DecimalType(8,2)) AS sales_amt#24] Input [13]: [cs_item_sk#78, cs_order_number#79, cs_quantity#80, cs_ext_sales_price#81, i_brand_id#85, i_class_id#86, i_category_id#87, i_manufact_id#88, d_year#90, cr_item_sk#92, cr_order_number#93, cr_return_quantity#94, cr_return_amount#95] (86) Scan parquet default.store_sales @@ -562,7 +562,7 @@ Right keys [2]: [sr_ticket_number#110, sr_item_sk#109] Join condition: None (100) Project [codegen id : 39] -Output [7]: [d_year#107, i_brand_id#102, i_class_id#103, i_category_id#104, i_manufact_id#105, (ss_quantity#98 - coalesce(sr_return_quantity#111, 0)) AS sales_cnt#44, CheckOverflow((promote_precision(cast(ss_ext_sales_price#99 as decimal(8,2))) - promote_precision(cast(coalesce(sr_return_amt#112, 0.00) as decimal(8,2)))), DecimalType(8,2), true) AS sales_amt#45] +Output [7]: [d_year#107, i_brand_id#102, i_class_id#103, i_category_id#104, i_manufact_id#105, (ss_quantity#98 - coalesce(sr_return_quantity#111, 0)) AS sales_cnt#44, CheckOverflow((promote_precision(cast(ss_ext_sales_price#99 as decimal(8,2))) - promote_precision(cast(coalesce(sr_return_amt#112, 0.00) as decimal(8,2)))), DecimalType(8,2)) AS sales_amt#45] Input [13]: [ss_item_sk#96, ss_ticket_number#97, ss_quantity#98, ss_ext_sales_price#99, i_brand_id#102, i_class_id#103, i_category_id#104, i_manufact_id#105, d_year#107, sr_item_sk#109, sr_ticket_number#110, sr_return_quantity#111, sr_return_amt#112] (101) Scan parquet default.web_sales @@ -625,7 +625,7 @@ Right keys [2]: [wr_order_number#127, wr_item_sk#126] Join condition: None (115) Project [codegen id : 46] -Output [7]: [d_year#124, i_brand_id#119, i_class_id#120, i_category_id#121, i_manufact_id#122, (ws_quantity#115 - coalesce(wr_return_quantity#128, 0)) AS sales_cnt#65, CheckOverflow((promote_precision(cast(ws_ext_sales_price#116 as decimal(8,2))) - promote_precision(cast(coalesce(wr_return_amt#129, 0.00) as decimal(8,2)))), DecimalType(8,2), true) AS sales_amt#66] +Output [7]: [d_year#124, i_brand_id#119, i_class_id#120, i_category_id#121, i_manufact_id#122, (ws_quantity#115 - coalesce(wr_return_quantity#128, 0)) AS sales_cnt#65, CheckOverflow((promote_precision(cast(ws_ext_sales_price#116 as decimal(8,2))) - promote_precision(cast(coalesce(wr_return_amt#129, 0.00) as decimal(8,2)))), DecimalType(8,2)) AS sales_amt#66] Input [13]: [ws_item_sk#113, ws_order_number#114, ws_quantity#115, ws_ext_sales_price#116, i_brand_id#119, i_class_id#120, i_category_id#121, i_manufact_id#122, d_year#124, wr_item_sk#126, wr_order_number#127, wr_return_quantity#128, wr_return_amt#129] (116) Union @@ -677,10 +677,10 @@ Arguments: [i_brand_id#85 ASC NULLS FIRST, i_class_id#86 ASC NULLS FIRST, i_cate (125) SortMergeJoin [codegen id : 51] Left keys [4]: [i_brand_id#8, i_class_id#9, i_category_id#10, i_manufact_id#12] Right keys [4]: [i_brand_id#85, i_class_id#86, i_category_id#87, i_manufact_id#88] -Join condition: (CheckOverflow((promote_precision(cast(sales_cnt#75 as decimal(17,2))) / promote_precision(cast(sales_cnt#134 as decimal(17,2)))), DecimalType(37,20), true) < 0.90000000000000000000) +Join condition: (CheckOverflow((promote_precision(cast(sales_cnt#75 as decimal(17,2))) / promote_precision(cast(sales_cnt#134 as decimal(17,2)))), DecimalType(37,20)) < 0.90000000000000000000) (126) Project [codegen id : 51] -Output [10]: [d_year#90 AS prev_year#137, d_year#15 AS year#138, i_brand_id#8, i_class_id#9, i_category_id#10, i_manufact_id#12, sales_cnt#134 AS prev_yr_cnt#139, sales_cnt#75 AS curr_yr_cnt#140, (sales_cnt#75 - sales_cnt#134) AS sales_cnt_diff#141, CheckOverflow((promote_precision(cast(sales_amt#76 as decimal(19,2))) - promote_precision(cast(sales_amt#135 as decimal(19,2)))), DecimalType(19,2), true) AS sales_amt_diff#142] +Output [10]: [d_year#90 AS prev_year#137, d_year#15 AS year#138, i_brand_id#8, i_class_id#9, i_category_id#10, i_manufact_id#12, sales_cnt#134 AS prev_yr_cnt#139, sales_cnt#75 AS curr_yr_cnt#140, (sales_cnt#75 - sales_cnt#134) AS sales_cnt_diff#141, CheckOverflow((promote_precision(cast(sales_amt#76 as decimal(19,2))) - promote_precision(cast(sales_amt#135 as decimal(19,2)))), DecimalType(19,2)) AS sales_amt_diff#142] Input [14]: [d_year#15, i_brand_id#8, i_class_id#9, i_category_id#10, i_manufact_id#12, sales_cnt#75, sales_amt#76, d_year#90, i_brand_id#85, i_class_id#86, i_category_id#87, i_manufact_id#88, sales_cnt#134, sales_amt#135] (127) TakeOrderedAndProject diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q75/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q75/explain.txt index cd66823f10e8c..27a2b5f734281 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q75/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q75/explain.txt @@ -226,7 +226,7 @@ Right keys [2]: [cr_order_number#18, cr_item_sk#17] Join condition: None (23) Project [codegen id : 7] -Output [7]: [d_year#15, i_brand_id#8, i_class_id#9, i_category_id#10, i_manufact_id#12, (cs_quantity#3 - coalesce(cr_return_quantity#19, 0)) AS sales_cnt#23, CheckOverflow((promote_precision(cast(cs_ext_sales_price#4 as decimal(8,2))) - promote_precision(cast(coalesce(cr_return_amount#20, 0.00) as decimal(8,2)))), DecimalType(8,2), true) AS sales_amt#24] +Output [7]: [d_year#15, i_brand_id#8, i_class_id#9, i_category_id#10, i_manufact_id#12, (cs_quantity#3 - coalesce(cr_return_quantity#19, 0)) AS sales_cnt#23, CheckOverflow((promote_precision(cast(cs_ext_sales_price#4 as decimal(8,2))) - promote_precision(cast(coalesce(cr_return_amount#20, 0.00) as decimal(8,2)))), DecimalType(8,2)) AS sales_amt#24] Input [13]: [cs_item_sk#1, cs_order_number#2, cs_quantity#3, cs_ext_sales_price#4, i_brand_id#8, i_class_id#9, i_category_id#10, i_manufact_id#12, d_year#15, cr_item_sk#17, cr_order_number#18, cr_return_quantity#19, cr_return_amount#20] (24) Scan parquet default.store_sales @@ -308,7 +308,7 @@ Right keys [2]: [sr_ticket_number#39, sr_item_sk#38] Join condition: None (42) Project [codegen id : 14] -Output [7]: [d_year#36, i_brand_id#31, i_class_id#32, i_category_id#33, i_manufact_id#34, (ss_quantity#27 - coalesce(sr_return_quantity#40, 0)) AS sales_cnt#44, CheckOverflow((promote_precision(cast(ss_ext_sales_price#28 as decimal(8,2))) - promote_precision(cast(coalesce(sr_return_amt#41, 0.00) as decimal(8,2)))), DecimalType(8,2), true) AS sales_amt#45] +Output [7]: [d_year#36, i_brand_id#31, i_class_id#32, i_category_id#33, i_manufact_id#34, (ss_quantity#27 - coalesce(sr_return_quantity#40, 0)) AS sales_cnt#44, CheckOverflow((promote_precision(cast(ss_ext_sales_price#28 as decimal(8,2))) - promote_precision(cast(coalesce(sr_return_amt#41, 0.00) as decimal(8,2)))), DecimalType(8,2)) AS sales_amt#45] Input [13]: [ss_item_sk#25, ss_ticket_number#26, ss_quantity#27, ss_ext_sales_price#28, i_brand_id#31, i_class_id#32, i_category_id#33, i_manufact_id#34, d_year#36, sr_item_sk#38, sr_ticket_number#39, sr_return_quantity#40, sr_return_amt#41] (43) Scan parquet default.web_sales @@ -390,7 +390,7 @@ Right keys [2]: [wr_order_number#60, wr_item_sk#59] Join condition: None (61) Project [codegen id : 21] -Output [7]: [d_year#57, i_brand_id#52, i_class_id#53, i_category_id#54, i_manufact_id#55, (ws_quantity#48 - coalesce(wr_return_quantity#61, 0)) AS sales_cnt#65, CheckOverflow((promote_precision(cast(ws_ext_sales_price#49 as decimal(8,2))) - promote_precision(cast(coalesce(wr_return_amt#62, 0.00) as decimal(8,2)))), DecimalType(8,2), true) AS sales_amt#66] +Output [7]: [d_year#57, i_brand_id#52, i_class_id#53, i_category_id#54, i_manufact_id#55, (ws_quantity#48 - coalesce(wr_return_quantity#61, 0)) AS sales_cnt#65, CheckOverflow((promote_precision(cast(ws_ext_sales_price#49 as decimal(8,2))) - promote_precision(cast(coalesce(wr_return_amt#62, 0.00) as decimal(8,2)))), DecimalType(8,2)) AS sales_amt#66] Input [13]: [ws_item_sk#46, ws_order_number#47, ws_quantity#48, ws_ext_sales_price#49, i_brand_id#52, i_class_id#53, i_category_id#54, i_manufact_id#55, d_year#57, wr_item_sk#59, wr_order_number#60, wr_return_quantity#61, wr_return_amt#62] (62) Union @@ -499,7 +499,7 @@ Right keys [2]: [cr_order_number#93, cr_item_sk#92] Join condition: None (85) Project [codegen id : 32] -Output [7]: [d_year#90, i_brand_id#85, i_class_id#86, i_category_id#87, i_manufact_id#88, (cs_quantity#80 - coalesce(cr_return_quantity#94, 0)) AS sales_cnt#23, CheckOverflow((promote_precision(cast(cs_ext_sales_price#81 as decimal(8,2))) - promote_precision(cast(coalesce(cr_return_amount#95, 0.00) as decimal(8,2)))), DecimalType(8,2), true) AS sales_amt#24] +Output [7]: [d_year#90, i_brand_id#85, i_class_id#86, i_category_id#87, i_manufact_id#88, (cs_quantity#80 - coalesce(cr_return_quantity#94, 0)) AS sales_cnt#23, CheckOverflow((promote_precision(cast(cs_ext_sales_price#81 as decimal(8,2))) - promote_precision(cast(coalesce(cr_return_amount#95, 0.00) as decimal(8,2)))), DecimalType(8,2)) AS sales_amt#24] Input [13]: [cs_item_sk#78, cs_order_number#79, cs_quantity#80, cs_ext_sales_price#81, i_brand_id#85, i_class_id#86, i_category_id#87, i_manufact_id#88, d_year#90, cr_item_sk#92, cr_order_number#93, cr_return_quantity#94, cr_return_amount#95] (86) Scan parquet default.store_sales @@ -562,7 +562,7 @@ Right keys [2]: [sr_ticket_number#110, sr_item_sk#109] Join condition: None (100) Project [codegen id : 39] -Output [7]: [d_year#107, i_brand_id#102, i_class_id#103, i_category_id#104, i_manufact_id#105, (ss_quantity#98 - coalesce(sr_return_quantity#111, 0)) AS sales_cnt#44, CheckOverflow((promote_precision(cast(ss_ext_sales_price#99 as decimal(8,2))) - promote_precision(cast(coalesce(sr_return_amt#112, 0.00) as decimal(8,2)))), DecimalType(8,2), true) AS sales_amt#45] +Output [7]: [d_year#107, i_brand_id#102, i_class_id#103, i_category_id#104, i_manufact_id#105, (ss_quantity#98 - coalesce(sr_return_quantity#111, 0)) AS sales_cnt#44, CheckOverflow((promote_precision(cast(ss_ext_sales_price#99 as decimal(8,2))) - promote_precision(cast(coalesce(sr_return_amt#112, 0.00) as decimal(8,2)))), DecimalType(8,2)) AS sales_amt#45] Input [13]: [ss_item_sk#96, ss_ticket_number#97, ss_quantity#98, ss_ext_sales_price#99, i_brand_id#102, i_class_id#103, i_category_id#104, i_manufact_id#105, d_year#107, sr_item_sk#109, sr_ticket_number#110, sr_return_quantity#111, sr_return_amt#112] (101) Scan parquet default.web_sales @@ -625,7 +625,7 @@ Right keys [2]: [wr_order_number#127, wr_item_sk#126] Join condition: None (115) Project [codegen id : 46] -Output [7]: [d_year#124, i_brand_id#119, i_class_id#120, i_category_id#121, i_manufact_id#122, (ws_quantity#115 - coalesce(wr_return_quantity#128, 0)) AS sales_cnt#65, CheckOverflow((promote_precision(cast(ws_ext_sales_price#116 as decimal(8,2))) - promote_precision(cast(coalesce(wr_return_amt#129, 0.00) as decimal(8,2)))), DecimalType(8,2), true) AS sales_amt#66] +Output [7]: [d_year#124, i_brand_id#119, i_class_id#120, i_category_id#121, i_manufact_id#122, (ws_quantity#115 - coalesce(wr_return_quantity#128, 0)) AS sales_cnt#65, CheckOverflow((promote_precision(cast(ws_ext_sales_price#116 as decimal(8,2))) - promote_precision(cast(coalesce(wr_return_amt#129, 0.00) as decimal(8,2)))), DecimalType(8,2)) AS sales_amt#66] Input [13]: [ws_item_sk#113, ws_order_number#114, ws_quantity#115, ws_ext_sales_price#116, i_brand_id#119, i_class_id#120, i_category_id#121, i_manufact_id#122, d_year#124, wr_item_sk#126, wr_order_number#127, wr_return_quantity#128, wr_return_amt#129] (116) Union @@ -677,10 +677,10 @@ Arguments: [i_brand_id#85 ASC NULLS FIRST, i_class_id#86 ASC NULLS FIRST, i_cate (125) SortMergeJoin [codegen id : 51] Left keys [4]: [i_brand_id#8, i_class_id#9, i_category_id#10, i_manufact_id#12] Right keys [4]: [i_brand_id#85, i_class_id#86, i_category_id#87, i_manufact_id#88] -Join condition: (CheckOverflow((promote_precision(cast(sales_cnt#75 as decimal(17,2))) / promote_precision(cast(sales_cnt#134 as decimal(17,2)))), DecimalType(37,20), true) < 0.90000000000000000000) +Join condition: (CheckOverflow((promote_precision(cast(sales_cnt#75 as decimal(17,2))) / promote_precision(cast(sales_cnt#134 as decimal(17,2)))), DecimalType(37,20)) < 0.90000000000000000000) (126) Project [codegen id : 51] -Output [10]: [d_year#90 AS prev_year#137, d_year#15 AS year#138, i_brand_id#8, i_class_id#9, i_category_id#10, i_manufact_id#12, sales_cnt#134 AS prev_yr_cnt#139, sales_cnt#75 AS curr_yr_cnt#140, (sales_cnt#75 - sales_cnt#134) AS sales_cnt_diff#141, CheckOverflow((promote_precision(cast(sales_amt#76 as decimal(19,2))) - promote_precision(cast(sales_amt#135 as decimal(19,2)))), DecimalType(19,2), true) AS sales_amt_diff#142] +Output [10]: [d_year#90 AS prev_year#137, d_year#15 AS year#138, i_brand_id#8, i_class_id#9, i_category_id#10, i_manufact_id#12, sales_cnt#134 AS prev_yr_cnt#139, sales_cnt#75 AS curr_yr_cnt#140, (sales_cnt#75 - sales_cnt#134) AS sales_cnt_diff#141, CheckOverflow((promote_precision(cast(sales_amt#76 as decimal(19,2))) - promote_precision(cast(sales_amt#135 as decimal(19,2)))), DecimalType(19,2)) AS sales_amt_diff#142] Input [14]: [d_year#15, i_brand_id#8, i_class_id#9, i_category_id#10, i_manufact_id#12, sales_cnt#75, sales_amt#76, d_year#90, i_brand_id#85, i_class_id#86, i_category_id#87, i_manufact_id#88, sales_cnt#134, sales_amt#135] (127) TakeOrderedAndProject diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q77a.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q77a.sf100/explain.txt index 4d27141fd8465..335e1aee4e5ca 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q77a.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q77a.sf100/explain.txt @@ -238,7 +238,7 @@ Right keys [1]: [s_store_sk#23] Join condition: None (30) Project [codegen id : 8] -Output [5]: [store channel AS channel#34, s_store_sk#7 AS id#35, sales#16, coalesce(returns#31, 0.00) AS returns#36, CheckOverflow((promote_precision(cast(profit#17 as decimal(18,2))) - promote_precision(cast(coalesce(profit_loss#32, 0.00) as decimal(18,2)))), DecimalType(18,2), true) AS profit#37] +Output [5]: [store channel AS channel#34, s_store_sk#7 AS id#35, sales#16, coalesce(returns#31, 0.00) AS returns#36, CheckOverflow((promote_precision(cast(profit#17 as decimal(18,2))) - promote_precision(cast(coalesce(profit_loss#32, 0.00) as decimal(18,2)))), DecimalType(18,2)) AS profit#37] Input [6]: [s_store_sk#7, sales#16, profit#17, s_store_sk#23, returns#31, profit_loss#32] (31) Scan parquet default.catalog_sales @@ -329,7 +329,7 @@ Arguments: IdentityBroadcastMode, [id=#65] Join condition: None (49) Project [codegen id : 14] -Output [5]: [catalog channel AS channel#66, cs_call_center_sk#38 AS id#67, sales#50, returns#63, CheckOverflow((promote_precision(cast(profit#51 as decimal(18,2))) - promote_precision(cast(profit_loss#64 as decimal(18,2)))), DecimalType(18,2), true) AS profit#68] +Output [5]: [catalog channel AS channel#66, cs_call_center_sk#38 AS id#67, sales#50, returns#63, CheckOverflow((promote_precision(cast(profit#51 as decimal(18,2))) - promote_precision(cast(profit_loss#64 as decimal(18,2)))), DecimalType(18,2)) AS profit#68] Input [5]: [cs_call_center_sk#38, sales#50, profit#51, returns#63, profit_loss#64] (50) Scan parquet default.web_sales @@ -471,7 +471,7 @@ Right keys [1]: [wp_web_page_sk#90] Join condition: None (79) Project [codegen id : 22] -Output [5]: [web channel AS channel#101, wp_web_page_sk#74 AS id#102, sales#83, coalesce(returns#98, 0.00) AS returns#103, CheckOverflow((promote_precision(cast(profit#84 as decimal(18,2))) - promote_precision(cast(coalesce(profit_loss#99, 0.00) as decimal(18,2)))), DecimalType(18,2), true) AS profit#104] +Output [5]: [web channel AS channel#101, wp_web_page_sk#74 AS id#102, sales#83, coalesce(returns#98, 0.00) AS returns#103, CheckOverflow((promote_precision(cast(profit#84 as decimal(18,2))) - promote_precision(cast(coalesce(profit_loss#99, 0.00) as decimal(18,2)))), DecimalType(18,2)) AS profit#104] Input [6]: [wp_web_page_sk#74, sales#83, profit#84, wp_web_page_sk#90, returns#98, profit_loss#99] (80) Union diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q77a/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q77a/explain.txt index a1d99b72c8147..815eabe2fe0e8 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q77a/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q77a/explain.txt @@ -238,7 +238,7 @@ Right keys [1]: [s_store_sk#23] Join condition: None (30) Project [codegen id : 8] -Output [5]: [store channel AS channel#34, s_store_sk#7 AS id#35, sales#16, coalesce(returns#31, 0.00) AS returns#36, CheckOverflow((promote_precision(cast(profit#17 as decimal(18,2))) - promote_precision(cast(coalesce(profit_loss#32, 0.00) as decimal(18,2)))), DecimalType(18,2), true) AS profit#37] +Output [5]: [store channel AS channel#34, s_store_sk#7 AS id#35, sales#16, coalesce(returns#31, 0.00) AS returns#36, CheckOverflow((promote_precision(cast(profit#17 as decimal(18,2))) - promote_precision(cast(coalesce(profit_loss#32, 0.00) as decimal(18,2)))), DecimalType(18,2)) AS profit#37] Input [6]: [s_store_sk#7, sales#16, profit#17, s_store_sk#23, returns#31, profit_loss#32] (31) Scan parquet default.catalog_sales @@ -329,7 +329,7 @@ Results [2]: [MakeDecimal(sum(UnscaledValue(cr_return_amount#53))#62,17,2) AS re Join condition: None (49) Project [codegen id : 14] -Output [5]: [catalog channel AS channel#66, cs_call_center_sk#38 AS id#67, sales#50, returns#64, CheckOverflow((promote_precision(cast(profit#51 as decimal(18,2))) - promote_precision(cast(profit_loss#65 as decimal(18,2)))), DecimalType(18,2), true) AS profit#68] +Output [5]: [catalog channel AS channel#66, cs_call_center_sk#38 AS id#67, sales#50, returns#64, CheckOverflow((promote_precision(cast(profit#51 as decimal(18,2))) - promote_precision(cast(profit_loss#65 as decimal(18,2)))), DecimalType(18,2)) AS profit#68] Input [5]: [cs_call_center_sk#38, sales#50, profit#51, returns#64, profit_loss#65] (50) Scan parquet default.web_sales @@ -471,7 +471,7 @@ Right keys [1]: [wp_web_page_sk#90] Join condition: None (79) Project [codegen id : 22] -Output [5]: [web channel AS channel#101, wp_web_page_sk#74 AS id#102, sales#83, coalesce(returns#98, 0.00) AS returns#103, CheckOverflow((promote_precision(cast(profit#84 as decimal(18,2))) - promote_precision(cast(coalesce(profit_loss#99, 0.00) as decimal(18,2)))), DecimalType(18,2), true) AS profit#104] +Output [5]: [web channel AS channel#101, wp_web_page_sk#74 AS id#102, sales#83, coalesce(returns#98, 0.00) AS returns#103, CheckOverflow((promote_precision(cast(profit#84 as decimal(18,2))) - promote_precision(cast(coalesce(profit_loss#99, 0.00) as decimal(18,2)))), DecimalType(18,2)) AS profit#104] Input [6]: [wp_web_page_sk#74, sales#83, profit#84, wp_web_page_sk#90, returns#98, profit_loss#99] (80) Union diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q78.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q78.sf100/explain.txt index b54f3fa20c63f..133d5272ec111 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q78.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q78.sf100/explain.txt @@ -382,7 +382,7 @@ Right keys [3]: [cs_sold_year#83, cs_item_sk#60, cs_customer_sk#84] Join condition: None (69) Project [codegen id : 23] -Output [13]: [round((cast(ss_qty#27 as double) / cast(coalesce((ws_qty#56 + cs_qty#85), 1) as double)), 2) AS ratio#88, ss_qty#27 AS store_qty#89, ss_wc#28 AS store_wholesale_cost#90, ss_sp#29 AS store_sales_price#91, (coalesce(ws_qty#56, 0) + coalesce(cs_qty#85, 0)) AS other_chan_qty#92, CheckOverflow((promote_precision(cast(coalesce(ws_wc#57, 0.00) as decimal(18,2))) + promote_precision(cast(coalesce(cs_wc#86, 0.00) as decimal(18,2)))), DecimalType(18,2), true) AS other_chan_wholesale_cost#93, CheckOverflow((promote_precision(cast(coalesce(ws_sp#58, 0.00) as decimal(18,2))) + promote_precision(cast(coalesce(cs_sp#87, 0.00) as decimal(18,2)))), DecimalType(18,2), true) AS other_chan_sales_price#94, ss_sold_year#26, ss_item_sk#1, ss_customer_sk#2, ss_qty#27, ss_wc#28, ss_sp#29] +Output [13]: [round((cast(ss_qty#27 as double) / cast(coalesce((ws_qty#56 + cs_qty#85), 1) as double)), 2) AS ratio#88, ss_qty#27 AS store_qty#89, ss_wc#28 AS store_wholesale_cost#90, ss_sp#29 AS store_sales_price#91, (coalesce(ws_qty#56, 0) + coalesce(cs_qty#85, 0)) AS other_chan_qty#92, CheckOverflow((promote_precision(cast(coalesce(ws_wc#57, 0.00) as decimal(18,2))) + promote_precision(cast(coalesce(cs_wc#86, 0.00) as decimal(18,2)))), DecimalType(18,2)) AS other_chan_wholesale_cost#93, CheckOverflow((promote_precision(cast(coalesce(ws_sp#58, 0.00) as decimal(18,2))) + promote_precision(cast(coalesce(cs_sp#87, 0.00) as decimal(18,2)))), DecimalType(18,2)) AS other_chan_sales_price#94, ss_sold_year#26, ss_item_sk#1, ss_customer_sk#2, ss_qty#27, ss_wc#28, ss_sp#29] Input [15]: [ss_sold_year#26, ss_item_sk#1, ss_customer_sk#2, ss_qty#27, ss_wc#28, ss_sp#29, ws_qty#56, ws_wc#57, ws_sp#58, cs_sold_year#83, cs_item_sk#60, cs_customer_sk#84, cs_qty#85, cs_wc#86, cs_sp#87] (70) TakeOrderedAndProject diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q78/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q78/explain.txt index b54f3fa20c63f..133d5272ec111 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q78/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q78/explain.txt @@ -382,7 +382,7 @@ Right keys [3]: [cs_sold_year#83, cs_item_sk#60, cs_customer_sk#84] Join condition: None (69) Project [codegen id : 23] -Output [13]: [round((cast(ss_qty#27 as double) / cast(coalesce((ws_qty#56 + cs_qty#85), 1) as double)), 2) AS ratio#88, ss_qty#27 AS store_qty#89, ss_wc#28 AS store_wholesale_cost#90, ss_sp#29 AS store_sales_price#91, (coalesce(ws_qty#56, 0) + coalesce(cs_qty#85, 0)) AS other_chan_qty#92, CheckOverflow((promote_precision(cast(coalesce(ws_wc#57, 0.00) as decimal(18,2))) + promote_precision(cast(coalesce(cs_wc#86, 0.00) as decimal(18,2)))), DecimalType(18,2), true) AS other_chan_wholesale_cost#93, CheckOverflow((promote_precision(cast(coalesce(ws_sp#58, 0.00) as decimal(18,2))) + promote_precision(cast(coalesce(cs_sp#87, 0.00) as decimal(18,2)))), DecimalType(18,2), true) AS other_chan_sales_price#94, ss_sold_year#26, ss_item_sk#1, ss_customer_sk#2, ss_qty#27, ss_wc#28, ss_sp#29] +Output [13]: [round((cast(ss_qty#27 as double) / cast(coalesce((ws_qty#56 + cs_qty#85), 1) as double)), 2) AS ratio#88, ss_qty#27 AS store_qty#89, ss_wc#28 AS store_wholesale_cost#90, ss_sp#29 AS store_sales_price#91, (coalesce(ws_qty#56, 0) + coalesce(cs_qty#85, 0)) AS other_chan_qty#92, CheckOverflow((promote_precision(cast(coalesce(ws_wc#57, 0.00) as decimal(18,2))) + promote_precision(cast(coalesce(cs_wc#86, 0.00) as decimal(18,2)))), DecimalType(18,2)) AS other_chan_wholesale_cost#93, CheckOverflow((promote_precision(cast(coalesce(ws_sp#58, 0.00) as decimal(18,2))) + promote_precision(cast(coalesce(cs_sp#87, 0.00) as decimal(18,2)))), DecimalType(18,2)) AS other_chan_sales_price#94, ss_sold_year#26, ss_item_sk#1, ss_customer_sk#2, ss_qty#27, ss_wc#28, ss_sp#29] Input [15]: [ss_sold_year#26, ss_item_sk#1, ss_customer_sk#2, ss_qty#27, ss_wc#28, ss_sp#29, ws_qty#56, ws_wc#57, ws_sp#58, cs_sold_year#83, cs_item_sk#60, cs_customer_sk#84, cs_qty#85, cs_wc#86, cs_sp#87] (70) TakeOrderedAndProject diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q80a.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q80a.sf100/explain.txt index 34777c108a268..a9ea4905b9fb7 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q80a.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q80a.sf100/explain.txt @@ -283,7 +283,7 @@ Input [7]: [ss_store_sk#2, ss_ext_sales_price#5, ss_net_profit#6, sr_return_amt# (37) HashAggregate [codegen id : 9] Input [5]: [ss_ext_sales_price#5, ss_net_profit#6, sr_return_amt#12, sr_net_loss#13, s_store_id#24] Keys [1]: [s_store_id#24] -Functions [3]: [partial_sum(UnscaledValue(ss_ext_sales_price#5)), partial_sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00)), partial_sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))] +Functions [3]: [partial_sum(UnscaledValue(ss_ext_sales_price#5)), partial_sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00)), partial_sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))] Aggregate Attributes [5]: [sum#26, sum#27, isEmpty#28, sum#29, isEmpty#30] Results [6]: [s_store_id#24, sum#31, sum#32, isEmpty#33, sum#34, isEmpty#35] @@ -294,9 +294,9 @@ Arguments: hashpartitioning(s_store_id#24, 5), ENSURE_REQUIREMENTS, [id=#36] (39) HashAggregate [codegen id : 10] Input [6]: [s_store_id#24, sum#31, sum#32, isEmpty#33, sum#34, isEmpty#35] Keys [1]: [s_store_id#24] -Functions [3]: [sum(UnscaledValue(ss_ext_sales_price#5)), sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00)), sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))] -Aggregate Attributes [3]: [sum(UnscaledValue(ss_ext_sales_price#5))#37, sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00))#38, sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))#39] -Results [5]: [store channel AS channel#40, concat(store, s_store_id#24) AS id#41, MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#5))#37,17,2) AS sales#42, sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00))#38 AS returns#43, sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))#39 AS profit#44] +Functions [3]: [sum(UnscaledValue(ss_ext_sales_price#5)), sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00)), sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))] +Aggregate Attributes [3]: [sum(UnscaledValue(ss_ext_sales_price#5))#37, sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00))#38, sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))#39] +Results [5]: [store channel AS channel#40, concat(store, s_store_id#24) AS id#41, MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#5))#37,17,2) AS sales#42, sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00))#38 AS returns#43, sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))#39 AS profit#44] (40) Scan parquet default.catalog_sales Output [7]: [cs_catalog_page_sk#45, cs_item_sk#46, cs_promo_sk#47, cs_order_number#48, cs_ext_sales_price#49, cs_net_profit#50, cs_sold_date_sk#51] @@ -422,7 +422,7 @@ Input [7]: [cs_catalog_page_sk#45, cs_ext_sales_price#49, cs_net_profit#50, cr_r (68) HashAggregate [codegen id : 19] Input [5]: [cs_ext_sales_price#49, cs_net_profit#50, cr_return_amount#55, cr_net_loss#56, cp_catalog_page_id#63] Keys [1]: [cp_catalog_page_id#63] -Functions [3]: [partial_sum(UnscaledValue(cs_ext_sales_price#49)), partial_sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00)), partial_sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))] +Functions [3]: [partial_sum(UnscaledValue(cs_ext_sales_price#49)), partial_sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00)), partial_sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))] Aggregate Attributes [5]: [sum#65, sum#66, isEmpty#67, sum#68, isEmpty#69] Results [6]: [cp_catalog_page_id#63, sum#70, sum#71, isEmpty#72, sum#73, isEmpty#74] @@ -433,9 +433,9 @@ Arguments: hashpartitioning(cp_catalog_page_id#63, 5), ENSURE_REQUIREMENTS, [id= (70) HashAggregate [codegen id : 20] Input [6]: [cp_catalog_page_id#63, sum#70, sum#71, isEmpty#72, sum#73, isEmpty#74] Keys [1]: [cp_catalog_page_id#63] -Functions [3]: [sum(UnscaledValue(cs_ext_sales_price#49)), sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00)), sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))] -Aggregate Attributes [3]: [sum(UnscaledValue(cs_ext_sales_price#49))#76, sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00))#77, sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))#78] -Results [5]: [catalog channel AS channel#79, concat(catalog_page, cp_catalog_page_id#63) AS id#80, MakeDecimal(sum(UnscaledValue(cs_ext_sales_price#49))#76,17,2) AS sales#81, sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00))#77 AS returns#82, sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))#78 AS profit#83] +Functions [3]: [sum(UnscaledValue(cs_ext_sales_price#49)), sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00)), sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))] +Aggregate Attributes [3]: [sum(UnscaledValue(cs_ext_sales_price#49))#76, sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00))#77, sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))#78] +Results [5]: [catalog channel AS channel#79, concat(catalog_page, cp_catalog_page_id#63) AS id#80, MakeDecimal(sum(UnscaledValue(cs_ext_sales_price#49))#76,17,2) AS sales#81, sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00))#77 AS returns#82, sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))#78 AS profit#83] (71) Scan parquet default.web_sales Output [7]: [ws_item_sk#84, ws_web_site_sk#85, ws_promo_sk#86, ws_order_number#87, ws_ext_sales_price#88, ws_net_profit#89, ws_sold_date_sk#90] @@ -561,7 +561,7 @@ Input [7]: [ws_web_site_sk#85, ws_ext_sales_price#88, ws_net_profit#89, wr_retur (99) HashAggregate [codegen id : 29] Input [5]: [ws_ext_sales_price#88, ws_net_profit#89, wr_return_amt#94, wr_net_loss#95, web_site_id#102] Keys [1]: [web_site_id#102] -Functions [3]: [partial_sum(UnscaledValue(ws_ext_sales_price#88)), partial_sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00)), partial_sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))] +Functions [3]: [partial_sum(UnscaledValue(ws_ext_sales_price#88)), partial_sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00)), partial_sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))] Aggregate Attributes [5]: [sum#104, sum#105, isEmpty#106, sum#107, isEmpty#108] Results [6]: [web_site_id#102, sum#109, sum#110, isEmpty#111, sum#112, isEmpty#113] @@ -572,9 +572,9 @@ Arguments: hashpartitioning(web_site_id#102, 5), ENSURE_REQUIREMENTS, [id=#114] (101) HashAggregate [codegen id : 30] Input [6]: [web_site_id#102, sum#109, sum#110, isEmpty#111, sum#112, isEmpty#113] Keys [1]: [web_site_id#102] -Functions [3]: [sum(UnscaledValue(ws_ext_sales_price#88)), sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00)), sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))] -Aggregate Attributes [3]: [sum(UnscaledValue(ws_ext_sales_price#88))#115, sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00))#116, sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))#117] -Results [5]: [web channel AS channel#118, concat(web_site, web_site_id#102) AS id#119, MakeDecimal(sum(UnscaledValue(ws_ext_sales_price#88))#115,17,2) AS sales#120, sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00))#116 AS returns#121, sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))#117 AS profit#122] +Functions [3]: [sum(UnscaledValue(ws_ext_sales_price#88)), sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00)), sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))] +Aggregate Attributes [3]: [sum(UnscaledValue(ws_ext_sales_price#88))#115, sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00))#116, sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))#117] +Results [5]: [web channel AS channel#118, concat(web_site, web_site_id#102) AS id#119, MakeDecimal(sum(UnscaledValue(ws_ext_sales_price#88))#115,17,2) AS sales#120, sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00))#116 AS returns#121, sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))#117 AS profit#122] (102) Union diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q80a.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q80a.sf100/simplified.txt index ef6c39d87b482..af80e8a825183 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q80a.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q80a.sf100/simplified.txt @@ -16,7 +16,7 @@ TakeOrderedAndProject [channel,id,sales,returns,profit] InputAdapter Union WholeStageCodegen (10) - HashAggregate [s_store_id,sum,sum,isEmpty,sum,isEmpty] [sum(UnscaledValue(ss_ext_sales_price)),sum(coalesce(cast(sr_return_amt as decimal(12,2)), 0.00)),sum(CheckOverflow((promote_precision(cast(ss_net_profit as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true)),channel,id,sales,returns,profit,sum,sum,isEmpty,sum,isEmpty] + HashAggregate [s_store_id,sum,sum,isEmpty,sum,isEmpty] [sum(UnscaledValue(ss_ext_sales_price)),sum(coalesce(cast(sr_return_amt as decimal(12,2)), 0.00)),sum(CheckOverflow((promote_precision(cast(ss_net_profit as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2))),channel,id,sales,returns,profit,sum,sum,isEmpty,sum,isEmpty] InputAdapter Exchange [s_store_id] #3 WholeStageCodegen (9) @@ -86,7 +86,7 @@ TakeOrderedAndProject [channel,id,sales,returns,profit] InputAdapter Scan parquet default.store [s_store_sk,s_store_id] WholeStageCodegen (20) - HashAggregate [cp_catalog_page_id,sum,sum,isEmpty,sum,isEmpty] [sum(UnscaledValue(cs_ext_sales_price)),sum(coalesce(cast(cr_return_amount as decimal(12,2)), 0.00)),sum(CheckOverflow((promote_precision(cast(cs_net_profit as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true)),channel,id,sales,returns,profit,sum,sum,isEmpty,sum,isEmpty] + HashAggregate [cp_catalog_page_id,sum,sum,isEmpty,sum,isEmpty] [sum(UnscaledValue(cs_ext_sales_price)),sum(coalesce(cast(cr_return_amount as decimal(12,2)), 0.00)),sum(CheckOverflow((promote_precision(cast(cs_net_profit as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2))),channel,id,sales,returns,profit,sum,sum,isEmpty,sum,isEmpty] InputAdapter Exchange [cp_catalog_page_id] #10 WholeStageCodegen (19) @@ -137,7 +137,7 @@ TakeOrderedAndProject [channel,id,sales,returns,profit] InputAdapter Scan parquet default.catalog_page [cp_catalog_page_sk,cp_catalog_page_id] WholeStageCodegen (30) - HashAggregate [web_site_id,sum,sum,isEmpty,sum,isEmpty] [sum(UnscaledValue(ws_ext_sales_price)),sum(coalesce(cast(wr_return_amt as decimal(12,2)), 0.00)),sum(CheckOverflow((promote_precision(cast(ws_net_profit as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true)),channel,id,sales,returns,profit,sum,sum,isEmpty,sum,isEmpty] + HashAggregate [web_site_id,sum,sum,isEmpty,sum,isEmpty] [sum(UnscaledValue(ws_ext_sales_price)),sum(coalesce(cast(wr_return_amt as decimal(12,2)), 0.00)),sum(CheckOverflow((promote_precision(cast(ws_net_profit as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2))),channel,id,sales,returns,profit,sum,sum,isEmpty,sum,isEmpty] InputAdapter Exchange [web_site_id] #14 WholeStageCodegen (29) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q80a/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q80a/explain.txt index 3e68f3fe694fc..03e744ac87b53 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q80a/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q80a/explain.txt @@ -283,7 +283,7 @@ Input [7]: [ss_promo_sk#3, ss_ext_sales_price#5, ss_net_profit#6, sr_return_amt# (37) HashAggregate [codegen id : 9] Input [5]: [ss_ext_sales_price#5, ss_net_profit#6, sr_return_amt#12, sr_net_loss#13, s_store_id#18] Keys [1]: [s_store_id#18] -Functions [3]: [partial_sum(UnscaledValue(ss_ext_sales_price#5)), partial_sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00)), partial_sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))] +Functions [3]: [partial_sum(UnscaledValue(ss_ext_sales_price#5)), partial_sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00)), partial_sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))] Aggregate Attributes [5]: [sum#26, sum#27, isEmpty#28, sum#29, isEmpty#30] Results [6]: [s_store_id#18, sum#31, sum#32, isEmpty#33, sum#34, isEmpty#35] @@ -294,9 +294,9 @@ Arguments: hashpartitioning(s_store_id#18, 5), ENSURE_REQUIREMENTS, [id=#36] (39) HashAggregate [codegen id : 10] Input [6]: [s_store_id#18, sum#31, sum#32, isEmpty#33, sum#34, isEmpty#35] Keys [1]: [s_store_id#18] -Functions [3]: [sum(UnscaledValue(ss_ext_sales_price#5)), sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00)), sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))] -Aggregate Attributes [3]: [sum(UnscaledValue(ss_ext_sales_price#5))#37, sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00))#38, sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))#39] -Results [5]: [store channel AS channel#40, concat(store, s_store_id#18) AS id#41, MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#5))#37,17,2) AS sales#42, sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00))#38 AS returns#43, sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))#39 AS profit#44] +Functions [3]: [sum(UnscaledValue(ss_ext_sales_price#5)), sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00)), sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))] +Aggregate Attributes [3]: [sum(UnscaledValue(ss_ext_sales_price#5))#37, sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00))#38, sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))#39] +Results [5]: [store channel AS channel#40, concat(store, s_store_id#18) AS id#41, MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#5))#37,17,2) AS sales#42, sum(coalesce(cast(sr_return_amt#12 as decimal(12,2)), 0.00))#38 AS returns#43, sum(CheckOverflow((promote_precision(cast(ss_net_profit#6 as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss#13 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))#39 AS profit#44] (40) Scan parquet default.catalog_sales Output [7]: [cs_catalog_page_sk#45, cs_item_sk#46, cs_promo_sk#47, cs_order_number#48, cs_ext_sales_price#49, cs_net_profit#50, cs_sold_date_sk#51] @@ -422,7 +422,7 @@ Input [7]: [cs_promo_sk#47, cs_ext_sales_price#49, cs_net_profit#50, cr_return_a (68) HashAggregate [codegen id : 19] Input [5]: [cs_ext_sales_price#49, cs_net_profit#50, cr_return_amount#55, cr_net_loss#56, cp_catalog_page_id#61] Keys [1]: [cp_catalog_page_id#61] -Functions [3]: [partial_sum(UnscaledValue(cs_ext_sales_price#49)), partial_sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00)), partial_sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))] +Functions [3]: [partial_sum(UnscaledValue(cs_ext_sales_price#49)), partial_sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00)), partial_sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))] Aggregate Attributes [5]: [sum#65, sum#66, isEmpty#67, sum#68, isEmpty#69] Results [6]: [cp_catalog_page_id#61, sum#70, sum#71, isEmpty#72, sum#73, isEmpty#74] @@ -433,9 +433,9 @@ Arguments: hashpartitioning(cp_catalog_page_id#61, 5), ENSURE_REQUIREMENTS, [id= (70) HashAggregate [codegen id : 20] Input [6]: [cp_catalog_page_id#61, sum#70, sum#71, isEmpty#72, sum#73, isEmpty#74] Keys [1]: [cp_catalog_page_id#61] -Functions [3]: [sum(UnscaledValue(cs_ext_sales_price#49)), sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00)), sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))] -Aggregate Attributes [3]: [sum(UnscaledValue(cs_ext_sales_price#49))#76, sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00))#77, sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))#78] -Results [5]: [catalog channel AS channel#79, concat(catalog_page, cp_catalog_page_id#61) AS id#80, MakeDecimal(sum(UnscaledValue(cs_ext_sales_price#49))#76,17,2) AS sales#81, sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00))#77 AS returns#82, sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))#78 AS profit#83] +Functions [3]: [sum(UnscaledValue(cs_ext_sales_price#49)), sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00)), sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))] +Aggregate Attributes [3]: [sum(UnscaledValue(cs_ext_sales_price#49))#76, sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00))#77, sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))#78] +Results [5]: [catalog channel AS channel#79, concat(catalog_page, cp_catalog_page_id#61) AS id#80, MakeDecimal(sum(UnscaledValue(cs_ext_sales_price#49))#76,17,2) AS sales#81, sum(coalesce(cast(cr_return_amount#55 as decimal(12,2)), 0.00))#77 AS returns#82, sum(CheckOverflow((promote_precision(cast(cs_net_profit#50 as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss#56 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))#78 AS profit#83] (71) Scan parquet default.web_sales Output [7]: [ws_item_sk#84, ws_web_site_sk#85, ws_promo_sk#86, ws_order_number#87, ws_ext_sales_price#88, ws_net_profit#89, ws_sold_date_sk#90] @@ -561,7 +561,7 @@ Input [7]: [ws_promo_sk#86, ws_ext_sales_price#88, ws_net_profit#89, wr_return_a (99) HashAggregate [codegen id : 29] Input [5]: [ws_ext_sales_price#88, ws_net_profit#89, wr_return_amt#94, wr_net_loss#95, web_site_id#100] Keys [1]: [web_site_id#100] -Functions [3]: [partial_sum(UnscaledValue(ws_ext_sales_price#88)), partial_sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00)), partial_sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))] +Functions [3]: [partial_sum(UnscaledValue(ws_ext_sales_price#88)), partial_sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00)), partial_sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))] Aggregate Attributes [5]: [sum#104, sum#105, isEmpty#106, sum#107, isEmpty#108] Results [6]: [web_site_id#100, sum#109, sum#110, isEmpty#111, sum#112, isEmpty#113] @@ -572,9 +572,9 @@ Arguments: hashpartitioning(web_site_id#100, 5), ENSURE_REQUIREMENTS, [id=#114] (101) HashAggregate [codegen id : 30] Input [6]: [web_site_id#100, sum#109, sum#110, isEmpty#111, sum#112, isEmpty#113] Keys [1]: [web_site_id#100] -Functions [3]: [sum(UnscaledValue(ws_ext_sales_price#88)), sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00)), sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))] -Aggregate Attributes [3]: [sum(UnscaledValue(ws_ext_sales_price#88))#115, sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00))#116, sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))#117] -Results [5]: [web channel AS channel#118, concat(web_site, web_site_id#100) AS id#119, MakeDecimal(sum(UnscaledValue(ws_ext_sales_price#88))#115,17,2) AS sales#120, sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00))#116 AS returns#121, sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true))#117 AS profit#122] +Functions [3]: [sum(UnscaledValue(ws_ext_sales_price#88)), sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00)), sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))] +Aggregate Attributes [3]: [sum(UnscaledValue(ws_ext_sales_price#88))#115, sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00))#116, sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))#117] +Results [5]: [web channel AS channel#118, concat(web_site, web_site_id#100) AS id#119, MakeDecimal(sum(UnscaledValue(ws_ext_sales_price#88))#115,17,2) AS sales#120, sum(coalesce(cast(wr_return_amt#94 as decimal(12,2)), 0.00))#116 AS returns#121, sum(CheckOverflow((promote_precision(cast(ws_net_profit#89 as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss#95 as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2)))#117 AS profit#122] (102) Union diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q80a/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q80a/simplified.txt index d3fc38799fe0e..169957c1c164e 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q80a/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q80a/simplified.txt @@ -16,7 +16,7 @@ TakeOrderedAndProject [channel,id,sales,returns,profit] InputAdapter Union WholeStageCodegen (10) - HashAggregate [s_store_id,sum,sum,isEmpty,sum,isEmpty] [sum(UnscaledValue(ss_ext_sales_price)),sum(coalesce(cast(sr_return_amt as decimal(12,2)), 0.00)),sum(CheckOverflow((promote_precision(cast(ss_net_profit as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true)),channel,id,sales,returns,profit,sum,sum,isEmpty,sum,isEmpty] + HashAggregate [s_store_id,sum,sum,isEmpty,sum,isEmpty] [sum(UnscaledValue(ss_ext_sales_price)),sum(coalesce(cast(sr_return_amt as decimal(12,2)), 0.00)),sum(CheckOverflow((promote_precision(cast(ss_net_profit as decimal(13,2))) - promote_precision(cast(coalesce(cast(sr_net_loss as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2))),channel,id,sales,returns,profit,sum,sum,isEmpty,sum,isEmpty] InputAdapter Exchange [s_store_id] #3 WholeStageCodegen (9) @@ -86,7 +86,7 @@ TakeOrderedAndProject [channel,id,sales,returns,profit] InputAdapter Scan parquet default.promotion [p_promo_sk,p_channel_tv] WholeStageCodegen (20) - HashAggregate [cp_catalog_page_id,sum,sum,isEmpty,sum,isEmpty] [sum(UnscaledValue(cs_ext_sales_price)),sum(coalesce(cast(cr_return_amount as decimal(12,2)), 0.00)),sum(CheckOverflow((promote_precision(cast(cs_net_profit as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true)),channel,id,sales,returns,profit,sum,sum,isEmpty,sum,isEmpty] + HashAggregate [cp_catalog_page_id,sum,sum,isEmpty,sum,isEmpty] [sum(UnscaledValue(cs_ext_sales_price)),sum(coalesce(cast(cr_return_amount as decimal(12,2)), 0.00)),sum(CheckOverflow((promote_precision(cast(cs_net_profit as decimal(13,2))) - promote_precision(cast(coalesce(cast(cr_net_loss as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2))),channel,id,sales,returns,profit,sum,sum,isEmpty,sum,isEmpty] InputAdapter Exchange [cp_catalog_page_id] #10 WholeStageCodegen (19) @@ -137,7 +137,7 @@ TakeOrderedAndProject [channel,id,sales,returns,profit] InputAdapter ReusedExchange [p_promo_sk] #9 WholeStageCodegen (30) - HashAggregate [web_site_id,sum,sum,isEmpty,sum,isEmpty] [sum(UnscaledValue(ws_ext_sales_price)),sum(coalesce(cast(wr_return_amt as decimal(12,2)), 0.00)),sum(CheckOverflow((promote_precision(cast(ws_net_profit as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2), true)),channel,id,sales,returns,profit,sum,sum,isEmpty,sum,isEmpty] + HashAggregate [web_site_id,sum,sum,isEmpty,sum,isEmpty] [sum(UnscaledValue(ws_ext_sales_price)),sum(coalesce(cast(wr_return_amt as decimal(12,2)), 0.00)),sum(CheckOverflow((promote_precision(cast(ws_net_profit as decimal(13,2))) - promote_precision(cast(coalesce(cast(wr_net_loss as decimal(12,2)), 0.00) as decimal(13,2)))), DecimalType(13,2))),channel,id,sales,returns,profit,sum,sum,isEmpty,sum,isEmpty] InputAdapter Exchange [web_site_id] #14 WholeStageCodegen (29) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q98.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q98.sf100/explain.txt index 2c31ce69e5e5a..fd1c4b503eaa8 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q98.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q98.sf100/explain.txt @@ -122,7 +122,7 @@ Input [8]: [i_item_id#7, i_item_desc#8, i_category#11, i_class#10, i_current_pri Arguments: [sum(_w1#20) windowspecdefinition(i_class#10, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS _we0#22], [i_class#10] (22) Project [codegen id : 9] -Output [7]: [i_item_id#7, i_item_desc#8, i_category#11, i_class#10, i_current_price#9, itemrevenue#18, CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(_w0#19) * 100.00), DecimalType(21,2), true) as decimal(27,2))) / promote_precision(_we0#22)), DecimalType(38,17), true) AS revenueratio#23] +Output [7]: [i_item_id#7, i_item_desc#8, i_category#11, i_class#10, i_current_price#9, itemrevenue#18, CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(_w0#19) * 100.00), DecimalType(21,2)) as decimal(27,2))) / promote_precision(_we0#22)), DecimalType(38,17)) AS revenueratio#23] Input [9]: [i_item_id#7, i_item_desc#8, i_category#11, i_class#10, i_current_price#9, itemrevenue#18, _w0#19, _w1#20, _we0#22] (23) Exchange diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q98/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q98/explain.txt index 259338b39c245..68e7dba19dbab 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q98/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q98/explain.txt @@ -107,7 +107,7 @@ Input [8]: [i_item_id#6, i_item_desc#7, i_category#10, i_class#9, i_current_pric Arguments: [sum(_w1#19) windowspecdefinition(i_class#9, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS _we0#21], [i_class#9] (19) Project [codegen id : 6] -Output [7]: [i_item_id#6, i_item_desc#7, i_category#10, i_class#9, i_current_price#8, itemrevenue#17, CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(_w0#18) * 100.00), DecimalType(21,2), true) as decimal(27,2))) / promote_precision(_we0#21)), DecimalType(38,17), true) AS revenueratio#22] +Output [7]: [i_item_id#6, i_item_desc#7, i_category#10, i_class#9, i_current_price#8, itemrevenue#17, CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(_w0#18) * 100.00), DecimalType(21,2)) as decimal(27,2))) / promote_precision(_we0#21)), DecimalType(38,17)) AS revenueratio#22] Input [9]: [i_item_id#6, i_item_desc#7, i_category#10, i_class#9, i_current_price#8, itemrevenue#17, _w0#18, _w1#19, _we0#21] (20) Exchange diff --git a/sql/core/src/test/resources/tpch-plan-stability/q1/explain.txt b/sql/core/src/test/resources/tpch-plan-stability/q1/explain.txt index cc0d21d19409b..c2fdfb24d2d85 100644 --- a/sql/core/src/test/resources/tpch-plan-stability/q1/explain.txt +++ b/sql/core/src/test/resources/tpch-plan-stability/q1/explain.txt @@ -31,7 +31,7 @@ Input [7]: [l_quantity#1, l_extendedprice#2, l_discount#3, l_tax#4, l_returnflag (5) HashAggregate [codegen id : 1] Input [6]: [l_quantity#1, l_extendedprice#2, l_discount#3, l_tax#4, l_returnflag#5, l_linestatus#6] Keys [2]: [l_returnflag#5, l_linestatus#6] -Functions [8]: [partial_sum(l_quantity#1), partial_sum(l_extendedprice#2), partial_sum(CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true)), partial_sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true)) * promote_precision(cast(CheckOverflow((1 + promote_precision(cast(l_tax#4 as decimal(11,0)))), DecimalType(11,0), true) as decimal(22,0)))), DecimalType(34,0), true)), partial_avg(UnscaledValue(l_quantity#1)), partial_avg(UnscaledValue(l_extendedprice#2)), partial_avg(UnscaledValue(l_discount#3)), partial_count(1)] +Functions [8]: [partial_sum(l_quantity#1), partial_sum(l_extendedprice#2), partial_sum(CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0))), partial_sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0))) * promote_precision(cast(CheckOverflow((1 + promote_precision(cast(l_tax#4 as decimal(11,0)))), DecimalType(11,0)) as decimal(22,0)))), DecimalType(34,0))), partial_avg(UnscaledValue(l_quantity#1)), partial_avg(UnscaledValue(l_extendedprice#2)), partial_avg(UnscaledValue(l_discount#3)), partial_count(1)] Aggregate Attributes [15]: [sum#8, isEmpty#9, sum#10, isEmpty#11, sum#12, isEmpty#13, sum#14, isEmpty#15, sum#16, count#17, sum#18, count#19, sum#20, count#21, count#22] Results [17]: [l_returnflag#5, l_linestatus#6, sum#23, isEmpty#24, sum#25, isEmpty#26, sum#27, isEmpty#28, sum#29, isEmpty#30, sum#31, count#32, sum#33, count#34, sum#35, count#36, count#37] @@ -42,9 +42,9 @@ Arguments: hashpartitioning(l_returnflag#5, l_linestatus#6, 5), ENSURE_REQUIREME (7) HashAggregate [codegen id : 2] Input [17]: [l_returnflag#5, l_linestatus#6, sum#23, isEmpty#24, sum#25, isEmpty#26, sum#27, isEmpty#28, sum#29, isEmpty#30, sum#31, count#32, sum#33, count#34, sum#35, count#36, count#37] Keys [2]: [l_returnflag#5, l_linestatus#6] -Functions [8]: [sum(l_quantity#1), sum(l_extendedprice#2), sum(CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true)), sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true)) * promote_precision(cast(CheckOverflow((1 + promote_precision(cast(l_tax#4 as decimal(11,0)))), DecimalType(11,0), true) as decimal(22,0)))), DecimalType(34,0), true)), avg(UnscaledValue(l_quantity#1)), avg(UnscaledValue(l_extendedprice#2)), avg(UnscaledValue(l_discount#3)), count(1)] -Aggregate Attributes [8]: [sum(l_quantity#1)#39, sum(l_extendedprice#2)#40, sum(CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true))#41, sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true)) * promote_precision(cast(CheckOverflow((1 + promote_precision(cast(l_tax#4 as decimal(11,0)))), DecimalType(11,0), true) as decimal(22,0)))), DecimalType(34,0), true))#42, avg(UnscaledValue(l_quantity#1))#43, avg(UnscaledValue(l_extendedprice#2))#44, avg(UnscaledValue(l_discount#3))#45, count(1)#46] -Results [10]: [l_returnflag#5, l_linestatus#6, sum(l_quantity#1)#39 AS sum_qty#47, sum(l_extendedprice#2)#40 AS sum_base_price#48, sum(CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true))#41 AS sum_disc_price#49, sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true)) * promote_precision(cast(CheckOverflow((1 + promote_precision(cast(l_tax#4 as decimal(11,0)))), DecimalType(11,0), true) as decimal(22,0)))), DecimalType(34,0), true))#42 AS sum_charge#50, cast((avg(UnscaledValue(l_quantity#1))#43 / 1.0) as decimal(14,4)) AS avg_qty#51, cast((avg(UnscaledValue(l_extendedprice#2))#44 / 1.0) as decimal(14,4)) AS avg_price#52, cast((avg(UnscaledValue(l_discount#3))#45 / 1.0) as decimal(14,4)) AS avg_disc#53, count(1)#46 AS count_order#54] +Functions [8]: [sum(l_quantity#1), sum(l_extendedprice#2), sum(CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0))), sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0))) * promote_precision(cast(CheckOverflow((1 + promote_precision(cast(l_tax#4 as decimal(11,0)))), DecimalType(11,0)) as decimal(22,0)))), DecimalType(34,0))), avg(UnscaledValue(l_quantity#1)), avg(UnscaledValue(l_extendedprice#2)), avg(UnscaledValue(l_discount#3)), count(1)] +Aggregate Attributes [8]: [sum(l_quantity#1)#39, sum(l_extendedprice#2)#40, sum(CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)))#41, sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0))) * promote_precision(cast(CheckOverflow((1 + promote_precision(cast(l_tax#4 as decimal(11,0)))), DecimalType(11,0)) as decimal(22,0)))), DecimalType(34,0)))#42, avg(UnscaledValue(l_quantity#1))#43, avg(UnscaledValue(l_extendedprice#2))#44, avg(UnscaledValue(l_discount#3))#45, count(1)#46] +Results [10]: [l_returnflag#5, l_linestatus#6, sum(l_quantity#1)#39 AS sum_qty#47, sum(l_extendedprice#2)#40 AS sum_base_price#48, sum(CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)))#41 AS sum_disc_price#49, sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0))) * promote_precision(cast(CheckOverflow((1 + promote_precision(cast(l_tax#4 as decimal(11,0)))), DecimalType(11,0)) as decimal(22,0)))), DecimalType(34,0)))#42 AS sum_charge#50, cast((avg(UnscaledValue(l_quantity#1))#43 / 1.0) as decimal(14,4)) AS avg_qty#51, cast((avg(UnscaledValue(l_extendedprice#2))#44 / 1.0) as decimal(14,4)) AS avg_price#52, cast((avg(UnscaledValue(l_discount#3))#45 / 1.0) as decimal(14,4)) AS avg_disc#53, count(1)#46 AS count_order#54] (8) Exchange Input [10]: [l_returnflag#5, l_linestatus#6, sum_qty#47, sum_base_price#48, sum_disc_price#49, sum_charge#50, avg_qty#51, avg_price#52, avg_disc#53, count_order#54] diff --git a/sql/core/src/test/resources/tpch-plan-stability/q1/simplified.txt b/sql/core/src/test/resources/tpch-plan-stability/q1/simplified.txt index f94c3d6b5b4d8..68e8e39486e48 100644 --- a/sql/core/src/test/resources/tpch-plan-stability/q1/simplified.txt +++ b/sql/core/src/test/resources/tpch-plan-stability/q1/simplified.txt @@ -3,7 +3,7 @@ WholeStageCodegen (3) InputAdapter Exchange [l_returnflag,l_linestatus] #1 WholeStageCodegen (2) - HashAggregate [l_returnflag,l_linestatus,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,count,sum,count,sum,count,count] [sum(l_quantity),sum(l_extendedprice),sum(CheckOverflow((promote_precision(cast(l_extendedprice as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true)),sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(l_extendedprice as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true)) * promote_precision(cast(CheckOverflow((1 + promote_precision(cast(l_tax as decimal(11,0)))), DecimalType(11,0), true) as decimal(22,0)))), DecimalType(34,0), true)),avg(UnscaledValue(l_quantity)),avg(UnscaledValue(l_extendedprice)),avg(UnscaledValue(l_discount)),count(1),sum_qty,sum_base_price,sum_disc_price,sum_charge,avg_qty,avg_price,avg_disc,count_order,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,count,sum,count,sum,count,count] + HashAggregate [l_returnflag,l_linestatus,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,count,sum,count,sum,count,count] [sum(l_quantity),sum(l_extendedprice),sum(CheckOverflow((promote_precision(cast(l_extendedprice as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0))),sum(CheckOverflow((promote_precision(CheckOverflow((promote_precision(cast(l_extendedprice as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0))) * promote_precision(cast(CheckOverflow((1 + promote_precision(cast(l_tax as decimal(11,0)))), DecimalType(11,0)) as decimal(22,0)))), DecimalType(34,0))),avg(UnscaledValue(l_quantity)),avg(UnscaledValue(l_extendedprice)),avg(UnscaledValue(l_discount)),count(1),sum_qty,sum_base_price,sum_disc_price,sum_charge,avg_qty,avg_price,avg_disc,count_order,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,isEmpty,sum,count,sum,count,sum,count,count] InputAdapter Exchange [l_returnflag,l_linestatus] #2 WholeStageCodegen (1) diff --git a/sql/core/src/test/resources/tpch-plan-stability/q10/explain.txt b/sql/core/src/test/resources/tpch-plan-stability/q10/explain.txt index 4cd56105a252b..08be511944f36 100644 --- a/sql/core/src/test/resources/tpch-plan-stability/q10/explain.txt +++ b/sql/core/src/test/resources/tpch-plan-stability/q10/explain.txt @@ -134,7 +134,7 @@ Input [11]: [c_custkey#1, c_name#2, c_address#3, c_nationkey#4, c_phone#5, c_acc (24) HashAggregate [codegen id : 4] Input [9]: [c_custkey#1, c_name#2, c_address#3, c_phone#5, c_acctbal#6, c_comment#7, l_extendedprice#13, l_discount#14, n_name#18] Keys [7]: [c_custkey#1, c_name#2, c_acctbal#6, c_phone#5, n_name#18, c_address#3, c_comment#7] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(l_extendedprice#13 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#14 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true))] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(l_extendedprice#13 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#14 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)))] Aggregate Attributes [2]: [sum#20, isEmpty#21] Results [9]: [c_custkey#1, c_name#2, c_acctbal#6, c_phone#5, n_name#18, c_address#3, c_comment#7, sum#22, isEmpty#23] @@ -145,9 +145,9 @@ Arguments: hashpartitioning(c_custkey#1, c_name#2, c_acctbal#6, c_phone#5, n_nam (26) HashAggregate [codegen id : 5] Input [9]: [c_custkey#1, c_name#2, c_acctbal#6, c_phone#5, n_name#18, c_address#3, c_comment#7, sum#22, isEmpty#23] Keys [7]: [c_custkey#1, c_name#2, c_acctbal#6, c_phone#5, n_name#18, c_address#3, c_comment#7] -Functions [1]: [sum(CheckOverflow((promote_precision(cast(l_extendedprice#13 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#14 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(l_extendedprice#13 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#14 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true))#25] -Results [8]: [c_custkey#1, c_name#2, sum(CheckOverflow((promote_precision(cast(l_extendedprice#13 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#14 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true))#25 AS revenue#26, c_acctbal#6, n_name#18, c_address#3, c_phone#5, c_comment#7] +Functions [1]: [sum(CheckOverflow((promote_precision(cast(l_extendedprice#13 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#14 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(l_extendedprice#13 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#14 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)))#25] +Results [8]: [c_custkey#1, c_name#2, sum(CheckOverflow((promote_precision(cast(l_extendedprice#13 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#14 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)))#25 AS revenue#26, c_acctbal#6, n_name#18, c_address#3, c_phone#5, c_comment#7] (27) TakeOrderedAndProject Input [8]: [c_custkey#1, c_name#2, revenue#26, c_acctbal#6, n_name#18, c_address#3, c_phone#5, c_comment#7] diff --git a/sql/core/src/test/resources/tpch-plan-stability/q10/simplified.txt b/sql/core/src/test/resources/tpch-plan-stability/q10/simplified.txt index eb09255ad799f..86cee35abda3d 100644 --- a/sql/core/src/test/resources/tpch-plan-stability/q10/simplified.txt +++ b/sql/core/src/test/resources/tpch-plan-stability/q10/simplified.txt @@ -1,6 +1,6 @@ TakeOrderedAndProject [revenue,c_custkey,c_name,c_acctbal,n_name,c_address,c_phone,c_comment] WholeStageCodegen (5) - HashAggregate [c_custkey,c_name,c_acctbal,c_phone,n_name,c_address,c_comment,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(l_extendedprice as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true)),revenue,sum,isEmpty] + HashAggregate [c_custkey,c_name,c_acctbal,c_phone,n_name,c_address,c_comment,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(l_extendedprice as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0))),revenue,sum,isEmpty] InputAdapter Exchange [c_custkey,c_name,c_acctbal,c_phone,n_name,c_address,c_comment] #1 WholeStageCodegen (4) diff --git a/sql/core/src/test/resources/tpch-plan-stability/q11/explain.txt b/sql/core/src/test/resources/tpch-plan-stability/q11/explain.txt index c210d30019ad8..bc7e629fd7dd8 100644 --- a/sql/core/src/test/resources/tpch-plan-stability/q11/explain.txt +++ b/sql/core/src/test/resources/tpch-plan-stability/q11/explain.txt @@ -98,7 +98,7 @@ Input [5]: [ps_partkey#1, ps_availqty#3, ps_supplycost#4, s_nationkey#6, n_natio (17) HashAggregate [codegen id : 3] Input [3]: [ps_partkey#1, ps_availqty#3, ps_supplycost#4] Keys [1]: [ps_partkey#1] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(ps_supplycost#4) * promote_precision(cast(ps_availqty#3 as decimal(10,0)))), DecimalType(21,0), true))] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(ps_supplycost#4) * promote_precision(cast(ps_availqty#3 as decimal(10,0)))), DecimalType(21,0)))] Aggregate Attributes [2]: [sum#11, isEmpty#12] Results [3]: [ps_partkey#1, sum#13, isEmpty#14] @@ -109,9 +109,9 @@ Arguments: hashpartitioning(ps_partkey#1, 5), ENSURE_REQUIREMENTS, [id=#15] (19) HashAggregate [codegen id : 4] Input [3]: [ps_partkey#1, sum#13, isEmpty#14] Keys [1]: [ps_partkey#1] -Functions [1]: [sum(CheckOverflow((promote_precision(ps_supplycost#4) * promote_precision(cast(ps_availqty#3 as decimal(10,0)))), DecimalType(21,0), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(ps_supplycost#4) * promote_precision(cast(ps_availqty#3 as decimal(10,0)))), DecimalType(21,0), true))#16] -Results [2]: [ps_partkey#1, sum(CheckOverflow((promote_precision(ps_supplycost#4) * promote_precision(cast(ps_availqty#3 as decimal(10,0)))), DecimalType(21,0), true))#16 AS value#17] +Functions [1]: [sum(CheckOverflow((promote_precision(ps_supplycost#4) * promote_precision(cast(ps_availqty#3 as decimal(10,0)))), DecimalType(21,0)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(ps_supplycost#4) * promote_precision(cast(ps_availqty#3 as decimal(10,0)))), DecimalType(21,0)))#16] +Results [2]: [ps_partkey#1, sum(CheckOverflow((promote_precision(ps_supplycost#4) * promote_precision(cast(ps_availqty#3 as decimal(10,0)))), DecimalType(21,0)))#16 AS value#17] (20) Filter [codegen id : 4] Input [2]: [ps_partkey#1, value#17] @@ -183,7 +183,7 @@ Input [4]: [ps_availqty#22, ps_supplycost#23, s_nationkey#25, n_nationkey#26] (32) HashAggregate [codegen id : 3] Input [2]: [ps_availqty#22, ps_supplycost#23] Keys: [] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(ps_supplycost#23) * promote_precision(cast(ps_availqty#22 as decimal(10,0)))), DecimalType(21,0), true))] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(ps_supplycost#23) * promote_precision(cast(ps_availqty#22 as decimal(10,0)))), DecimalType(21,0)))] Aggregate Attributes [2]: [sum#27, isEmpty#28] Results [2]: [sum#29, isEmpty#30] @@ -194,8 +194,8 @@ Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#31] (34) HashAggregate [codegen id : 4] Input [2]: [sum#29, isEmpty#30] Keys: [] -Functions [1]: [sum(CheckOverflow((promote_precision(ps_supplycost#23) * promote_precision(cast(ps_availqty#22 as decimal(10,0)))), DecimalType(21,0), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(ps_supplycost#23) * promote_precision(cast(ps_availqty#22 as decimal(10,0)))), DecimalType(21,0), true))#32] -Results [1]: [CheckOverflow((promote_precision(cast(sum(CheckOverflow((promote_precision(ps_supplycost#23) * promote_precision(cast(ps_availqty#22 as decimal(10,0)))), DecimalType(21,0), true))#32 as decimal(38,10))) * 0.0001000000), DecimalType(38,6), true) AS (sum((ps_supplycost * ps_availqty)) * 0.0001000000)#33] +Functions [1]: [sum(CheckOverflow((promote_precision(ps_supplycost#23) * promote_precision(cast(ps_availqty#22 as decimal(10,0)))), DecimalType(21,0)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(ps_supplycost#23) * promote_precision(cast(ps_availqty#22 as decimal(10,0)))), DecimalType(21,0)))#32] +Results [1]: [CheckOverflow((promote_precision(cast(sum(CheckOverflow((promote_precision(ps_supplycost#23) * promote_precision(cast(ps_availqty#22 as decimal(10,0)))), DecimalType(21,0)))#32 as decimal(38,10))) * 0.0001000000), DecimalType(38,6)) AS (sum((ps_supplycost * ps_availqty)) * 0.0001000000)#33] diff --git a/sql/core/src/test/resources/tpch-plan-stability/q11/simplified.txt b/sql/core/src/test/resources/tpch-plan-stability/q11/simplified.txt index f94cf82874cf3..bdafa6c8b43c1 100644 --- a/sql/core/src/test/resources/tpch-plan-stability/q11/simplified.txt +++ b/sql/core/src/test/resources/tpch-plan-stability/q11/simplified.txt @@ -6,7 +6,7 @@ WholeStageCodegen (5) Filter [value] Subquery #1 WholeStageCodegen (4) - HashAggregate [sum,isEmpty] [sum(CheckOverflow((promote_precision(ps_supplycost) * promote_precision(cast(ps_availqty as decimal(10,0)))), DecimalType(21,0), true)),(sum((ps_supplycost * ps_availqty)) * 0.0001000000),sum,isEmpty] + HashAggregate [sum,isEmpty] [sum(CheckOverflow((promote_precision(ps_supplycost) * promote_precision(cast(ps_availqty as decimal(10,0)))), DecimalType(21,0))),(sum((ps_supplycost * ps_availqty)) * 0.0001000000),sum,isEmpty] InputAdapter Exchange #5 WholeStageCodegen (3) @@ -23,7 +23,7 @@ WholeStageCodegen (5) ReusedExchange [s_suppkey,s_nationkey] #3 InputAdapter ReusedExchange [n_nationkey] #4 - HashAggregate [ps_partkey,sum,isEmpty] [sum(CheckOverflow((promote_precision(ps_supplycost) * promote_precision(cast(ps_availqty as decimal(10,0)))), DecimalType(21,0), true)),value,sum,isEmpty] + HashAggregate [ps_partkey,sum,isEmpty] [sum(CheckOverflow((promote_precision(ps_supplycost) * promote_precision(cast(ps_availqty as decimal(10,0)))), DecimalType(21,0))),value,sum,isEmpty] InputAdapter Exchange [ps_partkey] #2 WholeStageCodegen (3) diff --git a/sql/core/src/test/resources/tpch-plan-stability/q14/explain.txt b/sql/core/src/test/resources/tpch-plan-stability/q14/explain.txt index 98e3b4a5e8fac..0e923aebe1e11 100644 --- a/sql/core/src/test/resources/tpch-plan-stability/q14/explain.txt +++ b/sql/core/src/test/resources/tpch-plan-stability/q14/explain.txt @@ -62,7 +62,7 @@ Input [5]: [l_partkey#1, l_extendedprice#2, l_discount#3, p_partkey#5, p_type#6] (11) HashAggregate [codegen id : 2] Input [3]: [l_extendedprice#2, l_discount#3, p_type#6] Keys: [] -Functions [2]: [partial_sum(CASE WHEN StartsWith(p_type#6, PROMO) THEN CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true) ELSE 0 END), partial_sum(CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true))] +Functions [2]: [partial_sum(CASE WHEN StartsWith(p_type#6, PROMO) THEN CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)) ELSE 0 END), partial_sum(CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)))] Aggregate Attributes [4]: [sum#8, isEmpty#9, sum#10, isEmpty#11] Results [4]: [sum#12, isEmpty#13, sum#14, isEmpty#15] @@ -73,7 +73,7 @@ Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#16] (13) HashAggregate [codegen id : 3] Input [4]: [sum#12, isEmpty#13, sum#14, isEmpty#15] Keys: [] -Functions [2]: [sum(CASE WHEN StartsWith(p_type#6, PROMO) THEN CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true) ELSE 0 END), sum(CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true))] -Aggregate Attributes [2]: [sum(CASE WHEN StartsWith(p_type#6, PROMO) THEN CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true) ELSE 0 END)#17, sum(CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true))#18] -Results [1]: [CheckOverflow((promote_precision(CheckOverflow((100.00 * promote_precision(cast(sum(CASE WHEN StartsWith(p_type#6, PROMO) THEN CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true) ELSE 0 END)#17 as decimal(34,2)))), DecimalType(38,2), true)) / promote_precision(cast(sum(CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true))#18 as decimal(38,2)))), DecimalType(38,6), true) AS promo_revenue#19] +Functions [2]: [sum(CASE WHEN StartsWith(p_type#6, PROMO) THEN CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)) ELSE 0 END), sum(CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)))] +Aggregate Attributes [2]: [sum(CASE WHEN StartsWith(p_type#6, PROMO) THEN CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)) ELSE 0 END)#17, sum(CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)))#18] +Results [1]: [CheckOverflow((promote_precision(CheckOverflow((100.00 * promote_precision(cast(sum(CASE WHEN StartsWith(p_type#6, PROMO) THEN CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)) ELSE 0 END)#17 as decimal(34,2)))), DecimalType(38,2))) / promote_precision(cast(sum(CheckOverflow((promote_precision(cast(l_extendedprice#2 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#3 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)))#18 as decimal(38,2)))), DecimalType(38,6)) AS promo_revenue#19] diff --git a/sql/core/src/test/resources/tpch-plan-stability/q14/simplified.txt b/sql/core/src/test/resources/tpch-plan-stability/q14/simplified.txt index 8f46e5fff4efa..ca3c30110de04 100644 --- a/sql/core/src/test/resources/tpch-plan-stability/q14/simplified.txt +++ b/sql/core/src/test/resources/tpch-plan-stability/q14/simplified.txt @@ -1,5 +1,5 @@ WholeStageCodegen (3) - HashAggregate [sum,isEmpty,sum,isEmpty] [sum(CASE WHEN StartsWith(p_type, PROMO) THEN CheckOverflow((promote_precision(cast(l_extendedprice as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true) ELSE 0 END),sum(CheckOverflow((promote_precision(cast(l_extendedprice as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true)),promo_revenue,sum,isEmpty,sum,isEmpty] + HashAggregate [sum,isEmpty,sum,isEmpty] [sum(CASE WHEN StartsWith(p_type, PROMO) THEN CheckOverflow((promote_precision(cast(l_extendedprice as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)) ELSE 0 END),sum(CheckOverflow((promote_precision(cast(l_extendedprice as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0))),promo_revenue,sum,isEmpty,sum,isEmpty] InputAdapter Exchange #1 WholeStageCodegen (2) diff --git a/sql/core/src/test/resources/tpch-plan-stability/q15/explain.txt b/sql/core/src/test/resources/tpch-plan-stability/q15/explain.txt index a64943b45fefd..a615b73893782 100644 --- a/sql/core/src/test/resources/tpch-plan-stability/q15/explain.txt +++ b/sql/core/src/test/resources/tpch-plan-stability/q15/explain.txt @@ -52,7 +52,7 @@ Input [4]: [l_suppkey#5, l_extendedprice#6, l_discount#7, l_shipdate#8] (8) HashAggregate [codegen id : 1] Input [3]: [l_suppkey#5, l_extendedprice#6, l_discount#7] Keys [1]: [l_suppkey#5] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(l_extendedprice#6 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#7 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true))] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(l_extendedprice#6 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#7 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)))] Aggregate Attributes [2]: [sum#9, isEmpty#10] Results [3]: [l_suppkey#5, sum#11, isEmpty#12] @@ -63,9 +63,9 @@ Arguments: hashpartitioning(l_suppkey#5, 5), ENSURE_REQUIREMENTS, [id=#13] (10) HashAggregate [codegen id : 2] Input [3]: [l_suppkey#5, sum#11, isEmpty#12] Keys [1]: [l_suppkey#5] -Functions [1]: [sum(CheckOverflow((promote_precision(cast(l_extendedprice#6 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#7 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(l_extendedprice#6 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#7 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true))#14] -Results [2]: [l_suppkey#5 AS supplier_no#15, sum(CheckOverflow((promote_precision(cast(l_extendedprice#6 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#7 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true))#14 AS total_revenue#16] +Functions [1]: [sum(CheckOverflow((promote_precision(cast(l_extendedprice#6 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#7 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(l_extendedprice#6 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#7 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)))#14] +Results [2]: [l_suppkey#5 AS supplier_no#15, sum(CheckOverflow((promote_precision(cast(l_extendedprice#6 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#7 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)))#14 AS total_revenue#16] (11) Filter [codegen id : 2] Input [2]: [supplier_no#15, total_revenue#16] @@ -128,7 +128,7 @@ Input [4]: [l_suppkey#5, l_extendedprice#6, l_discount#7, l_shipdate#8] (21) HashAggregate [codegen id : 1] Input [3]: [l_suppkey#5, l_extendedprice#6, l_discount#7] Keys [1]: [l_suppkey#5] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(l_extendedprice#6 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#7 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true))] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(l_extendedprice#6 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#7 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)))] Aggregate Attributes [2]: [sum#21, isEmpty#22] Results [3]: [l_suppkey#5, sum#23, isEmpty#24] @@ -139,9 +139,9 @@ Arguments: hashpartitioning(l_suppkey#5, 5), ENSURE_REQUIREMENTS, [id=#25] (23) HashAggregate [codegen id : 2] Input [3]: [l_suppkey#5, sum#23, isEmpty#24] Keys [1]: [l_suppkey#5] -Functions [1]: [sum(CheckOverflow((promote_precision(cast(l_extendedprice#6 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#7 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(l_extendedprice#6 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#7 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true))#14] -Results [1]: [sum(CheckOverflow((promote_precision(cast(l_extendedprice#6 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#7 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true))#14 AS total_revenue#16] +Functions [1]: [sum(CheckOverflow((promote_precision(cast(l_extendedprice#6 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#7 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(l_extendedprice#6 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#7 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)))#14] +Results [1]: [sum(CheckOverflow((promote_precision(cast(l_extendedprice#6 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#7 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)))#14 AS total_revenue#16] (24) HashAggregate [codegen id : 2] Input [1]: [total_revenue#16] diff --git a/sql/core/src/test/resources/tpch-plan-stability/q15/simplified.txt b/sql/core/src/test/resources/tpch-plan-stability/q15/simplified.txt index a492b9e8b5249..ae1de64f65a92 100644 --- a/sql/core/src/test/resources/tpch-plan-stability/q15/simplified.txt +++ b/sql/core/src/test/resources/tpch-plan-stability/q15/simplified.txt @@ -20,7 +20,7 @@ WholeStageCodegen (4) Exchange #4 WholeStageCodegen (2) HashAggregate [total_revenue] [max,max] - HashAggregate [l_suppkey,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(l_extendedprice as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true)),total_revenue,sum,isEmpty] + HashAggregate [l_suppkey,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(l_extendedprice as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0))),total_revenue,sum,isEmpty] InputAdapter Exchange [l_suppkey] #5 WholeStageCodegen (1) @@ -30,7 +30,7 @@ WholeStageCodegen (4) ColumnarToRow InputAdapter Scan parquet default.lineitem [l_suppkey,l_extendedprice,l_discount,l_shipdate] - HashAggregate [l_suppkey,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(l_extendedprice as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true)),supplier_no,total_revenue,sum,isEmpty] + HashAggregate [l_suppkey,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(l_extendedprice as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0))),supplier_no,total_revenue,sum,isEmpty] InputAdapter Exchange [l_suppkey] #3 WholeStageCodegen (1) diff --git a/sql/core/src/test/resources/tpch-plan-stability/q17/explain.txt b/sql/core/src/test/resources/tpch-plan-stability/q17/explain.txt index 416b5345d6a82..652bf04238ca2 100644 --- a/sql/core/src/test/resources/tpch-plan-stability/q17/explain.txt +++ b/sql/core/src/test/resources/tpch-plan-stability/q17/explain.txt @@ -99,7 +99,7 @@ Input [3]: [l_partkey#9, sum#13, count#14] Keys [1]: [l_partkey#9] Functions [1]: [avg(UnscaledValue(l_quantity#10))] Aggregate Attributes [1]: [avg(UnscaledValue(l_quantity#10))#16] -Results [2]: [CheckOverflow((0.2000 * promote_precision(cast((avg(UnscaledValue(l_quantity#10))#16 / 1.0) as decimal(14,4)))), DecimalType(16,5), true) AS (0.2 * avg(l_quantity))#17, l_partkey#9] +Results [2]: [CheckOverflow((0.2000 * promote_precision(cast((avg(UnscaledValue(l_quantity#10))#16 / 1.0) as decimal(14,4)))), DecimalType(16,5)) AS (0.2 * avg(l_quantity))#17, l_partkey#9] (17) Filter [codegen id : 3] Input [2]: [(0.2 * avg(l_quantity))#17, l_partkey#9] @@ -134,5 +134,5 @@ Input [2]: [sum#21, isEmpty#22] Keys: [] Functions [1]: [sum(l_extendedprice#3)] Aggregate Attributes [1]: [sum(l_extendedprice#3)#24] -Results [1]: [CheckOverflow((promote_precision(cast(sum(l_extendedprice#3)#24 as decimal(21,1))) / 7.0), DecimalType(27,6), true) AS avg_yearly#25] +Results [1]: [CheckOverflow((promote_precision(cast(sum(l_extendedprice#3)#24 as decimal(21,1))) / 7.0), DecimalType(27,6)) AS avg_yearly#25] diff --git a/sql/core/src/test/resources/tpch-plan-stability/q19/explain.txt b/sql/core/src/test/resources/tpch-plan-stability/q19/explain.txt index 41bff0f6756ce..b5d84e54efc7e 100644 --- a/sql/core/src/test/resources/tpch-plan-stability/q19/explain.txt +++ b/sql/core/src/test/resources/tpch-plan-stability/q19/explain.txt @@ -62,7 +62,7 @@ Input [8]: [l_partkey#1, l_quantity#2, l_extendedprice#3, l_discount#4, p_partke (11) HashAggregate [codegen id : 2] Input [2]: [l_extendedprice#3, l_discount#4] Keys: [] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(l_extendedprice#3 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#4 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true))] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(l_extendedprice#3 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#4 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)))] Aggregate Attributes [2]: [sum#15, isEmpty#16] Results [2]: [sum#17, isEmpty#18] @@ -73,7 +73,7 @@ Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#19] (13) HashAggregate [codegen id : 3] Input [2]: [sum#17, isEmpty#18] Keys: [] -Functions [1]: [sum(CheckOverflow((promote_precision(cast(l_extendedprice#3 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#4 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(l_extendedprice#3 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#4 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true))#20] -Results [1]: [sum(CheckOverflow((promote_precision(cast(l_extendedprice#3 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#4 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true))#20 AS revenue#21] +Functions [1]: [sum(CheckOverflow((promote_precision(cast(l_extendedprice#3 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#4 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(l_extendedprice#3 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#4 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)))#20] +Results [1]: [sum(CheckOverflow((promote_precision(cast(l_extendedprice#3 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#4 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)))#20 AS revenue#21] diff --git a/sql/core/src/test/resources/tpch-plan-stability/q19/simplified.txt b/sql/core/src/test/resources/tpch-plan-stability/q19/simplified.txt index fc2ac1096938e..24838e5c93109 100644 --- a/sql/core/src/test/resources/tpch-plan-stability/q19/simplified.txt +++ b/sql/core/src/test/resources/tpch-plan-stability/q19/simplified.txt @@ -1,5 +1,5 @@ WholeStageCodegen (3) - HashAggregate [sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(l_extendedprice as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true)),revenue,sum,isEmpty] + HashAggregate [sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(l_extendedprice as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0))),revenue,sum,isEmpty] InputAdapter Exchange #1 WholeStageCodegen (2) diff --git a/sql/core/src/test/resources/tpch-plan-stability/q20/explain.txt b/sql/core/src/test/resources/tpch-plan-stability/q20/explain.txt index edf14f1c424e5..43d5431a70f2e 100644 --- a/sql/core/src/test/resources/tpch-plan-stability/q20/explain.txt +++ b/sql/core/src/test/resources/tpch-plan-stability/q20/explain.txt @@ -135,7 +135,7 @@ Input [4]: [l_partkey#11, l_suppkey#12, sum#17, isEmpty#18] Keys [2]: [l_partkey#11, l_suppkey#12] Functions [1]: [sum(l_quantity#13)] Aggregate Attributes [1]: [sum(l_quantity#13)#20] -Results [3]: [CheckOverflow((0.5 * promote_precision(cast(sum(l_quantity#13)#20 as decimal(21,1)))), DecimalType(22,1), true) AS (0.5 * sum(l_quantity))#21, l_partkey#11, l_suppkey#12] +Results [3]: [CheckOverflow((0.5 * promote_precision(cast(sum(l_quantity#13)#20 as decimal(21,1)))), DecimalType(22,1)) AS (0.5 * sum(l_quantity))#21, l_partkey#11, l_suppkey#12] (22) Filter [codegen id : 4] Input [3]: [(0.5 * sum(l_quantity))#21, l_partkey#11, l_suppkey#12] @@ -148,7 +148,7 @@ Arguments: HashedRelationBroadcastMode(List(input[1, bigint, true], input[2, big (24) BroadcastHashJoin [codegen id : 5] Left keys [2]: [ps_partkey#5, ps_suppkey#6] Right keys [2]: [l_partkey#11, l_suppkey#12] -Join condition: (cast(cast(ps_availqty#7 as decimal(10,0)) as decimal(22,1)) > (0.5 * sum(l_quantity))#21) +Join condition: (cast(ps_availqty#7 as decimal(22,1)) > (0.5 * sum(l_quantity))#21) (25) Project [codegen id : 5] Output [1]: [ps_suppkey#6] diff --git a/sql/core/src/test/resources/tpch-plan-stability/q3/explain.txt b/sql/core/src/test/resources/tpch-plan-stability/q3/explain.txt index ee09633bda706..e0243ce3bbd52 100644 --- a/sql/core/src/test/resources/tpch-plan-stability/q3/explain.txt +++ b/sql/core/src/test/resources/tpch-plan-stability/q3/explain.txt @@ -101,7 +101,7 @@ Input [6]: [o_orderkey#3, o_orderdate#5, o_shippriority#6, l_orderkey#8, l_exten (18) HashAggregate [codegen id : 3] Input [5]: [o_orderdate#5, o_shippriority#6, l_orderkey#8, l_extendedprice#9, l_discount#10] Keys [3]: [l_orderkey#8, o_orderdate#5, o_shippriority#6] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(l_extendedprice#9 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#10 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true))] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(l_extendedprice#9 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#10 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)))] Aggregate Attributes [2]: [sum#13, isEmpty#14] Results [5]: [l_orderkey#8, o_orderdate#5, o_shippriority#6, sum#15, isEmpty#16] @@ -112,9 +112,9 @@ Arguments: hashpartitioning(l_orderkey#8, o_orderdate#5, o_shippriority#6, 5), E (20) HashAggregate [codegen id : 4] Input [5]: [l_orderkey#8, o_orderdate#5, o_shippriority#6, sum#15, isEmpty#16] Keys [3]: [l_orderkey#8, o_orderdate#5, o_shippriority#6] -Functions [1]: [sum(CheckOverflow((promote_precision(cast(l_extendedprice#9 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#10 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(l_extendedprice#9 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#10 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true))#18] -Results [4]: [l_orderkey#8, sum(CheckOverflow((promote_precision(cast(l_extendedprice#9 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#10 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true))#18 AS revenue#19, o_orderdate#5, o_shippriority#6] +Functions [1]: [sum(CheckOverflow((promote_precision(cast(l_extendedprice#9 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#10 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(l_extendedprice#9 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#10 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)))#18] +Results [4]: [l_orderkey#8, sum(CheckOverflow((promote_precision(cast(l_extendedprice#9 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#10 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)))#18 AS revenue#19, o_orderdate#5, o_shippriority#6] (21) TakeOrderedAndProject Input [4]: [l_orderkey#8, revenue#19, o_orderdate#5, o_shippriority#6] diff --git a/sql/core/src/test/resources/tpch-plan-stability/q3/simplified.txt b/sql/core/src/test/resources/tpch-plan-stability/q3/simplified.txt index 9e234b2ff6d3d..26c18d19d7e20 100644 --- a/sql/core/src/test/resources/tpch-plan-stability/q3/simplified.txt +++ b/sql/core/src/test/resources/tpch-plan-stability/q3/simplified.txt @@ -1,6 +1,6 @@ TakeOrderedAndProject [revenue,o_orderdate,l_orderkey,o_shippriority] WholeStageCodegen (4) - HashAggregate [l_orderkey,o_orderdate,o_shippriority,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(l_extendedprice as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true)),revenue,sum,isEmpty] + HashAggregate [l_orderkey,o_orderdate,o_shippriority,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(l_extendedprice as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0))),revenue,sum,isEmpty] InputAdapter Exchange [l_orderkey,o_orderdate,o_shippriority] #1 WholeStageCodegen (3) diff --git a/sql/core/src/test/resources/tpch-plan-stability/q5/explain.txt b/sql/core/src/test/resources/tpch-plan-stability/q5/explain.txt index fba8d0ea9629d..c3dbd88338317 100644 --- a/sql/core/src/test/resources/tpch-plan-stability/q5/explain.txt +++ b/sql/core/src/test/resources/tpch-plan-stability/q5/explain.txt @@ -201,7 +201,7 @@ Input [5]: [l_extendedprice#9, l_discount#10, n_name#16, n_regionkey#17, r_regio (36) HashAggregate [codegen id : 6] Input [3]: [l_extendedprice#9, l_discount#10, n_name#16] Keys [1]: [n_name#16] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(l_extendedprice#9 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#10 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true))] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(l_extendedprice#9 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#10 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)))] Aggregate Attributes [2]: [sum#22, isEmpty#23] Results [3]: [n_name#16, sum#24, isEmpty#25] @@ -212,9 +212,9 @@ Arguments: hashpartitioning(n_name#16, 5), ENSURE_REQUIREMENTS, [id=#26] (38) HashAggregate [codegen id : 7] Input [3]: [n_name#16, sum#24, isEmpty#25] Keys [1]: [n_name#16] -Functions [1]: [sum(CheckOverflow((promote_precision(cast(l_extendedprice#9 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#10 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(l_extendedprice#9 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#10 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true))#27] -Results [2]: [n_name#16, sum(CheckOverflow((promote_precision(cast(l_extendedprice#9 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#10 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true))#27 AS revenue#28] +Functions [1]: [sum(CheckOverflow((promote_precision(cast(l_extendedprice#9 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#10 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(l_extendedprice#9 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#10 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)))#27] +Results [2]: [n_name#16, sum(CheckOverflow((promote_precision(cast(l_extendedprice#9 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#10 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)))#27 AS revenue#28] (39) Exchange Input [2]: [n_name#16, revenue#28] diff --git a/sql/core/src/test/resources/tpch-plan-stability/q5/simplified.txt b/sql/core/src/test/resources/tpch-plan-stability/q5/simplified.txt index aa5c8b0b0b844..a9d8480dc8b98 100644 --- a/sql/core/src/test/resources/tpch-plan-stability/q5/simplified.txt +++ b/sql/core/src/test/resources/tpch-plan-stability/q5/simplified.txt @@ -3,7 +3,7 @@ WholeStageCodegen (8) InputAdapter Exchange [revenue] #1 WholeStageCodegen (7) - HashAggregate [n_name,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(l_extendedprice as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true)),revenue,sum,isEmpty] + HashAggregate [n_name,sum,isEmpty] [sum(CheckOverflow((promote_precision(cast(l_extendedprice as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0))),revenue,sum,isEmpty] InputAdapter Exchange [n_name] #2 WholeStageCodegen (6) diff --git a/sql/core/src/test/resources/tpch-plan-stability/q6/explain.txt b/sql/core/src/test/resources/tpch-plan-stability/q6/explain.txt index 3b203b22cc70f..a092574d73c57 100644 --- a/sql/core/src/test/resources/tpch-plan-stability/q6/explain.txt +++ b/sql/core/src/test/resources/tpch-plan-stability/q6/explain.txt @@ -12,7 +12,7 @@ Arguments: , [l_extendedprice#1, l_discount#2] (2) HashAggregate [codegen id : 1] Input [2]: [l_extendedprice#1, l_discount#2] Keys: [] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(l_extendedprice#1) * promote_precision(l_discount#2)), DecimalType(21,0), true))] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(l_extendedprice#1) * promote_precision(l_discount#2)), DecimalType(21,0)))] Aggregate Attributes [2]: [sum#3, isEmpty#4] Results [2]: [sum#5, isEmpty#6] @@ -23,7 +23,7 @@ Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#7] (4) HashAggregate [codegen id : 2] Input [2]: [sum#5, isEmpty#6] Keys: [] -Functions [1]: [sum(CheckOverflow((promote_precision(l_extendedprice#1) * promote_precision(l_discount#2)), DecimalType(21,0), true))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(l_extendedprice#1) * promote_precision(l_discount#2)), DecimalType(21,0), true))#8] -Results [1]: [sum(CheckOverflow((promote_precision(l_extendedprice#1) * promote_precision(l_discount#2)), DecimalType(21,0), true))#8 AS revenue#9] +Functions [1]: [sum(CheckOverflow((promote_precision(l_extendedprice#1) * promote_precision(l_discount#2)), DecimalType(21,0)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(l_extendedprice#1) * promote_precision(l_discount#2)), DecimalType(21,0)))#8] +Results [1]: [sum(CheckOverflow((promote_precision(l_extendedprice#1) * promote_precision(l_discount#2)), DecimalType(21,0)))#8 AS revenue#9] diff --git a/sql/core/src/test/resources/tpch-plan-stability/q6/simplified.txt b/sql/core/src/test/resources/tpch-plan-stability/q6/simplified.txt index 3170df2269ac4..3d026241e9ccd 100644 --- a/sql/core/src/test/resources/tpch-plan-stability/q6/simplified.txt +++ b/sql/core/src/test/resources/tpch-plan-stability/q6/simplified.txt @@ -1,5 +1,5 @@ WholeStageCodegen (2) - HashAggregate [sum,isEmpty] [sum(CheckOverflow((promote_precision(l_extendedprice) * promote_precision(l_discount)), DecimalType(21,0), true)),revenue,sum,isEmpty] + HashAggregate [sum,isEmpty] [sum(CheckOverflow((promote_precision(l_extendedprice) * promote_precision(l_discount)), DecimalType(21,0))),revenue,sum,isEmpty] InputAdapter Exchange #1 WholeStageCodegen (1) diff --git a/sql/core/src/test/resources/tpch-plan-stability/q7/explain.txt b/sql/core/src/test/resources/tpch-plan-stability/q7/explain.txt index 7b20174aa50ce..9994d01a28e5c 100644 --- a/sql/core/src/test/resources/tpch-plan-stability/q7/explain.txt +++ b/sql/core/src/test/resources/tpch-plan-stability/q7/explain.txt @@ -167,7 +167,7 @@ Right keys [1]: [n_nationkey#18] Join condition: (((n_name#16 = FRANCE) AND (n_name#19 = GERMANY)) OR ((n_name#16 = GERMANY) AND (n_name#19 = FRANCE))) (30) Project [codegen id : 6] -Output [4]: [n_name#16 AS supp_nation#20, n_name#19 AS cust_nation#21, year(l_shipdate#7) AS l_year#22, CheckOverflow((promote_precision(cast(l_extendedprice#5 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#6 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true) AS volume#23] +Output [4]: [n_name#16 AS supp_nation#20, n_name#19 AS cust_nation#21, year(l_shipdate#7) AS l_year#22, CheckOverflow((promote_precision(cast(l_extendedprice#5 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#6 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)) AS volume#23] Input [7]: [l_extendedprice#5, l_discount#6, l_shipdate#7, c_nationkey#13, n_name#16, n_nationkey#18, n_name#19] (31) HashAggregate [codegen id : 6] diff --git a/sql/core/src/test/resources/tpch-plan-stability/q8/explain.txt b/sql/core/src/test/resources/tpch-plan-stability/q8/explain.txt index eb8ea81ef33a1..4eb4f811035d8 100644 --- a/sql/core/src/test/resources/tpch-plan-stability/q8/explain.txt +++ b/sql/core/src/test/resources/tpch-plan-stability/q8/explain.txt @@ -261,7 +261,7 @@ Right keys [1]: [r_regionkey#25] Join condition: None (47) Project [codegen id : 8] -Output [3]: [year(o_orderdate#14) AS o_year#28, CheckOverflow((promote_precision(cast(l_extendedprice#6 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#7 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true) AS volume#29, n_name#23 AS nation#30] +Output [3]: [year(o_orderdate#14) AS o_year#28, CheckOverflow((promote_precision(cast(l_extendedprice#6 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#7 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)) AS volume#29, n_name#23 AS nation#30] Input [6]: [l_extendedprice#6, l_discount#7, o_orderdate#14, n_regionkey#20, n_name#23, r_regionkey#25] (48) HashAggregate [codegen id : 8] @@ -280,7 +280,7 @@ Input [5]: [o_year#28, sum#35, isEmpty#36, sum#37, isEmpty#38] Keys [1]: [o_year#28] Functions [2]: [sum(CASE WHEN (nation#30 = BRAZIL) THEN volume#29 ELSE 0 END), sum(volume#29)] Aggregate Attributes [2]: [sum(CASE WHEN (nation#30 = BRAZIL) THEN volume#29 ELSE 0 END)#40, sum(volume#29)#41] -Results [2]: [o_year#28, CheckOverflow((promote_precision(sum(CASE WHEN (nation#30 = BRAZIL) THEN volume#29 ELSE 0 END)#40) / promote_precision(sum(volume#29)#41)), DecimalType(38,6), true) AS mkt_share#42] +Results [2]: [o_year#28, CheckOverflow((promote_precision(sum(CASE WHEN (nation#30 = BRAZIL) THEN volume#29 ELSE 0 END)#40) / promote_precision(sum(volume#29)#41)), DecimalType(38,6)) AS mkt_share#42] (51) Exchange Input [2]: [o_year#28, mkt_share#42] diff --git a/sql/core/src/test/resources/tpch-plan-stability/q9/explain.txt b/sql/core/src/test/resources/tpch-plan-stability/q9/explain.txt index 511c6b80f8cf0..9ed3700e668e0 100644 --- a/sql/core/src/test/resources/tpch-plan-stability/q9/explain.txt +++ b/sql/core/src/test/resources/tpch-plan-stability/q9/explain.txt @@ -190,7 +190,7 @@ Right keys [1]: [n_nationkey#20] Join condition: None (34) Project [codegen id : 6] -Output [3]: [n_name#21 AS nation#23, year(o_orderdate#18) AS o_year#24, CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(l_extendedprice#7 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#8 as decimal(11,0)))), DecimalType(11,0), true))), DecimalType(22,0), true) as decimal(23,0))) - promote_precision(cast(CheckOverflow((promote_precision(ps_supplycost#15) * promote_precision(l_quantity#6)), DecimalType(21,0), true) as decimal(23,0)))), DecimalType(23,0), true) AS amount#25] +Output [3]: [n_name#21 AS nation#23, year(o_orderdate#18) AS o_year#24, CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(l_extendedprice#7 as decimal(11,0))) * promote_precision(CheckOverflow((1 - promote_precision(cast(l_discount#8 as decimal(11,0)))), DecimalType(11,0)))), DecimalType(22,0)) as decimal(23,0))) - promote_precision(cast(CheckOverflow((promote_precision(ps_supplycost#15) * promote_precision(l_quantity#6)), DecimalType(21,0)) as decimal(23,0)))), DecimalType(23,0)) AS amount#25] Input [8]: [l_quantity#6, l_extendedprice#7, l_discount#8, s_nationkey#11, ps_supplycost#15, o_orderdate#18, n_nationkey#20, n_name#21] (35) HashAggregate [codegen id : 6] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala index 7ee533ac26d2b..dd30ff68da417 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala @@ -42,7 +42,7 @@ abstract class CTEInlineSuiteBase """.stripMargin) checkAnswer(df, Nil) assert( - df.queryExecution.optimizedPlan.find(_.isInstanceOf[WithCTE]).nonEmpty, + df.queryExecution.optimizedPlan.exists(_.isInstanceOf[WithCTE]), "Non-deterministic With-CTE with multiple references should be not inlined.") } } @@ -59,7 +59,7 @@ abstract class CTEInlineSuiteBase """.stripMargin) checkAnswer(df, Nil) assert( - df.queryExecution.optimizedPlan.find(_.isInstanceOf[WithCTE]).nonEmpty, + df.queryExecution.optimizedPlan.exists(_.isInstanceOf[WithCTE]), "Non-deterministic With-CTE with multiple references should be not inlined.") } } @@ -76,10 +76,10 @@ abstract class CTEInlineSuiteBase """.stripMargin) checkAnswer(df, Row(0, 1) :: Row(1, 2) :: Nil) assert( - df.queryExecution.analyzed.find(_.isInstanceOf[WithCTE]).nonEmpty, + df.queryExecution.analyzed.exists(_.isInstanceOf[WithCTE]), "With-CTE should not be inlined in analyzed plan.") assert( - df.queryExecution.optimizedPlan.find(_.isInstanceOf[WithCTE]).isEmpty, + !df.queryExecution.optimizedPlan.exists(_.isInstanceOf[WithCTE]), "With-CTE with one reference should be inlined in optimized plan.") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index 10eacdb08c424..6ade7a7c99e37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -332,8 +332,8 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { sql(s"CREATE TABLE t(c STRUCT) USING $format") sql("INSERT INTO t SELECT struct(null)") checkAnswer(spark.table("t"), Row(Row(null))) - val e = intercept[SparkException](sql("INSERT INTO t SELECT struct('123456')")) - assert(e.getCause.getMessage.contains(s"Exceeds char/varchar type length limitation: 5")) + val e = intercept[RuntimeException](sql("INSERT INTO t SELECT struct('123456')")) + assert(e.getMessage.contains(s"Exceeds char/varchar type length limitation: 5")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index fe56bcb99117e..995bf5d903ad4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -281,9 +281,11 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { testData.select(isnan($"a"), isnan($"b")), Row(true, true) :: Row(true, true) :: Row(false, false) :: Row(false, false) :: Nil) - checkAnswer( - sql("select isnan(15), isnan('invalid')"), - Row(false, false)) + if (!conf.ansiEnabled) { + checkAnswer( + sql("select isnan(15), isnan('invalid')"), + Row(false, false)) + } } test("nanvl") { @@ -934,14 +936,14 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("SPARK-37646: lit") { assert(lit($"foo") == $"foo") - assert(lit('foo) == $"foo") + assert(lit(Symbol("foo")) == $"foo") assert(lit(1) == Column(Literal(1))) assert(lit(null) == Column(Literal(null, NullType))) } test("typedLit") { assert(typedLit($"foo") == $"foo") - assert(typedLit('foo) == $"foo") + assert(typedLit(Symbol("foo")) == $"foo") assert(typedLit(1) == Column(Literal(1))) assert(typedLit[String](null) == Column(Literal(null, StringType))) @@ -1029,17 +1031,17 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should throw an exception if any intermediate structs don't exist") { intercept[AnalysisException] { - structLevel2.withColumn("a", 'a.withField("x.b", lit(2))) + structLevel2.withColumn("a", Symbol("a").withField("x.b", lit(2))) }.getMessage should include("No such struct field x in a") intercept[AnalysisException] { - structLevel3.withColumn("a", 'a.withField("a.x.b", lit(2))) + structLevel3.withColumn("a", Symbol("a").withField("a.x.b", lit(2))) }.getMessage should include("No such struct field x in a") } test("withField should throw an exception if intermediate field is not a struct") { intercept[AnalysisException] { - structLevel1.withColumn("a", 'a.withField("b.a", lit(2))) + structLevel1.withColumn("a", Symbol("a").withField("b.a", lit(2))) }.getMessage should include("struct argument should be struct type, got: int") } @@ -1053,7 +1055,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { StructField("a", structType, nullable = false))), nullable = false)))) - structLevel2.withColumn("a", 'a.withField("a.b", lit(2))) + structLevel2.withColumn("a", Symbol("a").withField("a.b", lit(2))) }.getMessage should include("Ambiguous reference to fields") } @@ -1072,7 +1074,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should add field to struct") { checkAnswer( - structLevel1.withColumn("a", 'a.withField("d", lit(4))), + structLevel1.withColumn("a", Symbol("a").withField("d", lit(4))), Row(Row(1, null, 3, 4)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1113,7 +1115,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should add null field to struct") { checkAnswer( - structLevel1.withColumn("a", 'a.withField("d", lit(null).cast(IntegerType))), + structLevel1.withColumn("a", Symbol("a").withField("d", lit(null).cast(IntegerType))), Row(Row(1, null, 3, null)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1126,7 +1128,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should add multiple fields to struct") { checkAnswer( - structLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("e", lit(5))), + structLevel1.withColumn("a", Symbol("a").withField("d", lit(4)).withField("e", lit(5))), Row(Row(1, null, 3, 4, 5)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1140,7 +1142,8 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should add multiple fields to nullable struct") { checkAnswer( - nullableStructLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("e", lit(5))), + nullableStructLevel1.withColumn("a", Symbol("a") + .withField("d", lit(4)).withField("e", lit(5))), Row(null) :: Row(Row(1, null, 3, 4, 5)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1154,8 +1157,8 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should add field to nested struct") { Seq( - structLevel2.withColumn("a", 'a.withField("a.d", lit(4))), - structLevel2.withColumn("a", 'a.withField("a", $"a.a".withField("d", lit(4)))) + structLevel2.withColumn("a", Symbol("a").withField("a.d", lit(4))), + structLevel2.withColumn("a", Symbol("a").withField("a", $"a.a".withField("d", lit(4)))) ).foreach { df => checkAnswer( df, @@ -1216,7 +1219,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should add field to deeply nested struct") { checkAnswer( - structLevel3.withColumn("a", 'a.withField("a.a.d", lit(4))), + structLevel3.withColumn("a", Symbol("a").withField("a.a.d", lit(4))), Row(Row(Row(Row(1, null, 3, 4)))) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1233,7 +1236,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should replace field in struct") { checkAnswer( - structLevel1.withColumn("a", 'a.withField("b", lit(2))), + structLevel1.withColumn("a", Symbol("a").withField("b", lit(2))), Row(Row(1, 2, 3)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1245,7 +1248,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should replace field in nullable struct") { checkAnswer( - nullableStructLevel1.withColumn("a", 'a.withField("b", lit("foo"))), + nullableStructLevel1.withColumn("a", Symbol("a").withField("b", lit("foo"))), Row(null) :: Row(Row(1, "foo", 3)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1271,7 +1274,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should replace field with null value in struct") { checkAnswer( - structLevel1.withColumn("a", 'a.withField("c", lit(null).cast(IntegerType))), + structLevel1.withColumn("a", Symbol("a").withField("c", lit(null).cast(IntegerType))), Row(Row(1, null, null)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1283,7 +1286,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should replace multiple fields in struct") { checkAnswer( - structLevel1.withColumn("a", 'a.withField("a", lit(10)).withField("b", lit(20))), + structLevel1.withColumn("a", Symbol("a").withField("a", lit(10)).withField("b", lit(20))), Row(Row(10, 20, 3)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1295,7 +1298,8 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should replace multiple fields in nullable struct") { checkAnswer( - nullableStructLevel1.withColumn("a", 'a.withField("a", lit(10)).withField("b", lit(20))), + nullableStructLevel1.withColumn("a", Symbol("a").withField("a", lit(10)) + .withField("b", lit(20))), Row(null) :: Row(Row(10, 20, 3)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1308,7 +1312,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should replace field in nested struct") { Seq( structLevel2.withColumn("a", $"a".withField("a.b", lit(2))), - structLevel2.withColumn("a", 'a.withField("a", $"a.a".withField("b", lit(2)))) + structLevel2.withColumn("a", Symbol("a").withField("a", $"a.a".withField("b", lit(2)))) ).foreach { df => checkAnswer( df, @@ -1389,7 +1393,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) checkAnswer( - structLevel1.withColumn("a", 'a.withField("b", lit(100))), + structLevel1.withColumn("a", Symbol("a").withField("b", lit(100))), Row(Row(1, 100, 100)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1401,7 +1405,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should replace fields in struct in given order") { checkAnswer( - structLevel1.withColumn("a", 'a.withField("b", lit(2)).withField("b", lit(20))), + structLevel1.withColumn("a", Symbol("a").withField("b", lit(2)).withField("b", lit(20))), Row(Row(1, 20, 3)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1413,7 +1417,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should add field and then replace same field in struct") { checkAnswer( - structLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("d", lit(5))), + structLevel1.withColumn("a", Symbol("a").withField("d", lit(4)).withField("d", lit(5))), Row(Row(1, null, 3, 5)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1437,7 +1441,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) checkAnswer( - df.withColumn("a", 'a.withField("`a.b`.`e.f`", lit(2))), + df.withColumn("a", Symbol("a").withField("`a.b`.`e.f`", lit(2))), Row(Row(Row(1, 2, 3))) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1449,7 +1453,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) intercept[AnalysisException] { - df.withColumn("a", 'a.withField("a.b.e.f", lit(2))) + df.withColumn("a", Symbol("a").withField("a.b.e.f", lit(2))) }.getMessage should include("No such struct field a in a.b") } @@ -1464,7 +1468,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should replace field in struct even if casing is different") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { checkAnswer( - mixedCaseStructLevel1.withColumn("a", 'a.withField("A", lit(2))), + mixedCaseStructLevel1.withColumn("a", Symbol("a").withField("A", lit(2))), Row(Row(2, 1)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1473,7 +1477,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) checkAnswer( - mixedCaseStructLevel1.withColumn("a", 'a.withField("b", lit(2))), + mixedCaseStructLevel1.withColumn("a", Symbol("a").withField("b", lit(2))), Row(Row(1, 2)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1486,7 +1490,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should add field to struct because casing is different") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { checkAnswer( - mixedCaseStructLevel1.withColumn("a", 'a.withField("A", lit(2))), + mixedCaseStructLevel1.withColumn("a", Symbol("a").withField("A", lit(2))), Row(Row(1, 1, 2)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1496,7 +1500,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) checkAnswer( - mixedCaseStructLevel1.withColumn("a", 'a.withField("b", lit(2))), + mixedCaseStructLevel1.withColumn("a", Symbol("a").withField("b", lit(2))), Row(Row(1, 1, 2)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1524,7 +1528,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should replace nested field in struct even if casing is different") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { checkAnswer( - mixedCaseStructLevel2.withColumn("a", 'a.withField("A.a", lit(2))), + mixedCaseStructLevel2.withColumn("a", Symbol("a").withField("A.a", lit(2))), Row(Row(Row(2, 1), Row(1, 1))) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1539,7 +1543,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) checkAnswer( - mixedCaseStructLevel2.withColumn("a", 'a.withField("b.a", lit(2))), + mixedCaseStructLevel2.withColumn("a", Symbol("a").withField("b.a", lit(2))), Row(Row(Row(1, 1), Row(2, 1))) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1558,11 +1562,11 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should throw an exception because casing is different") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { intercept[AnalysisException] { - mixedCaseStructLevel2.withColumn("a", 'a.withField("A.a", lit(2))) + mixedCaseStructLevel2.withColumn("a", Symbol("a").withField("A.a", lit(2))) }.getMessage should include("No such struct field A in a, B") intercept[AnalysisException] { - mixedCaseStructLevel2.withColumn("a", 'a.withField("b.a", lit(2))) + mixedCaseStructLevel2.withColumn("a", Symbol("a").withField("b.a", lit(2))) }.getMessage should include("No such struct field b in a, B") } } @@ -1769,17 +1773,17 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("dropFields should throw an exception if any intermediate structs don't exist") { intercept[AnalysisException] { - structLevel2.withColumn("a", 'a.dropFields("x.b")) + structLevel2.withColumn("a", Symbol("a").dropFields("x.b")) }.getMessage should include("No such struct field x in a") intercept[AnalysisException] { - structLevel3.withColumn("a", 'a.dropFields("a.x.b")) + structLevel3.withColumn("a", Symbol("a").dropFields("a.x.b")) }.getMessage should include("No such struct field x in a") } test("dropFields should throw an exception if intermediate field is not a struct") { intercept[AnalysisException] { - structLevel1.withColumn("a", 'a.dropFields("b.a")) + structLevel1.withColumn("a", Symbol("a").dropFields("b.a")) }.getMessage should include("struct argument should be struct type, got: int") } @@ -1793,13 +1797,13 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { StructField("a", structType, nullable = false))), nullable = false)))) - structLevel2.withColumn("a", 'a.dropFields("a.b")) + structLevel2.withColumn("a", Symbol("a").dropFields("a.b")) }.getMessage should include("Ambiguous reference to fields") } test("dropFields should drop field in struct") { checkAnswer( - structLevel1.withColumn("a", 'a.dropFields("b")), + structLevel1.withColumn("a", Symbol("a").dropFields("b")), Row(Row(1, 3)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1822,7 +1826,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("dropFields should drop multiple fields in struct") { Seq( structLevel1.withColumn("a", $"a".dropFields("b", "c")), - structLevel1.withColumn("a", 'a.dropFields("b").dropFields("c")) + structLevel1.withColumn("a", Symbol("a").dropFields("b").dropFields("c")) ).foreach { df => checkAnswer( df, @@ -1836,7 +1840,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("dropFields should throw an exception if no fields will be left in struct") { intercept[AnalysisException] { - structLevel1.withColumn("a", 'a.dropFields("a", "b", "c")) + structLevel1.withColumn("a", Symbol("a").dropFields("a", "b", "c")) }.getMessage should include("cannot drop all fields in struct") } @@ -1860,7 +1864,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("dropFields should drop field in nested struct") { checkAnswer( - structLevel2.withColumn("a", 'a.dropFields("a.b")), + structLevel2.withColumn("a", Symbol("a").dropFields("a.b")), Row(Row(Row(1, 3))) :: Nil, StructType( Seq(StructField("a", StructType(Seq( @@ -1873,7 +1877,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("dropFields should drop multiple fields in nested struct") { checkAnswer( - structLevel2.withColumn("a", 'a.dropFields("a.b", "a.c")), + structLevel2.withColumn("a", Symbol("a").dropFields("a.b", "a.c")), Row(Row(Row(1))) :: Nil, StructType( Seq(StructField("a", StructType(Seq( @@ -1910,7 +1914,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("dropFields should drop field in deeply nested struct") { checkAnswer( - structLevel3.withColumn("a", 'a.dropFields("a.a.b")), + structLevel3.withColumn("a", Symbol("a").dropFields("a.a.b")), Row(Row(Row(Row(1, 3)))) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1934,7 +1938,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) checkAnswer( - structLevel1.withColumn("a", 'a.dropFields("b")), + structLevel1.withColumn("a", Symbol("a").dropFields("b")), Row(Row(1)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1945,7 +1949,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("dropFields should drop field in struct even if casing is different") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { checkAnswer( - mixedCaseStructLevel1.withColumn("a", 'a.dropFields("A")), + mixedCaseStructLevel1.withColumn("a", Symbol("a").dropFields("A")), Row(Row(1)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1953,7 +1957,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) checkAnswer( - mixedCaseStructLevel1.withColumn("a", 'a.dropFields("b")), + mixedCaseStructLevel1.withColumn("a", Symbol("a").dropFields("b")), Row(Row(1)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1965,7 +1969,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("dropFields should not drop field in struct because casing is different") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { checkAnswer( - mixedCaseStructLevel1.withColumn("a", 'a.dropFields("A")), + mixedCaseStructLevel1.withColumn("a", Symbol("a").dropFields("A")), Row(Row(1, 1)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1974,7 +1978,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) checkAnswer( - mixedCaseStructLevel1.withColumn("a", 'a.dropFields("b")), + mixedCaseStructLevel1.withColumn("a", Symbol("a").dropFields("b")), Row(Row(1, 1)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1987,7 +1991,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("dropFields should drop nested field in struct even if casing is different") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { checkAnswer( - mixedCaseStructLevel2.withColumn("a", 'a.dropFields("A.a")), + mixedCaseStructLevel2.withColumn("a", Symbol("a").dropFields("A.a")), Row(Row(Row(1), Row(1, 1))) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -2001,7 +2005,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) checkAnswer( - mixedCaseStructLevel2.withColumn("a", 'a.dropFields("b.a")), + mixedCaseStructLevel2.withColumn("a", Symbol("a").dropFields("b.a")), Row(Row(Row(1, 1), Row(1))) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -2019,18 +2023,18 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("dropFields should throw an exception because casing is different") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { intercept[AnalysisException] { - mixedCaseStructLevel2.withColumn("a", 'a.dropFields("A.a")) + mixedCaseStructLevel2.withColumn("a", Symbol("a").dropFields("A.a")) }.getMessage should include("No such struct field A in a, B") intercept[AnalysisException] { - mixedCaseStructLevel2.withColumn("a", 'a.dropFields("b.a")) + mixedCaseStructLevel2.withColumn("a", Symbol("a").dropFields("b.a")) }.getMessage should include("No such struct field b in a, B") } } test("dropFields should drop only fields that exist") { checkAnswer( - structLevel1.withColumn("a", 'a.dropFields("d")), + structLevel1.withColumn("a", Symbol("a").dropFields("d")), Row(Row(1, null, 3)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -2040,7 +2044,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) checkAnswer( - structLevel1.withColumn("a", 'a.dropFields("b", "d")), + structLevel1.withColumn("a", Symbol("a").dropFields("b", "d")), Row(Row(1, 3)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index 2808652f2998d..461bbd8987cef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -82,16 +82,16 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { test("schema_of_csv - infers schemas") { checkAnswer( spark.range(1).select(schema_of_csv(lit("0.1,1"))), - Seq(Row("STRUCT<`_c0`: DOUBLE, `_c1`: INT>"))) + Seq(Row("STRUCT<_c0: DOUBLE, _c1: INT>"))) checkAnswer( spark.range(1).select(schema_of_csv("0.1,1")), - Seq(Row("STRUCT<`_c0`: DOUBLE, `_c1`: INT>"))) + Seq(Row("STRUCT<_c0: DOUBLE, _c1: INT>"))) } test("schema_of_csv - infers schemas using options") { val df = spark.range(1) .select(schema_of_csv(lit("0.1 1"), Map("sep" -> " ").asJava)) - checkAnswer(df, Seq(Row("STRUCT<`_c0`: DOUBLE, `_c1`: INT>"))) + checkAnswer(df, Seq(Row("STRUCT<_c0: DOUBLE, _c1: INT>"))) } test("to_csv - struct") { @@ -220,7 +220,7 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { val input = concat_ws(",", lit(0.1), lit(1)) checkAnswer( spark.range(1).select(schema_of_csv(input)), - Seq(Row("STRUCT<`_c0`: DOUBLE, `_c1`: INT>"))) + Seq(Row("STRUCT<_c0: DOUBLE, _c1: INT>"))) } test("optional datetime parser does not affect csv time formatting") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index c3076c5880ae9..157736f9777e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -1025,10 +1025,15 @@ class DataFrameAggregateSuite extends QueryTest sql("SELECT x FROM tempView GROUP BY x HAVING COUNT_IF(NULL) > 0"), Nil) - val error = intercept[AnalysisException] { - sql("SELECT COUNT_IF(x) FROM tempView") + // When ANSI mode is on, it will implicit cast the string as boolean and throw a runtime + // error. Here we simply test with ANSI mode off. + if (!conf.ansiEnabled) { + val error = intercept[AnalysisException] { + sql("SELECT COUNT_IF(x) FROM tempView") + } + assert(error.message.contains("cannot resolve 'count_if(tempview.x)' due to data type " + + "mismatch: argument 1 requires boolean type, however, 'tempview.x' is of string type")) } - assert(error.message.contains("function count_if requires boolean type")) } } @@ -1135,9 +1140,11 @@ class DataFrameAggregateSuite extends QueryTest val mapDF = Seq(Tuple1(Map("a" -> "a"))).toDF("col") checkAnswer(mapDF.groupBy(struct($"col.a")).count().select("count"), Row(1)) - val nonStringMapDF = Seq(Tuple1(Map(1 -> 1))).toDF("col") - // Spark implicit casts string literal "a" to int to match the key type. - checkAnswer(nonStringMapDF.groupBy(struct($"col.a")).count().select("count"), Row(1)) + if (!conf.ansiEnabled) { + val nonStringMapDF = Seq(Tuple1(Map(1 -> 1))).toDF("col") + // Spark implicit casts string literal "a" to int to match the key type. + checkAnswer(nonStringMapDF.groupBy(struct($"col.a")).count().select("count"), Row(1)) + } val arrayDF = Seq(Tuple1(Seq(1))).toDF("col") val e = intercept[AnalysisException](arrayDF.groupBy(struct($"col.a")).count()) @@ -1443,6 +1450,16 @@ class DataFrameAggregateSuite extends QueryTest val res = df.select($"d".cast("decimal(12, 2)").as("d")).agg(avg($"d").cast("string")) checkAnswer(res, Row("9999999999.990000")) } + + test("SPARK-38185: Fix data incorrect if aggregate function is empty") { + val emptyAgg = Map.empty[String, String] + assert(spark.range(2).where("id > 2").agg(emptyAgg).limit(1).count == 1) + } + + test("SPARK-38221: group by stream of complex expressions should not fail") { + val df = Seq(1).toDF("id").groupBy(Stream($"id" + 1, $"id" + 2): _*).sum("id") + checkAnswer(df, Row(2, 3, 1)) + } } case class B(c: Option[Double]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 1ddb238d1db2f..4d82d110a4c51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -481,6 +481,16 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { spark.sql("drop temporary function fStringLength") } + test("SPARK-38130: array_sort with lambda of non-orderable items") { + val df6 = Seq((Array[Map[String, Int]](Map("a" -> 1), Map("b" -> 2, "c" -> 3), + Map()), "x")).toDF("a", "b") + checkAnswer( + df6.selectExpr("array_sort(a, (x, y) -> cardinality(x) - cardinality(y))"), + Seq( + Row(Seq[Map[String, Int]](Map(), Map("a" -> 1), Map("b" -> 2, "c" -> 3)))) + ) + } + test("sort_array/array_sort functions") { val df = Seq( (Array[Int](2, 1, 3), Array("b", "c", "a")), @@ -569,8 +579,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } test("array size function - legacy") { - withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") { - testSizeOfArray(sizeOfNull = -1) + if (!conf.ansiEnabled) { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") { + testSizeOfArray(sizeOfNull = -1) + } } } @@ -722,8 +734,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } test("map size function - legacy") { - withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") { - testSizeOfMap(sizeOfNull = -1: Int) + if (!conf.ansiEnabled) { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") { + testSizeOfMap(sizeOfNull = -1: Int) + } } } @@ -1017,15 +1031,17 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Seq(Row(false)) ) - val e1 = intercept[AnalysisException] { - OneRowRelation().selectExpr("array_contains(array(1), .01234567890123456790123456780)") - } - val errorMsg1 = - s""" - |Input to function array_contains should have been array followed by a - |value with same element type, but it's [array, decimal(38,29)]. + if (!conf.ansiEnabled) { + val e1 = intercept[AnalysisException] { + OneRowRelation().selectExpr("array_contains(array(1), .01234567890123456790123456780)") + } + val errorMsg1 = + s""" + |Input to function array_contains should have been array followed by a + |value with same element type, but it's [array, decimal(38,29)]. """.stripMargin.replace("\n", " ").trim() - assert(e1.message.contains(errorMsg1)) + assert(e1.message.contains(errorMsg1)) + } val e2 = intercept[AnalysisException] { OneRowRelation().selectExpr("array_contains(array(1), 'foo')") @@ -1454,41 +1470,43 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Seq(Row(null), Row(null), Row(null)) ) } - checkAnswer( - df.select(element_at(df("a"), 4)), - Seq(Row(null), Row(null), Row(null)) - ) - checkAnswer( - df.select(element_at(df("a"), df("b"))), - Seq(Row("1"), Row(""), Row(null)) - ) - checkAnswer( - df.selectExpr("element_at(a, b)"), - Seq(Row("1"), Row(""), Row(null)) - ) + if (!conf.ansiEnabled) { + checkAnswer( + df.select(element_at(df("a"), 4)), + Seq(Row(null), Row(null), Row(null)) + ) + checkAnswer( + df.select(element_at(df("a"), df("b"))), + Seq(Row("1"), Row(""), Row(null)) + ) + checkAnswer( + df.selectExpr("element_at(a, b)"), + Seq(Row("1"), Row(""), Row(null)) + ) - checkAnswer( - df.select(element_at(df("a"), 1)), - Seq(Row("1"), Row(null), Row(null)) - ) - checkAnswer( - df.select(element_at(df("a"), -1)), - Seq(Row("3"), Row(""), Row(null)) - ) + checkAnswer( + df.select(element_at(df("a"), 1)), + Seq(Row("1"), Row(null), Row(null)) + ) + checkAnswer( + df.select(element_at(df("a"), -1)), + Seq(Row("3"), Row(""), Row(null)) + ) - checkAnswer( - df.selectExpr("element_at(a, 4)"), - Seq(Row(null), Row(null), Row(null)) - ) + checkAnswer( + df.selectExpr("element_at(a, 4)"), + Seq(Row(null), Row(null), Row(null)) + ) - checkAnswer( - df.selectExpr("element_at(a, 1)"), - Seq(Row("1"), Row(null), Row(null)) - ) - checkAnswer( - df.selectExpr("element_at(a, -1)"), - Seq(Row("3"), Row(""), Row(null)) - ) + checkAnswer( + df.selectExpr("element_at(a, 1)"), + Seq(Row("1"), Row(null), Row(null)) + ) + checkAnswer( + df.selectExpr("element_at(a, -1)"), + Seq(Row("3"), Row(""), Row(null)) + ) + } val e1 = intercept[AnalysisException] { Seq(("a string element", 1)).toDF().selectExpr("element_at(_1, _2)") @@ -1550,10 +1568,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Seq(Row("a")) ) - checkAnswer( - OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 1.23D)"), - Seq(Row(null)) - ) + if (!conf.ansiEnabled) { + checkAnswer( + OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 1.23D)"), + Seq(Row(null)) + ) + } val e3 = intercept[AnalysisException] { OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), '1')") @@ -1628,10 +1648,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { // Simple test cases def simpleTest(): Unit = { - checkAnswer ( - df.select(concat($"i1", $"s1")), - Seq(Row(Seq("1", "a", "b", "c")), Row(Seq("1", "0", "a"))) - ) + if (!conf.ansiEnabled) { + checkAnswer( + df.select(concat($"i1", $"s1")), + Seq(Row(Seq("1", "a", "b", "c")), Row(Seq("1", "0", "a"))) + ) + } checkAnswer( df.select(concat($"i1", $"i2", $"i3")), Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 20ae995af628b..8dbc57c0429c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -444,21 +444,25 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { } test("replace float with nan") { - checkAnswer( - createNaNDF().na.replace("*", Map( - 1.0f -> Float.NaN - )), - Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: - Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil) + if (!conf.ansiEnabled) { + checkAnswer( + createNaNDF().na.replace("*", Map( + 1.0f -> Float.NaN + )), + Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: + Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil) + } } test("replace double with nan") { - checkAnswer( - createNaNDF().na.replace("*", Map( - 1.0 -> Double.NaN - )), - Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: - Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil) + if (!conf.ansiEnabled) { + checkAnswer( + createNaNDF().na.replace("*", Map( + 1.0 -> Double.NaN + )), + Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: + Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil) + } } test("SPARK-34417: test fillMap() for column with a dot in the name") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 32cbb8b457d86..1a0c95beb18b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import java.time.LocalDateTime import java.util.Locale import org.apache.spark.sql.catalyst.expressions.aggregate.PivotFirst @@ -323,17 +324,6 @@ class DataFramePivotSuite extends QueryTest with SharedSparkSession { checkAnswer(df, expected) } - test("pivoting column list") { - val exception = intercept[RuntimeException] { - trainingSales - .groupBy($"sales.year") - .pivot(struct(lower($"sales.course"), $"training")) - .agg(sum($"sales.earnings")) - .collect() - } - assert(exception.getMessage.contains("Unsupported literal type")) - } - test("SPARK-26403: pivoting by array column") { val df = Seq( (2, Seq.empty[String]), @@ -352,4 +342,16 @@ class DataFramePivotSuite extends QueryTest with SharedSparkSession { percentile_approx(col("value"), array(lit(0.5)), lit(10000))) checkAnswer(actual, Row(Array(2.5), Array(3.0))) } + + test("SPARK-38133: Grouping by TIMESTAMP_NTZ should not corrupt results") { + checkAnswer( + courseSales.withColumn("ts", $"year".cast("string").cast("timestamp_ntz")) + .groupBy("ts") + .pivot("course", Seq("dotNET", "Java")) + .agg(sum($"earnings")) + .select("ts", "dotNET", "Java"), + Row(LocalDateTime.of(2012, 1, 1, 0, 0, 0, 0), 15000.0, 20000.0) :: + Row(LocalDateTime.of(2013, 1, 1, 0, 0, 0, 0), 48000.0, 30000.0) :: Nil + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index fc549e307c80f..917f80e58108e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -63,13 +63,15 @@ class DataFrameRangeSuite extends QueryTest with SharedSparkSession with Eventua val res7 = spark.range(-10, -9, -20, 1).select("id") assert(res7.count == 0) - val res8 = spark.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") - assert(res8.count == 3) - assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3))) - - val res9 = spark.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") - assert(res9.count == 2) - assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1))) + if (!conf.ansiEnabled) { + val res8 = spark.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") + assert(res8.count == 3) + assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3))) + + val res9 = spark.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") + assert(res9.count == 2) + assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1))) + } // only end provided as argument val res10 = spark.range(10).select("id") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala index a0ddabcf76043..4d0dd46b9569c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, AttributeReference, PythonUDF, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{Expand, Generate, ScriptInputOutputSchema, ScriptTransformation, Window => WindowPlan} import org.apache.spark.sql.expressions.Window -import org.apache.spark.sql.functions.{count, explode, sum} +import org.apache.spark.sql.functions.{count, explode, sum, year} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.test.SQLTestData.TestData @@ -467,4 +467,21 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { assertAmbiguousSelfJoin(df21.join(df22, df21("x") === df22("y"))) assertAmbiguousSelfJoin(df22.join(df21, df21("x") === df22("y"))) } + + test("SPARK-35937: GetDateFieldOperations should skip unresolved nodes") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + val df = Seq("1644821603").map(i => (i.toInt, i)).toDF("tsInt", "tsStr") + val df1 = df.select(df("tsStr").cast("timestamp")).as("df1") + val df2 = df.select(df("tsStr").cast("timestamp")).as("df2") + df1.join(df2, $"df1.tsStr" === $"df2.tsStr", "left_outer") + val df3 = df1.join(df2, $"df1.tsStr" === $"df2.tsStr", "left_outer") + .select($"df1.tsStr".as("timeStr")).as("df3") + // Before the fix, it throws "UnresolvedException: Invalid call to + // dataType on unresolved object". + val ex = intercept[AnalysisException]( + df3.join(df1, year($"df1.timeStr") === year($"df3.tsStr")) + ) + assert(ex.message.contains("Column 'df1.timeStr' does not exist.")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala index b3d212716dd9a..a5414f3e805fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala @@ -21,6 +21,7 @@ import java.time.LocalDateTime import org.scalatest.BeforeAndAfterEach +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand} import org.apache.spark.sql.functions._ @@ -82,7 +83,7 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession // key "b" => (19:39:27 ~ 19:39:37) checkAnswer( - df.groupBy(session_window($"time", "10 seconds"), 'id) + df.groupBy(session_window($"time", "10 seconds"), Symbol("id")) .agg(count("*").as("counts"), sum("value").as("sum")) .orderBy($"session_window.start".asc) .selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)", @@ -112,7 +113,7 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession // key "b" => (19:39:27 ~ 19:39:37) checkAnswer( - df.groupBy(session_window($"time", "10 seconds"), 'id) + df.groupBy(session_window($"time", "10 seconds"), Symbol("id")) .agg(count("*").as("counts"), sum_distinct(col("value")).as("sum")) .orderBy($"session_window.start".asc) .selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)", @@ -141,7 +142,7 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession // key "b" => (19:39:27 ~ 19:39:37) checkAnswer( - df.groupBy(session_window($"time", "10 seconds"), 'id) + df.groupBy(session_window($"time", "10 seconds"), Symbol("id")) .agg(sum_distinct(col("value")).as("sum"), sum_distinct(col("value2")).as("sum2")) .orderBy($"session_window.start".asc) .selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)", @@ -170,7 +171,7 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession // b => (19:39:27 ~ 19:39:37), (19:39:39 ~ 19:39:55) checkAnswer( - df.groupBy(session_window($"time", "10 seconds"), 'id) + df.groupBy(session_window($"time", "10 seconds"), Symbol("id")) .agg(count("*").as("counts"), sum("value").as("sum")) .orderBy($"session_window.start".asc) .selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)", @@ -381,6 +382,34 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession } } + test("SPARK-36465: filter out events with invalid gap duration.") { + val df = Seq( + ("2016-03-27 19:39:30", 1, "a")).toDF("time", "value", "id") + + checkAnswer( + df.groupBy(session_window($"time", "x sec")) + .agg(count("*").as("counts")) + .orderBy($"session_window.start".asc) + .select($"session_window.start".cast("string"), $"session_window.end".cast("string"), + $"counts"), + Seq() + ) + + withTempTable { table => + checkAnswer( + spark.sql("select session_window(time, " + + """case when value = 1 then "2 seconds" when value = 2 then "invalid gap duration" """ + + s"""else "20 seconds" end), value from $table""") + .select($"session_window.start".cast(StringType), $"session_window.end".cast(StringType), + $"value"), + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:47", 4), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:36", 1) + ) + ) + } + } + test("SPARK-36724: Support timestamp_ntz as a type of time column for SessionWindow") { val df = Seq((LocalDateTime.parse("2016-03-27T19:39:30"), 1, "a"), (LocalDateTime.parse("2016-03-27T19:39:25"), 2, "a")).toDF("time", "value", "id") @@ -406,4 +435,64 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession checkAnswer(aggDF, Seq(Row("2016-03-27 19:39:25", "2016-03-27 19:39:40", 2))) } + + test("SPARK-38227: 'start' and 'end' fields should be nullable") { + // We expect the fields in window struct as nullable since the dataType of SessionWindow + // defines them as nullable. The rule 'SessionWindowing' should respect the dataType. + val df1 = Seq( + ("hello", "2016-03-27 09:00:05", 1), + ("structured", "2016-03-27 09:00:32", 2)).toDF("id", "time", "value") + val df2 = Seq( + ("world", LocalDateTime.parse("2016-03-27T09:00:05"), 1), + ("spark", LocalDateTime.parse("2016-03-27T09:00:32"), 2)).toDF("id", "time", "value") + + val udf = spark.udf.register("gapDuration", (s: String) => { + if (s == "hello") { + "1 second" + } else if (s == "structured") { + // zero gap duration will be filtered out from aggregation + "0 second" + } else if (s == "world") { + // negative gap duration will be filtered out from aggregation + "-10 seconds" + } else { + "10 seconds" + } + }) + + def validateWindowColumnInSchema(schema: StructType, colName: String): Unit = { + schema.find(_.name == colName) match { + case Some(StructField(_, st: StructType, _, _)) => + assertFieldInWindowStruct(st, "start") + assertFieldInWindowStruct(st, "end") + + case _ => fail("Failed to find suitable window column from DataFrame!") + } + } + + def assertFieldInWindowStruct(windowType: StructType, fieldName: String): Unit = { + val field = windowType.fields.find(_.name == fieldName) + assert(field.isDefined, s"'$fieldName' field should exist in window struct") + assert(field.get.nullable, s"'$fieldName' field should be nullable") + } + + for { + df <- Seq(df1, df2) + nullable <- Seq(true, false) + } { + val dfWithDesiredNullability = new DataFrame(df.queryExecution, RowEncoder( + StructType(df.schema.fields.map(_.copy(nullable = nullable))))) + // session window without dynamic gap + val windowedProject = dfWithDesiredNullability + .select(session_window($"time", "10 seconds").as("session"), $"value") + val schema = windowedProject.queryExecution.optimizedPlan.schema + validateWindowColumnInSchema(schema, "session") + + // session window with dynamic gap + val windowedProject2 = dfWithDesiredNullability + .select(session_window($"time", udf($"id")).as("session"), $"value") + val schema2 = windowedProject2.queryExecution.optimizedPlan.schema + validateWindowColumnInSchema(schema2, "session") + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala index b19e4300b5af4..ca04adf642e15 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -341,7 +341,7 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { ).toDF("date", "timestamp", "decimal") val widenTypedRows = Seq( - (new Timestamp(2), 10.5D, "string") + (new Timestamp(2), 10.5D, "2021-01-01 00:00:00") ).toDF("date", "timestamp", "decimal") dates.union(widenTypedRows).collect() @@ -538,24 +538,25 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { } test("union by name - type coercion") { - var df1 = Seq((1, "a")).toDF("c0", "c1") - var df2 = Seq((3, 1L)).toDF("c1", "c0") - checkAnswer(df1.unionByName(df2), Row(1L, "a") :: Row(1L, "3") :: Nil) - - df1 = Seq((1, 1.0)).toDF("c0", "c1") - df2 = Seq((8L, 3.0)).toDF("c1", "c0") + var df1 = Seq((1, 1.0)).toDF("c0", "c1") + var df2 = Seq((8L, 3.0)).toDF("c1", "c0") checkAnswer(df1.unionByName(df2), Row(1.0, 1.0) :: Row(3.0, 8.0) :: Nil) - - df1 = Seq((2.0f, 7.4)).toDF("c0", "c1") - df2 = Seq(("a", 4.0)).toDF("c1", "c0") - checkAnswer(df1.unionByName(df2), Row(2.0, "7.4") :: Row(4.0, "a") :: Nil) - - df1 = Seq((1, "a", 3.0)).toDF("c0", "c1", "c2") - df2 = Seq((1.2, 2, "bc")).toDF("c2", "c0", "c1") - val df3 = Seq(("def", 1.2, 3)).toDF("c1", "c2", "c0") - checkAnswer(df1.unionByName(df2.unionByName(df3)), - Row(1, "a", 3.0) :: Row(2, "bc", 1.2) :: Row(3, "def", 1.2) :: Nil - ) + if (!conf.ansiEnabled) { + df1 = Seq((1, "a")).toDF("c0", "c1") + df2 = Seq((3, 1L)).toDF("c1", "c0") + checkAnswer(df1.unionByName(df2), Row(1L, "a") :: Row(1L, "3") :: Nil) + + df1 = Seq((2.0f, 7.4)).toDF("c0", "c1") + df2 = Seq(("a", 4.0)).toDF("c1", "c0") + checkAnswer(df1.unionByName(df2), Row(2.0, "7.4") :: Row(4.0, "a") :: Nil) + + df1 = Seq((1, "a", 3.0)).toDF("c0", "c1", "c2") + df2 = Seq((1.2, 2, "bc")).toDF("c2", "c0", "c1") + val df3 = Seq(("def", 1.2, 3)).toDF("c1", "c2", "c0") + checkAnswer(df1.unionByName(df2.unionByName(df3)), + Row(1, "a", 3.0) :: Row(2, "bc", 1.2) :: Row(3, "def", 1.2) :: Nil + ) + } } test("union by name - check case sensitivity") { @@ -804,7 +805,7 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { StructType(Seq(StructField("topLevelCol", nestedStructType2)))) val union = df1.unionByName(df2, allowMissingColumns = true) - assert(union.schema.toDDL == "`topLevelCol` STRUCT<`b`: STRING, `a`: STRING>") + assert(union.schema.toDDL == "topLevelCol STRUCT") checkAnswer(union, Row(Row("b", null)) :: Row(Row("b", "a")) :: Nil) } @@ -836,15 +837,15 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { StructType(Seq(StructField("topLevelCol", nestedStructType2)))) var unionDf = df1.unionByName(df2, true) - assert(unionDf.schema.toDDL == "`topLevelCol` " + - "STRUCT<`b`: STRUCT<`ba`: STRING, `bb`: STRING>, `a`: STRUCT<`aa`: STRING>>") + assert(unionDf.schema.toDDL == "topLevelCol " + + "STRUCT, a: STRUCT>") checkAnswer(unionDf, Row(Row(Row("ba", null), null)) :: Row(Row(Row(null, "bb"), Row("aa"))) :: Nil) unionDf = df2.unionByName(df1, true) - assert(unionDf.schema.toDDL == "`topLevelCol` STRUCT<`a`: STRUCT<`aa`: STRING>, " + - "`b`: STRUCT<`bb`: STRING, `ba`: STRING>>") + assert(unionDf.schema.toDDL == "topLevelCol STRUCT, " + + "b: STRUCT>") checkAnswer(unionDf, Row(Row(null, Row(null, "ba"))) :: Row(Row(Row("aa"), Row("bb", null))) :: Nil) @@ -1112,13 +1113,13 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { StructType(Seq(StructField("arr", arrayType2)))) var unionDf = df1.unionByName(df2) - assert(unionDf.schema.toDDL == "`arr` ARRAY>") + assert(unionDf.schema.toDDL == "arr ARRAY>") checkAnswer(unionDf, Row(Seq(Row("ba", "bb"))) :: Row(Seq(Row("ba", "bb"))) :: Nil) unionDf = df2.unionByName(df1) - assert(unionDf.schema.toDDL == "`arr` ARRAY>") + assert(unionDf.schema.toDDL == "arr ARRAY>") checkAnswer(unionDf, Row(Seq(Row("bb", "ba"))) :: Row(Seq(Row("bb", "ba"))) :: Nil) @@ -1150,7 +1151,7 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { } unionDf = df3.unionByName(df4, true) - assert(unionDf.schema.toDDL == "`arr` ARRAY>") + assert(unionDf.schema.toDDL == "arr ARRAY>") checkAnswer(unionDf, Row(Seq(Row("ba", null))) :: Row(Seq(Row(null, "bb"))) :: Nil) @@ -1160,7 +1161,7 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { } unionDf = df4.unionByName(df3, true) - assert(unionDf.schema.toDDL == "`arr` ARRAY>") + assert(unionDf.schema.toDDL == "arr ARRAY>") checkAnswer(unionDf, Row(Seq(Row("bb", null))) :: Row(Seq(Row(null, "ba"))) :: Nil) @@ -1196,15 +1197,15 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { StructType(Seq(StructField("topLevelCol", nestedStructType2)))) var unionDf = df1.unionByName(df2) - assert(unionDf.schema.toDDL == "`topLevelCol` " + - "STRUCT<`b`: ARRAY>>") + assert(unionDf.schema.toDDL == "topLevelCol " + + "STRUCT>>") checkAnswer(unionDf, Row(Row(Seq(Row("ba", "bb")))) :: Row(Row(Seq(Row("ba", "bb")))) :: Nil) unionDf = df2.unionByName(df1) - assert(unionDf.schema.toDDL == "`topLevelCol` STRUCT<" + - "`b`: ARRAY>>") + assert(unionDf.schema.toDDL == "topLevelCol STRUCT<" + + "b: ARRAY>>") checkAnswer(unionDf, Row(Row(Seq(Row("bb", "ba")))) :: Row(Row(Seq(Row("bb", "ba")))) :: Nil) @@ -1240,8 +1241,8 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { } unionDf = df3.unionByName(df4, true) - assert(unionDf.schema.toDDL == "`topLevelCol` " + - "STRUCT<`b`: ARRAY>>") + assert(unionDf.schema.toDDL == "topLevelCol " + + "STRUCT>>") checkAnswer(unionDf, Row(Row(Seq(Row("ba", null)))) :: Row(Row(Seq(Row(null, "bb")))) :: Nil) @@ -1251,8 +1252,8 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { } unionDf = df4.unionByName(df3, true) - assert(unionDf.schema.toDDL == "`topLevelCol` STRUCT<" + - "`b`: ARRAY>>") + assert(unionDf.schema.toDDL == "topLevelCol STRUCT<" + + "b: ARRAY>>") checkAnswer(unionDf, Row(Row(Seq(Row("bb", null)))) :: Row(Row(Seq(Row(null, "ba")))) :: Nil) @@ -1292,15 +1293,15 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { StructType(Seq(StructField("topLevelCol", nestedStructType2)))) var unionDf = df1.unionByName(df2) - assert(unionDf.schema.toDDL == "`topLevelCol` " + - "STRUCT<`b`: ARRAY>>>") + assert(unionDf.schema.toDDL == "topLevelCol " + + "STRUCT>>>") checkAnswer(unionDf, Row(Row(Seq(Seq(Row("ba", "bb"))))) :: Row(Row(Seq(Seq(Row("ba", "bb"))))) :: Nil) unionDf = df2.unionByName(df1) - assert(unionDf.schema.toDDL == "`topLevelCol` STRUCT<" + - "`b`: ARRAY>>>") + assert(unionDf.schema.toDDL == "topLevelCol STRUCT<" + + "b: ARRAY>>>") checkAnswer(unionDf, Row(Row(Seq(Seq(Row("bb", "ba"))))) :: Row(Row(Seq(Seq(Row("bb", "ba"))))) :: Nil) @@ -1340,8 +1341,8 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { } unionDf = df3.unionByName(df4, true) - assert(unionDf.schema.toDDL == "`topLevelCol` " + - "STRUCT<`b`: ARRAY>>>") + assert(unionDf.schema.toDDL == "topLevelCol " + + "STRUCT>>>") checkAnswer(unionDf, Row(Row(Seq(Seq(Row("ba", null))))) :: Row(Row(Seq(Seq(Row(null, "bb"))))) :: Nil) @@ -1351,8 +1352,8 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { } unionDf = df4.unionByName(df3, true) - assert(unionDf.schema.toDDL == "`topLevelCol` STRUCT<" + - "`b`: ARRAY>>>") + assert(unionDf.schema.toDDL == "topLevelCol STRUCT<" + + "b: ARRAY>>>") checkAnswer(unionDf, Row(Row(Seq(Seq(Row("bb", null))))) :: Row(Row(Seq(Seq(Row(null, "ba"))))) :: Nil) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 7482d76207388..d4e482540161f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -86,7 +86,9 @@ class DataFrameSuite extends QueryTest test("access complex data") { assert(complexData.filter(complexData("a").getItem(0) === 2).count() == 1) - assert(complexData.filter(complexData("m").getItem("1") === 1).count() == 1) + if (!conf.ansiEnabled) { + assert(complexData.filter(complexData("m").getItem("1") === 1).count() == 1) + } assert(complexData.filter(complexData("s").getField("key") === 1).count() == 1) } @@ -631,7 +633,19 @@ class DataFrameSuite extends QueryTest assert(df.schema.map(_.name) === Seq("key", "value", "newCol")) } - test("withColumns") { + test("withColumns: public API, with Map input") { + val df = testData.toDF().withColumns(Map( + "newCol1" -> (col("key") + 1), "newCol2" -> (col("key") + 2) + )) + checkAnswer( + df, + testData.collect().map { case Row(key: Int, value: String) => + Row(key, value, key + 1, key + 2) + }.toSeq) + assert(df.schema.map(_.name) === Seq("key", "value", "newCol1", "newCol2")) + } + + test("withColumns: internal method") { val df = testData.toDF().withColumns(Seq("newCol1", "newCol2"), Seq(col("key") + 1, col("key") + 2)) checkAnswer( @@ -655,7 +669,7 @@ class DataFrameSuite extends QueryTest assert(err2.getMessage.contains("Found duplicate column(s)")) } - test("withColumns: case sensitive") { + test("withColumns: internal method, case sensitive") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { val df = testData.toDF().withColumns(Seq("newCol1", "newCOL1"), Seq(col("key") + 1, col("key") + 2)) @@ -674,7 +688,7 @@ class DataFrameSuite extends QueryTest } } - test("withColumns: given metadata") { + test("withColumns: internal method, given metadata") { def buildMetadata(num: Int): Seq[Metadata] = { (0 until num).map { n => val builder = new MetadataBuilder @@ -928,29 +942,33 @@ class DataFrameSuite extends QueryTest def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name) - val describeAllCols = person2.describe() - assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "name", "age", "height")) - checkAnswer(describeAllCols, describeResult) - // All aggregate value should have been cast to string - describeAllCols.collect().foreach { row => - row.toSeq.foreach { value => - if (value != null) { - assert(value.isInstanceOf[String], "expected string but found " + value.getClass) + Seq("true", "false").foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled) { + val describeAllCols = person2.describe() + assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "name", "age", "height")) + checkAnswer(describeAllCols, describeResult) + // All aggregate value should have been cast to string + describeAllCols.collect().foreach { row => + row.toSeq.foreach { value => + if (value != null) { + assert(value.isInstanceOf[String], "expected string but found " + value.getClass) + } + } } - } - } - val describeOneCol = person2.describe("age") - assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age")) - checkAnswer(describeOneCol, describeResult.map { case Row(s, _, d, _) => Row(s, d)} ) + val describeOneCol = person2.describe("age") + assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age")) + checkAnswer(describeOneCol, describeResult.map { case Row(s, _, d, _) => Row(s, d) }) - val describeNoCol = person2.select().describe() - assert(getSchemaAsSeq(describeNoCol) === Seq("summary")) - checkAnswer(describeNoCol, describeResult.map { case Row(s, _, _, _) => Row(s)} ) + val describeNoCol = person2.select().describe() + assert(getSchemaAsSeq(describeNoCol) === Seq("summary")) + checkAnswer(describeNoCol, describeResult.map { case Row(s, _, _, _) => Row(s) }) - val emptyDescription = person2.limit(0).describe() - assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height")) - checkAnswer(emptyDescription, emptyDescribeResult) + val emptyDescription = person2.limit(0).describe() + assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height")) + checkAnswer(emptyDescription, emptyDescribeResult) + } + } } test("summary") { @@ -976,30 +994,34 @@ class DataFrameSuite extends QueryTest def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name) - val summaryAllCols = person2.summary() - - assert(getSchemaAsSeq(summaryAllCols) === Seq("summary", "name", "age", "height")) - checkAnswer(summaryAllCols, summaryResult) - // All aggregate value should have been cast to string - summaryAllCols.collect().foreach { row => - row.toSeq.foreach { value => - if (value != null) { - assert(value.isInstanceOf[String], "expected string but found " + value.getClass) + Seq("true", "false").foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled) { + val summaryAllCols = person2.summary() + + assert(getSchemaAsSeq(summaryAllCols) === Seq("summary", "name", "age", "height")) + checkAnswer(summaryAllCols, summaryResult) + // All aggregate value should have been cast to string + summaryAllCols.collect().foreach { row => + row.toSeq.foreach { value => + if (value != null) { + assert(value.isInstanceOf[String], "expected string but found " + value.getClass) + } + } } - } - } - val summaryOneCol = person2.select("age").summary() - assert(getSchemaAsSeq(summaryOneCol) === Seq("summary", "age")) - checkAnswer(summaryOneCol, summaryResult.map { case Row(s, _, d, _) => Row(s, d)} ) + val summaryOneCol = person2.select("age").summary() + assert(getSchemaAsSeq(summaryOneCol) === Seq("summary", "age")) + checkAnswer(summaryOneCol, summaryResult.map { case Row(s, _, d, _) => Row(s, d) }) - val summaryNoCol = person2.select().summary() - assert(getSchemaAsSeq(summaryNoCol) === Seq("summary")) - checkAnswer(summaryNoCol, summaryResult.map { case Row(s, _, _, _) => Row(s)} ) + val summaryNoCol = person2.select().summary() + assert(getSchemaAsSeq(summaryNoCol) === Seq("summary")) + checkAnswer(summaryNoCol, summaryResult.map { case Row(s, _, _, _) => Row(s) }) - val emptyDescription = person2.limit(0).summary() - assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height")) - checkAnswer(emptyDescription, emptySummaryResult) + val emptyDescription = person2.limit(0).summary() + assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height")) + checkAnswer(emptyDescription, emptySummaryResult) + } + } } test("SPARK-34165: Add count_distinct to summary") { @@ -1543,7 +1565,9 @@ class DataFrameSuite extends QueryTest test("SPARK-7133: Implement struct, array, and map field accessor") { assert(complexData.filter(complexData("a")(0) === 2).count() == 1) - assert(complexData.filter(complexData("m")("1") === 1).count() == 1) + if (!conf.ansiEnabled) { + assert(complexData.filter(complexData("m")("1") === 1).count() == 1) + } assert(complexData.filter(complexData("s")("key") === 1).count() == 1) assert(complexData.filter(complexData("m")(complexData("s")("value")) === 1).count() == 1) assert(complexData.filter(complexData("a")(complexData("s")("key")) === 1).count() == 1) @@ -2438,8 +2462,10 @@ class DataFrameSuite extends QueryTest val aggPlusSort2 = df.groupBy(col("name")).agg(count(col("name"))).orderBy(col("name")) checkAnswer(aggPlusSort1, aggPlusSort2.collect()) - val aggPlusFilter1 = df.groupBy(df("name")).agg(count(df("name"))).filter(df("name") === 0) - val aggPlusFilter2 = df.groupBy(col("name")).agg(count(col("name"))).filter(col("name") === 0) + val aggPlusFilter1 = + df.groupBy(df("name")).agg(count(df("name"))).filter(df("name") === "test1") + val aggPlusFilter2 = + df.groupBy(col("name")).agg(count(col("name"))).filter(col("name") === "test1") checkAnswer(aggPlusFilter1, aggPlusFilter2.collect()) } } @@ -3087,6 +3113,89 @@ class DataFrameSuite extends QueryTest assert(res.collect.length == 2) } + + test("SPARK-38285: Fix ClassCastException: GenericArrayData cannot be cast to InternalRow") { + withTempView("v1") { + val sqlText = + """ + |CREATE OR REPLACE TEMP VIEW v1 AS + |SELECT * FROM VALUES + |(array( + | named_struct('s', 'string1', 'b', array(named_struct('e', 'string2'))), + | named_struct('s', 'string4', 'b', array(named_struct('e', 'string5'))) + | ) + |) + |v1(o); + |""".stripMargin + sql(sqlText) + + val df = sql("SELECT eo.b.e FROM (SELECT explode(o) AS eo FROM v1)") + checkAnswer(df, Row(Seq("string2")) :: Row(Seq("string5")) :: Nil) + } + } + + test("SPARK-37865: Do not deduplicate union output columns") { + val df1 = Seq((1, 1), (1, 2)).toDF("a", "b") + val df2 = Seq((2, 2), (2, 3)).toDF("c", "d") + + def sqlQuery(cols1: Seq[String], cols2: Seq[String], distinct: Boolean): String = { + val union = if (distinct) { + "UNION" + } else { + "UNION ALL" + } + s""" + |SELECT ${cols1.mkString(",")} FROM VALUES (1, 1), (1, 2) AS t1(a, b) + |$union SELECT ${cols2.mkString(",")} FROM VALUES (2, 2), (2, 3) AS t2(c, d) + |""".stripMargin + } + + Seq( + (Seq("a", "a"), Seq("c", "d"), Seq(Row(1, 1), Row(1, 1), Row(2, 2), Row(2, 3))), + (Seq("a", "b"), Seq("c", "d"), Seq(Row(1, 1), Row(1, 2), Row(2, 2), Row(2, 3))), + (Seq("a", "b"), Seq("c", "c"), Seq(Row(1, 1), Row(1, 2), Row(2, 2), Row(2, 2))) + ).foreach { case (cols1, cols2, rows) => + // UNION ALL (non-distinct) + val df3 = df1.selectExpr(cols1: _*).union(df2.selectExpr(cols2: _*)) + checkAnswer(df3, rows) + + val t3 = sqlQuery(cols1, cols2, false) + checkAnswer(sql(t3), rows) + + // Avoid breaking change + var correctAnswer = rows.map(r => Row(r(0))) + checkAnswer(df3.select(df1.col("a")), correctAnswer) + checkAnswer(sql(s"select a from ($t3) t3"), correctAnswer) + + // This has always been broken + intercept[AnalysisException] { + df3.select(df2.col("d")).collect() + } + intercept[AnalysisException] { + sql(s"select d from ($t3) t3") + } + + // UNION (distinct) + val df4 = df3.distinct + checkAnswer(df4, rows.distinct) + + val t4 = sqlQuery(cols1, cols2, true) + checkAnswer(sql(t4), rows.distinct) + + // Avoid breaking change + correctAnswer = rows.distinct.map(r => Row(r(0))) + checkAnswer(df4.select(df1.col("a")), correctAnswer) + checkAnswer(sql(s"select a from ($t4) t4"), correctAnswer) + + // This has always been broken + intercept[AnalysisException] { + df4.select(df2.col("d")).collect() + } + intercept[AnalysisException] { + sql(s"select d from ($t4) t4") + } + } + } } case class GroupByKey(a: Int, b: Int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index c385d9f58cc84..bd39453f5120e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql import java.time.LocalDateTime +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, Filter} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -490,4 +491,88 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSparkSession { assert(attributeReference.dataType == tuple._2) } } + + test("No need to filter windows when windowDuration is multiple of slideDuration") { + val df1 = Seq( + ("2022-02-15 19:39:34", 1, "a"), + ("2022-02-15 19:39:56", 2, "a"), + ("2022-02-15 19:39:27", 4, "b")).toDF("time", "value", "id") + .select(window($"time", "9 seconds", "3 seconds", "0 second"), $"value") + .orderBy($"window.start".asc, $"value".desc).select("value") + val df2 = Seq( + (LocalDateTime.parse("2022-02-15T19:39:34"), 1, "a"), + (LocalDateTime.parse("2022-02-15T19:39:56"), 2, "a"), + (LocalDateTime.parse("2022-02-15T19:39:27"), 4, "b")).toDF("time", "value", "id") + .select(window($"time", "9 seconds", "3 seconds", "0 second"), $"value") + .orderBy($"window.start".asc, $"value".desc).select("value") + + val df3 = Seq( + ("2022-02-15 19:39:34", 1, "a"), + ("2022-02-15 19:39:56", 2, "a"), + ("2022-02-15 19:39:27", 4, "b")).toDF("time", "value", "id") + .select(window($"time", "9 seconds", "3 seconds", "-2 second"), $"value") + .orderBy($"window.start".asc, $"value".desc).select("value") + val df4 = Seq( + (LocalDateTime.parse("2022-02-15T19:39:34"), 1, "a"), + (LocalDateTime.parse("2022-02-15T19:39:56"), 2, "a"), + (LocalDateTime.parse("2022-02-15T19:39:27"), 4, "b")).toDF("time", "value", "id") + .select(window($"time", "9 seconds", "3 seconds", "2 second"), $"value") + .orderBy($"window.start".asc, $"value".desc).select("value") + + Seq(df1, df2, df3, df4).foreach { df => + val filter = df.queryExecution.optimizedPlan.find(_.isInstanceOf[Filter]) + assert(filter.isDefined) + val exist = filter.get.constraints.filter(e => + e.toString.contains(">=") || e.toString.contains("<")) + assert(exist.isEmpty, "No need to filter windows " + + "when windowDuration is multiple of slideDuration") + } + } + + test("SPARK-38227: 'start' and 'end' fields should be nullable") { + // We expect the fields in window struct as nullable since the dataType of TimeWindow defines + // them as nullable. The rule 'TimeWindowing' should respect the dataType. + val df1 = Seq( + ("2016-03-27 09:00:05", 1), + ("2016-03-27 09:00:32", 2)).toDF("time", "value") + val df2 = Seq( + (LocalDateTime.parse("2016-03-27T09:00:05"), 1), + (LocalDateTime.parse("2016-03-27T09:00:32"), 2)).toDF("time", "value") + + def validateWindowColumnInSchema(schema: StructType, colName: String): Unit = { + schema.find(_.name == colName) match { + case Some(StructField(_, st: StructType, _, _)) => + assertFieldInWindowStruct(st, "start") + assertFieldInWindowStruct(st, "end") + + case _ => fail("Failed to find suitable window column from DataFrame!") + } + } + + def assertFieldInWindowStruct(windowType: StructType, fieldName: String): Unit = { + val field = windowType.fields.find(_.name == fieldName) + assert(field.isDefined, s"'$fieldName' field should exist in window struct") + assert(field.get.nullable, s"'$fieldName' field should be nullable") + } + + for { + df <- Seq(df1, df2) + nullable <- Seq(true, false) + } { + val dfWithDesiredNullability = new DataFrame(df.queryExecution, RowEncoder( + StructType(df.schema.fields.map(_.copy(nullable = nullable))))) + // tumbling windows + val windowedProject = dfWithDesiredNullability + .select(window($"time", "10 seconds").as("window"), $"value") + val schema = windowedProject.queryExecution.optimizedPlan.schema + validateWindowColumnInSchema(schema, "window") + + // sliding windows + val windowedProject2 = dfWithDesiredNullability + .select(window($"time", "10 seconds", "3 seconds").as("window"), + $"value") + val schema2 = windowedProject2.queryExecution.optimizedPlan.schema + validateWindowColumnInSchema(schema2, "window") + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 1491c5a4f26b1..11b2309ee38eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -20,9 +20,12 @@ package org.apache.spark.sql import org.scalatest.matchers.must.Matchers.the import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.catalyst.optimizer.TransposeWindow +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.execution.exchange.Exchange +import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, Exchange, ShuffleExchangeExec} +import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction, Window} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -94,7 +97,8 @@ class DataFrameWindowFunctionsSuite extends QueryTest } test("corr, covar_pop, stddev_pop functions in specific window") { - withSQLConf(SQLConf.LEGACY_STATISTICAL_AGGREGATE.key -> "true") { + withSQLConf(SQLConf.LEGACY_STATISTICAL_AGGREGATE.key -> "true", + SQLConf.ANSI_ENABLED.key -> "false") { val df = Seq( ("a", "p1", 10.0, 20.0), ("b", "p1", 20.0, 10.0), @@ -147,7 +151,8 @@ class DataFrameWindowFunctionsSuite extends QueryTest test("SPARK-13860: " + "corr, covar_pop, stddev_pop functions in specific window " + "LEGACY_STATISTICAL_AGGREGATE off") { - withSQLConf(SQLConf.LEGACY_STATISTICAL_AGGREGATE.key -> "false") { + withSQLConf(SQLConf.LEGACY_STATISTICAL_AGGREGATE.key -> "false", + SQLConf.ANSI_ENABLED.key -> "false") { val df = Seq( ("a", "p1", 10.0, 20.0), ("b", "p1", 20.0, 10.0), @@ -404,22 +409,24 @@ class DataFrameWindowFunctionsSuite extends QueryTest } test("numerical aggregate functions on string column") { - val df = Seq((1, "a", "b")).toDF("key", "value1", "value2") - checkAnswer( - df.select($"key", - var_pop("value1").over(), - variance("value1").over(), - stddev_pop("value1").over(), - stddev("value1").over(), - sum("value1").over(), - mean("value1").over(), - avg("value1").over(), - corr("value1", "value2").over(), - covar_pop("value1", "value2").over(), - covar_samp("value1", "value2").over(), - skewness("value1").over(), - kurtosis("value1").over()), - Seq(Row(1, null, null, null, null, null, null, null, null, null, null, null, null))) + if (!conf.ansiEnabled) { + val df = Seq((1, "a", "b")).toDF("key", "value1", "value2") + checkAnswer( + df.select($"key", + var_pop("value1").over(), + variance("value1").over(), + stddev_pop("value1").over(), + stddev("value1").over(), + sum("value1").over(), + mean("value1").over(), + avg("value1").over(), + corr("value1", "value2").over(), + covar_pop("value1", "value2").over(), + covar_samp("value1", "value2").over(), + skewness("value1").over(), + kurtosis("value1").over()), + Seq(Row(1, null, null, null, null, null, null, null, null, null, null, null, null))) + } } test("statistical functions") { @@ -1071,4 +1078,48 @@ class DataFrameWindowFunctionsSuite extends QueryTest Row("a", 1, "x", "x"), Row("b", 0, null, null))) } + + test("SPARK-38237: require all cluster keys for child required distribution for window query") { + def partitionExpressionsColumns(expressions: Seq[Expression]): Seq[String] = { + expressions.flatMap { + case ref: AttributeReference => Some(ref.name) + } + } + + def isShuffleExecByRequirement( + plan: ShuffleExchangeExec, + desiredClusterColumns: Seq[String]): Boolean = plan match { + case ShuffleExchangeExec(op: HashPartitioning, _, ENSURE_REQUIREMENTS) => + partitionExpressionsColumns(op.expressions) === desiredClusterColumns + case _ => false + } + + val df = Seq(("a", 1, 1), ("a", 2, 2), ("b", 1, 3), ("b", 1, 4)).toDF("key1", "key2", "value") + val windowSpec = Window.partitionBy("key1", "key2").orderBy("value") + + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION.key -> "true") { + + val windowed = df + // repartition by subset of window partitionBy keys which satisfies ClusteredDistribution + .repartition($"key1") + .select( + lead($"key1", 1).over(windowSpec), + lead($"value", 1).over(windowSpec)) + + checkAnswer(windowed, Seq(Row("b", 4), Row(null, null), Row(null, null), Row(null, null))) + + val shuffleByRequirement = windowed.queryExecution.executedPlan.exists { + case w: WindowExec => + w.child.exists { + case s: ShuffleExchangeExec => isShuffleExecByRequirement(s, Seq("key1", "key2")) + case _ => false + } + case _ => false + } + + assert(shuffleByRequirement, "Can't find desired shuffle node from the query plan") + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index 009ccb9a45354..2f4098d7cc7eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -250,7 +250,7 @@ class DatasetCacheSuite extends QueryTest case i: InMemoryRelation => i.cacheBuilder.cachedPlan } assert(df2LimitInnerPlan.isDefined && - df2LimitInnerPlan.get.find(_.isInstanceOf[InMemoryTableScanExec]).isEmpty) + !df2LimitInnerPlan.get.exists(_.isInstanceOf[InMemoryTableScanExec])) } test("SPARK-27739 Save stats from optimized plan") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 2ce0754a5d1e7..c846441e9e009 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -48,10 +48,12 @@ object TestForTypeAlias { type TwoInt = (Int, Int) type ThreeInt = (TwoInt, Int) type SeqOfTwoInt = Seq[TwoInt] + type IntArray = Array[Int] def tupleTypeAlias: TwoInt = (1, 1) def nestedTupleTypeAlias: ThreeInt = ((1, 1), 2) def seqOfTupleTypeAlias: SeqOfTwoInt = Seq((1, 1), (2, 2)) + def aliasedArrayInTuple: (Int, IntArray) = (1, Array(1)) } class DatasetSuite extends QueryTest @@ -1647,6 +1649,12 @@ class DatasetSuite extends QueryTest ("", Seq((1, 1), (2, 2)))) } + test("SPARK-38042: Dataset should work with a product containing an aliased array type") { + checkDataset( + Seq(1).toDS().map(_ => ("", TestForTypeAlias.aliasedArrayInTuple)), + ("", (1, Array(1)))) + } + test("Check RelationalGroupedDataset toString: Single data") { val kvDataset = (1 to 3).toDF("id").groupBy("id") val expected = "RelationalGroupedDataset: [" + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 543f845aff735..fa246fa79b33c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -23,7 +23,7 @@ import java.time.{Instant, LocalDateTime, ZoneId} import java.util.{Locale, TimeZone} import java.util.concurrent.TimeUnit -import org.apache.spark.{SparkException, SparkUpgradeException} +import org.apache.spark.{SparkConf, SparkException, SparkUpgradeException} import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{CEST, LA} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions._ @@ -35,6 +35,10 @@ import org.apache.spark.unsafe.types.CalendarInterval class DateFunctionsSuite extends QueryTest with SharedSparkSession { import testImplicits._ + // The test cases which throw exceptions under ANSI mode are covered by date.sql and + // datetime-parsing-invalid.sql in org.apache.spark.sql.SQLQueryTestSuite. + override def sparkConf: SparkConf = super.sparkConf.set(SQLConf.ANSI_ENABLED.key, "false") + test("function current_date") { val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") val d0 = DateTimeUtils.currentDate(ZoneId.systemDefault()) @@ -512,7 +516,7 @@ class DateFunctionsSuite extends QueryTest with SharedSparkSession { Seq(Row(null), Row(null), Row(null))) val e = intercept[SparkUpgradeException](df.select(to_date(col("s"), "yyyy-dd-aa")).collect()) assert(e.getCause.isInstanceOf[IllegalArgumentException]) - assert(e.getMessage.contains("You may get a different result due to the upgrading of Spark")) + assert(e.getMessage.contains("You may get a different result due to the upgrading to Spark")) // February val x1 = "2016-02-29" @@ -695,7 +699,7 @@ class DateFunctionsSuite extends QueryTest with SharedSparkSession { val e = intercept[SparkUpgradeException](invalid.collect()) assert(e.getCause.isInstanceOf[IllegalArgumentException]) assert( - e.getMessage.contains("You may get a different result due to the upgrading of Spark")) + e.getMessage.contains("You may get a different result due to the upgrading to Spark")) } // February diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala index 9cef2553b365b..61885169ece4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala @@ -207,10 +207,10 @@ abstract class DynamicPartitionPruningSuiteBase case _: ReusedExchangeExec => // reuse check ok. case BroadcastQueryStageExec(_, _: ReusedExchangeExec, _) => // reuse check ok. case b: BroadcastExchangeLike => - val hasReuse = plan.find { + val hasReuse = plan.exists { case ReusedExchangeExec(_, e) => e eq b case _ => false - }.isDefined + } assert(hasReuse, s"$s\nshould have been reused in\n$plan") case a: AdaptiveSparkPlanExec => val broadcastQueryStage = collectFirst(a) { @@ -234,7 +234,7 @@ abstract class DynamicPartitionPruningSuiteBase case r: ReusedSubqueryExec => r.child case o => o } - assert(subquery.find(_.isInstanceOf[AdaptiveSparkPlanExec]).isDefined == isMainQueryAdaptive) + assert(subquery.exists(_.isInstanceOf[AdaptiveSparkPlanExec]) == isMainQueryAdaptive) } } @@ -344,12 +344,12 @@ abstract class DynamicPartitionPruningSuiteBase | ) """.stripMargin) - val found = df.queryExecution.executedPlan.find { + val found = df.queryExecution.executedPlan.exists { case BroadcastHashJoinExec(_, _, p: ExistenceJoin, _, _, _, _, _) => true case _ => false } - assert(found.isEmpty) + assert(!found) } } @@ -1153,7 +1153,8 @@ abstract class DynamicPartitionPruningSuiteBase test("join key with multiple references on the filtering plan") { withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true", - SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName + SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName, + SQLConf.ANSI_ENABLED.key -> "false" // ANSI mode doesn't support "String + String" ) { // when enable AQE, the reusedExchange is inserted when executed. withTable("fact", "dim") { @@ -1482,6 +1483,51 @@ abstract class DynamicPartitionPruningSuiteBase checkAnswer(df, Row(1150, 1) :: Row(1130, 4) :: Row(1140, 4) :: Nil) } } + + test("SPARK-38148: Do not add dynamic partition pruning if there exists static partition " + + "pruning") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") { + Seq( + "f.store_id = 1" -> false, + "1 = f.store_id" -> false, + "f.store_id <=> 1" -> false, + "1 <=> f.store_id" -> false, + "f.store_id > 1" -> true, + "5 > f.store_id" -> true).foreach { case (condition, hasDPP) => + // partitioned table at left side + val df1 = sql( + s""" + |SELECT /*+ broadcast(s) */ * FROM fact_sk f + |JOIN dim_store s ON f.store_id = s.store_id AND $condition + """.stripMargin) + checkPartitionPruningPredicate(df1, false, withBroadcast = hasDPP) + + val df2 = sql( + s""" + |SELECT /*+ broadcast(s) */ * FROM fact_sk f + |JOIN dim_store s ON f.store_id = s.store_id + |WHERE $condition + """.stripMargin) + checkPartitionPruningPredicate(df2, false, withBroadcast = hasDPP) + + // partitioned table at right side + val df3 = sql( + s""" + |SELECT /*+ broadcast(s) */ * FROM dim_store s + |JOIN fact_sk f ON f.store_id = s.store_id AND $condition + """.stripMargin) + checkPartitionPruningPredicate(df3, false, withBroadcast = hasDPP) + + val df4 = sql( + s""" + |SELECT /*+ broadcast(s) */ * FROM dim_store s + |JOIN fact_sk f ON f.store_id = s.store_id + |WHERE $condition + """.stripMargin) + checkPartitionPruningPredicate(df4, false, withBroadcast = hasDPP) + } + } + } } abstract class DynamicPartitionPruningDataSourceSuiteBase @@ -1514,14 +1560,14 @@ abstract class DynamicPartitionPruningDataSourceSuiteBase } // search dynamic pruning predicates on the executed plan val plan = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.executedPlan - val ret = plan.find { + val ret = plan.exists { case s: FileSourceScanExec => s.partitionFilters.exists { case _: DynamicPruningExpression => true case _ => false } case _ => false } - assert(ret.isDefined == false) + assert(!ret) } } } @@ -1561,10 +1607,10 @@ abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningDat val scanOption = find(plan) { case s: FileSourceScanExec => - s.output.exists(_.find(_.argString(maxFields = 100).contains("fid")).isDefined) + s.output.exists(_.exists(_.argString(maxFields = 100).contains("fid"))) case s: BatchScanExec => // we use f1 col for v2 tables due to schema pruning - s.output.exists(_.find(_.argString(maxFields = 100).contains("f1")).isDefined) + s.output.exists(_.exists(_.argString(maxFields = 100).contains("f1"))) case _ => false } assert(scanOption.isDefined) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala index 44d0445928b90..073b67e0472bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala @@ -106,7 +106,7 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite keywords = "InMemoryRelation", "StorageLevel(disk, memory, deserialized, 1 replicas)") } - test("optimized plan should show the rewritten aggregate expression") { + test("optimized plan should show the rewritten expression") { withTempView("test_agg") { sql( """ @@ -125,6 +125,13 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite "Aggregate [k#x], [k#x, every(v#x) AS every(v)#x, some(v#x) AS some(v)#x, " + "any(v#x) AS any(v)#x]") } + + withTable("t") { + sql("CREATE TABLE t(col TIMESTAMP) USING parquet") + val df = sql("SELECT date_part('month', col) FROM t") + checkKeywordsExistsInExplain(df, + "Project [month(cast(col#x as date)) AS date_part(month, col)#x]") + } } test("explain inline tables cross-joins") { @@ -217,8 +224,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite // AND conjunction // OR disjunction // --------------------------------------------------------------------------------------- - checkKeywordsExistsInExplain(sql("select 'a' || 1 + 2"), - "Project [null AS (concat(a, 1) + 2)#x]") + checkKeywordsExistsInExplain(sql("select '1' || 1 + 2"), + "Project [13", " AS (concat(1, 1) + 2)#x") checkKeywordsExistsInExplain(sql("select 1 - 2 || 'b'"), "Project [-1b AS concat((1 - 2), b)#x]") checkKeywordsExistsInExplain(sql("select 2 * 4 + 3 || 'b'"), @@ -232,12 +239,11 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite } test("explain for these functions; use range to avoid constant folding") { - val df = sql("select ifnull(id, 'x'), nullif(id, 'x'), nvl(id, 'x'), nvl2(id, 'x', 'y') " + + val df = sql("select ifnull(id, 1), nullif(id, 1), nvl(id, 1), nvl2(id, 1, 2) " + "from range(2)") checkKeywordsExistsInExplain(df, - "Project [cast(id#xL as string) AS ifnull(id, x)#x, " + - "id#xL AS nullif(id, x)#xL, cast(id#xL as string) AS nvl(id, x)#x, " + - "x AS nvl2(id, x, y)#x]") + "Project [id#xL AS ifnull(id, 1)#xL, if ((id#xL = 1)) null " + + "else id#xL AS nullif(id, 1)#xL, id#xL AS nvl(id, 1)#xL, 1 AS nvl2(id, 1, 2)#x]") } test("SPARK-26659: explain of DataWritingCommandExec should not contain duplicate cmd.nodeName") { @@ -594,7 +600,7 @@ class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuit } test("SPARK-35884: Explain should only display one plan before AQE takes effect") { - val df = (0 to 10).toDF("id").where('id > 5) + val df = (0 to 10).toDF("id").where(Symbol("id") > 5) val modes = Seq(SimpleMode, ExtendedMode, CostMode, FormattedMode) modes.foreach { mode => checkKeywordsExistsInExplain(df, mode, "AdaptiveSparkPlan") @@ -609,7 +615,8 @@ class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuit test("SPARK-35884: Explain formatted with subquery") { withTempView("t1", "t2") { - spark.range(100).select('id % 10 as "key", 'id as "value").createOrReplaceTempView("t1") + spark.range(100).select(Symbol("id") % 10 as "key", Symbol("id") as "value") + .createOrReplaceTempView("t1") spark.range(10).createOrReplaceTempView("t2") val query = """ @@ -723,6 +730,35 @@ class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuit assert(inMemoryRelationNodeId != columnarToRowNodeId) } } + + test("SPARK-38232: Explain formatted does not collect subqueries under query stage in AQE") { + withTable("t") { + sql("CREATE TABLE t USING PARQUET AS SELECT 1 AS c") + val expected = + "Subquery:1 Hosting operator id = 2 Hosting Expression = Subquery subquery#x, [id=#x]" + val df = sql("SELECT count(s) FROM (SELECT (SELECT c FROM t) as s)") + df.collect() + withNormalizedExplain(df, FormattedMode) { output => + assert(output.contains(expected)) + } + } + } + + test("SPARK-38322: Support query stage show runtime statistics in formatted explain mode") { + val df = Seq(1, 2).toDF("c").distinct() + val statistics = "Statistics(sizeInBytes=32.0 B, rowCount=2)" + + checkKeywordsNotExistsInExplain( + df, + FormattedMode, + statistics) + + df.collect() + checkKeywordsExistsInExplain( + df, + FormattedMode, + statistics) + } } case class ExplainSingleData(id: Int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 518090877e633..bc7a7b2977aca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -534,6 +534,64 @@ class FileBasedDataSourceSuite extends QueryTest } } + test("SPARK-30362: test input metrics for DSV2") { + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { + Seq("json", "orc", "parquet").foreach { format => + withTempPath { path => + val dir = path.getCanonicalPath + spark.range(0, 10).write.format(format).save(dir) + val df = spark.read.format(format).load(dir) + val bytesReads = new mutable.ArrayBuffer[Long]() + val recordsRead = new mutable.ArrayBuffer[Long]() + val bytesReadListener = new SparkListener() { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + bytesReads += taskEnd.taskMetrics.inputMetrics.bytesRead + recordsRead += taskEnd.taskMetrics.inputMetrics.recordsRead + } + } + sparkContext.addSparkListener(bytesReadListener) + try { + df.collect() + sparkContext.listenerBus.waitUntilEmpty() + assert(bytesReads.sum > 0) + assert(recordsRead.sum == 10) + } finally { + sparkContext.removeSparkListener(bytesReadListener) + } + } + } + } + } + + test("SPARK-37585: test input metrics for DSV2 with output limits") { + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { + Seq("json", "orc", "parquet").foreach { format => + withTempPath { path => + val dir = path.getCanonicalPath + spark.range(0, 100).write.format(format).save(dir) + val df = spark.read.format(format).load(dir) + val bytesReads = new mutable.ArrayBuffer[Long]() + val recordsRead = new mutable.ArrayBuffer[Long]() + val bytesReadListener = new SparkListener() { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + bytesReads += taskEnd.taskMetrics.inputMetrics.bytesRead + recordsRead += taskEnd.taskMetrics.inputMetrics.recordsRead + } + } + sparkContext.addSparkListener(bytesReadListener) + try { + df.limit(10).collect() + sparkContext.listenerBus.waitUntilEmpty() + assert(bytesReads.sum > 0) + assert(recordsRead.sum > 0) + } finally { + sparkContext.removeSparkListener(bytesReadListener) + } + } + } + } + } + test("Do not use cache on overwrite") { Seq("", "orc").foreach { useV1SourceReaderList => withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> useV1SourceReaderList) { @@ -778,9 +836,10 @@ class FileBasedDataSourceSuite extends QueryTest } assert(filterCondition.isDefined) // The partitions filters should be pushed down and no need to be reevaluated. - assert(filterCondition.get.collectFirst { - case a: AttributeReference if a.name == "p1" || a.name == "p2" => a - }.isEmpty) + assert(!filterCondition.get.exists { + case a: AttributeReference => a.name == "p1" || a.name == "p2" + case _ => false + }) val fileScan = df.queryExecution.executedPlan collectFirst { case BatchScanExec(_, f: FileScan, _) => f @@ -909,52 +968,57 @@ class FileBasedDataSourceSuite extends QueryTest // cases when value == MAX var v = Short.MaxValue - checkPushedFilters(format, df.where('id > v.toInt), Array(), noScan = true) - checkPushedFilters(format, df.where('id >= v.toInt), Array(sources.IsNotNull("id"), - sources.EqualTo("id", v))) - checkPushedFilters(format, df.where('id === v.toInt), Array(sources.IsNotNull("id"), - sources.EqualTo("id", v))) - checkPushedFilters(format, df.where('id <=> v.toInt), + checkPushedFilters(format, df.where(Symbol("id") > v.toInt), Array(), noScan = true) + checkPushedFilters(format, df.where(Symbol("id") >= v.toInt), + Array(sources.IsNotNull("id"), sources.EqualTo("id", v))) + checkPushedFilters(format, df.where(Symbol("id") === v.toInt), + Array(sources.IsNotNull("id"), sources.EqualTo("id", v))) + checkPushedFilters(format, df.where(Symbol("id") <=> v.toInt), Array(sources.EqualNullSafe("id", v))) - checkPushedFilters(format, df.where('id <= v.toInt), Array(sources.IsNotNull("id"))) - checkPushedFilters(format, df.where('id < v.toInt), Array(sources.IsNotNull("id"), - sources.Not(sources.EqualTo("id", v)))) + checkPushedFilters(format, df.where(Symbol("id") <= v.toInt), + Array(sources.IsNotNull("id"))) + checkPushedFilters(format, df.where(Symbol("id") < v.toInt), + Array(sources.IsNotNull("id"), sources.Not(sources.EqualTo("id", v)))) // cases when value > MAX var v1: Int = positiveInt - checkPushedFilters(format, df.where('id > v1), Array(), noScan = true) - checkPushedFilters(format, df.where('id >= v1), Array(), noScan = true) - checkPushedFilters(format, df.where('id === v1), Array(), noScan = true) - checkPushedFilters(format, df.where('id <=> v1), Array(), noScan = true) - checkPushedFilters(format, df.where('id <= v1), Array(sources.IsNotNull("id"))) - checkPushedFilters(format, df.where('id < v1), Array(sources.IsNotNull("id"))) + checkPushedFilters(format, df.where(Symbol("id") > v1), Array(), noScan = true) + checkPushedFilters(format, df.where(Symbol("id") >= v1), Array(), noScan = true) + checkPushedFilters(format, df.where(Symbol("id") === v1), Array(), noScan = true) + checkPushedFilters(format, df.where(Symbol("id") <=> v1), Array(), noScan = true) + checkPushedFilters(format, df.where(Symbol("id") <= v1), Array(sources.IsNotNull("id"))) + checkPushedFilters(format, df.where(Symbol("id") < v1), Array(sources.IsNotNull("id"))) // cases when value = MIN v = Short.MinValue - checkPushedFilters(format, df.where(lit(v.toInt) < 'id), Array(sources.IsNotNull("id"), - sources.Not(sources.EqualTo("id", v)))) - checkPushedFilters(format, df.where(lit(v.toInt) <= 'id), Array(sources.IsNotNull("id"))) - checkPushedFilters(format, df.where(lit(v.toInt) === 'id), Array(sources.IsNotNull("id"), + checkPushedFilters(format, df.where(lit(v.toInt) < Symbol("id")), + Array(sources.IsNotNull("id"), sources.Not(sources.EqualTo("id", v)))) + checkPushedFilters(format, df.where(lit(v.toInt) <= Symbol("id")), + Array(sources.IsNotNull("id"))) + checkPushedFilters(format, df.where(lit(v.toInt) === Symbol("id")), + Array(sources.IsNotNull("id"), sources.EqualTo("id", v))) - checkPushedFilters(format, df.where(lit(v.toInt) <=> 'id), + checkPushedFilters(format, df.where(lit(v.toInt) <=> Symbol("id")), Array(sources.EqualNullSafe("id", v))) - checkPushedFilters(format, df.where(lit(v.toInt) >= 'id), Array(sources.IsNotNull("id"), - sources.EqualTo("id", v))) - checkPushedFilters(format, df.where(lit(v.toInt) > 'id), Array(), noScan = true) + checkPushedFilters(format, df.where(lit(v.toInt) >= Symbol("id")), + Array(sources.IsNotNull("id"), sources.EqualTo("id", v))) + checkPushedFilters(format, df.where(lit(v.toInt) > Symbol("id")), Array(), noScan = true) // cases when value < MIN v1 = negativeInt - checkPushedFilters(format, df.where(lit(v1) < 'id), Array(sources.IsNotNull("id"))) - checkPushedFilters(format, df.where(lit(v1) <= 'id), Array(sources.IsNotNull("id"))) - checkPushedFilters(format, df.where(lit(v1) === 'id), Array(), noScan = true) - checkPushedFilters(format, df.where(lit(v1) >= 'id), Array(), noScan = true) - checkPushedFilters(format, df.where(lit(v1) > 'id), Array(), noScan = true) + checkPushedFilters(format, df.where(lit(v1) < Symbol("id")), + Array(sources.IsNotNull("id"))) + checkPushedFilters(format, df.where(lit(v1) <= Symbol("id")), + Array(sources.IsNotNull("id"))) + checkPushedFilters(format, df.where(lit(v1) === Symbol("id")), Array(), noScan = true) + checkPushedFilters(format, df.where(lit(v1) >= Symbol("id")), Array(), noScan = true) + checkPushedFilters(format, df.where(lit(v1) > Symbol("id")), Array(), noScan = true) // cases when value is within range (MIN, MAX) - checkPushedFilters(format, df.where('id > 30), Array(sources.IsNotNull("id"), + checkPushedFilters(format, df.where(Symbol("id") > 30), Array(sources.IsNotNull("id"), sources.GreaterThan("id", 30))) - checkPushedFilters(format, df.where(lit(100) >= 'id), Array(sources.IsNotNull("id"), - sources.LessThanOrEqual("id", 100))) + checkPushedFilters(format, df.where(lit(100) >= Symbol("id")), + Array(sources.IsNotNull("id"), sources.LessThanOrEqual("id", 100))) } } } @@ -991,28 +1055,6 @@ class FileBasedDataSourceSuite extends QueryTest checkAnswer(df, Row("v1", "v2")) } } - - test("SPARK-36271: V1 insert should check schema field name too") { - withView("v") { - spark.range(1).createTempView("v") - withTempDir { dir => - val e = intercept[AnalysisException] { - sql("SELECT ID, IF(ID=1,1,0) FROM v").write.mode(SaveMode.Overwrite) - .format("parquet").save(dir.getCanonicalPath) - }.getMessage - assert(e.contains("Column name \"(IF((ID = 1), 1, 0))\" contains invalid character(s).")) - } - - withTempDir { dir => - val e = intercept[AnalysisException] { - sql("SELECT NAMED_STRUCT('(IF((ID = 1), 1, 0))', IF(ID=1,ID,0)) AS col1 FROM v") - .write.mode(SaveMode.Overwrite) - .format("parquet").save(dir.getCanonicalPath) - }.getMessage - assert(e.contains("Column name \"(IF((ID = 1), 1, 0))\" contains invalid character(s).")) - } - } - } } object TestingUDT { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala index 14b59ba23d09f..ce98fd27350a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala @@ -85,10 +85,11 @@ trait FileScanSuiteBase extends SharedSparkSession { val options = new CaseInsensitiveStringMap(ImmutableMap.copyOf(optionsMap)) val optionsNotEqual = new CaseInsensitiveStringMap(ImmutableMap.copyOf(ImmutableMap.of("key2", "value2"))) - val partitionFilters = Seq(And(IsNull('data.int), LessThan('data.int, 0))) - val partitionFiltersNotEqual = Seq(And(IsNull('data.int), LessThan('data.int, 1))) - val dataFilters = Seq(And(IsNull('data.int), LessThan('data.int, 0))) - val dataFiltersNotEqual = Seq(And(IsNull('data.int), LessThan('data.int, 1))) + val partitionFilters = Seq(And(IsNull(Symbol("data").int), LessThan(Symbol("data").int, 0))) + val partitionFiltersNotEqual = Seq(And(IsNull(Symbol("data").int), + LessThan(Symbol("data").int, 1))) + val dataFilters = Seq(And(IsNull(Symbol("data").int), LessThan(Symbol("data").int, 0))) + val dataFiltersNotEqual = Seq(And(IsNull(Symbol("data").int), LessThan(Symbol("data").int, 1))) scanBuilders.foreach { case (name, scanBuilder, exclusions) => test(s"SPARK-33482: Test $name equals") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index d5c2d93055ba1..436ccb08294b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -357,6 +357,37 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { val df = Seq(1, 2, 3).toDF("v") checkAnswer(df.select(explode(array(min($"v"), max($"v")))), Row(1) :: Row(3) :: Nil) } + + test("SPARK-38528: generator in stream of aggregate expressions") { + val df = Seq(1, 2, 3).toDF("v") + checkAnswer( + df.select(Stream(explode(array(min($"v"), max($"v"))), sum($"v")): _*), + Row(1, 6) :: Row(3, 6) :: Nil) + } + + test("SPARK-37947: lateral view _outer()") { + checkAnswer( + sql("select * from values 1, 2 lateral view explode_outer(array()) a as b"), + Row(1, null) :: Row(2, null) :: Nil) + + checkAnswer( + sql("select * from values 1, 2 lateral view outer explode_outer(array()) a as b"), + Row(1, null) :: Row(2, null) :: Nil) + + withTempView("t1") { + sql( + """select * from values + |array(struct(0, 1), struct(3, 4)), + |array(struct(6, 7)), + |array(), + |null + |as tbl(arr) + """.stripMargin).createOrReplaceTempView("t1") + checkAnswer( + sql("select f1, f2 from t1 lateral view inline_outer(arr) as f1, f2"), + Row(0, 1) :: Row(3, 4) :: Row(6, 7) :: Row(null, null) :: Row(null, null) :: Nil) + } + } } case class EmptyGenerator() extends Generator with LeafLike[Expression] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index 76b3324e3e1c5..86c8c8261e833 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -31,15 +31,16 @@ import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExprId, Pyth import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.expressions.SparkUserDefinedFunction -import org.apache.spark.sql.types.{DataType, StringType} +import org.apache.spark.sql.types.{DataType, IntegerType, StringType} /** - * This object targets to integrate various UDF test cases so that Scalar UDF, Python UDF and - * Scalar Pandas UDFs can be tested in SBT & Maven tests. + * This object targets to integrate various UDF test cases so that Scalar UDF, Python UDF, + * Scalar Pandas UDF and Grouped Aggregate Pandas UDF can be tested in SBT & Maven tests. * - * The available UDFs are special. It defines an UDF wrapped by cast. So, the input column is - * casted into string, UDF returns strings as are, and then output column is casted back to - * the input column. In this way, UDF is virtually no-op. + * The available UDFs are special. For Scalar UDF, Python UDF and Scalar Pandas UDF, + * it defines an UDF wrapped by cast. So, the input column is casted into string, + * UDF returns strings as are, and then output column is casted back to the input column. + * In this way, UDF is virtually no-op. * * Note that, due to this implementation limitation, complex types such as map, array and struct * types do not work with this UDFs because they cannot be same after the cast roundtrip. @@ -69,6 +70,28 @@ import org.apache.spark.sql.types.{DataType, StringType} * df.select(expr("udf_name(id)") * df.select(pandasTestUDF(df("id"))) * }}} + * + * For Grouped Aggregate Pandas UDF, it defines an UDF that calculates the count using pandas. + * The UDF returns the count of the given column. In this way, UDF is virtually not no-op. + * + * To register Grouped Aggregate Pandas UDF in SQL: + * {{{ + * val groupedAggPandasTestUDF = TestGroupedAggPandasUDF(name = "udf_name") + * registerTestUDF(groupedAggPandasTestUDF, spark) + * }}} + * + * To use it in Scala API and SQL: + * {{{ + * sql("SELECT udf_name(1)") + * val df = Seq( + * (536361, "85123A", 2, 17850), + * (536362, "85123B", 4, 17850), + * (536363, "86123A", 6, 17851) + * ).toDF("InvoiceNo", "StockCode", "Quantity", "CustomerID") + * + * df.groupBy("CustomerID").agg(expr("udf_name(Quantity)")) + * df.groupBy("CustomerID").agg(groupedAggPandasTestUDF(df("Quantity"))) + * }}} */ object IntegratedUDFTestUtils extends SQLHelper { import scala.sys.process._ @@ -190,6 +213,28 @@ object IntegratedUDFTestUtils extends SQLHelper { throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.") } + private lazy val pandasGroupedAggFunc: Array[Byte] = if (shouldTestGroupedAggPandasUDFs) { + var binaryPandasFunc: Array[Byte] = null + withTempPath { path => + Process( + Seq( + pythonExec, + "-c", + "from pyspark.sql.types import IntegerType; " + + "from pyspark.serializers import CloudPickleSerializer; " + + s"f = open('$path', 'wb');" + + "f.write(CloudPickleSerializer().dumps((" + + "lambda x: x.agg('count'), IntegerType())))"), + None, + "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! + binaryPandasFunc = Files.readAllBytes(path.toPath) + } + assert(binaryPandasFunc != null) + binaryPandasFunc + } else { + throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.") + } + // Make sure this map stays mutable - this map gets updated later in Python runners. private val workerEnv = new java.util.HashMap[String, String]() workerEnv.put("PYTHONPATH", s"$pysparkPythonPath:$pythonPath") @@ -209,6 +254,8 @@ object IntegratedUDFTestUtils extends SQLHelper { lazy val shouldTestScalarPandasUDFs: Boolean = isPythonAvailable && isPandasAvailable && isPyArrowAvailable + lazy val shouldTestGroupedAggPandasUDFs: Boolean = shouldTestScalarPandasUDFs + /** * A base trait for various UDFs defined in this object. */ @@ -333,6 +380,46 @@ object IntegratedUDFTestUtils extends SQLHelper { val prettyName: String = "Scalar Pandas UDF" } + /** + * A Grouped Aggregate Pandas UDF that takes one column, executes the + * Python native function calculating the count of the column using pandas. + * + * Virtually equivalent to: + * + * {{{ + * import pandas as pd + * from pyspark.sql.functions import pandas_udf + * + * df = spark.createDataFrame( + * [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")) + * + * @pandas_udf("double") + * def pandas_count(v: pd.Series) -> int: + * return v.count() + * + * count_col = pandas_count(df['v']) + * }}} + */ + case class TestGroupedAggPandasUDF(name: String) extends TestUDF { + private[IntegratedUDFTestUtils] lazy val udf = new UserDefinedPythonFunction( + name = name, + func = PythonFunction( + command = pandasGroupedAggFunc, + envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], + pythonIncludes = List.empty[String].asJava, + pythonExec = pythonExec, + pythonVer = pythonVer, + broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, + accumulator = null), + dataType = IntegerType, + pythonEvalType = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, + udfDeterministic = true) + + def apply(exprs: Column*): Column = udf(exprs: _*) + + val prettyName: String = "Grouped Aggregate Pandas UDF" + } + /** * A Scala UDF that takes one column, casts into string, executes the * Scala native function, and casts back to the type of input column. @@ -387,6 +474,7 @@ object IntegratedUDFTestUtils extends SQLHelper { def registerTestUDF(testUDF: TestUDF, session: SparkSession): Unit = testUDF match { case udf: TestPythonUDF => session.udf.registerPython(udf.name, udf.udf) case udf: TestScalarPandasUDF => session.udf.registerPython(udf.name, udf.udf) + case udf: TestGroupedAggPandasUDF => session.udf.registerPython(udf.name, udf.udf) case udf: TestScalaUDF => session.udf.register(udf.name, udf.udf) case other => throw new RuntimeException(s"Unknown UDF class [${other.getClass}]") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 77493afe43145..4a8421a221194 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -183,7 +183,7 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan test("inner join where, one match per row") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { checkAnswer( - upperCaseData.join(lowerCaseData).where('n === 'N), + upperCaseData.join(lowerCaseData).where(Symbol("n") === 'N), Seq( Row(1, "A", 1, "a"), Row(2, "B", 2, "b"), @@ -404,8 +404,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan test("full outer join") { withTempView("`left`", "`right`") { - upperCaseData.where('N <= 4).createOrReplaceTempView("`left`") - upperCaseData.where('N >= 3).createOrReplaceTempView("`right`") + upperCaseData.where(Symbol("N") <= 4).createOrReplaceTempView("`left`") + upperCaseData.where(Symbol("N") >= 3).createOrReplaceTempView("`right`") val left = UnresolvedRelation(TableIdentifier("left")) val right = UnresolvedRelation(TableIdentifier("right")) @@ -623,7 +623,7 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan testData.createOrReplaceTempView("B") testData2.createOrReplaceTempView("C") testData3.createOrReplaceTempView("D") - upperCaseData.where('N >= 3).createOrReplaceTempView("`right`") + upperCaseData.where(Symbol("N") >= 3).createOrReplaceTempView("`right`") val cartesianQueries = Seq( /** The following should error out since there is no explicit cross join */ "SELECT * FROM testData inner join testData2", @@ -1074,8 +1074,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan val df = left.crossJoin(right).where(pythonTestUDF(left("a")) === right.col("c")) // Before optimization, there is a logical Filter operator. - val filterInAnalysis = df.queryExecution.analyzed.find(_.isInstanceOf[Filter]) - assert(filterInAnalysis.isDefined) + val filterInAnalysis = df.queryExecution.analyzed.exists(_.isInstanceOf[Filter]) + assert(filterInAnalysis) // Filter predicate was pushdown as join condition. So there is no Filter exec operator. val filterExec = find(df.queryExecution.executedPlan)(_.isInstanceOf[FilterExec]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 06babab122fd2..e18c087a26279 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -403,7 +403,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { test("SPARK-24709: infers schemas of json strings and pass them to from_json") { val in = Seq("""{"a": [1, 2, 3]}""").toDS() - val out = in.select(from_json('value, schema_of_json("""{"a": [1]}""")) as "parsed") + val out = in.select(from_json(Symbol("value"), schema_of_json("""{"a": [1]}""")) as "parsed") val expected = StructType(StructField( "parsed", StructType(StructField( @@ -417,7 +417,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { test("infers schemas using options") { val df = spark.range(1) .select(schema_of_json(lit("{a:1}"), Map("allowUnquotedFieldNames" -> "true").asJava)) - checkAnswer(df, Seq(Row("STRUCT<`a`: BIGINT>"))) + checkAnswer(df, Seq(Row("STRUCT"))) } test("from_json - array of primitive types") { @@ -697,14 +697,14 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { val input = regexp_replace(lit("""{"item_id": 1, "item_price": 0.1}"""), "item_", "") checkAnswer( spark.range(1).select(schema_of_json(input)), - Seq(Row("STRUCT<`id`: BIGINT, `price`: DOUBLE>"))) + Seq(Row("STRUCT"))) } test("SPARK-31065: schema_of_json - null and empty strings as strings") { Seq("""{"id": null}""", """{"id": ""}""").foreach { input => checkAnswer( spark.range(1).select(schema_of_json(input)), - Seq(Row("STRUCT<`id`: STRING>"))) + Seq(Row("STRUCT"))) } } @@ -716,7 +716,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { schema_of_json( lit("""{"id": "a", "drop": {"drop": null}}"""), options.asJava)), - Seq(Row("STRUCT<`id`: STRING>"))) + Seq(Row("STRUCT"))) // Array of structs checkAnswer( @@ -724,7 +724,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { schema_of_json( lit("""[{"id": "a", "drop": {"drop": null}}]"""), options.asJava)), - Seq(Row("ARRAY>"))) + Seq(Row("ARRAY>"))) // Other types are not affected. checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala index ce25a8869c8b8..ab52cb98208f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala @@ -47,12 +47,12 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { c: Column => Column, f: T => U): Unit = { checkAnswer( - doubleData.select(c('a)), + doubleData.select(c(Symbol("a"))), (1 to 10).map(n => Row(f((n * 0.2 - 1).asInstanceOf[T]))) ) checkAnswer( - doubleData.select(c('b)), + doubleData.select(c(Symbol("b"))), (1 to 10).map(n => Row(f((-n * 0.2 + 1).asInstanceOf[T]))) ) @@ -65,13 +65,13 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { private def testOneToOneNonNegativeMathFunction(c: Column => Column, f: Double => Double): Unit = { checkAnswer( - nnDoubleData.select(c('a)), + nnDoubleData.select(c(Symbol("a"))), (1 to 10).map(n => Row(f(n * 0.1))) ) if (f(-1) === StrictMath.log1p(-1)) { checkAnswer( - nnDoubleData.select(c('b)), + nnDoubleData.select(c(Symbol("b"))), (1 to 9).map(n => Row(f(n * -0.1))) :+ Row(null) ) } @@ -87,12 +87,12 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { d: (Column, Double) => Column, f: (Double, Double) => Double): Unit = { checkAnswer( - nnDoubleData.select(c('a, 'a)), + nnDoubleData.select(c('a, Symbol("a"))), nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(0)))) ) checkAnswer( - nnDoubleData.select(c('a, 'b)), + nnDoubleData.select(c('a, Symbol("b"))), nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(1)))) ) @@ -109,7 +109,7 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { val nonNull = nullDoubles.collect().toSeq.filter(r => r.get(0) != null) checkAnswer( - nullDoubles.select(c('a, 'a)).orderBy('a.asc), + nullDoubles.select(c('a, Symbol("a"))).orderBy(Symbol("a").asc), Row(null) +: nonNull.map(r => Row(f(r.getDouble(0), r.getDouble(0)))) ) } @@ -255,7 +255,7 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { test("factorial") { val df = (0 to 5).map(i => (i, i)).toDF("a", "b") checkAnswer( - df.select(factorial('a)), + df.select(factorial(Symbol("a"))), Seq(Row(1), Row(1), Row(2), Row(6), Row(24), Row(120)) ) checkAnswer( @@ -268,16 +268,24 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { testOneToOneMathFunction(rint, math.rint) } - test("round/bround") { + test("round/bround/ceil/floor") { val df = Seq(5, 55, 555).map(Tuple1(_)).toDF("a") checkAnswer( - df.select(round('a), round('a, -1), round('a, -2)), + df.select(round(Symbol("a")), round('a, -1), round('a, -2)), Seq(Row(5, 10, 0), Row(55, 60, 100), Row(555, 560, 600)) ) checkAnswer( - df.select(bround('a), bround('a, -1), bround('a, -2)), + df.select(bround(Symbol("a")), bround('a, -1), bround('a, -2)), Seq(Row(5, 0, 0), Row(55, 60, 100), Row(555, 560, 600)) ) + checkAnswer( + df.select(ceil('a), ceil('a, lit(-1)), ceil('a, lit(-2))), + Seq(Row(5, 10, 100), Row(55, 60, 100), Row(555, 560, 600)) + ) + checkAnswer( + df.select(floor('a), floor('a, lit(-1)), floor('a, lit(-2))), + Seq(Row(5, 0, 0), Row(55, 50, 0), Row(555, 550, 500)) + ) withSQLConf(SQLConf.LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED.key -> "true") { val pi = "3.1415" @@ -293,6 +301,18 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3), BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142"))) ) + checkAnswer( + sql(s"SELECT ceil($pi), ceil($pi, -3), ceil($pi, -2), ceil($pi, -1), " + + s"ceil($pi, 0), ceil($pi, 1), ceil($pi, 2), ceil($pi, 3)"), + Seq(Row(BigDecimal(4), BigDecimal("1E3"), BigDecimal("1E2"), BigDecimal("1E1"), + BigDecimal(4), BigDecimal("3.2"), BigDecimal("3.15"), BigDecimal("3.142"))) + ) + checkAnswer( + sql(s"SELECT floor($pi), floor($pi, -3), floor($pi, -2), floor($pi, -1), " + + s"floor($pi, 0), floor($pi, 1), floor($pi, 2), floor($pi, 3)"), + Seq(Row(BigDecimal(3), BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), + BigDecimal(3), BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.141"))) + ) } val bdPi: BigDecimal = BigDecimal(31415925L, 7) @@ -307,21 +327,46 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { s"bround($bdPi, 100), bround($bdPi, 6), bround(null, 8)"), Seq(Row(bdPi, bdPi, bdPi, bdPi, bdPi, BigDecimal("3.141592"), null)) ) + checkAnswer( + sql(s"SELECT ceil($bdPi, 7), ceil($bdPi, 8), ceil($bdPi, 9), ceil($bdPi, 10), " + + s"ceil($bdPi, 100), ceil($bdPi, 6), ceil(null, 8)"), + Seq(Row(bdPi, bdPi, bdPi, bdPi, bdPi, BigDecimal("3.141593"), null)) + ) + + checkAnswer( + sql(s"SELECT floor($bdPi, 7), floor($bdPi, 8), floor($bdPi, 9), floor($bdPi, 10), " + + s"floor($bdPi, 100), floor($bdPi, 6), floor(null, 8)"), + Seq(Row(bdPi, bdPi, bdPi, bdPi, bdPi, BigDecimal("3.141592"), null)) + ) } - test("round/bround with data frame from a local Seq of Product") { + test("round/bround/ceil/floor with data frame from a local Seq of Product") { val df = spark.createDataFrame(Seq(Tuple1(BigDecimal("5.9")))).toDF("value") checkAnswer( - df.withColumn("value_rounded", round('value)), + df.withColumn("value_rounded", round(Symbol("value"))), Seq(Row(BigDecimal("5.9"), BigDecimal("6"))) ) checkAnswer( - df.withColumn("value_brounded", bround('value)), + df.withColumn("value_brounded", bround(Symbol("value"))), Seq(Row(BigDecimal("5.9"), BigDecimal("6"))) ) + checkAnswer( + df + .withColumn("value_ceil", ceil('value)) + .withColumn("value_ceil1", ceil('value, lit(0))) + .withColumn("value_ceil2", ceil('value, lit(1))), + Seq(Row(BigDecimal("5.9"), BigDecimal("6"), BigDecimal("6"), BigDecimal("5.9"))) + ) + checkAnswer( + df + .withColumn("value_floor", floor('value)) + .withColumn("value_floor1", floor('value, lit(0))) + .withColumn("value_floor2", floor('value, lit(1))), + Seq(Row(BigDecimal("5.9"), BigDecimal("5"), BigDecimal("5"), BigDecimal("5.9"))) + ) } - test("round/bround with table columns") { + test("round/bround/ceil/floor with table columns") { withTable("t") { Seq(BigDecimal("5.9")).toDF("i").write.saveAsTable("t") checkAnswer( @@ -330,6 +375,24 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer( sql("select i, bround(i) from t"), Seq(Row(BigDecimal("5.9"), BigDecimal("6")))) + checkAnswer( + sql("select i, ceil(i) from t"), + Seq(Row(BigDecimal("5.9"), BigDecimal("6")))) + checkAnswer( + sql("select i, ceil(i, 0) from t"), + Seq(Row(BigDecimal("5.9"), BigDecimal("6")))) + checkAnswer( + sql("select i, ceil(i, 1) from t"), + Seq(Row(BigDecimal("5.9"), BigDecimal("5.9")))) + checkAnswer( + sql("select i, floor(i) from t"), + Seq(Row(BigDecimal("5.9"), BigDecimal("5")))) + checkAnswer( + sql("select i, floor(i, 0) from t"), + Seq(Row(BigDecimal("5.9"), BigDecimal("5")))) + checkAnswer( + sql("select i, floor(i, 1) from t"), + Seq(Row(BigDecimal("5.9"), BigDecimal("5.9")))) } } @@ -360,10 +423,10 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { test("hex") { val data = Seq((28, -28, 100800200404L, "hello")).toDF("a", "b", "c", "d") - checkAnswer(data.select(hex('a)), Seq(Row("1C"))) - checkAnswer(data.select(hex('b)), Seq(Row("FFFFFFFFFFFFFFE4"))) - checkAnswer(data.select(hex('c)), Seq(Row("177828FED4"))) - checkAnswer(data.select(hex('d)), Seq(Row("68656C6C6F"))) + checkAnswer(data.select(hex(Symbol("a"))), Seq(Row("1C"))) + checkAnswer(data.select(hex(Symbol("b"))), Seq(Row("FFFFFFFFFFFFFFE4"))) + checkAnswer(data.select(hex(Symbol("c"))), Seq(Row("177828FED4"))) + checkAnswer(data.select(hex(Symbol("d"))), Seq(Row("68656C6C6F"))) checkAnswer(data.selectExpr("hex(a)"), Seq(Row("1C"))) checkAnswer(data.selectExpr("hex(b)"), Seq(Row("FFFFFFFFFFFFFFE4"))) checkAnswer(data.selectExpr("hex(c)"), Seq(Row("177828FED4"))) @@ -373,8 +436,8 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { test("unhex") { val data = Seq(("1C", "737472696E67")).toDF("a", "b") - checkAnswer(data.select(unhex('a)), Row(Array[Byte](28.toByte))) - checkAnswer(data.select(unhex('b)), Row("string".getBytes(StandardCharsets.UTF_8))) + checkAnswer(data.select(unhex(Symbol("a"))), Row(Array[Byte](28.toByte))) + checkAnswer(data.select(unhex(Symbol("b"))), Row("string".getBytes(StandardCharsets.UTF_8))) checkAnswer(data.selectExpr("unhex(a)"), Row(Array[Byte](28.toByte))) checkAnswer(data.selectExpr("unhex(b)"), Row("string".getBytes(StandardCharsets.UTF_8))) checkAnswer(data.selectExpr("""unhex("##")"""), Row(null)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala index 262a6920d29c6..a0207e9b01920 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite import org.apache.spark.sql.execution.exchange.{Exchange, ReusedExchangeExec, ValidateRequirements} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.tags.ExtendedSQLTest // scalastyle:off line.size.limit @@ -82,8 +83,18 @@ trait PlanStabilitySuite extends DisableAdaptiveExecutionSuite { def goldenFilePath: String + private val approvedAnsiPlans: Seq[String] = Seq( + "q83", + "q83.sf100" + ) + private def getDirForTest(name: String): File = { - new File(goldenFilePath, name) + val goldenFileName = if (SQLConf.get.ansiEnabled && approvedAnsiPlans.contains(name)) { + name + ".ansi" + } else { + name + } + new File(goldenFilePath, goldenFileName) } private def isApproved( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 31569d82b4dc9..06f94c62d9c25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -207,11 +207,12 @@ abstract class QueryTest extends PlanTest { */ def assertCached(query: Dataset[_], cachedName: String, storageLevel: StorageLevel): Unit = { val planWithCaching = query.queryExecution.withCachedData - val matched = planWithCaching.collectFirst { case cached: InMemoryRelation => - val cacheBuilder = cached.cacheBuilder - cachedName == cacheBuilder.tableName.get && - (storageLevel == cacheBuilder.storageLevel) - }.getOrElse(false) + val matched = planWithCaching.exists { + case cached: InMemoryRelation => + val cacheBuilder = cached.cacheBuilder + cachedName == cacheBuilder.tableName.get && (storageLevel == cacheBuilder.storageLevel) + case _ => false + } assert(matched, s"Expected query plan to hit cache $cachedName with storage " + s"level $storageLevel, but it doesn't.") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala index 739b4052ee90d..8883e9be1937e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala @@ -59,13 +59,13 @@ class ReplaceNullWithFalseInPredicateEndToEndSuite extends QueryTest with Shared val q5 = df1.selectExpr("IF(l > 1 AND null, 5, 1) AS out") checkAnswer(q5, Row(1) :: Row(1) :: Nil) q5.queryExecution.executedPlan.foreach { p => - assert(p.expressions.forall(e => e.find(_.isInstanceOf[If]).isEmpty)) + assert(p.expressions.forall(e => !e.exists(_.isInstanceOf[If]))) } val q6 = df1.selectExpr("CASE WHEN (l > 2 AND null) THEN 3 ELSE 2 END") checkAnswer(q6, Row(2) :: Row(2) :: Nil) q6.queryExecution.executedPlan.foreach { p => - assert(p.expressions.forall(e => e.find(_.isInstanceOf[CaseWhen]).isEmpty)) + assert(p.expressions.forall(e => !e.exists(_.isInstanceOf[CaseWhen]))) } checkAnswer(df1.where("IF(l > 10, false, b OR null)"), Row(1, true)) @@ -75,10 +75,10 @@ class ReplaceNullWithFalseInPredicateEndToEndSuite extends QueryTest with Shared test("SPARK-26107: Replace Literal(null, _) with FalseLiteral in higher-order functions") { def assertNoLiteralNullInPlan(df: DataFrame): Unit = { df.queryExecution.executedPlan.foreach { p => - assert(p.expressions.forall(_.find { + assert(p.expressions.forall(!_.exists { case Literal(null, BooleanType) => true case _ => false - }.isEmpty)) + })) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala index 2f56fbaf7f821..fad01db82ca0e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala @@ -286,20 +286,43 @@ trait SQLInsertTestSuite extends QueryTest with SQLTestUtils { } else { SQLConf.StoreAssignmentPolicy.values } + + def shouldThrowException(policy: SQLConf.StoreAssignmentPolicy.Value): Boolean = policy match { + case SQLConf.StoreAssignmentPolicy.ANSI | SQLConf.StoreAssignmentPolicy.STRICT => + true + case SQLConf.StoreAssignmentPolicy.LEGACY => + false + } + testingPolicies.foreach { policy => - withSQLConf( - SQLConf.STORE_ASSIGNMENT_POLICY.key -> policy.toString) { + withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> policy.toString) { withTable("t") { sql("create table t(a int, b string) using parquet partitioned by (a)") - policy match { - case SQLConf.StoreAssignmentPolicy.ANSI | SQLConf.StoreAssignmentPolicy.STRICT => - val errorMsg = intercept[NumberFormatException] { - sql("insert into t partition(a='ansi') values('ansi')") - }.getMessage - assert(errorMsg.contains("invalid input syntax for type numeric: ansi")) - case SQLConf.StoreAssignmentPolicy.LEGACY => + if (shouldThrowException(policy)) { + val errorMsg = intercept[NumberFormatException] { sql("insert into t partition(a='ansi') values('ansi')") - checkAnswer(sql("select * from t"), Row("ansi", null) :: Nil) + }.getMessage + assert(errorMsg.contains("invalid input syntax for type numeric: ansi")) + } else { + sql("insert into t partition(a='ansi') values('ansi')") + checkAnswer(sql("select * from t"), Row("ansi", null) :: Nil) + } + } + } + } + } + + test("SPARK-38228: legacy store assignment should not fail on error under ANSI mode") { + // DS v2 doesn't support the legacy policy + if (format != "foo") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf( + SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.LEGACY.toString, + SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withTable("t") { + sql("create table t(a int) using parquet") + sql("insert into t values('ansi')") + checkAnswer(spark.table("t"), Row(null)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index d7f18ee801d72..c28dde9cea09a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -69,8 +69,10 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark val queryCaseWhen = sql("select case when true then 1.0 else '1' end from src ") val queryCoalesce = sql("select coalesce(null, 1, '1') from src ") - checkAnswer(queryCaseWhen, Row("1.0") :: Nil) - checkAnswer(queryCoalesce, Row("1") :: Nil) + if (!conf.ansiEnabled) { + checkAnswer(queryCaseWhen, Row("1.0") :: Nil) + checkAnswer(queryCoalesce, Row("1") :: Nil) + } } } @@ -393,10 +395,14 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark testCodeGen( "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x", Row(100, 1, 50.5, 300, 100) :: Nil) - // Aggregate with Code generation handling all null values - testCodeGen( - "SELECT sum('a'), avg('a'), count(null) FROM testData", - Row(null, null, 0) :: Nil) + // Aggregate with Code generation handling all null values. + // If ANSI mode is on, there will be an error since 'a' cannot converted as Numeric. + // Here we simply test it when ANSI mode is off. + if (!conf.ansiEnabled) { + testCodeGen( + "SELECT sum('a'), avg('a'), count(null) FROM testData", + Row(null, null, 0) :: Nil) + } } finally { spark.catalog.dropTempView("testData3x") } @@ -488,9 +494,11 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark Seq(Row(Timestamp.valueOf("1969-12-31 16:00:00.001")), Row(Timestamp.valueOf("1969-12-31 16:00:00.002")))) - checkAnswer(sql( - "SELECT time FROM timestamps WHERE time='123'"), - Nil) + if (!conf.ansiEnabled) { + checkAnswer(sql( + "SELECT time FROM timestamps WHERE time='123'"), + Nil) + } } } @@ -939,9 +947,13 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark Row(1, "A") :: Row(1, "a") :: Row(2, "B") :: Row(2, "b") :: Row(3, "C") :: Row(3, "c") :: Row(4, "D") :: Row(4, "d") :: Row(5, "E") :: Row(6, "F") :: Nil) // Column type mismatches are not allowed, forcing a type coercion. - checkAnswer( - sql("SELECT n FROM lowerCaseData UNION SELECT L FROM upperCaseData"), - ("1" :: "2" :: "3" :: "4" :: "A" :: "B" :: "C" :: "D" :: "E" :: "F" :: Nil).map(Row(_))) + // When ANSI mode is on, the String input will be cast as Int in the following Union, which will + // cause a runtime error. Here we simply test the case when ANSI mode is off. + if (!conf.ansiEnabled) { + checkAnswer( + sql("SELECT n FROM lowerCaseData UNION SELECT L FROM upperCaseData"), + ("1" :: "2" :: "3" :: "4" :: "A" :: "B" :: "C" :: "D" :: "E" :: "F" :: Nil).map(Row(_))) + } // Column type mismatches where a coercion is not possible, in this case between integer // and array types, trigger a TreeNodeException. intercept[AnalysisException] { @@ -1038,32 +1050,35 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark Row(Row(3, true), Map("C3" -> null)) :: Row(Row(4, true), Map("D4" -> 2147483644)) :: Nil) - checkAnswer( - sql("SELECT f1.f11, f2['D4'] FROM applySchema2"), - Row(1, null) :: - Row(2, null) :: - Row(3, null) :: - Row(4, 2147483644) :: Nil) - - // The value of a MapType column can be a mutable map. - val rowRDD3 = unparsedStrings.map { r => - val values = r.split(",").map(_.trim) - val v4 = try values(3).toInt catch { - case _: NumberFormatException => null + // If ANSI mode is on, there will be an error "Key D4 does not exist". + if (!conf.ansiEnabled) { + checkAnswer( + sql("SELECT f1.f11, f2['D4'] FROM applySchema2"), + Row(1, null) :: + Row(2, null) :: + Row(3, null) :: + Row(4, 2147483644) :: Nil) + + // The value of a MapType column can be a mutable map. + val rowRDD3 = unparsedStrings.map { r => + val values = r.split(",").map(_.trim) + val v4 = try values(3).toInt catch { + case _: NumberFormatException => null + } + Row(Row(values(0).toInt, values(2).toBoolean), + scala.collection.mutable.Map(values(1) -> v4)) } - Row(Row(values(0).toInt, values(2).toBoolean), - scala.collection.mutable.Map(values(1) -> v4)) - } - val df3 = spark.createDataFrame(rowRDD3, schema2) - df3.createOrReplaceTempView("applySchema3") + val df3 = spark.createDataFrame(rowRDD3, schema2) + df3.createOrReplaceTempView("applySchema3") - checkAnswer( - sql("SELECT f1.f11, f2['D4'] FROM applySchema3"), - Row(1, null) :: - Row(2, null) :: - Row(3, null) :: - Row(4, 2147483644) :: Nil) + checkAnswer( + sql("SELECT f1.f11, f2['D4'] FROM applySchema3"), + Row(1, null) :: + Row(2, null) :: + Row(3, null) :: + Row(4, 2147483644) :: Nil) + } } } @@ -1403,22 +1418,25 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } test("SPARK-7952: fix the equality check between boolean and numeric types") { - withTempView("t") { - // numeric field i, boolean field j, result of i = j, result of i <=> j - Seq[(Integer, java.lang.Boolean, java.lang.Boolean, java.lang.Boolean)]( - (1, true, true, true), - (0, false, true, true), - (2, true, false, false), - (2, false, false, false), - (null, true, null, false), - (null, false, null, false), - (0, null, null, false), - (1, null, null, false), - (null, null, null, true) - ).toDF("i", "b", "r1", "r2").createOrReplaceTempView("t") - - checkAnswer(sql("select i = b from t"), sql("select r1 from t")) - checkAnswer(sql("select i <=> b from t"), sql("select r2 from t")) + // If ANSI mode is on, Spark disallows comparing Int with Boolean. + if (!conf.ansiEnabled) { + withTempView("t") { + // numeric field i, boolean field j, result of i = j, result of i <=> j + Seq[(Integer, java.lang.Boolean, java.lang.Boolean, java.lang.Boolean)]( + (1, true, true, true), + (0, false, true, true), + (2, true, false, false), + (2, false, false, false), + (null, true, null, false), + (null, false, null, false), + (0, null, null, false), + (1, null, null, false), + (null, null, null, true) + ).toDF("i", "b", "r1", "r2").createOrReplaceTempView("t") + + checkAnswer(sql("select i = b from t"), sql("select r1 from t")) + checkAnswer(sql("select i <=> b from t"), sql("select r2 from t")) + } } } @@ -3048,15 +3066,17 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark val df = spark.read.format(format).load(dir.getCanonicalPath) checkPushedFilters( format, - df.where(('id < 2 and 's.contains("foo")) or ('id > 10 and 's.contains("bar"))), + df.where((Symbol("id") < 2 and Symbol("s").contains("foo")) or + (Symbol("id") > 10 and Symbol("s").contains("bar"))), Array(sources.Or(sources.LessThan("id", 2), sources.GreaterThan("id", 10)))) checkPushedFilters( format, - df.where('s.contains("foo") or ('id > 10 and 's.contains("bar"))), + df.where(Symbol("s").contains("foo") or + (Symbol("id") > 10 and Symbol("s").contains("bar"))), Array.empty) checkPushedFilters( format, - df.where('id < 2 and not('id > 10 and 's.contains("bar"))), + df.where(Symbol("id") < 2 and not(Symbol("id") > 10 and Symbol("s").contains("bar"))), Array(sources.IsNotNull("id"), sources.LessThan("id", 2))) } } @@ -3137,16 +3157,20 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark checkAnswer(sql("select * from t1 where d >= '2000-01-01'"), Row(result)) checkAnswer(sql("select * from t1 where d >= '2000-01-02'"), Nil) checkAnswer(sql("select * from t1 where '2000' >= d"), Row(result)) - checkAnswer(sql("select * from t1 where d > '2000-13'"), Nil) + if (!conf.ansiEnabled) { + checkAnswer(sql("select * from t1 where d > '2000-13'"), Nil) + } withSQLConf(SQLConf.LEGACY_CAST_DATETIME_TO_STRING.key -> "true") { checkAnswer(sql("select * from t1 where d < '2000'"), Nil) checkAnswer(sql("select * from t1 where d < '2001'"), Row(result)) - checkAnswer(sql("select * from t1 where d < '2000-1-1'"), Row(result)) checkAnswer(sql("select * from t1 where d <= '1999'"), Nil) checkAnswer(sql("select * from t1 where d >= '2000'"), Row(result)) - checkAnswer(sql("select * from t1 where d > '1999-13'"), Row(result)) - checkAnswer(sql("select to_date('2000-01-01') > '1'"), Row(true)) + if (!conf.ansiEnabled) { + checkAnswer(sql("select * from t1 where d < '2000-1-1'"), Row(result)) + checkAnswer(sql("select * from t1 where d > '1999-13'"), Row(result)) + checkAnswer(sql("select to_date('2000-01-01') > '1'"), Row(true)) + } } } } @@ -3179,17 +3203,21 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark checkAnswer(sql("select * from t1 where d >= '2000-01-01 01:10:00.000'"), Row(result)) checkAnswer(sql("select * from t1 where d >= '2000-01-02 01:10:00.000'"), Nil) checkAnswer(sql("select * from t1 where '2000' >= d"), Nil) - checkAnswer(sql("select * from t1 where d > '2000-13'"), Nil) + if (!conf.ansiEnabled) { + checkAnswer(sql("select * from t1 where d > '2000-13'"), Nil) + } withSQLConf(SQLConf.LEGACY_CAST_DATETIME_TO_STRING.key -> "true") { checkAnswer(sql("select * from t1 where d < '2000'"), Nil) checkAnswer(sql("select * from t1 where d < '2001'"), Row(result)) - checkAnswer(sql("select * from t1 where d <= '2000-1-1'"), Row(result)) checkAnswer(sql("select * from t1 where d <= '2000-01-02'"), Row(result)) checkAnswer(sql("select * from t1 where d <= '1999'"), Nil) checkAnswer(sql("select * from t1 where d >= '2000'"), Row(result)) - checkAnswer(sql("select * from t1 where d > '1999-13'"), Row(result)) - checkAnswer(sql("select to_timestamp('2000-01-01 01:10:00') > '1'"), Row(true)) + if (!conf.ansiEnabled) { + checkAnswer(sql("select * from t1 where d <= '2000-1-1'"), Row(result)) + checkAnswer(sql("select * from t1 where d > '1999-13'"), Row(result)) + checkAnswer(sql("select to_timestamp('2000-01-01 01:10:00') > '1'"), Row(true)) + } } sql("DROP VIEW t1") } @@ -3254,28 +3282,31 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } test("SPARK-29213: FilterExec should not throw NPE") { - withTempView("t1", "t2", "t3") { - sql("SELECT ''").as[String].map(identity).toDF("x").createOrReplaceTempView("t1") - sql("SELECT * FROM VALUES 0, CAST(NULL AS BIGINT)") - .as[java.lang.Long] - .map(identity) - .toDF("x") - .createOrReplaceTempView("t2") - sql("SELECT ''").as[String].map(identity).toDF("x").createOrReplaceTempView("t3") - sql( - """ - |SELECT t1.x - |FROM t1 - |LEFT JOIN ( - | SELECT x FROM ( - | SELECT x FROM t2 - | UNION ALL - | SELECT SUBSTR(x,5) x FROM t3 - | ) a - | WHERE LENGTH(x)>0 - |) t3 - |ON t1.x=t3.x + // Under ANSI mode, casting string '' as numeric will cause runtime error + if (!conf.ansiEnabled) { + withTempView("t1", "t2", "t3") { + sql("SELECT ''").as[String].map(identity).toDF("x").createOrReplaceTempView("t1") + sql("SELECT * FROM VALUES 0, CAST(NULL AS BIGINT)") + .as[java.lang.Long] + .map(identity) + .toDF("x") + .createOrReplaceTempView("t2") + sql("SELECT ''").as[String].map(identity).toDF("x").createOrReplaceTempView("t3") + sql( + """ + |SELECT t1.x + |FROM t1 + |LEFT JOIN ( + | SELECT x FROM ( + | SELECT x FROM t2 + | UNION ALL + | SELECT SUBSTR(x,5) x FROM t3 + | ) a + | WHERE LENGTH(x)>0 + |) t3 + |ON t1.x=t3.x """.stripMargin).collect() + } } } @@ -3295,7 +3326,6 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark sql("CREATE TEMPORARY VIEW tc AS SELECT * FROM VALUES(CAST(1 AS DOUBLE)) AS tc(id)") sql("CREATE TEMPORARY VIEW td AS SELECT * FROM VALUES(CAST(1 AS FLOAT)) AS td(id)") sql("CREATE TEMPORARY VIEW te AS SELECT * FROM VALUES(CAST(1 AS BIGINT)) AS te(id)") - sql("CREATE TEMPORARY VIEW tf AS SELECT * FROM VALUES(CAST(1 AS DECIMAL(38, 38))) AS tf(id)") val df1 = sql("SELECT id FROM ta WHERE id IN (SELECT id FROM tb)") checkAnswer(df1, Row(new java.math.BigDecimal(1))) val df2 = sql("SELECT id FROM ta WHERE id IN (SELECT id FROM tc)") @@ -3304,8 +3334,12 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark checkAnswer(df3, Row(new java.math.BigDecimal(1))) val df4 = sql("SELECT id FROM ta WHERE id IN (SELECT id FROM te)") checkAnswer(df4, Row(new java.math.BigDecimal(1))) - val df5 = sql("SELECT id FROM ta WHERE id IN (SELECT id FROM tf)") - checkAnswer(df5, Array.empty[Row]) + if (!conf.ansiEnabled) { + sql( + "CREATE TEMPORARY VIEW tf AS SELECT * FROM VALUES(CAST(1 AS DECIMAL(38, 38))) AS tf(id)") + val df5 = sql("SELECT id FROM ta WHERE id IN (SELECT id FROM tf)") + checkAnswer(df5, Array.empty[Row]) + } } } @@ -4243,6 +4277,57 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark checkAnswer(df3, df4) } } + + test("SPARK-27442: Spark support read/write parquet file with invalid char in field name") { + withTempDir { dir => + Seq((1, 2, 3, 4, 5, 6, 7, 8, 9, 10), (2, 4, 6, 8, 10, 12, 14, 16, 18, 20)) + .toDF("max(t)", "max(t", "=", "\n", ";", "a b", "{", ".", "a.b", "a") + .repartition(1) + .write.mode(SaveMode.Overwrite).parquet(dir.getAbsolutePath) + val df = spark.read.parquet(dir.getAbsolutePath) + checkAnswer(df, + Row(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) :: + Row(2, 4, 6, 8, 10, 12, 14, 16, 18, 20) :: Nil) + assert(df.schema.names.sameElements( + Array("max(t)", "max(t", "=", "\n", ";", "a b", "{", ".", "a.b", "a"))) + checkAnswer(df.select("`max(t)`", "`a b`", "`{`", "`.`", "`a.b`"), + Row(1, 6, 7, 8, 9) :: Row(2, 12, 14, 16, 18) :: Nil) + checkAnswer(df.where("`a.b` > 10"), + Row(2, 4, 6, 8, 10, 12, 14, 16, 18, 20) :: Nil) + } + } + + test("SPARK-37965: Spark support read/write orc file with invalid char in field name") { + withTempDir { dir => + Seq((1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11), (2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22)) + .toDF("max(t)", "max(t", "=", "\n", ";", "a b", "{", ".", "a.b", "a", ",") + .repartition(1) + .write.mode(SaveMode.Overwrite).orc(dir.getAbsolutePath) + val df = spark.read.orc(dir.getAbsolutePath) + checkAnswer(df, + Row(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11) :: + Row(2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22) :: Nil) + assert(df.schema.names.sameElements( + Array("max(t)", "max(t", "=", "\n", ";", "a b", "{", ".", "a.b", "a", ","))) + checkAnswer(df.select("`max(t)`", "`a b`", "`{`", "`.`", "`a.b`"), + Row(1, 6, 7, 8, 9) :: Row(2, 12, 14, 16, 18) :: Nil) + checkAnswer(df.where("`a.b` > 10"), + Row(2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22) :: Nil) + } + } + + test("SPARK-38173: Quoted column cannot be recognized correctly " + + "when quotedRegexColumnNames is true") { + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "true") { + checkAnswer( + sql( + """ + |SELECT `(C3)?+.+`,T.`C1` * `C2` AS CC + |FROM (SELECT 3 AS C1,2 AS C2,1 AS C3) T + |""".stripMargin), + Row(3, 2, 6) :: Nil) + } + } } case class Foo(bar: Option[String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 7a5684ef3ffbc..d6a7c69018f90 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -388,6 +388,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper localSparkSession.conf.set(SQLConf.TIMESTAMP_TYPE.key, TimestampTypes.TIMESTAMP_NTZ.toString) case _ => + localSparkSession.conf.set(SQLConf.ANSI_ENABLED.key, false) } if (configSet.nonEmpty) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ShowCreateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ShowCreateTableSuite.scala deleted file mode 100644 index 13983120955fb..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/ShowCreateTableSuite.scala +++ /dev/null @@ -1,267 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.CatalogTable -import org.apache.spark.sql.sources.SimpleInsertSource -import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} -import org.apache.spark.util.Utils - -class SimpleShowCreateTableSuite extends ShowCreateTableSuite with SharedSparkSession - -abstract class ShowCreateTableSuite extends QueryTest with SQLTestUtils { - import testImplicits._ - - test("data source table with user specified schema") { - withTable("ddl_test") { - val jsonFilePath = Utils.getSparkClassLoader.getResource("sample.json").getFile - - sql( - s"""CREATE TABLE ddl_test ( - | a STRING, - | b STRING, - | `extra col` ARRAY, - | `` STRUCT> - |) - |USING json - |OPTIONS ( - | PATH '$jsonFilePath' - |) - """.stripMargin - ) - - checkCreateTable("ddl_test") - } - } - - test("data source table CTAS") { - withTable("ddl_test") { - sql( - s"""CREATE TABLE ddl_test - |USING json - |AS SELECT 1 AS a, "foo" AS b - """.stripMargin - ) - - checkCreateTable("ddl_test") - } - } - - test("partitioned data source table") { - withTable("ddl_test") { - sql( - s"""CREATE TABLE ddl_test - |USING json - |PARTITIONED BY (b) - |AS SELECT 1 AS a, "foo" AS b - """.stripMargin - ) - - checkCreateTable("ddl_test") - } - } - - test("bucketed data source table") { - withTable("ddl_test") { - sql( - s"""CREATE TABLE ddl_test - |USING json - |CLUSTERED BY (a) SORTED BY (b) INTO 2 BUCKETS - |AS SELECT 1 AS a, "foo" AS b - """.stripMargin - ) - - checkCreateTable("ddl_test") - } - } - - test("partitioned bucketed data source table") { - withTable("ddl_test") { - sql( - s"""CREATE TABLE ddl_test - |USING json - |PARTITIONED BY (c) - |CLUSTERED BY (a) SORTED BY (b) INTO 2 BUCKETS - |AS SELECT 1 AS a, "foo" AS b, 2.5 AS c - """.stripMargin - ) - - checkCreateTable("ddl_test") - } - } - - test("data source table with a comment") { - withTable("ddl_test") { - sql( - s"""CREATE TABLE ddl_test - |USING json - |COMMENT 'This is a comment' - |AS SELECT 1 AS a, "foo" AS b, 2.5 AS c - """.stripMargin - ) - - checkCreateTable("ddl_test") - } - } - - test("data source table with table properties") { - withTable("ddl_test") { - sql( - s"""CREATE TABLE ddl_test - |USING json - |TBLPROPERTIES ('a' = '1') - |AS SELECT 1 AS a, "foo" AS b, 2.5 AS c - """.stripMargin - ) - - checkCreateTable("ddl_test") - } - } - - test("data source table using Dataset API") { - withTable("ddl_test") { - spark - .range(3) - .select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd, 'id as 'e) - .write - .mode("overwrite") - .partitionBy("a", "b") - .bucketBy(2, "c", "d") - .saveAsTable("ddl_test") - - checkCreateTable("ddl_test") - } - } - - test("temp view") { - val viewName = "spark_28383" - withTempView(viewName) { - sql(s"CREATE TEMPORARY VIEW $viewName AS SELECT 1 AS a") - val ex = intercept[AnalysisException] { - sql(s"SHOW CREATE TABLE $viewName") - } - assert(ex.getMessage.contains( - s"$viewName is a temp view. 'SHOW CREATE TABLE' expects a table or permanent view.")) - } - - withGlobalTempView(viewName) { - sql(s"CREATE GLOBAL TEMPORARY VIEW $viewName AS SELECT 1 AS a") - val globalTempViewDb = spark.sessionState.catalog.globalTempViewManager.database - val ex = intercept[AnalysisException] { - sql(s"SHOW CREATE TABLE $globalTempViewDb.$viewName") - } - assert(ex.getMessage.contains( - s"$globalTempViewDb.$viewName is a temp view. " + - "'SHOW CREATE TABLE' expects a table or permanent view.")) - } - } - - test("SPARK-24911: keep quotes for nested fields") { - withTable("t1") { - val createTable = "CREATE TABLE `t1` (`a` STRUCT<`b`: STRING>)" - sql(s"$createTable USING json") - val shownDDL = getShowDDL("SHOW CREATE TABLE t1") - assert(shownDDL == "CREATE TABLE `default`.`t1` ( `a` STRUCT<`b`: STRING>) USING json") - - checkCreateTable("t1") - } - } - - test("SPARK-36012: Add NULL flag when SHOW CREATE TABLE") { - val t = "SPARK_36012" - withTable(t) { - sql( - s""" - |CREATE TABLE $t ( - | a bigint NOT NULL, - | b bigint - |) - |USING ${classOf[SimpleInsertSource].getName} - """.stripMargin) - val showDDL = getShowDDL(s"SHOW CREATE TABLE $t") - assert(showDDL == s"CREATE TABLE `default`.`$t` ( `a` BIGINT NOT NULL," + - s" `b` BIGINT) USING ${classOf[SimpleInsertSource].getName}") - } - } - - test("SPARK-37494: Unify v1 and v2 option output") { - withTable("ddl_test") { - sql( - s"""CREATE TABLE ddl_test ( - | a STRING - |) - |USING json - |TBLPROPERTIES ( - | 'b' = '1', - | 'a' = '2') - |OPTIONS ( - | k4 'v4', - | `k3` 'v3', - | 'k5' 'v5', - | 'k1' = 'v1', - | k2 = 'v2' - |) - """.stripMargin - ) - val expected = "CREATE TABLE `default`.`ddl_test` ( `a` STRING) USING json" + - " OPTIONS ( 'k1' = 'v1', 'k2' = 'v2', 'k3' = 'v3', 'k4' = 'v4', 'k5' = 'v5')" + - " TBLPROPERTIES ( 'a' = '2', 'b' = '1')" - assert(getShowDDL("SHOW CREATE TABLE ddl_test") == expected) - } - } - - protected def getShowDDL(showCreateTableSql: String): String = { - sql(showCreateTableSql).head().getString(0).split("\n").map(_.trim).mkString(" ") - } - - protected def checkCreateTable(table: String, serde: Boolean = false): Unit = { - checkCreateTableOrView(TableIdentifier(table, Some("default")), "TABLE", serde) - } - - protected def checkCreateView(table: String, serde: Boolean = false): Unit = { - checkCreateTableOrView(TableIdentifier(table, Some("default")), "VIEW", serde) - } - - protected def checkCreateTableOrView( - table: TableIdentifier, - checkType: String, - serde: Boolean): Unit = { - val db = table.database.getOrElse("default") - val expected = spark.sharedState.externalCatalog.getTable(db, table.table) - val shownDDL = if (serde) { - sql(s"SHOW CREATE TABLE ${table.quotedString} AS SERDE").head().getString(0) - } else { - sql(s"SHOW CREATE TABLE ${table.quotedString}").head().getString(0) - } - - sql(s"DROP $checkType ${table.quotedString}") - - try { - sql(shownDDL) - val actual = spark.sharedState.externalCatalog.getTable(db, table.table) - checkCatalogTables(expected, actual) - } finally { - sql(s"DROP $checkType IF EXISTS ${table.table}") - } - } - - protected def checkCatalogTables(expected: CatalogTable, actual: CatalogTable): Unit = { - assert(CatalogTable.normalize(actual) == CatalogTable.normalize(expected)) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 4994968fdd6ba..3577812ac6f37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -725,37 +725,32 @@ class BrokenColumnarAdd( lhs = left.columnarEval(batch) rhs = right.columnarEval(batch) - if (lhs == null || rhs == null) { - ret = null - } else if (lhs.isInstanceOf[ColumnVector] && rhs.isInstanceOf[ColumnVector]) { - val l = lhs.asInstanceOf[ColumnVector] - val r = rhs.asInstanceOf[ColumnVector] - val result = new OnHeapColumnVector(batch.numRows(), dataType) - ret = result - - for (i <- 0 until batch.numRows()) { - result.appendLong(l.getLong(i) + r.getLong(i) + 1) // BUG to show we replaced Add - } - } else if (rhs.isInstanceOf[ColumnVector]) { - val l = lhs.asInstanceOf[Long] - val r = rhs.asInstanceOf[ColumnVector] - val result = new OnHeapColumnVector(batch.numRows(), dataType) - ret = result - - for (i <- 0 until batch.numRows()) { - result.appendLong(l + r.getLong(i) + 1) // BUG to show we replaced Add - } - } else if (lhs.isInstanceOf[ColumnVector]) { - val l = lhs.asInstanceOf[ColumnVector] - val r = rhs.asInstanceOf[Long] - val result = new OnHeapColumnVector(batch.numRows(), dataType) - ret = result - - for (i <- 0 until batch.numRows()) { - result.appendLong(l.getLong(i) + r + 1) // BUG to show we replaced Add - } - } else { - ret = nullSafeEval(lhs, rhs) + (lhs, rhs) match { + case (null, null) => + ret = null + case (l: ColumnVector, r: ColumnVector) => + val result = new OnHeapColumnVector(batch.numRows(), dataType) + ret = result + + for (i <- 0 until batch.numRows()) { + result.appendLong(l.getLong(i) + r.getLong(i) + 1) // BUG to show we replaced Add + } + case (l: Long, r: ColumnVector) => + val result = new OnHeapColumnVector(batch.numRows(), dataType) + ret = result + + for (i <- 0 until batch.numRows()) { + result.appendLong(l + r.getLong(i) + 1) // BUG to show we replaced Add + } + case (l: ColumnVector, r: Long) => + val result = new OnHeapColumnVector(batch.numRows(), dataType) + ret = result + + for (i <- 0 until batch.numRows()) { + result.appendLong(l.getLong(i) + r + 1) // BUG to show we replaced Add + } + case (l, r) => + ret = nullSafeEval(l, r) } } finally { if (lhs != null && lhs.isInstanceOf[ColumnVector]) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 9f8000a08f7af..c37309d97acae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -27,8 +27,9 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogColumnStat import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.DateTimeTestUtils -import org.apache.spark.sql.catalyst.util.DateTimeUtils.TimeZoneUTC +import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, PST, UTC} +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, TimeZoneUTC} import org.apache.spark.sql.functions.timestamp_seconds import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -406,9 +407,9 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared withTable("TBL1", "TBL") { import org.apache.spark.sql.functions._ val df = spark.range(1000L).select('id, - 'id * 2 as "FLD1", - 'id * 12 as "FLD2", - lit("aaa") + 'id as "fld3") + Symbol("id") * 2 as "FLD1", + Symbol("id") * 12 as "FLD2", + lit(null).cast(DoubleType) + Symbol("id") as "fld3") df.write .mode(SaveMode.Overwrite) .bucketBy(10, "id", "FLD1", "FLD2") @@ -424,7 +425,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared |WHERE t1.fld3 IN (-123.23,321.23) """.stripMargin) df2.createTempView("TBL2") - sql("SELECT * FROM tbl2 WHERE fld3 IN ('qqq', 'qwe') ").queryExecution.executedPlan + sql("SELECT * FROM tbl2 WHERE fld3 IN (0,1) ").queryExecution.executedPlan } } } @@ -470,7 +471,89 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } } - def getStatAttrNames(tableName: String): Set[String] = { + private def checkDescTimestampColStats( + tableName: String, + timestampColumn: String, + expectedMinTimestamp: String, + expectedMaxTimestamp: String): Unit = { + + def extractColumnStatsFromDesc(statsName: String, rows: Array[Row]): String = { + rows.collect { + case r: Row if r.getString(0) == statsName => + r.getString(1) + }.head + } + + val descTsCol = sql(s"DESC FORMATTED $tableName $timestampColumn").collect() + assert(extractColumnStatsFromDesc("min", descTsCol) == expectedMinTimestamp) + assert(extractColumnStatsFromDesc("max", descTsCol) == expectedMaxTimestamp) + } + + test("SPARK-38140: describe column stats (min, max) for timestamp column: desc results should " + + "be consistent with the written value if writing and desc happen in the same time zone") { + + val zoneIdAndOffsets = + Seq((UTC, "+0000"), (PST, "-0800"), (getZoneId("Asia/Hong_Kong"), "+0800")) + + zoneIdAndOffsets.foreach { case (zoneId, offset) => + withDefaultTimeZone(zoneId) { + val table = "insert_desc_same_time_zone" + val tsCol = "timestamp_typed_col" + withTable(table) { + val minTimestamp = "make_timestamp(2022, 1, 1, 0, 0, 1.123456)" + val maxTimestamp = "make_timestamp(2022, 1, 3, 0, 0, 2.987654)" + sql(s"CREATE TABLE $table ($tsCol Timestamp) USING parquet") + sql(s"INSERT INTO $table VALUES $minTimestamp, $maxTimestamp") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR ALL COLUMNS") + + checkDescTimestampColStats( + tableName = table, + timestampColumn = tsCol, + expectedMinTimestamp = "2022-01-01 00:00:01.123456 " + offset, + expectedMaxTimestamp = "2022-01-03 00:00:02.987654 " + offset) + } + } + } + } + + test("SPARK-38140: describe column stats (min, max) for timestamp column: desc should show " + + "different results if writing in UTC and desc in other time zones") { + + val table = "insert_desc_diff_time_zones" + val tsCol = "timestamp_typed_col" + + withDefaultTimeZone(UTC) { + withTable(table) { + val minTimestamp = "make_timestamp(2022, 1, 1, 0, 0, 1.123456)" + val maxTimestamp = "make_timestamp(2022, 1, 3, 0, 0, 2.987654)" + sql(s"CREATE TABLE $table ($tsCol Timestamp) USING parquet") + sql(s"INSERT INTO $table VALUES $minTimestamp, $maxTimestamp") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR ALL COLUMNS") + + checkDescTimestampColStats( + tableName = table, + timestampColumn = tsCol, + expectedMinTimestamp = "2022-01-01 00:00:01.123456 +0000", + expectedMaxTimestamp = "2022-01-03 00:00:02.987654 +0000") + + TimeZone.setDefault(DateTimeUtils.getTimeZone("PST")) + checkDescTimestampColStats( + tableName = table, + timestampColumn = tsCol, + expectedMinTimestamp = "2021-12-31 16:00:01.123456 -0800", + expectedMaxTimestamp = "2022-01-02 16:00:02.987654 -0800") + + TimeZone.setDefault(DateTimeUtils.getTimeZone("Asia/Hong_Kong")) + checkDescTimestampColStats( + tableName = table, + timestampColumn = tsCol, + expectedMinTimestamp = "2022-01-01 08:00:01.123456 +0800", + expectedMaxTimestamp = "2022-01-03 08:00:02.987654 +0800") + } + } + } + + private def getStatAttrNames(tableName: String): Set[String] = { val queryStats = spark.table(tableName).queryExecution.optimizedPlan.stats.attributeStats queryStats.map(_._1.name).toSet } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 30a6600c31765..2f118f236e2c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -112,9 +112,11 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { val df = Seq[(String, String, String, Int)](("hello", "world", null, 15)) .toDF("a", "b", "c", "d") - checkAnswer( - df.selectExpr("elt(0, a, b, c)", "elt(1, a, b, c)", "elt(4, a, b, c)"), - Row(null, "hello", null)) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + checkAnswer( + df.selectExpr("elt(0, a, b, c)", "elt(1, a, b, c)", "elt(4, a, b, c)"), + Row(null, "hello", null)) + } // check implicit type cast checkAnswer( @@ -383,9 +385,11 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { Row("host", "/file;param", "query;p2", null, "http", "/file;param?query;p2", "user:pass@host", "user:pass", null)) - testUrl( - "inva lid://user:pass@host/file;param?query;p2", - Row(null, null, null, null, null, null, null, null, null)) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + testUrl( + "inva lid://user:pass@host/file;param?query;p2", + Row(null, null, null, null, null, null, null, null, null)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index a376c9ce1b09b..92c373a33fb24 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -1956,4 +1956,66 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark assert(!nonDeterministicQueryPlan.deterministic) } + test("SPARK-38132: Not IN subquery correctness checks") { + val t = "test_table" + withTable(t) { + Seq[(Integer, Integer)]( + (1, 1), + (2, 2), + (3, 3), + (4, null), + (null, 0)) + .toDF("c1", "c2").write.saveAsTable(t) + val df = spark.table(t) + + checkAnswer(df.where(s"(c1 NOT IN (SELECT c2 FROM $t)) = true"), Seq.empty) + checkAnswer(df.where(s"(c1 NOT IN (SELECT c2 FROM $t WHERE c2 IS NOT NULL)) = true"), + Row(4, null) :: Nil) + checkAnswer(df.where(s"(c1 NOT IN (SELECT c2 FROM $t)) <=> true"), Seq.empty) + checkAnswer(df.where(s"(c1 NOT IN (SELECT c2 FROM $t WHERE c2 IS NOT NULL)) <=> true"), + Row(4, null) :: Nil) + checkAnswer(df.where(s"(c1 NOT IN (SELECT c2 FROM $t)) != false"), Seq.empty) + checkAnswer(df.where(s"(c1 NOT IN (SELECT c2 FROM $t WHERE c2 IS NOT NULL)) != false"), + Row(4, null) :: Nil) + checkAnswer(df.where(s"NOT((c1 NOT IN (SELECT c2 FROM $t)) <=> false)"), Seq.empty) + checkAnswer(df.where(s"NOT((c1 NOT IN (SELECT c2 FROM $t WHERE c2 IS NOT NULL)) <=> false)"), + Row(4, null) :: Nil) + } + } + + test("SPARK-38155: disallow distinct aggregate in lateral subqueries") { + withTempView("t1", "t2") { + Seq((0, 1)).toDF("c1", "c2").createOrReplaceTempView("t1") + Seq((1, 2), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") + assert(intercept[AnalysisException] { + sql("SELECT * FROM t1 JOIN LATERAL (SELECT DISTINCT c2 FROM t2 WHERE c1 > t1.c1)") + }.getMessage.contains("Correlated column is not allowed in predicate")) + } + } + + test("SPARK-38180: allow safe cast expressions in correlated equality conditions") { + withTempView("t1", "t2") { + Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") + Seq((0, 2), (0, 3)).toDF("c1", "c2").createOrReplaceTempView("t2") + checkAnswer(sql( + """ + |SELECT (SELECT SUM(c2) FROM t2 WHERE c1 = a) + |FROM (SELECT CAST(c1 AS DOUBLE) a FROM t1) + |""".stripMargin), + Row(5) :: Row(null) :: Nil) + checkAnswer(sql( + """ + |SELECT (SELECT SUM(c2) FROM t2 WHERE CAST(c1 AS STRING) = a) + |FROM (SELECT CAST(c1 AS STRING) a FROM t1) + |""".stripMargin), + Row(5) :: Row(null) :: Nil) + assert(intercept[AnalysisException] { + sql( + """ + |SELECT (SELECT SUM(c2) FROM t2 WHERE CAST(c1 AS SHORT) = a) + |FROM (SELECT CAST(c1 AS SHORT) a FROM t1) + |""".stripMargin) + }.getMessage.contains("Correlated column is not allowed in predicate")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index d100cad89fcc1..e651459394fd9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -424,7 +424,7 @@ class UDFSuite extends QueryTest with SharedSparkSession { ("N", Integer.valueOf(3), null)).toDF("a", "b", "c") val udf1 = udf((a: String, b: Int, c: Any) => a + b + c) - val df = input.select(udf1('a, 'b, 'c)) + val df = input.select(udf1(Symbol("a"), 'b, 'c)) checkAnswer(df, Seq(Row("null1x"), Row(null), Row("N3null"))) // test Java UDF. Java UDF can't have primitive inputs, as it's generic typed. @@ -554,7 +554,7 @@ class UDFSuite extends QueryTest with SharedSparkSession { spark.udf.register("buildLocalDateInstantType", udf((d: LocalDate, i: Instant) => LocalDateInstantType(d, i))) checkAnswer(df.selectExpr(s"buildLocalDateInstantType(d, i) as di") - .select('di.cast(StringType)), + .select(Symbol("di").cast(StringType)), Row(s"{$expectedDate, $expectedInstant}") :: Nil) // test null cases @@ -584,7 +584,7 @@ class UDFSuite extends QueryTest with SharedSparkSession { spark.udf.register("buildTimestampInstantType", udf((t: Timestamp, i: Instant) => TimestampInstantType(t, i))) checkAnswer(df.selectExpr("buildTimestampInstantType(t, i) as ti") - .select('ti.cast(StringType)), + .select(Symbol("ti").cast(StringType)), Row(s"{$expectedTimestamp, $expectedInstant}")) // test null cases diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala index 26ec6eeb6c2dc..2c361299b173d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala @@ -191,18 +191,21 @@ class UnwrapCastInComparisonEndToEndSuite extends QueryTest with SharedSparkSess } test("SPARK-36607: Support BooleanType in UnwrapCastInBinaryComparison") { - withTable(t) { - Seq(Some(true), Some(false), None).toDF().write.saveAsTable(t) - val df = spark.table(t) - - checkAnswer(df.where("value = -1"), Seq.empty) - checkAnswer(df.where("value = 0"), Row(false)) - checkAnswer(df.where("value = 1"), Row(true)) - checkAnswer(df.where("value = 2"), Seq.empty) - checkAnswer(df.where("value <=> -1"), Seq.empty) - checkAnswer(df.where("value <=> 0"), Row(false)) - checkAnswer(df.where("value <=> 1"), Row(true)) - checkAnswer(df.where("value <=> 2"), Seq.empty) + // If ANSI mode is on, Spark disallows comparing Int with Boolean. + if (!conf.ansiEnabled) { + withTable(t) { + Seq(Some(true), Some(false), None).toDF().write.saveAsTable(t) + val df = spark.table(t) + + checkAnswer(df.where("value = -1"), Seq.empty) + checkAnswer(df.where("value = 0"), Row(false)) + checkAnswer(df.where("value = 1"), Row(true)) + checkAnswer(df.where("value = 2"), Seq.empty) + checkAnswer(df.where("value <=> -1"), Seq.empty) + checkAnswer(df.where("value <=> 0"), Row(false)) + checkAnswer(df.where("value <=> 1"), Row(true)) + checkAnswer(df.where("value <=> 2"), Seq.empty) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index cc52b6d8a14a7..729312c3e5912 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -82,14 +82,14 @@ class UserDefinedTypeSuite extends QueryTest with SharedSparkSession with Parque } test("register user type: MyDenseVector for MyLabeledPoint") { - val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v } + val labels: RDD[Double] = pointsRDD.select(Symbol("label")).rdd.map { case Row(v: Double) => v } val labelsArrays: Array[Double] = labels.collect() assert(labelsArrays.size === 2) assert(labelsArrays.contains(1.0)) assert(labelsArrays.contains(0.0)) val features: RDD[TestUDT.MyDenseVector] = - pointsRDD.select('features).rdd.map { case Row(v: TestUDT.MyDenseVector) => v } + pointsRDD.select(Symbol("features")).rdd.map { case Row(v: TestUDT.MyDenseVector) => v } val featuresArrays: Array[TestUDT.MyDenseVector] = features.collect() assert(featuresArrays.size === 2) assert(featuresArrays.contains(new TestUDT.MyDenseVector(Array(0.1, 1.0)))) @@ -137,8 +137,9 @@ class UserDefinedTypeSuite extends QueryTest with SharedSparkSession with Parque val df = Seq((1, vec)).toDF("int", "vec") assert(vec === df.collect()(0).getAs[TestUDT.MyDenseVector](1)) assert(vec === df.take(1)(0).getAs[TestUDT.MyDenseVector](1)) - checkAnswer(df.limit(1).groupBy('int).agg(first('vec)), Row(1, vec)) - checkAnswer(df.orderBy('int).limit(1).groupBy('int).agg(first('vec)), Row(1, vec)) + checkAnswer(df.limit(1).groupBy(Symbol("int")).agg(first(Symbol("vec"))), Row(1, vec)) + checkAnswer(df.orderBy(Symbol("int")).limit(1).groupBy(Symbol("int")) + .agg(first(Symbol("vec"))), Row(1, vec)) } test("UDTs with JSON") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala index 3edc4b9502064..98d95e48f5447 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.connector -import java.util - import org.scalatest.BeforeAndAfter import org.apache.spark.sql.{DataFrame, QueryTest, SaveMode} @@ -97,7 +95,7 @@ class InMemoryTableSessionCatalog extends TestV2SessionCatalogBase[InMemoryTable name: String, schema: StructType, partitions: Array[Transform], - properties: util.Map[String, String]): InMemoryTable = { + properties: java.util.Map[String, String]): InMemoryTable = { new InMemoryTable(name, schema, partitions, properties) } @@ -210,7 +208,7 @@ private [connector] trait SessionCatalogTest[T <: Table, Catalog <: TestV2Sessio verifyTable(t1, df) // Check that appends are by name - df.select('data, 'id).write.format(v2Format).mode("append").saveAsTable(t1) + df.select(Symbol("data"), Symbol("id")).write.format(v2Format).mode("append").saveAsTable(t1) verifyTable(t1, df.union(df)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala index dd810a70d1585..03dcfcf7ddc7d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala @@ -93,7 +93,7 @@ class DataSourceV2DataFrameSuite assert(spark.table(t1).count() === 0) // appends are by name not by position - df.select('data, 'id).write.mode("append").saveAsTable(t1) + df.select(Symbol("data"), Symbol("id")).write.mode("append").saveAsTable(t1) checkAnswer(spark.table(t1), df) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala index 3277cd69a0e93..92a5c552108b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala @@ -17,25 +17,26 @@ package org.apache.spark.sql.connector -import java.util import java.util.Collections -import test.org.apache.spark.sql.connector.catalog.functions.{JavaAverage, JavaLongAdd, JavaStrLen} -import test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd.{JavaLongAddDefault, JavaLongAddMagic, JavaLongAddMismatchMagic, JavaLongAddStaticMagic} +import test.org.apache.spark.sql.connector.catalog.functions._ +import test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd._ +import test.org.apache.spark.sql.connector.catalog.functions.JavaRandomAdd._ import test.org.apache.spark.sql.connector.catalog.functions.JavaStrLen._ import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode.{FALLBACK, NO_CODEGEN} import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, InMemoryCatalog, SupportsNamespaces} import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction, _} +import org.apache.spark.sql.execution.ProjectExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { - private val emptyProps: util.Map[String, String] = Collections.emptyMap[String, String] + private val emptyProps: java.util.Map[String, String] = Collections.emptyMap[String, String] private def addFunction(ident: Identifier, fn: UnboundFunction): Unit = { catalog("testcat").asInstanceOf[InMemoryCatalog].createFunction(ident, fn) @@ -428,6 +429,31 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { } } + test("SPARK-37957: pass deterministic flag when creating V2 function expression") { + def checkDeterministic(df: DataFrame): Unit = { + val result = df.queryExecution.executedPlan.find(_.isInstanceOf[ProjectExec]) + assert(result.isDefined, s"Expect to find ProjectExec") + assert(!result.get.asInstanceOf[ProjectExec].projectList.exists(_.deterministic), + "Expect expressions in projectList to be non-deterministic") + } + + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + Seq(new JavaRandomAddDefault, new JavaRandomAddMagic, + new JavaRandomAddStaticMagic).foreach { fn => + addFunction(Identifier.of(Array("ns"), "rand_add"), new JavaRandomAdd(fn)) + checkDeterministic(sql("SELECT testcat.ns.rand_add(42)")) + } + + // A function call is non-deterministic if one of its arguments is non-deterministic + Seq(new JavaLongAddDefault(true), new JavaLongAddMagic(true), + new JavaLongAddStaticMagic(true)).foreach { fn => + addFunction(Identifier.of(Array("ns"), "add"), new JavaLongAdd(fn)) + addFunction(Identifier.of(Array("ns"), "rand_add"), + new JavaRandomAdd(new JavaRandomAddDefault)) + checkDeterministic(sql("SELECT testcat.ns.add(10, testcat.ns.rand_add(42))")) + } + } + private case class StrLen(impl: BoundFunction) extends UnboundFunction { override def description(): String = """strlen: returns the length of the input string diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 3667a10f132ad..b64ed080d8bf1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -410,9 +410,12 @@ class DataSourceV2SQLSuite test("SPARK-36850: CreateTableAsSelect partitions can be specified using " + "PARTITIONED BY and/or CLUSTERED BY") { val identifier = "testcat.table_name" + val df = spark.createDataFrame(Seq((1L, "a", "a1", "a2", "a3"), (2L, "b", "b1", "b2", "b3"), + (3L, "c", "c1", "c2", "c3"))).toDF("id", "data1", "data2", "data3", "data4") + df.createOrReplaceTempView("source_table") withTable(identifier) { spark.sql(s"CREATE TABLE $identifier USING foo PARTITIONED BY (id) " + - s"CLUSTERED BY (data) INTO 4 BUCKETS AS SELECT * FROM source") + s"CLUSTERED BY (data1, data2, data3, data4) INTO 4 BUCKETS AS SELECT * FROM source_table") val describe = spark.sql(s"DESCRIBE $identifier") val part1 = describe .filter("col_name = 'Part 0'") @@ -421,18 +424,22 @@ class DataSourceV2SQLSuite val part2 = describe .filter("col_name = 'Part 1'") .select("data_type").head.getString(0) - assert(part2 === "bucket(4, data)") + assert(part2 === "bucket(4, data1, data2, data3, data4)") } } test("SPARK-36850: ReplaceTableAsSelect partitions can be specified using " + "PARTITIONED BY and/or CLUSTERED BY") { val identifier = "testcat.table_name" + val df = spark.createDataFrame(Seq((1L, "a", "a1", "a2", "a3"), (2L, "b", "b1", "b2", "b3"), + (3L, "c", "c1", "c2", "c3"))).toDF("id", "data1", "data2", "data3", "data4") + df.createOrReplaceTempView("source_table") withTable(identifier) { spark.sql(s"CREATE TABLE $identifier USING foo " + "AS SELECT id FROM source") spark.sql(s"REPLACE TABLE $identifier USING foo PARTITIONED BY (id) " + - s"CLUSTERED BY (data) INTO 4 BUCKETS AS SELECT * FROM source") + s"CLUSTERED BY (data1, data2) SORTED by (data3, data4) INTO 4 BUCKETS " + + s"AS SELECT * FROM source_table") val describe = spark.sql(s"DESCRIBE $identifier") val part1 = describe .filter("col_name = 'Part 0'") @@ -441,7 +448,7 @@ class DataSourceV2SQLSuite val part2 = describe .filter("col_name = 'Part 1'") .select("data_type").head.getString(0) - assert(part2 === "bucket(4, data)") + assert(part2 === "sorted_bucket(data1, data2, 4, data3, data4)") } } @@ -1479,18 +1486,21 @@ class DataSourceV2SQLSuite test("create table using - with sorted bucket") { val identifier = "testcat.table_name" withTable(identifier) { - sql(s"CREATE TABLE $identifier (a int, b string, c int) USING $v2Source PARTITIONED BY (c)" + - s" CLUSTERED BY (b) SORTED by (a) INTO 4 BUCKETS") - val table = getTableMetadata(identifier) + sql(s"CREATE TABLE $identifier (a int, b string, c int, d int, e int, f int) USING" + + s" $v2Source PARTITIONED BY (a, b) CLUSTERED BY (c, d) SORTED by (e, f) INTO 4 BUCKETS") val describe = spark.sql(s"DESCRIBE $identifier") val part1 = describe .filter("col_name = 'Part 0'") .select("data_type").head.getString(0) - assert(part1 === "c") + assert(part1 === "a") val part2 = describe .filter("col_name = 'Part 1'") .select("data_type").head.getString(0) - assert(part2 === "bucket(4, b, a)") + assert(part2 === "b") + val part3 = describe + .filter("col_name = 'Part 2'") + .select("data_type").head.getString(0) + assert(part3 === "sorted_bucket(c, d, 4, e, f)") } } @@ -1854,109 +1864,6 @@ class DataSourceV2SQLSuite } } - test("SPARK-33898: SHOW CREATE TABLE AS SERDE") { - val t = "testcat.ns1.ns2.tbl" - withTable(t) { - spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo") - val e = intercept[AnalysisException] { - sql(s"SHOW CREATE TABLE $t AS SERDE") - } - assert(e.message.contains(s"SHOW CREATE TABLE AS SERDE is not supported for v2 tables.")) - } - } - - test("SPARK-33898: SHOW CREATE TABLE") { - val t = "testcat.ns1.ns2.tbl" - withTable(t) { - sql( - s""" - |CREATE TABLE $t ( - | a bigint NOT NULL, - | b bigint, - | c bigint, - | `extra col` ARRAY, - | `` STRUCT> - |) - |USING foo - |OPTIONS ( - | from = 0, - | to = 1, - | via = 2) - |COMMENT 'This is a comment' - |TBLPROPERTIES ('prop1' = '1', 'prop2' = '2', 'prop3' = 3, 'prop4' = 4) - |PARTITIONED BY (a) - |LOCATION 'file:/tmp' - """.stripMargin) - val showDDL = getShowCreateDDL(s"SHOW CREATE TABLE $t") - assert(showDDL === Array( - "CREATE TABLE testcat.ns1.ns2.tbl (", - "`a` BIGINT NOT NULL,", - "`b` BIGINT,", - "`c` BIGINT,", - "`extra col` ARRAY,", - "`` STRUCT<`x`: INT, `y`: ARRAY>)", - "USING foo", - "OPTIONS(", - "'from' = '0',", - "'to' = '1',", - "'via' = '2')", - "PARTITIONED BY (a)", - "COMMENT 'This is a comment'", - "LOCATION 'file:/tmp'", - "TBLPROPERTIES (", - "'prop1' = '1',", - "'prop2' = '2',", - "'prop3' = '3',", - "'prop4' = '4')" - )) - } - } - - test("SPARK-33898: SHOW CREATE TABLE WITH AS SELECT") { - val t = "testcat.ns1.ns2.tbl" - withTable(t) { - sql( - s""" - |CREATE TABLE $t - |USING foo - |AS SELECT 1 AS a, "foo" AS b - """.stripMargin) - val showDDL = getShowCreateDDL(s"SHOW CREATE TABLE $t") - assert(showDDL === Array( - "CREATE TABLE testcat.ns1.ns2.tbl (", - "`a` INT,", - "`b` STRING)", - "USING foo" - )) - } - } - - test("SPARK-33898: SHOW CREATE TABLE PARTITIONED BY Transforms") { - val t = "testcat.ns1.ns2.tbl" - withTable(t) { - sql( - s""" - |CREATE TABLE $t (a INT, b STRING, ts TIMESTAMP) USING foo - |PARTITIONED BY ( - | a, - | bucket(16, b), - | years(ts), - | months(ts), - | days(ts), - | hours(ts)) - """.stripMargin) - val showDDL = getShowCreateDDL(s"SHOW CREATE TABLE $t") - assert(showDDL === Array( - "CREATE TABLE testcat.ns1.ns2.tbl (", - "`a` INT,", - "`b` STRING,", - "`ts` TIMESTAMP)", - "USING foo", - "PARTITIONED BY (a, bucket(16, b), years(ts), months(ts), days(ts), hours(ts))" - )) - } - } - test("CACHE/UNCACHE TABLE") { val t = "testcat.ns1.ns2.tbl" withTable(t) { @@ -2866,8 +2773,9 @@ class DataSourceV2SQLSuite val properties = table.properties assert(properties.get(TableCatalog.PROP_PROVIDER) == "parquet") assert(properties.get(TableCatalog.PROP_COMMENT) == "This is a comment") - assert(properties.get(TableCatalog.PROP_LOCATION) == "file:/tmp") + assert(properties.get(TableCatalog.PROP_LOCATION) == "file:///tmp") assert(properties.containsKey(TableCatalog.PROP_OWNER)) + assert(properties.get(TableCatalog.PROP_EXTERNAL) == "true") assert(properties.get(s"${TableCatalog.OPTION_PREFIX}from") == "0") assert(properties.get(s"${TableCatalog.OPTION_PREFIX}to") == "1") assert(properties.get("prop1") == "1") @@ -2901,10 +2809,6 @@ class DataSourceV2SQLSuite assert(ex.getErrorClass == expectedErrorClass) assert(ex.messageParameters.sameElements(expectedErrorMessageParameters)) } - - private def getShowCreateDDL(showCreateTableSql: String): Array[String] = { - sql(showCreateTableSql).head().getString(0).split("\n").map(_.trim) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index fd3c69eff5652..8f37e42b167be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.connector import java.io.File -import java.util import java.util.OptionalLong import test.org.apache.spark.sql.connector._ @@ -80,8 +79,8 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS withClue(cls.getName) { val df = spark.read.format(cls.getName).load() checkAnswer(df, (0 until 10).map(i => Row(i, -i))) - checkAnswer(df.select('j), (0 until 10).map(i => Row(-i))) - checkAnswer(df.filter('i > 5), (6 until 10).map(i => Row(i, -i))) + checkAnswer(df.select(Symbol("j")), (0 until 10).map(i => Row(-i))) + checkAnswer(df.filter(Symbol("i") > 5), (6 until 10).map(i => Row(i, -i))) } } } @@ -92,7 +91,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS val df = spark.read.format(cls.getName).load() checkAnswer(df, (0 until 10).map(i => Row(i, -i))) - val q1 = df.select('j) + val q1 = df.select(Symbol("j")) checkAnswer(q1, (0 until 10).map(i => Row(-i))) if (cls == classOf[AdvancedDataSourceV2]) { val batch = getBatch(q1) @@ -104,7 +103,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS assert(batch.requiredSchema.fieldNames === Seq("j")) } - val q2 = df.filter('i > 3) + val q2 = df.filter(Symbol("i") > 3) checkAnswer(q2, (4 until 10).map(i => Row(i, -i))) if (cls == classOf[AdvancedDataSourceV2]) { val batch = getBatch(q2) @@ -116,7 +115,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS assert(batch.requiredSchema.fieldNames === Seq("i", "j")) } - val q3 = df.select('i).filter('i > 6) + val q3 = df.select(Symbol("i")).filter(Symbol("i") > 6) checkAnswer(q3, (7 until 10).map(i => Row(i))) if (cls == classOf[AdvancedDataSourceV2]) { val batch = getBatch(q3) @@ -128,16 +127,16 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS assert(batch.requiredSchema.fieldNames === Seq("i")) } - val q4 = df.select('j).filter('j < -10) + val q4 = df.select(Symbol("j")).filter(Symbol("j") < -10) checkAnswer(q4, Nil) if (cls == classOf[AdvancedDataSourceV2]) { val batch = getBatch(q4) - // 'j < 10 is not supported by the testing data source. + // Symbol("j") < 10 is not supported by the testing data source. assert(batch.filters.isEmpty) assert(batch.requiredSchema.fieldNames === Seq("j")) } else { val batch = getJavaBatch(q4) - // 'j < 10 is not supported by the testing data source. + // Symbol("j") < 10 is not supported by the testing data source. assert(batch.filters.isEmpty) assert(batch.requiredSchema.fieldNames === Seq("j")) } @@ -152,7 +151,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS val df = spark.read.format(cls.getName).load() checkAnswer(df, (0 until 10).map(i => Row(i, -i))) - val q1 = df.select('j) + val q1 = df.select(Symbol("j")) checkAnswer(q1, (0 until 10).map(i => Row(-i))) if (cls == classOf[AdvancedDataSourceV2WithV2Filter]) { val batch = getBatchWithV2Filter(q1) @@ -164,7 +163,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS assert(batch.requiredSchema.fieldNames === Seq("j")) } - val q2 = df.filter('i > 3) + val q2 = df.filter(Symbol("i") > 3) checkAnswer(q2, (4 until 10).map(i => Row(i, -i))) if (cls == classOf[AdvancedDataSourceV2WithV2Filter]) { val batch = getBatchWithV2Filter(q2) @@ -176,7 +175,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS assert(batch.requiredSchema.fieldNames === Seq("i", "j")) } - val q3 = df.select('i).filter('i > 6) + val q3 = df.select(Symbol("i")).filter(Symbol("i") > 6) checkAnswer(q3, (7 until 10).map(i => Row(i))) if (cls == classOf[AdvancedDataSourceV2WithV2Filter]) { val batch = getBatchWithV2Filter(q3) @@ -188,16 +187,16 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS assert(batch.requiredSchema.fieldNames === Seq("i")) } - val q4 = df.select('j).filter('j < -10) + val q4 = df.select(Symbol("j")).filter(Symbol("j") < -10) checkAnswer(q4, Nil) if (cls == classOf[AdvancedDataSourceV2WithV2Filter]) { val batch = getBatchWithV2Filter(q4) - // 'j < 10 is not supported by the testing data source. + // Symbol("j") < 10 is not supported by the testing data source. assert(batch.filters.isEmpty) assert(batch.requiredSchema.fieldNames === Seq("j")) } else { val batch = getJavaBatchWithV2Filter(q4) - // 'j < 10 is not supported by the testing data source. + // Symbol("j") < 10 is not supported by the testing data source. assert(batch.filters.isEmpty) assert(batch.requiredSchema.fieldNames === Seq("j")) } @@ -210,8 +209,8 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS withClue(cls.getName) { val df = spark.read.format(cls.getName).load() checkAnswer(df, (0 until 90).map(i => Row(i, -i))) - checkAnswer(df.select('j), (0 until 90).map(i => Row(-i))) - checkAnswer(df.filter('i > 50), (51 until 90).map(i => Row(i, -i))) + checkAnswer(df.select(Symbol("j")), (0 until 90).map(i => Row(-i))) + checkAnswer(df.filter(Symbol("i") > 50), (51 until 90).map(i => Row(i, -i))) } } } @@ -235,12 +234,12 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS "supports external metadata") { withTempDir { dir => val cls = classOf[SupportsExternalMetadataWritableDataSource].getName - spark.range(10).select('id as 'i, -'id as 'j).write.format(cls) - .option("path", dir.getCanonicalPath).mode("append").save() + spark.range(10).select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j")) + .write.format(cls).option("path", dir.getCanonicalPath).mode("append").save() val schema = new StructType().add("i", "long").add("j", "long") checkAnswer( spark.read.format(cls).option("path", dir.getCanonicalPath).schema(schema).load(), - spark.range(10).select('id, -'id)) + spark.range(10).select(Symbol("id"), -Symbol("id"))) } } @@ -251,25 +250,25 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS val df = spark.read.format(cls.getName).load() checkAnswer(df, Seq(Row(1, 4), Row(1, 4), Row(3, 6), Row(2, 6), Row(4, 2), Row(4, 2))) - val groupByColA = df.groupBy('i).agg(sum('j)) + val groupByColA = df.groupBy(Symbol("i")).agg(sum(Symbol("j"))) checkAnswer(groupByColA, Seq(Row(1, 8), Row(2, 6), Row(3, 6), Row(4, 4))) assert(collectFirst(groupByColA.queryExecution.executedPlan) { case e: ShuffleExchangeExec => e }.isEmpty) - val groupByColAB = df.groupBy('i, 'j).agg(count("*")) + val groupByColAB = df.groupBy(Symbol("i"), Symbol("j")).agg(count("*")) checkAnswer(groupByColAB, Seq(Row(1, 4, 2), Row(2, 6, 1), Row(3, 6, 1), Row(4, 2, 2))) assert(collectFirst(groupByColAB.queryExecution.executedPlan) { case e: ShuffleExchangeExec => e }.isEmpty) - val groupByColB = df.groupBy('j).agg(sum('i)) + val groupByColB = df.groupBy(Symbol("j")).agg(sum(Symbol("i"))) checkAnswer(groupByColB, Seq(Row(2, 8), Row(4, 2), Row(6, 5))) assert(collectFirst(groupByColB.queryExecution.executedPlan) { case e: ShuffleExchangeExec => e }.isDefined) - val groupByAPlusB = df.groupBy('i + 'j).agg(count("*")) + val groupByAPlusB = df.groupBy(Symbol("i") + Symbol("j")).agg(count("*")) checkAnswer(groupByAPlusB, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1))) assert(collectFirst(groupByAPlusB.queryExecution.executedPlan) { case e: ShuffleExchangeExec => e @@ -307,37 +306,43 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS val path = file.getCanonicalPath assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) - spark.range(10).select('id as 'i, -'id as 'j).write.format(cls.getName) + spark.range(10).select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j")) + .write.format(cls.getName) .option("path", path).mode("append").save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), - spark.range(10).select('id, -'id)) + spark.range(10).select(Symbol("id"), -Symbol("id"))) // default save mode is ErrorIfExists intercept[AnalysisException] { - spark.range(10).select('id as 'i, -'id as 'j).write.format(cls.getName) + spark.range(10).select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j")) + .write.format(cls.getName) .option("path", path).save() } - spark.range(10).select('id as 'i, -'id as 'j).write.mode("append").format(cls.getName) + spark.range(10).select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j")) + .write.mode("append").format(cls.getName) .option("path", path).save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), - spark.range(10).union(spark.range(10)).select('id, -'id)) + spark.range(10).union(spark.range(10)).select(Symbol("id"), -Symbol("id"))) - spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName) + spark.range(5).select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j")) + .write.format(cls.getName) .option("path", path).mode("overwrite").save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), - spark.range(5).select('id, -'id)) + spark.range(5).select(Symbol("id"), -Symbol("id"))) val e = intercept[AnalysisException] { - spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName) + spark.range(5).select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j")) + .write.format(cls.getName) .option("path", path).mode("ignore").save() } assert(e.message.contains("please use Append or Overwrite modes instead")) val e2 = intercept[AnalysisException] { - spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName) + spark.range(5).select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j")) + .write.format(cls.getName) .option("path", path).mode("error").save() } assert(e2.getMessage.contains("please use Append or Overwrite modes instead")) @@ -354,7 +359,8 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS } } // this input data will fail to read middle way. - val input = spark.range(15).select(failingUdf('id).as('i)).select('i, -'i as 'j) + val input = spark.range(15).select(failingUdf(Symbol("id")).as(Symbol("i"))) + .select(Symbol("i"), -Symbol("i") as Symbol("j")) val e3 = intercept[SparkException] { input.write.format(cls.getName).option("path", path).mode("overwrite").save() } @@ -374,11 +380,13 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) val numPartition = 6 - spark.range(0, 10, 1, numPartition).select('id as 'i, -'id as 'j).write.format(cls.getName) + spark.range(0, 10, 1, numPartition) + .select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j")) + .write.format(cls.getName) .mode("append").option("path", path).save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), - spark.range(10).select('id, -'id)) + spark.range(10).select(Symbol("id"), -Symbol("id"))) assert(SimpleCounter.getCounter == numPartition, "method onDataWriterCommit should be called as many as the number of partitions") @@ -395,7 +403,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS test("SPARK-23301: column pruning with arbitrary expressions") { val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() - val q1 = df.select('i + 1) + val q1 = df.select(Symbol("i") + 1) checkAnswer(q1, (1 until 11).map(i => Row(i))) val batch1 = getBatch(q1) assert(batch1.requiredSchema.fieldNames === Seq("i")) @@ -406,14 +414,14 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS assert(batch2.requiredSchema.isEmpty) // 'j === 1 can't be pushed down, but we should still be able do column pruning - val q3 = df.filter('j === -1).select('j * 2) + val q3 = df.filter(Symbol("j") === -1).select(Symbol("j") * 2) checkAnswer(q3, Row(-2)) val batch3 = getBatch(q3) assert(batch3.filters.isEmpty) assert(batch3.requiredSchema.fieldNames === Seq("j")) // column pruning should work with other operators. - val q4 = df.sort('i).limit(1).select('i + 1) + val q4 = df.sort(Symbol("i")).limit(1).select(Symbol("i") + 1) checkAnswer(q4, Row(1)) val batch4 = getBatch(q4) assert(batch4.requiredSchema.fieldNames === Seq("i")) @@ -435,7 +443,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() checkCanonicalizedOutput(df, 2, 2) - checkCanonicalizedOutput(df.select('i), 2, 1) + checkCanonicalizedOutput(df.select(Symbol("i")), 2, 1) } test("SPARK-25425: extra options should override sessions options during reading") { @@ -474,7 +482,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS withTempView("t1") { val t2 = spark.read.format(classOf[SimpleDataSourceV2].getName).load() Seq(2, 3).toDF("a").createTempView("t1") - val df = t2.where("i < (select max(a) from t1)").select('i) + val df = t2.where("i < (select max(a) from t1)").select(Symbol("i")) val subqueries = stripAQEPlan(df.queryExecution.executedPlan).collect { case p => p.subqueries }.flatten @@ -493,8 +501,8 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS Seq(classOf[AdvancedDataSourceV2], classOf[JavaAdvancedDataSourceV2]).foreach { cls => withClue(cls.getName) { val df = spark.read.format(cls.getName).load() - val q1 = df.select('i).filter('i > 6) - val q2 = df.select('i).filter('i > 5) + val q1 = df.select(Symbol("i")).filter(Symbol("i") > 6) + val q2 = df.select(Symbol("i")).filter(Symbol("i") > 5) val scan1 = getScanExec(q1) val scan2 = getScanExec(q2) assert(!scan1.equals(scan2)) @@ -507,7 +515,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS withClue(cls.getName) { val df = spark.read.format(cls.getName).load() // before SPARK-33267 below query just threw NPE - df.select('i).where("i in (1, null)").collect() + df.select(Symbol("i")).where("i in (1, null)").collect() } } } @@ -552,7 +560,7 @@ abstract class SimpleBatchTable extends Table with SupportsRead { override def name(): String = this.getClass.toString - override def capabilities(): util.Set[TableCapability] = util.EnumSet.of(BATCH_READ) + override def capabilities(): java.util.Set[TableCapability] = java.util.EnumSet.of(BATCH_READ) } abstract class SimpleScanBuilder extends ScanBuilder @@ -575,7 +583,7 @@ trait TestingV2Source extends TableProvider { override def getTable( schema: StructType, partitioning: Array[Transform], - properties: util.Map[String, String]): Table = { + properties: java.util.Map[String, String]): Table = { getTable(new CaseInsensitiveStringMap(properties)) } @@ -792,7 +800,7 @@ class SchemaRequiredDataSource extends TableProvider { override def getTable( schema: StructType, partitioning: Array[Transform], - properties: util.Map[String, String]): Table = { + properties: java.util.Map[String, String]): Table = { val userGivenSchema = schema new SimpleBatchTable { override def schema(): StructType = userGivenSchema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala index 5156bd40bee69..cfc8b2cc84524 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala @@ -184,7 +184,7 @@ class FileDataSourceV2FallBackSuite extends QueryTest with SharedSparkSession { val df = spark.read.format(format).load(path.getCanonicalPath) checkAnswer(df, inputData.toDF()) assert( - df.queryExecution.executedPlan.find(_.isInstanceOf[FileSourceScanExec]).isDefined) + df.queryExecution.executedPlan.exists(_.isInstanceOf[FileSourceScanExec])) } } finally { spark.listenerManager.unregister(listener) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala index 0dee48fbb5b92..fc98cfd5138e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala @@ -282,22 +282,6 @@ trait InsertIntoSQLOnlyTests } } - test("InsertInto: IF PARTITION NOT EXISTS not supported") { - val t1 = s"${catalogAndNamespace}tbl" - withTableAndData(t1) { view => - sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format PARTITIONED BY (id)") - - val exc = intercept[AnalysisException] { - sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id = 1) IF NOT EXISTS SELECT * FROM $view") - } - - verifyTable(t1, spark.emptyDataFrame) - assert(exc.getMessage.contains("Cannot write, IF NOT EXISTS is not supported for table")) - assert(exc.getMessage.contains(t1)) - assert(exc.getErrorClass == "IF_PARTITION_NOT_EXISTS_UNSUPPORTED") - } - } - test("InsertInto: overwrite - dynamic clause - static mode") { withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) { val t1 = s"${catalogAndNamespace}tbl" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala index 094667001b6c3..e3d61a846fdb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.connector -import java.util - import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, SupportsRead, Table, TableCapability} @@ -61,7 +59,7 @@ class TestLocalScanCatalog extends BasicInMemoryTableCatalog { ident: Identifier, schema: StructType, partitions: Array[Transform], - properties: util.Map[String, String]): Table = { + properties: java.util.Map[String, String]): Table = { val table = new TestLocalScanTable(ident.toString) tables.put(ident, table) table @@ -76,8 +74,8 @@ object TestLocalScanTable { class TestLocalScanTable(override val name: String) extends Table with SupportsRead { override def schema(): StructType = TestLocalScanTable.schema - override def capabilities(): util.Set[TableCapability] = - util.EnumSet.of(TableCapability.BATCH_READ) + override def capabilities(): java.util.Set[TableCapability] = + java.util.EnumSet.of(TableCapability.BATCH_READ) override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new TestLocalScanBuilder diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala index 99c322a7155f2..64c893ed74fdb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.connector import java.io.{BufferedReader, InputStreamReader, IOException} -import java.util import scala.collection.JavaConverters._ @@ -138,8 +137,8 @@ class SimpleWritableDataSource extends TestingV2Source { new MyWriteBuilder(path, info) } - override def capabilities(): util.Set[TableCapability] = - util.EnumSet.of(BATCH_READ, BATCH_WRITE, TRUNCATE) + override def capabilities(): java.util.Set[TableCapability] = + java.util.EnumSet.of(BATCH_READ, BATCH_WRITE, TRUNCATE) } override def getTable(options: CaseInsensitiveStringMap): Table = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala index 9cb524c2c3822..473f679b4b99d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala @@ -75,7 +75,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with saveMode: SaveMode, withCatalogOption: Option[String], partitionBy: Seq[String]): Unit = { - val df = spark.range(10).withColumn("part", 'id % 5) + val df = spark.range(10).withColumn("part", Symbol("id") % 5) val dfw = df.write.format(format).mode(saveMode).option("name", "t1") withCatalogOption.foreach(cName => dfw.option("catalog", cName)) dfw.partitionBy(partitionBy: _*).save() @@ -140,7 +140,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with test("Ignore mode if table exists - session catalog") { sql(s"create table t1 (id bigint) using $format") - val df = spark.range(10).withColumn("part", 'id % 5) + val df = spark.range(10).withColumn("part", Symbol("id") % 5) val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name", "t1") dfw.save() @@ -152,7 +152,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with test("Ignore mode if table exists - testcat catalog") { sql(s"create table $catalogName.t1 (id bigint) using $format") - val df = spark.range(10).withColumn("part", 'id % 5) + val df = spark.range(10).withColumn("part", Symbol("id") % 5) val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name", "t1") dfw.option("catalog", catalogName).save() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala index a12065ec0ab2a..5f2e0b28aeccc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.connector -import java.util - import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, NamedRelation} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal} @@ -215,8 +213,8 @@ private case object TestRelation extends LeafNode with NamedRelation { private case class CapabilityTable(_capabilities: TableCapability*) extends Table { override def name(): String = "capability_test_table" override def schema(): StructType = TableCapabilityCheckSuite.schema - override def capabilities(): util.Set[TableCapability] = { - val set = util.EnumSet.noneOf(classOf[TableCapability]) + override def capabilities(): java.util.Set[TableCapability] = { + val set = java.util.EnumSet.noneOf(classOf[TableCapability]) _capabilities.foreach(set.add) set } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala index bf2749d1afc53..0a0aaa8021996 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.connector -import java.util import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicBoolean @@ -35,7 +34,7 @@ import org.apache.spark.sql.types.StructType */ private[connector] trait TestV2SessionCatalogBase[T <: Table] extends DelegatingCatalogExtension { - protected val tables: util.Map[Identifier, T] = new ConcurrentHashMap[Identifier, T]() + protected val tables: java.util.Map[Identifier, T] = new ConcurrentHashMap[Identifier, T]() private val tableCreated: AtomicBoolean = new AtomicBoolean(false) @@ -48,7 +47,7 @@ private[connector] trait TestV2SessionCatalogBase[T <: Table] extends Delegating name: String, schema: StructType, partitions: Array[Transform], - properties: util.Map[String, String]): T + properties: java.util.Map[String, String]): T override def loadTable(ident: Identifier): Table = { if (tables.containsKey(ident)) { @@ -69,12 +68,12 @@ private[connector] trait TestV2SessionCatalogBase[T <: Table] extends Delegating ident: Identifier, schema: StructType, partitions: Array[Transform], - properties: util.Map[String, String]): Table = { + properties: java.util.Map[String, String]): Table = { val key = TestV2SessionCatalogBase.SIMULATE_ALLOW_EXTERNAL_PROPERTY val propsWithLocation = if (properties.containsKey(key)) { // Always set a location so that CREATE EXTERNAL TABLE won't fail with LOCATION not specified. if (!properties.containsKey(TableCatalog.PROP_LOCATION)) { - val newProps = new util.HashMap[String, String]() + val newProps = new java.util.HashMap[String, String]() newProps.putAll(properties) newProps.put(TableCatalog.PROP_LOCATION, "file:/abc") newProps diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala index ff1bd29808637..c5be222645b19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.connector -import java.util - import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, QueryTest, Row, SparkSession, SQLContext} import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, SupportsRead, Table, TableCapability} @@ -104,7 +102,7 @@ class V1ReadFallbackCatalog extends BasicInMemoryTableCatalog { ident: Identifier, schema: StructType, partitions: Array[Transform], - properties: util.Map[String, String]): Table = { + properties: java.util.Map[String, String]): Table = { // To simplify the test implementation, only support fixed schema. if (schema != V1ReadFallbackCatalog.schema || partitions.nonEmpty) { throw new UnsupportedOperationException @@ -129,8 +127,8 @@ class TableWithV1ReadFallback(override val name: String) extends Table with Supp override def schema(): StructType = V1ReadFallbackCatalog.schema - override def capabilities(): util.Set[TableCapability] = { - util.EnumSet.of(TableCapability.BATCH_READ) + override def capabilities(): java.util.Set[TableCapability] = { + java.util.EnumSet.of(TableCapability.BATCH_READ) } override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala index 9fbaf7890f8f8..992c46cc6cdb1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.connector -import java.util - import scala.collection.JavaConverters._ import scala.collection.mutable @@ -223,7 +221,7 @@ class V1FallbackTableCatalog extends TestV2SessionCatalogBase[InMemoryTableWithV name: String, schema: StructType, partitions: Array[Transform], - properties: util.Map[String, String]): InMemoryTableWithV1Fallback = { + properties: java.util.Map[String, String]): InMemoryTableWithV1Fallback = { val t = new InMemoryTableWithV1Fallback(name, schema, partitions, properties) InMemoryV1Provider.tables.put(name, t) tables.put(Identifier.of(Array("default"), name), t) @@ -321,7 +319,7 @@ class InMemoryTableWithV1Fallback( override val name: String, override val schema: StructType, override val partitioning: Array[Transform], - override val properties: util.Map[String, String]) + override val properties: java.util.Map[String, String]) extends Table with SupportsWrite with SupportsRead { @@ -331,7 +329,7 @@ class InMemoryTableWithV1Fallback( } } - override def capabilities: util.Set[TableCapability] = util.EnumSet.of( + override def capabilities: java.util.Set[TableCapability] = java.util.EnumSet.of( TableCapability.BATCH_READ, TableCapability.V1_BATCH_WRITE, TableCapability.OVERWRITE_BY_FILTER, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala index 15a25c2680722..fcb25751db8d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala @@ -46,7 +46,7 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes Seq(true, false).foreach { caseSensitive => withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { Seq("ID", "iD").foreach { ref => - val tableSpec = TableSpec(None, Map.empty, None, Map.empty, + val tableSpec = TableSpec(Map.empty, None, Map.empty, None, None, None, false) val plan = CreateTableAsSelect( UnresolvedDBObjectName(Array("table_name"), isNamespace = false), @@ -70,7 +70,7 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes Seq(true, false).foreach { caseSensitive => withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { Seq("POINT.X", "point.X", "poInt.x", "poInt.X").foreach { ref => - val tableSpec = TableSpec(None, Map.empty, None, Map.empty, + val tableSpec = TableSpec(Map.empty, None, Map.empty, None, None, None, false) val plan = CreateTableAsSelect( UnresolvedDBObjectName(Array("table_name"), isNamespace = false), @@ -95,7 +95,7 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes Seq(true, false).foreach { caseSensitive => withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { Seq("ID", "iD").foreach { ref => - val tableSpec = TableSpec(None, Map.empty, None, Map.empty, + val tableSpec = TableSpec(Map.empty, None, Map.empty, None, None, None, false) val plan = ReplaceTableAsSelect( UnresolvedDBObjectName(Array("table_name"), isNamespace = false), @@ -119,7 +119,7 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes Seq(true, false).foreach { caseSensitive => withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { Seq("POINT.X", "point.X", "poInt.x", "poInt.X").foreach { ref => - val tableSpec = TableSpec(None, Map.empty, None, Map.empty, + val tableSpec = TableSpec(Map.empty, None, Map.empty, None, None, None, false) val plan = ReplaceTableAsSelect( UnresolvedDBObjectName(Array("table_name"), isNamespace = false), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala index db4a9c153c0ff..5f8684a144778 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala @@ -21,7 +21,7 @@ import java.util.Collections import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.{catalyst, AnalysisException, DataFrame, QueryTest} +import org.apache.spark.sql.{catalyst, AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, RangePartitioning, UnknownPartitioning} @@ -33,7 +33,10 @@ import org.apache.spark.sql.execution.{QueryExecution, SortExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.streaming.{StreamingQueryException, Trigger} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StringType, StructType} import org.apache.spark.sql.util.QueryExecutionListener @@ -42,6 +45,7 @@ class WriteDistributionAndOrderingSuite extends QueryTest with SharedSparkSession with BeforeAndAfter with AdaptiveSparkPlanHelper { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + import testImplicits._ before { spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) @@ -52,6 +56,7 @@ class WriteDistributionAndOrderingSuite spark.sessionState.conf.unsetConf("spark.sql.catalog.testcat") } + private val microBatchPrefix = "micro_batch_" private val namespace = Array("ns1") private val ident = Identifier.of(namespace, "test_table") private val tableNameAsString = "testcat." + ident.toString @@ -74,6 +79,18 @@ class WriteDistributionAndOrderingSuite checkOrderedDistributionAndSortWithSameExprs("overwriteDynamic") } + test("ordered distribution and sort with same exprs: micro-batch append") { + checkOrderedDistributionAndSortWithSameExprs(microBatchPrefix + "append") + } + + test("ordered distribution and sort with same exprs: micro-batch update") { + checkOrderedDistributionAndSortWithSameExprs(microBatchPrefix + "update") + } + + test("ordered distribution and sort with same exprs: micro-batch complete") { + checkOrderedDistributionAndSortWithSameExprs(microBatchPrefix + "complete") + } + test("ordered distribution and sort with same exprs with numPartitions: append") { checkOrderedDistributionAndSortWithSameExprs("append", Some(10)) } @@ -86,6 +103,18 @@ class WriteDistributionAndOrderingSuite checkOrderedDistributionAndSortWithSameExprs("overwriteDynamic", Some(10)) } + test("ordered distribution and sort with same exprs with numPartitions: micro-batch append") { + checkOrderedDistributionAndSortWithSameExprs(microBatchPrefix + "append", Some(10)) + } + + test("ordered distribution and sort with same exprs with numPartitions: micro-batch update") { + checkOrderedDistributionAndSortWithSameExprs(microBatchPrefix + "update", Some(10)) + } + + test("ordered distribution and sort with same exprs with numPartitions: micro-batch complete") { + checkOrderedDistributionAndSortWithSameExprs(microBatchPrefix + "complete", Some(10)) + } + private def checkOrderedDistributionAndSortWithSameExprs(command: String): Unit = { checkOrderedDistributionAndSortWithSameExprs(command, None) } @@ -129,6 +158,18 @@ class WriteDistributionAndOrderingSuite checkClusteredDistributionAndSortWithSameExprs("overwriteDynamic") } + test("clustered distribution and sort with same exprs: micro-batch append") { + checkClusteredDistributionAndSortWithSameExprs(microBatchPrefix + "append") + } + + test("clustered distribution and sort with same exprs: micro-batch update") { + checkClusteredDistributionAndSortWithSameExprs(microBatchPrefix + "update") + } + + test("clustered distribution and sort with same exprs: micro-batch complete") { + checkClusteredDistributionAndSortWithSameExprs(microBatchPrefix + "complete") + } + test("clustered distribution and sort with same exprs with numPartitions: append") { checkClusteredDistributionAndSortWithSameExprs("append", Some(10)) } @@ -141,6 +182,18 @@ class WriteDistributionAndOrderingSuite checkClusteredDistributionAndSortWithSameExprs("overwriteDynamic", Some(10)) } + test("clustered distribution and sort with same exprs with numPartitions: micro-batch append") { + checkClusteredDistributionAndSortWithSameExprs(microBatchPrefix + "append", Some(10)) + } + + test("clustered distribution and sort with same exprs with numPartitions: micro-batch update") { + checkClusteredDistributionAndSortWithSameExprs(microBatchPrefix + "update", Some(10)) + } + + test("clustered distribution and sort with same exprs with numPartitions: micro-batch complete") { + checkClusteredDistributionAndSortWithSameExprs(microBatchPrefix + "complete", Some(10)) + } + private def checkClusteredDistributionAndSortWithSameExprs(command: String): Unit = { checkClusteredDistributionAndSortWithSameExprs(command, None) } @@ -193,6 +246,18 @@ class WriteDistributionAndOrderingSuite checkClusteredDistributionAndSortWithExtendedExprs("overwriteDynamic") } + test("clustered distribution and sort with extended exprs: micro-batch append") { + checkClusteredDistributionAndSortWithExtendedExprs(microBatchPrefix + "append") + } + + test("clustered distribution and sort with extended exprs: micro-batch update") { + checkClusteredDistributionAndSortWithExtendedExprs(microBatchPrefix + "update") + } + + test("clustered distribution and sort with extended exprs: micro-batch complete") { + checkClusteredDistributionAndSortWithExtendedExprs(microBatchPrefix + "complete") + } + test("clustered distribution and sort with extended exprs with numPartitions: append") { checkClusteredDistributionAndSortWithExtendedExprs("append", Some(10)) } @@ -206,6 +271,21 @@ class WriteDistributionAndOrderingSuite checkClusteredDistributionAndSortWithExtendedExprs("overwriteDynamic", Some(10)) } + test("clustered distribution and sort with extended exprs with numPartitions: " + + "micro-batch append") { + checkClusteredDistributionAndSortWithExtendedExprs(microBatchPrefix + "append", Some(10)) + } + + test("clustered distribution and sort with extended exprs with numPartitions: " + + "micro-batch update") { + checkClusteredDistributionAndSortWithExtendedExprs(microBatchPrefix + "update", Some(10)) + } + + test("clustered distribution and sort with extended exprs with numPartitions: " + + "micro-batch complete") { + checkClusteredDistributionAndSortWithExtendedExprs(microBatchPrefix + "complete", Some(10)) + } + private def checkClusteredDistributionAndSortWithExtendedExprs(command: String): Unit = { checkClusteredDistributionAndSortWithExtendedExprs(command, None) } @@ -258,6 +338,18 @@ class WriteDistributionAndOrderingSuite checkUnspecifiedDistributionAndLocalSort("overwriteDynamic") } + test("unspecified distribution and local sort: micro-batch append") { + checkUnspecifiedDistributionAndLocalSort(microBatchPrefix + "append") + } + + test("unspecified distribution and local sort: micro-batch update") { + checkUnspecifiedDistributionAndLocalSort(microBatchPrefix + "update") + } + + test("unspecified distribution and local sort: micro-batch complete") { + checkUnspecifiedDistributionAndLocalSort(microBatchPrefix + "complete") + } + test("unspecified distribution and local sort with numPartitions: append") { checkUnspecifiedDistributionAndLocalSort("append", Some(10)) } @@ -270,6 +362,18 @@ class WriteDistributionAndOrderingSuite checkUnspecifiedDistributionAndLocalSort("overwriteDynamic", Some(10)) } + test("unspecified distribution and local sort with numPartitions: micro-batch append") { + checkUnspecifiedDistributionAndLocalSort(microBatchPrefix + "append", Some(10)) + } + + test("unspecified distribution and local sort with numPartitions: micro-batch update") { + checkUnspecifiedDistributionAndLocalSort(microBatchPrefix + "update", Some(10)) + } + + test("unspecified distribution and local sort with numPartitions: micro-batch complete") { + checkUnspecifiedDistributionAndLocalSort(microBatchPrefix + "complete", Some(10)) + } + private def checkUnspecifiedDistributionAndLocalSort(command: String): Unit = { checkUnspecifiedDistributionAndLocalSort(command, None) } @@ -316,6 +420,18 @@ class WriteDistributionAndOrderingSuite checkUnspecifiedDistributionAndNoSort("overwriteDynamic") } + test("unspecified distribution and no sort: micro-batch append") { + checkUnspecifiedDistributionAndNoSort(microBatchPrefix + "append") + } + + test("unspecified distribution and no sort: micro-batch update") { + checkUnspecifiedDistributionAndNoSort(microBatchPrefix + "update") + } + + test("unspecified distribution and no sort: micro-batch complete") { + checkUnspecifiedDistributionAndNoSort(microBatchPrefix + "complete") + } + test("unspecified distribution and no sort with numPartitions: append") { checkUnspecifiedDistributionAndNoSort("append", Some(10)) } @@ -328,6 +444,18 @@ class WriteDistributionAndOrderingSuite checkUnspecifiedDistributionAndNoSort("overwriteDynamic", Some(10)) } + test("unspecified distribution and no sort with numPartitions: micro-batch append") { + checkUnspecifiedDistributionAndNoSort(microBatchPrefix + "append", Some(10)) + } + + test("unspecified distribution and no sort with numPartitions: micro-batch update") { + checkUnspecifiedDistributionAndNoSort(microBatchPrefix + "update", Some(10)) + } + + test("unspecified distribution and no sort with numPartitions: micro-batch complete") { + checkUnspecifiedDistributionAndNoSort(microBatchPrefix + "complete", Some(10)) + } + private def checkUnspecifiedDistributionAndNoSort(command: String): Unit = { checkUnspecifiedDistributionAndNoSort(command, None) } @@ -677,7 +805,95 @@ class WriteDistributionAndOrderingSuite writeCommand = command) } + test("continuous mode does not support write distribution and ordering") { + val ordering = Array[SortOrder]( + sort(FieldReference("data"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST) + ) + val distribution = Distributions.ordered(ordering) + + catalog.createTable(ident, schema, Array.empty, emptyProps, distribution, ordering, None) + + withTempDir { checkpointDir => + val inputData = ContinuousMemoryStream[(Long, String)] + val inputDF = inputData.toDF().toDF("id", "data") + + val writer = inputDF + .writeStream + .trigger(Trigger.Continuous(100)) + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .outputMode("append") + + val analysisException = intercept[AnalysisException] { + val query = writer.toTable(tableNameAsString) + + inputData.addData((1, "a"), (2, "b")) + + query.processAllAvailable() + query.stop() + } + + assert(analysisException.message.contains("Sinks cannot request distribution and ordering")) + } + } + + test("continuous mode allows unspecified distribution and empty ordering") { + catalog.createTable(ident, schema, Array.empty, emptyProps) + + withTempDir { checkpointDir => + val inputData = ContinuousMemoryStream[(Long, String)] + val inputDF = inputData.toDF().toDF("id", "data") + + val writer = inputDF + .writeStream + .trigger(Trigger.Continuous(100)) + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .outputMode("append") + + val query = writer.toTable(tableNameAsString) + + inputData.addData((1, "a"), (2, "b")) + + query.processAllAvailable() + query.stop() + + checkAnswer(spark.table(tableNameAsString), Row(1, "a") :: Row(2, "b") :: Nil) + } + } + private def checkWriteRequirements( + tableDistribution: Distribution, + tableOrdering: Array[SortOrder], + tableNumPartitions: Option[Int], + expectedWritePartitioning: physical.Partitioning, + expectedWriteOrdering: Seq[catalyst.expressions.SortOrder], + writeTransform: DataFrame => DataFrame = df => df, + writeCommand: String, + expectAnalysisException: Boolean = false): Unit = { + + if (writeCommand.startsWith(microBatchPrefix)) { + checkMicroBatchWriteRequirements( + tableDistribution, + tableOrdering, + tableNumPartitions, + expectedWritePartitioning, + expectedWriteOrdering, + writeTransform, + outputMode = writeCommand.stripPrefix(microBatchPrefix), + expectAnalysisException) + } else { + checkBatchWriteRequirements( + tableDistribution, + tableOrdering, + tableNumPartitions, + expectedWritePartitioning, + expectedWriteOrdering, + writeTransform, + writeCommand, + expectAnalysisException) + } + } + + private def checkBatchWriteRequirements( tableDistribution: Distribution, tableOrdering: Array[SortOrder], tableNumPartitions: Option[Int], @@ -712,15 +928,84 @@ class WriteDistributionAndOrderingSuite } } + private def checkMicroBatchWriteRequirements( + tableDistribution: Distribution, + tableOrdering: Array[SortOrder], + tableNumPartitions: Option[Int], + expectedWritePartitioning: physical.Partitioning, + expectedWriteOrdering: Seq[catalyst.expressions.SortOrder], + writeTransform: DataFrame => DataFrame = df => df, + outputMode: String = "append", + expectAnalysisException: Boolean = false): Unit = { + + catalog.createTable(ident, schema, Array.empty, emptyProps, tableDistribution, + tableOrdering, tableNumPartitions) + + withTempDir { checkpointDir => + val inputData = MemoryStream[(Long, String)] + val inputDF = inputData.toDF().toDF("id", "data") + + val queryDF = outputMode match { + case "append" | "update" => + inputDF + case "complete" => + // add an aggregate for complete mode + inputDF + .groupBy("id") + .agg(Map("data" -> "count")) + .select($"id", $"count(data)".cast("string").as("data")) + } + + val writer = writeTransform(queryDF) + .writeStream + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .outputMode(outputMode) + + def executeCommand(): SparkPlan = execute { + val query = writer.toTable(tableNameAsString) + + inputData.addData((1, "a"), (2, "b")) + + query.processAllAvailable() + query.stop() + } + + if (expectAnalysisException) { + val streamingQueryException = intercept[StreamingQueryException] { + executeCommand() + } + val cause = streamingQueryException.cause + assert(cause.getMessage.contains("number of partitions can't be specified")) + + } else { + val executedPlan = executeCommand() + + checkPartitioningAndOrdering( + executedPlan, + expectedWritePartitioning, + expectedWriteOrdering, + // there is an extra shuffle for groupBy in complete mode + maxNumShuffles = if (outputMode != "complete") 1 else 2) + + val expectedRows = outputMode match { + case "append" | "update" => Row(1, "a") :: Row(2, "b") :: Nil + case "complete" => Row(1, "1") :: Row(2, "1") :: Nil + } + checkAnswer(spark.table(tableNameAsString), expectedRows) + } + } + } + private def checkPartitioningAndOrdering( plan: SparkPlan, partitioning: physical.Partitioning, - ordering: Seq[catalyst.expressions.SortOrder]): Unit = { + ordering: Seq[catalyst.expressions.SortOrder], + maxNumShuffles: Int = 1): Unit = { val sorts = collect(plan) { case s: SortExec => s } assert(sorts.size <= 1, "must be at most one sort") val shuffles = collect(plan) { case s: ShuffleExchangeLike => s } - assert(shuffles.size <= 1, "must be at most one shuffle") + assert(shuffles.size <= maxNumShuffles, $"must be at most $maxNumShuffles shuffles") val actualPartitioning = plan.outputPartitioning val expectedPartitioning = partitioning match { @@ -730,6 +1015,9 @@ class WriteDistributionAndOrderingSuite case p: physical.HashPartitioning => val resolvedExprs = p.expressions.map(resolveAttrs(_, plan)) p.copy(expressions = resolvedExprs) + case _: UnknownPartitioning => + // don't check partitioning if no particular one is expected + actualPartitioning case other => other } assert(actualPartitioning == expectedPartitioning, "partitioning must match") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsDSv2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsDSv2Suite.scala new file mode 100644 index 0000000000000..bfea3f535dd94 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsDSv2Suite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.errors + +import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.connector.{DatasourceV2SQLBase, FakeV2Provider} +import org.apache.spark.sql.test.SharedSparkSession + +class QueryCompilationErrorsDSv2Suite + extends QueryTest + with SharedSparkSession + with DatasourceV2SQLBase { + + test("UNSUPPORTED_FEATURE: IF PARTITION NOT EXISTS not supported by INSERT") { + val v2Format = classOf[FakeV2Provider].getName + val tbl = "testcat.ns1.ns2.tbl" + + withTable(tbl) { + val view = "tmp_view" + val df = spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "data") + df.createOrReplaceTempView(view) + withTempView(view) { + sql(s"CREATE TABLE $tbl (id bigint, data string) USING $v2Format PARTITIONED BY (id)") + + val e = intercept[AnalysisException] { + sql(s"INSERT OVERWRITE TABLE $tbl PARTITION (id = 1) IF NOT EXISTS SELECT * FROM $view") + } + + checkAnswer(spark.table(tbl), spark.emptyDataFrame) + assert(e.getMessage === "The feature is not supported: " + + s"IF NOT EXISTS for the table '$tbl' by INSERT INTO.") + assert(e.getErrorClass === "UNSUPPORTED_FEATURE") + assert(e.getSqlState === "0A000") + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala new file mode 100644 index 0000000000000..d5cbfc844ccdd --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.errors + +import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest} +import org.apache.spark.sql.functions.{grouping, grouping_id, sum} +import org.apache.spark.sql.test.SharedSparkSession + +case class StringLongClass(a: String, b: Long) + +case class StringIntClass(a: String, b: Int) + +case class ComplexClass(a: Long, b: StringLongClass) + +class QueryCompilationErrorsSuite extends QueryTest with SharedSparkSession { + import testImplicits._ + + test("CANNOT_UP_CAST_DATATYPE: invalid upcast data type") { + val msg1 = intercept[AnalysisException] { + sql("select 'value1' as a, 1L as b").as[StringIntClass] + }.message + assert(msg1 === + s""" + |Cannot up cast b from bigint to int. + |The type path of the target object is: + |- field (class: "scala.Int", name: "b") + |- root class: "org.apache.spark.sql.errors.StringIntClass" + |You can either add an explicit cast to the input data or choose a higher precision type + """.stripMargin.trim + " of the field in the target object") + + val msg2 = intercept[AnalysisException] { + sql("select 1L as a," + + " named_struct('a', 'value1', 'b', cast(1.0 as decimal(38,18))) as b") + .as[ComplexClass] + }.message + assert(msg2 === + s""" + |Cannot up cast b.`b` from decimal(38,18) to bigint. + |The type path of the target object is: + |- field (class: "scala.Long", name: "b") + |- field (class: "org.apache.spark.sql.errors.StringLongClass", name: "b") + |- root class: "org.apache.spark.sql.errors.ComplexClass" + |You can either add an explicit cast to the input data or choose a higher precision type + """.stripMargin.trim + " of the field in the target object") + } + + test("UNSUPPORTED_GROUPING_EXPRESSION: filter with grouping/grouping_Id expression") { + val df = Seq( + (536361, "85123A", 2, 17850), + (536362, "85123B", 4, 17850), + (536363, "86123A", 6, 17851) + ).toDF("InvoiceNo", "StockCode", "Quantity", "CustomerID") + Seq("grouping", "grouping_id").foreach { grouping => + val errMsg = intercept[AnalysisException] { + df.groupBy("CustomerId").agg(Map("Quantity" -> "max")) + .filter(s"$grouping(CustomerId)=17850") + } + assert(errMsg.message === + "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") + assert(errMsg.errorClass === Some("UNSUPPORTED_GROUPING_EXPRESSION")) + } + } + + test("UNSUPPORTED_GROUPING_EXPRESSION: Sort with grouping/grouping_Id expression") { + val df = Seq( + (536361, "85123A", 2, 17850), + (536362, "85123B", 4, 17850), + (536363, "86123A", 6, 17851) + ).toDF("InvoiceNo", "StockCode", "Quantity", "CustomerID") + Seq(grouping("CustomerId"), grouping_id("CustomerId")).foreach { grouping => + val errMsg = intercept[AnalysisException] { + df.groupBy("CustomerId").agg(Map("Quantity" -> "max")). + sort(grouping) + } + assert(errMsg.errorClass === Some("UNSUPPORTED_GROUPING_EXPRESSION")) + assert(errMsg.message === + "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") + } + } + + test("ILLEGAL_SUBSTRING: the argument_index of string format is invalid") { + val e = intercept[AnalysisException] { + sql("select format_string('%0$s', 'Hello')") + } + assert(e.errorClass === Some("ILLEGAL_SUBSTRING")) + assert(e.message === + "The argument_index of string format cannot contain position 0$.") + } + + test("CANNOT_USE_MIXTURE: Using aggregate function with grouped aggregate pandas UDF") { + import IntegratedUDFTestUtils._ + + val df = Seq( + (536361, "85123A", 2, 17850), + (536362, "85123B", 4, 17850), + (536363, "86123A", 6, 17851) + ).toDF("InvoiceNo", "StockCode", "Quantity", "CustomerID") + val e = intercept[AnalysisException] { + val pandasTestUDF = TestGroupedAggPandasUDF(name = "pandas_udf") + df.groupBy("CustomerId") + .agg(pandasTestUDF(df("Quantity")), sum(df("Quantity"))).collect() + } + + assert(e.errorClass === Some("CANNOT_USE_MIXTURE")) + assert(e.message === + "Cannot use a mixture of aggregate function and group aggregate pandas UDF") + } + + test("UNSUPPORTED_FEATURE: Using Python UDF with unsupported join condition") { + import IntegratedUDFTestUtils._ + + val df1 = Seq( + (536361, "85123A", 2, 17850), + (536362, "85123B", 4, 17850), + (536363, "86123A", 6, 17851) + ).toDF("InvoiceNo", "StockCode", "Quantity", "CustomerID") + val df2 = Seq( + ("Bob", 17850), + ("Alice", 17850), + ("Tom", 17851) + ).toDF("CustomerName", "CustomerID") + + val e = intercept[AnalysisException] { + val pythonTestUDF = TestPythonUDF(name = "python_udf") + df1.join( + df2, pythonTestUDF(df1("CustomerID") === df2("CustomerID")), "leftouter").collect() + } + + assert(e.errorClass === Some("UNSUPPORTED_FEATURE")) + assert(e.getSqlState === "0A000") + assert(e.message === + "The feature is not supported: " + + "Using PythonUDF in join condition of join type LeftOuter is not supported") + } + + test("UNSUPPORTED_FEATURE: Using pandas UDF aggregate expression with pivot") { + import IntegratedUDFTestUtils._ + + val df = Seq( + (536361, "85123A", 2, 17850), + (536362, "85123B", 4, 17850), + (536363, "86123A", 6, 17851) + ).toDF("InvoiceNo", "StockCode", "Quantity", "CustomerID") + + val e = intercept[AnalysisException] { + val pandasTestUDF = TestGroupedAggPandasUDF(name = "pandas_udf") + df.groupBy(df("CustomerID")).pivot(df("CustomerID")).agg(pandasTestUDF(df("Quantity"))) + } + + assert(e.errorClass === Some("UNSUPPORTED_FEATURE")) + assert(e.getSqlState === "0A000") + assert(e.message === + "The feature is not supported: " + + "Pandas UDF aggregate expressions don't support pivot.") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala index 13f44a21499d2..9268be43ba490 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala @@ -17,11 +17,20 @@ package org.apache.spark.sql.errors -import org.apache.spark.{SparkException, SparkRuntimeException} +import org.apache.spark.{SparkArithmeticException, SparkException, SparkRuntimeException, SparkUnsupportedOperationException, SparkUpgradeException} import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.execution.datasources.orc.OrcTest +import org.apache.spark.sql.execution.datasources.parquet.ParquetTest +import org.apache.spark.sql.functions.{lit, lower, struct, sum} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy.EXCEPTION import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{StructType, TimestampType} +import org.apache.spark.sql.util.ArrowUtils + +class QueryExecutionErrorsSuite extends QueryTest + with ParquetTest with OrcTest with SharedSparkSession { -class QueryExecutionErrorsSuite extends QueryTest with SharedSparkSession { import testImplicits._ private def getAesInputs(): (DataFrame, DataFrame) = { @@ -41,16 +50,17 @@ class QueryExecutionErrorsSuite extends QueryTest with SharedSparkSession { (df1, df2) } - test("INVALID_AES_KEY_LENGTH: invalid key lengths in AES functions") { + test("INVALID_PARAMETER_VALUE: invalid key lengths in AES functions") { val (df1, df2) = getAesInputs() def checkInvalidKeyLength(df: => DataFrame): Unit = { val e = intercept[SparkException] { df.collect }.getCause.asInstanceOf[SparkRuntimeException] - assert(e.getErrorClass === "INVALID_AES_KEY_LENGTH") - assert(e.getSqlState === "42000") - assert(e.getMessage.contains( - "The key length of aes_encrypt/aes_decrypt should be one of 16, 24 or 32 bytes")) + assert(e.getErrorClass === "INVALID_PARAMETER_VALUE") + assert(e.getSqlState === "22023") + assert(e.getMessage.matches( + "The value of parameter\\(s\\) 'key' in the aes_encrypt/aes_decrypt function is invalid: " + + "expects a binary value with 16, 24 or 32 bytes, but got \\d+ bytes.")) } // Encryption failure - invalid key length @@ -71,7 +81,26 @@ class QueryExecutionErrorsSuite extends QueryTest with SharedSparkSession { } } - test("UNSUPPORTED_AES_MODE: unsupported combinations of AES modes and padding") { + test("INVALID_PARAMETER_VALUE: AES decrypt failure - key mismatch") { + val (_, df2) = getAesInputs() + Seq( + ("value16", "1234567812345678"), + ("value24", "123456781234567812345678"), + ("value32", "12345678123456781234567812345678")).foreach { case (colName, key) => + val e = intercept[SparkException] { + df2.selectExpr(s"aes_decrypt(unbase64($colName), binary('$key'), 'ECB')").collect + }.getCause.asInstanceOf[SparkRuntimeException] + assert(e.getErrorClass === "INVALID_PARAMETER_VALUE") + assert(e.getSqlState === "22023") + assert(e.getMessage === + "The value of parameter(s) 'expr, key' in the aes_encrypt/aes_decrypt function " + + "is invalid: Detail message: " + + "Given final block not properly padded. " + + "Such issues can arise if a bad key is used during decryption.") + } + } + + test("UNSUPPORTED_FEATURE: unsupported combinations of AES modes and padding") { val key16 = "abcdefghijklmnop" val key32 = "abcdefghijklmnop12345678ABCDEFGH" val (df1, df2) = getAesInputs() @@ -79,9 +108,10 @@ class QueryExecutionErrorsSuite extends QueryTest with SharedSparkSession { val e = intercept[SparkException] { df.collect }.getCause.asInstanceOf[SparkRuntimeException] - assert(e.getErrorClass === "UNSUPPORTED_AES_MODE") + assert(e.getErrorClass === "UNSUPPORTED_FEATURE") assert(e.getSqlState === "0A000") - assert(e.getMessage.matches("""The AES mode \w+ with the padding \w+ is not supported""")) + assert(e.getMessage.matches("""The feature is not supported: AES-\w+ with the padding \w+""" + + " by the aes_encrypt/aes_decrypt function.")) } // Unsupported AES mode and padding in encrypt @@ -94,18 +124,158 @@ class QueryExecutionErrorsSuite extends QueryTest with SharedSparkSession { checkUnsupportedMode(df2.selectExpr(s"aes_decrypt(value32, '$key32', 'ECB', 'None')")) } - test("AES_CRYPTO_ERROR: AES decrypt failure - key mismatch") { - val (_, df2) = getAesInputs() - Seq( - ("value16", "1234567812345678"), - ("value24", "123456781234567812345678"), - ("value32", "12345678123456781234567812345678")).foreach { case (colName, key) => + test("UNSUPPORTED_FEATURE: unsupported types (map and struct) in lit()") { + def checkUnsupportedTypeInLiteral(v: Any): Unit = { + val e1 = intercept[SparkRuntimeException] { lit(v) } + assert(e1.getErrorClass === "UNSUPPORTED_FEATURE") + assert(e1.getSqlState === "0A000") + assert(e1.getMessage.matches("""The feature is not supported: literal for '.+' of .+\.""")) + } + checkUnsupportedTypeInLiteral(Map("key1" -> 1, "key2" -> 2)) + checkUnsupportedTypeInLiteral(("mike", 29, 1.0)) + + val e2 = intercept[SparkRuntimeException] { + trainingSales + .groupBy($"sales.year") + .pivot(struct(lower(trainingSales("sales.course")), trainingSales("training"))) + .agg(sum($"sales.earnings")) + .collect() + } + assert(e2.getMessage === "The feature is not supported: pivoting by the value" + + """ '[dotnet,Dummies]' of the column data type 'struct'.""") + } + + test("UNSUPPORTED_FEATURE: unsupported pivot operations") { + val e1 = intercept[SparkUnsupportedOperationException] { + trainingSales + .groupBy($"sales.year") + .pivot($"sales.course") + .pivot($"training") + .agg(sum($"sales.earnings")) + .collect() + } + assert(e1.getErrorClass === "UNSUPPORTED_FEATURE") + assert(e1.getSqlState === "0A000") + assert(e1.getMessage === "The feature is not supported: Repeated pivots.") + + val e2 = intercept[SparkUnsupportedOperationException] { + trainingSales + .rollup($"sales.year") + .pivot($"training") + .agg(sum($"sales.earnings")) + .collect() + } + assert(e2.getErrorClass === "UNSUPPORTED_FEATURE") + assert(e2.getSqlState === "0A000") + assert(e2.getMessage === "The feature is not supported: Pivot not after a groupBy.") + } + + test("INCONSISTENT_BEHAVIOR_CROSS_VERSION: " + + "compatibility with Spark 2.4/3.2 in reading/writing dates") { + + // Fail to read ancient datetime values. + withSQLConf(SQLConf.PARQUET_REBASE_MODE_IN_READ.key -> EXCEPTION.toString) { + val fileName = "before_1582_date_v2_4_5.snappy.parquet" + val filePath = getResourceParquetFilePath("test-data/" + fileName) val e = intercept[SparkException] { - df2.selectExpr(s"aes_decrypt(unbase64($colName), binary('$key'), 'ECB')").collect - }.getCause.asInstanceOf[SparkRuntimeException] - assert(e.getErrorClass === "AES_CRYPTO_ERROR") - assert(e.getSqlState === null) - assert(e.getMessage.contains("AES crypto operation failed")) + spark.read.parquet(filePath).collect() + }.getCause.asInstanceOf[SparkUpgradeException] + + val format = "Parquet" + val config = SQLConf.PARQUET_REBASE_MODE_IN_READ.key + val option = "datetimeRebaseMode" + assert(e.getErrorClass === "INCONSISTENT_BEHAVIOR_CROSS_VERSION") + assert(e.getMessage === + "You may get a different result due to the upgrading to Spark >= 3.0: " + + s""" + |reading dates before 1582-10-15 or timestamps before 1900-01-01T00:00:00Z + |from $format files can be ambiguous, as the files may be written by + |Spark 2.x or legacy versions of Hive, which uses a legacy hybrid calendar + |that is different from Spark 3.0+'s Proleptic Gregorian calendar. + |See more details in SPARK-31404. You can set the SQL config '$config' or + |the datasource option '$option' to 'LEGACY' to rebase the datetime values + |w.r.t. the calendar difference during reading. To read the datetime values + |as it is, set the SQL config '$config' or the datasource option '$option' + |to 'CORRECTED'. + |""".stripMargin) + } + + // Fail to write ancient datetime values. + withSQLConf(SQLConf.PARQUET_REBASE_MODE_IN_WRITE.key -> EXCEPTION.toString) { + withTempPath { dir => + val df = Seq(java.sql.Date.valueOf("1001-01-01")).toDF("dt") + val e = intercept[SparkException] { + df.write.parquet(dir.getCanonicalPath) + }.getCause.getCause.getCause.asInstanceOf[SparkUpgradeException] + + val format = "Parquet" + val config = SQLConf.PARQUET_REBASE_MODE_IN_WRITE.key + assert(e.getErrorClass === "INCONSISTENT_BEHAVIOR_CROSS_VERSION") + assert(e.getMessage === + "You may get a different result due to the upgrading to Spark >= 3.0: " + + s""" + |writing dates before 1582-10-15 or timestamps before 1900-01-01T00:00:00Z + |into $format files can be dangerous, as the files may be read by Spark 2.x + |or legacy versions of Hive later, which uses a legacy hybrid calendar that + |is different from Spark 3.0+'s Proleptic Gregorian calendar. See more + |details in SPARK-31404. You can set $config to 'LEGACY' to rebase the + |datetime values w.r.t. the calendar difference during writing, to get maximum + |interoperability. Or set $config to 'CORRECTED' to write the datetime values + |as it is, if you are 100% sure that the written files will only be read by + |Spark 3.0+ or other systems that use Proleptic Gregorian calendar. + |""".stripMargin) + } + } + } + + test("UNSUPPORTED_OPERATION: timeZoneId not specified while converting TimestampType to Arrow") { + val schema = new StructType().add("value", TimestampType) + val e = intercept[SparkUnsupportedOperationException] { + ArrowUtils.toArrowSchema(schema, null) + } + + assert(e.getErrorClass === "UNSUPPORTED_OPERATION") + assert(e.getMessage === "The operation is not supported: " + + "timestamp must supply timeZoneId parameter while converting to ArrowType") + } + + test("UNSUPPORTED_OPERATION - SPARK-36346: can't read Timestamp as TimestampNTZ") { + withTempPath { file => + sql("select timestamp_ltz'2019-03-21 00:02:03'").write.orc(file.getCanonicalPath) + withAllNativeOrcReaders { + val e = intercept[SparkException] { + spark.read.schema("time timestamp_ntz").orc(file.getCanonicalPath).collect() + }.getCause.asInstanceOf[SparkUnsupportedOperationException] + + assert(e.getErrorClass === "UNSUPPORTED_OPERATION") + assert(e.getMessage === "The operation is not supported: " + + "Unable to convert timestamp of Orc to data type 'timestamp_ntz'") + } + } + } + + test("UNSUPPORTED_OPERATION - SPARK-38504: can't read TimestampNTZ as TimestampLTZ") { + withTempPath { file => + sql("select timestamp_ntz'2019-03-21 00:02:03'").write.orc(file.getCanonicalPath) + withAllNativeOrcReaders { + val e = intercept[SparkException] { + spark.read.schema("time timestamp_ltz").orc(file.getCanonicalPath).collect() + }.getCause.asInstanceOf[SparkUnsupportedOperationException] + + assert(e.getErrorClass === "UNSUPPORTED_OPERATION") + assert(e.getMessage === "The operation is not supported: " + + "Unable to convert timestamp ntz of Orc to data type 'timestamp_ltz'") + } + } + } + + test("DATETIME_OVERFLOW: timestampadd() overflows its input timestamp") { + val e = intercept[SparkArithmeticException] { + sql("select timestampadd(YEAR, 1000000, timestamp'2022-03-09 01:02:03')").collect() } + assert(e.getErrorClass === "DATETIME_OVERFLOW") + assert(e.getSqlState === "22008") + assert(e.getMessage === + "Datetime operation overflow: add 1000000 YEAR to '2022-03-09T09:02:03Z'.") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala new file mode 100644 index 0000000000000..e7f62d28e8efa --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala @@ -0,0 +1,393 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.errors + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.test.SharedSparkSession + +class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession { + def validateParsingError( + sqlText: String, + errorClass: String, + sqlState: String, + message: String): Unit = { + val e = intercept[ParseException] { + sql(sqlText) + } + assert(e.getErrorClass === errorClass) + assert(e.getSqlState === sqlState) + assert(e.getMessage === message) + } + + test("UNSUPPORTED_FEATURE: LATERAL join with NATURAL join not supported") { + validateParsingError( + sqlText = "SELECT * FROM t1 NATURAL JOIN LATERAL (SELECT c1 + c2 AS c2)", + errorClass = "UNSUPPORTED_FEATURE", + sqlState = "0A000", + message = + """ + |The feature is not supported: LATERAL join with NATURAL join.(line 1, pos 14) + | + |== SQL == + |SELECT * FROM t1 NATURAL JOIN LATERAL (SELECT c1 + c2 AS c2) + |--------------^^^ + |""".stripMargin) + } + + test("UNSUPPORTED_FEATURE: LATERAL join with USING join not supported") { + validateParsingError( + sqlText = "SELECT * FROM t1 JOIN LATERAL (SELECT c1 + c2 AS c2) USING (c2)", + errorClass = "UNSUPPORTED_FEATURE", + sqlState = "0A000", + message = + """ + |The feature is not supported: LATERAL join with USING join.(line 1, pos 14) + | + |== SQL == + |SELECT * FROM t1 JOIN LATERAL (SELECT c1 + c2 AS c2) USING (c2) + |--------------^^^ + |""".stripMargin) + } + + test("UNSUPPORTED_FEATURE: Unsupported LATERAL join type") { + Seq( + ("RIGHT OUTER", "RightOuter"), + ("FULL OUTER", "FullOuter"), + ("LEFT SEMI", "LeftSemi"), + ("LEFT ANTI", "LeftAnti")).foreach { pair => + validateParsingError( + sqlText = s"SELECT * FROM t1 ${pair._1} JOIN LATERAL (SELECT c1 + c2 AS c3) ON c2 = c3", + errorClass = "UNSUPPORTED_FEATURE", + sqlState = "0A000", + message = + s""" + |The feature is not supported: LATERAL join type '${pair._2}'.(line 1, pos 14) + | + |== SQL == + |SELECT * FROM t1 ${pair._1} JOIN LATERAL (SELECT c1 + c2 AS c3) ON c2 = c3 + |--------------^^^ + |""".stripMargin) + } + } + + test("INVALID_SQL_SYNTAX: LATERAL can only be used with subquery") { + Seq( + "SELECT * FROM t1, LATERAL t2" -> 26, + "SELECT * FROM t1 JOIN LATERAL t2" -> 30, + "SELECT * FROM t1, LATERAL (t2 JOIN t3)" -> 26, + "SELECT * FROM t1, LATERAL (LATERAL t2)" -> 26, + "SELECT * FROM t1, LATERAL VALUES (0, 1)" -> 26, + "SELECT * FROM t1, LATERAL RANGE(0, 1)" -> 26 + ).foreach { case (sqlText, pos) => + validateParsingError( + sqlText = sqlText, + errorClass = "INVALID_SQL_SYNTAX", + sqlState = "42000", + message = + s""" + |Invalid SQL syntax: LATERAL can only be used with subquery.(line 1, pos $pos) + | + |== SQL == + |$sqlText + |${"-" * pos}^^^ + |""".stripMargin) + } + } + + test("UNSUPPORTED_FEATURE: NATURAL CROSS JOIN is not supported") { + validateParsingError( + sqlText = "SELECT * FROM a NATURAL CROSS JOIN b", + errorClass = "UNSUPPORTED_FEATURE", + sqlState = "0A000", + message = + """ + |The feature is not supported: NATURAL CROSS JOIN.(line 1, pos 14) + | + |== SQL == + |SELECT * FROM a NATURAL CROSS JOIN b + |--------------^^^ + |""".stripMargin) + } + + test("INVALID_SQL_SYNTAX: redefine window") { + validateParsingError( + sqlText = "SELECT min(a) OVER win FROM t1 WINDOW win AS win, win AS win2", + errorClass = "INVALID_SQL_SYNTAX", + sqlState = "42000", + message = + """ + |Invalid SQL syntax: The definition of window 'win' is repetitive.(line 1, pos 31) + | + |== SQL == + |SELECT min(a) OVER win FROM t1 WINDOW win AS win, win AS win2 + |-------------------------------^^^ + |""".stripMargin) + } + + test("INVALID_SQL_SYNTAX: invalid window reference") { + validateParsingError( + sqlText = "SELECT min(a) OVER win FROM t1 WINDOW win AS win", + errorClass = "INVALID_SQL_SYNTAX", + sqlState = "42000", + message = + """ + |Invalid SQL syntax: Window reference 'win' is not a window specification.(line 1, pos 31) + | + |== SQL == + |SELECT min(a) OVER win FROM t1 WINDOW win AS win + |-------------------------------^^^ + |""".stripMargin) + } + + test("INVALID_SQL_SYNTAX: window reference cannot be resolved") { + validateParsingError( + sqlText = "SELECT min(a) OVER win FROM t1 WINDOW win AS win2", + errorClass = "INVALID_SQL_SYNTAX", + sqlState = "42000", + message = + """ + |Invalid SQL syntax: Cannot resolve window reference 'win2'.(line 1, pos 31) + | + |== SQL == + |SELECT min(a) OVER win FROM t1 WINDOW win AS win2 + |-------------------------------^^^ + |""".stripMargin) + } + + test("UNSUPPORTED_FEATURE: TRANSFORM does not support DISTINCT/ALL") { + validateParsingError( + sqlText = "SELECT TRANSFORM(DISTINCT a) USING 'a' FROM t", + errorClass = "UNSUPPORTED_FEATURE", + sqlState = "0A000", + message = + """ + |The feature is not supported: """.stripMargin + + """TRANSFORM does not support DISTINCT/ALL in inputs(line 1, pos 17) + | + |== SQL == + |SELECT TRANSFORM(DISTINCT a) USING 'a' FROM t + |-----------------^^^ + |""".stripMargin) + } + + test("UNSUPPORTED_FEATURE: In-memory mode does not support TRANSFORM with serde") { + validateParsingError( + sqlText = "SELECT TRANSFORM(a) ROW FORMAT SERDE " + + "'org.apache.hadoop.hive.serde2.OpenCSVSerde' USING 'a' FROM t", + errorClass = "UNSUPPORTED_FEATURE", + sqlState = "0A000", + message = + """ + |The feature is not supported: """.stripMargin + + """TRANSFORM with serde is only supported in hive mode(line 1, pos 0) + | + |== SQL == + |SELECT TRANSFORM(a) ROW FORMAT SERDE """.stripMargin + + """'org.apache.hadoop.hive.serde2.OpenCSVSerde' USING 'a' FROM t + |^^^ + |""".stripMargin) + } + + test("INVALID_SQL_SYNTAX: Too many arguments for transform") { + validateParsingError( + sqlText = "CREATE TABLE table(col int) PARTITIONED BY (years(col,col))", + errorClass = "INVALID_SQL_SYNTAX", + sqlState = "42000", + message = + """ + |Invalid SQL syntax: Too many arguments for transform years(line 1, pos 44) + | + |== SQL == + |CREATE TABLE table(col int) PARTITIONED BY (years(col,col)) + |--------------------------------------------^^^ + |""".stripMargin) + } + + test("INVALID_SQL_SYNTAX: Invalid table value function name") { + validateParsingError( + sqlText = "SELECT * FROM ns.db.func()", + errorClass = "INVALID_SQL_SYNTAX", + sqlState = "42000", + message = + """ + |Invalid SQL syntax: Unsupported function name 'ns.db.func'(line 1, pos 14) + | + |== SQL == + |SELECT * FROM ns.db.func() + |--------------^^^ + |""".stripMargin) + } + + test("INVALID_SQL_SYNTAX: Invalid scope in show functions") { + validateParsingError( + sqlText = "SHOW sys FUNCTIONS", + errorClass = "INVALID_SQL_SYNTAX", + sqlState = "42000", + message = + """ + |Invalid SQL syntax: SHOW sys FUNCTIONS not supported(line 1, pos 5) + | + |== SQL == + |SHOW sys FUNCTIONS + |-----^^^ + |""".stripMargin) + } + + test("INVALID_SQL_SYNTAX: Invalid pattern in show functions") { + val errorDesc = + "Invalid pattern in SHOW FUNCTIONS: f1. It must be a string literal.(line 1, pos 21)" + + validateParsingError( + sqlText = "SHOW FUNCTIONS IN db f1", + errorClass = "INVALID_SQL_SYNTAX", + sqlState = "42000", + message = + s""" + |Invalid SQL syntax: $errorDesc + | + |== SQL == + |SHOW FUNCTIONS IN db f1 + |---------------------^^^ + |""".stripMargin) + } + + test("INVALID_SQL_SYNTAX: Create function with both if not exists and replace") { + val sqlText = + """ + |CREATE OR REPLACE FUNCTION IF NOT EXISTS func1 as + |'com.matthewrathbone.example.SimpleUDFExample' USING JAR '/path/to/jar1', + |JAR '/path/to/jar2' + |""".stripMargin + val errorDesc = + "CREATE FUNCTION with both IF NOT EXISTS and REPLACE is not allowed.(line 2, pos 0)" + + validateParsingError( + sqlText = sqlText, + errorClass = "INVALID_SQL_SYNTAX", + sqlState = "42000", + message = + s""" + |Invalid SQL syntax: $errorDesc + | + |== SQL == + | + |CREATE OR REPLACE FUNCTION IF NOT EXISTS func1 as + |^^^ + |'com.matthewrathbone.example.SimpleUDFExample' USING JAR '/path/to/jar1', + |JAR '/path/to/jar2' + |""".stripMargin) + } + + test("INVALID_SQL_SYNTAX: Create temporary function with if not exists") { + val sqlText = + """ + |CREATE TEMPORARY FUNCTION IF NOT EXISTS func1 as + |'com.matthewrathbone.example.SimpleUDFExample' USING JAR '/path/to/jar1', + |JAR '/path/to/jar2' + |""".stripMargin + val errorDesc = + "It is not allowed to define a TEMPORARY function with IF NOT EXISTS.(line 2, pos 0)" + + validateParsingError( + sqlText = sqlText, + errorClass = "INVALID_SQL_SYNTAX", + sqlState = "42000", + message = + s""" + |Invalid SQL syntax: $errorDesc + | + |== SQL == + | + |CREATE TEMPORARY FUNCTION IF NOT EXISTS func1 as + |^^^ + |'com.matthewrathbone.example.SimpleUDFExample' USING JAR '/path/to/jar1', + |JAR '/path/to/jar2' + |""".stripMargin) + } + + test("INVALID_SQL_SYNTAX: Create temporary function with multi-part name") { + val sqlText = + """ + |CREATE TEMPORARY FUNCTION ns.db.func as + |'com.matthewrathbone.example.SimpleUDFExample' USING JAR '/path/to/jar1', + |JAR '/path/to/jar2' + |""".stripMargin + + validateParsingError( + sqlText = sqlText, + errorClass = "INVALID_SQL_SYNTAX", + sqlState = "42000", + message = + """ + |Invalid SQL syntax: Unsupported function name 'ns.db.func'(line 2, pos 0) + | + |== SQL == + | + |CREATE TEMPORARY FUNCTION ns.db.func as + |^^^ + |'com.matthewrathbone.example.SimpleUDFExample' USING JAR '/path/to/jar1', + |JAR '/path/to/jar2' + |""".stripMargin) + } + + test("INVALID_SQL_SYNTAX: Specifying database while creating temporary function") { + val sqlText = + """ + |CREATE TEMPORARY FUNCTION db.func as + |'com.matthewrathbone.example.SimpleUDFExample' USING JAR '/path/to/jar1', + |JAR '/path/to/jar2' + |""".stripMargin + val errorDesc = + "Specifying a database in CREATE TEMPORARY FUNCTION is not allowed: 'db'(line 2, pos 0)" + + validateParsingError( + sqlText = sqlText, + errorClass = "INVALID_SQL_SYNTAX", + sqlState = "42000", + message = + s""" + |Invalid SQL syntax: $errorDesc + | + |== SQL == + | + |CREATE TEMPORARY FUNCTION db.func as + |^^^ + |'com.matthewrathbone.example.SimpleUDFExample' USING JAR '/path/to/jar1', + |JAR '/path/to/jar2' + |""".stripMargin) + } + + test("INVALID_SQL_SYNTAX: Drop temporary function requires a single part name") { + val errorDesc = + "DROP TEMPORARY FUNCTION requires a single part name but got: db.func(line 1, pos 0)" + + validateParsingError( + sqlText = "DROP TEMPORARY FUNCTION db.func", + errorClass = "INVALID_SQL_SYNTAX", + sqlState = "42000", + message = + s""" + |Invalid SQL syntax: $errorDesc + | + |== SQL == + |DROP TEMPORARY FUNCTION db.func + |^^^ + |""".stripMargin) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregatingAccumulatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregatingAccumulatorSuite.scala index a33b9fad7ff4f..06fc2022c01ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregatingAccumulatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregatingAccumulatorSuite.scala @@ -35,9 +35,9 @@ class AggregatingAccumulatorSuite extends SparkFunSuite with SharedSparkSession with ExpressionEvalHelper { - private val a = 'a.long - private val b = 'b.string - private val c = 'c.double + private val a = Symbol("a").long + private val b = Symbol("b").string + private val c = Symbol("c").double private val inputAttributes = Seq(a, b, c) private def str(s: String): UTF8String = UTF8String.fromString(s) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index f774c4504bb43..09a880a706b0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -133,8 +133,8 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU """.stripMargin) checkAnswer(query, identity, df.select( - 'a.cast("string"), - 'b.cast("string"), + Symbol("a").cast("string"), + Symbol("b").cast("string"), 'c.cast("string"), 'd.cast("string"), 'e.cast("string")).collect()) @@ -164,7 +164,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU 'b.cast("string").as("value")).collect()) checkAnswer( - df.select('a, 'b), + df.select(Symbol("a"), Symbol("b")), (child: SparkPlan) => createScriptTransformationExec( script = "cat", output = Seq( @@ -178,7 +178,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU 'b.cast("string").as("value")).collect()) checkAnswer( - df.select('a), + df.select(Symbol("a")), (child: SparkPlan) => createScriptTransformationExec( script = "cat", output = Seq( @@ -242,7 +242,8 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU child = child, ioschema = serde ), - df.select('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j).collect()) + df.select(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("d"), Symbol("e"), + Symbol("f"), Symbol("g"), Symbol("h"), Symbol("i"), Symbol("j")).collect()) } } } @@ -282,7 +283,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU child = child, ioschema = defaultIOSchema ), - df.select('a, 'b, 'c, 'd, 'e).collect()) + df.select(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("d"), Symbol("e")).collect()) } } @@ -304,7 +305,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU |USING 'cat' AS (a timestamp, b date) |FROM v """.stripMargin) - checkAnswer(query, identity, df.select('a, 'b).collect()) + checkAnswer(query, identity, df.select(Symbol("a"), Symbol("b")).collect()) } } } @@ -379,7 +380,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU ).toDF("a", "b", "c", "d", "e") // Note column d's data type is Decimal(38, 18) checkAnswer( - df.select('a, 'b), + df.select(Symbol("a"), Symbol("b")), (child: SparkPlan) => createScriptTransformationExec( script = "cat", output = Seq( @@ -452,10 +453,10 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU (Array(6, 7, 8), Array(Array(6, 7), Array(8)), Map("c" -> 3), Map("d" -> Array("e", "f"))) ).toDF("a", "b", "c", "d") - .select('a, 'b, 'c, 'd, - struct('a, 'b).as("e"), - struct('a, 'd).as("f"), - struct(struct('a, 'b), struct('a, 'd)).as("g") + .select(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("d"), + struct(Symbol("a"), Symbol("b")).as("e"), + struct(Symbol("a"), Symbol("d")).as("f"), + struct(struct(Symbol("a"), Symbol("b")), struct(Symbol("a"), Symbol("d"))).as("g") ) checkAnswer( @@ -483,7 +484,8 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU child = child, ioschema = defaultIOSchema ), - df.select('a, 'b, 'c, 'd, 'e, 'f, 'g).collect()) + df.select(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("d"), Symbol("e"), + Symbol("f"), Symbol("g")).collect()) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala index 4ff96e6574cac..e4f17eb60108d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala @@ -26,9 +26,11 @@ class CoGroupedIteratorSuite extends SparkFunSuite with ExpressionEvalHelper { test("basic") { val leftInput = Seq(create_row(1, "a"), create_row(1, "b"), create_row(2, "c")).iterator val rightInput = Seq(create_row(1, 2L), create_row(2, 3L), create_row(3, 4L)).iterator - val leftGrouped = GroupedIterator(leftInput, Seq('i.int.at(0)), Seq('i.int, 's.string)) - val rightGrouped = GroupedIterator(rightInput, Seq('i.int.at(0)), Seq('i.int, 'l.long)) - val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq('i.int)) + val leftGrouped = GroupedIterator(leftInput, Seq(Symbol("i").int.at(0)), + Seq(Symbol("i").int, Symbol("s").string)) + val rightGrouped = GroupedIterator(rightInput, Seq(Symbol("i").int.at(0)), + Seq(Symbol("i").int, Symbol("l").long)) + val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq(Symbol("i").int)) val result = cogrouped.map { case (key, leftData, rightData) => @@ -52,7 +54,8 @@ class CoGroupedIteratorSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-11393: respect the fact that GroupedIterator.hasNext is not idempotent") { val leftInput = Seq(create_row(2, "a")).iterator val rightInput = Seq(create_row(1, 2L)).iterator - val leftGrouped = GroupedIterator(leftInput, Seq('i.int.at(0)), Seq('i.int, 's.string)) + val leftGrouped = GroupedIterator(leftInput, Seq(Symbol("i").int.at(0)), + Seq(Symbol("i").int, Symbol("s").string)) val rightGrouped = GroupedIterator(rightInput, Seq('i.int.at(0)), Seq('i.int, 'l.long)) val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq('i.int)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala index 612cd6f0d891b..e29b7f579fa91 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala @@ -18,13 +18,11 @@ package org.apache.spark.sql.execution import java.io.File -import scala.collection.mutable import scala.util.Random import org.apache.hadoop.fs.Path import org.apache.spark.SparkConf -import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql.{DataFrame, QueryTest} import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan @@ -215,33 +213,4 @@ class DataSourceV2ScanExecRedactionSuite extends DataSourceScanRedactionTest { } } } - - test("SPARK-30362: test input metrics for DSV2") { - withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { - Seq("json", "orc", "parquet").foreach { format => - withTempPath { path => - val dir = path.getCanonicalPath - spark.range(0, 10).write.format(format).save(dir) - val df = spark.read.format(format).load(dir) - val bytesReads = new mutable.ArrayBuffer[Long]() - val recordsRead = new mutable.ArrayBuffer[Long]() - val bytesReadListener = new SparkListener() { - override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { - bytesReads += taskEnd.taskMetrics.inputMetrics.bytesRead - recordsRead += taskEnd.taskMetrics.inputMetrics.recordsRead - } - } - sparkContext.addSparkListener(bytesReadListener) - try { - df.collect() - sparkContext.listenerBus.waitUntilEmpty() - assert(bytesReads.sum > 0) - assert(recordsRead.sum == 10) - } finally { - sparkContext.removeSparkListener(bytesReadListener) - } - } - } - } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DeprecatedWholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DeprecatedWholeStageCodegenSuite.scala index b27a940c364a4..635c794338065 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/DeprecatedWholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DeprecatedWholeStageCodegenSuite.scala @@ -36,9 +36,9 @@ class DeprecatedWholeStageCodegenSuite extends QueryTest .groupByKey(_._1).agg(typed.sum(_._2)) val plan = ds.queryExecution.executedPlan - assert(plan.find(p => + assert(plan.exists(p => p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined) + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec])) assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala index 4b2a2b439c89e..06c51cee02019 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala @@ -32,7 +32,7 @@ class GroupedIteratorSuite extends SparkFunSuite { val fromRow = encoder.createDeserializer() val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c")) val grouped = GroupedIterator(input.iterator.map(toRow), - Seq('i.int.at(0)), schema.toAttributes) + Seq(Symbol("i").int.at(0)), schema.toAttributes) val result = grouped.map { case (key, data) => @@ -59,7 +59,7 @@ class GroupedIteratorSuite extends SparkFunSuite { Row(3, 2L, "e")) val grouped = GroupedIterator(input.iterator.map(toRow), - Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes) + Seq(Symbol("i").int.at(0), Symbol("l").long.at(1)), schema.toAttributes) val result = grouped.map { case (key, data) => @@ -80,7 +80,7 @@ class GroupedIteratorSuite extends SparkFunSuite { val toRow = encoder.createSerializer() val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c")) val grouped = GroupedIterator(input.iterator.map(toRow), - Seq('i.int.at(0)), schema.toAttributes) + Seq(Symbol("i").int.at(0)), schema.toAttributes) assert(grouped.length == 2) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 3bda5625471b3..c3c8959d6e1ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -59,18 +59,21 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { } test("count is partially aggregated") { - val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed + val query = testData.groupBy(Symbol("value")).agg(count(Symbol("key"))).queryExecution.analyzed testPartialAggregationPlan(query) } test("count distinct is partially aggregated") { - val query = testData.groupBy('value).agg(count_distinct('key)).queryExecution.analyzed + val query = testData.groupBy(Symbol("value")).agg(count_distinct(Symbol("key"))) + .queryExecution.analyzed testPartialAggregationPlan(query) } test("mixed aggregates are partially aggregated") { val query = - testData.groupBy('value).agg(count('value), count_distinct('key)).queryExecution.analyzed + testData.groupBy(Symbol("value")) + .agg(count(Symbol("value")), count_distinct(Symbol("key"))) + .queryExecution.analyzed testPartialAggregationPlan(query) } @@ -193,47 +196,49 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { } test("efficient terminal limit -> sort should use TakeOrderedAndProject") { - val query = testData.select('key, 'value).sort('key).limit(2) + val query = testData.select(Symbol("key"), Symbol("value")).sort(Symbol("key")).limit(2) val planned = query.queryExecution.executedPlan assert(planned.isInstanceOf[execution.TakeOrderedAndProjectExec]) - assert(planned.output === testData.select('key, 'value).logicalPlan.output) + assert(planned.output === testData.select(Symbol("key"), Symbol("value")).logicalPlan.output) } test("terminal limit -> project -> sort should use TakeOrderedAndProject") { - val query = testData.select('key, 'value).sort('key).select('value, 'key).limit(2) + val query = testData.select(Symbol("key"), Symbol("value")).sort(Symbol("key")) + .select(Symbol("value"), Symbol("key")).limit(2) val planned = query.queryExecution.executedPlan assert(planned.isInstanceOf[execution.TakeOrderedAndProjectExec]) - assert(planned.output === testData.select('value, 'key).logicalPlan.output) + assert(planned.output === testData.select(Symbol("value"), Symbol("key")).logicalPlan.output) } test("terminal limits that are not handled by TakeOrderedAndProject should use CollectLimit") { - val query = testData.select('value).limit(2) + val query = testData.select(Symbol("value")).limit(2) val planned = query.queryExecution.sparkPlan assert(planned.isInstanceOf[CollectLimitExec]) - assert(planned.output === testData.select('value).logicalPlan.output) + assert(planned.output === testData.select(Symbol("value")).logicalPlan.output) } test("TakeOrderedAndProject can appear in the middle of plans") { - val query = testData.select('key, 'value).sort('key).limit(2).filter('key === 3) + val query = testData.select(Symbol("key"), Symbol("value")) + .sort(Symbol("key")).limit(2).filter('key === 3) val planned = query.queryExecution.executedPlan - assert(planned.find(_.isInstanceOf[TakeOrderedAndProjectExec]).isDefined) + assert(planned.exists(_.isInstanceOf[TakeOrderedAndProjectExec])) } test("CollectLimit can appear in the middle of a plan when caching is used") { - val query = testData.select('key, 'value).limit(2).cache() + val query = testData.select(Symbol("key"), Symbol("value")).limit(2).cache() val planned = query.queryExecution.optimizedPlan.asInstanceOf[InMemoryRelation] assert(planned.cachedPlan.isInstanceOf[CollectLimitExec]) } test("TakeOrderedAndProjectExec appears only when number of limit is below the threshold.") { withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "1000") { - val query0 = testData.select('value).orderBy('key).limit(100) + val query0 = testData.select(Symbol("value")).orderBy(Symbol("key")).limit(100) val planned0 = query0.queryExecution.executedPlan - assert(planned0.find(_.isInstanceOf[TakeOrderedAndProjectExec]).isDefined) + assert(planned0.exists(_.isInstanceOf[TakeOrderedAndProjectExec])) - val query1 = testData.select('value).orderBy('key).limit(2000) + val query1 = testData.select(Symbol("value")).orderBy(Symbol("key")).limit(2000) val planned1 = query1.queryExecution.executedPlan - assert(planned1.find(_.isInstanceOf[TakeOrderedAndProjectExec]).isEmpty) + assert(!planned1.exists(_.isInstanceOf[TakeOrderedAndProjectExec])) } } @@ -432,7 +437,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { } test("EnsureRequirements should respect ClusteredDistribution's num partitioning") { - val distribution = ClusteredDistribution(Literal(1) :: Nil, Some(13)) + val distribution = ClusteredDistribution(Literal(1) :: Nil, requiredNumPartitions = Some(13)) // Number of partitions differ val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 13) val childPartitioning = HashPartitioning(Literal(1) :: Nil, 5) @@ -706,14 +711,14 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { outputPlan match { case SortMergeJoinExec(leftKeys, rightKeys, _, _, SortExec(_, _, - DummySparkPlan(_, _, HashPartitioning(leftPartitioningExpressions, _), _, _), _), + ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _, _), _), SortExec(_, _, ShuffleExchangeExec(HashPartitioning(rightPartitioningExpressions, _), _, _), _), _) => assert(leftKeys === smjExec.leftKeys) assert(rightKeys === smjExec.rightKeys) - assert(leftPartitioningExpressions == Seq(exprA, exprB, exprA)) - assert(rightPartitioningExpressions == Seq(exprA, exprC, exprA)) + assert(leftKeys === leftPartitioningExpressions) + assert(rightKeys === rightPartitioningExpressions) case _ => fail(outputPlan.toString) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala index ecc448fe250d3..2c58b53969bcd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala @@ -261,4 +261,22 @@ class QueryExecutionSuite extends SharedSparkSession { val cmdResultExec = projectQe.executedPlan.asInstanceOf[CommandResultExec] assert(cmdResultExec.commandPhysicalPlan.isInstanceOf[ShowTablesExec]) } + + test("SPARK-38198: check specify maxFields when call toFile method") { + withTempDir { dir => + val path = dir.getCanonicalPath + "/plans.txt" + // Define a dataset with 6 columns + val ds = spark.createDataset(Seq((0, 1, 2, 3, 4, 5), (6, 7, 8, 9, 10, 11))) + // `CodegenMode` and `FormattedMode` doesn't use the maxFields, so not tested in this case + Seq(SimpleMode.name, ExtendedMode.name, CostMode.name).foreach { modeName => + val maxFields = 3 + ds.queryExecution.debug.toFile(path, explainMode = Some(modeName), maxFields = maxFields) + Utils.tryWithResource(Source.fromFile(path)) { source => + val tableScan = source.getLines().filter(_.contains("LocalTableScan")) + assert(tableScan.exists(_.contains("more fields")), + s"Specify maxFields = $maxFields doesn't take effect when explainMode is $modeName") + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala index 751078d08fda9..21702b6cf582c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala @@ -51,7 +51,7 @@ abstract class RemoveRedundantSortsSuiteBase test("remove redundant sorts with limit") { withTempView("t") { - spark.range(100).select('id as "key").createOrReplaceTempView("t") + spark.range(100).select(Symbol("id") as "key").createOrReplaceTempView("t") val query = """ |SELECT key FROM @@ -64,8 +64,8 @@ abstract class RemoveRedundantSortsSuiteBase test("remove redundant sorts with broadcast hash join") { withTempView("t1", "t2") { - spark.range(1000).select('id as "key").createOrReplaceTempView("t1") - spark.range(1000).select('id as "key").createOrReplaceTempView("t2") + spark.range(1000).select(Symbol("id") as "key").createOrReplaceTempView("t1") + spark.range(1000).select(Symbol("id") as "key").createOrReplaceTempView("t2") val queryTemplate = """ |SELECT /*+ BROADCAST(%s) */ t1.key FROM @@ -100,8 +100,8 @@ abstract class RemoveRedundantSortsSuiteBase test("remove redundant sorts with sort merge join") { withTempView("t1", "t2") { - spark.range(1000).select('id as "key").createOrReplaceTempView("t1") - spark.range(1000).select('id as "key").createOrReplaceTempView("t2") + spark.range(1000).select(Symbol("id") as "key").createOrReplaceTempView("t1") + spark.range(1000).select(Symbol("id") as "key").createOrReplaceTempView("t2") val query = """ |SELECT /*+ MERGE(t1) */ t1.key FROM | (SELECT key FROM t1 WHERE key > 10 ORDER BY key DESC LIMIT 10) t1 @@ -123,8 +123,8 @@ abstract class RemoveRedundantSortsSuiteBase test("cached sorted data doesn't need to be re-sorted") { withSQLConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED.key -> "true") { - val df = spark.range(1000).select('id as "key").sort('key.desc).cache() - val resorted = df.sort('key.desc) + val df = spark.range(1000).select(Symbol("id") as "key").sort(Symbol("key").desc).cache() + val resorted = df.sort(Symbol("key").desc) val sortedAsc = df.sort('key.asc) checkNumSorts(df, 0) checkNumSorts(resorted, 0) @@ -140,7 +140,7 @@ abstract class RemoveRedundantSortsSuiteBase test("SPARK-33472: shuffled join with different left and right side partition numbers") { withTempView("t1", "t2") { - spark.range(0, 100, 1, 2).select('id as "key").createOrReplaceTempView("t1") + spark.range(0, 100, 1, 2).select(Symbol("id") as "key").createOrReplaceTempView("t1") (0 to 100).toDF("key").createOrReplaceTempView("t2") val queryTemplate = """ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 1861d9cf045a1..68eb15b4ae097 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -593,7 +593,8 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { spark.range(10).write.saveAsTable("add_col") withView("v") { sql("CREATE VIEW v AS SELECT * FROM add_col") - spark.range(10).select('id, 'id as 'a).write.mode("overwrite").saveAsTable("add_col") + spark.range(10).select(Symbol("id"), 'id as Symbol("a")) + .write.mode("overwrite").saveAsTable("add_col") checkAnswer(sql("SELECT * FROM v"), spark.range(10).toDF()) } } @@ -765,7 +766,9 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { withTable("t") { Seq(2, 3, 1).toDF("c1").write.format("parquet").saveAsTable("t") withTempView("v1") { - sql("CREATE TEMPORARY VIEW v1 AS SELECT 1/0") + withSQLConf(ANSI_ENABLED.key -> "false") { + sql("CREATE TEMPORARY VIEW v1 AS SELECT 1/0") + } withSQLConf( USE_CURRENT_SQL_CONFIGS_FOR_VIEW.key -> "true", ANSI_ENABLED.key -> "true") { @@ -838,7 +841,9 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { sql("CREATE VIEW v2 (c1) AS SELECT c1 FROM t ORDER BY 1 ASC, c1 DESC") sql("CREATE VIEW v3 (c1, count) AS SELECT c1, count(c1) AS cnt FROM t GROUP BY 1") sql("CREATE VIEW v4 (a, count) AS SELECT c1 as a, count(c1) AS cnt FROM t GROUP BY a") - sql("CREATE VIEW v5 (c1) AS SELECT 1/0 AS invalid") + withSQLConf(ANSI_ENABLED.key -> "false") { + sql("CREATE VIEW v5 (c1) AS SELECT 1/0 AS invalid") + } withSQLConf(CASE_SENSITIVE.key -> "true") { checkAnswer(sql("SELECT * FROM v1"), Seq(Row(2), Row(3), Row(1))) @@ -903,4 +908,23 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } } } + + test("SPARK-37932: view join with same view") { + withTable("t") { + withView("v1") { + Seq((1, "test1"), (2, "test2"), (1, "test2")).toDF("id", "name") + .write.format("parquet").saveAsTable("t") + sql("CREATE VIEW v1 (id, name) AS SELECT id, name FROM t") + + checkAnswer( + sql("""SELECT l1.id FROM v1 l1 + |INNER JOIN ( + | SELECT id FROM v1 + | GROUP BY id HAVING COUNT(DISTINCT name) > 1 + | ) l2 ON l1.id = l2.id GROUP BY l1.name, l1.id; + |""".stripMargin), + Seq(Row(1), Row(1))) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala index 730299c2f2c9b..316b1cfd5e842 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import scala.collection.JavaConverters._ -import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.CatalogFunction import org.apache.spark.sql.catalyst.expressions.Expression @@ -45,10 +45,12 @@ abstract class SQLViewTestSuite extends QueryTest with SQLTestUtils { viewName: String, sqlText: String, columnNames: Seq[String] = Seq.empty, + others: Seq[String] = Seq.empty, replace: Boolean = false): String = { val replaceString = if (replace) "OR REPLACE" else "" val columnString = if (columnNames.nonEmpty) columnNames.mkString("(", ",", ")") else "" - sql(s"CREATE $replaceString $viewTypeString $viewName $columnString AS $sqlText") + val othersString = if (others.nonEmpty) others.mkString(" ") else "" + sql(s"CREATE $replaceString $viewTypeString $viewName $columnString $othersString AS $sqlText") formattedViewName(viewName) } @@ -117,11 +119,13 @@ abstract class SQLViewTestSuite extends QueryTest with SQLTestUtils { test("change SQLConf should not change view behavior - ansiEnabled") { withTable("t") { Seq(2, 3, 1).toDF("c1").write.format("parquet").saveAsTable("t") - val viewName = createView("v1", "SELECT 1/0 AS invalid", Seq("c1")) - withView(viewName) { - Seq("true", "false").foreach { flag => - withSQLConf(ANSI_ENABLED.key -> flag) { - checkViewOutput(viewName, Seq(Row(null))) + withSQLConf(ANSI_ENABLED.key -> "false") { + val viewName = createView("v1", "SELECT 1/0 AS invalid", Seq("c1")) + withView(viewName) { + Seq("true", "false").foreach { flag => + withSQLConf(ANSI_ENABLED.key -> flag) { + checkViewOutput(viewName, Seq(Row(null))) + } } } } @@ -406,6 +410,9 @@ abstract class SQLViewTestSuite extends QueryTest with SQLTestUtils { } abstract class TempViewTestSuite extends SQLViewTestSuite { + + def createOrReplaceDatasetView(df: DataFrame, viewName: String): Unit + test("SPARK-37202: temp view should capture the function registered by catalog API") { val funcName = "tempFunc" withUserDefinedFunction(funcName -> true) { @@ -421,6 +428,40 @@ abstract class TempViewTestSuite extends SQLViewTestSuite { } } } + + test("show create table does not support temp view") { + val viewName = "spark_28383" + withView(viewName) { + createView(viewName, "SELECT 1 AS a") + val ex = intercept[AnalysisException] { + sql(s"SHOW CREATE TABLE ${formattedViewName(viewName)}") + } + assert(ex.getMessage.contains( + s"$viewName is a temp view. 'SHOW CREATE TABLE' expects a table or permanent view.")) + } + } + + test("back compatibility: skip cyclic reference check if view is stored as logical plan") { + val viewName = formattedViewName("v") + withSQLConf(STORE_ANALYZED_PLAN_FOR_VIEW.key -> "false") { + withView(viewName) { + createOrReplaceDatasetView(sql("SELECT 1"), "v") + createOrReplaceDatasetView(sql(s"SELECT * FROM $viewName"), "v") + checkViewOutput(viewName, Seq(Row(1))) + } + } + withSQLConf(STORE_ANALYZED_PLAN_FOR_VIEW.key -> "true") { + withView(viewName) { + createOrReplaceDatasetView(sql("SELECT 1"), "v") + createOrReplaceDatasetView(sql(s"SELECT * FROM $viewName"), "v") + checkViewOutput(viewName, Seq(Row(1))) + + createView("v", "SELECT 2", replace = true) + createView("v", s"SELECT * FROM $viewName", replace = true) + checkViewOutput(viewName, Seq(Row(2))) + } + } + } } class LocalTempViewTestSuite extends TempViewTestSuite with SharedSparkSession { @@ -429,6 +470,9 @@ class LocalTempViewTestSuite extends TempViewTestSuite with SharedSparkSession { override protected def tableIdentifier(viewName: String): TableIdentifier = { TableIdentifier(viewName) } + override def createOrReplaceDatasetView(df: DataFrame, viewName: String): Unit = { + df.createOrReplaceTempView(viewName) + } } class GlobalTempViewTestSuite extends TempViewTestSuite with SharedSparkSession { @@ -440,6 +484,9 @@ class GlobalTempViewTestSuite extends TempViewTestSuite with SharedSparkSession override protected def tableIdentifier(viewName: String): TableIdentifier = { TableIdentifier(viewName, Some(db)) } + override def createOrReplaceDatasetView(df: DataFrame, viewName: String): Unit = { + df.createOrReplaceGlobalTempView(viewName) + } } class OneTableCatalog extends InMemoryCatalog { @@ -591,4 +638,52 @@ class PersistedViewTestSuite extends SQLViewTestSuite with SharedSparkSession { s" The view ${table.qualifiedName} may have been tampered with")) } } + + test("show create table for persisted simple view") { + val viewName = "v1" + Seq(true, false).foreach { serde => + withView(viewName) { + createView(viewName, "SELECT 1 AS a") + val expected = s"CREATE VIEW ${formattedViewName(viewName)} ( a) AS SELECT 1 AS a" + assert(getShowCreateDDL(formattedViewName(viewName), serde) == expected) + } + } + } + + test("show create table for persisted view with output columns") { + val viewName = "v1" + Seq(true, false).foreach { serde => + withView(viewName) { + createView(viewName, "SELECT 1 AS a, 2 AS b", Seq("a", "b COMMENT 'b column'")) + val expected = s"CREATE VIEW ${formattedViewName(viewName)}" + + s" ( a, b COMMENT 'b column') AS SELECT 1 AS a, 2 AS b" + assert(getShowCreateDDL(formattedViewName(viewName), serde) == expected) + } + } + } + + test("show create table for persisted simple view with table comment and properties") { + val viewName = "v1" + Seq(true, false).foreach { serde => + withView(viewName) { + createView(viewName, "SELECT 1 AS c1, '2' AS c2", Seq("c1 COMMENT 'bla'", "c2"), + Seq("COMMENT 'table comment'", "TBLPROPERTIES ( 'prop2' = 'value2', 'prop1' = 'value1')")) + + val expected = s"CREATE VIEW ${formattedViewName(viewName)} ( c1 COMMENT 'bla', c2)" + + " COMMENT 'table comment'" + + " TBLPROPERTIES ( 'prop1' = 'value1', 'prop2' = 'value2')" + + " AS SELECT 1 AS c1, '2' AS c2" + assert(getShowCreateDDL(formattedViewName(viewName), serde) == expected) + } + } + } + + def getShowCreateDDL(view: String, serde: Boolean = false): String = { + val result = if (serde) { + sql(s"SHOW CREATE TABLE $view AS SERDE") + } else { + sql(s"SHOW CREATE TABLE $view") + } + result.head().getString(0).split("\n").map(_.trim).mkString(" ") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala index 99856650fea1f..da05373125d31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala @@ -705,50 +705,50 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite with LocalSparkContext { val smallPartitionFactor1 = ShufflePartitionsUtil.SMALL_PARTITION_FACTOR // merge the small partitions at the beginning/end - val sizeList1 = Seq[Long](15, 90, 15, 15, 15, 90, 15) + val sizeList1 = Array[Long](15, 90, 15, 15, 15, 90, 15) assert(ShufflePartitionsUtil.splitSizeListByTargetSize( sizeList1, targetSize, smallPartitionFactor1).toSeq == Seq(0, 2, 5)) // merge the small partitions in the middle - val sizeList2 = Seq[Long](30, 15, 90, 10, 90, 15, 30) + val sizeList2 = Array[Long](30, 15, 90, 10, 90, 15, 30) assert(ShufflePartitionsUtil.splitSizeListByTargetSize( sizeList2, targetSize, smallPartitionFactor1).toSeq == Seq(0, 2, 4, 5)) // merge small partitions if the partition itself is smaller than // targetSize * SMALL_PARTITION_FACTOR - val sizeList3 = Seq[Long](15, 1000, 15, 1000) + val sizeList3 = Array[Long](15, 1000, 15, 1000) assert(ShufflePartitionsUtil.splitSizeListByTargetSize( sizeList3, targetSize, smallPartitionFactor1).toSeq == Seq(0, 3)) // merge small partitions if the combined size is smaller than // targetSize * MERGED_PARTITION_FACTOR - val sizeList4 = Seq[Long](35, 75, 90, 20, 35, 25, 35) + val sizeList4 = Array[Long](35, 75, 90, 20, 35, 25, 35) assert(ShufflePartitionsUtil.splitSizeListByTargetSize( sizeList4, targetSize, smallPartitionFactor1).toSeq == Seq(0, 2, 3)) val smallPartitionFactor2 = 0.5 // merge last two partition if their size is not bigger than smallPartitionFactor * target - val sizeList5 = Seq[Long](50, 50, 40, 5) + val sizeList5 = Array[Long](50, 50, 40, 5) assert(ShufflePartitionsUtil.splitSizeListByTargetSize( sizeList5, targetSize, smallPartitionFactor2).toSeq == Seq(0)) - val sizeList6 = Seq[Long](40, 5, 50, 45) + val sizeList6 = Array[Long](40, 5, 50, 45) assert(ShufflePartitionsUtil.splitSizeListByTargetSize( sizeList6, targetSize, smallPartitionFactor2).toSeq == Seq(0)) // do not merge - val sizeList7 = Seq[Long](50, 50, 10, 40, 5) + val sizeList7 = Array[Long](50, 50, 10, 40, 5) assert(ShufflePartitionsUtil.splitSizeListByTargetSize( sizeList7, targetSize, smallPartitionFactor2).toSeq == Seq(0, 2)) - val sizeList8 = Seq[Long](10, 40, 5, 50, 50) + val sizeList8 = Array[Long](10, 40, 5, 50, 50) assert(ShufflePartitionsUtil.splitSizeListByTargetSize( sizeList8, targetSize, smallPartitionFactor2).toSeq == Seq(0, 3)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index 812fdba8dda23..5fa7a4d0c71cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -44,13 +44,15 @@ class SortSuite extends SparkPlanTest with SharedSparkSession { checkAnswer( input.toDF("a", "b", "c"), - (child: SparkPlan) => SortExec('a.asc :: 'b.asc :: Nil, global = true, child = child), + (child: SparkPlan) => SortExec(Symbol("a").asc :: Symbol("b").asc :: Nil, + global = true, child = child), input.sortBy(t => (t._1, t._2)).map(Row.fromTuple), sortAnswers = false) checkAnswer( input.toDF("a", "b", "c"), - (child: SparkPlan) => SortExec('b.asc :: 'a.asc :: Nil, global = true, child = child), + (child: SparkPlan) => SortExec(Symbol("b").asc :: Symbol("a").asc :: Nil, + global = true, child = child), input.sortBy(t => (t._2, t._1)).map(Row.fromTuple), sortAnswers = false) } @@ -59,9 +61,9 @@ class SortSuite extends SparkPlanTest with SharedSparkSession { checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF().selectExpr("NULL as a"), (child: SparkPlan) => - GlobalLimitExec(10, SortExec('a.asc :: Nil, global = true, child = child)), + GlobalLimitExec(10, SortExec(Symbol("a").asc :: Nil, global = true, child = child)), (child: SparkPlan) => - GlobalLimitExec(10, ReferenceSort('a.asc :: Nil, global = true, child)), + GlobalLimitExec(10, ReferenceSort(Symbol("a").asc :: Nil, global = true, child)), sortAnswers = false ) } @@ -70,15 +72,15 @@ class SortSuite extends SparkPlanTest with SharedSparkSession { checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), (child: SparkPlan) => - GlobalLimitExec(10, SortExec('a.asc :: Nil, global = true, child = child)), + GlobalLimitExec(10, SortExec(Symbol("a").asc :: Nil, global = true, child = child)), (child: SparkPlan) => - GlobalLimitExec(10, ReferenceSort('a.asc :: Nil, global = true, child)), + GlobalLimitExec(10, ReferenceSort(Symbol("a").asc :: Nil, global = true, child)), sortAnswers = false ) } test("sorting does not crash for large inputs") { - val sortOrder = 'a.asc :: Nil + val sortOrder = Symbol("a").asc :: Nil val stringLength = 1024 * 1024 * 2 checkThatPlansAgree( Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1), @@ -92,8 +94,8 @@ class SortSuite extends SparkPlanTest with SharedSparkSession { AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "unsafe external sort") { checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), - (child: SparkPlan) => SortExec('a.asc :: Nil, global = true, child = child), - (child: SparkPlan) => ReferenceSort('a.asc :: Nil, global = true, child), + (child: SparkPlan) => SortExec(Symbol("a").asc :: Nil, global = true, child = child), + (child: SparkPlan) => ReferenceSort(Symbol("a").asc :: Nil, global = true, child), sortAnswers = false) } } @@ -106,7 +108,8 @@ class SortSuite extends SparkPlanTest with SharedSparkSession { ) checkAnswer( input.toDF("a", "b", "c"), - (child: SparkPlan) => SortExec(Stream('a.asc, 'b.asc, 'c.asc), global = true, child = child), + (child: SparkPlan) => SortExec(Stream(Symbol("a").asc, 'b.asc, 'c.asc), + global = true, child = child), input.sortBy(t => (t._1, t._2, t._3)).map(Row.fromTuple), sortAnswers = false) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index ba6dd170d89a9..fb8f2ea6d8db2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -47,7 +47,7 @@ class SparkSqlParserSuite extends AnalysisTest { } private def intercept(sqlCommand: String, messages: String*): Unit = - interceptParseException(parser.parsePlan)(sqlCommand, messages: _*) + interceptParseException(parser.parsePlan)(sqlCommand, messages: _*)() test("Checks if SET/RESET can parse all the configurations") { // Force to build static SQL configurations @@ -312,7 +312,7 @@ class SparkSqlParserSuite extends AnalysisTest { Seq(AttributeReference("a", StringType)(), AttributeReference("b", StringType)(), AttributeReference("c", StringType)()), - Project(Seq('a, 'b, 'c), + Project(Seq(Symbol("a"), Symbol("b"), Symbol("c")), UnresolvedRelation(TableIdentifier("testData"))), ioSchema)) @@ -336,9 +336,9 @@ class SparkSqlParserSuite extends AnalysisTest { UnresolvedFunction("sum", Seq(UnresolvedAttribute("b")), isDistinct = false), Literal(10)), Aggregate( - Seq('a), + Seq(Symbol("a")), Seq( - 'a, + Symbol("a"), UnresolvedAlias( UnresolvedFunction("sum", Seq(UnresolvedAttribute("b")), isDistinct = false), None), UnresolvedAlias( @@ -363,12 +363,12 @@ class SparkSqlParserSuite extends AnalysisTest { AttributeReference("c", StringType)()), WithWindowDefinition( Map("w" -> WindowSpecDefinition( - Seq('a), - Seq(SortOrder('b, Ascending, NullsFirst, Seq.empty)), + Seq(Symbol("a")), + Seq(SortOrder(Symbol("b"), Ascending, NullsFirst, Seq.empty)), UnspecifiedFrame)), Project( Seq( - 'a, + Symbol("a"), UnresolvedAlias( UnresolvedWindowExpression( UnresolvedFunction("sum", Seq(UnresolvedAttribute("b")), isDistinct = false), @@ -403,9 +403,9 @@ class SparkSqlParserSuite extends AnalysisTest { UnresolvedFunction("sum", Seq(UnresolvedAttribute("b")), isDistinct = false), Literal(10)), Aggregate( - Seq('a, 'myCol, 'myCol2), + Seq(Symbol("a"), Symbol("myCol"), Symbol("myCol2")), Seq( - 'a, + Symbol("a"), UnresolvedAlias( UnresolvedFunction("sum", Seq(UnresolvedAttribute("b")), isDistinct = false), None), UnresolvedAlias( @@ -415,7 +415,7 @@ class SparkSqlParserSuite extends AnalysisTest { UnresolvedGenerator( FunctionIdentifier("explode"), Seq(UnresolvedAttribute("myTable.myCol"))), - Nil, false, Option("mytable2"), Seq('myCol2), + Nil, false, Option("mytable2"), Seq(Symbol("myCol2")), Generate( UnresolvedGenerator( FunctionIdentifier("explode"), @@ -423,7 +423,7 @@ class SparkSqlParserSuite extends AnalysisTest { Seq( UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), false)), false))), - Nil, false, Option("mytable"), Seq('myCol), + Nil, false, Option("mytable"), Seq(Symbol("myCol")), UnresolvedRelation(TableIdentifier("testData")))))), ioSchema)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SubExprEliminationBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SubExprEliminationBenchmark.scala index c025670fb895e..3718b3a3c3378 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SubExprEliminationBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SubExprEliminationBenchmark.scala @@ -49,7 +49,7 @@ object SubExprEliminationBenchmark extends SqlBasedBenchmark { val schema = writeWideRow(path.getAbsolutePath, rowsNum, numCols) val cols = (0 until numCols).map { idx => - from_json('value, schema).getField(s"col$idx") + from_json(Symbol("value"), schema).getField(s"col$idx") } Seq( @@ -88,7 +88,7 @@ object SubExprEliminationBenchmark extends SqlBasedBenchmark { val schema = writeWideRow(path.getAbsolutePath, rowsNum, numCols) val predicate = (0 until numCols).map { idx => - (from_json('value, schema).getField(s"col$idx") >= Literal(100000)).expr + (from_json(Symbol("value"), schema).getField(s"col$idx") >= Literal(100000)).expr }.asInstanceOf[Seq[Expression]].reduce(Or) Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala index 6ec5c6287eed1..ce48945e52c5d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala @@ -58,7 +58,7 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSparkSession { private def noOpFilter(plan: SparkPlan): SparkPlan = FilterExec(Literal(true), plan) val limit = 250 - val sortOrder = 'a.desc :: 'b.desc :: Nil + val sortOrder = Symbol("a").desc :: Symbol("b").desc :: Nil test("TakeOrderedAndProject.doExecute without project") { withClue(s"seed = $seed") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSparkSubmitSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSparkSubmitSuite.scala index ffbdc3f64195f..73c4e4c3e1eb8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSparkSubmitSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSparkSubmitSuite.scala @@ -71,7 +71,7 @@ object WholeStageCodegenSparkSubmitSuite extends Assertions with Logging { var spark: SparkSession = _ def main(args: Array[String]): Unit = { - TestUtils.configTestLog4j("INFO") + TestUtils.configTestLog4j2("INFO") spark = SparkSession.builder().getOrCreate() @@ -84,7 +84,7 @@ object WholeStageCodegenSparkSubmitSuite extends Assertions with Logging { val df = spark.range(71773).select((col("id") % lit(10)).cast(IntegerType) as "v") .groupBy(array(col("v"))).agg(count(col("*"))) val plan = df.queryExecution.executedPlan - assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined) + assert(plan.exists(_.isInstanceOf[WholeStageCodegenExec])) val expectedAnswer = Row(Array(0), 7178) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 7332d49b942f8..ba511354f7a40 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -37,16 +37,16 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession test("range/filter should be combined") { val df = spark.range(10).filter("id = 1").selectExpr("id + 1") val plan = df.queryExecution.executedPlan - assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined) + assert(plan.exists(_.isInstanceOf[WholeStageCodegenExec])) assert(df.collect() === Array(Row(2))) } test("HashAggregate should be included in WholeStageCodegen") { val df = spark.range(10).groupBy().agg(max(col("id")), avg(col("id"))) val plan = df.queryExecution.executedPlan - assert(plan.find(p => + assert(plan.exists(p => p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined) + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec])) assert(df.collect() === Array(Row(9, 4.5))) } @@ -54,9 +54,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession val df = spark.range(10).agg(max(col("id")), avg(col("id"))) withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true") { val plan = df.queryExecution.executedPlan - assert(plan.find(p => + assert(plan.exists(p => p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec]).isDefined) + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec])) assert(df.collect() === Array(Row(9, 4.5))) } } @@ -70,22 +70,22 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession // Array - explode var expDF = df.select($"name", explode($"knownLanguages"), $"properties") var plan = expDF.queryExecution.executedPlan - assert(plan.find { + assert(plan.exists { case stage: WholeStageCodegenExec => - stage.find(_.isInstanceOf[GenerateExec]).isDefined + stage.exists(_.isInstanceOf[GenerateExec]) case _ => !codegenEnabled.toBoolean - }.isDefined) + }) checkAnswer(expDF, Array(Row("James", "Java", Map("hair" -> "black", "eye" -> "brown")), Row("James", "Scala", Map("hair" -> "black", "eye" -> "brown")))) // Map - explode expDF = df.select($"name", $"knownLanguages", explode($"properties")) plan = expDF.queryExecution.executedPlan - assert(plan.find { + assert(plan.exists { case stage: WholeStageCodegenExec => - stage.find(_.isInstanceOf[GenerateExec]).isDefined + stage.exists(_.isInstanceOf[GenerateExec]) case _ => !codegenEnabled.toBoolean - }.isDefined) + }) checkAnswer(expDF, Array(Row("James", List("Java", "Scala"), "hair", "black"), Row("James", List("Java", "Scala"), "eye", "brown"))) @@ -93,33 +93,33 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession // Array - posexplode expDF = df.select($"name", posexplode($"knownLanguages")) plan = expDF.queryExecution.executedPlan - assert(plan.find { + assert(plan.exists { case stage: WholeStageCodegenExec => - stage.find(_.isInstanceOf[GenerateExec]).isDefined + stage.exists(_.isInstanceOf[GenerateExec]) case _ => !codegenEnabled.toBoolean - }.isDefined) + }) checkAnswer(expDF, Array(Row("James", 0, "Java"), Row("James", 1, "Scala"))) // Map - posexplode expDF = df.select($"name", posexplode($"properties")) plan = expDF.queryExecution.executedPlan - assert(plan.find { + assert(plan.exists { case stage: WholeStageCodegenExec => - stage.find(_.isInstanceOf[GenerateExec]).isDefined + stage.exists(_.isInstanceOf[GenerateExec]) case _ => !codegenEnabled.toBoolean - }.isDefined) + }) checkAnswer(expDF, Array(Row("James", 0, "hair", "black"), Row("James", 1, "eye", "brown"))) // Array - explode , selecting all columns expDF = df.select($"*", explode($"knownLanguages")) plan = expDF.queryExecution.executedPlan - assert(plan.find { + assert(plan.exists { case stage: WholeStageCodegenExec => - stage.find(_.isInstanceOf[GenerateExec]).isDefined + stage.exists(_.isInstanceOf[GenerateExec]) case _ => !codegenEnabled.toBoolean - }.isDefined) + }) checkAnswer(expDF, Array(Row("James", Seq("Java", "Scala"), Map("hair" -> "black", "eye" -> "brown"), "Java"), Row("James", Seq("Java", "Scala"), Map("hair" -> "black", "eye" -> "brown"), "Scala"))) @@ -127,11 +127,11 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession // Map - explode, selecting all columns expDF = df.select($"*", explode($"properties")) plan = expDF.queryExecution.executedPlan - assert(plan.find { + assert(plan.exists { case stage: WholeStageCodegenExec => - stage.find(_.isInstanceOf[GenerateExec]).isDefined + stage.exists(_.isInstanceOf[GenerateExec]) case _ => !codegenEnabled.toBoolean - }.isDefined) + }) checkAnswer(expDF, Array( Row("James", List("Java", "Scala"), @@ -143,9 +143,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession test("HashAggregate with grouping keys should be included in WholeStageCodegen") { val df = spark.range(3).groupBy(col("id") * 2).count().orderBy(col("id") * 2) val plan = df.queryExecution.executedPlan - assert(plan.find(p => + assert(plan.exists(p => p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined) + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec])) assert(df.collect() === Array(Row(0, 1), Row(2, 1), Row(4, 1))) } @@ -154,9 +154,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession val schema = new StructType().add("k", IntegerType).add("v", StringType) val smallDF = spark.createDataFrame(rdd, schema) val df = spark.range(10).join(broadcast(smallDF), col("k") === col("id")) - assert(df.queryExecution.executedPlan.find(p => + assert(df.queryExecution.executedPlan.exists(p => p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[BroadcastHashJoinExec]).isDefined) + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[BroadcastHashJoinExec])) assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2"))) } @@ -434,9 +434,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession test("Sort should be included in WholeStageCodegen") { val df = spark.range(3, 0, -1).toDF().sort(col("id")) val plan = df.queryExecution.executedPlan - assert(plan.find(p => + assert(plan.exists(p => p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortExec]).isDefined) + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortExec])) assert(df.collect() === Array(Row(1), Row(2), Row(3))) } @@ -445,27 +445,27 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession val ds = spark.range(10).map(_.toString) val plan = ds.queryExecution.executedPlan - assert(plan.find(p => + assert(plan.exists(p => p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SerializeFromObjectExec]).isDefined) + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SerializeFromObjectExec])) assert(ds.collect() === 0.until(10).map(_.toString).toArray) } test("typed filter should be included in WholeStageCodegen") { val ds = spark.range(10).filter(_ % 2 == 0) val plan = ds.queryExecution.executedPlan - assert(plan.find(p => + assert(plan.exists(p => p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined) + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec])) assert(ds.collect() === Array(0, 2, 4, 6, 8)) } test("back-to-back typed filter should be included in WholeStageCodegen") { val ds = spark.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0) val plan = ds.queryExecution.executedPlan - assert(plan.find(p => + assert(plan.exists(p => p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined) + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec])) assert(ds.collect() === Array(0, 6)) } @@ -517,10 +517,10 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession .select("int") val plan = df.queryExecution.executedPlan - assert(plan.find(p => + assert(!plan.exists(p => p.isInstanceOf[WholeStageCodegenExec] && p.asInstanceOf[WholeStageCodegenExec].child.children(0) - .isInstanceOf[SortMergeJoinExec]).isEmpty) + .isInstanceOf[SortMergeJoinExec])) assert(df.collect() === Array(Row(1), Row(2))) } } @@ -573,7 +573,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession import testImplicits._ withTempPath { dir => val path = dir.getCanonicalPath - val df = spark.range(10).select(Seq.tabulate(201) {i => ('id + i).as(s"c$i")} : _*) + val df = spark.range(10).select(Seq.tabulate(201) {i => (Symbol("id") + i).as(s"c$i")} : _*) df.write.mode(SaveMode.Overwrite).parquet(path) withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "202", @@ -590,7 +590,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession test("Control splitting consume function by operators with config") { import testImplicits._ - val df = spark.range(10).select(Seq.tabulate(2) {i => ('id + i).as(s"c$i")} : _*) + val df = spark.range(10).select(Seq.tabulate(2) {i => (Symbol("id") + i).as(s"c$i")} : _*) Seq(true, false).foreach { config => withSQLConf(SQLConf.WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR.key -> s"$config") { @@ -639,9 +639,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession val df = spark.range(100) val join = df.join(df, "id") val plan = join.queryExecution.executedPlan - assert(plan.find(p => + assert(!plan.exists(p => p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].codegenStageId == 0).isEmpty, + p.asInstanceOf[WholeStageCodegenExec].codegenStageId == 0), "codegen stage IDs should be preserved through ReuseExchange") checkAnswer(join, df.toDF) } @@ -653,9 +653,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME.key -> "true") { // the same query run twice should produce identical code, which would imply a hit in // the generated code cache. - val ds1 = spark.range(3).select('id + 2) + val ds1 = spark.range(3).select(Symbol("id") + 2) val code1 = genCode(ds1) - val ds2 = spark.range(3).select('id + 2) + val ds2 = spark.range(3).select(Symbol("id") + 2) val code2 = genCode(ds2) // same query shape as above, deliberately assert(code1 == code2, "Should produce same code") } @@ -700,10 +700,11 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession // BroadcastHashJoinExec with a HashAggregateExec child containing no aggregate expressions val distinctWithId = baseTable.distinct().withColumn("id", monotonically_increasing_id()) .join(baseTable, "idx") - assert(distinctWithId.queryExecution.executedPlan.collectFirst { + assert(distinctWithId.queryExecution.executedPlan.exists { case WholeStageCodegenExec( ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _, _))) => true - }.isDefined) + case _ => false + }) checkAnswer(distinctWithId, Seq(Row(1, 0), Row(1, 0))) // BroadcastHashJoinExec with a HashAggregateExec child containing a Final mode aggregate @@ -711,10 +712,11 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession val groupByWithId = baseTable.groupBy("idx").sum().withColumn("id", monotonically_increasing_id()) .join(baseTable, "idx") - assert(groupByWithId.queryExecution.executedPlan.collectFirst { + assert(groupByWithId.queryExecution.executedPlan.exists { case WholeStageCodegenExec( ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _, _))) => true - }.isDefined) + case _ => false + }) checkAnswer(groupByWithId, Seq(Row(1, 2, 0), Row(1, 2, 0))) } } @@ -740,11 +742,11 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession // HashAggregateExec supports WholeStageCodegen and it's the parent of // LocalTableScanExec so LocalTableScanExec should be within a WholeStageCodegen domain. assert( - executedPlan.find { + executedPlan.exists { case WholeStageCodegenExec( - HashAggregateExec(_, _, _, _, _, _, _: LocalTableScanExec)) => true + HashAggregateExec(_, _, _, _, _, _, _, _, _: LocalTableScanExec)) => true case _ => false - }.isDefined, + }, "LocalTableScanExec should be within a WholeStageCodegen domain.") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index a29989cc06c7c..76741dc4d08e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.execution.{CollectLimitExec, CommandResultExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, UnaryExecNode, UnionExec} +import org.apache.spark.sql.execution.aggregate.BaseAggregateExec import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.datasources.noop.NoopDataSource import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec @@ -126,6 +127,12 @@ class AdaptiveQueryExecSuite } } + private def findTopLevelAggregate(plan: SparkPlan): Seq[BaseAggregateExec] = { + collect(plan) { + case agg: BaseAggregateExec => agg + } + } + private def findTopLevelLimit(plan: SparkPlan): Seq[CollectLimitExec] = { collect(plan) { case l: CollectLimitExec => l @@ -187,6 +194,29 @@ class AdaptiveQueryExecSuite } } + test("Change broadcast join to merge join") { + withTable("t1", "t2") { + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10000", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + sql("CREATE TABLE t1 USING PARQUET AS SELECT 1 c1") + sql("CREATE TABLE t2 USING PARQUET AS SELECT 1 c1") + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + """ + |SELECT * FROM ( + | SELECT distinct c1 from t1 + | ) tmp1 JOIN ( + | SELECT distinct c1 from t2 + | ) tmp2 ON tmp1.c1 = tmp2.c1 + |""".stripMargin) + assert(findTopLevelBroadcastHashJoin(plan).size == 1) + assert(findTopLevelBroadcastHashJoin(adaptivePlan).isEmpty) + assert(findTopLevelSortMergeJoin(adaptivePlan).size == 1) + } + } + } + test("Reuse the parallelism of coalesced shuffle in local shuffle read") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", @@ -250,11 +280,12 @@ class AdaptiveQueryExecSuite SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) { - val df1 = spark.range(10).withColumn("a", 'id) - val df2 = spark.range(10).withColumn("b", 'id) + val df1 = spark.range(10).withColumn("a", Symbol("id")) + val df2 = spark.range(10).withColumn("b", Symbol("id")) withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - val testDf = df1.where('a > 10).join(df2.where('b > 10), Seq("id"), "left_outer") - .groupBy('a).count() + val testDf = df1.where(Symbol("a") > 10) + .join(df2.where(Symbol("b") > 10), Seq("id"), "left_outer") + .groupBy(Symbol("a")).count() checkAnswer(testDf, Seq()) val plan = testDf.queryExecution.executedPlan assert(find(plan)(_.isInstanceOf[SortMergeJoinExec]).isDefined) @@ -266,8 +297,9 @@ class AdaptiveQueryExecSuite } withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") { - val testDf = df1.where('a > 10).join(df2.where('b > 10), Seq("id"), "left_outer") - .groupBy('a).count() + val testDf = df1.where(Symbol("a") > 10) + .join(df2.where(Symbol("b") > 10), Seq("id"), "left_outer") + .groupBy(Symbol("a")).count() checkAnswer(testDf, Seq()) val plan = testDf.queryExecution.executedPlan assert(find(plan)(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) @@ -651,6 +683,41 @@ class AdaptiveQueryExecSuite } } } + test("SPARK-37753: Allow changing outer join to broadcast join even if too many empty" + + " partitions on broadcast side") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.NON_EMPTY_PARTITION_RATIO_FOR_BROADCAST_JOIN.key -> "0.5") { + // `testData` is small enough to be broadcast but has empty partition ratio over the config. + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM (select * from testData where value = '1') td" + + " right outer join testData2 ON key = a") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + } + } + } + + test("SPARK-37753: Inhibit broadcast in left outer join when there are many empty" + + " partitions on outer/left side") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.NON_EMPTY_PARTITION_RATIO_FOR_BROADCAST_JOIN.key -> "0.5") { + // `testData` is small enough to be broadcast but has empty partition ratio over the config. + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "200") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM (select * from testData where value = '1') td" + + " left outer join testData2 ON key = a") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.isEmpty) + } + } + } test("SPARK-29906: AQE should not introduce extra shuffle for outermost limit") { var numStages = 0 @@ -721,17 +788,17 @@ class AdaptiveQueryExecSuite spark .range(0, 1000, 1, 10) .select( - when('id < 250, 249) - .when('id >= 750, 1000) - .otherwise('id).as("key1"), - 'id as "value1") + when(Symbol("id") < 250, 249) + .when(Symbol("id") >= 750, 1000) + .otherwise(Symbol("id")).as("key1"), + Symbol("id") as "value1") .createOrReplaceTempView("skewData1") spark .range(0, 1000, 1, 10) .select( - when('id < 250, 249) - .otherwise('id).as("key2"), - 'id as "value2") + when(Symbol("id") < 250, 249) + .otherwise(Symbol("id")).as("key2"), + Symbol("id") as "value2") .createOrReplaceTempView("skewData2") def checkSkewJoin( @@ -966,17 +1033,17 @@ class AdaptiveQueryExecSuite spark .range(0, 1000, 1, 10) .select( - when('id < 250, 249) - .when('id >= 750, 1000) - .otherwise('id).as("key1"), - 'id as "value1") + when(Symbol("id") < 250, 249) + .when(Symbol("id") >= 750, 1000) + .otherwise(Symbol("id")).as("key1"), + Symbol("id") as "value1") .createOrReplaceTempView("skewData1") spark .range(0, 1000, 1, 10) .select( - when('id < 250, 249) - .otherwise('id).as("key2"), - 'id as "value2") + when(Symbol("id") < 250, 249) + .otherwise(Symbol("id")).as("key2"), + Symbol("id") as "value2") .createOrReplaceTempView("skewData2") val (_, adaptivePlan) = runAdaptiveAndVerifyResult( "SELECT * FROM skewData1 join skewData2 ON key1 = key2") @@ -1054,7 +1121,7 @@ class AdaptiveQueryExecSuite test("AQE should set active session during execution") { withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { - val df = spark.range(10).select(sum('id)) + val df = spark.range(10).select(sum(Symbol("id"))) assert(df.queryExecution.executedPlan.isInstanceOf[AdaptiveSparkPlanExec]) SparkSession.setActiveSession(null) checkAnswer(df, Seq(Row(45))) @@ -1081,7 +1148,7 @@ class AdaptiveQueryExecSuite SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { try { spark.experimental.extraStrategies = TestStrategy :: Nil - val df = spark.range(10).groupBy('id).count() + val df = spark.range(10).groupBy(Symbol("id")).count() df.collect() } finally { spark.experimental.extraStrategies = Nil @@ -1537,7 +1604,7 @@ class AdaptiveQueryExecSuite test("SPARK-33494: Do not use local shuffle read for repartition") { withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { - val df = spark.table("testData").repartition('key) + val df = spark.table("testData").repartition(Symbol("key")) df.collect() // local shuffle read breaks partitioning and shouldn't be used for repartition operation // which is specified by users. @@ -1616,28 +1683,28 @@ class AdaptiveQueryExecSuite | SELECT * FROM testData WHERE key = 1 |) |RIGHT OUTER JOIN testData2 - |ON value = b + |ON CAST(value AS INT) = b """.stripMargin) withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { // Repartition with no partition num specified. - checkBHJ(df.repartition('b), + checkBHJ(df.repartition(Symbol("b")), // The top shuffle from repartition is optimized out. optimizeOutRepartition = true, probeSideLocalRead = false, probeSideCoalescedRead = true) // Repartition with default partition num (5 in test env) specified. - checkBHJ(df.repartition(5, 'b), + checkBHJ(df.repartition(5, Symbol("b")), // The top shuffle from repartition is optimized out // The final plan must have 5 partitions, no optimization can be made to the probe side. optimizeOutRepartition = true, probeSideLocalRead = false, probeSideCoalescedRead = false) // Repartition with non-default partition num specified. - checkBHJ(df.repartition(4, 'b), + checkBHJ(df.repartition(4, Symbol("b")), // The top shuffle from repartition is not optimized out optimizeOutRepartition = false, probeSideLocalRead = true, probeSideCoalescedRead = true) // Repartition by col and project away the partition cols - checkBHJ(df.repartition('b).select('key), + checkBHJ(df.repartition(Symbol("b")).select(Symbol("key")), // The top shuffle from repartition is not optimized out optimizeOutRepartition = false, probeSideLocalRead = true, probeSideCoalescedRead = true) } @@ -1649,23 +1716,23 @@ class AdaptiveQueryExecSuite SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR.key -> "0", SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "10") { // Repartition with no partition num specified. - checkSMJ(df.repartition('b), + checkSMJ(df.repartition(Symbol("b")), // The top shuffle from repartition is optimized out. optimizeOutRepartition = true, optimizeSkewJoin = false, coalescedRead = true) // Repartition with default partition num (5 in test env) specified. - checkSMJ(df.repartition(5, 'b), + checkSMJ(df.repartition(5, Symbol("b")), // The top shuffle from repartition is optimized out. // The final plan must have 5 partitions, can't do coalesced read. optimizeOutRepartition = true, optimizeSkewJoin = false, coalescedRead = false) // Repartition with non-default partition num specified. - checkSMJ(df.repartition(4, 'b), + checkSMJ(df.repartition(4, Symbol("b")), // The top shuffle from repartition is not optimized out. optimizeOutRepartition = false, optimizeSkewJoin = true, coalescedRead = false) // Repartition by col and project away the partition cols - checkSMJ(df.repartition('b).select('key), + checkSMJ(df.repartition(Symbol("b")).select(Symbol("key")), // The top shuffle from repartition is not optimized out. optimizeOutRepartition = false, optimizeSkewJoin = true, coalescedRead = false) } @@ -2243,6 +2310,15 @@ class AdaptiveQueryExecSuite """.stripMargin) assert(findTopLevelLimit(origin2).size == 1) assert(findTopLevelLimit(adaptive2).isEmpty) + + // The strategy of Eliminate Limits batch should be fixedPoint + val (origin3, adaptive3) = runAdaptiveAndVerifyResult( + """ + |SELECT * FROM (SELECT c1 + c2 FROM (SELECT DISTINCT * FROM v LIMIT 10086)) LIMIT 20 + """.stripMargin + ) + assert(findTopLevelLimit(origin3).size == 1) + assert(findTopLevelLimit(adaptive3).isEmpty) } } } @@ -2409,6 +2485,96 @@ class AdaptiveQueryExecSuite } } } + + test("SPARK-37652: optimize skewed join through union") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100") { + withTempView("skewData1", "skewData2") { + spark + .range(0, 1000, 1, 10) + .selectExpr("id % 3 as key1", "id as value1") + .createOrReplaceTempView("skewData1") + spark + .range(0, 1000, 1, 10) + .selectExpr("id % 1 as key2", "id as value2") + .createOrReplaceTempView("skewData2") + + def checkSkewJoin(query: String, joinNums: Int, optimizeSkewJoinNums: Int): Unit = { + val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(query) + val joins = findTopLevelSortMergeJoin(innerAdaptivePlan) + val optimizeSkewJoins = joins.filter(_.isSkewJoin) + assert(joins.size == joinNums && optimizeSkewJoins.size == optimizeSkewJoinNums) + } + + // skewJoin union skewJoin + checkSkewJoin( + "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " + + "UNION ALL SELECT key2 FROM skewData1 JOIN skewData2 ON key1 = key2", 2, 2) + + // skewJoin union aggregate + checkSkewJoin( + "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " + + "UNION ALL SELECT key2 FROM skewData2 GROUP BY key2", 1, 1) + + // skewJoin1 union (skewJoin2 join aggregate) + // skewJoin2 will lead to extra shuffles, but skew1 cannot be optimized + checkSkewJoin( + "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 UNION ALL " + + "SELECT key1 from (SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2) tmp1 " + + "JOIN (SELECT key2 FROM skewData2 GROUP BY key2) tmp2 ON key1 = key2", 3, 0) + } + } + } + + test("SPARK-38162: Optimize one row plan in AQE Optimizer") { + withTempView("v") { + spark.sparkContext.parallelize( + (1 to 4).map(i => TestData(i, i.toString)), 2) + .toDF("c1", "c2").createOrReplaceTempView("v") + + // remove sort + val (origin1, adaptive1) = runAdaptiveAndVerifyResult( + """ + |SELECT * FROM v where c1 = 1 order by c1, c2 + |""".stripMargin) + assert(findTopLevelSort(origin1).size == 1) + assert(findTopLevelSort(adaptive1).isEmpty) + + // convert group only aggregate to project + val (origin2, adaptive2) = runAdaptiveAndVerifyResult( + """ + |SELECT distinct c1 FROM (SELECT /*+ repartition(c1) */ * FROM v where c1 = 1) + |""".stripMargin) + assert(findTopLevelAggregate(origin2).size == 2) + assert(findTopLevelAggregate(adaptive2).isEmpty) + + // remove distinct in aggregate + val (origin3, adaptive3) = runAdaptiveAndVerifyResult( + """ + |SELECT sum(distinct c1) FROM (SELECT /*+ repartition(c1) */ * FROM v where c1 = 1) + |""".stripMargin) + assert(findTopLevelAggregate(origin3).size == 4) + assert(findTopLevelAggregate(adaptive3).size == 2) + + // do not optimize if the aggregate is inside query stage + val (origin4, adaptive4) = runAdaptiveAndVerifyResult( + """ + |SELECT distinct c1 FROM v where c1 = 1 + |""".stripMargin) + assert(findTopLevelAggregate(origin4).size == 2) + assert(findTopLevelAggregate(adaptive4).size == 2) + + val (origin5, adaptive5) = runAdaptiveAndVerifyResult( + """ + |SELECT sum(distinct c1) FROM v where c1 = 1 + |""".stripMargin) + assert(findTopLevelAggregate(origin5).size == 4) + assert(findTopLevelAggregate(adaptive5).size == 4) + } + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index a5ac2d5aa70c9..e876e9d6ff20c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -1377,7 +1377,7 @@ class ArrowConvertersSuite extends SharedSparkSession { val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) val ctx = TaskContext.empty() - val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator, schema, 5, null, ctx) + val batchIter = ArrowConverters.toBatchIterator(inputRows.iterator, schema, 5, null, ctx) val outputRowIter = ArrowConverters.fromBatchIterator(batchIter, schema, null, ctx) var count = 0 @@ -1398,7 +1398,7 @@ class ArrowConvertersSuite extends SharedSparkSession { val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) val ctx = TaskContext.empty() - val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator, schema, 5, null, ctx) + val batchIter = ArrowConverters.toBatchIterator(inputRows.iterator, schema, 5, null, ctx) // Write batches to Arrow stream format as a byte array val out = new ByteArrayOutputStream() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BuiltInDataSourceWriteBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BuiltInDataSourceWriteBenchmark.scala index 361deb0d3e3b6..45d50b5e11a90 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BuiltInDataSourceWriteBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BuiltInDataSourceWriteBenchmark.scala @@ -16,6 +16,9 @@ */ package org.apache.spark.sql.execution.benchmark +import org.apache.parquet.column.ParquetProperties +import org.apache.parquet.hadoop.ParquetOutputFormat + import org.apache.spark.sql.internal.SQLConf /** @@ -53,7 +56,16 @@ object BuiltInDataSourceWriteBenchmark extends DataSourceWriteBenchmark { formats.foreach { format => runBenchmark(s"$format writer benchmark") { - runDataSourceBenchmark(format) + if (format.equals("Parquet")) { + ParquetProperties.WriterVersion.values().foreach { + writeVersion => + withSQLConf(ParquetOutputFormat.WRITER_VERSION -> writeVersion.toString) { + runDataSourceBenchmark("Parquet", Some(writeVersion.toString)) + } + } + } else { + runDataSourceBenchmark(format) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala index 31cee48c1787d..7c9fa58d77f42 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala @@ -24,11 +24,11 @@ import scala.util.Random import org.apache.parquet.column.ParquetProperties import org.apache.parquet.hadoop.ParquetOutputFormat -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, TestUtils} import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.datasources.parquet.{SpecificParquetRecordReaderBase, VectorizedParquetRecordReader} +import org.apache.spark.sql.execution.datasources.parquet.VectorizedParquetRecordReader import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnVector @@ -78,7 +78,7 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { saveAsCsvTable(testDf, dir.getCanonicalPath + "/csv") saveAsJsonTable(testDf, dir.getCanonicalPath + "/json") - saveAsParquetTable(testDf, dir.getCanonicalPath + "/parquet") + saveAsParquetV1Table(testDf, dir.getCanonicalPath + "/parquetV1") saveAsParquetV2Table(testDf, dir.getCanonicalPath + "/parquetV2") saveAsOrcTable(testDf, dir.getCanonicalPath + "/orc") } @@ -93,9 +93,9 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { spark.read.json(dir).createOrReplaceTempView("jsonTable") } - private def saveAsParquetTable(df: DataFrameWriter[Row], dir: String): Unit = { + private def saveAsParquetV1Table(df: DataFrameWriter[Row], dir: String): Unit = { df.mode("overwrite").option("compression", "snappy").parquet(dir) - spark.read.parquet(dir).createOrReplaceTempView("parquetTable") + spark.read.parquet(dir).createOrReplaceTempView("parquetV1Table") } private def saveAsParquetV2Table(df: DataFrameWriter[Row], dir: String): Unit = { @@ -111,6 +111,8 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { spark.read.orc(dir).createOrReplaceTempView("orcTable") } + private def withParquetVersions(f: String => Unit): Unit = Seq("V1", "V2").foreach(f) + def numericScanBenchmark(values: Int, dataType: DataType): Unit = { // Benchmarks running through spark sql. val sqlBenchmark = new Benchmark( @@ -125,7 +127,7 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { output = output) withTempPath { dir => - withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + withTempTable("t1", "csvTable", "jsonTable", "parquetV1Table", "parquetV2Table", "orcTable") { import spark.implicits._ spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") @@ -144,13 +146,17 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { spark.sql(s"select $query from jsonTable").noop() } - sqlBenchmark.addCase("SQL Parquet Vectorized") { _ => - spark.sql(s"select $query from parquetTable").noop() + withParquetVersions { version => + sqlBenchmark.addCase(s"SQL Parquet Vectorized: DataPage$version") { _ => + spark.sql(s"select $query from parquet${version}Table").noop() + } } - sqlBenchmark.addCase("SQL Parquet MR") { _ => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - spark.sql(s"select $query from parquetTable").noop() + withParquetVersions { version => + sqlBenchmark.addCase(s"SQL Parquet MR: DataPage$version") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql(s"select $query from parquet${version}Table").noop() + } } } @@ -166,79 +172,93 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { sqlBenchmark.run() - // Driving the parquet reader in batch mode directly. - val files = SpecificParquetRecordReaderBase.listDirectory(new File(dir, "parquet")).toArray val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize - parquetReaderBenchmark.addCase("ParquetReader Vectorized") { _ => - var longSum = 0L - var doubleSum = 0.0 - val aggregateValue: (ColumnVector, Int) => Unit = dataType match { - case BooleanType => (col: ColumnVector, i: Int) => if (col.getBoolean(i)) longSum += 1L - case ByteType => (col: ColumnVector, i: Int) => longSum += col.getByte(i) - case ShortType => (col: ColumnVector, i: Int) => longSum += col.getShort(i) - case IntegerType => (col: ColumnVector, i: Int) => longSum += col.getInt(i) - case LongType => (col: ColumnVector, i: Int) => longSum += col.getLong(i) - case FloatType => (col: ColumnVector, i: Int) => doubleSum += col.getFloat(i) - case DoubleType => (col: ColumnVector, i: Int) => doubleSum += col.getDouble(i) - } + withParquetVersions { version => + // Driving the parquet reader in batch mode directly. + val files = TestUtils.listDirectory(new File(dir, s"parquet$version")) + parquetReaderBenchmark.addCase(s"ParquetReader Vectorized: DataPage$version") { _ => + var longSum = 0L + var doubleSum = 0.0 + val aggregateValue: (ColumnVector, Int) => Unit = dataType match { + case BooleanType => + (col: ColumnVector, i: Int) => if (col.getBoolean(i)) longSum += 1L + case ByteType => + (col: ColumnVector, i: Int) => longSum += col.getByte(i) + case ShortType => + (col: ColumnVector, i: Int) => longSum += col.getShort(i) + case IntegerType => + (col: ColumnVector, i: Int) => longSum += col.getInt(i) + case LongType => + (col: ColumnVector, i: Int) => longSum += col.getLong(i) + case FloatType => + (col: ColumnVector, i: Int) => doubleSum += col.getFloat(i) + case DoubleType => + (col: ColumnVector, i: Int) => doubleSum += col.getDouble(i) + } - files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader( - enableOffHeapColumnVector, vectorizedReaderBatchSize) - try { - reader.initialize(p, ("id" :: Nil).asJava) - val batch = reader.resultBatch() - val col = batch.column(0) - while (reader.nextBatch()) { - val numRows = batch.numRows() - var i = 0 - while (i < numRows) { - if (!col.isNullAt(i)) aggregateValue(col, i) - i += 1 + files.foreach { p => + val reader = new VectorizedParquetRecordReader( + enableOffHeapColumnVector, vectorizedReaderBatchSize) + try { + reader.initialize(p, ("id" :: Nil).asJava) + val batch = reader.resultBatch() + val col = batch.column(0) + while (reader.nextBatch()) { + val numRows = batch.numRows() + var i = 0 + while (i < numRows) { + if (!col.isNullAt(i)) aggregateValue(col, i) + i += 1 + } } + } finally { + reader.close() } - } finally { - reader.close() } } } - // Decoding in vectorized but having the reader return rows. - parquetReaderBenchmark.addCase("ParquetReader Vectorized -> Row") { num => - var longSum = 0L - var doubleSum = 0.0 - val aggregateValue: (InternalRow) => Unit = dataType match { - case BooleanType => (col: InternalRow) => if (col.getBoolean(0)) longSum += 1L - case ByteType => (col: InternalRow) => longSum += col.getByte(0) - case ShortType => (col: InternalRow) => longSum += col.getShort(0) - case IntegerType => (col: InternalRow) => longSum += col.getInt(0) - case LongType => (col: InternalRow) => longSum += col.getLong(0) - case FloatType => (col: InternalRow) => doubleSum += col.getFloat(0) - case DoubleType => (col: InternalRow) => doubleSum += col.getDouble(0) - } + withParquetVersions { version => + // Driving the parquet reader in batch mode directly. + val files = TestUtils.listDirectory(new File(dir, s"parquet$version")) + // Decoding in vectorized but having the reader return rows. + parquetReaderBenchmark + .addCase(s"ParquetReader Vectorized -> Row: DataPage$version") { _ => + var longSum = 0L + var doubleSum = 0.0 + val aggregateValue: (InternalRow) => Unit = dataType match { + case BooleanType => (col: InternalRow) => if (col.getBoolean(0)) longSum += 1L + case ByteType => (col: InternalRow) => longSum += col.getByte(0) + case ShortType => (col: InternalRow) => longSum += col.getShort(0) + case IntegerType => (col: InternalRow) => longSum += col.getInt(0) + case LongType => (col: InternalRow) => longSum += col.getLong(0) + case FloatType => (col: InternalRow) => doubleSum += col.getFloat(0) + case DoubleType => (col: InternalRow) => doubleSum += col.getDouble(0) + } - files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader( - enableOffHeapColumnVector, vectorizedReaderBatchSize) - try { - reader.initialize(p, ("id" :: Nil).asJava) - val batch = reader.resultBatch() - while (reader.nextBatch()) { - val it = batch.rowIterator() - while (it.hasNext) { - val record = it.next() - if (!record.isNullAt(0)) aggregateValue(record) + files.foreach { p => + val reader = new VectorizedParquetRecordReader( + enableOffHeapColumnVector, vectorizedReaderBatchSize) + try { + reader.initialize(p, ("id" :: Nil).asJava) + val batch = reader.resultBatch() + while (reader.nextBatch()) { + val it = batch.rowIterator() + while (it.hasNext) { + val record = it.next() + if (!record.isNullAt(0)) aggregateValue(record) + } + } + } finally { + reader.close() } } - } finally { - reader.close() } - } } - - parquetReaderBenchmark.run() } + + parquetReaderBenchmark.run() } } @@ -246,7 +266,7 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { val benchmark = new Benchmark("Int and String Scan", values, output = output) withTempPath { dir => - withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + withTempTable("t1", "csvTable", "jsonTable", "parquetV1Table", "parquetV2Table", "orcTable") { import spark.implicits._ spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") @@ -262,13 +282,17 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { spark.sql("select sum(c1), sum(length(c2)) from jsonTable").noop() } - benchmark.addCase("SQL Parquet Vectorized") { _ => - spark.sql("select sum(c1), sum(length(c2)) from parquetTable").noop() + withParquetVersions { version => + benchmark.addCase(s"SQL Parquet Vectorized: DataPage$version") { _ => + spark.sql(s"select sum(c1), sum(length(c2)) from parquet${version}Table").noop() + } } - benchmark.addCase("SQL Parquet MR") { _ => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - spark.sql("select sum(c1), sum(length(c2)) from parquetTable").noop() + withParquetVersions { version => + benchmark.addCase(s"SQL Parquet MR: DataPage$version") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql(s"select sum(c1), sum(length(c2)) from parquet${version}Table").noop() + } } } @@ -291,7 +315,7 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { val benchmark = new Benchmark("Repeated String", values, output = output) withTempPath { dir => - withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + withTempTable("t1", "csvTable", "jsonTable", "parquetV1Table", "parquetV2Table", "orcTable") { import spark.implicits._ spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") @@ -307,13 +331,17 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { spark.sql("select sum(length(c1)) from jsonTable").noop() } - benchmark.addCase("SQL Parquet Vectorized") { _ => - spark.sql("select sum(length(c1)) from parquetTable").noop() + withParquetVersions { version => + benchmark.addCase(s"SQL Parquet Vectorized: DataPage$version") { _ => + spark.sql(s"select sum(length(c1)) from parquet${version}Table").noop() + } } - benchmark.addCase("SQL Parquet MR") { _ => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - spark.sql("select sum(length(c1)) from parquetTable").noop() + withParquetVersions { version => + benchmark.addCase(s"SQL Parquet MR: DataPage$version") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql(s"select sum(length(c1)) from parquet${version}Table").noop() + } } } @@ -336,7 +364,7 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { val benchmark = new Benchmark("Partitioned Table", values, output = output) withTempPath { dir => - withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + withTempTable("t1", "csvTable", "jsonTable", "parquetV1Table", "parquetV2Table", "orcTable") { import spark.implicits._ spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") @@ -350,13 +378,17 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { spark.sql("select sum(id) from jsonTable").noop() } - benchmark.addCase("Data column - Parquet Vectorized") { _ => - spark.sql("select sum(id) from parquetTable").noop() + withParquetVersions { version => + benchmark.addCase(s"Data column - Parquet Vectorized: DataPage$version") { _ => + spark.sql(s"select sum(id) from parquet${version}Table").noop() + } } - benchmark.addCase("Data column - Parquet MR") { _ => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - spark.sql("select sum(id) from parquetTable").noop() + withParquetVersions { version => + benchmark.addCase(s"Data column - Parquet MR: DataPage$version") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql(s"select sum(id) from parquet${version}Table").noop() + } } } @@ -378,13 +410,17 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { spark.sql("select sum(p) from jsonTable").noop() } - benchmark.addCase("Partition column - Parquet Vectorized") { _ => - spark.sql("select sum(p) from parquetTable").noop() + withParquetVersions { version => + benchmark.addCase(s"Partition column - Parquet Vectorized: DataPage$version") { _ => + spark.sql(s"select sum(p) from parquet${version}Table").noop() + } } - benchmark.addCase("Partition column - Parquet MR") { _ => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - spark.sql("select sum(p) from parquetTable").noop() + withParquetVersions { version => + benchmark.addCase(s"Partition column - Parquet MR: DataPage$version") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql(s"select sum(p) from parquet${version}Table").noop() + } } } @@ -406,13 +442,17 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { spark.sql("select sum(p), sum(id) from jsonTable").noop() } - benchmark.addCase("Both columns - Parquet Vectorized") { _ => - spark.sql("select sum(p), sum(id) from parquetTable").noop() + withParquetVersions { version => + benchmark.addCase(s"Both columns - Parquet Vectorized: DataPage$version") { _ => + spark.sql(s"select sum(p), sum(id) from parquet${version}Table").noop() + } } - benchmark.addCase("Both columns - Parquet MR") { _ => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - spark.sql("select sum(p), sum(id) from parquetTable").noop() + withParquetVersions { version => + benchmark.addCase(s"Both columns - Parquet MR: DataPage$version") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql(s"select sum(p), sum(id) from parquet${version}Table").noop() + } } } @@ -437,7 +477,7 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { new Benchmark(s"String with Nulls Scan ($percentageOfNulls%)", values, output = output) withTempPath { dir => - withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + withTempTable("t1", "csvTable", "jsonTable", "parquetV1Table", "parquetV2Table", "orcTable") { spark.range(values).createOrReplaceTempView("t1") prepareTable( @@ -456,39 +496,45 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { "not NULL and c2 is not NULL").noop() } - benchmark.addCase("SQL Parquet Vectorized") { _ => - spark.sql("select sum(length(c2)) from parquetTable where c1 is " + - "not NULL and c2 is not NULL").noop() + withParquetVersions { version => + benchmark.addCase(s"SQL Parquet Vectorized: DataPage$version") { _ => + spark.sql(s"select sum(length(c2)) from parquet${version}Table where c1 is " + + "not NULL and c2 is not NULL").noop() + } } - benchmark.addCase("SQL Parquet MR") { _ => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - spark.sql("select sum(length(c2)) from parquetTable where c1 is " + - "not NULL and c2 is not NULL").noop() + withParquetVersions { version => + benchmark.addCase(s"SQL Parquet MR: DataPage$version") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql(s"select sum(length(c2)) from parquet${version}Table where c1 is " + + "not NULL and c2 is not NULL").noop() + } } } - val files = SpecificParquetRecordReaderBase.listDirectory(new File(dir, "parquet")).toArray - val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled - val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize - benchmark.addCase("ParquetReader Vectorized") { num => - var sum = 0 - files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader( - enableOffHeapColumnVector, vectorizedReaderBatchSize) - try { - reader.initialize(p, ("c1" :: "c2" :: Nil).asJava) - val batch = reader.resultBatch() - while (reader.nextBatch()) { - val rowIterator = batch.rowIterator() - while (rowIterator.hasNext) { - val row = rowIterator.next() - val value = row.getUTF8String(0) - if (!row.isNullAt(0) && !row.isNullAt(1)) sum += value.numBytes() + withParquetVersions { version => + val files = TestUtils.listDirectory(new File(dir, s"parquet$version")) + val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled + val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize + benchmark.addCase(s"ParquetReader Vectorized: DataPage$version") { _ => + var sum = 0 + files.foreach { p => + val reader = new VectorizedParquetRecordReader( + enableOffHeapColumnVector, vectorizedReaderBatchSize) + try { + reader.initialize(p, ("c1" :: "c2" :: Nil).asJava) + val batch = reader.resultBatch() + while (reader.nextBatch()) { + val rowIterator = batch.rowIterator() + while (rowIterator.hasNext) { + val row = rowIterator.next() + val value = row.getUTF8String(0) + if (!row.isNullAt(0) && !row.isNullAt(1)) sum += value.numBytes() + } } + } finally { + reader.close() } - } finally { - reader.close() } } } @@ -517,7 +563,7 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { output = output) withTempPath { dir => - withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + withTempTable("t1", "csvTable", "jsonTable", "parquetV1Table", "parquetV2Table", "orcTable") { import spark.implicits._ val middle = width / 2 val selectExpr = (1 to width).map(i => s"value as c$i") @@ -534,13 +580,17 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { spark.sql(s"SELECT sum(c$middle) FROM jsonTable").noop() } - benchmark.addCase("SQL Parquet Vectorized") { _ => - spark.sql(s"SELECT sum(c$middle) FROM parquetTable").noop() + withParquetVersions { version => + benchmark.addCase(s"SQL Parquet Vectorized: DataPage$version") { _ => + spark.sql(s"SELECT sum(c$middle) FROM parquet${version}Table").noop() + } } - benchmark.addCase("SQL Parquet MR") { _ => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - spark.sql(s"SELECT sum(c$middle) FROM parquetTable").noop() + withParquetVersions { version => + benchmark.addCase(s"SQL Parquet MR: DataPage$version") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql(s"SELECT sum(c$middle) FROM parquet${version}Table").noop() + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala index 405d60794ede0..77e26048e0425 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala @@ -66,7 +66,7 @@ trait DataSourceWriteBenchmark extends SqlBasedBenchmark { } } - def runDataSourceBenchmark(format: String): Unit = { + def runDataSourceBenchmark(format: String, extraInfo: Option[String] = None): Unit = { val tableInt = "tableInt" val tableDouble = "tableDouble" val tableIntString = "tableIntString" @@ -75,7 +75,12 @@ trait DataSourceWriteBenchmark extends SqlBasedBenchmark { withTempTable(tempTable) { spark.range(numRows).createOrReplaceTempView(tempTable) withTable(tableInt, tableDouble, tableIntString, tablePartition, tableBucket) { - val benchmark = new Benchmark(s"$format writer benchmark", numRows, output = output) + val writerName = extraInfo match { + case Some(extra) => s"$format($extra)" + case _ => format + } + val benchmark = + new Benchmark(s"$writerName writer benchmark", numRows, output = output) writeNumeric(tableInt, format, benchmark, "Int") writeNumeric(tableDouble, format, benchmark, "Double") writeIntString(tableIntString, format, benchmark) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala index 849c41307245e..787fdc7b59d67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala @@ -44,7 +44,7 @@ object JoinBenchmark extends SqlBasedBenchmark { val dim = broadcast(spark.range(M).selectExpr("id as k", "cast(id as string) as v")) codegenBenchmark("Join w long", N) { val df = spark.range(N).join(dim, (col("id") % M) === col("k")) - assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) + assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[BroadcastHashJoinExec])) df.noop() } } @@ -55,7 +55,7 @@ object JoinBenchmark extends SqlBasedBenchmark { val dim = broadcast(spark.range(M).selectExpr("cast(id/10 as long) as k")) codegenBenchmark("Join w long duplicated", N) { val df = spark.range(N).join(dim, (col("id") % M) === col("k")) - assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) + assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[BroadcastHashJoinExec])) df.noop() } } @@ -70,7 +70,7 @@ object JoinBenchmark extends SqlBasedBenchmark { val df = spark.range(N).join(dim2, (col("id") % M).cast(IntegerType) === col("k1") && (col("id") % M).cast(IntegerType) === col("k2")) - assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) + assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[BroadcastHashJoinExec])) df.noop() } } @@ -84,7 +84,7 @@ object JoinBenchmark extends SqlBasedBenchmark { codegenBenchmark("Join w 2 longs", N) { val df = spark.range(N).join(dim3, (col("id") % M) === col("k1") && (col("id") % M) === col("k2")) - assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) + assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[BroadcastHashJoinExec])) df.noop() } } @@ -98,7 +98,7 @@ object JoinBenchmark extends SqlBasedBenchmark { codegenBenchmark("Join w 2 longs duplicated", N) { val df = spark.range(N).join(dim4, (col("id") bitwiseAND M) === col("k1") && (col("id") bitwiseAND M) === col("k2")) - assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) + assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[BroadcastHashJoinExec])) df.noop() } } @@ -109,7 +109,7 @@ object JoinBenchmark extends SqlBasedBenchmark { val dim = broadcast(spark.range(M).selectExpr("id as k", "cast(id as string) as v")) codegenBenchmark("outer join w long", N) { val df = spark.range(N).join(dim, (col("id") % M) === col("k"), "left") - assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) + assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[BroadcastHashJoinExec])) df.noop() } } @@ -120,7 +120,7 @@ object JoinBenchmark extends SqlBasedBenchmark { val dim = broadcast(spark.range(M).selectExpr("id as k", "cast(id as string) as v")) codegenBenchmark("semi join w long", N) { val df = spark.range(N).join(dim, (col("id") % M) === col("k"), "leftsemi") - assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) + assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[BroadcastHashJoinExec])) df.noop() } } @@ -131,7 +131,7 @@ object JoinBenchmark extends SqlBasedBenchmark { val df1 = spark.range(N).selectExpr(s"id * 2 as k1") val df2 = spark.range(N).selectExpr(s"id * 3 as k2") val df = df1.join(df2, col("k1") === col("k2")) - assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[SortMergeJoinExec]).isDefined) + assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[SortMergeJoinExec])) df.noop() } } @@ -144,7 +144,7 @@ object JoinBenchmark extends SqlBasedBenchmark { val df2 = spark.range(N) .selectExpr(s"(id * 15485867) % ${N*10} as k2") val df = df1.join(df2, col("k1") === col("k2")) - assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[SortMergeJoinExec]).isDefined) + assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[SortMergeJoinExec])) df.noop() } } @@ -159,7 +159,7 @@ object JoinBenchmark extends SqlBasedBenchmark { val df1 = spark.range(N).selectExpr(s"id as k1") val df2 = spark.range(N / 3).selectExpr(s"id * 3 as k2") val df = df1.join(df2, col("k1") === col("k2")) - assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[ShuffledHashJoinExec]).isDefined) + assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[ShuffledHashJoinExec])) df.noop() } } @@ -172,8 +172,7 @@ object JoinBenchmark extends SqlBasedBenchmark { val dim = broadcast(spark.range(M).selectExpr("id as k", "cast(id as string) as v")) codegenBenchmark("broadcast nested loop join", N) { val df = spark.range(N).join(dim) - assert(df.queryExecution.sparkPlan.find( - _.isInstanceOf[BroadcastNestedLoopJoinExec]).isDefined) + assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[BroadcastNestedLoopJoinExec])) df.noop() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/RangeBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/RangeBenchmark.scala index e9bdff5853a51..31d5fd9ffdffe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/RangeBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/RangeBenchmark.scala @@ -49,7 +49,7 @@ object RangeBenchmark extends SqlBasedBenchmark { } benchmark.addCase("filter after range", numIters = 4) { _ => - spark.range(N).filter('id % 100 === 0).noop() + spark.range(N).filter(Symbol("id") % 100 === 0).noop() } benchmark.addCase("count after range", numIters = 4) { _ => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 2cf12dd92f64c..120ddf469f4a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -152,7 +152,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSparkSession { } test("projection") { - val logicalPlan = testData.select('value, 'key).logicalPlan + val logicalPlan = testData.select(Symbol("value"), Symbol("key")).logicalPlan val plan = spark.sessionState.executePlan(logicalPlan).sparkPlan val scan = InMemoryRelation(new TestCachedBatchSerializer(useCompression = true, 5), MEMORY_ONLY, plan, None, logicalPlan) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenameSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenameSuiteBase.scala index 6370939cef6a2..1803ec046930b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenameSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenameSuiteBase.scala @@ -126,4 +126,14 @@ trait AlterTableRenameSuiteBase extends QueryTest with DDLCommandTestUtils { spark.sessionState.catalogManager.reset() } } + + test("SPARK-37963: preserve partition info") { + withNamespaceAndTable("ns", "dst_tbl") { dst => + val src = dst.replace("dst", "src") + sql(s"CREATE TABLE $src (i int, j int) $defaultUsing partitioned by (j)") + sql(s"insert into table $src partition(j=2) values (1)") + sql(s"ALTER TABLE $src RENAME TO ns.dst_tbl") + checkAnswer(spark.table(dst), Row(1, 2)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateNamespaceParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateNamespaceParserSuite.scala index c3cb96814a506..69a208b942429 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateNamespaceParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateNamespaceParserSuite.scala @@ -107,5 +107,5 @@ class CreateNamespaceParserSuite extends AnalysisTest { } private def intercept(sqlCommand: String, messages: String*): Unit = - interceptParseException(parsePlan)(sqlCommand, messages: _*) + interceptParseException(parsePlan)(sqlCommand, messages: _*)() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index 4d24b262fa03a..1053cb9f2a772 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -46,7 +46,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { } private def intercept(sqlCommand: String, messages: String*): Unit = - interceptParseException(parser.parsePlan)(sqlCommand, messages: _*) + interceptParseException(parser.parsePlan)(sqlCommand, messages: _*)() private def compareTransformQuery(sql: String, expected: LogicalPlan): Unit = { val plan = parser.parsePlan(sql).asInstanceOf[ScriptTransformation].copy(ioschema = null) @@ -288,12 +288,12 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { val s = ScriptTransformation("func", Seq.empty, p, null) compareTransformQuery("select transform(a, b) using 'func' from e where f < 10", - s.copy(child = p.copy(child = p.child.where('f < 10)), - output = Seq('key.string, 'value.string))) + s.copy(child = p.copy(child = p.child.where(Symbol("f") < 10)), + output = Seq(Symbol("key").string, Symbol("value").string))) compareTransformQuery("map a, b using 'func' as c, d from e", - s.copy(output = Seq('c.string, 'd.string))) + s.copy(output = Seq(Symbol("c").string, Symbol("d").string))) compareTransformQuery("reduce a, b using 'func' as (c int, d decimal(10, 0)) from e", - s.copy(output = Seq('c.int, 'd.decimal(10, 0)))) + s.copy(output = Seq(Symbol("c").int, Symbol("d").decimal(10, 0)))) } test("use backticks in output of Script Transform") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 00d1ed2cbc680..c3d1126dc07f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -115,7 +115,7 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSparkSession { }.getMessage assert(e.contains("Hive support is required to CREATE Hive TABLE (AS SELECT)")) - spark.range(1).select('id as 'a, 'id as 'b).write.saveAsTable("t1") + spark.range(1).select('id as Symbol("a"), 'id as Symbol("b")).write.saveAsTable("t1") e = intercept[AnalysisException] { sql("CREATE TABLE t STORED AS parquet SELECT a, b from t1") }.getMessage @@ -1374,7 +1374,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { sql("CREATE TABLE t USING parquet SELECT 1 as a, 1 as b") checkAnswer(spark.table("t"), Row(1, 1) :: Nil) - spark.range(1).select('id as 'a, 'id as 'b).write.saveAsTable("t1") + spark.range(1).select('id as Symbol("a"), 'id as Symbol("b")).write.saveAsTable("t1") sql("CREATE TABLE t2 USING parquet SELECT a, b from t1") checkAnswer(spark.table("t2"), spark.table("t1")) } @@ -2103,57 +2103,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } - test("create temporary function with if not exists") { - withUserDefinedFunction("func1" -> true) { - val sql1 = - """ - |CREATE TEMPORARY FUNCTION IF NOT EXISTS func1 as - |'com.matthewrathbone.example.SimpleUDFExample' USING JAR '/path/to/jar1', - |JAR '/path/to/jar2' - """.stripMargin - val e = intercept[AnalysisException] { - sql(sql1) - }.getMessage - assert(e.contains("It is not allowed to define a TEMPORARY function with IF NOT EXISTS")) - } - } - - test("create function with both if not exists and replace") { - withUserDefinedFunction("func1" -> false) { - val sql1 = - """ - |CREATE OR REPLACE FUNCTION IF NOT EXISTS func1 as - |'com.matthewrathbone.example.SimpleUDFExample' USING JAR '/path/to/jar1', - |JAR '/path/to/jar2' - """.stripMargin - val e = intercept[AnalysisException] { - sql(sql1) - }.getMessage - assert(e.contains("CREATE FUNCTION with both IF NOT EXISTS and REPLACE is not allowed")) - } - } - - test("create temporary function by specifying a database") { - val dbName = "mydb" - withDatabase(dbName) { - sql(s"CREATE DATABASE $dbName") - sql(s"USE $dbName") - withUserDefinedFunction("func1" -> true) { - val sql1 = - s""" - |CREATE TEMPORARY FUNCTION $dbName.func1 as - |'com.matthewrathbone.example.SimpleUDFExample' USING JAR '/path/to/jar1', - |JAR '/path/to/jar2' - """.stripMargin - val e = intercept[AnalysisException] { - sql(sql1) - }.getMessage - assert(e.contains(s"Specifying a database in CREATE TEMPORARY FUNCTION " + - s"is not allowed: '$dbName'")) - } - } - } - Seq(true, false).foreach { caseSensitive => test(s"alter table add columns with existing column name - caseSensitive $caseSensitive") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> s"$caseSensitive") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 5862acff70ab1..5399f9674377a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -249,9 +249,7 @@ class PlanResolutionSuite extends AnalysisTest { } test("create table - partitioned by transforms") { - val transforms = Seq( - "bucket(16, b)", "years(ts)", "months(ts)", "days(ts)", "hours(ts)", "foo(a, 'bar', 34)", - "bucket(32, b), days(ts)") + val transforms = Seq("years(ts)", "months(ts)", "days(ts)", "hours(ts)", "foo(a, 'bar', 34)") transforms.foreach { transform => val query = s""" @@ -259,12 +257,30 @@ class PlanResolutionSuite extends AnalysisTest { |PARTITIONED BY ($transform) """.stripMargin - val ae = intercept[AnalysisException] { + val ae = intercept[UnsupportedOperationException] { parseAndResolve(query) } - assert(ae.message - .contains(s"Transforms cannot be converted to partition columns: $transform")) + assert(ae.getMessage + .contains(s"Unsupported partition transform: $transform")) + } + } + + test("create table - partitioned by multiple bucket transforms") { + val transforms = Seq("bucket(32, b), sorted_bucket(b, 32, c)") + transforms.foreach { transform => + val query = + s""" + |CREATE TABLE my_tab(a INT, b STRING, c String) USING parquet + |PARTITIONED BY ($transform) + """.stripMargin + + val ae = intercept[UnsupportedOperationException] { + parseAndResolve(query) + } + + assert(ae.getMessage + .contains("Multiple bucket transforms are not supported.")) } } @@ -1734,7 +1750,7 @@ class PlanResolutionSuite extends AnalysisTest { interceptParseException(parsePlan)( "CREATE TABLE my_tab(a: INT COMMENT 'test', b: STRING)", - "extraneous input ':'") + "extraneous input ':'")() } test("create hive table - table file format") { @@ -1859,7 +1875,7 @@ class PlanResolutionSuite extends AnalysisTest { test("Duplicate clauses - create hive table") { def intercept(sqlCommand: String, messages: String*): Unit = - interceptParseException(parsePlan)(sqlCommand, messages: _*) + interceptParseException(parsePlan)(sqlCommand, messages: _*)() def createTableHeader(duplicateClause: String): String = { s"CREATE TABLE my_tab(a INT, b STRING) STORED AS parquet $duplicateClause $duplicateClause" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowCreateTableParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowCreateTableParserSuite.scala new file mode 100644 index 0000000000000..ab7c6e4dec568 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowCreateTableParserSuite.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedTableOrView} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan +import org.apache.spark.sql.catalyst.plans.logical.ShowCreateTable + +class ShowCreateTableParserSuite extends AnalysisTest { + test("show create table") { + comparePlans( + parsePlan("SHOW CREATE TABLE a.b.c"), + ShowCreateTable( + UnresolvedTableOrView(Seq("a", "b", "c"), "SHOW CREATE TABLE", allowTempView = false))) + + comparePlans( + parsePlan("SHOW CREATE TABLE a.b.c AS SERDE"), + ShowCreateTable( + UnresolvedTableOrView(Seq("a", "b", "c"), "SHOW CREATE TABLE", allowTempView = false), + asSerde = true)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowCreateTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowCreateTableSuiteBase.scala new file mode 100644 index 0000000000000..7bc076561f448 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowCreateTableSuiteBase.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.sources.SimpleInsertSource +import org.apache.spark.util.Utils + +/** + * This base suite contains unified tests for the `SHOW CREATE TABLE` command that check V1 and V2 + * table catalogs. The tests that cannot run for all supported catalogs are located in more + * specific test suites: + * + * - V2 table catalog tests: `org.apache.spark.sql.execution.command.v2.ShowCreateTableSuite` + * - V1 table catalog tests: + * `org.apache.spark.sql.execution.command.v1.ShowCreateTableSuiteBase` + * - V1 In-Memory catalog: `org.apache.spark.sql.execution.command.v1.ShowCreateTableSuite` + * - V1 Hive External catalog: +* `org.apache.spark.sql.hive.execution.command.ShowCreateTableSuite` + */ +trait ShowCreateTableSuiteBase extends QueryTest with DDLCommandTestUtils { + override val command = "SHOW CREATE TABLE" + protected def ns: String = "ns1" + protected def table: String = "tbl" + protected def fullName: String + + test("SPARK-36012: add null flag when show create table") { + withNamespaceAndTable(ns, table) { t => + sql( + s""" + |CREATE TABLE $t ( + | a bigint NOT NULL, + | b bigint + |) + |USING ${classOf[SimpleInsertSource].getName} + """.stripMargin) + val showDDL = getShowCreateDDL(t) + assert(showDDL(0) == s"CREATE TABLE $fullName (") + assert(showDDL(1) == "a BIGINT NOT NULL,") + assert(showDDL(2) == "b BIGINT)") + assert(showDDL(3) == s"USING ${classOf[SimpleInsertSource].getName}") + } + } + + test("data source table with user specified schema") { + withNamespaceAndTable(ns, table) { t => + val jsonFilePath = Utils.getSparkClassLoader.getResource("sample.json").getFile + sql( + s"""CREATE TABLE $t ( + | a STRING, + | b STRING, + | `extra col` ARRAY, + | `` STRUCT> + |) + |USING json + |OPTIONS ( + | PATH '$jsonFilePath' + |) + """.stripMargin + ) + val showDDL = getShowCreateDDL(t) + assert(showDDL(0) == s"CREATE TABLE $fullName (") + assert(showDDL(1) == "a STRING,") + assert(showDDL(2) == "b STRING,") + assert(showDDL(3) == "`extra col` ARRAY,") + assert(showDDL(4) == "`` STRUCT>)") + assert(showDDL(5) == "USING json") + assert(showDDL(6).startsWith("LOCATION 'file:") && showDDL(6).endsWith("sample.json'")) + } + } + + test("SPARK-24911: keep quotes for nested fields") { + withNamespaceAndTable(ns, table) { t => + sql( + s""" + |CREATE TABLE $t ( + | `a` STRUCT<`b`: STRING> + |) + |USING json + """.stripMargin) + val showDDL = getShowCreateDDL(t) + assert(showDDL(0) == s"CREATE TABLE $fullName (") + assert(showDDL(1) == "a STRUCT)") + assert(showDDL(2) == "USING json") + } + } + + test("SPARK-37494: Unify v1 and v2 option output") { + withNamespaceAndTable(ns, table) { t => + sql( + s"""CREATE TABLE $t ( + | a STRING + |) + |USING json + |TBLPROPERTIES ( + | 'b' = '1', + | 'a' = '2') + |OPTIONS ( + | k4 'v4', + | `k3` 'v3', + | 'k5' 'v5', + | 'k1' = 'v1', + | k2 = 'v2' + |) + """.stripMargin + ) + val expected = s"CREATE TABLE $fullName ( a STRING) USING json" + + " OPTIONS ( 'k1' = 'v1', 'k2' = 'v2', 'k3' = 'v3', 'k4' = 'v4', 'k5' = 'v5')" + + " TBLPROPERTIES ( 'a' = '2', 'b' = '1')" + assert(getShowCreateDDL(t).mkString(" ") == expected) + } + } + + test("data source table CTAS") { + withNamespaceAndTable(ns, table) { t => + sql( + s"""CREATE TABLE $t + |USING json + |AS SELECT 1 AS a, "foo" AS b + """.stripMargin + ) + val expected = s"CREATE TABLE $fullName ( a INT, b STRING) USING json" + assert(getShowCreateDDL(t).mkString(" ") == expected) + } + } + + test("partitioned data source table") { + withNamespaceAndTable(ns, table) { t => + sql( + s"""CREATE TABLE $t + |USING json + |PARTITIONED BY (b) + |AS SELECT 1 AS a, "foo" AS b + """.stripMargin + ) + val expected = s"CREATE TABLE $fullName ( a INT, b STRING) USING json PARTITIONED BY (b)" + assert(getShowCreateDDL(t).mkString(" ") == expected) + } + } + + test("data source table with a comment") { + withNamespaceAndTable(ns, table) { t => + sql( + s"""CREATE TABLE $t + |USING json + |COMMENT 'This is a comment' + |AS SELECT 1 AS a, "foo" AS b, 2.5 AS c + """.stripMargin + ) + val expected = s"CREATE TABLE $fullName ( a INT, b STRING, c DECIMAL(2,1)) USING json" + + s" COMMENT 'This is a comment'" + assert(getShowCreateDDL(t).mkString(" ") == expected) + } + } + + test("data source table with table properties") { + withNamespaceAndTable(ns, table) { t => + sql( + s"""CREATE TABLE $t + |USING json + |TBLPROPERTIES ('a' = '1') + |AS SELECT 1 AS a, "foo" AS b, 2.5 AS c + """.stripMargin + ) + val expected = s"CREATE TABLE $fullName ( a INT, b STRING, c DECIMAL(2,1)) USING json" + + s" TBLPROPERTIES ( 'a' = '1')" + assert(getShowCreateDDL(t).mkString(" ") == expected) + } + } + + def getShowCreateDDL(table: String, serde: Boolean = false): Array[String] = { + val result = if (serde) { + sql(s"SHOW CREATE TABLE $table AS SERDE") + } else { + sql(s"SHOW CREATE TABLE $table") + } + result.head().getString(0).split("\n").map(_.trim) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/DropNamespaceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/DropNamespaceSuite.scala index 24e51317575d3..174ac970be6bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/DropNamespaceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/DropNamespaceSuite.scala @@ -28,7 +28,8 @@ import org.apache.spark.sql.execution.command * - V1 In-Memory catalog: `org.apache.spark.sql.execution.command.v1.DropNamespaceSuite` * - V1 Hive External catalog: `org.apache.spark.sql.hive.execution.command.DropNamespaceSuite` */ -trait DropNamespaceSuiteBase extends command.DropNamespaceSuiteBase { +trait DropNamespaceSuiteBase extends command.DropNamespaceSuiteBase + with command.TestsV1AndV2Commands { override protected def builtinTopNamespaces: Seq[String] = Seq("default") override protected def namespaceAlias(): String = "database" @@ -41,4 +42,6 @@ trait DropNamespaceSuiteBase extends command.DropNamespaceSuiteBase { } } -class DropNamespaceSuite extends DropNamespaceSuiteBase with CommandSuiteBase +class DropNamespaceSuite extends DropNamespaceSuiteBase with CommandSuiteBase { + override def commandVersion: String = super[DropNamespaceSuiteBase].commandVersion +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowCreateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowCreateTableSuite.scala new file mode 100644 index 0000000000000..ee8aa424d5c26 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowCreateTableSuite.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command.v1 + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.command + +/** + * This base suite contains unified tests for the `SHOW CREATE TABLE` command that checks V1 + * table catalogs. The tests that cannot run for all V1 catalogs are located in more + * specific test suites: + * + * - V1 In-Memory catalog: `org.apache.spark.sql.execution.command.v1.ShowCreateTableSuite` + * - V1 Hive External catalog: + * `org.apache.spark.sql.hive.execution.command.ShowCreateTableSuite` + */ +trait ShowCreateTableSuiteBase extends command.ShowCreateTableSuiteBase + with command.TestsV1AndV2Commands { + override def fullName: String = s"$ns.$table" + + test("show create table[simple]") { + // todo After SPARK-37517 unify the testcase both v1 and v2 + withNamespaceAndTable(ns, table) { t => + sql( + s""" + |CREATE TABLE $t ( + | a bigint NOT NULL, + | b bigint, + | c bigint, + | `extraCol` ARRAY, + | `` STRUCT> + |) + |using parquet + |OPTIONS ( + | from = 0, + | to = 1, + | via = 2) + |COMMENT 'This is a comment' + |TBLPROPERTIES ('prop1' = '1', 'prop2' = '2', 'prop3' = 3, 'prop4' = 4) + |PARTITIONED BY (a) + |LOCATION 'file:/tmp' + """.stripMargin) + val showDDL = getShowCreateDDL(t) + assert(showDDL === Array( + s"CREATE TABLE $fullName (", + "b BIGINT,", + "c BIGINT,", + "extraCol ARRAY,", + "`` STRUCT>,", + "a BIGINT NOT NULL)", + "USING parquet", + "OPTIONS (", + "'from' = '0',", + "'to' = '1',", + "'via' = '2')", + "PARTITIONED BY (a)", + "COMMENT 'This is a comment'", + "LOCATION 'file:/tmp'", + "TBLPROPERTIES (", + "'prop1' = '1',", + "'prop2' = '2',", + "'prop3' = '3',", + "'prop4' = '4')" + )) + } + } + + test("bucketed data source table") { + withNamespaceAndTable(ns, table) { t => + sql( + s"""CREATE TABLE $t + |USING json + |CLUSTERED BY (a) INTO 2 BUCKETS + |AS SELECT 1 AS a, "foo" AS b + """.stripMargin + ) + val expected = s"CREATE TABLE $fullName ( a INT, b STRING) USING json" + + s" CLUSTERED BY (a) INTO 2 BUCKETS" + assert(getShowCreateDDL(t).mkString(" ") == expected) + } + } + + test("sort bucketed data source table") { + withNamespaceAndTable(ns, table) { t => + sql( + s"""CREATE TABLE $t + |USING json + |CLUSTERED BY (a) SORTED BY (b) INTO 2 BUCKETS + |AS SELECT 1 AS a, "foo" AS b + """.stripMargin + ) + val expected = s"CREATE TABLE $fullName ( a INT, b STRING) USING json" + + s" CLUSTERED BY (a) SORTED BY (b) INTO 2 BUCKETS" + assert(getShowCreateDDL(t).mkString(" ") == expected) + } + } + + test("partitioned bucketed data source table") { + withNamespaceAndTable(ns, table) { t => + sql( + s"""CREATE TABLE $t + |USING json + |PARTITIONED BY (c) + |CLUSTERED BY (a) INTO 2 BUCKETS + |AS SELECT 1 AS a, "foo" AS b, 2.5 AS c + """.stripMargin + ) + val expected = s"CREATE TABLE $fullName ( a INT, b STRING, c DECIMAL(2,1)) USING json" + + s" PARTITIONED BY (c) CLUSTERED BY (a) INTO 2 BUCKETS" + assert(getShowCreateDDL(t).mkString(" ") == expected) + } + } + + test("partitioned sort bucketed data source table") { + withNamespaceAndTable(ns, table) { t => + sql( + s"""CREATE TABLE $t + |USING json + |PARTITIONED BY (c) + |CLUSTERED BY (a) SORTED BY (b) INTO 2 BUCKETS + |AS SELECT 1 AS a, "foo" AS b, 2.5 AS c + """.stripMargin + ) + val expected = s"CREATE TABLE $fullName ( a INT, b STRING, c DECIMAL(2,1)) USING json" + + s" PARTITIONED BY (c) CLUSTERED BY (a) SORTED BY (b) INTO 2 BUCKETS" + assert(getShowCreateDDL(t).mkString(" ") == expected) + } + } + + test("show create table as serde can't work on data source table") { + withNamespaceAndTable(ns, table) { t => + sql( + s""" + |CREATE TABLE $t ( + | c1 STRING COMMENT 'bla', + | c2 STRING + |) + |USING orc + """.stripMargin + ) + + val cause = intercept[AnalysisException] { + getShowCreateDDL(t, true) + } + + assert(cause.getMessage.contains("Use `SHOW CREATE TABLE` without `AS SERDE` instead")) + } + } +} + +/** + * The class contains tests for the `SHOW CREATE TABLE` command to check V1 In-Memory + * table catalog. + */ +class ShowCreateTableSuite extends ShowCreateTableSuiteBase with CommandSuiteBase { + override def commandVersion: String = super[ShowCreateTableSuiteBase].commandVersion +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowCreateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowCreateTableSuite.scala new file mode 100644 index 0000000000000..7c506812079ec --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowCreateTableSuite.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command.v2 + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.command + +/** + * The class contains tests for the `SHOW CREATE TABLE` command to check V2 table catalogs. + */ +class ShowCreateTableSuite extends command.ShowCreateTableSuiteBase with CommandSuiteBase { + override def fullName: String = s"$catalog.$ns.$table" + + test("SPARK-33898: show create table as serde") { + withNamespaceAndTable(ns, table) { t => + spark.sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing") + val e = intercept[AnalysisException] { + sql(s"SHOW CREATE TABLE $t AS SERDE") + } + assert(e.message.contains(s"SHOW CREATE TABLE AS SERDE is not supported for v2 tables.")) + } + } + + test("SPARK-33898: show create table[CTAS]") { + // does not work with hive, also different order between v2 with v1/hive + withNamespaceAndTable(ns, table) { t => + sql( + s"""CREATE TABLE $t + |$defaultUsing + |PARTITIONED BY (a) + |COMMENT 'This is a comment' + |TBLPROPERTIES ('a' = '1') + |AS SELECT 1 AS a, "foo" AS b + """.stripMargin + ) + val showDDL = getShowCreateDDL(t, false) + assert(showDDL === Array( + s"CREATE TABLE $t (", + "a INT,", + "b STRING)", + defaultUsing, + "PARTITIONED BY (a)", + "COMMENT 'This is a comment'", + "TBLPROPERTIES (", + "'a' = '1')" + )) + } + } + + test("SPARK-33898: show create table[simple]") { + // TODO: After SPARK-37517, we can move the test case to base to test for v2/v1/hive + val db = "ns1" + val table = "tbl" + withNamespaceAndTable(db, table) { t => + sql( + s""" + |CREATE TABLE $t ( + | a bigint NOT NULL, + | b bigint, + | c bigint, + | `extraCol` ARRAY, + | `` STRUCT> + |) + |$defaultUsing + |OPTIONS ( + | from = 0, + | to = 1, + | via = 2) + |COMMENT 'This is a comment' + |TBLPROPERTIES ('prop1' = '1', 'prop2' = '2', 'prop3' = 3, 'prop4' = 4) + |PARTITIONED BY (a) + |LOCATION '/tmp' + """.stripMargin) + val showDDL = getShowCreateDDL(t, false) + assert(showDDL === Array( + s"CREATE TABLE $t (", + "a BIGINT NOT NULL,", + "b BIGINT,", + "c BIGINT,", + "extraCol ARRAY,", + "`` STRUCT>)", + defaultUsing, + "OPTIONS (", + "'from' = '0',", + "'to' = '1',", + "'via' = '2')", + "PARTITIONED BY (a)", + "COMMENT 'This is a comment'", + "LOCATION 'file:/tmp'", + "TBLPROPERTIES (", + "'prop1' = '1',", + "'prop2' = '2',", + "'prop3' = '3',", + "'prop4' = '4')" + )) + } + } + + test("SPARK-33898: show create table[multi-partition]") { + withNamespaceAndTable(ns, table) { t => + sql( + s""" + |CREATE TABLE $t (a INT, b STRING, ts TIMESTAMP) $defaultUsing + |PARTITIONED BY ( + | a, + | bucket(16, b), + | years(ts), + | months(ts), + | days(ts), + | hours(ts)) + """.stripMargin) + // V1 transforms cannot be converted to partition columns: bucket,year,...) + val showDDL = getShowCreateDDL(t, false) + assert(showDDL === Array( + s"CREATE TABLE $t (", + "a INT,", + "b STRING,", + "ts TIMESTAMP)", + defaultUsing, + "PARTITIONED BY (a, years(ts), months(ts), days(ts), hours(ts))", + "CLUSTERED BY (b)", + "INTO 16 BUCKETS" + )) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala index 37fe3c205e5d8..ef6d6f4a2968a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala @@ -26,12 +26,12 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructT class DataSourceStrategySuite extends PlanTest with SharedSparkSession { val attrInts = Seq( - 'cint.int, + Symbol("cint").int, Symbol("c.int").int, - GetStructField('a.struct(StructType( + GetStructField(Symbol("a").struct(StructType( StructField("cstr", StringType, nullable = true) :: StructField("cint", IntegerType, nullable = true) :: Nil)), 1, None), - GetStructField('a.struct(StructType( + GetStructField(Symbol("a").struct(StructType( StructField("c.int", IntegerType, nullable = true) :: StructField("cstr", StringType, nullable = true) :: Nil)), 0, None), GetStructField(Symbol("a.b").struct(StructType( @@ -40,7 +40,7 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { StructField("cint", IntegerType, nullable = true) :: Nil)), 2, None), GetStructField(Symbol("a.b").struct(StructType( StructField("c.int", IntegerType, nullable = true) :: Nil)), 0, None), - GetStructField(GetStructField('a.struct(StructType( + GetStructField(GetStructField(Symbol("a").struct(StructType( StructField("cstr1", StringType, nullable = true) :: StructField("b", StructType(StructField("cint", IntegerType, nullable = true) :: StructField("cstr2", StringType, nullable = true) :: Nil)) :: Nil)), 1, None), 0, None) @@ -55,12 +55,12 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { )) val attrStrs = Seq( - 'cstr.string, + Symbol("cstr").string, Symbol("c.str").string, - GetStructField('a.struct(StructType( + GetStructField(Symbol("a").struct(StructType( StructField("cint", IntegerType, nullable = true) :: StructField("cstr", StringType, nullable = true) :: Nil)), 1, None), - GetStructField('a.struct(StructType( + GetStructField(Symbol("a").struct(StructType( StructField("c.str", StringType, nullable = true) :: StructField("cint", IntegerType, nullable = true) :: Nil)), 0, None), GetStructField(Symbol("a.b").struct(StructType( @@ -69,7 +69,7 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { StructField("cstr", StringType, nullable = true) :: Nil)), 2, None), GetStructField(Symbol("a.b").struct(StructType( StructField("c.str", StringType, nullable = true) :: Nil)), 0, None), - GetStructField(GetStructField('a.struct(StructType( + GetStructField(GetStructField(Symbol("a").struct(StructType( StructField("cint1", IntegerType, nullable = true) :: StructField("b", StructType(StructField("cstr", StringType, nullable = true) :: StructField("cint2", IntegerType, nullable = true) :: Nil)) :: Nil)), 1, None), 0, None) @@ -280,7 +280,7 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { }} test("SPARK-26865 DataSourceV2Strategy should push normalized filters") { - val attrInt = 'cint.int + val attrInt = Symbol("cint").int assertResult(Seq(IsNotNull(attrInt))) { DataSourceStrategy.normalizeExprs(Seq(IsNotNull(attrInt.withName("CiNt"))), Seq(attrInt)) } @@ -308,11 +308,11 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { } // `Abs(col)` can not be pushed down, so it returns `None` - assert(PushableColumnAndNestedColumn.unapply(Abs('col.int)) === None) + assert(PushableColumnAndNestedColumn.unapply(Abs(Symbol("col").int)) === None) } test("SPARK-36644: Push down boolean column filter") { - testTranslateFilter('col.boolean, Some(sources.EqualTo("col", true))) + testTranslateFilter(Symbol("col").boolean, Some(sources.EqualTo("col", true))) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceSuite.scala index 6ba3d2723412b..3034d4fe67c1b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceSuite.scala @@ -143,7 +143,8 @@ class DataSourceSuite extends SharedSparkSession with PrivateMethodTester { test("Data source options should be propagated in method checkAndGlobPathIfNecessary") { val dataSourceOptions = Map("fs.defaultFS" -> "nonexistentFs://nonexistentFs") val dataSource = DataSource(spark, "parquet", Seq("/path3"), options = dataSourceOptions) - val checkAndGlobPathIfNecessary = PrivateMethod[Seq[Path]]('checkAndGlobPathIfNecessary) + val checkAndGlobPathIfNecessary = + PrivateMethod[Seq[Path]](Symbol("checkAndGlobPathIfNecessary")) val message = intercept[java.io.IOException] { dataSource invokePrivate checkAndGlobPathIfNecessary(false, false) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala index f492fc653653e..c9e15f71524d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala @@ -39,12 +39,15 @@ class FileFormatWriterSuite test("SPARK-22252: FileFormatWriter should respect the input query schema") { withTable("t1", "t2", "t3", "t4") { - spark.range(1).select('id as 'col1, 'id as 'col2).write.saveAsTable("t1") + spark.range(1).select(Symbol("id") as Symbol("col1"), Symbol("id") as Symbol("col2")) + .write.saveAsTable("t1") spark.sql("select COL1, COL2 from t1").write.saveAsTable("t2") checkAnswer(spark.table("t2"), Row(0, 0)) // Test picking part of the columns when writing. - spark.range(1).select('id, 'id as 'col1, 'id as 'col2).write.saveAsTable("t3") + spark.range(1) + .select(Symbol("id"), Symbol("id") as Symbol("col1"), Symbol("id") as Symbol("col2")) + .write.saveAsTable("t3") spark.sql("select COL1, COL2 from t3").write.saveAsTable("t4") checkAnswer(spark.table("t4"), Row(0, 0)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala index fcaf8df4f9a02..08ddc67cd6553 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala @@ -488,7 +488,7 @@ class FileIndexSuite extends SharedSparkSession { new Path("file")), Array(new BlockLocation())) ) when(dfs.listLocatedStatus(path)).thenReturn(new RemoteIterator[LocatedFileStatus] { - val iter = statuses.toIterator + val iter = statuses.iterator override def hasNext: Boolean = iter.hasNext override def next(): LocatedFileStatus = iter.next }) @@ -520,6 +520,18 @@ class FileIndexSuite extends SharedSparkSession { SQLConf.get.setConf(StaticSQLConf.METADATA_CACHE_TTL_SECONDS, previousValue) } } + + test("SPARK-38182: Fix NoSuchElementException if pushed filter does not contain any " + + "references") { + withTable("t") { + withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> + "org.apache.spark.sql.catalyst.optimizer.BooleanSimplification") { + + sql("CREATE TABLE t (c1 int) USING PARQUET") + assert(sql("SELECT * FROM t WHERE c1 = 1 AND 2 > 1").count() == 0) + } + } + } } object DeletionRaceFileSystem { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala index fffac885da5fc..410fc985dd3bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala @@ -22,6 +22,7 @@ import java.sql.Timestamp import java.text.SimpleDateFormat import org.apache.spark.sql.{AnalysisException, Column, DataFrame, QueryTest, Row} +import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -278,9 +279,21 @@ class FileMetadataStructSuite extends QueryTest with SharedSparkSession { } metadataColumnsTest("filter", schema) { (df, f0, _) => + val filteredDF = df.select("name", "age", METADATA_FILE_NAME) + .where(Column(METADATA_FILE_NAME) === f0(METADATA_FILE_NAME)) + + // check the filtered file + val partitions = filteredDF.queryExecution.sparkPlan.collectFirst { + case p: FileSourceScanExec => p.selectedPartitions + }.get + + assert(partitions.length == 1) // 1 partition + assert(partitions.head.files.length == 1) // 1 file in that partition + assert(partitions.head.files.head.getPath.toString == f0(METADATA_FILE_PATH)) // the file is f0 + + // check result checkAnswer( - df.select("name", "age", METADATA_FILE_NAME) - .where(Column(METADATA_FILE_NAME) === f0(METADATA_FILE_NAME)), + filteredDF, Seq( // _file_name == f0's name, so we will only have 1 row Row("jack", 24, f0(METADATA_FILE_NAME)) @@ -288,6 +301,36 @@ class FileMetadataStructSuite extends QueryTest with SharedSparkSession { ) } + metadataColumnsTest("filter on metadata and user data", schema) { (df, _, f1) => + + val filteredDF = df.select("name", "age", "info", + METADATA_FILE_NAME, METADATA_FILE_PATH, + METADATA_FILE_SIZE, METADATA_FILE_MODIFICATION_TIME) + // mix metadata column + user column + .where(Column(METADATA_FILE_NAME) === f1(METADATA_FILE_NAME) and Column("name") === "lily") + // only metadata columns + .where(Column(METADATA_FILE_PATH) === f1(METADATA_FILE_PATH)) + // only user column + .where("age == 31") + + // check the filtered file + val partitions = filteredDF.queryExecution.sparkPlan.collectFirst { + case p: FileSourceScanExec => p.selectedPartitions + }.get + + assert(partitions.length == 1) // 1 partition + assert(partitions.head.files.length == 1) // 1 file in that partition + assert(partitions.head.files.head.getPath.toString == f1(METADATA_FILE_PATH)) // the file is f1 + + // check result + checkAnswer( + filteredDF, + Seq(Row("lily", 31, Row(54321L, "ucb"), + f1(METADATA_FILE_NAME), f1(METADATA_FILE_PATH), + f1(METADATA_FILE_SIZE), f1(METADATA_FILE_MODIFICATION_TIME))) + ) + } + Seq(true, false).foreach { caseSensitive => metadataColumnsTest(s"upper/lower case when case " + s"sensitive is $caseSensitive", schemaWithNameConflicts) { (df, f0, f1) => @@ -384,4 +427,141 @@ class FileMetadataStructSuite extends QueryTest with SharedSparkSession { } } } + + metadataColumnsTest("prune metadata schema in projects", schema) { (df, f0, f1) => + val prunedDF = df.select("name", "age", "info.id", METADATA_FILE_NAME) + val fileSourceScanMetaCols = prunedDF.queryExecution.sparkPlan.collectFirst { + case p: FileSourceScanExec => p.metadataColumns + }.get + assert(fileSourceScanMetaCols.size == 1) + assert(fileSourceScanMetaCols.head.name == "file_name") + + checkAnswer( + prunedDF, + Seq(Row("jack", 24, 12345L, f0(METADATA_FILE_NAME)), + Row("lily", 31, 54321L, f1(METADATA_FILE_NAME))) + ) + } + + metadataColumnsTest("prune metadata schema in filters", schema) { (df, f0, f1) => + val prunedDF = df.select("name", "age", "info.id") + .where(col(METADATA_FILE_PATH).contains("data/f0")) + + val fileSourceScanMetaCols = prunedDF.queryExecution.sparkPlan.collectFirst { + case p: FileSourceScanExec => p.metadataColumns + }.get + assert(fileSourceScanMetaCols.size == 1) + assert(fileSourceScanMetaCols.head.name == "file_path") + + checkAnswer( + prunedDF, + Seq(Row("jack", 24, 12345L)) + ) + } + + metadataColumnsTest("prune metadata schema in projects and filters", schema) { (df, f0, f1) => + val prunedDF = df.select("name", "age", "info.id", METADATA_FILE_SIZE) + .where(col(METADATA_FILE_PATH).contains("data/f0")) + + val fileSourceScanMetaCols = prunedDF.queryExecution.sparkPlan.collectFirst { + case p: FileSourceScanExec => p.metadataColumns + }.get + assert(fileSourceScanMetaCols.size == 2) + assert(fileSourceScanMetaCols.map(_.name).toSet == Set("file_size", "file_path")) + + checkAnswer( + prunedDF, + Seq(Row("jack", 24, 12345L, f0(METADATA_FILE_SIZE))) + ) + } + + metadataColumnsTest("write _metadata in parquet and read back", schema) { (df, f0, f1) => + // SPARK-38314: Selecting and then writing df containing hidden file + // metadata column `_metadata` into parquet files will still keep the internal `Attribute` + // metadata information of the column. It will then fail when read again. + withTempDir { dir => + df.select("*", "_metadata") + .write.format("parquet").save(dir.getCanonicalPath + "/new-data") + + val newDF = spark.read.format("parquet").load(dir.getCanonicalPath + "/new-data") + + // SELECT * will have: name, age, info, _metadata of f0 and f1 + checkAnswer( + newDF.select("*"), + Seq( + Row("jack", 24, Row(12345L, "uom"), + Row(f0(METADATA_FILE_PATH), f0(METADATA_FILE_NAME), + f0(METADATA_FILE_SIZE), f0(METADATA_FILE_MODIFICATION_TIME))), + Row("lily", 31, Row(54321L, "ucb"), + Row(f1(METADATA_FILE_PATH), f1(METADATA_FILE_NAME), + f1(METADATA_FILE_SIZE), f1(METADATA_FILE_MODIFICATION_TIME))) + ) + ) + + // SELECT _metadata won't override the existing user data (_metadata of f0 and f1) + checkAnswer( + newDF.select("_metadata"), + Seq( + Row(Row(f0(METADATA_FILE_PATH), f0(METADATA_FILE_NAME), + f0(METADATA_FILE_SIZE), f0(METADATA_FILE_MODIFICATION_TIME))), + Row(Row(f1(METADATA_FILE_PATH), f1(METADATA_FILE_NAME), + f1(METADATA_FILE_SIZE), f1(METADATA_FILE_MODIFICATION_TIME))) + ) + ) + } + } + + metadataColumnsTest("file metadata in streaming", schema) { (df, _, _) => + withTempDir { dir => + df.coalesce(1).write.format("json").save(dir.getCanonicalPath + "/source/new-streaming-data") + + val stream = spark.readStream.format("json") + .schema(schema) + .load(dir.getCanonicalPath + "/source/new-streaming-data") + .select("*", "_metadata") + .writeStream.format("json") + .option("checkpointLocation", dir.getCanonicalPath + "/target/checkpoint") + .start(dir.getCanonicalPath + "/target/new-streaming-data") + + stream.processAllAvailable() + stream.stop() + + val newDF = spark.read.format("json") + .load(dir.getCanonicalPath + "/target/new-streaming-data") + + val sourceFile = new File(dir, "/source/new-streaming-data").listFiles() + .filter(_.getName.endsWith(".json")).head + val sourceFileMetadata = Map( + METADATA_FILE_PATH -> sourceFile.toURI.toString, + METADATA_FILE_NAME -> sourceFile.getName, + METADATA_FILE_SIZE -> sourceFile.length(), + METADATA_FILE_MODIFICATION_TIME -> new Timestamp(sourceFile.lastModified()) + ) + + // SELECT * will have: name, age, info, _metadata of /source/new-streaming-data + assert(newDF.select("*").columns.toSet == Set("name", "age", "info", "_metadata")) + // Verify the data is expected + checkAnswer( + newDF.select(col("name"), col("age"), col("info"), + col(METADATA_FILE_PATH), col(METADATA_FILE_NAME), + // since we are writing _metadata to a json file, + // we should explicitly cast the column to timestamp type + col(METADATA_FILE_SIZE), to_timestamp(col(METADATA_FILE_MODIFICATION_TIME))), + Seq( + Row( + "jack", 24, Row(12345L, "uom"), + sourceFileMetadata(METADATA_FILE_PATH), + sourceFileMetadata(METADATA_FILE_NAME), + sourceFileMetadata(METADATA_FILE_SIZE), + sourceFileMetadata(METADATA_FILE_MODIFICATION_TIME)), + Row( + "lily", 31, Row(54321L, "ucb"), + sourceFileMetadata(METADATA_FILE_PATH), + sourceFileMetadata(METADATA_FILE_NAME), + sourceFileMetadata(METADATA_FILE_SIZE), + sourceFileMetadata(METADATA_FILE_MODIFICATION_TIME)) + ) + ) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala index e4a41ba9e71be..47740c5274616 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala @@ -372,6 +372,34 @@ trait FileSourceAggregatePushDownSuite } } + test("aggregate not push down - MIN/MAX/COUNT with CASE WHEN") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 6)) + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { + val selectAgg = sql( + """ + |SELECT + | min(CASE WHEN _1 < 0 THEN 0 ELSE _1 END), + | min(CASE WHEN _3 > 5 THEN 1 ELSE 0 END), + | max(CASE WHEN _1 < 0 THEN 0 ELSE _1 END), + | max(CASE WHEN NOT(_3 > 5) THEN 1 ELSE 0 END), + | count(CASE WHEN _1 < 0 AND _2 IS NOT NULL THEN 0 ELSE _1 END), + | count(CASE WHEN _3 != 5 OR _2 IS NULL THEN 1 ELSE 0 END) + |FROM t + """.stripMargin) + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + + checkAnswer(selectAgg, Seq(Row(0, 0, 9, 1, 6, 6))) + } + } + } + private def testPushDownForAllDataTypes( inputRows: Seq[Row], expectedMinWithAllTypes: Seq[Row], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 634016664dfb6..b14ccb089f449 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -60,7 +60,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSparkSession with Pre "file9" -> 1, "file10" -> 1)) - checkScan(table.select('c1)) { partitions => + checkScan(table.select(Symbol("c1"))) { partitions => // 10 one byte files should fit in a single partition with 10 files. assert(partitions.size == 1, "when checking partitions") assert(partitions.head.files.size == 10, "when checking partition 1") @@ -83,7 +83,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSparkSession with Pre withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "11", SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "1") { - checkScan(table.select('c1)) { partitions => + checkScan(table.select(Symbol("c1"))) { partitions => // 5 byte files should be laid out [(5, 5), (5)] assert(partitions.size == 2, "when checking partitions") assert(partitions(0).files.size == 2, "when checking partition 1") @@ -108,7 +108,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSparkSession with Pre withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "10", SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "1") { - checkScan(table.select('c1)) { partitions => + checkScan(table.select(Symbol("c1"))) { partitions => // Files should be laid out [(0-10), (10-15, 4)] assert(partitions.size == 2, "when checking partitions") assert(partitions(0).files.size == 1, "when checking partition 1") @@ -141,7 +141,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSparkSession with Pre withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "4", SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "1") { - checkScan(table.select('c1)) { partitions => + checkScan(table.select(Symbol("c1"))) { partitions => // Files should be laid out [(file1), (file2, file3), (file4, file5), (file6)] assert(partitions.size == 4, "when checking partitions") assert(partitions(0).files.size == 1, "when checking partition 1") @@ -359,7 +359,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSparkSession with Pre withSQLConf( SQLConf.FILES_MAX_PARTITION_BYTES.key -> "2", SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "0") { - checkScan(table.select('c1)) { partitions => + checkScan(table.select(Symbol("c1"))) { partitions => assert(partitions.size == 2) assert(partitions(0).files.size == 1) assert(partitions(1).files.size == 2) @@ -375,7 +375,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSparkSession with Pre withSQLConf( SQLConf.FILES_MAX_PARTITION_BYTES.key -> "2", SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "0") { - checkScan(table.select('c1)) { partitions => + checkScan(table.select(Symbol("c1"))) { partitions => assert(partitions.size == 3) assert(partitions(0).files.size == 1) assert(partitions(1).files.size == 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitionsSuite.scala index 98d3d65befe60..bf14a7d91233b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitionsSuite.scala @@ -118,6 +118,28 @@ class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase with Shared } } + test("SPARK-38357: data + partition filters with OR") { + // Force datasource v2 for parquet + withSQLConf((SQLConf.USE_V1_SOURCE_LIST.key, "")) { + withTempPath { dir => + spark.range(10).coalesce(1).selectExpr("id", "id % 3 as p") + .write.partitionBy("p").parquet(dir.getCanonicalPath) + withTempView("tmp") { + spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp"); + assertPrunedPartitions("SELECT * FROM tmp WHERE (p = 0 AND id > 0) OR (p = 1 AND id = 2)", + 2, + "((tmp.p = 0) || (tmp.p = 1))") + assertPrunedPartitions("SELECT * FROM tmp WHERE p = 0 AND id > 0", + 1, + "(tmp.p = 0)") + assertPrunedPartitions("SELECT * FROM tmp WHERE p = 0", + 1, + "(tmp.p = 0)") + } + } + } + } + protected def collectPartitionFiltersFn(): PartialFunction[SparkPlan, Seq[Expression]] = { case scan: FileSourceScanExec => scan.partitionFilters } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index 6fd966c42a067..2c227baa04fc2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -21,6 +21,7 @@ import java.io.File import org.scalactic.Equality +import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.SchemaPruningTest import org.apache.spark.sql.catalyst.expressions.Concat @@ -57,6 +58,9 @@ abstract class SchemaPruningSuite contactId: Int, employer: Employer) + override protected def sparkConf: SparkConf = + super.sparkConf.set(SQLConf.ANSI_STRICT_INDEX_OPERATOR.key, "false") + val janeDoe = FullName("Jane", "X.", "Doe") val johnDoe = FullName("John", "Y.", "Doe") val susanSmith = FullName("Susan", "Z.", "Smith") @@ -569,7 +573,7 @@ abstract class SchemaPruningSuite Seq(Concat(Seq($"name.first", $"name.last")), Concat(Seq($"name.last", $"name.first"))) ), - Seq('a.string, 'b.string), + Seq(Symbol("a").string, Symbol("b").string), sql("select * from contacts").logicalPlan ).toDF() checkScan(query1, "struct>") @@ -586,7 +590,7 @@ abstract class SchemaPruningSuite val name = StructType.fromDDL("first string, middle string, last string") val query2 = Expand( Seq(Seq($"name", $"name.last")), - Seq('a.struct(name), 'b.string), + Seq(Symbol("a").struct(name), Symbol("b").string), sql("select * from contacts").logicalPlan ).toDF() checkScan(query2, "struct>") @@ -905,7 +909,7 @@ abstract class SchemaPruningSuite .createOrReplaceTempView("table") val read = spark.table("table") - val query = read.select(explode($"items").as('item)).select(count($"*")) + val query = read.select(explode($"items").as(Symbol("item"))).select(count($"*")) checkScan(query, "struct>>") checkAnswer(query, Row(2) :: Nil) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 7bbe371879d40..41b4f909ce958 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1621,38 +1621,30 @@ abstract class CSVSuite checkAnswer(df, Row("a", null, "a")) } - test("SPARK-21610: Corrupt records are not handled properly when creating a dataframe " + - "from a file") { - val columnNameOfCorruptRecord = "_corrupt_record" + test("SPARK-38523: referring to the corrupt record column") { val schema = new StructType() .add("a", IntegerType) .add("b", DateType) - .add(columnNameOfCorruptRecord, StringType) - // negative cases - val msg = intercept[AnalysisException] { - spark - .read - .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) - .schema(schema) - .csv(testFile(valueMalformedFile)) - .select(columnNameOfCorruptRecord) - .collect() - }.getMessage - assert(msg.contains("only include the internal corrupt record column")) - - // workaround - val df = spark + .add("corrRec", StringType) + val readback = spark .read - .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + .option("columnNameOfCorruptRecord", "corrRec") .schema(schema) .csv(testFile(valueMalformedFile)) - .cache() - assert(df.filter($"_corrupt_record".isNotNull).count() == 1) - assert(df.filter($"_corrupt_record".isNull).count() == 1) checkAnswer( - df.select(columnNameOfCorruptRecord), - Row("0,2013-111_11 12:13:14") :: Row(null) :: Nil - ) + readback, + Row(0, null, "0,2013-111_11 12:13:14") :: + Row(1, Date.valueOf("1983-08-04"), null) :: Nil) + checkAnswer( + readback.filter($"corrRec".isNotNull), + Row(0, null, "0,2013-111_11 12:13:14")) + checkAnswer( + readback.select($"corrRec", $"b"), + Row("0,2013-111_11 12:13:14", null) :: + Row(null, Date.valueOf("1983-08-04")) :: Nil) + checkAnswer( + readback.filter($"corrRec".isNull && $"a" === 1), + Row(1, Date.valueOf("1983-08-04"), null) :: Nil) } test("SPARK-23846: schema inferring touches less data if samplingRatio < 1.0") { @@ -1836,7 +1828,7 @@ abstract class CSVSuite val idf = spark.read .schema(schema) .csv(path.getCanonicalPath) - .select('f15, 'f10, 'f5) + .select(Symbol("f15"), Symbol("f10"), Symbol("f5")) assert(idf.count() == 2) checkAnswer(idf, List(Row(15, 10, 5), Row(-15, -10, -5))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala index 7d277c1ffaffe..cfcddbaf0d92d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala @@ -63,6 +63,6 @@ class JdbcUtilsSuite extends SparkFunSuite { JdbcUtils.getCustomSchema(tableSchema, "c3 DATE. C2 STRING", caseInsensitive) === StructType(Seq(StructField("c3", DateType, false), StructField("C2", StringType, false))) } - assert(mismatchedInput.getMessage.contains("mismatched input '.' expecting")) + assert(mismatchedInput.getMessage.contains("Syntax error at or near '.'")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmark.scala index e4f6ccaa9a621..c741320d4220b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmark.scala @@ -263,7 +263,7 @@ object JsonBenchmark extends SqlBasedBenchmark { benchmark.addCase("from_json", iters) { _ => val schema = new StructType().add("a", IntegerType) - val from_json_ds = in.select(from_json('value, schema)) + val from_json_ds = in.select(from_json(Symbol("value"), schema)) from_json_ds.noop() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala index e9fe79a0641b9..703085dca66f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala @@ -130,6 +130,45 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSparkSession { Double.NegativeInfinity, Double.NegativeInfinity)) } + test("allowNonNumericNumbers on - quoted") { + val str = + """{"c0":"NaN", "c1":"+INF", "c2":"+Infinity", "c3":"Infinity", "c4":"-INF", + |"c5":"-Infinity"}""".stripMargin + val df = spark.read + .schema(new StructType() + .add("c0", "double") + .add("c1", "double") + .add("c2", "double") + .add("c3", "double") + .add("c4", "double") + .add("c5", "double")) + .option("allowNonNumericNumbers", true).json(Seq(str).toDS()) + checkAnswer( + df, + Row( + Double.NaN, + Double.PositiveInfinity, Double.PositiveInfinity, Double.PositiveInfinity, + Double.NegativeInfinity, Double.NegativeInfinity)) + } + + test("allowNonNumericNumbers off - quoted") { + val str = + """{"c0":"NaN", "c1":"+INF", "c2":"+Infinity", "c3":"Infinity", "c4":"-INF", + |"c5":"-Infinity"}""".stripMargin + val df = spark.read + .schema(new StructType() + .add("c0", "double") + .add("c1", "double") + .add("c2", "double") + .add("c3", "double") + .add("c4", "double") + .add("c5", "double")) + .option("allowNonNumericNumbers", false).json(Seq(str).toDS()) + checkAnswer( + df, + Row(null, null, null, null, null, null)) + } + test("allowBackslashEscapingAnyCharacter off") { val str = """{"name": "Cazen Lee", "price": "\$10"}""" val df = spark.read.option("allowBackslashEscapingAnyCharacter", "false").json(Seq(str).toDS()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 3daad301c23b1..0897ad2ff3009 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -59,6 +59,9 @@ abstract class JsonSuite override protected def dataSourceFormat = "json" + override protected def sparkConf: SparkConf = + super.sparkConf.set(SQLConf.ANSI_STRICT_INDEX_OPERATOR.key, "false") + test("Type promotion") { def checkTypePromotion(expected: Any, actual: Any): Unit = { assert(expected.getClass == actual.getClass, @@ -452,12 +455,6 @@ abstract class JsonSuite Row(null, 21474836570L, 1.1, 21474836470L, "92233720368547758070", null) :: Nil ) - // Number and Boolean conflict: resolve the type as number in this query. - checkAnswer( - sql("select num_bool - 10 from jsonTable where num_bool > 11"), - Row(2) - ) - // Widening to LongType checkAnswer( sql("select num_num_1 - 100 from jsonTable where num_num_1 > 11"), @@ -482,17 +479,27 @@ abstract class JsonSuite Row(101.2) :: Row(21474836471.2) :: Nil ) - // Number and String conflict: resolve the type as number in this query. - checkAnswer( - sql("select num_str + 1.2 from jsonTable where num_str > 14d"), - Row(92233720368547758071.2) - ) + // The following tests are about type coercion instead of JSON data source. + // Here we simply forcus on the behavior of non-Ansi. + if(!SQLConf.get.ansiEnabled) { + // Number and Boolean conflict: resolve the type as number in this query. + checkAnswer( + sql("select num_bool - 10 from jsonTable where num_bool > 11"), + Row(2) + ) - // Number and String conflict: resolve the type as number in this query. - checkAnswer( - sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"), - Row(new java.math.BigDecimal("92233720368547758071.2").doubleValue) - ) + // Number and String conflict: resolve the type as number in this query. + checkAnswer( + sql("select num_str + 1.2 from jsonTable where num_str > 14d"), + Row(92233720368547758071.2) + ) + + // Number and String conflict: resolve the type as number in this query. + checkAnswer( + sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"), + Row(new java.math.BigDecimal("92233720368547758071.2").doubleValue) + ) + } // String and Boolean conflict: resolve the type as string. checkAnswer( @@ -1346,12 +1353,12 @@ abstract class JsonSuite } test("Dataset toJSON doesn't construct rdd") { - val containsRDD = spark.emptyDataFrame.toJSON.queryExecution.logical.find { + val containsRDDExists = spark.emptyDataFrame.toJSON.queryExecution.logical.exists { case ExternalRDD(_, _) => true case _ => false } - assert(containsRDD.isEmpty, "Expected logical plan of toJSON to not contain an RDD") + assert(!containsRDDExists, "Expected logical plan of toJSON to not contain an RDD") } test("JSONRelation equality test") { @@ -2021,13 +2028,19 @@ abstract class JsonSuite test("SPARK-18772: Parse special floats correctly") { val jsons = Seq( """{"a": "NaN"}""", + """{"a": "+INF"}""", + """{"a": "-INF"}""", """{"a": "Infinity"}""", + """{"a": "+Infinity"}""", """{"a": "-Infinity"}""") // positive cases val checks: Seq[Double => Boolean] = Seq( _.isNaN, _.isPosInfinity, + _.isNegInfinity, + _.isPosInfinity, + _.isPosInfinity, _.isNegInfinity) Seq(FloatType, DoubleType).foreach { dt => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/noop/NoopStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/noop/NoopStreamSuite.scala index 3cb8287f09b26..b892a9e155815 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/noop/NoopStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/noop/NoopStreamSuite.scala @@ -90,7 +90,7 @@ class NoopStreamSuite extends StreamTest { .option("numPartitions", "1") .option("rowsPerSecond", "5") .load() - .select('value) + .select(Symbol("value")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/noop/NoopSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/noop/NoopSuite.scala index b4073bedf5597..811953754953a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/noop/NoopSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/noop/NoopSuite.scala @@ -42,7 +42,7 @@ class NoopSuite extends SharedSparkSession { withTempPath { dir => val path = dir.getCanonicalPath spark.range(numElems) - .select('id mod 10 as "key", 'id as "value") + .select(Symbol("id") mod 10 as "key", Symbol("id") as "value") .write .partitionBy("key") .parquet(path) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReaderSuite.scala index bfcef46339908..4ff9612ab4847 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReaderSuite.scala @@ -25,11 +25,11 @@ import org.apache.hadoop.mapreduce.lib.input.FileSplit import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.orc.TypeDescription +import org.apache.spark.TestUtils import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase import org.apache.spark.sql.execution.vectorized.{OnHeapColumnVector, WritableColumnVector} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -117,7 +117,7 @@ class OrcColumnarBatchReaderSuite extends QueryTest with SharedSparkSession { dataTypes.zip(constantValues).foreach { case (dt, v) => val schema = StructType(StructField("col1", IntegerType) :: StructField("pcol", dt) :: Nil) val partitionValues = new GenericInternalRow(Array(v)) - val file = new File(SpecificParquetRecordReaderBase.listDirectory(dir).get(0)) + val file = new File(TestUtils.listDirectory(dir).head) val fileSplit = new FileSplit(new Path(file.getCanonicalPath), 0L, file.length, Array.empty) val taskConf = sqlContext.sessionState.newHadoopConf() val orcFileSchema = TypeDescription.fromString(schema.simpleString) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index 038606b854d9e..49b7cfa9d3724 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.orc import java.io.File import java.nio.charset.StandardCharsets import java.sql.Timestamp -import java.time.{LocalDateTime, ZoneOffset} +import java.time.LocalDateTime import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -371,7 +371,7 @@ abstract class OrcQueryTest extends OrcTest { withTempPath { dir => val path = dir.getCanonicalPath - spark.range(0, 10).select('id as "Acol").write.orc(path) + spark.range(0, 10).select(Symbol("id") as "Acol").write.orc(path) spark.read.orc(path).schema("Acol") intercept[IllegalArgumentException] { spark.read.orc(path).schema("acol") @@ -416,19 +416,19 @@ abstract class OrcQueryTest extends OrcTest { s"No data was filtered for predicate: $pred") } - checkPredicate('a === 5, List(5).map(Row(_, null))) - checkPredicate('a <=> 5, List(5).map(Row(_, null))) - checkPredicate('a < 5, List(1, 3).map(Row(_, null))) - checkPredicate('a <= 5, List(1, 3, 5).map(Row(_, null))) - checkPredicate('a > 5, List(7, 9).map(Row(_, null))) - checkPredicate('a >= 5, List(5, 7, 9).map(Row(_, null))) - checkPredicate('a.isNull, List(null).map(Row(_, null))) - checkPredicate('b.isNotNull, List()) - checkPredicate('a.isin(3, 5, 7), List(3, 5, 7).map(Row(_, null))) - checkPredicate('a > 0 && 'a < 3, List(1).map(Row(_, null))) - checkPredicate('a < 1 || 'a > 8, List(9).map(Row(_, null))) - checkPredicate(!('a > 3), List(1, 3).map(Row(_, null))) - checkPredicate(!('a > 0 && 'a < 3), List(3, 5, 7, 9).map(Row(_, null))) + checkPredicate(Symbol("a") === 5, List(5).map(Row(_, null))) + checkPredicate(Symbol("a") <=> 5, List(5).map(Row(_, null))) + checkPredicate(Symbol("a") < 5, List(1, 3).map(Row(_, null))) + checkPredicate(Symbol("a") <= 5, List(1, 3, 5).map(Row(_, null))) + checkPredicate(Symbol("a") > 5, List(7, 9).map(Row(_, null))) + checkPredicate(Symbol("a") >= 5, List(5, 7, 9).map(Row(_, null))) + checkPredicate(Symbol("a").isNull, List(null).map(Row(_, null))) + checkPredicate(Symbol("b").isNotNull, List()) + checkPredicate(Symbol("a").isin(3, 5, 7), List(3, 5, 7).map(Row(_, null))) + checkPredicate(Symbol("a") > 0 && Symbol("a") < 3, List(1).map(Row(_, null))) + checkPredicate(Symbol("a") < 1 || Symbol("a") > 8, List(9).map(Row(_, null))) + checkPredicate(!(Symbol("a") > 3), List(1, 3).map(Row(_, null))) + checkPredicate(!(Symbol("a") > 0 && Symbol("a") < 3), List(3, 5, 7, 9).map(Row(_, null))) } } } @@ -734,10 +734,10 @@ abstract class OrcQuerySuite extends OrcQueryTest with SharedSparkSession { withSQLConf(SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "true") { val readDf = spark.read.orc(path) - val vectorizationEnabled = readDf.queryExecution.executedPlan.find { + val vectorizationEnabled = readDf.queryExecution.executedPlan.exists { case scan @ (_: FileSourceScanExec | _: BatchScanExec) => scan.supportsColumnar case _ => false - }.isDefined + } assert(vectorizationEnabled) checkAnswer(readDf, df) } @@ -756,10 +756,10 @@ abstract class OrcQuerySuite extends OrcQueryTest with SharedSparkSession { withSQLConf(SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "true") { val readDf = spark.read.orc(path) - val vectorizationEnabled = readDf.queryExecution.executedPlan.find { + val vectorizationEnabled = readDf.queryExecution.executedPlan.exists { case scan @ (_: FileSourceScanExec | _: BatchScanExec) => scan.supportsColumnar case _ => false - }.isDefined + } assert(vectorizationEnabled) checkAnswer(readDf, df) } @@ -783,10 +783,10 @@ abstract class OrcQuerySuite extends OrcQueryTest with SharedSparkSession { withSQLConf(SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "true", SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> maxNumFields) { val scanPlan = spark.read.orc(path).queryExecution.executedPlan - assert(scanPlan.find { + assert(scanPlan.exists { case scan @ (_: FileSourceScanExec | _: BatchScanExec) => scan.supportsColumnar case _ => false - }.isDefined == vectorizedEnabled) + } == vectorizedEnabled) } } } @@ -803,55 +803,6 @@ abstract class OrcQuerySuite extends OrcQueryTest with SharedSparkSession { } } } - - test("SPARK-36346: can't read TimestampLTZ as TimestampNTZ") { - val data = (1 to 10).map { i => - val ts = new Timestamp(i) - Row(ts) - } - val answer = (1 to 10).map { i => - // The second parameter is `nanoOfSecond`, while java.sql.Timestamp accepts milliseconds - // as input. So here we multiple the `nanoOfSecond` by NANOS_PER_MILLIS - val ts = LocalDateTime.ofEpochSecond(0, i * 1000000, ZoneOffset.UTC) - Row(ts) - } - val actualSchema = StructType(Seq(StructField("time", TimestampType, false))) - val providedSchema = StructType(Seq(StructField("time", TimestampNTZType, false))) - - withTempPath { file => - val df = spark.createDataFrame(sparkContext.parallelize(data), actualSchema) - df.write.orc(file.getCanonicalPath) - withAllNativeOrcReaders { - val msg = intercept[SparkException] { - spark.read.schema(providedSchema).orc(file.getCanonicalPath).collect() - }.getMessage - assert(msg.contains("Unable to convert timestamp of Orc to data type 'timestamp_ntz'")) - } - } - } - - test("SPARK-36346: read TimestampNTZ as TimestampLTZ") { - val data = (1 to 10).map { i => - // The second parameter is `nanoOfSecond`, while java.sql.Timestamp accepts milliseconds - // as input. So here we multiple the `nanoOfSecond` by NANOS_PER_MILLIS - val ts = LocalDateTime.ofEpochSecond(0, i * 1000000, ZoneOffset.UTC) - Row(ts) - } - val answer = (1 to 10).map { i => - val ts = new java.sql.Timestamp(i) - Row(ts) - } - val actualSchema = StructType(Seq(StructField("time", TimestampNTZType, false))) - val providedSchema = StructType(Seq(StructField("time", TimestampType, false))) - - withTempPath { file => - val df = spark.createDataFrame(sparkContext.parallelize(data), actualSchema) - df.write.orc(file.getCanonicalPath) - withAllNativeOrcReaders { - checkAnswer(spark.read.schema(providedSchema).orc(file.getCanonicalPath), answer) - } - } - } } class OrcV1QuerySuite extends OrcQuerySuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala index 96932de3275bc..c36bfd9362466 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala @@ -47,7 +47,7 @@ import org.apache.spark.sql.internal.SQLConf.ORC_IMPLEMENTATION * -> HiveOrcPartitionDiscoverySuite * -> OrcFilterSuite */ -abstract class OrcTest extends QueryTest with FileBasedDataSourceTest with BeforeAndAfterAll { +trait OrcTest extends QueryTest with FileBasedDataSourceTest with BeforeAndAfterAll { val orcImp: String = "native" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala index f545e88517700..f7100a53444aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.fs.Path import org.apache.parquet.column.{Encoding, ParquetProperties} import org.apache.parquet.hadoop.ParquetOutputFormat +import org.apache.spark.TestUtils import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf @@ -50,12 +51,12 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSparkSess (1 :: 1000 :: Nil).foreach { n => { withTempPath { dir => List.fill(n)(ROW).toDF.repartition(1).write.parquet(dir.getCanonicalPath) - val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head + val file = TestUtils.listDirectory(dir).head val conf = sqlContext.conf val reader = new VectorizedParquetRecordReader( conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) - reader.initialize(file.asInstanceOf[String], null) + reader.initialize(file, null) val batch = reader.resultBatch() assert(reader.nextBatch()) assert(batch.numRows() == n) @@ -80,12 +81,12 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSparkSess withTempPath { dir => val data = List.fill(n)(NULL_ROW).toDF data.repartition(1).write.parquet(dir.getCanonicalPath) - val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head + val file = TestUtils.listDirectory(dir).head val conf = sqlContext.conf val reader = new VectorizedParquetRecordReader( conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) - reader.initialize(file.asInstanceOf[String], null) + reader.initialize(file, null) val batch = reader.resultBatch() assert(reader.nextBatch()) assert(batch.numRows() == n) @@ -114,7 +115,7 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSparkSess // first page is dictionary encoded and the remaining two are plain encoded. val data = (0 until 512).flatMap(i => Seq.fill(3)(i.toString)) data.toDF("f").coalesce(1).write.parquet(dir.getCanonicalPath) - val file = SpecificParquetRecordReaderBase.listDirectory(dir).asScala.head + val file = TestUtils.listDirectory(dir).head val conf = sqlContext.conf val reader = new VectorizedParquetRecordReader( @@ -182,4 +183,40 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSparkSess } } } + + test("parquet v2 pages - rle encoding for boolean value columns") { + val extraOptions = Map[String, String]( + ParquetOutputFormat.WRITER_VERSION -> ParquetProperties.WriterVersion.PARQUET_2_0.toString + ) + + val hadoopConf = spark.sessionState.newHadoopConfWithOptions(extraOptions) + withSQLConf( + SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true", + ParquetOutputFormat.JOB_SUMMARY_LEVEL -> "ALL") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/test.parquet" + val size = 10000 + val data = (1 to size).map { i => (true, false, i % 2 == 1) } + + spark.createDataFrame(data) + .write.options(extraOptions).mode("overwrite").parquet(path) + + val blockMetadata = readFooter(new Path(path), hadoopConf).getBlocks.asScala.head + val columnChunkMetadataList = blockMetadata.getColumns.asScala + + // Verify that indeed rle encoding is used for each column + assert(columnChunkMetadataList.length === 3) + assert(columnChunkMetadataList.head.getEncodings.contains(Encoding.RLE)) + assert(columnChunkMetadataList(1).getEncodings.contains(Encoding.RLE)) + assert(columnChunkMetadataList(2).getEncodings.contains(Encoding.RLE)) + + val actual = spark.read.parquet(path).collect() + assert(actual.length == size) + assert(actual.map(_.getBoolean(0)).forall(_ == true)) + assert(actual.map(_.getBoolean(1)).forall(_ == false)) + val excepted = (1 to size).map { i => i % 2 == 1 } + assert(actual.map(_.getBoolean(2)).sameElements(excepted)) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala new file mode 100644 index 0000000000000..5e01d3f447c96 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkException +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{ArrayType, IntegerType, MapType, Metadata, MetadataBuilder, StringType, StructType} + +class ParquetFieldIdIOSuite extends QueryTest with ParquetTest with SharedSparkSession { + + private def withId(id: Int): Metadata = + new MetadataBuilder().putLong(ParquetUtils.FIELD_ID_METADATA_KEY, id).build() + + test("Parquet reads infer fields using field ids correctly") { + withTempDir { dir => + val readSchema = + new StructType() + .add("a", StringType, true, withId(0)) + .add("b", IntegerType, true, withId(1)) + + val readSchemaMixed = + new StructType() + .add("name", StringType, true) + .add("b", IntegerType, true, withId(1)) + + val readSchemaMixedHalfMatched = + new StructType() + .add("unmatched", StringType, true) + .add("b", IntegerType, true, withId(1)) + + val writeSchema = + new StructType() + .add("random", IntegerType, true, withId(1)) + .add("name", StringType, true, withId(0)) + + val readData = Seq(Row("text", 100), Row("more", 200)) + val readDataHalfMatched = Seq(Row(null, 100), Row(null, 200)) + val writeData = Seq(Row(100, "text"), Row(200, "more")) + spark.createDataFrame(writeData.asJava, writeSchema) + .write.mode("overwrite").parquet(dir.getCanonicalPath) + + withAllParquetReaders { + // read with schema + checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath), readData) + checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath) + .where("b < 50"), Seq.empty) + checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath) + .where("a >= 'oh'"), Row("text", 100) :: Nil) + // read with mixed field-id/name schema + checkAnswer(spark.read.schema(readSchemaMixed).parquet(dir.getCanonicalPath), readData) + checkAnswer(spark.read.schema(readSchemaMixedHalfMatched) + .parquet(dir.getCanonicalPath), readDataHalfMatched) + + // schema inference should pull into the schema with ids + val reader = spark.read.parquet(dir.getCanonicalPath) + assert(reader.schema == writeSchema) + checkAnswer(reader.where("name >= 'oh'"), Row(100, "text") :: Nil) + } + } + } + + test("absence of field ids") { + withTempDir { dir => + val readSchema = + new StructType() + .add("a", IntegerType, true, withId(1)) + .add("b", StringType, true, withId(2)) + .add("c", IntegerType, true, withId(3)) + + val writeSchema = + new StructType() + .add("a", IntegerType, true, withId(3)) + .add("randomName", StringType, true) + + val writeData = Seq(Row(100, "text"), Row(200, "more")) + + spark.createDataFrame(writeData.asJava, writeSchema) + .write.mode("overwrite").parquet(dir.getCanonicalPath) + + withAllParquetReaders { + checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath), + // 3 different cases for the 3 columns to read: + // - a: ID 1 is not found, but there is column with name `a`, still return null + // - b: ID 2 is not found, return null + // - c: ID 3 is found, read it + Row(null, null, 100) :: Row(null, null, 200) :: Nil) + } + } + } + + test("SPARK-38094: absence of field ids: reading nested schema") { + withTempDir { dir => + // now with nested schema/complex type + val readSchema = + new StructType() + .add("a", IntegerType, true, withId(1)) + .add("b", ArrayType(StringType), true, withId(2)) + .add("c", new StructType().add("c1", IntegerType, true, withId(6)), true, withId(3)) + .add("d", MapType(StringType, StringType), true, withId(4)) + .add("e", IntegerType, true, withId(5)) + + val writeSchema = + new StructType() + .add("a", IntegerType, true, withId(5)) + .add("randomName", StringType, true) + + val writeData = Seq(Row(100, "text"), Row(200, "more")) + + spark.createDataFrame(writeData.asJava, writeSchema) + .write.mode("overwrite").parquet(dir.getCanonicalPath) + + withAllParquetReaders { + checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath), + // a, b, c, d all couldn't be found + Row(null, null, null, null, 100) :: Row(null, null, null, null, 200) :: Nil) + } + } + } + + test("multiple id matches") { + withTempDir { dir => + val readSchema = + new StructType() + .add("a", IntegerType, true, withId(1)) + + val writeSchema = + new StructType() + .add("a", IntegerType, true, withId(1)) + .add("rand1", StringType, true, withId(2)) + .add("rand2", StringType, true, withId(1)) + + val writeData = Seq(Row(100, "text", "txt"), Row(200, "more", "mr")) + + spark.createDataFrame(writeData.asJava, writeSchema) + .write.mode("overwrite").parquet(dir.getCanonicalPath) + + withAllParquetReaders { + val cause = intercept[SparkException] { + spark.read.schema(readSchema).parquet(dir.getCanonicalPath).collect() + }.getCause + assert(cause.isInstanceOf[RuntimeException] && + cause.getMessage.contains("Found duplicate field(s)")) + } + } + } + + test("read parquet file without ids") { + withTempDir { dir => + val readSchema = + new StructType() + .add("a", IntegerType, true, withId(1)) + + val writeSchema = + new StructType() + .add("a", IntegerType, true) + .add("rand1", StringType, true) + .add("rand2", StringType, true) + + val writeData = Seq(Row(100, "text", "txt"), Row(200, "more", "mr")) + spark.createDataFrame(writeData.asJava, writeSchema) + .write.mode("overwrite").parquet(dir.getCanonicalPath) + withAllParquetReaders { + Seq(readSchema, readSchema.add("b", StringType, true)).foreach { schema => + val cause = intercept[SparkException] { + spark.read.schema(schema).parquet(dir.getCanonicalPath).collect() + }.getCause + assert(cause.isInstanceOf[RuntimeException] && + cause.getMessage.contains("Parquet file schema doesn't contain any field Ids")) + val expectedValues = (1 to schema.length).map(_ => null) + withSQLConf(SQLConf.IGNORE_MISSING_PARQUET_FIELD_ID.key -> "true") { + checkAnswer( + spark.read.schema(schema).parquet(dir.getCanonicalPath), + Row(expectedValues: _*) :: Row(expectedValues: _*) :: Nil) + } + } + } + } + } + + test("global read/write flag should work correctly") { + withTempDir { dir => + val readSchema = + new StructType() + .add("some", IntegerType, true, withId(1)) + .add("other", StringType, true, withId(2)) + .add("name", StringType, true, withId(3)) + + val writeSchema = + new StructType() + .add("a", IntegerType, true, withId(1)) + .add("rand1", StringType, true, withId(2)) + .add("rand2", StringType, true, withId(3)) + + val writeData = Seq(Row(100, "text", "txt"), Row(200, "more", "mr")) + + val expectedResult = Seq(Row(null, null, null), Row(null, null, null)) + + withSQLConf(SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.key -> "false", + SQLConf.PARQUET_FIELD_ID_READ_ENABLED.key -> "true") { + spark.createDataFrame(writeData.asJava, writeSchema) + .write.mode("overwrite").parquet(dir.getCanonicalPath) + withAllParquetReaders { + // no field id found exception + val cause = intercept[SparkException] { + spark.read.schema(readSchema).parquet(dir.getCanonicalPath).collect() + }.getCause + assert(cause.isInstanceOf[RuntimeException] && + cause.getMessage.contains("Parquet file schema doesn't contain any field Ids")) + } + } + + withSQLConf(SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.key -> "true", + SQLConf.PARQUET_FIELD_ID_READ_ENABLED.key -> "false") { + spark.createDataFrame(writeData.asJava, writeSchema) + .write.mode("overwrite").parquet(dir.getCanonicalPath) + withAllParquetReaders { + // ids are there, but we don't use id for matching, so no results would be returned + checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath), expectedResult) + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala new file mode 100644 index 0000000000000..b3babdd3a0cff --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala @@ -0,0 +1,528 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import scala.collection.JavaConverters._ + +import org.apache.parquet.schema.{MessageType, MessageTypeParser} + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +class ParquetFieldIdSchemaSuite extends ParquetSchemaTest { + + private val FAKE_COLUMN_NAME = "_fake_name_" + private val UUID_REGEX = + "[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}".r + + private def withId(id: Int) = + new MetadataBuilder().putLong(ParquetUtils.FIELD_ID_METADATA_KEY, id).build() + + private def testSchemaClipping( + testName: String, + parquetSchema: String, + catalystSchema: StructType, + expectedSchema: String, + caseSensitive: Boolean = true, + useFieldId: Boolean = true): Unit = { + test(s"Clipping with field id - $testName") { + val fileSchema = MessageTypeParser.parseMessageType(parquetSchema) + val actual = ParquetReadSupport.clipParquetSchema( + fileSchema, + catalystSchema, + caseSensitive = caseSensitive, + useFieldId = useFieldId) + + // each fake name should be uniquely generated + val fakeColumnNames = actual.getPaths.asScala.flatten.filter(_.startsWith(FAKE_COLUMN_NAME)) + assert( + fakeColumnNames.distinct == fakeColumnNames, "Should generate unique fake column names") + + // replace the random part of all fake names with a fixed id generator + val ids1 = (1 to 100).iterator + val actualNormalized = MessageTypeParser.parseMessageType( + UUID_REGEX.replaceAllIn(actual.toString, _ => ids1.next().toString) + ) + val ids2 = (1 to 100).iterator + val expectedNormalized = MessageTypeParser.parseMessageType( + FAKE_COLUMN_NAME.r.replaceAllIn(expectedSchema, _ => s"$FAKE_COLUMN_NAME${ids2.next()}") + ) + + try { + expectedNormalized.checkContains(actualNormalized) + actualNormalized.checkContains(expectedNormalized) + } catch { case cause: Throwable => + fail( + s"""Expected clipped schema: + |$expectedSchema + |Actual clipped schema: + |$actual + """.stripMargin, + cause) + } + checkEqual(actualNormalized, expectedNormalized) + // might be redundant but just to have some free tests for the utils + assert(ParquetReadSupport.containsFieldIds(fileSchema)) + assert(ParquetUtils.hasFieldIds(catalystSchema)) + } + } + + private def testSqlToParquet( + testName: String, + sqlSchema: StructType, + parquetSchema: String): Unit = { + val converter = new SparkToParquetSchemaConverter( + writeLegacyParquetFormat = false, + outputTimestampType = SQLConf.ParquetOutputTimestampType.INT96, + useFieldId = true) + + test(s"sql => parquet: $testName") { + val actual = converter.convert(sqlSchema) + val expected = MessageTypeParser.parseMessageType(parquetSchema) + checkEqual(actual, expected) + } + } + + private def checkEqual(actual: MessageType, expected: MessageType): Unit = { + actual.checkContains(expected) + expected.checkContains(actual) + assert(actual.toString == expected.toString, + s""" + |Schema mismatch. + |Expected schema: + |${expected.toString} + |Actual schema: + |${actual.toString} + """.stripMargin + ) + } + + test("check hasFieldIds for schema") { + val simpleSchemaMissingId = new StructType() + .add("f010", DoubleType, nullable = true, withId(7)) + .add("f012", LongType, nullable = true) + + assert(ParquetUtils.hasFieldIds(simpleSchemaMissingId)) + + val f01ElementType = new StructType() + .add("f010", DoubleType, nullable = true, withId(7)) + .add("f012", LongType, nullable = true, withId(8)) + + assert(ParquetUtils.hasFieldIds(f01ElementType)) + + val f0Type = new StructType() + .add("f00", ArrayType(StringType, containsNull = false), nullable = true, withId(2)) + .add("f01", ArrayType(f01ElementType, containsNull = false), nullable = true) + + assert(ParquetUtils.hasFieldIds(f0Type)) + + assert(ParquetUtils.hasFieldIds( + new StructType().add("f0", f0Type, nullable = false, withId(1)))) + + assert(!ParquetUtils.hasFieldIds(new StructType().add("f0", IntegerType, nullable = true))) + assert(!ParquetUtils.hasFieldIds(new StructType())); + } + + test("check getFieldId for schema") { + val schema = new StructType() + .add("overflowId", DoubleType, nullable = true, + new MetadataBuilder() + .putLong(ParquetUtils.FIELD_ID_METADATA_KEY, 12345678987654321L).build()) + .add("stringId", StringType, nullable = true, + new MetadataBuilder() + .putString(ParquetUtils.FIELD_ID_METADATA_KEY, "lol").build()) + .add("negativeId", LongType, nullable = true, withId(-20)) + .add("noId", LongType, nullable = true) + + assert(intercept[IllegalArgumentException] { + ParquetUtils.getFieldId(schema.findNestedField(Seq("noId")).get._2) + }.getMessage.contains("doesn't exist")) + + assert(intercept[IllegalArgumentException] { + ParquetUtils.getFieldId(schema.findNestedField(Seq("overflowId")).get._2) + }.getMessage.contains("must be a 32-bit integer")) + + assert(intercept[IllegalArgumentException] { + ParquetUtils.getFieldId(schema.findNestedField(Seq("stringId")).get._2) + }.getMessage.contains("must be a 32-bit integer")) + + // negative id allowed + assert(ParquetUtils.getFieldId(schema.findNestedField(Seq("negativeId")).get._2) == -20) + } + + test("check containsFieldIds for parquet schema") { + + // empty Parquet schema fails too + assert( + !ParquetReadSupport.containsFieldIds( + MessageTypeParser.parseMessageType( + """message root { + |} + """.stripMargin))) + + assert( + !ParquetReadSupport.containsFieldIds( + MessageTypeParser.parseMessageType( + """message root { + | required group f0 { + | optional int32 f00; + | } + |} + """.stripMargin))) + + assert( + ParquetReadSupport.containsFieldIds( + MessageTypeParser.parseMessageType( + """message root { + | required group f0 = 1 { + | optional int32 f00; + | optional binary f01; + | } + |} + """.stripMargin))) + + assert( + ParquetReadSupport.containsFieldIds( + MessageTypeParser.parseMessageType( + """message root { + | required group f0 { + | optional int32 f00 = 1; + | optional binary f01; + | } + |} + """.stripMargin))) + + assert( + !ParquetReadSupport.containsFieldIds( + MessageTypeParser.parseMessageType( + """message spark_schema { + | required group f0 { + | optional group f00 (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | } + |} + """.stripMargin))) + + assert( + ParquetReadSupport.containsFieldIds( + MessageTypeParser.parseMessageType( + """message spark_schema { + | required group f0 { + | optional group f00 (LIST) { + | repeated group list = 1 { + | required binary element (UTF8); + | } + | } + | } + |} + """.stripMargin))) + } + + test("ID in Parquet Types is read as null when not set") { + val parquetSchemaString = + """message root { + | required group f0 { + | optional int32 f00; + | } + |} + """.stripMargin + + val parquetSchema = MessageTypeParser.parseMessageType(parquetSchemaString) + val f0 = parquetSchema.getFields().get(0) + assert(f0.getId() == null) + assert(f0.asGroupType().getFields.get(0).getId == null) + } + + testSqlToParquet( + "standard array", + sqlSchema = { + val f01ElementType = new StructType() + .add("f010", DoubleType, nullable = true, withId(7)) + .add("f012", LongType, nullable = true, withId(9)) + + val f0Type = new StructType() + .add("f00", ArrayType(StringType, containsNull = false), nullable = true, withId(2)) + .add("f01", ArrayType(f01ElementType, containsNull = false), nullable = true, withId(5)) + + new StructType().add("f0", f0Type, nullable = false, withId(1)) + }, + parquetSchema = + """message spark_schema { + | required group f0 = 1 { + | optional group f00 (LIST) = 2 { + | repeated group list { + | required binary element (UTF8); + | } + | } + | + | optional group f01 (LIST) = 5 { + | repeated group list { + | required group element { + | optional double f010 = 7; + | optional int64 f012 = 9; + | } + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "simple nested struct", + + parquetSchema = + """message root { + | required group f0 = 1 { + | optional int32 f00 = 2; + | optional int32 f01 = 3; + | } + |} + """.stripMargin, + + catalystSchema = { + val f0Type = new StructType().add( + "g00", IntegerType, nullable = true, withId(2)) + new StructType() + .add("g0", f0Type, nullable = false, withId(1)) + .add("g1", IntegerType, nullable = true, withId(4)) + }, + + expectedSchema = + s"""message spark_schema { + | required group f0 = 1 { + | optional int32 f00 = 2; + | } + | optional int32 $FAKE_COLUMN_NAME = 4; + |} + """.stripMargin) + + testSchemaClipping( + "standard array", + + parquetSchema = + """message root { + | required group f0 = 1 { + | optional group f00 (LIST) = 2 { + | repeated group list { + | required binary element (UTF8); + | } + | } + | + | optional group f01 (LIST) = 5 { + | repeated group list { + | required group element { + | optional int32 f010 = 7; + | optional double f011 = 8; + | } + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val f01ElementType = new StructType() + .add("g011", DoubleType, nullable = true, withId(8)) + .add("g012", LongType, nullable = true, withId(9)) + + val f0Type = new StructType() + .add("g00", ArrayType(StringType, containsNull = false), nullable = true, withId(2)) + .add("g01", ArrayType(f01ElementType, containsNull = false), nullable = true, withId(5)) + + new StructType().add("g0", f0Type, nullable = false, withId(1)) + }, + + expectedSchema = + s"""message spark_schema { + | required group f0 = 1 { + | optional group f00 (LIST) = 2 { + | repeated group list { + | required binary element (UTF8); + | } + | } + | + | optional group f01 (LIST) = 5 { + | repeated group list { + | required group element { + | optional double f011 = 8; + | optional int64 $FAKE_COLUMN_NAME = 9; + | } + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "standard map with complex key", + + parquetSchema = + """message root { + | required group f0 (MAP) = 3 { + | repeated group key_value = 1 { + | required group key = 2 { + | required int32 value_f0 = 4; + | required int64 value_f1 = 6; + | } + | required int32 value = 5; + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val keyType = + new StructType() + .add("value_g1", LongType, nullable = false, withId(6)) + .add("value_g2", DoubleType, nullable = false, withId(7)) + + val f0Type = MapType(keyType, IntegerType, valueContainsNull = false) + + new StructType().add("g0", f0Type, nullable = false, withId(3)) + }, + + expectedSchema = + s"""message spark_schema { + | required group f0 (MAP) = 3 { + | repeated group key_value = 1 { + | required group key = 2 { + | required int64 value_f1 = 6; + | required double $FAKE_COLUMN_NAME = 7; + | } + | required int32 value = 5; + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "won't match field id if structure is different", + + parquetSchema = + """message root { + | required group f0 = 1 { + | optional int32 f00 = 2; + | } + | optional int32 f1 = 3; + |} + """.stripMargin, + + catalystSchema = { + val f0Type = new StructType() + .add("g00", IntegerType, nullable = true, withId(2)) + // parquet has id 3, but won't use because structure is different + .add("g01", IntegerType, nullable = true, withId(3)) + new StructType() + .add("g0", f0Type, nullable = false, withId(1)) + }, + + // note that f1 is not picked up, even though it's Id is 3 + expectedSchema = + s"""message spark_schema { + | required group f0 = 1 { + | optional int32 f00 = 2; + | optional int32 $FAKE_COLUMN_NAME = 3; + | } + |} + """.stripMargin) + + testSchemaClipping( + "Complex type with multiple mismatches should work", + + parquetSchema = + """message root { + | required group f0 = 1 { + | optional int32 f00 = 2; + | } + | optional int32 f1 = 3; + | optional int32 f2 = 4; + |} + """.stripMargin, + + catalystSchema = { + val f0Type = new StructType() + .add("g00", IntegerType, nullable = true, withId(2)) + + new StructType() + .add("g0", f0Type, nullable = false, withId(999)) + .add("g1", IntegerType, nullable = true, withId(3)) + .add("g2", IntegerType, nullable = true, withId(888)) + }, + + expectedSchema = + s"""message spark_schema { + | required group $FAKE_COLUMN_NAME = 999 { + | optional int32 g00 = 2; + | } + | optional int32 f1 = 3; + | optional int32 $FAKE_COLUMN_NAME = 888; + |} + """.stripMargin) + + testSchemaClipping( + "Should allow fall-back to name matching if id not found", + + parquetSchema = + """message root { + | required group f0 = 1 { + | optional int32 f00 = 2; + | } + | optional int32 f1 = 3; + | optional int32 f2 = 4; + | required group f4 = 5 { + | optional int32 f40 = 6; + | } + |} + """.stripMargin, + + catalystSchema = { + val f0Type = new StructType() + // nested f00 without id should also work + .add("f00", IntegerType, nullable = true) + + val f4Type = new StructType() + .add("g40", IntegerType, nullable = true, withId(6)) + + new StructType() + .add("g0", f0Type, nullable = false, withId(1)) + .add("g1", IntegerType, nullable = true, withId(3)) + // f2 without id should be matched using name matching + .add("f2", IntegerType, nullable = true) + // name is not matched + .add("g2", IntegerType, nullable = true) + // f4 without id will do name matching, but g40 will be matched using id + .add("f4", f4Type, nullable = true) + }, + + expectedSchema = + s"""message spark_schema { + | required group f0 = 1 { + | optional int32 f00 = 2; + | } + | optional int32 f1 = 3; + | optional int32 f2 = 4; + | optional int32 g2; + | required group f4 = 5 { + | optional int32 f40 = 6; + | } + |} + """.stripMargin) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 9b554b626df85..d5180a393f61a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -1426,39 +1426,39 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared test("filter pushdown - StringStartsWith") { withParquetDataFrame((1 to 4).map(i => Tuple1(i + "str" + i))) { implicit df => checkFilterPredicate( - '_1.startsWith("").asInstanceOf[Predicate], + Symbol("_1").startsWith("").asInstanceOf[Predicate], classOf[UserDefinedByInstance[_, _]], Seq("1str1", "2str2", "3str3", "4str4").map(Row(_))) Seq("2", "2s", "2st", "2str", "2str2").foreach { prefix => checkFilterPredicate( - '_1.startsWith(prefix).asInstanceOf[Predicate], + Symbol("_1").startsWith(prefix).asInstanceOf[Predicate], classOf[UserDefinedByInstance[_, _]], "2str2") } Seq("2S", "null", "2str22").foreach { prefix => checkFilterPredicate( - '_1.startsWith(prefix).asInstanceOf[Predicate], + Symbol("_1").startsWith(prefix).asInstanceOf[Predicate], classOf[UserDefinedByInstance[_, _]], Seq.empty[Row]) } checkFilterPredicate( - !'_1.startsWith("").asInstanceOf[Predicate], + !Symbol("_1").startsWith("").asInstanceOf[Predicate], classOf[Operators.Not], Seq().map(Row(_))) Seq("2", "2s", "2st", "2str", "2str2").foreach { prefix => checkFilterPredicate( - !'_1.startsWith(prefix).asInstanceOf[Predicate], + !Symbol("_1").startsWith(prefix).asInstanceOf[Predicate], classOf[Operators.Not], Seq("1str1", "3str3", "4str4").map(Row(_))) } Seq("2S", "null", "2str22").foreach { prefix => checkFilterPredicate( - !'_1.startsWith(prefix).asInstanceOf[Predicate], + !Symbol("_1").startsWith(prefix).asInstanceOf[Predicate], classOf[Operators.Not], Seq("1str1", "2str2", "3str3", "4str4").map(Row(_))) } @@ -1472,7 +1472,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared // SPARK-28371: make sure filter is null-safe. withParquetDataFrame(Seq(Tuple1[String](null))) { implicit df => checkFilterPredicate( - '_1.startsWith("blah").asInstanceOf[Predicate], + Symbol("_1").startsWith("blah").asInstanceOf[Predicate], classOf[UserDefinedByInstance[_, _]], Seq.empty[Row]) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 0966319f53fc7..99b2d9844ed1b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -38,7 +38,7 @@ import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.parquet.hadoop.metadata.CompressionCodecName.GZIP import org.apache.parquet.schema.{MessageType, MessageTypeParser} -import org.apache.spark.{SPARK_VERSION_SHORT, SparkException} +import org.apache.spark.{SPARK_VERSION_SHORT, SparkException, TestUtils} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} @@ -187,11 +187,16 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession .range(1000) // Parquet doesn't allow column names with spaces, have to add an alias here. // Minus 500 here so that negative decimals are also tested. - .select((('id - 500) / 100.0) cast decimal as 'dec) + .select(((Symbol("id") - 500) / 100.0) cast decimal as Symbol("dec")) .coalesce(1) } - val combinations = Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17), (19, 0), (38, 37)) + var combinations = Seq((5, 2), (1, 0), (18, 10), (18, 17), (19, 0), (38, 37)) + // If ANSI mode is on, the combination (1, 1) will cause a runtime error. Otherwise, the + // decimal RDD contains all null values and should be able to read back from Parquet. + if (!SQLConf.get.ansiEnabled) { + combinations = combinations++ Seq((1, 1)) + } for ((precision, scale) <- combinations) { withTempPath { dir => val data = makeDecimalRDD(DecimalType(precision, scale)) @@ -797,7 +802,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession withTempPath { dir => val m2 = intercept[SparkException] { - val df = spark.range(1).select('id as 'a, 'id as 'b).coalesce(1) + val df = spark.range(1).select(Symbol("id") as Symbol("a"), Symbol("id") as Symbol("b")) + .coalesce(1) df.write.partitionBy("a").options(extraOptions).parquet(dir.getCanonicalPath) }.getCause.getMessage assert(m2.contains("Intentional exception for testing purposes")) @@ -863,7 +869,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession checkAnswer( // Decimal column in this file is encoded using plain dictionary readResourceParquetFile("test-data/dec-in-i32.parquet"), - spark.range(1 << 4).select('id % 10 cast DecimalType(5, 2) as 'i32_dec)) + spark.range(1 << 4).select(Symbol("id") % 10 cast DecimalType(5, 2) as Symbol("i32_dec"))) } } @@ -872,7 +878,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession checkAnswer( // Decimal column in this file is encoded using plain dictionary readResourceParquetFile("test-data/dec-in-i64.parquet"), - spark.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'i64_dec)) + spark.range(1 << 4).select(Symbol("id") % 10 cast DecimalType(10, 2) as Symbol("i64_dec"))) } } @@ -881,7 +887,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession checkAnswer( // Decimal column in this file is encoded using plain dictionary readResourceParquetFile("test-data/dec-in-fixed-len.parquet"), - spark.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'fixed_len_dec)) + spark.range(1 << 4) + .select(Symbol("id") % 10 cast DecimalType(10, 2) as Symbol("fixed_len_dec"))) } } @@ -928,7 +935,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession val data = (0 to 10).map(i => (i, (i + 'a').toChar.toString)) withTempPath { dir => spark.createDataFrame(data).repartition(1).write.parquet(dir.getCanonicalPath) - val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0); + val file = TestUtils.listDirectory(dir).head; { val conf = sqlContext.conf val reader = new VectorizedParquetRecordReader( @@ -1032,7 +1039,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession val vectorizedReader = new VectorizedParquetRecordReader( conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) val partitionValues = new GenericInternalRow(Array(v)) - val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0) + val file = TestUtils.listDirectory(dir).head try { vectorizedReader.initialize(file, null) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index bf37421331db6..f3751562c332e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -979,7 +979,8 @@ abstract class ParquetPartitionDiscoverySuite withTempPath { dir => withSQLConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "1") { val path = dir.getCanonicalPath - val df = spark.range(5).select('id as 'a, 'id as 'b, 'id as 'c).coalesce(1) + val df = spark.range(5).select(Symbol("id") as Symbol("a"), Symbol("id") as Symbol("b"), + Symbol("id") as Symbol("c")).coalesce(1) df.write.partitionBy("b", "c").parquet(path) checkAnswer(spark.read.parquet(path), df) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 057de2abdb9e0..654ab7fe36200 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -153,7 +153,7 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS (1, "2016-01-01 10:11:12.123456"), (2, null), (3, "1965-01-01 10:11:12.123456")) - .toDS().select('_1, $"_2".cast("timestamp")) + .toDS().select(Symbol("_1"), $"_2".cast("timestamp")) checkAnswer(sql("select * from ts"), expected) } } @@ -805,7 +805,7 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS test("SPARK-15804: write out the metadata to parquet file") { val df = Seq((1, "abc"), (2, "hello")).toDF("a", "b") val md = new MetadataBuilder().putString("key", "value").build() - val dfWithmeta = df.select('a, 'b.as("b", md)) + val dfWithmeta = df.select(Symbol("a"), Symbol("b").as("b", md)) withTempPath { dir => val path = dir.getCanonicalPath @@ -1027,7 +1027,7 @@ class ParquetV1QuerySuite extends ParquetQuerySuite { withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "10") { withTempPath { dir => val path = dir.getCanonicalPath - val df = spark.range(10).select(Seq.tabulate(11) {i => ('id + i).as(s"c$i")} : _*) + val df = spark.range(10).select(Seq.tabulate(11) {i => (Symbol("id") + i).as(s"c$i")} : _*) df.write.mode(SaveMode.Overwrite).parquet(path) // do not return batch - whole stage codegen is disabled for wide table (>200 columns) @@ -1060,7 +1060,7 @@ class ParquetV2QuerySuite extends ParquetQuerySuite { withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "10") { withTempPath { dir => val path = dir.getCanonicalPath - val df = spark.range(10).select(Seq.tabulate(11) {i => ('id + i).as(s"c$i")} : _*) + val df = spark.range(10).select(Seq.tabulate(11) {i => (Symbol("id") + i).as(s"c$i")} : _*) df.write.mode(SaveMode.Overwrite).parquet(path) // do not return batch - whole stage codegen is disabled for wide table (>200 columns) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala index 49251af54193f..dbf7f54f6ff90 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala @@ -143,12 +143,12 @@ abstract class ParquetRebaseDatetimeSuite val df = Seq.tabulate(N)(rowFunc).toDF("dict", "plain") .select($"dict".cast(catalystType), $"plain".cast(catalystType)) withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> tsOutputType) { - checkDefaultLegacyRead(oldPath) + checkDefaultLegacyRead(oldPath) withSQLConf(inWriteConf -> CORRECTED.toString) { - df.write.mode("overwrite").parquet(path3_x) + df.write.mode("overwrite").parquet(path3_x) } withSQLConf(inWriteConf -> LEGACY.toString) { - df.write.parquet(path3_x_rebase) + df.write.parquet(path3_x_rebase) } } // For Parquet files written by Spark 3.0, we know the writer info and don't need the @@ -243,40 +243,41 @@ abstract class ParquetRebaseDatetimeSuite SQLConf.PARQUET_INT96_REBASE_MODE_IN_READ.key ) ).foreach { case (outType, tsStr, nonRebased, inWriteConf, inReadConf) => - // Ignore the default JVM time zone and use the session time zone instead of it in rebasing. - DateTimeTestUtils.withDefaultTimeZone(DateTimeTestUtils.JST) { - withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> DateTimeTestUtils.LA.getId) { - withClue(s"output type $outType") { - withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> outType) { - withTempPath { dir => - val path = dir.getAbsolutePath - withSQLConf(inWriteConf -> LEGACY.toString) { - Seq.tabulate(N)(_ => tsStr).toDF("tsS") - .select($"tsS".cast("timestamp").as("ts")) - .repartition(1) - .write - .option("parquet.enable.dictionary", dictionaryEncoding) - .parquet(path) - } + // Ignore the default JVM time zone and use the session time zone instead of + // it in rebasing. + DateTimeTestUtils.withDefaultTimeZone(DateTimeTestUtils.JST) { + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> DateTimeTestUtils.LA.getId) { + withClue(s"output type $outType") { + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> outType) { + withTempPath { dir => + val path = dir.getAbsolutePath + withSQLConf(inWriteConf -> LEGACY.toString) { + Seq.tabulate(N)(_ => tsStr).toDF("tsS") + .select($"tsS".cast("timestamp").as("ts")) + .repartition(1) + .write + .option("parquet.enable.dictionary", dictionaryEncoding) + .parquet(path) + } - withAllParquetReaders { - // The file metadata indicates if it needs rebase or not, so we can always get - // the correct result regardless of the "rebase mode" config. - runInMode(inReadConf, Seq(LEGACY, CORRECTED, EXCEPTION)) { options => - checkAnswer( - spark.read.options(options).parquet(path).select($"ts".cast("string")), - Seq.tabulate(N)(_ => Row(tsStr))) - } + withAllParquetReaders { + // The file metadata indicates if it needs rebase or not, so we can always get + // the correct result regardless of the "rebase mode" config. + runInMode(inReadConf, Seq(LEGACY, CORRECTED, EXCEPTION)) { options => + checkAnswer( + spark.read.options(options).parquet(path).select($"ts".cast("string")), + Seq.tabulate(N)(_ => Row(tsStr))) + } - // Force to not rebase to prove the written datetime values are rebased - // and we will get wrong result if we don't rebase while reading. - withSQLConf("spark.test.forceNoRebase" -> "true") { - checkAnswer( - spark.read.parquet(path).select($"ts".cast("string")), - Seq.tabulate(N)(_ => Row(nonRebased))) + // Force to not rebase to prove the written datetime values are rebased + // and we will get wrong result if we don't rebase while reading. + withSQLConf("spark.test.forceNoRebase" -> "true") { + checkAnswer( + spark.read.parquet(path).select($"ts".cast("string")), + Seq.tabulate(N)(_ => Row(nonRebased))) + } } } - } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 272f12e138b68..d0228d7bdf9f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -944,7 +944,8 @@ class ParquetSchemaSuite extends ParquetSchemaTest { withTempPath { dir => val path = dir.getCanonicalPath spark.range(3).write.parquet(s"$path/p=1") - spark.range(3).select('id cast IntegerType as 'id).write.parquet(s"$path/p=2") + spark.range(3).select(Symbol("id") cast IntegerType as Symbol("id")) + .write.parquet(s"$path/p=2") val message = intercept[SparkException] { spark.read.option("mergeSchema", "true").parquet(path).schema @@ -2257,7 +2258,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { caseSensitive: Boolean): Unit = { test(s"Clipping - $testName") { val actual = ParquetReadSupport.clipParquetSchema( - MessageTypeParser.parseMessageType(parquetSchema), catalystSchema, caseSensitive) + MessageTypeParser.parseMessageType(parquetSchema), + catalystSchema, + caseSensitive, + useFieldId = false) try { expectedSchema.checkContains(actual) @@ -2821,7 +2825,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { } assertThrows[RuntimeException] { ParquetReadSupport.clipParquetSchema( - MessageTypeParser.parseMessageType(parquetSchema), catalystSchema, caseSensitive = false) + MessageTypeParser.parseMessageType(parquetSchema), + catalystSchema, + caseSensitive = false, + useFieldId = false) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 47723166213dd..18690844d484c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -33,6 +33,7 @@ import org.apache.parquet.hadoop.metadata.{BlockMetaData, FileMetaData, ParquetM import org.apache.parquet.hadoop.util.HadoopInputFile import org.apache.parquet.schema.MessageType +import org.apache.spark.TestUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.execution.datasources.FileBasedDataSourceTest import org.apache.spark.sql.internal.SQLConf @@ -164,9 +165,17 @@ private[sql] trait ParquetTest extends FileBasedDataSourceTest { def withAllParquetReaders(code: => Unit): Unit = { // test the row-based reader - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false")(code) + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + withClue("Parquet-mr reader") { + code + } + } // test the vectorized reader - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true")(code) + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") { + withClue("Vectorized reader") { + code + } + } } def withAllParquetWriters(code: => Unit): Unit = { @@ -179,7 +188,7 @@ private[sql] trait ParquetTest extends FileBasedDataSourceTest { } def getMetaData(dir: java.io.File): Map[String, String] = { - val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0) + val file = TestUtils.listDirectory(dir).head val conf = new Configuration() val hadoopInputFile = HadoopInputFile.fromPath(new Path(file), conf) val parquetReadOptions = HadoopReadOptions.builder(conf).build() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala index 143feebdd4994..0fb6fc58c400d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.types.BooleanType class DataSourceV2StrategySuite extends PlanTest with SharedSparkSession { test("SPARK-36644: Push down boolean column filter") { - testTranslateFilter('col.boolean, + testTranslateFilter(Symbol("col").boolean, Some(new V2EqualTo(FieldReference("col"), LiteralValue(true, BooleanType)))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala index 86f4dc467638f..bae793bb01214 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala @@ -29,7 +29,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchDatabaseException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, NamespaceChange, TableCatalog, TableChange, V1Table} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, NamespaceChange, SupportsNamespaces, TableCatalog, TableChange, V1Table} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType, StructField, StructType, TimestampType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -60,17 +60,18 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { super.beforeAll() val catalog = newCatalog() catalog.createNamespace(Array("db"), emptyProps) - catalog.createNamespace(Array("db2"), emptyProps) + catalog.createNamespace(Array("db2"), + Map(SupportsNamespaces.PROP_LOCATION -> "file:///db2.db").asJava) catalog.createNamespace(Array("ns"), emptyProps) catalog.createNamespace(Array("ns2"), emptyProps) } override protected def afterAll(): Unit = { val catalog = newCatalog() - catalog.dropNamespace(Array("db")) - catalog.dropNamespace(Array("db2")) - catalog.dropNamespace(Array("ns")) - catalog.dropNamespace(Array("ns2")) + catalog.dropNamespace(Array("db"), cascade = true) + catalog.dropNamespace(Array("db2"), cascade = true) + catalog.dropNamespace(Array("ns"), cascade = true) + catalog.dropNamespace(Array("ns2"), cascade = true) super.afterAll() } @@ -186,10 +187,17 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { assert(t2.catalogTable.location === makeQualifiedPathWithWarehouse("db.db/relative/path")) catalog.dropTable(testIdent) - // absolute path + // absolute path without scheme properties.put(TableCatalog.PROP_LOCATION, "/absolute/path") val t3 = catalog.createTable(testIdent, schema, Array.empty, properties).asInstanceOf[V1Table] - assert(t3.catalogTable.location.toString === "file:/absolute/path") + assert(t3.catalogTable.location.toString === "file:///absolute/path") + catalog.dropTable(testIdent) + + // absolute path with scheme + properties.put(TableCatalog.PROP_LOCATION, "file:/absolute/path") + val t4 = catalog.createTable(testIdent, schema, Array.empty, properties).asInstanceOf[V1Table] + assert(t4.catalogTable.location.toString === "file:/absolute/path") + catalog.dropTable(testIdent) } test("tableExists") { @@ -685,10 +693,15 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { TableChange.setProperty(TableCatalog.PROP_LOCATION, "relative/path")).asInstanceOf[V1Table] assert(t2.catalogTable.location === makeQualifiedPathWithWarehouse("db.db/relative/path")) - // absolute path + // absolute path without scheme val t3 = catalog.alterTable(testIdent, TableChange.setProperty(TableCatalog.PROP_LOCATION, "/absolute/path")).asInstanceOf[V1Table] - assert(t3.catalogTable.location.toString === "file:/absolute/path") + assert(t3.catalogTable.location.toString === "file:///absolute/path") + + // absolute path with scheme + val t4 = catalog.alterTable(testIdent, TableChange.setProperty( + TableCatalog.PROP_LOCATION, "file:/absolute/path")).asInstanceOf[V1Table] + assert(t4.catalogTable.location.toString === "file:/absolute/path") } test("dropTable") { @@ -785,8 +798,7 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { private def filterV2TableProperties( properties: util.Map[String, String]): Map[String, String] = { properties.asScala.filter(kv => !CatalogV2Util.TABLE_RESERVED_PROPERTIES.contains(kv._1)) - .filter(!_._1.startsWith(TableCatalog.OPTION_PREFIX)) - .filter(_._1 != TableCatalog.PROP_EXTERNAL).toMap + .filter(!_._1.startsWith(TableCatalog.OPTION_PREFIX)).toMap } } @@ -811,7 +823,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { assert(catalog.listNamespaces(Array()) === Array(testNs, defaultNs)) assert(catalog.listNamespaces(testNs) === Array()) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("listNamespaces: fail if missing namespace") { @@ -849,7 +861,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { assert(catalog.namespaceExists(testNs) === true) checkMetadata(metadata.asScala, Map("property" -> "value")) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("loadNamespaceMetadata: empty metadata") { @@ -864,7 +876,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { assert(catalog.namespaceExists(testNs) === true) checkMetadata(metadata.asScala, emptyProps.asScala) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("createNamespace: basic behavior") { @@ -884,7 +896,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { checkMetadata(metadata, Map("property" -> "value")) assert(expectedPath === metadata("location")) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("createNamespace: initialize location") { @@ -900,7 +912,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { checkMetadata(metadata, Map.empty) assert(expectedPath === metadata("location")) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("createNamespace: relative location") { @@ -917,7 +929,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { checkMetadata(metadata, Map.empty) assert(expectedPath === metadata("location")) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("createNamespace: fail if namespace already exists") { @@ -933,7 +945,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { assert(catalog.namespaceExists(testNs) === true) checkMetadata(catalog.loadNamespaceMetadata(testNs).asScala, Map("property" -> "value")) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("createNamespace: fail nested namespace") { @@ -948,7 +960,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { assert(exc.getMessage.contains("Invalid namespace name: db.nested")) - catalog.dropNamespace(Array("db")) + catalog.dropNamespace(Array("db"), cascade = false) } test("createTable: fail if namespace does not exist") { @@ -969,7 +981,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { assert(catalog.namespaceExists(testNs) === false) - val ret = catalog.dropNamespace(testNs) + val ret = catalog.dropNamespace(testNs, cascade = false) assert(ret === false) } @@ -981,7 +993,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { assert(catalog.namespaceExists(testNs) === true) - val ret = catalog.dropNamespace(testNs) + val ret = catalog.dropNamespace(testNs, cascade = false) assert(ret === true) assert(catalog.namespaceExists(testNs) === false) @@ -993,8 +1005,8 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { catalog.createNamespace(testNs, Map("property" -> "value").asJava) catalog.createTable(testIdent, schema, Array.empty, emptyProps) - val exc = intercept[IllegalStateException] { - catalog.dropNamespace(testNs) + val exc = intercept[AnalysisException] { + catalog.dropNamespace(testNs, cascade = false) } assert(exc.getMessage.contains(testNs.quoted)) @@ -1002,7 +1014,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { checkMetadata(catalog.loadNamespaceMetadata(testNs).asScala, Map("property" -> "value")) catalog.dropTable(testIdent) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("alterNamespace: basic behavior") { @@ -1027,7 +1039,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { catalog.loadNamespaceMetadata(testNs).asScala, Map("property" -> "value")) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("alterNamespace: update namespace location") { @@ -1050,7 +1062,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { catalog.alterNamespace(testNs, NamespaceChange.setProperty("location", "relativeP")) assert(newRelativePath === spark.catalog.getDatabase(testNs(0)).locationUri) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("alterNamespace: update namespace comment") { @@ -1065,7 +1077,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { assert(newComment === spark.catalog.getDatabase(testNs(0)).description) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("alterNamespace: fail if namespace doesn't exist") { @@ -1092,6 +1104,6 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { assert(exc.getMessage.contains(s"Cannot remove reserved property: $p")) } - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 046ff78ce9bd3..db99557466d95 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -138,6 +138,14 @@ class EnsureRequirementsSuite extends SharedSparkSession { } } + private def applyEnsureRequirementsWithSubsetKeys(plan: SparkPlan): SparkPlan = { + var res: SparkPlan = null + withSQLConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false") { + res = EnsureRequirements.apply(plan) + } + res + } + test("Successful compatibility check with HashShuffleSpec") { val plan1 = DummySparkPlan( outputPartitioning = HashPartitioning(exprA :: Nil, 5)) @@ -155,10 +163,14 @@ class EnsureRequirementsSuite extends SharedSparkSession { case other => fail(other.toString) } - // should also work if both partition keys are subset of their corresponding cluster keys smjExec = SortMergeJoinExec( exprA :: exprB :: Nil, exprB :: exprC :: Nil, Inner, None, plan1, plan2) - EnsureRequirements.apply(smjExec) match { + // By default we can't eliminate shuffles if the partitions keys are subset of join keys. + assert(EnsureRequirements.apply(smjExec) + .collect { case s: ShuffleExchangeLike => s }.length == 2) + // with the config set, it should also work if both partition keys are subset of their + // corresponding cluster keys + applyEnsureRequirementsWithSubsetKeys(smjExec) match { case SortMergeJoinExec(leftKeys, rightKeys, _, _, SortExec(_, _, DummySparkPlan(_, _, _: HashPartitioning, _, _), _), SortExec(_, _, DummySparkPlan(_, _, _: HashPartitioning, _, _), _), _) => @@ -169,7 +181,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { smjExec = SortMergeJoinExec( exprB :: exprA :: Nil, exprC :: exprB :: Nil, Inner, None, plan1, plan2) - EnsureRequirements.apply(smjExec) match { + applyEnsureRequirementsWithSubsetKeys(smjExec) match { case SortMergeJoinExec(leftKeys, rightKeys, _, _, SortExec(_, _, DummySparkPlan(_, _, _: HashPartitioning, _, _), _), SortExec(_, _, DummySparkPlan(_, _, _: HashPartitioning, _, _), _), _) => @@ -186,7 +198,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { outputPartitioning = HashPartitioning(exprA :: exprC :: Nil, 5)) var smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) - EnsureRequirements.apply(smjExec) match { + applyEnsureRequirementsWithSubsetKeys(smjExec) match { case SortMergeJoinExec(leftKeys, rightKeys, _, _, SortExec(_, _, DummySparkPlan(_, _, _: HashPartitioning, _, _), _), SortExec(_, _, DummySparkPlan(_, _, _: HashPartitioning, _, _), _), _) => @@ -201,7 +213,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { outputPartitioning = HashPartitioning(exprA :: exprC :: exprA :: Nil, 5)) smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) - EnsureRequirements.apply(smjExec) match { + applyEnsureRequirementsWithSubsetKeys(smjExec) match { case SortMergeJoinExec(leftKeys, rightKeys, _, _, SortExec(_, _, DummySparkPlan(_, _, _: HashPartitioning, _, _), _), SortExec(_, _, DummySparkPlan(_, _, _: HashPartitioning, _, _), _), _) => @@ -216,7 +228,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { outputPartitioning = HashPartitioning(exprA :: exprC :: exprA :: Nil, 5)) smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprD :: Nil, Inner, None, plan1, plan2) - EnsureRequirements.apply(smjExec) match { + applyEnsureRequirementsWithSubsetKeys(smjExec) match { case SortMergeJoinExec(leftKeys, rightKeys, _, _, SortExec(_, _, DummySparkPlan(_, _, _: HashPartitioning, _, _), _), SortExec(_, _, DummySparkPlan(_, _, _: HashPartitioning, _, _), _), _) => @@ -231,7 +243,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { outputPartitioning = HashPartitioning(exprA :: exprC :: Nil, 5)) smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) - EnsureRequirements.apply(smjExec) match { + applyEnsureRequirementsWithSubsetKeys(smjExec) match { case SortMergeJoinExec(leftKeys, rightKeys, _, _, SortExec(_, _, DummySparkPlan(_, _, _: HashPartitioning, _, _), _), SortExec(_, _, DummySparkPlan(_, _, _: HashPartitioning, _, _), _), _) => @@ -249,7 +261,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { outputPartitioning = HashPartitioning(exprD :: Nil, 5)) var smjExec = SortMergeJoinExec( exprA :: exprB :: Nil, exprC :: exprD :: Nil, Inner, None, plan1, plan2) - EnsureRequirements.apply(smjExec) match { + applyEnsureRequirementsWithSubsetKeys(smjExec) match { case SortMergeJoinExec(leftKeys, rightKeys, _, _, SortExec(_, _, DummySparkPlan(_, _, _: HashPartitioning, _, _), _), SortExec(_, _, ShuffleExchangeExec(p: HashPartitioning, _, _), _), _) => @@ -266,7 +278,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { outputPartitioning = HashPartitioning(exprD :: Nil, 10)) smjExec = SortMergeJoinExec( exprA :: exprB :: Nil, exprC :: exprD :: Nil, Inner, None, plan1, plan2) - EnsureRequirements.apply(smjExec) match { + applyEnsureRequirementsWithSubsetKeys(smjExec) match { case SortMergeJoinExec(leftKeys, rightKeys, _, _, SortExec(_, _, ShuffleExchangeExec(p: HashPartitioning, _, _), _), SortExec(_, _, DummySparkPlan(_, _, _: HashPartitioning, _, _), _), _) => @@ -283,7 +295,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { outputPartitioning = HashPartitioning(exprD :: Nil, 5)) smjExec = SortMergeJoinExec( exprA :: exprB :: Nil, exprC :: exprD :: Nil, Inner, None, plan1, plan2) - EnsureRequirements.apply(smjExec) match { + applyEnsureRequirementsWithSubsetKeys(smjExec) match { case SortMergeJoinExec(leftKeys, rightKeys, _, _, SortExec(_, _, DummySparkPlan(_, _, _: HashPartitioning, _, _), _), SortExec(_, _, ShuffleExchangeExec(p: HashPartitioning, _, _), _), _) => @@ -292,8 +304,6 @@ class EnsureRequirementsSuite extends SharedSparkSession { assert(p.expressions == Seq(exprC)) case other => fail(other.toString) } - - } } @@ -304,7 +314,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { outputPartitioning = HashPartitioning(exprA :: exprC :: exprB :: Nil, 5)) var smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) - EnsureRequirements.apply(smjExec) match { + applyEnsureRequirementsWithSubsetKeys(smjExec) match { case SortMergeJoinExec(leftKeys, rightKeys, _, _, SortExec(_, _, DummySparkPlan(_, _, _: HashPartitioning, _, _), _), SortExec(_, _, ShuffleExchangeExec(p: HashPartitioning, _, _), _), _) => @@ -320,7 +330,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { outputPartitioning = HashPartitioning(exprA :: exprC :: exprB :: Nil, 5)) smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprD :: Nil, Inner, None, plan1, plan2) - EnsureRequirements.apply(smjExec) match { + applyEnsureRequirementsWithSubsetKeys(smjExec) match { case SortMergeJoinExec(leftKeys, rightKeys, _, _, SortExec(_, _, DummySparkPlan(_, _, _: HashPartitioning, _, _), _), SortExec(_, _, ShuffleExchangeExec(p: HashPartitioning, _, _), _), _) => @@ -403,13 +413,26 @@ class EnsureRequirementsSuite extends SharedSparkSession { } // HashPartitioning(1) <-> RangePartitioning(10) - // Only RHS should be shuffled and be converted to HashPartitioning(1) <-> HashPartitioning(1) + // If the conf is not set, both sides should be shuffled and be converted to + // HashPartitioning(5) <-> HashPartitioning(5) + // If the conf is set, only RHS should be shuffled and be converted to + // HashPartitioning(1) <-> HashPartitioning(1) plan1 = DummySparkPlan(outputPartitioning = HashPartitioning(Seq(exprA), 1)) plan2 = DummySparkPlan(outputPartitioning = RangePartitioning( Seq(SortOrder.apply(exprC, Ascending, sameOrderExpressions = Seq.empty)), 10)) smjExec = SortMergeJoinExec( exprA :: exprB :: Nil, exprC :: exprD :: Nil, Inner, None, plan1, plan2) EnsureRequirements.apply(smjExec) match { + case SortMergeJoinExec(_, _, _, _, + SortExec(_, _, ShuffleExchangeExec(left: HashPartitioning, _, _), _), + SortExec(_, _, ShuffleExchangeExec(right: HashPartitioning, _, _), _), _) => + assert(left.numPartitions == 5) + assert(left.expressions == Seq(exprA, exprB)) + assert(right.numPartitions == 5) + assert(right.expressions == Seq(exprC, exprD)) + case other => fail(other.toString) + } + applyEnsureRequirementsWithSubsetKeys(smjExec) match { case SortMergeJoinExec(_, _, _, _, SortExec(_, _, DummySparkPlan(_, _, left: HashPartitioning, _, _), _), SortExec(_, _, ShuffleExchangeExec(right: HashPartitioning, _, _), _), _) => @@ -446,7 +469,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { smjExec = SortMergeJoinExec( exprA :: exprB :: exprC :: exprD :: Nil, exprA :: exprB :: exprC :: exprD :: Nil, Inner, None, plan1, plan2) - EnsureRequirements.apply(smjExec) match { + applyEnsureRequirementsWithSubsetKeys(smjExec) match { case SortMergeJoinExec(_, _, _, _, SortExec(_, _, DummySparkPlan(_, _, left: PartitioningCollection, _, _), _), SortExec(_, _, ShuffleExchangeExec(right: HashPartitioning, _, _), _), _) => @@ -463,7 +486,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { smjExec = SortMergeJoinExec( exprA :: exprB :: exprC :: exprD :: Nil, exprA :: exprB :: exprC :: exprD :: Nil, Inner, None, plan1, plan2) - EnsureRequirements.apply(smjExec) match { + applyEnsureRequirementsWithSubsetKeys(smjExec) match { case SortMergeJoinExec(_, _, _, _, SortExec(_, _, ShuffleExchangeExec(left: HashPartitioning, _, _), _), SortExec(_, _, DummySparkPlan(_, _, right: PartitioningCollection, _, _), _), _) => @@ -482,17 +505,17 @@ class EnsureRequirementsSuite extends SharedSparkSession { // HashPartitioning(5) <-> HashPartitioning(5) // No shuffle should be inserted var plan1: SparkPlan = DummySparkPlan( - outputPartitioning = HashPartitioning(exprA :: Nil, 5)) + outputPartitioning = HashPartitioning(exprA :: exprB :: Nil, 5)) var plan2: SparkPlan = DummySparkPlan( - outputPartitioning = HashPartitioning(exprC :: Nil, 5)) + outputPartitioning = HashPartitioning(exprC :: exprD :: Nil, 5)) var smjExec = SortMergeJoinExec( exprA :: exprB :: Nil, exprC :: exprD :: Nil, Inner, None, plan1, plan2) EnsureRequirements.apply(smjExec) match { case SortMergeJoinExec(_, _, _, _, SortExec(_, _, DummySparkPlan(_, _, left: HashPartitioning, _, _), _), SortExec(_, _, DummySparkPlan(_, _, right: HashPartitioning, _, _), _), _) => - assert(left.expressions === Seq(exprA)) - assert(right.expressions === Seq(exprC)) + assert(left.expressions === Seq(exprA, exprB)) + assert(right.expressions === Seq(exprC, exprD)) case other => fail(other.toString) } @@ -521,15 +544,15 @@ class EnsureRequirementsSuite extends SharedSparkSession { outputPartitioning = RangePartitioning( Seq(SortOrder.apply(exprA, Ascending, sameOrderExpressions = Seq.empty)), 10)) plan2 = DummySparkPlan( - outputPartitioning = HashPartitioning(exprD :: Nil, 5)) + outputPartitioning = HashPartitioning(exprC :: exprD :: Nil, 5)) smjExec = SortMergeJoinExec( exprA :: exprB :: Nil, exprC :: exprD :: Nil, Inner, None, plan1, plan2) EnsureRequirements.apply(smjExec) match { case SortMergeJoinExec(_, _, _, _, SortExec(_, _, ShuffleExchangeExec(left: HashPartitioning, _, _), _), SortExec(_, _, DummySparkPlan(_, _, right: HashPartitioning, _, _), _), _) => - assert(left.expressions === Seq(exprB)) - assert(right.expressions === Seq(exprD)) + assert(left.expressions === Seq(exprA, exprB)) + assert(right.expressions === Seq(exprC, exprD)) assert(left.numPartitions == 5) assert(right.numPartitions == 5) case other => fail(other.toString) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/ValidateRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/ValidateRequirementsSuite.scala new file mode 100644 index 0000000000000..6e2eba68d9262 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/ValidateRequirementsSuite.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.exchange + +import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder} +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, SinglePartition} +import org.apache.spark.sql.execution.SortExec +import org.apache.spark.sql.execution.joins.SortMergeJoinExec +import org.apache.spark.sql.test.SharedSparkSession + +class ValidateRequirementsSuite extends PlanTest with SharedSparkSession { + + import testImplicits._ + + private def testValidate( + joinKeyIndices: Seq[Int], + leftPartitionKeyIndices: Seq[Int], + rightPartitionKeyIndices: Seq[Int], + leftPartitionNum: Int, + rightPartitionNum: Int, + success: Boolean): Unit = { + val table1 = + spark.range(10).select(Symbol("id") + 1 as Symbol("a1"), Symbol("id") + 2 as Symbol("b1"), + Symbol("id") + 3 as Symbol("c1")).queryExecution.executedPlan + val table2 = + spark.range(10).select(Symbol("id") + 1 as Symbol("a2"), Symbol("id") + 2 as Symbol("b2"), + Symbol("id") + 3 as Symbol("c2")).queryExecution.executedPlan + + val leftKeys = joinKeyIndices.map(table1.output) + val rightKeys = joinKeyIndices.map(table2.output) + val leftPartitioning = + HashPartitioning(leftPartitionKeyIndices.map(table1.output), leftPartitionNum) + val rightPartitioning = + HashPartitioning(rightPartitionKeyIndices.map(table2.output), rightPartitionNum) + val left = + SortExec(leftKeys.map(SortOrder(_, Ascending)), false, + ShuffleExchangeExec(leftPartitioning, table1)) + val right = + SortExec(rightKeys.map(SortOrder(_, Ascending)), false, + ShuffleExchangeExec(rightPartitioning, table2)) + + val plan = SortMergeJoinExec(leftKeys, rightKeys, Inner, None, left, right) + assert(ValidateRequirements.validate(plan) == success, plan) + } + + test("SMJ requirements satisfied with partial partition key") { + testValidate(Seq(0, 1, 2), Seq(1), Seq(1), 5, 5, true) + } + + test("SMJ requirements satisfied with different partition key order") { + testValidate(Seq(0, 1, 2), Seq(2, 0, 1), Seq(2, 0, 1), 5, 5, true) + } + + test("SMJ requirements not satisfied with unequal partition key order") { + testValidate(Seq(0, 1, 2), Seq(1, 0), Seq(0, 1), 5, 5, false) + } + + test("SMJ requirements not satisfied with unequal partition key length") { + testValidate(Seq(0, 1, 2), Seq(1), Seq(1, 2), 5, 5, false) + } + + test("SMJ requirements not satisfied with partition key missing from join key") { + testValidate(Seq(1, 2), Seq(1, 0), Seq(1, 0), 5, 5, false) + } + + test("SMJ requirements not satisfied with unequal partition number") { + testValidate(Seq(0, 1, 2), Seq(0, 1, 2), Seq(0, 1, 2), 12, 10, false) + } + + test("SMJ with HashPartitioning(1) and SinglePartition") { + val table1 = spark.range(10).queryExecution.executedPlan + val table2 = spark.range(10).queryExecution.executedPlan + val leftPartitioning = HashPartitioning(table1.output, 1) + val rightPartitioning = SinglePartition + val left = + SortExec(table1.output.map(SortOrder(_, Ascending)), false, + ShuffleExchangeExec(leftPartitioning, table1)) + val right = + SortExec(table2.output.map(SortOrder(_, Ascending)), false, + ShuffleExchangeExec(rightPartitioning, table2)) + + val plan = SortMergeJoinExec(table1.output, table2.output, Inner, None, left, right) + assert(ValidateRequirements.validate(plan), plan) + } + + private def testNestedJoin( + joinKeyIndices1: Seq[(Int, Int)], + joinKeyIndices2: Seq[(Int, Int)], + partNums: Seq[Int], + success: Boolean): Unit = { + val table1 = + spark.range(10).select(Symbol("id") + 1 as Symbol("a1"), Symbol("id") + 2 as Symbol("b1"), + Symbol("id") + 3 as Symbol("c1")).queryExecution.executedPlan + val table2 = + spark.range(10).select(Symbol("id") + 1 as Symbol("a2"), Symbol("id") + 2 as Symbol("b2"), + Symbol("id") + 3 as Symbol("c2")).queryExecution.executedPlan + val table3 = + spark.range(10).select(Symbol("id") + 1 as Symbol("a3"), Symbol("id") + 2 as Symbol("b3"), + Symbol("id") + 3 as Symbol("c3")).queryExecution.executedPlan + + val key1 = joinKeyIndices1.map(_._1).map(table1.output) + val key2 = joinKeyIndices1.map(_._2).map(table2.output) + val key3 = joinKeyIndices2.map(_._1).map(table3.output) + val key4 = joinKeyIndices2.map(_._2).map(table1.output ++ table2.output) + val partitioning1 = HashPartitioning(key1, partNums(0)) + val partitioning2 = HashPartitioning(key2, partNums(1)) + val partitioning3 = HashPartitioning(key3, partNums(2)) + val joinRel1 = + SortExec(key1.map(SortOrder(_, Ascending)), false, ShuffleExchangeExec(partitioning1, table1)) + val joinRel2 = + SortExec(key2.map(SortOrder(_, Ascending)), false, ShuffleExchangeExec(partitioning2, table2)) + val joinRel3 = + SortExec(key3.map(SortOrder(_, Ascending)), false, ShuffleExchangeExec(partitioning3, table3)) + + val plan = SortMergeJoinExec(key3, key4, Inner, None, + joinRel3, SortMergeJoinExec(key1, key2, Inner, None, joinRel1, joinRel2)) + assert(ValidateRequirements.validate(plan) == success, plan) + } + + test("ValidateRequirements should work bottom up") { + Seq(true, false).foreach { success => + testNestedJoin(Seq((0, 0)), Seq((0, 0)), Seq(5, if (success) 5 else 10, 5), success) + } + } + + test("PartitioningCollection exact match") { + testNestedJoin(Seq((0, 0), (1, 1)), Seq((0, 0), (1, 1)), Seq(5, 5, 5), true) + testNestedJoin(Seq((0, 0), (1, 1)), Seq((0, 3), (1, 4)), Seq(5, 5, 5), true) + } + + test("PartitioningCollection mismatch with different order") { + testNestedJoin(Seq((0, 0), (1, 1)), Seq((1, 1), (0, 0)), Seq(5, 5, 5), false) + testNestedJoin(Seq((0, 0), (1, 1)), Seq((1, 4), (0, 3)), Seq(5, 5, 5), false) + } + + test("PartitioningCollection mismatch with different set") { + testNestedJoin(Seq((1, 1)), Seq((2, 2), (1, 1)), Seq(5, 5, 5), false) + testNestedJoin(Seq((1, 1)), Seq((2, 5), (1, 4)), Seq(5, 5, 5), false) + } + + test("PartitioningCollection mismatch with key missing from required") { + testNestedJoin(Seq((2, 2), (1, 1)), Seq((2, 2)), Seq(5, 5, 5), false) + testNestedJoin(Seq((2, 2), (1, 1)), Seq((2, 5)), Seq(5, 5, 5), false) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index a8b4856261d83..256e942620272 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -402,13 +402,12 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils assert(b.buildSide === buildSide) case w: WholeStageCodegenExec => assert(w.children.head.getClass.getSimpleName === joinMethod) - if (w.children.head.isInstanceOf[BroadcastNestedLoopJoinExec]) { - assert( - w.children.head.asInstanceOf[BroadcastNestedLoopJoinExec].buildSide === buildSide) - } else if (w.children.head.isInstanceOf[BroadcastHashJoinExec]) { - assert(w.children.head.asInstanceOf[BroadcastHashJoinExec].buildSide === buildSide) - } else { - fail() + w.children.head match { + case bnlj: BroadcastNestedLoopJoinExec => + assert(bnlj.buildSide === buildSide) + case bhj: BroadcastHashJoinExec => + assert(bhj.buildSide === buildSide) + case _ => fail() } } } @@ -416,8 +415,8 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils test("Broadcast timeout") { val timeout = 5 val slowUDF = udf({ x: Int => Thread.sleep(timeout * 1000); x }) - val df1 = spark.range(10).select($"id" as 'a) - val df2 = spark.range(5).select(slowUDF($"id") as 'a) + val df1 = spark.range(10).select($"id" as Symbol("a")) + val df2 = spark.range(5).select(slowUDF($"id") as Symbol("a")) val testDf = df1.join(broadcast(df2), "a") withSQLConf(SQLConf.BROADCAST_TIMEOUT.key -> timeout.toString) { if (!conf.adaptiveExecutionEnabled) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index b8ffc47d6ec3c..6c87178f267c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.memory.{TaskMemoryManager, UnifiedMemoryManager} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.unsafe.map.BytesToBytesMap @@ -92,6 +93,9 @@ class HashedRelationSuite extends SharedSparkSession { assert(hashed2.get(toUnsafe(InternalRow(10))) === null) assert(hashed2.get(unsafeData(2)).toArray === data2) + // SPARK-38542: UnsafeHashedRelation should serialize numKeys out + assert(hashed2.keys().map(_.copy()).forall(_.numFields == 1)) + val os2 = new ByteArrayOutputStream() val out2 = new ObjectOutputStream(os2) hashed2.writeExternal(out2) @@ -610,20 +614,25 @@ class HashedRelationSuite extends SharedSparkSession { val keys = Seq(BoundReference(0, ByteType, false), BoundReference(1, IntegerType, false), BoundReference(2, ShortType, false)) - val packed = HashJoin.rewriteKeyExpr(keys) - val unsafeProj = UnsafeProjection.create(packed) - val packedKeys = unsafeProj(row) - - Seq((0, ByteType), (1, IntegerType), (2, ShortType)).foreach { case (i, dt) => - val key = HashJoin.extractKeyExprAt(keys, i) - val proj = UnsafeProjection.create(key) - assert(proj(packedKeys).get(0, dt) == -i - 1) + // Rewrite and exacting key expressions should not cause exception when ANSI mode is on. + Seq("false", "true").foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled) { + val packed = HashJoin.rewriteKeyExpr(keys) + val unsafeProj = UnsafeProjection.create(packed) + val packedKeys = unsafeProj(row) + + Seq((0, ByteType), (1, IntegerType), (2, ShortType)).foreach { case (i, dt) => + val key = HashJoin.extractKeyExprAt(keys, i) + val proj = UnsafeProjection.create(key) + assert(proj(packedKeys).get(0, dt) == -i - 1) + } + } } } test("EmptyHashedRelation override methods behavior test") { val buildKey = Seq(BoundReference(0, LongType, false)) - val hashed = HashedRelation(Seq.empty[InternalRow].toIterator, buildKey, 1, mm) + val hashed = HashedRelation(Seq.empty[InternalRow].iterator, buildKey, 1, mm) assert(hashed == EmptyHashedRelation) val key = InternalRow(1L) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 0fd5c892e2c42..063f18622646c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -79,7 +79,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils // Assume the execution plan is // PhysicalRDD(nodeId = 1) -> Filter(nodeId = 0) Seq((0L, false), (1L, true)).foreach { case (nodeId, enableWholeStage) => - val df = person.filter('age < 25) + val df = person.filter(Symbol("age") < 25) testSparkPlanMetrics(df, 1, Map( nodeId -> (("Filter", Map( "number of output rows" -> 1L)))), @@ -94,7 +94,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils // Filter(nodeId = 1) // Range(nodeId = 2) // TODO: update metrics in generated operators - val ds = spark.range(10).filter('id < 5) + val ds = spark.range(10).filter(Symbol("id") < 5) testSparkPlanMetricsWithPredicates(ds.toDF(), 1, Map( 0L -> (("WholeStageCodegen (1)", Map( "duration" -> { @@ -109,11 +109,11 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils val df = testData2.groupBy().count() // 2 partitions val expected1 = Seq( Map("number of output rows" -> 2L, - "avg hash probe bucket list iters" -> + "avg hash probes per key" -> aggregateMetricsPattern, "number of sort fallback tasks" -> 0L), Map("number of output rows" -> 1L, - "avg hash probe bucket list iters" -> + "avg hash probes per key" -> aggregateMetricsPattern, "number of sort fallback tasks" -> 0L)) val shuffleExpected1 = Map( @@ -128,14 +128,14 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils ) // 2 partitions and each partition contains 2 keys - val df2 = testData2.groupBy('a).count() + val df2 = testData2.groupBy(Symbol("a")).count() val expected2 = Seq( Map("number of output rows" -> 4L, - "avg hash probe bucket list iters" -> + "avg hash probes per key" -> aggregateMetricsPattern, "number of sort fallback tasks" -> 0L), Map("number of output rows" -> 3L, - "avg hash probe bucket list iters" -> + "avg hash probes per key" -> aggregateMetricsPattern, "number of sort fallback tasks" -> 0L)) @@ -176,7 +176,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils // Exchange(nodeId = 5) // LocalTableScan(nodeId = 6) Seq(true, false).foreach { enableWholeStage => - val df = generateRandomBytesDF().repartition(2).groupBy('a).count() + val df = generateRandomBytesDF().repartition(2).groupBy(Symbol("a")).count() val nodeIds = if (enableWholeStage) { Set(4L, 1L) } else { @@ -184,7 +184,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils } val metrics = getSparkPlanMetrics(df, 1, nodeIds, enableWholeStage).get nodeIds.foreach { nodeId => - val probes = metrics(nodeId)._2("avg hash probe bucket list iters").toString + val probes = metrics(nodeId)._2("avg hash probes per key").toString if (!probes.contains("\n")) { // It's a single metrics value assert(probes.toDouble > 1.0) @@ -204,7 +204,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils // Assume the execution plan is // ... -> ObjectHashAggregate(nodeId = 2) -> Exchange(nodeId = 1) // -> ObjectHashAggregate(nodeId = 0) - val df = testData2.groupBy().agg(collect_set('a)) // 2 partitions + val df = testData2.groupBy().agg(collect_set(Symbol("a"))) // 2 partitions testSparkPlanMetrics(df, 1, Map( 2L -> (("ObjectHashAggregate", Map("number of output rows" -> 2L))), 1L -> (("Exchange", Map( @@ -216,7 +216,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils ) // 2 partitions and each partition contains 2 keys - val df2 = testData2.groupBy('a).agg(collect_set('a)) + val df2 = testData2.groupBy(Symbol("a")).agg(collect_set(Symbol("a"))) testSparkPlanMetrics(df2, 1, Map( 2L -> (("ObjectHashAggregate", Map( "number of output rows" -> 4L, @@ -233,7 +233,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils // 2 partitions and each partition contains 2 keys, with fallback to sort-based aggregation withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "1") { - val df3 = testData2.groupBy('a).agg(collect_set('a)) + val df3 = testData2.groupBy(Symbol("a")).agg(collect_set(Symbol("a"))) testSparkPlanMetrics(df3, 1, Map( 2L -> (("ObjectHashAggregate", Map( "number of output rows" -> 4L, @@ -263,7 +263,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils // LocalTableScan(nodeId = 3) // Because of SPARK-25267, ConvertToLocalRelation is disabled in the test cases of sql/core, // so Project here is not collapsed into LocalTableScan. - val df = Seq(1, 3, 2).toDF("id").sort('id) + val df = Seq(1, 3, 2).toDF("id").sort(Symbol("id")) testSparkPlanMetricsWithPredicates(df, 2, Map( 0L -> (("Sort", Map( "sort time" -> { @@ -281,7 +281,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils test("SortMergeJoin metrics") { // Because SortMergeJoin may skip different rows if the number of partitions is different, this // test should use the deterministic number of partitions. - val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + val testDataForJoin = testData2.filter(Symbol("a") < 2) // TestData2(1, 1) :: TestData2(1, 2) testDataForJoin.createOrReplaceTempView("testDataForJoin") withTempView("testDataForJoin") { // Assume the execution plan is @@ -314,7 +314,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils test("SortMergeJoin(outer) metrics") { // Because SortMergeJoin may skip different rows if the number of partitions is different, // this test should use the deterministic number of partitions. - val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + val testDataForJoin = testData2.filter(Symbol("a") < 2) // TestData2(1, 1) :: TestData2(1, 2) testDataForJoin.createOrReplaceTempView("testDataForJoin") withTempView("testDataForJoin") { // Assume the execution plan is @@ -372,7 +372,8 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils val df = df1.join(df2, "key") testSparkPlanMetrics(df, 1, Map( nodeId1 -> (("ShuffledHashJoin", Map( - "number of output rows" -> 2L))), + "number of output rows" -> 2L, + "avg hash probes per key" -> aggregateMetricsPattern))), nodeId2 -> (("Exchange", Map( "shuffle records written" -> 2L, "records read" -> 2L))), @@ -401,7 +402,8 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils rightDf.hint("shuffle_hash"), $"key" === $"key2", joinType) testSparkPlanMetrics(df, 1, Map( nodeId -> (("ShuffledHashJoin", Map( - "number of output rows" -> rows)))), + "number of output rows" -> rows, + "avg hash probes per key" -> aggregateMetricsPattern)))), enableWholeStage ) } @@ -459,7 +461,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils } test("BroadcastNestedLoopJoin metrics") { - val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + val testDataForJoin = testData2.filter(Symbol("a") < 2) // TestData2(1, 1) :: TestData2(1, 2) testDataForJoin.createOrReplaceTempView("testDataForJoin") withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { withTempView("testDataForJoin") { @@ -512,7 +514,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils test("CartesianProduct metrics") { withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + val testDataForJoin = testData2.filter(Symbol("a") < 2) // TestData2(1, 1) :: TestData2(1, 2) testDataForJoin.createOrReplaceTempView("testDataForJoin") withTempView("testDataForJoin") { // Assume the execution plan is @@ -547,7 +549,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils test("save metrics") { withTempPath { file => // person creates a temporary view. get the DF before listing previous execution IDs - val data = person.select('name) + val data = person.select(Symbol("name")) val previousExecutionIds = currentExecutionIds() // Assume the execution plan is // PhysicalRDD(nodeId = 0) @@ -704,7 +706,8 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") { // A special query that only has one partition, so there is no shuffle and the entire query // can be whole-stage-codegened. - val df = spark.range(0, 1500, 1, 1).limit(10).groupBy('id).count().limit(1).filter('id >= 0) + val df = spark.range(0, 1500, 1, 1).limit(10).groupBy(Symbol("id")) + .count().limit(1).filter('id >= 0) df.collect() val plan = df.queryExecution.executedPlan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala index a508f923ffa13..f06e62b33b1a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution.streaming +import java.io.File + +import org.apache.commons.io.FileUtils import org.scalatest.BeforeAndAfter import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} @@ -24,8 +27,9 @@ import org.apache.spark.sql.catalyst.plans.logical.Range import org.apache.spark.sql.connector.read.streaming import org.apache.spark.sql.connector.read.streaming.SparkDataStream import org.apache.spark.sql.functions.{count, timestamp_seconds, window} -import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.streaming.{StreamTest, Trigger} import org.apache.spark.sql.types.{LongType, StructType} +import org.apache.spark.util.Utils class MicroBatchExecutionSuite extends StreamTest with BeforeAndAfter { @@ -40,8 +44,8 @@ class MicroBatchExecutionSuite extends StreamTest with BeforeAndAfter { val df = inputData.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) testStream(df)( @@ -74,13 +78,34 @@ class MicroBatchExecutionSuite extends StreamTest with BeforeAndAfter { ) } + test("SPARK-38033: SS cannot be started because the commitId and offsetId are inconsistent") { + val inputData = MemoryStream[Int] + val streamEvent = inputData.toDF().select("value") + + val resourceUri = this.getClass.getResource( + "/structured-streaming/checkpoint-test-offsetId-commitId-inconsistent/").toURI + + val checkpointDir = Utils.createTempDir().getCanonicalFile + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + + testStream(streamEvent) ( + AddData(inputData, 1, 2, 3, 4, 5, 6), + StartStream(Trigger.Once, checkpointLocation = checkpointDir.getAbsolutePath), + ExpectFailure[IllegalStateException] { e => + assert(e.getMessage.contains("batch 3 doesn't exist")) + } + ) + } + test("no-data-batch re-executed after restart should call V1 source.getBatch()") { val testSource = ReExecutedBatchTestSource(spark) val df = testSource.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"window".getField("start").cast("long").as[Long]) /** Reset this test source so that it appears to be a new source requiring initialization */ @@ -153,7 +178,6 @@ class MicroBatchExecutionSuite extends StreamTest with BeforeAndAfter { ) } - case class ReExecutedBatchTestSource(spark: SparkSession) extends Source { @volatile var currentOffset = 0L @volatile var getBatchCallCount = 0 @@ -191,4 +215,3 @@ class MicroBatchExecutionSuite extends StreamTest with BeforeAndAfter { } } } - diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala index 5884380271f0e..11dbf9c2beaa1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala @@ -141,7 +141,7 @@ class ConsoleWriteSupportSuite extends StreamTest { .option("numPartitions", "1") .option("rowsPerSecond", "5") .load() - .select('value) + .select(Symbol("value")) val query = input.writeStream.format("console").trigger(Trigger.Continuous(200)).start() assert(query.isActive) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala index a0bd0fb582ca2..ce98e2e6a5bb6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala @@ -160,9 +160,9 @@ class ForeachBatchSinkSuite extends StreamTest { var planAsserted = false val writer: (Dataset[T], Long) => Unit = { case (df, _) => - assert(df.queryExecution.executedPlan.find { p => + assert(!df.queryExecution.executedPlan.exists { p => p.isInstanceOf[SerializeFromObjectExec] - }.isEmpty, "Untyped Dataset should not introduce serialization on object!") + }, "Untyped Dataset should not introduce serialization on object!") planAsserted = true } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala index 0fe339b93047a..46440c98226aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala @@ -165,8 +165,8 @@ class ForeachWriterSuite extends StreamTest with SharedSparkSession with BeforeA val windowedAggregation = inputData.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"count".as[Long]) .map(_.toInt) .repartition(1) @@ -199,8 +199,8 @@ class ForeachWriterSuite extends StreamTest with SharedSparkSession with BeforeA val windowedAggregation = inputData.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"count".as[Long]) .map(_.toInt) .repartition(1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RatePerMicroBatchProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RatePerMicroBatchProviderSuite.scala index 449aea8256673..fe846acab28ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RatePerMicroBatchProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RatePerMicroBatchProviderSuite.scala @@ -60,7 +60,7 @@ class RatePerMicroBatchProviderSuite extends StreamTest { .format("rate-micro-batch") .option("rowsPerBatch", "10") .load() - .select('value) + .select(Symbol("value")) val clock = new StreamManualClock testStream(input)( @@ -97,7 +97,7 @@ class RatePerMicroBatchProviderSuite extends StreamTest { .format("rate-micro-batch") .option("rowsPerBatch", "10") .load() - .select('value) + .select(Symbol("value")) val clock = new StreamManualClock testStream(input)( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index 6440e69e2ec23..2c1bb41302c11 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -83,7 +83,7 @@ class RateStreamProviderSuite extends StreamTest { .format("rate") .option("rowsPerSecond", "10") .load() - .select('value) + .select(Symbol("value")) var streamDuration = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala index d4792301a1ce5..0678cfc38660e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala @@ -67,7 +67,7 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest { val inputData = MemoryStream[Int] val query = inputData.toDS().toDF("value") - .select('value) + .select(Symbol("value")) .groupBy($"value") .agg(count("*")) .writeStream @@ -119,7 +119,7 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest { def startQuery(): StreamingQuery = { inputData.toDS().toDF("value") - .select('value) + .select(Symbol("value")) .groupBy($"value") .agg(count("*")) .writeStream @@ -156,7 +156,7 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest { SQLConf.STATE_STORE_ROCKSDB_FORMAT_VERSION.key -> "100") { val inputData = MemoryStream[Int] val query = inputData.toDS().toDF("value") - .select('value) + .select(Symbol("value")) .groupBy($"value") .agg(count("*")) .writeStream @@ -179,7 +179,7 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest { val inputData = MemoryStream[Int] val query = inputData.toDS().toDF("value") - .select('value) + .select(Symbol("value")) .groupBy($"value") .agg(count("*")) .writeStream diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala index a9cc90ca45ce8..1539341359337 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala @@ -63,6 +63,8 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { private val valueSchema65535Bytes = new StructType() .add(StructField("v" * (65535 - 87), IntegerType, nullable = true)) + // Checks on adding/removing (nested) field. + test("adding field to key should fail") { val fieldAddedKeySchema = keySchema.add(StructField("newKey", IntegerType)) verifyException(keySchema, valueSchema, fieldAddedKeySchema, valueSchema) @@ -107,6 +109,8 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { verifyException(keySchema, valueSchema, keySchema, newValueSchema) } + // Checks on changing type of (nested) field. + test("changing the type of field in key should fail") { val typeChangedKeySchema = StructType(keySchema.map(_.copy(dataType = TimestampType))) verifyException(keySchema, valueSchema, typeChangedKeySchema, valueSchema) @@ -129,28 +133,59 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { verifyException(keySchema, valueSchema, keySchema, newValueSchema) } - test("changing the nullability of nullable to non-nullable in key should fail") { + // Checks on changing nullability of (nested) field. + // Note that these tests have different format of the test name compared to others, since it was + // misleading to understand the assignment as the opposite way. + + test("storing non-nullable column into nullable column in key should be allowed") { val nonNullChangedKeySchema = StructType(keySchema.map(_.copy(nullable = false))) - verifyException(keySchema, valueSchema, nonNullChangedKeySchema, valueSchema) + verifySuccess(keySchema, valueSchema, nonNullChangedKeySchema, valueSchema) } - test("changing the nullability of nullable to non-nullable in value should fail") { + test("storing non-nullable column into nullable column in value schema should be allowed") { val nonNullChangedValueSchema = StructType(valueSchema.map(_.copy(nullable = false))) - verifyException(keySchema, valueSchema, keySchema, nonNullChangedValueSchema) + verifySuccess(keySchema, valueSchema, keySchema, nonNullChangedValueSchema) } - test("changing the nullability of nullable to nonnullable in nested field in key should fail") { + test("storing non-nullable into nullable in nested field in key should be allowed") { val typeChangedNestedSchema = StructType(structSchema.map(_.copy(nullable = false))) val newKeySchema = applyNewSchemaToNestedFieldInKey(typeChangedNestedSchema) - verifyException(keySchema, valueSchema, newKeySchema, valueSchema) + verifySuccess(keySchema, valueSchema, newKeySchema, valueSchema) } - test("changing the nullability of nullable to nonnullable in nested field in value should fail") { + test("storing non-nullable into nullable in nested field in value should be allowed") { val typeChangedNestedSchema = StructType(structSchema.map(_.copy(nullable = false))) val newValueSchema = applyNewSchemaToNestedFieldInValue(typeChangedNestedSchema) - verifyException(keySchema, valueSchema, keySchema, newValueSchema) + verifySuccess(keySchema, valueSchema, keySchema, newValueSchema) + } + + test("storing nullable column into non-nullable column in key should fail") { + val nonNullChangedKeySchema = StructType(keySchema.map(_.copy(nullable = false))) + verifyException(nonNullChangedKeySchema, valueSchema, keySchema, valueSchema) + } + + test("storing nullable column into non-nullable column in value schema should fail") { + val nonNullChangedValueSchema = StructType(valueSchema.map(_.copy(nullable = false))) + verifyException(keySchema, nonNullChangedValueSchema, keySchema, valueSchema) + } + + test("storing nullable column into non-nullable column in nested field in key should fail") { + val typeChangedNestedSchema = StructType(structSchema.map(_.copy(nullable = false))) + val newKeySchema = applyNewSchemaToNestedFieldInKey(typeChangedNestedSchema) + verifyException(newKeySchema, valueSchema, keySchema, valueSchema) } + test("storing nullable column into non-nullable column in nested field in value should fail") { + val typeChangedNestedSchema = StructType(structSchema.map(_.copy(nullable = false))) + val newValueSchema = applyNewSchemaToNestedFieldInValue(typeChangedNestedSchema) + verifyException(keySchema, newValueSchema, keySchema, valueSchema) + } + + // Checks on changing name of (nested) field. + // Changing the name is allowed since it may be possible Spark can make relevant changes from + // operators/functions by chance. This opens a risk that end users swap two fields having same + // data type, but there is no way to address both. + test("changing the name of field in key should be allowed") { val newName: StructField => StructField = f => f.copy(name = f.name + "_new") val fieldNameChangedKeySchema = StructType(keySchema.map(newName)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 601b62bd81007..dde925bb2d96f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -1017,6 +1017,64 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] } } + // This test illustrates state store iterator behavior differences leading to SPARK-38320. + testWithAllCodec("SPARK-38320 - state store iterator behavior differences") { + val ROCKSDB_STATE_STORE = "RocksDBStateStore" + val dir = newDir() + val storeId = StateStoreId(dir, 0L, 1) + var version = 0L + + tryWithProviderResource(newStoreProvider(storeId)) { provider => + val store = provider.getStore(version) + logInfo(s"Running SPARK-38320 test with state store ${store.getClass.getName}") + + val itr1 = store.iterator() // itr1 is created before any writes to the store. + put(store, "1", 11, 100) + put(store, "2", 22, 200) + val itr2 = store.iterator() // itr2 is created in the middle of the writes. + put(store, "1", 11, 101) // Overwrite row (1, 11) + put(store, "3", 33, 300) + val itr3 = store.iterator() // itr3 is created after all writes. + + val expected = Set(("1", 11) -> 101, ("2", 22) -> 200, ("3", 33) -> 300) // The final state. + // Itr1 does not see any updates - original state of the store (SPARK-38320) + assert(rowPairsToDataSet(itr1) === Set.empty[Set[((String, Int), Int)]]) + assert(rowPairsToDataSet(itr2) === expected) + assert(rowPairsToDataSet(itr3) === expected) + + version = store.commit() + } + + // Reload the store from the commited version and repeat the above test. + tryWithProviderResource(newStoreProvider(storeId)) { provider => + assert(version > 0) + val store = provider.getStore(version) + + val itr1 = store.iterator() // itr1 is created before any writes to the store. + put(store, "3", 33, 301) // Overwrite row (3, 33) + put(store, "4", 44, 400) + val itr2 = store.iterator() // itr2 is created in the middle of the writes. + put(store, "4", 44, 401) // Overwrite row (4, 44) + put(store, "5", 55, 500) + val itr3 = store.iterator() // itr3 is created after all writes. + + // The final state. + val expected = Set( + ("1", 11) -> 101, ("2", 22) -> 200, ("3", 33) -> 301, ("4", 44) -> 401, ("5", 55) -> 500) + if (store.getClass.getName contains ROCKSDB_STATE_STORE) { + // RocksDB itr1 does not see any updates - original state of the store (SPARK-38320) + assert(rowPairsToDataSet(itr1) === Set( + ("1", 11) -> 101, ("2", 22) -> 200, ("3", 33) -> 300)) + } else { + assert(rowPairsToDataSet(itr1) === expected) + } + assert(rowPairsToDataSet(itr2) === expected) + assert(rowPairsToDataSet(itr3) === expected) + + version = store.commit() + } + } + test("StateStore.get") { quietly { val dir = newDir() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index 8eaeefccc5ec3..9b5b532d3ecdc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -878,7 +878,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils val oldCount = statusStore.executionsList().size val cls = classOf[CustomMetricsDataSource].getName - spark.range(10).select('id as 'i, -'id as 'j).write.format(cls) + spark.range(10).select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j")) + .write.format(cls) .option("path", dir.getCanonicalPath).mode("append").save() // Wait until the new execution is started and being tracked. @@ -919,7 +920,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils try { val cls = classOf[CustomMetricsDataSource].getName - spark.range(0, 10, 1, 2).select('id as 'i, -'id as 'j).write.format(cls) + spark.range(0, 10, 1, 2).select(Symbol("id") as Symbol("i"), -'id as Symbol("j")) + .write.format(cls) .option("path", dir.getCanonicalPath).mode("append").save() // Wait until the new execution is started and being tracked. @@ -933,6 +935,7 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils statusStore.executionsList().last.metricValues != null) } + spark.sparkContext.listenerBus.waitUntilEmpty() assert(bytesWritten.sum == 246) assert(recordsWritten.sum == 20) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index cdf41ed651d4e..4cf2376a3fccd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.execution.columnar.ColumnAccessor import org.apache.spark.sql.execution.columnar.compression.ColumnBuilderHelper import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarArray -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { private def withVector( @@ -605,5 +605,14 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } } } + + test("SPARK-38018: ColumnVectorUtils.populate to handle CalendarIntervalType correctly") { + val vector = new OnHeapColumnVector(5, CalendarIntervalType) + val row = new SpecificInternalRow(Array(CalendarIntervalType)) + val interval = new CalendarInterval(3, 5, 1000000) + row.setInterval(0, interval) + ColumnVectorUtils.populate(vector, row, 0) + assert(vector.getInterval(0) === interval) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 738f2281c9a65..0395798d9e7ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.vectorized import java.nio.ByteBuffer import java.nio.ByteOrder import java.nio.charset.StandardCharsets +import java.time.LocalDateTime import java.util import java.util.NoSuchElementException @@ -1591,10 +1592,21 @@ class ColumnarBatchSuite extends SparkFunSuite { )) :: StructField("int_to_int", MapType(IntegerType, IntegerType)) :: StructField("binary", BinaryType) :: + StructField("ts_ntz", TimestampNTZType) :: Nil) var mapBuilder = new ArrayBasedMapBuilder(IntegerType, IntegerType) mapBuilder.put(1, 10) mapBuilder.put(20, null) + + val tsString1 = "2015-01-01 23:50:59.123" + val ts1 = DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf(tsString1)) + val tsNTZ1 = + DateTimeUtils.localDateTimeToMicros(LocalDateTime.parse(tsString1.replace(" ", "T"))) + val tsString2 = "1880-01-05 12:45:21.321" + val ts2 = DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf(tsString2)) + val tsNTZ2 = + DateTimeUtils.localDateTimeToMicros(LocalDateTime.parse(tsString2.replace(" ", "T"))) + val row1 = new GenericInternalRow(Array[Any]( UTF8String.fromString("a string"), true, @@ -1606,12 +1618,13 @@ class ColumnarBatchSuite extends SparkFunSuite { 0.75D, Decimal("1234.23456"), DateTimeUtils.fromJavaDate(java.sql.Date.valueOf("2015-01-01")), - DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf("2015-01-01 23:50:59.123")), + ts1, new CalendarInterval(1, 0, 0), new GenericArrayData(Array(1, 2, 3, 4, null)), new GenericInternalRow(Array[Any](5.asInstanceOf[Any], 10)), mapBuilder.build(), - "Spark SQL".getBytes() + "Spark SQL".getBytes(), + tsNTZ1 )) mapBuilder = new ArrayBasedMapBuilder(IntegerType, IntegerType) @@ -1628,12 +1641,13 @@ class ColumnarBatchSuite extends SparkFunSuite { Double.PositiveInfinity, Decimal("0.01000"), DateTimeUtils.fromJavaDate(java.sql.Date.valueOf("1875-12-12")), - DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf("1880-01-05 12:45:21.321")), + ts2, new CalendarInterval(-10, -50, -100), new GenericArrayData(Array(5, 10, -100)), new GenericInternalRow(Array[Any](20.asInstanceOf[Any], null)), mapBuilder.build(), - "Parquet".getBytes() + "Parquet".getBytes(), + tsNTZ2 )) val row3 = new GenericInternalRow(Array[Any]( @@ -1652,6 +1666,7 @@ class ColumnarBatchSuite extends SparkFunSuite { null, null, null, + null, null )) @@ -1716,10 +1731,8 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(columns(9).isNullAt(2)) assert(columns(10).dataType() == TimestampType) - assert(columns(10).getLong(0) == - DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf("2015-01-01 23:50:59.123"))) - assert(columns(10).getLong(1) == - DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf("1880-01-05 12:45:21.321"))) + assert(columns(10).getLong(0) == ts1) + assert(columns(10).getLong(1) == ts2) assert(columns(10).isNullAt(2)) assert(columns(11).dataType() == CalendarIntervalType) @@ -1777,6 +1790,11 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(new String(columns(15).getBinary(0)) == "Spark SQL") assert(new String(columns(15).getBinary(1)) == "Parquet") assert(columns(15).isNullAt(2)) + + assert(columns(16).dataType() == TimestampNTZType) + assert(columns(16).getLong(0) == tsNTZ1) + assert(columns(16).getLong(1) == tsNTZ2) + assert(columns(16).isNullAt(2)) } finally { batch.close() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ConstantColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ConstantColumnVectorSuite.scala new file mode 100644 index 0000000000000..2bee643df4eff --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ConstantColumnVectorSuite.scala @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.vectorized + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ColumnarArray, ColumnarMap} +import org.apache.spark.unsafe.types.UTF8String + +class ConstantColumnVectorSuite extends SparkFunSuite { + + private def testVector(name: String, size: Int, dt: DataType) + (f: ConstantColumnVector => Unit): Unit = { + test(name) { + val vector = new ConstantColumnVector(size, dt) + f(vector) + vector.close() + } + } + + testVector("null", 10, IntegerType) { vector => + vector.setNull() + assert(vector.hasNull) + assert(vector.numNulls() == 10) + (0 until 10).foreach { i => + assert(vector.isNullAt(i)) + } + + vector.setNotNull() + assert(!vector.hasNull) + assert(vector.numNulls() == 0) + (0 until 10).foreach { i => + assert(!vector.isNullAt(i)) + } + } + + testVector("boolean", 10, BooleanType) { vector => + vector.setBoolean(true) + (0 until 10).foreach { i => + assert(vector.getBoolean(i)) + } + } + + testVector("byte", 10, ByteType) { vector => + vector.setByte(3.toByte) + (0 until 10).foreach { i => + assert(vector.getByte(i) == 3.toByte) + } + } + + testVector("short", 10, ShortType) { vector => + vector.setShort(3.toShort) + (0 until 10).foreach { i => + assert(vector.getShort(i) == 3.toShort) + } + } + + testVector("int", 10, IntegerType) { vector => + vector.setInt(3) + (0 until 10).foreach { i => + assert(vector.getInt(i) == 3) + } + } + + testVector("long", 10, LongType) { vector => + vector.setLong(3L) + (0 until 10).foreach { i => + assert(vector.getLong(i) == 3L) + } + } + + testVector("float", 10, FloatType) { vector => + vector.setFloat(3.toFloat) + (0 until 10).foreach { i => + assert(vector.getFloat(i) == 3.toFloat) + } + } + + testVector("double", 10, DoubleType) { vector => + vector.setDouble(3.toDouble) + (0 until 10).foreach { i => + assert(vector.getDouble(i) == 3.toDouble) + } + } + + testVector("array", 10, ArrayType(IntegerType)) { vector => + // create an vector with constant array: [0, 1, 2, 3, 4] + val arrayVector = new OnHeapColumnVector(5, IntegerType) + (0 until 5).foreach { i => + arrayVector.putInt(i, i) + } + val columnarArray = new ColumnarArray(arrayVector, 0, 5) + + vector.setArray(columnarArray) + + (0 until 10).foreach { i => + assert(vector.getArray(i) == columnarArray) + assert(vector.getArray(i).toIntArray === Array(0, 1, 2, 3, 4)) + } + } + + testVector("map", 10, MapType(IntegerType, BooleanType)) { vector => + // create an vector with constant map: + // [(0, true), (1, false), (2, true), (3, false), (4, true)] + val keys = new OnHeapColumnVector(5, IntegerType) + val values = new OnHeapColumnVector(5, BooleanType) + + (0 until 5).foreach { i => + keys.putInt(i, i) + values.putBoolean(i, i % 2 == 0) + } + + val columnarMap = new ColumnarMap(keys, values, 0, 5) + vector.setMap(columnarMap) + + (0 until 10).foreach { i => + assert(vector.getMap(i) == columnarMap) + assert(vector.getMap(i).keyArray().toIntArray === Array(0, 1, 2, 3, 4)) + assert(vector.getMap(i).valueArray().toBooleanArray === + Array(true, false, true, false, true)) + } + } + + testVector("decimal", 10, DecimalType(10, 0)) { vector => + val decimal = Decimal(100L) + vector.setDecimal(decimal, 10) + (0 until 10).foreach { i => + assert(vector.getDecimal(i, 10, 0) == decimal) + } + } + + testVector("utf8string", 10, StringType) { vector => + vector.setUtf8String(UTF8String.fromString("hello")) + (0 until 10).foreach { i => + assert(vector.getUTF8String(i) == UTF8String.fromString("hello")) + } + } + + testVector("binary", 10, BinaryType) { vector => + vector.setBinary("hello".getBytes("utf8")) + (0 until 10).foreach { i => + assert(vector.getBinary(i) === "hello".getBytes("utf8")) + } + } + + testVector("struct", 10, + new StructType() + .add(StructField("name", StringType)) + .add(StructField("age", IntegerType))) { vector => + + val nameVector = new ConstantColumnVector(10, StringType) + nameVector.setUtf8String(UTF8String.fromString("jack")) + vector.setChild(0, nameVector) + + val ageVector = new ConstantColumnVector(10, IntegerType) + ageVector.setInt(27) + vector.setChild(1, ageVector) + + + assert(vector.getChild(0) == nameVector) + assert(vector.getChild(1) == ageVector) + (0 until 10).foreach { i => + assert(vector.getChild(0).getUTF8String(i) == UTF8String.fromString("jack")) + assert(vector.getChild(1).getInt(i) == 27) + } + + // another API + (0 until 10).foreach { i => + assert(vector.getStruct(i).get(0, StringType) == UTF8String.fromString("jack")) + assert(vector.getStruct(i).get(1, IntegerType) == 27) + } + } + + testVector("calendar interval", 10, CalendarIntervalType) { vector => + val monthsVector = new ConstantColumnVector(10, IntegerType) + monthsVector.setInt(3) + val daysVector = new ConstantColumnVector(10, IntegerType) + daysVector.setInt(25) + val microsecondsVector = new ConstantColumnVector(10, LongType) + microsecondsVector.setLong(12345L) + + vector.setChild(0, monthsVector) + vector.setChild(1, daysVector) + vector.setChild(2, microsecondsVector) + + (0 until 10).foreach { i => + assert(vector.getChild(0).getInt(i) == 3) + assert(vector.getChild(1).getInt(i) == 25) + assert(vector.getChild(2).getLong(i) == 12345L) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala index dde463dd395f7..0d1ab5ef77b64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala @@ -81,7 +81,7 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { withTempPath { path => val pathString = path.getCanonicalPath - spark.range(10).select('id.as("ID")).write.json(pathString) + spark.range(10).select(Symbol("id").as("ID")).write.json(pathString) spark.range(10).write.mode("append").json(pathString) assert(spark.read.json(pathString).columns.toSet == Set("id", "ID")) } @@ -139,9 +139,10 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { Seq(true) .toDF() .mapPartitions { _ => - TaskContext.get.getLocalProperty(confKey) == confValue match { - case true => Iterator(true) - case false => Iterator.empty + if (TaskContext.get.getLocalProperty(confKey) == confValue) { + Iterator(true) + } else { + Iterator.empty } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index abde486b2db2b..a589d4ee3e3c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -469,7 +469,7 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { if (i == 0) { assert(zone === "Z") } else { - assert(zone === String.format("%+03d:00", new Integer(i))) + assert(zone === String.format("%+03d:00", Integer.valueOf(i))) } } val e2 = intercept[ParseException](sql("set time zone interval 19 hours")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index f4b18f1adfdec..d32e958c7ca2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -776,30 +776,36 @@ class JDBCSuite extends QueryTest val compileFilter = PrivateMethod[Option[String]](Symbol("compileFilter")) def doCompileFilter(f: Filter): String = JDBCRDD invokePrivate compileFilter(f, JdbcDialects.get("jdbc:")) getOrElse("") - assert(doCompileFilter(EqualTo("col0", 3)) === """"col0" = 3""") - assert(doCompileFilter(Not(EqualTo("col1", "abc"))) === """(NOT ("col1" = 'abc'))""") - assert(doCompileFilter(And(EqualTo("col0", 0), EqualTo("col1", "def"))) - === """("col0" = 0) AND ("col1" = 'def')""") - assert(doCompileFilter(Or(EqualTo("col0", 2), EqualTo("col1", "ghi"))) - === """("col0" = 2) OR ("col1" = 'ghi')""") - assert(doCompileFilter(LessThan("col0", 5)) === """"col0" < 5""") - assert(doCompileFilter(LessThan("col3", - Timestamp.valueOf("1995-11-21 00:00:00.0"))) === """"col3" < '1995-11-21 00:00:00.0'""") - assert(doCompileFilter(LessThan("col4", Date.valueOf("1983-08-04"))) - === """"col4" < '1983-08-04'""") - assert(doCompileFilter(LessThanOrEqual("col0", 5)) === """"col0" <= 5""") - assert(doCompileFilter(GreaterThan("col0", 3)) === """"col0" > 3""") - assert(doCompileFilter(GreaterThanOrEqual("col0", 3)) === """"col0" >= 3""") - assert(doCompileFilter(In("col1", Array("jkl"))) === """"col1" IN ('jkl')""") - assert(doCompileFilter(In("col1", Array.empty)) === - """CASE WHEN "col1" IS NULL THEN NULL ELSE FALSE END""") - assert(doCompileFilter(Not(In("col1", Array("mno", "pqr")))) - === """(NOT ("col1" IN ('mno', 'pqr')))""") - assert(doCompileFilter(IsNull("col1")) === """"col1" IS NULL""") - assert(doCompileFilter(IsNotNull("col1")) === """"col1" IS NOT NULL""") - assert(doCompileFilter(And(EqualNullSafe("col0", "abc"), EqualTo("col1", "def"))) - === """((NOT ("col0" != 'abc' OR "col0" IS NULL OR 'abc' IS NULL) """ + Seq(("col0", "col1"), ("`col0`", "`col1`")).foreach { case(col0, col1) => + assert(doCompileFilter(EqualTo(col0, 3)) === """"col0" = 3""") + assert(doCompileFilter(Not(EqualTo(col1, "abc"))) === """(NOT ("col1" = 'abc'))""") + assert(doCompileFilter(And(EqualTo(col0, 0), EqualTo(col1, "def"))) + === """("col0" = 0) AND ("col1" = 'def')""") + assert(doCompileFilter(Or(EqualTo(col0, 2), EqualTo(col1, "ghi"))) + === """("col0" = 2) OR ("col1" = 'ghi')""") + assert(doCompileFilter(LessThan(col0, 5)) === """"col0" < 5""") + assert(doCompileFilter(LessThan(col0, + Timestamp.valueOf("1995-11-21 00:00:00.0"))) === """"col0" < '1995-11-21 00:00:00.0'""") + assert(doCompileFilter(LessThan(col0, Date.valueOf("1983-08-04"))) + === """"col0" < '1983-08-04'""") + assert(doCompileFilter(LessThanOrEqual(col0, 5)) === """"col0" <= 5""") + assert(doCompileFilter(GreaterThan(col0, 3)) === """"col0" > 3""") + assert(doCompileFilter(GreaterThanOrEqual(col0, 3)) === """"col0" >= 3""") + assert(doCompileFilter(In(col1, Array("jkl"))) === """"col1" IN ('jkl')""") + assert(doCompileFilter(In(col1, Array.empty)) === + """CASE WHEN "col1" IS NULL THEN NULL ELSE FALSE END""") + assert(doCompileFilter(Not(In(col1, Array("mno", "pqr")))) + === """(NOT ("col1" IN ('mno', 'pqr')))""") + assert(doCompileFilter(IsNull(col1)) === """"col1" IS NULL""") + assert(doCompileFilter(IsNotNull(col1)) === """"col1" IS NOT NULL""") + assert(doCompileFilter(And(EqualNullSafe(col0, "abc"), EqualTo(col1, "def"))) + === """((NOT ("col0" != 'abc' OR "col0" IS NULL OR 'abc' IS NULL) """ + """OR ("col0" IS NULL AND 'abc' IS NULL))) AND ("col1" = 'def')""") + } + val e = intercept[AnalysisException] { + doCompileFilter(EqualTo("col0.nested", 3)) + }.getMessage + assert(e.contains("Filter push down does not support nested column: col0.nested")) } test("Dialect unregister") { @@ -1008,14 +1014,16 @@ class JDBCSuite extends QueryTest val defaultQuery = s"TRUNCATE TABLE $table" val postgresQuery = s"TRUNCATE TABLE ONLY $table" val teradataQuery = s"DELETE FROM $table ALL" + val db2Query = s"TRUNCATE TABLE $table IMMEDIATE" - Seq(mysql, db2, h2, derby).foreach{ dialect => + Seq(mysql, h2, derby).foreach{ dialect => assert(dialect.getTruncateQuery(table, Some(true)) == defaultQuery) } assert(postgres.getTruncateQuery(table) == postgresQuery) assert(oracle.getTruncateQuery(table) == defaultQuery) assert(teradata.getTruncateQuery(table) == teradataQuery) + assert(db2.getTruncateQuery(table) == db2Query) } test("SPARK-22880: Truncate table with CASCADE by jdbc dialect") { @@ -1034,13 +1042,15 @@ class JDBCSuite extends QueryTest val postgresQuery = s"TRUNCATE TABLE ONLY $table CASCADE" val oracleQuery = s"TRUNCATE TABLE $table CASCADE" val teradataQuery = s"DELETE FROM $table ALL" + val db2Query = s"TRUNCATE TABLE $table IMMEDIATE" - Seq(mysql, db2, h2, derby).foreach{ dialect => + Seq(mysql, h2, derby).foreach{ dialect => assert(dialect.getTruncateQuery(table, Some(true)) == defaultQuery) } assert(postgres.getTruncateQuery(table, Some(true)) == postgresQuery) assert(oracle.getTruncateQuery(table, Some(true)) == oracleQuery) assert(teradata.getTruncateQuery(table, Some(true)) == teradataQuery) + assert(db2.getTruncateQuery(table, Some(true)) == db2Query) } test("Test DataFrame.where for Date and Timestamp") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index c5e1a6ace7029..85ccf828873d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -20,14 +20,15 @@ package org.apache.spark.sql.jdbc import java.sql.{Connection, DriverManager} import java.util.Properties -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sort} import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortValue} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.functions.{lit, sum, udf} +import org.apache.spark.sql.functions.{avg, count, lit, sum, udf} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.Utils @@ -91,6 +92,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel // scalastyle:on conn.prepareStatement("INSERT INTO \"test\".\"person\" VALUES (1)").executeUpdate() conn.prepareStatement("INSERT INTO \"test\".\"person\" VALUES (2)").executeUpdate() + conn.prepareStatement( + """CREATE TABLE "test"."view1" ("|col1" INTEGER, "|col2" INTEGER)""").executeUpdate() + conn.prepareStatement( + """CREATE TABLE "test"."view2" ("|col1" INTEGER, "|col3" INTEGER)""").executeUpdate() } } @@ -316,7 +321,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("show tables") { checkAnswer(sql("SHOW TABLES IN h2.test"), Seq(Row("test", "people", false), Row("test", "empty_table", false), - Row("test", "employee", false), Row("test", "dept", false), Row("test", "person", false))) + Row("test", "employee", false), Row("test", "dept", false), Row("test", "person", false), + Row("test", "view1", false), Row("test", "view2", false))) } test("SQL API: create table as select") { @@ -806,17 +812,81 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(query, Seq(Row(29000.0))) } - test("scan with aggregate push-down: SUM(CASE WHEN) with group by") { - val df = - sql("SELECT SUM(CASE WHEN SALARY > 0 THEN 1 ELSE 0 END) FROM h2.test.employee GROUP BY DEPT") - checkAggregateRemoved(df, false) + test("scan with aggregate push-down: aggregate function with CASE WHEN") { + val df = sql( + """ + |SELECT + | COUNT(CASE WHEN SALARY > 8000 AND SALARY < 10000 THEN SALARY ELSE 0 END), + | COUNT(CASE WHEN SALARY > 8000 AND SALARY <= 13000 THEN SALARY ELSE 0 END), + | COUNT(CASE WHEN SALARY > 11000 OR SALARY < 10000 THEN SALARY ELSE 0 END), + | COUNT(CASE WHEN SALARY >= 12000 OR SALARY < 9000 THEN SALARY ELSE 0 END), + | COUNT(CASE WHEN SALARY >= 12000 OR NOT(SALARY >= 9000) THEN SALARY ELSE 0 END), + | MAX(CASE WHEN NOT(SALARY > 8000) AND SALARY >= 8000 THEN SALARY ELSE 0 END), + | MAX(CASE WHEN NOT(SALARY > 8000) OR SALARY > 8000 THEN SALARY ELSE 0 END), + | MAX(CASE WHEN NOT(SALARY > 8000) AND NOT(SALARY < 8000) THEN SALARY ELSE 0 END), + | MAX(CASE WHEN NOT(SALARY != 0) OR NOT(SALARY < 8000) THEN SALARY ELSE 0 END), + | MAX(CASE WHEN NOT(SALARY > 8000 AND SALARY > 8000) THEN 0 ELSE SALARY END), + | MIN(CASE WHEN NOT(SALARY > 8000 OR SALARY IS NULL) THEN SALARY ELSE 0 END), + | SUM(CASE WHEN NOT(SALARY > 8000 AND SALARY IS NOT NULL) THEN SALARY ELSE 0 END), + | SUM(CASE WHEN SALARY > 10000 THEN 2 WHEN SALARY > 8000 THEN 1 END), + | AVG(CASE WHEN NOT(SALARY > 8000 OR SALARY IS NOT NULL) THEN SALARY ELSE 0 END) + |FROM h2.test.employee GROUP BY DEPT + """.stripMargin) + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedFilters: [], " + "PushedAggregates: [COUNT(CASE WHEN ((SALARY) > (8000.00)) AND ((SALARY) < (10000.00))" + + " THEN SALARY ELSE 0.00 END), C..., " + + "PushedFilters: [], " + + "PushedGroupByColumns: [DEPT]" checkKeywordsExistsInExplain(df, expected_plan_fragment) } - checkAnswer(df, Seq(Row(1), Row(2), Row(2))) + checkAnswer(df, Seq(Row(1, 1, 1, 1, 1, 0d, 12000d, 0d, 12000d, 12000d, 0d, 0d, 2, 0d), + Row(2, 2, 2, 2, 2, 0d, 10000d, 0d, 10000d, 10000d, 0d, 0d, 2, 0d), + Row(2, 2, 2, 2, 2, 0d, 12000d, 0d, 12000d, 12000d, 0d, 0d, 3, 0d))) + } + + test("scan with aggregate push-down: aggregate function with binary arithmetic") { + Seq(false, true).foreach { ansiMode => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { + val df = sql("SELECT SUM(2147483647 + DEPT) FROM h2.test.employee") + checkAggregateRemoved(df, ansiMode) + val expected_plan_fragment = if (ansiMode) { + "PushedAggregates: [SUM((2147483647) + (DEPT))], " + + "PushedFilters: [], PushedGroupByColumns: []" + } else { + "PushedFilters: []" + } + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + if (ansiMode) { + val e = intercept[SparkException] { + checkAnswer(df, Seq(Row(-10737418233L))) + } + assert(e.getMessage.contains( + "org.h2.jdbc.JdbcSQLDataException: Numeric value out of range: \"2147483648\"")) + } else { + checkAnswer(df, Seq(Row(-10737418233L))) + } + } + } + } + + test("scan with aggregate push-down: aggregate function with UDF") { + val df = spark.table("h2.test.employee") + val decrease = udf { (x: Double, y: Double) => x - y } + val query = df.select(sum(decrease($"SALARY", $"BONUS")).as("value")) + checkAggregateRemoved(query, false) + query.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedFilters: []" + checkKeywordsExistsInExplain(query, expected_plan_fragment) + } + checkAnswer(query, Seq(Row(47100.0))) } test("scan with aggregate push-down: partition columns with multi group by columns") { @@ -874,4 +944,92 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df, Seq(Row(2))) // scalastyle:on } + + test("scan with aggregate push-down: complete push-down SUM, AVG, COUNT") { + val df = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "1") + .table("h2.test.employee") + .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count")) + checkAggregateRemoved(df) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [SUM(SALARY), AVG(SALARY), COUNT(SALARY)]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(53000.00, 10600.000000, 5))) + + val df2 = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "1") + .table("h2.test.employee") + .groupBy($"name") + .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count")) + checkAggregateRemoved(df) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [SUM(SALARY), AVG(SALARY), COUNT(SALARY)]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df2, Seq( + Row("alex", 12000.00, 12000.000000, 1), + Row("amy", 10000.00, 10000.000000, 1), + Row("cathy", 9000.00, 9000.000000, 1), + Row("david", 10000.00, 10000.000000, 1), + Row("jen", 12000.00, 12000.000000, 1))) + } + + test("scan with aggregate push-down: partial push-down SUM, AVG, COUNT") { + val df = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count")) + checkAggregateRemoved(df, false) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [SUM(SALARY), COUNT(SALARY)]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(53000.00, 10600.000000, 5))) + + val df2 = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .groupBy($"name") + .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count")) + checkAggregateRemoved(df, false) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [SUM(SALARY), COUNT(SALARY)]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df2, Seq( + Row("alex", 12000.00, 12000.000000, 1), + Row("amy", 10000.00, 10000.000000, 1), + Row("cathy", 9000.00, 9000.000000, 1), + Row("david", 10000.00, 10000.000000, 1), + Row("jen", 12000.00, 12000.000000, 1))) + } + + test("SPARK-37895: JDBC push down with delimited special identifiers") { + val df = sql( + """SELECT h2.test.view1.`|col1`, h2.test.view1.`|col2`, h2.test.view2.`|col3` + |FROM h2.test.view1 LEFT JOIN h2.test.view2 + |ON h2.test.view1.`|col1` = h2.test.view2.`|col1`""".stripMargin) + checkAnswer(df, Seq.empty[Row]) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index be9d1b0e179fe..18039db2ca744 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -282,7 +282,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti withTable("bucketed_table") { val numBuckets = NumBucketsForPruningNullDf val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) - val naNDF = nullDF.selectExpr("i", "cast(if(isnull(j), 'NaN', j) as double) as j", "k") + val naNDF = nullDF.selectExpr("i", "try_cast(if(isnull(j), 'NaN', j) as double) as j", "k") // json does not support predicate push-down, and thus json is used here naNDF.write .format("json") @@ -463,18 +463,18 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti // check existence of shuffle assert( - joinOperator.left.find(_.isInstanceOf[ShuffleExchangeExec]).isDefined == shuffleLeft, + joinOperator.left.exists(_.isInstanceOf[ShuffleExchangeExec]) == shuffleLeft, s"expected shuffle in plan to be $shuffleLeft but found\n${joinOperator.left}") assert( - joinOperator.right.find(_.isInstanceOf[ShuffleExchangeExec]).isDefined == shuffleRight, + joinOperator.right.exists(_.isInstanceOf[ShuffleExchangeExec]) == shuffleRight, s"expected shuffle in plan to be $shuffleRight but found\n${joinOperator.right}") // check existence of sort assert( - joinOperator.left.find(_.isInstanceOf[SortExec]).isDefined == sortLeft, + joinOperator.left.exists(_.isInstanceOf[SortExec]) == sortLeft, s"expected sort in the left child to be $sortLeft but found\n${joinOperator.left}") assert( - joinOperator.right.find(_.isInstanceOf[SortExec]).isDefined == sortRight, + joinOperator.right.exists(_.isInstanceOf[SortExec]) == sortRight, s"expected sort in the right child to be $sortRight but found\n${joinOperator.right}") // check the output partitioning @@ -678,7 +678,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti df1.groupBy("i", "j").agg(max("k")).sort("i", "j")) assert( - aggregated.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchangeExec]).isEmpty) + !aggregated.queryExecution.executedPlan.exists(_.isInstanceOf[ShuffleExchangeExec])) } } @@ -719,7 +719,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti df1.groupBy("i", "j").agg(max("k")).sort("i", "j")) assert( - aggregated.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchangeExec]).isEmpty) + !aggregated.queryExecution.executedPlan.exists(_.isInstanceOf[ShuffleExchangeExec])) } } @@ -773,8 +773,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti // join predicates is a super set of child's partitioning columns val bucketedTableTestSpec1 = - BucketedTableTestSpec(Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))), - numPartitions = 1, expectedShuffle = false) + BucketedTableTestSpec(Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))), numPartitions = 1) testBucketing( bucketedTableTestSpecLeft = bucketedTableTestSpec1, bucketedTableTestSpecRight = bucketedTableTestSpec1, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala index 81ce979ef0b62..1b1f3714dc701 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala @@ -36,7 +36,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll with override def beforeAll(): Unit = { super.beforeAll() - targetAttributes = Seq('a.int, 'd.int, 'b.int, 'c.int) + targetAttributes = Seq(Symbol("a").int, Symbol("d").int, Symbol("b").int, Symbol("c").int) targetPartitionSchema = new StructType() .add("b", IntegerType) .add("c", IntegerType) @@ -74,7 +74,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll with caseSensitive) { intercept[AssertionError] { rule.convertStaticPartitions( - sourceAttributes = Seq('e.int, 'f.int), + sourceAttributes = Seq(Symbol("e").int, Symbol("f").int), providedPartitions = Map("b" -> None, "c" -> None), targetAttributes = targetAttributes, targetPartitionSchema = targetPartitionSchema) @@ -85,7 +85,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll with // Missing columns. intercept[AnalysisException] { rule.convertStaticPartitions( - sourceAttributes = Seq('e.int), + sourceAttributes = Seq(Symbol("e").int), providedPartitions = Map("b" -> Some("1"), "c" -> None), targetAttributes = targetAttributes, targetPartitionSchema = targetPartitionSchema) @@ -96,7 +96,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll with // Missing partitioning columns. intercept[AnalysisException] { rule.convertStaticPartitions( - sourceAttributes = Seq('e.int, 'f.int), + sourceAttributes = Seq(Symbol("e").int, Symbol("f").int), providedPartitions = Map("b" -> Some("1")), targetAttributes = targetAttributes, targetPartitionSchema = targetPartitionSchema) @@ -105,7 +105,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll with // Missing partitioning columns. intercept[AnalysisException] { rule.convertStaticPartitions( - sourceAttributes = Seq('e.int, 'f.int, 'g.int), + sourceAttributes = Seq(Symbol("e").int, Symbol("f").int, Symbol("g").int), providedPartitions = Map("b" -> Some("1")), targetAttributes = targetAttributes, targetPartitionSchema = targetPartitionSchema) @@ -114,7 +114,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll with // Wrong partitioning columns. intercept[AnalysisException] { rule.convertStaticPartitions( - sourceAttributes = Seq('e.int, 'f.int), + sourceAttributes = Seq(Symbol("e").int, Symbol("f").int), providedPartitions = Map("b" -> Some("1"), "d" -> None), targetAttributes = targetAttributes, targetPartitionSchema = targetPartitionSchema) @@ -125,7 +125,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll with // Wrong partitioning columns. intercept[AnalysisException] { rule.convertStaticPartitions( - sourceAttributes = Seq('e.int, 'f.int), + sourceAttributes = Seq(Symbol("e").int, Symbol("f").int), providedPartitions = Map("b" -> Some("1"), "d" -> Some("2")), targetAttributes = targetAttributes, targetPartitionSchema = targetPartitionSchema) @@ -134,7 +134,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll with // Wrong partitioning columns. intercept[AnalysisException] { rule.convertStaticPartitions( - sourceAttributes = Seq('e.int), + sourceAttributes = Seq(Symbol("e").int), providedPartitions = Map("b" -> Some("1"), "c" -> Some("3"), "d" -> Some("2")), targetAttributes = targetAttributes, targetPartitionSchema = targetPartitionSchema) @@ -144,7 +144,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll with // Wrong partitioning columns. intercept[AnalysisException] { rule.convertStaticPartitions( - sourceAttributes = Seq('e.int, 'f.int), + sourceAttributes = Seq(Symbol("e").int, Symbol("f").int), providedPartitions = Map("b" -> Some("1"), "C" -> Some("3")), targetAttributes = targetAttributes, targetPartitionSchema = targetPartitionSchema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index b553e6ed566b5..1fb4737c45a61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.execution.datasources.DataSourceUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.test.SharedSparkSession @@ -807,13 +808,15 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { Seq((1, 2), (1, 3)).toDF("i", "part") .write.partitionBy("part").mode("overwrite") - .option("partitionOverwriteMode", "dynamic").parquet(path.getAbsolutePath) + .option(DataSourceUtils.PARTITION_OVERWRITE_MODE, PartitionOverwriteMode.DYNAMIC.toString) + .parquet(path.getAbsolutePath) checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(1, 1) :: Row(1, 2) :: Row(1, 3) :: Nil) Seq((1, 2), (1, 3)).toDF("i", "part") .write.partitionBy("part").mode("overwrite") - .option("partitionOverwriteMode", "static").parquet(path.getAbsolutePath) + .option(DataSourceUtils.PARTITION_OVERWRITE_MODE, PartitionOverwriteMode.STATIC.toString) + .parquet(path.getAbsolutePath) checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(1, 2) :: Row(1, 3) :: Nil) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index de54b38627443..8f263f042cf9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -109,6 +109,19 @@ case class AllDataTypesScan( } } +class LegacyTimestampSource extends RelationProvider { + override def createRelation(ctx: SQLContext, parameters: Map[String, String]): BaseRelation = { + new BaseRelation() with TableScan { + override val sqlContext: SQLContext = ctx + override val schema: StructType = StructType(StructField("col", TimestampType) :: Nil) + override def buildScan(): RDD[Row] = { + sqlContext.sparkContext.parallelize( + Row(java.sql.Timestamp.valueOf("2022-03-08 12:13:14")) :: Nil) + } + } + } +} + class TableScanSuite extends DataSourceTest with SharedSparkSession { protected override lazy val sql = spark.sql _ @@ -359,7 +372,7 @@ class TableScanSuite extends DataSourceTest with SharedSparkSession { val schemaNotMatch = intercept[Exception] { sql( s""" - |CREATE $tableType relationProviderWithSchema (i int) + |CREATE $tableType relationProviderWithSchema (i string) |USING org.apache.spark.sql.sources.SimpleScanSource |OPTIONS ( | From '1', @@ -420,4 +433,18 @@ class TableScanSuite extends DataSourceTest with SharedSparkSession { val comments = planned.schema.fields.map(_.getComment().getOrElse("NO_COMMENT")).mkString(",") assert(comments === "SN,SA,NO_COMMENT") } + + test("SPARK-38437: accept java.sql.Timestamp even when Java 8 API is enabled") { + val tableName = "relationProviderWithLegacyTimestamps" + withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") { + withTable (tableName) { + sql(s""" + |CREATE TABLE $tableName (col TIMESTAMP) + |USING org.apache.spark.sql.sources.LegacyTimestampSource""".stripMargin) + checkAnswer( + spark.table(tableName), + Row(java.sql.Timestamp.valueOf("2022-03-08 12:13:14").toInstant) :: Nil) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/AcceptsLatestSeenOffsetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/AcceptsLatestSeenOffsetSuite.scala new file mode 100644 index 0000000000000..d3e9a08509b0b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/AcceptsLatestSeenOffsetSuite.scala @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.plans.logical.Range +import org.apache.spark.sql.connector.read.streaming +import org.apache.spark.sql.connector.read.streaming.{AcceptsLatestSeenOffset, SparkDataStream} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.sources.{ContinuousMemoryStream, ContinuousMemoryStreamOffset} +import org.apache.spark.sql.types.{LongType, StructType} + +class AcceptsLatestSeenOffsetSuite extends StreamTest with BeforeAndAfter { + + import testImplicits._ + + after { + sqlContext.streams.active.foreach(_.stop()) + } + + test("DataSource V1 source with micro-batch is not supported") { + val testSource = new TestSource(spark) + val df = testSource.toDF() + + /** Add data to this test source by incrementing its available offset */ + def addData(numNewRows: Int): StreamAction = new AddData { + override def addData( + query: Option[StreamExecution]): (SparkDataStream, streaming.Offset) = { + testSource.incrementAvailableOffset(numNewRows) + (testSource, testSource.getOffset.get) + } + } + + addData(10) + val query = df.writeStream.format("console").start() + val exc = intercept[StreamingQueryException] { + query.processAllAvailable() + } + assert(exc.getMessage.contains( + "AcceptsLatestSeenOffset is not supported with DSv1 streaming source")) + } + + test("DataSource V2 source with micro-batch") { + val inputData = new TestMemoryStream[Long](0, spark.sqlContext) + val df = inputData.toDF().select("value") + + /** Add data to this test source by incrementing its available offset */ + def addData(values: Array[Long]): StreamAction = new AddData { + override def addData( + query: Option[StreamExecution]): (SparkDataStream, streaming.Offset) = { + (inputData, inputData.addData(values)) + } + } + + testStream(df)( + StartStream(), + addData((1L to 10L).toArray), + ProcessAllAvailable(), + Execute("latest seen offset should be null") { _ => + // this verifies that the callback method is not called for the new query + assert(inputData.latestSeenOffset === null) + }, + StopStream, + + StartStream(), + addData((11L to 20L).toArray), + ProcessAllAvailable(), + Execute("latest seen offset should be 0") { _ => + assert(inputData.latestSeenOffset === LongOffset(0)) + }, + StopStream, + + Execute("mark last batch as incomplete") { q => + // Delete the last committed batch from the commit log to signify that the last batch + // (a no-data batch) did not complete and has to be re-executed on restart. + val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) + q.commitLog.purgeAfter(commit - 1) + }, + StartStream(), + addData((21L to 30L).toArray), + ProcessAllAvailable(), + Execute("latest seen offset should be 1") { _ => + assert(inputData.latestSeenOffset === LongOffset(1)) + } + ) + } + + test("DataSource V2 source with micro-batch - rollback of microbatch 0") { + // Test case: when the query is restarted, we expect the execution to call `latestSeenOffset` + // first. Later as part of the execution, execution may call `initialOffset` if the previous + // run of the query had no committed batches. + val inputData = new TestMemoryStream[Long](0, spark.sqlContext) + val df = inputData.toDF().select("value") + + /** Add data to this test source by incrementing its available offset */ + def addData(values: Array[Long]): StreamAction = new AddData { + override def addData( + query: Option[StreamExecution]): (SparkDataStream, streaming.Offset) = { + (inputData, inputData.addData(values)) + } + } + + testStream(df)( + StartStream(), + addData((1L to 10L).toArray), + ProcessAllAvailable(), + Execute("latest seen offset should be null") { _ => + // this verifies that the callback method is not called for the new query + assert(inputData.latestSeenOffset === null) + }, + StopStream, + + Execute("mark last batch as incomplete") { q => + // Delete the last committed batch from the commit log to signify that the last batch + // (a no-data batch) did not complete and has to be re-executed on restart. + val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) + q.commitLog.purgeAfter(commit - 1) + }, + + Execute("reset flag initial offset called flag") { q => + inputData.assertInitialOffsetIsCalledAfterLatestOffsetSeen = true + }, + StartStream(), + addData((11L to 20L).toArray), + ProcessAllAvailable(), + Execute("latest seen offset should be 0") { _ => + assert(inputData.latestSeenOffset === LongOffset(0)) + }, + StopStream + ) + } + + test("DataSource V2 source with continuous mode") { + val inputData = new TestContinuousMemoryStream[Long](0, spark.sqlContext, 1) + val df = inputData.toDF().select("value") + + /** Add data to this test source by incrementing its available offset */ + def addData(values: Array[Long]): StreamAction = new AddData { + override def addData( + query: Option[StreamExecution]): (SparkDataStream, streaming.Offset) = { + (inputData, inputData.addData(values)) + } + } + + testStream(df)( + StartStream(trigger = Trigger.Continuous("1 hour")), + addData((1L to 10L).toArray), + AwaitEpoch(0), + Execute { _ => + assert(inputData.latestSeenOffset === null) + }, + IncrementEpoch(), + StopStream, + + StartStream(trigger = Trigger.Continuous("1 hour")), + addData((11L to 20L).toArray), + AwaitEpoch(2), + Execute { _ => + assert(inputData.latestSeenOffset === ContinuousMemoryStreamOffset(Map(0 -> 10))) + }, + IncrementEpoch(), + StopStream, + + StartStream(trigger = Trigger.Continuous("1 hour")), + addData((21L to 30L).toArray), + AwaitEpoch(3), + Execute { _ => + assert(inputData.latestSeenOffset === ContinuousMemoryStreamOffset(Map(0 -> 20))) + } + ) + } + + class TestSource(spark: SparkSession) extends Source with AcceptsLatestSeenOffset { + + @volatile var currentOffset = 0L + + override def getOffset: Option[Offset] = { + if (currentOffset <= 0) None else Some(LongOffset(currentOffset)) + } + + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + if (currentOffset == 0) currentOffset = getOffsetValue(end) + val plan = Range( + start.map(getOffsetValue).getOrElse(0L) + 1L, getOffsetValue(end) + 1L, 1, None, + isStreaming = true) + Dataset.ofRows(spark, plan) + } + + def incrementAvailableOffset(numNewRows: Int): Unit = { + currentOffset = currentOffset + numNewRows + } + + override def setLatestSeenOffset(offset: streaming.Offset): Unit = { + assert(false, "This method should not be called!") + } + + def reset(): Unit = { + currentOffset = 0L + } + + def toDF(): DataFrame = Dataset.ofRows(spark, StreamingExecutionRelation(this, spark)) + override def schema: StructType = new StructType().add("value", LongType) + override def stop(): Unit = {} + private def getOffsetValue(offset: Offset): Long = { + offset match { + case s: SerializedOffset => LongOffset(s).offset + case l: LongOffset => l.offset + case _ => throw new IllegalArgumentException("incorrect offset type: " + offset) + } + } + } + + class TestMemoryStream[A : Encoder]( + _id: Int, + _sqlContext: SQLContext, + _numPartitions: Option[Int] = None) + extends MemoryStream[A](_id, _sqlContext, _numPartitions) + with AcceptsLatestSeenOffset { + + @volatile var latestSeenOffset: streaming.Offset = null + + // Flag to assert the sequence of calls in following scenario: + // When the query is restarted, we expect the execution to call `latestSeenOffset` first. + // Later as part of the execution, execution may call `initialOffset` if the previous + // run of the query had no committed batches. + @volatile var assertInitialOffsetIsCalledAfterLatestOffsetSeen: Boolean = false + + override def setLatestSeenOffset(offset: streaming.Offset): Unit = { + latestSeenOffset = offset + } + + override def initialOffset: streaming.Offset = { + if (assertInitialOffsetIsCalledAfterLatestOffsetSeen && latestSeenOffset == null) { + fail("Expected the latest seen offset to be set.") + } + super.initialOffset + } + } + + class TestContinuousMemoryStream[A : Encoder]( + _id: Int, + _sqlContext: SQLContext, + _numPartitions: Int = 2) + extends ContinuousMemoryStream[A](_id, _sqlContext, _numPartitions) + with AcceptsLatestSeenOffset { + + @volatile var latestSeenOffset: streaming.Offset = _ + + override def setLatestSeenOffset(offset: streaming.Offset): Unit = { + latestSeenOffset = offset + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index a81bd3bd060d3..3d315be636741 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -133,8 +133,8 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche val inputData1 = MemoryStream[Int] val aggWithoutWatermark = inputData1.toDF() .withColumn("eventTime", timestamp_seconds($"value")) - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) testStream(aggWithoutWatermark, outputMode = Complete)( @@ -151,8 +151,8 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche val aggWithWatermark = inputData2.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) testStream(aggWithWatermark)( @@ -174,8 +174,8 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche val aggWithWatermark = inputData.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) // Unlike the ProcessingTime trigger, Trigger.Once only runs one trigger every time @@ -229,8 +229,8 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche val aggWithWatermark = inputData.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) @@ -291,8 +291,8 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche val windowedAggregation = inputData.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) testStream(windowedAggregation)( @@ -316,8 +316,8 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche val windowedAggregation = inputData.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) testStream(windowedAggregation, OutputMode.Update)( @@ -346,8 +346,8 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche val aggWithWatermark = input.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "2 years 5 months") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) def monthsSinceEpoch(date: Date): Int = { @@ -378,8 +378,8 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche val df = inputData.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) testStream(df)( @@ -413,17 +413,17 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche val firstDf = first.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .select('value) + .select(Symbol("value")) val second = MemoryStream[Int] val secondDf = second.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "5 seconds") - .select('value) + .select(Symbol("value")) withTempDir { checkpointDir => - val unionWriter = firstDf.union(secondDf).agg(sum('value)) + val unionWriter = firstDf.union(secondDf).agg(sum(Symbol("value"))) .writeStream .option("checkpointLocation", checkpointDir.getCanonicalPath) .format("memory") @@ -490,8 +490,8 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche val windowedAggregation = inputData.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) // No eviction when asked to compute complete results. @@ -516,7 +516,7 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") .groupBy($"eventTime") - .agg(count("*") as 'count) + .agg(count("*") as Symbol("count")) .select($"eventTime".cast("long").as[Long], $"count".as[Long]) testStream(windowedAggregation)( @@ -587,7 +587,7 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche val groupEvents = input .withWatermark("eventTime", "2 seconds") .groupBy("symbol", "eventTime") - .agg(count("price") as 'count) + .agg(count("price") as Symbol("count")) .select("symbol", "eventTime", "count") val q = groupEvents.writeStream .outputMode("append") @@ -606,14 +606,14 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche val aliasWindow = inputData.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .select(window($"eventTime", "5 seconds") as 'aliasWindow) + .select(window($"eventTime", "5 seconds") as Symbol("aliasWindow")) // Check the eventTime metadata is kept in the top level alias. assert(aliasWindow.logicalPlan.output.exists( _.metadata.contains(EventTimeWatermark.delayKey))) val windowedAggregation = aliasWindow - .groupBy('aliasWindow) - .agg(count("*") as 'count) + .groupBy(Symbol("aliasWindow")) + .agg(count("*") as Symbol("count")) .select($"aliasWindow".getField("start").cast("long").as[Long], $"count".as[Long]) testStream(windowedAggregation)( @@ -636,8 +636,8 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche val windowedAggregation = inputData.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) testStream(windowedAggregation)( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateDistributionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateDistributionSuite.scala new file mode 100644 index 0000000000000..f1578ae5df97d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateDistributionSuite.scala @@ -0,0 +1,455 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.io.File + +import org.apache.commons.io.FileUtils + +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Update +import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, MemoryStream} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.GroupStateTimeout.ProcessingTimeTimeout +import org.apache.spark.sql.streaming.util.{StatefulOpClusteredDistributionTestHelper, StreamManualClock} +import org.apache.spark.util.Utils + +class FlatMapGroupsWithStateDistributionSuite extends StreamTest + with StatefulOpClusteredDistributionTestHelper { + + import testImplicits._ + + test("SPARK-38204: flatMapGroupsWithState should require StatefulOpClusteredDistribution " + + "from children - with initial state") { + // function will return -1 on timeout and returns count of the state otherwise + val stateFunc = + (key: (String, String), values: Iterator[(String, String, Long)], + state: GroupState[RunningCount]) => { + + if (state.hasTimedOut) { + state.remove() + Iterator((key, "-1")) + } else { + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + state.update(RunningCount(count)) + state.setTimeoutDuration("10 seconds") + Iterator((key, count.toString)) + } + } + + val clock = new StreamManualClock + val inputData = MemoryStream[(String, String, Long)] + val initialState = Seq(("c", "c", new RunningCount(2))) + .toDS() + .repartition($"_2") + .groupByKey(a => (a._1, a._2)).mapValues(_._3) + val result = + inputData.toDF().toDF("key1", "key2", "time") + .selectExpr("key1", "key2", "timestamp_seconds(time) as timestamp") + .withWatermark("timestamp", "10 second") + .as[(String, String, Long)] + .repartition($"_1") + .groupByKey(x => (x._1, x._2)) + .flatMapGroupsWithState(Update, ProcessingTimeTimeout(), initialState)(stateFunc) + .select($"_1._1".as("key1"), $"_1._2".as("key2"), $"_2".as("cnt")) + + testStream(result, Update)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, ("a", "a", 1L)), + AdvanceManualClock(1 * 1000), // a and c are processed here for the first time. + CheckNewAnswer(("a", "a", "1"), ("c", "c", "2")), + Execute { query => + val numPartitions = query.lastExecution.numStateStores + + val flatMapGroupsWithStateExecs = query.lastExecution.executedPlan.collect { + case f: FlatMapGroupsWithStateExec => f + } + + assert(flatMapGroupsWithStateExecs.length === 1) + assert(requireStatefulOpClusteredDistribution( + flatMapGroupsWithStateExecs.head, Seq(Seq("_1", "_2"), Seq("_1", "_2")), numPartitions)) + assert(hasDesiredHashPartitioningInChildren( + flatMapGroupsWithStateExecs.head, Seq(Seq("_1", "_2"), Seq("_1", "_2")), numPartitions)) + } + ) + } + + test("SPARK-38204: flatMapGroupsWithState should require StatefulOpClusteredDistribution " + + "from children - without initial state") { + // function will return -1 on timeout and returns count of the state otherwise + val stateFunc = + (key: (String, String), values: Iterator[(String, String, Long)], + state: GroupState[RunningCount]) => { + + if (state.hasTimedOut) { + state.remove() + Iterator((key, "-1")) + } else { + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + state.update(RunningCount(count)) + state.setTimeoutDuration("10 seconds") + Iterator((key, count.toString)) + } + } + + val clock = new StreamManualClock + val inputData = MemoryStream[(String, String, Long)] + val result = + inputData.toDF().toDF("key1", "key2", "time") + .selectExpr("key1", "key2", "timestamp_seconds(time) as timestamp") + .withWatermark("timestamp", "10 second") + .as[(String, String, Long)] + .repartition($"_1") + .groupByKey(x => (x._1, x._2)) + .flatMapGroupsWithState(Update, ProcessingTimeTimeout())(stateFunc) + .select($"_1._1".as("key1"), $"_1._2".as("key2"), $"_2".as("cnt")) + + testStream(result, Update)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, ("a", "a", 1L)), + AdvanceManualClock(1 * 1000), // a is processed here for the first time. + CheckNewAnswer(("a", "a", "1")), + Execute { query => + val numPartitions = query.lastExecution.numStateStores + + val flatMapGroupsWithStateExecs = query.lastExecution.executedPlan.collect { + case f: FlatMapGroupsWithStateExec => f + } + + assert(flatMapGroupsWithStateExecs.length === 1) + assert(requireStatefulOpClusteredDistribution( + flatMapGroupsWithStateExecs.head, Seq(Seq("_1", "_2"), Seq("_1", "_2")), numPartitions)) + assert(hasDesiredHashPartitioningInChildren( + flatMapGroupsWithStateExecs.head, Seq(Seq("_1", "_2"), Seq("_1", "_2")), numPartitions)) + } + ) + } + + test("SPARK-38204: flatMapGroupsWithState should require ClusteredDistribution " + + "from children if the query starts from checkpoint in 3.2.x - with initial state") { + // function will return -1 on timeout and returns count of the state otherwise + val stateFunc = + (key: (String, String), values: Iterator[(String, String, Long)], + state: GroupState[RunningCount]) => { + + if (state.hasTimedOut) { + state.remove() + Iterator((key, "-1")) + } else { + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + state.update(RunningCount(count)) + state.setTimeoutDuration("10 seconds") + Iterator((key, count.toString)) + } + } + + val clock = new StreamManualClock + val inputData = MemoryStream[(String, String, Long)] + val initialState = Seq(("c", "c", new RunningCount(2))) + .toDS() + .repartition($"_2") + .groupByKey(a => (a._1, a._2)).mapValues(_._3) + val result = + inputData.toDF().toDF("key1", "key2", "time") + .selectExpr("key1", "key2", "timestamp_seconds(time) as timestamp") + .withWatermark("timestamp", "10 second") + .as[(String, String, Long)] + .repartition($"_1") + .groupByKey(x => (x._1, x._2)) + .flatMapGroupsWithState(Update, ProcessingTimeTimeout(), initialState)(stateFunc) + .select($"_1._1".as("key1"), $"_1._2".as("key2"), $"_2".as("cnt")) + + val resourceUri = this.getClass.getResource( + "/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate1-repartition/").toURI + + val checkpointDir = Utils.createTempDir().getCanonicalFile + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + + inputData.addData(("a", "a", 1L)) + + testStream(result, Update)( + StartStream(Trigger.ProcessingTime("1 second"), + checkpointLocation = checkpointDir.getAbsolutePath, + triggerClock = clock, + additionalConfs = Map(SQLConf.STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION.key -> "true")), + + // scalastyle:off line.size.limit + /* + Note: The checkpoint was generated using the following input in Spark version 3.2.0 + AddData(inputData, ("a", "a", 1L)), + AdvanceManualClock(1 * 1000), // a and c are processed here for the first time. + CheckNewAnswer(("a", "a", "1"), ("c", "c", "2")), + + Note2: The following is the physical plan of the query in Spark version 3.2.0. + + WriteToDataSourceV2 org.apache.spark.sql.execution.streaming.sources.MicroBatchWrite@253dd5ad, org.apache.spark.sql.execution.datasources.v2.DataSourceV2Strategy$$Lambda$2214/0x0000000840ead440@6ede0d42 + +- *(6) Project [_1#58._1 AS key1#63, _1#58._2 AS key2#64, _2#59 AS cnt#65] + +- *(6) SerializeFromObject [if (isnull(knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1)) null else named_struct(_1, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1)._1, true, false), _2, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1)._2, true, false)) AS _1#58, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._2, true, false) AS _2#59] + +- FlatMapGroupsWithState org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite$$Lambda$1067/0x0000000840770440@3f2e51a9, newInstance(class scala.Tuple2), newInstance(class scala.Tuple3), newInstance(class org.apache.spark.sql.streaming.RunningCount), [_1#52, _2#53], [_1#22, _2#23], [key1#29, key2#30, timestamp#35-T10000ms], [count#25L], obj#57: scala.Tuple2, state info [ checkpoint = file:/tmp/streaming.metadata-d4f0d156-78b5-4129-97fb-361241ab03d8/state, runId = eb107298-692d-4336-bb76-6b11b34a0753, opId = 0, ver = 0, numPartitions = 5], class[count[0]: bigint], 2, Update, ProcessingTimeTimeout, 1000, 0, true + :- *(3) Sort [_1#52 ASC NULLS FIRST, _2#53 ASC NULLS FIRST], false, 0 + : +- Exchange hashpartitioning(_1#52, _2#53, 5), ENSURE_REQUIREMENTS, [id=#78] + : +- AppendColumns org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite$$Lambda$1751/0x0000000840ccc040@41d4c0d8, newInstance(class scala.Tuple3), [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1, true, false) AS _1#52, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._2, true, false) AS _2#53] + : +- *(2) Project [key1#29, key2#30, timestamp#35-T10000ms] + : +- Exchange hashpartitioning(_1#3, 5), REPARTITION_BY_COL, [id=#73] + : +- EventTimeWatermark timestamp#35: timestamp, 10 seconds + : +- *(1) Project [_1#3 AS key1#29, _2#4 AS key2#30, timestamp_seconds(_3#5L) AS timestamp#35, _1#3] + : +- MicroBatchScan[_1#3, _2#4, _3#5L] MemoryStreamDataSource + +- *(5) Sort [_1#22 ASC NULLS FIRST, _2#23 ASC NULLS FIRST], false, 0 + +- Exchange hashpartitioning(_1#22, _2#23, 5), ENSURE_REQUIREMENTS, [id=#85] + +- *(4) Project [count#25L, _1#22, _2#23] + +- AppendColumns org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite$$Lambda$1686/0x0000000840c9b840@6bb881d0, newInstance(class scala.Tuple3), [knownnotnull(assertnotnull(input[0, org.apache.spark.sql.streaming.RunningCount, true])).count AS count#25L] + +- AppendColumns org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite$$Lambda$1681/0x0000000840c98840@11355c7b, newInstance(class scala.Tuple3), [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1, true, false) AS _1#22, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._2, true, false) AS _2#23] + +- Exchange hashpartitioning(_1#9, 5), REPARTITION_BY_COL, [id=#43] + +- LocalTableScan [_1#9, _2#10, _3#11] + */ + // scalastyle:on line.size.limit + + AddData(inputData, ("a", "b", 1L)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "b", "1")), + + Execute { query => + val numPartitions = query.lastExecution.numStateStores + + val flatMapGroupsWithStateExecs = query.lastExecution.executedPlan.collect { + case f: FlatMapGroupsWithStateExec => f + } + + assert(flatMapGroupsWithStateExecs.length === 1) + assert(requireClusteredDistribution(flatMapGroupsWithStateExecs.head, + Seq(Seq("_1", "_2"), Seq("_1", "_2")), Some(numPartitions))) + assert(hasDesiredHashPartitioningInChildren( + flatMapGroupsWithStateExecs.head, Seq(Seq("_1", "_2"), Seq("_1", "_2")), numPartitions)) + } + ) + } + + test("SPARK-38204: flatMapGroupsWithState should require ClusteredDistribution " + + "from children if the query starts from checkpoint in 3.2.x - without initial state") { + // function will return -1 on timeout and returns count of the state otherwise + val stateFunc = + (key: (String, String), values: Iterator[(String, String, Long)], + state: GroupState[RunningCount]) => { + + if (state.hasTimedOut) { + state.remove() + Iterator((key, "-1")) + } else { + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + state.update(RunningCount(count)) + state.setTimeoutDuration("10 seconds") + Iterator((key, count.toString)) + } + } + + val clock = new StreamManualClock + val inputData = MemoryStream[(String, String, Long)] + val result = + inputData.toDF().toDF("key1", "key2", "time") + .selectExpr("key1", "key2", "timestamp_seconds(time) as timestamp") + .withWatermark("timestamp", "10 second") + .as[(String, String, Long)] + .repartition($"_1") + .groupByKey(x => (x._1, x._2)) + .flatMapGroupsWithState(Update, ProcessingTimeTimeout())(stateFunc) + .select($"_1._1".as("key1"), $"_1._2".as("key2"), $"_2".as("cnt")) + + val resourceUri = this.getClass.getResource( + "/structured-streaming/checkpoint-version-3.2.0-flatmapgroupswithstate2-repartition/").toURI + + val checkpointDir = Utils.createTempDir().getCanonicalFile + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + + inputData.addData(("a", "a", 1L)) + + testStream(result, Update)( + StartStream(Trigger.ProcessingTime("1 second"), + checkpointLocation = checkpointDir.getAbsolutePath, + triggerClock = clock, + additionalConfs = Map(SQLConf.STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION.key -> "true")), + + // scalastyle:off line.size.limit + /* + Note: The checkpoint was generated using the following input in Spark version 3.2.0 + AddData(inputData, ("a", "a", 1L)), + AdvanceManualClock(1 * 1000), // a is processed here for the first time. + CheckNewAnswer(("a", "a", "1")), + + Note2: The following is the physical plan of the query in Spark version 3.2.0 (convenience for checking backward compatibility) + WriteToDataSourceV2 org.apache.spark.sql.execution.streaming.sources.MicroBatchWrite@20732f1b, org.apache.spark.sql.execution.datasources.v2.DataSourceV2Strategy$$Lambda$2205/0x0000000840ea5440@48e6c016 + +- *(5) Project [_1#39._1 AS key1#44, _1#39._2 AS key2#45, _2#40 AS cnt#46] + +- *(5) SerializeFromObject [if (isnull(knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1)) null else named_struct(_1, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1)._1, true, false), _2, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1)._2, true, false)) AS _1#39, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._2, true, false) AS _2#40] + +- FlatMapGroupsWithState org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite$$Lambda$1065/0x0000000840770040@240e41f8, newInstance(class scala.Tuple2), newInstance(class scala.Tuple3), newInstance(class scala.Tuple2), [_1#32, _2#33], [_1#32, _2#33], [key1#9, key2#10, timestamp#15-T10000ms], [key1#9, key2#10, timestamp#15-T10000ms], obj#37: scala.Tuple2, state info [ checkpoint = file:/tmp/spark-6619d285-b0ca-42ab-8284-723a564e13b6/state, runId = b3383a6c-9976-483c-a463-7fc9e9ae3e1a, opId = 0, ver = 0, numPartitions = 5], class[count[0]: bigint], 2, Update, ProcessingTimeTimeout, 1000, 0, false + :- *(3) Sort [_1#32 ASC NULLS FIRST, _2#33 ASC NULLS FIRST], false, 0 + : +- Exchange hashpartitioning(_1#32, _2#33, 5), ENSURE_REQUIREMENTS, [id=#62] + : +- AppendColumns org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite$$Lambda$1709/0x0000000840ca7040@351810cb, newInstance(class scala.Tuple3), [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1, true, false) AS _1#32, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._2, true, false) AS _2#33] + : +- *(2) Project [key1#9, key2#10, timestamp#15-T10000ms] + : +- Exchange hashpartitioning(_1#3, 5), REPARTITION_BY_COL, [id=#57] + : +- EventTimeWatermark timestamp#15: timestamp, 10 seconds + : +- *(1) Project [_1#3 AS key1#9, _2#4 AS key2#10, timestamp_seconds(_3#5L) AS timestamp#15, _1#3] + : +- MicroBatchScan[_1#3, _2#4, _3#5L] MemoryStreamDataSource + +- *(4) !Sort [_1#32 ASC NULLS FIRST, _2#33 ASC NULLS FIRST], false, 0 + +- !Exchange hashpartitioning(_1#32, _2#33, 5), ENSURE_REQUIREMENTS, [id=#46] + +- LocalTableScan , [count#38L] + */ + // scalastyle:on line.size.limit + + AddData(inputData, ("a", "b", 1L)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "b", "1")), + + Execute { query => + val numPartitions = query.lastExecution.numStateStores + + val flatMapGroupsWithStateExecs = query.lastExecution.executedPlan.collect { + case f: FlatMapGroupsWithStateExec => f + } + + assert(flatMapGroupsWithStateExecs.length === 1) + assert(requireClusteredDistribution(flatMapGroupsWithStateExecs.head, + Seq(Seq("_1", "_2"), Seq("_1", "_2")), Some(numPartitions))) + assert(hasDesiredHashPartitioningInChildren( + flatMapGroupsWithStateExecs.head, Seq(Seq("_1", "_2"), Seq("_1", "_2")), numPartitions)) + } + ) + } + + test("SPARK-38204: flatMapGroupsWithState should require ClusteredDistribution " + + "from children if the query starts from checkpoint in prior to 3.2") { + // function will return -1 on timeout and returns count of the state otherwise + val stateFunc = + (key: (String, String), values: Iterator[(String, String, Long)], + state: GroupState[RunningCount]) => { + + if (state.hasTimedOut) { + state.remove() + Iterator((key, "-1")) + } else { + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + state.update(RunningCount(count)) + state.setTimeoutDuration("10 seconds") + Iterator((key, count.toString)) + } + } + + val clock = new StreamManualClock + val inputData = MemoryStream[(String, String, Long)] + val result = + inputData.toDF().toDF("key1", "key2", "time") + .selectExpr("key1", "key2", "timestamp_seconds(time) as timestamp") + .withWatermark("timestamp", "10 second") + .as[(String, String, Long)] + .repartition($"_1") + .groupByKey(x => (x._1, x._2)) + .flatMapGroupsWithState(Update, ProcessingTimeTimeout())(stateFunc) + .select($"_1._1".as("key1"), $"_1._2".as("key2"), $"_2".as("cnt")) + + val resourceUri = this.getClass.getResource( + "/structured-streaming/checkpoint-version-3.1.0-flatmapgroupswithstate-repartition/").toURI + + val checkpointDir = Utils.createTempDir().getCanonicalFile + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + + inputData.addData(("a", "a", 1L)) + + testStream(result, Update)( + StartStream(Trigger.ProcessingTime("1 second"), + checkpointLocation = checkpointDir.getAbsolutePath, + triggerClock = clock, + additionalConfs = Map(SQLConf.STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION.key -> "true")), + + // scalastyle:off line.size.limit + /* + Note: The checkpoint was generated using the following input in Spark version 3.2.0 + AddData(inputData, ("a", "a", 1L)), + AdvanceManualClock(1 * 1000), // a is processed here for the first time. + CheckNewAnswer(("a", "a", "1")), + + Note2: The following plans are the physical plans of the query in older Spark versions + The physical plans around FlatMapGroupsWithStateExec are quite similar, especially + shuffles being injected are same. That said, verifying with checkpoint being built with + Spark 3.1.0 would verify the following versions as well. + + A. Spark 3.1.0 + WriteToDataSourceV2 org.apache.spark.sql.execution.streaming.sources.MicroBatchWrite@4505821b + +- *(3) Project [_1#38._1 AS key1#43, _1#38._2 AS key2#44, _2#39 AS cnt#45] + +- *(3) SerializeFromObject [if (isnull(knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1)) null else named_struct(_1, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1)._1, true, false), _2, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1)._2, true, false)) AS _1#38, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._2, true, false) AS _2#39] + +- FlatMapGroupsWithState org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite$$Lambda$1035/0x0000000840721840@64351072, newInstance(class scala.Tuple2), newInstance(class scala.Tuple3), [_1#32, _2#33], [key1#9, key2#10, timestamp#15-T10000ms], obj#37: scala.Tuple2, state info [ checkpoint = file:/tmp/spark-56397379-d014-48e0-a002-448c0621cfe8/state, runId = 4f9a129f-2b0c-4838-9d26-18171d94be7d, opId = 0, ver = 0, numPartitions = 5], class[count[0]: bigint], 2, Update, ProcessingTimeTimeout, 1000, 0 + +- *(2) Sort [_1#32 ASC NULLS FIRST, _2#33 ASC NULLS FIRST], false, 0 + +- Exchange hashpartitioning(_1#32, _2#33, 5), ENSURE_REQUIREMENTS, [id=#54] + +- AppendColumns org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite$$Lambda$1594/0x0000000840bc8840@857c80d, newInstance(class scala.Tuple3), [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1, true, false) AS _1#32, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._2, true, false) AS _2#33] + +- Exchange hashpartitioning(key1#9, 5), REPARTITION, [id=#52] + +- EventTimeWatermark timestamp#15: timestamp, 10 seconds + +- *(1) Project [_1#3 AS key1#9, _2#4 AS key2#10, timestamp_seconds(_3#5L) AS timestamp#15] + +- *(1) Project [_1#3, _2#4, _3#5L] + +- MicroBatchScan[_1#3, _2#4, _3#5L] MemoryStreamDataSource + + B. Spark 3.0.0 + WriteToDataSourceV2 org.apache.spark.sql.execution.streaming.sources.MicroBatchWrite@32ae8206 + +- *(3) Project [_1#38._1 AS key1#43, _1#38._2 AS key2#44, _2#39 AS cnt#45] + +- *(3) SerializeFromObject [if (isnull(knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1)) null else named_struct(_1, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1)._1, true, false), _2, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1)._2, true, false)) AS _1#38, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._2, true, false) AS _2#39] + +- FlatMapGroupsWithState org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite$$Lambda$972/0x0000000840721c40@3e8c825d, newInstance(class scala.Tuple2), newInstance(class scala.Tuple3), [_1#32, _2#33], [key1#9, key2#10, timestamp#15-T10000ms], obj#37: scala.Tuple2, state info [ checkpoint = file:/tmp/spark-dcd6753e-54c7-481c-aa21-f7fc677a29a4/state, runId = 4854d427-436c-4f4e-9e1d-577bcd9cc890, opId = 0, ver = 0, numPartitions = 5], class[count[0]: bigint], 2, Update, ProcessingTimeTimeout, 1000, 0 + +- *(2) Sort [_1#32 ASC NULLS FIRST, _2#33 ASC NULLS FIRST], false, 0 + +- Exchange hashpartitioning(_1#32, _2#33, 5), true, [id=#54] + +- AppendColumns org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite$$Lambda$1477/0x0000000840bb6040@627623e, newInstance(class scala.Tuple3), [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1, true, false) AS _1#32, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._2, true, false) AS _2#33] + +- Exchange hashpartitioning(key1#9, 5), false, [id=#52] + +- EventTimeWatermark timestamp#15: timestamp, 10 seconds + +- *(1) Project [_1#3 AS key1#9, _2#4 AS key2#10, cast(_3#5L as timestamp) AS timestamp#15] + +- *(1) Project [_1#3, _2#4, _3#5L] + +- MicroBatchScan[_1#3, _2#4, _3#5L] MemoryStreamDataSource + + C. Spark 2.4.0 + *(3) Project [_1#32._1 AS key1#35, _1#32._2 AS key2#36, _2#33 AS cnt#37] + +- *(3) SerializeFromObject [if (isnull(assertnotnull(input[0, scala.Tuple2, true])._1)) null else named_struct(_1, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true])._1)._1, true, false), _2, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true])._1)._2, true, false)) AS _1#32, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._2, true, false) AS _2#33] + +- FlatMapGroupsWithState , newInstance(class scala.Tuple2), newInstance(class scala.Tuple3), [_1#26, _2#27], [key1#9, key2#10, timestamp#15-T10000ms], obj#31: scala.Tuple2, state info [ checkpoint = file:/tmp/spark-634482c9-a55a-4f4e-b352-babec98fb4fc/state, runId = dd65fff0-d901-4e0b-a1ad-8c09b69f33ba, opId = 0, ver = 0, numPartitions = 5], class[count[0]: bigint], 2, Update, ProcessingTimeTimeout, 1000, 0 + +- *(2) Sort [_1#26 ASC NULLS FIRST, _2#27 ASC NULLS FIRST], false, 0 + +- Exchange hashpartitioning(_1#26, _2#27, 5) + +- AppendColumns , newInstance(class scala.Tuple3), [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._1, true, false) AS _1#26, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._2, true, false) AS _2#27] + +- Exchange hashpartitioning(key1#9, 5) + +- EventTimeWatermark timestamp#15: timestamp, interval 10 seconds + +- *(1) Project [_1#56 AS key1#9, _2#57 AS key2#10, cast(_3#58L as timestamp) AS timestamp#15] + +- *(1) Project [_1#56, _2#57, _3#58L] + +- *(1) ScanV2 MemoryStreamDataSource$[_1#56, _2#57, _3#58L] + */ + // scalastyle:on line.size.limit + + AddData(inputData, ("a", "b", 1L)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "b", "1")), + + Execute { query => + val numPartitions = query.lastExecution.numStateStores + + val flatMapGroupsWithStateExecs = query.lastExecution.executedPlan.collect { + case f: FlatMapGroupsWithStateExec => f + } + + assert(flatMapGroupsWithStateExecs.length === 1) + assert(requireClusteredDistribution(flatMapGroupsWithStateExecs.head, + Seq(Seq("_1", "_2"), Seq("_1", "_2")), Some(numPartitions))) + assert(hasDesiredHashPartitioningInChildren( + flatMapGroupsWithStateExecs.head, Seq(Seq("_1", "_2"), Seq("_1", "_2")), numPartitions)) + } + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index d34b2b8e9f7b1..9d34ceea8dd47 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.streaming import java.io.File -import java.sql.Date +import java.sql.{Date, Timestamp} import org.apache.commons.io.FileUtils import org.scalatest.exceptions.TestFailedException @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.RDDScanExec import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, MemoryStateStore, StateStore} +import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, MemoryStateStore, RocksDBStateStoreProvider, StateStore} import org.apache.spark.sql.functions.timestamp_seconds import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock @@ -427,9 +427,10 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { timeoutConf: GroupStateTimeout, procTime: Long, watermarkPresent: Boolean): GroupState[Int] = { - val eventTimeWatermarkMs = watermarkPresent match { - case true => Optional.of(1000L) - case false => Optional.empty[Long] + val eventTimeWatermarkMs = if (watermarkPresent) { + Optional.of(1000L) + } else { + Optional.empty[Long] } TestGroupState.create[Int]( Optional.of(1000), timeoutConf, procTime, eventTimeWatermarkMs, hasTimedOut = false) @@ -1519,6 +1520,50 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { ) } + test("SPARK-38320 - flatMapGroupsWithState state with data should not timeout") { + withTempDir { dir => + withSQLConf( + (SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED.key -> "false"), + (SQLConf.CHECKPOINT_LOCATION.key -> dir.getCanonicalPath), + (SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName)) { + + val inputData = MemoryStream[Timestamp] + val stateFunc = (key: Int, values: Iterator[Timestamp], state: GroupState[Int]) => { + // Should never timeout. All batches should have data and even if a timeout is set, + // it should get cleared when the key receives data per contract. + require(!state.hasTimedOut, "The state should not have timed out!") + // Set state and timeout once, only on the first call. The timeout should get cleared + // in the subsequent batch which has data for the key. + if (!state.exists) { + state.update(0) + state.setTimeoutTimestamp(500) // Timeout at 500 milliseconds. + } + 0 + } + + val query = inputData.toDS() + .withWatermark("value", "0 seconds") + .groupByKey(_ => 0) // Always the same key: 0. + .mapGroupsWithState(GroupStateTimeout.EventTimeTimeout())(stateFunc) + .writeStream + .format("console") + .outputMode("update") + .start() + + try { + // 2 batches. Records are routed to the same key 0. The first batch sets timeout on + // the key, the second batch with data should clear the timeout. + (1 to 2).foreach {i => + inputData.addData(new Timestamp(i * 1000)) + query.processAllAvailable() + } + } finally { + query.stop() + } + } + } + } + testWithAllStateVersions("mapGroupsWithState - initial state - null key") { val mapGroupsWithStateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index e89197b5ff26c..71e8ae74fe207 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -216,7 +216,7 @@ class StreamSuite extends StreamTest { query.processAllAvailable() // Parquet write page-level CRC checksums will change the file size and // affect the data order when reading these files. Please see PARQUET-1746 for details. - val outputDf = spark.read.parquet(outputDir.getAbsolutePath).sort('a).as[Long] + val outputDf = spark.read.parquet(outputDir.getAbsolutePath).sort(Symbol("a")).as[Long] checkDataset[Long](outputDf, (0L to 10L).toArray: _*) } finally { query.stop() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index ff182b524be70..2bb43ec930760 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -528,8 +528,10 @@ trait StreamTest extends QueryTest with SharedSparkSession with TimeLimits with verify(triggerClock.isInstanceOf[SystemClock] || triggerClock.isInstanceOf[StreamManualClock], "Use either SystemClock or StreamManualClock to start the stream") - if (triggerClock.isInstanceOf[StreamManualClock]) { - manualClockExpectedTime = triggerClock.asInstanceOf[StreamManualClock].getTimeMillis() + triggerClock match { + case clock: StreamManualClock => + manualClockExpectedTime = clock.getTimeMillis() + case _ => } val metadataRoot = Option(checkpointLocation).getOrElse(defaultCheckpointLocation) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala new file mode 100644 index 0000000000000..615434f2edad9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.io.File + +import org.apache.commons.io.FileUtils +import org.scalatest.Assertions + +import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution +import org.apache.spark.sql.execution.aggregate.BaseAggregateExec +import org.apache.spark.sql.execution.streaming.{MemoryStream, StateStoreRestoreExec, StateStoreSaveExec} +import org.apache.spark.sql.functions.count +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.OutputMode.Update +import org.apache.spark.sql.streaming.util.StatefulOpClusteredDistributionTestHelper +import org.apache.spark.util.Utils + +class StreamingAggregationDistributionSuite extends StreamTest + with StatefulOpClusteredDistributionTestHelper with Assertions { + + import testImplicits._ + + test("SPARK-38204: streaming aggregation should require StatefulOpClusteredDistribution " + + "from children") { + + val input = MemoryStream[Int] + val df1 = input.toDF().select('value as 'key1, 'value * 2 as 'key2, 'value * 3 as 'value) + val agg = df1.repartition('key1).groupBy('key1, 'key2).agg(count('*)) + + testStream(agg, OutputMode.Update())( + AddData(input, 1, 1, 2, 3, 4), + CheckAnswer((1, 2, 2), (2, 4, 1), (3, 6, 1), (4, 8, 1)), + Execute { query => + val numPartitions = query.lastExecution.numStateStores + + // verify state store restore/save + val stateStoreOps = query.lastExecution.executedPlan.collect { + case s: StateStoreRestoreExec => s + case s: StateStoreSaveExec => s + } + + assert(stateStoreOps.nonEmpty) + stateStoreOps.foreach { stateOp => + assert(requireStatefulOpClusteredDistribution(stateOp, Seq(Seq("key1", "key2")), + numPartitions)) + assert(hasDesiredHashPartitioningInChildren(stateOp, Seq(Seq("key1", "key2")), + numPartitions)) + } + + // verify aggregations in between, except partial aggregation + val allAggregateExecs = query.lastExecution.executedPlan.collect { + case a: BaseAggregateExec => a + } + + val aggregateExecsWithoutPartialAgg = allAggregateExecs.filter { + _.requiredChildDistribution.head != UnspecifiedDistribution + } + + // We expect single partial aggregation - remaining agg execs should have child producing + // expected output partitioning. + assert(allAggregateExecs.length - 1 === aggregateExecsWithoutPartialAgg.length) + + // For aggregate execs, we make sure output partitioning of the children is same as + // we expect, HashPartitioning with clustering keys & number of partitions. + aggregateExecsWithoutPartialAgg.foreach { aggr => + assert(hasDesiredHashPartitioningInChildren(aggr, Seq(Seq("key1", "key2")), + numPartitions)) + } + } + ) + } + + test("SPARK-38204: streaming aggregation should require ClusteredDistribution " + + "from children if the query starts from checkpoint in prior to 3.3") { + + val inputData = MemoryStream[Int] + val df1 = inputData.toDF().select('value as 'key1, 'value * 2 as 'key2, 'value * 3 as 'value) + val agg = df1.repartition('key1).groupBy('key1, 'key2).agg(count('*)) + + val resourceUri = this.getClass.getResource( + "/structured-streaming/checkpoint-version-3.2.0-streaming-aggregate-with-repartition/").toURI + + val checkpointDir = Utils.createTempDir().getCanonicalFile + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + + inputData.addData(3) + inputData.addData(3, 2) + + testStream(agg, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath, + additionalConfs = Map(SQLConf.STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION.key -> "true")), + + // scalastyle:off line.size.limit + /* + Note: The checkpoint was generated using the following input in Spark version 3.2.0 + AddData(inputData, 3), + CheckLastBatch((3, 6, 1)), + AddData(inputData, 3, 2), + CheckLastBatch((3, 6, 2), (2, 4, 1)) + + Note2: The following plans are the physical plans of the query in older Spark versions + The physical plans around StateStoreRestore and StateStoreSave are quite similar, + especially shuffles being injected are same. That said, verifying with checkpoint being + built with Spark 3.2.0 would verify the following versions as well. + + A. Spark 3.2.0 + WriteToDataSourceV2 org.apache.spark.sql.execution.streaming.sources.MicroBatchWrite@61a581c0, org.apache.spark.sql.execution.datasources.v2.DataSourceV2Strategy$$Lambda$1968/1468582588@325b0006 + +- *(4) HashAggregate(keys=[key1#3, key2#4], functions=[count(1)], output=[key1#3, key2#4, count(1)#13L]) + +- StateStoreSave [key1#3, key2#4], state info [ checkpoint = file:/blabla/state, runId = 2bd7d18c-73b2-49a2-b2aa-1835162f9186, opId = 0, ver = 1, numPartitions = 5], Update, 0, 2 + +- *(3) HashAggregate(keys=[key1#3, key2#4], functions=[merge_count(1)], output=[key1#3, key2#4, count#47L]) + +- StateStoreRestore [key1#3, key2#4], state info [ checkpoint = file:/blabla/state, runId = 2bd7d18c-73b2-49a2-b2aa-1835162f9186, opId = 0, ver = 1, numPartitions = 5], 2 + +- *(2) HashAggregate(keys=[key1#3, key2#4], functions=[merge_count(1)], output=[key1#3, key2#4, count#47L]) + +- *(2) HashAggregate(keys=[key1#3, key2#4], functions=[partial_count(1)], output=[key1#3, key2#4, count#47L]) + +- Exchange hashpartitioning(key1#3, 5), REPARTITION_BY_COL, [id=#220] + +- *(1) Project [value#1 AS key1#3, (value#1 * 2) AS key2#4] + +- MicroBatchScan[value#1] MemoryStreamDataSource + + B. Spark 3.1.0 + WriteToDataSourceV2 org.apache.spark.sql.execution.streaming.sources.MicroBatchWrite@53602363 + +- *(4) HashAggregate(keys=[key1#3, key2#4], functions=[count(1)], output=[key1#3, key2#4, count(1)#13L]) + +- StateStoreSave [key1#3, key2#4], state info [ checkpoint = file:/tmp/spark-178e9eaf-b527-499c-8eb6-c9e734f9fdfc/state, runId = 9c7e8635-41ab-4141-9f46-7ab473c58560, opId = 0, ver = 1, numPartitions = 5], Update, 0, 2 + +- *(3) HashAggregate(keys=[key1#3, key2#4], functions=[merge_count(1)], output=[key1#3, key2#4, count#47L]) + +- StateStoreRestore [key1#3, key2#4], state info [ checkpoint = file:/tmp/spark-178e9eaf-b527-499c-8eb6-c9e734f9fdfc/state, runId = 9c7e8635-41ab-4141-9f46-7ab473c58560, opId = 0, ver = 1, numPartitions = 5], 2 + +- *(2) HashAggregate(keys=[key1#3, key2#4], functions=[merge_count(1)], output=[key1#3, key2#4, count#47L]) + +- *(2) HashAggregate(keys=[key1#3, key2#4], functions=[partial_count(1)], output=[key1#3, key2#4, count#47L]) + +- Exchange hashpartitioning(key1#3, 5), REPARTITION, [id=#222] + +- *(1) Project [value#1 AS key1#3, (value#1 * 2) AS key2#4] + +- *(1) Project [value#1] + +- MicroBatchScan[value#1] MemoryStreamDataSource + + C. Spark 3.0.0 + WriteToDataSourceV2 org.apache.spark.sql.execution.streaming.sources.MicroBatchWrite@33379044 + +- *(4) HashAggregate(keys=[key1#3, key2#4], functions=[count(1)], output=[key1#3, key2#4, count(1)#13L]) + +- StateStoreSave [key1#3, key2#4], state info [ checkpoint = file:/tmp/spark-83497e04-657c-4cad-b532-f433b1532302/state, runId = 1a650994-486f-4f32-92d9-f7c05d49d0a0, opId = 0, ver = 1, numPartitions = 5], Update, 0, 2 + +- *(3) HashAggregate(keys=[key1#3, key2#4], functions=[merge_count(1)], output=[key1#3, key2#4, count#47L]) + +- StateStoreRestore [key1#3, key2#4], state info [ checkpoint = file:/tmp/spark-83497e04-657c-4cad-b532-f433b1532302/state, runId = 1a650994-486f-4f32-92d9-f7c05d49d0a0, opId = 0, ver = 1, numPartitions = 5], 2 + +- *(2) HashAggregate(keys=[key1#3, key2#4], functions=[merge_count(1)], output=[key1#3, key2#4, count#47L]) + +- *(2) HashAggregate(keys=[key1#3, key2#4], functions=[partial_count(1)], output=[key1#3, key2#4, count#47L]) + +- Exchange hashpartitioning(key1#3, 5), false, [id=#104] + +- *(1) Project [value#1 AS key1#3, (value#1 * 2) AS key2#4] + +- *(1) Project [value#1] + +- MicroBatchScan[value#1] MemoryStreamDataSource + + D. Spark 2.4.0 + *(4) HashAggregate(keys=[key1#3, key2#4], functions=[count(1)], output=[key1#3, key2#4, count(1)#13L]) + +- StateStoreSave [key1#3, key2#4], state info [ checkpoint = file:/tmp/spark-c4fd5b1f-18e0-4433-ac7a-00df93464b49/state, runId = 89bfe27b-da33-4a75-9f36-97717c137b2a, opId = 0, ver = 1, numPartitions = 5], Update, 0, 2 + +- *(3) HashAggregate(keys=[key1#3, key2#4], functions=[merge_count(1)], output=[key1#3, key2#4, count#42L]) + +- StateStoreRestore [key1#3, key2#4], state info [ checkpoint = file:/tmp/spark-c4fd5b1f-18e0-4433-ac7a-00df93464b49/state, runId = 89bfe27b-da33-4a75-9f36-97717c137b2a, opId = 0, ver = 1, numPartitions = 5], 2 + +- *(2) HashAggregate(keys=[key1#3, key2#4], functions=[merge_count(1)], output=[key1#3, key2#4, count#42L]) + +- *(2) HashAggregate(keys=[key1#3, key2#4], functions=[partial_count(1)], output=[key1#3, key2#4, count#42L]) + +- Exchange hashpartitioning(key1#3, 5) + +- *(1) Project [value#47 AS key1#3, (value#47 * 2) AS key2#4] + +- *(1) Project [value#47] + +- *(1) ScanV2 MemoryStreamDataSource$[value#47] + */ + // scalastyle:on line.size.limit + + AddData(inputData, 3, 2, 1), + CheckLastBatch((3, 6, 3), (2, 4, 2), (1, 2, 1)), + + Execute { query => + val executedPlan = query.lastExecution.executedPlan + assert(!executedPlan.conf.getConf(SQLConf.STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION)) + + val numPartitions = query.lastExecution.numStateStores + + // verify state store restore/save + val stateStoreOps = executedPlan.collect { + case s: StateStoreRestoreExec => s + case s: StateStoreSaveExec => s + } + + assert(stateStoreOps.nonEmpty) + stateStoreOps.foreach { stateOp => + assert(requireClusteredDistribution(stateOp, Seq(Seq("key1", "key2")), + Some(numPartitions))) + assert(hasDesiredHashPartitioningInChildren(stateOp, Seq(Seq("key1")), + numPartitions)) + } + + // verify aggregations in between, except partial aggregation + val allAggregateExecs = executedPlan.collect { + case a: BaseAggregateExec => a + } + + val aggregateExecsWithoutPartialAgg = allAggregateExecs.filter { + _.requiredChildDistribution.head != UnspecifiedDistribution + } + + // We expect single partial aggregation - remaining agg execs should have child producing + // expected output partitioning. + assert(allAggregateExecs.length - 1 === aggregateExecsWithoutPartialAgg.length) + + // For aggregate execs, we make sure output partitioning of the children is same as + // we expect, HashPartitioning with sub-clustering keys & number of partitions. + aggregateExecsWithoutPartialAgg.foreach { aggr => + assert(requireClusteredDistribution(aggr, Seq(Seq("key1", "key2")), + Some(numPartitions))) + assert(hasDesiredHashPartitioningInChildren(aggr, Seq(Seq("key1")), + numPartitions)) + } + } + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 77334ad64c3ce..64dffe7f571ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.rdd.BlockRDD import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.exchange.Exchange @@ -109,7 +110,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { val aggregated = inputData.toDF() - .select($"*", explode($"_2") as 'value) + .select($"*", explode($"_2") as Symbol("value")) .groupBy($"_1") .agg(size(collect_set($"value"))) .as[(Int, Int)] @@ -190,8 +191,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { val aggWithWatermark = inputData.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) implicit class RichStreamExecution(query: StreamExecution) { @@ -413,13 +414,13 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { inputDataOne.toDF() .groupBy($"value") .agg(count("*")) - .where('value >= current_timestamp().cast("long") - 10L) + .where(Symbol("value") >= current_timestamp().cast("long") - 10L) val inputDataTwo = MemoryStream[Long] val aggregatedTwo = inputDataTwo.toDF() .groupBy($"value") .agg(count("*")) - .where('value >= localtimestamp().cast(TimestampType).cast("long") - 10L) + .where(Symbol("value") >= localtimestamp().cast(TimestampType).cast("long") - 10L) Seq((inputDataOne, aggregatedOne), (inputDataTwo, aggregatedTwo)).foreach { x => val inputData = x._1 @@ -475,7 +476,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { val inputData = MemoryStream[Long] val aggregated = inputData.toDF() - .select(to_utc_timestamp(from_unixtime('value * SECONDS_PER_DAY), tz)) + .select(to_utc_timestamp(from_unixtime(Symbol("value") * SECONDS_PER_DAY), tz)) .toDF("value") .groupBy($"value") .agg(count("*")) @@ -522,12 +523,12 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { val streamInput = MemoryStream[Int] val batchDF = Seq(1, 2, 3, 4, 5) .toDF("value") - .withColumn("parity", 'value % 2) - .groupBy('parity) - .agg(count("*") as 'joinValue) + .withColumn("parity", Symbol("value") % 2) + .groupBy(Symbol("parity")) + .agg(count("*") as Symbol("joinValue")) val joinDF = streamInput .toDF() - .join(batchDF, 'value === 'parity) + .join(batchDF, Symbol("value") === Symbol("parity")) // make sure we're planning an aggregate in the first place assert(batchDF.queryExecution.optimizedPlan match { case _: Aggregate => true }) @@ -542,8 +543,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { /** * This method verifies certain properties in the SparkPlan of a streaming aggregation. * First of all, it checks that the child of a `StateStoreRestoreExec` creates the desired - * data distribution, where the child could be an Exchange, or a `HashAggregateExec` which already - * provides the expected data distribution. + * data distribution, where the child is a `HashAggregateExec` which already provides + * the expected data distribution. * * The second thing it checks that the child provides the expected number of partitions. * @@ -552,7 +553,6 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { */ private def checkAggregationChain( se: StreamExecution, - expectShuffling: Boolean, expectedPartition: Int): Boolean = { val executedPlan = se.lastExecution.executedPlan val restore = executedPlan @@ -560,12 +560,17 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { .head restore.child match { case node: UnaryExecNode => - assert(node.outputPartitioning.numPartitions === expectedPartition, - "Didn't get the expected number of partitions.") - if (expectShuffling) { - assert(node.isInstanceOf[Exchange], s"Expected a shuffle, got: ${node.child}") - } else { - assert(!node.isInstanceOf[Exchange], "Didn't expect a shuffle") + node.outputPartitioning match { + case HashPartitioning(_, numPartitions) => + assert(numPartitions === expectedPartition, + "Didn't get the expected number of partitions.") + + // below case should only applied to no grouping key which leads to AllTuples + case SinglePartition if expectedPartition == 1 => // OK + + case p => + fail("Expected a hash partitioning for child output partitioning, but has " + + s"$p instead.") } case _ => fail("Expected no shuffling") @@ -605,12 +610,12 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { AddBlockData(inputSource, Seq(1)), CheckLastBatch(1), AssertOnQuery("Verify no shuffling") { se => - checkAggregationChain(se, expectShuffling = false, 1) + checkAggregationChain(se, 1) }, AddBlockData(inputSource), // create an empty trigger CheckLastBatch(1), AssertOnQuery("Verify that no exchange is required") { se => - checkAggregationChain(se, expectShuffling = false, 1) + checkAggregationChain(se, 1) }, AddBlockData(inputSource, Seq(2, 3)), CheckLastBatch(3), @@ -639,7 +644,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { def createDf(partitions: Int): Dataset[(Long, Long)] = { spark.readStream .format((new MockSourceProvider).getClass.getCanonicalName) - .load().coalesce(partitions).groupBy('a % 1).count().as[(Long, Long)] + .load().coalesce(partitions).groupBy(Symbol("a") % 1).count().as[(Long, Long)] } testStream(createDf(1), Complete())( @@ -647,10 +652,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { AddBlockData(inputSource, Seq(1)), CheckLastBatch((0L, 1L)), AssertOnQuery("Verify addition of exchange operator") { se => - checkAggregationChain( - se, - expectShuffling = true, - spark.sessionState.conf.numShufflePartitions) + checkAggregationChain(se, spark.sessionState.conf.numShufflePartitions) }, StopStream ) @@ -661,10 +663,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { AddBlockData(inputSource, Seq(2), Seq(3), Seq(4)), CheckLastBatch((0L, 4L)), AssertOnQuery("Verify no exchange added") { se => - checkAggregationChain( - se, - expectShuffling = false, - spark.sessionState.conf.numShufflePartitions) + checkAggregationChain(se, spark.sessionState.conf.numShufflePartitions) }, AddBlockData(inputSource), CheckLastBatch((0L, 4L)), @@ -677,7 +676,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { testWithAllStateVersions("SPARK-22230: last should change with new batches") { val input = MemoryStream[Int] - val aggregated = input.toDF().agg(last('value)) + val aggregated = input.toDF().agg(last(Symbol("value"))) testStream(aggregated, OutputMode.Complete())( AddData(input, 1, 2, 3), CheckLastBatch(3), @@ -766,7 +765,11 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { } testQuietlyWithAllStateVersions("changing schema of state when restarting query", - (SQLConf.STATE_STORE_FORMAT_VALIDATION_ENABLED.key, "false")) { + (SQLConf.STATE_STORE_FORMAT_VALIDATION_ENABLED.key, "false"), + // Since we only do the check in partition 0 and other partitions still may fail with + // different errors, we change the number of shuffle partitions to 1 to make the test + // result to be deterministic. + (SQLConf.SHUFFLE_PARTITIONS.key, "1")) { withTempDir { tempDir => val (inputData, aggregated) = prepareTestForChangingSchemaOfState(tempDir) @@ -790,7 +793,11 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { testQuietlyWithAllStateVersions("changing schema of state when restarting query -" + " schema check off", (SQLConf.STATE_SCHEMA_CHECK_ENABLED.key, "false"), - (SQLConf.STATE_STORE_FORMAT_VALIDATION_ENABLED.key, "false")) { + (SQLConf.STATE_STORE_FORMAT_VALIDATION_ENABLED.key, "false"), + // Since we only do the check in partition 0 and other partitions still may fail with + // different errors, we change the number of shuffle partitions to 1 to make the test + // result to be deterministic. + (SQLConf.SHUFFLE_PARTITIONS.key, "1")) { withTempDir { tempDir => val (inputData, aggregated) = prepareTestForChangingSchemaOfState(tempDir) @@ -845,8 +852,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { val aggWithWatermark = inputData.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) inputData.reset() // reset the input to clear any data from prev test diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationDistributionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationDistributionSuite.scala new file mode 100644 index 0000000000000..8dbdb3620688e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationDistributionSuite.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.io.File + +import org.apache.commons.io.FileUtils + +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Update +import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingDeduplicateExec} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.util.StatefulOpClusteredDistributionTestHelper +import org.apache.spark.util.Utils + +class StreamingDeduplicationDistributionSuite extends StreamTest + with StatefulOpClusteredDistributionTestHelper { + + import testImplicits._ + + test("SPARK-38204: streaming deduplication should require StatefulOpClusteredDistribution " + + "from children") { + + val input = MemoryStream[Int] + val df1 = input.toDF().select('value as 'key1, 'value * 2 as 'key2, 'value * 3 as 'value) + val dedup = df1.repartition('key1).dropDuplicates("key1", "key2") + + testStream(dedup, OutputMode.Update())( + AddData(input, 1, 1, 2, 3, 4), + CheckAnswer((1, 2, 3), (2, 4, 6), (3, 6, 9), (4, 8, 12)), + Execute { query => + val numPartitions = query.lastExecution.numStateStores + + val dedupExecs = query.lastExecution.executedPlan.collect { + case d: StreamingDeduplicateExec => d + } + + assert(dedupExecs.length === 1) + assert(requireStatefulOpClusteredDistribution( + dedupExecs.head, Seq(Seq("key1", "key2")), numPartitions)) + assert(hasDesiredHashPartitioningInChildren( + dedupExecs.head, Seq(Seq("key1", "key2")), numPartitions)) + } + ) + } + + test("SPARK-38204: streaming deduplication should require ClusteredDistribution " + + "from children if the query starts from checkpoint in prior to 3.3") { + + val inputData = MemoryStream[Int] + val df1 = inputData.toDF().select('value as 'key1, 'value * 2 as 'key2, 'value * 3 as 'value) + val dedup = df1.repartition('key1).dropDuplicates("key1", "key2") + + val resourceUri = this.getClass.getResource( + "/structured-streaming/checkpoint-version-3.2.0-deduplication-with-repartition/").toURI + + val checkpointDir = Utils.createTempDir().getCanonicalFile + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + + inputData.addData(1, 1, 2) + inputData.addData(3, 4) + + testStream(dedup, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath, + additionalConfs = Map(SQLConf.STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION.key -> "true")), + + // scalastyle:off line.size.limit + /* + Note: The checkpoint was generated using the following input in Spark version 3.2.0 + AddData(inputData, 1, 1, 2), + CheckLastBatch((1, 2, 3), (2, 4, 6)), + AddData(inputData, 3, 4), + CheckLastBatch((3, 6, 9), (4, 8, 12)) + + Note2: The following plans are the physical plans of the query in older Spark versions + The physical plans around StreamingDeduplicate are quite similar, especially shuffles + being injected are same. That said, verifying with checkpoint being built with + Spark 3.2.0 would verify the following versions as well. + + A. Spark 3.2.0 + WriteToDataSourceV2 org.apache.spark.sql.execution.streaming.sources.MicroBatchWrite@76467fb2, org.apache.spark.sql.execution.datasources.v2.DataSourceV2Strategy$$Lambda$1900/1334867523@32b72162 + +- StreamingDeduplicate [key1#3, key2#4], state info [ checkpoint = file:/blabla/state, runId = bf82c05e-4031-4421-89e0-28fd9127eb5b, opId = 0, ver = 1, numPartitions = 5], 0 + +- Exchange hashpartitioning(key1#3, 5), REPARTITION_BY_COL, [id=#115] + +- *(1) Project [value#1 AS key1#3, (value#1 * 2) AS key2#4, (value#1 * 3) AS value#5] + +- MicroBatchScan[value#1] MemoryStreamDataSource + + B. Spark 3.1.0 + WriteToDataSourceV2 org.apache.spark.sql.execution.streaming.sources.MicroBatchWrite@133d8337 + +- StreamingDeduplicate [key1#3, key2#4], state info [ checkpoint = file:/tmp/spark-c0b73191-75ec-4a54-89b7-368fbbc4b2a8/state, runId = 9b2baaee-1147-4faf-98b4-3c3d8ee34966, opId = 0, ver = 1, numPartitions = 5], 0 + +- Exchange hashpartitioning(key1#3, 5), REPARTITION, [id=#117] + +- *(1) Project [value#1 AS key1#3, (value#1 * 2) AS key2#4, (value#1 * 3) AS value#5] + +- *(1) Project [value#1] + +- MicroBatchScan[value#1] MemoryStreamDataSource + + C. Spark 3.0.0 + WriteToDataSourceV2 org.apache.spark.sql.execution.streaming.sources.MicroBatchWrite@bb06c00 + +- StreamingDeduplicate [key1#3, key2#4], state info [ checkpoint = file:/tmp/spark-6f8a96c7-2af5-4952-a1b4-c779766334ef/state, runId = 9a208eb0-d915-46dd-a0fd-23b1df82b951, opId = 0, ver = 1, numPartitions = 5], 0 + +- Exchange hashpartitioning(key1#3, 5), false, [id=#57] + +- *(1) Project [value#1 AS key1#3, (value#1 * 2) AS key2#4, (value#1 * 3) AS value#5] + +- *(1) Project [value#1] + +- MicroBatchScan[value#1] MemoryStreamDataSource + + D. Spark 2.4.0 + StreamingDeduplicate [key1#3, key2#4], state info [ checkpoint = file:/tmp/spark-d8a684a0-5623-4739-85e8-e45b99768aa7/state, runId = 85bd75bd-3d45-4d42-aeac-9e45fc559ee9, opId = 0, ver = 1, numPartitions = 5], 0 + +- Exchange hashpartitioning(key1#3, 5) + +- *(1) Project [value#37 AS key1#3, (value#37 * 2) AS key2#4, (value#37 * 3) AS value#5] + +- *(1) Project [value#37] + +- *(1) ScanV2 MemoryStreamDataSource$[value#37] + */ + // scalastyle:on line.size.limit + + AddData(inputData, 2, 3, 4, 5), + CheckLastBatch((5, 10, 15)), + Execute { query => + val executedPlan = query.lastExecution.executedPlan + assert(!executedPlan.conf.getConf(SQLConf.STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION)) + + val numPartitions = query.lastExecution.numStateStores + + val dedupExecs = executedPlan.collect { + case d: StreamingDeduplicateExec => d + } + + assert(dedupExecs.length === 1) + assert(requireClusteredDistribution( + dedupExecs.head, Seq(Seq("key1", "key2")), Some(numPartitions))) + assert(hasDesiredHashPartitioningInChildren( + dedupExecs.head, Seq(Seq("key1")), numPartitions)) + } + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala index aa03da6c5843f..c1908d95f39e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala @@ -146,8 +146,8 @@ class StreamingDeduplicationSuite extends StateStoreMetricsTest { .withWatermark("eventTime", "10 seconds") .dropDuplicates() .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) testStream(windowedaggregate)( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index a24e76f81b4aa..29caaf7289d6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.streaming import java.io.File +import java.lang.{Integer => JInteger} import java.sql.Timestamp import java.util.{Locale, UUID} @@ -28,6 +29,8 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinExec, StreamingSymmetricHashJoinHelper} import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreProviderId} @@ -53,9 +56,9 @@ abstract class StreamingJoinSuite val input = MemoryStream[Int] val df = input.toDF .select( - 'value as "key", + Symbol("value") as "key", timestamp_seconds($"value") as s"${prefix}Time", - ('value * multiplier) as s"${prefix}Value") + (Symbol("value") * multiplier) as s"${prefix}Value") .withWatermark(s"${prefix}Time", "10 seconds") (input, df) @@ -66,13 +69,16 @@ abstract class StreamingJoinSuite val (input1, df1) = setupStream("left", 2) val (input2, df2) = setupStream("right", 3) - val windowed1 = df1.select('key, window('leftTime, "10 second"), 'leftValue) - val windowed2 = df2.select('key, window('rightTime, "10 second"), 'rightValue) + val windowed1 = df1 + .select(Symbol("key"), window(Symbol("leftTime"), "10 second"), Symbol("leftValue")) + val windowed2 = df2 + .select(Symbol("key"), window(Symbol("rightTime"), "10 second"), Symbol("rightValue")) val joined = windowed1.join(windowed2, Seq("key", "window"), joinType) val select = if (joinType == "left_semi") { - joined.select('key, $"window.end".cast("long"), 'leftValue) + joined.select(Symbol("key"), $"window.end".cast("long"), Symbol("leftValue")) } else { - joined.select('key, $"window.end".cast("long"), 'leftValue, 'rightValue) + joined.select(Symbol("key"), $"window.end".cast("long"), Symbol("leftValue"), + Symbol("rightValue")) } (input1, input2, select) @@ -84,25 +90,29 @@ abstract class StreamingJoinSuite val (leftInput, df1) = setupStream("left", 2) val (rightInput, df2) = setupStream("right", 3) // Use different schemas to ensure the null row is being generated from the correct side. - val left = df1.select('key, window('leftTime, "10 second"), 'leftValue) - val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string")) + val left = df1.select(Symbol("key"), window(Symbol("leftTime"), "10 second"), + Symbol("leftValue")) + val right = df2.select(Symbol("key"), window(Symbol("rightTime"), "10 second"), + Symbol("rightValue").cast("string")) val joined = left.join( right, left("key") === right("key") && left("window") === right("window") - && 'leftValue > 4, + && Symbol("leftValue") > 4, joinType) val select = if (joinType == "left_semi") { - joined.select(left("key"), left("window.end").cast("long"), 'leftValue) + joined.select(left("key"), left("window.end").cast("long"), Symbol("leftValue")) } else if (joinType == "left_outer") { - joined.select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) + joined.select(left("key"), left("window.end").cast("long"), Symbol("leftValue"), + Symbol("rightValue")) } else if (joinType == "right_outer") { - joined.select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue) + joined.select(right("key"), right("window.end").cast("long"), Symbol("leftValue"), + Symbol("rightValue")) } else { - joined.select(left("key"), left("window.end").cast("long"), 'leftValue, - right("key"), right("window.end").cast("long"), 'rightValue) + joined.select(left("key"), left("window.end").cast("long"), Symbol("leftValue"), + right("key"), right("window.end").cast("long"), Symbol("rightValue")) } (leftInput, rightInput, select) @@ -114,25 +124,29 @@ abstract class StreamingJoinSuite val (leftInput, df1) = setupStream("left", 2) val (rightInput, df2) = setupStream("right", 3) // Use different schemas to ensure the null row is being generated from the correct side. - val left = df1.select('key, window('leftTime, "10 second"), 'leftValue) - val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string")) + val left = df1.select(Symbol("key"), window(Symbol("leftTime"), "10 second"), + Symbol("leftValue")) + val right = df2.select(Symbol("key"), window(Symbol("rightTime"), "10 second"), + Symbol("rightValue").cast("string")) val joined = left.join( right, left("key") === right("key") && left("window") === right("window") - && 'rightValue.cast("int") > 7, + && Symbol("rightValue").cast("int") > 7, joinType) val select = if (joinType == "left_semi") { - joined.select(left("key"), left("window.end").cast("long"), 'leftValue) + joined.select(left("key"), left("window.end").cast("long"), Symbol("leftValue")) } else if (joinType == "left_outer") { - joined.select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) + joined.select(left("key"), left("window.end").cast("long"), Symbol("leftValue"), + Symbol("rightValue")) } else if (joinType == "right_outer") { - joined.select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue) + joined.select(right("key"), right("window.end").cast("long"), Symbol("leftValue"), + Symbol("rightValue")) } else { - joined.select(left("key"), left("window.end").cast("long"), 'leftValue, - right("key"), right("window.end").cast("long"), 'rightValue) + joined.select(left("key"), left("window.end").cast("long"), Symbol("leftValue"), + right("key"), right("window.end").cast("long"), Symbol("rightValue")) } (leftInput, rightInput, select) @@ -149,12 +163,13 @@ abstract class StreamingJoinSuite val rightInput = MemoryStream[(Int, Int)] val df1 = leftInput.toDF.toDF("leftKey", "time") - .select('leftKey, timestamp_seconds($"time") as "leftTime", ('leftKey * 2) as "leftValue") + .select(Symbol("leftKey"), timestamp_seconds($"time") as "leftTime", + (Symbol("leftKey") * 2) as "leftValue") .withWatermark("leftTime", watermark) val df2 = rightInput.toDF.toDF("rightKey", "time") - .select('rightKey, timestamp_seconds($"time") as "rightTime", - ('rightKey * 3) as "rightValue") + .select(Symbol("rightKey"), timestamp_seconds($"time") as "rightTime", + (Symbol("rightKey") * 3) as "rightValue") .withWatermark("rightTime", watermark) val joined = @@ -165,9 +180,10 @@ abstract class StreamingJoinSuite joinType) val select = if (joinType == "left_semi") { - joined.select('leftKey, 'leftTime.cast("int")) + joined.select(Symbol("leftKey"), Symbol("leftTime").cast("int")) } else { - joined.select('leftKey, 'rightKey, 'leftTime.cast("int"), 'rightTime.cast("int")) + joined.select(Symbol("leftKey"), Symbol("rightKey"), Symbol("leftTime").cast("int"), + Symbol("rightTime").cast("int")) } (leftInput, rightInput, select) @@ -214,8 +230,8 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { val input1 = MemoryStream[Int] val input2 = MemoryStream[Int] - val df1 = input1.toDF.select('value as "key", ('value * 2) as "leftValue") - val df2 = input2.toDF.select('value as "key", ('value * 3) as "rightValue") + val df1 = input1.toDF.select(Symbol("value") as "key", (Symbol("value") * 2) as "leftValue") + val df2 = input2.toDF.select(Symbol("value") as "key", (Symbol("value") * 3) as "rightValue") val joined = df1.join(df2, "key") testStream(joined)( @@ -244,17 +260,17 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { val input2 = MemoryStream[Int] val df1 = input1.toDF - .select('value as "key", timestamp_seconds($"value") as "timestamp", - ('value * 2) as "leftValue") - .select('key, window('timestamp, "10 second"), 'leftValue) + .select(Symbol("value") as "key", timestamp_seconds($"value") as "timestamp", + (Symbol("value") * 2) as "leftValue") + .select(Symbol("key"), window(Symbol("timestamp"), "10 second"), Symbol("leftValue")) val df2 = input2.toDF - .select('value as "key", timestamp_seconds($"value") as "timestamp", - ('value * 3) as "rightValue") - .select('key, window('timestamp, "10 second"), 'rightValue) + .select(Symbol("value") as "key", timestamp_seconds($"value") as "timestamp", + (Symbol("value") * 3) as "rightValue") + .select(Symbol("key"), window(Symbol("timestamp"), "10 second"), Symbol("rightValue")) val joined = df1.join(df2, Seq("key", "window")) - .select('key, $"window.end".cast("long"), 'leftValue, 'rightValue) + .select(Symbol("key"), $"window.end".cast("long"), Symbol("leftValue"), Symbol("rightValue")) testStream(joined)( AddData(input1, 1), @@ -285,18 +301,18 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { val input2 = MemoryStream[Int] val df1 = input1.toDF - .select('value as "key", timestamp_seconds($"value") as "timestamp", - ('value * 2) as "leftValue") + .select(Symbol("value") as "key", timestamp_seconds($"value") as "timestamp", + (Symbol("value") * 2) as "leftValue") .withWatermark("timestamp", "10 seconds") - .select('key, window('timestamp, "10 second"), 'leftValue) + .select(Symbol("key"), window(Symbol("timestamp"), "10 second"), Symbol("leftValue")) val df2 = input2.toDF - .select('value as "key", timestamp_seconds($"value") as "timestamp", - ('value * 3) as "rightValue") - .select('key, window('timestamp, "10 second"), 'rightValue) + .select(Symbol("value") as "key", timestamp_seconds($"value") as "timestamp", + (Symbol("value") * 3) as "rightValue") + .select(Symbol("key"), window(Symbol("timestamp"), "10 second"), Symbol("rightValue")) val joined = df1.join(df2, Seq("key", "window")) - .select('key, $"window.end".cast("long"), 'leftValue, 'rightValue) + .select(Symbol("key"), $"window.end".cast("long"), Symbol("leftValue"), Symbol("rightValue")) testStream(joined)( AddData(input1, 1), @@ -336,17 +352,18 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { val rightInput = MemoryStream[(Int, Int)] val df1 = leftInput.toDF.toDF("leftKey", "time") - .select('leftKey, timestamp_seconds($"time") as "leftTime", ('leftKey * 2) as "leftValue") + .select(Symbol("leftKey"), timestamp_seconds($"time") as "leftTime", + (Symbol("leftKey") * 2) as "leftValue") .withWatermark("leftTime", "10 seconds") val df2 = rightInput.toDF.toDF("rightKey", "time") - .select('rightKey, timestamp_seconds($"time") as "rightTime", - ('rightKey * 3) as "rightValue") + .select(Symbol("rightKey"), timestamp_seconds($"time") as "rightTime", + (Symbol("rightKey") * 3) as "rightValue") .withWatermark("rightTime", "10 seconds") val joined = df1.join(df2, expr("leftKey = rightKey AND leftTime < rightTime - interval 5 seconds")) - .select('leftKey, 'leftTime.cast("int"), 'rightTime.cast("int")) + .select(Symbol("leftKey"), Symbol("leftTime").cast("int"), Symbol("rightTime").cast("int")) testStream(joined)( AddData(leftInput, (1, 5)), @@ -395,12 +412,13 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { val rightInput = MemoryStream[(Int, Int)] val df1 = leftInput.toDF.toDF("leftKey", "time") - .select('leftKey, timestamp_seconds($"time") as "leftTime", ('leftKey * 2) as "leftValue") + .select(Symbol("leftKey"), timestamp_seconds($"time") as "leftTime", + (Symbol("leftKey") * 2) as "leftValue") .withWatermark("leftTime", "20 seconds") val df2 = rightInput.toDF.toDF("rightKey", "time") - .select('rightKey, timestamp_seconds($"time") as "rightTime", - ('rightKey * 3) as "rightValue") + .select(Symbol("rightKey"), timestamp_seconds($"time") as "rightTime", + (Symbol("rightKey") * 3) as "rightValue") .withWatermark("rightTime", "30 seconds") val condition = expr( @@ -429,7 +447,8 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { // drop state where rightTime < eventTime - 5 val joined = - df1.join(df2, condition).select('leftKey, 'leftTime.cast("int"), 'rightTime.cast("int")) + df1.join(df2, condition).select(Symbol("leftKey"), Symbol("leftTime").cast("int"), + Symbol("rightTime").cast("int")) testStream(joined)( // If leftTime = 20, then it match only with rightTime = [15, 30] @@ -476,8 +495,10 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { val input1 = MemoryStream[Int] val input2 = MemoryStream[Int] - val df1 = input1.toDF.select('value as "leftKey", ('value * 2) as "leftValue") - val df2 = input2.toDF.select('value as "rightKey", ('value * 3) as "rightValue") + val df1 = input1.toDF + .select(Symbol("value") as "leftKey", (Symbol("value") * 2) as "leftValue") + val df2 = input2.toDF + .select(Symbol("value") as "rightKey", (Symbol("value") * 3) as "rightValue") val joined = df1.join(df2, expr("leftKey < rightKey")) val e = intercept[Exception] { val q = joined.writeStream.format("memory").queryName("test").start() @@ -491,8 +512,8 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { val input = MemoryStream[Int] val df = input.toDF val join = - df.select('value % 5 as "key", 'value).join( - df.select('value % 5 as "key", 'value), "key") + df.select(Symbol("value") % 5 as "key", Symbol("value")).join( + df.select(Symbol("value") % 5 as "key", Symbol("value")), "key") testStream(join)( AddData(input, 1, 2), @@ -556,9 +577,11 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { val input2 = MemoryStream[Int] val input3 = MemoryStream[Int] - val df1 = input1.toDF.select('value as "leftKey", ('value * 2) as "leftValue") - val df2 = input2.toDF.select('value as "middleKey", ('value * 3) as "middleValue") - val df3 = input3.toDF.select('value as "rightKey", ('value * 5) as "rightValue") + val df1 = input1.toDF.select(Symbol("value") as "leftKey", (Symbol("value") * 2) as "leftValue") + val df2 = input2.toDF + .select(Symbol("value") as "middleKey", (Symbol("value") * 3) as "middleValue") + val df3 = input3.toDF + .select(Symbol("value") as "rightKey", (Symbol("value") * 5) as "rightValue") val joined = df1.join(df2, expr("leftKey = middleKey")).join(df3, expr("rightKey = middleKey")) @@ -569,13 +592,16 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { CheckNewAnswer((5, 10, 5, 15, 5, 25))) } - test("streaming join should require HashClusteredDistribution from children") { + test("streaming join should require StatefulOpClusteredDistribution from children") { val input1 = MemoryStream[Int] val input2 = MemoryStream[Int] - val df1 = input1.toDF.select('value as 'a, 'value * 2 as 'b) - val df2 = input2.toDF.select('value as 'a, 'value * 2 as 'b).repartition('b) - val joined = df1.join(df2, Seq("a", "b")).select('a) + val df1 = input1.toDF + .select(Symbol("value") as Symbol("a"), Symbol("value") * 2 as Symbol("b")) + val df2 = input2.toDF + .select(Symbol("value") as Symbol("a"), Symbol("value") * 2 as Symbol("b")) + .repartition(Symbol("b")) + val joined = df1.join(df2, Seq("a", "b")).select(Symbol("a")) testStream(joined)( AddData(input1, 1.to(1000): _*), @@ -583,9 +609,21 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { CheckAnswer(1.to(1000): _*), Execute { query => // Verify the query plan + def partitionExpressionsColumns(expressions: Seq[Expression]): Seq[String] = { + expressions.flatMap { + case ref: AttributeReference => Some(ref.name) + } + } + + val numPartitions = spark.sqlContext.conf.getConf(SQLConf.SHUFFLE_PARTITIONS) + assert(query.lastExecution.executedPlan.collect { case j @ StreamingSymmetricHashJoinExec(_, _, _, _, _, _, _, _, - _: ShuffleExchangeExec, _: ShuffleExchangeExec) => j + ShuffleExchangeExec(opA: HashPartitioning, _, _), + ShuffleExchangeExec(opB: HashPartitioning, _, _)) + if partitionExpressionsColumns(opA.expressions) === Seq("a", "b") + && partitionExpressionsColumns(opB.expressions) === Seq("a", "b") + && opA.numPartitions == numPartitions && opB.numPartitions == numPartitions => j }.size == 1) }) } @@ -652,18 +690,18 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { val input2 = MemoryStream[Int] val df1 = input1.toDF - .select('value as "key", timestamp_seconds($"value") as "timestamp", - ('value * 2) as "leftValue") + .select(Symbol("value") as "key", timestamp_seconds($"value") as "timestamp", + (Symbol("value") * 2) as "leftValue") .withWatermark("timestamp", "10 seconds") - .select('key, window('timestamp, "10 second"), 'leftValue) + .select(Symbol("key"), window(Symbol("timestamp"), "10 second"), Symbol("leftValue")) val df2 = input2.toDF - .select('value as "key", timestamp_seconds($"value") as "timestamp", - ('value * 3) as "rightValue") - .select('key, window('timestamp, "10 second"), 'rightValue) + .select(Symbol("value") as "key", timestamp_seconds($"value") as "timestamp", + (Symbol("value") * 3) as "rightValue") + .select(Symbol("key"), window(Symbol("timestamp"), "10 second"), Symbol("rightValue")) val joined = df1.join(df2, Seq("key", "window")) - .select('key, $"window.end".cast("long"), 'leftValue, 'rightValue) + .select(Symbol("key"), $"window.end".cast("long"), Symbol("leftValue"), Symbol("rightValue")) testStream(joined)( StartStream(additionalConfs = Map(SQLConf.SHUFFLE_PARTITIONS.key -> "3")), @@ -688,6 +726,53 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { total = Seq(2), updated = Seq(1), droppedByWatermark = Seq(0), removed = Some(Seq(0))) ) } + + test("joining non-nullable left join key with nullable right join key") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[JInteger] + + val joined = testForJoinKeyNullability(input1.toDF(), input2.toDF()) + testStream(joined)( + AddData(input1, 1, 5), + AddData(input2, JInteger.valueOf(1), JInteger.valueOf(5), JInteger.valueOf(10), null), + CheckNewAnswer(Row(1, 1, 2, 3), Row(5, 5, 10, 15)) + ) + } + + test("joining nullable left join key with non-nullable right join key") { + val input1 = MemoryStream[JInteger] + val input2 = MemoryStream[Int] + + val joined = testForJoinKeyNullability(input1.toDF(), input2.toDF()) + testStream(joined)( + AddData(input1, JInteger.valueOf(1), JInteger.valueOf(5), JInteger.valueOf(10), null), + AddData(input2, 1, 5), + CheckNewAnswer(Row(1, 1, 2, 3), Row(5, 5, 10, 15)) + ) + } + + test("joining nullable left join key with nullable right join key") { + val input1 = MemoryStream[JInteger] + val input2 = MemoryStream[JInteger] + + val joined = testForJoinKeyNullability(input1.toDF(), input2.toDF()) + testStream(joined)( + AddData(input1, JInteger.valueOf(1), JInteger.valueOf(5), JInteger.valueOf(10), null), + AddData(input2, JInteger.valueOf(1), JInteger.valueOf(5), null), + CheckNewAnswer( + Row(JInteger.valueOf(1), JInteger.valueOf(1), JInteger.valueOf(2), JInteger.valueOf(3)), + Row(JInteger.valueOf(5), JInteger.valueOf(5), JInteger.valueOf(10), JInteger.valueOf(15)), + Row(null, null, null, null)) + ) + } + + private def testForJoinKeyNullability(left: DataFrame, right: DataFrame): DataFrame = { + val df1 = left.selectExpr("value as leftKey", "value * 2 as leftValue") + val df2 = right.selectExpr("value as rightKey", "value * 3 as rightValue") + + df1.join(df2, expr("leftKey <=> rightKey")) + .select("leftKey", "rightKey", "leftValue", "rightValue") + } } @@ -862,15 +947,19 @@ class StreamingOuterJoinSuite extends StreamingJoinSuite { val (leftInput, simpleLeftDf) = setupStream("left", 2) val (rightInput, simpleRightDf) = setupStream("right", 3) - val left = simpleLeftDf.select('key, window('leftTime, "10 second"), 'leftValue) - val right = simpleRightDf.select('key, window('rightTime, "10 second"), 'rightValue) + val left = simpleLeftDf + .select(Symbol("key"), window(Symbol("leftTime"), "10 second"), Symbol("leftValue")) + val right = simpleRightDf + .select(Symbol("key"), window(Symbol("rightTime"), "10 second"), Symbol("rightValue")) val joined = left.join( right, left("key") === right("key") && left("window") === right("window") && - 'leftValue > 10 && ('rightValue < 300 || 'rightValue > 1000), + Symbol("leftValue") > 10 && + (Symbol("rightValue") < 300 || Symbol("rightValue") > 1000), "left_outer") - .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) + .select(left("key"), left("window.end").cast("long"), Symbol("leftValue"), + Symbol("rightValue")) testStream(joined)( // leftValue <= 10 should generate outer join rows even though it matches right keys @@ -1061,9 +1150,9 @@ class StreamingOuterJoinSuite extends StreamingJoinSuite { val input1 = MemoryStream[Int](desiredPartitionsForInput1) val df1 = input1.toDF .select( - 'value as "key", - 'value as "leftValue", - 'value as "rightValue") + Symbol("value") as "key", + Symbol("value") as "leftValue", + Symbol("value") as "rightValue") val (input2, df2) = setupStream("left", 2) val (input3, df3) = setupStream("right", 3) @@ -1071,7 +1160,7 @@ class StreamingOuterJoinSuite extends StreamingJoinSuite { .join(df3, df2("key") === df3("key") && df2("leftTime") === df3("rightTime"), "inner") - .select(df2("key"), 'leftValue, 'rightValue) + .select(df2("key"), Symbol("leftValue"), Symbol("rightValue")) (input1, input2, input3, df1.union(joined)) } @@ -1154,6 +1243,116 @@ class StreamingOuterJoinSuite extends StreamingJoinSuite { CheckNewAnswer(expectedOutput.head, expectedOutput.tail: _*) ) } + + test("left-outer: joining non-nullable left join key with nullable right join key") { + val input1 = MemoryStream[(Int, Int)] + val input2 = MemoryStream[(JInteger, Int)] + + val joined = testForLeftOuterJoinKeyNullability(input1.toDF(), input2.toDF()) + + testStream(joined)( + AddData(input1, (1, 1), (1, 2), (1, 3), (1, 4), (1, 5)), + AddData(input2, + (JInteger.valueOf(1), 3), + (JInteger.valueOf(1), 4), + (JInteger.valueOf(1), 5), + (JInteger.valueOf(1), 6) + ), + CheckNewAnswer( + Row(1, 1, 3, 3, 10, 6, 9), + Row(1, 1, 4, 4, 10, 8, 12), + Row(1, 1, 5, 5, 10, 10, 15)), + AddData(input1, (1, 21)), + // right-null join + AddData(input2, (JInteger.valueOf(1), 22)), // watermark = 11, no-data-batch computes nulls + CheckNewAnswer( + Row(1, null, 1, null, 10, 2, null), + Row(1, null, 2, null, 10, 4, null) + ) + ) + } + + test("left-outer: joining nullable left join key with non-nullable right join key") { + val input1 = MemoryStream[(JInteger, Int)] + val input2 = MemoryStream[(Int, Int)] + + val joined = testForLeftOuterJoinKeyNullability(input1.toDF(), input2.toDF()) + + testStream(joined)( + AddData(input1, + (JInteger.valueOf(1), 1), + (null, 2), + (JInteger.valueOf(1), 3), + (JInteger.valueOf(1), 4), + (JInteger.valueOf(1), 5)), + AddData(input2, (1, 3), (1, 4), (1, 5), (1, 6)), + CheckNewAnswer( + Row(1, 1, 3, 3, 10, 6, 9), + Row(1, 1, 4, 4, 10, 8, 12), + Row(1, 1, 5, 5, 10, 10, 15)), + // right-null join + AddData(input1, (JInteger.valueOf(1), 21)), + AddData(input2, (1, 22)), // watermark = 11, no-data-batch computes nulls + CheckNewAnswer( + Row(1, null, 1, null, 10, 2, null), + Row(null, null, 2, null, 10, 4, null) + ) + ) + } + + test("left-outer: joining nullable left join key with nullable right join key") { + val input1 = MemoryStream[(JInteger, Int)] + val input2 = MemoryStream[(JInteger, Int)] + + val joined = testForLeftOuterJoinKeyNullability(input1.toDF(), input2.toDF()) + + testStream(joined)( + AddData(input1, + (JInteger.valueOf(1), 1), + (null, 2), + (JInteger.valueOf(1), 3), + (null, 4), + (JInteger.valueOf(1), 5)), + AddData(input2, + (JInteger.valueOf(1), 3), + (null, 4), + (JInteger.valueOf(1), 5), + (JInteger.valueOf(1), 6)), + CheckNewAnswer( + Row(1, 1, 3, 3, 10, 6, 9), + Row(null, null, 4, 4, 10, 8, 12), + Row(1, 1, 5, 5, 10, 10, 15)), + // right-null join + AddData(input1, (JInteger.valueOf(1), 21)), + AddData(input2, (JInteger.valueOf(1), 22)), // watermark = 11, no-data-batch computes nulls + CheckNewAnswer( + Row(1, null, 1, null, 10, 2, null), + Row(null, null, 2, null, 10, 4, null) + ) + ) + } + + private def testForLeftOuterJoinKeyNullability(left: DataFrame, right: DataFrame): DataFrame = { + val df1 = left + .selectExpr("_1 as leftKey1", "_2 as leftKey2", "timestamp_seconds(_2) as leftTime", + "_2 * 2 as leftValue") + .withWatermark("leftTime", "10 seconds") + val df2 = right + .selectExpr( + "_1 as rightKey1", "_2 as rightKey2", "timestamp_seconds(_2) as rightTime", + "_2 * 3 as rightValue") + .withWatermark("rightTime", "10 seconds") + + val windowed1 = df1.select(Symbol("leftKey1"), Symbol("leftKey2"), + window(Symbol("leftTime"), "10 second").as(Symbol("leftWindow")), Symbol("leftValue")) + val windowed2 = df2.select(Symbol("rightKey1"), Symbol("rightKey2"), + window(Symbol("rightTime"), "10 second").as(Symbol("rightWindow")), Symbol("rightValue")) + windowed1.join(windowed2, + expr("leftKey1 <=> rightKey1 AND leftKey2 = rightKey2 AND leftWindow = rightWindow"), + "left_outer" + ).select(Symbol("leftKey1"), Symbol("rightKey1"), Symbol("leftKey2"), Symbol("rightKey2"), + $"leftWindow.end".cast("long"), Symbol("leftValue"), Symbol("rightValue")) + } } class StreamingFullOuterJoinSuite extends StreamingJoinSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala index 96f7efeef98e6..cc66ce856732a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala @@ -201,6 +201,10 @@ class StreamingQueryManagerSuite extends StreamTest { // After that query is stopped, awaitAnyTerm should throw exception eventually(Timeout(streamingTimeout)) { require(!q3.isActive) } // wait for query to stop + // When `isActive` becomes `false`, `StreamingQueryManager` may not receive the error yet. + // Hence, call `stop` to wait until the thread of `q3` exits so that we can ensure + // `StreamingQueryManager` has already received the error. + q3.stop() testAwaitAnyTermination( ExpectException[SparkException], awaitTimeout = 100.milliseconds, @@ -217,6 +221,10 @@ class StreamingQueryManagerSuite extends StreamTest { require(!q4.isActive) val q5 = stopRandomQueryAsync(10.milliseconds, withError = true) eventually(Timeout(streamingTimeout)) { require(!q5.isActive) } + // When `isActive` becomes `false`, `StreamingQueryManager` may not receive the error yet. + // Hence, call `stop` to wait until the thread of `q5` exits so that we can ensure + // `StreamingQueryManager` has already received the error. + q5.stop() // After q5 terminates with exception, awaitAnyTerm should start throwing exception testAwaitAnyTermination(ExpectException[SparkException], awaitTimeout = 2.seconds) } @@ -447,9 +455,9 @@ class StreamingQueryManagerSuite extends StreamTest { /** Stop a random active query either with `stop()` or with an error */ private def stopRandomQueryAsync(stopAfter: Span, withError: Boolean): StreamingQuery = { - + // scalastyle:off executioncontextglobal import scala.concurrent.ExecutionContext.Implicits.global - + // scalastyle:on executioncontextglobal val activeQueries = spark.streams.active val queryToStop = activeQueries(Random.nextInt(activeQueries.length)) Future { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala index 99fcef109a07c..7bc4288b2c1c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala @@ -237,7 +237,7 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { val inputData = MemoryStream[Int] val query = inputData.toDS().toDF("value") - .select('value) + .select(Symbol("value")) .groupBy($"value") .agg(count("*")) .writeStream diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 54bed5c966d1f..84060733e865c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -860,8 +860,8 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi val baseDf = Seq((1, "A"), (2, "b")).toDF("num", "char").where("char = 'A'") val otherDf = stream.toDF().toDF("num", "numSq") .join(broadcast(baseDf), "num") - .groupBy('char) - .agg(sum('numSq)) + .groupBy(Symbol("char")) + .agg(sum(Symbol("numSq"))) testStream(otherDf, OutputMode.Complete())( AddData(stream, (1, 1), (2, 4)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowDistributionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowDistributionSuite.scala new file mode 100644 index 0000000000000..bb7b9804105fa --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowDistributionSuite.scala @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.io.File + +import org.apache.commons.io.FileUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution +import org.apache.spark.sql.execution.aggregate.BaseAggregateExec +import org.apache.spark.sql.execution.streaming.{MemoryStream, SessionWindowStateStoreRestoreExec, SessionWindowStateStoreSaveExec} +import org.apache.spark.sql.functions.{count, session_window} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.util.StatefulOpClusteredDistributionTestHelper +import org.apache.spark.util.Utils + +class StreamingSessionWindowDistributionSuite extends StreamTest + with StatefulOpClusteredDistributionTestHelper with Logging { + + import testImplicits._ + + test("SPARK-38204: session window aggregation should require StatefulOpClusteredDistribution " + + "from children") { + + withSQLConf( + // exclude partial merging session to simplify test + SQLConf.STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION.key -> "false") { + + val inputData = MemoryStream[(String, String, Long)] + + // Split the lines into words, treat words as sessionId of events + val events = inputData.toDF() + .select($"_1".as("value"), $"_2".as("userId"), $"_3".as("timestamp")) + .withColumn("eventTime", $"timestamp".cast("timestamp")) + .withWatermark("eventTime", "30 seconds") + .selectExpr("explode(split(value, ' ')) AS sessionId", "userId", "eventTime") + + val sessionUpdates = events + .repartition($"userId") + .groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId, 'userId) + .agg(count("*").as("numEvents")) + .selectExpr("sessionId", "userId", "CAST(session.start AS LONG)", + "CAST(session.end AS LONG)", + "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", + "numEvents") + + testStream(sessionUpdates, OutputMode.Append())( + AddData(inputData, + ("hello world spark streaming", "key1", 40L), + ("world hello structured streaming", "key2", 41L) + ), + + // skip checking the result, since we focus to verify the physical plan + ProcessAllAvailable(), + Execute { query => + val numPartitions = query.lastExecution.numStateStores + + val operators = query.lastExecution.executedPlan.collect { + case s: SessionWindowStateStoreRestoreExec => s + case s: SessionWindowStateStoreSaveExec => s + } + + assert(operators.nonEmpty) + operators.foreach { stateOp => + assert(requireStatefulOpClusteredDistribution(stateOp, Seq(Seq("sessionId", "userId")), + numPartitions)) + assert(hasDesiredHashPartitioningInChildren(stateOp, Seq(Seq("sessionId", "userId")), + numPartitions)) + } + + // Verify aggregations in between, except partial aggregation. + // This includes MergingSessionsExec. + val allAggregateExecs = query.lastExecution.executedPlan.collect { + case a: BaseAggregateExec => a + } + + val aggregateExecsWithoutPartialAgg = allAggregateExecs.filter { + _.requiredChildDistribution.head != UnspecifiedDistribution + } + + // We expect single partial aggregation since we disable partial merging sessions. + // Remaining agg execs should have child producing expected output partitioning. + assert(allAggregateExecs.length - 1 === aggregateExecsWithoutPartialAgg.length) + + // For aggregate execs, we make sure output partitioning of the children is same as + // we expect, HashPartitioning with clustering keys & number of partitions. + aggregateExecsWithoutPartialAgg.foreach { aggr => + assert(hasDesiredHashPartitioningInChildren(aggr, Seq(Seq("sessionId", "userId")), + numPartitions)) + } + } + ) + } + } + + test("SPARK-38204: session window aggregation should require ClusteredDistribution " + + "from children if the query starts from checkpoint in 3.2") { + + withSQLConf( + // exclude partial merging session to simplify test + SQLConf.STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION.key -> "false") { + + val inputData = MemoryStream[(String, String, Long)] + + // Split the lines into words, treat words as sessionId of events + val events = inputData.toDF() + .select($"_1".as("value"), $"_2".as("userId"), $"_3".as("timestamp")) + .withColumn("eventTime", $"timestamp".cast("timestamp")) + .withWatermark("eventTime", "30 seconds") + .selectExpr("explode(split(value, ' ')) AS sessionId", "userId", "eventTime") + + val sessionUpdates = events + .repartition($"userId") + .groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId, 'userId) + .agg(count("*").as("numEvents")) + .selectExpr("sessionId", "userId", "CAST(session.start AS LONG)", + "CAST(session.end AS LONG)", + "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", + "numEvents") + + val resourceUri = this.getClass.getResource( + "/structured-streaming/checkpoint-version-3.2.0-session-window-with-repartition/").toURI + + val checkpointDir = Utils.createTempDir().getCanonicalFile + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + + inputData.addData( + ("hello world spark streaming", "key1", 40L), + ("world hello structured streaming", "key2", 41L)) + + testStream(sessionUpdates, OutputMode.Append())( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath, + additionalConfs = Map(SQLConf.STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION.key -> "true")), + + // scalastyle:off line.size.limit + /* + Note: The checkpoint was generated using the following input in Spark version 3.2.0 + AddData(inputData, + ("hello world spark streaming", "key1", 40L), + ("world hello structured streaming", "key2", 41L)), + // skip checking the result, since we focus to verify the physical plan + ProcessAllAvailable() + + Note2: The following is the physical plan of the query in Spark version 3.2.0. + + WriteToDataSourceV2 org.apache.spark.sql.execution.streaming.sources.MicroBatchWrite@6649ee50, org.apache.spark.sql.execution.datasources.v2.DataSourceV2Strategy$$Lambda$2209/0x0000000840ebd440@f9f45c6 + +- *(3) HashAggregate(keys=[session_window#33-T30000ms, sessionId#21, userId#10], functions=[count(1)], output=[sessionId#21, userId#10, CAST(session.start AS BIGINT)#43L, CAST(session.end AS BIGINT)#44L, durationMs#38L, numEvents#32L]) + +- SessionWindowStateStoreSave [sessionId#21, userId#10], session_window#33: struct, state info [ checkpoint = file:/tmp/spark-f8a951f5-c7c1-43b0-883d-9b893d672ee5/state, runId = 92681f36-1f0d-434e-8492-897e4e988bb3, opId = 0, ver = 1, numPartitions = 5], Append, 11000, 1 + +- MergingSessions List(ClusteredDistribution(ArrayBuffer(sessionId#21, userId#10),None)), [session_window#33-T30000ms, sessionId#21, userId#10], session_window#33: struct, [merge_count(1)], [count(1)#30L], 3, [session_window#33-T30000ms, sessionId#21, userId#10, count#58L] + +- SessionWindowStateStoreRestore [sessionId#21, userId#10], session_window#33: struct, state info [ checkpoint = file:/tmp/spark-f8a951f5-c7c1-43b0-883d-9b893d672ee5/state, runId = 92681f36-1f0d-434e-8492-897e4e988bb3, opId = 0, ver = 1, numPartitions = 5], 11000, 1 + +- *(2) Sort [sessionId#21 ASC NULLS FIRST, userId#10 ASC NULLS FIRST, session_window#33-T30000ms ASC NULLS FIRST], false, 0 + +- *(2) HashAggregate(keys=[session_window#33-T30000ms, sessionId#21, userId#10], functions=[partial_count(1)], output=[session_window#33-T30000ms, sessionId#21, userId#10, count#58L]) + +- *(2) Project [named_struct(start, precisetimestampconversion(precisetimestampconversion(eventTime#15-T30000ms, TimestampType, LongType), LongType, TimestampType), end, precisetimestampconversion(precisetimestampconversion(eventTime#15-T30000ms + 10 seconds, TimestampType, LongType), LongType, TimestampType)) AS session_window#33-T30000ms, sessionId#21, userId#10] + +- Exchange hashpartitioning(userId#10, 5), REPARTITION_BY_COL, [id=#372] + +- *(1) Project [sessionId#21, userId#10, eventTime#15-T30000ms] + +- *(1) Generate explode(split(value#9, , -1)), [userId#10, eventTime#15-T30000ms], false, [sessionId#21] + +- *(1) Filter (precisetimestampconversion(precisetimestampconversion(eventTime#15-T30000ms + 10 seconds, TimestampType, LongType), LongType, TimestampType) > precisetimestampconversion(precisetimestampconversion(eventTime#15-T30000ms, TimestampType, LongType), LongType, TimestampType)) + +- EventTimeWatermark eventTime#15: timestamp, 30 seconds + +- LocalTableScan , [value#9, userId#10, eventTime#15] + */ + // scalastyle:on line.size.limit + + AddData(inputData, ("spark streaming", "key1", 25L)), + // skip checking the result, since we focus to verify the physical plan + ProcessAllAvailable(), + + Execute { query => + val numPartitions = query.lastExecution.numStateStores + + val operators = query.lastExecution.executedPlan.collect { + case s: SessionWindowStateStoreRestoreExec => s + case s: SessionWindowStateStoreSaveExec => s + } + + assert(operators.nonEmpty) + operators.foreach { stateOp => + assert(requireClusteredDistribution(stateOp, Seq(Seq("sessionId", "userId")), + Some(numPartitions))) + assert(hasDesiredHashPartitioningInChildren(stateOp, Seq(Seq("userId")), + numPartitions)) + } + + // Verify aggregations in between, except partial aggregation. + // This includes MergingSessionsExec. + val allAggregateExecs = query.lastExecution.executedPlan.collect { + case a: BaseAggregateExec => a + } + + val aggregateExecsWithoutPartialAgg = allAggregateExecs.filter { + _.requiredChildDistribution.head != UnspecifiedDistribution + } + + // We expect single partial aggregation since we disable partial merging sessions. + // Remaining agg execs should have child producing expected output partitioning. + assert(allAggregateExecs.length - 1 === aggregateExecsWithoutPartialAgg.length) + + // For aggregate execs, we make sure output partitioning of the children is same as + // we expect, HashPartitioning with sub-clustering keys & number of partitions. + aggregateExecsWithoutPartialAgg.foreach { aggr => + assert(hasDesiredHashPartitioningInChildren(aggr, Seq(Seq("userId")), + numPartitions)) + } + } + ) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala index e82b9df93dd7d..d0f3a87acbc29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala @@ -417,7 +417,7 @@ class StreamingSessionWindowSuite extends StreamTest .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime") events - .groupBy(sessionWindow as 'session, 'sessionId) + .groupBy(sessionWindow as Symbol("session"), Symbol("sessionId")) .agg(count("*").as("numEvents")) .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", @@ -429,8 +429,8 @@ class StreamingSessionWindowSuite extends StreamTest .selectExpr("*") .withColumn("eventTime", $"value".cast("timestamp")) .withWatermark("eventTime", "10 seconds") - .groupBy(session_window($"eventTime", "5 seconds") as 'session) - .agg(count("*") as 'count, sum("value") as 'sum) + .groupBy(session_window($"eventTime", "5 seconds") as Symbol("session")) + .agg(count("*") as Symbol("count"), sum("value") as Symbol("sum")) .select($"session".getField("start").cast("long").as[Long], $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TriggerAvailableNowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TriggerAvailableNowSuite.scala index 0c7348b91c0a7..cb4410d9da92d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TriggerAvailableNowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TriggerAvailableNowSuite.scala @@ -128,7 +128,7 @@ class TriggerAvailableNowSuite extends FileStreamSourceTest { .option("maxFilesPerTrigger", 1) .text(src.getCanonicalPath) - val df2 = testSource.toDF + val df2 = testSource.toDF.selectExpr("cast(value as string)") def startQuery(): StreamingQuery = { df1.union(df2).writeStream diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 0e2fcfbd46356..5893c3da09812 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -257,7 +257,7 @@ class ContinuousSuite extends ContinuousSuiteBase { .option("numPartitions", "2") .option("rowsPerSecond", "2") .load() - .select('value) + .select(Symbol("value")) val query = df.writeStream .format("memory") @@ -306,7 +306,7 @@ class ContinuousStressSuite extends ContinuousSuiteBase { .option("numPartitions", "5") .option("rowsPerSecond", "500") .load() - .select('value) + .select(Symbol("value")) testStream(df)( StartStream(longContinuousTrigger), @@ -326,7 +326,7 @@ class ContinuousStressSuite extends ContinuousSuiteBase { .option("numPartitions", "5") .option("rowsPerSecond", "500") .load() - .select('value) + .select(Symbol("value")) testStream(df)( StartStream(Trigger.Continuous(2012)), @@ -345,7 +345,7 @@ class ContinuousStressSuite extends ContinuousSuiteBase { .option("numPartitions", "5") .option("rowsPerSecond", "500") .load() - .select('value) + .select(Symbol("value")) testStream(df)( StartStream(Trigger.Continuous(1012)), @@ -436,7 +436,7 @@ class ContinuousEpochBacklogSuite extends ContinuousSuiteBase { .option("numPartitions", "2") .option("rowsPerSecond", "500") .load() - .select('value) + .select(Symbol("value")) testStream(df)( StartStream(Trigger.Continuous(1)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index fc78527af381e..c40ba02fd0dd8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -553,7 +553,10 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { val createArray = udf { (length: Long) => for (i <- 1 to length.toInt) yield i.toString } - spark.range(4).select(createArray('id + 1) as 'ex, 'id, 'id % 4 as 'part).coalesce(1).write + spark.range(4) + .select(createArray(Symbol("id") + 1) as Symbol("ex"), Symbol("id"), + Symbol("id") % 4 as Symbol("part")) + .coalesce(1).write .partitionBy("part", "id") .mode("overwrite") .parquet(src.toString) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala index 62e944c96ef9a..61b3ec26a4d20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala @@ -162,11 +162,9 @@ class DataStreamTableAPISuite extends StreamTest with BeforeAndAfter { spark.sql(s"CREATE TABLE $tblName (data int) USING $v2Source") // Check the StreamingRelationV2 has been replaced by StreamingRelation - val plan = spark.readStream.option("path", tempDir.getCanonicalPath).table(tblName) - .queryExecution.analyzed.collectFirst { - case d: StreamingRelationV2 => d - } - assert(plan.isEmpty) + val exists = spark.readStream.option("path", tempDir.getCanonicalPath).table(tblName) + .queryExecution.analyzed.exists(_.isInstanceOf[StreamingRelationV2]) + assert(!exists) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/StreamingQueryPageSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/StreamingQueryPageSuite.scala index 246fa1f7c9184..78ade6a1eef36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/StreamingQueryPageSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/StreamingQueryPageSuite.scala @@ -103,7 +103,7 @@ class StreamingQueryPageSuite extends SharedSparkSession with BeforeAndAfter { when(summary.isActive).thenReturn(true) when(summary.name).thenReturn("query") when(summary.id).thenReturn(id) - when(summary.runId).thenReturn(id) + when(summary.runId).thenReturn(id.toString) when(summary.startTimestamp).thenReturn(1L) when(summary.exception).thenReturn(None) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatusListenerSuite.scala index 91c55d5598a6b..1d1b51354f8d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatusListenerSuite.scala @@ -23,13 +23,16 @@ import java.util.{Date, UUID} import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS} import org.scalatest.time.SpanSugar._ +import org.apache.spark.SparkConf +import org.apache.spark.internal.config.History.{HYBRID_STORE_DISK_BACKEND, HybridStoreDiskBackend} import org.apache.spark.sql.catalyst.util.DateTimeUtils.getTimeZone import org.apache.spark.sql.execution.ui.StreamingQueryStatusStore import org.apache.spark.sql.internal.StaticSQLConf import org.apache.spark.sql.streaming.{StreamingQueryListener, StreamingQueryProgress, StreamTest} import org.apache.spark.sql.streaming -import org.apache.spark.status.ElementTrackingStore -import org.apache.spark.util.kvstore.InMemoryStore +import org.apache.spark.status.{ElementTrackingStore, KVUtils} +import org.apache.spark.util.Utils +import org.apache.spark.util.kvstore.{InMemoryStore, KVStore} class StreamingQueryStatusListenerSuite extends StreamTest { @@ -48,7 +51,7 @@ class StreamingQueryStatusListenerSuite extends StreamTest { // result checking assert(queryStore.allQueryUIData.count(_.summary.isActive) == 1) assert(queryStore.allQueryUIData.filter(_.summary.isActive).exists(uiData => - uiData.summary.runId == runId && uiData.summary.name.equals("test"))) + uiData.summary.runId == runId.toString && uiData.summary.name.equals("test"))) // handle query progress event val progress = mock(classOf[StreamingQueryProgress], RETURNS_SMART_NULLS) @@ -64,7 +67,7 @@ class StreamingQueryStatusListenerSuite extends StreamTest { // result checking val activeQuery = - queryStore.allQueryUIData.filter(_.summary.isActive).find(_.summary.runId == runId) + queryStore.allQueryUIData.filter(_.summary.isActive).find(_.summary.runId == runId.toString) assert(activeQuery.isDefined) assert(activeQuery.get.summary.isActive) assert(activeQuery.get.recentProgress.length == 1) @@ -81,7 +84,8 @@ class StreamingQueryStatusListenerSuite extends StreamTest { listener.onQueryTerminated(terminateEvent) assert(!queryStore.allQueryUIData.filterNot(_.summary.isActive).head.summary.isActive) - assert(queryStore.allQueryUIData.filterNot(_.summary.isActive).head.summary.runId == runId) + assert( + queryStore.allQueryUIData.filterNot(_.summary.isActive).head.summary.runId == runId.toString) assert(queryStore.allQueryUIData.filterNot(_.summary.isActive).head.summary.id == id) } @@ -110,10 +114,12 @@ class StreamingQueryStatusListenerSuite extends StreamTest { // result checking assert(queryStore.allQueryUIData.count(_.summary.isActive) == 1) assert(queryStore.allQueryUIData.filterNot(_.summary.isActive).length == 1) - assert(queryStore.allQueryUIData.filter(_.summary.isActive).exists(_.summary.runId == runId1)) + assert(queryStore.allQueryUIData.filter(_.summary.isActive).exists( + _.summary.runId == runId1.toString)) assert(queryStore.allQueryUIData.filter(_.summary.isActive).exists(uiData => - uiData.summary.runId == runId1 && uiData.summary.id == id)) - assert(queryStore.allQueryUIData.filterNot(_.summary.isActive).head.summary.runId == runId0) + uiData.summary.runId == runId1.toString && uiData.summary.id == id)) + assert( + queryStore.allQueryUIData.filterNot(_.summary.isActive).head.summary.runId == runId0.toString) assert(queryStore.allQueryUIData.filterNot(_.summary.isActive).head.summary.id == id) } @@ -210,4 +216,52 @@ class StreamingQueryStatusListenerSuite extends StreamTest { addQueryProgress() checkQueryProcessData(5) } + + test("SPARK-38056: test writing StreamingQueryData to an in-memory store") { + testStreamingQueryData(new InMemoryStore()) + } + + test("SPARK-38056: test writing StreamingQueryData to a LevelDB store") { + assume(!Utils.isMacOnAppleSilicon) + val conf = new SparkConf() + .set(HYBRID_STORE_DISK_BACKEND, HybridStoreDiskBackend.LEVELDB.toString) + val testDir = Utils.createTempDir() + val kvStore = KVUtils.open(testDir, getClass.getName, conf) + try { + testStreamingQueryData(kvStore) + } finally { + kvStore.close() + Utils.deleteRecursively(testDir) + } + } + + test("SPARK-38056: test writing StreamingQueryData to a RocksDB store") { + assume(!Utils.isMacOnAppleSilicon) + val conf = new SparkConf() + .set(HYBRID_STORE_DISK_BACKEND, HybridStoreDiskBackend.ROCKSDB.toString) + val testDir = Utils.createTempDir() + val kvStore = KVUtils.open(testDir, getClass.getName, conf) + try { + testStreamingQueryData(kvStore) + } finally { + kvStore.close() + Utils.deleteRecursively(testDir) + } + } + + private def testStreamingQueryData(kvStore: KVStore): Unit = { + val id = UUID.randomUUID() + val testData = new StreamingQueryData( + "some-query", + id, + id.toString, + isActive = false, + None, + 1L, + None + ) + val store = new ElementTrackingStore(kvStore, sparkConf) + store.write(testData) + store.close(closeParent = false) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/StatefulOpClusteredDistributionTestHelper.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/StatefulOpClusteredDistributionTestHelper.scala new file mode 100644 index 0000000000000..f2684b8c39cd9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/StatefulOpClusteredDistributionTestHelper.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.streaming.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning, StatefulOpClusteredDistribution} +import org.apache.spark.sql.execution.SparkPlan + +trait StatefulOpClusteredDistributionTestHelper extends SparkFunSuite { + protected def requireClusteredDistribution( + plan: SparkPlan, + desiredClusterColumns: Seq[Seq[String]], + desiredNumPartitions: Option[Int]): Boolean = { + assert(plan.requiredChildDistribution.length === desiredClusterColumns.length) + plan.requiredChildDistribution.zip(desiredClusterColumns).forall { + case (d: ClusteredDistribution, clusterColumns: Seq[String]) + if partitionExpressionsColumns(d.clustering) == clusterColumns && + d.requiredNumPartitions == desiredNumPartitions => true + + case _ => false + } + } + + protected def requireStatefulOpClusteredDistribution( + plan: SparkPlan, + desiredClusterColumns: Seq[Seq[String]], + desiredNumPartitions: Int): Boolean = { + assert(plan.requiredChildDistribution.length === desiredClusterColumns.length) + plan.requiredChildDistribution.zip(desiredClusterColumns).forall { + case (d: StatefulOpClusteredDistribution, clusterColumns: Seq[String]) + if partitionExpressionsColumns(d.expressions) == clusterColumns && + d._requiredNumPartitions == desiredNumPartitions => true + + case _ => false + } + } + + protected def hasDesiredHashPartitioning( + plan: SparkPlan, + desiredClusterColumns: Seq[String], + desiredNumPartitions: Int): Boolean = { + plan.outputPartitioning match { + case HashPartitioning(expressions, numPartitions) + if partitionExpressionsColumns(expressions) == desiredClusterColumns && + numPartitions == desiredNumPartitions => true + + case _ => false + } + } + + protected def hasDesiredHashPartitioningInChildren( + plan: SparkPlan, + desiredClusterColumns: Seq[Seq[String]], + desiredNumPartitions: Int): Boolean = { + plan.children.zip(desiredClusterColumns).forall { case (child, clusterColumns) => + hasDesiredHashPartitioning(child, clusterColumns, desiredNumPartitions) + } + } + + private def partitionExpressionsColumns(expressions: Seq[Expression]): Seq[String] = { + expressions.flatMap { + case ref: AttributeReference => Some(ref.name) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 41d11568750cc..dabd9c001eb3d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -32,7 +32,7 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.apache.parquet.schema.Type.Repetition import org.scalatest.BeforeAndAfter -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, TestUtils} import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} @@ -42,7 +42,6 @@ import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Ove import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.{DataSourceUtils, HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.datasources.noop.NoopDataSource -import org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -536,12 +535,31 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with .option("TO", "10") .format("org.apache.spark.sql.sources.SimpleScanSource") + val answerDf = spark.range(1, 11).toDF() + // when users do not specify the schema - checkAnswer(dfReader.load(), spark.range(1, 11).toDF()) + checkAnswer(dfReader.load(), answerDf) + + // same base schema, differing metadata and nullability + val fooBarMetadata = new MetadataBuilder().putString("foo", "bar").build() + val nullableAndMetadataCases = Seq( + (false, fooBarMetadata), + (false, Metadata.empty), + (true, fooBarMetadata), + (true, Metadata.empty)) + nullableAndMetadataCases.foreach { case (nullable, metadata) => + val inputSchema = new StructType() + .add("i", IntegerType, nullable = nullable, metadata = metadata) + checkAnswer(dfReader.schema(inputSchema).load(), answerDf) + } // when users specify a wrong schema - val inputSchema = new StructType().add("s", IntegerType, nullable = false) - val e = intercept[AnalysisException] { dfReader.schema(inputSchema).load() } + var inputSchema = new StructType().add("s", IntegerType, nullable = false) + var e = intercept[AnalysisException] { dfReader.schema(inputSchema).load() } + assert(e.getMessage.contains("The user-specified schema doesn't match the actual schema")) + + inputSchema = new StructType().add("i", StringType, nullable = true) + e = intercept[AnalysisException] { dfReader.schema(inputSchema).load() } assert(e.getMessage.contains("The user-specified schema doesn't match the actual schema")) } @@ -745,7 +763,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with withTempPath { dir => val path = dir.getAbsolutePath df.write.mode("overwrite").parquet(path) - val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0) + val file = TestUtils.listDirectory(dir).head val hadoopInputFile = HadoopInputFile.fromPath(new Path(file), new Configuration()) val f = ParquetFileReader.open(hadoopInputFile) @@ -862,7 +880,8 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with val createArray = udf { (length: Long) => for (i <- 1 to length.toInt) yield i.toString } - spark.range(4).select(createArray('id + 1) as 'ex, 'id, 'id % 4 as 'part).coalesce(1).write + spark.range(4).select(createArray(Symbol("id") + 1) as Symbol("ex"), + Symbol("id"), Symbol("id") % 4 as Symbol("part")).coalesce(1).write .partitionBy("part", "id") .mode("overwrite") .parquet(src.toString) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 47a6f3617da63..fb3d38f3b7b18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -61,7 +61,13 @@ private[sql] object TestSQLContext { val overrideConfs: Map[String, String] = Map( // Fewer shuffle partitions to speed up testing. - SQLConf.SHUFFLE_PARTITIONS.key -> "5") + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + // Enable parquet read field id for tests to ensure correctness + // By default, if Spark schema doesn't contain the `parquet.field.id` metadata, + // the underlying matching mechanism should behave exactly like name matching + // which is the existing behavior. Therefore, turning this on ensures that we didn't + // introduce any regression for such mixed matching mode. + SQLConf.PARQUET_FIELD_ID_READ_ENABLED.key -> "true") } private[sql] class TestSQLSessionStateBuilder( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/vectorized/ArrowColumnVectorSuite.scala similarity index 92% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/vectorized/ArrowColumnVectorSuite.scala index 60f1b32a41f05..25beda99cd654 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/vectorized/ArrowColumnVectorSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized +package org.apache.spark.sql.vectorized import org.apache.arrow.vector._ import org.apache.arrow.vector.complex._ @@ -23,7 +23,6 @@ import org.apache.arrow.vector.complex._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ArrowUtils -import org.apache.spark.sql.vectorized.ArrowColumnVector import org.apache.spark.unsafe.types.UTF8String class ArrowColumnVectorSuite extends SparkFunSuite { @@ -431,4 +430,35 @@ class ArrowColumnVectorSuite extends SparkFunSuite { columnVector.close() allocator.close() } + + test ("SPARK-38086: subclassing") { + class ChildArrowColumnVector(vector: ValueVector, n: Int) + extends ArrowColumnVector(vector: ValueVector) { + + override def getValueVector: ValueVector = accessor.vector + override def getInt(rowId: Int): Int = accessor.getInt(rowId) + n + } + + val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue) + val vector = ArrowUtils.toArrowField("int", IntegerType, nullable = true, null) + .createVector(allocator).asInstanceOf[IntVector] + vector.allocateNew() + + (0 until 10).foreach { i => + vector.setSafe(i, i) + } + + val columnVector = new ChildArrowColumnVector(vector, 1) + assert(columnVector.dataType === IntegerType) + assert(!columnVector.hasNull) + + val intVector = columnVector.getValueVector.asInstanceOf[IntVector] + (0 until 10).foreach { i => + assert(columnVector.getInt(i) === i + 1) + assert(intVector.get(i) === i) + } + + columnVector.close() + allocator.close() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceSuite.scala index baa04ada8b5d1..11201aadf67f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceSuite.scala @@ -152,7 +152,7 @@ class SqlResourceSuite extends SparkFunSuite with PrivateMethodTester { import SqlResourceSuite._ val sqlResource = new SqlResource() - val prepareExecutionData = PrivateMethod[ExecutionData]('prepareExecutionData) + val prepareExecutionData = PrivateMethod[ExecutionData](Symbol("prepareExecutionData")) test("Prepare ExecutionData when details = false and planDescription = false") { val executionData = @@ -196,7 +196,7 @@ class SqlResourceSuite extends SparkFunSuite with PrivateMethodTester { } test("Parse wholeStageCodegenId from nodeName") { - val getWholeStageCodegenId = PrivateMethod[Option[Long]]('getWholeStageCodegenId) + val getWholeStageCodegenId = PrivateMethod[Option[Long]](Symbol("getWholeStageCodegenId")) val wholeStageCodegenId = sqlResource invokePrivate getWholeStageCodegenId(WHOLE_STAGE_CODEGEN_1) assert(wholeStageCodegenId == Some(1)) diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/LogDivertAppender.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/LogDivertAppender.java index 2fabf70c0f274..ca0fbe7eb67a9 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/LogDivertAppender.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/LogDivertAppender.java @@ -267,7 +267,7 @@ private static StringLayout initLayout(OperationLog.LoggingLevel loggingMode) { Appender ap = entry.getValue(); if (ap.getClass().equals(ConsoleAppender.class)) { Layout l = ap.getLayout(); - if (l.getClass().equals(StringLayout.class)) { + if (l instanceof StringLayout) { layout = (StringLayout) l; break; } diff --git a/sql/hive-thriftserver/src/main/resources/META-INF/services/org.apache.spark.status.AppHistoryServerPlugin b/sql/hive-thriftserver/src/main/resources/META-INF/services/org.apache.spark.status.AppHistoryServerPlugin index 96d990372ee4c..75feb9da53a93 100644 --- a/sql/hive-thriftserver/src/main/resources/META-INF/services/org.apache.spark.status.AppHistoryServerPlugin +++ b/sql/hive-thriftserver/src/main/resources/META-INF/services/org.apache.spark.status.AppHistoryServerPlugin @@ -1 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + org.apache.spark.sql.hive.thriftserver.ui.HiveThriftServer2HistoryServerPlugin diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index e17b74873395e..4c26e93606083 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -527,7 +527,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { // string, the origin implementation from Hive will not drop the trailing semicolon as expected, // hence we refined this function a little bit. // Note: [SPARK-33100] Ignore a semicolon inside a bracketed comment in spark-sql. - private def splitSemiColon(line: String): JList[String] = { + private[hive] def splitSemiColon(line: String): JList[String] = { var insideSingleQuote = false var insideDoubleQuote = false var insideSimpleComment = false @@ -613,7 +613,17 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { isStatement = statementInProgress(index) } - if (beginIndex < line.length()) { + // Check the last char is end of nested bracketed comment. + val endOfBracketedComment = leavingBracketedComment && bracketedCommentLevel == 1 + // Spark SQL support simple comment and nested bracketed comment in query body. + // But if Spark SQL receives a comment alone, it will throw parser exception. + // In Spark SQL CLI, if there is a completed comment in the end of whole query, + // since Spark SQL CLL use `;` to split the query, CLI will pass the comment + // to the backend engine and throw exception. CLI should ignore this comment, + // If there is an uncompleted statement or an uncompleted bracketed comment in the end, + // CLI should also pass this part to the backend engine, which may throw an exception + // with clear error message. + if (!endOfBracketedComment && (isStatement || insideBracketedComment)) { ret.add(line.substring(beginIndex)) } ret diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index bc4b64c287e6c..d2c0235a23f21 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -138,7 +138,7 @@ private[thriftserver] trait ReflectedCompositeService { this: AbstractService => serviceStartCount += 1 } // Emulating `AbstractService.start` - val startTime = new java.lang.Long(System.currentTimeMillis()) + val startTime = java.lang.Long.valueOf(System.currentTimeMillis()) setAncestorField(this, 3, "startTime", startTime) invoke(classOf[AbstractService], this, "ensureCurrentState", classOf[STATE] -> STATE.INITED) invoke(classOf[AbstractService], this, "changeState", classOf[STATE] -> STATE.STARTED) @@ -147,7 +147,7 @@ private[thriftserver] trait ReflectedCompositeService { this: AbstractService => case NonFatal(e) => logError(s"Error starting services $getName", e) invoke(classOf[CompositeService], this, "stop", - classOf[Int] -> new Integer(serviceStartCount)) + classOf[Int] -> Integer.valueOf(serviceStartCount)) throw HiveThriftServerErrors.failedToStartServiceError(getName, e) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/HiveThriftServer2Listener.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/HiveThriftServer2Listener.scala index 4cf672e3d9d9e..7b2da6970fb86 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/HiveThriftServer2Listener.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/HiveThriftServer2Listener.scala @@ -93,7 +93,7 @@ private[thriftserver] class HiveThriftServer2Listener( val execList = executionList.values().asScala.filter(_.groupId == groupId).toSeq if (execList.nonEmpty) { execList.foreach { exec => - exec.jobId += jobId.toString + exec.jobId += jobId updateLiveStore(exec) } } else { @@ -105,7 +105,7 @@ private[thriftserver] class HiveThriftServer2Listener( storeExecInfo.foreach { exec => val liveExec = getOrCreateExecution(exec.execId, exec.statement, exec.sessionId, exec.startTimestamp, exec.userName) - liveExec.jobId += jobId.toString + liveExec.jobId += jobId updateStoreWithTriggerEnabled(liveExec) executionList.remove(liveExec.execId) } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index 54a40e3990f09..d0378efd646e3 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -232,7 +232,7 @@ private[ui] class SqlStatsPagedTable( def jobLinks(jobData: Seq[String]): Seq[Node] = { jobData.map { jobId => - [{jobId.toString}] + [{jobId}] } } diff --git a/sql/hive-thriftserver/src/test/resources/log4j2.properties b/sql/hive-thriftserver/src/test/resources/log4j2.properties index 58e18af0a8e6c..939335bf3ac8d 100644 --- a/sql/hive-thriftserver/src/test/resources/log4j2.properties +++ b/sql/hive-thriftserver/src/test/resources/log4j2.properties @@ -33,8 +33,8 @@ appender.console.filter.1.a.type = ThresholdFilter appender.console.filter.1.a.level = warn # SPARK-34128: Suppress undesirable TTransportException warnings, due to THRIFT-4805 -appender.console.filter.1.b.type = MarkerFilter -appender.console.filter.1.b.marker = Thrift error occurred during processing of message +appender.console.filter.1.b.type = RegexFilter +appender.console.filter.1.b.regex = .*Thrift error occurred during processing of message.* appender.console.filter.1.b.onMatch = deny appender.console.filter.1.b.onMismatch = neutral @@ -47,8 +47,8 @@ appender.file.layout.pattern = %d{HH:mm:ss.SSS} %t %p %c{1}: %m%n appender.file.filter.1.type = Filters -appender.file.filter.1.a.type = MarkerFilter -appender.file.filter.1.a.marker = Thrift error occurred during processing of message +appender.file.filter.1.a.type = RegexFilter +appender.file.filter.1.a.regx = .*Thrift error occurred during processing of message.* appender.file.filter.1.a.onMatch = deny appender.file.filter.1.a.onMismatch = neutral diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 234fb89b01a83..2f0fd858ba206 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -22,17 +22,23 @@ import java.nio.charset.StandardCharsets import java.sql.Timestamp import java.util.Date +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.concurrent.Promise import scala.concurrent.duration._ +import org.apache.hadoop.hive.cli.CliSessionState import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hadoop.hive.ql.session.SessionState import org.scalatest.BeforeAndAfterAll +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.ProcessTestUtils.ProcessOutputCapturer -import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging +import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.HiveUtils._ +import org.apache.spark.sql.hive.client.HiveClientImpl import org.apache.spark.sql.hive.test.HiveTestJars import org.apache.spark.sql.internal.StaticSQLConf import org.apache.spark.util.{ThreadUtils, Utils} @@ -549,22 +555,22 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { ) } - test("AnalysisException with root cause will be printStacktrace") { + test("SparkException with root cause will be printStacktrace") { // If it is not in silent mode, will print the stacktrace runCliWithin( 1.minute, extraArgs = Seq("--hiveconf", "hive.session.silent=false", - "-e", "select date_sub(date'2011-11-11', '1.2');"), - errorResponses = Seq("NumberFormatException"))( - ("", "Error in query: The second argument of 'date_sub' function needs to be an integer."), - ("", "NumberFormatException: invalid input syntax for type numeric: 1.2")) + "-e", "select from_json('a', 'a INT', map('mode', 'FAILFAST'));"), + errorResponses = Seq("JsonParseException"))( + ("", "SparkException: Malformed records are detected in record parsing"), + ("", "JsonParseException: Unrecognized token 'a'")) // If it is in silent mode, will print the error message only runCliWithin( 1.minute, extraArgs = Seq("--conf", "spark.hive.session.silent=true", - "-e", "select date_sub(date'2011-11-11', '1.2');"), - errorResponses = Seq("AnalysisException"))( - ("", "Error in query: The second argument of 'date_sub' function needs to be an integer.")) + "-e", "select from_json('a', 'a INT', map('mode', 'FAILFAST'));"), + errorResponses = Seq("SparkException"))( + ("", "SparkException: Malformed records are detected in record parsing")) } test("SPARK-30808: use Java 8 time API in Thrift SQL CLI by default") { @@ -624,7 +630,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { test("SPARK-37555: spark-sql should pass last unclosed comment to backend") { runCliWithin(2.minute)( // Only unclosed comment. - "/* SELECT /*+ HINT() 4; */;".stripMargin -> "mismatched input ';'", + "/* SELECT /*+ HINT() 4; */;".stripMargin -> "Syntax error at or near ';'", // Unclosed nested bracketed comment. "/* SELECT /*+ HINT() 4; */ SELECT 1;".stripMargin -> "1", // Unclosed comment with query. @@ -638,4 +644,40 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { runCliWithin(2.minute, errorResponses = Seq("ParseException"))( "delete jar dummy.jar;" -> "missing 'FROM' at 'jar'(line 1, pos 7)") } + + test("SPARK-37906: Spark SQL CLI should not pass final comment") { + val sparkConf = new SparkConf(loadDefaults = true) + .setMaster("local-cluster[1,1,1024]") + .setAppName("SPARK-37906") + val sparkContext = new SparkContext(sparkConf) + SparkSQLEnv.sparkContext = sparkContext + val hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf) + val extraConfigs = HiveUtils.formatTimeVarsForHiveClient(hadoopConf) + val cliConf = HiveClientImpl.newHiveConf(sparkConf, hadoopConf, extraConfigs) + val sessionState = new CliSessionState(cliConf) + SessionState.setCurrentSessionState(sessionState) + val cli = new SparkSQLCLIDriver + Seq("SELECT 1; --comment" -> Seq("SELECT 1"), + "SELECT 1; /* comment */" -> Seq("SELECT 1"), + "SELECT 1; /* comment" -> Seq("SELECT 1", " /* comment"), + "SELECT 1; /* comment select 1;" -> Seq("SELECT 1", " /* comment select 1;"), + "/* This is a comment without end symbol SELECT 1;" -> + Seq("/* This is a comment without end symbol SELECT 1;"), + "SELECT 1; --comment\n" -> Seq("SELECT 1"), + "SELECT 1; /* comment */\n" -> Seq("SELECT 1"), + "SELECT 1; /* comment\n" -> Seq("SELECT 1", " /* comment\n"), + "SELECT 1; /* comment select 1;\n" -> Seq("SELECT 1", " /* comment select 1;\n"), + "/* This is a comment without end symbol SELECT 1;\n" -> + Seq("/* This is a comment without end symbol SELECT 1;\n"), + "/* comment */ SELECT 1;" -> Seq("/* comment */ SELECT 1"), + "SELECT /* comment */ 1;" -> Seq("SELECT /* comment */ 1"), + "-- comment " -> Seq(), + "-- comment \nSELECT 1" -> Seq("-- comment \nSELECT 1"), + "/* comment */ " -> Seq() + ).foreach { case (query, ret) => + assert(cli.splitSemiColon(query).asScala === ret) + } + sessionState.close() + SparkSQLEnv.stop() + } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkThriftServerProtocolVersionsSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkThriftServerProtocolVersionsSuite.scala index 851b8e48684de..daf410556f5b8 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkThriftServerProtocolVersionsSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkThriftServerProtocolVersionsSuite.scala @@ -30,6 +30,7 @@ import org.apache.thrift.protocol.TBinaryProtocol import org.apache.thrift.transport.TSocket import org.apache.spark.sql.catalyst.util.NumberConverter +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.unsafe.types.UTF8String class SparkThriftServerProtocolVersionsSuite extends HiveThriftServer2TestBase { @@ -298,9 +299,11 @@ class SparkThriftServerProtocolVersionsSuite extends HiveThriftServer2TestBase { assert(metaData.getPrecision(1) === Int.MaxValue) assert(metaData.getScale(1) === 0) } - testExecuteStatementWithProtocolVersion(version, "SELECT cast(49960 as binary)") { rs => - assert(rs.next()) - assert(rs.getString(1) === UTF8String.fromBytes(NumberConverter.toBinary(49960)).toString) + if (!SQLConf.get.ansiEnabled) { + testExecuteStatementWithProtocolVersion(version, "SELECT cast(49960 as binary)") { rs => + assert(rs.next()) + assert(rs.getString(1) === UTF8String.fromBytes(NumberConverter.toBinary(49960)).toString) + } } testExecuteStatementWithProtocolVersion(version, "SELECT cast(null as binary)") { rs => assert(rs.next()) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala index ad527a2571898..b5cfa04bab581 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala @@ -56,7 +56,7 @@ trait ThriftServerWithSparkContextSuite extends SharedThriftServer { } test("Full stack traces as error message for jdbc or thrift client") { - val sql = "select date_sub(date'2011-11-11', '1.2')" + val sql = "select from_json('a', 'a INT', map('mode', 'FAILFAST'))" withCLIServiceClient() { client => val sessionHandle = client.openSession(user, "") @@ -67,24 +67,18 @@ trait ThriftServerWithSparkContextSuite extends SharedThriftServer { sql, confOverlay) } - - assert(e.getMessage - .contains("The second argument of 'date_sub' function needs to be an integer.")) - assert(!e.getMessage.contains("" + - "java.lang.NumberFormatException: invalid input syntax for type numeric: 1.2")) - assert(e.getSQLState == "22023") + assert(e.getMessage.contains("JsonParseException: Unrecognized token 'a'")) + assert(!e.getMessage.contains( + "SparkException: Malformed records are detected in record parsing")) } withJdbcStatement { statement => val e = intercept[SQLException] { statement.executeQuery(sql) } - assert(e.getMessage - .contains("The second argument of 'date_sub' function needs to be an integer.")) - assert(e.getMessage.contains("[SECOND_FUNCTION_ARGUMENT_NOT_INTEGER]")) - assert(e.getMessage.contains("" + - "java.lang.NumberFormatException: invalid input syntax for type numeric: 1.2")) - assert(e.getSQLState == "22023") + assert(e.getMessage.contains("JsonParseException: Unrecognized token 'a'")) + assert(e.getMessage.contains( + "SparkException: Malformed records are detected in record parsing")) } } diff --git a/sql/hive/benchmarks/OrcReadBenchmark-jdk11-results.txt b/sql/hive/benchmarks/OrcReadBenchmark-jdk11-results.txt index 3f9e63f9b8f2d..f9ab5dd5d51ae 100644 --- a/sql/hive/benchmarks/OrcReadBenchmark-jdk11-results.txt +++ b/sql/hive/benchmarks/OrcReadBenchmark-jdk11-results.txt @@ -2,221 +2,221 @@ SQL Single Numeric Column Scan ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 1064 1070 9 14.8 67.6 1.0X -Native ORC Vectorized 237 326 73 66.3 15.1 4.5X -Hive built-in ORC 1232 1330 139 12.8 78.3 0.9X +Hive built-in ORC 1137 1138 1 13.8 72.3 1.0X +Native ORC MR 962 982 17 16.3 61.2 1.2X +Native ORC Vectorized 225 298 65 69.9 14.3 5.1X -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single SMALLINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 947 1056 155 16.6 60.2 1.0X -Native ORC Vectorized 232 311 56 67.7 14.8 4.1X -Hive built-in ORC 1317 1330 19 11.9 83.7 0.7X +Hive built-in ORC 1250 1253 4 12.6 79.5 1.0X +Native ORC MR 1038 1135 136 15.1 66.0 1.2X +Native ORC Vectorized 232 307 47 67.9 14.7 5.4X -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single INT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 964 1070 150 16.3 61.3 1.0X -Native ORC Vectorized 275 304 32 57.2 17.5 3.5X -Hive built-in ORC 1328 1336 11 11.8 84.4 0.7X +Hive built-in ORC 1360 1399 55 11.6 86.5 1.0X +Native ORC MR 1047 1107 85 15.0 66.5 1.3X +Native ORC Vectorized 273 291 20 57.7 17.3 5.0X -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single BIGINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 1006 1066 84 15.6 64.0 1.0X -Native ORC Vectorized 342 353 12 46.0 21.7 2.9X -Hive built-in ORC 1361 1386 36 11.6 86.5 0.7X +Hive built-in ORC 1381 1425 62 11.4 87.8 1.0X +Native ORC MR 1136 1138 4 13.9 72.2 1.2X +Native ORC Vectorized 336 377 31 46.8 21.4 4.1X -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single FLOAT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 1020 1026 8 15.4 64.8 1.0X -Native ORC Vectorized 352 381 23 44.7 22.4 2.9X -Hive built-in ORC 1457 1457 0 10.8 92.7 0.7X +Hive built-in ORC 1425 1425 1 11.0 90.6 1.0X +Native ORC MR 1090 1093 4 14.4 69.3 1.3X +Native ORC Vectorized 349 381 47 45.1 22.2 4.1X -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single DOUBLE Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 1036 1056 28 15.2 65.9 1.0X -Native ORC Vectorized 387 403 15 40.6 24.6 2.7X -Hive built-in ORC 1409 1417 11 11.2 89.6 0.7X +Hive built-in ORC 1434 1477 61 11.0 91.2 1.0X +Native ORC MR 1116 1125 12 14.1 71.0 1.3X +Native ORC Vectorized 366 388 18 43.0 23.2 3.9X ================================================================================================ Int and String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Int and String Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 1993 2094 144 5.3 190.0 1.0X -Native ORC Vectorized 1290 1348 83 8.1 123.0 1.5X -Hive built-in ORC 2336 2426 127 4.5 222.8 0.9X +Hive built-in ORC 2442 2543 143 4.3 232.8 1.0X +Native ORC MR 2030 2048 25 5.2 193.6 1.2X +Native ORC Vectorized 1261 1266 8 8.3 120.2 1.9X ================================================================================================ Partitioned Table Scan ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Partitioned Table: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Data column - Native ORC MR 1369 1384 22 11.5 87.0 1.0X -Data column - Native ORC Vectorized 406 428 20 38.7 25.8 3.4X -Data column - Hive built-in ORC 1444 1527 118 10.9 91.8 0.9X -Partition column - Native ORC MR 745 796 45 21.1 47.4 1.8X -Partition column - Native ORC Vectorized 70 96 28 223.2 4.5 19.4X -Partition column - Hive built-in ORC 1035 1063 39 15.2 65.8 1.3X -Both columns - Native ORC MR 1245 1306 86 12.6 79.2 1.1X -Both columns - Native ORC Vectorized 385 424 35 40.9 24.5 3.6X -Both columns - Hive built-in ORC 1481 1566 120 10.6 94.2 0.9X +Data column - Hive built-in ORC 1615 1617 3 9.7 102.7 1.0X +Data column - Native ORC MR 1330 1373 61 11.8 84.6 1.2X +Data column - Native ORC Vectorized 343 404 83 45.8 21.8 4.7X +Partition column - Hive built-in ORC 1087 1099 18 14.5 69.1 1.5X +Partition column - Native ORC MR 912 922 12 17.2 58.0 1.8X +Partition column - Native ORC Vectorized 67 94 33 234.6 4.3 24.1X +Both columns - Hive built-in ORC 1743 1748 7 9.0 110.8 0.9X +Both columns - Native ORC MR 1454 1459 6 10.8 92.5 1.1X +Both columns - Native ORC Vectorized 354 414 57 44.4 22.5 4.6X ================================================================================================ Repeated String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Repeated String: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 1102 1261 224 9.5 105.1 1.0X -Native ORC Vectorized 216 260 55 48.5 20.6 5.1X -Hive built-in ORC 1299 1427 181 8.1 123.9 0.8X +Hive built-in ORC 1331 1342 16 7.9 126.9 1.0X +Native ORC MR 901 910 12 11.6 85.9 1.5X +Native ORC Vectorized 228 291 72 45.9 21.8 5.8X ================================================================================================ String with Nulls Scan ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz String with Nulls Scan (0.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 1632 1653 30 6.4 155.6 1.0X -Native ORC Vectorized 689 698 8 15.2 65.7 2.4X -Hive built-in ORC 2224 2254 43 4.7 212.1 0.7X +Hive built-in ORC 2295 2298 4 4.6 218.9 1.0X +Native ORC MR 1711 1743 46 6.1 163.1 1.3X +Native ORC Vectorized 686 692 8 15.3 65.4 3.3X -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz String with Nulls Scan (50.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 1516 1555 54 6.9 144.6 1.0X -Native ORC Vectorized 782 801 19 13.4 74.6 1.9X -Hive built-in ORC 2023 2110 123 5.2 192.9 0.7X +Hive built-in ORC 2045 2107 88 5.1 195.0 1.0X +Native ORC MR 1577 1585 11 6.6 150.4 1.3X +Native ORC Vectorized 801 804 5 13.1 76.4 2.6X -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz String with Nulls Scan (95.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 879 931 48 11.9 83.8 1.0X -Native ORC Vectorized 250 342 85 42.0 23.8 3.5X -Hive built-in ORC 1204 1219 20 8.7 114.9 0.7X +Hive built-in ORC 1254 1261 10 8.4 119.6 1.0X +Native ORC MR 944 962 15 11.1 90.1 1.3X +Native ORC Vectorized 262 334 103 40.1 25.0 4.8X ================================================================================================ Single Column Scan From Wide Columns ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Column Scan from 100 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 159 192 24 6.6 151.4 1.0X -Native ORC Vectorized 85 116 32 12.3 81.0 1.9X -Hive built-in ORC 790 853 99 1.3 753.9 0.2X +Hive built-in ORC 954 1002 68 1.1 909.8 1.0X +Native ORC MR 149 188 30 7.0 141.9 6.4X +Native ORC Vectorized 83 108 30 12.7 78.7 11.6X -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Column Scan from 200 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 161 196 40 6.5 153.9 1.0X -Native ORC Vectorized 110 139 28 9.6 104.6 1.5X -Hive built-in ORC 1549 1585 51 0.7 1476.8 0.1X +Hive built-in ORC 1939 1994 78 0.5 1848.9 1.0X +Native ORC MR 187 259 57 5.6 178.2 10.4X +Native ORC Vectorized 117 193 46 9.0 111.2 16.6X -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Column Scan from 300 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 201 221 14 5.2 191.8 1.0X -Native ORC Vectorized 135 163 23 7.8 128.6 1.5X -Hive built-in ORC 2166 2172 8 0.5 2065.6 0.1X +Hive built-in ORC 2759 2827 96 0.4 2631.6 1.0X +Native ORC MR 328 368 50 3.2 312.5 8.4X +Native ORC Vectorized 149 210 68 7.0 141.9 18.5X ================================================================================================ Struct scan ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Struct Column Scan with 10 Fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 473 522 41 2.2 451.4 1.0X -Native ORC Vectorized 234 351 58 4.5 222.9 2.0X -Hive built-in ORC 472 601 116 2.2 449.8 1.0X +Hive built-in ORC 681 696 17 1.5 649.0 1.0X +Native ORC MR 484 497 9 2.2 461.7 1.4X +Native ORC Vectorized 303 371 59 3.5 289.3 2.2X -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Struct Column Scan with 100 Fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -Native ORC MR 3238 3394 221 0.3 3087.5 1.0X -Native ORC Vectorized 2724 2844 169 0.4 2598.2 1.2X -Hive built-in ORC 3898 3934 52 0.3 3717.0 0.8X +Hive built-in ORC 3762 4091 465 0.3 3588.1 1.0X +Native ORC MR 3503 3577 104 0.3 3340.7 1.1X +Native ORC Vectorized 2296 2415 168 0.5 2189.9 1.6X -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Struct Column Scan with 300 Fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -Native ORC MR 10723 10890 236 0.1 10226.4 1.0X -Native ORC Vectorized 9966 10091 177 0.1 9503.9 1.1X -Hive built-in ORC 12360 12482 172 0.1 11787.4 0.9X +Hive built-in ORC 11058 11109 72 0.1 10545.5 1.0X +Native ORC MR 11323 11354 44 0.1 10798.4 1.0X +Native ORC Vectorized 11246 11315 97 0.1 10725.2 1.0X -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Struct Column Scan with 600 Fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -Native ORC MR 24875 25382 717 0.0 23722.6 1.0X -Native ORC Vectorized 22763 22830 95 0.0 21708.5 1.1X -Hive built-in ORC 27783 28079 419 0.0 26496.0 0.9X +Hive built-in ORC 25265 29571 441 0.0 24094.4 1.0X +Native ORC MR 26980 27178 280 0.0 25730.4 0.9X +Native ORC Vectorized 26603 26976 527 0.0 25370.3 0.9X ================================================================================================ Nested Struct scan ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Nested Struct Scan with 10 Elements, 10 Fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 4175 4184 12 0.3 3982.0 1.0X -Native ORC Vectorized 1476 1483 9 0.7 1407.9 2.8X -Hive built-in ORC 4128 4150 31 0.3 3936.6 1.0X +Hive built-in ORC 4354 4453 140 0.2 4152.1 1.0X +Native ORC MR 3674 4025 497 0.3 3503.4 1.2X +Native ORC Vectorized 1000 1014 21 1.0 953.4 4.4X -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Nested Struct Scan with 30 Elements, 10 Fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 9819 9945 178 0.1 9364.0 1.0X -Native ORC Vectorized 3771 3809 54 0.3 3596.0 2.6X -Hive built-in ORC 11067 11090 32 0.1 10554.8 0.9X +Hive built-in ORC 11727 11762 50 0.1 11183.8 1.0X +Native ORC MR 8861 8862 1 0.1 8450.8 1.3X +Native ORC Vectorized 2441 2497 79 0.4 2327.9 4.8X -OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Nested Struct Scan with 10 Elements, 30 Fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 10779 10781 3 0.1 10279.7 1.0X -Native ORC Vectorized 7162 7392 325 0.1 6830.7 1.5X -Hive built-in ORC 8417 8553 192 0.1 8027.5 1.3X +Hive built-in ORC 9604 9616 17 0.1 9159.4 1.0X +Native ORC MR 9501 9535 47 0.1 9061.0 1.0X +Native ORC Vectorized 4418 4582 232 0.2 4213.6 2.2X diff --git a/sql/hive/benchmarks/OrcReadBenchmark-jdk17-results.txt b/sql/hive/benchmarks/OrcReadBenchmark-jdk17-results.txt index 836b563063fa7..b24cef4ef4953 100644 --- a/sql/hive/benchmarks/OrcReadBenchmark-jdk17-results.txt +++ b/sql/hive/benchmarks/OrcReadBenchmark-jdk17-results.txt @@ -2,221 +2,221 @@ SQL Single Numeric Column Scan ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 803 838 38 19.6 51.1 1.0X -Native ORC Vectorized 147 173 21 107.1 9.3 5.5X -Hive built-in ORC 1098 1115 23 14.3 69.8 0.7X +Hive built-in ORC 933 962 48 16.9 59.3 1.0X +Native ORC MR 864 910 76 18.2 54.9 1.1X +Native ORC Vectorized 144 172 22 108.9 9.2 6.5X -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single SMALLINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 856 927 81 18.4 54.4 1.0X -Native ORC Vectorized 136 161 15 115.3 8.7 6.3X -Hive built-in ORC 1188 1328 198 13.2 75.5 0.7X +Hive built-in ORC 1203 1301 139 13.1 76.5 1.0X +Native ORC MR 848 875 27 18.5 53.9 1.4X +Native ORC Vectorized 117 139 17 134.3 7.4 10.3X -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single INT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 813 875 105 19.3 51.7 1.0X -Native ORC Vectorized 138 158 15 113.9 8.8 5.9X -Hive built-in ORC 1158 1158 0 13.6 73.6 0.7X +Hive built-in ORC 1252 1257 6 12.6 79.6 1.0X +Native ORC MR 873 939 92 18.0 55.5 1.4X +Native ORC Vectorized 127 146 17 124.0 8.1 9.9X -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single BIGINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 839 844 7 18.8 53.3 1.0X -Native ORC Vectorized 180 207 30 87.4 11.4 4.7X -Hive built-in ORC 1358 1394 52 11.6 86.3 0.6X +Hive built-in ORC 1286 1299 19 12.2 81.8 1.0X +Native ORC MR 948 966 17 16.6 60.3 1.4X +Native ORC Vectorized 171 203 24 91.9 10.9 7.5X -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single FLOAT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 906 968 58 17.4 57.6 1.0X -Native ORC Vectorized 237 292 56 66.3 15.1 3.8X -Hive built-in ORC 1395 1416 30 11.3 88.7 0.6X +Hive built-in ORC 1234 1243 13 12.7 78.4 1.0X +Native ORC MR 1019 1048 41 15.4 64.8 1.2X +Native ORC Vectorized 219 235 15 71.8 13.9 5.6X -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single DOUBLE Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 1041 1060 27 15.1 66.2 1.0X -Native ORC Vectorized 265 320 44 59.4 16.8 3.9X -Hive built-in ORC 1339 1374 49 11.7 85.2 0.8X +Hive built-in ORC 1304 1309 6 12.1 82.9 1.0X +Native ORC MR 1007 1022 22 15.6 64.0 1.3X +Native ORC Vectorized 253 274 16 62.2 16.1 5.2X ================================================================================================ Int and String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Int and String Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 2091 2136 63 5.0 199.5 1.0X -Native ORC Vectorized 1253 1260 10 8.4 119.5 1.7X -Hive built-in ORC 2384 2391 9 4.4 227.4 0.9X +Hive built-in ORC 2178 2250 102 4.8 207.7 1.0X +Native ORC MR 1816 1821 7 5.8 173.2 1.2X +Native ORC Vectorized 1003 1025 31 10.5 95.6 2.2X ================================================================================================ Partitioned Table Scan ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Partitioned Table: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Data column - Native ORC MR 1549 1631 116 10.2 98.5 1.0X -Data column - Native ORC Vectorized 295 346 45 53.3 18.8 5.3X -Data column - Hive built-in ORC 1851 1896 64 8.5 117.7 0.8X -Partition column - Native ORC MR 850 868 19 18.5 54.1 1.8X -Partition column - Native ORC Vectorized 54 67 9 288.7 3.5 28.4X -Partition column - Hive built-in ORC 1131 1174 60 13.9 71.9 1.4X -Both columns - Native ORC MR 1069 1077 10 14.7 68.0 1.4X -Both columns - Native ORC Vectorized 208 226 18 75.6 13.2 7.4X -Both columns - Hive built-in ORC 1811 1812 1 8.7 115.2 0.9X +Data column - Hive built-in ORC 1442 1449 9 10.9 91.7 1.0X +Data column - Native ORC MR 1171 1186 20 13.4 74.5 1.2X +Data column - Native ORC Vectorized 179 197 20 87.8 11.4 8.1X +Partition column - Hive built-in ORC 1022 1045 32 15.4 65.0 1.4X +Partition column - Native ORC MR 848 887 43 18.5 53.9 1.7X +Partition column - Native ORC Vectorized 54 64 8 293.9 3.4 26.9X +Both columns - Hive built-in ORC 1513 1548 50 10.4 96.2 1.0X +Both columns - Native ORC MR 1189 1204 21 13.2 75.6 1.2X +Both columns - Native ORC Vectorized 197 225 24 79.7 12.6 7.3X ================================================================================================ Repeated String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Repeated String: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 825 830 5 12.7 78.6 1.0X -Native ORC Vectorized 199 207 10 52.8 18.9 4.2X -Hive built-in ORC 1206 1210 6 8.7 115.0 0.7X +Hive built-in ORC 1259 1271 17 8.3 120.1 1.0X +Native ORC MR 842 864 21 12.5 80.3 1.5X +Native ORC Vectorized 187 199 13 56.2 17.8 6.7X ================================================================================================ String with Nulls Scan ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz String with Nulls Scan (0.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 1542 1572 42 6.8 147.1 1.0X -Native ORC Vectorized 523 582 66 20.1 49.8 3.0X -Hive built-in ORC 2190 2190 0 4.8 208.9 0.7X +Hive built-in ORC 2140 2155 21 4.9 204.1 1.0X +Native ORC MR 1559 1563 6 6.7 148.7 1.4X +Native ORC Vectorized 512 535 34 20.5 48.9 4.2X -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz String with Nulls Scan (50.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 1490 1499 13 7.0 142.1 1.0X -Native ORC Vectorized 630 695 97 16.7 60.1 2.4X -Hive built-in ORC 2112 2121 13 5.0 201.4 0.7X +Hive built-in ORC 1880 1920 56 5.6 179.3 1.0X +Native ORC MR 1467 1484 24 7.1 139.9 1.3X +Native ORC Vectorized 608 624 11 17.2 58.0 3.1X -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz String with Nulls Scan (95.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 815 830 23 12.9 77.7 1.0X -Native ORC Vectorized 225 249 26 46.5 21.5 3.6X -Hive built-in ORC 1247 1259 16 8.4 119.0 0.7X +Hive built-in ORC 1195 1209 20 8.8 113.9 1.0X +Native ORC MR 857 895 34 12.2 81.7 1.4X +Native ORC Vectorized 218 233 15 48.1 20.8 5.5X ================================================================================================ Single Column Scan From Wide Columns ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Column Scan from 100 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 141 173 19 7.5 134.0 1.0X -Native ORC Vectorized 77 91 9 13.7 73.2 1.8X -Hive built-in ORC 758 776 16 1.4 722.9 0.2X +Hive built-in ORC 884 924 43 1.2 842.7 1.0X +Native ORC MR 122 145 18 8.6 116.7 7.2X +Native ORC Vectorized 67 82 13 15.7 63.9 13.2X -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Column Scan from 200 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 190 232 29 5.5 181.4 1.0X -Native ORC Vectorized 118 149 41 8.9 112.7 1.6X -Hive built-in ORC 1537 1558 30 0.7 1465.7 0.1X +Hive built-in ORC 1473 1520 67 0.7 1404.6 1.0X +Native ORC MR 161 177 16 6.5 153.4 9.2X +Native ORC Vectorized 107 126 14 9.8 102.0 13.8X -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Column Scan from 300 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 237 268 28 4.4 226.0 1.0X -Native ORC Vectorized 165 188 17 6.4 157.2 1.4X -Hive built-in ORC 2103 2171 96 0.5 2005.3 0.1X +Hive built-in ORC 1988 2050 87 0.5 1896.3 1.0X +Native ORC MR 210 237 27 5.0 199.9 9.5X +Native ORC Vectorized 149 166 16 7.0 142.0 13.4X ================================================================================================ Struct scan ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Struct Column Scan with 10 Fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 278 294 12 3.8 265.5 1.0X -Native ORC Vectorized 213 246 41 4.9 202.9 1.3X -Hive built-in ORC 536 586 40 2.0 511.0 0.5X +Hive built-in ORC 477 498 14 2.2 454.9 1.0X +Native ORC MR 323 329 5 3.2 307.7 1.5X +Native ORC Vectorized 169 206 49 6.2 161.6 2.8X -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Struct Column Scan with 100 Fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -Native ORC MR 2235 2244 13 0.5 2131.8 1.0X -Native ORC Vectorized 3154 3159 7 0.3 3007.6 0.7X -Hive built-in ORC 3740 4089 493 0.3 3567.0 0.6X +Hive built-in ORC 3006 3007 1 0.3 2867.0 1.0X +Native ORC MR 2469 2707 337 0.4 2354.2 1.2X +Native ORC Vectorized 1407 1422 22 0.7 1341.4 2.1X -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Struct Column Scan with 300 Fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -Native ORC MR 7350 8577 1735 0.1 7009.2 1.0X -Native ORC Vectorized 7161 8481 1867 0.1 6829.0 1.0X -Hive built-in ORC 10307 10909 851 0.1 9829.6 0.7X +Hive built-in ORC 8820 8867 67 0.1 8411.4 1.0X +Native ORC MR 7301 7422 171 0.1 6962.8 1.2X +Native ORC Vectorized 7286 7300 20 0.1 6948.6 1.2X -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Struct Column Scan with 600 Fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -Native ORC MR 15931 18238 NaN 0.1 15192.6 1.0X -Native ORC Vectorized 15192 16500 1851 0.1 14487.9 1.0X -Hive built-in ORC 29853 30027 247 0.0 28469.9 0.5X +Hive built-in ORC 24634 27218 NaN 0.0 23492.4 1.0X +Native ORC MR 19304 19441 195 0.1 18409.3 1.3X +Native ORC Vectorized 19081 19091 14 0.1 18197.3 1.3X ================================================================================================ Nested Struct scan ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Nested Struct Scan with 10 Elements, 10 Fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 3399 3463 90 0.3 3241.5 1.0X -Native ORC Vectorized 1513 1630 166 0.7 1442.7 2.2X -Hive built-in ORC 3953 3960 10 0.3 3770.0 0.9X +Hive built-in ORC 4044 4112 96 0.3 3857.0 1.0X +Native ORC MR 4086 4092 8 0.3 3897.0 1.0X +Native ORC Vectorized 977 1007 43 1.1 931.5 4.1X -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Nested Struct Scan with 30 Elements, 10 Fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 7667 7684 24 0.1 7311.9 1.0X -Native ORC Vectorized 3865 3881 22 0.3 3685.8 2.0X -Hive built-in ORC 11223 11246 32 0.1 10703.5 0.7X +Hive built-in ORC 10733 10785 73 0.1 10236.0 1.0X +Native ORC MR 7707 7707 0 0.1 7349.8 1.4X +Native ORC Vectorized 2260 2318 82 0.5 2155.3 4.7X -OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1022-azure +OpenJDK 64-Bit Server VM 17.0.1+12-LTS on Linux 5.11.0-1025-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Nested Struct Scan with 10 Elements, 30 Fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 9506 9633 181 0.1 9065.4 1.0X -Native ORC Vectorized 4170 4320 212 0.3 3976.4 2.3X -Hive built-in ORC 12756 13821 1506 0.1 12164.7 0.7X +Hive built-in ORC 7851 8136 403 0.1 7487.6 1.0X +Native ORC MR 9074 9180 150 0.1 8653.9 0.9X +Native ORC Vectorized 2485 2588 146 0.4 2369.7 3.2X diff --git a/sql/hive/benchmarks/OrcReadBenchmark-results.txt b/sql/hive/benchmarks/OrcReadBenchmark-results.txt index a08c34968c87b..137bfcc148927 100644 --- a/sql/hive/benchmarks/OrcReadBenchmark-results.txt +++ b/sql/hive/benchmarks/OrcReadBenchmark-results.txt @@ -2,221 +2,221 @@ SQL Single Numeric Column Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 1016 1068 74 15.5 64.6 1.0X -Native ORC Vectorized 220 252 33 71.4 14.0 4.6X -Hive built-in ORC 1274 1290 22 12.3 81.0 0.8X +Hive built-in ORC 1138 1191 76 13.8 72.3 1.0X +Native ORC MR 999 1115 164 15.7 63.5 1.1X +Native ORC Vectorized 155 183 23 101.7 9.8 7.4X -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single SMALLINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 1117 1142 36 14.1 71.0 1.0X -Native ORC Vectorized 157 189 20 100.4 10.0 7.1X -Hive built-in ORC 1369 1399 42 11.5 87.1 0.8X +Hive built-in ORC 1034 1056 30 15.2 65.8 1.0X +Native ORC MR 859 878 19 18.3 54.6 1.2X +Native ORC Vectorized 130 155 22 121.1 8.3 8.0X -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single INT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 1064 1189 177 14.8 67.6 1.0X -Native ORC Vectorized 179 204 25 87.9 11.4 5.9X -Hive built-in ORC 1454 1468 20 10.8 92.4 0.7X +Hive built-in ORC 1056 1081 35 14.9 67.1 1.0X +Native ORC MR 946 1015 96 16.6 60.2 1.1X +Native ORC Vectorized 152 173 25 103.5 9.7 6.9X -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single BIGINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 1070 1196 177 14.7 68.1 1.0X -Native ORC Vectorized 216 232 14 72.8 13.7 5.0X -Hive built-in ORC 1484 1533 69 10.6 94.4 0.7X +Hive built-in ORC 1619 1776 222 9.7 103.0 1.0X +Native ORC MR 913 1015 145 17.2 58.0 1.8X +Native ORC Vectorized 187 207 19 84.3 11.9 8.7X -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single FLOAT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 1164 1181 24 13.5 74.0 1.0X -Native ORC Vectorized 264 290 24 59.6 16.8 4.4X -Hive built-in ORC 1536 1572 51 10.2 97.7 0.8X +Hive built-in ORC 1117 1138 30 14.1 71.0 1.0X +Native ORC MR 909 921 20 17.3 57.8 1.2X +Native ORC Vectorized 202 224 36 78.0 12.8 5.5X -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single DOUBLE Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 1127 1174 67 14.0 71.7 1.0X -Native ORC Vectorized 285 302 17 55.2 18.1 4.0X -Hive built-in ORC 1571 1582 16 10.0 99.9 0.7X +Hive built-in ORC 1123 1124 2 14.0 71.4 1.0X +Native ORC MR 933 951 22 16.9 59.3 1.2X +Native ORC Vectorized 231 247 34 68.1 14.7 4.9X ================================================================================================ Int and String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Int and String Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 2329 2413 119 4.5 222.1 1.0X -Native ORC Vectorized 1274 1282 12 8.2 121.5 1.8X -Hive built-in ORC 2622 2692 99 4.0 250.0 0.9X +Hive built-in ORC 2149 2163 21 4.9 204.9 1.0X +Native ORC MR 1844 1863 27 5.7 175.9 1.2X +Native ORC Vectorized 1059 1071 18 9.9 101.0 2.0X ================================================================================================ Partitioned Table Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Partitioned Table: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Data column - Native ORC MR 1304 1309 8 12.1 82.9 1.0X -Data column - Native ORC Vectorized 221 259 25 71.1 14.1 5.9X -Data column - Hive built-in ORC 1586 1606 28 9.9 100.8 0.8X -Partition column - Native ORC MR 868 889 29 18.1 55.2 1.5X -Partition column - Native ORC Vectorized 71 85 18 222.3 4.5 18.4X -Partition column - Hive built-in ORC 1210 1241 43 13.0 77.0 1.1X -Both columns - Native ORC MR 1397 1435 54 11.3 88.8 0.9X -Both columns - Native ORC Vectorized 236 257 22 66.5 15.0 5.5X -Both columns - Hive built-in ORC 1723 1726 4 9.1 109.6 0.8X +Data column - Hive built-in ORC 1218 1220 3 12.9 77.4 1.0X +Data column - Native ORC MR 1110 1113 4 14.2 70.6 1.1X +Data column - Native ORC Vectorized 185 205 19 85.1 11.7 6.6X +Partition column - Hive built-in ORC 884 897 18 17.8 56.2 1.4X +Partition column - Native ORC MR 701 745 71 22.4 44.6 1.7X +Partition column - Native ORC Vectorized 56 65 6 281.7 3.5 21.8X +Both columns - Hive built-in ORC 1206 1225 26 13.0 76.7 1.0X +Both columns - Native ORC MR 1103 1164 86 14.3 70.1 1.1X +Both columns - Native ORC Vectorized 201 240 47 78.4 12.8 6.1X ================================================================================================ Repeated String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Repeated String: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 1074 1089 21 9.8 102.4 1.0X -Native ORC Vectorized 221 254 33 47.5 21.0 4.9X -Hive built-in ORC 1435 1437 2 7.3 136.9 0.7X +Hive built-in ORC 1124 1136 17 9.3 107.2 1.0X +Native ORC MR 854 867 17 12.3 81.5 1.3X +Native ORC Vectorized 173 179 6 60.5 16.5 6.5X ================================================================================================ String with Nulls Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz String with Nulls Scan (0.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 1948 1964 21 5.4 185.8 1.0X -Native ORC Vectorized 666 687 31 15.7 63.5 2.9X -Hive built-in ORC 2454 2489 50 4.3 234.0 0.8X +Hive built-in ORC 1985 1985 0 5.3 189.3 1.0X +Native ORC MR 1557 1561 5 6.7 148.5 1.3X +Native ORC Vectorized 470 486 22 22.3 44.8 4.2X -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz String with Nulls Scan (50.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 1744 1756 16 6.0 166.4 1.0X -Native ORC Vectorized 707 736 38 14.8 67.4 2.5X -Hive built-in ORC 2225 2259 48 4.7 212.2 0.8X +Hive built-in ORC 1857 1891 49 5.6 177.1 1.0X +Native ORC MR 1508 1518 14 7.0 143.8 1.2X +Native ORC Vectorized 646 660 11 16.2 61.6 2.9X -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz String with Nulls Scan (95.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 996 1101 149 10.5 95.0 1.0X -Native ORC Vectorized 282 311 18 37.1 26.9 3.5X -Hive built-in ORC 1405 1420 20 7.5 134.0 0.7X +Hive built-in ORC 1066 1084 25 9.8 101.7 1.0X +Native ORC MR 834 851 14 12.6 79.6 1.3X +Native ORC Vectorized 242 269 36 43.3 23.1 4.4X ================================================================================================ Single Column Scan From Wide Columns ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Column Scan from 100 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 153 180 17 6.8 146.2 1.0X -Native ORC Vectorized 85 99 18 12.3 81.4 1.8X -Hive built-in ORC 912 971 97 1.2 869.4 0.2X +Hive built-in ORC 912 1006 133 1.2 869.3 1.0X +Native ORC MR 125 144 19 8.4 119.4 7.3X +Native ORC Vectorized 74 83 14 14.2 70.3 12.4X -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Column Scan from 200 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 254 272 15 4.1 242.5 1.0X -Native ORC Vectorized 122 138 15 8.6 116.6 2.1X -Hive built-in ORC 1772 1819 67 0.6 1689.5 0.1X +Hive built-in ORC 1502 1531 40 0.7 1432.7 1.0X +Native ORC MR 160 174 17 6.6 152.3 9.4X +Native ORC Vectorized 110 125 20 9.5 105.3 13.6X -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Column Scan from 300 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 233 271 31 4.5 222.5 1.0X -Native ORC Vectorized 162 184 25 6.5 154.8 1.4X -Hive built-in ORC 2591 2602 16 0.4 2470.6 0.1X +Hive built-in ORC 2184 2191 9 0.5 2082.9 1.0X +Native ORC MR 215 233 19 4.9 204.6 10.2X +Native ORC Vectorized 160 172 18 6.5 152.7 13.6X ================================================================================================ Struct scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Struct Column Scan with 10 Fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 369 415 54 2.8 351.7 1.0X -Native ORC Vectorized 201 214 9 5.2 191.3 1.8X -Hive built-in ORC 712 719 6 1.5 679.0 0.5X +Hive built-in ORC 513 558 70 2.0 489.3 1.0X +Native ORC MR 316 327 11 3.3 301.6 1.6X +Native ORC Vectorized 171 189 28 6.1 163.3 3.0X -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Struct Column Scan with 100 Fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -Native ORC MR 2764 2834 99 0.4 2636.2 1.0X -Native ORC Vectorized 1651 1669 26 0.6 1574.2 1.7X -Hive built-in ORC 3957 3998 58 0.3 3774.0 0.7X +Hive built-in ORC 3081 3260 254 0.3 2938.2 1.0X +Native ORC MR 2552 2627 105 0.4 2434.1 1.2X +Native ORC Vectorized 1473 1610 193 0.7 1404.8 2.1X -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Struct Column Scan with 300 Fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -Native ORC MR 9368 11693 NaN 0.1 8934.4 1.0X -Native ORC Vectorized 9324 9737 585 0.1 8891.6 1.0X -Hive built-in ORC 13303 13665 512 0.1 12687.2 0.7X +Hive built-in ORC 9531 10232 991 0.1 9089.8 1.0X +Native ORC MR 9412 9496 119 0.1 8975.6 1.0X +Native ORC Vectorized 9434 9483 69 0.1 8997.0 1.0X -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Struct Column Scan with 600 Fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -Native ORC MR 32403 35146 NaN 0.0 30902.3 1.0X -Native ORC Vectorized 38268 39336 1511 0.0 36495.2 0.8X -Hive built-in ORC 47590 48669 1525 0.0 45385.7 0.7X +Hive built-in ORC 34314 35490 1663 0.0 32724.4 1.0X +Native ORC MR 36051 36191 197 0.0 34381.3 1.0X +Native ORC Vectorized 36014 37273 1780 0.0 34346.1 1.0X ================================================================================================ Nested Struct scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Nested Struct Scan with 10 Elements, 10 Fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 5127 5720 838 0.2 4889.8 1.0X -Native ORC Vectorized 1064 1067 4 1.0 1014.8 4.8X -Hive built-in ORC 4622 4647 36 0.2 4407.6 1.1X +Hive built-in ORC 3492 3768 390 0.3 3330.1 1.0X +Native ORC MR 3918 3932 20 0.3 3736.1 0.9X +Native ORC Vectorized 893 911 17 1.2 851.7 3.9X -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Nested Struct Scan with 30 Elements, 10 Fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 11342 11343 2 0.1 10816.3 1.0X -Native ORC Vectorized 2889 2891 4 0.4 2755.1 3.9X -Hive built-in ORC 12754 12890 192 0.1 12163.6 0.9X +Hive built-in ORC 9499 10127 888 0.1 9058.7 1.0X +Native ORC MR 9227 9234 9 0.1 8799.9 1.0X +Native ORC Vectorized 2326 2389 89 0.5 2218.2 4.1X -OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1022-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1025-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Nested Struct Scan with 10 Elements, 30 Fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------ -Native ORC MR 12483 12602 167 0.1 11905.1 1.0X -Native ORC Vectorized 3522 3615 132 0.3 3358.5 3.5X -Hive built-in ORC 9775 9784 12 0.1 9322.4 1.3X +Hive built-in ORC 8315 8552 335 0.1 7929.5 1.0X +Native ORC MR 11559 12147 832 0.1 11023.1 0.7X +Native ORC Vectorized 2808 2965 222 0.4 2678.2 3.0X diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 37efb2d1ba49e..bd323dc4b24e1 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -40,6 +40,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled private val originalSessionLocalTimeZone = TestHive.conf.sessionLocalTimeZone + private val originalAnsiMode = TestHive.conf.getConf(SQLConf.ANSI_ENABLED) private val originalCreateHiveTable = TestHive.conf.getConf(SQLConf.LEGACY_CREATE_HIVE_TABLE_BY_DEFAULT) @@ -56,6 +57,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) // Ensures that cross joins are enabled so that we can test them TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true) + // Hive doesn't follow ANSI Standard. + TestHive.setConf(SQLConf.ANSI_ENABLED, false) // Ensures that the table insertion behavior is consistent with Hive TestHive.setConf(SQLConf.STORE_ASSIGNMENT_POLICY, StoreAssignmentPolicy.LEGACY.toString) // Fix session local timezone to America/Los_Angeles for those timezone sensitive tests @@ -72,6 +75,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled) TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, originalSessionLocalTimeZone) + TestHive.setConf(SQLConf.ANSI_ENABLED, originalAnsiMode) TestHive.setConf(SQLConf.LEGACY_CREATE_HIVE_TABLE_BY_DEFAULT, originalCreateHiveTable) // For debugging dump some statistics about how much time was spent in various optimizer rules diff --git a/sql/hive/src/main/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider b/sql/hive/src/main/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider index 2b0acc0305c49..eb7862b407c61 100644 --- a/sql/hive/src/main/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider +++ b/sql/hive/src/main/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider @@ -1 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + org.apache.spark.sql.hive.security.HiveDelegationTokenProvider diff --git a/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index e7d762fbebe76..bb06156b63339 100644 --- a/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,2 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + org.apache.spark.sql.hive.orc.OrcFileFormat org.apache.spark.sql.hive.execution.HiveFileFormat \ No newline at end of file diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 24e60529d227b..fefa032d35105 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -36,18 +36,17 @@ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{DatabaseAlreadyExistsException, TableAlreadyExistsException} +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils} -import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{PartitioningUtils, SourceOptions} import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.sql.internal.StaticSQLConf._ -import org.apache.spark.sql.types.{AnsiIntervalType, ArrayType, DataType, MapType, StructType} +import org.apache.spark.sql.types.{AnsiIntervalType, ArrayType, DataType, MapType, StructType, TimestampNTZType} /** * A persistent implementation of the system catalog using Hive. @@ -94,22 +93,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } /** - * Run some code involving `client` in a [[synchronized]] block and wrap non-fatal + * Run some code involving `client` in a [[synchronized]] block and wrap certain * exceptions thrown in the process in [[AnalysisException]]. */ - private def withClient[T](body: => T): T = withClientWrappingException { - body - } { - _ => None // Will fallback to default wrapping strategy in withClientWrappingException. - } - - /** - * Run some code involving `client` in a [[synchronized]] block and wrap non-fatal - * exceptions thrown in the process in [[AnalysisException]] using the given - * `wrapException` function. - */ - private def withClientWrappingException[T](body: => T) - (wrapException: Throwable => Option[AnalysisException]): T = synchronized { + private def withClient[T](body: => T): T = synchronized { try { body } catch { @@ -120,11 +107,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat case i: InvocationTargetException => i.getCause case o => o } - wrapException(e) match { - case Some(wrapped) => throw wrapped - case None => throw new AnalysisException( - e.getClass.getCanonicalName + ": " + e.getMessage, cause = Some(e)) - } + throw new AnalysisException( + e.getClass.getCanonicalName + ": " + e.getMessage, cause = Some(e)) } } @@ -204,32 +188,15 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat override def createDatabase( dbDefinition: CatalogDatabase, - ignoreIfExists: Boolean): Unit = withClientWrappingException { + ignoreIfExists: Boolean): Unit = withClient { client.createDatabase(dbDefinition, ignoreIfExists) - } { exception => - if (exception.getClass.getName.equals( - "org.apache.hadoop.hive.metastore.api.AlreadyExistsException") - && exception.getMessage.contains( - s"Database ${dbDefinition.name} already exists")) { - Some(new DatabaseAlreadyExistsException(dbDefinition.name)) - } else { - None - } } override def dropDatabase( db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = withClient { - try { - client.dropDatabase(db, ignoreIfNotExists, cascade) - } catch { - case NonFatal(exception) => - if (exception.getClass.getName.equals("org.apache.hadoop.hive.ql.metadata.HiveException") - && exception.getMessage.contains(s"Database $db is not empty.")) { - throw QueryCompilationErrors.cannotDropNonemptyDatabaseError(db) - } else throw exception - } + client.dropDatabase(db, ignoreIfNotExists, cascade) } /** @@ -801,8 +768,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val version: String = table.properties.getOrElse(CREATED_SPARK_VERSION, "2.2 or prior") // Restore Spark's statistics from information in Metastore. - val restoredStats = - statsFromProperties(table.properties, table.identifier.table, table.schema) + val restoredStats = statsFromProperties(table.properties, table.identifier.table) if (restoredStats.isDefined) { table = table.copy(stats = restoredStats) } @@ -1170,8 +1136,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat private def statsFromProperties( properties: Map[String, String], - table: String, - schema: StructType): Option[CatalogStatistics] = { + table: String): Option[CatalogStatistics] = { val statsProps = properties.filterKeys(_.startsWith(STATISTICS_PREFIX)) if (statsProps.isEmpty) { @@ -1241,8 +1206,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Restore Spark's statistics from information in Metastore. // Note: partition-level statistics were introduced in 2.3. - val restoredStats = - statsFromProperties(partition.parameters, table.identifier.table, table.schema) + val restoredStats = statsFromProperties(partition.parameters, table.identifier.table) if (restoredStats.isDefined) { partition.copy( spec = restoredSpec, @@ -1306,7 +1270,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // treats dot as matching any single character and may return more partitions than we // expected. Here we do an extra filter to drop unexpected partitions. case Some(spec) if spec.exists(_._2.contains(".")) => - res.filter(p => isPartialPartitionSpec(spec, p.spec)) + res.filter(p => isPartialPartitionSpec(spec, toMetaStorePartitionSpec(p.spec))) case _ => res } } @@ -1461,6 +1425,7 @@ object HiveExternalCatalog { private[spark] def isHiveCompatibleDataType(dt: DataType): Boolean = dt match { case _: AnsiIntervalType => false + case _: TimestampNTZType => false case s: StructType => s.forall(f => isHiveCompatibleDataType(f.dataType)) case a: ArrayType => isHiveCompatibleDataType(a.elementType) case m: MapType => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index c905a52c4836b..12b570e818650 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.HiveCaseSensitiveInferenceMode._ +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.types._ /** @@ -156,6 +157,32 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } } + def convertStorageFormat(storage: CatalogStorageFormat): CatalogStorageFormat = { + val serde = storage.serde.getOrElse("").toLowerCase(Locale.ROOT) + + if (serde.contains("parquet")) { + val options = storage.properties + (ParquetOptions.MERGE_SCHEMA -> + SQLConf.get.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING).toString) + storage.copy( + serde = None, + properties = options + ) + } else { + val options = storage.properties + if (SQLConf.get.getConf(SQLConf.ORC_IMPLEMENTATION) == "native") { + storage.copy( + serde = None, + properties = options + ) + } else { + storage.copy( + serde = None, + properties = options + ) + } + } + } + private def convertToLogicalRelation( relation: HiveTableRelation, options: Map[String, String], @@ -222,7 +249,8 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log // Spark SQL's data source table now support static and dynamic partition insert. Source // table converted from Hive table should always use dynamic. - val enableDynamicPartition = hiveOptions.updated("partitionOverwriteMode", "dynamic") + val enableDynamicPartition = hiveOptions.updated(DataSourceUtils.PARTITION_OVERWRITE_MODE, + PartitionOverwriteMode.DYNAMIC.toString) val fsRelation = HadoopFsRelation( location = fileIndex, partitionSchema = partitionSchema, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 37970fbe532d4..d1e222794a526 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoStatement, LogicalPlan, ScriptTransformation, Statistics} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils} +import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils, InsertIntoDataSourceDirCommand} import org.apache.spark.sql.execution.datasources.{CreateTable, DataSourceStrategy} import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.hive.execution.HiveScriptTransformationExec @@ -184,6 +184,10 @@ object HiveAnalysis extends Rule[LogicalPlan] { * Relation conversion from metastore relations to data source relations for better performance * * - When writing to non-partitioned Hive-serde Parquet/Orc tables + * - When writing to partitioned Hive-serde Parquet/Orc tables when + * `spark.sql.hive.convertInsertingPartitionedTable` is true + * - When writing to directory with Hive-serde + * - When writing to non-partitioned Hive-serde Parquet/ORC tables using CTAS * - When scanning Hive-serde Parquet/ORC tables * * This rule must be run before all other DDL post-hoc resolution rules, i.e. @@ -196,11 +200,20 @@ case class RelationConversions( } private def isConvertible(tableMeta: CatalogTable): Boolean = { - val serde = tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) + isConvertible(tableMeta.storage) + } + + private def isConvertible(storage: CatalogStorageFormat): Boolean = { + val serde = storage.serde.getOrElse("").toLowerCase(Locale.ROOT) serde.contains("parquet") && conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET) || serde.contains("orc") && conf.getConf(HiveUtils.CONVERT_METASTORE_ORC) } + private def convertProvider(storage: CatalogStorageFormat): String = { + val serde = storage.serde.getOrElse("").toLowerCase(Locale.ROOT) + if (serde.contains("parquet")) "parquet" else "orc" + } + private val metastoreCatalog = sessionCatalog.metastoreCatalog override def apply(plan: LogicalPlan): LogicalPlan = { @@ -228,6 +241,16 @@ case class RelationConversions( DDLUtils.checkTableColumns(tableDesc.copy(schema = query.schema)) OptimizedCreateHiveTableAsSelectCommand( tableDesc, query, query.output.map(_.name), mode) + + // INSERT HIVE DIR + case InsertIntoDir(_, storage, provider, query, overwrite) + if query.resolved && DDLUtils.isHiveTable(provider) && + isConvertible(storage) && conf.getConf(HiveUtils.CONVERT_METASTORE_INSERT_DIR) => + val outputPath = new Path(storage.locationUri.get) + if (overwrite) DDLUtils.verifyNotReadPath(query, outputPath) + + InsertIntoDataSourceDirCommand(metastoreCatalog.convertStorageFormat(storage), + convertProvider(storage), query, overwrite) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index 93a38e524ebdc..911cb98588d78 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -160,6 +160,15 @@ private[spark] object HiveUtils extends Logging { .booleanConf .createWithDefault(true) + val CONVERT_METASTORE_INSERT_DIR = buildConf("spark.sql.hive.convertMetastoreInsertDir") + .doc("When set to true, Spark will try to use built-in data source writer " + + "instead of Hive serde in INSERT OVERWRITE DIRECTORY. This flag is effective only if " + + "`spark.sql.hive.convertMetastoreParquet` or `spark.sql.hive.convertMetastoreOrc` is " + + "enabled respectively for Parquet and ORC formats") + .version("3.3.0") + .booleanConf + .createWithDefault(true) + val HIVE_METASTORE_SHARED_PREFIXES = buildStaticConf("spark.sql.hive.metastore.sharedPrefixes") .doc("A comma separated list of class prefixes that should be loaded using the classloader " + "that is shared between Spark SQL and a specific version of Hive. An example of classes " + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 9c9a4fd2b3741..3dddca844750d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -49,7 +49,7 @@ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchPartitionException, NoSuchPartitionsException, NoSuchTableException, PartitionAlreadyExistsException, PartitionsAlreadyExistException} +import org.apache.spark.sql.catalyst.analysis.{DatabaseAlreadyExistsException, NoSuchDatabaseException, NoSuchPartitionException, NoSuchPartitionsException, NoSuchTableException, PartitionAlreadyExistsException, PartitionsAlreadyExistException} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Expression @@ -332,14 +332,24 @@ private[hive] class HiveClientImpl( database: CatalogDatabase, ignoreIfExists: Boolean): Unit = withHiveState { val hiveDb = toHiveDatabase(database, Some(userName)) - shim.createDatabase(client, hiveDb, ignoreIfExists) + try { + shim.createDatabase(client, hiveDb, ignoreIfExists) + } catch { + case _: AlreadyExistsException => + throw new DatabaseAlreadyExistsException(database.name) + } } override def dropDatabase( name: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = withHiveState { - shim.dropDatabase(client, name, true, ignoreIfNotExists, cascade) + try { + shim.dropDatabase(client, name, true, ignoreIfNotExists, cascade) + } catch { + case e: HiveException if e.getMessage.contains(s"Database $name is not empty") => + throw QueryCompilationErrors.cannotDropNonemptyDatabaseError(name) + } } override def alterDatabase(database: CatalogDatabase): Unit = withHiveState { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index c197b17224c9c..67bb72c187802 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -978,6 +978,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { val inSetThreshold = SQLConf.get.metastorePartitionPruningInSetThreshold object ExtractAttribute { + @scala.annotation.tailrec def unapply(expr: Expression): Option[Attribute] = { expr match { case attr: Attribute => Some(attr) @@ -1144,10 +1145,11 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { // Because there is no way to know whether the partition properties has timeZone, // client-side filtering cannot be used with TimeZoneAwareExpression. def hasTimeZoneAwareExpression(e: Expression): Boolean = { - e.collectFirst { - case cast: CastBase if cast.needsTimeZone => cast - case tz: TimeZoneAwareExpression if !tz.isInstanceOf[CastBase] => tz - }.isDefined + e.exists { + case cast: CastBase => cast.needsTimeZone + case tz: TimeZoneAwareExpression => !tz.isInstanceOf[CastBase] + case _ => false + } } if (!SQLConf.get.metastorePartitionPruningFastFallback || diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 828f9872eb159..15c172a6e75c2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -69,7 +69,7 @@ private[hive] object IsolatedClientLoader extends Logging { // If the error message contains hadoop, it is probably because the hadoop // version cannot be resolved. val fallbackVersion = if (VersionUtils.isHadoop3) { - "3.3.1" + "3.3.2" } else { "2.7.4" } @@ -316,12 +316,12 @@ private[hive] class IsolatedClientLoader( .asInstanceOf[HiveClient] } catch { case e: InvocationTargetException => - if (e.getCause().isInstanceOf[NoClassDefFoundError]) { - val cnf = e.getCause().asInstanceOf[NoClassDefFoundError] - throw QueryExecutionErrors.loadHiveClientCausesNoClassDefFoundError( - cnf, execJars, HiveUtils.HIVE_METASTORE_JARS.key, e) - } else { - throw e + e.getCause match { + case cnf: NoClassDefFoundError => + throw QueryExecutionErrors.loadHiveClientCausesNoClassDefFoundError( + cnf, execJars, HiveUtils.HIVE_METASTORE_JARS.key, e) + case _ => + throw e } } finally { Thread.currentThread.setContextClassLoader(origLoader) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala index b6b3cac4130a0..7dc1fbb433cd5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat} import org.apache.hadoop.hive.serde2.Serializer import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorUtils, StructObjectInspector} import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.{JobConf, Reporter} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} @@ -106,6 +107,21 @@ class HiveFileFormat(fileSinkConf: FileSinkDesc) } } } + + override def supportFieldName(name: String): Boolean = { + fileSinkConf.getTableInfo.getOutputFileFormatClassName match { + case "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat" => + !name.matches(".*[ ,;{}()\n\t=].*") + case "org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat" => + try { + TypeInfoUtils.getTypeInfoFromTypeString(s"struct<$name:int>") + true + } catch { + case _: IllegalArgumentException => false + } + case _ => true + } + } } class HiveOutputWriter( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala index 219b1a27f70a2..beb5583d81a60 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala @@ -64,7 +64,7 @@ private[hive] case class HiveScriptTransformationExec( outputSoi: StructObjectInspector, hadoopConf: Configuration): Iterator[InternalRow] = { new Iterator[InternalRow] with HiveInspectors { - var curLine: String = null + private var completed = false val scriptOutputStream = new DataInputStream(inputStream) val scriptOutputReader = @@ -78,6 +78,9 @@ private[hive] case class HiveScriptTransformationExec( lazy val unwrappers = outputSoi.getAllStructFieldRefs.asScala.map(unwrapperFor) override def hasNext: Boolean = { + if (completed) { + return false + } try { if (scriptOutputWritable == null) { scriptOutputWritable = reusedWritableObject @@ -85,6 +88,7 @@ private[hive] case class HiveScriptTransformationExec( if (scriptOutputReader != null) { if (scriptOutputReader.next(scriptOutputWritable) <= 0) { checkFailureAndPropagate(writerThread, null, proc, stderrBuffer) + completed = true return false } } else { @@ -97,6 +101,7 @@ private[hive] case class HiveScriptTransformationExec( // there can be a lag between EOF being written out and the process // being terminated. So explicitly waiting for the process to be done. checkFailureAndPropagate(writerThread, null, proc, stderrBuffer) + completed = true return false } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala index 1a5f47bf5aa7d..f6a85c4778bd1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala @@ -93,7 +93,7 @@ private[hive] object OrcFileOperator extends Logging { : Option[StructType] = { // Take the first file where we can open a valid reader if we can find one. Otherwise just // return None to indicate we can't infer the schema. - paths.toIterator.map(getFileReader(_, conf, ignoreCorruptFiles)).collectFirst { + paths.iterator.map(getFileReader(_, conf, ignoreCorruptFiles)).collectFirst { case Some(reader) => val readerInspector = reader.getObjectInspector.asInstanceOf[StructObjectInspector] val schema = readerInspector.getTypeName diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSourceSuite.scala index 7690e1e9e1465..5778b259c7d5a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSourceSuite.scala @@ -207,23 +207,6 @@ class HiveParquetSourceSuite extends ParquetPartitioningTest with ParquetTest { } } - test("Aggregation attribute names can't contain special chars \" ,;{}()\\n\\t=\"") { - withTempDir { tempDir => - val filePath = new File(tempDir, "testParquet").getCanonicalPath - val filePath2 = new File(tempDir, "testParquet2").getCanonicalPath - - val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") - val df2 = df.as("x").join(df.as("y"), $"x.str" === $"y.str").groupBy("y.str").max("y.int") - intercept[Throwable](df2.write.parquet(filePath)) - - val df3 = df2.toDF("str", "max_int") - df3.write.parquet(filePath2) - val df4 = read.parquet(filePath2) - checkAnswer(df4, Row("1", 1) :: Row("2", 2) :: Row("3", 3) :: Nil) - assert(df4.columns === Array("str", "max_int")) - } - } - test("SPARK-25993 CREATE EXTERNAL TABLE with subdirectories") { Seq("true", "false").foreach { parquetConversion => withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> parquetConversion) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 90752e70e1b57..170cf4898f314 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -371,7 +371,7 @@ class HiveSparkSubmitSuite object SetMetastoreURLTest extends Logging { def main(args: Array[String]): Unit = { - TestUtils.configTestLog4j("INFO") + TestUtils.configTestLog4j2("INFO") val sparkConf = new SparkConf(loadDefaults = true) val builder = SparkSession.builder() @@ -409,7 +409,7 @@ object SetMetastoreURLTest extends Logging { object SetWarehouseLocationTest extends Logging { def main(args: Array[String]): Unit = { - TestUtils.configTestLog4j("INFO") + TestUtils.configTestLog4j2("INFO") val sparkConf = new SparkConf(loadDefaults = true).set(UI_ENABLED, false) val providedExpectedWarehouseLocation = @@ -489,7 +489,7 @@ object SetWarehouseLocationTest extends Logging { // can load the jar defined with the function. object TemporaryHiveUDFTest extends Logging { def main(args: Array[String]): Unit = { - TestUtils.configTestLog4j("INFO") + TestUtils.configTestLog4j2("INFO") val conf = new SparkConf() conf.set(UI_ENABLED, false) val sc = new SparkContext(conf) @@ -527,7 +527,7 @@ object TemporaryHiveUDFTest extends Logging { // can load the jar defined with the function. object PermanentHiveUDFTest1 extends Logging { def main(args: Array[String]): Unit = { - TestUtils.configTestLog4j("INFO") + TestUtils.configTestLog4j2("INFO") val conf = new SparkConf() conf.set(UI_ENABLED, false) val sc = new SparkContext(conf) @@ -565,7 +565,7 @@ object PermanentHiveUDFTest1 extends Logging { // can load the jar defined with the function. object PermanentHiveUDFTest2 extends Logging { def main(args: Array[String]): Unit = { - TestUtils.configTestLog4j("INFO") + TestUtils.configTestLog4j2("INFO") val conf = new SparkConf() conf.set(UI_ENABLED, false) val sc = new SparkContext(conf) @@ -600,7 +600,7 @@ object PermanentHiveUDFTest2 extends Logging { // We test if we can load user jars in both driver and executors when HiveContext is used. object SparkSubmitClassLoaderTest extends Logging { def main(args: Array[String]): Unit = { - TestUtils.configTestLog4j("INFO") + TestUtils.configTestLog4j2("INFO") val conf = new SparkConf() val hiveWarehouseLocation = Utils.createTempDir() conf.set(UI_ENABLED, false) @@ -670,7 +670,7 @@ object SparkSubmitClassLoaderTest extends Logging { // We test if we can correctly set spark sql configurations when HiveContext is used. object SparkSQLConfTest extends Logging { def main(args: Array[String]): Unit = { - TestUtils.configTestLog4j("INFO") + TestUtils.configTestLog4j2("INFO") // We override the SparkConf to add spark.sql.hive.metastore.version and // spark.sql.hive.metastore.jars to the beginning of the conf entry array. // So, if metadataHive get initialized after we set spark.sql.hive.metastore.version but @@ -711,7 +711,7 @@ object SPARK_9757 extends QueryTest { protected var spark: SparkSession = _ def main(args: Array[String]): Unit = { - TestUtils.configTestLog4j("INFO") + TestUtils.configTestLog4j2("INFO") val hiveWarehouseLocation = Utils.createTempDir() val sparkContext = new SparkContext( @@ -760,7 +760,7 @@ object SPARK_11009 extends QueryTest { protected var spark: SparkSession = _ def main(args: Array[String]): Unit = { - TestUtils.configTestLog4j("INFO") + TestUtils.configTestLog4j2("INFO") val sparkContext = new SparkContext( new SparkConf() @@ -791,7 +791,7 @@ object SPARK_14244 extends QueryTest { protected var spark: SparkSession = _ def main(args: Array[String]): Unit = { - TestUtils.configTestLog4j("INFO") + TestUtils.configTestLog4j2("INFO") val sparkContext = new SparkContext( new SparkConf() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala index 177c227595162..9e29386475232 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala @@ -717,7 +717,7 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter """.stripMargin) }.getMessage - assert(e.contains("mismatched input 'ROW'")) + assert(e.contains("Syntax error at or near 'ROW'")) } } @@ -739,7 +739,7 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter """.stripMargin) }.getMessage - assert(e.contains("mismatched input 'ROW'")) + assert(e.contains("Syntax error at or near 'ROW'")) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index dbe1b1234da99..16b5d6cf1bf8b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -1429,14 +1429,15 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv | c10 ARRAY, | c11 MAP, | c12 MAP, - | c13 MAP + | c13 MAP, + | c14 TIMESTAMP_NTZ |) USING Parquet""".stripMargin) } val expectedMsg = "Hive incompatible types found: interval day to minute, " + "interval year to month, interval hour, interval month, " + "struct, " + "array, map, " + - "map. " + + "map, timestamp_ntz. " + "Persisting data source table `default`.`t` into Hive metastore in " + "Spark SQL specific format, which is NOT compatible with Hive." val actualMessages = logAppender.loggingEvents @@ -1467,7 +1468,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv StructField("c10", ArrayType(YearMonthIntervalType(YEAR))), StructField("c11", MapType(IntegerType, StringType)), StructField("c12", MapType(IntegerType, DayTimeIntervalType(DAY))), - StructField("c13", MapType(DayTimeIntervalType(MINUTE, SECOND), StringType))))) + StructField("c13", MapType(DayTimeIntervalType(MINUTE, SECOND), StringType)), + StructField("c14", TimestampNTZType)))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala new file mode 100644 index 0000000000000..a23efd8ffd34d --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -0,0 +1,1072 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import java.io.{ByteArrayOutputStream, File, PrintStream, PrintWriter} +import java.net.URI + +import org.apache.commons.lang3.{JavaVersion, SystemUtils} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.common.StatsSetupConst +import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat +import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe +import org.apache.hadoop.mapred.TextInputFormat +import org.apache.hadoop.security.UserGroupInformation + +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis.{DatabaseAlreadyExistsException, NoSuchDatabaseException, NoSuchPermanentFunctionException, PartitionsAlreadyExistException} +import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal} +import org.apache.spark.sql.hive.HiveExternalCatalog +import org.apache.spark.sql.hive.test.TestHiveVersion +import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.util.{MutableURLClassLoader, Utils} + +class HiveClientSuite(version: String, allVersions: Seq[String]) + extends HiveVersionSuite(version) { + + private var versionSpark: TestHiveVersion = null + + private val emptyDir = Utils.createTempDir().getCanonicalPath + + /** + * Drops table `tableName` after calling `f`. + */ + protected def withTable(tableNames: String*)(f: => Unit): Unit = { + try f finally { + tableNames.foreach { name => + versionSpark.sql(s"DROP TABLE IF EXISTS $name") + } + } + } + + test("create client") { + client = null + System.gc() // Hack to avoid SEGV on some JVM versions. + val hadoopConf = new Configuration() + hadoopConf.set("test", "success") + client = buildClient(hadoopConf) + if (versionSpark != null) versionSpark.reset() + versionSpark = TestHiveVersion(client) + assert(versionSpark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] + .client.version.fullVersion.startsWith(version)) + } + + def table(database: String, tableName: String, + tableType: CatalogTableType = CatalogTableType.MANAGED): CatalogTable = { + CatalogTable( + identifier = TableIdentifier(tableName, Some(database)), + tableType = tableType, + schema = new StructType().add("key", "int"), + storage = CatalogStorageFormat( + locationUri = None, + inputFormat = Some(classOf[TextInputFormat].getName), + outputFormat = Some(classOf[HiveIgnoreKeyTextOutputFormat[_, _]].getName), + serde = Some(classOf[LazySimpleSerDe].getName), + compressed = false, + properties = Map.empty + )) + } + + /////////////////////////////////////////////////////////////////////////// + // Database related API + /////////////////////////////////////////////////////////////////////////// + + private val tempDatabasePath = Utils.createTempDir().toURI + + test("createDatabase") { + val defaultDB = CatalogDatabase("default", "desc", new URI("loc"), Map()) + client.createDatabase(defaultDB, ignoreIfExists = true) + val tempDB = CatalogDatabase( + "temporary", description = "test create", tempDatabasePath, Map()) + client.createDatabase(tempDB, ignoreIfExists = true) + + intercept[DatabaseAlreadyExistsException] { + client.createDatabase(tempDB, ignoreIfExists = false) + } + } + + test("create/get/alter database should pick right user name as owner") { + if (version != "0.12") { + val currentUser = UserGroupInformation.getCurrentUser.getUserName + val ownerName = "SPARK_29425" + val db1 = "SPARK_29425_1" + val db2 = "SPARK_29425_2" + val ownerProps = Map("owner" -> ownerName) + + // create database with owner + val dbWithOwner = CatalogDatabase(db1, "desc", Utils.createTempDir().toURI, ownerProps) + client.createDatabase(dbWithOwner, ignoreIfExists = true) + val getDbWithOwner = client.getDatabase(db1) + assert(getDbWithOwner.properties("owner") === ownerName) + // alter database without owner + client.alterDatabase(getDbWithOwner.copy(properties = Map())) + assert(client.getDatabase(db1).properties("owner") === "") + + // create database without owner + val dbWithoutOwner = CatalogDatabase(db2, "desc", Utils.createTempDir().toURI, Map()) + client.createDatabase(dbWithoutOwner, ignoreIfExists = true) + val getDbWithoutOwner = client.getDatabase(db2) + assert(getDbWithoutOwner.properties("owner") === currentUser) + // alter database with owner + client.alterDatabase(getDbWithoutOwner.copy(properties = ownerProps)) + assert(client.getDatabase(db2).properties("owner") === ownerName) + } + } + + test("createDatabase with null description") { + withTempDir { tmpDir => + val dbWithNullDesc = + CatalogDatabase("dbWithNullDesc", description = null, tmpDir.toURI, Map()) + client.createDatabase(dbWithNullDesc, ignoreIfExists = true) + assert(client.getDatabase("dbWithNullDesc").description == "") + } + } + + test("setCurrentDatabase") { + client.setCurrentDatabase("default") + } + + test("getDatabase") { + // No exception should be thrown + client.getDatabase("default") + intercept[NoSuchDatabaseException](client.getDatabase("nonexist")) + } + + test("databaseExists") { + assert(client.databaseExists("default")) + assert(!client.databaseExists("nonexist")) + } + + test("listDatabases") { + assert(client.listDatabases("defau.*") == Seq("default")) + } + + test("alterDatabase") { + val database = client.getDatabase("temporary").copy(properties = Map("flag" -> "true")) + client.alterDatabase(database) + assert(client.getDatabase("temporary").properties.contains("flag")) + + // test alter database location + val tempDatabasePath2 = Utils.createTempDir().toURI + // Hive support altering database location since HIVE-8472. + if (version == "3.0" || version == "3.1") { + client.alterDatabase(database.copy(locationUri = tempDatabasePath2)) + val uriInCatalog = client.getDatabase("temporary").locationUri + assert("file" === uriInCatalog.getScheme) + assert(new Path(tempDatabasePath2.getPath).toUri.getPath === uriInCatalog.getPath, + "Failed to alter database location") + } else { + val e = intercept[AnalysisException] { + client.alterDatabase(database.copy(locationUri = tempDatabasePath2)) + } + assert(e.getMessage.contains("does not support altering database location")) + } + } + + test("dropDatabase") { + assert(client.databaseExists("temporary")) + + client.createTable(table("temporary", tableName = "tbl"), ignoreIfExists = false) + val ex = intercept[AnalysisException] { + client.dropDatabase("temporary", ignoreIfNotExists = false, cascade = false) + assert(false, "dropDatabase should throw HiveException") + } + assert(ex.message.contains("Cannot drop a non-empty database: temporary.")) + + client.dropDatabase("temporary", ignoreIfNotExists = false, cascade = true) + assert(!client.databaseExists("temporary")) + } + + /////////////////////////////////////////////////////////////////////////// + // Table related API + /////////////////////////////////////////////////////////////////////////// + + test("createTable") { + client.createTable(table("default", tableName = "src"), ignoreIfExists = false) + client.createTable(table("default", tableName = "temporary"), ignoreIfExists = false) + client.createTable(table("default", tableName = "view1", tableType = CatalogTableType.VIEW), + ignoreIfExists = false) + } + + test("loadTable") { + client.loadTable( + emptyDir, + tableName = "src", + replace = false, + isSrcLocal = false) + } + + test("tableExists") { + // No exception should be thrown + assert(client.tableExists("default", "src")) + assert(!client.tableExists("default", "nonexistent")) + } + + test("getTable") { + // No exception should be thrown + client.getTable("default", "src") + } + + test("getTableOption") { + assert(client.getTableOption("default", "src").isDefined) + } + + test("getTablesByName") { + assert(client.getTablesByName("default", Seq("src")).head + == client.getTableOption("default", "src").get) + } + + test("getTablesByName when multiple tables") { + assert(client.getTablesByName("default", Seq("src", "temporary")) + .map(_.identifier.table) == Seq("src", "temporary")) + } + + test("getTablesByName when some tables do not exist") { + assert(client.getTablesByName("default", Seq("src", "notexist")) + .map(_.identifier.table) == Seq("src")) + } + + test("getTablesByName when contains invalid name") { + // scalastyle:off + val name = "ç –" + // scalastyle:on + assert(client.getTablesByName("default", Seq("src", name)) + .map(_.identifier.table) == Seq("src")) + } + + test("getTablesByName when empty") { + assert(client.getTablesByName("default", Seq.empty).isEmpty) + } + + test("alterTable(table: CatalogTable)") { + val newTable = client.getTable("default", "src").copy(properties = Map("changed" -> "")) + client.alterTable(newTable) + assert(client.getTable("default", "src").properties.contains("changed")) + } + + test("alterTable - should respect the original catalog table's owner name") { + val ownerName = "SPARK-29405" + val originalTable = client.getTable("default", "src") + // mocking the owner is what we declared + val newTable = originalTable.copy(owner = ownerName) + client.alterTable(newTable) + assert(client.getTable("default", "src").owner === ownerName) + // mocking the owner is empty + val newTable2 = originalTable.copy(owner = "") + client.alterTable(newTable2) + assert(client.getTable("default", "src").owner === client.userName) + } + + test("alterTable(dbName: String, tableName: String, table: CatalogTable)") { + val newTable = client.getTable("default", "src").copy(properties = Map("changedAgain" -> "")) + client.alterTable("default", "src", newTable) + assert(client.getTable("default", "src").properties.contains("changedAgain")) + } + + test("alterTable - rename") { + val newTable = client.getTable("default", "src") + .copy(identifier = TableIdentifier("tgt", database = Some("default"))) + assert(!client.tableExists("default", "tgt")) + + client.alterTable("default", "src", newTable) + + assert(client.tableExists("default", "tgt")) + assert(!client.tableExists("default", "src")) + } + + test("alterTable - change database") { + val tempDB = CatalogDatabase( + "temporary", description = "test create", tempDatabasePath, Map()) + client.createDatabase(tempDB, ignoreIfExists = true) + + val newTable = client.getTable("default", "tgt") + .copy(identifier = TableIdentifier("tgt", database = Some("temporary"))) + assert(!client.tableExists("temporary", "tgt")) + + client.alterTable("default", "tgt", newTable) + + assert(client.tableExists("temporary", "tgt")) + assert(!client.tableExists("default", "tgt")) + } + + test("alterTable - change database and table names") { + val newTable = client.getTable("temporary", "tgt") + .copy(identifier = TableIdentifier("src", database = Some("default"))) + assert(!client.tableExists("default", "src")) + + client.alterTable("temporary", "tgt", newTable) + + assert(client.tableExists("default", "src")) + assert(!client.tableExists("temporary", "tgt")) + } + + test("listTables(database)") { + assert(client.listTables("default") === Seq("src", "temporary", "view1")) + } + + test("listTables(database, pattern)") { + assert(client.listTables("default", pattern = "src") === Seq("src")) + assert(client.listTables("default", pattern = "nonexist").isEmpty) + } + + test("listTablesByType(database, pattern, tableType)") { + assert(client.listTablesByType("default", pattern = "view1", + CatalogTableType.VIEW) === Seq("view1")) + assert(client.listTablesByType("default", pattern = "nonexist", + CatalogTableType.VIEW).isEmpty) + } + + test("dropTable") { + val versionsWithoutPurge = + if (allVersions.contains("0.14")) allVersions.takeWhile(_ != "0.14") else Nil + // First try with the purge option set. This should fail if the version is < 0.14, in which + // case we check the version and try without it. + try { + client.dropTable("default", tableName = "temporary", ignoreIfNotExists = false, + purge = true) + assert(!versionsWithoutPurge.contains(version)) + } catch { + case _: UnsupportedOperationException => + assert(versionsWithoutPurge.contains(version)) + client.dropTable("default", tableName = "temporary", ignoreIfNotExists = false, + purge = false) + } + // Drop table with type CatalogTableType.VIEW. + try { + client.dropTable("default", tableName = "view1", ignoreIfNotExists = false, + purge = true) + assert(!versionsWithoutPurge.contains(version)) + } catch { + case _: UnsupportedOperationException => + client.dropTable("default", tableName = "view1", ignoreIfNotExists = false, + purge = false) + } + assert(client.listTables("default") === Seq("src")) + } + + /////////////////////////////////////////////////////////////////////////// + // Partition related API + /////////////////////////////////////////////////////////////////////////// + + private val storageFormat = CatalogStorageFormat( + locationUri = None, + inputFormat = None, + outputFormat = None, + serde = None, + compressed = false, + properties = Map.empty) + + test("sql create partitioned table") { + val table = CatalogTable( + identifier = TableIdentifier("src_part", Some("default")), + tableType = CatalogTableType.MANAGED, + schema = new StructType().add("value", "int").add("key1", "int").add("key2", "int"), + partitionColumnNames = Seq("key1", "key2"), + storage = CatalogStorageFormat( + locationUri = None, + inputFormat = Some(classOf[TextInputFormat].getName), + outputFormat = Some(classOf[HiveIgnoreKeyTextOutputFormat[_, _]].getName), + serde = Some(classOf[LazySimpleSerDe].getName), + compressed = false, + properties = Map.empty + )) + client.createTable(table, ignoreIfExists = false) + } + + val testPartitionCount = 2 + + test("createPartitions") { + val partitions = (1 to testPartitionCount).map { key2 => + CatalogTablePartition(Map("key1" -> "1", "key2" -> key2.toString), storageFormat) + } + client.createPartitions( + "default", "src_part", partitions, ignoreIfExists = true) + } + + test("getPartitionNames(catalogTable)") { + val partitionNames = (1 to testPartitionCount).map(key2 => s"key1=1/key2=$key2") + assert(partitionNames == client.getPartitionNames(client.getTable("default", "src_part"))) + } + + test("getPartitions(db, table, spec)") { + assert(testPartitionCount == + client.getPartitions("default", "src_part", None).size) + } + + test("getPartitionsByFilter") { + // Only one partition [1, 1] for key2 == 1 + val result = client.getPartitionsByFilter(client.getTable("default", "src_part"), + Seq(EqualTo(AttributeReference("key2", IntegerType)(), Literal(1)))) + + // Hive 0.12 doesn't support getPartitionsByFilter, it ignores the filter condition. + if (version != "0.12") { + assert(result.size == 1) + } else { + assert(result.size == testPartitionCount) + } + } + + test("getPartition") { + // No exception should be thrown + client.getPartition("default", "src_part", Map("key1" -> "1", "key2" -> "2")) + } + + test("getPartitionOption(db: String, table: String, spec: TablePartitionSpec)") { + val partition = client.getPartitionOption( + "default", "src_part", Map("key1" -> "1", "key2" -> "2")) + assert(partition.isDefined) + } + + test("getPartitionOption(table: CatalogTable, spec: TablePartitionSpec)") { + val partition = client.getPartitionOption( + client.getTable("default", "src_part"), Map("key1" -> "1", "key2" -> "2")) + assert(partition.isDefined) + } + + test("getPartitions(db: String, table: String)") { + assert(testPartitionCount == client.getPartitions("default", "src_part", None).size) + } + + test("loadPartition") { + val partSpec = new java.util.LinkedHashMap[String, String] + partSpec.put("key1", "1") + partSpec.put("key2", "2") + + client.loadPartition( + emptyDir, + "default", + "src_part", + partSpec, + replace = false, + inheritTableSpecs = false, + isSrcLocal = false) + } + + test("loadDynamicPartitions") { + val partSpec = new java.util.LinkedHashMap[String, String] + partSpec.put("key1", "1") + partSpec.put("key2", "") // Dynamic partition + + client.loadDynamicPartitions( + emptyDir, + "default", + "src_part", + partSpec, + replace = false, + numDP = 1) + } + + test("renamePartitions") { + val oldSpec = Map("key1" -> "1", "key2" -> "1") + val newSpec = Map("key1" -> "1", "key2" -> "3") + client.renamePartitions("default", "src_part", Seq(oldSpec), Seq(newSpec)) + + // Checks the existence of the new partition (key1 = 1, key2 = 3) + assert(client.getPartitionOption("default", "src_part", newSpec).isDefined) + } + + test("alterPartitions") { + val spec = Map("key1" -> "1", "key2" -> "2") + val parameters = Map(StatsSetupConst.TOTAL_SIZE -> "0", StatsSetupConst.NUM_FILES -> "1") + val newLocation = new URI(Utils.createTempDir().toURI.toString.stripSuffix("/")) + val storage = storageFormat.copy( + locationUri = Some(newLocation), + // needed for 0.12 alter partitions + serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + val partition = CatalogTablePartition(spec, storage, parameters) + client.alterPartitions("default", "src_part", Seq(partition)) + assert(client.getPartition("default", "src_part", spec) + .storage.locationUri.contains(newLocation)) + assert(client.getPartition("default", "src_part", spec) + .parameters.get(StatsSetupConst.TOTAL_SIZE).contains("0")) + } + + test("dropPartitions") { + val spec = Map("key1" -> "1", "key2" -> "3") + val versionsWithoutPurge = + if (allVersions.contains("1.2")) allVersions.takeWhile(_ != "1.2") else Nil + // Similar to dropTable; try with purge set, and if it fails, make sure we're running + // with a version that is older than the minimum (1.2 in this case). + try { + client.dropPartitions("default", "src_part", Seq(spec), ignoreIfNotExists = true, + purge = true, retainData = false) + assert(!versionsWithoutPurge.contains(version)) + } catch { + case _: UnsupportedOperationException => + assert(versionsWithoutPurge.contains(version)) + client.dropPartitions("default", "src_part", Seq(spec), ignoreIfNotExists = true, + purge = false, retainData = false) + } + + assert(client.getPartitionOption("default", "src_part", spec).isEmpty) + } + + test("createPartitions if already exists") { + val partitions = Seq(CatalogTablePartition( + Map("key1" -> "101", "key2" -> "102"), + storageFormat)) + try { + client.createPartitions("default", "src_part", partitions, ignoreIfExists = false) + val errMsg = intercept[PartitionsAlreadyExistException] { + client.createPartitions("default", "src_part", partitions, ignoreIfExists = false) + }.getMessage + assert(errMsg.contains("partitions already exists")) + } finally { + client.dropPartitions( + "default", + "src_part", + partitions.map(_.spec), + ignoreIfNotExists = true, + purge = false, + retainData = false) + } + } + + /////////////////////////////////////////////////////////////////////////// + // Function related API + /////////////////////////////////////////////////////////////////////////// + + def function(name: String, className: String): CatalogFunction = { + CatalogFunction( + FunctionIdentifier(name, Some("default")), className, Seq.empty[FunctionResource]) + } + + test("createFunction") { + val functionClass = "org.apache.spark.MyFunc1" + if (version == "0.12") { + // Hive 0.12 doesn't support creating permanent functions + intercept[AnalysisException] { + client.createFunction("default", function("func1", functionClass)) + } + } else { + client.createFunction("default", function("func1", functionClass)) + } + } + + test("functionExists") { + if (version == "0.12") { + // Hive 0.12 doesn't allow customized permanent functions + assert(!client.functionExists("default", "func1")) + } else { + assert(client.functionExists("default", "func1")) + } + } + + test("renameFunction") { + if (version == "0.12") { + // Hive 0.12 doesn't allow customized permanent functions + intercept[NoSuchPermanentFunctionException] { + client.renameFunction("default", "func1", "func2") + } + } else { + client.renameFunction("default", "func1", "func2") + assert(client.functionExists("default", "func2")) + } + } + + test("alterFunction") { + val functionClass = "org.apache.spark.MyFunc2" + if (version == "0.12") { + // Hive 0.12 doesn't allow customized permanent functions + intercept[NoSuchPermanentFunctionException] { + client.alterFunction("default", function("func2", functionClass)) + } + } else { + client.alterFunction("default", function("func2", functionClass)) + } + } + + test("getFunction") { + if (version == "0.12") { + // Hive 0.12 doesn't allow customized permanent functions + intercept[NoSuchPermanentFunctionException] { + client.getFunction("default", "func2") + } + } else { + // No exception should be thrown + val func = client.getFunction("default", "func2") + assert(func.className == "org.apache.spark.MyFunc2") + } + } + + test("getFunctionOption") { + if (version == "0.12") { + // Hive 0.12 doesn't allow customized permanent functions + assert(client.getFunctionOption("default", "func2").isEmpty) + } else { + assert(client.getFunctionOption("default", "func2").isDefined) + assert(client.getFunctionOption("default", "the_func_not_exists").isEmpty) + } + } + + test("listFunctions") { + if (version == "0.12") { + // Hive 0.12 doesn't allow customized permanent functions + assert(client.listFunctions("default", "fun.*").isEmpty) + } else { + assert(client.listFunctions("default", "fun.*").size == 1) + } + } + + test("dropFunction") { + if (version == "0.12") { + // Hive 0.12 doesn't support creating permanent functions + intercept[NoSuchPermanentFunctionException] { + client.dropFunction("default", "func2") + } + } else { + // No exception should be thrown + client.dropFunction("default", "func2") + assert(client.listFunctions("default", "fun.*").isEmpty) + } + } + + /////////////////////////////////////////////////////////////////////////// + // SQL related API + /////////////////////////////////////////////////////////////////////////// + + test("sql set command") { + client.runSqlHive("SET spark.sql.test.key=1") + } + + test("sql create index and reset") { + // HIVE-18448 Since Hive 3.0, INDEX is not supported. + if (version != "3.0" && version != "3.1") { + client.runSqlHive("CREATE TABLE indexed_table (key INT)") + client.runSqlHive("CREATE INDEX index_1 ON TABLE indexed_table(key) " + + "as 'COMPACT' WITH DEFERRED REBUILD") + } + } + + test("sql read hive materialized view") { + // HIVE-14249 Since Hive 2.3.0, materialized view is supported. + if (version == "2.3" || version == "3.0" || version == "3.1") { + // Since Hive 3.0(HIVE-19383), we can not run local MR by `client.runSqlHive` with JDK 11. + assume(version == "2.3" || !SystemUtils.isJavaVersionAtLeast(JavaVersion.JAVA_9)) + // Since HIVE-18394(Hive 3.1), "Create Materialized View" should default to rewritable ones + val disableRewrite = if (version == "2.3" || version == "3.0") "" else "DISABLE REWRITE" + client.runSqlHive("CREATE TABLE materialized_view_tbl (c1 INT)") + client.runSqlHive( + s"CREATE MATERIALIZED VIEW mv1 $disableRewrite AS SELECT * FROM materialized_view_tbl") + val e = intercept[AnalysisException](versionSpark.table("mv1").collect()).getMessage + assert(e.contains("Hive materialized view is not supported")) + } + } + + /////////////////////////////////////////////////////////////////////////// + // Miscellaneous API + /////////////////////////////////////////////////////////////////////////// + + test("version") { + assert(client.version.fullVersion.startsWith(version)) + } + + test("getConf") { + assert("success" === client.getConf("test", null)) + } + + test("setOut") { + client.setOut(new PrintStream(new ByteArrayOutputStream())) + } + + test("setInfo") { + client.setInfo(new PrintStream(new ByteArrayOutputStream())) + } + + test("setError") { + client.setError(new PrintStream(new ByteArrayOutputStream())) + } + + test("newSession") { + val newClient = client.newSession() + assert(newClient != null) + } + + test("withHiveState and addJar") { + val newClassPath = "." + client.addJar(newClassPath) + client.withHiveState { + // No exception should be thrown. + // withHiveState changes the classloader to MutableURLClassLoader + val classLoader = Thread.currentThread().getContextClassLoader + .asInstanceOf[MutableURLClassLoader] + + val urls = classLoader.getURLs + urls.contains(new File(newClassPath).toURI.toURL) + } + } + + test("reset") { + // Clears all database, tables, functions... + client.reset() + assert(client.listTables("default").isEmpty) + } + + /////////////////////////////////////////////////////////////////////////// + // End-To-End tests + /////////////////////////////////////////////////////////////////////////// + + test("CREATE TABLE AS SELECT") { + withTable("tbl") { + versionSpark.sql("CREATE TABLE tbl AS SELECT 1 AS a") + assert(versionSpark.table("tbl").collect().toSeq == Seq(Row(1))) + val tableMeta = versionSpark.sessionState.catalog.getTableMetadata(TableIdentifier("tbl")) + val totalSize = tableMeta.stats.map(_.sizeInBytes) + // Except 0.12, all the following versions will fill the Hive-generated statistics + if (version == "0.12") { + assert(totalSize.isEmpty) + } else { + assert(totalSize.nonEmpty && totalSize.get > 0) + } + } + } + + test("CREATE Partitioned TABLE AS SELECT") { + withTable("tbl") { + versionSpark.sql( + """ + |CREATE TABLE tbl(c1 string) + |USING hive + |PARTITIONED BY (ds STRING) + """.stripMargin) + versionSpark.sql("INSERT OVERWRITE TABLE tbl partition (ds='2') SELECT '1'") + + assert(versionSpark.table("tbl").collect().toSeq == Seq(Row("1", "2"))) + val partMeta = versionSpark.sessionState.catalog.getPartition( + TableIdentifier("tbl"), spec = Map("ds" -> "2")).parameters + val totalSize = partMeta.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong) + val numFiles = partMeta.get(StatsSetupConst.NUM_FILES).map(_.toLong) + // Except 0.12, all the following versions will fill the Hive-generated statistics + if (version == "0.12") { + assert(totalSize.isEmpty && numFiles.isEmpty) + } else { + assert(totalSize.nonEmpty && numFiles.nonEmpty) + } + + versionSpark.sql( + """ + |ALTER TABLE tbl PARTITION (ds='2') + |SET SERDEPROPERTIES ('newKey' = 'vvv') + """.stripMargin) + val newPartMeta = versionSpark.sessionState.catalog.getPartition( + TableIdentifier("tbl"), spec = Map("ds" -> "2")).parameters + + val newTotalSize = newPartMeta.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong) + val newNumFiles = newPartMeta.get(StatsSetupConst.NUM_FILES).map(_.toLong) + // Except 0.12, all the following versions will fill the Hive-generated statistics + if (version == "0.12") { + assert(newTotalSize.isEmpty && newNumFiles.isEmpty) + } else { + assert(newTotalSize.nonEmpty && newNumFiles.nonEmpty) + } + } + } + + test("Delete the temporary staging directory and files after each insert") { + withTempDir { tmpDir => + withTable("tab") { + versionSpark.sql( + s""" + |CREATE TABLE tab(c1 string) + |location '${tmpDir.toURI.toString}' + """.stripMargin) + + (1 to 3).map { i => + versionSpark.sql(s"INSERT OVERWRITE TABLE tab SELECT '$i'") + } + def listFiles(path: File): List[String] = { + val dir = path.listFiles() + val folders = dir.filter(_.isDirectory).toList + val filePaths = dir.map(_.getName).toList + folders.flatMap(listFiles) ++: filePaths + } + // expect 2 files left: `.part-00000-random-uuid.crc` and `part-00000-random-uuid` + // 0.12, 0.13, 1.0 and 1.1 also has another two more files ._SUCCESS.crc and _SUCCESS + val metadataFiles = Seq("._SUCCESS.crc", "_SUCCESS") + assert(listFiles(tmpDir).filterNot(metadataFiles.contains).length == 2) + } + } + } + + test("SPARK-13709: reading partitioned Avro table with nested schema") { + withTempDir { dir => + val path = dir.toURI.toString + val tableName = "spark_13709" + val tempTableName = "spark_13709_temp" + + new File(dir.getAbsolutePath, tableName).mkdir() + new File(dir.getAbsolutePath, tempTableName).mkdir() + + val avroSchema = + """{ + | "name": "test_record", + | "type": "record", + | "fields": [ { + | "name": "f0", + | "type": "int" + | }, { + | "name": "f1", + | "type": { + | "type": "record", + | "name": "inner", + | "fields": [ { + | "name": "f10", + | "type": "int" + | }, { + | "name": "f11", + | "type": "double" + | } ] + | } + | } ] + |} + """.stripMargin + + withTable(tableName, tempTableName) { + // Creates the external partitioned Avro table to be tested. + versionSpark.sql( + s"""CREATE EXTERNAL TABLE $tableName + |PARTITIONED BY (ds STRING) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat' + |LOCATION '$path/$tableName' + |TBLPROPERTIES ('avro.schema.literal' = '$avroSchema') + """.stripMargin + ) + + // Creates an temporary Avro table used to prepare testing Avro file. + versionSpark.sql( + s"""CREATE EXTERNAL TABLE $tempTableName + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat' + |LOCATION '$path/$tempTableName' + |TBLPROPERTIES ('avro.schema.literal' = '$avroSchema') + """.stripMargin + ) + + // Generates Avro data. + versionSpark.sql(s"INSERT OVERWRITE TABLE $tempTableName SELECT 1, STRUCT(2, 2.5)") + + // Adds generated Avro data as a new partition to the testing table. + versionSpark.sql( + s"ALTER TABLE $tableName ADD PARTITION (ds = 'foo') LOCATION '$path/$tempTableName'") + + // The following query fails before SPARK-13709 is fixed. This is because when reading + // data from table partitions, Avro deserializer needs the Avro schema, which is defined + // in table property "avro.schema.literal". However, we only initializes the deserializer + // using partition properties, which doesn't include the wanted property entry. Merging + // two sets of properties solves the problem. + assert(versionSpark.sql(s"SELECT * FROM $tableName").collect() === + Array(Row(1, Row(2, 2.5D), "foo"))) + } + } + } + + test("CTAS for managed data source tables") { + withTable("t", "t1") { + versionSpark.range(1).write.saveAsTable("t") + assert(versionSpark.table("t").collect() === Array(Row(0))) + versionSpark.sql("create table t1 using parquet as select 2 as a") + assert(versionSpark.table("t1").collect() === Array(Row(2))) + } + } + + test("Decimal support of Avro Hive serde") { + val tableName = "tab1" + // TODO: add the other logical types. For details, see the link: + // https://avro.apache.org/docs/1.8.1/spec.html#Logical+Types + val avroSchema = + """{ + | "name": "test_record", + | "type": "record", + | "fields": [ { + | "name": "f0", + | "type": [ + | "null", + | { + | "precision": 38, + | "scale": 2, + | "type": "bytes", + | "logicalType": "decimal" + | } + | ] + | } ] + |} + """.stripMargin + + Seq(true, false).foreach { isPartitioned => + withTable(tableName) { + val partitionClause = if (isPartitioned) "PARTITIONED BY (ds STRING)" else "" + // Creates the (non-)partitioned Avro table + versionSpark.sql( + s""" + |CREATE TABLE $tableName + |$partitionClause + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat' + |TBLPROPERTIES ('avro.schema.literal' = '$avroSchema') + """.stripMargin + ) + + val errorMsg = "Cannot safely cast 'f0': decimal(2,1) to binary" + + if (isPartitioned) { + val insertStmt = s"INSERT OVERWRITE TABLE $tableName partition (ds='a') SELECT 1.3" + if (version == "0.12" || version == "0.13") { + val e = intercept[AnalysisException](versionSpark.sql(insertStmt)).getMessage + assert(e.contains(errorMsg)) + } else { + versionSpark.sql(insertStmt) + assert(versionSpark.table(tableName).collect() === + versionSpark.sql("SELECT 1.30, 'a'").collect()) + } + } else { + val insertStmt = s"INSERT OVERWRITE TABLE $tableName SELECT 1.3" + if (version == "0.12" || version == "0.13") { + val e = intercept[AnalysisException](versionSpark.sql(insertStmt)).getMessage + assert(e.contains(errorMsg)) + } else { + versionSpark.sql(insertStmt) + assert(versionSpark.table(tableName).collect() === + versionSpark.sql("SELECT 1.30").collect()) + } + } + } + } + } + + test("read avro file containing decimal") { + val url = Thread.currentThread().getContextClassLoader.getResource("avroDecimal") + val location = new File(url.getFile).toURI.toString + + val tableName = "tab1" + val avroSchema = + """{ + | "name": "test_record", + | "type": "record", + | "fields": [ { + | "name": "f0", + | "type": [ + | "null", + | { + | "precision": 38, + | "scale": 2, + | "type": "bytes", + | "logicalType": "decimal" + | } + | ] + | } ] + |} + """.stripMargin + withTable(tableName) { + versionSpark.sql( + s""" + |CREATE TABLE $tableName + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe' + |WITH SERDEPROPERTIES ('respectSparkSchema' = 'true') + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat' + |LOCATION '$location' + |TBLPROPERTIES ('avro.schema.literal' = '$avroSchema') + """.stripMargin + ) + assert(versionSpark.table(tableName).collect() === + versionSpark.sql("SELECT 1.30").collect()) + } + } + + test("SPARK-17920: Insert into/overwrite avro table") { + // skipped because it's failed in the condition on Windows + assume(!(Utils.isWindows && version == "0.12")) + withTempDir { dir => + val avroSchema = + """ + |{ + | "name": "test_record", + | "type": "record", + | "fields": [{ + | "name": "f0", + | "type": [ + | "null", + | { + | "precision": 38, + | "scale": 2, + | "type": "bytes", + | "logicalType": "decimal" + | } + | ] + | }] + |} + """.stripMargin + val schemaFile = new File(dir, "avroDecimal.avsc") + Utils.tryWithResource(new PrintWriter(schemaFile)) { writer => + writer.write(avroSchema) + } + val schemaPath = schemaFile.toURI.toString + + val url = Thread.currentThread().getContextClassLoader.getResource("avroDecimal") + val srcLocation = new File(url.getFile).toURI.toString + val destTableName = "tab1" + val srcTableName = "tab2" + + withTable(srcTableName, destTableName) { + versionSpark.sql( + s""" + |CREATE EXTERNAL TABLE $srcTableName + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe' + |WITH SERDEPROPERTIES ('respectSparkSchema' = 'true') + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat' + |LOCATION '$srcLocation' + |TBLPROPERTIES ('avro.schema.url' = '$schemaPath') + """.stripMargin + ) + + versionSpark.sql( + s""" + |CREATE TABLE $destTableName + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe' + |WITH SERDEPROPERTIES ('respectSparkSchema' = 'true') + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat' + |TBLPROPERTIES ('avro.schema.url' = '$schemaPath') + """.stripMargin + ) + versionSpark.sql( + s"""INSERT OVERWRITE TABLE $destTableName SELECT * FROM $srcTableName""") + val result = versionSpark.table(srcTableName).collect() + assert(versionSpark.table(destTableName).collect() === result) + versionSpark.sql( + s"""INSERT INTO TABLE $destTableName SELECT * FROM $srcTableName""") + assert(versionSpark.table(destTableName).collect().toSeq === result ++ result) + } + } + } + // TODO: add more tests. +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuites.scala new file mode 100644 index 0000000000000..b172c0dfedc9f --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuites.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import java.net.URI + +import scala.collection.immutable.IndexedSeq + +import org.apache.hadoop.conf.Configuration +import org.scalatest.Suite + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.catalog.CatalogDatabase +import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.hive.HiveUtils +import org.apache.spark.tags.{ExtendedHiveTest, SlowHiveTest} + +/** + * A simple set of tests that call the methods of a [[HiveClient]], loading different version + * of hive from maven central. These tests are simple in that they are mostly just testing to make + * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionality + * is not fully tested. + */ +@SlowHiveTest +@ExtendedHiveTest +class HiveClientSuites extends SparkFunSuite with HiveClientVersions { + + override protected val enableAutoThreadAudit = false + + import HiveClientBuilder.buildClient + + test("success sanity check") { + val badClient = buildClient(HiveUtils.builtinHiveVersion, new Configuration()) + val db = CatalogDatabase("default", "desc", new URI("loc"), Map()) + badClient.createDatabase(db, ignoreIfExists = true) + } + + test("hadoop configuration preserved") { + val hadoopConf = new Configuration() + hadoopConf.set("test", "success") + val client = buildClient(HiveUtils.builtinHiveVersion, hadoopConf) + assert("success" === client.getConf("test", null)) + } + + test("override useless and side-effect hive configurations ") { + val hadoopConf = new Configuration() + // These hive flags should be reset by spark + hadoopConf.setBoolean("hive.cbo.enable", true) + hadoopConf.setBoolean("hive.session.history.enabled", true) + hadoopConf.set("hive.execution.engine", "tez") + val client = buildClient(HiveUtils.builtinHiveVersion, hadoopConf) + assert(!client.getConf("hive.cbo.enable", "true").toBoolean) + assert(!client.getConf("hive.session.history.enabled", "true").toBoolean) + assert(client.getConf("hive.execution.engine", "tez") === "mr") + } + + private def getNestedMessages(e: Throwable): String = { + var causes = "" + var lastException = e + while (lastException != null) { + causes += lastException.toString + "\n" + lastException = lastException.getCause + } + causes + } + + // Its actually pretty easy to mess things up and have all of your tests "pass" by accidentally + // connecting to an auto-populated, in-process metastore. Let's make sure we are getting the + // versions right by forcing a known compatibility failure. + // TODO: currently only works on mysql where we manually create the schema... + ignore("failure sanity check") { + val e = intercept[Throwable] { + val badClient = quietly { buildClient("13", new Configuration()) } + } + assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'") + } + + override def nestedSuites: IndexedSeq[Suite] = { + versions.map(new HiveClientSuite(_, versions)) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala index 5fef7d1d623ac..e9ab8edf9ad18 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala @@ -637,30 +637,32 @@ class HivePartitionFilteringSuite(version: String) } test("SPARK-35437: getPartitionsByFilter: relax cast if does not need timezone") { - // does not need time zone - Seq(("true", "20200104" :: Nil), ("false", dateStrValue)).foreach { - case (pruningFastFallbackEnabled, prunedPartition) => + if (!SQLConf.get.ansiEnabled) { + // does not need time zone + Seq(("true", "20200104" :: Nil), ("false", dateStrValue)).foreach { + case (pruningFastFallbackEnabled, prunedPartition) => + withSQLConf(pruningFastFallback -> pruningFastFallbackEnabled) { + testMetastorePartitionFiltering( + attr("datestr").cast(IntegerType) === 20200104, + dsValue, + hValue, + chunkValue, + dateValue, + prunedPartition) + } + } + + // need time zone + Seq("true", "false").foreach { pruningFastFallbackEnabled => withSQLConf(pruningFastFallback -> pruningFastFallbackEnabled) { testMetastorePartitionFiltering( - attr("datestr").cast(IntegerType) === 20200104, + attr("datestr").cast(DateType) === Date.valueOf("2020-01-01"), dsValue, hValue, chunkValue, dateValue, - prunedPartition) + dateStrValue) } - } - - // need time zone - Seq("true", "false").foreach { pruningFastFallbackEnabled => - withSQLConf(pruningFastFallback -> pruningFastFallbackEnabled) { - testMetastorePartitionFiltering( - attr("datestr").cast(DateType) === Date.valueOf("2020-01-01"), - dsValue, - hValue, - chunkValue, - dateValue, - dateStrValue) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala index 02e9b7fb151fd..4cc51064cfdd3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala @@ -40,6 +40,7 @@ private[client] abstract class HiveVersionSuite(version: String) extends SparkFu // Since Hive 3.0, HIVE-19310 skipped `ensureDbInit` if `hive.in.test=false`. if (version == "3.0" || version == "3.1") { hadoopConf.set("hive.in.test", "true") + hadoopConf.set("hive.query.reexecution.enabled", "false") } HiveClientBuilder.buildClient( version, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala deleted file mode 100644 index 14b2a51bff8c0..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ /dev/null @@ -1,1157 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.client - -import java.io.{ByteArrayOutputStream, File, PrintStream, PrintWriter} -import java.net.URI - -import org.apache.commons.lang3.{JavaVersion, SystemUtils} -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.common.StatsSetupConst -import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat -import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe -import org.apache.hadoop.mapred.TextInputFormat -import org.apache.hadoop.security.UserGroupInformation - -import org.apache.spark.SparkFunSuite -import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, Row} -import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchPermanentFunctionException, PartitionsAlreadyExistException} -import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal} -import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils} -import org.apache.spark.sql.hive.test.TestHiveVersion -import org.apache.spark.sql.types.IntegerType -import org.apache.spark.sql.types.StructType -import org.apache.spark.tags.{ExtendedHiveTest, SlowHiveTest} -import org.apache.spark.util.{MutableURLClassLoader, Utils} - -/** - * A simple set of tests that call the methods of a [[HiveClient]], loading different version - * of hive from maven central. These tests are simple in that they are mostly just testing to make - * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionality - * is not fully tested. - */ -// TODO: Refactor this to `HiveClientSuite` and make it a subclass of `HiveVersionSuite` -@SlowHiveTest -@ExtendedHiveTest -class VersionsSuite extends SparkFunSuite with Logging { - - override protected val enableAutoThreadAudit = false - - import HiveClientBuilder.buildClient - - /** - * Drops table `tableName` after calling `f`. - */ - protected def withTable(tableNames: String*)(f: => Unit): Unit = { - try f finally { - tableNames.foreach { name => - versionSpark.sql(s"DROP TABLE IF EXISTS $name") - } - } - } - - test("success sanity check") { - val badClient = buildClient(HiveUtils.builtinHiveVersion, new Configuration()) - val db = new CatalogDatabase("default", "desc", new URI("loc"), Map()) - badClient.createDatabase(db, ignoreIfExists = true) - } - - test("hadoop configuration preserved") { - val hadoopConf = new Configuration() - hadoopConf.set("test", "success") - val client = buildClient(HiveUtils.builtinHiveVersion, hadoopConf) - assert("success" === client.getConf("test", null)) - } - - test("override useless and side-effect hive configurations ") { - val hadoopConf = new Configuration() - // These hive flags should be reset by spark - hadoopConf.setBoolean("hive.cbo.enable", true) - hadoopConf.setBoolean("hive.session.history.enabled", true) - hadoopConf.set("hive.execution.engine", "tez") - val client = buildClient(HiveUtils.builtinHiveVersion, hadoopConf) - assert(!client.getConf("hive.cbo.enable", "true").toBoolean) - assert(!client.getConf("hive.session.history.enabled", "true").toBoolean) - assert(client.getConf("hive.execution.engine", "tez") === "mr") - } - - private def getNestedMessages(e: Throwable): String = { - var causes = "" - var lastException = e - while (lastException != null) { - causes += lastException.toString + "\n" - lastException = lastException.getCause - } - causes - } - - private val emptyDir = Utils.createTempDir().getCanonicalPath - - // Its actually pretty easy to mess things up and have all of your tests "pass" by accidentally - // connecting to an auto-populated, in-process metastore. Let's make sure we are getting the - // versions right by forcing a known compatibility failure. - // TODO: currently only works on mysql where we manually create the schema... - ignore("failure sanity check") { - val e = intercept[Throwable] { - val badClient = quietly { buildClient("13", new Configuration()) } - } - assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'") - } - - private val versions = if (SystemUtils.isJavaVersionAtLeast(JavaVersion.JAVA_9)) { - Seq("2.0", "2.1", "2.2", "2.3", "3.0", "3.1") - } else { - Seq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0", "2.1", "2.2", "2.3", "3.0", "3.1") - } - - private var client: HiveClient = null - - private var versionSpark: TestHiveVersion = null - - versions.foreach { version => - test(s"$version: create client") { - client = null - System.gc() // Hack to avoid SEGV on some JVM versions. - val hadoopConf = new Configuration() - hadoopConf.set("test", "success") - // Hive changed the default of datanucleus.schema.autoCreateAll from true to false and - // hive.metastore.schema.verification from false to true since 2.0 - // For details, see the JIRA HIVE-6113 and HIVE-12463 - if (version == "2.0" || version == "2.1" || version == "2.2" || version == "2.3" || - version == "3.0" || version == "3.1") { - hadoopConf.set("datanucleus.schema.autoCreateAll", "true") - hadoopConf.set("hive.metastore.schema.verification", "false") - } - if (version == "3.0" || version == "3.1") { - // Since Hive 3.0, HIVE-19310 skipped `ensureDbInit` if `hive.in.test=false`. - hadoopConf.set("hive.in.test", "true") - // Since HIVE-17626(Hive 3.0.0), need to set hive.query.reexecution.enabled=false. - hadoopConf.set("hive.query.reexecution.enabled", "false") - } - client = buildClient(version, hadoopConf, HiveUtils.formatTimeVarsForHiveClient(hadoopConf)) - if (versionSpark != null) versionSpark.reset() - versionSpark = TestHiveVersion(client) - assert(versionSpark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] - .client.version.fullVersion.startsWith(version)) - } - - def table(database: String, tableName: String, - tableType: CatalogTableType = CatalogTableType.MANAGED): CatalogTable = { - CatalogTable( - identifier = TableIdentifier(tableName, Some(database)), - tableType = tableType, - schema = new StructType().add("key", "int"), - storage = CatalogStorageFormat( - locationUri = None, - inputFormat = Some(classOf[TextInputFormat].getName), - outputFormat = Some(classOf[HiveIgnoreKeyTextOutputFormat[_, _]].getName), - serde = Some(classOf[LazySimpleSerDe].getName()), - compressed = false, - properties = Map.empty - )) - } - - /////////////////////////////////////////////////////////////////////////// - // Database related API - /////////////////////////////////////////////////////////////////////////// - - val tempDatabasePath = Utils.createTempDir().toURI - - test(s"$version: createDatabase") { - val defaultDB = CatalogDatabase("default", "desc", new URI("loc"), Map()) - client.createDatabase(defaultDB, ignoreIfExists = true) - val tempDB = CatalogDatabase( - "temporary", description = "test create", tempDatabasePath, Map()) - client.createDatabase(tempDB, ignoreIfExists = true) - - try { - client.createDatabase(tempDB, ignoreIfExists = false) - assert(false, "createDatabase should throw AlreadyExistsException") - } catch { - case ex: Throwable => - assert(ex.getClass.getName.equals( - "org.apache.hadoop.hive.metastore.api.AlreadyExistsException")) - assert(ex.getMessage.contains(s"Database ${tempDB.name} already exists")) - } - } - - test(s"$version: create/get/alter database should pick right user name as owner") { - if (version != "0.12") { - val currentUser = UserGroupInformation.getCurrentUser.getUserName - val ownerName = "SPARK_29425" - val db1 = "SPARK_29425_1" - val db2 = "SPARK_29425_2" - val ownerProps = Map("owner" -> ownerName) - - // create database with owner - val dbWithOwner = CatalogDatabase(db1, "desc", Utils.createTempDir().toURI, ownerProps) - client.createDatabase(dbWithOwner, ignoreIfExists = true) - val getDbWithOwner = client.getDatabase(db1) - assert(getDbWithOwner.properties("owner") === ownerName) - // alter database without owner - client.alterDatabase(getDbWithOwner.copy(properties = Map())) - assert(client.getDatabase(db1).properties("owner") === "") - - // create database without owner - val dbWithoutOwner = CatalogDatabase(db2, "desc", Utils.createTempDir().toURI, Map()) - client.createDatabase(dbWithoutOwner, ignoreIfExists = true) - val getDbWithoutOwner = client.getDatabase(db2) - assert(getDbWithoutOwner.properties("owner") === currentUser) - // alter database with owner - client.alterDatabase(getDbWithoutOwner.copy(properties = ownerProps)) - assert(client.getDatabase(db2).properties("owner") === ownerName) - } - } - - test(s"$version: createDatabase with null description") { - withTempDir { tmpDir => - val dbWithNullDesc = - CatalogDatabase("dbWithNullDesc", description = null, tmpDir.toURI, Map()) - client.createDatabase(dbWithNullDesc, ignoreIfExists = true) - assert(client.getDatabase("dbWithNullDesc").description == "") - } - } - - test(s"$version: setCurrentDatabase") { - client.setCurrentDatabase("default") - } - - test(s"$version: getDatabase") { - // No exception should be thrown - client.getDatabase("default") - intercept[NoSuchDatabaseException](client.getDatabase("nonexist")) - } - - test(s"$version: databaseExists") { - assert(client.databaseExists("default")) - assert(client.databaseExists("nonexist") == false) - } - - test(s"$version: listDatabases") { - assert(client.listDatabases("defau.*") == Seq("default")) - } - - test(s"$version: alterDatabase") { - val database = client.getDatabase("temporary").copy(properties = Map("flag" -> "true")) - client.alterDatabase(database) - assert(client.getDatabase("temporary").properties.contains("flag")) - - // test alter database location - val tempDatabasePath2 = Utils.createTempDir().toURI - // Hive support altering database location since HIVE-8472. - if (version == "3.0" || version == "3.1") { - client.alterDatabase(database.copy(locationUri = tempDatabasePath2)) - val uriInCatalog = client.getDatabase("temporary").locationUri - assert("file" === uriInCatalog.getScheme) - assert(new Path(tempDatabasePath2.getPath).toUri.getPath === uriInCatalog.getPath, - "Failed to alter database location") - } else { - val e = intercept[AnalysisException] { - client.alterDatabase(database.copy(locationUri = tempDatabasePath2)) - } - assert(e.getMessage.contains("does not support altering database location")) - } - } - - test(s"$version: dropDatabase") { - assert(client.databaseExists("temporary")) - client.dropDatabase("temporary", ignoreIfNotExists = false, cascade = true) - assert(client.databaseExists("temporary") == false) - } - - /////////////////////////////////////////////////////////////////////////// - // Table related API - /////////////////////////////////////////////////////////////////////////// - - test(s"$version: createTable") { - client.createTable(table("default", tableName = "src"), ignoreIfExists = false) - client.createTable(table("default", tableName = "temporary"), ignoreIfExists = false) - client.createTable(table("default", tableName = "view1", tableType = CatalogTableType.VIEW), - ignoreIfExists = false) - } - - test(s"$version: loadTable") { - client.loadTable( - emptyDir, - tableName = "src", - replace = false, - isSrcLocal = false) - } - - test(s"$version: tableExists") { - // No exception should be thrown - assert(client.tableExists("default", "src")) - assert(!client.tableExists("default", "nonexistent")) - } - - test(s"$version: getTable") { - // No exception should be thrown - client.getTable("default", "src") - } - - test(s"$version: getTableOption") { - assert(client.getTableOption("default", "src").isDefined) - } - - test(s"$version: getTablesByName") { - assert(client.getTablesByName("default", Seq("src")).head - == client.getTableOption("default", "src").get) - } - - test(s"$version: getTablesByName when multiple tables") { - assert(client.getTablesByName("default", Seq("src", "temporary")) - .map(_.identifier.table) == Seq("src", "temporary")) - } - - test(s"$version: getTablesByName when some tables do not exist") { - assert(client.getTablesByName("default", Seq("src", "notexist")) - .map(_.identifier.table) == Seq("src")) - } - - test(s"$version: getTablesByName when contains invalid name") { - // scalastyle:off - val name = "ç –" - // scalastyle:on - assert(client.getTablesByName("default", Seq("src", name)) - .map(_.identifier.table) == Seq("src")) - } - - test(s"$version: getTablesByName when empty") { - assert(client.getTablesByName("default", Seq.empty).isEmpty) - } - - test(s"$version: alterTable(table: CatalogTable)") { - val newTable = client.getTable("default", "src").copy(properties = Map("changed" -> "")) - client.alterTable(newTable) - assert(client.getTable("default", "src").properties.contains("changed")) - } - - test(s"$version: alterTable - should respect the original catalog table's owner name") { - val ownerName = "SPARK-29405" - val originalTable = client.getTable("default", "src") - // mocking the owner is what we declared - val newTable = originalTable.copy(owner = ownerName) - client.alterTable(newTable) - assert(client.getTable("default", "src").owner === ownerName) - // mocking the owner is empty - val newTable2 = originalTable.copy(owner = "") - client.alterTable(newTable2) - assert(client.getTable("default", "src").owner === client.userName) - } - - test(s"$version: alterTable(dbName: String, tableName: String, table: CatalogTable)") { - val newTable = client.getTable("default", "src").copy(properties = Map("changedAgain" -> "")) - client.alterTable("default", "src", newTable) - assert(client.getTable("default", "src").properties.contains("changedAgain")) - } - - test(s"$version: alterTable - rename") { - val newTable = client.getTable("default", "src") - .copy(identifier = TableIdentifier("tgt", database = Some("default"))) - assert(!client.tableExists("default", "tgt")) - - client.alterTable("default", "src", newTable) - - assert(client.tableExists("default", "tgt")) - assert(!client.tableExists("default", "src")) - } - - test(s"$version: alterTable - change database") { - val tempDB = CatalogDatabase( - "temporary", description = "test create", tempDatabasePath, Map()) - client.createDatabase(tempDB, ignoreIfExists = true) - - val newTable = client.getTable("default", "tgt") - .copy(identifier = TableIdentifier("tgt", database = Some("temporary"))) - assert(!client.tableExists("temporary", "tgt")) - - client.alterTable("default", "tgt", newTable) - - assert(client.tableExists("temporary", "tgt")) - assert(!client.tableExists("default", "tgt")) - } - - test(s"$version: alterTable - change database and table names") { - val newTable = client.getTable("temporary", "tgt") - .copy(identifier = TableIdentifier("src", database = Some("default"))) - assert(!client.tableExists("default", "src")) - - client.alterTable("temporary", "tgt", newTable) - - assert(client.tableExists("default", "src")) - assert(!client.tableExists("temporary", "tgt")) - } - - test(s"$version: listTables(database)") { - assert(client.listTables("default") === Seq("src", "temporary", "view1")) - } - - test(s"$version: listTables(database, pattern)") { - assert(client.listTables("default", pattern = "src") === Seq("src")) - assert(client.listTables("default", pattern = "nonexist").isEmpty) - } - - test(s"$version: listTablesByType(database, pattern, tableType)") { - assert(client.listTablesByType("default", pattern = "view1", - CatalogTableType.VIEW) === Seq("view1")) - assert(client.listTablesByType("default", pattern = "nonexist", - CatalogTableType.VIEW).isEmpty) - } - - test(s"$version: dropTable") { - val versionsWithoutPurge = - if (versions.contains("0.14")) versions.takeWhile(_ != "0.14") else Nil - // First try with the purge option set. This should fail if the version is < 0.14, in which - // case we check the version and try without it. - try { - client.dropTable("default", tableName = "temporary", ignoreIfNotExists = false, - purge = true) - assert(!versionsWithoutPurge.contains(version)) - } catch { - case _: UnsupportedOperationException => - assert(versionsWithoutPurge.contains(version)) - client.dropTable("default", tableName = "temporary", ignoreIfNotExists = false, - purge = false) - } - // Drop table with type CatalogTableType.VIEW. - try { - client.dropTable("default", tableName = "view1", ignoreIfNotExists = false, - purge = true) - assert(!versionsWithoutPurge.contains(version)) - } catch { - case _: UnsupportedOperationException => - client.dropTable("default", tableName = "view1", ignoreIfNotExists = false, - purge = false) - } - assert(client.listTables("default") === Seq("src")) - } - - /////////////////////////////////////////////////////////////////////////// - // Partition related API - /////////////////////////////////////////////////////////////////////////// - - val storageFormat = CatalogStorageFormat( - locationUri = None, - inputFormat = None, - outputFormat = None, - serde = None, - compressed = false, - properties = Map.empty) - - test(s"$version: sql create partitioned table") { - val table = CatalogTable( - identifier = TableIdentifier("src_part", Some("default")), - tableType = CatalogTableType.MANAGED, - schema = new StructType().add("value", "int").add("key1", "int").add("key2", "int"), - partitionColumnNames = Seq("key1", "key2"), - storage = CatalogStorageFormat( - locationUri = None, - inputFormat = Some(classOf[TextInputFormat].getName), - outputFormat = Some(classOf[HiveIgnoreKeyTextOutputFormat[_, _]].getName), - serde = Some(classOf[LazySimpleSerDe].getName()), - compressed = false, - properties = Map.empty - )) - client.createTable(table, ignoreIfExists = false) - } - - val testPartitionCount = 2 - - test(s"$version: createPartitions") { - val partitions = (1 to testPartitionCount).map { key2 => - CatalogTablePartition(Map("key1" -> "1", "key2" -> key2.toString), storageFormat) - } - client.createPartitions( - "default", "src_part", partitions, ignoreIfExists = true) - } - - test(s"$version: getPartitionNames(catalogTable)") { - val partitionNames = (1 to testPartitionCount).map(key2 => s"key1=1/key2=$key2") - assert(partitionNames == client.getPartitionNames(client.getTable("default", "src_part"))) - } - - test(s"$version: getPartitions(db, table, spec)") { - assert(testPartitionCount == - client.getPartitions("default", "src_part", None).size) - } - - test(s"$version: getPartitionsByFilter") { - // Only one partition [1, 1] for key2 == 1 - val result = client.getPartitionsByFilter(client.getTable("default", "src_part"), - Seq(EqualTo(AttributeReference("key2", IntegerType)(), Literal(1)))) - - // Hive 0.12 doesn't support getPartitionsByFilter, it ignores the filter condition. - if (version != "0.12") { - assert(result.size == 1) - } else { - assert(result.size == testPartitionCount) - } - } - - test(s"$version: getPartition") { - // No exception should be thrown - client.getPartition("default", "src_part", Map("key1" -> "1", "key2" -> "2")) - } - - test(s"$version: getPartitionOption(db: String, table: String, spec: TablePartitionSpec)") { - val partition = client.getPartitionOption( - "default", "src_part", Map("key1" -> "1", "key2" -> "2")) - assert(partition.isDefined) - } - - test(s"$version: getPartitionOption(table: CatalogTable, spec: TablePartitionSpec)") { - val partition = client.getPartitionOption( - client.getTable("default", "src_part"), Map("key1" -> "1", "key2" -> "2")) - assert(partition.isDefined) - } - - test(s"$version: getPartitions(db: String, table: String)") { - assert(testPartitionCount == client.getPartitions("default", "src_part", None).size) - } - - test(s"$version: loadPartition") { - val partSpec = new java.util.LinkedHashMap[String, String] - partSpec.put("key1", "1") - partSpec.put("key2", "2") - - client.loadPartition( - emptyDir, - "default", - "src_part", - partSpec, - replace = false, - inheritTableSpecs = false, - isSrcLocal = false) - } - - test(s"$version: loadDynamicPartitions") { - val partSpec = new java.util.LinkedHashMap[String, String] - partSpec.put("key1", "1") - partSpec.put("key2", "") // Dynamic partition - - client.loadDynamicPartitions( - emptyDir, - "default", - "src_part", - partSpec, - replace = false, - numDP = 1) - } - - test(s"$version: renamePartitions") { - val oldSpec = Map("key1" -> "1", "key2" -> "1") - val newSpec = Map("key1" -> "1", "key2" -> "3") - client.renamePartitions("default", "src_part", Seq(oldSpec), Seq(newSpec)) - - // Checks the existence of the new partition (key1 = 1, key2 = 3) - assert(client.getPartitionOption("default", "src_part", newSpec).isDefined) - } - - test(s"$version: alterPartitions") { - val spec = Map("key1" -> "1", "key2" -> "2") - val parameters = Map(StatsSetupConst.TOTAL_SIZE -> "0", StatsSetupConst.NUM_FILES -> "1") - val newLocation = new URI(Utils.createTempDir().toURI.toString.stripSuffix("/")) - val storage = storageFormat.copy( - locationUri = Some(newLocation), - // needed for 0.12 alter partitions - serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) - val partition = CatalogTablePartition(spec, storage, parameters) - client.alterPartitions("default", "src_part", Seq(partition)) - assert(client.getPartition("default", "src_part", spec) - .storage.locationUri == Some(newLocation)) - assert(client.getPartition("default", "src_part", spec) - .parameters.get(StatsSetupConst.TOTAL_SIZE) == Some("0")) - } - - test(s"$version: dropPartitions") { - val spec = Map("key1" -> "1", "key2" -> "3") - val versionsWithoutPurge = - if (versions.contains("1.2")) versions.takeWhile(_ != "1.2") else Nil - // Similar to dropTable; try with purge set, and if it fails, make sure we're running - // with a version that is older than the minimum (1.2 in this case). - try { - client.dropPartitions("default", "src_part", Seq(spec), ignoreIfNotExists = true, - purge = true, retainData = false) - assert(!versionsWithoutPurge.contains(version)) - } catch { - case _: UnsupportedOperationException => - assert(versionsWithoutPurge.contains(version)) - client.dropPartitions("default", "src_part", Seq(spec), ignoreIfNotExists = true, - purge = false, retainData = false) - } - - assert(client.getPartitionOption("default", "src_part", spec).isEmpty) - } - - test(s"$version: createPartitions if already exists") { - val partitions = Seq(CatalogTablePartition( - Map("key1" -> "101", "key2" -> "102"), - storageFormat)) - try { - client.createPartitions("default", "src_part", partitions, ignoreIfExists = false) - val errMsg = intercept[PartitionsAlreadyExistException] { - client.createPartitions("default", "src_part", partitions, ignoreIfExists = false) - }.getMessage - assert(errMsg.contains("partitions already exists")) - } finally { - client.dropPartitions( - "default", - "src_part", - partitions.map(_.spec), - ignoreIfNotExists = true, - purge = false, - retainData = false) - } - } - - /////////////////////////////////////////////////////////////////////////// - // Function related API - /////////////////////////////////////////////////////////////////////////// - - def function(name: String, className: String): CatalogFunction = { - CatalogFunction( - FunctionIdentifier(name, Some("default")), className, Seq.empty[FunctionResource]) - } - - test(s"$version: createFunction") { - val functionClass = "org.apache.spark.MyFunc1" - if (version == "0.12") { - // Hive 0.12 doesn't support creating permanent functions - intercept[AnalysisException] { - client.createFunction("default", function("func1", functionClass)) - } - } else { - client.createFunction("default", function("func1", functionClass)) - } - } - - test(s"$version: functionExists") { - if (version == "0.12") { - // Hive 0.12 doesn't allow customized permanent functions - assert(client.functionExists("default", "func1") == false) - } else { - assert(client.functionExists("default", "func1")) - } - } - - test(s"$version: renameFunction") { - if (version == "0.12") { - // Hive 0.12 doesn't allow customized permanent functions - intercept[NoSuchPermanentFunctionException] { - client.renameFunction("default", "func1", "func2") - } - } else { - client.renameFunction("default", "func1", "func2") - assert(client.functionExists("default", "func2")) - } - } - - test(s"$version: alterFunction") { - val functionClass = "org.apache.spark.MyFunc2" - if (version == "0.12") { - // Hive 0.12 doesn't allow customized permanent functions - intercept[NoSuchPermanentFunctionException] { - client.alterFunction("default", function("func2", functionClass)) - } - } else { - client.alterFunction("default", function("func2", functionClass)) - } - } - - test(s"$version: getFunction") { - if (version == "0.12") { - // Hive 0.12 doesn't allow customized permanent functions - intercept[NoSuchPermanentFunctionException] { - client.getFunction("default", "func2") - } - } else { - // No exception should be thrown - val func = client.getFunction("default", "func2") - assert(func.className == "org.apache.spark.MyFunc2") - } - } - - test(s"$version: getFunctionOption") { - if (version == "0.12") { - // Hive 0.12 doesn't allow customized permanent functions - assert(client.getFunctionOption("default", "func2").isEmpty) - } else { - assert(client.getFunctionOption("default", "func2").isDefined) - assert(client.getFunctionOption("default", "the_func_not_exists").isEmpty) - } - } - - test(s"$version: listFunctions") { - if (version == "0.12") { - // Hive 0.12 doesn't allow customized permanent functions - assert(client.listFunctions("default", "fun.*").isEmpty) - } else { - assert(client.listFunctions("default", "fun.*").size == 1) - } - } - - test(s"$version: dropFunction") { - if (version == "0.12") { - // Hive 0.12 doesn't support creating permanent functions - intercept[NoSuchPermanentFunctionException] { - client.dropFunction("default", "func2") - } - } else { - // No exception should be thrown - client.dropFunction("default", "func2") - assert(client.listFunctions("default", "fun.*").size == 0) - } - } - - /////////////////////////////////////////////////////////////////////////// - // SQL related API - /////////////////////////////////////////////////////////////////////////// - - test(s"$version: sql set command") { - client.runSqlHive("SET spark.sql.test.key=1") - } - - test(s"$version: sql create index and reset") { - // HIVE-18448 Since Hive 3.0, INDEX is not supported. - if (version != "3.0" && version != "3.1") { - client.runSqlHive("CREATE TABLE indexed_table (key INT)") - client.runSqlHive("CREATE INDEX index_1 ON TABLE indexed_table(key) " + - "as 'COMPACT' WITH DEFERRED REBUILD") - } - } - - test(s"$version: sql read hive materialized view") { - // HIVE-14249 Since Hive 2.3.0, materialized view is supported. - if (version == "2.3" || version == "3.0" || version == "3.1") { - // Since Hive 3.0(HIVE-19383), we can not run local MR by `client.runSqlHive` with JDK 11. - assume(version == "2.3" || !SystemUtils.isJavaVersionAtLeast(JavaVersion.JAVA_9)) - // Since HIVE-18394(Hive 3.1), "Create Materialized View" should default to rewritable ones - val disableRewrite = if (version == "2.3" || version == "3.0") "" else "DISABLE REWRITE" - client.runSqlHive("CREATE TABLE materialized_view_tbl (c1 INT)") - client.runSqlHive( - s"CREATE MATERIALIZED VIEW mv1 $disableRewrite AS SELECT * FROM materialized_view_tbl") - val e = intercept[AnalysisException](versionSpark.table("mv1").collect()).getMessage - assert(e.contains("Hive materialized view is not supported")) - } - } - - /////////////////////////////////////////////////////////////////////////// - // Miscellaneous API - /////////////////////////////////////////////////////////////////////////// - - test(s"$version: version") { - assert(client.version.fullVersion.startsWith(version)) - } - - test(s"$version: getConf") { - assert("success" === client.getConf("test", null)) - } - - test(s"$version: setOut") { - client.setOut(new PrintStream(new ByteArrayOutputStream())) - } - - test(s"$version: setInfo") { - client.setInfo(new PrintStream(new ByteArrayOutputStream())) - } - - test(s"$version: setError") { - client.setError(new PrintStream(new ByteArrayOutputStream())) - } - - test(s"$version: newSession") { - val newClient = client.newSession() - assert(newClient != null) - } - - test(s"$version: withHiveState and addJar") { - val newClassPath = "." - client.addJar(newClassPath) - client.withHiveState { - // No exception should be thrown. - // withHiveState changes the classloader to MutableURLClassLoader - val classLoader = Thread.currentThread().getContextClassLoader - .asInstanceOf[MutableURLClassLoader] - - val urls = classLoader.getURLs() - urls.contains(new File(newClassPath).toURI.toURL) - } - } - - test(s"$version: reset") { - // Clears all database, tables, functions... - client.reset() - assert(client.listTables("default").isEmpty) - } - - /////////////////////////////////////////////////////////////////////////// - // End-To-End tests - /////////////////////////////////////////////////////////////////////////// - - test(s"$version: CREATE TABLE AS SELECT") { - withTable("tbl") { - versionSpark.sql("CREATE TABLE tbl AS SELECT 1 AS a") - assert(versionSpark.table("tbl").collect().toSeq == Seq(Row(1))) - val tableMeta = versionSpark.sessionState.catalog.getTableMetadata(TableIdentifier("tbl")) - val totalSize = tableMeta.stats.map(_.sizeInBytes) - // Except 0.12, all the following versions will fill the Hive-generated statistics - if (version == "0.12") { - assert(totalSize.isEmpty) - } else { - assert(totalSize.nonEmpty && totalSize.get > 0) - } - } - } - - test(s"$version: CREATE Partitioned TABLE AS SELECT") { - withTable("tbl") { - versionSpark.sql( - """ - |CREATE TABLE tbl(c1 string) - |USING hive - |PARTITIONED BY (ds STRING) - """.stripMargin) - versionSpark.sql("INSERT OVERWRITE TABLE tbl partition (ds='2') SELECT '1'") - - assert(versionSpark.table("tbl").collect().toSeq == Seq(Row("1", "2"))) - val partMeta = versionSpark.sessionState.catalog.getPartition( - TableIdentifier("tbl"), spec = Map("ds" -> "2")).parameters - val totalSize = partMeta.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong) - val numFiles = partMeta.get(StatsSetupConst.NUM_FILES).map(_.toLong) - // Except 0.12, all the following versions will fill the Hive-generated statistics - if (version == "0.12") { - assert(totalSize.isEmpty && numFiles.isEmpty) - } else { - assert(totalSize.nonEmpty && numFiles.nonEmpty) - } - - versionSpark.sql( - """ - |ALTER TABLE tbl PARTITION (ds='2') - |SET SERDEPROPERTIES ('newKey' = 'vvv') - """.stripMargin) - val newPartMeta = versionSpark.sessionState.catalog.getPartition( - TableIdentifier("tbl"), spec = Map("ds" -> "2")).parameters - - val newTotalSize = newPartMeta.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong) - val newNumFiles = newPartMeta.get(StatsSetupConst.NUM_FILES).map(_.toLong) - // Except 0.12, all the following versions will fill the Hive-generated statistics - if (version == "0.12") { - assert(newTotalSize.isEmpty && newNumFiles.isEmpty) - } else { - assert(newTotalSize.nonEmpty && newNumFiles.nonEmpty) - } - } - } - - test(s"$version: Delete the temporary staging directory and files after each insert") { - withTempDir { tmpDir => - withTable("tab") { - versionSpark.sql( - s""" - |CREATE TABLE tab(c1 string) - |location '${tmpDir.toURI.toString}' - """.stripMargin) - - (1 to 3).map { i => - versionSpark.sql(s"INSERT OVERWRITE TABLE tab SELECT '$i'") - } - def listFiles(path: File): List[String] = { - val dir = path.listFiles() - val folders = dir.filter(_.isDirectory).toList - val filePaths = dir.map(_.getName).toList - folders.flatMap(listFiles) ++: filePaths - } - // expect 2 files left: `.part-00000-random-uuid.crc` and `part-00000-random-uuid` - // 0.12, 0.13, 1.0 and 1.1 also has another two more files ._SUCCESS.crc and _SUCCESS - val metadataFiles = Seq("._SUCCESS.crc", "_SUCCESS") - assert(listFiles(tmpDir).filterNot(metadataFiles.contains).length == 2) - } - } - } - - test(s"$version: SPARK-13709: reading partitioned Avro table with nested schema") { - withTempDir { dir => - val path = dir.toURI.toString - val tableName = "spark_13709" - val tempTableName = "spark_13709_temp" - - new File(dir.getAbsolutePath, tableName).mkdir() - new File(dir.getAbsolutePath, tempTableName).mkdir() - - val avroSchema = - """{ - | "name": "test_record", - | "type": "record", - | "fields": [ { - | "name": "f0", - | "type": "int" - | }, { - | "name": "f1", - | "type": { - | "type": "record", - | "name": "inner", - | "fields": [ { - | "name": "f10", - | "type": "int" - | }, { - | "name": "f11", - | "type": "double" - | } ] - | } - | } ] - |} - """.stripMargin - - withTable(tableName, tempTableName) { - // Creates the external partitioned Avro table to be tested. - versionSpark.sql( - s"""CREATE EXTERNAL TABLE $tableName - |PARTITIONED BY (ds STRING) - |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe' - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat' - |LOCATION '$path/$tableName' - |TBLPROPERTIES ('avro.schema.literal' = '$avroSchema') - """.stripMargin - ) - - // Creates an temporary Avro table used to prepare testing Avro file. - versionSpark.sql( - s"""CREATE EXTERNAL TABLE $tempTableName - |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe' - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat' - |LOCATION '$path/$tempTableName' - |TBLPROPERTIES ('avro.schema.literal' = '$avroSchema') - """.stripMargin - ) - - // Generates Avro data. - versionSpark.sql(s"INSERT OVERWRITE TABLE $tempTableName SELECT 1, STRUCT(2, 2.5)") - - // Adds generated Avro data as a new partition to the testing table. - versionSpark.sql( - s"ALTER TABLE $tableName ADD PARTITION (ds = 'foo') LOCATION '$path/$tempTableName'") - - // The following query fails before SPARK-13709 is fixed. This is because when reading - // data from table partitions, Avro deserializer needs the Avro schema, which is defined - // in table property "avro.schema.literal". However, we only initializes the deserializer - // using partition properties, which doesn't include the wanted property entry. Merging - // two sets of properties solves the problem. - assert(versionSpark.sql(s"SELECT * FROM $tableName").collect() === - Array(Row(1, Row(2, 2.5D), "foo"))) - } - } - } - - test(s"$version: CTAS for managed data source tables") { - withTable("t", "t1") { - versionSpark.range(1).write.saveAsTable("t") - assert(versionSpark.table("t").collect() === Array(Row(0))) - versionSpark.sql("create table t1 using parquet as select 2 as a") - assert(versionSpark.table("t1").collect() === Array(Row(2))) - } - } - - test(s"$version: Decimal support of Avro Hive serde") { - val tableName = "tab1" - // TODO: add the other logical types. For details, see the link: - // https://avro.apache.org/docs/1.8.1/spec.html#Logical+Types - val avroSchema = - """{ - | "name": "test_record", - | "type": "record", - | "fields": [ { - | "name": "f0", - | "type": [ - | "null", - | { - | "precision": 38, - | "scale": 2, - | "type": "bytes", - | "logicalType": "decimal" - | } - | ] - | } ] - |} - """.stripMargin - - Seq(true, false).foreach { isPartitioned => - withTable(tableName) { - val partitionClause = if (isPartitioned) "PARTITIONED BY (ds STRING)" else "" - // Creates the (non-)partitioned Avro table - versionSpark.sql( - s""" - |CREATE TABLE $tableName - |$partitionClause - |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe' - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat' - |TBLPROPERTIES ('avro.schema.literal' = '$avroSchema') - """.stripMargin - ) - - val errorMsg = "Cannot safely cast 'f0': decimal(2,1) to binary" - - if (isPartitioned) { - val insertStmt = s"INSERT OVERWRITE TABLE $tableName partition (ds='a') SELECT 1.3" - if (version == "0.12" || version == "0.13") { - val e = intercept[AnalysisException](versionSpark.sql(insertStmt)).getMessage - assert(e.contains(errorMsg)) - } else { - versionSpark.sql(insertStmt) - assert(versionSpark.table(tableName).collect() === - versionSpark.sql("SELECT 1.30, 'a'").collect()) - } - } else { - val insertStmt = s"INSERT OVERWRITE TABLE $tableName SELECT 1.3" - if (version == "0.12" || version == "0.13") { - val e = intercept[AnalysisException](versionSpark.sql(insertStmt)).getMessage - assert(e.contains(errorMsg)) - } else { - versionSpark.sql(insertStmt) - assert(versionSpark.table(tableName).collect() === - versionSpark.sql("SELECT 1.30").collect()) - } - } - } - } - } - - test(s"$version: read avro file containing decimal") { - val url = Thread.currentThread().getContextClassLoader.getResource("avroDecimal") - val location = new File(url.getFile).toURI.toString - - val tableName = "tab1" - val avroSchema = - """{ - | "name": "test_record", - | "type": "record", - | "fields": [ { - | "name": "f0", - | "type": [ - | "null", - | { - | "precision": 38, - | "scale": 2, - | "type": "bytes", - | "logicalType": "decimal" - | } - | ] - | } ] - |} - """.stripMargin - withTable(tableName) { - versionSpark.sql( - s""" - |CREATE TABLE $tableName - |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe' - |WITH SERDEPROPERTIES ('respectSparkSchema' = 'true') - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat' - |LOCATION '$location' - |TBLPROPERTIES ('avro.schema.literal' = '$avroSchema') - """.stripMargin - ) - assert(versionSpark.table(tableName).collect() === - versionSpark.sql("SELECT 1.30").collect()) - } - } - - test(s"$version: SPARK-17920: Insert into/overwrite avro table") { - // skipped because it's failed in the condition on Windows - assume(!(Utils.isWindows && version == "0.12")) - withTempDir { dir => - val avroSchema = - """ - |{ - | "name": "test_record", - | "type": "record", - | "fields": [{ - | "name": "f0", - | "type": [ - | "null", - | { - | "precision": 38, - | "scale": 2, - | "type": "bytes", - | "logicalType": "decimal" - | } - | ] - | }] - |} - """.stripMargin - val schemaFile = new File(dir, "avroDecimal.avsc") - Utils.tryWithResource(new PrintWriter(schemaFile)) { writer => - writer.write(avroSchema) - } - val schemaPath = schemaFile.toURI.toString - - val url = Thread.currentThread().getContextClassLoader.getResource("avroDecimal") - val srcLocation = new File(url.getFile).toURI.toString - val destTableName = "tab1" - val srcTableName = "tab2" - - withTable(srcTableName, destTableName) { - versionSpark.sql( - s""" - |CREATE EXTERNAL TABLE $srcTableName - |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe' - |WITH SERDEPROPERTIES ('respectSparkSchema' = 'true') - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat' - |LOCATION '$srcLocation' - |TBLPROPERTIES ('avro.schema.url' = '$schemaPath') - """.stripMargin - ) - - versionSpark.sql( - s""" - |CREATE TABLE $destTableName - |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe' - |WITH SERDEPROPERTIES ('respectSparkSchema' = 'true') - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat' - |TBLPROPERTIES ('avro.schema.url' = '$schemaPath') - """.stripMargin - ) - versionSpark.sql( - s"""INSERT OVERWRITE TABLE $destTableName SELECT * FROM $srcTableName""") - val result = versionSpark.table(srcTableName).collect() - assert(versionSpark.table(destTableName).collect() === result) - versionSpark.sql( - s"""INSERT INTO TABLE $destTableName SELECT * FROM $srcTableName""") - assert(versionSpark.table(destTableName).collect().toSeq === result ++ result) - } - } - } - // TODO: add more tests. - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 014feb33df5ea..c4cef44b6cc90 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.connector.catalog.SupportsNamespaces.PROP_OWNER import org.apache.spark.sql.execution.command.{DDLSuite, DDLUtils} import org.apache.spark.sql.execution.datasources.parquet.ParquetFooterReader import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.HiveExternalCatalog +import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils} import org.apache.spark.sql.hive.HiveUtils.{CONVERT_METASTORE_ORC, CONVERT_METASTORE_PARQUET} import org.apache.spark.sql.hive.orc.OrcFileOperator import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton, TestHiveSparkSession} @@ -844,7 +844,7 @@ class HiveDDLSuite assert( catalog.getTableMetadata(TableIdentifier(tabName)).tableType == CatalogTableType.MANAGED) // The table property is case sensitive. Thus, external is allowed - sql(s"ALTER TABLE $tabName SET TBLPROPERTIES ('external' = 'TRUE')") + sql(s"ALTER TABLE $tabName SET TBLPROPERTIES ('External' = 'TRUE')") // The table type is not changed to external assert( catalog.getTableMetadata(TableIdentifier(tabName)).tableType == CatalogTableType.MANAGED) @@ -2926,54 +2926,42 @@ class HiveDDLSuite } } - test("SPARK-33844: Insert overwrite directory should check schema too") { - withView("v") { - spark.range(1).createTempView("v") - withTempPath { path => - val e = intercept[AnalysisException] { - spark.sql(s"INSERT OVERWRITE LOCAL DIRECTORY '${path.getCanonicalPath}' " + - s"STORED AS PARQUET SELECT ID, if(1=1, 1, 0), abs(id), '^-' FROM v") - }.getMessage - assert(e.contains("Column name \"(IF((1 = 1), 1, 0))\" contains invalid character(s). " + - "Please use alias to rename it.")) + test("SPARK-33844, 37969: Insert overwrite directory should check schema too") { + withSQLConf(HiveUtils.CONVERT_METASTORE_INSERT_DIR.key -> "false") { + withView("v") { + spark.range(1).createTempView("v") + withTempPath { path => + Seq("PARQUET", "ORC").foreach { format => + val e = intercept[SparkException] { + spark.sql(s"INSERT OVERWRITE LOCAL DIRECTORY '${path.getCanonicalPath}' " + + s"STORED AS $format SELECT ID, if(1=1, 1, 0), abs(id), '^-' FROM v") + }.getCause.getMessage + assert(e.contains("Column name \"(IF((1 = 1), 1, 0))\" contains" + + " invalid character(s). Please use alias to rename it.")) + } + } } } } test("SPARK-36201: Add check for inner field of parquet/orc schema") { - withView("v") { - spark.range(1).createTempView("v") - withTempPath { path => - val e = intercept[AnalysisException] { - spark.sql( - s""" - |INSERT OVERWRITE LOCAL DIRECTORY '${path.getCanonicalPath}' - |STORED AS PARQUET - |SELECT - |NAMED_STRUCT('ID', ID, 'IF(ID=1,ID,0)', IF(ID=1,ID,0), 'B', ABS(ID)) AS col1 - |FROM v + withSQLConf(HiveUtils.CONVERT_METASTORE_INSERT_DIR.key -> "false") { + withView("v") { + spark.range(1).createTempView("v") + withTempPath { path => + val e = intercept[SparkException] { + spark.sql( + s""" + |INSERT OVERWRITE LOCAL DIRECTORY '${path.getCanonicalPath}' + |STORED AS PARQUET + |SELECT + |NAMED_STRUCT('ID', ID, 'IF(ID=1,ID,0)', IF(ID=1,ID,0), 'B', ABS(ID)) AS col1 + |FROM v """.stripMargin) - }.getMessage - assert(e.contains("Column name \"IF(ID=1,ID,0)\" contains" + - " invalid character(s). Please use alias to rename it.")) - } - } - } - - test("SPARK-36312: ParquetWriteSupport should check inner field") { - withView("v") { - spark.range(1).createTempView("v") - withTempPath { path => - val e = intercept[AnalysisException] { - spark.sql( - """ - |SELECT - |NAMED_STRUCT('ID', ID, 'IF(ID=1,ID,0)', IF(ID=1,ID,0), 'B', ABS(ID)) AS col1 - |FROM v - |""".stripMargin).write.mode(SaveMode.Overwrite).parquet(path.toString) - }.getMessage - assert(e.contains("Column name \"IF(ID=1,ID,0)\" contains" + - " invalid character(s). Please use alias to rename it.")) + }.getCause.getMessage + assert(e.contains("Column name \"IF(ID=1,ID,0)\" contains invalid character(s). " + + "Please use alias to rename it.")) + } } } } @@ -3059,4 +3047,10 @@ class HiveDDLSuite assert(df1.schema.names.toSeq == Seq("A", "B")) } } + + test("SPARK-38216: Fail early if all the columns are partitioned columns") { + assertAnalysisError( + "CREATE TABLE tab (c1 int) PARTITIONED BY (c1) STORED AS PARQUET", + "Cannot use all columns for partition columns") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index ee091e89379c4..e80c41401227d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -236,17 +236,19 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd createQueryTest("no from clause", "SELECT 1, +1, -1") - createQueryTest("boolean = number", - """ - |SELECT - | 1 = true, 1L = true, 1Y = true, true = 1, true = 1L, true = 1Y, - | 0 = true, 0L = true, 0Y = true, true = 0, true = 0L, true = 0Y, - | 1 = false, 1L = false, 1Y = false, false = 1, false = 1L, false = 1Y, - | 0 = false, 0L = false, 0Y = false, false = 0, false = 0L, false = 0Y, - | 2 = true, 2L = true, 2Y = true, true = 2, true = 2L, true = 2Y, - | 2 = false, 2L = false, 2Y = false, false = 2, false = 2L, false = 2Y - |FROM src LIMIT 1 + if (!conf.ansiEnabled) { + createQueryTest("boolean = number", + """ + |SELECT + | 1 = true, 1L = true, 1Y = true, true = 1, true = 1L, true = 1Y, + | 0 = true, 0L = true, 0Y = true, true = 0, true = 0L, true = 0Y, + | 1 = false, 1L = false, 1Y = false, false = 1, false = 1L, false = 1Y, + | 0 = false, 0L = false, 0Y = false, false = 0, false = 0L, false = 0Y, + | 2 = true, 2L = true, 2Y = true, true = 2, true = 2L, true = 2Y, + | 2 = false, 2L = false, 2Y = false, false = 2, false = 2L, false = 2Y + |FROM src LIMIT 1 """.stripMargin) + } test("CREATE TABLE AS runs once") { sql("CREATE TABLE foo AS SELECT 1 FROM src LIMIT 1").collect() @@ -282,11 +284,13 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd createQueryTest("Constant Folding Optimization for AVG_SUM_COUNT", "SELECT AVG(0), SUM(0), COUNT(null), COUNT(value) FROM src GROUP BY key") - createQueryTest("Cast Timestamp to Timestamp in UDF", - """ - | SELECT DATEDIFF(CAST(value AS timestamp), CAST('2002-03-21 00:00:00' AS timestamp)) - | FROM src LIMIT 1 + if (!conf.ansiEnabled) { + createQueryTest("Cast Timestamp to Timestamp in UDF", + """ + | SELECT DATEDIFF(CAST(value AS timestamp), CAST('2002-03-21 00:00:00' AS timestamp)) + | FROM src LIMIT 1 """.stripMargin) + } createQueryTest("Date comparison test 1", """ @@ -516,8 +520,10 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd createQueryTest("Specify the udtf output", "SELECT d FROM (SELECT explode(array(1,1)) d FROM src LIMIT 1) t") - createQueryTest("SPARK-9034 Reflect field names defined in GenericUDTF #1", - "SELECT col FROM (SELECT explode(array(key,value)) FROM src LIMIT 1) t") + if (!conf.ansiEnabled) { + createQueryTest("SPARK-9034 Reflect field names defined in GenericUDTF #1", + "SELECT col FROM (SELECT explode(array(key,value)) FROM src LIMIT 1) t") + } createQueryTest("SPARK-9034 Reflect field names defined in GenericUDTF #2", "SELECT key,value FROM (SELECT explode(map(key,value)) FROM src LIMIT 1) t") @@ -768,9 +774,11 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd test("SPARK-5367: resolve star expression in udf") { assert(sql("select concat(*) from src limit 5").collect().size == 5) - assert(sql("select array(*) from src limit 5").collect().size == 5) assert(sql("select concat(key, *) from src limit 5").collect().size == 5) - assert(sql("select array(key, *) from src limit 5").collect().size == 5) + if (!conf.ansiEnabled) { + assert(sql("select array(*) from src limit 5").collect().size == 5) + assert(sql("select array(key, *) from src limit 5").collect().size == 5) + } } test("Exactly once semantics for DDL and command statements") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala index 189a5c5768f61..cbf5e640db468 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType, HiveTableRelation} import org.apache.spark.sql.execution.SQLViewSuite -import org.apache.spark.sql.hive.HiveUtils +import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.types.{NullType, StructType} import org.apache.spark.tags.SlowHiveTest @@ -181,4 +181,43 @@ class HiveSQLViewSuite extends SQLViewSuite with TestHiveSingleton { } } } + + test("hive partitioned view is not supported") { + withTable("test") { + withView("v1") { + sql( + s""" + |CREATE TABLE test (c1 INT, c2 STRING) + |PARTITIONED BY ( + | p1 BIGINT COMMENT 'bla', + | p2 STRING ) + """.stripMargin) + + createRawHiveTable( + s""" + |CREATE VIEW v1 + |PARTITIONED ON (p1, p2) + |AS SELECT * from test + """.stripMargin + ) + + val cause = intercept[AnalysisException] { + sql("SHOW CREATE TABLE v1") + } + + assert(cause.getMessage.contains(" - partitioned view")) + + val causeForSpark = intercept[AnalysisException] { + sql("SHOW CREATE TABLE v1 AS SERDE") + } + + assert(causeForSpark.getMessage.contains(" - partitioned view")) + } + } + } + + private def createRawHiveTable(ddl: String): Unit = { + hiveContext.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] + .client.runSqlHive(ddl) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala index d6185ac487d65..d54265e53c126 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala @@ -621,4 +621,21 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T assert(e.contains("java.lang.ArithmeticException: long overflow")) } } + + test("SPARK-38075: ORDER BY with LIMIT should not add fake rows") { + withTempView("v") { + val df = Seq((1), (2), (3)).toDF("a") + df.createTempView("v") + checkAnswer(sql( + """ + |SELECT TRANSFORM(a) + | USING 'cat' AS (a) + |FROM v + |ORDER BY a + |LIMIT 10 + |""".stripMargin), + identity, + Row("1") :: Row("2") :: Row("3") :: Nil) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index 3f75454b8d8da..9b0d7d9f674d9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo} import org.apache.spark.sql.execution.ProjectExec import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.tags.SlowHiveTest /** @@ -27,13 +28,21 @@ import org.apache.spark.tags.SlowHiveTest */ @SlowHiveTest class HiveTypeCoercionSuite extends HiveComparisonTest { - val baseTypes = Seq( + val baseTypes = if (SQLConf.get.ansiEnabled) { + Seq( ("1", "1"), ("1.0", "CAST(1.0 AS DOUBLE)"), - ("1L", "1L"), ("1S", "1S"), - ("1Y", "1Y"), - ("'1'", "'1'")) + ("1Y", "1Y")) + } else { + Seq( + ("1", "1"), + ("1.0", "CAST(1.0 AS DOUBLE)"), + ("1L", "1L"), + ("1S", "1S"), + ("1Y", "1Y"), + ("'1'", "'1'")) + } baseTypes.foreach { case (ni, si) => baseTypes.foreach { case (nj, sj) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 1829f38fe5775..f2711db839913 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.execution.TestUncaughtExceptionHandler import org.apache.spark.sql.execution.adaptive.{DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} -import org.apache.spark.sql.execution.command.LoadDataCommand +import org.apache.spark.sql.execution.command.{InsertIntoDataSourceDirCommand, LoadDataCommand} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils} @@ -2212,44 +2212,10 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi } } - test("SPARK-21912 Parquet table should not create invalid column names") { - Seq(" ", ",", ";", "{", "}", "(", ")", "\n", "\t", "=").foreach { name => - val source = "PARQUET" - withTable("t21912") { - val m = intercept[AnalysisException] { - sql(s"CREATE TABLE t21912(`col$name` INT) USING $source") - }.getMessage - assert(m.contains(s"contains invalid character(s)")) - - val m1 = intercept[AnalysisException] { - sql(s"CREATE TABLE t21912 STORED AS $source AS SELECT 1 `col$name`") - }.getMessage - assert(m1.contains(s"contains invalid character(s)")) - - val m2 = intercept[AnalysisException] { - sql(s"CREATE TABLE t21912 USING $source AS SELECT 1 `col$name`") - }.getMessage - assert(m2.contains(s"contains invalid character(s)")) - - withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> "false") { - val m3 = intercept[AnalysisException] { - sql(s"CREATE TABLE t21912(`col$name` INT) USING hive OPTIONS (fileFormat '$source')") - }.getMessage - assert(m3.contains(s"contains invalid character(s)")) - } - - sql(s"CREATE TABLE t21912(`col` INT) USING $source") - val m4 = intercept[AnalysisException] { - sql(s"ALTER TABLE t21912 ADD COLUMNS(`col$name` INT)") - }.getMessage - assert(m4.contains(s"contains invalid character(s)")) - } - } - } - test("SPARK-32889: ORC table column name supports special characters") { - // " " "," is not allowed. - Seq("$", ";", "{", "}", "(", ")", "\n", "\t", "=").foreach { name => + // "," is not allowed since cannot create a table having a column whose name + // contains commas in Hive metastore. + Seq("$", ";", "{", "}", "(", ")", "\n", "\t", "=", " ", "a b").foreach { name => val source = "ORC" Seq(s"CREATE TABLE t32889(`$name` INT) USING $source", s"CREATE TABLE t32889 STORED AS $source AS SELECT 1 `$name`", @@ -2688,6 +2654,46 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi } } } + + test("SPARK-38215: Hive Insert Dir should use data source if it is convertible") { + withTempView("p") { + Seq(1, 2, 3).toDF("id").createOrReplaceTempView("p") + + Seq("orc", "parquet").foreach { format => + Seq(true, false).foreach { isConverted => + withSQLConf( + HiveUtils.CONVERT_METASTORE_ORC.key -> s"$isConverted", + HiveUtils.CONVERT_METASTORE_PARQUET.key -> s"$isConverted") { + Seq(true, false).foreach { isConvertedCtas => + withSQLConf(HiveUtils.CONVERT_METASTORE_INSERT_DIR.key -> s"$isConvertedCtas") { + withTempDir { dir => + val df = sql( + s""" + |INSERT OVERWRITE LOCAL DIRECTORY '${dir.getAbsolutePath}' + |STORED AS $format + |SELECT 1 + """.stripMargin) + val insertIntoDSDir = df.queryExecution.analyzed.collect { + case _: InsertIntoDataSourceDirCommand => true + }.headOption + val insertIntoHiveDir = df.queryExecution.analyzed.collect { + case _: InsertIntoHiveDirCommand => true + }.headOption + if (isConverted && isConvertedCtas) { + assert(insertIntoDSDir.nonEmpty) + assert(insertIntoHiveDir.isEmpty) + } else { + assert(insertIntoDSDir.isEmpty) + assert(insertIntoHiveDir.nonEmpty) + } + } + } + } + } + } + } + } + } } @SlowHiveTest diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/DropNamespaceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/DropNamespaceSuite.scala index cabebb9e11510..955fe332cf1d0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/DropNamespaceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/DropNamespaceSuite.scala @@ -25,4 +25,5 @@ import org.apache.spark.sql.execution.command.v1 */ class DropNamespaceSuite extends v1.DropNamespaceSuiteBase with CommandSuiteBase { override def isCasePreserving: Boolean = false + override def commandVersion: String = super[DropNamespaceSuiteBase].commandVersion } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveShowCreateTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/ShowCreateTableSuite.scala similarity index 63% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveShowCreateTableSuite.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/ShowCreateTableSuite.scala index e3a1034ad4f1d..a7d5e7b083488 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveShowCreateTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/ShowCreateTableSuite.scala @@ -15,77 +15,30 @@ * limitations under the License. */ -package org.apache.spark.sql.hive +package org.apache.spark.sql.hive.execution.command -import org.apache.spark.sql.{AnalysisException, ShowCreateTableSuite} +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.CatalogTable -import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} - -class HiveShowCreateTableSuite extends ShowCreateTableSuite with TestHiveSingleton { - - private var origCreateHiveTableConfig = false - - protected override def beforeAll(): Unit = { - super.beforeAll() - origCreateHiveTableConfig = - spark.conf.get(SQLConf.LEGACY_CREATE_HIVE_TABLE_BY_DEFAULT) - spark.conf.set(SQLConf.LEGACY_CREATE_HIVE_TABLE_BY_DEFAULT.key, true) - } - - protected override def afterAll(): Unit = { - spark.conf.set( - SQLConf.LEGACY_CREATE_HIVE_TABLE_BY_DEFAULT.key, - origCreateHiveTableConfig) - super.afterAll() - } - - test("view") { - Seq(true, false).foreach { serde => - withView("v1") { - sql("CREATE VIEW v1 AS SELECT 1 AS a") - checkCreateView("v1", serde) - } - } - } - - test("view with output columns") { - Seq(true, false).foreach { serde => - withView("v1") { - sql("CREATE VIEW v1 (a, b COMMENT 'b column') AS SELECT 1 AS a, 2 AS b") - checkCreateView("v1", serde) - } - } - } - - test("view with table comment and properties") { - Seq(true, false).foreach { serde => - withView("v1") { - sql( - s""" - |CREATE VIEW v1 ( - | c1 COMMENT 'bla', - | c2 - |) - |COMMENT 'table comment' - |TBLPROPERTIES ( - | 'prop1' = 'value1', - | 'prop2' = 'value2' - |) - |AS SELECT 1 AS c1, '2' AS c2 - """.stripMargin - ) +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogUtils} +import org.apache.spark.sql.catalyst.util.escapeSingleQuotedString +import org.apache.spark.sql.execution.command.v1 +import org.apache.spark.sql.internal.HiveSerDe + +/** + * The class contains tests for the `SHOW CREATE TABLE` command to check V1 Hive external + * table catalog. + */ +class ShowCreateTableSuite extends v1.ShowCreateTableSuiteBase with CommandSuiteBase { + override def commandVersion: String = super[ShowCreateTableSuiteBase].commandVersion - checkCreateView("v1", serde) - } - } + override def getShowCreateDDL(table: String, serde: Boolean = false): Array[String] = { + super.getShowCreateDDL(table, serde).filter(!_.startsWith("'transient_lastDdlTime'")) } test("simple hive table") { - withTable("t1") { + withNamespaceAndTable(ns, table) { t => sql( - s"""CREATE TABLE t1 ( + s"""CREATE TABLE $t ( | c1 INT COMMENT 'bla', | c2 STRING |) @@ -95,16 +48,21 @@ class HiveShowCreateTableSuite extends ShowCreateTableSuite with TestHiveSinglet |) """.stripMargin ) - - checkCreateTable("t1", serde = true) + val expected = s"CREATE TABLE $fullName ( c1 INT COMMENT 'bla', c2 STRING)" + + " ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe'" + + " WITH SERDEPROPERTIES ( 'serialization.format' = '1')" + + " STORED AS INPUTFORMAT 'org.apache.hadoop.mapred.TextInputFormat'" + + " OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat'" + + " TBLPROPERTIES ( 'prop1' = 'value1', 'prop2' = 'value2'," + assert(getShowCreateDDL(t, true).mkString(" ") == expected) } } test("simple external hive table") { withTempDir { dir => - withTable("t1") { + withNamespaceAndTable(ns, table) { t => sql( - s"""CREATE TABLE t1 ( + s"""CREATE TABLE $t ( | c1 INT COMMENT 'bla', | c2 STRING |) @@ -115,16 +73,23 @@ class HiveShowCreateTableSuite extends ShowCreateTableSuite with TestHiveSinglet |) """.stripMargin ) - - checkCreateTable("t1", serde = true) + val expected = s"CREATE EXTERNAL TABLE $fullName ( c1 INT COMMENT 'bla', c2 STRING)" + + s" ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe'" + + s" WITH SERDEPROPERTIES ( 'serialization.format' = '1')" + + s" STORED AS INPUTFORMAT 'org.apache.hadoop.mapred.TextInputFormat'" + + s" OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat'" + + s" LOCATION" + + s" '${escapeSingleQuotedString(CatalogUtils.URIToString(dir.toURI)).dropRight(1)}'" + + s" TBLPROPERTIES ( 'prop1' = 'value1', 'prop2' = 'value2'," + assert(getShowCreateDDL(t, true).mkString(" ") == expected) } } } test("partitioned hive table") { - withTable("t1") { + withNamespaceAndTable(ns, table) { t => sql( - s"""CREATE TABLE t1 ( + s"""CREATE TABLE $t ( | c1 INT COMMENT 'bla', | c2 STRING |) @@ -135,15 +100,21 @@ class HiveShowCreateTableSuite extends ShowCreateTableSuite with TestHiveSinglet |) """.stripMargin ) - - checkCreateTable("t1", serde = true) + val expected = s"CREATE TABLE $fullName ( c1 INT COMMENT 'bla', c2 STRING)" + + " COMMENT 'bla' PARTITIONED BY (p1 BIGINT COMMENT 'bla', p2 STRING)" + + " ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe'" + + " WITH SERDEPROPERTIES ( 'serialization.format' = '1')" + + " STORED AS INPUTFORMAT 'org.apache.hadoop.mapred.TextInputFormat'" + + " OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat'" + + " TBLPROPERTIES (" + assert(getShowCreateDDL(t, true).mkString(" ") == expected) } } test("hive table with explicit storage info") { - withTable("t1") { + withNamespaceAndTable(ns, table) { t => sql( - s"""CREATE TABLE t1 ( + s"""CREATE TABLE $t ( | c1 INT COMMENT 'bla', | c2 STRING |) @@ -153,30 +124,44 @@ class HiveShowCreateTableSuite extends ShowCreateTableSuite with TestHiveSinglet |NULL DEFINED AS 'NaN' """.stripMargin ) - - checkCreateTable("t1", serde = true) + val expected = s"CREATE TABLE $fullName ( c1 INT COMMENT 'bla', c2 STRING)" + + " ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe'" + + " WITH SERDEPROPERTIES (" + + " 'colelction.delim' = '@'," + + " 'mapkey.delim' = '#'," + + " 'serialization.format' = ','," + + " 'field.delim' = ',')" + + " STORED AS INPUTFORMAT 'org.apache.hadoop.mapred.TextInputFormat'" + + " OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat'" + + " TBLPROPERTIES (" + assert(getShowCreateDDL(t, true).mkString(" ") == expected) } } test("hive table with STORED AS clause") { - withTable("t1") { + withNamespaceAndTable(ns, table) { t => sql( - s"""CREATE TABLE t1 ( + s"""CREATE TABLE $t ( | c1 INT COMMENT 'bla', | c2 STRING |) |STORED AS PARQUET """.stripMargin ) - - checkCreateTable("t1", serde = true) + val expected = s"CREATE TABLE $fullName ( c1 INT COMMENT 'bla', c2 STRING)" + + " ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'" + + " WITH SERDEPROPERTIES ( 'serialization.format' = '1')" + + " STORED AS INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat'" + + " OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat'" + + " TBLPROPERTIES (" + assert(getShowCreateDDL(t, true).mkString(" ") == expected) } } test("hive table with serde info") { - withTable("t1") { + withNamespaceAndTable(ns, table) { t => sql( - s"""CREATE TABLE t1 ( + s"""CREATE TABLE $t ( | c1 INT COMMENT 'bla', | c2 STRING |) @@ -190,75 +175,39 @@ class HiveShowCreateTableSuite extends ShowCreateTableSuite with TestHiveSinglet | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' """.stripMargin ) - - checkCreateTable("t1", serde = true) + val expected = s"CREATE TABLE $fullName ( c1 INT COMMENT 'bla', c2 STRING)" + + " ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'" + + " WITH SERDEPROPERTIES (" + + " 'mapkey.delim' = ','," + + " 'serialization.format' = '1'," + + " 'field.delim' = ',')" + + " STORED AS INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat'" + + " OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat'" + + " TBLPROPERTIES (" + assert(getShowCreateDDL(t, true).mkString(" ") == expected) } } test("hive bucketing is supported") { - withTable("t1") { + withNamespaceAndTable(ns, table) { t => sql( - s"""CREATE TABLE t1 (a INT, b STRING) + s"""CREATE TABLE $t (a INT, b STRING) |CLUSTERED BY (a) |SORTED BY (b) |INTO 2 BUCKETS """.stripMargin ) - checkCreateTable("t1", serde = true) - } - } - - test("hive partitioned view is not supported") { - withTable("t1") { - withView("v1") { - sql( - s""" - |CREATE TABLE t1 (c1 INT, c2 STRING) - |PARTITIONED BY ( - | p1 BIGINT COMMENT 'bla', - | p2 STRING ) - """.stripMargin) - - createRawHiveTable( - s""" - |CREATE VIEW v1 - |PARTITIONED ON (p1, p2) - |AS SELECT * from t1 - """.stripMargin - ) - - val cause = intercept[AnalysisException] { - sql("SHOW CREATE TABLE v1") - } - - assert(cause.getMessage.contains(" - partitioned view")) - - val causeForSpark = intercept[AnalysisException] { - sql("SHOW CREATE TABLE v1 AS SERDE") - } - - assert(causeForSpark.getMessage.contains(" - partitioned view")) - } + val expected = s"CREATE TABLE $fullName ( a INT, b STRING)" + + " CLUSTERED BY (a) SORTED BY (b ASC) INTO 2 BUCKETS" + + " ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe'" + + " WITH SERDEPROPERTIES ( 'serialization.format' = '1')" + + " STORED AS INPUTFORMAT 'org.apache.hadoop.mapred.TextInputFormat'" + + " OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat'" + + " TBLPROPERTIES (" + assert(getShowCreateDDL(t, true).mkString(" ") == expected) } } - test("SPARK-24911: keep quotes for nested fields in hive") { - withTable("t1") { - val createTable = "CREATE TABLE `t1` (`a` STRUCT<`b`: STRING>) USING hive" - sql(createTable) - val shownDDL = getShowDDL("SHOW CREATE TABLE t1") - assert(shownDDL.substring(0, shownDDL.indexOf(" USING")) == - "CREATE TABLE `default`.`t1` ( `a` STRUCT<`b`: STRING>)") - - checkCreateTable("t1", serde = true) - } - } - - private def createRawHiveTable(ddl: String): Unit = { - hiveContext.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] - .client.runSqlHive(ddl) - } - private def checkCreateSparkTableAsHive(tableName: String): Unit = { val table = TableIdentifier(tableName, Some("default")) val db = table.database.get @@ -339,26 +288,6 @@ class HiveShowCreateTableSuite extends ShowCreateTableSuite with TestHiveSinglet } } - test("show create table as serde can't work on data source table") { - withTable("t1") { - sql( - s""" - |CREATE TABLE t1 ( - | c1 STRING COMMENT 'bla', - | c2 STRING - |) - |USING orc - """.stripMargin - ) - - val cause = intercept[AnalysisException] { - checkCreateTable("t1", serde = true) - } - - assert(cause.getMessage.contains("Use `SHOW CREATE TABLE` without `AS SERDE` instead")) - } - } - test("simple external hive table in Spark DDL") { withTempDir { dir => withTable("t1") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala index 990b34cda33a3..61a9360684166 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala @@ -90,6 +90,10 @@ object OrcReadBenchmark extends SqlBasedBenchmark { prepareTable(dir, spark.sql(s"SELECT CAST(value as ${dataType.sql}) id FROM t1")) + benchmark.addCase("Hive built-in ORC") { _ => + spark.sql("SELECT sum(id) FROM hiveOrcTable").noop() + } + benchmark.addCase("Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { spark.sql("SELECT sum(id) FROM nativeOrcTable").noop() @@ -100,10 +104,6 @@ object OrcReadBenchmark extends SqlBasedBenchmark { spark.sql("SELECT sum(id) FROM nativeOrcTable").noop() } - benchmark.addCase("Hive built-in ORC") { _ => - spark.sql("SELECT sum(id) FROM hiveOrcTable").noop() - } - benchmark.run() } } @@ -121,6 +121,10 @@ object OrcReadBenchmark extends SqlBasedBenchmark { dir, spark.sql("SELECT CAST(value AS INT) AS c1, CAST(value as STRING) AS c2 FROM t1")) + benchmark.addCase("Hive built-in ORC") { _ => + spark.sql("SELECT sum(c1), sum(length(c2)) FROM hiveOrcTable").noop() + } + benchmark.addCase("Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { spark.sql("SELECT sum(c1), sum(length(c2)) FROM nativeOrcTable").noop() @@ -131,10 +135,6 @@ object OrcReadBenchmark extends SqlBasedBenchmark { spark.sql("SELECT sum(c1), sum(length(c2)) FROM nativeOrcTable").noop() } - benchmark.addCase("Hive built-in ORC") { _ => - spark.sql("SELECT sum(c1), sum(length(c2)) FROM hiveOrcTable").noop() - } - benchmark.run() } } @@ -150,6 +150,10 @@ object OrcReadBenchmark extends SqlBasedBenchmark { prepareTable(dir, spark.sql("SELECT value % 2 AS p, value AS id FROM t1"), Some("p")) + benchmark.addCase("Data column - Hive built-in ORC") { _ => + spark.sql("SELECT sum(id) FROM hiveOrcTable").noop() + } + benchmark.addCase("Data column - Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { spark.sql("SELECT sum(id) FROM nativeOrcTable").noop() @@ -160,8 +164,8 @@ object OrcReadBenchmark extends SqlBasedBenchmark { spark.sql("SELECT sum(id) FROM nativeOrcTable").noop() } - benchmark.addCase("Data column - Hive built-in ORC") { _ => - spark.sql("SELECT sum(id) FROM hiveOrcTable").noop() + benchmark.addCase("Partition column - Hive built-in ORC") { _ => + spark.sql("SELECT sum(p) FROM hiveOrcTable").noop() } benchmark.addCase("Partition column - Native ORC MR") { _ => @@ -174,8 +178,8 @@ object OrcReadBenchmark extends SqlBasedBenchmark { spark.sql("SELECT sum(p) FROM nativeOrcTable").noop() } - benchmark.addCase("Partition column - Hive built-in ORC") { _ => - spark.sql("SELECT sum(p) FROM hiveOrcTable").noop() + benchmark.addCase("Both columns - Hive built-in ORC") { _ => + spark.sql("SELECT sum(p), sum(id) FROM hiveOrcTable").noop() } benchmark.addCase("Both columns - Native ORC MR") { _ => @@ -188,10 +192,6 @@ object OrcReadBenchmark extends SqlBasedBenchmark { spark.sql("SELECT sum(p), sum(id) FROM nativeOrcTable").noop() } - benchmark.addCase("Both columns - Hive built-in ORC") { _ => - spark.sql("SELECT sum(p), sum(id) FROM hiveOrcTable").noop() - } - benchmark.run() } } @@ -206,6 +206,10 @@ object OrcReadBenchmark extends SqlBasedBenchmark { prepareTable(dir, spark.sql("SELECT CAST((id % 200) + 10000 as STRING) AS c1 FROM t1")) + benchmark.addCase("Hive built-in ORC") { _ => + spark.sql("SELECT sum(length(c1)) FROM hiveOrcTable").noop() + } + benchmark.addCase("Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { spark.sql("SELECT sum(length(c1)) FROM nativeOrcTable").noop() @@ -216,10 +220,6 @@ object OrcReadBenchmark extends SqlBasedBenchmark { spark.sql("SELECT sum(length(c1)) FROM nativeOrcTable").noop() } - benchmark.addCase("Hive built-in ORC") { _ => - spark.sql("SELECT sum(length(c1)) FROM hiveOrcTable").noop() - } - benchmark.run() } } @@ -240,6 +240,11 @@ object OrcReadBenchmark extends SqlBasedBenchmark { val benchmark = new Benchmark(s"String with Nulls Scan ($percentageOfNulls%)", values, output = output) + benchmark.addCase("Hive built-in ORC") { _ => + spark.sql("SELECT SUM(LENGTH(c2)) FROM hiveOrcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").noop() + } + benchmark.addCase("Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { spark.sql("SELECT SUM(LENGTH(c2)) FROM nativeOrcTable " + @@ -252,11 +257,6 @@ object OrcReadBenchmark extends SqlBasedBenchmark { "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").noop() } - benchmark.addCase("Hive built-in ORC") { _ => - spark.sql("SELECT SUM(LENGTH(c2)) FROM hiveOrcTable " + - "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").noop() - } - benchmark.run() } } @@ -275,6 +275,10 @@ object OrcReadBenchmark extends SqlBasedBenchmark { prepareTable(dir, spark.sql("SELECT * FROM t1")) + benchmark.addCase("Hive built-in ORC") { _ => + spark.sql(s"SELECT sum(c$middle) FROM hiveOrcTable").noop() + } + benchmark.addCase("Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { spark.sql(s"SELECT sum(c$middle) FROM nativeOrcTable").noop() @@ -285,10 +289,6 @@ object OrcReadBenchmark extends SqlBasedBenchmark { spark.sql(s"SELECT sum(c$middle) FROM nativeOrcTable").noop() } - benchmark.addCase("Hive built-in ORC") { _ => - spark.sql(s"SELECT sum(c$middle) FROM hiveOrcTable").noop() - } - benchmark.run() } } @@ -307,6 +307,10 @@ object OrcReadBenchmark extends SqlBasedBenchmark { prepareTable(dir, spark.sql("SELECT * FROM t1")) + benchmark.addCase("Hive built-in ORC") { _ => + spark.sql(s"SELECT * FROM hiveOrcTable").noop() + } + benchmark.addCase("Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { spark.sql(s"SELECT * FROM nativeOrcTable").noop() @@ -319,10 +323,6 @@ object OrcReadBenchmark extends SqlBasedBenchmark { } } - benchmark.addCase("Hive built-in ORC") { _ => - spark.sql(s"SELECT * FROM hiveOrcTable").noop() - } - benchmark.run() } } @@ -346,6 +346,10 @@ object OrcReadBenchmark extends SqlBasedBenchmark { prepareTable(dir, spark.sql("SELECT * FROM t1")) + benchmark.addCase("Hive built-in ORC") { _ => + spark.sql(s"SELECT * FROM hiveOrcTable").noop() + } + benchmark.addCase("Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { spark.sql(s"SELECT * FROM nativeOrcTable").noop() @@ -358,10 +362,6 @@ object OrcReadBenchmark extends SqlBasedBenchmark { } } - benchmark.addCase("Hive built-in ORC") { _ => - spark.sql(s"SELECT * FROM hiveOrcTable").noop() - } - benchmark.run() } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index 2e6b86206a631..18e8401ee3d2b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -107,21 +107,6 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { } } - test("SPARK-8079: Avoid NPE thrown from BaseWriterContainer.abortJob") { - withTempPath { dir => - intercept[AnalysisException] { - // Parquet doesn't allow field names with spaces. Here we are intentionally making an - // exception thrown from the `ParquetRelation2.prepareForWriteJob()` method to trigger - // the bug. Please refer to spark-8079 for more details. - spark.range(1, 10) - .withColumnRenamed("id", "a b") - .write - .format("parquet") - .save(dir.getCanonicalPath) - } - } - } - test("SPARK-8604: Parquet data source should write summary file while doing appending") { withSQLConf( ParquetOutputFormat.JOB_SUMMARY_LEVEL -> "ALL", diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 8008a5c495e9d..282946dd8ef4b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -204,10 +204,12 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { // If manual clock is being used for testing, then // either set the manual clock to the last checkpointed time, // or if the property is defined set it to that time - if (clock.isInstanceOf[ManualClock]) { - val lastTime = ssc.initialCheckpoint.checkpointTime.milliseconds - val jumpTime = ssc.sc.conf.get(StreamingConf.MANUAL_CLOCK_JUMP) - clock.asInstanceOf[ManualClock].setTime(lastTime + jumpTime) + clock match { + case manualClock: ManualClock => + val lastTime = ssc.initialCheckpoint.checkpointTime.milliseconds + val jumpTime = ssc.sc.conf.get(StreamingConf.MANUAL_CLOCK_JUMP) + manualClock.setTime(lastTime + jumpTime) + case _ => // do nothing } val batchDuration = ssc.graph.batchDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 42d0e50a068ec..2c8e51e19d3e3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -294,7 +294,7 @@ private[ui] class StreamingPage(parent: StreamingTab) {if (hasStream) { - {generateInputDStreamsTable(jsCollector, minBatchTime, maxBatchTime, minRecordRate, maxRecordRate)} + {generateInputDStreamsTable(jsCollector, minBatchTime, maxBatchTime, minRecordRate)} }} @@ -340,8 +340,7 @@ private[ui] class StreamingPage(parent: StreamingTab) jsCollector: JsCollector, minX: Long, maxX: Long, - minY: Double, - maxY: Double): Seq[Node] = { + minY: Double): Seq[Node] = { val maxYCalculated = listener.receivedRecordRateWithBatchTime.values .flatMap { case streamAndRates => streamAndRates.map { case (_, recordRate) => recordRate } } .reduceOption[Double](math.max) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala index e207dab7de068..3263f12a4e1ea 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala @@ -47,11 +47,11 @@ object RawTextHelper { i += 1 } } - map.toIterator.map { + map.iterator.map { case (k, v) => (k, v) } } - map.toIterator.map{case (k, v) => (k, v)} + map.iterator.map{case (k, v) => (k, v)} } /** @@ -89,7 +89,7 @@ object RawTextHelper { } } } - taken.toIterator + taken.iterator } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index 4224cef1cbae1..8069e7915b1d1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -296,16 +296,17 @@ private[streaming] class OpenHashMapBasedStateMap[K, S]( var parentSessionLoopDone = false while(!parentSessionLoopDone) { val obj = inputStream.readObject() - if (obj.isInstanceOf[LimitMarker]) { - parentSessionLoopDone = true - val expectedCount = obj.asInstanceOf[LimitMarker].num - assert(expectedCount == newParentSessionStore.deltaMap.size) - } else { - val key = obj.asInstanceOf[K] - val state = inputStream.readObject().asInstanceOf[S] - val updateTime = inputStream.readLong() - newParentSessionStore.deltaMap.update( - key, StateInfo(state, updateTime, deleted = false)) + obj match { + case marker: LimitMarker => + parentSessionLoopDone = true + val expectedCount = marker.num + assert(expectedCount == newParentSessionStore.deltaMap.size) + case _ => + val key = obj.asInstanceOf[K] + val state = inputStream.readObject().asInstanceOf[S] + val updateTime = inputStream.readLong() + newParentSessionStore.deltaMap.update( + key, StateInfo(state, updateTime, deleted = false)) } } parentStateMap = newParentSessionStore diff --git a/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java index 8a57b0c58b228..41c4bf9e711d5 100644 --- a/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java @@ -75,7 +75,6 @@ public void testInitialization() { Assert.assertNotNull(ssc.sparkContext()); } - @SuppressWarnings("unchecked") @Test public void testContextState() { List> inputData = Arrays.asList(Arrays.asList(1, 2, 3, 4)); @@ -89,7 +88,6 @@ public void testContextState() { Assert.assertEquals(StreamingContextState.STOPPED, ssc.getState()); } - @SuppressWarnings("unchecked") @Test public void testCount() { List> inputData = Arrays.asList( @@ -109,7 +107,6 @@ public void testCount() { assertOrderInvariantEquals(expected, result); } - @SuppressWarnings("unchecked") @Test public void testMap() { List> inputData = Arrays.asList( @@ -128,7 +125,6 @@ public void testMap() { assertOrderInvariantEquals(expected, result); } - @SuppressWarnings("unchecked") @Test public void testWindow() { List> inputData = Arrays.asList( @@ -150,7 +146,6 @@ public void testWindow() { assertOrderInvariantEquals(expected, result); } - @SuppressWarnings("unchecked") @Test public void testWindowWithSlideDuration() { List> inputData = Arrays.asList( @@ -175,7 +170,6 @@ public void testWindowWithSlideDuration() { assertOrderInvariantEquals(expected, result); } - @SuppressWarnings("unchecked") @Test public void testFilter() { List> inputData = Arrays.asList( @@ -194,7 +188,6 @@ public void testFilter() { assertOrderInvariantEquals(expected, result); } - @SuppressWarnings("unchecked") @Test public void testRepartitionMorePartitions() { List> inputData = Arrays.asList( @@ -214,7 +207,6 @@ public void testRepartitionMorePartitions() { } } - @SuppressWarnings("unchecked") @Test public void testRepartitionFewerPartitions() { List> inputData = Arrays.asList( @@ -233,7 +225,6 @@ public void testRepartitionFewerPartitions() { } } - @SuppressWarnings("unchecked") @Test public void testGlom() { List> inputData = Arrays.asList( @@ -252,7 +243,6 @@ public void testGlom() { Assert.assertEquals(expected, result); } - @SuppressWarnings("unchecked") @Test public void testMapPartitions() { List> inputData = Arrays.asList( @@ -291,7 +281,6 @@ public Integer call(Integer i1, Integer i2) { } } - @SuppressWarnings("unchecked") @Test public void testReduce() { List> inputData = Arrays.asList( @@ -312,19 +301,16 @@ public void testReduce() { Assert.assertEquals(expected, result); } - @SuppressWarnings("unchecked") @Test public void testReduceByWindowWithInverse() { testReduceByWindow(true); } - @SuppressWarnings("unchecked") @Test public void testReduceByWindowWithoutInverse() { testReduceByWindow(false); } - @SuppressWarnings("unchecked") private void testReduceByWindow(boolean withInverse) { List> inputData = Arrays.asList( Arrays.asList(1,2,3), @@ -354,7 +340,6 @@ private void testReduceByWindow(boolean withInverse) { Assert.assertEquals(expected, result); } - @SuppressWarnings("unchecked") @Test public void testQueueStream() { ssc.stop(); @@ -386,7 +371,6 @@ public void testQueueStream() { Assert.assertEquals(expected, result); } - @SuppressWarnings("unchecked") @Test public void testTransform() { List> inputData = Arrays.asList( @@ -408,7 +392,6 @@ public void testTransform() { assertOrderInvariantEquals(expected, result); } - @SuppressWarnings("unchecked") @Test public void testVariousTransform() { // tests whether all variations of transform can be called from Java @@ -495,7 +478,6 @@ public void testTransformWith() { } - @SuppressWarnings("unchecked") @Test public void testVariousTransformWith() { // tests whether all variations of transformWith can be called from Java @@ -593,7 +575,6 @@ public void testStreamingContextTransform(){ Assert.assertEquals(expected, result); } - @SuppressWarnings("unchecked") @Test public void testFlatMap() { List> inputData = Arrays.asList( @@ -615,7 +596,6 @@ public void testFlatMap() { assertOrderInvariantEquals(expected, result); } - @SuppressWarnings("unchecked") @Test public void testForeachRDD() { final LongAccumulator accumRdd = ssc.sparkContext().sc().longAccumulator(); @@ -641,7 +621,6 @@ public void testForeachRDD() { Assert.assertEquals(6, accumEle.value().intValue()); } - @SuppressWarnings("unchecked") @Test public void testPairFlatMap() { List> inputData = Arrays.asList( @@ -690,7 +669,6 @@ public void testPairFlatMap() { Assert.assertEquals(expected, result); } - @SuppressWarnings("unchecked") @Test public void testUnion() { List> inputData1 = Arrays.asList( @@ -737,7 +715,6 @@ public static void assertOrderInvariantEquals( // PairDStream Functions - @SuppressWarnings("unchecked") @Test public void testPairFilter() { List> inputData = Arrays.asList( @@ -759,7 +736,6 @@ public void testPairFilter() { Assert.assertEquals(expected, result); } - @SuppressWarnings("unchecked") private final List>> stringStringKVStream = Arrays.asList( Arrays.asList(new Tuple2<>("california", "dodgers"), new Tuple2<>("california", "giants"), @@ -770,7 +746,6 @@ public void testPairFilter() { new Tuple2<>("new york", "rangers"), new Tuple2<>("new york", "islanders"))); - @SuppressWarnings("unchecked") private final List>> stringIntKVStream = Arrays.asList( Arrays.asList( new Tuple2<>("california", 1), @@ -783,7 +758,6 @@ public void testPairFilter() { new Tuple2<>("new york", 3), new Tuple2<>("new york", 1))); - @SuppressWarnings("unchecked") @Test public void testPairMap() { // Maps pair -> pair of different type List>> inputData = stringIntKVStream; @@ -811,7 +785,6 @@ public void testPairMap() { // Maps pair -> pair of different type Assert.assertEquals(expected, result); } - @SuppressWarnings("unchecked") @Test public void testPairMapPartitions() { // Maps pair -> pair of different type List>> inputData = stringIntKVStream; @@ -846,7 +819,6 @@ public void testPairMapPartitions() { // Maps pair -> pair of different type Assert.assertEquals(expected, result); } - @SuppressWarnings("unchecked") @Test public void testPairMap2() { // Maps pair -> single List>> inputData = stringIntKVStream; @@ -866,7 +838,6 @@ public void testPairMap2() { // Maps pair -> single Assert.assertEquals(expected, result); } - @SuppressWarnings("unchecked") @Test public void testPairToPairFlatMapWithChangingTypes() { // Maps pair -> pair List>> inputData = Arrays.asList( @@ -905,7 +876,6 @@ public void testPairToPairFlatMapWithChangingTypes() { // Maps pair -> pair Assert.assertEquals(expected, result); } - @SuppressWarnings("unchecked") @Test public void testPairGroupByKey() { List>> inputData = stringStringKVStream; @@ -942,7 +912,6 @@ public void testPairGroupByKey() { } } - @SuppressWarnings("unchecked") @Test public void testPairReduceByKey() { List>> inputData = stringIntKVStream; @@ -967,7 +936,6 @@ public void testPairReduceByKey() { Assert.assertEquals(expected, result); } - @SuppressWarnings("unchecked") @Test public void testCombineByKey() { List>> inputData = stringIntKVStream; @@ -993,7 +961,6 @@ public void testCombineByKey() { Assert.assertEquals(expected, result); } - @SuppressWarnings("unchecked") @Test public void testCountByValue() { List> inputData = Arrays.asList( @@ -1019,7 +986,6 @@ public void testCountByValue() { Assert.assertEquals(expected, result); } - @SuppressWarnings("unchecked") @Test public void testGroupByKeyAndWindow() { List>> inputData = stringIntKVStream; @@ -1067,7 +1033,6 @@ private static Tuple2> convert(Tuple2(tuple._1(), new HashSet<>(tuple._2())); } - @SuppressWarnings("unchecked") @Test public void testReduceByKeyAndWindow() { List>> inputData = stringIntKVStream; @@ -1092,7 +1057,6 @@ public void testReduceByKeyAndWindow() { Assert.assertEquals(expected, result); } - @SuppressWarnings("unchecked") @Test public void testUpdateStateByKey() { List>> inputData = stringIntKVStream; @@ -1125,7 +1089,6 @@ public void testUpdateStateByKey() { Assert.assertEquals(expected, result); } - @SuppressWarnings("unchecked") @Test public void testUpdateStateByKeyWithInitial() { List>> inputData = stringIntKVStream; @@ -1165,7 +1128,6 @@ public void testUpdateStateByKeyWithInitial() { assertOrderInvariantEquals(expected, result); } - @SuppressWarnings("unchecked") @Test public void testReduceByKeyAndWindowWithInverse() { List>> inputData = stringIntKVStream; @@ -1225,7 +1187,6 @@ public void testCountByValueAndWindow() { Assert.assertEquals(expected, unorderedResult); } - @SuppressWarnings("unchecked") @Test public void testPairTransform() { List>> inputData = Arrays.asList( @@ -1264,7 +1225,6 @@ public void testPairTransform() { Assert.assertEquals(expected, result); } - @SuppressWarnings("unchecked") @Test public void testPairToNormalRDDTransform() { List>> inputData = Arrays.asList( @@ -1295,7 +1255,6 @@ public void testPairToNormalRDDTransform() { Assert.assertEquals(expected, result); } - @SuppressWarnings("unchecked") @Test public void testMapValues() { List>> inputData = stringStringKVStream; @@ -1323,7 +1282,6 @@ public void testMapValues() { Assert.assertEquals(expected, result); } - @SuppressWarnings("unchecked") @Test public void testFlatMapValues() { List>> inputData = stringStringKVStream; @@ -1364,7 +1322,6 @@ public void testFlatMapValues() { Assert.assertEquals(expected, result); } - @SuppressWarnings("unchecked") @Test public void testCoGroup() { List>> stringStringKVStream1 = Arrays.asList( @@ -1430,7 +1387,6 @@ public void testCoGroup() { } } - @SuppressWarnings("unchecked") @Test public void testJoin() { List>> stringStringKVStream1 = Arrays.asList( @@ -1474,7 +1430,6 @@ public void testJoin() { Assert.assertEquals(expected, result); } - @SuppressWarnings("unchecked") @Test public void testLeftOuterJoin() { List>> stringStringKVStream1 = Arrays.asList( @@ -1507,7 +1462,6 @@ public void testLeftOuterJoin() { Assert.assertEquals(expected, result); } - @SuppressWarnings("unchecked") @Test public void testCheckpointMasterRecovery() throws InterruptedException { List> inputData = Arrays.asList( @@ -1543,7 +1497,6 @@ public void testCheckpointMasterRecovery() throws InterruptedException { Utils.deleteRecursively(tempDir); } - @SuppressWarnings("unchecked") @Test public void testContextGetOrCreate() throws InterruptedException { ssc.stop(); @@ -1648,7 +1601,6 @@ public void testSocketString() { StorageLevel.MEMORY_ONLY()); } - @SuppressWarnings("unchecked") @Test public void testTextFileStream() throws IOException { File testDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); @@ -1661,7 +1613,6 @@ public void testTextFileStream() throws IOException { assertOrderInvariantEquals(expected, result); } - @SuppressWarnings("unchecked") @Test public void testFileStream() throws IOException { File testDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala index 0576bf560f30e..dad324b53dd04 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala @@ -90,7 +90,7 @@ class DStreamClosureSuite extends SparkFunSuite with LocalStreamingContext with ds.filter { _ => return; true } } private def testMapPartitions(ds: DStream[Int]): Unit = expectCorrectException { - ds.mapPartitions { _ => return; Seq.empty.toIterator } + ds.mapPartitions { _ => return; Seq.empty.iterator } } private def testReduce(ds: DStream[Int]): Unit = expectCorrectException { ds.reduce { case (_, _) => return; 1 } @@ -153,7 +153,7 @@ class DStreamClosureSuite extends SparkFunSuite with LocalStreamingContext with } private def testUpdateStateByKey(ds: DStream[(Int, Int)]): Unit = { val updateF1 = (_: Seq[Int], _: Option[Int]) => { return; Some(1) } - val updateF2 = (_: Iterator[(Int, Seq[Int], Option[Int])]) => { return; Seq((1, 1)).toIterator } + val updateF2 = (_: Iterator[(Int, Seq[Int], Option[Int])]) => { return; Seq((1, 1)).iterator } val updateF3 = (_: Time, _: Int, _: Seq[Int], _: Option[Int]) => { return Option(1) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 03182ae64db3d..174c3ca379363 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -365,7 +365,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Setup data queued into the stream val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val inputIterator = input.toIterator + val inputIterator = input.iterator for (i <- input.indices) { // Enqueue more than 1 item per tick but they should dequeue one at a time inputIterator.take(2).foreach { i => @@ -411,7 +411,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] // Enqueue the first 3 items (one by one), they should be merged in the next batch - val inputIterator = input.toIterator + val inputIterator = input.iterator inputIterator.take(3).foreach { i => queue.synchronized { queue += ssc.sparkContext.makeRDD(Seq(i)) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 3bcea1ab2c680..a3b5b38904a2e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -363,7 +363,7 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean) val blocks = data.grouped(10).toSeq - storeAndVerify(blocks.map { b => IteratorBlock(b.toIterator) }) + storeAndVerify(blocks.map { b => IteratorBlock(b.iterator) }) storeAndVerify(blocks.map { b => ArrayBufferBlock(new ArrayBuffer ++= b) }) storeAndVerify(blocks.map { b => ByteBufferBlock(dataToByteBuffer(b).toByteBuffer) }) } @@ -372,7 +372,7 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean) private def testErrorHandling(receivedBlockHandler: ReceivedBlockHandler): Unit = { // Handle error in iterator (e.g. divide-by-zero error) intercept[Exception] { - val iterator = (10 to (-10, -1)).toIterator.map { _ / 0 } + val iterator = (10 to (-10, -1)).iterator.map { _ / 0 } receivedBlockHandler.storeBlock(StreamBlockId(1, 1), IteratorBlock(iterator)) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index b54d60aa29c4f..08121a38dc5d5 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -21,7 +21,9 @@ import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.JavaConverters._ import scala.collection.mutable.HashMap +// scalastyle:off executioncontextglobal import scala.concurrent.ExecutionContext.Implicits.global +// scalastyle:on executioncontextglobal import scala.concurrent.Future import org.mockito.Mockito.{mock, reset, verifyNoMoreInteractions} diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index a6fee8616df11..ef28095850bad 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -69,7 +69,7 @@ object GenerateMIMAIgnore { /* Inner classes defined within a private[spark] class or object are effectively invisible, so we account for them as package private. */ lazy val indirectlyPrivateSpark = { - val maybeOuter = className.toString.takeWhile(_ != '$') + val maybeOuter = className.takeWhile(_ != '$') if (maybeOuter != className) { isPackagePrivate(mirror.classSymbol(Class.forName(maybeOuter, false, classLoader))) || isPackagePrivateModule(mirror.staticModule(maybeOuter))