diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 553a961109ab0..e51c896cc3f99 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -276,7 +276,7 @@ jobs: - name: Install Python packages (Python 3.11) if: (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-')) || contains(matrix.modules, 'connect') run: | - python3.11 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy unittest-xml-reporting 'lxml==4.9.4' 'grpcio==1.62.0' 'grpcio-status==1.62.0' 'protobuf==5.28.3' + python3.11 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy unittest-xml-reporting 'lxml==4.9.4' 'grpcio==1.67.0' 'grpcio-status==1.67.0' 'protobuf==5.28.3' python3.11 -m pip list # Run the tests. - name: Run tests @@ -702,13 +702,6 @@ jobs: run: ./dev/lint-java - name: Spark connect jvm client mima check run: ./dev/connect-jvm-client-mima-check - - name: Install Python linter dependencies for branch-3.4 - if: inputs.branch == 'branch-3.4' - run: | - # SPARK-44554: Copy from https://github.com/apache/spark/blob/a05c27e85829fe742c1828507a1fd180cdc84b54/.github/workflows/build_and_test.yml#L571-L578 - # Should delete this section after SPARK 3.4 EOL. - python3.9 -m pip install 'flake8==3.9.0' pydata_sphinx_theme 'mypy==0.920' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' numpydoc 'jinja2<3.0.0' 'black==22.6.0' - python3.9 -m pip install 'pandas-stubs==1.2.0.53' ipython 'grpcio==1.48.1' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' - name: Install Python linter dependencies for branch-3.5 if: inputs.branch == 'branch-3.5' run: | @@ -717,7 +710,7 @@ jobs: python3.9 -m pip install 'flake8==3.9.0' pydata_sphinx_theme 'mypy==0.982' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' numpydoc 'jinja2<3.0.0' 'black==22.6.0' python3.9 -m pip install 'pandas-stubs==1.2.0.53' ipython 'grpcio==1.56.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' - name: Install Python dependencies for python linter and documentation generation - if: inputs.branch != 'branch-3.4' && inputs.branch != 'branch-3.5' + if: inputs.branch != 'branch-3.5' run: | # Should unpin 'sphinxcontrib-*' after upgrading sphinx>5 # See 'ipython_genutils' in SPARK-38517 @@ -725,7 +718,7 @@ jobs: python3.9 -m pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13' sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe 'pyzmq<24.0.0' \ ipython ipython_genutils sphinx_plotly_directive numpy pyarrow pandas 'plotly>=4.8' 'docutils<0.18.0' \ 'flake8==3.9.0' 'mypy==1.8.0' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'black==23.9.1' \ - 'pandas-stubs==1.2.0.53' 'grpcio==1.62.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \ + 'pandas-stubs==1.2.0.53' 'grpcio==1.67.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \ 'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5' python3.9 -m pip list - name: Python linter @@ -745,16 +738,16 @@ jobs: if: inputs.branch == 'branch-3.5' run: if test -f ./dev/connect-check-protos.py; then PATH=$PATH:$HOME/buf/bin PYTHON_EXECUTABLE=python3.9 ./dev/connect-check-protos.py; fi # Should delete this section after SPARK 3.5 EOL. - - name: Install JavaScript linter dependencies for branch-3.4, branch-3.5 - if: inputs.branch == 'branch-3.4' || inputs.branch == 'branch-3.5' + - name: Install JavaScript linter dependencies for branch-3.5 + if: inputs.branch == 'branch-3.5' run: | apt update apt-get install -y nodejs npm - name: JS linter run: ./dev/lint-js # Should delete this section after SPARK 3.5 EOL. - - name: Install R linter dependencies for branch-3.4, branch-3.5 - if: inputs.branch == 'branch-3.4' || inputs.branch == 'branch-3.5' + - name: Install R linter dependencies for branch-3.5 + if: inputs.branch == 'branch-3.5' run: | apt update apt-get install -y libcurl4-openssl-dev libgit2-dev libssl-dev libxml2-dev \ @@ -834,7 +827,7 @@ jobs: distribution: zulu java-version: ${{ inputs.java }} - name: Install Python dependencies for python linter and documentation generation - if: inputs.branch != 'branch-3.4' && inputs.branch != 'branch-3.5' + if: inputs.branch != 'branch-3.5' run: | # Should unpin 'sphinxcontrib-*' after upgrading sphinx>5 # See 'ipython_genutils' in SPARK-38517 @@ -845,8 +838,8 @@ jobs: 'pandas-stubs==1.2.0.53' 'grpcio==1.62.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \ 'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5' python3.9 -m pip list - - name: Install dependencies for documentation generation for branch-3.4, branch-3.5 - if: inputs.branch == 'branch-3.4' || inputs.branch == 'branch-3.5' + - name: Install dependencies for documentation generation for branch-3.5 + if: inputs.branch == 'branch-3.5' run: | # pandoc is required to generate PySpark APIs as well in nbsphinx. apt-get update -y @@ -1134,7 +1127,7 @@ jobs: export PVC_TESTS_VM_PATH=$PVC_TMP_DIR minikube mount ${PVC_TESTS_HOST_PATH}:${PVC_TESTS_VM_PATH} --gid=0 --uid=185 & kubectl create clusterrolebinding serviceaccounts-cluster-admin --clusterrole=cluster-admin --group=system:serviceaccounts || true - if [[ "${{ inputs.branch }}" == 'branch-3.5' || "${{ inputs.branch }}" == 'branch-3.4' ]]; then + if [[ "${{ inputs.branch }}" == 'branch-3.5' ]]; then kubectl apply -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.7.0/installer/volcano-development.yaml || true else kubectl apply -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.9.0/installer/volcano-development.yaml || true diff --git a/.github/workflows/build_branch34.yml b/.github/workflows/build_branch34.yml deleted file mode 100644 index deb6c42407970..0000000000000 --- a/.github/workflows/build_branch34.yml +++ /dev/null @@ -1,51 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -name: "Build (branch-3.4, Scala 2.13, Hadoop 3, JDK 8)" - -on: - schedule: - - cron: '0 9 * * *' - -jobs: - run-build: - permissions: - packages: write - name: Run - uses: ./.github/workflows/build_and_test.yml - if: github.repository == 'apache/spark' - with: - java: 8 - branch: branch-3.4 - hadoop: hadoop3 - envs: >- - { - "SCALA_PROFILE": "scala2.13", - "PYTHON_TO_TEST": "", - "ORACLE_DOCKER_IMAGE_NAME": "gvenzl/oracle-xe:21.3.0" - } - jobs: >- - { - "build": "true", - "sparkr": "true", - "tpcds-1g": "true", - "docker-integration-tests": "true", - "k8s-integration-tests": "true", - "lint" : "true" - } diff --git a/.github/workflows/build_branch34_python.yml b/.github/workflows/build_branch34_python.yml deleted file mode 100644 index c109ba2dc7922..0000000000000 --- a/.github/workflows/build_branch34_python.yml +++ /dev/null @@ -1,45 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -name: "Build / Python-only (branch-3.4)" - -on: - schedule: - - cron: '0 9 * * *' - -jobs: - run-build: - permissions: - packages: write - name: Run - uses: ./.github/workflows/build_and_test.yml - if: github.repository == 'apache/spark' - with: - java: 8 - branch: branch-3.4 - hadoop: hadoop3 - envs: >- - { - "PYTHON_TO_TEST": "" - } - jobs: >- - { - "pyspark": "true", - "pyspark-pandas": "true" - } diff --git a/.github/workflows/maven_test.yml b/.github/workflows/maven_test.yml index 22153fe2f980c..6965fb4968af3 100644 --- a/.github/workflows/maven_test.yml +++ b/.github/workflows/maven_test.yml @@ -178,7 +178,7 @@ jobs: - name: Install Python packages (Python 3.11) if: (contains(matrix.modules, 'sql#core')) || contains(matrix.modules, 'connect') run: | - python3.11 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy unittest-xml-reporting 'grpcio==1.62.0' 'grpcio-status==1.62.0' 'protobuf==5.28.3' + python3.11 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy unittest-xml-reporting 'grpcio==1.67.0' 'grpcio-status==1.67.0' 'protobuf==5.28.3' python3.11 -m pip list # Run the tests. - name: Run tests diff --git a/.github/workflows/publish_snapshot.yml b/.github/workflows/publish_snapshot.yml index 1b5bd0ba61288..a5854d96a4d1a 100644 --- a/.github/workflows/publish_snapshot.yml +++ b/.github/workflows/publish_snapshot.yml @@ -28,7 +28,7 @@ on: description: 'list of branches to publish (JSON)' required: true # keep in sync with default value of strategy matrix 'branch' - default: '["master", "branch-3.5", "branch-3.4"]' + default: '["master", "branch-3.5"]' jobs: publish-snapshot: @@ -38,7 +38,7 @@ jobs: fail-fast: false matrix: # keep in sync with default value of workflow_dispatch input 'branch' - branch: ${{ fromJSON( inputs.branch || '["master", "branch-3.5", "branch-3.4"]' ) }} + branch: ${{ fromJSON( inputs.branch || '["master", "branch-3.5"]' ) }} steps: - name: Checkout Spark repository uses: actions/checkout@v4 @@ -52,13 +52,13 @@ jobs: restore-keys: | snapshot-maven- - name: Install Java 8 for branch-3.x - if: matrix.branch == 'branch-3.5' || matrix.branch == 'branch-3.4' + if: matrix.branch == 'branch-3.5' uses: actions/setup-java@v4 with: distribution: temurin java-version: 8 - name: Install Java 17 - if: matrix.branch != 'branch-3.5' && matrix.branch != 'branch-3.4' + if: matrix.branch != 'branch-3.5' uses: actions/setup-java@v4 with: distribution: temurin diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java index d67697eaea38b..97c8bbe562aff 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java @@ -1080,6 +1080,24 @@ public static UTF8String translate(final UTF8String input, return UTF8String.fromString(sb.toString()); } + /** + * Trims the `srcString` string from both ends of the string using the specified `trimString` + * characters, with respect to the UTF8_BINARY trim collation. String trimming is performed by + * first trimming the left side of the string, and then trimming the right side of the string. + * The method returns the trimmed string. If the `trimString` is null, the method returns null. + * + * @param srcString the input string to be trimmed from both ends of the string + * @param trimString the trim string characters to trim + * @param collationId the collation ID to use for string trim + * @return the trimmed string (for UTF8_BINARY collation) + */ + public static UTF8String binaryTrim( + final UTF8String srcString, + final UTF8String trimString, + final int collationId) { + return binaryTrimRight(srcString.trimLeft(trimString), trimString, collationId); + } + /** * Trims the `srcString` string from both ends of the string using the specified `trimString` * characters, with respect to the UTF8_LCASE collation. String trimming is performed by @@ -1088,12 +1106,14 @@ public static UTF8String translate(final UTF8String input, * * @param srcString the input string to be trimmed from both ends of the string * @param trimString the trim string characters to trim + * @param collationId the collation ID to use for string trim * @return the trimmed string (for UTF8_LCASE collation) */ public static UTF8String lowercaseTrim( final UTF8String srcString, - final UTF8String trimString) { - return lowercaseTrimRight(lowercaseTrimLeft(srcString, trimString), trimString); + final UTF8String trimString, + final int collationId) { + return lowercaseTrimRight(lowercaseTrimLeft(srcString, trimString), trimString, collationId); } /** @@ -1121,7 +1141,8 @@ public static UTF8String trim( * the left side, until reaching a character whose lowercased code point is not in the hash set. * Finally, the method returns the substring from that position to the end of `srcString`. * If `trimString` is null, null is returned. If `trimString` is empty, `srcString` is returned. - * + * Note: as currently only trimming collation supported is RTRIM, trimLeft is not modified + * to support other trim collations, this should be done in case of adding TRIM and LTRIM. * @param srcString the input string to be trimmed from the left end of the string * @param trimString the trim string characters to trim * @return the trimmed string (for UTF8_LCASE collation) @@ -1184,7 +1205,9 @@ public static UTF8String lowercaseTrimLeft( * character in `trimString`, until reaching a character that is not found in `trimString`. * Finally, the method returns the substring from that position to the end of `srcString`. * If `trimString` is null, null is returned. If `trimString` is empty, `srcString` is returned. - * + * Note: as currently only trimming collation supported is RTRIM, trimLeft is not modified + * to support other trim collations, this should be done in case of adding TRIM and LTRIM + * collation. * @param srcString the input string to be trimmed from the left end of the string * @param trimString the trim string characters to trim * @param collationId the collation ID to use for string trimming @@ -1232,22 +1255,103 @@ public static UTF8String trimLeft( // Return the substring from the calculated position until the end of the string. return UTF8String.fromString(src.substring(charIndex)); } + /** + * Trims the `srcString` string from the right side using the specified `trimString` characters, + * with respect to the UTF8_BINARY trim collation. For UTF8_BINARY trim collation, the method has + * one special case to cover with respect to trimRight function for regular UTF8_Binary collation. + * Trailing spaces should be ignored in case of trim collation (rtrim for example) and if + * trimString does not contain spaces. In this case, the method trims the string from the right + * and after call of regular trim functions returns back trimmed spaces as those should not get + * removed. + * @param srcString the input string to be trimmed from the right end of the string + * @param trimString the trim string characters to trim + * @param collationId the collation ID to use for string trim + * @return the trimmed string (for UTF_BINARY collation) + */ + public static UTF8String binaryTrimRight( + final UTF8String srcString, + final UTF8String trimString, + final int collationId) { + // Matching the default UTF8String behavior for null `trimString`. + if (trimString == null) { + return null; + } + + // Create a hash set of code points for all characters of `trimString`. + HashSet trimChars = new HashSet<>(); + Iterator trimIter = trimString.codePointIterator(); + while (trimIter.hasNext()) trimChars.add(trimIter.next()); + + // Iterate over `srcString` from the right to find the first character that is not in the set. + int searchIndex = srcString.numChars(), codePoint, codePointBuffer = -1; + + // In cases of trim collation (rtrim for example) trailing spaces should be ignored. + // If trimString contains spaces this behaviour is not important as they would get trimmed + // anyway. However, if it is not the case they should be ignored and then appended after + // trimming other characters. + int lastNonSpaceByteIdx = srcString.numBytes(), lastNonSpaceCharacterIdx = srcString.numChars(); + if (!trimChars.contains(SpecialCodePointConstants.ASCII_SPACE) && + CollationFactory.ignoresSpacesInTrimFunctions( + collationId, /*isLTrim=*/ false, /*isRTrim=*/true)) { + while (lastNonSpaceByteIdx > 0 && + srcString.getByte(lastNonSpaceByteIdx - 1) == ' ') { + --lastNonSpaceByteIdx; + } + // In case of src string contains only spaces there is no need to do any trimming, since it's + // already checked that trim string does not contain any spaces. + if (lastNonSpaceByteIdx == 0) { + return srcString; + } + searchIndex = lastNonSpaceCharacterIdx = + srcString.numChars() - (srcString.numBytes() - lastNonSpaceByteIdx); + } + Iterator srcIter = srcString.reverseCodePointIterator(); + for (int i = lastNonSpaceCharacterIdx; i < srcString.numChars(); i++) { + srcIter.next(); + } + + while (srcIter.hasNext()) { + codePoint = srcIter.next(); + if (trimChars.contains(codePoint)) { + --searchIndex; + } + else { + break; + } + } + + // Return the substring from the start of the string to the calculated position and append + // trailing spaces if they were ignored + if (searchIndex == srcString.numChars()) { + return srcString; + } + if (lastNonSpaceCharacterIdx == srcString.numChars()) { + return srcString.substring(0, searchIndex); + } + return UTF8String.concat( + srcString.substring(0, searchIndex), + srcString.substring(lastNonSpaceCharacterIdx, srcString.numChars())); + } /** * Trims the `srcString` string from the right side using the specified `trimString` characters, * with respect to the UTF8_LCASE collation. For UTF8_LCASE, the method first creates a hash * set of lowercased code points in `trimString`, and then iterates over the `srcString` from * the right side, until reaching a character whose lowercased code point is not in the hash set. + * In case of UTF8_LCASE trim collation and when trimString does not contain spaces, trailing + * spaces should be ignored. However, after trimming function call they should be appended back. * Finally, the method returns the substring from the start of `srcString` until that position. * If `trimString` is null, null is returned. If `trimString` is empty, `srcString` is returned. * * @param srcString the input string to be trimmed from the right end of the string * @param trimString the trim string characters to trim + * @param collationId the collation ID to use for string trim * @return the trimmed string (for UTF8_LCASE collation) */ public static UTF8String lowercaseTrimRight( final UTF8String srcString, - final UTF8String trimString) { + final UTF8String trimString, + final int collationId) { // Matching the default UTF8String behavior for null `trimString`. if (trimString == null) { return null; @@ -1260,7 +1364,32 @@ public static UTF8String lowercaseTrimRight( // Iterate over `srcString` from the right to find the first character that is not in the set. int searchIndex = srcString.numChars(), codePoint, codePointBuffer = -1; + + // In cases of trim collation (rtrim for example) trailing spaces should be ignored. + // If trimString contains spaces this behaviour is not important as they would get trimmed + // anyway. However, if it is not the case they should be ignored and then appended after + // trimming other characters. + int lastNonSpaceByteIdx = srcString.numBytes(), lastNonSpaceCharacterIdx = srcString.numChars(); + if (!trimChars.contains(SpecialCodePointConstants.ASCII_SPACE) && + CollationFactory.ignoresSpacesInTrimFunctions( + collationId, /*isLTrim=*/ false, /*isRTrim=*/true)) { + while (lastNonSpaceByteIdx > 0 && + srcString.getByte(lastNonSpaceByteIdx - 1) == ' ') { + --lastNonSpaceByteIdx; + } + // In case of src string contains only spaces there is no need to do any trimming, since it's + // already checked that trim string does not contain any spaces. + if (lastNonSpaceByteIdx == 0) { + return srcString; + } + searchIndex = lastNonSpaceCharacterIdx = + srcString.numChars() - (srcString.numBytes() - lastNonSpaceByteIdx); + } Iterator srcIter = srcString.reverseCodePointIterator(); + for (int i = lastNonSpaceCharacterIdx; i < srcString.numChars(); i++) { + srcIter.next(); + } + while (srcIter.hasNext()) { if (codePointBuffer != -1) { codePoint = codePointBuffer; @@ -1291,8 +1420,17 @@ public static UTF8String lowercaseTrimRight( } } - // Return the substring from the start of the string to the calculated position. - return searchIndex == srcString.numChars() ? srcString : srcString.substring(0, searchIndex); + // Return the substring from the start of the string to the calculated position and append + // trailing spaces if they were ignored + if (searchIndex == srcString.numChars()) { + return srcString; + } + if (lastNonSpaceCharacterIdx == srcString.numChars()) { + return srcString.substring(0, searchIndex); + } + return UTF8String.concat( + srcString.substring(0, searchIndex), + srcString.substring(lastNonSpaceCharacterIdx, srcString.numChars())); } /** @@ -1329,7 +1467,26 @@ public static UTF8String trimRight( String src = srcString.toValidString(); CharacterIterator target = new StringCharacterIterator(src); Collator collator = CollationFactory.fetchCollation(collationId).collator; - int charIndex = src.length(), longestMatchLen; + int charIndex = src.length(), longestMatchLen, lastNonSpacePosition = src.length(); + + // In cases of trim collation (rtrim for example) trailing spaces should be ignored. + // If trimString contains spaces this behaviour is not important as they would get trimmed + // anyway. However, if it is not the case they should be ignored and then appended after + // trimming other characters. + if (!trimChars.containsKey(SpecialCodePointConstants.ASCII_SPACE) && + CollationFactory.ignoresSpacesInTrimFunctions( + collationId, /*isLTrim=*/ false, /*isRTrim=*/true)) { + while (lastNonSpacePosition > 0 && src.charAt(lastNonSpacePosition - 1) == ' ') { + --lastNonSpacePosition; + } + // In case of src string contains only spaces there is no need to do any trimming, since it's + // already checked that trim string does not contain any spaces. + if (lastNonSpacePosition == 0) { + return UTF8String.fromString(src); + } + charIndex = lastNonSpacePosition; + } + while (charIndex >= 0) { longestMatchLen = 0; for (String trim : trimChars.values()) { @@ -1357,8 +1514,18 @@ public static UTF8String trimRight( else charIndex -= longestMatchLen; } - // Return the substring from the start of the string until that position. - return UTF8String.fromString(src.substring(0, charIndex)); + // Return the substring from the start of the string until that position and append + // trailing spaces if they were ignored + if (charIndex == src.length()) { + return srcString; + } + if (lastNonSpacePosition == srcString.numChars()) { + return UTF8String.fromString(src.substring(0, charIndex)); + } + return UTF8String.fromString( + src.substring(0, charIndex) + + src.substring(lastNonSpacePosition) + ); } public static UTF8String[] splitSQL(final UTF8String input, final UTF8String delim, diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 1305d82bcd785..3117854a432b1 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -359,6 +359,23 @@ protected static UTF8String applyTrimmingPolicy(UTF8String s, int collationId) { return applyTrimmingPolicy(s, getSpaceTrimming(collationId)); } + /** + * Returns if leading/trailing spaces should be ignored in trim string expressions. This is + * needed because space trimming collation directly changes behaviour of trim functions. + */ + protected static boolean ignoresSpacesInTrimFunctions( + int collationId, + boolean isLTrim, + boolean isRTrim) { + if (isRTrim && getSpaceTrimming(collationId) == SpaceTrimming.RTRIM) { + return true; + } + + // In case of adding new trimming collations in the future (LTRIM and TRIM) here logic + // should be added. + return false; + } + /** * Utility function to trim spaces when collation uses space trimming. */ @@ -1200,6 +1217,24 @@ public static String[] getICULocaleNames() { return Collation.CollationSpecICU.ICULocaleNames; } + /** + * Applies trimming policy depending up on trim collation type. + */ + public static UTF8String applyTrimmingPolicy(UTF8String input, int collationId) { + return Collation.CollationSpec.applyTrimmingPolicy(input, collationId); + } + + /** + * Returns if leading/trailing spaces should be ignored in trim string expressions. This is needed + * because space trimming collation directly changes behaviour of trim functions. + */ + public static boolean ignoresSpacesInTrimFunctions( + int collationId, + boolean isLTrim, + boolean isRTrim) { + return Collation.CollationSpec.ignoresSpacesInTrimFunctions(collationId, isLTrim, isRTrim); + } + public static UTF8String getCollationKey(UTF8String input, int collationId) { Collation collation = fetchCollation(collationId); if (collation.supportsSpaceTrimming) { diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index 978b663cc25c9..135250e482b16 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -35,8 +35,11 @@ public final class CollationSupport { */ public static class StringSplitSQL { - public static UTF8String[] exec(final UTF8String s, final UTF8String d, final int collationId) { + public static UTF8String[] exec(final UTF8String s, UTF8String d, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsSpaceTrimming) { + d = CollationFactory.applyTrimmingPolicy(d, collationId); + } if (collation.isUtf8BinaryType) { return execBinary(s, d); } else if (collation.isUtf8LcaseType) { @@ -46,14 +49,11 @@ public static UTF8String[] exec(final UTF8String s, final UTF8String d, final in } } public static String genCode(final String s, final String d, final int collationId) { - CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringSplitSQL.exec"; - if (collation.isUtf8BinaryType) { + if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { return String.format(expr + "Binary(%s, %s)", s, d); - } else if (collation.isUtf8LcaseType) { - return String.format(expr + "Lowercase(%s, %s)", s, d); } else { - return String.format(expr + "ICU(%s, %s, %d)", s, d, collationId); + return String.format(expr + "(%s, %s, %d)", s, d, collationId); } } public static UTF8String[] execBinary(final UTF8String string, final UTF8String delimiter) { @@ -69,8 +69,12 @@ public static UTF8String[] execICU(final UTF8String string, final UTF8String del } public static class Contains { - public static boolean exec(final UTF8String l, final UTF8String r, final int collationId) { + public static boolean exec(UTF8String l, UTF8String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsSpaceTrimming) { + l = CollationFactory.applyTrimmingPolicy(l, collationId); + r = CollationFactory.applyTrimmingPolicy(r, collationId); + } if (collation.isUtf8BinaryType) { return execBinary(l, r); } else if (collation.isUtf8LcaseType) { @@ -80,14 +84,11 @@ public static boolean exec(final UTF8String l, final UTF8String r, final int col } } public static String genCode(final String l, final String r, final int collationId) { - CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.Contains.exec"; - if (collation.isUtf8BinaryType) { + if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { return String.format(expr + "Binary(%s, %s)", l, r); - } else if (collation.isUtf8LcaseType) { - return String.format(expr + "Lowercase(%s, %s)", l, r); } else { - return String.format(expr + "ICU(%s, %s, %d)", l, r, collationId); + return String.format(expr + "(%s, %s, %d)", l, r, collationId); } } public static boolean execBinary(final UTF8String l, final UTF8String r) { @@ -106,9 +107,14 @@ public static boolean execICU(final UTF8String l, final UTF8String r, } public static class StartsWith { - public static boolean exec(final UTF8String l, final UTF8String r, + public static boolean exec(UTF8String l, UTF8String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsSpaceTrimming) { + l = CollationFactory.applyTrimmingPolicy(l, collationId); + r = CollationFactory.applyTrimmingPolicy(r, collationId); + } + if (collation.isUtf8BinaryType) { return execBinary(l, r); } else if (collation.isUtf8LcaseType) { @@ -118,14 +124,11 @@ public static boolean exec(final UTF8String l, final UTF8String r, } } public static String genCode(final String l, final String r, final int collationId) { - CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StartsWith.exec"; - if (collation.isUtf8BinaryType) { + if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { return String.format(expr + "Binary(%s, %s)", l, r); - } else if (collation.isUtf8LcaseType) { - return String.format(expr + "Lowercase(%s, %s)", l, r); } else { - return String.format(expr + "ICU(%s, %s, %d)", l, r, collationId); + return String.format(expr + "(%s, %s, %d)", l, r, collationId); } } public static boolean execBinary(final UTF8String l, final UTF8String r) { @@ -144,8 +147,12 @@ public static boolean execICU(final UTF8String l, final UTF8String r, } public static class EndsWith { - public static boolean exec(final UTF8String l, final UTF8String r, final int collationId) { + public static boolean exec(UTF8String l, UTF8String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsSpaceTrimming) { + l = CollationFactory.applyTrimmingPolicy(l, collationId); + r = CollationFactory.applyTrimmingPolicy(r, collationId); + } if (collation.isUtf8BinaryType) { return execBinary(l, r); } else if (collation.isUtf8LcaseType) { @@ -155,14 +162,11 @@ public static boolean exec(final UTF8String l, final UTF8String r, final int col } } public static String genCode(final String l, final String r, final int collationId) { - CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.EndsWith.exec"; - if (collation.isUtf8BinaryType) { + if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { return String.format(expr + "Binary(%s, %s)", l, r); - } else if (collation.isUtf8LcaseType) { - return String.format(expr + "Lowercase(%s, %s)", l, r); } else { - return String.format(expr + "ICU(%s, %s, %d)", l, r, collationId); + return String.format(expr + "(%s, %s, %d)", l, r, collationId); } } public static boolean execBinary(final UTF8String l, final UTF8String r) { @@ -184,6 +188,7 @@ public static boolean execICU(final UTF8String l, final UTF8String r, public static class Upper { public static UTF8String exec(final UTF8String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + // Space trimming does not affect the output of this expression. if (collation.isUtf8BinaryType) { return useICU ? execBinaryICU(v) : execBinary(v); } else if (collation.isUtf8LcaseType) { @@ -221,6 +226,7 @@ public static UTF8String execICU(final UTF8String v, final int collationId) { public static class Lower { public static UTF8String exec(final UTF8String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + // Space trimming does not affect the output of this expression. if (collation.isUtf8BinaryType) { return useICU ? execBinaryICU(v) : execBinary(v); } else if (collation.isUtf8LcaseType) { @@ -258,6 +264,7 @@ public static UTF8String execICU(final UTF8String v, final int collationId) { public static class InitCap { public static UTF8String exec(final UTF8String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + // Space trimming does not affect the output of this expression. if (collation.isUtf8BinaryType) { return useICU ? execBinaryICU(v) : execBinary(v); } else if (collation.isUtf8LcaseType) { @@ -295,17 +302,16 @@ public static UTF8String execICU(final UTF8String v, final int collationId) { public static class FindInSet { public static int exec(final UTF8String word, final UTF8String set, final int collationId) { - CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.isUtf8BinaryType) { + // FindInSet does space trimming collation as comparison is space trimming collation aware + if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { return execBinary(word, set); } else { return execCollationAware(word, set, collationId); } } public static String genCode(final String word, final String set, final int collationId) { - CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.FindInSet.exec"; - if (collation.isUtf8BinaryType) { + if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { return String.format(expr + "Binary(%s, %s)", word, set); } else { return String.format(expr + "CollationAware(%s, %s, %d)", word, set, collationId); @@ -321,9 +327,12 @@ public static int execCollationAware(final UTF8String word, final UTF8String set } public static class StringInstr { - public static int exec(final UTF8String string, final UTF8String substring, + public static int exec(final UTF8String string, UTF8String substring, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsSpaceTrimming) { + substring = CollationFactory.applyTrimmingPolicy(substring, collationId); + } if (collation.isUtf8BinaryType) { return execBinary(string, substring); } else if (collation.isUtf8LcaseType) { @@ -334,14 +343,11 @@ public static int exec(final UTF8String string, final UTF8String substring, } public static String genCode(final String string, final String substring, final int collationId) { - CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringInstr.exec"; - if (collation.isUtf8BinaryType) { + if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { return String.format(expr + "Binary(%s, %s)", string, substring); - } else if (collation.isUtf8LcaseType) { - return String.format(expr + "Lowercase(%s, %s)", string, substring); } else { - return String.format(expr + "ICU(%s, %s, %d)", string, substring, collationId); + return String.format(expr + "(%s, %s, %d)", string, substring, collationId); } } public static int execBinary(final UTF8String string, final UTF8String substring) { @@ -360,6 +366,7 @@ public static class StringReplace { public static UTF8String exec(final UTF8String src, final UTF8String search, final UTF8String replace, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + // Space trimming does not affect the output of this expression. if (collation.isUtf8BinaryType) { return execBinary(src, search, replace); } else if (collation.isUtf8LcaseType) { @@ -395,9 +402,12 @@ public static UTF8String execICU(final UTF8String src, final UTF8String search, } public static class StringLocate { - public static int exec(final UTF8String string, final UTF8String substring, final int start, + public static int exec(final UTF8String string, UTF8String substring, final int start, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsSpaceTrimming) { + substring = CollationFactory.applyTrimmingPolicy(substring, collationId); + } if (collation.isUtf8BinaryType) { return execBinary(string, substring, start); } else if (collation.isUtf8LcaseType) { @@ -408,14 +418,11 @@ public static int exec(final UTF8String string, final UTF8String substring, fina } public static String genCode(final String string, final String substring, final int start, final int collationId) { - CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringLocate.exec"; - if (collation.isUtf8BinaryType) { + if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { return String.format(expr + "Binary(%s, %s, %d)", string, substring, start); - } else if (collation.isUtf8LcaseType) { - return String.format(expr + "Lowercase(%s, %s, %d)", string, substring, start); } else { - return String.format(expr + "ICU(%s, %s, %d, %d)", string, substring, start, collationId); + return String.format(expr + "(%s, %s, %d, %d)", string, substring, start, collationId); } } public static int execBinary(final UTF8String string, final UTF8String substring, @@ -433,9 +440,12 @@ public static int execICU(final UTF8String string, final UTF8String substring, f } public static class SubstringIndex { - public static UTF8String exec(final UTF8String string, final UTF8String delimiter, + public static UTF8String exec(final UTF8String string, UTF8String delimiter, final int count, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsSpaceTrimming) { + delimiter = CollationFactory.applyTrimmingPolicy(delimiter, collationId); + } if (collation.isUtf8BinaryType) { return execBinary(string, delimiter, count); } else if (collation.isUtf8LcaseType) { @@ -446,14 +456,11 @@ public static UTF8String exec(final UTF8String string, final UTF8String delimite } public static String genCode(final String string, final String delimiter, final String count, final int collationId) { - CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.SubstringIndex.exec"; - if (collation.isUtf8BinaryType) { + if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { return String.format(expr + "Binary(%s, %s, %s)", string, delimiter, count); - } else if (collation.isUtf8LcaseType) { - return String.format(expr + "Lowercase(%s, %s, %s)", string, delimiter, count); } else { - return String.format(expr + "ICU(%s, %s, %s, %d)", string, delimiter, count, collationId); + return String.format(expr + "(%s, %s, %s, %d)", string, delimiter, count, collationId); } } public static UTF8String execBinary(final UTF8String string, final UTF8String delimiter, @@ -474,6 +481,7 @@ public static class StringTranslate { public static UTF8String exec(final UTF8String source, Map dict, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + // Space trimming does not affect the output of this expression. if (collation.isUtf8BinaryType) { return execBinary(source, dict); } else if (collation.isUtf8LcaseType) { @@ -503,10 +511,15 @@ public static UTF8String exec( final UTF8String trimString, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.isUtf8BinaryType) { + if (collation.isUtf8BinaryType && !collation.supportsSpaceTrimming) { return execBinary(srcString, trimString); + } + + if (collation.isUtf8BinaryType) { + // special handling needed for utf8_binary_rtrim collation. + return execBinaryTrim(srcString, trimString, collationId); } else if (collation.isUtf8LcaseType) { - return execLowercase(srcString, trimString); + return execLowercase(srcString, trimString, collationId); } else { return execICU(srcString, trimString, collationId); } @@ -518,14 +531,11 @@ public static String genCode( final String srcString, final String trimString, final int collationId) { - CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringTrim.exec"; - if (collation.isUtf8BinaryType) { + if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { return String.format(expr + "Binary(%s, %s)", srcString, trimString); - } else if (collation.isUtf8LcaseType) { - return String.format(expr + "Lowercase(%s, %s)", srcString, trimString); } else { - return String.format(expr + "ICU(%s, %s, %d)", srcString, trimString, collationId); + return String.format(expr + "(%s, %s, %d)", srcString, trimString, collationId); } } public static UTF8String execBinary( @@ -539,8 +549,9 @@ public static UTF8String execBinary( } public static UTF8String execLowercase( final UTF8String srcString, - final UTF8String trimString) { - return CollationAwareUTF8String.lowercaseTrim(srcString, trimString); + final UTF8String trimString, + final int collationId) { + return CollationAwareUTF8String.lowercaseTrim(srcString, trimString, collationId); } public static UTF8String execICU( final UTF8String srcString, @@ -548,6 +559,12 @@ public static UTF8String execICU( final int collationId) { return CollationAwareUTF8String.trim(srcString, trimString, collationId); } + public static UTF8String execBinaryTrim( + final UTF8String srcString, + final UTF8String trimString, + final int collationId) { + return CollationAwareUTF8String.binaryTrim(srcString, trimString, collationId); + } } public static class StringTrimLeft { @@ -555,10 +572,12 @@ public static UTF8String exec(final UTF8String srcString) { return execBinary(srcString); } public static UTF8String exec( - final UTF8String srcString, - final UTF8String trimString, - final int collationId) { + final UTF8String srcString, + UTF8String trimString, + final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + // Space trimming does not affect the output of this expression as currently only supported + // space trimming is RTRIM. if (collation.isUtf8BinaryType) { return execBinary(srcString, trimString); } else if (collation.isUtf8LcaseType) { @@ -574,14 +593,11 @@ public static String genCode( final String srcString, final String trimString, final int collationId) { - CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringTrimLeft.exec"; - if (collation.isUtf8BinaryType) { + if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { return String.format(expr + "Binary(%s, %s)", srcString, trimString); - } else if (collation.isUtf8LcaseType) { - return String.format(expr + "Lowercase(%s, %s)", srcString, trimString); } else { - return String.format(expr + "ICU(%s, %s, %d)", srcString, trimString, collationId); + return String.format(expr + "(%s, %s, %d)", srcString, trimString, collationId); } } public static UTF8String execBinary(final UTF8String srcString) { @@ -614,10 +630,15 @@ public static UTF8String exec( final UTF8String trimString, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.isUtf8BinaryType) { + if (collation.isUtf8BinaryType && !collation.supportsSpaceTrimming) { return execBinary(srcString, trimString); + } + + if (collation.isUtf8BinaryType) { + // special handling needed for utf8_binary_rtrim collation. + return execBinaryTrim(srcString, trimString, collationId); } else if (collation.isUtf8LcaseType) { - return execLowercase(srcString, trimString); + return execLowercase(srcString, trimString, collationId); } else { return execICU(srcString, trimString, collationId); } @@ -629,14 +650,11 @@ public static String genCode( final String srcString, final String trimString, final int collationId) { - CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringTrimRight.exec"; - if (collation.isUtf8BinaryType) { + if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { return String.format(expr + "Binary(%s, %s)", srcString, trimString); - } else if (collation.isUtf8LcaseType) { - return String.format(expr + "Lowercase(%s, %s)", srcString, trimString); } else { - return String.format(expr + "ICU(%s, %s, %d)", srcString, trimString, collationId); + return String.format(expr + "(%s, %s, %d)", srcString, trimString, collationId); } } public static UTF8String execBinary(final UTF8String srcString) { @@ -649,8 +667,9 @@ public static UTF8String execBinary( } public static UTF8String execLowercase( final UTF8String srcString, - final UTF8String trimString) { - return CollationAwareUTF8String.lowercaseTrimRight(srcString, trimString); + final UTF8String trimString, + final int collationId) { + return CollationAwareUTF8String.lowercaseTrimRight(srcString, trimString, collationId); } public static UTF8String execICU( final UTF8String srcString, @@ -658,6 +677,12 @@ public static UTF8String execICU( final int collationId) { return CollationAwareUTF8String.trimRight(srcString, trimString, collationId); } + public static UTF8String execBinaryTrim( + final UTF8String srcString, + final UTF8String trimString, + final int collationId) { + return CollationAwareUTF8String.binaryTrimRight(srcString, trimString, collationId); + } } // TODO: Add more collation-aware string expressions. diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 50a3cc6049ea4..caf8461b0b5d6 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -141,6 +141,7 @@ private enum IsFullAscii { private static final UTF8String COMMA_UTF8 = UTF8String.fromString(","); public static final UTF8String EMPTY_UTF8 = UTF8String.fromString(""); public static final UTF8String ZERO_UTF8 = UTF8String.fromString("0"); + public static final UTF8String SPACE_UTF8 = UTF8String.fromString(" "); /** diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala index 6daaf2a4c6759..1b16432e63786 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala @@ -169,7 +169,13 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UNICODE_CI_RTRIM", "aaa", "AaA ", true), CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AaA ", true), CollationTestCase("UNICODE_CI_RTRIM", "aaa", " AaA ", false), - CollationTestCase("UNICODE_RTRIM", " ", " ", true) + CollationTestCase("UNICODE_RTRIM", " ", " ", true), + CollationTestCase("SR_CI", "cČć", "CčĆ", true), + CollationTestCase("SR_CI", "cCc", "CčĆ", false), + CollationTestCase("SR_CI_AI", "cCc", "CčĆ", true), + CollationTestCase("sr_Cyrl_CI", "цЧћ", "ЦчЋ", true), + CollationTestCase("sr_Cyrl_CI", "цЦц", "ЦчЋ", false), + CollationTestCase("sr_Cyrl_CI_AI", "цЦц", "ЦчЋ", false) ) checks.foreach(testCase => { @@ -229,7 +235,9 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "bbb ", -1), CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "aa", 1), CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "aa ", 1), - CollationTestCase("UNICODE_CI_RTRIM", " ", " ", 0) + CollationTestCase("UNICODE_CI_RTRIM", " ", " ", 0), + CollationTestCase("SR_CI_AI", "cČć", "ČćC", 0), + CollationTestCase("SR_CI", "cČć", "ČćC", -1) ) checks.foreach(testCase => { @@ -248,7 +256,10 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UNICODE_CI", "abcde", "abcde", 5), CollationTestCase("UNICODE_CI", "abcde", "ABCDE", 5), CollationTestCase("UNICODE_CI", "abcde", "fgh", 0), - CollationTestCase("UNICODE_CI", "abcde", "FGH", 0) + CollationTestCase("UNICODE_CI", "abcde", "FGH", 0), + CollationTestCase("SR_CI_AI", "abcčċ", "CCC", 3), + CollationTestCase("SR_CI", "abcčċ", "C", 1), + CollationTestCase("SR", "abcčċ", "CCC", 0) ) checks.foreach(testCase => { @@ -285,7 +296,9 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig "UNICODE_CI", "UNICODE_AI", "UNICODE_CI_AI", - "UNICODE_AI_CI" + "UNICODE_AI_CI", + "DE_CI_AI", + "MT_CI" ).foreach(collationId => { val col1 = fetchCollation(collationId) val col2 = fetchCollation(collationId) diff --git a/common/utils/src/main/scala/org/apache/spark/unsafe/array/ByteArrayUtils.java b/common/utils/src/main/java/org/apache/spark/unsafe/array/ByteArrayUtils.java similarity index 100% rename from common/utils/src/main/scala/org/apache/spark/unsafe/array/ByteArrayUtils.java rename to common/utils/src/main/java/org/apache/spark/unsafe/array/ByteArrayUtils.java diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index af4fa73fe4c42..553d7e43db9c4 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -2419,7 +2419,8 @@ }, "INVALID_FRACTION_OF_SECOND" : { "message" : [ - "The fraction of sec must be zero. Valid range is [0, 60]. If necessary set to \"false\" to bypass this error." + "Valid range for seconds is [0, 60] (inclusive), but the provided value is . To avoid this error, use `try_make_timestamp`, which returns NULL on error.", + "If you do not want to use the session default timestamp version of this function, use `try_make_timestamp_ntz` or `try_make_timestamp_ltz`." ], "sqlState" : "22023" }, @@ -5765,11 +5766,6 @@ "The value of from-to unit must be a string." ] }, - "_LEGACY_ERROR_TEMP_0028" : { - "message" : [ - "Intervals FROM TO are not supported." - ] - }, "_LEGACY_ERROR_TEMP_0029" : { "message" : [ "Cannot mix year-month and day-time fields: ." diff --git a/connector/avro/src/main/java/org/apache/spark/sql/avro/CustomDecimal.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/CustomDecimal.scala similarity index 100% rename from connector/avro/src/main/java/org/apache/spark/sql/avro/CustomDecimal.scala rename to connector/avro/src/main/scala/org/apache/spark/sql/avro/CustomDecimal.scala diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 72f56f35bf935..a11de64ed61fe 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -1977,6 +1977,58 @@ class PlanGenerationTestSuite fn.col("b")) } + functionTest("try_make_timestamp with timezone") { + fn.try_make_timestamp( + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("b"), + fn.col("g")) + } + + functionTest("try_make_timestamp without timezone") { + fn.try_make_timestamp( + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("b")) + } + + functionTest("try_make_timestamp_ltz with timezone") { + fn.try_make_timestamp_ltz( + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("b"), + fn.col("g")) + } + + functionTest("try_make_timestamp_ltz without timezone") { + fn.try_make_timestamp_ltz( + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("b")) + } + + functionTest("try_make_timestamp_ntz") { + fn.try_make_timestamp_ntz( + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("b")) + } + functionTest("make_ym_interval years months") { fn.make_ym_interval(fn.col("a"), fn.col("a")) } diff --git a/connector/protobuf/pom.xml b/connector/protobuf/pom.xml index 8c52576c3531f..e85481ef9e1c8 100644 --- a/connector/protobuf/pom.xml +++ b/connector/protobuf/pom.xml @@ -146,7 +146,7 @@ src/test/resources/protobuf - true + direct java diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala index febe1ac4bb4cf..1c4c00c03a470 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.storage import java.io.File +import java.nio.file.{Files, Paths} import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue, Semaphore, TimeUnit} import scala.collection.mutable.ArrayBuffer @@ -377,20 +378,22 @@ class BlockManagerDecommissionIntegrationSuite extends SparkFunSuite with LocalS .set(config.STORAGE_DECOMMISSION_SHUFFLE_BLOCKS_ENABLED, true) sc = new SparkContext(conf) TestUtils.waitUntilExecutorsUp(sc, 2, 60000) - val shuffleBlockUpdates = new ArrayBuffer[BlockId]() - var isDecommissionedExecutorRemoved = false + val shuffleBlockUpdates = new ConcurrentLinkedQueue[BlockId]() val execToDecommission = sc.getExecutorIds().head + val decommissionedExecutorLocalDir = sc.parallelize(1 to 100, 10).flatMap { _ => + if (SparkEnv.get.executorId == execToDecommission) { + SparkEnv.get.blockManager.getLocalDiskDirs + } else { + Array.empty[String] + } + }.collect().toSet + assert(decommissionedExecutorLocalDir.size == 1) sc.addSparkListener(new SparkListener { override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = { if (blockUpdated.blockUpdatedInfo.blockId.isShuffle) { - shuffleBlockUpdates += blockUpdated.blockUpdatedInfo.blockId + shuffleBlockUpdates.add(blockUpdated.blockUpdatedInfo.blockId) } } - - override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = { - assert(execToDecommission === executorRemoved.executorId) - isDecommissionedExecutorRemoved = true - } }) // Run a job to create shuffle data @@ -409,12 +412,13 @@ class BlockManagerDecommissionIntegrationSuite extends SparkFunSuite with LocalS ) eventually(timeout(1.minute), interval(10.milliseconds)) { - assert(isDecommissionedExecutorRemoved) + assert(Files.notExists(Paths.get(decommissionedExecutorLocalDir.head))) // Ensure there are shuffle data have been migrated assert(shuffleBlockUpdates.size >= 2) } val shuffleId = shuffleBlockUpdates + .asScala .find(_.isInstanceOf[ShuffleIndexBlockId]) .map(_.asInstanceOf[ShuffleIndexBlockId].shuffleId) .get diff --git a/dev/create-release/spark-rm/Dockerfile b/dev/create-release/spark-rm/Dockerfile index f70a1dec6e468..fd7c3dbaa61d6 100644 --- a/dev/create-release/spark-rm/Dockerfile +++ b/dev/create-release/spark-rm/Dockerfile @@ -102,7 +102,7 @@ RUN pypy3 -m pip install numpy 'six==1.16.0' 'pandas==2.2.2' scipy coverage matp ARG BASIC_PIP_PKGS="numpy pyarrow>=15.0.0 six==1.16.0 pandas==2.2.2 scipy plotly>=4.8 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2 twine==3.4.1" # Python deps for Spark Connect -ARG CONNECT_PIP_PKGS="grpcio==1.62.0 grpcio-status==1.62.0 protobuf==5.28.3 googleapis-common-protos==1.65.0" +ARG CONNECT_PIP_PKGS="grpcio==1.67.0 grpcio-status==1.67.0 protobuf==5.28.3 googleapis-common-protos==1.65.0" # Install Python 3.10 packages RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10 @@ -131,7 +131,7 @@ RUN python3.9 -m pip install --force $BASIC_PIP_PKGS unittest-xml-reporting $CON RUN python3.9 -m pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13' sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe 'pyzmq<24.0.0' \ ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8' 'docutils<0.18.0' \ 'flake8==3.9.0' 'mypy==1.8.0' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'black==23.9.1' \ -'pandas-stubs==1.2.0.53' 'grpcio==1.62.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \ +'pandas-stubs==1.2.0.53' 'grpcio==1.67.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \ 'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5' RUN python3.9 -m pip list diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 747c4be50225f..f2aaf96f3e4f8 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -51,9 +51,9 @@ commons-math3/3.6.1//commons-math3-3.6.1.jar commons-pool/1.5.4//commons-pool-1.5.4.jar commons-text/1.12.0//commons-text-1.12.0.jar compress-lzf/1.1.2//compress-lzf-1.1.2.jar -curator-client/5.7.0//curator-client-5.7.0.jar -curator-framework/5.7.0//curator-framework-5.7.0.jar -curator-recipes/5.7.0//curator-recipes-5.7.0.jar +curator-client/5.7.1//curator-client-5.7.1.jar +curator-framework/5.7.1//curator-framework-5.7.1.jar +curator-recipes/5.7.1//curator-recipes-5.7.1.jar datanucleus-api-jdo/4.2.4//datanucleus-api-jdo-4.2.4.jar datanucleus-core/4.1.17//datanucleus-core-4.1.17.jar datanucleus-rdbms/4.1.19//datanucleus-rdbms-4.1.19.jar @@ -140,8 +140,8 @@ jersey-container-servlet/3.0.16//jersey-container-servlet-3.0.16.jar jersey-hk2/3.0.16//jersey-hk2-3.0.16.jar jersey-server/3.0.16//jersey-server-3.0.16.jar jettison/1.5.4//jettison-1.5.4.jar -jetty-util-ajax/11.0.23//jetty-util-ajax-11.0.23.jar -jetty-util/11.0.23//jetty-util-11.0.23.jar +jetty-util-ajax/11.0.24//jetty-util-ajax-11.0.24.jar +jetty-util/11.0.24//jetty-util-11.0.24.jar jjwt-api/0.12.6//jjwt-api-0.12.6.jar jline/2.14.6//jline-2.14.6.jar jline/3.26.3//jline-3.26.3.jar @@ -250,7 +250,7 @@ parquet-jackson/1.14.3//parquet-jackson-1.14.3.jar pickle/1.5//pickle-1.5.jar py4j/0.10.9.7//py4j-0.10.9.7.jar remotetea-oncrpc/1.1.2//remotetea-oncrpc-1.1.2.jar -rocksdbjni/9.5.2//rocksdbjni-9.5.2.jar +rocksdbjni/9.7.3//rocksdbjni-9.7.3.jar scala-collection-compat_2.13/2.7.0//scala-collection-compat_2.13-2.7.0.jar scala-compiler/2.13.15//scala-compiler-2.13.15.jar scala-library/2.13.15//scala-library-2.13.15.jar @@ -278,6 +278,6 @@ xbean-asm9-shaded/4.26//xbean-asm9-shaded-4.26.jar xmlschema-core/2.3.1//xmlschema-core-2.3.1.jar xz/1.10//xz-1.10.jar zjsonpatch/0.3.0//zjsonpatch-0.3.0.jar -zookeeper-jute/3.9.2//zookeeper-jute-3.9.2.jar -zookeeper/3.9.2//zookeeper-3.9.2.jar +zookeeper-jute/3.9.3//zookeeper-jute-3.9.3.jar +zookeeper/3.9.3//zookeeper-3.9.3.jar zstd-jni/1.5.6-6//zstd-jni-1.5.6-6.jar diff --git a/dev/infra/Dockerfile b/dev/infra/Dockerfile index 70efeecfac581..f61562afb1694 100644 --- a/dev/infra/Dockerfile +++ b/dev/infra/Dockerfile @@ -17,14 +17,14 @@ # Image for building and testing Spark branches. Based on Ubuntu 22.04. # See also in https://hub.docker.com/_/ubuntu -FROM ubuntu:jammy-20240227 +FROM ubuntu:jammy-20240911.1 LABEL org.opencontainers.image.authors="Apache Spark project " LABEL org.opencontainers.image.licenses="Apache-2.0" LABEL org.opencontainers.image.ref.name="Apache Spark Infra Image" # Overwrite this label to avoid exposing the underlying Ubuntu OS version label LABEL org.opencontainers.image.version="" -ENV FULL_REFRESH_DATE 20241007 +ENV FULL_REFRESH_DATE 20241028 ENV DEBIAN_FRONTEND noninteractive ENV DEBCONF_NONINTERACTIVE_SEEN true @@ -96,7 +96,7 @@ RUN pypy3 -m pip install numpy 'six==1.16.0' 'pandas==2.2.3' scipy coverage matp ARG BASIC_PIP_PKGS="numpy pyarrow>=15.0.0 six==1.16.0 pandas==2.2.3 scipy plotly>=4.8 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2" # Python deps for Spark Connect -ARG CONNECT_PIP_PKGS="grpcio==1.62.0 grpcio-status==1.62.0 protobuf==5.28.3 googleapis-common-protos==1.65.0 graphviz==0.20.3" +ARG CONNECT_PIP_PKGS="grpcio==1.67.0 grpcio-status==1.67.0 protobuf==5.28.3 googleapis-common-protos==1.65.0 graphviz==0.20.3" # Install Python 3.10 packages RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10 diff --git a/dev/protobuf-breaking-changes-check.sh b/dev/protobuf-breaking-changes-check.sh index 05d90be564aba..327e54be63e62 100755 --- a/dev/protobuf-breaking-changes-check.sh +++ b/dev/protobuf-breaking-changes-check.sh @@ -21,7 +21,7 @@ set -ex if [[ $# -gt 1 ]]; then echo "Illegal number of parameters." echo "Usage: ./dev/protobuf-breaking-changes-check.sh [branch]" - echo "the default branch is 'master', available options are 'master', 'branch-3.4', etc" + echo "the default branch is 'master'" exit -1 fi diff --git a/dev/requirements.txt b/dev/requirements.txt index 88456e876d271..9f8d040659000 100644 --- a/dev/requirements.txt +++ b/dev/requirements.txt @@ -58,8 +58,8 @@ black==23.9.1 py # Spark Connect (required) -grpcio>=1.62.0 -grpcio-status>=1.62.0 +grpcio>=1.67.0 +grpcio-status>=1.67.0 googleapis-common-protos>=1.65.0 # Spark Connect python proto generation plugin (optional) diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index 200ddc9a20f3d..500b41f7569a3 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -380,6 +380,10 @@ When ANSI mode is on, it throws exceptions for invalid operations. You can use t - `try_element_at`: identical to the function `element_at`, except that it returns `NULL` result instead of throwing an exception on array's index out of bound. - `try_to_timestamp`: identical to the function `to_timestamp`, except that it returns `NULL` result instead of throwing an exception on string parsing error. - `try_parse_url`: identical to the function `parse_url`, except that it returns `NULL` result instead of throwing an exception on url parsing error. + - `try_make_timestamp`: identical to the function `make_timestamp`, except that it returns `NULL` result instead of throwing an exception on error. + - `try_make_timestamp_ltz`: identical to the function `make_timestamp_ltz`, except that it returns `NULL` result instead of throwing an exception on error. + - `try_make_timestamp_ntz`: identical to the function `make_timestamp_ntz`, except that it returns `NULL` result instead of throwing an exception on error. + ### SQL Keywords (optional, disabled by default) diff --git a/pom.xml b/pom.xml index 086948aac7fa3..94c51f2d3563a 100644 --- a/pom.xml +++ b/pom.xml @@ -128,8 +128,8 @@ 4.28.3 3.11.4 ${hadoop.version} - 3.9.2 - 5.7.0 + 3.9.3 + 5.7.1 org.apache.hive core @@ -141,7 +141,7 @@ 1.14.3 2.0.2 shaded-protobuf - 11.0.23 + 11.0.24 5.0.0 4.0.1 @@ -294,7 +294,7 @@ 33.2.1-jre 1.0.2 - 1.62.2 + 1.67.1 1.1.4 6.0.53 @@ -726,7 +726,7 @@ org.rocksdb rocksdbjni - 9.5.2 + 9.7.3 ${leveldbjni.group} diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b061ce96bc0fe..cbd0c11958dfc 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -91,7 +91,7 @@ object BuildCommons { // SPARK-41247: needs to be consistent with `protobuf.version` in `pom.xml`. val protoVersion = "4.28.3" // GRPC version used for Spark Connect. - val grpcVersion = "1.62.2" + val grpcVersion = "1.67.1" } object SparkBuild extends PomBuild { diff --git a/python/docs/source/getting_started/install.rst b/python/docs/source/getting_started/install.rst index 4d777e0840dc7..d0dc285b5257c 100644 --- a/python/docs/source/getting_started/install.rst +++ b/python/docs/source/getting_started/install.rst @@ -208,8 +208,8 @@ Package Supported version Note ========================== ================= ========================== `pandas` >=2.0.0 Required for Spark Connect `pyarrow` >=10.0.0 Required for Spark Connect -`grpcio` >=1.62.0 Required for Spark Connect -`grpcio-status` >=1.62.0 Required for Spark Connect +`grpcio` >=1.67.0 Required for Spark Connect +`grpcio-status` >=1.67.0 Required for Spark Connect `googleapis-common-protos` >=1.65.0 Required for Spark Connect `graphviz` >=0.20 Optional for Spark Connect ========================== ================= ========================== diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst index bf73fec58280d..b9df5691b82a9 100644 --- a/python/docs/source/reference/pyspark.sql/functions.rst +++ b/python/docs/source/reference/pyspark.sql/functions.rst @@ -301,6 +301,9 @@ Date and Timestamp Functions to_unix_timestamp to_utc_timestamp trunc + try_make_timestamp + try_make_timestamp_ltz + try_make_timestamp_ntz try_to_timestamp unix_date unix_micros diff --git a/python/packaging/classic/setup.py b/python/packaging/classic/setup.py index 60da51caa20ae..d799af1216345 100755 --- a/python/packaging/classic/setup.py +++ b/python/packaging/classic/setup.py @@ -153,7 +153,7 @@ def _supports_symlinks(): _minimum_pandas_version = "2.0.0" _minimum_numpy_version = "1.21" _minimum_pyarrow_version = "10.0.0" -_minimum_grpc_version = "1.62.0" +_minimum_grpc_version = "1.67.0" _minimum_googleapis_common_protos_version = "1.65.0" diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index ae9fbccceb3e9..5aa0313631c04 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -816,6 +816,11 @@ "message": [ "Pipe function `` exited with error code ." ] + }, + "PLOT_INVALID_TYPE_COLUMN": { + "message": [ + "Column must be one of for plotting, got ." + ] }, "PLOT_NOT_NUMERIC_COLUMN": { "message": [ diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index ab42f244a5a60..b8bd0e9bf7fdc 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -2570,15 +2570,23 @@ def locate(substr: str, str: "ColumnOrName", pos: int = 1) -> Column: locate.__doc__ = pysparkfuncs.locate.__doc__ -def lpad(col: "ColumnOrName", len: int, pad: str) -> Column: - return _invoke_function("lpad", _to_col(col), lit(len), lit(pad)) +def lpad( + col: "ColumnOrName", + len: Union[Column, int], + pad: Union[Column, str], +) -> Column: + return _invoke_function_over_columns("lpad", col, lit(len), lit(pad)) lpad.__doc__ = pysparkfuncs.lpad.__doc__ -def rpad(col: "ColumnOrName", len: int, pad: str) -> Column: - return _invoke_function("rpad", _to_col(col), lit(len), lit(pad)) +def rpad( + col: "ColumnOrName", + len: Union[Column, int], + pad: Union[Column, str], +) -> Column: + return _invoke_function_over_columns("rpad", col, lit(len), lit(pad)) rpad.__doc__ = pysparkfuncs.rpad.__doc__ @@ -3751,6 +3759,28 @@ def make_timestamp( make_timestamp.__doc__ = pysparkfuncs.make_timestamp.__doc__ +def try_make_timestamp( + years: "ColumnOrName", + months: "ColumnOrName", + days: "ColumnOrName", + hours: "ColumnOrName", + mins: "ColumnOrName", + secs: "ColumnOrName", + timezone: Optional["ColumnOrName"] = None, +) -> Column: + if timezone is not None: + return _invoke_function_over_columns( + "try_make_timestamp", years, months, days, hours, mins, secs, timezone + ) + else: + return _invoke_function_over_columns( + "try_make_timestamp", years, months, days, hours, mins, secs + ) + + +try_make_timestamp.__doc__ = pysparkfuncs.try_make_timestamp.__doc__ + + def make_timestamp_ltz( years: "ColumnOrName", months: "ColumnOrName", @@ -3773,6 +3803,28 @@ def make_timestamp_ltz( make_timestamp_ltz.__doc__ = pysparkfuncs.make_timestamp_ltz.__doc__ +def try_make_timestamp_ltz( + years: "ColumnOrName", + months: "ColumnOrName", + days: "ColumnOrName", + hours: "ColumnOrName", + mins: "ColumnOrName", + secs: "ColumnOrName", + timezone: Optional["ColumnOrName"] = None, +) -> Column: + if timezone is not None: + return _invoke_function_over_columns( + "try_make_timestamp_ltz", years, months, days, hours, mins, secs, timezone + ) + else: + return _invoke_function_over_columns( + "try_make_timestamp_ltz", years, months, days, hours, mins, secs + ) + + +try_make_timestamp_ltz.__doc__ = pysparkfuncs.try_make_timestamp_ltz.__doc__ + + def make_timestamp_ntz( years: "ColumnOrName", months: "ColumnOrName", @@ -3789,6 +3841,22 @@ def make_timestamp_ntz( make_timestamp_ntz.__doc__ = pysparkfuncs.make_timestamp_ntz.__doc__ +def try_make_timestamp_ntz( + years: "ColumnOrName", + months: "ColumnOrName", + days: "ColumnOrName", + hours: "ColumnOrName", + mins: "ColumnOrName", + secs: "ColumnOrName", +) -> Column: + return _invoke_function_over_columns( + "try_make_timestamp_ntz", years, months, days, hours, mins, secs + ) + + +try_make_timestamp_ntz.__doc__ = pysparkfuncs.try_make_timestamp_ntz.__doc__ + + def make_ym_interval( years: Optional["ColumnOrName"] = None, months: Optional["ColumnOrName"] = None, diff --git a/python/pyspark/sql/connect/proto/base_pb2_grpc.py b/python/pyspark/sql/connect/proto/base_pb2_grpc.py index 12675747e0f92..7501aaf0a3a23 100644 --- a/python/pyspark/sql/connect/proto/base_pb2_grpc.py +++ b/python/pyspark/sql/connect/proto/base_pb2_grpc.py @@ -34,51 +34,61 @@ def __init__(self, channel): "/spark.connect.SparkConnectService/ExecutePlan", request_serializer=spark_dot_connect_dot_base__pb2.ExecutePlanRequest.SerializeToString, response_deserializer=spark_dot_connect_dot_base__pb2.ExecutePlanResponse.FromString, + _registered_method=True, ) self.AnalyzePlan = channel.unary_unary( "/spark.connect.SparkConnectService/AnalyzePlan", request_serializer=spark_dot_connect_dot_base__pb2.AnalyzePlanRequest.SerializeToString, response_deserializer=spark_dot_connect_dot_base__pb2.AnalyzePlanResponse.FromString, + _registered_method=True, ) self.Config = channel.unary_unary( "/spark.connect.SparkConnectService/Config", request_serializer=spark_dot_connect_dot_base__pb2.ConfigRequest.SerializeToString, response_deserializer=spark_dot_connect_dot_base__pb2.ConfigResponse.FromString, + _registered_method=True, ) self.AddArtifacts = channel.stream_unary( "/spark.connect.SparkConnectService/AddArtifacts", request_serializer=spark_dot_connect_dot_base__pb2.AddArtifactsRequest.SerializeToString, response_deserializer=spark_dot_connect_dot_base__pb2.AddArtifactsResponse.FromString, + _registered_method=True, ) self.ArtifactStatus = channel.unary_unary( "/spark.connect.SparkConnectService/ArtifactStatus", request_serializer=spark_dot_connect_dot_base__pb2.ArtifactStatusesRequest.SerializeToString, response_deserializer=spark_dot_connect_dot_base__pb2.ArtifactStatusesResponse.FromString, + _registered_method=True, ) self.Interrupt = channel.unary_unary( "/spark.connect.SparkConnectService/Interrupt", request_serializer=spark_dot_connect_dot_base__pb2.InterruptRequest.SerializeToString, response_deserializer=spark_dot_connect_dot_base__pb2.InterruptResponse.FromString, + _registered_method=True, ) self.ReattachExecute = channel.unary_stream( "/spark.connect.SparkConnectService/ReattachExecute", request_serializer=spark_dot_connect_dot_base__pb2.ReattachExecuteRequest.SerializeToString, response_deserializer=spark_dot_connect_dot_base__pb2.ExecutePlanResponse.FromString, + _registered_method=True, ) self.ReleaseExecute = channel.unary_unary( "/spark.connect.SparkConnectService/ReleaseExecute", request_serializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteRequest.SerializeToString, response_deserializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteResponse.FromString, + _registered_method=True, ) self.ReleaseSession = channel.unary_unary( "/spark.connect.SparkConnectService/ReleaseSession", request_serializer=spark_dot_connect_dot_base__pb2.ReleaseSessionRequest.SerializeToString, response_deserializer=spark_dot_connect_dot_base__pb2.ReleaseSessionResponse.FromString, + _registered_method=True, ) self.FetchErrorDetails = channel.unary_unary( "/spark.connect.SparkConnectService/FetchErrorDetails", request_serializer=spark_dot_connect_dot_base__pb2.FetchErrorDetailsRequest.SerializeToString, response_deserializer=spark_dot_connect_dot_base__pb2.FetchErrorDetailsResponse.FromString, + _registered_method=True, ) @@ -220,6 +230,7 @@ def add_SparkConnectServiceServicer_to_server(servicer, server): "spark.connect.SparkConnectService", rpc_method_handlers ) server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers("spark.connect.SparkConnectService", rpc_method_handlers) # This class is part of an EXPERIMENTAL API. @@ -253,6 +264,7 @@ def ExecutePlan( wait_for_ready, timeout, metadata, + _registered_method=True, ) @staticmethod @@ -282,6 +294,7 @@ def AnalyzePlan( wait_for_ready, timeout, metadata, + _registered_method=True, ) @staticmethod @@ -311,6 +324,7 @@ def Config( wait_for_ready, timeout, metadata, + _registered_method=True, ) @staticmethod @@ -340,6 +354,7 @@ def AddArtifacts( wait_for_ready, timeout, metadata, + _registered_method=True, ) @staticmethod @@ -369,6 +384,7 @@ def ArtifactStatus( wait_for_ready, timeout, metadata, + _registered_method=True, ) @staticmethod @@ -398,6 +414,7 @@ def Interrupt( wait_for_ready, timeout, metadata, + _registered_method=True, ) @staticmethod @@ -427,6 +444,7 @@ def ReattachExecute( wait_for_ready, timeout, metadata, + _registered_method=True, ) @staticmethod @@ -456,6 +474,7 @@ def ReleaseExecute( wait_for_ready, timeout, metadata, + _registered_method=True, ) @staticmethod @@ -485,6 +504,7 @@ def ReleaseSession( wait_for_ready, timeout, metadata, + _registered_method=True, ) @staticmethod @@ -514,4 +534,5 @@ def FetchErrorDetails( wait_for_ready, timeout, metadata, + _registered_method=True, ) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 16e7cf052d6f1..810c6731de9a7 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -220,6 +220,38 @@ def lit(col: Any) -> Column: | false| Yes| | false| No| +-----------+--------+ + + Example 5: Creating literal columns from Numpy scalar. + + >>> from pyspark.sql import functions as sf + >>> import numpy as np # doctest: +SKIP + >>> spark.range(1).select( + ... sf.lit(np.bool_(True)), + ... sf.lit(np.int64(123)), + ... sf.lit(np.float64(0.456)), + ... sf.lit(np.str_("xyz")) + ... ).show() # doctest: +SKIP + +----+---+-----+---+ + |true|123|0.456|xyz| + +----+---+-----+---+ + |true|123|0.456|xyz| + +----+---+-----+---+ + + Example 6: Creating literal columns from Numpy ndarray. + + >>> from pyspark.sql import functions as sf + >>> import numpy as np # doctest: +SKIP + >>> spark.range(1).select( + ... sf.lit(np.array([True, False], np.bool_)), + ... sf.lit(np.array([], np.int8)), + ... sf.lit(np.array([1.5, 0.1], np.float64)), + ... sf.lit(np.array(["a", "b", "c"], np.str_)), + ... ).show() # doctest: +SKIP + +------------------+-------+-----------------+--------------------+ + |ARRAY(true, false)|ARRAY()|ARRAY(1.5D, 0.1D)|ARRAY('a', 'b', 'c')| + +------------------+-------+-----------------+--------------------+ + | [true, false]| []| [1.5, 0.1]| [a, b, c]| + +------------------+-------+-----------------+--------------------+ """ if isinstance(col, Column): return col @@ -272,7 +304,7 @@ def col(col: str) -> Column: Parameters ---------- - col : str + col : column name the name for the column Returns @@ -306,7 +338,7 @@ def asc(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name Target column to sort by in the ascending order. Returns @@ -318,9 +350,9 @@ def asc(col: "ColumnOrName") -> Column: -------- Example 1: Sort DataFrame by 'id' column in ascending order. - >>> from pyspark.sql.functions import asc + >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([(4, 'B'), (3, 'A'), (2, 'C')], ['id', 'value']) - >>> df.sort(asc("id")).show() + >>> df.sort(sf.asc("id")).show() +---+-----+ | id|value| +---+-----+ @@ -331,9 +363,9 @@ def asc(col: "ColumnOrName") -> Column: Example 2: Use `asc` in `orderBy` function to sort the DataFrame. - >>> from pyspark.sql.functions import asc + >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([(4, 'B'), (3, 'A'), (2, 'C')], ['id', 'value']) - >>> df.orderBy(asc("value")).show() + >>> df.orderBy(sf.asc("value")).show() +---+-----+ | id|value| +---+-----+ @@ -344,11 +376,11 @@ def asc(col: "ColumnOrName") -> Column: Example 3: Combine `asc` with `desc` to sort by multiple columns. - >>> from pyspark.sql.functions import asc, desc - >>> df = spark.createDataFrame([(2, 'A', 4), - ... (1, 'B', 3), - ... (3, 'A', 2)], ['id', 'group', 'value']) - >>> df.sort(asc("group"), desc("value")).show() + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame( + ... [(2, 'A', 4), (1, 'B', 3), (3, 'A', 2)], + ... ['id', 'group', 'value']) + >>> df.sort(sf.asc("group"), sf.desc("value")).show() +---+-----+-----+ | id|group|value| +---+-----+-----+ @@ -385,7 +417,7 @@ def desc(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name Target column to sort by in the descending order. Returns @@ -397,9 +429,9 @@ def desc(col: "ColumnOrName") -> Column: -------- Example 1: Sort DataFrame by 'id' column in descending order. - >>> from pyspark.sql.functions import desc + >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([(4, 'B'), (3, 'A'), (2, 'C')], ['id', 'value']) - >>> df.sort(desc("id")).show() + >>> df.sort(sf.desc("id")).show() +---+-----+ | id|value| +---+-----+ @@ -410,9 +442,9 @@ def desc(col: "ColumnOrName") -> Column: Example 2: Use `desc` in `orderBy` function to sort the DataFrame. - >>> from pyspark.sql.functions import desc + >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([(4, 'B'), (3, 'A'), (2, 'C')], ['id', 'value']) - >>> df.orderBy(desc("value")).show() + >>> df.orderBy(sf.desc("value")).show() +---+-----+ | id|value| +---+-----+ @@ -423,11 +455,11 @@ def desc(col: "ColumnOrName") -> Column: Example 3: Combine `asc` with `desc` to sort by multiple columns. - >>> from pyspark.sql.functions import asc, desc - >>> df = spark.createDataFrame([(2, 'A', 4), - ... (1, 'B', 3), - ... (3, 'A', 2)], ['id', 'group', 'value']) - >>> df.sort(desc("group"), asc("value")).show() + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame( + ... [(2, 'A', 4), (1, 'B', 3), (3, 'A', 2)], + ... ['id', 'group', 'value']) + >>> df.sort(sf.desc("group"), sf.asc("value")).show() +---+-----+-----+ | id|group|value| +---+-----+-----+ @@ -463,7 +495,7 @@ def sqrt(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name target column to compute on. Returns @@ -473,13 +505,19 @@ def sqrt(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.range(1) - >>> df.select(sqrt(lit(4))).show() - +-------+ - |SQRT(4)| - +-------+ - | 2.0| - +-------+ + >>> from pyspark.sql import functions as sf + >>> spark.sql( + ... "SELECT * FROM VALUES (-1), (0), (1), (4), (NULL) AS TAB(value)" + ... ).select("*", sf.sqrt("value")).show() + +-----+-----------+ + |value|SQRT(value)| + +-----+-----------+ + | -1| NaN| + | 0| 0.0| + | 1| 1.0| + | 4| 2.0| + | NULL| NULL| + +-----+-----------+ """ return _invoke_function_over_columns("sqrt", col) @@ -494,8 +532,8 @@ def try_add(left: "ColumnOrName", right: "ColumnOrName") -> Column: Parameters ---------- - left : :class:`~pyspark.sql.Column` or str - right : :class:`~pyspark.sql.Column` or str + left : :class:`~pyspark.sql.Column` or column name + right : :class:`~pyspark.sql.Column` or column name Examples -------- @@ -504,49 +542,49 @@ def try_add(left: "ColumnOrName", right: "ColumnOrName") -> Column: >>> import pyspark.sql.functions as sf >>> spark.createDataFrame( ... [(1982, 15), (1990, 2)], ["birth", "age"] - ... ).select(sf.try_add("birth", "age")).show() - +-------------------+ - |try_add(birth, age)| - +-------------------+ - | 1997| - | 1992| - +-------------------+ + ... ).select("*", sf.try_add("birth", "age")).show() + +-----+---+-------------------+ + |birth|age|try_add(birth, age)| + +-----+---+-------------------+ + | 1982| 15| 1997| + | 1990| 2| 1992| + +-----+---+-------------------+ Example 2: Date plus Integer. >>> import pyspark.sql.functions as sf >>> spark.sql( ... "SELECT * FROM VALUES (DATE('2015-09-30')) AS TAB(date)" - ... ).select(sf.try_add("date", sf.lit(1))).show() - +----------------+ - |try_add(date, 1)| - +----------------+ - | 2015-10-01| - +----------------+ + ... ).select("*", sf.try_add("date", sf.lit(1))).show() + +----------+----------------+ + | date|try_add(date, 1)| + +----------+----------------+ + |2015-09-30| 2015-10-01| + +----------+----------------+ Example 3: Date plus Interval. >>> import pyspark.sql.functions as sf >>> spark.sql( - ... "SELECT * FROM VALUES (DATE('2015-09-30'), INTERVAL 1 YEAR) AS TAB(date, i)" - ... ).select(sf.try_add("date", "i")).show() - +----------------+ - |try_add(date, i)| - +----------------+ - | 2016-09-30| - +----------------+ + ... "SELECT * FROM VALUES (DATE('2015-09-30'), INTERVAL 1 YEAR) AS TAB(date, itvl)" + ... ).select("*", sf.try_add("date", "itvl")).show() + +----------+-----------------+-------------------+ + | date| itvl|try_add(date, itvl)| + +----------+-----------------+-------------------+ + |2015-09-30|INTERVAL '1' YEAR| 2016-09-30| + +----------+-----------------+-------------------+ Example 4: Interval plus Interval. >>> import pyspark.sql.functions as sf >>> spark.sql( - ... "SELECT * FROM VALUES (INTERVAL 1 YEAR, INTERVAL 2 YEAR) AS TAB(i, j)" - ... ).select(sf.try_add("i", "j")).show() - +-----------------+ - | try_add(i, j)| - +-----------------+ - |INTERVAL '3' YEAR| - +-----------------+ + ... "SELECT * FROM VALUES (INTERVAL 1 YEAR, INTERVAL 2 YEAR) AS TAB(itvl1, itvl2)" + ... ).select("*", sf.try_add("itvl1", "itvl2")).show() + +-----------------+-----------------+---------------------+ + | itvl1| itvl2|try_add(itvl1, itvl2)| + +-----------------+-----------------+---------------------+ + |INTERVAL '1' YEAR|INTERVAL '2' YEAR| INTERVAL '3' YEAR| + +-----------------+-----------------+---------------------+ Example 5: Overflow results in NULL when ANSI mode is on @@ -554,8 +592,7 @@ def try_add(left: "ColumnOrName", right: "ColumnOrName") -> Column: >>> origin = spark.conf.get("spark.sql.ansi.enabled") >>> spark.conf.set("spark.sql.ansi.enabled", "true") >>> try: - ... df = spark.range(1) - ... df.select(sf.try_add(sf.lit(sys.maxsize), sf.lit(sys.maxsize))).show() + ... spark.range(1).select(sf.try_add(sf.lit(sys.maxsize), sf.lit(sys.maxsize))).show() ... finally: ... spark.conf.set("spark.sql.ansi.enabled", origin) +-------------------------------------------------+ @@ -576,7 +613,7 @@ def try_avg(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name Examples -------- @@ -633,9 +670,9 @@ def try_divide(left: "ColumnOrName", right: "ColumnOrName") -> Column: Parameters ---------- - left : :class:`~pyspark.sql.Column` or str + left : :class:`~pyspark.sql.Column` or column name dividend - right : :class:`~pyspark.sql.Column` or str + right : :class:`~pyspark.sql.Column` or column name divisor Examples @@ -645,29 +682,28 @@ def try_divide(left: "ColumnOrName", right: "ColumnOrName") -> Column: >>> import pyspark.sql.functions as sf >>> spark.createDataFrame( ... [(6000, 15), (1990, 2), (1234, 0)], ["a", "b"] - ... ).select(sf.try_divide("a", "b")).show() - +----------------+ - |try_divide(a, b)| - +----------------+ - | 400.0| - | 995.0| - | NULL| - +----------------+ + ... ).select("*", sf.try_divide("a", "b")).show() + +----+---+----------------+ + | a| b|try_divide(a, b)| + +----+---+----------------+ + |6000| 15| 400.0| + |1990| 2| 995.0| + |1234| 0| NULL| + +----+---+----------------+ Example 2: Interval divided by Integer. >>> import pyspark.sql.functions as sf - >>> spark.range(4).select( - ... sf.try_divide(sf.make_interval(sf.lit(1)), "id") - ... ).show() - +--------------------------------------------------+ - |try_divide(make_interval(1, 0, 0, 0, 0, 0, 0), id)| - +--------------------------------------------------+ - | NULL| - | 1 years| - | 6 months| - | 4 months| - +--------------------------------------------------+ + >>> df = spark.range(4).select(sf.make_interval(sf.lit(1)).alias("itvl"), "id") + >>> df.select("*", sf.try_divide("itvl", "id")).show() + +-------+---+--------------------+ + | itvl| id|try_divide(itvl, id)| + +-------+---+--------------------+ + |1 years| 0| NULL| + |1 years| 1| 1 years| + |1 years| 2| 6 months| + |1 years| 3| 4 months| + +-------+---+--------------------+ Example 3: Exception during division, resulting in NULL when ANSI mode is on @@ -675,8 +711,7 @@ def try_divide(left: "ColumnOrName", right: "ColumnOrName") -> Column: >>> origin = spark.conf.get("spark.sql.ansi.enabled") >>> spark.conf.set("spark.sql.ansi.enabled", "true") >>> try: - ... df = spark.range(1) - ... df.select(sf.try_divide(df.id, sf.lit(0))).show() + ... spark.range(1).select(sf.try_divide("id", sf.lit(0))).show() ... finally: ... spark.conf.set("spark.sql.ansi.enabled", origin) +-----------------+ @@ -698,9 +733,9 @@ def try_mod(left: "ColumnOrName", right: "ColumnOrName") -> Column: Parameters ---------- - left : :class:`~pyspark.sql.Column` or str + left : :class:`~pyspark.sql.Column` or column name dividend - right : :class:`~pyspark.sql.Column` or str + right : :class:`~pyspark.sql.Column` or column name divisor Examples @@ -710,14 +745,14 @@ def try_mod(left: "ColumnOrName", right: "ColumnOrName") -> Column: >>> import pyspark.sql.functions as sf >>> spark.createDataFrame( ... [(6000, 15), (3, 2), (1234, 0)], ["a", "b"] - ... ).select(sf.try_mod("a", "b")).show() - +-------------+ - |try_mod(a, b)| - +-------------+ - | 0| - | 1| - | NULL| - +-------------+ + ... ).select("*", sf.try_mod("a", "b")).show() + +----+---+-------------+ + | a| b|try_mod(a, b)| + +----+---+-------------+ + |6000| 15| 0| + | 3| 2| 1| + |1234| 0| NULL| + +----+---+-------------+ Example 2: Exception during division, resulting in NULL when ANSI mode is on @@ -725,8 +760,7 @@ def try_mod(left: "ColumnOrName", right: "ColumnOrName") -> Column: >>> origin = spark.conf.get("spark.sql.ansi.enabled") >>> spark.conf.set("spark.sql.ansi.enabled", "true") >>> try: - ... df = spark.range(1) - ... df.select(sf.try_mod(df.id, sf.lit(0))).show() + ... spark.range(1).select(sf.try_mod("id", sf.lit(0))).show() ... finally: ... spark.conf.set("spark.sql.ansi.enabled", origin) +--------------+ @@ -748,9 +782,9 @@ def try_multiply(left: "ColumnOrName", right: "ColumnOrName") -> Column: Parameters ---------- - left : :class:`~pyspark.sql.Column` or str + left : :class:`~pyspark.sql.Column` or column name multiplicand - right : :class:`~pyspark.sql.Column` or str + right : :class:`~pyspark.sql.Column` or column name multiplier Examples @@ -760,30 +794,29 @@ def try_multiply(left: "ColumnOrName", right: "ColumnOrName") -> Column: >>> import pyspark.sql.functions as sf >>> spark.createDataFrame( ... [(6000, 15), (1990, 2)], ["a", "b"] - ... ).select(sf.try_multiply("a", "b")).show() - +------------------+ - |try_multiply(a, b)| - +------------------+ - | 90000| - | 3980| - +------------------+ + ... ).select("*", sf.try_multiply("a", "b")).show() + +----+---+------------------+ + | a| b|try_multiply(a, b)| + +----+---+------------------+ + |6000| 15| 90000| + |1990| 2| 3980| + +----+---+------------------+ Example 2: Interval multiplied by Integer. >>> import pyspark.sql.functions as sf - >>> spark.range(6).select( - ... sf.try_multiply(sf.make_interval(sf.lit(0), sf.lit(3)), "id") - ... ).show() - +----------------------------------------------------+ - |try_multiply(make_interval(0, 3, 0, 0, 0, 0, 0), id)| - +----------------------------------------------------+ - | 0 seconds| - | 3 months| - | 6 months| - | 9 months| - | 1 years| - | 1 years 3 months| - +----------------------------------------------------+ + >>> df = spark.range(6).select(sf.make_interval(sf.col("id"), sf.lit(3)).alias("itvl"), "id") + >>> df.select("*", sf.try_multiply("itvl", "id")).show() + +----------------+---+----------------------+ + | itvl| id|try_multiply(itvl, id)| + +----------------+---+----------------------+ + | 3 months| 0| 0 seconds| + |1 years 3 months| 1| 1 years 3 months| + |2 years 3 months| 2| 4 years 6 months| + |3 years 3 months| 3| 9 years 9 months| + |4 years 3 months| 4| 17 years| + |5 years 3 months| 5| 26 years 3 months| + +----------------+---+----------------------+ Example 3: Overflow results in NULL when ANSI mode is on @@ -791,8 +824,7 @@ def try_multiply(left: "ColumnOrName", right: "ColumnOrName") -> Column: >>> origin = spark.conf.get("spark.sql.ansi.enabled") >>> spark.conf.set("spark.sql.ansi.enabled", "true") >>> try: - ... df = spark.range(1) - ... df.select(sf.try_multiply(sf.lit(sys.maxsize), sf.lit(sys.maxsize))).show() + ... spark.range(1).select(sf.try_multiply(sf.lit(sys.maxsize), sf.lit(sys.maxsize))).show() ... finally: ... spark.conf.set("spark.sql.ansi.enabled", origin) +------------------------------------------------------+ @@ -814,8 +846,8 @@ def try_subtract(left: "ColumnOrName", right: "ColumnOrName") -> Column: Parameters ---------- - left : :class:`~pyspark.sql.Column` or str - right : :class:`~pyspark.sql.Column` or str + left : :class:`~pyspark.sql.Column` or column name + right : :class:`~pyspark.sql.Column` or column name Examples -------- @@ -824,49 +856,49 @@ def try_subtract(left: "ColumnOrName", right: "ColumnOrName") -> Column: >>> import pyspark.sql.functions as sf >>> spark.createDataFrame( ... [(1982, 15), (1990, 2)], ["birth", "age"] - ... ).select(sf.try_subtract("birth", "age")).show() - +------------------------+ - |try_subtract(birth, age)| - +------------------------+ - | 1967| - | 1988| - +------------------------+ + ... ).select("*", sf.try_subtract("birth", "age")).show() + +-----+---+------------------------+ + |birth|age|try_subtract(birth, age)| + +-----+---+------------------------+ + | 1982| 15| 1967| + | 1990| 2| 1988| + +-----+---+------------------------+ Example 2: Date minus Integer. >>> import pyspark.sql.functions as sf >>> spark.sql( ... "SELECT * FROM VALUES (DATE('2015-10-01')) AS TAB(date)" - ... ).select(sf.try_subtract("date", sf.lit(1))).show() - +---------------------+ - |try_subtract(date, 1)| - +---------------------+ - | 2015-09-30| - +---------------------+ + ... ).select("*", sf.try_subtract("date", sf.lit(1))).show() + +----------+---------------------+ + | date|try_subtract(date, 1)| + +----------+---------------------+ + |2015-10-01| 2015-09-30| + +----------+---------------------+ Example 3: Date minus Interval. >>> import pyspark.sql.functions as sf >>> spark.sql( - ... "SELECT * FROM VALUES (DATE('2015-09-30'), INTERVAL 1 YEAR) AS TAB(date, i)" - ... ).select(sf.try_subtract("date", "i")).show() - +---------------------+ - |try_subtract(date, i)| - +---------------------+ - | 2014-09-30| - +---------------------+ + ... "SELECT * FROM VALUES (DATE('2015-09-30'), INTERVAL 1 YEAR) AS TAB(date, itvl)" + ... ).select("*", sf.try_subtract("date", "itvl")).show() + +----------+-----------------+------------------------+ + | date| itvl|try_subtract(date, itvl)| + +----------+-----------------+------------------------+ + |2015-09-30|INTERVAL '1' YEAR| 2014-09-30| + +----------+-----------------+------------------------+ Example 4: Interval minus Interval. >>> import pyspark.sql.functions as sf >>> spark.sql( - ... "SELECT * FROM VALUES (INTERVAL 1 YEAR, INTERVAL 2 YEAR) AS TAB(i, j)" - ... ).select(sf.try_subtract("i", "j")).show() - +------------------+ - |try_subtract(i, j)| - +------------------+ - |INTERVAL '-1' YEAR| - +------------------+ + ... "SELECT * FROM VALUES (INTERVAL 1 YEAR, INTERVAL 2 YEAR) AS TAB(itvl1, itvl2)" + ... ).select("*", sf.try_subtract("itvl1", "itvl2")).show() + +-----------------+-----------------+--------------------------+ + | itvl1| itvl2|try_subtract(itvl1, itvl2)| + +-----------------+-----------------+--------------------------+ + |INTERVAL '1' YEAR|INTERVAL '2' YEAR| INTERVAL '-1' YEAR| + +-----------------+-----------------+--------------------------+ Example 5: Overflow results in NULL when ANSI mode is on @@ -874,8 +906,7 @@ def try_subtract(left: "ColumnOrName", right: "ColumnOrName") -> Column: >>> origin = spark.conf.get("spark.sql.ansi.enabled") >>> spark.conf.set("spark.sql.ansi.enabled", "true") >>> try: - ... df = spark.range(1) - ... df.select(sf.try_subtract(sf.lit(-sys.maxsize), sf.lit(sys.maxsize))).show() + ... spark.range(1).select(sf.try_subtract(sf.lit(-sys.maxsize), sf.lit(sys.maxsize))).show() ... finally: ... spark.conf.set("spark.sql.ansi.enabled", origin) +-------------------------------------------------------+ @@ -896,15 +927,14 @@ def try_sum(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name Examples -------- Example 1: Calculating the sum of values in a column >>> from pyspark.sql import functions as sf - >>> df = spark.range(10) - >>> df.select(sf.try_sum(df["id"])).show() + >>> spark.range(10).select(sf.try_sum("id")).show() +-----------+ |try_sum(id)| +-----------+ @@ -965,7 +995,7 @@ def abs(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name The target column or expression to compute the absolute value on. Returns @@ -975,57 +1005,46 @@ def abs(col: "ColumnOrName") -> Column: Examples -------- - Example 1: Compute the absolute value of a negative number + Example 1: Compute the absolute value of a long column >>> from pyspark.sql import functions as sf - >>> df = spark.createDataFrame([(1, -1), (2, -2), (3, -3)], ["id", "value"]) - >>> df.select(sf.abs(df.value)).show() - +----------+ - |abs(value)| - +----------+ - | 1| - | 2| - | 3| - +----------+ - - Example 2: Compute the absolute value of an expression - - >>> from pyspark.sql import functions as sf - >>> df = spark.createDataFrame([(1, 1), (2, -2), (3, 3)], ["id", "value"]) - >>> df.select(sf.abs(df.id - df.value)).show() - +-----------------+ - |abs((id - value))| - +-----------------+ - | 0| - | 4| - | 0| - +-----------------+ + >>> df = spark.createDataFrame([(-1,), (-2,), (-3,), (None,)], ["value"]) + >>> df.select("*", sf.abs(df.value)).show() + +-----+----------+ + |value|abs(value)| + +-----+----------+ + | -1| 1| + | -2| 2| + | -3| 3| + | NULL| NULL| + +-----+----------+ - Example 3: Compute the absolute value of a column with null values + Example 2: Compute the absolute value of a double column >>> from pyspark.sql import functions as sf - >>> df = spark.createDataFrame([(1, None), (2, -2), (3, None)], ["id", "value"]) - >>> df.select(sf.abs(df.value)).show() - +----------+ - |abs(value)| - +----------+ - | NULL| - | 2| - | NULL| - +----------+ + >>> df = spark.createDataFrame([(-1.5,), (-2.5,), (None,), (float("nan"),)], ["value"]) + >>> df.select("*", sf.abs(df.value)).show() + +-----+----------+ + |value|abs(value)| + +-----+----------+ + | -1.5| 1.5| + | -2.5| 2.5| + | NULL| NULL| + | NaN| NaN| + +-----+----------+ - Example 4: Compute the absolute value of a column with double values + Example 3: Compute the absolute value of an expression >>> from pyspark.sql import functions as sf - >>> df = spark.createDataFrame([(1, -1.5), (2, -2.5), (3, -3.5)], ["id", "value"]) - >>> df.select(sf.abs(df.value)).show() - +----------+ - |abs(value)| - +----------+ - | 1.5| - | 2.5| - | 3.5| - +----------+ + >>> df = spark.createDataFrame([(1, 1), (2, -2), (3, 3)], ["id", "value"]) + >>> df.select("*", sf.abs(df.id - df.value)).show() + +---+-----+-----------------+ + | id|value|abs((id - value))| + +---+-----+-----------------+ + | 1| 1| 0| + | 2| -2| 4| + | 3| 3| 0| + +---+-----+-----------------+ """ return _invoke_function_over_columns("abs", col) @@ -1042,7 +1061,7 @@ def mode(col: "ColumnOrName", deterministic: bool = False) -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name target column to compute on. deterministic : bool, optional if there are multiple equally-frequent results then return the lowest (defaults to false). @@ -1084,6 +1103,7 @@ def mode(col: "ColumnOrName", deterministic: bool = False) -> Column: +---------+ | 0| +---------+ + >>> df.select(sf.mode("col", True)).show() +---------------------------------------+ |mode() WITHIN GROUP (ORDER BY col DESC)| @@ -1108,7 +1128,7 @@ def max(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name The target column on which the maximum value is computed. Returns @@ -1213,7 +1233,7 @@ def min(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name The target column on which the minimum value is computed. Returns @@ -1309,10 +1329,10 @@ def max_by(col: "ColumnOrName", ord: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name The column representing the values to be returned. This could be the column instance or the column name as string. - ord : :class:`~pyspark.sql.Column` or str + ord : :class:`~pyspark.sql.Column` or column name The column that needs to be maximized. This could be the column instance or the column name as string. @@ -1395,10 +1415,10 @@ def min_by(col: "ColumnOrName", ord: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name The column representing the values that will be returned. This could be the column instance or the column name as string. - ord : :class:`~pyspark.sql.Column` or str + ord : :class:`~pyspark.sql.Column` or column name The column that needs to be minimized. This could be the column instance or the column name as string. @@ -1474,7 +1494,7 @@ def count(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name target column to compute on. Returns @@ -1542,7 +1562,7 @@ def sum(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name target column to compute on. Returns @@ -1600,7 +1620,7 @@ def avg(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name target column to compute on. Returns @@ -1648,7 +1668,7 @@ def mean(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name target column to compute on. Returns @@ -1692,7 +1712,7 @@ def median(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name target column to compute on. Returns @@ -1706,12 +1726,13 @@ def median(col: "ColumnOrName") -> Column: Examples -------- + >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([ ... ("Java", 2012, 20000), ("dotNET", 2012, 5000), ... ("Java", 2012, 22000), ("dotNET", 2012, 10000), ... ("dotNET", 2013, 48000), ("Java", 2013, 30000)], ... schema=("course", "year", "earnings")) - >>> df.groupby("course").agg(median("earnings")).show() + >>> df.groupby("course").agg(sf.median("earnings")).show() +------+----------------+ |course|median(earnings)| +------+----------------+ @@ -1751,7 +1772,7 @@ def sum_distinct(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name target column to compute on. Returns @@ -1822,26 +1843,26 @@ def product(col: "ColumnOrName") -> Column: Parameters ---------- - col : str, :class:`Column` + col : :class:`~pyspark.sql.Column` or column name column containing values to be multiplied together Returns ------- - :class:`~pyspark.sql.Column` + :class:`~pyspark.sql.Column` or column name the column for computed results. Examples -------- - >>> df = spark.range(1, 10).toDF('x').withColumn('mod3', col('x') % 3) - >>> prods = df.groupBy('mod3').agg(product('x').alias('product')) - >>> prods.orderBy('mod3').show() - +----+-------+ - |mod3|product| - +----+-------+ - | 0| 162.0| - | 1| 28.0| - | 2| 80.0| - +----+-------+ + >>> from pyspark.sql import functions as sf + >>> df = spark.sql("SELECT id % 3 AS mod3, id AS value FROM RANGE(10)") + >>> df.groupBy('mod3').agg(sf.product('value')).orderBy('mod3').show() + +----+--------------+ + |mod3|product(value)| + +----+--------------+ + | 0| 0.0| + | 1| 28.0| + | 2| 80.0| + +----+--------------+ """ return _invoke_function_over_columns("product", col) @@ -1859,7 +1880,7 @@ def acos(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name The target column or expression to compute the inverse cosine on. Returns @@ -1869,11 +1890,11 @@ def acos(col: "ColumnOrName") -> Column: Examples -------- - Example 1: Compute the inverse cosine of a column of numbers + Example 1: Compute the inverse cosine >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([(-1.0,), (-0.5,), (0.0,), (0.5,), (1.0,)], ["value"]) - >>> df.select("value", sf.acos("value")).show() + >>> df.select("*", sf.acos("value")).show() +-----+------------------+ |value| ACOS(value)| +-----+------------------+ @@ -1884,30 +1905,19 @@ def acos(col: "ColumnOrName") -> Column: | 1.0| 0.0| +-----+------------------+ - Example 2: Compute the inverse cosine of a column with null values + Example 2: Compute the inverse cosine of invalid values >>> from pyspark.sql import functions as sf - >>> from pyspark.sql.types import StructType, StructField, IntegerType - >>> schema = StructType([StructField("value", IntegerType(), True)]) - >>> df = spark.createDataFrame([(None,)], schema=schema) - >>> df.select(sf.acos(df.value)).show() - +-----------+ - |ACOS(value)| - +-----------+ - | NULL| - +-----------+ - - Example 3: Compute the inverse cosine of a column with values outside the valid range - - >>> from pyspark.sql import functions as sf - >>> df = spark.createDataFrame([(2,), (-2,)], ["value"]) - >>> df.select(sf.acos(df.value)).show() - +-----------+ - |ACOS(value)| - +-----------+ - | NaN| - | NaN| - +-----------+ + >>> spark.sql( + ... "SELECT * FROM VALUES (-2), (2), (NULL) AS TAB(value)" + ... ).select("*", sf.acos("value")).show() + +-----+-----------+ + |value|ACOS(value)| + +-----+-----------+ + | -2| NaN| + | 2| NaN| + | NULL| NULL| + +-----+-----------+ """ return _invoke_function_over_columns("acos", col) @@ -1925,7 +1935,7 @@ def acosh(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name The target column or expression to compute the inverse hyperbolic cosine on. Returns @@ -1935,11 +1945,11 @@ def acosh(col: "ColumnOrName") -> Column: Examples -------- - Example 1: Compute the inverse hyperbolic cosine of a column of numbers + Example 1: Compute the inverse hyperbolic cosine >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([(1,), (2,)], ["value"]) - >>> df.select("value", sf.acosh(df.value)).show() + >>> df.select("*", sf.acosh(df.value)).show() +-----+------------------+ |value| ACOSH(value)| +-----+------------------+ @@ -1947,30 +1957,19 @@ def acosh(col: "ColumnOrName") -> Column: | 2|1.3169578969248...| +-----+------------------+ - Example 2: Compute the inverse hyperbolic cosine of a column with null values - - >>> from pyspark.sql import functions as sf - >>> from pyspark.sql.types import StructType, StructField, IntegerType - >>> schema = StructType([StructField("value", IntegerType(), True)]) - >>> df = spark.createDataFrame([(None,)], schema=schema) - >>> df.select(sf.acosh(df.value)).show() - +------------+ - |ACOSH(value)| - +------------+ - | NULL| - +------------+ - - Example 3: Compute the inverse hyperbolic cosine of a column with values less than 1 + Example 2: Compute the inverse hyperbolic cosine of invalid values >>> from pyspark.sql import functions as sf - >>> df = spark.createDataFrame([(0.5,), (-0.5,)], ["value"]) - >>> df.select(sf.acosh(df.value)).show() - +------------+ - |ACOSH(value)| - +------------+ - | NaN| - | NaN| - +------------+ + >>> spark.sql( + ... "SELECT * FROM VALUES (-0.5), (0.5), (NULL) AS TAB(value)" + ... ).select("*", sf.acosh("value")).show() + +-----+------------+ + |value|ACOSH(value)| + +-----+------------+ + | -0.5| NaN| + | 0.5| NaN| + | NULL| NULL| + +-----+------------+ """ return _invoke_function_over_columns("acosh", col) @@ -1987,7 +1986,7 @@ def asin(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name target column to compute on. Returns @@ -1997,14 +1996,32 @@ def asin(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([(0,), (2,)]) - >>> df.select(asin(df.schema.fieldNames()[0])).show() - +--------+ - |ASIN(_1)| - +--------+ - | 0.0| - | NaN| - +--------+ + Example 1: Compute the inverse sine + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([(-0.5,), (0.0,), (0.5,)], ["value"]) + >>> df.select("*", sf.asin(df.value)).show() + +-----+-------------------+ + |value| ASIN(value)| + +-----+-------------------+ + | -0.5|-0.5235987755982...| + | 0.0| 0.0| + | 0.5| 0.5235987755982...| + +-----+-------------------+ + + Example 2: Compute the inverse sine of invalid values + + >>> from pyspark.sql import functions as sf + >>> spark.sql( + ... "SELECT * FROM VALUES (-2), (2), (NULL) AS TAB(value)" + ... ).select("*", sf.asin("value")).show() + +-----+-----------+ + |value|ASIN(value)| + +-----+-----------+ + | -2| NaN| + | 2| NaN| + | NULL| NULL| + +-----+-----------+ """ return _invoke_function_over_columns("asin", col) @@ -2021,7 +2038,7 @@ def asinh(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name target column to compute on. Returns @@ -2031,13 +2048,31 @@ def asinh(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.range(1) - >>> df.select(asinh(col("id"))).show() - +---------+ - |ASINH(id)| - +---------+ - | 0.0| - +---------+ + Example 1: Compute the inverse hyperbolic sine + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([(-0.5,), (0.0,), (0.5,)], ["value"]) + >>> df.select("*", sf.asinh(df.value)).show() + +-----+--------------------+ + |value| ASINH(value)| + +-----+--------------------+ + | -0.5|-0.48121182505960...| + | 0.0| 0.0| + | 0.5| 0.48121182505960...| + +-----+--------------------+ + + Example 2: Compute the inverse hyperbolic sine of invalid values + + >>> from pyspark.sql import functions as sf + >>> spark.sql( + ... "SELECT * FROM VALUES (FLOAT('NAN')), (NULL) AS TAB(value)" + ... ).select("*", sf.asinh("value")).show() + +-----+------------+ + |value|ASINH(value)| + +-----+------------+ + | NaN| NaN| + | NULL| NULL| + +-----+------------+ """ return _invoke_function_over_columns("asinh", col) @@ -2054,7 +2089,7 @@ def atan(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name target column to compute on. Returns @@ -2064,13 +2099,31 @@ def atan(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.range(1) - >>> df.select(atan(df.id)).show() - +--------+ - |ATAN(id)| - +--------+ - | 0.0| - +--------+ + Example 1: Compute the inverse tangent + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([(-0.5,), (0.0,), (0.5,)], ["value"]) + >>> df.select("*", sf.atan(df.value)).show() + +-----+-------------------+ + |value| ATAN(value)| + +-----+-------------------+ + | -0.5|-0.4636476090008...| + | 0.0| 0.0| + | 0.5| 0.4636476090008...| + +-----+-------------------+ + + Example 2: Compute the inverse tangent of invalid values + + >>> from pyspark.sql import functions as sf + >>> spark.sql( + ... "SELECT * FROM VALUES (FLOAT('NAN')), (NULL) AS TAB(value)" + ... ).select("*", sf.atan("value")).show() + +-----+-----------+ + |value|ATAN(value)| + +-----+-----------+ + | NaN| NaN| + | NULL| NULL| + +-----+-----------+ """ return _invoke_function_over_columns("atan", col) @@ -2087,7 +2140,7 @@ def atanh(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name target column to compute on. Returns @@ -2097,14 +2150,33 @@ def atanh(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([(0,), (2,)], schema=["numbers"]) - >>> df.select(atanh(df["numbers"])).show() - +--------------+ - |ATANH(numbers)| - +--------------+ - | 0.0| - | NaN| - +--------------+ + Example 1: Compute the inverse hyperbolic tangent + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([(-0.5,), (0.0,), (0.5,)], ["value"]) + >>> df.select("*", sf.atanh(df.value)).show() + +-----+-------------------+ + |value| ATANH(value)| + +-----+-------------------+ + | -0.5|-0.5493061443340...| + | 0.0| 0.0| + | 0.5| 0.5493061443340...| + +-----+-------------------+ + + Example 2: Compute the inverse hyperbolic tangent of invalid values + + >>> from pyspark.sql import functions as sf + >>> spark.sql( + ... "SELECT * FROM VALUES (-2), (2), (FLOAT('NAN')), (NULL) AS TAB(value)" + ... ).select("*", sf.atanh("value")).show() + +-----+------------+ + |value|ATANH(value)| + +-----+------------+ + | -2.0| NaN| + | 2.0| NaN| + | NaN| NaN| + | NULL| NULL| + +-----+------------+ """ return _invoke_function_over_columns("atanh", col) @@ -2121,7 +2193,7 @@ def cbrt(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name target column to compute on. Returns @@ -2131,13 +2203,31 @@ def cbrt(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.range(1) - >>> df.select(cbrt(lit(27))).show() - +--------+ - |CBRT(27)| - +--------+ - | 3.0| - +--------+ + Example 1: Compute the cube-root + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([(-8,), (0,), (8,)], ["value"]) + >>> df.select("*", sf.cbrt(df.value)).show() + +-----+-----------+ + |value|CBRT(value)| + +-----+-----------+ + | -8| -2.0| + | 0| 0.0| + | 8| 2.0| + +-----+-----------+ + + Example 2: Compute the cube-root of invalid values + + >>> from pyspark.sql import functions as sf + >>> spark.sql( + ... "SELECT * FROM VALUES (FLOAT('NAN')), (NULL) AS TAB(value)" + ... ).select("*", sf.cbrt("value")).show() + +-----+-----------+ + |value|CBRT(value)| + +-----+-----------+ + | NaN| NaN| + | NULL| NULL| + +-----+-----------+ """ return _invoke_function_over_columns("cbrt", col) @@ -2154,7 +2244,7 @@ def ceil(col: "ColumnOrName", scale: Optional[Union[Column, int]] = None) -> Col Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name The target column or column name to compute the ceiling on. scale : :class:`~pyspark.sql.Column` or int, optional An optional parameter to control the rounding behavior. @@ -2208,7 +2298,7 @@ def ceiling(col: "ColumnOrName", scale: Optional[Union[Column, int]] = None) -> Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name The target column or column name to compute the ceiling on. scale : :class:`~pyspark.sql.Column` or int An optional parameter to control the rounding behavior. @@ -2262,7 +2352,7 @@ def cos(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name angle in radians Returns @@ -2272,13 +2362,32 @@ def cos(col: "ColumnOrName") -> Column: Examples -------- + Example 1: Compute the cosine + >>> from pyspark.sql import functions as sf - >>> spark.range(1).select(sf.cos(sf.pi())).show() - +---------+ - |COS(PI())| - +---------+ - | -1.0| - +---------+ + >>> spark.sql( + ... "SELECT * FROM VALUES (PI()), (PI() / 4), (PI() / 16) AS TAB(value)" + ... ).select("*", sf.cos("value")).show() + +-------------------+------------------+ + | value| COS(value)| + +-------------------+------------------+ + | 3.141592653589...| -1.0| + | 0.7853981633974...|0.7071067811865...| + |0.19634954084936...|0.9807852804032...| + +-------------------+------------------+ + + Example 2: Compute the cosine of invalid values + + >>> from pyspark.sql import functions as sf + >>> spark.sql( + ... "SELECT * FROM VALUES (FLOAT('NAN')), (NULL) AS TAB(value)" + ... ).select("*", sf.cos("value")).show() + +-----+----------+ + |value|COS(value)| + +-----+----------+ + | NaN| NaN| + | NULL| NULL| + +-----+----------+ """ return _invoke_function_over_columns("cos", col) @@ -2295,7 +2404,7 @@ def cosh(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name hyperbolic angle Returns @@ -2305,9 +2414,31 @@ def cosh(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.range(1) - >>> df.select(cosh(lit(1))).first() - Row(COSH(1)=1.54308...) + Example 1: Compute the cosine + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([(-1,), (0,), (1,)], ["value"]) + >>> df.select("*", sf.cosh(df.value)).show() + +-----+-----------------+ + |value| COSH(value)| + +-----+-----------------+ + | -1|1.543080634815...| + | 0| 1.0| + | 1|1.543080634815...| + +-----+-----------------+ + + Example 2: Compute the cosine of invalid values + + >>> from pyspark.sql import functions as sf + >>> spark.sql( + ... "SELECT * FROM VALUES (FLOAT('NAN')), (NULL) AS TAB(value)" + ... ).select("*", sf.cosh("value")).show() + +-----+-----------+ + |value|COSH(value)| + +-----+-----------+ + | NaN| NaN| + | NULL| NULL| + +-----+-----------+ """ return _invoke_function_over_columns("cosh", col) @@ -2324,7 +2455,7 @@ def cot(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name angle in radians. Returns @@ -2334,13 +2465,32 @@ def cot(col: "ColumnOrName") -> Column: Examples -------- + Example 1: Compute the cotangent + >>> from pyspark.sql import functions as sf - >>> spark.range(1).select(sf.cot(sf.pi() / 4)).show() - +------------------+ - | COT((PI() / 4))| - +------------------+ - |1.0000000000000...| - +------------------+ + >>> spark.sql( + ... "SELECT * FROM VALUES (PI() / 4), (PI() / 16) AS TAB(value)" + ... ).select("*", sf.cot("value")).show() + +-------------------+------------------+ + | value| COT(value)| + +-------------------+------------------+ + | 0.7853981633974...|1.0000000000000...| + |0.19634954084936...| 5.027339492125...| + +-------------------+------------------+ + + Example 2: Compute the cotangent of invalid values + + >>> from pyspark.sql import functions as sf + >>> spark.sql( + ... "SELECT * FROM VALUES (0.0), (FLOAT('NAN')), (NULL) AS TAB(value)" + ... ).select("*", sf.cot("value")).show() + +-----+----------+ + |value|COT(value)| + +-----+----------+ + | 0.0| Infinity| + | NaN| NaN| + | NULL| NULL| + +-----+----------+ """ return _invoke_function_over_columns("cot", col) @@ -2357,7 +2507,7 @@ def csc(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name angle in radians. Returns @@ -2367,13 +2517,32 @@ def csc(col: "ColumnOrName") -> Column: Examples -------- + Example 1: Compute the cosecant + >>> from pyspark.sql import functions as sf - >>> spark.range(1).select(sf.csc(sf.pi() / 2)).show() - +---------------+ - |CSC((PI() / 2))| - +---------------+ - | 1.0| - +---------------+ + >>> spark.sql( + ... "SELECT * FROM VALUES (PI() / 2), (PI() / 4) AS TAB(value)" + ... ).select("*", sf.csc("value")).show() + +------------------+------------------+ + | value| CSC(value)| + +------------------+------------------+ + |1.5707963267948...| 1.0| + |0.7853981633974...|1.4142135623730...| + +------------------+------------------+ + + Example 2: Compute the cosecant of invalid values + + >>> from pyspark.sql import functions as sf + >>> spark.sql( + ... "SELECT * FROM VALUES (0.0), (FLOAT('NAN')), (NULL) AS TAB(value)" + ... ).select("*", sf.csc("value")).show() + +-----+----------+ + |value|CSC(value)| + +-----+----------+ + | 0.0| Infinity| + | NaN| NaN| + | NULL| NULL| + +-----+----------+ """ return _invoke_function_over_columns("csc", col) @@ -2386,7 +2555,8 @@ def e() -> Column: Examples -------- - >>> spark.range(1).select(e()).show() + >>> from pyspark.sql import functions as sf + >>> spark.range(1).select(sf.e()).show() +-----------------+ | E()| +-----------------+ @@ -2408,7 +2578,7 @@ def exp(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name column to calculate exponential for. Returns @@ -2418,13 +2588,33 @@ def exp(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.range(1) - >>> df.select(exp(lit(0))).show() - +------+ - |EXP(0)| - +------+ - | 1.0| - +------+ + Example 1: Compute the exponential + + >>> from pyspark.sql import functions as sf + >>> df = spark.sql("SELECT id AS value FROM RANGE(5)") + >>> df.select("*", sf.exp(df.value)).show() + +-----+------------------+ + |value| EXP(value)| + +-----+------------------+ + | 0| 1.0| + | 1|2.7182818284590...| + | 2| 7.38905609893...| + | 3|20.085536923187...| + | 4|54.598150033144...| + +-----+------------------+ + + Example 2: Compute the exponential of invalid values + + >>> from pyspark.sql import functions as sf + >>> spark.sql( + ... "SELECT * FROM VALUES (FLOAT('NAN')), (NULL) AS TAB(value)" + ... ).select("*", sf.exp("value")).show() + +-----+----------+ + |value|EXP(value)| + +-----+----------+ + | NaN| NaN| + | NULL| NULL| + +-----+----------+ """ return _invoke_function_over_columns("exp", col) @@ -2441,7 +2631,7 @@ def expm1(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name column to calculate exponential for. Returns @@ -2451,9 +2641,33 @@ def expm1(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.range(1) - >>> df.select(expm1(lit(1))).first() - Row(EXPM1(1)=1.71828...) + Example 1: Compute the exponential minus one + + >>> from pyspark.sql import functions as sf + >>> df = spark.sql("SELECT id AS value FROM RANGE(5)") + >>> df.select("*", sf.expm1(df.value)).show() + +-----+------------------+ + |value| EXPM1(value)| + +-----+------------------+ + | 0| 0.0| + | 1| 1.718281828459...| + | 2| 6.38905609893...| + | 3|19.085536923187...| + | 4|53.598150033144...| + +-----+------------------+ + + Example 2: Compute the exponential minus one of invalid values + + >>> from pyspark.sql import functions as sf + >>> spark.sql( + ... "SELECT * FROM VALUES (FLOAT('NAN')), (NULL) AS TAB(value)" + ... ).select("*", sf.expm1("value")).show() + +-----+------------+ + |value|EXPM1(value)| + +-----+------------+ + | NaN| NaN| + | NULL| NULL| + +-----+------------+ """ return _invoke_function_over_columns("expm1", col) @@ -2470,7 +2684,7 @@ def floor(col: "ColumnOrName", scale: Optional[Union[Column, int]] = None) -> Co Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name The target column or column name to compute the floor on. scale : :class:`~pyspark.sql.Column` or int, optional An optional parameter to control the rounding behavior. @@ -2525,7 +2739,7 @@ def log(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name column to calculate natural logarithm for. Returns @@ -2535,6 +2749,8 @@ def log(col: "ColumnOrName") -> Column: Examples -------- + Example 1: Compute the natural logarithm of E + >>> from pyspark.sql import functions as sf >>> spark.range(1).select(sf.log(sf.e())).show() +-------+ @@ -2542,6 +2758,21 @@ def log(col: "ColumnOrName") -> Column: +-------+ | 1.0| +-------+ + + Example 2: Compute the natural logarithm of invalid values + + >>> from pyspark.sql import functions as sf + >>> spark.sql( + ... "SELECT * FROM VALUES (-1), (0), (FLOAT('NAN')), (NULL) AS TAB(value)" + ... ).select("*", sf.log("value")).show() + +-----+---------+ + |value|ln(value)| + +-----+---------+ + | -1.0| NULL| + | 0.0| NULL| + | NaN| NaN| + | NULL| NULL| + +-----+---------+ """ return _invoke_function_over_columns("log", col) @@ -2558,7 +2789,7 @@ def log10(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name column to calculate logarithm for. Returns @@ -2568,13 +2799,33 @@ def log10(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.range(1) - >>> df.select(log10(lit(100))).show() - +----------+ - |LOG10(100)| - +----------+ - | 2.0| - +----------+ + Example 1: Compute the logarithm in Base 10 + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([(1,), (10,), (100,)], ["value"]) + >>> df.select("*", sf.log10(df.value)).show() + +-----+------------+ + |value|LOG10(value)| + +-----+------------+ + | 1| 0.0| + | 10| 1.0| + | 100| 2.0| + +-----+------------+ + + Example 2: Compute the logarithm in Base 10 of invalid values + + >>> from pyspark.sql import functions as sf + >>> spark.sql( + ... "SELECT * FROM VALUES (-1), (0), (FLOAT('NAN')), (NULL) AS TAB(value)" + ... ).select("*", sf.log10("value")).show() + +-----+------------+ + |value|LOG10(value)| + +-----+------------+ + | -1.0| NULL| + | 0.0| NULL| + | NaN| NaN| + | NULL| NULL| + +-----+------------+ """ return _invoke_function_over_columns("log10", col) @@ -2582,7 +2833,7 @@ def log10(col: "ColumnOrName") -> Column: @_try_remote_functions def log1p(col: "ColumnOrName") -> Column: """ - Computes the natural logarithm of the "given value plus one". + Computes the natural logarithm of the given value plus one. .. versionadded:: 1.4.0 @@ -2591,7 +2842,7 @@ def log1p(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name column to calculate natural logarithm for. Returns @@ -2630,7 +2881,7 @@ def negative(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name column to calculate negative value for. Returns @@ -2641,14 +2892,15 @@ def negative(col: "ColumnOrName") -> Column: Examples -------- >>> import pyspark.sql.functions as sf - >>> spark.range(3).select(sf.negative("id")).show() - +------------+ - |negative(id)| - +------------+ - | 0| - | -1| - | -2| - +------------+ + >>> df = spark.createDataFrame([(-1,), (0,), (1,)], ["value"]) + >>> df.select("*", sf.negative(df.value)).show() + +-----+---------------+ + |value|negative(value)| + +-----+---------------+ + | -1| 1| + | 0| 0| + | 1| -1| + +-----+---------------+ """ return _invoke_function_over_columns("negative", col) @@ -2664,7 +2916,8 @@ def pi() -> Column: Examples -------- - >>> spark.range(1).select(pi()).show() + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select(sf.pi()).show() +-----------------+ | PI()| +-----------------+ @@ -2683,7 +2936,7 @@ def positive(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name input value column. Returns @@ -2693,15 +2946,16 @@ def positive(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([(-1,), (0,), (1,)], ['v']) - >>> df.select(positive("v").alias("p")).show() - +---+ - | p| - +---+ - | -1| - | 0| - | 1| - +---+ + >>> import pyspark.sql.functions as sf + >>> df = spark.createDataFrame([(-1,), (0,), (1,)], ["value"]) + >>> df.select("*", sf.positive(df.value)).show() + +-----+---------+ + |value|(+ value)| + +-----+---------+ + | -1| -1| + | 0| 0| + | 1| 1| + +-----+---------+ """ return _invoke_function_over_columns("positive", col) @@ -2719,7 +2973,7 @@ def rint(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name target column to compute on. Returns @@ -2729,15 +2983,15 @@ def rint(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.range(1) - >>> df.select(rint(lit(10.6))).show() + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select(sf.rint(sf.lit(10.6))).show() +----------+ |rint(10.6)| +----------+ | 11.0| +----------+ - >>> df.select(rint(lit(10.3))).show() + >>> spark.range(1).select(sf.rint(sf.lit(10.3))).show() +----------+ |rint(10.3)| +----------+ @@ -2759,7 +3013,7 @@ def sec(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name Angle in radians Returns @@ -2769,9 +3023,31 @@ def sec(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.range(1) - >>> df.select(sec(lit(1.5))).first() - Row(SEC(1.5)=14.13683...) + Example 1: Compute the secant + + >>> from pyspark.sql import functions as sf + >>> spark.sql( + ... "SELECT * FROM VALUES (PI() / 4), (PI() / 16) AS TAB(value)" + ... ).select("*", sf.sec("value")).show() + +-------------------+------------------+ + | value| SEC(value)| + +-------------------+------------------+ + | 0.7853981633974...| 1.414213562373...| + |0.19634954084936...|1.0195911582083...| + +-------------------+------------------+ + + Example 2: Compute the secant of invalid values + + >>> from pyspark.sql import functions as sf + >>> spark.sql( + ... "SELECT * FROM VALUES (FLOAT('NAN')), (NULL) AS TAB(value)" + ... ).select("*", sf.sec("value")).show() + +-----+----------+ + |value|SEC(value)| + +-----+----------+ + | NaN| NaN| + | NULL| NULL| + +-----+----------+ """ return _invoke_function_over_columns("sec", col) @@ -2788,7 +3064,7 @@ def signum(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name target column to compute on. Returns @@ -2801,13 +3077,15 @@ def signum(col: "ColumnOrName") -> Column: >>> import pyspark.sql.functions as sf >>> spark.range(1).select( ... sf.signum(sf.lit(-5)), - ... sf.signum(sf.lit(6)) + ... sf.signum(sf.lit(6)), + ... sf.signum(sf.lit(float('nan'))), + ... sf.signum(sf.lit(None)) ... ).show() - +----------+---------+ - |SIGNUM(-5)|SIGNUM(6)| - +----------+---------+ - | -1.0| 1.0| - +----------+---------+ + +----------+---------+-----------+------------+ + |SIGNUM(-5)|SIGNUM(6)|SIGNUM(NaN)|SIGNUM(NULL)| + +----------+---------+-----------+------------+ + | -1.0| 1.0| NaN| NULL| + +----------+---------+-----------+------------+ """ return _invoke_function_over_columns("signum", col) @@ -2824,7 +3102,7 @@ def sign(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name target column to compute on. Returns @@ -2837,13 +3115,15 @@ def sign(col: "ColumnOrName") -> Column: >>> import pyspark.sql.functions as sf >>> spark.range(1).select( ... sf.sign(sf.lit(-5)), - ... sf.sign(sf.lit(6)) + ... sf.sign(sf.lit(6)), + ... sf.sign(sf.lit(float('nan'))), + ... sf.sign(sf.lit(None)) ... ).show() - +--------+-------+ - |sign(-5)|sign(6)| - +--------+-------+ - | -1.0| 1.0| - +--------+-------+ + +--------+-------+---------+----------+ + |sign(-5)|sign(6)|sign(NaN)|sign(NULL)| + +--------+-------+---------+----------+ + | -1.0| 1.0| NaN| NULL| + +--------+-------+---------+----------+ """ return _invoke_function_over_columns("sign", col) @@ -2860,7 +3140,7 @@ def sin(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name target column to compute on. Returns @@ -2870,13 +3150,32 @@ def sin(col: "ColumnOrName") -> Column: Examples -------- + Example 1: Compute the sine + >>> from pyspark.sql import functions as sf - >>> spark.range(1).select(sf.sin(sf.pi() / 2)).show() - +---------------+ - |SIN((PI() / 2))| - +---------------+ - | 1.0| - +---------------+ + >>> spark.sql( + ... "SELECT * FROM VALUES (0.0), (PI() / 2), (PI() / 4) AS TAB(value)" + ... ).select("*", sf.sin("value")).show() + +------------------+------------------+ + | value| SIN(value)| + +------------------+------------------+ + | 0.0| 0.0| + |1.5707963267948...| 1.0| + |0.7853981633974...|0.7071067811865...| + +------------------+------------------+ + + Example 2: Compute the sine of invalid values + + >>> from pyspark.sql import functions as sf + >>> spark.sql( + ... "SELECT * FROM VALUES (FLOAT('NAN')), (NULL) AS TAB(value)" + ... ).select("*", sf.sin("value")).show() + +-----+----------+ + |value|SIN(value)| + +-----+----------+ + | NaN| NaN| + | NULL| NULL| + +-----+----------+ """ return _invoke_function_over_columns("sin", col) @@ -2893,7 +3192,7 @@ def sinh(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name hyperbolic angle. Returns @@ -2904,9 +3203,31 @@ def sinh(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.range(1) - >>> df.select(sinh(lit(1.1))).first() - Row(SINH(1.1)=1.33564...) + Example 1: Compute the hyperbolic sine + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([(-1,), (0,), (1,)], ["value"]) + >>> df.select("*", sf.sinh(df.value)).show() + +-----+-------------------+ + |value| SINH(value)| + +-----+-------------------+ + | -1|-1.1752011936438...| + | 0| 0.0| + | 1| 1.1752011936438...| + +-----+-------------------+ + + Example 2: Compute the hyperbolic sine of invalid values + + >>> from pyspark.sql import functions as sf + >>> spark.sql( + ... "SELECT * FROM VALUES (FLOAT('NAN')), (NULL) AS TAB(value)" + ... ).select("*", sf.sinh("value")).show() + +-----+-----------+ + |value|SINH(value)| + +-----+-----------+ + | NaN| NaN| + | NULL| NULL| + +-----+-----------+ """ return _invoke_function_over_columns("sinh", col) @@ -2923,7 +3244,7 @@ def tan(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name angle in radians Returns @@ -2933,13 +3254,32 @@ def tan(col: "ColumnOrName") -> Column: Examples -------- + Example 1: Compute the tangent + >>> from pyspark.sql import functions as sf - >>> spark.range(1).select(sf.tan(sf.pi() / 4)).show() - +------------------+ - | TAN((PI() / 4))| - +------------------+ - |0.9999999999999...| - +------------------+ + >>> spark.sql( + ... "SELECT * FROM VALUES (0.0), (PI() / 4), (PI() / 6) AS TAB(value)" + ... ).select("*", sf.tan("value")).show() + +------------------+------------------+ + | value| TAN(value)| + +------------------+------------------+ + | 0.0| 0.0| + |0.7853981633974...|0.9999999999999...| + |0.5235987755982...|0.5773502691896...| + +------------------+------------------+ + + Example 2: Compute the tangent of invalid values + + >>> from pyspark.sql import functions as sf + >>> spark.sql( + ... "SELECT * FROM VALUES (FLOAT('NAN')), (NULL) AS TAB(value)" + ... ).select("*", sf.tan("value")).show() + +-----+----------+ + |value|TAN(value)| + +-----+----------+ + | NaN| NaN| + | NULL| NULL| + +-----+----------+ """ return _invoke_function_over_columns("tan", col) @@ -2956,7 +3296,7 @@ def tanh(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name hyperbolic angle Returns @@ -2967,13 +3307,31 @@ def tanh(col: "ColumnOrName") -> Column: Examples -------- + Example 1: Compute the hyperbolic tangent sine + >>> from pyspark.sql import functions as sf - >>> spark.range(1).select(sf.tanh(sf.pi() / 2)).show() - +------------------+ - | TANH((PI() / 2))| - +------------------+ - |0.9171523356672744| - +------------------+ + >>> df = spark.createDataFrame([(-1,), (0,), (1,)], ["value"]) + >>> df.select("*", sf.tanh(df.value)).show() + +-----+-------------------+ + |value| TANH(value)| + +-----+-------------------+ + | -1|-0.7615941559557...| + | 0| 0.0| + | 1| 0.7615941559557...| + +-----+-------------------+ + + Example 2: Compute the hyperbolic tangent of invalid values + + >>> from pyspark.sql import functions as sf + >>> spark.sql( + ... "SELECT * FROM VALUES (FLOAT('NAN')), (NULL) AS TAB(value)" + ... ).select("*", sf.tanh("value")).show() + +-----+-----------+ + |value|TANH(value)| + +-----+-----------+ + | NaN| NaN| + | NULL| NULL| + +-----+-----------+ """ return _invoke_function_over_columns("tanh", col) @@ -5073,7 +5431,7 @@ def degrees(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name angle in radians Returns @@ -5084,12 +5442,17 @@ def degrees(col: "ColumnOrName") -> Column: Examples -------- >>> from pyspark.sql import functions as sf - >>> spark.range(1).select(sf.degrees(sf.pi())).show() - +-------------+ - |DEGREES(PI())| - +-------------+ - | 180.0| - +-------------+ + >>> spark.sql( + ... "SELECT * FROM VALUES (0.0), (PI()), (PI() / 2), (PI() / 4) AS TAB(value)" + ... ).select("*", sf.degrees("value")).show() + +------------------+--------------+ + | value|DEGREES(value)| + +------------------+--------------+ + | 0.0| 0.0| + | 3.141592653589...| 180.0| + |1.5707963267948...| 90.0| + |0.7853981633974...| 45.0| + +------------------+--------------+ """ return _invoke_function_over_columns("degrees", col) @@ -5107,7 +5470,7 @@ def radians(col: "ColumnOrName") -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name angle in degrees Returns @@ -5117,9 +5480,18 @@ def radians(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.range(1) - >>> df.select(radians(lit(180))).first() - Row(RADIANS(180)=3.14159...) + >>> from pyspark.sql import functions as sf + >>> spark.sql( + ... "SELECT * FROM VALUES (180), (90), (45), (0) AS TAB(value)" + ... ).select("*", sf.radians("value")).show() + +-----+------------------+ + |value| RADIANS(value)| + +-----+------------------+ + | 180| 3.141592653589...| + | 90|1.5707963267948...| + | 45|0.7853981633974...| + | 0| 0.0| + +-----+------------------+ """ return _invoke_function_over_columns("radians", col) @@ -12539,7 +12911,11 @@ def locate(substr: str, str: "ColumnOrName", pos: int = 1) -> Column: @_try_remote_functions -def lpad(col: "ColumnOrName", len: int, pad: str) -> Column: +def lpad( + col: "ColumnOrName", + len: Union[Column, int], + pad: Union[Column, str], +) -> Column: """ Left-pad the string column to width `len` with `pad`. @@ -12550,13 +12926,20 @@ def lpad(col: "ColumnOrName", len: int, pad: str) -> Column: Parameters ---------- - col : :class:`~pyspark.sql.Column` or str + col : :class:`~pyspark.sql.Column` or column name target column to work on. - len : int + len : :class:`~pyspark.sql.Column` or int length of the final string. - pad : str + + .. versionchanged:: 4.0.0 + `pattern` now accepts column. + + pad : :class:`~pyspark.sql.Column` or literal string chars to prepend. + .. versionchanged:: 4.0.0 + `pattern` now accepts column. + Returns ------- :class:`~pyspark.sql.Column` @@ -12564,17 +12947,41 @@ def lpad(col: "ColumnOrName", len: int, pad: str) -> Column: Examples -------- - >>> df = spark.createDataFrame([('abcd',)], ['s',]) - >>> df.select(lpad(df.s, 6, '#').alias('s')).collect() - [Row(s='##abcd')] - """ - from pyspark.sql.classic.column import _to_java_column + Example 1: Pad with a literal string - return _invoke_function("lpad", _to_java_column(col), _enum_to_value(len), _enum_to_value(pad)) + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([('abcd',), ('xyz',), ('12',)], ['s',]) + >>> df.select("*", sf.lpad(df.s, 6, '#')).show() + +----+-------------+ + | s|lpad(s, 6, #)| + +----+-------------+ + |abcd| ##abcd| + | xyz| ###xyz| + | 12| ####12| + +----+-------------+ + + Example 2: Pad with a bytes column + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([('abcd',), ('xyz',), ('12',)], ['s',]) + >>> df.select("*", sf.lpad(df.s, 6, sf.lit(b"\x75\x76"))).show() + +----+-------------------+ + | s|lpad(s, 6, X'7576')| + +----+-------------------+ + |abcd| uvabcd| + | xyz| uvuxyz| + | 12| uvuv12| + +----+-------------------+ + """ + return _invoke_function_over_columns("lpad", col, lit(len), lit(pad)) @_try_remote_functions -def rpad(col: "ColumnOrName", len: int, pad: str) -> Column: +def rpad( + col: "ColumnOrName", + len: Union[Column, int], + pad: Union[Column, str], +) -> Column: """ Right-pad the string column to width `len` with `pad`. @@ -12587,10 +12994,17 @@ def rpad(col: "ColumnOrName", len: int, pad: str) -> Column: ---------- col : :class:`~pyspark.sql.Column` or str target column to work on. - len : int + len : :class:`~pyspark.sql.Column` or int length of the final string. - pad : str - chars to append. + + .. versionchanged:: 4.0.0 + `pattern` now accepts column. + + pad : :class:`~pyspark.sql.Column` or literal string + chars to prepend. + + .. versionchanged:: 4.0.0 + `pattern` now accepts column. Returns ------- @@ -12599,13 +13013,33 @@ def rpad(col: "ColumnOrName", len: int, pad: str) -> Column: Examples -------- - >>> df = spark.createDataFrame([('abcd',)], ['s',]) - >>> df.select(rpad(df.s, 6, '#').alias('s')).collect() - [Row(s='abcd##')] - """ - from pyspark.sql.classic.column import _to_java_column + Example 1: Pad with a literal string - return _invoke_function("rpad", _to_java_column(col), _enum_to_value(len), _enum_to_value(pad)) + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([('abcd',), ('xyz',), ('12',)], ['s',]) + >>> df.select("*", sf.rpad(df.s, 6, '#')).show() + +----+-------------+ + | s|rpad(s, 6, #)| + +----+-------------+ + |abcd| abcd##| + | xyz| xyz###| + | 12| 12####| + +----+-------------+ + + Example 2: Pad with a bytes column + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([('abcd',), ('xyz',), ('12',)], ['s',]) + >>> df.select("*", sf.rpad(df.s, 6, sf.lit(b"\x75\x76"))).show() + +----+-------------------+ + | s|rpad(s, 6, X'7576')| + +----+-------------------+ + |abcd| abcduv| + | xyz| xyzuvu| + | 12| 12uvuv| + +----+-------------------+ + """ + return _invoke_function_over_columns("rpad", col, lit(len), lit(pad)) @_try_remote_functions @@ -21314,6 +21748,108 @@ def make_timestamp( ) +@_try_remote_functions +def try_make_timestamp( + years: "ColumnOrName", + months: "ColumnOrName", + days: "ColumnOrName", + hours: "ColumnOrName", + mins: "ColumnOrName", + secs: "ColumnOrName", + timezone: Optional["ColumnOrName"] = None, +) -> Column: + """ + Try to create timestamp from years, months, days, hours, mins, secs and timezone fields. + The result data type is consistent with the value of configuration `spark.sql.timestampType`. + The function returns NULL on invalid inputs. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + years : :class:`~pyspark.sql.Column` or column name + The year to represent, from 1 to 9999 + months : :class:`~pyspark.sql.Column` or column name + The month-of-year to represent, from 1 (January) to 12 (December) + days : :class:`~pyspark.sql.Column` or column name + The day-of-month to represent, from 1 to 31 + hours : :class:`~pyspark.sql.Column` or column name + The hour-of-day to represent, from 0 to 23 + mins : :class:`~pyspark.sql.Column` or column name + The minute-of-hour to represent, from 0 to 59 + secs : :class:`~pyspark.sql.Column` or column name + The second-of-minute and its micro-fraction to represent, from 0 to 60. + The value can be either an integer like 13 , or a fraction like 13.123. + If the sec argument equals to 60, the seconds field is set + to 0 and 1 minute is added to the final timestamp. + timezone : :class:`~pyspark.sql.Column` or column name, optional + The time zone identifier. For example, CET, UTC and etc. + + Returns + ------- + :class:`~pyspark.sql.Column` + A new column that contains a timestamp or NULL in case of an error. + + Examples + -------- + + Example 1: Make timestamp from years, months, days, hours, mins and secs. + + >>> import pyspark.sql.functions as sf + >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") + >>> df = spark.createDataFrame([[2014, 12, 28, 6, 30, 45.887, 'CET']], + ... ["year", "month", "day", "hour", "min", "sec", "timezone"]) + >>> df.select(sf.try_make_timestamp( + ... df.year, df.month, df.day, df.hour, df.min, df.sec, df.timezone) + ... ).show(truncate=False) + +----------------------------------------------------+ + |try_make_timestamp(year, month, day, hour, min, sec)| + +----------------------------------------------------+ + |2014-12-27 21:30:45.887 | + +----------------------------------------------------+ + + Example 2: Make timestamp without timezone. + + >>> import pyspark.sql.functions as sf + >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") + >>> df = spark.createDataFrame([[2014, 12, 28, 6, 30, 45.887, 'CET']], + ... ["year", "month", "day", "hour", "min", "sec", "timezone"]) + >>> df.select(sf.try_make_timestamp( + ... df.year, df.month, df.day, df.hour, df.min, df.sec) + ... ).show(truncate=False) + +----------------------------------------------------+ + |try_make_timestamp(year, month, day, hour, min, sec)| + +----------------------------------------------------+ + |2014-12-28 06:30:45.887 | + +----------------------------------------------------+ + >>> spark.conf.unset("spark.sql.session.timeZone") + + Example 3: Make timestamp with invalid input. + + >>> import pyspark.sql.functions as sf + >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") + >>> df = spark.createDataFrame([[2014, 13, 28, 6, 30, 45.887, 'CET']], + ... ["year", "month", "day", "hour", "min", "sec", "timezone"]) + >>> df.select(sf.try_make_timestamp( + ... df.year, df.month, df.day, df.hour, df.min, df.sec) + ... ).show(truncate=False) + +----------------------------------------------------+ + |try_make_timestamp(year, month, day, hour, min, sec)| + +----------------------------------------------------+ + |NULL | + +----------------------------------------------------+ + >>> spark.conf.unset("spark.sql.session.timeZone") + """ + if timezone is not None: + return _invoke_function_over_columns( + "try_make_timestamp", years, months, days, hours, mins, secs, timezone + ) + else: + return _invoke_function_over_columns( + "try_make_timestamp", years, months, days, hours, mins, secs + ) + + @_try_remote_functions def make_timestamp_ltz( years: "ColumnOrName", @@ -21400,6 +21936,108 @@ def make_timestamp_ltz( ) +@_try_remote_functions +def try_make_timestamp_ltz( + years: "ColumnOrName", + months: "ColumnOrName", + days: "ColumnOrName", + hours: "ColumnOrName", + mins: "ColumnOrName", + secs: "ColumnOrName", + timezone: Optional["ColumnOrName"] = None, +) -> Column: + """ + Try to create the current timestamp with local time zone from years, months, days, hours, mins, + secs and timezone fields. + The function returns NULL on invalid inputs. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + years : :class:`~pyspark.sql.Column` or column name + The year to represent, from 1 to 9999 + months : :class:`~pyspark.sql.Column` or column name + The month-of-year to represent, from 1 (January) to 12 (December) + days : :class:`~pyspark.sql.Column` or column name + The day-of-month to represent, from 1 to 31 + hours : :class:`~pyspark.sql.Column` or column name + The hour-of-day to represent, from 0 to 23 + mins : :class:`~pyspark.sql.Column` or column name + The minute-of-hour to represent, from 0 to 59 + secs : :class:`~pyspark.sql.Column` or column name + The second-of-minute and its micro-fraction to represent, from 0 to 60. + The value can be either an integer like 13 , or a fraction like 13.123. + If the sec argument equals to 60, the seconds field is set + to 0 and 1 minute is added to the final timestamp. + timezone : :class:`~pyspark.sql.Column` or column name, optional + The time zone identifier. For example, CET, UTC and etc. + + Returns + ------- + :class:`~pyspark.sql.Column` + A new column that contains a current timestamp, or NULL in case of an error. + + Examples + -------- + + Example 1: Make the current timestamp from years, months, days, hours, mins and secs. + + >>> import pyspark.sql.functions as sf + >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") + >>> df = spark.createDataFrame([[2014, 12, 28, 6, 30, 45.887, 'CET']], + ... ["year", "month", "day", "hour", "min", "sec", "timezone"]) + >>> df.select(sf.try_make_timestamp_ltz( + ... df.year, df.month, df.day, df.hour, df.min, df.sec, df.timezone) + ... ).show(truncate=False) + +------------------------------------------------------------------+ + |try_make_timestamp_ltz(year, month, day, hour, min, sec, timezone)| + +------------------------------------------------------------------+ + |2014-12-27 21:30:45.887 | + +------------------------------------------------------------------+ + + Example 2: Make the current timestamp without timezone. + + >>> import pyspark.sql.functions as sf + >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") + >>> df = spark.createDataFrame([[2014, 12, 28, 6, 30, 45.887, 'CET']], + ... ["year", "month", "day", "hour", "min", "sec", "timezone"]) + >>> df.select(sf.try_make_timestamp_ltz( + ... df.year, df.month, df.day, df.hour, df.min, df.sec) + ... ).show(truncate=False) + +--------------------------------------------------------+ + |try_make_timestamp_ltz(year, month, day, hour, min, sec)| + +--------------------------------------------------------+ + |2014-12-28 06:30:45.887 | + +--------------------------------------------------------+ + >>> spark.conf.unset("spark.sql.session.timeZone") + + Example 3: Make the current timestamp with invalid input. + + >>> import pyspark.sql.functions as sf + >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") + >>> df = spark.createDataFrame([[2014, 13, 28, 6, 30, 45.887, 'CET']], + ... ["year", "month", "day", "hour", "min", "sec", "timezone"]) + >>> df.select(sf.try_make_timestamp_ltz( + ... df.year, df.month, df.day, df.hour, df.min, df.sec) + ... ).show(truncate=False) + +--------------------------------------------------------+ + |try_make_timestamp_ltz(year, month, day, hour, min, sec)| + +--------------------------------------------------------+ + |NULL | + +--------------------------------------------------------+ + >>> spark.conf.unset("spark.sql.session.timeZone") + """ + if timezone is not None: + return _invoke_function_over_columns( + "try_make_timestamp_ltz", years, months, days, hours, mins, secs, timezone + ) + else: + return _invoke_function_over_columns( + "try_make_timestamp_ltz", years, months, days, hours, mins, secs + ) + + @_try_remote_functions def make_timestamp_ntz( years: "ColumnOrName", @@ -21463,6 +22101,84 @@ def make_timestamp_ntz( ) +@_try_remote_functions +def try_make_timestamp_ntz( + years: "ColumnOrName", + months: "ColumnOrName", + days: "ColumnOrName", + hours: "ColumnOrName", + mins: "ColumnOrName", + secs: "ColumnOrName", +) -> Column: + """ + Try to create local date-time from years, months, days, hours, mins, secs fields. + The function returns NULL on invalid inputs. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + years : :class:`~pyspark.sql.Column` or column name + The year to represent, from 1 to 9999 + months : :class:`~pyspark.sql.Column` or column name + The month-of-year to represent, from 1 (January) to 12 (December) + days : :class:`~pyspark.sql.Column` or column name + The day-of-month to represent, from 1 to 31 + hours : :class:`~pyspark.sql.Column` or column name + The hour-of-day to represent, from 0 to 23 + mins : :class:`~pyspark.sql.Column` or column name + The minute-of-hour to represent, from 0 to 59 + secs : :class:`~pyspark.sql.Column` or column name + The second-of-minute and its micro-fraction to represent, from 0 to 60. + The value can be either an integer like 13 , or a fraction like 13.123. + If the sec argument equals to 60, the seconds field is set + to 0 and 1 minute is added to the final timestamp. + + Returns + ------- + :class:`~pyspark.sql.Column` + A new column that contains a local date-time, or NULL in case of an error. + + Examples + -------- + + Example 1: Make local date-time from years, months, days, hours, mins, secs. + + >>> import pyspark.sql.functions as sf + >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") + >>> df = spark.createDataFrame([[2014, 12, 28, 6, 30, 45.887]], + ... ["year", "month", "day", "hour", "min", "sec"]) + >>> df.select(sf.try_make_timestamp_ntz( + ... df.year, df.month, df.day, df.hour, df.min, df.sec) + ... ).show(truncate=False) + +--------------------------------------------------------+ + |try_make_timestamp_ntz(year, month, day, hour, min, sec)| + +--------------------------------------------------------+ + |2014-12-28 06:30:45.887 | + +--------------------------------------------------------+ + >>> spark.conf.unset("spark.sql.session.timeZone") + + Example 2: Make local date-time with invalid input + + >>> import pyspark.sql.functions as sf + >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") + >>> df = spark.createDataFrame([[2014, 13, 28, 6, 30, 45.887]], + ... ["year", "month", "day", "hour", "min", "sec"]) + >>> df.select(sf.try_make_timestamp_ntz( + ... df.year, df.month, df.day, df.hour, df.min, df.sec) + ... ).show(truncate=False) + +--------------------------------------------------------+ + |try_make_timestamp_ntz(year, month, day, hour, min, sec)| + +--------------------------------------------------------+ + |NULL | + +--------------------------------------------------------+ + >>> spark.conf.unset("spark.sql.session.timeZone") + """ + return _invoke_function_over_columns( + "try_make_timestamp_ntz", years, months, days, hours, mins, secs + ) + + @_try_remote_functions def make_ym_interval( years: Optional["ColumnOrName"] = None, diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py index 158d9130560aa..328ebe3488781 100644 --- a/python/pyspark/sql/plot/core.py +++ b/python/pyspark/sql/plot/core.py @@ -360,7 +360,7 @@ def pie(self, x: str, y: str, **kwargs: Any) -> "Figure": ) return self(kind="pie", x=x, y=y, **kwargs) - def box(self, column: Union[str, List[str]], **kwargs: Any) -> "Figure": + def box(self, column: Optional[Union[str, List[str]]] = None, **kwargs: Any) -> "Figure": """ Make a box plot of the DataFrame columns. @@ -374,8 +374,9 @@ def box(self, column: Union[str, List[str]], **kwargs: Any) -> "Figure": Parameters ---------- - column: str or list of str - Column name or list of names to be used for creating the boxplot. + column: str or list of str, optional + Column name or list of names to be used for creating the box plot. + If None (default), all numeric columns will be used. **kwargs Extra arguments to `precision`: refer to a float that is used by pyspark to compute approximate statistics for building a boxplot. @@ -399,6 +400,7 @@ def box(self, column: Union[str, List[str]], **kwargs: Any) -> "Figure": ... ] >>> columns = ["student", "math_score", "english_score"] >>> df = spark.createDataFrame(data, columns) + >>> df.plot.box() # doctest: +SKIP >>> df.plot.box(column="math_score") # doctest: +SKIP >>> df.plot.box(column=["math_score", "english_score"]) # doctest: +SKIP """ @@ -406,9 +408,9 @@ def box(self, column: Union[str, List[str]], **kwargs: Any) -> "Figure": def kde( self, - column: Union[str, List[str]], bw_method: Union[int, float], - ind: Union["np.ndarray", int, None] = None, + column: Optional[Union[str, List[str]]] = None, + ind: Optional[Union["np.ndarray", int]] = None, **kwargs: Any, ) -> "Figure": """ @@ -420,11 +422,12 @@ def kde( Parameters ---------- - column: str or list of str - Column name or list of names to be used for creating the kde plot. bw_method : int or float The method used to calculate the estimator bandwidth. See KernelDensity in PySpark for more information. + column: str or list of str, optional + Column name or list of names to be used for creating the kde plot. + If None (default), all numeric columns will be used. ind : NumPy array or integer, optional Evaluation points for the estimated PDF. If None (default), 1000 equally spaced points are used. If `ind` is a NumPy array, the @@ -442,12 +445,15 @@ def kde( >>> data = [(5.1, 3.5, 0), (4.9, 3.0, 0), (7.0, 3.2, 1), (6.4, 3.2, 1), (5.9, 3.0, 2)] >>> columns = ["length", "width", "species"] >>> df = spark.createDataFrame(data, columns) + >>> df.plot.kde(bw_method=0.3) # doctest: +SKIP >>> df.plot.kde(column=["length", "width"], bw_method=0.3) # doctest: +SKIP >>> df.plot.kde(column="length", bw_method=0.3) # doctest: +SKIP """ return self(kind="kde", column=column, bw_method=bw_method, ind=ind, **kwargs) - def hist(self, column: Union[str, List[str]], bins: int = 10, **kwargs: Any) -> "Figure": + def hist( + self, column: Optional[Union[str, List[str]]] = None, bins: int = 10, **kwargs: Any + ) -> "Figure": """ Draw one histogram of the DataFrame’s columns. @@ -457,8 +463,9 @@ def hist(self, column: Union[str, List[str]], bins: int = 10, **kwargs: Any) -> Parameters ---------- - column: str or list of str - Column name or list of names to be used for creating the histogram. + column: str or list of str, optional + Column name or list of names to be used for creating the hostogram plot. + If None (default), all numeric columns will be used. bins : integer, default 10 Number of histogram bins to be used. **kwargs @@ -473,6 +480,7 @@ def hist(self, column: Union[str, List[str]], bins: int = 10, **kwargs: Any) -> >>> data = [(5.1, 3.5, 0), (4.9, 3.0, 0), (7.0, 3.2, 1), (6.4, 3.2, 1), (5.9, 3.0, 2)] >>> columns = ["length", "width", "species"] >>> df = spark.createDataFrame(data, columns) + >>> df.plot.hist(bins=4) # doctest: +SKIP >>> df.plot.hist(column=["length", "width"]) # doctest: +SKIP >>> df.plot.hist(column="length", bins=4) # doctest: +SKIP """ @@ -481,7 +489,7 @@ def hist(self, column: Union[str, List[str]], bins: int = 10, **kwargs: Any) -> class PySparkKdePlotBase: @staticmethod - def get_ind(sdf: "DataFrame", ind: Union["np.ndarray", int, None]) -> "np.ndarray": + def get_ind(sdf: "DataFrame", ind: Optional[Union["np.ndarray", int]]) -> "np.ndarray": require_minimum_numpy_version() import numpy as np diff --git a/python/pyspark/sql/plot/plotly.py b/python/pyspark/sql/plot/plotly.py index ece5995bf2817..ceae4b999aa83 100644 --- a/python/pyspark/sql/plot/plotly.py +++ b/python/pyspark/sql/plot/plotly.py @@ -16,15 +16,16 @@ # import inspect -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, List, Optional, Union -from pyspark.errors import PySparkValueError +from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.sql.plot import ( PySparkPlotAccessor, PySparkBoxPlotBase, PySparkKdePlotBase, PySparkHistogramPlotBase, ) +from pyspark.sql.types import NumericType if TYPE_CHECKING: from pyspark.sql import DataFrame @@ -67,9 +68,7 @@ def plot_box(data: "DataFrame", **kwargs: Any) -> "Figure": whis = kwargs.pop("whis", 1.5) # 'precision' is pyspark specific to control precision for approx_percentile precision = kwargs.pop("precision", 0.01) - colnames = kwargs.pop("column", None) - if isinstance(colnames, str): - colnames = [colnames] + colnames = process_column_param(kwargs.pop("column", None), data) # Plotly options boxpoints = kwargs.pop("boxpoints", "suspectedoutliers") @@ -142,9 +141,7 @@ def plot_kde(data: "DataFrame", **kwargs: Any) -> "Figure": kwargs["color"] = "names" bw_method = kwargs.pop("bw_method", None) - colnames = kwargs.pop("column", None) - if isinstance(colnames, str): - colnames = [colnames] + colnames = process_column_param(kwargs.pop("column", None), data) ind = PySparkKdePlotBase.get_ind(data.select(*colnames), kwargs.pop("ind", None)) kde_cols = [ @@ -177,13 +174,11 @@ def plot_histogram(data: "DataFrame", **kwargs: Any) -> "Figure": import plotly.graph_objs as go bins = kwargs.get("bins", 10) - colnames = kwargs.pop("column", None) - if isinstance(colnames, str): - colnames = [colnames] - data = data.select(*colnames) - bins = PySparkHistogramPlotBase.get_bins(data, bins) + colnames = process_column_param(kwargs.pop("column", None), data) + numeric_data = data.select(*colnames) + bins = PySparkHistogramPlotBase.get_bins(numeric_data, bins) assert len(bins) > 2, "the number of buckets must be higher than 2." - output_series = PySparkHistogramPlotBase.compute_hist(data, bins) + output_series = PySparkHistogramPlotBase.compute_hist(numeric_data, bins) prev = float("%.9f" % bins[0]) # to make it prettier, truncate. text_bins = [] for b in bins[1:]: @@ -214,3 +209,34 @@ def plot_histogram(data: "DataFrame", **kwargs: Any) -> "Figure": fig["layout"]["xaxis"]["title"] = "value" fig["layout"]["yaxis"]["title"] = "count" return fig + + +def process_column_param(column: Optional[Union[str, List[str]]], data: "DataFrame") -> List[str]: + """ + Processes the provided column parameter for a DataFrame. + - If `column` is None, returns a list of numeric columns from the DataFrame. + - If `column` is a string, converts it to a list first. + - If `column` is a list, it checks if all specified columns exist in the DataFrame + and are of NumericType. + - Raises a PySparkTypeError if any column in the list is not present in the DataFrame + or is not of NumericType. + """ + if column is None: + return [ + field.name for field in data.schema.fields if isinstance(field.dataType, NumericType) + ] + if isinstance(column, str): + column = [column] + + for col in column: + field = next((f for f in data.schema.fields if f.name == col), None) + if not field or not isinstance(field.dataType, NumericType): + raise PySparkTypeError( + errorClass="PLOT_INVALID_TYPE_COLUMN", + messageParameters={ + "col_name": col, + "valid_types": NumericType.__name__, + "col_type": field.dataType.__class__.__name__ if field else "None", + }, + ) + return column diff --git a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py index 4f24b4c463f51..95a706c9d9972 100644 --- a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py +++ b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py @@ -45,7 +45,7 @@ def sdf(self): @property def sdf2(self): - data = [(5.1, 3.5, 0), (4.9, 3.0, 0), (7.0, 3.2, 1), (6.4, 3.2, 1), (5.9, 3.0, 2)] + data = [(5.1, 3.5, "0"), (4.9, 3.0, "0"), (7.0, 3.2, "1"), (6.4, 3.2, "1"), (5.9, 3.0, "2")] columns = ["length", "width", "species"] return self.spark.createDataFrame(data, columns) @@ -330,7 +330,7 @@ def test_pie_plot(self): def test_box_plot(self): fig = self.sdf4.plot.box(column="math_score") - expected_fig_data = { + expected_fig_data1 = { "boxpoints": "suspectedoutliers", "lowerfence": (5,), "mean": (50.0,), @@ -343,11 +343,11 @@ def test_box_plot(self): "x": [0], "type": "box", } - self._check_fig_data(fig["data"][0], **expected_fig_data) + self._check_fig_data(fig["data"][0], **expected_fig_data1) fig = self.sdf4.plot(kind="box", column=["math_score", "english_score"]) - self._check_fig_data(fig["data"][0], **expected_fig_data) - expected_fig_data = { + self._check_fig_data(fig["data"][0], **expected_fig_data1) + expected_fig_data2 = { "boxpoints": "suspectedoutliers", "lowerfence": (55,), "mean": (72.5,), @@ -361,7 +361,12 @@ def test_box_plot(self): "y": [[150, 15]], "type": "box", } - self._check_fig_data(fig["data"][1], **expected_fig_data) + self._check_fig_data(fig["data"][1], **expected_fig_data2) + + fig = self.sdf4.plot(kind="box") + self._check_fig_data(fig["data"][0], **expected_fig_data1) + self._check_fig_data(fig["data"][1], **expected_fig_data2) + with self.assertRaises(PySparkValueError) as pe: self.sdf4.plot.box(column="math_score", boxpoints=True) self.check_error( @@ -390,7 +395,7 @@ def test_box_plot(self): @unittest.skipIf(not have_numpy, numpy_requirement_message) def test_kde_plot(self): fig = self.sdf4.plot.kde(column="math_score", bw_method=0.3, ind=5) - expected_fig_data = { + expected_fig_data1 = { "mode": "lines", "name": "math_score", "orientation": "v", @@ -398,11 +403,11 @@ def test_kde_plot(self): "yaxis": "y", "type": "scatter", } - self._check_fig_data(fig["data"][0], **expected_fig_data) + self._check_fig_data(fig["data"][0], **expected_fig_data1) fig = self.sdf4.plot.kde(column=["math_score", "english_score"], bw_method=0.3, ind=5) - self._check_fig_data(fig["data"][0], **expected_fig_data) - expected_fig_data = { + self._check_fig_data(fig["data"][0], **expected_fig_data1) + expected_fig_data2 = { "mode": "lines", "name": "english_score", "orientation": "v", @@ -410,7 +415,12 @@ def test_kde_plot(self): "yaxis": "y", "type": "scatter", } - self._check_fig_data(fig["data"][1], **expected_fig_data) + self._check_fig_data(fig["data"][1], **expected_fig_data2) + self.assertEqual(list(fig["data"][0]["x"]), list(fig["data"][1]["x"])) + + fig = self.sdf4.plot.kde(bw_method=0.3, ind=5) + self._check_fig_data(fig["data"][0], **expected_fig_data1) + self._check_fig_data(fig["data"][1], **expected_fig_data2) self.assertEqual(list(fig["data"][0]["x"]), list(fig["data"][1]["x"])) def test_hist_plot(self): @@ -423,23 +433,55 @@ def test_hist_plot(self): "type": "bar", } self._check_fig_data(fig["data"][0], **expected_fig_data) + fig = self.sdf2.plot.hist(column=["length", "width"], bins=4) - expected_fig_data = { + expected_fig_data1 = { "name": "length", "x": [3.5, 4.5, 5.5, 6.5], "y": [0, 1, 2, 2], "text": ("[3.0, 4.0)", "[4.0, 5.0)", "[5.0, 6.0)", "[6.0, 7.0]"), "type": "bar", } - self._check_fig_data(fig["data"][0], **expected_fig_data) - expected_fig_data = { + self._check_fig_data(fig["data"][0], **expected_fig_data1) + expected_fig_data2 = { "name": "width", "x": [3.5, 4.5, 5.5, 6.5], "y": [5, 0, 0, 0], "text": ("[3.0, 4.0)", "[4.0, 5.0)", "[5.0, 6.0)", "[6.0, 7.0]"), "type": "bar", } - self._check_fig_data(fig["data"][1], **expected_fig_data) + self._check_fig_data(fig["data"][1], **expected_fig_data2) + + fig = self.sdf2.plot.hist(bins=4) + self._check_fig_data(fig["data"][0], **expected_fig_data1) + self._check_fig_data(fig["data"][1], **expected_fig_data2) + + def test_process_column_param_errors(self): + with self.assertRaises(PySparkTypeError) as pe: + self.sdf4.plot.box(column="math_scor") + + self.check_error( + exception=pe.exception, + errorClass="PLOT_INVALID_TYPE_COLUMN", + messageParameters={ + "col_name": "math_scor", + "valid_types": "NumericType", + "col_type": "None", + }, + ) + + with self.assertRaises(PySparkTypeError) as pe: + self.sdf4.plot.box(column="student") + + self.check_error( + exception=pe.exception, + errorClass="PLOT_INVALID_TYPE_COLUMN", + messageParameters={ + "col_name": "student", + "valid_types": "NumericType", + "col_type": "StringType", + }, + ) class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, ReusedSQLTestCase): diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index c00a0e7febf67..74e043ca1e6e8 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -347,6 +347,61 @@ def test_try_parse_url(self): actual = df.select(F.try_parse_url(df.url, df.part, df.key)).collect() self.assertEqual(actual, [Row(None)]) + def test_try_make_timestamp(self): + data = [(2024, 5, 22, 10, 30, 0)] + df = self.spark.createDataFrame(data, ["year", "month", "day", "hour", "minute", "second"]) + actual = df.select( + F.try_make_timestamp(df.year, df.month, df.day, df.hour, df.minute, df.second) + ).collect() + self.assertEqual(actual, [Row(datetime.datetime(2024, 5, 22, 10, 30))]) + + data = [(2024, 13, 22, 10, 30, 0)] + df = self.spark.createDataFrame(data, ["year", "month", "day", "hour", "minute", "second"]) + actual = df.select( + F.try_make_timestamp(df.year, df.month, df.day, df.hour, df.minute, df.second) + ).collect() + self.assertEqual(actual, [Row(None)]) + + def test_try_make_timestamp_ltz(self): + # use local timezone here to avoid flakiness + data = [(2024, 5, 22, 10, 30, 0, datetime.datetime.now().astimezone().tzinfo.__str__())] + df = self.spark.createDataFrame( + data, ["year", "month", "day", "hour", "minute", "second", "timezone"] + ) + actual = df.select( + F.try_make_timestamp_ltz( + df.year, df.month, df.day, df.hour, df.minute, df.second, df.timezone + ) + ).collect() + self.assertEqual(actual, [Row(datetime.datetime(2024, 5, 22, 10, 30, 0))]) + + # use local timezone here to avoid flakiness + data = [(2024, 13, 22, 10, 30, 0, datetime.datetime.now().astimezone().tzinfo.__str__())] + df = self.spark.createDataFrame( + data, ["year", "month", "day", "hour", "minute", "second", "timezone"] + ) + actual = df.select( + F.try_make_timestamp_ltz( + df.year, df.month, df.day, df.hour, df.minute, df.second, df.timezone + ) + ).collect() + self.assertEqual(actual, [Row(None)]) + + def test_try_make_timestamp_ntz(self): + data = [(2024, 5, 22, 10, 30, 0)] + df = self.spark.createDataFrame(data, ["year", "month", "day", "hour", "minute", "second"]) + actual = df.select( + F.try_make_timestamp_ntz(df.year, df.month, df.day, df.hour, df.minute, df.second) + ).collect() + self.assertEqual(actual, [Row(datetime.datetime(2024, 5, 22, 10, 30))]) + + data = [(2024, 13, 22, 10, 30, 0)] + df = self.spark.createDataFrame(data, ["year", "month", "day", "hour", "minute", "second"]) + actual = df.select( + F.try_make_timestamp_ntz(df.year, df.month, df.day, df.hour, df.minute, df.second) + ).collect() + self.assertEqual(actual, [Row(None)]) + def test_string_functions(self): string_functions = [ "upper", diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala index 907c46f583cf1..0ee1d7037d438 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala @@ -116,10 +116,10 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase { def arithmeticOverflowError( message: String, - hint: String = "", + suggestedFunc: String = "", context: QueryContext = null): ArithmeticException = { - val alternative = if (hint.nonEmpty) { - s" Use '$hint' to tolerate overflow and return NULL instead." + val alternative = if (suggestedFunc.nonEmpty) { + s" Use '$suggestedFunc' to tolerate overflow and return NULL instead." } else "" new SparkArithmeticException( errorClass = "ARITHMETIC_OVERFLOW", diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala index a9e556fad0464..0fa6eb0434ab1 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala @@ -284,9 +284,16 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { from: String, to: String, ctx: ParserRuleContext): Throwable = { + val intervalInput = ctx.getText() + val pattern = "'([^']*)'".r + val input = pattern.findFirstMatchIn(intervalInput) match { + case Some(m) => m.group(1) + case None => "" + } + new ParseException( - errorClass = "_LEGACY_ERROR_TEMP_0028", - messageParameters = Map("from" -> from, "to" -> to), + errorClass = "INVALID_INTERVAL_FORMAT.UNSUPPORTED_FROM_TO_EXPRESSION", + messageParameters = Map("input" -> input, "from" -> from, "to" -> to), ctx) } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index d81b9c5060f68..d7b61468b43d7 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -4075,8 +4075,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def lpad(str: Column, len: Int, pad: String): Column = - Column.fn("lpad", str, lit(len), lit(pad)) + def lpad(str: Column, len: Int, pad: String): Column = lpad(str, lit(len), lit(pad)) /** * Left-pad the binary column with pad to a byte length of len. If the binary column is longer @@ -4085,8 +4084,16 @@ object functions { * @group string_funcs * @since 3.3.0 */ - def lpad(str: Column, len: Int, pad: Array[Byte]): Column = - Column.fn("lpad", str, lit(len), lit(pad)) + def lpad(str: Column, len: Int, pad: Array[Byte]): Column = lpad(str, lit(len), lit(pad)) + + /** + * Left-pad the string column with pad to a length of len. If the string column is longer than + * len, the return value is shortened to len characters. + * + * @group string_funcs + * @since 4.0.0 + */ + def lpad(str: Column, len: Column, pad: Column): Column = Column.fn("lpad", str, len, pad) /** * Trim the spaces from left end for the specified string value. @@ -4263,8 +4270,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def rpad(str: Column, len: Int, pad: String): Column = - Column.fn("rpad", str, lit(len), lit(pad)) + def rpad(str: Column, len: Int, pad: String): Column = rpad(str, lit(len), lit(pad)) /** * Right-pad the binary column with pad to a byte length of len. If the binary column is longer @@ -4273,8 +4279,16 @@ object functions { * @group string_funcs * @since 3.3.0 */ - def rpad(str: Column, len: Int, pad: Array[Byte]): Column = - Column.fn("rpad", str, lit(len), lit(pad)) + def rpad(str: Column, len: Int, pad: Array[Byte]): Column = rpad(str, lit(len), lit(pad)) + + /** + * Right-pad the string column with pad to a length of len. If the string column is longer than + * len, the return value is shortened to len characters. + * + * @group string_funcs + * @since 4.0.0 + */ + def rpad(str: Column, len: Column, pad: Column): Column = Column.fn("rpad", str, len, pad) /** * Repeats a string column n times, and returns it as a new string column. @@ -8105,6 +8119,41 @@ object functions { secs: Column): Column = Column.fn("make_timestamp", years, months, days, hours, mins, secs) + /** + * Try to create a timestamp from years, months, days, hours, mins, secs and timezone fields. + * The result data type is consistent with the value of configuration `spark.sql.timestampType`. + * The function returns NULL on invalid inputs. + * + * @group datetime_funcs + * @since 4.0.0 + */ + def try_make_timestamp( + years: Column, + months: Column, + days: Column, + hours: Column, + mins: Column, + secs: Column, + timezone: Column): Column = + Column.fn("try_make_timestamp", years, months, days, hours, mins, secs, timezone) + + /** + * Try to create a timestamp from years, months, days, hours, mins, and secs fields. The result + * data type is consistent with the value of configuration `spark.sql.timestampType`. The + * function returns NULL on invalid inputs. + * + * @group datetime_funcs + * @since 4.0.0 + */ + def try_make_timestamp( + years: Column, + months: Column, + days: Column, + hours: Column, + mins: Column, + secs: Column): Column = + Column.fn("try_make_timestamp", years, months, days, hours, mins, secs) + /** * Create the current timestamp with local time zone from years, months, days, hours, mins, secs * and timezone fields. If the configuration `spark.sql.ansi.enabled` is false, the function @@ -8140,6 +8189,39 @@ object functions { secs: Column): Column = Column.fn("make_timestamp_ltz", years, months, days, hours, mins, secs) + /** + * Try to create the current timestamp with local time zone from years, months, days, hours, + * mins, secs and timezone fields. The function returns NULL on invalid inputs. + * + * @group datetime_funcs + * @since 4.0.0 + */ + def try_make_timestamp_ltz( + years: Column, + months: Column, + days: Column, + hours: Column, + mins: Column, + secs: Column, + timezone: Column): Column = + Column.fn("try_make_timestamp_ltz", years, months, days, hours, mins, secs, timezone) + + /** + * Try to create the current timestamp with local time zone from years, months, days, hours, + * mins and secs fields. The function returns NULL on invalid inputs. + * + * @group datetime_funcs + * @since 4.0.0 + */ + def try_make_timestamp_ltz( + years: Column, + months: Column, + days: Column, + hours: Column, + mins: Column, + secs: Column): Column = + Column.fn("try_make_timestamp_ltz", years, months, days, hours, mins, secs) + /** * Create local date-time from years, months, days, hours, mins, secs fields. If the * configuration `spark.sql.ansi.enabled` is false, the function returns NULL on invalid inputs. @@ -8157,6 +8239,22 @@ object functions { secs: Column): Column = Column.fn("make_timestamp_ntz", years, months, days, hours, mins, secs) + /** + * Try to create a local date-time from years, months, days, hours, mins, secs fields. The + * function returns NULL on invalid inputs. + * + * @group datetime_funcs + * @since 4.0.0 + */ + def try_make_timestamp_ntz( + years: Column, + months: Column, + days: Column, + hours: Column, + mins: Column, + secs: Column): Column = + Column.fn("try_make_timestamp_ntz", years, months, days, hours, mins, secs) + /** * Make year-month interval from years, months. * diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArrayExpressionUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArrayExpressionUtils.java index ff6525acbe539..5411aa684ea5f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArrayExpressionUtils.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArrayExpressionUtils.java @@ -19,20 +19,13 @@ import java.util.Arrays; import java.util.Comparator; -import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.SQLOrderingUtil; -import org.apache.spark.sql.types.ByteType$; -import org.apache.spark.sql.types.BooleanType$; -import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.DoubleType$; -import org.apache.spark.sql.types.FloatType$; -import org.apache.spark.sql.types.IntegerType$; -import org.apache.spark.sql.types.LongType$; -import org.apache.spark.sql.types.ShortType$; public class ArrayExpressionUtils { - private static final Comparator booleanComp = (o1, o2) -> { + // comparator + // Boolean ascending nullable comparator + private static final Comparator booleanComp = (o1, o2) -> { if (o1 == null && o2 == null) { return 0; } else if (o1 == null) { @@ -40,11 +33,11 @@ public class ArrayExpressionUtils { } else if (o2 == null) { return 1; } - boolean c1 = (Boolean) o1, c2 = (Boolean) o2; - return c1 == c2 ? 0 : (c1 ? 1 : -1); + return o1.equals(o2) ? 0 : (o1 ? 1 : -1); }; - private static final Comparator byteComp = (o1, o2) -> { + // Byte ascending nullable comparator + private static final Comparator byteComp = (o1, o2) -> { if (o1 == null && o2 == null) { return 0; } else if (o1 == null) { @@ -52,11 +45,11 @@ public class ArrayExpressionUtils { } else if (o2 == null) { return 1; } - byte c1 = (Byte) o1, c2 = (Byte) o2; - return Byte.compare(c1, c2); + return Byte.compare(o1, o2); }; - private static final Comparator shortComp = (o1, o2) -> { + // Short ascending nullable comparator + private static final Comparator shortComp = (o1, o2) -> { if (o1 == null && o2 == null) { return 0; } else if (o1 == null) { @@ -64,11 +57,11 @@ public class ArrayExpressionUtils { } else if (o2 == null) { return 1; } - short c1 = (Short) o1, c2 = (Short) o2; - return Short.compare(c1, c2); + return Short.compare(o1, o2); }; - private static final Comparator integerComp = (o1, o2) -> { + // Integer ascending nullable comparator + private static final Comparator integerComp = (o1, o2) -> { if (o1 == null && o2 == null) { return 0; } else if (o1 == null) { @@ -76,11 +69,11 @@ public class ArrayExpressionUtils { } else if (o2 == null) { return 1; } - int c1 = (Integer) o1, c2 = (Integer) o2; - return Integer.compare(c1, c2); + return Integer.compare(o1, o2); }; - private static final Comparator longComp = (o1, o2) -> { + // Long ascending nullable comparator + private static final Comparator longComp = (o1, o2) -> { if (o1 == null && o2 == null) { return 0; } else if (o1 == null) { @@ -88,11 +81,11 @@ public class ArrayExpressionUtils { } else if (o2 == null) { return 1; } - long c1 = (Long) o1, c2 = (Long) o2; - return Long.compare(c1, c2); + return Long.compare(o1, o2); }; - private static final Comparator floatComp = (o1, o2) -> { + // Float ascending nullable comparator + private static final Comparator floatComp = (o1, o2) -> { if (o1 == null && o2 == null) { return 0; } else if (o1 == null) { @@ -100,11 +93,11 @@ public class ArrayExpressionUtils { } else if (o2 == null) { return 1; } - float c1 = (Float) o1, c2 = (Float) o2; - return SQLOrderingUtil.compareFloats(c1, c2); + return SQLOrderingUtil.compareFloats(o1, o2); }; - private static final Comparator doubleComp = (o1, o2) -> { + // Double ascending nullable comparator + private static final Comparator doubleComp = (o1, o2) -> { if (o1 == null && o2 == null) { return 0; } else if (o1 == null) { @@ -112,65 +105,104 @@ public class ArrayExpressionUtils { } else if (o2 == null) { return 1; } - double c1 = (Double) o1, c2 = (Double) o2; - return SQLOrderingUtil.compareDoubles(c1, c2); + return SQLOrderingUtil.compareDoubles(o1, o2); }; - public static int binarySearchNullSafe(ArrayData data, Boolean value) { - return Arrays.binarySearch(data.toObjectArray(BooleanType$.MODULE$), value, booleanComp); + // boolean + // boolean non-nullable + public static int binarySearch(boolean[] data, boolean value) { + int low = 0; + int high = data.length - 1; + + while (low <= high) { + int mid = (low + high) >>> 1; + boolean midVal = data[mid]; + + if (value == midVal) { + return mid; // key found + } else if (value) { + low = mid + 1; + } else { + high = mid - 1; + } + } + + return -(low + 1); // key not found. + } + + // Boolean nullable + public static int binarySearch(Boolean[] data, Boolean value) { + return Arrays.binarySearch(data, value, booleanComp); } - public static int binarySearch(ArrayData data, byte value) { - return Arrays.binarySearch(data.toByteArray(), value); + // byte + // byte non-nullable + public static int binarySearch(byte[] data, byte value) { + return Arrays.binarySearch(data, value); } - public static int binarySearchNullSafe(ArrayData data, Byte value) { - return Arrays.binarySearch(data.toObjectArray(ByteType$.MODULE$), value, byteComp); + // Byte nullable + public static int binarySearch(Byte[] data, Byte value) { + return Arrays.binarySearch(data, value, byteComp); } - public static int binarySearch(ArrayData data, short value) { - return Arrays.binarySearch(data.toShortArray(), value); + // short + // short non-nullable + public static int binarySearch(short[] data, short value) { + return Arrays.binarySearch(data, value); } - public static int binarySearchNullSafe(ArrayData data, Short value) { - return Arrays.binarySearch(data.toObjectArray(ShortType$.MODULE$), value, shortComp); + // Short nullable + public static int binarySearch(Short[] data, Short value) { + return Arrays.binarySearch(data, value, shortComp); } - public static int binarySearch(ArrayData data, int value) { - return Arrays.binarySearch(data.toIntArray(), value); + // int + // int non-nullable + public static int binarySearch(int[] data, int value) { + return Arrays.binarySearch(data, value); } - public static int binarySearchNullSafe(ArrayData data, Integer value) { - return Arrays.binarySearch(data.toObjectArray(IntegerType$.MODULE$), value, integerComp); + // Integer nullable + public static int binarySearch(Integer[] data, Integer value) { + return Arrays.binarySearch(data, value, integerComp); } - public static int binarySearch(ArrayData data, long value) { - return Arrays.binarySearch(data.toLongArray(), value); + // long + // long non-nullable + public static int binarySearch(long[] data, long value) { + return Arrays.binarySearch(data, value); } - public static int binarySearchNullSafe(ArrayData data, Long value) { - return Arrays.binarySearch(data.toObjectArray(LongType$.MODULE$), value, longComp); + // Long nullable + public static int binarySearch(Long[] data, Long value) { + return Arrays.binarySearch(data, value, longComp); } - public static int binarySearch(ArrayData data, float value) { - return Arrays.binarySearch(data.toFloatArray(), value); + // float + // float non-nullable + public static int binarySearch(float[] data, float value) { + return Arrays.binarySearch(data, value); } - public static int binarySearchNullSafe(ArrayData data, Float value) { - return Arrays.binarySearch(data.toObjectArray(FloatType$.MODULE$), value, floatComp); + // Float nullable + public static int binarySearch(Float[] data, Float value) { + return Arrays.binarySearch(data, value, floatComp); } - public static int binarySearch(ArrayData data, double value) { - return Arrays.binarySearch(data.toDoubleArray(), value); + // double + // double non-nullable + public static int binarySearch(double[] data, double value) { + return Arrays.binarySearch(data, value); } - public static int binarySearchNullSafe(ArrayData data, Double value) { - return Arrays.binarySearch(data.toObjectArray(DoubleType$.MODULE$), value, doubleComp); + // Double nullable + public static int binarySearch(Double[] data, Double value) { + return Arrays.binarySearch(data, value, doubleComp); } - public static int binarySearch( - DataType elementType, Comparator comp, ArrayData data, Object value) { - Object[] array = data.toObjectArray(elementType); - return Arrays.binarySearch(array, value, comp); + // Object + public static int binarySearch(Object[] data, Object value, Comparator comp) { + return Arrays.binarySearch(data, value, comp); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ToJavaArrayUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ToJavaArrayUtils.java new file mode 100644 index 0000000000000..ead138590ca50 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ToJavaArrayUtils.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions; + +import scala.reflect.ClassTag$; + +import org.apache.spark.sql.catalyst.util.ArrayData; + +import static org.apache.spark.sql.types.DataTypes.BooleanType; +import static org.apache.spark.sql.types.DataTypes.ByteType; +import static org.apache.spark.sql.types.DataTypes.DoubleType; +import static org.apache.spark.sql.types.DataTypes.FloatType; +import static org.apache.spark.sql.types.DataTypes.IntegerType; +import static org.apache.spark.sql.types.DataTypes.LongType; +import static org.apache.spark.sql.types.DataTypes.ShortType; + +public class ToJavaArrayUtils { + + // boolean + // boolean non-nullable + public static boolean[] toBooleanArray(ArrayData arrayData) { + return arrayData.toBooleanArray(); + } + + // Boolean nullable + public static Boolean[] toBoxedBooleanArray(ArrayData arrayData) { + return (Boolean[]) arrayData.toArray(BooleanType, + ClassTag$.MODULE$.apply(java.lang.Boolean.class)); + } + + // byte + // byte non-nullable + public static byte[] toByteArray(ArrayData arrayData) { + return arrayData.toByteArray(); + } + + // Byte nullable + public static Byte[] toBoxedByteArray(ArrayData arrayData) { + return (Byte[]) arrayData.toArray(ByteType, ClassTag$.MODULE$.apply(java.lang.Byte.class)); + } + + // short + // short non-nullable + public static short[] toShortArray(ArrayData arrayData) { + return arrayData.toShortArray(); + } + + // Short nullable + public static Short[] toBoxedShortArray(ArrayData arrayData) { + return (Short[]) arrayData.toArray(ShortType, ClassTag$.MODULE$.apply(java.lang.Short.class)); + } + + // int + // int non-nullable + public static int[] toIntegerArray(ArrayData arrayData) { + return arrayData.toIntArray(); + } + + // Integer nullable + public static Integer[] toBoxedIntegerArray(ArrayData arrayData) { + return (Integer[]) arrayData.toArray(IntegerType, + ClassTag$.MODULE$.apply(java.lang.Integer.class)); + } + + // long + // long non-nullable + public static long[] toLongArray(ArrayData arrayData) { + return arrayData.toLongArray(); + } + + // Long nullable + public static Long[] toBoxedLongArray(ArrayData arrayData) { + return (Long[]) arrayData.toArray(LongType, ClassTag$.MODULE$.apply(java.lang.Long.class)); + } + + // float + // float non-nullable + public static float[] toFloatArray(ArrayData arrayData) { + return arrayData.toFloatArray(); + } + + // Float nullable + public static Float[] toBoxedFloatArray(ArrayData arrayData) { + return (Float[]) arrayData.toArray(FloatType, ClassTag$.MODULE$.apply(java.lang.Float.class)); + } + + // double + // double non-nullable + public static double[] toDoubleArray(ArrayData arrayData) { + return arrayData.toDoubleArray(); + } + + // Double nullable + public static Double[] toBoxedDoubleArray(ArrayData arrayData) { + return (Double[]) arrayData.toArray(DoubleType, + ClassTag$.MODULE$.apply(java.lang.Double.class)); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9d7ea6148757d..6b64f493f4052 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3929,7 +3929,7 @@ object CleanupAliases extends Rule[LogicalPlan] with AliasHelper { object EliminateEventTimeWatermark extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( _.containsPattern(EVENT_TIME_WATERMARK)) { - case EventTimeWatermark(_, _, child) if child.resolved && !child.isStreaming => child + case EventTimeWatermark(_, _, _, child) if child.resolved && !child.isStreaming => child case UpdateEventTimeWatermarkColumn(_, _, child) if child.resolved && !child.isStreaming => child } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 3836eabe6bec6..4ad0b81b8f269 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -665,10 +665,15 @@ object FunctionRegistry { expression[WindowTime]("window_time"), expression[MakeDate]("make_date"), expression[MakeTimestamp]("make_timestamp"), + expression[TryMakeTimestamp]("try_make_timestamp"), expression[MonthName]("monthname"), // We keep the 2 expression builders below to have different function docs. expressionBuilder("make_timestamp_ntz", MakeTimestampNTZExpressionBuilder, setAlias = true), expressionBuilder("make_timestamp_ltz", MakeTimestampLTZExpressionBuilder, setAlias = true), + expressionBuilder( + "try_make_timestamp_ntz", TryMakeTimestampNTZExpressionBuilder, setAlias = true), + expressionBuilder( + "try_make_timestamp_ltz", TryMakeTimestampLTZExpressionBuilder, setAlias = true), expression[MakeInterval]("make_interval"), expression[MakeDTInterval]("make_dt_interval"), expression[MakeYMInterval]("make_ym_interval"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUpdateEventTimeWatermarkColumn.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUpdateEventTimeWatermarkColumn.scala index 31c4f068a83eb..cddc519d0887e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUpdateEventTimeWatermarkColumn.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUpdateEventTimeWatermarkColumn.scala @@ -36,7 +36,7 @@ object ResolveUpdateEventTimeWatermarkColumn extends Rule[LogicalPlan] { _.containsPattern(UPDATE_EVENT_TIME_WATERMARK_COLUMN), ruleId) { case u: UpdateEventTimeWatermarkColumn if u.delay.isEmpty && u.childrenResolved => val existingWatermarkDelay = u.child.collect { - case EventTimeWatermark(_, delay, _) => delay + case EventTimeWatermark(_, _, delay, _) => delay } if (existingWatermarkDelay.isEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToJavaArray.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToJavaArray.scala new file mode 100644 index 0000000000000..861d7ff4024a3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToJavaArray.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.lang.reflect.{Array => JArray} + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator +import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} +import org.apache.spark.sql.errors.QueryErrorsBase +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +/** + * This expression converts data of `ArrayData` to an array of java type. + * + * NOTE: When the data type of expression is `ArrayType`, and the expression is foldable, + * the `ConstantFolding` can do constant folding optimization automatically, + * (avoiding frequent calls to `ArrayData.to{XXX}Array()`). + */ +case class ToJavaArray(array: Expression) + extends UnaryExpression + with NullIntolerant + with RuntimeReplaceable + with QueryErrorsBase { + + override def checkInputDataTypes(): TypeCheckResult = array.dataType match { + case ArrayType(_, _) => + TypeCheckResult.TypeCheckSuccess + case _ => + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> ordinalNumber(0), + "requiredType" -> toSQLType(ArrayType), + "inputSql" -> toSQLExpr(array), + "inputType" -> toSQLType(array.dataType)) + ) + } + + override def foldable: Boolean = array.foldable + + override def child: Expression = array + override def prettyName: String = "to_java_array" + + private def resultArrayElementNullable: Boolean = + array.dataType.asInstanceOf[ArrayType].containsNull + private def isPrimitiveType: Boolean = CodeGenerator.isPrimitiveType(elementType) + private def canPerformFast: Boolean = isPrimitiveType && !resultArrayElementNullable + + @transient lazy val elementType: DataType = + array.dataType.asInstanceOf[ArrayType].elementType + @transient private lazy val elementObjectType = ObjectType(classOf[DataType]) + @transient private lazy val elementCls: Class[_] = { + if (canPerformFast) { + CodeGenerator.javaClass(elementType) + } else if (isPrimitiveType) { + Utils.classForName(s"java.lang.${CodeGenerator.boxedType(elementType)}") + } else { + classOf[Object] + } + } + @transient private lazy val returnCls = JArray.newInstance(elementCls, 0).getClass + + override def dataType: DataType = ObjectType(returnCls) + + override def replacement: Expression = { + if (isPrimitiveType) { + val funcNamePrefix = if (resultArrayElementNullable) "toBoxed" else "to" + val funcName = s"$funcNamePrefix${CodeGenerator.boxedType(elementType)}Array" + StaticInvoke( + classOf[ToJavaArrayUtils], + dataType, + funcName, + Seq(array), + Seq(array.dataType)) + } else { + Invoke( + array, + "toObjectArray", + dataType, + Seq(Literal(elementType, elementObjectType)), + Seq(elementObjectType)) + } + } + + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(array = newChild) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 0d563530bcbcf..10e64626d1a1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1603,12 +1603,7 @@ case class ArrayBinarySearch(array: Expression, value: Expression) @transient private lazy val elementType: DataType = array.dataType.asInstanceOf[ArrayType].elementType - @transient private lazy val resultArrayElementNullable: Boolean = - array.dataType.asInstanceOf[ArrayType].containsNull - @transient private lazy val isPrimitiveType: Boolean = CodeGenerator.isPrimitiveType(elementType) - @transient private lazy val canPerformFastBinarySearch: Boolean = isPrimitiveType && - elementType != BooleanType && !resultArrayElementNullable @transient private lazy val comp: Comparator[Any] = new Comparator[Any] with Serializable { private val ordering = array.dataType match { @@ -1619,39 +1614,28 @@ case class ArrayBinarySearch(array: Expression, value: Expression) override def compare(o1: Any, o2: Any): Int = (o1, o2) match { case (null, null) => 0 - case (null, _) => 1 - case (_, null) => -1 + case (null, _) => -1 + case (_, null) => 1 case _ => ordering.compare(o1, o2) } } - @transient private lazy val elementObjectType = ObjectType(classOf[DataType]) - @transient private lazy val comparatorObjectType = ObjectType(classOf[Comparator[Object]]) - override def replacement: Expression = - if (canPerformFastBinarySearch) { - StaticInvoke( - classOf[ArrayExpressionUtils], - IntegerType, - "binarySearch", - Seq(array, value), - inputTypes) - } else if (isPrimitiveType) { - StaticInvoke( - classOf[ArrayExpressionUtils], - IntegerType, - "binarySearchNullSafe", - Seq(array, value), - inputTypes) + @transient private lazy val comparatorObjectType = ObjectType(classOf[Comparator[Object]]) + + override def replacement: Expression = { + val toJavaArray = ToJavaArray(array) + val (arguments, inputTypes) = if (isPrimitiveType) { + (Seq(toJavaArray, value), Seq(toJavaArray.dataType, value.dataType)) } else { - StaticInvoke( - classOf[ArrayExpressionUtils], - IntegerType, - "binarySearch", - Seq(Literal(elementType, elementObjectType), - Literal(comp, comparatorObjectType), - array, - value), - elementObjectType +: comparatorObjectType +: inputTypes) + (Seq(toJavaArray, value, Literal(comp, comparatorObjectType)), + Seq(toJavaArray.dataType, value.dataType, comparatorObjectType)) + } + StaticInvoke( + classOf[ArrayExpressionUtils], + IntegerType, + "binarySearch", + arguments, + inputTypes) } override def prettyName: String = "array_binary_search" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index d0c4a53e491d8..dd20418496ca4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -2561,6 +2561,53 @@ object MakeTimestampNTZExpressionBuilder extends ExpressionBuilder { } } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(year, month, day, hour, min, sec) - Try to create local date-time from year, month, day, hour, min, sec fields. The function returns NULL on invalid inputs.", + arguments = """ + Arguments: + * year - the year to represent, from 1 to 9999 + * month - the month-of-year to represent, from 1 (January) to 12 (December) + * day - the day-of-month to represent, from 1 to 31 + * hour - the hour-of-day to represent, from 0 to 23 + * min - the minute-of-hour to represent, from 0 to 59 + * sec - the second-of-minute and its micro-fraction to represent, from + 0 to 60. If the sec argument equals to 60, the seconds field is set + to 0 and 1 minute is added to the final timestamp. + """, + examples = """ + Examples: + > SELECT _FUNC_(2014, 12, 28, 6, 30, 45.887); + 2014-12-28 06:30:45.887 + > SELECT _FUNC_(2019, 6, 30, 23, 59, 60); + 2019-07-01 00:00:00 + > SELECT _FUNC_(null, 7, 22, 15, 30, 0); + NULL + > SELECT _FUNC_(2024, 13, 22, 15, 30, 0); + NULL + """, + group = "datetime_funcs", + since = "4.0.0") +// scalastyle:on line.size.limit +object TryMakeTimestampNTZExpressionBuilder extends ExpressionBuilder { + override def build(funcName: String, expressions: Seq[Expression]): Expression = { + val numArgs = expressions.length + if (numArgs == 6) { + MakeTimestamp( + expressions(0), + expressions(1), + expressions(2), + expressions(3), + expressions(4), + expressions(5), + dataType = TimestampNTZType, + failOnError = false) + } else { + throw QueryCompilationErrors.wrongNumArgsError(funcName, Seq(6), numArgs) + } + } +} + // scalastyle:off line.size.limit @ExpressionDescription( usage = "_FUNC_(year, month, day, hour, min, sec[, timezone]) - Create the current timestamp with local time zone from year, month, day, hour, min, sec and timezone fields. If the configuration `spark.sql.ansi.enabled` is false, the function returns NULL on invalid inputs. Otherwise, it will throw an error instead.", @@ -2609,6 +2656,57 @@ object MakeTimestampLTZExpressionBuilder extends ExpressionBuilder { } } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(year, month, day, hour, min, sec[, timezone]) - Try to create the current timestamp with local time zone from year, month, day, hour, min, sec and timezone fields. The function returns NULL on invalid inputs.", + arguments = """ + Arguments: + * year - the year to represent, from 1 to 9999 + * month - the month-of-year to represent, from 1 (January) to 12 (December) + * day - the day-of-month to represent, from 1 to 31 + * hour - the hour-of-day to represent, from 0 to 23 + * min - the minute-of-hour to represent, from 0 to 59 + * sec - the second-of-minute and its micro-fraction to represent, from + 0 to 60. If the sec argument equals to 60, the seconds field is set + to 0 and 1 minute is added to the final timestamp. + * timezone - the time zone identifier. For example, CET, UTC and etc. + """, + examples = """ + Examples: + > SELECT _FUNC_(2014, 12, 28, 6, 30, 45.887); + 2014-12-28 06:30:45.887 + > SELECT _FUNC_(2014, 12, 28, 6, 30, 45.887, 'CET'); + 2014-12-27 21:30:45.887 + > SELECT _FUNC_(2019, 6, 30, 23, 59, 60); + 2019-07-01 00:00:00 + > SELECT _FUNC_(null, 7, 22, 15, 30, 0); + NULL + > SELECT _FUNC_(2024, 13, 22, 15, 30, 0); + NULL + """, + group = "datetime_funcs", + since = "4.0.0") +// scalastyle:on line.size.limit +object TryMakeTimestampLTZExpressionBuilder extends ExpressionBuilder { + override def build(funcName: String, expressions: Seq[Expression]): Expression = { + val numArgs = expressions.length + if (numArgs == 6 || numArgs == 7) { + MakeTimestamp( + expressions(0), + expressions(1), + expressions(2), + expressions(3), + expressions(4), + expressions(5), + expressions.drop(6).lastOption, + dataType = TimestampType, + failOnError = false) + } else { + throw QueryCompilationErrors.wrongNumArgsError(funcName, Seq(6), numArgs) + } + } +} + // scalastyle:off line.size.limit @ExpressionDescription( usage = "_FUNC_(year, month, day, hour, min, sec[, timezone]) - Create timestamp from year, month, day, hour, min, sec and timezone fields. The result data type is consistent with the value of configuration `spark.sql.timestampType`. If the configuration `spark.sql.ansi.enabled` is false, the function returns NULL on invalid inputs. Otherwise, it will throw an error instead.", @@ -2710,7 +2808,7 @@ case class MakeTimestamp( // This case of sec = 60 and nanos = 0 is supported for compatibility with PostgreSQL LocalDateTime.of(year, month, day, hour, min, 0, 0).plusMinutes(1) } else { - throw QueryExecutionErrors.invalidFractionOfSecondError() + throw QueryExecutionErrors.invalidFractionOfSecondError(secAndMicros) } } else { LocalDateTime.of(year, month, day, hour, min, seconds, nanos) @@ -2781,7 +2879,7 @@ case class MakeTimestamp( ldt = java.time.LocalDateTime.of( $year, $month, $day, $hour, $min, 0, 0).plusMinutes(1); } else { - throw QueryExecutionErrors.invalidFractionOfSecondError(); + throw QueryExecutionErrors.invalidFractionOfSecondError($secAndNanos); } } else { ldt = java.time.LocalDateTime.of($year, $month, $day, $hour, $min, seconds, nanos); @@ -2812,6 +2910,89 @@ case class MakeTimestamp( } } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(year, month, day, hour, min, sec[, timezone]) - Try to create a timestamp from year, month, day, hour, min, sec and timezone fields. The result data type is consistent with the value of configuration `spark.sql.timestampType`. The function returns NULL on invalid inputs.", + arguments = """ + Arguments: + * year - the year to represent, from 1 to 9999 + * month - the month-of-year to represent, from 1 (January) to 12 (December) + * day - the day-of-month to represent, from 1 to 31 + * hour - the hour-of-day to represent, from 0 to 23 + * min - the minute-of-hour to represent, from 0 to 59 + * sec - the second-of-minute and its micro-fraction to represent, from 0 to 60. + The value can be either an integer like 13 , or a fraction like 13.123. + If the sec argument equals to 60, the seconds field is set + to 0 and 1 minute is added to the final timestamp. + * timezone - the time zone identifier. For example, CET, UTC and etc. + """, + examples = """ + Examples: + > SELECT _FUNC_(2014, 12, 28, 6, 30, 45.887); + 2014-12-28 06:30:45.887 + > SELECT _FUNC_(2014, 12, 28, 6, 30, 45.887, 'CET'); + 2014-12-27 21:30:45.887 + > SELECT _FUNC_(2019, 6, 30, 23, 59, 60); + 2019-07-01 00:00:00 + > SELECT _FUNC_(2019, 6, 30, 23, 59, 1); + 2019-06-30 23:59:01 + > SELECT _FUNC_(null, 7, 22, 15, 30, 0); + NULL + > SELECT _FUNC_(2024, 13, 22, 15, 30, 0); + NULL + """, + group = "datetime_funcs", + since = "4.0.0") +// scalastyle:on line.size.limit +case class TryMakeTimestamp( + year: Expression, + month: Expression, + day: Expression, + hour: Expression, + min: Expression, + sec: Expression, + timezone: Option[Expression], + timeZoneId: Option[String], + replacement: Expression) + extends RuntimeReplaceable with InheritAnalysisRules { + + private def this( + year: Expression, + month: Expression, + day: Expression, + hour: Expression, + min: Expression, + sec: Expression, + timezone: Option[Expression]) = this(year, month, day, hour, min, sec, timezone, None, + MakeTimestamp(year, month, day, hour, min, sec, timezone, None, failOnError = false)) + + def this( + year: Expression, + month: Expression, + day: Expression, + hour: Expression, + min: Expression, + sec: Expression, + timezone: Expression) = this(year, month, day, hour, min, sec, Some(timezone)) + + def this( + year: Expression, + month: Expression, + day: Expression, + hour: Expression, + min: Expression, + sec: Expression) = this(year, month, day, hour, min, sec, None) + + override def prettyName: String = "try_make_timestamp" + + override def parameters: Seq[Expression] = Seq( + year, month, day, hour, min, sec) + + override protected def withNewChildInternal(newChild: Expression): TryMakeTimestamp = { + copy(replacement = newChild) + } +} + object DatePart { def parseExtractField( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 5f13d397d1bf9..f7509f124ab50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -166,7 +166,9 @@ case class CheckOverflowInSum( val value = child.eval(input) if (value == null) { if (nullOnOverflow) null - else throw QueryExecutionErrors.overflowInSumOfDecimalError(context) + else { + throw QueryExecutionErrors.overflowInSumOfDecimalError(context, suggestedFunc = "try_sum") + } } else { value.asInstanceOf[Decimal].toPrecision( dataType.precision, @@ -183,7 +185,7 @@ case class CheckOverflowInSum( val nullHandling = if (nullOnOverflow) { "" } else { - s"throw QueryExecutionErrors.overflowInSumOfDecimalError($errorContextCode);" + s"""throw QueryExecutionErrors.overflowInSumOfDecimalError($errorContextCode, "try_sum");""" } // scalastyle:off line.size.limit val code = code""" @@ -270,7 +272,8 @@ case class DecimalDivideWithOverflowCheck( if (nullOnOverflow) { null } else { - throw QueryExecutionErrors.overflowInSumOfDecimalError(getContextOrNull()) + throw QueryExecutionErrors.overflowInSumOfDecimalError(getContextOrNull(), + suggestedFunc = "try_avg") } } else { val value2 = right.eval(input) @@ -286,7 +289,7 @@ case class DecimalDivideWithOverflowCheck( val nullHandling = if (nullOnOverflow) { "" } else { - s"throw QueryExecutionErrors.overflowInSumOfDecimalError($errorContextCode);" + s"""throw QueryExecutionErrors.overflowInSumOfDecimalError($errorContextCode, "try_avg");""" } val eval1 = left.genCode(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 9af63a754124c..7c198f05cf496 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -86,11 +86,29 @@ trait InvokeLike extends Expression with NonSQLExpression with ImplicitCastInput // Returns true if we can trust all values of the given DataType can be serialized. private def trustedSerializable(dt: DataType): Boolean = { - // Right now we conservatively block all ObjectType (Java objects) regardless of - // serializability, because the type-level info with java.io.Serializable and - // java.io.Externalizable marker interfaces are not strong guarantees. + // Right now we conservatively block all ObjectType (Java objects) except for + // it's `cls` equal to `Array[JavaBoxedPrimitive]` & `JavaBoxedPrimitive` + // regardless of serializability, because the type-level info with java.io.Serializable + // and java.io.Externalizable marker interfaces are not strong guarantees. // This restriction can be relaxed in the future to expose more optimizations. - !dt.existsRecursively(_.isInstanceOf[ObjectType]) + !dt.existsRecursively { + case ObjectType(cls) if cls == classOf[Array[java.lang.Boolean]] => false + case ObjectType(cls) if cls == classOf[Array[java.lang.Byte]] => false + case ObjectType(cls) if cls == classOf[Array[java.lang.Short]] => false + case ObjectType(cls) if cls == classOf[Array[java.lang.Integer]] => false + case ObjectType(cls) if cls == classOf[Array[java.lang.Long]] => false + case ObjectType(cls) if cls == classOf[Array[java.lang.Float]] => false + case ObjectType(cls) if cls == classOf[Array[java.lang.Double]] => false + case ObjectType(cls) if cls == classOf[java.lang.Boolean] => false + case ObjectType(cls) if cls == classOf[java.lang.Byte] => false + case ObjectType(cls) if cls == classOf[java.lang.Short] => false + case ObjectType(cls) if cls == classOf[java.lang.Integer] => false + case ObjectType(cls) if cls == classOf[java.lang.Long] => false + case ObjectType(cls) if cls == classOf[java.lang.Float] => false + case ObjectType(cls) if cls == classOf[java.lang.Double] => false + case ObjectType(_) => true + case _ => false + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 4367920f939e4..2452da5d69682 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -82,10 +82,12 @@ case class ConcatWs(children: Seq[Expression]) /** The 1st child (separator) is str, and rest are either str or array of str. */ override def inputTypes: Seq[AbstractDataType] = { val arrayOrStr = - TypeCollection(AbstractArrayType(StringTypeWithCollation), - StringTypeWithCollation + TypeCollection(AbstractArrayType( + StringTypeWithCollation(supportsTrimCollation = true)), + StringTypeWithCollation(supportsTrimCollation = true) ) - StringTypeWithCollation +: Seq.fill(children.size - 1)(arrayOrStr) + StringTypeWithCollation(supportsTrimCollation = true) +: + Seq.fill(children.size - 1)(arrayOrStr) } override def dataType: DataType = children.head.dataType @@ -436,7 +438,8 @@ trait String2StringExpression extends ImplicitCastInputTypes { def convert(v: UTF8String): UTF8String override def dataType: DataType = child.dataType - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCollation(supportsTrimCollation = true)) protected override def nullSafeEval(input: Any): Any = convert(input.asInstanceOf[UTF8String]) @@ -518,7 +521,8 @@ abstract class StringPredicate extends BinaryExpression def compare(l: UTF8String, r: UTF8String): Boolean override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCollation, StringTypeWithCollation) + Seq(StringTypeWithCollation(supportsTrimCollation = true), + StringTypeWithCollation(supportsTrimCollation = true)) protected override def nullSafeEval(input1: Any, input2: Any): Any = compare(input1.asInstanceOf[UTF8String], input2.asInstanceOf[UTF8String]) @@ -613,7 +617,9 @@ case class Contains(left: Expression, right: Expression) extends StringPredicate CollationSupport.Contains.genCode(c1, c2, collationId)) } override def inputTypes : Seq[AbstractDataType] = - Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation) + Seq(StringTypeNonCSAICollation(supportsTrimCollation = true), + StringTypeNonCSAICollation(supportsTrimCollation = true) + ) override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): Contains = copy(left = newLeft, right = newRight) } @@ -657,7 +663,11 @@ case class StartsWith(left: Expression, right: Expression) extends StringPredica } override def inputTypes : Seq[AbstractDataType] = - Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, StringTypeNonCSAICollation) + Seq( + StringTypeNonCSAICollation(supportsTrimCollation = true), + StringTypeNonCSAICollation(supportsTrimCollation = true), + StringTypeNonCSAICollation(supportsTrimCollation = true) + ) override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): StartsWith = copy(left = newLeft, right = newRight) @@ -702,7 +712,11 @@ case class EndsWith(left: Expression, right: Expression) extends StringPredicate } override def inputTypes : Seq[AbstractDataType] = - Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, StringTypeNonCSAICollation) + Seq( + StringTypeNonCSAICollation(supportsTrimCollation = true), + StringTypeNonCSAICollation(supportsTrimCollation = true), + StringTypeNonCSAICollation(supportsTrimCollation = true) + ) override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): EndsWith = copy(left = newLeft, right = newRight) @@ -735,7 +749,8 @@ case class IsValidUTF8(input: Expression) extends RuntimeReplaceable with Implic override lazy val replacement: Expression = Invoke(input, "isValid", BooleanType) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCollation(supportsTrimCollation = true)) override def nodeName: String = "is_valid_utf8" @@ -782,7 +797,8 @@ case class MakeValidUTF8(input: Expression) extends RuntimeReplaceable with Impl override lazy val replacement: Expression = Invoke(input, "makeValid", input.dataType) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCollation(supportsTrimCollation = true)) override def nodeName: String = "make_valid_utf8" @@ -827,7 +843,8 @@ case class ValidateUTF8(input: Expression) extends RuntimeReplaceable with Impli Seq(input), inputTypes) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCollation(supportsTrimCollation = true)) override def nodeName: String = "validate_utf8" @@ -876,7 +893,8 @@ case class TryValidateUTF8(input: Expression) extends RuntimeReplaceable with Im Seq(input), inputTypes) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCollation(supportsTrimCollation = true)) override def nodeName: String = "try_validate_utf8" @@ -932,7 +950,11 @@ case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExp override def dataType: DataType = srcExpr.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, StringTypeNonCSAICollation) + Seq( + StringTypeNonCSAICollation(supportsTrimCollation = true), + StringTypeNonCSAICollation(supportsTrimCollation = true), + StringTypeNonCSAICollation(supportsTrimCollation = true) + ) override def first: Expression = srcExpr override def second: Expression = searchExpr override def third: Expression = replaceExpr @@ -1011,8 +1033,14 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len: override def dataType: DataType = input.dataType override def inputTypes: Seq[AbstractDataType] = Seq( - TypeCollection(StringTypeWithCollation, BinaryType), - TypeCollection(StringTypeWithCollation, BinaryType), IntegerType, IntegerType) + TypeCollection( + StringTypeWithCollation(supportsTrimCollation = true), BinaryType + ), + TypeCollection( + StringTypeWithCollation(supportsTrimCollation = true), BinaryType + ), + IntegerType, + IntegerType) override def checkInputDataTypes(): TypeCheckResult = { val inputTypeCheck = super.checkInputDataTypes() @@ -1180,7 +1208,10 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac override def dataType: DataType = srcExpr.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, StringTypeNonCSAICollation) + Seq( + StringTypeNonCSAICollation(supportsTrimCollation = true), + StringTypeNonCSAICollation(supportsTrimCollation = true), + StringTypeNonCSAICollation(supportsTrimCollation = true)) override def first: Expression = srcExpr override def second: Expression = matchingExpr override def third: Expression = replaceExpr @@ -1216,7 +1247,10 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCollation, StringTypeWithCollation) + Seq( + StringTypeWithCollation(supportsTrimCollation = true), + StringTypeWithCollation(supportsTrimCollation = true) + ) override protected def nullSafeEval(word: Any, set: Any): Any = { CollationSupport.FindInSet. @@ -1245,7 +1279,7 @@ trait String2TrimExpression extends Expression with ImplicitCastInputTypes { override def children: Seq[Expression] = srcStr +: trimStr.toSeq override def dataType: DataType = srcStr.dataType override def inputTypes: Seq[AbstractDataType] = - Seq.fill(children.size)(StringTypeWithCollation) + Seq.fill(children.size)(StringTypeWithCollation(supportsTrimCollation = true)) final lazy val collationId: Int = srcStr.dataType.asInstanceOf[StringType].collationId @@ -1409,7 +1443,10 @@ case class StringTrim(srcStr: Expression, trimStr: Option[Expression] = None) CollationSupport.StringTrim.exec(srcString, trimString, collationId) override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation) + Seq( + StringTypeNonCSAICollation(supportsTrimCollation = true), + StringTypeNonCSAICollation(supportsTrimCollation = true) + ) override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy( @@ -1519,7 +1556,10 @@ case class StringTrimLeft(srcStr: Expression, trimStr: Option[Expression] = None CollationSupport.StringTrimLeft.exec(srcString, trimString, collationId) override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation) + Seq( + StringTypeNonCSAICollation(supportsTrimCollation = true), + StringTypeNonCSAICollation(supportsTrimCollation = true) + ) override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): StringTrimLeft = @@ -1582,7 +1622,10 @@ case class StringTrimRight(srcStr: Expression, trimStr: Option[Expression] = Non CollationSupport.StringTrimRight.exec(srcString, trimString, collationId) override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation) + Seq( + StringTypeNonCSAICollation(supportsTrimCollation = true), + StringTypeNonCSAICollation(supportsTrimCollation = true) + ) override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): StringTrimRight = @@ -1618,7 +1661,10 @@ case class StringInstr(str: Expression, substr: Expression) override def right: Expression = substr override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation) + Seq( + StringTypeNonCSAICollation(supportsTrimCollation = true), + StringTypeNonCSAICollation(supportsTrimCollation = true) + ) override def nullSafeEval(string: Any, sub: Any): Any = { CollationSupport.StringInstr. @@ -1666,7 +1712,11 @@ case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: override def dataType: DataType = strExpr.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, IntegerType) + Seq( + StringTypeNonCSAICollation(supportsTrimCollation = true), + StringTypeNonCSAICollation(supportsTrimCollation = true), + IntegerType + ) override def first: Expression = strExpr override def second: Expression = delimExpr override def third: Expression = countExpr @@ -1724,7 +1774,11 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) override def nullable: Boolean = substr.nullable || str.nullable override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, IntegerType) + Seq( + StringTypeNonCSAICollation(supportsTrimCollation = true), + StringTypeNonCSAICollation(supportsTrimCollation = true), + IntegerType + ) override def eval(input: InternalRow): Any = { val s = start.eval(input) @@ -1850,7 +1904,11 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) override def dataType: DataType = str.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCollation, IntegerType, StringTypeWithCollation) + Seq( + StringTypeWithCollation(supportsTrimCollation = true), + IntegerType, + StringTypeWithCollation(supportsTrimCollation = true) + ) override def nullSafeEval(string: Any, len: Any, pad: Any): Any = { string.asInstanceOf[UTF8String].lpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) @@ -1930,7 +1988,11 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression = Litera override def dataType: DataType = str.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCollation, IntegerType, StringTypeWithCollation) + Seq( + StringTypeWithCollation(supportsTrimCollation = true), + IntegerType, + StringTypeWithCollation(supportsTrimCollation = true) + ) override def nullSafeEval(string: Any, len: Any, pad: Any): Any = { string.asInstanceOf[UTF8String].rpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) @@ -1975,7 +2037,8 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC override def dataType: DataType = children(0).dataType override def inputTypes: Seq[AbstractDataType] = - StringTypeWithCollation :: List.fill(children.size - 1)(AnyDataType) + StringTypeWithCollation(supportsTrimCollation = true) :: + List.fill(children.size - 1)(AnyDataType) override def checkInputDataTypes(): TypeCheckResult = { if (children.isEmpty) { @@ -2086,7 +2149,8 @@ case class InitCap(child: Expression) // Flag to indicate whether to use ICU instead of JVM case mappings for UTF8_BINARY collation. private final lazy val useICU = SQLConf.get.getConf(SQLConf.ICU_CASE_MAPPINGS_ENABLED) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCollation(supportsTrimCollation = true)) override def dataType: DataType = child.dataType override def nullSafeEval(string: Any): Any = { @@ -2119,7 +2183,10 @@ case class StringRepeat(str: Expression, times: Expression) override def right: Expression = times override def dataType: DataType = str.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCollation, IntegerType) + Seq( + StringTypeWithCollation(supportsTrimCollation = true), + IntegerType + ) override def nullSafeEval(string: Any, n: Any): Any = { string.asInstanceOf[UTF8String].repeat(n.asInstanceOf[Integer]) @@ -2212,7 +2279,11 @@ case class Substring(str: Expression, pos: Expression, len: Expression) override def dataType: DataType = str.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringTypeWithCollation, BinaryType), IntegerType, IntegerType) + Seq( + TypeCollection(StringTypeWithCollation(supportsTrimCollation = true), BinaryType), + IntegerType, + IntegerType + ) override def first: Expression = str override def second: Expression = pos @@ -2271,7 +2342,10 @@ case class Right(str: Expression, len: Expression) extends RuntimeReplaceable ) override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCollation, IntegerType) + Seq( + StringTypeWithCollation(supportsTrimCollation = true), + IntegerType + ) override def left: Expression = str override def right: Expression = len override protected def withNewChildrenInternal( @@ -2302,7 +2376,12 @@ case class Left(str: Expression, len: Expression) extends RuntimeReplaceable override lazy val replacement: Expression = Substring(str, Literal(1), len) override def inputTypes: Seq[AbstractDataType] = { - Seq(TypeCollection(StringTypeWithCollation, BinaryType), IntegerType) + Seq( + TypeCollection( + StringTypeWithCollation(supportsTrimCollation = true), + BinaryType) + , IntegerType + ) } override def left: Expression = str @@ -2338,7 +2417,12 @@ case class Length(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringTypeWithCollation, BinaryType)) + Seq( + TypeCollection( + StringTypeWithCollation(supportsTrimCollation = true), + BinaryType + ) + ) protected override def nullSafeEval(value: Any): Any = child.dataType match { case _: StringType => value.asInstanceOf[UTF8String].numChars @@ -2373,8 +2457,12 @@ case class BitLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringTypeWithCollation, BinaryType)) - + Seq( + TypeCollection( + StringTypeWithCollation(supportsTrimCollation = true), + BinaryType + ) + ) protected override def nullSafeEval(value: Any): Any = child.dataType match { case _: StringType => value.asInstanceOf[UTF8String].numBytes * 8 case BinaryType => value.asInstanceOf[Array[Byte]].length * 8 @@ -2412,7 +2500,12 @@ case class OctetLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringTypeWithCollation, BinaryType)) + Seq( + TypeCollection( + StringTypeWithCollation(supportsTrimCollation = true), + BinaryType + ) + ) protected override def nullSafeEval(value: Any): Any = child.dataType match { case _: StringType => value.asInstanceOf[UTF8String].numBytes @@ -2473,8 +2566,16 @@ case class Levenshtein( override def inputTypes: Seq[AbstractDataType] = threshold match { case Some(_) => - Seq(StringTypeWithCollation, StringTypeWithCollation, IntegerType) - case _ => Seq(StringTypeWithCollation, StringTypeWithCollation) + Seq( + StringTypeWithCollation(supportsTrimCollation = true), + StringTypeWithCollation(supportsTrimCollation = true), + IntegerType + ) + case _ => + Seq( + StringTypeWithCollation(supportsTrimCollation = true), + StringTypeWithCollation(supportsTrimCollation = true) + ) } override def children: Seq[Expression] = threshold match { @@ -2599,7 +2700,8 @@ case class SoundEx(child: Expression) override def dataType: DataType = SQLConf.get.defaultStringType - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCollation(supportsTrimCollation = true)) override def nullSafeEval(input: Any): Any = input.asInstanceOf[UTF8String].soundex() @@ -2629,7 +2731,8 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCollation(supportsTrimCollation = true)) protected override def nullSafeEval(string: Any): Any = { // only pick the first character to reduce the `toString` cost @@ -2774,7 +2877,8 @@ case class UnBase64(child: Expression, failOnError: Boolean = false) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = BinaryType - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCollation(supportsTrimCollation = true)) def this(expr: Expression) = this(expr, false) @@ -2953,8 +3057,10 @@ case class StringDecode( this(bin, charset, SQLConf.get.legacyJavaCharsets, SQLConf.get.legacyCodingErrorAction) override val dataType: DataType = SQLConf.get.defaultStringType - override def inputTypes: Seq[AbstractDataType] = - Seq(BinaryType, StringTypeWithCollation) + override def inputTypes: Seq[AbstractDataType] = Seq( + BinaryType, + StringTypeWithCollation(supportsTrimCollation = true) + ) override def prettyName: String = "decode" override def toString: String = s"$prettyName($bin, $charset)" @@ -2963,7 +3069,13 @@ case class StringDecode( SQLConf.get.defaultStringType, "decode", Seq(bin, charset, Literal(legacyCharsets), Literal(legacyErrorAction)), - Seq(BinaryType, StringTypeWithCollation, BooleanType, BooleanType)) + Seq( + BinaryType, + StringTypeWithCollation(supportsTrimCollation = true), + BooleanType, + BooleanType + ) + ) override def children: Seq[Expression] = Seq(bin, charset) override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = @@ -3020,7 +3132,10 @@ case class Encode( override def dataType: DataType = BinaryType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCollation, StringTypeWithCollation) + Seq( + StringTypeWithCollation(supportsTrimCollation = true), + StringTypeWithCollation(supportsTrimCollation = true) + ) override lazy val replacement: Expression = StaticInvoke( classOf[Encode], @@ -3030,8 +3145,8 @@ case class Encode( str, charset, Literal(legacyCharsets, BooleanType), Literal(legacyErrorAction, BooleanType) ), Seq( - StringTypeWithCollation, - StringTypeWithCollation, + StringTypeWithCollation(supportsTrimCollation = true), + StringTypeWithCollation(supportsTrimCollation = true), BooleanType, BooleanType)) @@ -3118,7 +3233,7 @@ case class ToBinary( override def children: Seq[Expression] = expr +: format.toSeq override def inputTypes: Seq[AbstractDataType] = - children.map(_ => StringTypeWithCollation) + children.map(_ => StringTypeWithCollation(supportsTrimCollation = true)) override def checkInputDataTypes(): TypeCheckResult = { def isValidFormat: Boolean = { @@ -3205,7 +3320,12 @@ case class FormatNumber(x: Expression, d: Expression) override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = - Seq(NumericType, TypeCollection(IntegerType, StringTypeWithCollation)) + Seq( + NumericType, + TypeCollection(IntegerType, + StringTypeWithCollation(supportsTrimCollation = true) + ) + ) private val defaultFormat = "#,###,###,###,###,###,##0" @@ -3412,7 +3532,9 @@ case class Sentences( override def inputTypes: Seq[AbstractDataType] = Seq( StringTypeWithCollation, - StringTypeWithCollation, StringTypeWithCollation) + StringTypeWithCollation, + StringTypeWithCollation + ) override def first: Expression = str override def second: Expression = language override def third: Expression = country @@ -3499,7 +3621,11 @@ case class SplitPart ( false) override def nodeName: String = "split_part" override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, IntegerType) + Seq( + StringTypeNonCSAICollation(supportsTrimCollation = true), + StringTypeNonCSAICollation(supportsTrimCollation = true), + IntegerType + ) def children: Seq[Expression] = Seq(str, delimiter, partNum) protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = { copy(str = newChildren.apply(0), delimiter = newChildren.apply(1), @@ -3560,7 +3686,8 @@ case class Luhncheck(input: Expression) extends RuntimeReplaceable with Implicit "isLuhnNumber", Seq(input), inputTypes) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCollation(supportsTrimCollation = true)) override def prettyName: String = "luhn_check" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala index 8cfc939755ef7..0d7f2b1d0f3f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical +import java.util.UUID import java.util.concurrent.TimeUnit import org.apache.spark.sql.catalyst.expressions.Attribute @@ -69,6 +70,7 @@ object EventTimeWatermark { * Used to mark a user specified column as holding the event time for a row. */ case class EventTimeWatermark( + nodeId: UUID, eventTime: Attribute, delay: CalendarInterval, child: LogicalPlan) extends UnaryNode { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index efdc06d4cbd8a..2cc223ba69fa7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -257,11 +257,11 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE summary = "") } - def invalidFractionOfSecondError(): DateTimeException = { + def invalidFractionOfSecondError(secAndMicros: Decimal): DateTimeException = { new SparkDateTimeException( errorClass = "INVALID_FRACTION_OF_SECOND", messageParameters = Map( - "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key) + "secAndMicros" -> s"$secAndMicros" ), context = Array.empty, summary = "") @@ -295,8 +295,11 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key))) } - def overflowInSumOfDecimalError(context: QueryContext): ArithmeticException = { - arithmeticOverflowError("Overflow in sum of decimals", context = context) + def overflowInSumOfDecimalError( + context: QueryContext, + suggestedFunc: String): ArithmeticException = { + arithmeticOverflowError("Overflow in sum of decimals", suggestedFunc = suggestedFunc, + context = context) } def overflowInIntegralDivideError(context: QueryContext): ArithmeticException = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 8409f454bfb88..939801e3f07af 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.analysis -import java.util.TimeZone +import java.util.{TimeZone, UUID} import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag @@ -1763,7 +1763,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { } test("SPARK-46064 Basic functionality of elimination for watermark node in batch query") { - val dfWithEventTimeWatermark = EventTimeWatermark($"ts", + val dfWithEventTimeWatermark = EventTimeWatermark(UUID.randomUUID(), $"ts", IntervalUtils.fromIntervalString("10 seconds"), batchRelationWithTs) val analyzed = getAnalyzer.executeAndCheck(dfWithEventTimeWatermark, new QueryPlanningTracker) @@ -1776,7 +1776,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { "EventTimeWatermark changes the isStreaming flag during resolution") { // UnresolvedRelation which is batch initially and will be resolved as streaming val dfWithTempView = UnresolvedRelation(TableIdentifier("streamingTable")) - val dfWithEventTimeWatermark = EventTimeWatermark($"ts", + val dfWithEventTimeWatermark = EventTimeWatermark(UUID.randomUUID(), $"ts", IntervalUtils.fromIntervalString("10 seconds"), dfWithTempView) val analyzed = getAnalyzer.executeAndCheck(dfWithEventTimeWatermark, new QueryPlanningTracker) diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/BitmapExpressionUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitmapExpressionUtilsSuite.scala similarity index 100% rename from sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/BitmapExpressionUtilsSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitmapExpressionUtilsSuite.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala index 2634008cdd9e4..a30f604550a38 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala @@ -95,7 +95,10 @@ class CollationExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { (Seq("a"), Seq("A"), true, "UTF8_LCASE"), (Seq("a", "B"), Seq("A", "b"), true, "UTF8_LCASE"), (Seq("a"), Seq("A"), false, "UNICODE"), - (Seq("a", "B"), Seq("A", "b"), true, "UNICODE_CI") + (Seq("a", "B"), Seq("A", "b"), true, "UNICODE_CI"), + (Seq("c"), Seq("C"), false, "SR"), + (Seq("c"), Seq("C"), true, "SR_CI"), + (Seq("a", "c"), Seq("b", "C"), true, "SR_CI_AI") ) for ((inLeft, inRight, out, collName) <- overlap) { val left = arrayLiteral(inLeft, collName) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 55148978fa005..1907ec7c23aa6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -137,6 +137,8 @@ class CollectionExpressionsSuite test("ArrayBinarySearch") { // primitive type: boolean、byte、short、int、long、float、double + // boolean + // boolean foldable val a0_0 = Literal.create(Seq(false, true), ArrayType(BooleanType, containsNull = false)) checkEvaluation(ArrayBinarySearch(a0_0, Literal(true)), 1) @@ -144,7 +146,23 @@ class CollectionExpressionsSuite checkEvaluation(ArrayBinarySearch(a0_1, Literal(false)), 1) val a0_2 = Literal.create(Seq(null, false, true), ArrayType(BooleanType)) checkEvaluation(ArrayBinarySearch(a0_2, Literal(null, BooleanType)), null) - + val a0_3 = CreateArray(Seq(Literal(false), Literal(true))) + checkEvaluation(ArrayBinarySearch(a0_3, Literal(true)), 1) + val a0_4 = CreateArray(Seq(Literal(null, BooleanType), Literal(false), Literal(true))) + checkEvaluation(ArrayBinarySearch(a0_4, Literal(false)), 1) + val a0_5 = CreateArray(Seq(Literal(null, BooleanType), Literal(false), Literal(true))) + checkEvaluation(ArrayBinarySearch(a0_5, Literal(null, BooleanType)), null) + // boolean non-foldable + val a0_6 = NonFoldableLiteral.create(Seq(false, true), + ArrayType(BooleanType, containsNull = false)) + checkEvaluation(ArrayBinarySearch(a0_6, Literal(true)), 1) + val a0_7 = NonFoldableLiteral.create(Seq(null, false, true), ArrayType(BooleanType)) + checkEvaluation(ArrayBinarySearch(a0_7, Literal(false)), 1) + val a0_8 = NonFoldableLiteral.create(Seq(null, false, true), ArrayType(BooleanType)) + checkEvaluation(ArrayBinarySearch(a0_8, Literal(null, BooleanType)), null) + + // byte + // byte foldable val a1_0 = Literal.create(Seq(1.toByte, 2.toByte, 3.toByte), ArrayType(ByteType, containsNull = false)) checkEvaluation(ArrayBinarySearch(a1_0, Literal(3.toByte)), 2) @@ -155,18 +173,70 @@ class CollectionExpressionsSuite val a1_3 = Literal.create(Seq(1.toByte, 3.toByte, 4.toByte), ArrayType(ByteType, containsNull = false)) checkEvaluation(ArrayBinarySearch(a1_3, Literal(2.toByte, ByteType)), -2) + val a1_4 = CreateArray(Seq(Literal(1.toByte), Literal(2.toByte), Literal(3.toByte))) + checkEvaluation(ArrayBinarySearch(a1_4, Literal(3.toByte)), 2) + val a1_5 = CreateArray(Seq(Literal(null, ByteType), + Literal(1.toByte), Literal(2.toByte), Literal(3.toByte))) + checkEvaluation(ArrayBinarySearch(a1_5, Literal(1.toByte)), 1) + val a1_6 = CreateArray(Seq(Literal(null, ByteType), + Literal(1.toByte), Literal(2.toByte), Literal(3.toByte))) + checkEvaluation(ArrayBinarySearch(a1_6, Literal(null, ByteType)), null) + val a1_7 = CreateArray(Seq(Literal(1.toByte), Literal(3.toByte), Literal(4.toByte))) + checkEvaluation(ArrayBinarySearch(a1_7, Literal(2.toByte, ByteType)), -2) + // byte non-foldable + val a1_8 = NonFoldableLiteral.create(Seq(1.toByte, 2.toByte, 3.toByte), + ArrayType(ByteType, containsNull = false)) + checkEvaluation(ArrayBinarySearch(a1_8, Literal(3.toByte)), 2) + val a1_9 = NonFoldableLiteral.create(Seq(null, 1.toByte, 2.toByte, 3.toByte), + ArrayType(ByteType)) + checkEvaluation(ArrayBinarySearch(a1_9, Literal(1.toByte)), 1) + val a1_10 = NonFoldableLiteral.create(Seq(null, 1.toByte, 2.toByte, 3.toByte), + ArrayType(ByteType)) + checkEvaluation(ArrayBinarySearch(a1_10, Literal(null, ByteType)), null) + val a1_11 = NonFoldableLiteral.create(Seq(1.toByte, 3.toByte, 4.toByte), + ArrayType(ByteType, containsNull = false)) + checkEvaluation(ArrayBinarySearch(a1_11, Literal(2.toByte, ByteType)), -2) + // short + // short foldable val a2_0 = Literal.create(Seq(1.toShort, 2.toShort, 3.toShort), ArrayType(ShortType, containsNull = false)) checkEvaluation(ArrayBinarySearch(a2_0, Literal(1.toShort)), 0) - val a2_1 = Literal.create(Seq(null, 1.toShort, 2.toShort, 3.toShort), ArrayType(ShortType)) + val a2_1 = Literal.create(Seq(null, 1.toShort, 2.toShort, 3.toShort), + ArrayType(ShortType)) checkEvaluation(ArrayBinarySearch(a2_1, Literal(2.toShort)), 2) - val a2_2 = Literal.create(Seq(null, 1.toShort, 2.toShort, 3.toShort), ArrayType(ShortType)) + val a2_2 = Literal.create(Seq(null, 1.toShort, 2.toShort, 3.toShort), + ArrayType(ShortType)) checkEvaluation(ArrayBinarySearch(a2_2, Literal(null, ShortType)), null) val a2_3 = Literal.create(Seq(1.toShort, 3.toShort, 4.toShort), ArrayType(ShortType, containsNull = false)) checkEvaluation(ArrayBinarySearch(a2_3, Literal(2.toShort, ShortType)), -2) + val a2_4 = CreateArray(Seq(Literal(1.toShort), Literal(2.toShort), Literal(3.toShort))) + checkEvaluation(ArrayBinarySearch(a2_4, Literal(1.toShort)), 0) + val a2_5 = CreateArray(Seq(Literal(null, ShortType), + Literal(1.toShort), Literal(2.toShort), Literal(3.toShort))) + checkEvaluation(ArrayBinarySearch(a2_5, Literal(2.toShort)), 2) + val a2_6 = CreateArray(Seq(Literal(null, ShortType), + Literal(1.toShort), Literal(2.toShort), Literal(3.toShort))) + checkEvaluation(ArrayBinarySearch(a2_6, Literal(null, ShortType)), null) + val a2_7 = CreateArray(Seq(Literal(1.toShort), Literal(3.toShort), Literal(4.toShort))) + checkEvaluation(ArrayBinarySearch(a2_7, Literal(2.toShort, ShortType)), -2) + // short non-foldable + val a2_8 = NonFoldableLiteral.create(Seq(1.toShort, 2.toShort, 3.toShort), + ArrayType(ShortType, containsNull = false)) + checkEvaluation(ArrayBinarySearch(a2_8, Literal(1.toShort)), 0) + val a2_9 = NonFoldableLiteral.create(Seq(null, 1.toShort, 2.toShort, 3.toShort), + ArrayType(ShortType)) + checkEvaluation(ArrayBinarySearch(a2_9, Literal(2.toShort)), 2) + val a2_10 = NonFoldableLiteral.create(Seq(null, 1.toShort, 2.toShort, 3.toShort), + ArrayType(ShortType)) + checkEvaluation(ArrayBinarySearch(a2_10, Literal(null, ShortType)), null) + val a2_11 = NonFoldableLiteral.create(Seq(1.toShort, 3.toShort, 4.toShort), + ArrayType(ShortType, containsNull = false)) + checkEvaluation(ArrayBinarySearch(a2_11, Literal(2.toShort, ShortType)), -2) + // int + // int foldable val a3_0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) checkEvaluation(ArrayBinarySearch(a3_0, Literal(2)), 1) val a3_1 = Literal.create(Seq(null, 1, 2, 3), ArrayType(IntegerType)) @@ -175,7 +245,28 @@ class CollectionExpressionsSuite checkEvaluation(ArrayBinarySearch(a3_2, Literal(null, IntegerType)), null) val a3_3 = Literal.create(Seq(1, 3, 4), ArrayType(IntegerType, containsNull = false)) checkEvaluation(ArrayBinarySearch(a3_3, Literal(2, IntegerType)), -2) - + val a3_4 = CreateArray(Seq(Literal(1), Literal(2), Literal(3))) + checkEvaluation(ArrayBinarySearch(a3_4, Literal(2)), 1) + val a3_5 = CreateArray(Seq(Literal(null, IntegerType), Literal(1), Literal(2), Literal(3))) + checkEvaluation(ArrayBinarySearch(a3_5, Literal(2)), 2) + val a3_6 = CreateArray(Seq(Literal(null, IntegerType), Literal(1), Literal(2), Literal(3))) + checkEvaluation(ArrayBinarySearch(a3_6, Literal(null, IntegerType)), null) + val a3_7 = CreateArray(Seq(Literal(1), Literal(3), Literal(4))) + checkEvaluation(ArrayBinarySearch(a3_7, Literal(2, IntegerType)), -2) + // int non-foldable + val a3_8 = NonFoldableLiteral.create(Seq(1, 2, 3), + ArrayType(IntegerType, containsNull = false)) + checkEvaluation(ArrayBinarySearch(a3_8, Literal(2)), 1) + val a3_9 = NonFoldableLiteral.create(Seq(null, 1, 2, 3), ArrayType(IntegerType)) + checkEvaluation(ArrayBinarySearch(a3_9, Literal(2)), 2) + val a3_10 = NonFoldableLiteral.create(Seq(null, 1, 2, 3), ArrayType(IntegerType)) + checkEvaluation(ArrayBinarySearch(a3_10, Literal(null, IntegerType)), null) + val a3_11 = NonFoldableLiteral.create(Seq(1, 3, 4), + ArrayType(IntegerType, containsNull = false)) + checkEvaluation(ArrayBinarySearch(a3_11, Literal(2, IntegerType)), -2) + + // long + // long foldable val a4_0 = Literal.create(Seq(1L, 2L, 3L), ArrayType(LongType, containsNull = false)) checkEvaluation(ArrayBinarySearch(a4_0, Literal(2L)), 1) val a4_1 = Literal.create(Seq(null, 1L, 2L, 3L), ArrayType(LongType)) @@ -184,7 +275,30 @@ class CollectionExpressionsSuite checkEvaluation(ArrayBinarySearch(a4_2, Literal(null, LongType)), null) val a4_3 = Literal.create(Seq(1L, 3L, 4L), ArrayType(LongType, containsNull = false)) checkEvaluation(ArrayBinarySearch(a4_3, Literal(2L, LongType)), -2) - + val a4_4 = CreateArray(Seq(Literal(1L), Literal(2L), Literal(3L))) + checkEvaluation(ArrayBinarySearch(a4_4, Literal(2L)), 1) + val a4_5 = CreateArray(Seq(Literal(null, LongType), + Literal(1L), Literal(2L), Literal(3L))) + checkEvaluation(ArrayBinarySearch(a4_5, Literal(2L)), 2) + val a4_6 = CreateArray(Seq(Literal(null, LongType), + Literal(1L), Literal(2L), Literal(3L))) + checkEvaluation(ArrayBinarySearch(a4_6, Literal(null, LongType)), null) + val a4_7 = CreateArray(Seq(Literal(1L), Literal(3L), Literal(4L))) + checkEvaluation(ArrayBinarySearch(a4_7, Literal(2L, LongType)), -2) + // long non-foldable + val a4_8 = NonFoldableLiteral.create(Seq(1L, 2L, 3L), + ArrayType(LongType, containsNull = false)) + checkEvaluation(ArrayBinarySearch(a4_8, Literal(2L)), 1) + val a4_9 = NonFoldableLiteral.create(Seq(null, 1L, 2L, 3L), ArrayType(LongType)) + checkEvaluation(ArrayBinarySearch(a4_9, Literal(2L)), 2) + val a4_10 = NonFoldableLiteral.create(Seq(null, 1L, 2L, 3L), ArrayType(LongType)) + checkEvaluation(ArrayBinarySearch(a4_10, Literal(null, LongType)), null) + val a4_11 = NonFoldableLiteral.create(Seq(1L, 3L, 4L), + ArrayType(LongType, containsNull = false)) + checkEvaluation(ArrayBinarySearch(a4_11, Literal(2L, LongType)), -2) + + // float + // float foldable val a5_0 = Literal.create(Seq(1.0F, 2.0F, 3.0F), ArrayType(FloatType, containsNull = false)) checkEvaluation(ArrayBinarySearch(a5_0, Literal(3.0F)), 2) val a5_1 = Literal.create(Seq(null, 1.0F, 2.0F, 3.0F), ArrayType(FloatType)) @@ -193,7 +307,30 @@ class CollectionExpressionsSuite checkEvaluation(ArrayBinarySearch(a5_2, Literal(null, FloatType)), null) val a5_3 = Literal.create(Seq(1.0F, 2.0F, 3.0F), ArrayType(FloatType, containsNull = false)) checkEvaluation(ArrayBinarySearch(a5_3, Literal(1.1F, FloatType)), -2) - + val a5_4 = CreateArray(Seq(Literal(1.0F), Literal(2.0F), Literal(3.0F))) + checkEvaluation(ArrayBinarySearch(a5_4, Literal(3.0F)), 2) + val a5_5 = CreateArray(Seq(Literal(null, FloatType), + Literal(1.0F), Literal(2.0F), Literal(3.0F))) + checkEvaluation(ArrayBinarySearch(a5_5, Literal(1.0F)), 1) + val a5_6 = CreateArray(Seq(Literal(null, FloatType), + Literal(1.0F), Literal(2.0F), Literal(3.0F))) + checkEvaluation(ArrayBinarySearch(a5_6, Literal(null, FloatType)), null) + val a5_7 = CreateArray(Seq(Literal(1.0F), Literal(2.0F), Literal(3.0F))) + checkEvaluation(ArrayBinarySearch(a5_7, Literal(1.1F, FloatType)), -2) + // float non-foldable + val a5_8 = NonFoldableLiteral.create(Seq(1.0F, 2.0F, 3.0F), + ArrayType(FloatType, containsNull = false)) + checkEvaluation(ArrayBinarySearch(a5_8, Literal(3.0F)), 2) + val a5_9 = NonFoldableLiteral.create(Seq(null, 1.0F, 2.0F, 3.0F), ArrayType(FloatType)) + checkEvaluation(ArrayBinarySearch(a5_9, Literal(1.0F)), 1) + val a5_10 = NonFoldableLiteral.create(Seq(null, 1.0F, 2.0F, 3.0F), ArrayType(FloatType)) + checkEvaluation(ArrayBinarySearch(a5_10, Literal(null, FloatType)), null) + val a5_11 = NonFoldableLiteral.create(Seq(1.0F, 2.0F, 3.0F), + ArrayType(FloatType, containsNull = false)) + checkEvaluation(ArrayBinarySearch(a5_11, Literal(1.1F, FloatType)), -2) + + // double + // double foldable val a6_0 = Literal.create(Seq(1.0d, 2.0d, 3.0d), ArrayType(DoubleType, containsNull = false)) checkEvaluation(ArrayBinarySearch(a6_0, Literal(1.0d)), 0) val a6_1 = Literal.create(Seq(null, 1.0d, 2.0d, 3.0d), ArrayType(DoubleType)) @@ -202,8 +339,30 @@ class CollectionExpressionsSuite checkEvaluation(ArrayBinarySearch(a6_2, Literal(null, DoubleType)), null) val a6_3 = Literal.create(Seq(1.0d, 2.0d, 3.0d), ArrayType(DoubleType, containsNull = false)) checkEvaluation(ArrayBinarySearch(a6_3, Literal(1.1d, DoubleType)), -2) + val a6_4 = CreateArray(Seq(Literal(1.0d), Literal(2.0d), Literal(3.0d))) + checkEvaluation(ArrayBinarySearch(a6_4, Literal(1.0d)), 0) + val a6_5 = CreateArray(Seq(Literal(null, DoubleType), + Literal(1.0d), Literal(2.0d), Literal(3.0d))) + checkEvaluation(ArrayBinarySearch(a6_5, Literal(1.0d)), 1) + val a6_6 = CreateArray(Seq(Literal(null, DoubleType), + Literal(1.0d), Literal(2.0d), Literal(3.0d))) + checkEvaluation(ArrayBinarySearch(a6_6, Literal(null, DoubleType)), null) + val a6_7 = CreateArray(Seq(Literal(1.0d), Literal(2.0d), Literal(3.0d))) + checkEvaluation(ArrayBinarySearch(a6_7, Literal(1.1d, DoubleType)), -2) + // double non-foldable + val a6_8 = NonFoldableLiteral.create(Seq(1.0d, 2.0d, 3.0d), + ArrayType(DoubleType, containsNull = false)) + checkEvaluation(ArrayBinarySearch(a6_8, Literal(1.0d)), 0) + val a6_9 = NonFoldableLiteral.create(Seq(null, 1.0d, 2.0d, 3.0d), ArrayType(DoubleType)) + checkEvaluation(ArrayBinarySearch(a6_9, Literal(1.0d)), 1) + val a6_10 = NonFoldableLiteral.create(Seq(null, 1.0d, 2.0d, 3.0d), ArrayType(DoubleType)) + checkEvaluation(ArrayBinarySearch(a6_10, Literal(null, DoubleType)), null) + val a6_11 = NonFoldableLiteral.create(Seq(1.0d, 2.0d, 3.0d), + ArrayType(DoubleType, containsNull = false)) + checkEvaluation(ArrayBinarySearch(a6_11, Literal(1.1d, DoubleType)), -2) // string + // string foldable val a7_0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false)) checkEvaluation(ArrayBinarySearch(a7_0, Literal("a")), 0) val a7_1 = Literal.create(Seq(null, "a", "b", "c"), ArrayType(StringType)) @@ -212,6 +371,27 @@ class CollectionExpressionsSuite checkEvaluation(ArrayBinarySearch(a7_2, Literal(null, StringType)), null) val a7_3 = Literal.create(Seq("a", "c", "d"), ArrayType(StringType, containsNull = false)) checkEvaluation(ArrayBinarySearch(a7_3, Literal(UTF8String.fromString("b"), StringType)), -2) + val a7_4 = CreateArray(Seq(Literal("a"), Literal("b"), Literal("c"))) + checkEvaluation(ArrayBinarySearch(a7_4, Literal("a")), 0) + val a7_5 = CreateArray(Seq(Literal(null, StringType), + Literal("a"), Literal("b"), Literal("c"))) + checkEvaluation(ArrayBinarySearch(a7_5, Literal("c")), 3) + val a7_6 = CreateArray(Seq(Literal(null, StringType), + Literal("a"), Literal("b"), Literal("c"))) + checkEvaluation(ArrayBinarySearch(a7_6, Literal(null, StringType)), null) + val a7_7 = CreateArray(Seq(Literal("a"), Literal("c"), Literal("d"))) + checkEvaluation(ArrayBinarySearch(a7_7, Literal(UTF8String.fromString("b"), StringType)), -2) + // string non-foldable + val a7_8 = NonFoldableLiteral.create(Seq("a", "b", "c"), + ArrayType(StringType, containsNull = false)) + checkEvaluation(ArrayBinarySearch(a7_8, Literal("a")), 0) + val a7_9 = NonFoldableLiteral.create(Seq(null, "a", "b", "c"), ArrayType(StringType)) + checkEvaluation(ArrayBinarySearch(a7_9, Literal("c")), 3) + val a7_10 = NonFoldableLiteral.create(Seq(null, "a", "b", "c"), ArrayType(StringType)) + checkEvaluation(ArrayBinarySearch(a7_10, Literal(null, StringType)), null) + val a7_11 = NonFoldableLiteral.create(Seq("a", "c", "d"), + ArrayType(StringType, containsNull = false)) + checkEvaluation(ArrayBinarySearch(a7_11, Literal(UTF8String.fromString("b"), StringType)), -2) } test("MapEntries") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 21ae35146282b..05d68504a7270 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -1202,7 +1202,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Literal(23), Literal(59), Literal(Decimal(BigDecimal(60.0), 16, 6))) if (ansi) { checkExceptionInExpression[DateTimeException](makeTimestampExpr.copy(sec = Literal( - Decimal(BigDecimal(60.5), 16, 6))), EmptyRow, "The fraction of sec must be zero") + Decimal(BigDecimal(60.5), 16, 6))), EmptyRow, "Valid range for seconds is [0, 60]") } else { checkEvaluation(makeTimestampExpr, expectedAnswer("2019-07-01 00:00:00")) } diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtilsSuite.scala similarity index 100% rename from sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtilsSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtilsSuite.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 5027222be6b80..9424ecda0ed8b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.optimizer +import java.util.UUID + import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -1229,9 +1231,10 @@ class FilterPushdownSuite extends PlanTest { // Verify that all conditions except the watermark touching condition are pushed down // by the optimizer and others are not. - val originalQuery = EventTimeWatermark($"b", interval, relation) + val nodeId = UUID.randomUUID() + val originalQuery = EventTimeWatermark(nodeId, $"b", interval, relation) .where($"a" === 5 && $"b" === new java.sql.Timestamp(0) && $"c" === 5) - val correctAnswer = EventTimeWatermark( + val correctAnswer = EventTimeWatermark(nodeId, $"b", interval, relation.where($"a" === 5 && $"c" === 5)) .where($"b" === new java.sql.Timestamp(0)) @@ -1244,9 +1247,10 @@ class FilterPushdownSuite extends PlanTest { // Verify that all conditions except the watermark touching condition are pushed down // by the optimizer and others are not. - val originalQuery = EventTimeWatermark($"c", interval, relation) + val nodeId = UUID.randomUUID() + val originalQuery = EventTimeWatermark(nodeId, $"c", interval, relation) .where($"a" === 5 && $"b" === Rand(10) && $"c" === new java.sql.Timestamp(0)) - val correctAnswer = EventTimeWatermark( + val correctAnswer = EventTimeWatermark(nodeId, $"c", interval, relation.where($"a" === 5)) .where($"b" === Rand(10) && $"c" === new java.sql.Timestamp(0)) @@ -1260,9 +1264,10 @@ class FilterPushdownSuite extends PlanTest { // Verify that all conditions except the watermark touching condition are pushed down // by the optimizer and others are not. - val originalQuery = EventTimeWatermark($"c", interval, relation) + val nodeId = UUID.randomUUID() + val originalQuery = EventTimeWatermark(nodeId, $"c", interval, relation) .where($"a" === 5 && $"b" === 10) - val correctAnswer = EventTimeWatermark( + val correctAnswer = EventTimeWatermark(nodeId, $"c", interval, relation.where($"a" === 5 && $"b" === 10)) comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze, @@ -1273,9 +1278,10 @@ class FilterPushdownSuite extends PlanTest { val interval = new CalendarInterval(2, 2, 2000L) val relation = LocalRelation(Seq($"a".timestamp, attrB, attrC), Nil, isStreaming = true) - val originalQuery = EventTimeWatermark($"a", interval, relation) + val nodeId = UUID.randomUUID() + val originalQuery = EventTimeWatermark(nodeId, $"a", interval, relation) .where($"a" === new java.sql.Timestamp(0) && $"b" === 10) - val correctAnswer = EventTimeWatermark( + val correctAnswer = EventTimeWatermark(nodeId, $"a", interval, relation.where($"b" === 10)).where($"a" === new java.sql.Timestamp(0)) comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 6d307d1cd9a87..fc8bcfa3f6870 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -1082,8 +1082,8 @@ class ExpressionParserSuite extends AnalysisTest { // Unknown FROM TO intervals checkError( exception = parseException("interval '10' month to second"), - condition = "_LEGACY_ERROR_TEMP_0028", - parameters = Map("from" -> "month", "to" -> "second"), + condition = "INVALID_INTERVAL_FORMAT.UNSUPPORTED_FROM_TO_EXPRESSION", + parameters = Map("input" -> "10", "from" -> "month", "to" -> "second"), context = ExpectedContext( fragment = "'10' month to second", start = 9, diff --git a/sql/connect/common/src/main/buf.gen.yaml b/sql/connect/common/src/main/buf.gen.yaml index 9b0b07932eae8..a68bc880b8315 100644 --- a/sql/connect/common/src/main/buf.gen.yaml +++ b/sql/connect/common/src/main/buf.gen.yaml @@ -22,14 +22,14 @@ plugins: out: gen/proto/csharp - plugin: buf.build/protocolbuffers/java:v21.7 out: gen/proto/java - - plugin: buf.build/grpc/ruby:v1.62.0 + - plugin: buf.build/grpc/ruby:v1.67.0 out: gen/proto/ruby - plugin: buf.build/protocolbuffers/ruby:v21.7 out: gen/proto/ruby # Building the Python build and building the mypy interfaces. - plugin: buf.build/protocolbuffers/python:v21.7 out: gen/proto/python - - plugin: buf.build/grpc/python:v1.62.0 + - plugin: buf.build/grpc/python:v1.67.0 out: gen/proto/python - name: mypy out: gen/proto/python diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_make_timestamp_ltz_with_timezone.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_make_timestamp_ltz_with_timezone.explain new file mode 100644 index 0000000000000..ec8a7336a9b71 --- /dev/null +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_make_timestamp_ltz_with_timezone.explain @@ -0,0 +1,2 @@ +Project [try_make_timestamp_ltz(a#0, a#0, a#0, a#0, a#0, cast(b#0 as decimal(16,6)), Some(g#0), Some(America/Los_Angeles), false, TimestampType) AS try_make_timestamp_ltz(a, a, a, a, a, b, g)#0] ++- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_make_timestamp_ltz_without_timezone.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_make_timestamp_ltz_without_timezone.explain new file mode 100644 index 0000000000000..39f8095a1e095 --- /dev/null +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_make_timestamp_ltz_without_timezone.explain @@ -0,0 +1,2 @@ +Project [try_make_timestamp_ltz(a#0, a#0, a#0, a#0, a#0, cast(b#0 as decimal(16,6)), None, Some(America/Los_Angeles), false, TimestampType) AS try_make_timestamp_ltz(a, a, a, a, a, b)#0] ++- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_make_timestamp_ntz.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_make_timestamp_ntz.explain new file mode 100644 index 0000000000000..aa6613263622e --- /dev/null +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_make_timestamp_ntz.explain @@ -0,0 +1,2 @@ +Project [try_make_timestamp_ntz(a#0, a#0, a#0, a#0, a#0, cast(b#0 as decimal(16,6)), None, Some(America/Los_Angeles), false, TimestampNTZType) AS try_make_timestamp_ntz(a, a, a, a, a, b)#0] ++- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_make_timestamp_with_timezone.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_make_timestamp_with_timezone.explain new file mode 100644 index 0000000000000..91d8e638750e6 --- /dev/null +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_make_timestamp_with_timezone.explain @@ -0,0 +1,2 @@ +Project [make_timestamp(a#0, a#0, a#0, a#0, a#0, cast(b#0 as decimal(16,6)), Some(g#0), Some(America/Los_Angeles), false, TimestampType) AS try_make_timestamp(a, a, a, a, a, b)#0] ++- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_make_timestamp_without_timezone.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_make_timestamp_without_timezone.explain new file mode 100644 index 0000000000000..5bca1302ead5e --- /dev/null +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_make_timestamp_without_timezone.explain @@ -0,0 +1,2 @@ +Project [make_timestamp(a#0, a#0, a#0, a#0, a#0, cast(b#0 as decimal(16,6)), None, Some(America/Los_Angeles), false, TimestampType) AS try_make_timestamp(a, a, a, a, a, b)#0] ++- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_try_make_timestamp_ltz_with_timezone.json b/sql/connect/common/src/test/resources/query-tests/queries/function_try_make_timestamp_ltz_with_timezone.json new file mode 100644 index 0000000000000..179f6e06988fc --- /dev/null +++ b/sql/connect/common/src/test/resources/query-tests/queries/function_try_make_timestamp_ltz_with_timezone.json @@ -0,0 +1,49 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "try_make_timestamp_ltz", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "b" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "g" + } + }] + } + }] + } +} \ No newline at end of file diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_try_make_timestamp_ltz_with_timezone.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_try_make_timestamp_ltz_with_timezone.proto.bin new file mode 100644 index 0000000000000..d0c60ba1c7bf8 Binary files /dev/null and b/sql/connect/common/src/test/resources/query-tests/queries/function_try_make_timestamp_ltz_with_timezone.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_try_make_timestamp_ltz_without_timezone.json b/sql/connect/common/src/test/resources/query-tests/queries/function_try_make_timestamp_ltz_without_timezone.json new file mode 100644 index 0000000000000..29aa2096c2273 --- /dev/null +++ b/sql/connect/common/src/test/resources/query-tests/queries/function_try_make_timestamp_ltz_without_timezone.json @@ -0,0 +1,45 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "try_make_timestamp_ltz", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "b" + } + }] + } + }] + } +} \ No newline at end of file diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_try_make_timestamp_ltz_without_timezone.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_try_make_timestamp_ltz_without_timezone.proto.bin new file mode 100644 index 0000000000000..9caf6f6ba5285 Binary files /dev/null and b/sql/connect/common/src/test/resources/query-tests/queries/function_try_make_timestamp_ltz_without_timezone.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_try_make_timestamp_ntz.json b/sql/connect/common/src/test/resources/query-tests/queries/function_try_make_timestamp_ntz.json new file mode 100644 index 0000000000000..6b8d31d0c58e5 --- /dev/null +++ b/sql/connect/common/src/test/resources/query-tests/queries/function_try_make_timestamp_ntz.json @@ -0,0 +1,45 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "try_make_timestamp_ntz", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "b" + } + }] + } + }] + } +} \ No newline at end of file diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_try_make_timestamp_ntz.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_try_make_timestamp_ntz.proto.bin new file mode 100644 index 0000000000000..7d7e2a8029def Binary files /dev/null and b/sql/connect/common/src/test/resources/query-tests/queries/function_try_make_timestamp_ntz.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_try_make_timestamp_with_timezone.json b/sql/connect/common/src/test/resources/query-tests/queries/function_try_make_timestamp_with_timezone.json new file mode 100644 index 0000000000000..79e11efc20d41 --- /dev/null +++ b/sql/connect/common/src/test/resources/query-tests/queries/function_try_make_timestamp_with_timezone.json @@ -0,0 +1,49 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "try_make_timestamp", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "b" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "g" + } + }] + } + }] + } +} \ No newline at end of file diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_try_make_timestamp_with_timezone.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_try_make_timestamp_with_timezone.proto.bin new file mode 100644 index 0000000000000..53b9839cf8c1f Binary files /dev/null and b/sql/connect/common/src/test/resources/query-tests/queries/function_try_make_timestamp_with_timezone.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_try_make_timestamp_without_timezone.json b/sql/connect/common/src/test/resources/query-tests/queries/function_try_make_timestamp_without_timezone.json new file mode 100644 index 0000000000000..39ce728a38862 --- /dev/null +++ b/sql/connect/common/src/test/resources/query-tests/queries/function_try_make_timestamp_without_timezone.json @@ -0,0 +1,45 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "try_make_timestamp", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "b" + } + }] + } + }] + } +} \ No newline at end of file diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_try_make_timestamp_without_timezone.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_try_make_timestamp_without_timezone.proto.bin new file mode 100644 index 0000000000000..74918d42f89c6 Binary files /dev/null and b/sql/connect/common/src/test/resources/query-tests/queries/function_try_make_timestamp_without_timezone.proto.bin differ diff --git a/sql/core/benchmarks/StateStoreBasicOperationsBenchmark-jdk21-results.txt b/sql/core/benchmarks/StateStoreBasicOperationsBenchmark-jdk21-results.txt index caca60875a8f2..4ab5f6d0061cc 100644 --- a/sql/core/benchmarks/StateStoreBasicOperationsBenchmark-jdk21-results.txt +++ b/sql/core/benchmarks/StateStoreBasicOperationsBenchmark-jdk21-results.txt @@ -2,143 +2,143 @@ put rows ================================================================================================ -OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 21.0.5+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor putting 10000 rows (10000 rows to overwrite - rate 100): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative --------------------------------------------------------------------------------------------------------------------------------------- -In-memory 9 11 1 1.1 943.6 1.0X -RocksDB (trackTotalNumberOfRows: true) 43 45 2 0.2 4332.8 0.2X -RocksDB (trackTotalNumberOfRows: false) 16 17 1 0.6 1583.2 0.6X +In-memory 10 14 1 1.0 1006.5 1.0X +RocksDB (trackTotalNumberOfRows: true) 43 45 2 0.2 4345.4 0.2X +RocksDB (trackTotalNumberOfRows: false) 15 17 1 0.6 1547.6 0.7X -OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 21.0.5+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor putting 10000 rows (5000 rows to overwrite - rate 50): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------- -In-memory 9 11 1 1.1 938.2 1.0X -RocksDB (trackTotalNumberOfRows: true) 45 46 1 0.2 4452.3 0.2X -RocksDB (trackTotalNumberOfRows: false) 16 17 1 0.6 1586.1 0.6X +In-memory 10 12 1 1.0 1011.1 1.0X +RocksDB (trackTotalNumberOfRows: true) 44 46 1 0.2 4441.2 0.2X +RocksDB (trackTotalNumberOfRows: false) 15 16 1 0.7 1521.7 0.7X -OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 21.0.5+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor putting 10000 rows (1000 rows to overwrite - rate 10): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------- -In-memory 9 10 1 1.1 920.7 1.0X -RocksDB (trackTotalNumberOfRows: true) 45 46 1 0.2 4478.9 0.2X -RocksDB (trackTotalNumberOfRows: false) 16 17 1 0.6 1581.1 0.6X +In-memory 9 10 1 1.1 940.8 1.0X +RocksDB (trackTotalNumberOfRows: true) 44 46 1 0.2 4425.1 0.2X +RocksDB (trackTotalNumberOfRows: false) 15 16 1 0.7 1515.2 0.6X -OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 21.0.5+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor putting 10000 rows (0 rows to overwrite - rate 0): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative --------------------------------------------------------------------------------------------------------------------------------- -In-memory 9 10 1 1.1 912.4 1.0X -RocksDB (trackTotalNumberOfRows: true) 44 46 1 0.2 4445.8 0.2X -RocksDB (trackTotalNumberOfRows: false) 16 17 0 0.6 1587.8 0.6X +In-memory 9 11 2 1.1 932.2 1.0X +RocksDB (trackTotalNumberOfRows: true) 44 46 1 0.2 4400.3 0.2X +RocksDB (trackTotalNumberOfRows: false) 15 17 1 0.7 1506.0 0.6X ================================================================================================ merge rows ================================================================================================ -OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 21.0.5+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor merging 10000 rows with 10 values per key (10000 rows to overwrite - rate 100): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------------------------------------- -RocksDB (trackTotalNumberOfRows: true) 537 551 7 0.0 53664.8 1.0X -RocksDB (trackTotalNumberOfRows: false) 173 178 4 0.1 17277.2 3.1X +RocksDB (trackTotalNumberOfRows: true) 532 547 8 0.0 53154.1 1.0X +RocksDB (trackTotalNumberOfRows: false) 174 180 3 0.1 17410.5 3.1X -OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 21.0.5+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor merging 10000 rows with 10 values per key (5000 rows to overwrite - rate 50): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------------------------ -RocksDB (trackTotalNumberOfRows: true) 474 486 5 0.0 47389.2 1.0X -RocksDB (trackTotalNumberOfRows: false) 172 177 2 0.1 17189.8 2.8X +RocksDB (trackTotalNumberOfRows: true) 472 484 5 0.0 47228.8 1.0X +RocksDB (trackTotalNumberOfRows: false) 174 179 3 0.1 17433.5 2.7X -OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 21.0.5+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor merging 10000 rows with 10 values per key (1000 rows to overwrite - rate 10): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------------------------ -RocksDB (trackTotalNumberOfRows: true) 417 430 6 0.0 41696.3 1.0X -RocksDB (trackTotalNumberOfRows: false) 175 180 3 0.1 17458.6 2.4X +RocksDB (trackTotalNumberOfRows: true) 422 434 5 0.0 42226.0 1.0X +RocksDB (trackTotalNumberOfRows: false) 172 179 3 0.1 17235.9 2.4X -OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 21.0.5+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor merging 10000 rows with 10 values per key (0 rows to overwrite - rate 0): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------------------------------- -RocksDB (trackTotalNumberOfRows: true) 400 412 5 0.0 39958.9 1.0X -RocksDB (trackTotalNumberOfRows: false) 170 175 4 0.1 16952.8 2.4X +RocksDB (trackTotalNumberOfRows: true) 406 419 7 0.0 40646.7 1.0X +RocksDB (trackTotalNumberOfRows: false) 173 179 3 0.1 17265.8 2.4X ================================================================================================ delete rows ================================================================================================ -OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 21.0.5+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor trying to delete 10000 rows from 10000 rows(10000 rows are non-existing - rate 100): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------------------------------- In-memory 0 1 0 27.0 37.0 1.0X -RocksDB (trackTotalNumberOfRows: true) 43 44 1 0.2 4315.2 0.0X -RocksDB (trackTotalNumberOfRows: false) 15 15 0 0.7 1489.0 0.0X +RocksDB (trackTotalNumberOfRows: true) 44 46 1 0.2 4447.0 0.0X +RocksDB (trackTotalNumberOfRows: false) 15 16 1 0.7 1453.0 0.0X -OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 21.0.5+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor trying to delete 10000 rows from 10000 rows(5000 rows are non-existing - rate 50): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ----------------------------------------------------------------------------------------------------------------------------------------------------------------- -In-memory 8 8 0 1.3 781.5 1.0X -RocksDB (trackTotalNumberOfRows: true) 43 44 1 0.2 4323.9 0.2X -RocksDB (trackTotalNumberOfRows: false) 15 15 0 0.7 1500.1 0.5X +In-memory 8 9 1 1.3 796.5 1.0X +RocksDB (trackTotalNumberOfRows: true) 44 45 1 0.2 4384.0 0.2X +RocksDB (trackTotalNumberOfRows: false) 15 15 0 0.7 1463.5 0.5X -OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 21.0.5+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor trying to delete 10000 rows from 10000 rows(1000 rows are non-existing - rate 10): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ----------------------------------------------------------------------------------------------------------------------------------------------------------------- -In-memory 8 9 1 1.2 829.7 1.0X -RocksDB (trackTotalNumberOfRows: true) 42 43 1 0.2 4234.1 0.2X -RocksDB (trackTotalNumberOfRows: false) 15 15 0 0.7 1491.0 0.6X +In-memory 9 9 1 1.2 853.3 1.0X +RocksDB (trackTotalNumberOfRows: true) 43 44 1 0.2 4278.0 0.2X +RocksDB (trackTotalNumberOfRows: false) 15 15 1 0.7 1460.7 0.6X -OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 21.0.5+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor trying to delete 10000 rows from 10000 rows(0 rows are non-existing - rate 0): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------------------------- -In-memory 8 9 0 1.2 838.7 1.0X -RocksDB (trackTotalNumberOfRows: true) 42 43 1 0.2 4185.5 0.2X -RocksDB (trackTotalNumberOfRows: false) 15 15 0 0.7 1485.0 0.6X +In-memory 9 10 2 1.2 854.0 1.0X +RocksDB (trackTotalNumberOfRows: true) 42 44 1 0.2 4183.1 0.2X +RocksDB (trackTotalNumberOfRows: false) 15 16 1 0.7 1457.0 0.6X ================================================================================================ evict rows ================================================================================================ -OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 21.0.5+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor evicting 10000 rows (maxTimestampToEvictInMillis: 9999) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------------------- -In-memory 8 9 0 1.2 832.1 1.0X -RocksDB (trackTotalNumberOfRows: true) 41 42 1 0.2 4142.6 0.2X -RocksDB (trackTotalNumberOfRows: false) 16 16 1 0.6 1621.2 0.5X +In-memory 8 9 0 1.2 837.4 1.0X +RocksDB (trackTotalNumberOfRows: true) 41 42 1 0.2 4146.8 0.2X +RocksDB (trackTotalNumberOfRows: false) 16 17 1 0.6 1623.1 0.5X -OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 21.0.5+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor evicting 5000 rows (maxTimestampToEvictInMillis: 4999) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------------------ -In-memory 8 8 0 1.3 783.9 1.0X -RocksDB (trackTotalNumberOfRows: true) 22 23 1 0.4 2226.5 0.4X -RocksDB (trackTotalNumberOfRows: false) 10 10 0 1.0 969.3 0.8X +In-memory 8 9 1 1.3 798.1 1.0X +RocksDB (trackTotalNumberOfRows: true) 22 23 1 0.5 2201.4 0.4X +RocksDB (trackTotalNumberOfRows: false) 10 10 1 1.0 956.5 0.8X -OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 21.0.5+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor evicting 1000 rows (maxTimestampToEvictInMillis: 999) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ----------------------------------------------------------------------------------------------------------------------------------------------------- -In-memory 7 8 0 1.4 714.1 1.0X -RocksDB (trackTotalNumberOfRows: true) 7 7 1 1.4 725.5 1.0X -RocksDB (trackTotalNumberOfRows: false) 5 5 0 2.1 476.0 1.5X +In-memory 7 8 1 1.4 724.1 1.0X +RocksDB (trackTotalNumberOfRows: true) 7 7 0 1.4 698.4 1.0X +RocksDB (trackTotalNumberOfRows: false) 5 5 0 2.2 450.9 1.6X -OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 21.0.5+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor evicting 0 rows (maxTimestampToEvictInMillis: -1) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------------- -In-memory 0 0 0 24.1 41.5 1.0X -RocksDB (trackTotalNumberOfRows: true) 3 4 0 2.9 343.6 0.1X -RocksDB (trackTotalNumberOfRows: false) 3 4 0 2.9 343.8 0.1X +In-memory 0 0 0 24.0 41.6 1.0X +RocksDB (trackTotalNumberOfRows: true) 3 3 1 3.2 317.3 0.1X +RocksDB (trackTotalNumberOfRows: false) 3 3 0 3.2 317.2 0.1X diff --git a/sql/core/benchmarks/StateStoreBasicOperationsBenchmark-results.txt b/sql/core/benchmarks/StateStoreBasicOperationsBenchmark-results.txt index 378cecf0271d2..856985b5d071f 100644 --- a/sql/core/benchmarks/StateStoreBasicOperationsBenchmark-results.txt +++ b/sql/core/benchmarks/StateStoreBasicOperationsBenchmark-results.txt @@ -2,143 +2,143 @@ put rows ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 17.0.13+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor putting 10000 rows (10000 rows to overwrite - rate 100): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative --------------------------------------------------------------------------------------------------------------------------------------- -In-memory 9 10 1 1.1 948.8 1.0X -RocksDB (trackTotalNumberOfRows: true) 43 44 2 0.2 4286.1 0.2X -RocksDB (trackTotalNumberOfRows: false) 16 16 1 0.6 1556.1 0.6X +In-memory 10 10 1 1.0 953.1 1.0X +RocksDB (trackTotalNumberOfRows: true) 43 44 2 0.2 4269.8 0.2X +RocksDB (trackTotalNumberOfRows: false) 16 16 1 0.6 1550.5 0.6X -OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 17.0.13+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor putting 10000 rows (5000 rows to overwrite - rate 50): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------- -In-memory 9 10 0 1.1 937.8 1.0X -RocksDB (trackTotalNumberOfRows: true) 44 45 1 0.2 4368.4 0.2X -RocksDB (trackTotalNumberOfRows: false) 16 16 1 0.6 1559.7 0.6X +In-memory 9 10 0 1.1 930.0 1.0X +RocksDB (trackTotalNumberOfRows: true) 44 45 1 0.2 4387.9 0.2X +RocksDB (trackTotalNumberOfRows: false) 15 16 1 0.7 1521.4 0.6X -OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 17.0.13+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor putting 10000 rows (1000 rows to overwrite - rate 10): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------- -In-memory 9 10 0 1.1 921.3 1.0X -RocksDB (trackTotalNumberOfRows: true) 44 45 1 0.2 4406.5 0.2X -RocksDB (trackTotalNumberOfRows: false) 16 16 1 0.6 1555.2 0.6X +In-memory 9 10 0 1.1 918.1 1.0X +RocksDB (trackTotalNumberOfRows: true) 44 45 1 0.2 4441.6 0.2X +RocksDB (trackTotalNumberOfRows: false) 15 16 1 0.7 1521.7 0.6X -OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 17.0.13+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor putting 10000 rows (0 rows to overwrite - rate 0): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative --------------------------------------------------------------------------------------------------------------------------------- -In-memory 9 10 0 1.1 918.1 1.0X -RocksDB (trackTotalNumberOfRows: true) 44 45 1 0.2 4372.7 0.2X -RocksDB (trackTotalNumberOfRows: false) 16 16 0 0.6 1559.8 0.6X +In-memory 9 10 0 1.1 916.9 1.0X +RocksDB (trackTotalNumberOfRows: true) 44 45 1 0.2 4413.7 0.2X +RocksDB (trackTotalNumberOfRows: false) 15 16 0 0.7 1522.0 0.6X ================================================================================================ merge rows ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 17.0.13+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor merging 10000 rows with 10 values per key (10000 rows to overwrite - rate 100): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------------------------------------- -RocksDB (trackTotalNumberOfRows: true) 542 555 7 0.0 54234.9 1.0X -RocksDB (trackTotalNumberOfRows: false) 179 185 3 0.1 17909.4 3.0X +RocksDB (trackTotalNumberOfRows: true) 542 553 6 0.0 54222.4 1.0X +RocksDB (trackTotalNumberOfRows: false) 174 179 3 0.1 17391.9 3.1X -OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 17.0.13+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor merging 10000 rows with 10 values per key (5000 rows to overwrite - rate 50): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------------------------ -RocksDB (trackTotalNumberOfRows: true) 474 487 5 0.0 47434.9 1.0X -RocksDB (trackTotalNumberOfRows: false) 180 184 3 0.1 17961.1 2.6X +RocksDB (trackTotalNumberOfRows: true) 479 490 5 0.0 47921.1 1.0X +RocksDB (trackTotalNumberOfRows: false) 174 179 3 0.1 17446.2 2.7X -OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 17.0.13+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor merging 10000 rows with 10 values per key (1000 rows to overwrite - rate 10): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------------------------ -RocksDB (trackTotalNumberOfRows: true) 419 428 4 0.0 41901.9 1.0X -RocksDB (trackTotalNumberOfRows: false) 175 181 2 0.1 17545.5 2.4X +RocksDB (trackTotalNumberOfRows: true) 423 433 5 0.0 42311.4 1.0X +RocksDB (trackTotalNumberOfRows: false) 173 178 3 0.1 17309.1 2.4X -OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 17.0.13+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor merging 10000 rows with 10 values per key (0 rows to overwrite - rate 0): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------------------------------- -RocksDB (trackTotalNumberOfRows: true) 400 410 5 0.0 39961.3 1.0X -RocksDB (trackTotalNumberOfRows: false) 175 182 3 0.1 17527.9 2.3X +RocksDB (trackTotalNumberOfRows: true) 408 419 5 0.0 40762.3 1.0X +RocksDB (trackTotalNumberOfRows: false) 174 183 3 0.1 17377.7 2.3X ================================================================================================ delete rows ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 17.0.13+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor trying to delete 10000 rows from 10000 rows(10000 rows are non-existing - rate 100): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------------------------------- -In-memory 0 1 0 25.7 38.9 1.0X -RocksDB (trackTotalNumberOfRows: true) 43 45 1 0.2 4347.8 0.0X -RocksDB (trackTotalNumberOfRows: false) 15 16 0 0.7 1495.4 0.0X +In-memory 0 0 0 26.1 38.3 1.0X +RocksDB (trackTotalNumberOfRows: true) 44 46 1 0.2 4444.2 0.0X +RocksDB (trackTotalNumberOfRows: false) 15 15 0 0.7 1489.6 0.0X -OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 17.0.13+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor trying to delete 10000 rows from 10000 rows(5000 rows are non-existing - rate 50): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ----------------------------------------------------------------------------------------------------------------------------------------------------------------- -In-memory 8 9 0 1.3 789.7 1.0X -RocksDB (trackTotalNumberOfRows: true) 44 45 1 0.2 4360.4 0.2X -RocksDB (trackTotalNumberOfRows: false) 15 15 0 0.7 1502.6 0.5X +In-memory 8 8 0 1.3 788.8 1.0X +RocksDB (trackTotalNumberOfRows: true) 44 45 1 0.2 4425.4 0.2X +RocksDB (trackTotalNumberOfRows: false) 15 16 1 0.7 1499.2 0.5X -OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 17.0.13+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor trying to delete 10000 rows from 10000 rows(1000 rows are non-existing - rate 10): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ----------------------------------------------------------------------------------------------------------------------------------------------------------------- -In-memory 8 9 0 1.2 833.1 1.0X -RocksDB (trackTotalNumberOfRows: true) 43 44 1 0.2 4274.2 0.2X -RocksDB (trackTotalNumberOfRows: false) 15 15 0 0.7 1499.0 0.6X +In-memory 8 9 0 1.2 841.3 1.0X +RocksDB (trackTotalNumberOfRows: true) 43 44 1 0.2 4336.9 0.2X +RocksDB (trackTotalNumberOfRows: false) 15 16 1 0.7 1493.6 0.6X -OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 17.0.13+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor trying to delete 10000 rows from 10000 rows(0 rows are non-existing - rate 0): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------------------------- -In-memory 8 9 0 1.2 845.5 1.0X -RocksDB (trackTotalNumberOfRows: true) 42 43 1 0.2 4220.8 0.2X -RocksDB (trackTotalNumberOfRows: false) 15 15 0 0.7 1479.2 0.6X +In-memory 8 9 0 1.2 848.9 1.0X +RocksDB (trackTotalNumberOfRows: true) 42 43 1 0.2 4216.8 0.2X +RocksDB (trackTotalNumberOfRows: false) 15 15 0 0.7 1467.4 0.6X ================================================================================================ evict rows ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 17.0.13+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor evicting 10000 rows (maxTimestampToEvictInMillis: 9999) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------------------- -In-memory 8 9 0 1.2 848.5 1.0X -RocksDB (trackTotalNumberOfRows: true) 42 43 0 0.2 4184.8 0.2X -RocksDB (trackTotalNumberOfRows: false) 16 16 0 0.6 1592.8 0.5X +In-memory 8 9 0 1.2 836.6 1.0X +RocksDB (trackTotalNumberOfRows: true) 42 43 2 0.2 4182.0 0.2X +RocksDB (trackTotalNumberOfRows: false) 16 17 0 0.6 1645.0 0.5X -OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 17.0.13+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor evicting 5000 rows (maxTimestampToEvictInMillis: 4999) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------------------ -In-memory 8 8 0 1.3 792.3 1.0X -RocksDB (trackTotalNumberOfRows: true) 23 23 1 0.4 2267.5 0.3X -RocksDB (trackTotalNumberOfRows: false) 10 10 0 1.0 983.7 0.8X +In-memory 8 8 0 1.3 785.1 1.0X +RocksDB (trackTotalNumberOfRows: true) 23 23 1 0.4 2258.3 0.3X +RocksDB (trackTotalNumberOfRows: false) 10 10 0 1.0 999.7 0.8X -OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 17.0.13+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor evicting 1000 rows (maxTimestampToEvictInMillis: 999) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ----------------------------------------------------------------------------------------------------------------------------------------------------- -In-memory 7 8 0 1.4 737.0 1.0X -RocksDB (trackTotalNumberOfRows: true) 7 8 0 1.3 742.5 1.0X -RocksDB (trackTotalNumberOfRows: false) 5 5 0 2.1 486.5 1.5X +In-memory 7 8 0 1.4 726.0 1.0X +RocksDB (trackTotalNumberOfRows: true) 7 8 0 1.4 736.8 1.0X +RocksDB (trackTotalNumberOfRows: false) 5 5 0 2.1 487.0 1.5X -OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 17.0.13+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor evicting 0 rows (maxTimestampToEvictInMillis: -1) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------------- -In-memory 0 0 0 21.4 46.6 1.0X -RocksDB (trackTotalNumberOfRows: true) 4 4 0 2.8 354.6 0.1X -RocksDB (trackTotalNumberOfRows: false) 4 4 0 2.8 353.8 0.1X +In-memory 0 0 0 22.8 43.9 1.0X +RocksDB (trackTotalNumberOfRows: true) 4 4 0 2.8 354.8 0.1X +RocksDB (trackTotalNumberOfRows: false) 4 4 0 2.8 353.1 0.1X diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcCompressionCodec.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcCompressionCodec.java similarity index 100% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcCompressionCodec.java rename to sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcCompressionCodec.java diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b489f33cd63b9..3953a5c3704f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -573,7 +573,8 @@ class Dataset[T] private[sql]( require(!IntervalUtils.isNegative(parsedDelay), s"delay threshold ($delayThreshold) should not be negative.") EliminateEventTimeWatermark( - EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan)) + EventTimeWatermark(util.UUID.randomUUID(), UnresolvedAttribute(eventTime), + parsedDelay, logicalPlan)) } /** @inheritdoc */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 53c335c1eced6..30b395d0c1369 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -425,8 +425,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case _ if !plan.isStreaming => Nil - case EventTimeWatermark(columnName, delay, child) => - EventTimeWatermarkExec(columnName, delay, planLater(child)) :: Nil + case EventTimeWatermark(nodeId, columnName, delay, child) => + EventTimeWatermarkExec(nodeId, columnName, delay, planLater(child)) :: Nil case UpdateEventTimeWatermarkColumn(columnName, delay, child) => // we expect watermarkDelay to be resolved before physical planning. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala index 54041abdc9ab4..d25c4be0fb84a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming +import java.util.UUID + import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Predicate, SortOrder, UnsafeProjection} @@ -90,6 +92,7 @@ class EventTimeStatsAccum(protected var currentStats: EventTimeStats = EventTime * period. Note that event time is measured in milliseconds. */ case class EventTimeWatermarkExec( + nodeId: UUID, eventTime: Attribute, delay: CalendarInterval, child: SparkPlan) extends UnaryExecNode { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index dc141b21780e7..5ce9e13eb8fac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -485,7 +485,7 @@ class MicroBatchExecution( OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.sessionState.conf) execCtx.offsetSeqMetadata = OffsetSeqMetadata( metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf) - watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf) + watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf, logicalPlan) watermarkTracker.setWatermark(metadata.batchWatermarkMs) } @@ -539,7 +539,7 @@ class MicroBatchExecution( case None => // We are starting this stream for the first time. logInfo(s"Starting new streaming query.") execCtx.batchId = 0 - watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf) + watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf, logicalPlan) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala index 3e6f122f463d3..7228767c4d18a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.execution.streaming -import java.util.Locale +import java.util.{Locale, UUID} import scala.collection.mutable import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys._ import org.apache.spark.sql.RuntimeConfig +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.internal.SQLConf @@ -79,8 +80,21 @@ case object MaxWatermark extends MultipleWatermarkPolicy { } /** Tracks the watermark value of a streaming query based on a given `policy` */ -case class WatermarkTracker(policy: MultipleWatermarkPolicy) extends Logging { - private val operatorToWatermarkMap = mutable.HashMap[Int, Long]() +class WatermarkTracker( + policy: MultipleWatermarkPolicy, + initialPlan: LogicalPlan) extends Logging { + + private val operatorToWatermarkMap: mutable.Map[UUID, Option[Long]] = { + val map = mutable.HashMap[UUID, Option[Long]]() + val watermarkOperators = initialPlan.collect { + case e: EventTimeWatermark => e + } + watermarkOperators.foreach { op => + map.put(op.nodeId, None) + } + map + } + private var globalWatermarkMs: Long = 0 def setWatermark(newWatermarkMs: Long): Unit = synchronized { @@ -93,26 +107,33 @@ case class WatermarkTracker(policy: MultipleWatermarkPolicy) extends Logging { } if (watermarkOperators.isEmpty) return - watermarkOperators.zipWithIndex.foreach { - case (e, index) if e.eventTimeStats.value.count > 0 => - logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}") + watermarkOperators.foreach { + case e if e.eventTimeStats.value.count > 0 => + logDebug(s"Observed event time stats ${e.nodeId}: ${e.eventTimeStats.value}") + + if (!operatorToWatermarkMap.isDefinedAt(e.nodeId)) { + throw new IllegalStateException(s"Unknown watermark node ID: ${e.nodeId}, known IDs: " + + s"${operatorToWatermarkMap.keys.mkString("[", ",", "]")}") + } + val newWatermarkMs = e.eventTimeStats.value.max - e.delayMs - val prevWatermarkMs = operatorToWatermarkMap.get(index) + val prevWatermarkMs = operatorToWatermarkMap(e.nodeId) if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) { - operatorToWatermarkMap.put(index, newWatermarkMs) + operatorToWatermarkMap.put(e.nodeId, Some(newWatermarkMs)) } - // Populate 0 if we haven't seen any data yet for this watermark node. - case (_, index) => - if (!operatorToWatermarkMap.isDefinedAt(index)) { - operatorToWatermarkMap.put(index, 0) + case e => + if (!operatorToWatermarkMap.isDefinedAt(e.nodeId)) { + throw new IllegalStateException(s"Unknown watermark node ID: ${e.nodeId}, known IDs: " + + s"${operatorToWatermarkMap.keys.mkString("[", ",", "]")}") } } // Update the global watermark accordingly to the chosen policy. To find all available policies // and their semantics, please check the comments of // `org.apache.spark.sql.execution.streaming.MultipleWatermarkPolicy` implementations. - val chosenGlobalWatermark = policy.chooseGlobalWatermark(operatorToWatermarkMap.values.toSeq) + val chosenGlobalWatermark = policy.chooseGlobalWatermark( + operatorToWatermarkMap.values.map(_.getOrElse(0L)).toSeq) if (chosenGlobalWatermark > globalWatermarkMs) { logInfo(log"Updating event-time watermark from " + log"${MDC(GLOBAL_WATERMARK, globalWatermarkMs)} " + @@ -124,10 +145,14 @@ case class WatermarkTracker(policy: MultipleWatermarkPolicy) extends Logging { } def currentWatermark: Long = synchronized { globalWatermarkMs } + + private[sql] def watermarkMap: Map[UUID, Option[Long]] = synchronized { + operatorToWatermarkMap.toMap + } } object WatermarkTracker { - def apply(conf: RuntimeConfig): WatermarkTracker = { + def apply(conf: RuntimeConfig, initialPlan: LogicalPlan): WatermarkTracker = { // If the session has been explicitly configured to use non-default policy then use it, // otherwise use the default `min` policy as thats the safe thing to do. // When recovering from a checkpoint location, it is expected that the `conf` will already @@ -137,6 +162,6 @@ object WatermarkTracker { val policyName = conf.get( SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key, MultipleWatermarkPolicy.DEFAULT_POLICY_NAME) - new WatermarkTracker(MultipleWatermarkPolicy(policyName)) + new WatermarkTracker(MultipleWatermarkPolicy(policyName), initialPlan) } } diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 9006a20d13f08..27d9367c49e9f 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -354,6 +354,9 @@ | org.apache.spark.sql.catalyst.expressions.TryAesDecrypt | try_aes_decrypt | SELECT try_aes_decrypt(unhex('6E7CA17BBB468D3084B5744BCA729FB7B2B7BCB8E4472847D02670489D95FA97DBBA7D3210'), '0000111122223333', 'GCM') | struct | | org.apache.spark.sql.catalyst.expressions.TryDivide | try_divide | SELECT try_divide(3, 2) | struct | | org.apache.spark.sql.catalyst.expressions.TryElementAt | try_element_at | SELECT try_element_at(array(1, 2, 3), 2) | struct | +| org.apache.spark.sql.catalyst.expressions.TryMakeTimestamp | try_make_timestamp | SELECT try_make_timestamp(2014, 12, 28, 6, 30, 45.887) | struct | +| org.apache.spark.sql.catalyst.expressions.TryMakeTimestampLTZExpressionBuilder | try_make_timestamp_ltz | SELECT try_make_timestamp_ltz(2014, 12, 28, 6, 30, 45.887) | struct | +| org.apache.spark.sql.catalyst.expressions.TryMakeTimestampNTZExpressionBuilder | try_make_timestamp_ntz | SELECT try_make_timestamp_ntz(2014, 12, 28, 6, 30, 45.887) | struct | | org.apache.spark.sql.catalyst.expressions.TryMod | try_mod | SELECT try_mod(3, 2) | struct | | org.apache.spark.sql.catalyst.expressions.TryMultiply | try_multiply | SELECT try_multiply(2, 3) | struct | | org.apache.spark.sql.catalyst.expressions.TryParseUrl | try_parse_url | SELECT try_parse_url('http://spark.apache.org/path?query=1', 'HOST') | struct | diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/interval.sql.out index b0d128c4cab69..c023e3b56f117 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/interval.sql.out @@ -1233,9 +1233,11 @@ select interval '1' year to second -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0028", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNSUPPORTED_FROM_TO_EXPRESSION", + "sqlState" : "22006", "messageParameters" : { "from" : "year", + "input" : "1", "to" : "second" }, "queryContext" : [ { diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out index 883acd3ca966f..c8e28c2cfafc9 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out @@ -649,6 +649,22 @@ InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_d +- LocalRelation [col1#x, col2#x, col3#x] +-- !query +insert into t5 values ('İo', 'İo', 'İo ') +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [s, utf8_binary, utf8_lcase] ++- Project [cast(col1#x as string) AS s#x, cast(col2#x as string) AS utf8_binary#x, cast(col3#x as string collate UTF8_LCASE) AS utf8_lcase#x] + +- LocalRelation [col1#x, col2#x, col3#x] + + +-- !query +insert into t5 values ('İo', 'İo ', 'İo') +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [s, utf8_binary, utf8_lcase] ++- Project [cast(col1#x as string) AS s#x, cast(col2#x as string) AS utf8_binary#x, cast(col3#x as string collate UTF8_LCASE) AS utf8_lcase#x] + +- LocalRelation [col1#x, col2#x, col3#x] + + -- !query insert into t5 values ('İo', 'İo', 'i̇o') -- !query analysis @@ -1021,6 +1037,14 @@ Project [split_part(cast(utf8_binary#x as string collate UTF8_LCASE), collate(a, +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select split_part(utf8_binary, 'a ' collate utf8_lcase_rtrim, 3), split_part(utf8_lcase, 'a' collate utf8_binary, 3) from t5 +-- !query analysis +Project [split_part(cast(utf8_binary#x as string collate UTF8_LCASE_RTRIM), collate(a , utf8_lcase_rtrim), 3) AS split_part(utf8_binary, collate(a , utf8_lcase_rtrim), 3)#x, split_part(cast(utf8_lcase#x as string), collate(a, utf8_binary), 3) AS split_part(utf8_lcase, collate(a, utf8_binary), 3)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + -- !query select contains(utf8_binary, utf8_lcase) from t5 -- !query analysis @@ -1111,6 +1135,14 @@ Project [Contains(cast(utf8_binary#x as string collate UTF8_LCASE), collate(AaAA +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select contains(utf8_binary, 'AaAA ' collate utf8_lcase_rtrim), contains(utf8_lcase, 'AAa ' collate utf8_binary_rtrim) from t5 +-- !query analysis +Project [Contains(cast(utf8_binary#x as string collate UTF8_LCASE_RTRIM), collate(AaAA , utf8_lcase_rtrim)) AS contains(utf8_binary, collate(AaAA , utf8_lcase_rtrim))#x, Contains(cast(utf8_lcase#x as string collate UTF8_BINARY_RTRIM), collate(AAa , utf8_binary_rtrim)) AS contains(utf8_lcase, collate(AAa , utf8_binary_rtrim))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + -- !query select substring_index(utf8_binary, utf8_lcase, 2) from t5 -- !query analysis @@ -1201,6 +1233,14 @@ Project [substring_index(cast(utf8_binary#x as string collate UTF8_LCASE), colla +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select substring_index(utf8_binary, 'AaAA ' collate utf8_lcase_rtrim, 2), substring_index(utf8_lcase, 'AAa' collate utf8_binary, 2) from t5 +-- !query analysis +Project [substring_index(cast(utf8_binary#x as string collate UTF8_LCASE_RTRIM), collate(AaAA , utf8_lcase_rtrim), 2) AS substring_index(utf8_binary, collate(AaAA , utf8_lcase_rtrim), 2)#x, substring_index(cast(utf8_lcase#x as string), collate(AAa, utf8_binary), 2) AS substring_index(utf8_lcase, collate(AAa, utf8_binary), 2)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + -- !query select instr(utf8_binary, utf8_lcase) from t5 -- !query analysis @@ -1357,6 +1397,14 @@ Project [find_in_set(cast(utf8_binary#x as string collate UTF8_LCASE), collate(a +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select find_in_set(utf8_binary, 'aaAaaAaA,i̇o ' collate utf8_lcase_rtrim), find_in_set(utf8_lcase, 'aaAaaAaA,i̇o' collate utf8_binary) from t5 +-- !query analysis +Project [find_in_set(cast(utf8_binary#x as string collate UTF8_LCASE_RTRIM), collate(aaAaaAaA,i̇o , utf8_lcase_rtrim)) AS find_in_set(utf8_binary, collate(aaAaaAaA,i̇o , utf8_lcase_rtrim))#x, find_in_set(cast(utf8_lcase#x as string), collate(aaAaaAaA,i̇o, utf8_binary)) AS find_in_set(utf8_lcase, collate(aaAaaAaA,i̇o, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + -- !query select startswith(utf8_binary, utf8_lcase) from t5 -- !query analysis @@ -1447,6 +1495,14 @@ Project [StartsWith(cast(utf8_binary#x as string collate UTF8_LCASE), collate(aa +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select startswith(utf8_binary, 'aaAaaAaA ' collate utf8_lcase_rtrim), startswith(utf8_lcase, 'aaAaaAaA' collate utf8_binary) from t5 +-- !query analysis +Project [StartsWith(cast(utf8_binary#x as string collate UTF8_LCASE_RTRIM), collate(aaAaaAaA , utf8_lcase_rtrim)) AS startswith(utf8_binary, collate(aaAaaAaA , utf8_lcase_rtrim))#x, StartsWith(cast(utf8_lcase#x as string), collate(aaAaaAaA, utf8_binary)) AS startswith(utf8_lcase, collate(aaAaaAaA, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + -- !query select translate(utf8_lcase, utf8_lcase, '12345') from t5 -- !query analysis @@ -1529,6 +1585,14 @@ Project [translate(cast(utf8_lcase#x as string), collate(aBc, utf8_binary), 1234 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select translate(utf8_lcase, 'aBc ' collate utf8_binary_rtrim, '12345'), translate(utf8_binary, 'aBc' collate utf8_lcase, '12345') from t5 +-- !query analysis +Project [translate(cast(utf8_lcase#x as string collate UTF8_BINARY_RTRIM), collate(aBc , utf8_binary_rtrim), cast(12345 as string collate UTF8_BINARY_RTRIM)) AS translate(utf8_lcase, collate(aBc , utf8_binary_rtrim), 12345)#x, translate(cast(utf8_binary#x as string collate UTF8_LCASE), collate(aBc, utf8_lcase), cast(12345 as string collate UTF8_LCASE)) AS translate(utf8_binary, collate(aBc, utf8_lcase), 12345)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + -- !query select replace(utf8_binary, utf8_lcase, 'abc') from t5 -- !query analysis @@ -1619,6 +1683,14 @@ Project [replace(cast(utf8_binary#x as string collate UTF8_LCASE), collate(aaAaa +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select replace(utf8_binary, 'aaAaaAaA ' collate utf8_lcase_rtrim, 'abc'), replace(utf8_lcase, 'aaAaaAaA' collate utf8_binary, 'abc') from t5 +-- !query analysis +Project [replace(cast(utf8_binary#x as string collate UTF8_LCASE_RTRIM), collate(aaAaaAaA , utf8_lcase_rtrim), cast(abc as string collate UTF8_LCASE_RTRIM)) AS replace(utf8_binary, collate(aaAaaAaA , utf8_lcase_rtrim), abc)#x, replace(cast(utf8_lcase#x as string), collate(aaAaaAaA, utf8_binary), abc) AS replace(utf8_lcase, collate(aaAaaAaA, utf8_binary), abc)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + -- !query select endswith(utf8_binary, utf8_lcase) from t5 -- !query analysis @@ -1709,6 +1781,14 @@ Project [EndsWith(cast(utf8_binary#x as string collate UTF8_LCASE), collate(aaAa +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select endswith(utf8_binary, 'aaAaaAaA ' collate utf8_lcase_rtrim), endswith(utf8_lcase, 'aaAaaAaA' collate utf8_binary) from t5 +-- !query analysis +Project [EndsWith(cast(utf8_binary#x as string collate UTF8_LCASE_RTRIM), collate(aaAaaAaA , utf8_lcase_rtrim)) AS endswith(utf8_binary, collate(aaAaaAaA , utf8_lcase_rtrim))#x, EndsWith(cast(utf8_lcase#x as string), collate(aaAaaAaA, utf8_binary)) AS endswith(utf8_lcase, collate(aaAaaAaA, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + -- !query select repeat(utf8_binary, 3), repeat(utf8_lcase, 2) from t5 -- !query analysis @@ -2299,6 +2379,14 @@ Project [rpad(collate(utf8_binary#x, utf8_lcase), 8, collate(utf8_lcase#x, utf8_ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select lpad(utf8_binary collate utf8_binary_rtrim, 8, utf8_lcase collate utf8_binary_rtrim) from t5 +-- !query analysis +Project [lpad(collate(utf8_binary#x, utf8_binary_rtrim), 8, collate(utf8_lcase#x, utf8_binary_rtrim)) AS lpad(collate(utf8_binary, utf8_binary_rtrim), 8, collate(utf8_lcase, utf8_binary_rtrim))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + -- !query select rpad(utf8_binary, 8, 'a'), rpad(utf8_lcase, 8, 'a') from t5 -- !query analysis @@ -2365,6 +2453,14 @@ Project [lpad(collate(utf8_binary#x, utf8_lcase), 8, collate(utf8_lcase#x, utf8_ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select lpad(utf8_binary collate utf8_binary_rtrim, 8, utf8_lcase collate utf8_binary_rtrim) from t5 +-- !query analysis +Project [lpad(collate(utf8_binary#x, utf8_binary_rtrim), 8, collate(utf8_lcase#x, utf8_binary_rtrim)) AS lpad(collate(utf8_binary, utf8_binary_rtrim), 8, collate(utf8_lcase, utf8_binary_rtrim))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + -- !query select lpad(utf8_binary, 8, 'a'), lpad(utf8_lcase, 8, 'a') from t5 -- !query analysis @@ -2471,6 +2567,14 @@ Project [locate(cast(utf8_binary#x as string collate UTF8_LCASE), collate(AaAA, +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select locate(utf8_binary, 'AaAA ' collate utf8_binary_rtrim, 4), locate(utf8_lcase, 'AAa ' collate utf8_binary, 4) from t5 +-- !query analysis +Project [locate(cast(utf8_binary#x as string collate UTF8_BINARY_RTRIM), collate(AaAA , utf8_binary_rtrim), 4) AS locate(utf8_binary, collate(AaAA , utf8_binary_rtrim), 4)#x, locate(cast(utf8_lcase#x as string), collate(AAa , utf8_binary), 4) AS locate(utf8_lcase, collate(AAa , utf8_binary), 4)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + -- !query select TRIM(utf8_binary, utf8_lcase) from t5 -- !query analysis @@ -2545,6 +2649,14 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException } +-- !query +select TRIM(utf8_binary collate utf8_binary_rtrim, utf8_lcase collate utf8_binary_rtrim) from t5 +-- !query analysis +Project [trim(collate(utf8_lcase#x, utf8_binary_rtrim), Some(collate(utf8_binary#x, utf8_binary_rtrim))) AS TRIM(BOTH collate(utf8_binary, utf8_binary_rtrim) FROM collate(utf8_lcase, utf8_binary_rtrim))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + -- !query select TRIM('ABc', utf8_binary), TRIM('ABc', utf8_lcase) from t5 -- !query analysis @@ -2635,6 +2747,14 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException } +-- !query +select BTRIM(utf8_binary collate utf8_binary_rtrim, utf8_lcase collate utf8_binary_rtrim) from t5 +-- !query analysis +Project [btrim(collate(utf8_binary#x, utf8_binary_rtrim), collate(utf8_lcase#x, utf8_binary_rtrim)) AS btrim(collate(utf8_binary, utf8_binary_rtrim), collate(utf8_lcase, utf8_binary_rtrim))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + -- !query select BTRIM('ABc', utf8_binary), BTRIM('ABc', utf8_lcase) from t5 -- !query analysis @@ -2725,6 +2845,14 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException } +-- !query +select LTRIM(utf8_binary collate utf8_binary_rtrim, utf8_lcase collate utf8_binary_rtrim) from t5 +-- !query analysis +Project [ltrim(collate(utf8_lcase#x, utf8_binary_rtrim), Some(collate(utf8_binary#x, utf8_binary_rtrim))) AS TRIM(LEADING collate(utf8_binary, utf8_binary_rtrim) FROM collate(utf8_lcase, utf8_binary_rtrim))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + -- !query select LTRIM('ABc', utf8_binary), LTRIM('ABc', utf8_lcase) from t5 -- !query analysis @@ -2815,6 +2943,14 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException } +-- !query +select RTRIM(utf8_binary collate utf8_binary_rtrim, utf8_lcase collate utf8_binary_rtrim) from t5 +-- !query analysis +Project [rtrim(collate(utf8_lcase#x, utf8_binary_rtrim), Some(collate(utf8_binary#x, utf8_binary_rtrim))) AS TRIM(TRAILING collate(utf8_binary, utf8_binary_rtrim) FROM collate(utf8_lcase, utf8_binary_rtrim))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + -- !query select RTRIM('ABc', utf8_binary), RTRIM('ABc', utf8_lcase) from t5 -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/interval.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/interval.sql.out index efa149509751d..c0196bbe118ef 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/interval.sql.out @@ -1233,9 +1233,11 @@ select interval '1' year to second -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0028", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNSUPPORTED_FROM_TO_EXPRESSION", + "sqlState" : "22006", "messageParameters" : { "from" : "year", + "input" : "1", "to" : "second" }, "queryContext" : [ { diff --git a/sql/core/src/test/resources/sql-tests/inputs/collations.sql b/sql/core/src/test/resources/sql-tests/inputs/collations.sql index bbbb229ad1cd7..b4d33bb0196c9 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/collations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/collations.sql @@ -117,6 +117,8 @@ insert into t5 values ('aaAaAAaA', 'aaAaAAaA', 'aaAaaAaA'); insert into t5 values ('aaAaAAaA', 'aaAaAAaA', 'aaAaaAaAaaAaaAaAaaAaaAaA'); insert into t5 values ('bbAbaAbA', 'bbAbAAbA', 'a'); insert into t5 values ('İo', 'İo', 'İo'); +insert into t5 values ('İo', 'İo', 'İo '); +insert into t5 values ('İo', 'İo ', 'İo'); insert into t5 values ('İo', 'İo', 'i̇o'); insert into t5 values ('efd2', 'efd2', 'efd2'); insert into t5 values ('Hello, world! Nice day.', 'Hello, world! Nice day.', 'Hello, world! Nice day.'); @@ -170,6 +172,7 @@ select split_part(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, select split_part(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2) from t5; select split_part(utf8_binary, 'a', 3), split_part(utf8_lcase, 'a', 3) from t5; select split_part(utf8_binary, 'a' collate utf8_lcase, 3), split_part(utf8_lcase, 'a' collate utf8_binary, 3) from t5; +select split_part(utf8_binary, 'a ' collate utf8_lcase_rtrim, 3), split_part(utf8_lcase, 'a' collate utf8_binary, 3) from t5; -- Contains select contains(utf8_binary, utf8_lcase) from t5; @@ -180,6 +183,7 @@ select contains(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) f select contains(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; select contains(utf8_binary, 'a'), contains(utf8_lcase, 'a') from t5; select contains(utf8_binary, 'AaAA' collate utf8_lcase), contains(utf8_lcase, 'AAa' collate utf8_binary) from t5; +select contains(utf8_binary, 'AaAA ' collate utf8_lcase_rtrim), contains(utf8_lcase, 'AAa ' collate utf8_binary_rtrim) from t5; -- SubstringIndex select substring_index(utf8_binary, utf8_lcase, 2) from t5; @@ -190,6 +194,7 @@ select substring_index(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_l select substring_index(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2) from t5; select substring_index(utf8_binary, 'a', 2), substring_index(utf8_lcase, 'a', 2) from t5; select substring_index(utf8_binary, 'AaAA' collate utf8_lcase, 2), substring_index(utf8_lcase, 'AAa' collate utf8_binary, 2) from t5; +select substring_index(utf8_binary, 'AaAA ' collate utf8_lcase_rtrim, 2), substring_index(utf8_lcase, 'AAa' collate utf8_binary, 2) from t5; -- StringInStr select instr(utf8_binary, utf8_lcase) from t5; @@ -209,7 +214,7 @@ select find_in_set(utf8_binary, utf8_lcase collate utf8_binary) from t5; select find_in_set(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; select find_in_set(utf8_binary, 'aaAaaAaA,i̇o'), find_in_set(utf8_lcase, 'aaAaaAaA,i̇o') from t5; select find_in_set(utf8_binary, 'aaAaaAaA,i̇o' collate utf8_lcase), find_in_set(utf8_lcase, 'aaAaaAaA,i̇o' collate utf8_binary) from t5; - +select find_in_set(utf8_binary, 'aaAaaAaA,i̇o ' collate utf8_lcase_rtrim), find_in_set(utf8_lcase, 'aaAaaAaA,i̇o' collate utf8_binary) from t5; -- StartsWith select startswith(utf8_binary, utf8_lcase) from t5; select startswith(s, utf8_binary) from t5; @@ -219,6 +224,7 @@ select startswith(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) select startswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; select startswith(utf8_binary, 'aaAaaAaA'), startswith(utf8_lcase, 'aaAaaAaA') from t5; select startswith(utf8_binary, 'aaAaaAaA' collate utf8_lcase), startswith(utf8_lcase, 'aaAaaAaA' collate utf8_binary) from t5; +select startswith(utf8_binary, 'aaAaaAaA ' collate utf8_lcase_rtrim), startswith(utf8_lcase, 'aaAaaAaA' collate utf8_binary) from t5; -- StringTranslate select translate(utf8_lcase, utf8_lcase, '12345') from t5; @@ -228,6 +234,7 @@ select translate(utf8_binary, 'SQL' collate utf8_lcase, '12345' collate utf8_lca select translate(utf8_binary, 'SQL' collate unicode_ai, '12345' collate unicode_ai) from t5; select translate(utf8_lcase, 'aaAaaAaA', '12345'), translate(utf8_binary, 'aaAaaAaA', '12345') from t5; select translate(utf8_lcase, 'aBc' collate utf8_binary, '12345'), translate(utf8_binary, 'aBc' collate utf8_lcase, '12345') from t5; +select translate(utf8_lcase, 'aBc ' collate utf8_binary_rtrim, '12345'), translate(utf8_binary, 'aBc' collate utf8_lcase, '12345') from t5; -- Replace select replace(utf8_binary, utf8_lcase, 'abc') from t5; @@ -238,6 +245,7 @@ select replace(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 'a select replace(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 'abc') from t5; select replace(utf8_binary, 'aaAaaAaA', 'abc'), replace(utf8_lcase, 'aaAaaAaA', 'abc') from t5; select replace(utf8_binary, 'aaAaaAaA' collate utf8_lcase, 'abc'), replace(utf8_lcase, 'aaAaaAaA' collate utf8_binary, 'abc') from t5; +select replace(utf8_binary, 'aaAaaAaA ' collate utf8_lcase_rtrim, 'abc'), replace(utf8_lcase, 'aaAaaAaA' collate utf8_binary, 'abc') from t5; -- EndsWith select endswith(utf8_binary, utf8_lcase) from t5; @@ -248,6 +256,7 @@ select endswith(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) f select endswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; select endswith(utf8_binary, 'aaAaaAaA'), endswith(utf8_lcase, 'aaAaaAaA') from t5; select endswith(utf8_binary, 'aaAaaAaA' collate utf8_lcase), endswith(utf8_lcase, 'aaAaaAaA' collate utf8_binary) from t5; +select endswith(utf8_binary, 'aaAaaAaA ' collate utf8_lcase_rtrim), endswith(utf8_lcase, 'aaAaaAaA' collate utf8_binary) from t5; -- StringRepeat select repeat(utf8_binary, 3), repeat(utf8_lcase, 2) from t5; @@ -362,6 +371,7 @@ select rpad(s, 8, utf8_binary) from t5; select rpad(utf8_binary collate utf8_binary, 8, s collate utf8_lcase) from t5; select rpad(utf8_binary, 8, utf8_lcase collate utf8_binary) from t5; select rpad(utf8_binary collate utf8_lcase, 8, utf8_lcase collate utf8_lcase) from t5; +select lpad(utf8_binary collate utf8_binary_rtrim, 8, utf8_lcase collate utf8_binary_rtrim) from t5; select rpad(utf8_binary, 8, 'a'), rpad(utf8_lcase, 8, 'a') from t5; select rpad(utf8_binary, 8, 'AaAA' collate utf8_lcase), rpad(utf8_lcase, 8, 'AAa' collate utf8_binary) from t5; @@ -371,6 +381,7 @@ select lpad(s, 8, utf8_binary) from t5; select lpad(utf8_binary collate utf8_binary, 8, s collate utf8_lcase) from t5; select lpad(utf8_binary, 8, utf8_lcase collate utf8_binary) from t5; select lpad(utf8_binary collate utf8_lcase, 8, utf8_lcase collate utf8_lcase) from t5; +select lpad(utf8_binary collate utf8_binary_rtrim, 8, utf8_lcase collate utf8_binary_rtrim) from t5; select lpad(utf8_binary, 8, 'a'), lpad(utf8_lcase, 8, 'a') from t5; select lpad(utf8_binary, 8, 'AaAA' collate utf8_lcase), lpad(utf8_lcase, 8, 'AAa' collate utf8_binary) from t5; @@ -383,6 +394,7 @@ select locate(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 3) select locate(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 3) from t5; select locate(utf8_binary, 'a'), locate(utf8_lcase, 'a') from t5; select locate(utf8_binary, 'AaAA' collate utf8_lcase, 4), locate(utf8_lcase, 'AAa' collate utf8_binary, 4) from t5; +select locate(utf8_binary, 'AaAA ' collate utf8_binary_rtrim, 4), locate(utf8_lcase, 'AAa ' collate utf8_binary, 4) from t5; -- StringTrim select TRIM(utf8_binary, utf8_lcase) from t5; @@ -391,6 +403,7 @@ select TRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select TRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5; select TRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; select TRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; +select TRIM(utf8_binary collate utf8_binary_rtrim, utf8_lcase collate utf8_binary_rtrim) from t5; select TRIM('ABc', utf8_binary), TRIM('ABc', utf8_lcase) from t5; select TRIM('ABc' collate utf8_lcase, utf8_binary), TRIM('AAa' collate utf8_binary, utf8_lcase) from t5; -- StringTrimBoth @@ -400,6 +413,7 @@ select BTRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select BTRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5; select BTRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; select BTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; +select BTRIM(utf8_binary collate utf8_binary_rtrim, utf8_lcase collate utf8_binary_rtrim) from t5; select BTRIM('ABc', utf8_binary), BTRIM('ABc', utf8_lcase) from t5; select BTRIM('ABc' collate utf8_lcase, utf8_binary), BTRIM('AAa' collate utf8_binary, utf8_lcase) from t5; -- StringTrimLeft @@ -409,6 +423,7 @@ select LTRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select LTRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5; select LTRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; select LTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; +select LTRIM(utf8_binary collate utf8_binary_rtrim, utf8_lcase collate utf8_binary_rtrim) from t5; select LTRIM('ABc', utf8_binary), LTRIM('ABc', utf8_lcase) from t5; select LTRIM('ABc' collate utf8_lcase, utf8_binary), LTRIM('AAa' collate utf8_binary, utf8_lcase) from t5; -- StringTrimRight @@ -418,6 +433,7 @@ select RTRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select RTRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5; select RTRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; select RTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; +select RTRIM(utf8_binary collate utf8_binary_rtrim, utf8_lcase collate utf8_binary_rtrim) from t5; select RTRIM('ABc', utf8_binary), RTRIM('ABc', utf8_lcase) from t5; select RTRIM('ABc' collate utf8_lcase, utf8_binary), RTRIM('AAa' collate utf8_binary, utf8_lcase) from t5; diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out index b2f85835eb0df..766bfba7696f0 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out @@ -1535,9 +1535,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0028", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNSUPPORTED_FROM_TO_EXPRESSION", + "sqlState" : "22006", "messageParameters" : { "from" : "year", + "input" : "1", "to" : "second" }, "queryContext" : [ { diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/timestamp.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/timestamp.sql.out index d7a58e321b0f0..c64bd2ff57e17 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/timestamp.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/timestamp.sql.out @@ -126,7 +126,7 @@ org.apache.spark.SparkDateTimeException "errorClass" : "INVALID_FRACTION_OF_SECOND", "sqlState" : "22023", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"" + "secAndMicros" : "60.007000" } } diff --git a/sql/core/src/test/resources/sql-tests/results/collations.sql.out b/sql/core/src/test/resources/sql-tests/results/collations.sql.out index d64b8869905d4..f92fc5de8c3f4 100644 --- a/sql/core/src/test/resources/sql-tests/results/collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/collations.sql.out @@ -708,6 +708,22 @@ struct<> +-- !query +insert into t5 values ('İo', 'İo', 'İo ') +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into t5 values ('İo', 'İo ', 'İo') +-- !query schema +struct<> +-- !query output + + + -- !query insert into t5 values ('İo', 'İo', 'i̇o') -- !query schema @@ -893,6 +909,8 @@ abc abc efd2 efd2 i̇o i̇o sitTing sitTing +İo İo +İo İo İo İo @@ -942,6 +960,8 @@ abcdcba SQL bbAbAAbA SQL efd2 SQL kitten SQL +İo SQL +İo SQL İo SQL İo SQL @@ -963,6 +983,8 @@ abc,word abc,word efd2,word efd2,word i̇o,word İo,word sitTing,word kitten,word +İo ,word İo,word +İo,word İo ,word İo,word İo,word @@ -983,6 +1005,8 @@ abc,word abc,word efd2,word efd2,word i̇o,word İo,word sitTing,word kitten,word +İo ,word İo,word +İo,word İo ,word İo,word İo,word @@ -1004,6 +1028,8 @@ efd2 kitten İo İo +İo +İo -- !query @@ -1054,6 +1080,8 @@ efd2 kitten İo İo +İo +İo -- !query @@ -1074,6 +1102,8 @@ efd2 kitten İo İo +İo +İo -- !query @@ -1094,6 +1124,8 @@ efd2 efd2 kitten sitTing İo i̇o İo İo +İo İo +İo İo -- !query @@ -1114,6 +1146,8 @@ efd2 efd2 kitten sitTing İo i̇o İo İo +İo İo +İo İo -- !query @@ -1148,7 +1182,9 @@ struct + bbAbaAbA +İo -- !query @@ -1191,6 +1227,8 @@ struct +-- !query output + + + + + + + + + + + + A A A @@ -1281,6 +1345,8 @@ select contains(s, utf8_binary) from t5 struct -- !query output false +false +true true true true @@ -1322,6 +1388,8 @@ false false false false +false +true true true true @@ -1338,6 +1406,8 @@ struct +-- !query output +false false +false false +false false +false false +false false +false false +false false +false false +false false +false false +false false +false false true false true false true true @@ -1448,7 +1544,9 @@ struct + bbAbaAbA +İo -- !query @@ -1484,6 +1582,8 @@ efd2 kitten İo İo +İo +İo -- !query @@ -1504,6 +1604,8 @@ efd2 kitten İo İo +İo +İo -- !query @@ -1550,6 +1652,8 @@ efd2 efd2 kitten sitTing İo i̇o İo İo +İo İo +İo İo -- !query @@ -1570,6 +1674,30 @@ efd2 efd2 kitten sitTing İo i̇o İo İo +İo İo +İo İo + + +-- !query +select substring_index(utf8_binary, 'AaAA ' collate utf8_lcase_rtrim, 2), substring_index(utf8_lcase, 'AAa' collate utf8_binary, 2) from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. Hello, world! Nice day. +Something else. Nothing here. Something else. Nothing here. +Spark SQL +a aaAaAAaA +a aaAaaAaA +a aaAaaAaAaaAaaAaAaaAaaAaA +abc abc +abcdcba aBcDCbA +bbAbAAbA a +efd2 efd2 +kitten sitTing +İo i̇o +İo İo +İo İo +İo İo -- !query @@ -1593,6 +1721,8 @@ select instr(s, utf8_binary) from t5 struct -- !query output 0 +0 +1 1 1 1 @@ -1634,6 +1764,8 @@ struct 0 0 0 +0 +1 1 1 1 @@ -1650,6 +1782,8 @@ struct 0 0 0 0 0 0 +0 0 +0 0 0 1 1 1 1 1 @@ -1723,6 +1859,8 @@ struct -- !query output 0 0 +0 +1 1 1 1 @@ -1791,6 +1931,8 @@ struct 0 0 0 +0 +0 1 1 1 @@ -1808,6 +1950,8 @@ struct +-- !query output +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +1 0 +1 0 +1 1 +2 0 +2 0 +2 0 2 2 @@ -1879,6 +2049,8 @@ select startswith(s, utf8_binary) from t5 struct -- !query output false +false +true true true true @@ -1920,6 +2092,8 @@ false false false false +false +true true true true @@ -1937,6 +2111,8 @@ false false false false +false +true true true true @@ -1989,6 +2165,8 @@ false false false false false false false false +false false +false false false true false true false true @@ -2009,6 +2187,30 @@ false false false false false false false false +false false +false false +true false +true true +true true + + +-- !query +select startswith(utf8_binary, 'aaAaaAaA ' collate utf8_lcase_rtrim), startswith(utf8_lcase, 'aaAaaAaA' collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +false false +false false +false false +false false +false false +false false +false false +false false +false false +false false +false false +false false true false true true true true @@ -2024,6 +2226,8 @@ struct 11111111 111111111111111111111111 12 +12 +123 123 123 123 @@ -2082,6 +2286,8 @@ efd2 kitten İo İo +İo +İo -- !query @@ -2128,6 +2334,8 @@ efd2 efd2 i̇o İo sitTing kitten İo İo +İo İo +İo İo -- !query @@ -2148,6 +2356,30 @@ efd2 efd2 i̇o İo sitTing kitten İo İo +İo İo +İo İo + + +-- !query +select translate(utf8_lcase, 'aBc ' collate utf8_binary_rtrim, '12345'), translate(utf8_binary, 'aBc' collate utf8_lcase, '12345') from t5 +-- !query schema +struct +-- !query output +1 22121121 +11A11A1A 11111111 +11A11A1A11A11A1A11A11A1A 11111111 +11A1AA1A 11111111 +123DCbA 123d321 +1b3 123 +Hello,4world!4Ni3e4d1y. Hello, world! Ni3e d1y. +SQL Sp1rk +Something4else.4Nothing4here. Something else. Nothing here. +efd2 efd2 +i̇o İo +sitTing kitten +İo İo +İo İo +İo4 İo -- !query @@ -2182,7 +2414,9 @@ abc abc abc abc +abc bbAbaAbA +İo -- !query @@ -2214,10 +2448,12 @@ abc abc abc abc +abc abcdcba bbAbAAbA kitten İo +İo -- !query @@ -2236,8 +2472,10 @@ abc abc abc abc +abc bbabcbabcabcbabc kitten +İo -- !query @@ -2284,6 +2522,8 @@ efd2 efd2 kitten sitTing İo i̇o İo İo +İo İo +İo İo -- !query @@ -2304,6 +2544,30 @@ efd2 efd2 kitten sitTing İo i̇o İo İo +İo İo +İo İo + + +-- !query +select replace(utf8_binary, 'aaAaaAaA ' collate utf8_lcase_rtrim, 'abc'), replace(utf8_lcase, 'aaAaaAaA' collate utf8_binary, 'abc') from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. Hello, world! Nice day. +Something else. Nothing here. Something else. Nothing here. +Spark SQL +aaAaAAaA aaAaAAaA +aaAaAAaA abc +aaAaAAaA abcabcabc +abc abc +abcdcba aBcDCbA +bbAbAAbA a +efd2 efd2 +kitten sitTing +İo i̇o +İo İo +İo İo +İo İo -- !query @@ -2327,6 +2591,8 @@ select endswith(s, utf8_binary) from t5 struct -- !query output false +false +true true true true @@ -2368,6 +2634,8 @@ false false false false +false +false true true true @@ -2384,6 +2652,8 @@ struct +-- !query output +false false +false false +false false +false false +false false +false false +false false +false false +false false +false false +false false +false false true false true true true true @@ -2478,7 +2774,9 @@ abcdcbaabcdcbaabcdcba aBcDCbAaBcDCbA bbAbAAbAbbAbAAbAbbAbAAbA aa efd2efd2efd2 efd2efd2 kittenkittenkitten sitTingsitTing +İo İo İo İoİo İoİoİo i̇oi̇o +İoİoİo İo İo İoİoİo İoİo @@ -2498,7 +2796,9 @@ abcdcbaabcdcbaabcdcba aBcDCbAaBcDCbA bbAbAAbAbbAbAAbAbbAbAAbA aa efd2efd2efd2 efd2efd2 kittenkittenkitten sitTingsitTing +İo İo İo İoİo İoİoİo i̇oi̇o +İoİoİo İo İo İoİoİo İoİo @@ -2511,6 +2811,8 @@ struct 107 115 304 105 304 304 +304 304 +304 304 72 72 83 83 83 83 @@ -2531,6 +2833,8 @@ struct>,sentences(utf8_lcase, , ) [["kitten"]] [["sitTing"]] [["İo"]] [["i̇o"]] [["İo"]] [["İo"]] +[["İo"]] [["İo"]] +[["İo"]] [["İo"]] -- !query @@ -2785,6 +3107,8 @@ struct -- !query output 2 2 2 3 +2 3 23 23 29 29 +3 2 3 3 4 4 5 3 @@ -3128,8 +3480,10 @@ struct 24 24 24 24 24 32 +24 32 +32 24 32 32 40 24 48 56 @@ -3171,6 +3527,8 @@ struct 3 3 3 3 3 4 +3 4 +4 3 4 4 5 3 6 7 @@ -3211,6 +3571,8 @@ struct 0 0 0 +0 +1 1 @@ -3292,6 +3656,8 @@ struct 0 0 1 +1 +1 16 2 4 @@ -3312,6 +3678,8 @@ struct 2 2 2 2 2 3 +2 3 22 22 29 29 +3 2 4 3 4 4 6 6 @@ -3356,6 +3726,8 @@ struct +-- !query output +Hello, w +SQLSpark +Somethin +aaAaAAaA +aaAaAAaA +aaAaAAaA +aabcdcba +abcababc +bbAbAAbA +efd2efd2 +i̇oi̇oİo +sikitten +İo İo İo +İoİoİoİo +İoİoİİo + + -- !query select rpad(utf8_binary, 8, 'a'), rpad(utf8_lcase, 8, 'a') from t5 -- !query schema @@ -3746,7 +4174,9 @@ abcdcbaa aBcDCbAa bbAbAAbA aaaaaaaa efd2aaaa efd2aaaa kittenaa sitTinga +İo aaaaa İoaaaaaa İoaaaaaa i̇oaaaaa +İoaaaaaa İo aaaaa İoaaaaaa İoaaaaaa @@ -3766,7 +4196,9 @@ abcdcbaA aBcDCbAA bbAbAAbA aAAaAAaA efd2AaAA efd2AAaA kittenAa sitTingA +İo AaAAA İoAAaAAa İoAaAAAa i̇oAAaAA +İoAaAAAa İo AAaAA İoAaAAAa İoAAaAAa @@ -3801,6 +4233,8 @@ abcababc bbAbaAbA efd2efd2 kikitten +İo İo İo +İoİoİoİo İoİoİoİo İoİoİoİo @@ -3837,7 +4271,9 @@ bbAbAAbA efd2efd2 i̇oi̇oİo sikitten +İo İo İo İoİoİoİo +İoİoİİo -- !query @@ -3857,7 +4293,31 @@ bbAbAAbA efd2efd2 i̇oi̇oİo sikitten +İo İo İo İoİoİoİo +İoİoİİo + + +-- !query +select lpad(utf8_binary collate utf8_binary_rtrim, 8, utf8_lcase collate utf8_binary_rtrim) from t5 +-- !query schema +struct +-- !query output +Hello, w +SQLSpark +Somethin +aaAaAAaA +aaAaAAaA +aaAaAAaA +aabcdcba +abcababc +bbAbAAbA +efd2efd2 +i̇oi̇oİo +sikitten +İo İo İo +İoİoİoİo +İoİoİİo -- !query @@ -3874,6 +4334,8 @@ aaaSpark aaaaaSQL aaaaaabc aaaaaabc aaaaaaİo aaaaaaİo aaaaaaİo aaaaai̇o +aaaaaaİo aaaaaİo +aaaaaİo aaaaaaİo aaaaefd2 aaaaefd2 aabcdcba aaBcDCbA aakitten asitTing @@ -3888,6 +4350,8 @@ struct 1 1 1 +1 +1 -- !query @@ -3962,6 +4428,8 @@ struct 0 0 0 +0 +1 1 1 1 @@ -3987,6 +4455,8 @@ struct 0 0 0 0 0 0 +0 0 +0 0 0 1 @@ -4054,6 +4526,30 @@ struct +-- !query output +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 -- !query @@ -4107,6 +4603,8 @@ struct + + BcDCbA QL a @@ -4130,6 +4628,8 @@ struct +-- !query output + + + + + + + + + + +BcDCbA +QL +a +i̇ +sitTing + + -- !query select TRIM('ABc', utf8_binary), TRIM('ABc', utf8_lcase) from t5 -- !query schema @@ -4178,6 +4700,8 @@ efd2 efd2 kitten sitTing İo i̇o İo İo +İo İo +İo İo -- !query @@ -4198,6 +4722,8 @@ efd2 efd2 kitten sitTing İo i̇o İo İo +İo İo +İo İo -- !query @@ -4232,6 +4758,8 @@ struct + + a @@ -4263,6 +4791,8 @@ struct + + bbAbAAbA d kitte @@ -4284,6 +4814,8 @@ struct +-- !query output + + + + + + + + + + +bbAbAAbA +d +kitte +park +İ + + -- !query select BTRIM('ABc', utf8_binary), BTRIM('ABc', utf8_lcase) from t5 -- !query schema @@ -4330,6 +4884,8 @@ ABc ABc ABc ABc ABc ABc ABc ABc +ABc ABc +ABc ABc Bc Bc Bc Bc Bc Bc @@ -4348,6 +4904,8 @@ ABc AAa ABc AAa ABc AAa ABc AAa +ABc AAa +ABc AAa B AA Bc Bc @@ -4407,6 +4965,8 @@ struct + + BcDCbA QL a @@ -4430,6 +4990,8 @@ struct +-- !query output + + + + + + + + + + +BcDCbA +QL +a +i̇o +sitTing + + -- !query select LTRIM('ABc', utf8_binary), LTRIM('ABc', utf8_lcase) from t5 -- !query schema @@ -4478,6 +5062,8 @@ efd2 efd2 kitten sitTing İo i̇o İo İo +İo İo +İo İo -- !query @@ -4498,6 +5084,8 @@ efd2 efd2 kitten sitTing İo i̇o İo İo +İo İo +İo İo -- !query @@ -4523,6 +5111,20 @@ struct + + + + + + + + + + + +İo + + -- !query select RTRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5 -- !query schema @@ -4551,11 +5153,13 @@ struct + SQL a aBcDCbA i̇ sitTing +İo -- !query @@ -4574,8 +5178,10 @@ struct +-- !query output + + + + + + + + + + +SQL +a +aBcDCbA +i̇ +sitTing + + -- !query select RTRIM('ABc', utf8_binary), RTRIM('ABc', utf8_lcase) from t5 -- !query schema @@ -4622,6 +5250,8 @@ efd2 efd2 kitten sitTing İo i̇o İo İo +İo İo +İo İo -- !query @@ -4642,6 +5272,8 @@ efd2 efd2 kitten sitTing İo i̇o İo İo +İo İo +İo İo -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/interval.sql.out index 5471dafaec8eb..7eed2d42da043 100644 --- a/sql/core/src/test/resources/sql-tests/results/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/interval.sql.out @@ -1422,9 +1422,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0028", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNSUPPORTED_FROM_TO_EXPRESSION", + "sqlState" : "22006", "messageParameters" : { "from" : "year", + "input" : "1", "to" : "second" }, "queryContext" : [ { diff --git a/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out b/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out index cd94674d2bf2b..482a1efb6b095 100644 --- a/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out @@ -126,7 +126,7 @@ org.apache.spark.SparkDateTimeException "errorClass" : "INVALID_FRACTION_OF_SECOND", "sqlState" : "22023", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"" + "secAndMicros" : "60.007000" } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index daa9e6cf9e0a7..adfc5b703da47 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -39,6 +39,7 @@ class CollationSQLExpressionsSuite with ExpressionEvalHelper { private val testSuppCollations = Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI") + private val testAdditionalCollations = Seq("UNICODE", "SR", "SR_CI", "SR_AI", "SR_CI_AI") test("Support Md5 hash expression with collation") { case class Md5TestCase( @@ -55,7 +56,8 @@ class CollationSQLExpressionsSuite Md5TestCase("SQL", "UNICODE", "9778840a0100cb30c982876741b0b5a2"), Md5TestCase("SQL", "UNICODE_RTRIM", "9778840a0100cb30c982876741b0b5a2"), Md5TestCase("SQL", "UNICODE_CI", "9778840a0100cb30c982876741b0b5a2"), - Md5TestCase("SQL", "UNICODE_CI_RTRIM", "9778840a0100cb30c982876741b0b5a2") + Md5TestCase("SQL", "UNICODE_CI_RTRIM", "9778840a0100cb30c982876741b0b5a2"), + Md5TestCase("SQL", "SR_CI_AI", "9778840a0100cb30c982876741b0b5a2") ) // Supported collations @@ -98,6 +100,8 @@ class CollationSQLExpressionsSuite Sha2TestCase("SQL", "UNICODE_CI", 256, "a7056a455639d1c7deec82ee787db24a0c1878e2792b4597709f0facf7cc7b35"), Sha2TestCase("SQL", "UNICODE_CI_RTRIM", 256, + "a7056a455639d1c7deec82ee787db24a0c1878e2792b4597709f0facf7cc7b35"), + Sha2TestCase("SQL", "SR_AI", 256, "a7056a455639d1c7deec82ee787db24a0c1878e2792b4597709f0facf7cc7b35") ) @@ -132,7 +136,8 @@ class CollationSQLExpressionsSuite Sha1TestCase("SQL", "UNICODE", "2064cb643caa8d9e1de12eea7f3e143ca9f8680d"), Sha1TestCase("SQL", "UNICODE_RTRIM", "2064cb643caa8d9e1de12eea7f3e143ca9f8680d"), Sha1TestCase("SQL", "UNICODE_CI", "2064cb643caa8d9e1de12eea7f3e143ca9f8680d"), - Sha1TestCase("SQL", "UNICODE_CI_RTRIM", "2064cb643caa8d9e1de12eea7f3e143ca9f8680d") + Sha1TestCase("SQL", "UNICODE_CI_RTRIM", "2064cb643caa8d9e1de12eea7f3e143ca9f8680d"), + Sha1TestCase("Spark", "SR_CI", "85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c") ) // Supported collations @@ -549,7 +554,8 @@ class CollationSQLExpressionsSuite HexTestCase("Spark SQL", "UTF8_BINARY", "537061726B2053514C"), HexTestCase("Spark SQL", "UTF8_LCASE", "537061726B2053514C"), HexTestCase("Spark SQL", "UNICODE", "537061726B2053514C"), - HexTestCase("Spark SQL", "UNICODE_CI", "537061726B2053514C") + HexTestCase("Spark SQL", "UNICODE_CI", "537061726B2053514C"), + HexTestCase("Spark SQL", "DE_CI_AI", "537061726B2053514C") ) testCases.foreach(t => { val query = @@ -572,7 +578,8 @@ class CollationSQLExpressionsSuite UnHexTestCase("537061726B2053514C", "UTF8_BINARY", "Spark SQL"), UnHexTestCase("537061726B2053514C", "UTF8_LCASE", "Spark SQL"), UnHexTestCase("537061726B2053514C", "UNICODE", "Spark SQL"), - UnHexTestCase("537061726B2053514C", "UNICODE_CI", "Spark SQL") + UnHexTestCase("537061726B2053514C", "UNICODE_CI", "Spark SQL"), + UnHexTestCase("537061726B2053514C", "DE", "Spark SQL") ) testCases.foreach(t => { val query = @@ -640,7 +647,8 @@ class CollationSQLExpressionsSuite StringSpaceTestCase(1, "UTF8_BINARY", " "), StringSpaceTestCase(2, "UTF8_LCASE", " "), StringSpaceTestCase(3, "UNICODE", " "), - StringSpaceTestCase(4, "UNICODE_CI", " ") + StringSpaceTestCase(4, "UNICODE_CI", " "), + StringSpaceTestCase(5, "AF_CI_AI", " ") ) // Supported collations @@ -1008,7 +1016,11 @@ class CollationSQLExpressionsSuite StringToMapTestCase("1:ax2:bx3:c", "x", ":", "UNICODE", Map("1" -> "a", "2" -> "b", "3" -> "c")), StringToMapTestCase("1/AX2/BX3/C", "x", "/", "UNICODE_CI", - Map("1" -> "A", "2" -> "B", "3" -> "C")) + Map("1" -> "A", "2" -> "B", "3" -> "C")), + StringToMapTestCase("1:cx2:čx3:ć", "x", ":", "SR_CI_AI", + Map("1" -> "c", "2" -> "č", "3" -> "ć")), + StringToMapTestCase("c:1,č:2,ć:3", ",", ":", "SR_CI", + Map("c" -> "1", "č" -> "2", "ć" -> "3")) ) val unsupportedTestCases = Seq( StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UNICODE_AI", null), @@ -1082,7 +1094,7 @@ class CollationSQLExpressionsSuite test("Support CurrentDatabase/Catalog/User expressions with collation") { // Supported collations - Seq("UTF8_LCASE", "UNICODE", "UNICODE_CI").foreach(collationName => + Seq("UTF8_LCASE", "UNICODE", "UNICODE_CI", "SR_CI_AI").foreach(collationName => withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collationName) { val queryDatabase = sql("SELECT current_schema()") val queryCatalog = sql("SELECT current_catalog()") @@ -1098,7 +1110,7 @@ class CollationSQLExpressionsSuite test("Support Uuid misc expression with collation") { // Supported collations - Seq("UTF8_LCASE", "UNICODE", "UNICODE_CI").foreach(collationName => + Seq("UTF8_LCASE", "UNICODE", "UNICODE_CI", "NO_CI_AI").foreach(collationName => withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collationName) { val query = s"SELECT uuid()" // Result & data type @@ -1114,7 +1126,7 @@ class CollationSQLExpressionsSuite test("Support SparkVersion misc expression with collation") { // Supported collations - Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI").foreach(collationName => + Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI", "DE").foreach(collationName => withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collationName) { val query = s"SELECT version()" // Result & data type @@ -1689,7 +1701,7 @@ class CollationSQLExpressionsSuite test("Support InputFileName expression with collation") { // Supported collations - Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI").foreach(collationName => { + Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI", "MT_CI_AI").foreach(collationName => { val query = s""" |select input_file_name() @@ -1737,7 +1749,7 @@ class CollationSQLExpressionsSuite } test("Support mode for string expression with collation - Basic Test") { - Seq("utf8_binary", "UTF8_LCASE", "unicode_ci", "unicode").foreach { collationId => + Seq("utf8_binary", "UTF8_LCASE", "unicode_ci", "unicode", "NL_AI").foreach { collationId => val query = s"SELECT mode(collate('abc', '${collationId}'))" checkAnswer(sql(query), Row("abc")) assert(sql(query).schema.fields.head.dataType.sameType(StringType(collationId))) @@ -1750,7 +1762,8 @@ class CollationSQLExpressionsSuite ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), - ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a") + ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("SR", Map("c" -> 3L, "č" -> 2L, "Č" -> 2L), "c") ) testCases.foreach(t => { val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) => @@ -2434,829 +2447,637 @@ class CollationSQLExpressionsSuite ) } - test("min_by supports collation") { - val collation = "UNICODE" - val query = s"SELECT min_by(x, y) FROM VALUES ('a', 10), ('b', 50), ('c', 20) AS tab(x, y);" + // common method for subsequent tests verifying various SQL expressions with collations + private def testCollationSqlExpressionCommon( + query: String, + collation: String, + result: Row, + expectedType: DataType): Unit = { + testCollationSqlExpressionCommon(query, collation, Seq(result), Seq(expectedType)) + } + + // common method for subsequent tests verifying various SQL expressions with collations + private def testCollationSqlExpressionCommon( + query: String, + collation: String, + result: Seq[Row], + expectedTypes: Seq[DataType]): Unit = { withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row("a") - ) + // check result correctness + checkAnswer(sql(query), result) + // check result rows data types + for (i <- 0 until expectedTypes.length) + assert(sql(query).schema(i).dataType == expectedTypes(i)) + } + } + + test("min_by supports collation") { + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT min_by(x, y) FROM VALUES ('a', 10), ('b', 50), ('c', 20) AS tab(x, y)", + collation, + result = Row("a"), + expectedType = StringType(collation) ) - // check result row data type - val dataType = StringType(collation) - assert(sql(query).schema.head.dataType == dataType) } } test("max_by supports collation") { - val collation = "UNICODE" - val query = s"SELECT max_by(x, y) FROM VALUES ('a', 10), ('b', 50), ('c', 20) AS tab(x, y);" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row("b") - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT max_by(x, y) FROM VALUES ('a', 10), ('b', 50), ('c', 20) AS tab(x, y)", + collation, + result = Row("b"), + expectedType = StringType(collation) ) - // check result row data type - val dataType = StringType(collation) - assert(sql(query).schema.head.dataType == dataType) } } test("array supports collation") { - val collation = "UNICODE" - val query = s"SELECT array('a', 'b', 'c');" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Seq("a", "b", "c")) - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT array('a', 'b', 'c')", + collation, + result = Row(Seq("a", "b", "c")), + expectedType = ArrayType(StringType(collation), false) ) - // check result row data type - val dataType = ArrayType(StringType(collation), false) - assert(sql(query).schema.head.dataType == dataType) } } test("array_agg supports collation") { - val collation = "UNICODE" - val query = s"SELECT array_agg(col) FROM VALUES ('a'), ('b'), ('c') AS tab(col);" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Seq("a", "b", "c")) - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT array_agg(col) FROM VALUES ('a'), ('b'), ('c') AS tab(col)", + collation, + result = Row(Seq("a", "b", "c")), + expectedType = ArrayType(StringType(collation), false) ) - // check result row data type - val dataType = ArrayType(StringType(collation), false) - assert(sql(query).schema.head.dataType == dataType) } } test("array_contains supports collation") { - val collation = "UNICODE" - val query = s"SELECT array_contains(array('a', 'b', 'c'), 'b');" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(true) - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT array_contains(array('a', 'b', 'c'), 'b')", + collation, + result = Row(true), + expectedType = BooleanType ) - // check result row data type - val dataType = BooleanType - assert(sql(query).schema.head.dataType == dataType) } } test("arrays_overlap supports collation") { - val collation = "UNICODE" - val query = s"SELECT arrays_overlap(array('a', 'b', 'c'), array('c', 'd', 'e'));" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(true) - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT arrays_overlap(array('a', 'b', 'c'), array('c', 'd', 'e'))", + collation, + result = Row(true), + expectedType = BooleanType ) - // check result row data type - val dataType = BooleanType - assert(sql(query).schema.head.dataType == dataType) } } test("array_insert supports collation") { - val collation = "UNICODE" - val query = s"SELECT array_insert(array('a', 'b', 'c', 'd'), 5, 'e');" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Seq("a", "b", "c", "d", "e")) - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT array_insert(array('a', 'b', 'c', 'd'), 5, 'e')", + collation, + result = Row(Seq("a", "b", "c", "d", "e")), + expectedType = ArrayType(StringType(collation), true) ) - // check result row data type - val dataType = ArrayType(StringType(collation), true) - assert(sql(query).schema.head.dataType == dataType) } } test("array_intersect supports collation") { - val collation = "UNICODE" - val query = s"SELECT array_intersect(array('a', 'b', 'c'), array('b', 'c', 'd'));" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Seq("b", "c")) - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT array_intersect(array('a', 'b', 'c'), array('b', 'c', 'd'))", + collation, + result = Row(Seq("b", "c")), + expectedType = ArrayType(StringType(collation), false) ) - // check result row data type - val dataType = ArrayType(StringType(collation), false) - assert(sql(query).schema.head.dataType == dataType) } } test("array_join supports collation") { - val collation = "UNICODE" - val query = s"SELECT array_join(array('hello', 'world'), ' ');" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row("hello world") - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT array_join(array('hello', 'world'), ' ')", + collation, + result = Row("hello world"), + expectedType = StringType(collation) ) - // check result row data type - val dataType = StringType(collation) - assert(sql(query).schema.head.dataType == dataType) } } test("array_position supports collation") { - val collation = "UNICODE" - val query = s"SELECT array_position(array('a', 'b', 'c', 'c'), 'c');" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(3) - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT array_position(array('a', 'b', 'c', 'c'), 'c')", + collation, + result = Row(3), + expectedType = LongType ) - // check result row data type - val dataType = LongType - assert(sql(query).schema.head.dataType == dataType) } } test("array_size supports collation") { - val collation = "UNICODE" - val query = s"SELECT array_size(array('a', 'b', 'c', 'c'));" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(4) - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT array_size(array('a', 'b', 'c', 'c'))", + collation, + result = Row(4), + expectedType = IntegerType ) - // check result row data type - val dataType = IntegerType - assert(sql(query).schema.head.dataType == dataType) } } test("array_sort supports collation") { - val collation = "UNICODE" - val query = s"SELECT array_sort(array('b', null, 'A'));" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Seq("A", "b", null)) - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT array_sort(array('b', null, 'A'))", + collation, + result = Row(Seq("A", "b", null)), + expectedType = ArrayType(StringType(collation), true) ) - // check result row data type - val dataType = ArrayType(StringType(collation), true) - assert(sql(query).schema.head.dataType == dataType) } } test("array_except supports collation") { - val collation = "UNICODE" - val query = s"SELECT array_except(array('a', 'b', 'c'), array('c', 'd', 'e'));" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Seq("a", "b")) - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT array_except(array('a', 'b', 'c'), array('c', 'd', 'e'))", + collation, + result = Row(Seq("a", "b")), + expectedType = ArrayType(StringType(collation), false) ) - // check result row data type - val dataType = ArrayType(StringType(collation), false) - assert(sql(query).schema.head.dataType == dataType) } } test("array_union supports collation") { - val collation = "UNICODE" - val query = s"SELECT array_union(array('a', 'b', 'c'), array('a', 'c', 'd'));" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Seq("a", "b", "c", "d")) - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT array_union(array('a', 'b', 'c'), array('a', 'c', 'd'))", + collation, + result = Row(Seq("a", "b", "c", "d")), + expectedType = ArrayType(StringType(collation), false) ) - // check result row data type - val dataType = ArrayType(StringType(collation), false) - assert(sql(query).schema.head.dataType == dataType) } } test("array_compact supports collation") { - val collation = "UNICODE" - val query = s"SELECT array_compact(array('a', 'b', null, 'c'));" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Seq("a", "b", "c")) - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT array_compact(array('a', 'b', null, 'c'))", + collation, + result = Row(Seq("a", "b", "c")), + expectedType = ArrayType(StringType(collation), false) ) - // check result row data type - val dataType = ArrayType(StringType(collation), false) - assert(sql(query).schema.head.dataType == dataType) } } test("arrays_zip supports collation") { - val collation = "UNICODE" - val query = s"SELECT arrays_zip(array('a', 'b', 'c'), array(1, 2, 3));" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Seq(Row("a", 1), Row("b", 2), Row("c", 3))) - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT arrays_zip(array('a', 'b', 'c'), array(1, 2, 3))", + collation, + result = Row(Seq(Row("a", 1), Row("b", 2), Row("c", 3))), + expectedType = ArrayType(StructType( + StructField("0", StringType(collation), true) :: + StructField("1", IntegerType, true) :: Nil + ), false) ) - // check result row data type - val dataType = ArrayType(StructType( - StructField("0", StringType(collation), true) :: - StructField("1", IntegerType, true) :: Nil - ), false) - assert(sql(query).schema.head.dataType == dataType) } } test("array_min supports collation") { - val collation = "UNICODE" - val query = s"SELECT array_min(array('a', 'b', null, 'c'));" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row("a") - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT array_min(array('a', 'b', null, 'c'))", + collation, + result = Row("a"), + expectedType = StringType(collation) ) - // check result row data type - val dataType = StringType(collation) - assert(sql(query).schema.head.dataType == dataType) } } test("array_max supports collation") { - val collation = "UNICODE" - val query = s"SELECT array_max(array('a', 'b', null, 'c'));" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row("c") - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT array_max(array('a', 'b', null, 'c'))", + collation, + result = Row("c"), + expectedType = StringType(collation) ) - // check result row data type - val dataType = StringType(collation) - assert(sql(query).schema.head.dataType == dataType) } } test("array_append supports collation") { - val collation = "UNICODE" - val query = s"SELECT array_append(array('b', 'd', 'c', 'a'), 'e');" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Seq("b", "d", "c", "a", "e")) - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT array_append(array('b', 'd', 'c', 'a'), 'e')", + collation, + result = Row(Seq("b", "d", "c", "a", "e")), + expectedType = ArrayType(StringType(collation), true) ) - // check result row data type - val dataType = ArrayType(StringType(collation), true) - assert(sql(query).schema.head.dataType == dataType) } } test("array_repeat supports collation") { - val collation = "UNICODE" - val query = s"SELECT array_repeat('abc', 2);" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Seq("abc", "abc")) - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT array_repeat('abc', 2)", + collation, + result = Row(Seq("abc", "abc")), + expectedType = ArrayType(StringType(collation), false) ) - // check result row data type - val dataType = ArrayType(StringType(collation), false) - assert(sql(query).schema.head.dataType == dataType) } } test("array_remove supports collation") { - val collation = "UNICODE" - val query = s"SELECT array_remove(array('a', 'b', null, 'c'), 'b');" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Seq("a", null, "c")) - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT array_remove(array('a', 'b', null, 'c'), 'b')", + collation, + result = Row(Seq("a", null, "c")), + expectedType = ArrayType(StringType(collation), true) ) - // check result row data type - val dataType = ArrayType(StringType(collation), true) - assert(sql(query).schema.head.dataType == dataType) } } test("array_prepend supports collation") { - val collation = "UNICODE" - val query = s"SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd');" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Seq("d", "b", "d", "c", "a")) - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd')", + collation, + result = Row(Seq("d", "b", "d", "c", "a")), + expectedType = ArrayType(StringType(collation), true) ) - // check result row data type - val dataType = ArrayType(StringType(collation), true) - assert(sql(query).schema.head.dataType == dataType) } } test("array_distinct supports collation") { - val collation = "UNICODE" - val query = s"SELECT array_distinct(array('a', 'b', 'c', null, 'c'));" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Seq("a", "b", "c", null)) - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT array_distinct(array('a', 'b', 'c', null, 'c'))", + collation, + result = Row(Seq("a", "b", "c", null)), + expectedType = ArrayType(StringType(collation), true) ) - // check result row data type - val dataType = ArrayType(StringType(collation), true) - assert(sql(query).schema.head.dataType == dataType) } } test("collect_list supports collation") { - val collation = "UNICODE" - val query = s"SELECT collect_list(col) FROM VALUES ('a'), ('b'), ('c') AS tab(col);" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Seq("a", "b", "c")) - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT collect_list(col) FROM VALUES ('a'), ('b'), ('c') AS tab(col)", + collation, + result = Row(Seq("a", "b", "c")), + expectedType = ArrayType(StringType(collation), false) ) - // check result row data type - val dataType = ArrayType(StringType(collation), false) - assert(sql(query).schema.head.dataType == dataType) } } test("collect_set does not support collation") { - val collation = "UNICODE" - val query = s"SELECT collect_set(col) FROM VALUES ('a'), ('b'), ('a') AS tab(col);" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkError( - exception = intercept[AnalysisException] { - sql(query) - }, - condition = "DATATYPE_MISMATCH.UNSUPPORTED_INPUT_TYPE", - sqlState = Some("42K09"), - parameters = Map( - "functionName" -> "`collect_set`", - "dataType" -> "\"MAP\" or \"COLLATED STRING\"", - "sqlExpr" -> "\"collect_set(col)\""), - context = ExpectedContext( - fragment = "collect_set(col)", - start = 7, - stop = 22)) + testAdditionalCollations.foreach { collation => + val query = "SELECT collect_set(col) FROM VALUES ('a'), ('b'), ('a') AS tab(col);" + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { + checkError( + exception = intercept[AnalysisException] { + sql(query) + }, + condition = "DATATYPE_MISMATCH.UNSUPPORTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "functionName" -> "`collect_set`", + "dataType" -> "\"MAP\" or \"COLLATED STRING\"", + "sqlExpr" -> "\"collect_set(col)\""), + context = ExpectedContext( + fragment = "collect_set(col)", + start = 7, + stop = 22)) + } } } test("element_at supports collation") { - val collation = "UNICODE" - val query = s"SELECT element_at(array('a', 'b', 'c'), 2);" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row("b") - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT element_at(array('a', 'b', 'c'), 2)", + collation, + result = Row("b"), + expectedType = StringType(collation) ) - // check result row data type - val dataType = StringType(collation) - assert(sql(query).schema.head.dataType == dataType) } } test("aggregate supports collation") { - val collation = "UNICODE" - val query = s"SELECT aggregate(array('a', 'b', 'c'), '', (acc, x) -> concat(acc, x));" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row("abc") - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT aggregate(array('a', 'b', 'c'), '', (acc, x) -> concat(acc, x))", + collation, + result = Row("abc"), + expectedType = StringType(collation) ) - // check result row data type - val dataType = StringType(collation) - assert(sql(query).schema.head.dataType == dataType) } } test("explode supports collation") { - val collation = "UNICODE" - val query = s"SELECT explode(array('a', 'b'));" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT explode(array('a', 'b'))", + collation, + result = Seq( Row("a"), Row("b") + ), + expectedTypes = Seq( + StringType(collation) ) ) - // check result row data type - val dataType = StringType(collation) - assert(sql(query).schema.head.dataType == dataType) } } test("posexplode supports collation") { - val collation = "UNICODE" - val query = s"SELECT posexplode(array('a', 'b'));" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT posexplode(array('a', 'b'))", + collation, + result = Seq( Row(0, "a"), Row(1, "b") + ), + expectedTypes = Seq( + IntegerType, + StringType(collation) ) ) - // check result row data type - val dataType = StringType(collation) - assert(sql(query).schema.head.dataType == IntegerType) - assert(sql(query).schema(1).dataType == dataType) } } test("filter supports collation") { - val collation = "UNICODE" - val query = s"SELECT filter(array('a', 'b', 'c'), x -> x < 'b');" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Seq("a")) - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT filter(array('a', 'b', 'c'), x -> x < 'b')", + collation, + result = Row(Seq("a")), + expectedType = ArrayType(StringType(collation), false) ) - // check result row data type - val dataType = ArrayType(StringType(collation), false) - assert(sql(query).schema.head.dataType == dataType) } } test("flatten supports collation") { - val collation = "UNICODE" - val query = s"SELECT flatten(array(array('a', 'b'), array('c', 'd')));" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Seq("a", "b", "c", "d")) - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT flatten(array(array('a', 'b'), array('c', 'd')))", + collation, + result = Row(Seq("a", "b", "c", "d")), + expectedType = ArrayType(StringType(collation), false) ) - // check result row data type - val dataType = ArrayType(StringType(collation), false) - assert(sql(query).schema.head.dataType == dataType) } } test("inline supports collation") { - val collation = "UNICODE" - val query = s"SELECT inline(array(struct(1, 'a'), struct(2, 'b')));" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT inline(array(struct(1, 'a'), struct(2, 'b')))", + collation, Seq( Row(1, "a"), Row(2, "b") + ), + expectedTypes = Seq( + IntegerType, + StringType(collation) ) ) - // check result row data type - val dataType = StringType(collation) - assert(sql(query).schema.head.dataType == IntegerType) - assert(sql(query).schema(1).dataType == dataType) } } test("shuffle supports collation") { - val collation = "UNICODE" - val query = s"SELECT shuffle(array('a', 'b', 'c', 'd'));" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - // check result row data type - val dataType = ArrayType(StringType(collation), false) - assert(sql(query).schema.head.dataType == dataType) + testAdditionalCollations.foreach { collation => + val query = "SELECT shuffle(array('a', 'b', 'c', 'd'));" + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { + // check result row data type + val dataType = ArrayType(StringType(collation), false) + assert(sql(query).schema.head.dataType == dataType) + } } } test("slice supports collation") { - val collation = "UNICODE" - val query = s"SELECT slice(array('a', 'b', 'c', 'd'), 2, 2);" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Seq("b", "c")) - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT slice(array('a', 'b', 'c', 'd'), 2, 2)", + collation, + result = Row(Seq("b", "c")), + expectedType = ArrayType(StringType(collation), false) ) - // check result row data type - val dataType = ArrayType(StringType(collation), false) - assert(sql(query).schema.head.dataType == dataType) } } test("sort_array supports collation") { - val collation = "UNICODE" - val query = s"SELECT sort_array(array('b', 'd', null, 'c', 'a'), true);" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Seq(null, "a", "b", "c", "d")) - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT sort_array(array('b', 'd', null, 'c', 'a'), true)", + collation, + result = Row(Seq(null, "a", "b", "c", "d")), + expectedType = ArrayType(StringType(collation), true) ) - // check result row data type - val dataType = ArrayType(StringType(collation), true) - assert(sql(query).schema.head.dataType == dataType) } } test("zip_with supports collation") { - val collation = "UNICODE" - val query = s"SELECT zip_with(array('a', 'b'), array('x', 'y'), (x, y) -> concat(x, y));" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Seq("ax", "by")) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT zip_with(array('a', 'b'), array('x', 'y'), (x, y) -> concat(x, y))", + collation, + result = Row(Seq("ax", "by")), + expectedType = ArrayType( + StringType(collation), + containsNull = true ) ) - // check result row data type - val dataType = ArrayType( - StringType(collation), - containsNull = true - ) - assert(sql(query).schema.head.dataType == dataType) } } test("map_contains_key supports collation") { - val collation = "UNICODE" - val query = s"SELECT map_contains_key(map('a', 1, 'b', 2), 'a')" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(true) - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT map_contains_key(map('a', 1, 'b', 2), 'a')", + collation, + result = Row(true), + expectedType = BooleanType ) - // check result row data type - val dataType = BooleanType - assert(sql(query).schema.head.dataType == dataType) } } test("map_from_arrays supports collation") { - val collation = "UNICODE" - val query = s"SELECT map_from_arrays(array('a','b','c'), array(1,2,3))" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Map("a" -> 1, "b" -> 2, "c" -> 3)) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT map_from_arrays(array('a','b','c'), array(1,2,3))", + collation, + result = Row(Map("a" -> 1, "b" -> 2, "c" -> 3)), + expectedType = MapType( + StringType(collation), + IntegerType, false ) ) - // check result row data type - val dataType = MapType( - StringType(collation), - IntegerType, false - ) - assert(sql(query).schema.head.dataType == dataType) } } test("map_keys supports collation") { - val collation = "UNICODE" - val query = s"SELECT map_keys(map('a', 1, 'b', 2))" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Seq("a", "b")) - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT map_keys(map('a', 1, 'b', 2))", + collation, + result = Row(Seq("a", "b")), + expectedType = ArrayType(StringType(collation), true) ) - // check result row data type - val dataType = ArrayType( - StringType(collation), true - ) - assert(sql(query).schema.head.dataType == dataType) } } test("map_values supports collation") { - val collation = "UNICODE" - val query = s"SELECT map_values(map(1, 'a', 2, 'b'))" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Seq("a", "b")) - ) - ) - // check result row data type - val dataType = ArrayType( - StringType(collation), true + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT map_values(map(1, 'a', 2, 'b'))", + collation, + result = Row(Seq("a", "b")), + expectedType = ArrayType(StringType(collation), true) ) - assert(sql(query).schema.head.dataType == dataType) } } test("map_entries supports collation") { - val collation = "UNICODE" - val query = s"SELECT map_entries(map('a', 1, 'b', 2))" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Seq(Row("a", 1), Row("b", 2))) - ) - ) - // check result row data type - val dataType = ArrayType(StructType( + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT map_entries(map('a', 1, 'b', 2))", + collation, + result = Row(Seq(Row("a", 1), Row("b", 2))), + expectedType = ArrayType(StructType( StructField("key", StringType(collation), false) :: - StructField("value", IntegerType, false) :: Nil - ), false) - assert(sql(query).schema.head.dataType == dataType) + StructField("value", IntegerType, false) :: Nil + ), false) + ) } } test("map_from_entries supports collation") { - val collation = "UNICODE" - val query = s"SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'b')))" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Map(1 -> "a", 2 -> "b")) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'b')))", + collation, + result = Row(Map(1 -> "a", 2 -> "b")), + expectedType = MapType( + IntegerType, + StringType(collation), + valueContainsNull = false ) ) - // check result row data type - val dataType = MapType( - IntegerType, - StringType(collation), - valueContainsNull = false - ) - assert(sql(query).schema.head.dataType == dataType) } } test("map_concat supports collation") { - val collation = "UNICODE" - val query = s"SELECT map_concat(map(1, 'a'), map(2, 'b'))" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Map(1 -> "a", 2 -> "b")) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT map_concat(map(1, 'a'), map(2, 'b'))", + collation, + result = Row(Map(1 -> "a", 2 -> "b")), + expectedType = MapType( + IntegerType, + StringType(collation), + valueContainsNull = false ) ) - // check result row data type - val dataType = MapType( - IntegerType, - StringType(collation), - valueContainsNull = false - ) - assert(sql(query).schema.head.dataType == dataType) } } test("map_filter supports collation") { - val collation = "UNICODE" - val query = s"SELECT map_filter(map('a', 1, 'b', 2, 'c', 3), (k, v) -> k < 'c')" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Map("a" -> 1, "b" -> 2)) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT map_filter(map('a', 1, 'b', 2, 'c', 3), (k, v) -> k < 'c')", + collation, + result = Row(Map("a" -> 1, "b" -> 2)), + expectedType = MapType( + StringType(collation), + IntegerType, + valueContainsNull = false ) ) - // check result row data type - val dataType = MapType( - StringType(collation), - IntegerType, - valueContainsNull = false - ) - assert(sql(query).schema.head.dataType == dataType) } } test("map_zip_with supports collation") { - val collation = "UNICODE" - val query = s"SELECT map_zip_with(map(1, 'a'), map(1, 'x'), (k, v1, v2) -> concat(v1, v2))" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Map(1 -> "ax")) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT map_zip_with(map(1, 'a'), map(1, 'x'), (k, v1, v2) -> concat(v1, v2))", + collation, + result = Row(Map(1 -> "ax")), + expectedType = MapType( + IntegerType, + StringType(collation), + valueContainsNull = true ) ) - // check result row data type - val dataType = MapType( - IntegerType, - StringType(collation), - valueContainsNull = true - ) - assert(sql(query).schema.head.dataType == dataType) } } test("transform supports collation") { - val collation = "UNICODE" - val query = s"SELECT transform(array('aa', 'bb', 'cc'), x -> substring(x, 2))" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Seq("a", "b", "c")) - ) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT transform(array('aa', 'bb', 'cc'), x -> substring(x, 2))", + collation, + result = Row(Seq("a", "b", "c")), + expectedType = ArrayType(StringType(collation), false) ) - // check result row data type - val dataType = ArrayType(StringType(collation), false) - assert(sql(query).schema.head.dataType == dataType) } } test("transform_values supports collation") { - val collation = "UNICODE" - val query = s"SELECT transform_values(map_from_arrays(array(1, 2, 3)," + - s"array('aa', 'bb', 'cc')), (k, v) -> substring(v, 2))" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Map(1 -> "a", 2 -> "b", 3 -> "c")) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT transform_values(map_from_arrays(array(1, 2, 3)," + + "array('aa', 'bb', 'cc')), (k, v) -> substring(v, 2))", + collation, + result = Row(Map(1 -> "a", 2 -> "b", 3 -> "c")), + expectedType = MapType( + IntegerType, + StringType(collation), + false ) ) - // check result row data type - val dataType = MapType(IntegerType, - StringType(collation), false) - assert(sql(query).schema.head.dataType == dataType) } } test("transform_keys supports collation") { - val collation = "UNICODE" - val query = s"SELECT transform_keys(map_from_arrays(array('aa', 'bb', 'cc')," + - s"array(1, 2, 3)), (k, v) -> substring(k, 2))" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( - Row(Map("a" -> 1, "b" -> 2, "c" -> 3)) + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT transform_keys(map_from_arrays(array('aa', 'bb', 'cc')," + + "array(1, 2, 3)), (k, v) -> substring(k, 2))", + collation, + result = Row(Map("a" -> 1, "b" -> 2, "c" -> 3)), + expectedType = MapType( + StringType(collation), + IntegerType, + false ) ) - // check result row data type - val dataType = MapType( - StringType(collation), IntegerType, false - ) - assert(sql(query).schema.head.dataType == dataType) } } test("stack supports collation") { - val query = s"SELECT stack(2, 'a', 'b', 'c')" - val collation = "UNICODE" - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - checkAnswer( - sql(query), - Seq( + testAdditionalCollations.foreach { collation => + testCollationSqlExpressionCommon( + query = "SELECT stack(2, 'a', 'b', 'c')", + collation, + result = Seq( Row("a", "b"), Row("c", null) + ), + expectedTypes = Seq( + StringType(collation) ) ) - // check result row data type - val dataType = StringType(collation) - assert(sql(query).schema.head.dataType == dataType) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 6db30a7ed0c6f..9ee2cfb964feb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -38,9 +38,15 @@ class CollationStringExpressionsSuite result: R) val testCases = Seq( ConcatWsTestCase(" ", Array("Spark", "SQL"), "UTF8_BINARY", "Spark SQL"), + ConcatWsTestCase(" ", Array("Spark", "SQL"), "UTF8_BINARY_RTRIM", "Spark SQL"), ConcatWsTestCase(" ", Array("Spark", "SQL"), "UTF8_LCASE", "Spark SQL"), + ConcatWsTestCase(" ", Array("Spark", "SQL"), "UTF8_LCASE_RTRIM", "Spark SQL"), ConcatWsTestCase(" ", Array("Spark", "SQL"), "UNICODE", "Spark SQL"), - ConcatWsTestCase(" ", Array("Spark", "SQL"), "UNICODE_CI", "Spark SQL") + ConcatWsTestCase(" ", Array("Spark", "SQL"), "UNICODE_RTRIM", "Spark SQL"), + ConcatWsTestCase(" ", Array("Spark", "SQL"), "UNICODE_CI", "Spark SQL"), + ConcatWsTestCase(" ", Array("Spark", "SQL"), "UNICODE_CI_RTRIM", "Spark SQL"), + ConcatWsTestCase(" ", Array("Spark", "SQL"), "UNICODE_CI", "Spark SQL"), + ConcatWsTestCase(" ", Array("Spark", "Unterstützung"), "DE_CI_AI", "Spark Unterstützung") ) testCases.foreach(t => { // Unit test. @@ -64,9 +70,15 @@ class CollationStringExpressionsSuite case class EltTestCase[R](index: Integer, inputs: Array[String], collation: String, result: R) val testCases = Seq( EltTestCase(1, Array("Spark", "SQL"), "UTF8_BINARY", "Spark"), + EltTestCase(1, Array("Spark", "SQL"), "UTF8_BINARY_RTRIM", "Spark"), EltTestCase(1, Array("Spark", "SQL"), "UTF8_LCASE", "Spark"), + EltTestCase(1, Array("Spark", "SQL"), "UTF8_LCASE_RTRIM", "Spark"), EltTestCase(2, Array("Spark", "SQL"), "UNICODE", "SQL"), - EltTestCase(2, Array("Spark", "SQL"), "UNICODE_CI", "SQL") + EltTestCase(2, Array("Spark", "SQL"), "UNICODE_RTRIM", "SQL"), + EltTestCase(2, Array("Spark", "SQL"), "UNICODE_CI", "SQL"), + EltTestCase(2, Array("Spark", "SQL"), "UNICODE_CI_RTRIM", "SQL"), + EltTestCase(2, Array("Spark", "SQL"), "UNICODE_CI", "SQL"), + EltTestCase(2, Array("Spark", "Unterstützung"), "DE_CI", "Unterstützung") ) testCases.foreach(t => { // Unit test. @@ -94,9 +106,15 @@ class CollationStringExpressionsSuite result: R) val testCases = Seq( SplitPartTestCase("1a2", "a", 2, "UTF8_BINARY", "2"), + SplitPartTestCase("1a2", "a ", 1, "UTF8_BINARY_RTRIM", "1"), SplitPartTestCase("1a2", "a", 2, "UNICODE", "2"), + SplitPartTestCase("1a 2", "a ", 2, "UNICODE_RTRIM", " 2"), SplitPartTestCase("1a2", "A", 2, "UTF8_LCASE", "2"), - SplitPartTestCase("1a2", "A", 2, "UNICODE_CI", "2") + SplitPartTestCase("1 a2", "A ", 2, "UTF8_LCASE_RTRIM", "2"), + SplitPartTestCase("1a2", "A", 2, "UNICODE_CI", "2"), + SplitPartTestCase("1 a2 ", "A ", 2, "UNICODE_CI_RTRIM", "2 "), + SplitPartTestCase("1a2", "A", 2, "UNICODE_CI", "2"), + SplitPartTestCase("1ö2", "O", 2, "DE_CI_AI", "2") ) val unsupportedTestCase = SplitPartTestCase("1a2", "a", 2, "UNICODE_AI", "2") testCases.foreach(t => { @@ -142,9 +160,16 @@ class CollationStringExpressionsSuite result: R) val testCases = Seq( StringSplitSQLTestCase("1a2", "a", "UTF8_BINARY", Array("1", "2")), + StringSplitSQLTestCase("1a2", "a ", "UTF8_BINARY_RTRIM", Array("1", "2")), StringSplitSQLTestCase("1a2", "a", "UNICODE", Array("1", "2")), + StringSplitSQLTestCase("1a 2", "a ", "UNICODE_RTRIM", Array("1", " 2")), StringSplitSQLTestCase("1a2", "A", "UTF8_LCASE", Array("1", "2")), - StringSplitSQLTestCase("1a2", "A", "UNICODE_CI", Array("1", "2")) + StringSplitSQLTestCase("1 a2", "A ", "UTF8_LCASE_RTRIM", Array("1 ", "2")), + StringSplitSQLTestCase("1a2", "A", "UNICODE_CI", Array("1", "2")), + StringSplitSQLTestCase("1 a2 ", "A ", "UNICODE_CI_RTRIM", Array("1 ", "2 ")), + StringSplitSQLTestCase("1a2", "A", "UNICODE_CI", Array("1", "2")), + StringSplitSQLTestCase("1ä2", "Ä", "DE_CI", Array("1", "2")), + StringSplitSQLTestCase("1ä2", "A", "DE_CI_AI", Array("1", "2")) ) testCases.foreach(t => { // Unit test. @@ -187,9 +212,20 @@ class CollationStringExpressionsSuite case class ContainsTestCase[R](left: String, right: String, collation: String, result: R) val testCases = Seq( ContainsTestCase("", "", "UTF8_BINARY", true), + ContainsTestCase("", " ", "UTF8_BINARY_RTRIM", true), ContainsTestCase("abcde", "C", "UNICODE", false), + ContainsTestCase("abcde", " C ", "UNICODE_RTRIM", false), ContainsTestCase("abcde", "FGH", "UTF8_LCASE", false), - ContainsTestCase("abcde", "BCD", "UNICODE_CI", true) + ContainsTestCase("abcde", "ABC ", "UTF8_LCASE_RTRIM", true), + ContainsTestCase("abcde", "BCD", "UNICODE_CI", true), + ContainsTestCase("ab c de ", "B C D ", "UNICODE_CI_RTRIM", true), + ContainsTestCase("abcde", "BCD", "UNICODE_CI", true), + ContainsTestCase("Priča o Maču u kamenu", "MAC", "SR_CI_AI", true), + ContainsTestCase("Priča o Maču u kamenu", "MAC", "SR_CI", false), + ContainsTestCase("Priča o Maču u kamenu", "MAČ", "SR", false), + ContainsTestCase("Priča o Maču u kamenu", "Mač", "SR", true), + ContainsTestCase("Прича о Мачу у камену", "мач", "sr_Cyrl_CI_AI", true), + ContainsTestCase("Прича о Мачу у камену", "мац", "sr_Cyrl_CI_AI", false) ) val unsupportedTestCase = ContainsTestCase("abcde", "A", "UNICODE_AI", false) testCases.foreach(t => { @@ -235,9 +271,16 @@ class CollationStringExpressionsSuite val testCases = Seq( SubstringIndexTestCase("wwwgapachegorg", "g", -3, "UTF8_BINARY", "apachegorg"), SubstringIndexTestCase("www||apache||org", "||", 2, "UTF8_BINARY", "www||apache"), + SubstringIndexTestCase("wwwgapachegorg", "g ", -3, "UTF8_BINARY_RTRIM", "apachegorg"), + SubstringIndexTestCase("www ||apache||org", "|| ", 2, "UTF8_BINARY_RTRIM", "www ||apache"), SubstringIndexTestCase("wwwXapacheXorg", "x", 2, "UTF8_LCASE", "wwwXapache"), + SubstringIndexTestCase("AAA ", "a ", -2, "UTF8_LCASE_RTRIM", "A "), SubstringIndexTestCase("aaaaaaaaaa", "aa", 2, "UNICODE", "a"), - SubstringIndexTestCase("wwwmapacheMorg", "M", -2, "UNICODE_CI", "apacheMorg") + SubstringIndexTestCase("aaaaaaaaaa ", "aa ", 2, "UNICODE_RTRIM", "a"), + SubstringIndexTestCase("wwwmapacheMorg", "M", -2, "UNICODE_CI", "apacheMorg"), + SubstringIndexTestCase("AA A ", "a ", -2, "UNICODE_CI_RTRIM", " A "), + SubstringIndexTestCase("wwwmapacheMorg", "M", -2, "UNICODE_CI", "apacheMorg"), + SubstringIndexTestCase("wwwüapacheüorg", "U", 2, "DE_CI_AI", "wwwüapache") ) val unsupportedTestCase = SubstringIndexTestCase("abacde", "a", 2, "UNICODE_AI", "cde") testCases.foreach(t => { @@ -282,11 +325,16 @@ class CollationStringExpressionsSuite case class StringInStrTestCase[R](str: String, substr: String, collation: String, result: R) val testCases = Seq( StringInStrTestCase("test大千世界X大千世界", "大千", "UTF8_BINARY", 5), + StringInStrTestCase("test大千世界X大千世界", "大千 ", "UTF8_BINARY_RTRIM", 5), StringInStrTestCase("test大千世界X大千世界", "界x", "UTF8_LCASE", 8), + StringInStrTestCase(" test大千世界X大千世界 ", "界x ", "UTF8_LCASE_RTRIM", 9), StringInStrTestCase("test大千世界X大千世界", "界x", "UNICODE", 0), + StringInStrTestCase("test大千世界X大千世界", "界x", "UNICODE_RTRIM", 0), StringInStrTestCase("test大千世界X大千世界", "界y", "UNICODE_CI", 0), StringInStrTestCase("test大千世界X大千世界", "界x", "UNICODE_CI", 8), - StringInStrTestCase("abİo12", "i̇o", "UNICODE_CI", 3) + StringInStrTestCase("abİo12", "i̇o", "UNICODE_CI", 3), + StringInStrTestCase("test大千世界X大千世界", "大 ", "UNICODE_CI_RTRIM", 5), + StringInStrTestCase("test大千世界X大千世界", " 大 ", "UNICODE_CI_RTRIM", 0) ) val unsupportedTestCase = StringInStrTestCase("a", "abcde", "UNICODE_AI", 0) testCases.foreach(t => { @@ -326,10 +374,17 @@ class CollationStringExpressionsSuite case class FindInSetTestCase[R](left: String, right: String, collation: String, result: R) val testCases = Seq( FindInSetTestCase("AB", "abc,b,ab,c,def", "UTF8_BINARY", 0), + FindInSetTestCase("b ", "abc,b,ab,c,def", "UTF8_BINARY_RTRIM", 2), + FindInSetTestCase("def", "abc,b,ab,c,def ", "UTF8_BINARY_RTRIM", 5), FindInSetTestCase("C", "abc,b,ab,c,def", "UTF8_LCASE", 4), + FindInSetTestCase("C ", "abc,b,ab,c ,def", "UTF8_LCASE_RTRIM", 4), FindInSetTestCase("d,ef", "abc,b,ab,c,def", "UNICODE", 0), + FindInSetTestCase(" def", "abc,b,ab,c,def", "UNICODE_RTRIM", 0), FindInSetTestCase("i̇o", "ab,İo,12", "UNICODE_CI", 2), - FindInSetTestCase("İo", "ab,i̇o,12", "UNICODE_CI", 2) + FindInSetTestCase("İo", "ab,i̇o,12", "UNICODE_CI", 2), + FindInSetTestCase("İo", "ab,i̇o,12", "UNICODE_CI", 2), + FindInSetTestCase("a", "A ,B ,C", "UNICODE_CI_RTRIM", 1), + FindInSetTestCase(" a", "A ,B ,C", "UNICODE_CI_RTRIM", 0) ) testCases.foreach(t => { // Unit test. @@ -349,9 +404,16 @@ class CollationStringExpressionsSuite case class StartsWithTestCase[R](left: String, right: String, collation: String, result: R) val testCases = Seq( StartsWithTestCase("", "", "UTF8_BINARY", true), + StartsWithTestCase("", " ", "UTF8_BINARY_RTRIM", true), StartsWithTestCase("abcde", "A", "UNICODE", false), + StartsWithTestCase("abcde", "a ", "UNICODE_RTRIM", true), StartsWithTestCase("abcde", "FGH", "UTF8_LCASE", false), - StartsWithTestCase("abcde", "ABC", "UNICODE_CI", true) + StartsWithTestCase("abcde ", "FGH ", "UTF8_LCASE_RTRIM", false), + StartsWithTestCase("abcde", "ABC", "UNICODE_CI", true), + StartsWithTestCase("a b c de ", "A B C ", "UNICODE_CI_RTRIM", true), + StartsWithTestCase("abcde", "ABC", "UNICODE_CI", true), + StartsWithTestCase("Šuma", "šum", "SR_CI_AI", true), + StartsWithTestCase("Šuma", "šum", "SR", false) ) val unsupportedTestCase = StartsWithTestCase("abcde", "A", "UNICODE_AI", false) testCases.foreach(t => { @@ -396,9 +458,15 @@ class CollationStringExpressionsSuite result: R) val testCases = Seq( StringTranslateTestCase("Translate", "Rnlt", "12", "UTF8_BINARY", "Tra2sae"), + StringTranslateTestCase(" abc ", "abc", "123", "UTF8_BINARY_RTRIM", " 123 "), StringTranslateTestCase("Translate", "Rnlt", "1234", "UTF8_LCASE", "41a2s3a4e"), + StringTranslateTestCase(" abc ", " AB", "123", "UTF8_LCASE_RTRIM", "123c1"), StringTranslateTestCase("Translate", "Rn", "\u0000\u0000", "UNICODE", "Traslate"), - StringTranslateTestCase("Translate", "Rn", "1234", "UNICODE_CI", "T1a2slate") + StringTranslateTestCase(" a b c ", "abc ", "1234", "UNICODE_RTRIM", "4142434"), + StringTranslateTestCase("Translate", "Rn", "1234", "UNICODE_CI", "T1a2slate"), + StringTranslateTestCase(" abc ", "AB ", "123", "UNICODE_CI_RTRIM", "312c3"), + StringTranslateTestCase("Translate", "Rn", "1234", "UNICODE_CI", "T1a2slate"), + StringTranslateTestCase("Êtèréêë", "te", "12", "AF_CI_AI", "212r222") ) val unsupportedTestCase = StringTranslateTestCase("ABC", "AB", "12", "UNICODE_AI", "12C") testCases.foreach(t => { @@ -446,11 +514,16 @@ class CollationStringExpressionsSuite result: R) val testCases = Seq( StringReplaceTestCase("r世eplace", "pl", "123", "UTF8_BINARY", "r世e123ace"), + StringReplaceTestCase(" abc ", "b ", "x", "UTF8_BINARY_RTRIM", " abc "), StringReplaceTestCase("repl世ace", "PL", "AB", "UTF8_LCASE", "reAB世ace"), + StringReplaceTestCase(" abc ", " AB", "123", "UTF8_LCASE_RTRIM", "123c "), StringReplaceTestCase("abcdabcd", "bc", "", "UNICODE", "adad"), + StringReplaceTestCase(" abc ", "b ", "x", "UNICODE_RTRIM", " abc "), StringReplaceTestCase("aBc世abc", "b", "12", "UNICODE_CI", "a12c世a12c"), StringReplaceTestCase("abi̇o12i̇o", "İo", "yy", "UNICODE_CI", "abyy12yy"), - StringReplaceTestCase("abİo12i̇o", "i̇o", "xx", "UNICODE_CI", "abxx12xx") + StringReplaceTestCase("abİo12i̇o", "i̇o", "xx", "UNICODE_CI", "abxx12xx"), + StringReplaceTestCase(" ABC ", "bc ", "123", "UNICODE_CI_RTRIM", " A123"), + StringReplaceTestCase("češalj", "eSal", "A", "SR_CI_AI", "čAj") ) val unsupportedTestCase = StringReplaceTestCase("abcde", "A", "B", "UNICODE_AI", "abcde") testCases.foreach(t => { @@ -493,9 +566,20 @@ class CollationStringExpressionsSuite case class EndsWithTestCase[R](left: String, right: String, collation: String, result: R) val testCases = Seq( EndsWithTestCase("", "", "UTF8_BINARY", true), + EndsWithTestCase("", " ", "UTF8_BINARY_RTRIM", true), EndsWithTestCase("abcde", "E", "UNICODE", false), + EndsWithTestCase("abcde ", "E ", "UNICODE_RTRIM", false), EndsWithTestCase("abcde", "FGH", "UTF8_LCASE", false), - EndsWithTestCase("abcde", "CDE", "UNICODE_CI", true) + EndsWithTestCase("abcde ", "FGH ", "UTF8_LCASE_RTRIM", false), + EndsWithTestCase("abcde", "CDE", "UNICODE_CI", true), + EndsWithTestCase("abc d e ", "C D E", "UNICODE_CI_RTRIM", true), + EndsWithTestCase("abcde", "CDE", "UNICODE_CI", true), + EndsWithTestCase("xnigħat", "għat", "MT", true), + // The following two test cases showcase different behavior based on collation. + EndsWithTestCase("xnigħat", "ħat", "MT_CI", false), + EndsWithTestCase("muljavo", "javo", "SR_CI", true), + EndsWithTestCase("xnigħat", "GĦat", "MT_CI", true), + EndsWithTestCase("xnigħat", "Għat", "MT_CI", true) ) val unsupportedTestCase = EndsWithTestCase("abcde", "A", "UNICODE_AI", false) testCases.foreach(t => { @@ -535,9 +619,13 @@ class CollationStringExpressionsSuite case class StringRepeatTestCase[R](str: String, times: Integer, collation: String, result: R) val testCases = Seq( StringRepeatTestCase("", 1, "UTF8_BINARY", ""), + StringRepeatTestCase(" ", 1, "UTF8_BINARY_RTRIM", " "), StringRepeatTestCase("a", 0, "UNICODE", ""), + StringRepeatTestCase("a", 0, "UNICODE_RTRIM", ""), StringRepeatTestCase("XY", 3, "UTF8_LCASE", "XYXYXY"), - StringRepeatTestCase("123", 2, "UNICODE_CI", "123123") + StringRepeatTestCase("XY ", 3, "UTF8_LCASE_RTRIM", "XY XY XY "), + StringRepeatTestCase("123", 2, "UNICODE_CI", "123123"), + StringRepeatTestCase("123 ", 2, "UNICODE_CI_RTRIM", "123 123 ") ) testCases.foreach(t => { // Unit test. @@ -557,9 +645,13 @@ class CollationStringExpressionsSuite case class AsciiTestCase[R](input: String, collation: String, result: R) val testCases = Seq( AsciiTestCase("a", "UTF8_BINARY", 97), + AsciiTestCase("a ", "UTF8_BINARY_RTRIM", 97), AsciiTestCase("B", "UTF8_LCASE", 66), + AsciiTestCase("B ", "UTF8_LCASE_RTRIM", 66), AsciiTestCase("#", "UNICODE", 35), - AsciiTestCase("!", "UNICODE_CI", 33) + AsciiTestCase("# ", "UNICODE_RTRIM", 35), + AsciiTestCase("!", "UNICODE_CI", 33), + AsciiTestCase("! ", "UNICODE_CI_RTRIM", 33) ) testCases.foreach(t => { // Unit test. @@ -577,9 +669,13 @@ class CollationStringExpressionsSuite case class ChrTestCase[R](input: Long, collation: String, result: R) val testCases = Seq( ChrTestCase(65, "UTF8_BINARY", "A"), + ChrTestCase(65, "UTF8_BINARY_RTRIM", "A"), ChrTestCase(66, "UTF8_LCASE", "B"), + ChrTestCase(66, "UTF8_LCASE_RTRIM", "B"), ChrTestCase(97, "UNICODE", "a"), - ChrTestCase(98, "UNICODE_CI", "b") + ChrTestCase(97, "UNICODE_RTRIM", "a"), + ChrTestCase(98, "UNICODE_CI", "b"), + ChrTestCase(98, "UNICODE_CI_RTRIM", "b") ) testCases.foreach(t => { // Unit test. @@ -597,9 +693,13 @@ class CollationStringExpressionsSuite case class UnBase64TestCase[R](input: String, collation: String, result: R) val testCases = Seq( UnBase64TestCase("QUJD", "UTF8_BINARY", Array(65, 66, 67)), + UnBase64TestCase("QUJD", "UTF8_BINARY_RTRIM", Array(65, 66, 67)), UnBase64TestCase("eHl6", "UTF8_LCASE", Array(120, 121, 122)), + UnBase64TestCase("eHl6", "UTF8_LCASE_RTRIM", Array(120, 121, 122)), UnBase64TestCase("IyMj", "UNICODE", Array(35, 35, 35)), - UnBase64TestCase("IQ==", "UNICODE_CI", Array(33)) + UnBase64TestCase("IyMj", "UNICODE_RTRIM", Array(35, 35, 35)), + UnBase64TestCase("IQ==", "UNICODE_CI", Array(33)), + UnBase64TestCase("IQ==", "UNICODE_CI_RTRIM", Array(33)) ) testCases.foreach(t => { // Unit test. @@ -617,9 +717,13 @@ class CollationStringExpressionsSuite case class Base64TestCase[R](input: Array[Byte], collation: String, result: R) val testCases = Seq( Base64TestCase(Array(65, 66, 67), "UTF8_BINARY", "QUJD"), + Base64TestCase(Array(65, 66, 67), "UTF8_BINARY_RTRIM", "QUJD"), Base64TestCase(Array(120, 121, 122), "UTF8_LCASE", "eHl6"), + Base64TestCase(Array(120, 121, 122), "UTF8_LCASE_RTRIM", "eHl6"), Base64TestCase(Array(35, 35, 35), "UNICODE", "IyMj"), - Base64TestCase(Array(33), "UNICODE_CI", "IQ==") + Base64TestCase(Array(35, 35, 35), "UNICODE_RTRIM", "IyMj"), + Base64TestCase(Array(33), "UNICODE_CI", "IQ=="), + Base64TestCase(Array(33), "UNICODE_CI_RTRIM", "IQ==") ) testCases.foreach(t => { // Unit test. @@ -638,9 +742,15 @@ class CollationStringExpressionsSuite case class FormatNumberTestCase[R](x: Double, d: String, collation: String, r: R) val testCases = Seq( FormatNumberTestCase(123.123, "###.###", "UTF8_BINARY", "123.123"), + FormatNumberTestCase(123.123, "###.###", "UTF8_BINARY_RTRIM", "123.123"), FormatNumberTestCase(99.99, "##.##", "UTF8_LCASE", "99.99"), + FormatNumberTestCase(99.99, "##.##", "UTF8_LCASE_RTRIM", "99.99"), FormatNumberTestCase(123.123, "###.###", "UNICODE", "123.123"), - FormatNumberTestCase(99.99, "##.##", "UNICODE_CI", "99.99") + FormatNumberTestCase(123.123, "###.###", "UNICODE_RTRIM", "123.123"), + FormatNumberTestCase(99.99, "##.##", "UNICODE_CI", "99.99"), + FormatNumberTestCase(99.99, "##.##", "UNICODE_CI_RTRIM", "99.99"), + FormatNumberTestCase(99.99, "##.##", "UNICODE_CI", "99.99"), + FormatNumberTestCase(99.999, "##.###", "AF_CI_AI", "99.999") ) testCases.foreach(t => { // Unit test. @@ -660,9 +770,13 @@ class CollationStringExpressionsSuite case class DecodeTestCase[R](input: String, collation: String, result: R) val testCases = Seq( DecodeTestCase("a", "UTF8_BINARY", "a"), + DecodeTestCase("a", "UTF8_BINARY_RTRIM", "a"), DecodeTestCase("A", "UTF8_LCASE", "A"), + DecodeTestCase("A", "UTF8_LCASE_RTRIM", "A"), DecodeTestCase("b", "UNICODE", "b"), - DecodeTestCase("B", "UNICODE_CI", "B") + DecodeTestCase("b", "UNICODE_RTRIM", "b"), + DecodeTestCase("B", "UNICODE_CI", "B"), + DecodeTestCase("B", "UNICODE_CI_RTRIM", "B") ) testCases.foreach(t => { // Unit test. @@ -689,9 +803,13 @@ class CollationStringExpressionsSuite case class EncodeTestCase[R](input: String, collation: String, result: R) val testCases = Seq( EncodeTestCase("a", "UTF8_BINARY", Array(97)), + EncodeTestCase("a ", "UTF8_BINARY_RTRIM", Array(97, 32)), EncodeTestCase("A", "UTF8_LCASE", Array(65)), + EncodeTestCase("A ", "UTF8_LCASE_RTRIM", Array(65, 32)), EncodeTestCase("b", "UNICODE", Array(98)), - EncodeTestCase("B", "UNICODE_CI", Array(66)) + EncodeTestCase("b ", "UNICODE_RTRIM", Array(98, 32)), + EncodeTestCase("B", "UNICODE_CI", Array(66)), + EncodeTestCase("B ", "UNICODE_CI_RTRIM", Array(66, 32)) ) testCases.foreach(t => { // Unit test. @@ -711,9 +829,13 @@ class CollationStringExpressionsSuite case class ToBinaryTestCase[R](expr: String, format: String, collation: String, result: R) val testCases = Seq( ToBinaryTestCase("a", "utf-8", "UTF8_BINARY", Array(97)), + ToBinaryTestCase("a ", "utf-8", "UTF8_BINARY_RTRIM", Array(97, 32)), ToBinaryTestCase("A", "utf-8", "UTF8_LCASE", Array(65)), + ToBinaryTestCase("A ", "utf-8", "UTF8_LCASE_RTRIM", Array(65, 32)), ToBinaryTestCase("b", "utf-8", "UNICODE", Array(98)), - ToBinaryTestCase("B", "utf-8", "UNICODE_CI", Array(66)) + ToBinaryTestCase("b ", "utf-8", "UNICODE_RTRIM", Array(98, 32)), + ToBinaryTestCase("B", "utf-8", "UNICODE_CI", Array(66)), + ToBinaryTestCase("B ", "utf-8", "UNICODE_CI_RTRIM", Array(66, 32)) ) testCases.foreach(t => { // Unit test. @@ -751,6 +873,11 @@ class CollationStringExpressionsSuite "Something else. Nothing here.", "UNICODE_CI", Seq(Seq("Something", "else"), Seq("Nothing", "here")) + ), + SentencesTestCase( + "Hello, dinja! Ġurnata sabiħa.", + "MT_AI", + Seq(Seq("Hello", "dinja"), Seq("Ġurnata", "sabiħa")) ) ) testCases.foreach(t => { @@ -770,9 +897,15 @@ class CollationStringExpressionsSuite case class UpperTestCase[R](input: String, collation: String, result: R) val testCases = Seq( UpperTestCase("aBc", "UTF8_BINARY", "ABC"), + UpperTestCase("aBc ", "UTF8_BINARY_RTRIM", "ABC "), UpperTestCase("aBc", "UTF8_LCASE", "ABC"), + UpperTestCase("aBc ", "UTF8_LCASE_RTRIM", "ABC "), UpperTestCase("aBc", "UNICODE", "ABC"), - UpperTestCase("aBc", "UNICODE_CI", "ABC") + UpperTestCase("aBc ", "UNICODE_RTRIM", "ABC "), + UpperTestCase("aBc", "UNICODE_CI", "ABC"), + UpperTestCase("aBc ", "UNICODE_CI_RTRIM", "ABC "), + UpperTestCase("aBc", "UNICODE_CI", "ABC"), + UpperTestCase("xnìgħat", "MT_CI_AI", "XNÌGĦAT") ) testCases.foreach(t => { // Unit test. @@ -790,9 +923,15 @@ class CollationStringExpressionsSuite case class LowerTestCase[R](input: String, collation: String, result: R) val testCases = Seq( LowerTestCase("aBc", "UTF8_BINARY", "abc"), + LowerTestCase("aBc ", "UTF8_BINARY_RTRIM", "abc "), LowerTestCase("aBc", "UTF8_LCASE", "abc"), + LowerTestCase("aBc ", "UTF8_LCASE_RTRIM", "abc "), LowerTestCase("aBc", "UNICODE", "abc"), - LowerTestCase("aBc", "UNICODE_CI", "abc") + LowerTestCase("aBc ", "UNICODE_RTRIM", "abc "), + LowerTestCase("aBc", "UNICODE_CI", "abc"), + LowerTestCase("aBc ", "UNICODE_CI_RTRIM", "abc "), + LowerTestCase("aBc", "UNICODE_CI", "abc"), + LowerTestCase("VeRGrÖßeRn", "DE_CI", "vergrößern") ) testCases.foreach(t => { // Unit test. @@ -810,9 +949,15 @@ class CollationStringExpressionsSuite case class InitCapTestCase[R](input: String, collation: String, result: R) val testCases = Seq( InitCapTestCase("aBc ABc", "UTF8_BINARY", "Abc Abc"), + InitCapTestCase(" aBc ABc ", "UTF8_BINARY_RTRIM", " Abc Abc "), InitCapTestCase("aBc ABc", "UTF8_LCASE", "Abc Abc"), + InitCapTestCase(" aBc ABc ", "UTF8_LCASE_RTRIM", " Abc Abc "), InitCapTestCase("aBc ABc", "UNICODE", "Abc Abc"), - InitCapTestCase("aBc ABc", "UNICODE_CI", "Abc Abc") + InitCapTestCase(" aBc ABc ", "UNICODE_RTRIM", " Abc Abc "), + InitCapTestCase("aBc ABc", "UNICODE_CI", "Abc Abc"), + InitCapTestCase(" aBc ABc ", "UNICODE_CI_RTRIM", " Abc Abc "), + InitCapTestCase("aBc ABc", "UNICODE_CI", "Abc Abc"), + InitCapTestCase("æØÅ ÆøÅ", "NO", "Æøå Æøå") ) testCases.foreach(t => { // Unit test. @@ -836,9 +981,13 @@ class CollationStringExpressionsSuite result: R) val testCases = Seq( OverlayTestCase("hello", " world", 6, -1, "UTF8_BINARY", "hello world"), + OverlayTestCase("hello ", " world ", 7, -1, "UTF8_BINARY_RTRIM", "hello world "), OverlayTestCase("nice", " day", 5, -1, "UTF8_LCASE", "nice day"), + OverlayTestCase(" nice ", " day ", 7, -1, "UTF8_LCASE_RTRIM", " nice day "), OverlayTestCase("A", "B", 1, -1, "UNICODE", "B"), - OverlayTestCase("!", "!!!", 1, -1, "UNICODE_CI", "!!!") + OverlayTestCase("A", " B ", 1, -1, "UNICODE_RTRIM", " B "), + OverlayTestCase("!", "!!!", 1, -1, "UNICODE_CI", "!!!"), + OverlayTestCase("!", " !!! ", 1, -1, "UNICODE_CI_RTRIM", " !!! ") ) testCases.foreach(t => { // Unit test. @@ -864,9 +1013,15 @@ class CollationStringExpressionsSuite result: R) val testCases = Seq( FormatStringTestCase("%s%s", Seq("a", "b"), "UTF8_BINARY", "ab"), + FormatStringTestCase("%s%s", Seq("a", "b "), "UTF8_BINARY_RTRIM", "ab "), FormatStringTestCase("%d", Seq(123), "UTF8_LCASE", "123"), + FormatStringTestCase("%d", Seq(123), "UTF8_LCASE_RTRIM", "123"), FormatStringTestCase("%s%d", Seq("A", 0), "UNICODE", "A0"), - FormatStringTestCase("%s%s", Seq("Hello", "!!!"), "UNICODE_CI", "Hello!!!") + FormatStringTestCase("%s%d", Seq(" A ", 0), "UNICODE_RTRIM", " A 0"), + FormatStringTestCase("%s%s", Seq("Hello", "!!!"), "UNICODE_CI", "Hello!!!"), + FormatStringTestCase("%s%s", Seq(" Hello ", " !!! "), "UNICODE_CI_RTRIM", " Hello !!! "), + FormatStringTestCase("%s%s", Seq("Hello", "!!!"), "UNICODE_CI", "Hello!!!"), + FormatStringTestCase("%s%s", Seq("Storslått", ".?!"), "NN_AI", "Storslått.?!") ) testCases.foreach(t => { // Unit test. @@ -895,9 +1050,15 @@ class CollationStringExpressionsSuite case class SoundExTestCase[R](input: String, collation: String, result: R) val testCases = Seq( SoundExTestCase("A", "UTF8_BINARY", "A000"), + SoundExTestCase("A", "UTF8_BINARY_RTRIM", "A000"), SoundExTestCase("!", "UTF8_LCASE", "!"), + SoundExTestCase("!", "UTF8_LCASE_RTRIM", "!"), SoundExTestCase("$", "UNICODE", "$"), - SoundExTestCase("X", "UNICODE_CI", "X000") + SoundExTestCase("$", "UNICODE_RTRIM", "$"), + SoundExTestCase("X", "UNICODE_CI", "X000"), + SoundExTestCase("X", "UNICODE_CI_RTRIM", "X000"), + SoundExTestCase("X", "UNICODE_CI", "X000"), + SoundExTestCase("ß", "DE", "ß") ) testCases.foreach(t => { // Unit test. @@ -915,9 +1076,16 @@ class CollationStringExpressionsSuite case class LengthTestCase[R](input: String, collation: String, result: R) val testCases = Seq( LengthTestCase("", "UTF8_BINARY", 0), + LengthTestCase(" ", "UTF8_BINARY_RTRIM", 1), LengthTestCase("abc", "UTF8_LCASE", 3), + LengthTestCase("abc ", "UTF8_LCASE_RTRIM", 4), LengthTestCase("hello", "UNICODE", 5), - LengthTestCase("ff", "UNICODE_CI", 1) + LengthTestCase("hello ", "UNICODE_RTRIM", 6), + LengthTestCase("ff", "UNICODE_CI", 1), + LengthTestCase("ff ", "UNICODE_CI_RTRIM", 2), + LengthTestCase("ff", "UNICODE_CI", 1), + LengthTestCase("groß", "DE_CI_AI", 4), + LengthTestCase("gross", "DE_AI", 5) ) testCases.foreach(t => { // Unit test. @@ -935,9 +1103,15 @@ class CollationStringExpressionsSuite case class BitLengthTestCase[R](input: String, collation: String, result: R) val testCases = Seq( BitLengthTestCase("", "UTF8_BINARY", 0), + BitLengthTestCase(" ", "UTF8_BINARY_RTRIM", 8), BitLengthTestCase("abc", "UTF8_LCASE", 24), + BitLengthTestCase("abc ", "UTF8_LCASE_RTRIM", 32), BitLengthTestCase("hello", "UNICODE", 40), - BitLengthTestCase("ff", "UNICODE_CI", 24) + BitLengthTestCase("hello ", "UNICODE_RTRIM", 48), + BitLengthTestCase("ff", "UNICODE_CI", 24), + BitLengthTestCase("ff ", "UNICODE_CI_RTRIM", 32), + BitLengthTestCase("ff", "UNICODE_CI", 24), + BitLengthTestCase("GROß", "DE", 40) ) testCases.foreach(t => { // Unit test. @@ -955,9 +1129,13 @@ class CollationStringExpressionsSuite case class OctetLengthTestCase[R](input: String, collation: String, result: R) val testCases = Seq( OctetLengthTestCase("", "UTF8_BINARY", 0), + OctetLengthTestCase(" ", "UTF8_BINARY_RTRIM", 1), OctetLengthTestCase("abc", "UTF8_LCASE", 3), + OctetLengthTestCase("abc ", "UTF8_LCASE_RTRIM", 4), OctetLengthTestCase("hello", "UNICODE", 5), - OctetLengthTestCase("ff", "UNICODE_CI", 3) + OctetLengthTestCase("hello ", "UNICODE_RTRIM", 6), + OctetLengthTestCase("ff", "UNICODE_CI", 3), + OctetLengthTestCase("ff ", "UNICODE_CI_RTRIM", 4) ) testCases.foreach(t => { // Unit test. @@ -975,9 +1153,13 @@ class CollationStringExpressionsSuite case class LuhncheckTestCase[R](input: String, collation: String, result: R) val testCases = Seq( LuhncheckTestCase("123", "UTF8_BINARY", false), + LuhncheckTestCase("123", "UTF8_BINARY_RTRIM", false), LuhncheckTestCase("000", "UTF8_LCASE", true), + LuhncheckTestCase("000", "UTF8_LCASE_RTRIM", true), LuhncheckTestCase("111", "UNICODE", false), - LuhncheckTestCase("222", "UNICODE_CI", false) + LuhncheckTestCase("111", "UNICODE_RTRIM", false), + LuhncheckTestCase("222", "UNICODE_CI", false), + LuhncheckTestCase("222", "UNICODE_CI_RTRIM", false) ) testCases.foreach(t => { // Unit test. @@ -1000,9 +1182,17 @@ class CollationStringExpressionsSuite result: R) val testCases = Seq( LevenshteinTestCase("kitten", "sitTing", "UTF8_BINARY", None, 4), + LevenshteinTestCase("kitten", "sitTing ", "UTF8_BINARY_RTRIM", None, 6), LevenshteinTestCase("kitten", "sitTing", "UTF8_LCASE", None, 4), + LevenshteinTestCase("kitten", "sitTing ", "UTF8_LCASE", None, 6), LevenshteinTestCase("kitten", "sitTing", "UNICODE", Some(3), -1), - LevenshteinTestCase("kitten", "sitTing", "UNICODE_CI", Some(3), -1) + LevenshteinTestCase("kitten", "sitTing ", "UNICODE_RTRIM", Some(3), -1), + LevenshteinTestCase("kitten", "sitTing", "UNICODE_CI", Some(3), -1), + LevenshteinTestCase("kitten ", "sitTing ", "UNICODE_CI_RTRIM", Some(3), -1), + LevenshteinTestCase("kitten", "sitTing", "UNICODE_CI", Some(3), -1), + // Levenshtein function is currently not collation-aware (not considering case or accent). + LevenshteinTestCase("gr", "GR", "UNICODE_CI_AI", None, 2), + LevenshteinTestCase("groß", "Größer", "UNICODE_CI_AI", None, 4) ) testCases.foreach(t => { // Unit test. @@ -1024,9 +1214,15 @@ class CollationStringExpressionsSuite case class IsValidUTF8TestCase[R](input: Any, collation: String, result: R) val testCases = Seq( IsValidUTF8TestCase(null, "UTF8_BINARY", null), + IsValidUTF8TestCase(null, "UTF8_BINARY_RTRIM", null), IsValidUTF8TestCase("", "UTF8_LCASE", true), + IsValidUTF8TestCase("", "UTF8_LCASE_RTRIM", true), IsValidUTF8TestCase("abc", "UNICODE", true), - IsValidUTF8TestCase("hello", "UNICODE_CI", true) + IsValidUTF8TestCase("abc", "UNICODE_RTRIM", true), + IsValidUTF8TestCase("hello", "UNICODE_CI", true), + IsValidUTF8TestCase("hello", "UNICODE_CI_RTRIM", true), + IsValidUTF8TestCase("hello", "UNICODE_CI", true), + IsValidUTF8TestCase("ćao", "SR_CI_AI", true) ) testCases.foreach(t => { // Unit test. @@ -1045,9 +1241,13 @@ class CollationStringExpressionsSuite case class MakeValidUTF8TestCase[R](input: String, collation: String, result: R) val testCases = Seq( MakeValidUTF8TestCase(null, "UTF8_BINARY", null), + MakeValidUTF8TestCase(null, "UTF8_BINARY_RTRIM", null), MakeValidUTF8TestCase("", "UTF8_LCASE", ""), + MakeValidUTF8TestCase("", "UTF8_LCASE_RTRIM", ""), MakeValidUTF8TestCase("abc", "UNICODE", "abc"), - MakeValidUTF8TestCase("hello", "UNICODE_CI", "hello") + MakeValidUTF8TestCase("abc", "UNICODE_RTRIM", "abc"), + MakeValidUTF8TestCase("hello", "UNICODE_CI", "hello"), + MakeValidUTF8TestCase("hello", "UNICODE_CI_RTRIM", "hello") ) testCases.foreach(t => { // Unit test. @@ -1066,9 +1266,13 @@ class CollationStringExpressionsSuite case class ValidateUTF8TestCase[R](input: String, collation: String, result: R) val testCases = Seq( ValidateUTF8TestCase(null, "UTF8_BINARY", null), + ValidateUTF8TestCase(null, "UTF8_BINARY_RTRIM", null), ValidateUTF8TestCase("", "UTF8_LCASE", ""), + ValidateUTF8TestCase("", "UTF8_LCASE_RTRIM", ""), ValidateUTF8TestCase("abc", "UNICODE", "abc"), - ValidateUTF8TestCase("hello", "UNICODE_CI", "hello") + ValidateUTF8TestCase("abc", "UNICODE_RTRIM", "abc"), + ValidateUTF8TestCase("hello", "UNICODE_CI", "hello"), + ValidateUTF8TestCase("hello", "UNICODE_CI_RTRIM", "hello") ) testCases.foreach(t => { // Unit test. @@ -1087,9 +1291,13 @@ class CollationStringExpressionsSuite case class ValidateUTF8TestCase(input: String, collation: String, result: Any) val testCases = Seq( ValidateUTF8TestCase(null, "UTF8_BINARY", null), + ValidateUTF8TestCase(null, "UTF8_BINARY_RTRIM", null), ValidateUTF8TestCase("", "UTF8_LCASE", ""), + ValidateUTF8TestCase("", "UTF8_LCASE_RTRIM", ""), ValidateUTF8TestCase("abc", "UNICODE", "abc"), - ValidateUTF8TestCase("hello", "UNICODE_CI", "hello") + ValidateUTF8TestCase("abc", "UNICODE_RTRIM", "abc"), + ValidateUTF8TestCase("hello", "UNICODE_CI", "hello"), + ValidateUTF8TestCase("hello", "UNICODE_CI_RTRIM", "hello") ) testCases.foreach(t => { // Unit test. @@ -1113,9 +1321,12 @@ class CollationStringExpressionsSuite result: R) val testCases = Seq( SubstringTestCase("example", 1, Some(100), "UTF8_LCASE", "example"), + SubstringTestCase("example ", 1, Some(100), "UTF8_LCASE_RTRIM", "example "), SubstringTestCase("example", 2, Some(2), "UTF8_BINARY", "xa"), SubstringTestCase("example", 0, Some(0), "UNICODE", ""), + SubstringTestCase("example", 0, Some(0), "UNICODE_RTRIM", ""), SubstringTestCase("example", -3, Some(2), "UNICODE_CI", "pl"), + SubstringTestCase("example ", -3, Some(2), "UNICODE_CI_RTRIM", "le"), SubstringTestCase(" a世a ", 2, Some(3), "UTF8_LCASE", "a世a"), SubstringTestCase("", 1, Some(1), "UTF8_LCASE", ""), SubstringTestCase("", 1, Some(1), "UNICODE", ""), @@ -1124,7 +1335,10 @@ class CollationStringExpressionsSuite SubstringTestCase(null, null, Some(null), "UTF8_BINARY", null), SubstringTestCase(null, null, None, "UNICODE_CI", null), SubstringTestCase("ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ", null, None, "UTF8_BINARY", null), - SubstringTestCase("", null, None, "UNICODE_CI", null) + SubstringTestCase("ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ", null, None, "UTF8_BINARY_RTRIM", null), + SubstringTestCase("", null, None, "UNICODE_CI", null), + SubstringTestCase("", null, None, "UNICODE_CI", null), + SubstringTestCase("xnigħat", 4, Some(2), "MT_CI_AI", "għ") ) testCases.foreach(t => { // Unit test. @@ -1147,9 +1361,15 @@ class CollationStringExpressionsSuite case class LeftTestCase[R](str: String, len: Integer, collation: String, result: R) val testCases = Seq( LeftTestCase(null, null, "UTF8_BINARY", null), + LeftTestCase(null, null, "UTF8_BINARY_RTRIM", null), LeftTestCase(" a世a ", 3, "UTF8_LCASE", " a世"), + LeftTestCase(" a世a ", 3, "UTF8_LCASE_RTRIM", " a世"), LeftTestCase("", 1, "UNICODE", ""), - LeftTestCase("ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ", 3, "UNICODE", "ÀÃÂ") + LeftTestCase("", 1, "UNICODE_RTRIM", ""), + LeftTestCase("ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ", 3, "UNICODE", "ÀÃÂ"), + LeftTestCase("ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ", 3, "UNICODE_RTRIM", "ÀÃÂ"), + LeftTestCase("ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ", 3, "UNICODE", "ÀÃÂ"), + LeftTestCase("ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ", 7, "NO_AI", "ÀÃÂĀĂȦÄ") ) testCases.foreach(t => { // Unit test. @@ -1170,9 +1390,15 @@ class CollationStringExpressionsSuite case class RightTestCase[R](str: String, len: Integer, collation: String, result: R) val testCases = Seq( RightTestCase(null, null, "UTF8_BINARY", null), + RightTestCase(null, null, "UTF8_BINARY_RTRIM", null), RightTestCase(" a世a ", 3, "UTF8_LCASE", "世a "), + RightTestCase(" a世a ", 3, "UTF8_LCASE_RTRIM", "世a "), RightTestCase("", 1, "UNICODE", ""), - RightTestCase("ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ", 3, "UNICODE", "ǢǼÆ") + RightTestCase("", 1, "UNICODE_RTRIM", ""), + RightTestCase("ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ", 3, "UNICODE", "ǢǼÆ"), + RightTestCase("ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ", 3, "UNICODE_RTRIM", "ǢǼÆ"), + RightTestCase("ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ", 3, "UNICODE", "ǢǼÆ"), + RightTestCase("ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ", 5, "NO_CI_AI", "ȻȻǢǼÆ") ) testCases.foreach(t => { // Unit test. @@ -1198,13 +1424,18 @@ class CollationStringExpressionsSuite result: R) val testCases = Seq( StringRPadTestCase("", 5, " ", "UTF8_BINARY", " "), + StringRPadTestCase("", 5, " ", "UTF8_BINARY_RTRIM", " "), StringRPadTestCase("abc", 5, " ", "UNICODE", "abc "), + StringRPadTestCase("ab c ", 5, " ", "UNICODE_RTRIM", "ab c "), StringRPadTestCase("Hello", 7, "Wörld", "UTF8_LCASE", "HelloWö"), StringRPadTestCase("1234567890", 5, "aaaAAa", "UNICODE_CI", "12345"), + StringRPadTestCase("12 34567890", 5, "aaaAAa", "UNICODE_CI_RTRIM", "12 34"), StringRPadTestCase("aaAA", 2, " ", "UTF8_BINARY", "aa"), StringRPadTestCase("ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ℀℃", 2, "1", "UTF8_LCASE", "ÀÃ"), StringRPadTestCase("ĂȦÄäåäá", 20, "ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ", "UNICODE", "ĂȦÄäåäáÀÃÂĀĂȦÄäåäáâã"), - StringRPadTestCase("aȦÄä", 8, "a1", "UNICODE_CI", "aȦÄäa1a1") + StringRPadTestCase("ĂȦÄäåäá", 20, "ÀÃÂĀĂȦÄäåäáâãȻȻ", "UNICODE_RTRIM", "ĂȦÄäåäáÀÃÂĀĂȦÄäåäáâã"), + StringRPadTestCase("aȦÄä", 8, "a1", "UNICODE_CI", "aȦÄäa1a1"), + StringRPadTestCase("aȦÄä", 8, "a1", "UNICODE_CI_RTRIM", "aȦÄäa1a1") ) testCases.foreach(t => { // Unit test. @@ -1230,13 +1461,17 @@ class CollationStringExpressionsSuite result: R) val testCases = Seq( StringLPadTestCase("", 5, " ", "UTF8_BINARY", " "), + StringLPadTestCase("", 5, " ", "UTF8_BINARY_RTRIM", " "), StringLPadTestCase("abc", 5, " ", "UNICODE", " abc"), StringLPadTestCase("Hello", 7, "Wörld", "UTF8_LCASE", "WöHello"), + StringLPadTestCase("Hello", 7, "W örld", "UTF8_LCASE_RTRIM", "W Hello"), StringLPadTestCase("1234567890", 5, "aaaAAa", "UNICODE_CI", "12345"), StringLPadTestCase("aaAA", 2, " ", "UTF8_BINARY", "aa"), StringLPadTestCase("ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ℀℃", 2, "1", "UTF8_LCASE", "ÀÃ"), StringLPadTestCase("ĂȦÄäåäá", 20, "ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ", "UNICODE", "ÀÃÂĀĂȦÄäåäáâãĂȦÄäåäá"), - StringLPadTestCase("aȦÄä", 8, "a1", "UNICODE_CI", "a1a1aȦÄä") + StringLPadTestCase("ĂȦÄäåäá", 20, "ÀÃÂĀĂȦÄäåäáâãȻȻ", "UNICODE_RTRIM", "ÀÃÂĀĂȦÄäåäáâãĂȦÄäåäá"), + StringLPadTestCase("aȦÄä", 8, "a1", "UNICODE_CI", "a1a1aȦÄä"), + StringLPadTestCase("aȦÄ ", 8, "a1", "UNICODE_CI_RTRIM", "a1a1aȦÄ ") ) testCases.foreach(t => { // Unit test. @@ -1262,13 +1497,25 @@ class CollationStringExpressionsSuite result: R) val testCases = Seq( StringLocateTestCase("aa", "aaads", 0, "UTF8_BINARY", 0), + StringLocateTestCase(" ", "", 1, "UTF8_BINARY_RTRIM", 1), + StringLocateTestCase(" abc ", " cdfg abc ", 1, "UTF8_BINARY_RTRIM", 12), StringLocateTestCase("aa", "Aaads", 0, "UTF8_LCASE", 0), StringLocateTestCase("界x", "test大千世界X大千世界", 1, "UTF8_LCASE", 8), StringLocateTestCase("aBc", "abcabc", 4, "UTF8_LCASE", 4), + StringLocateTestCase("aa", "Aaads", 0, "UTF8_LCASE_RTRIM", 0), + StringLocateTestCase("界 ", "test大千世界X大千世界", 1, "UTF8_LCASE_RTRIM", 8), + StringLocateTestCase("aBc", "a bc abc ", 4, "UTF8_LCASE_RTRIM", 6), StringLocateTestCase("aa", "Aaads", 0, "UNICODE", 0), StringLocateTestCase("abC", "abCabC", 2, "UNICODE", 4), + StringLocateTestCase("aa", "Aaads", 0, "UNICODE_RTRIM", 0), + StringLocateTestCase("abC ", "ab C abC ", 2, "UNICODE_RTRIM", 6), StringLocateTestCase("aa", "Aaads", 0, "UNICODE_CI", 0), - StringLocateTestCase("界x", "test大千世界X大千世界", 1, "UNICODE_CI", 8) + StringLocateTestCase("界x", "test大千世界X大千世界", 1, "UNICODE_CI", 8), + StringLocateTestCase("aa", "Aaads", 0, "UNICODE_CI_RTRIM", 0), + StringLocateTestCase(" 界", "test大千世界X大千世界", 1, "UNICODE_CI_RTRIM", 0), + StringLocateTestCase("oa", "TÖäöäoAoa", 1, "DE", 8), + StringLocateTestCase("oa", "TÖäöäoAoa", 1, "DE_CI", 6), + StringLocateTestCase("oa", "TÖäöäoAoa", 1, "DE_CI_AI", 2) ) val unsupportedTestCase = StringLocateTestCase("aa", "Aaads", 0, "UNICODE_AI", 1) testCases.foreach(t => { @@ -1314,9 +1561,22 @@ class CollationStringExpressionsSuite result: R) val testCases = Seq( StringTrimLeftTestCase("xxasdxx", Some("x"), "UTF8_BINARY", "asdxx"), + StringTrimLeftTestCase(" xxasdxx", Some("x"), "UTF8_BINARY_RTRIM", " xxasdxx"), + StringTrimLeftTestCase(" xxasdxx", Some("x "), "UTF8_BINARY_RTRIM", "asdxx"), + StringTrimLeftTestCase(" xxasdxx ", Some("x "), "UTF8_BINARY_RTRIM", "asdxx "), StringTrimLeftTestCase("xxasdxx", Some("X"), "UTF8_LCASE", "asdxx"), + StringTrimLeftTestCase("xxasdxx ", Some("X "), "UTF8_LCASE_RTRIM", "asdxx "), + StringTrimLeftTestCase("xxasdxx ", Some("X"), "UTF8_LCASE_RTRIM", "asdxx "), + StringTrimLeftTestCase(" xxasdxx ", Some("X "), "UTF8_LCASE_RTRIM", "asdxx "), StringTrimLeftTestCase("xxasdxx", Some("y"), "UNICODE", "xxasdxx"), - StringTrimLeftTestCase(" asd ", None, "UNICODE_CI", "asd ") + StringTrimLeftTestCase("xxasdxx", Some("y"), "UNICODE_RTRIM", "xxasdxx"), + StringTrimLeftTestCase(" asd ", None, "UNICODE_RTRIM", "asd "), + StringTrimLeftTestCase(" asd ", None, "UNICODE_CI", "asd "), + StringTrimLeftTestCase(" asd ", Some("A"), "UNICODE_CI_RTRIM", " asd "), + StringTrimLeftTestCase(" asd ", None, "UNICODE_CI", "asd "), + StringTrimLeftTestCase("ćĆčČcCabCcČčĆć", Some("ć"), "SR", "ĆčČcCabCcČčĆć"), + StringTrimLeftTestCase("ćĆčČcCabCcČčĆć", Some("Ć"), "SR_CI", "čČcCabCcČčĆć"), + StringTrimLeftTestCase("ćĆčČcCabCcČčĆć", Some("ć"), "SR_CI_AI", "abCcČčĆć") ) val unsupportedTestCase = StringTrimLeftTestCase("xxasdxx", Some("x"), "UNICODE_AI", null) testCases.foreach(t => { @@ -1360,10 +1620,25 @@ class CollationStringExpressionsSuite collation: String, result: R) val testCases = Seq( - StringTrimRightTestCase("xxasdxx", Some("x"), "UTF8_BINARY", "xxasd"), + StringTrimRightTestCase(" xxasdxx", Some("x "), "UTF8_BINARY", " xxasd"), + StringTrimRightTestCase("xxasdxx ", Some("x "), "UTF8_BINARY_RTRIM", "xxasd"), + StringTrimRightTestCase("xxasdxx ", Some("x"), "UTF8_BINARY_RTRIM", "xxasd "), + StringTrimRightTestCase(" xxasdxx ", Some("x "), "UTF8_BINARY_RTRIM", " xxasd"), + StringTrimRightTestCase(" xxasdxx", Some("x"), "UTF8_BINARY_RTRIM", " xxasd"), StringTrimRightTestCase("xxasdxx", Some("X"), "UTF8_LCASE", "xxasd"), + StringTrimRightTestCase("xxasdxx ", Some("X "), "UTF8_LCASE_RTRIM", "xxasd"), + StringTrimRightTestCase("xxasdxx ", Some("X"), "UTF8_LCASE_RTRIM", "xxasd "), + StringTrimRightTestCase(" xxasdxx ", Some("X "), "UTF8_LCASE_RTRIM", " xxasd"), + StringTrimRightTestCase(" xxasdxx", Some("x"), "UTF8_LCASE_RTRIM", " xxasd"), StringTrimRightTestCase("xxasdxx", Some("y"), "UNICODE", "xxasdxx"), - StringTrimRightTestCase(" asd ", None, "UNICODE_CI", " asd") + StringTrimRightTestCase("xxasdxx", Some("y"), "UNICODE_RTRIM", "xxasdxx"), + StringTrimRightTestCase(" asd ", None, "UNICODE_RTRIM", " asd"), + StringTrimRightTestCase(" asd ", None, "UNICODE_CI", " asd"), + StringTrimRightTestCase(" asd ", Some("D"), "UNICODE_CI_RTRIM", " as "), + StringTrimRightTestCase(" asd ", None, "UNICODE_CI", " asd"), + StringTrimRightTestCase("ćĆčČcCabCcČčĆć", Some("ć"), "SR", "ćĆčČcCabCcČčĆ"), + StringTrimRightTestCase("ćĆčČcCabCcČčĆć", Some("Ć"), "SR_CI", "ćĆčČcCabCcČč"), + StringTrimRightTestCase("ćĆčČcCabCcČčĆć", Some("ć"), "SR_CI_AI", "ćĆčČcCab") ) val unsupportedTestCase = StringTrimRightTestCase("xxasdxx", Some("x"), "UNICODE_AI", "xxasd") testCases.foreach(t => { @@ -1409,9 +1684,25 @@ class CollationStringExpressionsSuite result: R) val testCases = Seq( StringTrimTestCase("xxasdxx", Some("x"), "UTF8_BINARY", "asd"), + StringTrimTestCase("xxasdxx ", Some("x "), "UTF8_BINARY_RTRIM", "asd"), + StringTrimTestCase("xxasdxx ", Some("x"), "UTF8_BINARY_RTRIM", "asd "), + StringTrimTestCase(" xxasdxx ", Some("x "), "UTF8_BINARY_RTRIM", "asd"), + StringTrimTestCase(" xxasdxx", Some("x"), "UTF8_BINARY_RTRIM", " xxasd"), StringTrimTestCase("xxasdxx", Some("X"), "UTF8_LCASE", "asd"), + StringTrimTestCase("xxasdxx ", Some("X "), "UTF8_LCASE_RTRIM", "asd"), + StringTrimTestCase("xxasdxx ", Some("X"), "UTF8_LCASE_RTRIM", "asd "), + StringTrimTestCase(" xxasdxx ", Some("X "), "UTF8_LCASE_RTRIM", "asd"), + StringTrimTestCase(" xxasdxx", Some("x"), "UTF8_LCASE_RTRIM", " xxasd"), StringTrimTestCase("xxasdxx", Some("y"), "UNICODE", "xxasdxx"), - StringTrimTestCase(" asd ", None, "UNICODE_CI", "asd") + StringTrimTestCase("xxasdxx", Some("y"), "UNICODE_RTRIM", "xxasdxx"), + StringTrimTestCase(" asd ", None, "UNICODE_RTRIM", "asd"), + StringTrimTestCase(" asd ", None, "UNICODE_CI", "asd"), + StringTrimTestCase(" asd ", Some("D"), "UNICODE_CI_RTRIM", " as "), + StringTrimTestCase(" asd ", None, "UNICODE_CI", "asd"), + StringTrimTestCase("ćĆčČcCabCcČčĆć", Some("ć"), "SR", "ĆčČcCabCcČčĆ"), + StringTrimTestCase("ćĆčČcCabCcČčĆć", Some("Ć"), "SR_CI", "čČcCabCcČč"), + StringTrimTestCase("ćĆčČcCabCcČčĆć", Some("ć"), "SR_CI_AI", "ab"), + StringTrimTestCase(" ćĆčČcCabCcČčĆć ", None, "SR_CI_AI", "ćĆčČcCabCcČčĆć") ) val unsupportedTestCase = StringTrimTestCase("xxasdxx", Some("x"), "UNICODE_AI", "asd") testCases.foreach(t => { @@ -1456,9 +1747,25 @@ class CollationStringExpressionsSuite result: R) val testCases = Seq( StringTrimBothTestCase("xxasdxx", Some("x"), "UTF8_BINARY", "asd"), + StringTrimBothTestCase("xxasdxx ", Some("x "), "UTF8_BINARY_RTRIM", "asd"), + StringTrimBothTestCase("xxasdxx ", Some("x"), "UTF8_BINARY_RTRIM", "asd "), + StringTrimBothTestCase(" xxasdxx ", Some("x "), "UTF8_BINARY_RTRIM", "asd"), + StringTrimBothTestCase(" xxasdxx", Some("x"), "UTF8_BINARY_RTRIM", " xxasd"), StringTrimBothTestCase("xxasdxx", Some("X"), "UTF8_LCASE", "asd"), + StringTrimBothTestCase("xxasdxx ", Some("X "), "UTF8_LCASE_RTRIM", "asd"), + StringTrimBothTestCase("xxasdxx ", Some("X"), "UTF8_LCASE_RTRIM", "asd "), + StringTrimBothTestCase(" xxasdxx ", Some("X "), "UTF8_LCASE_RTRIM", "asd"), + StringTrimBothTestCase(" xxasdxx", Some("x"), "UTF8_LCASE_RTRIM", " xxasd"), StringTrimBothTestCase("xxasdxx", Some("y"), "UNICODE", "xxasdxx"), - StringTrimBothTestCase(" asd ", None, "UNICODE_CI", "asd") + StringTrimBothTestCase("xxasdxx", Some("y"), "UNICODE_RTRIM", "xxasdxx"), + StringTrimBothTestCase(" asd ", None, "UNICODE_RTRIM", "asd"), + StringTrimBothTestCase(" asd ", None, "UNICODE_CI", "asd"), + StringTrimBothTestCase(" asd ", Some("D"), "UNICODE_CI_RTRIM", " as "), + StringTrimBothTestCase(" asd ", None, "UNICODE_CI", "asd"), + StringTrimBothTestCase("ćĆčČcCabCcČčĆć", Some("ć"), "SR", "ĆčČcCabCcČčĆ"), + StringTrimBothTestCase("ćĆčČcCabCcČčĆć", Some("Ć"), "SR_CI", "čČcCabCcČč"), + StringTrimBothTestCase("ćĆčČcCabCcČčĆć", Some("ć"), "SR_CI_AI", "ab"), + StringTrimBothTestCase(" ćĆčČcCabCcČčĆć ", None, "SR_CI_AI", "ćĆčČcCabCcČčĆć") ) testCases.foreach(t => { // Unit test. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index f2444df38b3b2..3a5a650e5a0c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -1941,7 +1941,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } test("cache table with collated columns") { - val collations = Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI") + val collations = Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI", "SR_CI_AI") val lazyOptions = Seq(false, true) for ( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 25f4d9f62354a..7ebcb280def6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -2270,7 +2270,7 @@ class DataFrameAggregateSuite extends QueryTest } private def assertDecimalSumOverflow( - df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row): Unit = { + df: DataFrame, ansiEnabled: Boolean, fnName: String, expectedAnswer: Row): Unit = { if (!ansiEnabled) { checkAnswer(df, expectedAnswer) } else { @@ -2278,11 +2278,12 @@ class DataFrameAggregateSuite extends QueryTest df.collect() } assert(e.getMessage.contains("cannot be represented as Decimal") || - e.getMessage.contains("Overflow in sum of decimals")) + e.getMessage.contains(s"Overflow in sum of decimals. Use 'try_$fnName' to tolerate " + + s"overflow and return NULL instead.")) } } - def checkAggResultsForDecimalOverflow(aggFn: Column => Column): Unit = { + def checkAggResultsForDecimalOverflow(aggFn: Column => Column, fnName: String): Unit = { Seq("true", "false").foreach { wholeStageEnabled => withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) { Seq(true, false).foreach { ansiEnabled => @@ -2306,27 +2307,27 @@ class DataFrameAggregateSuite extends QueryTest join(df, "intNum").agg(aggFn($"decNum")) val expectedAnswer = Row(null) - assertDecimalSumOverflow(df2, ansiEnabled, expectedAnswer) + assertDecimalSumOverflow(df2, ansiEnabled, fnName, expectedAnswer) val decStr = "1" + "0" * 19 val d1 = spark.range(0, 12, 1, 1) val d2 = d1.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(aggFn($"d")) - assertDecimalSumOverflow(d2, ansiEnabled, expectedAnswer) + assertDecimalSumOverflow(d2, ansiEnabled, fnName, expectedAnswer) val d3 = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1)) val d4 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(aggFn($"d")) - assertDecimalSumOverflow(d4, ansiEnabled, expectedAnswer) + assertDecimalSumOverflow(d4, ansiEnabled, fnName, expectedAnswer) val d5 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d"), lit(1).as("key")).groupBy("key").agg(aggFn($"d").alias("aggd")).select($"aggd") - assertDecimalSumOverflow(d5, ansiEnabled, expectedAnswer) + assertDecimalSumOverflow(d5, ansiEnabled, fnName, expectedAnswer) val nullsDf = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d")) val largeDecimals = Seq(BigDecimal("1"* 20 + ".123"), BigDecimal("9"* 20 + ".123")). toDF("d") assertDecimalSumOverflow( - nullsDf.union(largeDecimals).agg(aggFn($"d")), ansiEnabled, expectedAnswer) + nullsDf.union(largeDecimals).agg(aggFn($"d")), ansiEnabled, fnName, expectedAnswer) val df3 = Seq( (BigDecimal("10000000000000000000"), 1), @@ -2344,9 +2345,9 @@ class DataFrameAggregateSuite extends QueryTest (BigDecimal("20000000000000000000"), 2)).toDF("decNum", "intNum") val df6 = df3.union(df4).union(df5) - val df7 = df6.groupBy("intNum").agg(sum("decNum"), countDistinct("decNum")). + val df7 = df6.groupBy("intNum").agg(aggFn($"decNum"), countDistinct("decNum")). filter("intNum == 1") - assertDecimalSumOverflow(df7, ansiEnabled, Row(1, null, 2)) + assertDecimalSumOverflow(df7, ansiEnabled, fnName, Row(1, null, 2)) } } } @@ -2354,11 +2355,11 @@ class DataFrameAggregateSuite extends QueryTest } test("SPARK-28067: Aggregate sum should not return wrong results for decimal overflow") { - checkAggResultsForDecimalOverflow(c => sum(c)) + checkAggResultsForDecimalOverflow(c => sum(c), "sum") } test("SPARK-35955: Aggregate avg should not return wrong results for decimal overflow") { - checkAggResultsForDecimalOverflow(c => avg(c)) + checkAggResultsForDecimalOverflow(c => avg(c), "avg") } test("SPARK-28224: Aggregate sum big decimal overflow") { @@ -2369,7 +2370,7 @@ class DataFrameAggregateSuite extends QueryTest Seq(true, false).foreach { ansiEnabled => withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) { val structDf = largeDecimals.select("a").agg(sum("a")) - assertDecimalSumOverflow(structDf, ansiEnabled, Row(null)) + assertDecimalSumOverflow(structDf, ansiEnabled, "sum", Row(null)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 4cab05dfd2b9b..b65636dfcde07 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -1366,6 +1366,84 @@ class DateFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer(result1, result2) } + test("try_make_timestamp") { + val df = Seq((100, 11, 1, 12, 30, 01.001001, "UTC")). + toDF("year", "month", "day", "hour", "min", "sec", "timezone") + + val result1 = df.selectExpr("try_make_timestamp(year, month, day, hour, min, sec, timezone)") + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + val result2 = df.select(make_timestamp( + col("year"), col("month"), col("day"), col("hour"), + col("min"), col("sec"), col("timezone"))) + checkAnswer(result1, result2) + } + + val result3 = df.selectExpr("try_make_timestamp(year, month, day, hour, min, sec)") + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + val result4 = df.select(make_timestamp( + col("year"), col("month"), col("day"), col("hour"), + col("min"), col("sec"))) + checkAnswer(result3, result4) + } + + val result5 = df.selectExpr("try_make_timestamp(year, month, day, hour, min, sec)") + val result6 = df.select(try_make_timestamp( + col("year"), col("month"), col("day"), col("hour"), + col("min"), col("sec"))) + checkAnswer(result5, result6) + } + + test("try_make_timestamp_ntz") { + val df = Seq((100, 11, 1, 12, 30, 01.001001)). + toDF("year", "month", "day", "hour", "min", "sec") + + val result1 = df.selectExpr( + "try_make_timestamp_ntz(year, month, day, hour, min, sec)") + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + val result2 = df.select(make_timestamp_ntz( + col("year"), col("month"), col("day"), col("hour"), + col("min"), col("sec"))) + checkAnswer(result1, result2) + } + + val result3 = df.selectExpr( + "try_make_timestamp_ntz(year, month, day, hour, min, sec)") + val result4 = df.select(try_make_timestamp_ntz( + col("year"), col("month"), col("day"), col("hour"), + col("min"), col("sec"))) + checkAnswer(result3, result4) + } + + test("try_make_timestamp_ltz") { + val df = Seq((100, 11, 1, 12, 30, 01.001001, "UTC")). + toDF("year", "month", "day", "hour", "min", "sec", "timezone") + + val result1 = df.selectExpr( + "try_make_timestamp_ltz(year, month, day, hour, min, sec, timezone)") + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + val result2 = df.select(make_timestamp_ltz( + col("year"), col("month"), col("day"), col("hour"), + col("min"), col("sec"), col("timezone"))) + checkAnswer(result1, result2) + } + + val result3 = df.selectExpr( + "try_make_timestamp_ltz(year, month, day, hour, min, sec)") + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + val result4 = df.select(make_timestamp_ltz( + col("year"), col("month"), col("day"), col("hour"), + col("min"), col("sec"))) + checkAnswer(result3, result4) + } + + val result5 = df.selectExpr( + "try_make_timestamp_ltz(year, month, day, hour, min, sec)") + val result6 = df.select(try_make_timestamp_ltz( + col("year"), col("month"), col("day"), col("hour"), + col("min"), col("sec"))) + checkAnswer(result5, result6) + } + test("make_ym_interval") { val df = Seq((100, 11)).toDF("year", "month") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala index 2e0983fe0319c..3e896ae69b686 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala @@ -97,11 +97,13 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest test("INVALID_FRACTION_OF_SECOND: in the function make_timestamp") { checkError( exception = intercept[SparkDateTimeException] { - sql("select make_timestamp(2012, 11, 30, 9, 19, 60.66666666)").collect() + sql("select make_timestamp(2012, 11, 30, 9, 19, 60.1)").collect() }, condition = "INVALID_FRACTION_OF_SECOND", sqlState = "22023", - parameters = Map("ansiConfig" -> ansiConf)) + parameters = Map( + "secAndMicros" -> "60.100000" + )) } test("NUMERIC_VALUE_OUT_OF_RANGE: cast string to decimal") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/WatermarkTrackerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/WatermarkTrackerSuite.scala new file mode 100644 index 0000000000000..6018d286fc21e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/WatermarkTrackerSuite.scala @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.UUID + +import scala.collection.mutable + +import org.apache.spark.sql.execution.{SparkPlan, UnionExec} +import org.apache.spark.sql.functions.timestamp_seconds +import org.apache.spark.sql.streaming.StreamTest + +class WatermarkTrackerSuite extends StreamTest { + + import testImplicits._ + + test("SPARK-50046 proper watermark advancement with dropped watermark nodes") { + val inputStream1 = MemoryStream[Int] + val inputStream2 = MemoryStream[Int] + val inputStream3 = MemoryStream[Int] + + val df1 = inputStream1.toDF() + .withColumn("eventTime", timestamp_seconds($"value")) + .withWatermark("eventTime", "10 seconds") + + val df2 = inputStream2.toDF() + .withColumn("eventTime", timestamp_seconds($"value")) + .withWatermark("eventTime", "20 seconds") + + val df3 = inputStream3.toDF() + .withColumn("eventTime", timestamp_seconds($"value")) + .withWatermark("eventTime", "30 seconds") + + val union = df1.union(df2).union(df3) + + testStream(union)( + // just to ensure that executedPlan has watermark nodes for every stream. + MultiAddData( + (inputStream1, Seq(0)), + (inputStream2, Seq(0)), + (inputStream3, Seq(0)) + ), + ProcessAllAvailable(), + Execute { q => + val initialPlan = q.logicalPlan + val executedPlan = q.lastExecution.executedPlan + + val tracker = WatermarkTracker(spark.conf, initialPlan) + tracker.setWatermark(5) + + val delayMsToNodeId = executedPlan.collect { + case e: EventTimeWatermarkExec => e.delayMs -> e.nodeId + }.toMap + + def setupScenario( + data: Map[Long, Seq[Long]])(fnToPruneSubtree: UnionExec => UnionExec): SparkPlan = { + val eventTimeStatsMap = new mutable.HashMap[Long, EventTimeStatsAccum]() + executedPlan.foreach { + case e: EventTimeWatermarkExec => + eventTimeStatsMap.put(e.delayMs, e.eventTimeStats) + + case _ => + } + + data.foreach { case (delayMs, values) => + val stats = eventTimeStatsMap(delayMs) + values.foreach { value => + stats.add(value) + } + } + + executedPlan.transform { + case e: UnionExec => fnToPruneSubtree(e) + } + } + + def verifyWatermarkMap(expectation: Map[UUID, Option[Long]]): Unit = { + expectation.foreach { case (nodeId, watermarkValue) => + assert(tracker.watermarkMap(nodeId) === watermarkValue, + s"Watermark value for nodeId $nodeId is ${tracker.watermarkMap(nodeId)}, where " + + s"we expect $watermarkValue") + } + } + + // Before SPARK-50046, WatermarkTracker simply assumes that the watermark node won't + // be ever dropped, and the order of watermark nodes won't be changed. We don't find + // a case which breaks this, but it had been happening for other operators (e.g. + // PruneFilters), hence we would be better to guard against this in prior. + + // Scenario: We have three streams with watermark defined per stream. The query has + // executed the first batch in the query run, and (due to some reason) Spark drops one + // of subtrees. This should be considered like stream being a part of dropped subtree + // had no data (because we do not know), hence watermark should not be advanced. But + // before SPARK-50046, WatermarkTracker does not indicate there were watermark node being + // dropped, hence watermark is advanced based on the calculation with remaining two + // streams. + + val executedPlanFor1stBatch = setupScenario( + Map( + // watermark value for this node: 22 - 10 = 12 + 10000L -> Seq(20000L, 21000L, 22000L), + // watermark value for this node: 42 - 20 = 22 + 20000L -> Seq(40000L, 41000L, 42000L), + // watermark value for this node: 62 - 30 = 32 + 30000L -> Seq(60000L, 61000L, 62000L) + ) + ) { unionExec => + // drop the subtree which has watermark node having delay 10 seconds + unionExec.copy(unionExec.children.drop(1)) + } + + tracker.updateWatermark(executedPlanFor1stBatch) + + // watermark hasn't advanced, hence taking default value. + assert(tracker.currentWatermark === 5) + + verifyWatermarkMap( + Map( + delayMsToNodeId(10000L) -> None, + delayMsToNodeId(20000L) -> Some(22000L), + delayMsToNodeId(30000L) -> Some(32000L)) + ) + + // NOTE: Before SPARK-50046, the above verification failed and the below verification works. + // WatermarkTracker can't track the dropped node, hence it advances the watermark from the + // remaining nodes, hence min(22, 32) = 22 + // + // assert(tracker.currentWatermark === 22000) + // + // WatermarkTracker updates the map with shifted index. It should only update index 1 and + // 2, but it updates 0 and 1. + // verifyWatermarkMap(Map(0 -> Some(22000L), 1 -> Some(32000L))) + + // Scenario: after the first batch, the query has executed the second batch. In the second + // batch, and (due to some reason) Spark only retains the middle of the subtrees. Before + // SPARK-50046, WatermarkTracker only tracks the watermark nodes from physical plan with + // index, hence the watermark node for the index 1 in logical plan is shifted to index 0, + // updating the map incorrectly and also advancing the watermark. The correct behavior is, + // the watermark node for the first stream has been dropped for both batches, hence + // watermark must not be advanced. + + val executedPlanFor2ndBatch = setupScenario( + Map( + // watermark value for this node: 52 - 10 = 42 + 10000L -> Seq(50000L, 51000L, 52000L), + // watermark value for this node: 72 - 20 = 52 + 20000L -> Seq(70000L, 71000L, 72000L), + // watermark value for this node: 92 - 30 = 62 + 30000L -> Seq(90000L, 91000L, 92000L) + ) + ) { unionExec => + // only take the middle of the subtree, dropping remaining + unionExec.copy(Seq(unionExec.children(1))) + } + + tracker.updateWatermark(executedPlanFor2ndBatch) + + // watermark hasn't advanced, hence taking default value. + assert(tracker.currentWatermark === 5) + + // WatermarkTracker properly updates the map for the middle of watermark node. + verifyWatermarkMap( + Map( + delayMsToNodeId(10000L) -> None, + delayMsToNodeId(20000L) -> Some(52000L), + delayMsToNodeId(30000L) -> Some(32000L)) + ) + } + ) + } +}