diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index aa51b98fa3c58..52d7e1f4daa53 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -23,7 +23,7 @@ Suggests: testthat, e1071, survival, - arrow + arrow (>= 0.15.1) Collate: 'schema.R' 'generics.R' diff --git a/appveyor.yml b/appveyor.yml index fc0b7d53ddabc..a4da5f9040ded 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -48,7 +48,8 @@ install: build_script: # '-Djna.nosys=true' is required to avoid kernel32.dll load failure. # See SPARK-28759. - - cmd: mvn -DskipTests -Psparkr -Phive -Djna.nosys=true package + # Ideally we should check the tests related to Hive in SparkR as well (SPARK-31745). + - cmd: mvn -DskipTests -Psparkr -Djna.nosys=true package environment: NOT_CRAN: true diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index 57b86254ab424..8a01b80c4164b 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -19,6 +19,8 @@ # This script builds and pushes docker images when run from a release of Spark # with Kubernetes support. +set -x + function error { echo "$@" 1>&2 exit 1 @@ -172,6 +174,7 @@ function build { local BASEDOCKERFILE=${BASEDOCKERFILE:-"kubernetes/dockerfiles/spark/Dockerfile"} local PYDOCKERFILE=${PYDOCKERFILE:-false} local RDOCKERFILE=${RDOCKERFILE:-false} + local ARCHS=${ARCHS:-"--platform linux/amd64,linux/arm64"} (cd $(img_ctx_dir base) && docker build $NOCACHEARG "${BUILD_ARGS[@]}" \ -t $(image_ref spark) \ @@ -179,6 +182,11 @@ function build { if [ $? -ne 0 ]; then error "Failed to build Spark JVM Docker image, please refer to Docker build output for details." fi + if [ "${CROSS_BUILD}" != "false" ]; then + (cd $(img_ctx_dir base) && docker buildx build $ARCHS $NOCACHEARG "${BUILD_ARGS[@]}" \ + -t $(image_ref spark) \ + -f "$BASEDOCKERFILE" .) + fi if [ "${PYDOCKERFILE}" != "false" ]; then (cd $(img_ctx_dir pyspark) && docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ @@ -187,6 +195,11 @@ function build { if [ $? -ne 0 ]; then error "Failed to build PySpark Docker image, please refer to Docker build output for details." fi + if [ "${CROSS_BUILD}" != "false" ]; then + (cd $(img_ctx_dir pyspark) && docker buildx build $ARCHS $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ + -t $(image_ref spark-py) \ + -f "$PYDOCKERFILE" .) + fi fi if [ "${RDOCKERFILE}" != "false" ]; then @@ -196,6 +209,11 @@ function build { if [ $? -ne 0 ]; then error "Failed to build SparkR Docker image, please refer to Docker build output for details." fi + if [ "${CROSS_BUILD}" != "false" ]; then + (cd $(img_ctx_dir sparkr) && docker buildx build $ARCHS $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ + -t $(image_ref spark-r) \ + -f "$RDOCKERFILE" .) + fi fi } @@ -227,6 +245,8 @@ Options: -n Build docker image with --no-cache -u uid UID to use in the USER directive to set the user the main Spark process runs as inside the resulting container + -X Use docker buildx to cross build. Automatically pushes. + See https://docs.docker.com/buildx/working-with-buildx/ for steps to setup buildx. -b arg Build arg to build or push the image. For multiple build args, this option needs to be used separately for each build arg. @@ -252,6 +272,12 @@ Examples: - Build and push JDK11-based image with tag "v3.0.0" to docker.io/myrepo $0 -r docker.io/myrepo -t v3.0.0 -b java_image_tag=11-jre-slim build $0 -r docker.io/myrepo -t v3.0.0 push + + - Build and push JDK11-based image for multiple archs to docker.io/myrepo + $0 -r docker.io/myrepo -t v3.0.0 -X -b java_image_tag=11-jre-slim build + # Note: buildx, which does cross building, needs to do the push during build + # So there is no seperate push step with -X + EOF } @@ -268,7 +294,8 @@ RDOCKERFILE= NOCACHEARG= BUILD_PARAMS= SPARK_UID= -while getopts f:p:R:mr:t:nb:u: option +CROSS_BUILD="false" +while getopts f:p:R:mr:t:Xnb:u: option do case "${option}" in @@ -279,6 +306,7 @@ do t) TAG=${OPTARG};; n) NOCACHEARG="--no-cache";; b) BUILD_PARAMS=${BUILD_PARAMS}" --build-arg "${OPTARG};; + X) CROSS_BUILD=1;; m) if ! which minikube 1>/dev/null; then error "Cannot find minikube." diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java index 8c05288fb4111..33865a21ea914 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java @@ -229,8 +229,6 @@ public class ShuffleMetrics implements MetricSet { private final Meter blockTransferRateBytes = new Meter(); // Number of active connections to the shuffle service private Counter activeConnections = new Counter(); - // Number of registered connections to the shuffle service - private Counter registeredConnections = new Counter(); // Number of exceptions caught in connections to the shuffle service private Counter caughtExceptions = new Counter(); @@ -242,7 +240,6 @@ public ShuffleMetrics() { allMetrics.put("registeredExecutorsSize", (Gauge) () -> blockManager.getRegisteredExecutorsSize()); allMetrics.put("numActiveConnections", activeConnections); - allMetrics.put("numRegisteredConnections", registeredConnections); allMetrics.put("numCaughtExceptions", caughtExceptions); } diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index c41efbad8ffec..3d14318bf90f0 100644 --- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -198,6 +198,7 @@ protected void serviceInit(Configuration conf) throws Exception { // register metrics on the block handler into the Node Manager's metrics system. blockHandler.getAllMetrics().getMetrics().put("numRegisteredConnections", shuffleServer.getRegisteredConnections()); + blockHandler.getAllMetrics().getMetrics().putAll(shuffleServer.getAllMetrics().getMetrics()); YarnShuffleServiceMetrics serviceMetrics = new YarnShuffleServiceMetrics(blockHandler.getAllMetrics()); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala b/common/tags/src/test/java/org/apache/spark/tags/ChromeUITest.java similarity index 76% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala rename to common/tags/src/test/java/org/apache/spark/tags/ChromeUITest.java index 134376628ae7f..e3fed3d656d20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala +++ b/common/tags/src/test/java/org/apache/spark/tags/ChromeUITest.java @@ -15,17 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.execution +package org.apache.spark.tags; -/** - * Physical execution operators for join operations. - */ -package object joins { - - sealed abstract class BuildSide - - case object BuildRight extends BuildSide +import java.lang.annotation.*; - case object BuildLeft extends BuildSide +import org.scalatest.TagAnnotation; -} +@TagAnnotation +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.METHOD, ElementType.TYPE}) +public @interface ChromeUITest { } diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index df39ad8b0dcc2..3c003f45ed27a 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -62,6 +62,7 @@ # Generic options for the daemons used in the standalone deploy mode # - SPARK_CONF_DIR Alternate conf dir. (Default: ${SPARK_HOME}/conf) # - SPARK_LOG_DIR Where log files are stored. (Default: ${SPARK_HOME}/logs) +# - SPARK_LOG_MAX_FILES Max log files of Spark daemons can rotate to. Default is 5. # - SPARK_PID_DIR Where the pid file is stored. (Default: /tmp) # - SPARK_IDENT_STRING A string representing this instance of spark. (Default: $USER) # - SPARK_NICENESS The scheduling priority for daemons. (Default: 0) diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java index 731f6fc767dfd..579e7ff320f5c 100644 --- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -162,6 +162,11 @@ public void onSpeculativeTaskSubmitted(SparkListenerSpeculativeTaskSubmitted spe onEvent(speculativeTask); } + @Override + public void onResourceProfileAdded(SparkListenerResourceProfileAdded event) { + onEvent(event); + } + @Override public void onOtherEvent(SparkListenerEvent event) { onEvent(event); diff --git a/core/src/main/resources/org/apache/spark/ui/static/bootstrap-tooltip.js b/core/src/main/resources/org/apache/spark/ui/static/bootstrap-tooltip.js deleted file mode 100644 index acd6096e6743e..0000000000000 --- a/core/src/main/resources/org/apache/spark/ui/static/bootstrap-tooltip.js +++ /dev/null @@ -1,361 +0,0 @@ -/* =========================================================== - * bootstrap-tooltip.js v2.3.2 - * http://getbootstrap.com/2.3.2/javascript.html#tooltips - * Inspired by the original jQuery.tipsy by Jason Frame - * =========================================================== - * Copyright 2013 Twitter, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ========================================================== */ - - -!function ($) { - - "use strict"; // jshint ;_; - - - /* TOOLTIP PUBLIC CLASS DEFINITION - * =============================== */ - - var Tooltip = function (element, options) { - this.init('tooltip', element, options) - } - - Tooltip.prototype = { - - constructor: Tooltip - - , init: function (type, element, options) { - var eventIn - , eventOut - , triggers - , trigger - , i - - this.type = type - this.$element = $(element) - this.options = this.getOptions(options) - this.enabled = true - - triggers = this.options.trigger.split(' ') - - for (i = triggers.length; i--;) { - trigger = triggers[i] - if (trigger == 'click') { - this.$element.on('click.' + this.type, this.options.selector, $.proxy(this.toggle, this)) - } else if (trigger != 'manual') { - eventIn = trigger == 'hover' ? 'mouseenter' : 'focus' - eventOut = trigger == 'hover' ? 'mouseleave' : 'blur' - this.$element.on(eventIn + '.' + this.type, this.options.selector, $.proxy(this.enter, this)) - this.$element.on(eventOut + '.' + this.type, this.options.selector, $.proxy(this.leave, this)) - } - } - - this.options.selector ? - (this._options = $.extend({}, this.options, { trigger: 'manual', selector: '' })) : - this.fixTitle() - } - - , getOptions: function (options) { - options = $.extend({}, $.fn[this.type].defaults, this.$element.data(), options) - - if (options.delay && typeof options.delay == 'number') { - options.delay = { - show: options.delay - , hide: options.delay - } - } - - return options - } - - , enter: function (e) { - var defaults = $.fn[this.type].defaults - , options = {} - , self - - this._options && $.each(this._options, function (key, value) { - if (defaults[key] != value) options[key] = value - }, this) - - self = $(e.currentTarget)[this.type](options).data(this.type) - - if (!self.options.delay || !self.options.delay.show) return self.show() - - clearTimeout(this.timeout) - self.hoverState = 'in' - this.timeout = setTimeout(function() { - if (self.hoverState == 'in') self.show() - }, self.options.delay.show) - } - - , leave: function (e) { - var self = $(e.currentTarget)[this.type](this._options).data(this.type) - - if (this.timeout) clearTimeout(this.timeout) - if (!self.options.delay || !self.options.delay.hide) return self.hide() - - self.hoverState = 'out' - this.timeout = setTimeout(function() { - if (self.hoverState == 'out') self.hide() - }, self.options.delay.hide) - } - - , show: function () { - var $tip - , pos - , actualWidth - , actualHeight - , placement - , tp - , e = $.Event('show') - - if (this.hasContent() && this.enabled) { - this.$element.trigger(e) - if (e.isDefaultPrevented()) return - $tip = this.tip() - this.setContent() - - if (this.options.animation) { - $tip.addClass('fade') - } - - placement = typeof this.options.placement == 'function' ? - this.options.placement.call(this, $tip[0], this.$element[0]) : - this.options.placement - - $tip - .detach() - .css({ top: 0, left: 0, display: 'block' }) - - this.options.container ? $tip.appendTo(this.options.container) : $tip.insertAfter(this.$element) - - pos = this.getPosition() - - actualWidth = $tip[0].offsetWidth - actualHeight = $tip[0].offsetHeight - - switch (placement) { - case 'bottom': - tp = {top: pos.top + pos.height, left: pos.left + pos.width / 2 - actualWidth / 2} - break - case 'top': - tp = {top: pos.top - actualHeight, left: pos.left + pos.width / 2 - actualWidth / 2} - break - case 'left': - tp = {top: pos.top + pos.height / 2 - actualHeight / 2, left: pos.left - actualWidth} - break - case 'right': - tp = {top: pos.top + pos.height / 2 - actualHeight / 2, left: pos.left + pos.width} - break - } - - this.applyPlacement(tp, placement) - this.$element.trigger('shown') - } - } - - , applyPlacement: function(offset, placement){ - var $tip = this.tip() - , width = $tip[0].offsetWidth - , height = $tip[0].offsetHeight - , actualWidth - , actualHeight - , delta - , replace - - $tip - .offset(offset) - .addClass(placement) - .addClass('in') - - actualWidth = $tip[0].offsetWidth - actualHeight = $tip[0].offsetHeight - - if (placement == 'top' && actualHeight != height) { - offset.top = offset.top + height - actualHeight - replace = true - } - - if (placement == 'bottom' || placement == 'top') { - delta = 0 - - if (offset.left < 0){ - delta = offset.left * -2 - offset.left = 0 - $tip.offset(offset) - actualWidth = $tip[0].offsetWidth - actualHeight = $tip[0].offsetHeight - } - - this.replaceArrow(delta - width + actualWidth, actualWidth, 'left') - } else { - this.replaceArrow(actualHeight - height, actualHeight, 'top') - } - - if (replace) $tip.offset(offset) - } - - , replaceArrow: function(delta, dimension, position){ - this - .arrow() - .css(position, delta ? (50 * (1 - delta / dimension) + "%") : '') - } - - , setContent: function () { - var $tip = this.tip() - , title = this.getTitle() - - $tip.find('.tooltip-inner')[this.options.html ? 'html' : 'text'](title) - $tip.removeClass('fade in top bottom left right') - } - - , hide: function () { - var that = this - , $tip = this.tip() - , e = $.Event('hide') - - this.$element.trigger(e) - if (e.isDefaultPrevented()) return - - $tip.removeClass('in') - - function removeWithAnimation() { - var timeout = setTimeout(function () { - $tip.off($.support.transition.end).detach() - }, 500) - - $tip.one($.support.transition.end, function () { - clearTimeout(timeout) - $tip.detach() - }) - } - - $.support.transition && this.$tip.hasClass('fade') ? - removeWithAnimation() : - $tip.detach() - - this.$element.trigger('hidden') - - return this - } - - , fixTitle: function () { - var $e = this.$element - if ($e.attr('title') || typeof($e.attr('data-original-title')) != 'string') { - $e.attr('data-original-title', $e.attr('title') || '').attr('title', '') - } - } - - , hasContent: function () { - return this.getTitle() - } - - , getPosition: function () { - var el = this.$element[0] - return $.extend({}, (typeof el.getBoundingClientRect == 'function') ? el.getBoundingClientRect() : { - width: el.offsetWidth - , height: el.offsetHeight - }, this.$element.offset()) - } - - , getTitle: function () { - var title - , $e = this.$element - , o = this.options - - title = $e.attr('data-original-title') - || (typeof o.title == 'function' ? o.title.call($e[0]) : o.title) - - return title - } - - , tip: function () { - return this.$tip = this.$tip || $(this.options.template) - } - - , arrow: function(){ - return this.$arrow = this.$arrow || this.tip().find(".tooltip-arrow") - } - - , validate: function () { - if (!this.$element[0].parentNode) { - this.hide() - this.$element = null - this.options = null - } - } - - , enable: function () { - this.enabled = true - } - - , disable: function () { - this.enabled = false - } - - , toggleEnabled: function () { - this.enabled = !this.enabled - } - - , toggle: function (e) { - var self = e ? $(e.currentTarget)[this.type](this._options).data(this.type) : this - self.tip().hasClass('in') ? self.hide() : self.show() - } - - , destroy: function () { - this.hide().$element.off('.' + this.type).removeData(this.type) - } - - } - - - /* TOOLTIP PLUGIN DEFINITION - * ========================= */ - - var old = $.fn.tooltip - - $.fn.tooltip = function ( option ) { - return this.each(function () { - var $this = $(this) - , data = $this.data('tooltip') - , options = typeof option == 'object' && option - if (!data) $this.data('tooltip', (data = new Tooltip(this, options))) - if (typeof option == 'string') data[option]() - }) - } - - $.fn.tooltip.Constructor = Tooltip - - $.fn.tooltip.defaults = { - animation: true - , placement: 'top' - , selector: false - , template: '
' - , trigger: 'hover focus' - , title: '' - , delay: 0 - , html: false - , container: false - } - - - /* TOOLTIP NO CONFLICT - * =================== */ - - $.fn.tooltip.noConflict = function () { - $.fn.tooltip = old - return this - } - -}(window.jQuery); diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html b/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html index 0b26bfc5b2d82..0729dfe1cef72 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html +++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html @@ -89,6 +89,7 @@

Executors

Disk Used Cores Resources + Resource Profile Id Active Tasks Failed Tasks Complete Tasks diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js index ec57797ba0909..520edb9cc3e34 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js @@ -119,7 +119,7 @@ function totalDurationColor(totalGCTime, totalDuration) { } var sumOptionalColumns = [3, 4]; -var execOptionalColumns = [5, 6, 9]; +var execOptionalColumns = [5, 6, 9, 10]; var execDataTable; var sumDataTable; @@ -415,6 +415,7 @@ $(document).ready(function () { {data: 'diskUsed', render: formatBytes}, {data: 'totalCores'}, {name: 'resourcesCol', data: 'resources', render: formatResourceCells, orderable: false}, + {name: 'resourceProfileIdCol', data: 'resourceProfileId'}, { data: 'activeTasks', "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) { @@ -461,7 +462,8 @@ $(document).ready(function () { "columnDefs": [ {"visible": false, "targets": 5}, {"visible": false, "targets": 6}, - {"visible": false, "targets": 9} + {"visible": false, "targets": 9}, + {"visible": false, "targets": 10} ], "deferRender": true }; @@ -570,6 +572,7 @@ $(document).ready(function () { "
On Heap Memory
" + "
Off Heap Memory
" + "
Resources
" + + "
Resource Profile Id
" + ""); reselectCheckboxesBasedOnTaskTableState(); diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js index ae02defd9bb9c..474c453643365 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js @@ -173,9 +173,11 @@ function renderDagViz(forJob) { }); metadataContainer().selectAll(".barrier-rdd").each(function() { - var rddId = d3.select(this).text().trim(); - var clusterId = VizConstants.clusterPrefix + rddId; - svg.selectAll("g." + clusterId).classed("barrier", true) + var opId = d3.select(this).text().trim(); + var opClusterId = VizConstants.clusterPrefix + opId; + var stageId = $(this).parents(".stage-metadata").attr("stage-id"); + var stageClusterId = VizConstants.graphPrefix + stageId; + svg.selectAll("g[id=" + stageClusterId + "] g." + opClusterId).classed("barrier", true) }); resizeSvg(svg); @@ -216,7 +218,7 @@ function renderDagVizForJob(svgContainer) { var dot = metadata.select(".dot-file").text(); var stageId = metadata.attr("stage-id"); var containerId = VizConstants.graphPrefix + stageId; - var isSkipped = metadata.attr("skipped") == "true"; + var isSkipped = metadata.attr("skipped") === "true"; var container; if (isSkipped) { container = svgContainer @@ -225,11 +227,8 @@ function renderDagVizForJob(svgContainer) { .attr("skipped", "true"); } else { // Link each graph to the corresponding stage page (TODO: handle stage attempts) - // Use the link from the stage table so it also works for the history server var attemptId = 0; - var stageLink = d3.select("#stage-" + stageId + "-" + attemptId) - .select("a.name-link") - .attr("href"); + var stageLink = uiRoot + appBasePath + "/stages/stage/?id=" + stageId + "&attempt=" + attemptId; container = svgContainer .append("a") .attr("xlink:href", stageLink) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.js b/core/src/main/resources/org/apache/spark/ui/static/webui.js index 0ba461f02317f..4f8409ca2b7c2 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.js +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.js @@ -16,11 +16,16 @@ */ var uiRoot = ""; +var appBasePath = ""; function setUIRoot(val) { uiRoot = val; } +function setAppBasePath(path) { + appBasePath = path; +} + function collapseTablePageLoad(name, table){ if (window.localStorage.getItem(name) == "true") { // Set it to false so that the click function can revert it @@ -33,7 +38,7 @@ function collapseTable(thisName, table){ var status = window.localStorage.getItem(thisName) == "true"; status = !status; - var thisClass = '.' + thisName + var thisClass = '.' + thisName; // Expand the list of additional metrics. var tableDiv = $(thisClass).parent().find('.' + table); diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index 5663055129d19..04faf7f87cf2b 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -21,7 +21,7 @@ import java.util.{Timer, TimerTask} import java.util.concurrent.ConcurrentHashMap import java.util.function.Consumer -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, HashSet} import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} @@ -106,9 +106,11 @@ private[spark] class BarrierCoordinator( // The messages will be replied to all tasks once sync finished. private val messages = Array.ofDim[String](numTasks) - // The request method which is called inside this barrier sync. All tasks should make sure - // that they're calling the same method within the same barrier sync phase. - private var requestMethod: RequestMethod.Value = _ + // Request methods collected from tasks inside this barrier sync. All tasks should make sure + // that they're calling the same method within the same barrier sync phase. In other words, + // the size of requestMethods should always be 1 for a legitimate barrier sync. Otherwise, + // the barrier sync would fail if the size of requestMethods becomes greater than 1. + private val requestMethods = new HashSet[RequestMethod.Value] // A timer task that ensures we may timeout for a barrier() call. private var timerTask: TimerTask = null @@ -141,17 +143,14 @@ private[spark] class BarrierCoordinator( val taskId = request.taskAttemptId val epoch = request.barrierEpoch val curReqMethod = request.requestMethod - - if (requesters.isEmpty) { - requestMethod = curReqMethod - } else if (requestMethod != curReqMethod) { - requesters.foreach( - _.sendFailure(new SparkException(s"$barrierId tried to use requestMethod " + - s"`$curReqMethod` during barrier epoch $barrierEpoch, which does not match " + - s"the current synchronized requestMethod `$requestMethod`" - )) - ) - cleanupBarrierStage(barrierId) + requestMethods.add(curReqMethod) + if (requestMethods.size > 1) { + val error = new SparkException(s"Different barrier sync types found for the " + + s"sync $barrierId: ${requestMethods.mkString(", ")}. Please use the " + + s"same barrier sync type within a single sync.") + (requesters :+ requester).foreach(_.sendFailure(error)) + clear() + return } // Require the number of tasks is correctly set from the BarrierTaskContext. @@ -184,6 +183,7 @@ private[spark] class BarrierCoordinator( s"tasks, finished successfully.") barrierEpoch += 1 requesters.clear() + requestMethods.clear() cancelTimerTask() } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 5c92527b7b80e..38d7319b1f0ef 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -435,7 +435,7 @@ class SparkContext(config: SparkConf) extends Logging { } _listenerBus = new LiveListenerBus(_conf) - _resourceProfileManager = new ResourceProfileManager(_conf) + _resourceProfileManager = new ResourceProfileManager(_conf, _listenerBus) // Initialize the app status store and listener before SparkEnv is created so that it gets // all events. diff --git a/core/src/main/scala/org/apache/spark/SparkException.scala b/core/src/main/scala/org/apache/spark/SparkException.scala index 81c087e314be1..41382133bd84c 100644 --- a/core/src/main/scala/org/apache/spark/SparkException.scala +++ b/core/src/main/scala/org/apache/spark/SparkException.scala @@ -48,5 +48,5 @@ private[spark] case class ExecutorDeadException(message: String) * Exception thrown when Spark returns different result after upgrading to a new version. */ private[spark] class SparkUpgradeException(version: String, message: String, cause: Throwable) - extends SparkException("You may get a different result due to the upgrading of Spark" + + extends RuntimeException("You may get a different result due to the upgrading of Spark" + s" $version: $message", cause) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 99d3eceb1121a..25ea75acc37d3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -108,7 +108,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) private val historyUiAdminAclsGroups = conf.get(History.HISTORY_SERVER_UI_ADMIN_ACLS_GROUPS) logInfo(s"History server ui acls " + (if (historyUiAclsEnable) "enabled" else "disabled") + "; users with admin permissions: " + historyUiAdminAcls.mkString(",") + - "; groups with admin permissions" + historyUiAdminAclsGroups.mkString(",")) + "; groups with admin permissions: " + historyUiAdminAclsGroups.mkString(",")) private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) // Visible for testing diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryAppStatusStore.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryAppStatusStore.scala index 741050027fc6b..7973652b3e254 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryAppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryAppStatusStore.scala @@ -72,7 +72,8 @@ private[spark] class HistoryAppStatusStore( source.totalGCTime, source.totalInputBytes, source.totalShuffleRead, source.totalShuffleWrite, source.isBlacklisted, source.maxMemory, source.addTime, source.removeTime, source.removeReason, newExecutorLogs, source.memoryMetrics, - source.blacklistedInStages, source.peakMemoryMetrics, source.attributes, source.resources) + source.blacklistedInStages, source.peakMemoryMetrics, source.attributes, source.resources, + source.resourceProfileId) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index f9f0308aa5138..aa9e9a6dd4887 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -69,6 +69,9 @@ class HistoryServer( private val loaderServlet = new HttpServlet { protected override def doGet(req: HttpServletRequest, res: HttpServletResponse): Unit = { + + res.setContentType("text/html;charset=utf-8") + // Parse the URI created by getAttemptURI(). It contains an app ID and an optional // attempt ID (separated by a slash). val parts = Option(req.getPathInfo()).getOrElse("").split("/") diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala index 0a1f33395ad62..b1adc3c112ed3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala @@ -122,10 +122,12 @@ private class HistoryServerDiskManager( * being used so that it's not evicted when running out of designated space. */ def openStore(appId: String, attemptId: Option[String]): Option[File] = { + var newSize: Long = 0 val storePath = active.synchronized { val path = appStorePath(appId, attemptId) if (path.isDirectory()) { - active(appId -> attemptId) = sizeOf(path) + newSize = sizeOf(path) + active(appId -> attemptId) = newSize Some(path) } else { None @@ -133,7 +135,7 @@ private class HistoryServerDiskManager( } storePath.foreach { path => - updateAccessTime(appId, attemptId) + updateApplicationStoreInfo(appId, attemptId, newSize) } storePath @@ -238,10 +240,11 @@ private class HistoryServerDiskManager( new File(appStoreDir, fileName) } - private def updateAccessTime(appId: String, attemptId: Option[String]): Unit = { + private def updateApplicationStoreInfo( + appId: String, attemptId: Option[String], newSize: Long): Unit = { val path = appStorePath(appId, attemptId) - val info = ApplicationStoreInfo(path.getAbsolutePath(), clock.getTimeMillis(), appId, attemptId, - sizeOf(path)) + val info = ApplicationStoreInfo(path.getAbsolutePath(), clock.getTimeMillis(), appId, + attemptId, newSize) listing.write(info) } @@ -297,7 +300,7 @@ private class HistoryServerDiskManager( s"exceeded ($current > $max)") } - updateAccessTime(appId, attemptId) + updateApplicationStoreInfo(appId, attemptId, newSize) active.synchronized { active(appId -> attemptId) = newSize diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 2bfa1cea4b26f..45cec726c4ca7 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -34,6 +34,7 @@ import scala.concurrent.duration._ import scala.util.control.NonFatal import com.google.common.util.concurrent.ThreadFactoryBuilder +import org.slf4j.MDC import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil @@ -320,7 +321,12 @@ private[spark] class Executor( val taskId = taskDescription.taskId val threadName = s"Executor task launch worker for task $taskId" - private val taskName = taskDescription.name + val taskName = taskDescription.name + val mdcProperties = taskDescription.properties.asScala + .filter(_._1.startsWith("mdc.")).map { item => + val key = item._1.substring(4) + (key, item._2) + }.toSeq /** If specified, this task has been killed and this option contains the reason. */ @volatile private var reasonIfKilled: Option[String] = None @@ -395,6 +401,9 @@ private[spark] class Executor( } override def run(): Unit = { + + setMDCForTask(taskName, mdcProperties) + threadId = Thread.currentThread.getId Thread.currentThread.setName(threadName) val threadMXBean = ManagementFactory.getThreadMXBean @@ -693,6 +702,14 @@ private[spark] class Executor( } } + private def setMDCForTask(taskName: String, mdc: Seq[(String, String)]): Unit = { + MDC.put("taskName", taskName) + + mdc.foreach { case (key, value) => + MDC.put(key, value) + } + } + /** * Supervises the killing / cancellation of a task by sending the interrupted flag, optionally * sending a Thread.interrupt(), and monitoring the task until it finishes. @@ -733,6 +750,9 @@ private[spark] class Executor( private[this] val takeThreadDump: Boolean = conf.get(TASK_REAPER_THREAD_DUMP) override def run(): Unit = { + + setMDCForTask(taskRunner.taskName, taskRunner.mdcProperties) + val startTimeNs = System.nanoTime() def elapsedTimeNs = System.nanoTime() - startTimeNs def timeoutExceeded(): Boolean = killTimeoutNs > 0 && elapsedTimeNs > killTimeoutNs diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 4cda4b180d97d..8ef0c37198568 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -413,6 +413,34 @@ package object config { .intConf .createWithDefault(1) + private[spark] val STORAGE_DECOMMISSION_ENABLED = + ConfigBuilder("spark.storage.decommission.enabled") + .doc("Whether to decommission the block manager when decommissioning executor") + .version("3.1.0") + .booleanConf + .createWithDefault(false) + + private[spark] val STORAGE_DECOMMISSION_MAX_REPLICATION_FAILURE_PER_BLOCK = + ConfigBuilder("spark.storage.decommission.maxReplicationFailuresPerBlock") + .internal() + .doc("Maximum number of failures which can be handled for the replication of " + + "one RDD block when block manager is decommissioning and trying to move its " + + "existing blocks.") + .version("3.1.0") + .intConf + .createWithDefault(3) + + private[spark] val STORAGE_DECOMMISSION_REPLICATION_REATTEMPT_INTERVAL = + ConfigBuilder("spark.storage.decommission.replicationReattemptInterval") + .internal() + .doc("The interval of time between consecutive cache block replication reattempts " + + "happening on each decommissioning executor (due to storage decommissioning).") + .version("3.1.0") + .timeConf(TimeUnit.MILLISECONDS) + .checkValue(_ > 0, "Time interval between two consecutive attempts of " + + "cache block replication should be positive.") + .createWithDefaultString("30s") + private[spark] val STORAGE_REPLICATION_TOPOLOGY_FILE = ConfigBuilder("spark.storage.replication.topologyFile") .version("2.1.0") diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/PrometheusServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/PrometheusServlet.scala index 7c33bce78378d..59b863b89f75a 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/PrometheusServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/PrometheusServlet.scala @@ -24,15 +24,18 @@ import com.codahale.metrics.MetricRegistry import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.annotation.Experimental import org.apache.spark.ui.JettyUtils._ /** + * :: Experimental :: * This exposes the metrics of the given registry with Prometheus format. * * The output is consistent with /metrics/json result in terms of item ordering * and with the previous result of Spark JMX Sink + Prometheus JMX Converter combination * in terms of key string format. */ +@Experimental private[spark] class PrometheusServlet( val property: Properties, val registry: MetricRegistry, @@ -53,58 +56,65 @@ private[spark] class PrometheusServlet( def getMetricsSnapshot(request: HttpServletRequest): String = { import scala.collection.JavaConverters._ + val guagesLabel = """{type="gauges"}""" + val countersLabel = """{type="counters"}""" + val metersLabel = countersLabel + val histogramslabels = """{type="histograms"}""" + val timersLabels = """{type="timers"}""" + val sb = new StringBuilder() registry.getGauges.asScala.foreach { case (k, v) => if (!v.getValue.isInstanceOf[String]) { - sb.append(s"${normalizeKey(k)}Value ${v.getValue}\n") + sb.append(s"${normalizeKey(k)}Number$guagesLabel ${v.getValue}\n") + sb.append(s"${normalizeKey(k)}Value$guagesLabel ${v.getValue}\n") } } registry.getCounters.asScala.foreach { case (k, v) => - sb.append(s"${normalizeKey(k)}Count ${v.getCount}\n") + sb.append(s"${normalizeKey(k)}Count$countersLabel ${v.getCount}\n") } registry.getHistograms.asScala.foreach { case (k, h) => val snapshot = h.getSnapshot val prefix = normalizeKey(k) - sb.append(s"${prefix}Count ${h.getCount}\n") - sb.append(s"${prefix}Max ${snapshot.getMax}\n") - sb.append(s"${prefix}Mean ${snapshot.getMean}\n") - sb.append(s"${prefix}Min ${snapshot.getMin}\n") - sb.append(s"${prefix}50thPercentile ${snapshot.getMedian}\n") - sb.append(s"${prefix}75thPercentile ${snapshot.get75thPercentile}\n") - sb.append(s"${prefix}95thPercentile ${snapshot.get95thPercentile}\n") - sb.append(s"${prefix}98thPercentile ${snapshot.get98thPercentile}\n") - sb.append(s"${prefix}99thPercentile ${snapshot.get99thPercentile}\n") - sb.append(s"${prefix}999thPercentile ${snapshot.get999thPercentile}\n") - sb.append(s"${prefix}StdDev ${snapshot.getStdDev}\n") + sb.append(s"${prefix}Count$histogramslabels ${h.getCount}\n") + sb.append(s"${prefix}Max$histogramslabels ${snapshot.getMax}\n") + sb.append(s"${prefix}Mean$histogramslabels ${snapshot.getMean}\n") + sb.append(s"${prefix}Min$histogramslabels ${snapshot.getMin}\n") + sb.append(s"${prefix}50thPercentile$histogramslabels ${snapshot.getMedian}\n") + sb.append(s"${prefix}75thPercentile$histogramslabels ${snapshot.get75thPercentile}\n") + sb.append(s"${prefix}95thPercentile$histogramslabels ${snapshot.get95thPercentile}\n") + sb.append(s"${prefix}98thPercentile$histogramslabels ${snapshot.get98thPercentile}\n") + sb.append(s"${prefix}99thPercentile$histogramslabels ${snapshot.get99thPercentile}\n") + sb.append(s"${prefix}999thPercentile$histogramslabels ${snapshot.get999thPercentile}\n") + sb.append(s"${prefix}StdDev$histogramslabels ${snapshot.getStdDev}\n") } registry.getMeters.entrySet.iterator.asScala.foreach { kv => val prefix = normalizeKey(kv.getKey) val meter = kv.getValue - sb.append(s"${prefix}Count ${meter.getCount}\n") - sb.append(s"${prefix}MeanRate ${meter.getMeanRate}\n") - sb.append(s"${prefix}OneMinuteRate ${meter.getOneMinuteRate}\n") - sb.append(s"${prefix}FiveMinuteRate ${meter.getFiveMinuteRate}\n") - sb.append(s"${prefix}FifteenMinuteRate ${meter.getFifteenMinuteRate}\n") + sb.append(s"${prefix}Count$metersLabel ${meter.getCount}\n") + sb.append(s"${prefix}MeanRate$metersLabel ${meter.getMeanRate}\n") + sb.append(s"${prefix}OneMinuteRate$metersLabel ${meter.getOneMinuteRate}\n") + sb.append(s"${prefix}FiveMinuteRate$metersLabel ${meter.getFiveMinuteRate}\n") + sb.append(s"${prefix}FifteenMinuteRate$metersLabel ${meter.getFifteenMinuteRate}\n") } registry.getTimers.entrySet.iterator.asScala.foreach { kv => val prefix = normalizeKey(kv.getKey) val timer = kv.getValue val snapshot = timer.getSnapshot - sb.append(s"${prefix}Count ${timer.getCount}\n") - sb.append(s"${prefix}Max ${snapshot.getMax}\n") - sb.append(s"${prefix}Mean ${snapshot.getMax}\n") - sb.append(s"${prefix}Min ${snapshot.getMin}\n") - sb.append(s"${prefix}50thPercentile ${snapshot.getMedian}\n") - sb.append(s"${prefix}75thPercentile ${snapshot.get75thPercentile}\n") - sb.append(s"${prefix}95thPercentile ${snapshot.get95thPercentile}\n") - sb.append(s"${prefix}98thPercentile ${snapshot.get98thPercentile}\n") - sb.append(s"${prefix}99thPercentile ${snapshot.get99thPercentile}\n") - sb.append(s"${prefix}999thPercentile ${snapshot.get999thPercentile}\n") - sb.append(s"${prefix}StdDev ${snapshot.getStdDev}\n") - sb.append(s"${prefix}FifteenMinuteRate ${timer.getFifteenMinuteRate}\n") - sb.append(s"${prefix}FiveMinuteRate ${timer.getFiveMinuteRate}\n") - sb.append(s"${prefix}OneMinuteRate ${timer.getOneMinuteRate}\n") - sb.append(s"${prefix}MeanRate ${timer.getMeanRate}\n") + sb.append(s"${prefix}Count$timersLabels ${timer.getCount}\n") + sb.append(s"${prefix}Max$timersLabels ${snapshot.getMax}\n") + sb.append(s"${prefix}Mean$timersLabels ${snapshot.getMax}\n") + sb.append(s"${prefix}Min$timersLabels ${snapshot.getMin}\n") + sb.append(s"${prefix}50thPercentile$timersLabels ${snapshot.getMedian}\n") + sb.append(s"${prefix}75thPercentile$timersLabels ${snapshot.get75thPercentile}\n") + sb.append(s"${prefix}95thPercentile$timersLabels ${snapshot.get95thPercentile}\n") + sb.append(s"${prefix}98thPercentile$timersLabels ${snapshot.get98thPercentile}\n") + sb.append(s"${prefix}99thPercentile$timersLabels ${snapshot.get99thPercentile}\n") + sb.append(s"${prefix}999thPercentile$timersLabels ${snapshot.get999thPercentile}\n") + sb.append(s"${prefix}StdDev$timersLabels ${snapshot.getStdDev}\n") + sb.append(s"${prefix}FifteenMinuteRate$timersLabels ${timer.getFifteenMinuteRate}\n") + sb.append(s"${prefix}FiveMinuteRate$timersLabels ${timer.getFiveMinuteRate}\n") + sb.append(s"${prefix}OneMinuteRate$timersLabels ${timer.getOneMinuteRate}\n") + sb.append(s"${prefix}MeanRate$timersLabels ${timer.getMeanRate}\n") } sb.toString() } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 39068a5ab046d..6095042de7f0c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1780,10 +1780,9 @@ abstract class RDD[T: ClassTag]( * It will result in new executors with the resources specified being acquired to * calculate the RDD. */ - // PRIVATE for now, added for testing purposes, will be made public with SPARK-29150 @Experimental - @Since("3.0.0") - private[spark] def withResources(rp: ResourceProfile): this.type = { + @Since("3.1.0") + def withResources(rp: ResourceProfile): this.type = { resourceProfile = Option(rp) sc.resourceProfileManager.addResourceProfile(resourceProfile.get) this @@ -1794,10 +1793,9 @@ abstract class RDD[T: ClassTag]( * @return the user specified ResourceProfile or null (for Java compatibility) if * none was specified */ - // PRIVATE for now, added for testing purposes, will be made public with SPARK-29150 @Experimental - @Since("3.0.0") - private[spark] def getResourceProfile(): ResourceProfile = resourceProfile.getOrElse(null) + @Since("3.1.0") + def getResourceProfile(): ResourceProfile = resourceProfile.getOrElse(null) // ======================================================================= // Other internal methods and fields diff --git a/core/src/main/scala/org/apache/spark/resource/ExecutorResourceRequest.scala b/core/src/main/scala/org/apache/spark/resource/ExecutorResourceRequest.scala index 9a920914ed674..3e3db7e8c8910 100644 --- a/core/src/main/scala/org/apache/spark/resource/ExecutorResourceRequest.scala +++ b/core/src/main/scala/org/apache/spark/resource/ExecutorResourceRequest.scala @@ -17,6 +17,8 @@ package org.apache.spark.resource +import org.apache.spark.annotation.{Evolving, Since} + /** * An Executor resource request. This is used in conjunction with the ResourceProfile to * programmatically specify the resources needed for an RDD that will be applied at the @@ -46,11 +48,10 @@ package org.apache.spark.resource * allocated. The script runs on Executors startup to discover the addresses * of the resources available. * @param vendor Optional vendor, required for some cluster managers - * - * This api is currently private until the rest of the pieces are in place and then it - * will become public. */ -private[spark] class ExecutorResourceRequest( +@Evolving +@Since("3.1.0") +class ExecutorResourceRequest( val resourceName: String, val amount: Long, val discoveryScript: String = "", diff --git a/core/src/main/scala/org/apache/spark/resource/ExecutorResourceRequests.scala b/core/src/main/scala/org/apache/spark/resource/ExecutorResourceRequests.scala index 4ee1a07564042..9da6ffb1d2577 100644 --- a/core/src/main/scala/org/apache/spark/resource/ExecutorResourceRequests.scala +++ b/core/src/main/scala/org/apache/spark/resource/ExecutorResourceRequests.scala @@ -22,6 +22,7 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ +import org.apache.spark.annotation.{Evolving, Since} import org.apache.spark.network.util.JavaUtils import org.apache.spark.resource.ResourceProfile._ @@ -29,11 +30,10 @@ import org.apache.spark.resource.ResourceProfile._ * A set of Executor resource requests. This is used in conjunction with the ResourceProfile to * programmatically specify the resources needed for an RDD that will be applied at the * stage level. - * - * This api is currently private until the rest of the pieces are in place and then it - * will become public. */ -private[spark] class ExecutorResourceRequests() extends Serializable { +@Evolving +@Since("3.1.0") +class ExecutorResourceRequests() extends Serializable { private val _executorResources = new ConcurrentHashMap[String, ExecutorResourceRequest]() diff --git a/core/src/main/scala/org/apache/spark/resource/ResourceDiscoveryScriptPlugin.scala b/core/src/main/scala/org/apache/spark/resource/ResourceDiscoveryScriptPlugin.scala index 7027d1e3511b5..11a9bb86d3034 100644 --- a/core/src/main/scala/org/apache/spark/resource/ResourceDiscoveryScriptPlugin.scala +++ b/core/src/main/scala/org/apache/spark/resource/ResourceDiscoveryScriptPlugin.scala @@ -32,6 +32,8 @@ import org.apache.spark.util.Utils.executeAndGetOutput * and gets the json output back and contructs ResourceInformation objects from that. * If the user specifies custom plugins, this is the last one to be executed and * throws if the resource isn't discovered. + * + * @since 3.0.0 */ @DeveloperApi class ResourceDiscoveryScriptPlugin extends ResourceDiscoveryPlugin with Logging { diff --git a/core/src/main/scala/org/apache/spark/resource/ResourceInformation.scala b/core/src/main/scala/org/apache/spark/resource/ResourceInformation.scala index d5ac41b995559..be056e15b6d03 100644 --- a/core/src/main/scala/org/apache/spark/resource/ResourceInformation.scala +++ b/core/src/main/scala/org/apache/spark/resource/ResourceInformation.scala @@ -33,6 +33,8 @@ import org.apache.spark.annotation.Evolving * * @param name the name of the resource * @param addresses an array of strings describing the addresses of the resource + * + * @since 3.0.0 */ @Evolving class ResourceInformation( diff --git a/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala b/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala index 761c4450ca5f1..1dbdc3d81e44d 100644 --- a/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala +++ b/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.annotation.Evolving +import org.apache.spark.annotation.{Evolving, Since} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.internal.config.Python.PYSPARK_EXECUTOR_MEMORY @@ -37,6 +37,7 @@ import org.apache.spark.internal.config.Python.PYSPARK_EXECUTOR_MEMORY * This is meant to be immutable so user can't change it after building. */ @Evolving +@Since("3.1.0") class ResourceProfile( val executorResources: Map[String, ExecutorResourceRequest], val taskResources: Map[String, TaskResourceRequest]) extends Serializable with Logging { diff --git a/core/src/main/scala/org/apache/spark/resource/ResourceProfileBuilder.scala b/core/src/main/scala/org/apache/spark/resource/ResourceProfileBuilder.scala index 26f23f4bf0476..29a117b47fe95 100644 --- a/core/src/main/scala/org/apache/spark/resource/ResourceProfileBuilder.scala +++ b/core/src/main/scala/org/apache/spark/resource/ResourceProfileBuilder.scala @@ -22,16 +22,19 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ -import org.apache.spark.annotation.Evolving +import org.apache.spark.annotation.{Evolving, Since} + /** * Resource profile builder to build a Resource profile to associate with an RDD. * A ResourceProfile allows the user to specify executor and task requirements for an RDD * that will get applied during a stage. This allows the user to change the resource * requirements between stages. + * */ @Evolving -private[spark] class ResourceProfileBuilder() { +@Since("3.1.0") +class ResourceProfileBuilder() { private val _taskResources = new ConcurrentHashMap[String, TaskResourceRequest]() private val _executorResources = new ConcurrentHashMap[String, ExecutorResourceRequest]() diff --git a/core/src/main/scala/org/apache/spark/resource/ResourceProfileManager.scala b/core/src/main/scala/org/apache/spark/resource/ResourceProfileManager.scala index c3e244474a692..f365548c75359 100644 --- a/core/src/main/scala/org/apache/spark/resource/ResourceProfileManager.scala +++ b/core/src/main/scala/org/apache/spark/resource/ResourceProfileManager.scala @@ -25,17 +25,19 @@ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.annotation.Evolving import org.apache.spark.internal.Logging import org.apache.spark.internal.config.Tests._ +import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerResourceProfileAdded} import org.apache.spark.util.Utils import org.apache.spark.util.Utils.isTesting /** * Manager of resource profiles. The manager allows one place to keep the actual ResourceProfiles * and everywhere else we can use the ResourceProfile Id to save on space. - * Note we never remove a resource profile at this point. Its expected this number if small + * Note we never remove a resource profile at this point. Its expected this number is small * so this shouldn't be much overhead. */ @Evolving -private[spark] class ResourceProfileManager(sparkConf: SparkConf) extends Logging { +private[spark] class ResourceProfileManager(sparkConf: SparkConf, + listenerBus: LiveListenerBus) extends Logging { private val resourceProfileIdToResourceProfile = new HashMap[Int, ResourceProfile]() private val (readLock, writeLock) = { @@ -83,6 +85,7 @@ private[spark] class ResourceProfileManager(sparkConf: SparkConf) extends Loggin // force the computation of maxTasks and limitingResource now so we don't have cost later rp.limitingResource(sparkConf) logInfo(s"Added ResourceProfile id: ${rp.id}") + listenerBus.post(SparkListenerResourceProfileAdded(rp)) } } diff --git a/core/src/main/scala/org/apache/spark/resource/TaskResourceRequest.scala b/core/src/main/scala/org/apache/spark/resource/TaskResourceRequest.scala index bffb0a2f523b1..d3f979fa8672f 100644 --- a/core/src/main/scala/org/apache/spark/resource/TaskResourceRequest.scala +++ b/core/src/main/scala/org/apache/spark/resource/TaskResourceRequest.scala @@ -17,17 +17,18 @@ package org.apache.spark.resource +import org.apache.spark.annotation.{Evolving, Since} + /** * A task resource request. This is used in conjuntion with the ResourceProfile to * programmatically specify the resources needed for an RDD that will be applied at the * stage level. * * Use TaskResourceRequests class as a convenience API. - * - * This api is currently private until the rest of the pieces are in place and then it - * will become public. */ -private[spark] class TaskResourceRequest(val resourceName: String, val amount: Double) +@Evolving +@Since("3.1.0") +class TaskResourceRequest(val resourceName: String, val amount: Double) extends Serializable { assert(amount <= 0.5 || amount % 1 == 0, diff --git a/core/src/main/scala/org/apache/spark/resource/TaskResourceRequests.scala b/core/src/main/scala/org/apache/spark/resource/TaskResourceRequests.scala index 09f4e02eee9e0..b4e70b3b046ce 100644 --- a/core/src/main/scala/org/apache/spark/resource/TaskResourceRequests.scala +++ b/core/src/main/scala/org/apache/spark/resource/TaskResourceRequests.scala @@ -22,17 +22,17 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ +import org.apache.spark.annotation.{Evolving, Since} import org.apache.spark.resource.ResourceProfile._ /** * A set of task resource requests. This is used in conjunction with the ResourceProfile to * programmatically specify the resources needed for an RDD that will be applied at the * stage level. - * - * This api is currently private until the rest of the pieces are in place and then it - * will become public. */ -private[spark] class TaskResourceRequests() extends Serializable { +@Evolving +@Since("3.1.0") +class TaskResourceRequests() extends Serializable { private val _taskResources = new ConcurrentHashMap[String, TaskResourceRequest]() diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 24e2a5e4d4a62..b2e9a0b2a04e8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -235,6 +235,10 @@ private[spark] class EventLoggingListener( } } + override def onResourceProfileAdded(event: SparkListenerResourceProfileAdded): Unit = { + logEvent(event, flushLogger = true) + } + override def onOtherEvent(event: SparkListenerEvent): Unit = { if (event.logEvent) { logEvent(event, flushLogger = true) diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index c150b0341500c..62d54f3b74a47 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -27,6 +27,7 @@ import com.fasterxml.jackson.annotation.JsonTypeInfo import org.apache.spark.TaskEndReason import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} +import org.apache.spark.resource.ResourceProfile import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo} @@ -207,6 +208,10 @@ case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent @DeveloperApi case class SparkListenerLogStart(sparkVersion: String) extends SparkListenerEvent +@DeveloperApi +case class SparkListenerResourceProfileAdded(resourceProfile: ResourceProfile) + extends SparkListenerEvent + /** * Interface for listening to events from the Spark scheduler. Most applications should probably * extend SparkListener or SparkFirehoseListener directly, rather than implementing this class. @@ -348,6 +353,11 @@ private[spark] trait SparkListenerInterface { * Called when other events like SQL-specific events are posted. */ def onOtherEvent(event: SparkListenerEvent): Unit + + /** + * Called when a Resource Profile is added to the manager. + */ + def onResourceProfileAdded(event: SparkListenerResourceProfileAdded): Unit } @@ -421,4 +431,6 @@ abstract class SparkListener extends SparkListenerInterface { speculativeTask: SparkListenerSpeculativeTaskSubmitted): Unit = { } override def onOtherEvent(event: SparkListenerEvent): Unit = { } + + override def onResourceProfileAdded(event: SparkListenerResourceProfileAdded): Unit = { } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 8f6b7ad309602..3d316c948db7e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -79,6 +79,8 @@ private[spark] trait SparkListenerBus listener.onBlockUpdated(blockUpdated) case speculativeTaskSubmitted: SparkListenerSpeculativeTaskSubmitted => listener.onSpeculativeTaskSubmitted(speculativeTaskSubmitted) + case resourceProfileAdded: SparkListenerResourceProfileAdded => + listener.onResourceProfileAdded(resourceProfileAdded) case _ => listener.onOtherEvent(event) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index a0e84b94735ec..a302f680a272e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -1107,10 +1107,19 @@ private[spark] class TaskSetManager( def recomputeLocality(): Unit = { // A zombie TaskSetManager may reach here while executorLost happens if (isZombie) return + val previousLocalityIndex = currentLocalityIndex val previousLocalityLevel = myLocalityLevels(currentLocalityIndex) + val previousMyLocalityLevels = myLocalityLevels myLocalityLevels = computeValidLocalityLevels() localityWaits = myLocalityLevels.map(getLocalityWait) currentLocalityIndex = getLocalityIndex(previousLocalityLevel) + if (currentLocalityIndex > previousLocalityIndex) { + // SPARK-31837: If the new level is more local, shift to the new most local locality + // level in terms of better data locality. For example, say the previous locality + // levels are [PROCESS, NODE, ANY] and current level is ANY. After recompute, the + // locality levels are [PROCESS, NODE, RACK, ANY]. Then, we'll shift to RACK level. + currentLocalityIndex = getLocalityIndex(myLocalityLevels.diff(previousMyLocalityLevels).head) + } } def executorAdded(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 701d69ba43498..67638a5f9593c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -438,6 +438,19 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp logError(s"Unexpected error during decommissioning ${e.toString}", e) } logInfo(s"Finished decommissioning executor $executorId.") + + if (conf.get(STORAGE_DECOMMISSION_ENABLED)) { + try { + logInfo("Starting decommissioning block manager corresponding to " + + s"executor $executorId.") + scheduler.sc.env.blockManager.master.decommissionBlockManagers(Seq(executorId)) + } catch { + case e: Exception => + logError("Unexpected error during block manager " + + s"decommissioning for executor $executorId: ${e.toString}", e) + } + logInfo(s"Acknowledged decommissioning block manager corresponding to $executorId.") + } } else { logInfo(s"Skipping decommissioning of executor $executorId.") } @@ -574,7 +587,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp */ private[spark] def decommissionExecutor(executorId: String): Unit = { if (driverEndpoint != null) { - logInfo("Propegating executor decommission to driver.") + logInfo("Propagating executor decommission to driver.") driverEndpoint.send(DecommissionExecutor(executorId)) } } @@ -658,7 +671,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp /** * Update the cluster manager on our scheduling needs. Three bits of information are included * to help it make decisions. - * @param resourceProfileToNumExecutors The total number of executors we'd like to have per + * @param resourceProfileIdToNumExecutors The total number of executors we'd like to have per * ResourceProfile. The cluster manager shouldn't kill any * running executor to reach this number, but, if all * existing executors were to die, this is the number diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index c3f22f32993a8..f7b0e9b62fc29 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -28,6 +28,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.Logging import org.apache.spark.internal.config.CPUS_PER_TASK import org.apache.spark.internal.config.Status._ +import org.apache.spark.resource.ResourceProfile.CPUS import org.apache.spark.scheduler._ import org.apache.spark.status.api.v1 import org.apache.spark.storage._ @@ -51,7 +52,7 @@ private[spark] class AppStatusListener( private var sparkVersion = SPARK_VERSION private var appInfo: v1.ApplicationInfo = null private var appSummary = new AppSummary(0, 0) - private var coresPerTask: Int = 1 + private var defaultCpusPerTask: Int = 1 // How often to update live entities. -1 means "never update" when replaying applications, // meaning only the last write will happen. For live applications, this avoids a few @@ -76,6 +77,7 @@ private[spark] class AppStatusListener( private val liveTasks = new HashMap[Long, LiveTask]() private val liveRDDs = new HashMap[Int, LiveRDD]() private val pools = new HashMap[String, SchedulerPool]() + private val liveResourceProfiles = new HashMap[Int, LiveResourceProfile]() private val SQL_EXECUTION_ID_KEY = "spark.sql.execution.id" // Keep the active executor count as a separate variable to avoid having to do synchronization @@ -145,6 +147,20 @@ private[spark] class AppStatusListener( } } + override def onResourceProfileAdded(event: SparkListenerResourceProfileAdded): Unit = { + val maxTasks = if (event.resourceProfile.isCoresLimitKnown) { + Some(event.resourceProfile.maxTasksPerExecutor(conf)) + } else { + None + } + val liveRP = new LiveResourceProfile(event.resourceProfile.id, + event.resourceProfile.executorResources, event.resourceProfile.taskResources, maxTasks) + liveResourceProfiles(event.resourceProfile.id) = liveRP + val rpInfo = new v1.ResourceProfileInfo(liveRP.resourceProfileId, + liveRP.executorResources, liveRP.taskResources) + kvstore.write(new ResourceProfileWrapper(rpInfo)) + } + override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate): Unit = { val details = event.environmentDetails @@ -159,10 +175,11 @@ private[spark] class AppStatusListener( details.getOrElse("Spark Properties", Nil), details.getOrElse("Hadoop Properties", Nil), details.getOrElse("System Properties", Nil), - details.getOrElse("Classpath Entries", Nil)) + details.getOrElse("Classpath Entries", Nil), + Nil) - coresPerTask = envInfo.sparkProperties.toMap.get(CPUS_PER_TASK.key).map(_.toInt) - .getOrElse(coresPerTask) + defaultCpusPerTask = envInfo.sparkProperties.toMap.get(CPUS_PER_TASK.key).map(_.toInt) + .getOrElse(defaultCpusPerTask) kvstore.write(new ApplicationEnvironmentInfoWrapper(envInfo)) } @@ -197,10 +214,16 @@ private[spark] class AppStatusListener( exec.host = event.executorInfo.executorHost exec.isActive = true exec.totalCores = event.executorInfo.totalCores - exec.maxTasks = event.executorInfo.totalCores / coresPerTask + val rpId = event.executorInfo.resourceProfileId + val liveRP = liveResourceProfiles.get(rpId) + val cpusPerTask = liveRP.flatMap(_.taskResources.get(CPUS)) + .map(_.amount.toInt).getOrElse(defaultCpusPerTask) + val maxTasksPerExec = liveRP.flatMap(_.maxTasksPerExecutor) + exec.maxTasks = maxTasksPerExec.getOrElse(event.executorInfo.totalCores / cpusPerTask) exec.executorLogs = event.executorInfo.logUrlMap exec.resources = event.executorInfo.resourcesInfo exec.attributes = event.executorInfo.attributes + exec.resourceProfileId = rpId liveUpdate(exec, System.nanoTime()) } diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index 6b89812cc2bf0..ea033d0c890ac 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -22,7 +22,8 @@ import java.util.{List => JList} import scala.collection.JavaConverters._ import scala.collection.mutable.HashMap -import org.apache.spark.{JobExecutionStatus, SparkConf} +import org.apache.spark.{JobExecutionStatus, SparkConf, SparkException} +import org.apache.spark.resource.ResourceProfileManager import org.apache.spark.status.api.v1 import org.apache.spark.ui.scope._ import org.apache.spark.util.Utils @@ -36,7 +37,14 @@ private[spark] class AppStatusStore( val listener: Option[AppStatusListener] = None) { def applicationInfo(): v1.ApplicationInfo = { - store.view(classOf[ApplicationInfoWrapper]).max(1).iterator().next().info + try { + // The ApplicationInfo may not be available when Spark is starting up. + store.view(classOf[ApplicationInfoWrapper]).max(1).iterator().next().info + } catch { + case _: NoSuchElementException => + throw new SparkException("Failed to get the application information. " + + "If you are starting up Spark, please wait a while until it's ready.") + } } def environmentInfo(): v1.ApplicationEnvironmentInfo = { @@ -44,6 +52,10 @@ private[spark] class AppStatusStore( store.read(klass, klass.getName()).info } + def resourceProfileInfo(): Seq[v1.ResourceProfileInfo] = { + store.view(classOf[ResourceProfileWrapper]).asScala.map(_.rpInfo).toSeq + } + def jobsList(statuses: JList[JobExecutionStatus]): Seq[v1.JobData] = { val it = store.view(classOf[JobDataWrapper]).reverse().asScala.map(_.info) if (statuses != null && !statuses.isEmpty()) { @@ -479,7 +491,8 @@ private[spark] class AppStatusStore( accumulatorUpdates = stage.accumulatorUpdates, tasks = Some(tasks), executorSummary = Some(executorSummary(stage.stageId, stage.attemptId)), - killedTasksSummary = stage.killedTasksSummary) + killedTasksSummary = stage.killedTasksSummary, + resourceProfileId = stage.resourceProfileId) } def rdd(rddId: Int): v1.RDDStorageInfo = { diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 2714f30de14f0..86cb4fe138773 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -28,7 +28,7 @@ import com.google.common.collect.Interners import org.apache.spark.JobExecutionStatus import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} -import org.apache.spark.resource.ResourceInformation +import org.apache.spark.resource.{ExecutorResourceRequest, ResourceInformation, ResourceProfile, TaskResourceRequest} import org.apache.spark.scheduler.{AccumulableInfo, StageInfo, TaskInfo} import org.apache.spark.status.api.v1 import org.apache.spark.storage.{RDDInfo, StorageLevel} @@ -245,6 +245,21 @@ private class LiveTask( } +private class LiveResourceProfile( + val resourceProfileId: Int, + val executorResources: Map[String, ExecutorResourceRequest], + val taskResources: Map[String, TaskResourceRequest], + val maxTasksPerExecutor: Option[Int]) extends LiveEntity { + + def toApi(): v1.ResourceProfileInfo = { + new v1.ResourceProfileInfo(resourceProfileId, executorResources, taskResources) + } + + override protected def doUpdate(): Any = { + new ResourceProfileWrapper(toApi()) + } +} + private[spark] class LiveExecutor(val executorId: String, _addTime: Long) extends LiveEntity { var hostPort: String = null @@ -285,6 +300,8 @@ private[spark] class LiveExecutor(val executorId: String, _addTime: Long) extend var usedOnHeap = 0L var usedOffHeap = 0L + var resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID + def hasMemoryInfo: Boolean = totalOnHeap >= 0L // peak values for executor level metrics @@ -327,7 +344,8 @@ private[spark] class LiveExecutor(val executorId: String, _addTime: Long) extend blacklistedInStages, Some(peakExecutorMetrics).filter(_.isSet), attributes, - resources) + resources, + resourceProfileId) new ExecutorSummaryWrapper(info) } } @@ -465,7 +483,8 @@ private class LiveStage extends LiveEntity { accumulatorUpdates = newAccumulatorInfos(info.accumulables.values), tasks = None, executorSummary = None, - killedTasksSummary = killedSummary) + killedTasksSummary = killedSummary, + resourceProfileId = info.resourceProfileId) } override protected def doUpdate(): Any = { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala index cf5c759bebdbb..e0c85fdf6fb5d 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala @@ -101,12 +101,14 @@ private[v1] class AbstractApplicationResource extends BaseAppResource { @Path("environment") def environmentInfo(): ApplicationEnvironmentInfo = withUI { ui => val envInfo = ui.store.environmentInfo() + val resourceProfileInfo = ui.store.resourceProfileInfo() new v1.ApplicationEnvironmentInfo( envInfo.runtime, Utils.redact(ui.conf, envInfo.sparkProperties), Utils.redact(ui.conf, envInfo.hadoopProperties), Utils.redact(ui.conf, envInfo.systemProperties), - envInfo.classpathEntries) + envInfo.classpathEntries, + resourceProfileInfo) } @GET diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/PrometheusResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/PrometheusResource.scala index f9fb78e65a3d9..9658e5e627724 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/PrometheusResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/PrometheusResource.scala @@ -23,15 +23,19 @@ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.glassfish.jersey.server.ServerProperties import org.glassfish.jersey.servlet.ServletContainer +import org.apache.spark.{SPARK_REVISION, SPARK_VERSION_SHORT} +import org.apache.spark.annotation.Experimental import org.apache.spark.ui.SparkUI /** + * :: Experimental :: * This aims to expose Executor metrics like REST API which is documented in * * https://spark.apache.org/docs/3.0.0/monitoring.html#executor-metrics * * Note that this is based on ExecutorSummary which is different from ExecutorSource. */ +@Experimental @Path("/executors") private[v1] class PrometheusResource extends ApiRequestContext { @GET @@ -39,6 +43,7 @@ private[v1] class PrometheusResource extends ApiRequestContext { @Produces(Array(MediaType.TEXT_PLAIN)) def executors(): String = { val sb = new StringBuilder + sb.append(s"""spark_info{version="$SPARK_VERSION_SHORT", revision="$SPARK_REVISION"} 1.0\n""") val store = uiRoot.asInstanceOf[SparkUI].store store.executorList(true).foreach { executor => val prefix = "metrics_executor_" @@ -47,27 +52,27 @@ private[v1] class PrometheusResource extends ApiRequestContext { "application_name" -> store.applicationInfo.name, "executor_id" -> executor.id ).map { case (k, v) => s"""$k="$v"""" }.mkString("{", ", ", "}") - sb.append(s"${prefix}rddBlocks_Count$labels ${executor.rddBlocks}\n") - sb.append(s"${prefix}memoryUsed_Count$labels ${executor.memoryUsed}\n") - sb.append(s"${prefix}diskUsed_Count$labels ${executor.diskUsed}\n") - sb.append(s"${prefix}totalCores_Count$labels ${executor.totalCores}\n") - sb.append(s"${prefix}maxTasks_Count$labels ${executor.maxTasks}\n") - sb.append(s"${prefix}activeTasks_Count$labels ${executor.activeTasks}\n") - sb.append(s"${prefix}failedTasks_Count$labels ${executor.failedTasks}\n") - sb.append(s"${prefix}completedTasks_Count$labels ${executor.completedTasks}\n") - sb.append(s"${prefix}totalTasks_Count$labels ${executor.totalTasks}\n") - sb.append(s"${prefix}totalDuration_Value$labels ${executor.totalDuration}\n") - sb.append(s"${prefix}totalGCTime_Value$labels ${executor.totalGCTime}\n") - sb.append(s"${prefix}totalInputBytes_Count$labels ${executor.totalInputBytes}\n") - sb.append(s"${prefix}totalShuffleRead_Count$labels ${executor.totalShuffleRead}\n") - sb.append(s"${prefix}totalShuffleWrite_Count$labels ${executor.totalShuffleWrite}\n") - sb.append(s"${prefix}maxMemory_Count$labels ${executor.maxMemory}\n") + sb.append(s"${prefix}rddBlocks$labels ${executor.rddBlocks}\n") + sb.append(s"${prefix}memoryUsed_bytes$labels ${executor.memoryUsed}\n") + sb.append(s"${prefix}diskUsed_bytes$labels ${executor.diskUsed}\n") + sb.append(s"${prefix}totalCores$labels ${executor.totalCores}\n") + sb.append(s"${prefix}maxTasks$labels ${executor.maxTasks}\n") + sb.append(s"${prefix}activeTasks$labels ${executor.activeTasks}\n") + sb.append(s"${prefix}failedTasks_total$labels ${executor.failedTasks}\n") + sb.append(s"${prefix}completedTasks_total$labels ${executor.completedTasks}\n") + sb.append(s"${prefix}totalTasks_total$labels ${executor.totalTasks}\n") + sb.append(s"${prefix}totalDuration_seconds_total$labels ${executor.totalDuration * 0.001}\n") + sb.append(s"${prefix}totalGCTime_seconds_total$labels ${executor.totalGCTime * 0.001}\n") + sb.append(s"${prefix}totalInputBytes_bytes_total$labels ${executor.totalInputBytes}\n") + sb.append(s"${prefix}totalShuffleRead_bytes_total$labels ${executor.totalShuffleRead}\n") + sb.append(s"${prefix}totalShuffleWrite_bytes_total$labels ${executor.totalShuffleWrite}\n") + sb.append(s"${prefix}maxMemory_bytes$labels ${executor.maxMemory}\n") executor.executorLogs.foreach { case (k, v) => } executor.memoryMetrics.foreach { m => - sb.append(s"${prefix}usedOnHeapStorageMemory_Count$labels ${m.usedOnHeapStorageMemory}\n") - sb.append(s"${prefix}usedOffHeapStorageMemory_Count$labels ${m.usedOffHeapStorageMemory}\n") - sb.append(s"${prefix}totalOnHeapStorageMemory_Count$labels ${m.totalOnHeapStorageMemory}\n") - sb.append(s"${prefix}totalOffHeapStorageMemory_Count$labels " + + sb.append(s"${prefix}usedOnHeapStorageMemory_bytes$labels ${m.usedOnHeapStorageMemory}\n") + sb.append(s"${prefix}usedOffHeapStorageMemory_bytes$labels ${m.usedOffHeapStorageMemory}\n") + sb.append(s"${prefix}totalOnHeapStorageMemory_bytes$labels ${m.totalOnHeapStorageMemory}\n") + sb.append(s"${prefix}totalOffHeapStorageMemory_bytes$labels " + s"${m.totalOffHeapStorageMemory}\n") } executor.peakMemoryMetrics.foreach { m => @@ -87,14 +92,16 @@ private[v1] class PrometheusResource extends ApiRequestContext { "ProcessTreePythonVMemory", "ProcessTreePythonRSSMemory", "ProcessTreeOtherVMemory", - "ProcessTreeOtherRSSMemory", - "MinorGCCount", - "MinorGCTime", - "MajorGCCount", - "MajorGCTime" + "ProcessTreeOtherRSSMemory" ) names.foreach { name => - sb.append(s"$prefix${name}_Count$labels ${m.getMetricValue(name)}\n") + sb.append(s"$prefix${name}_bytes$labels ${m.getMetricValue(name)}\n") + } + Seq("MinorGCCount", "MajorGCCount").foreach { name => + sb.append(s"$prefix${name}_total$labels ${m.getMetricValue(name)}\n") + } + Seq("MinorGCTime", "MajorGCTime").foreach { name => + sb.append(s"$prefix${name}_seconds_total$labels ${m.getMetricValue(name) * 0.001}\n") } } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 5ec9b36393764..e89e29101a126 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -30,7 +30,7 @@ import com.fasterxml.jackson.databind.annotation.{JsonDeserialize, JsonSerialize import org.apache.spark.JobExecutionStatus import org.apache.spark.executor.ExecutorMetrics import org.apache.spark.metrics.ExecutorMetricType -import org.apache.spark.resource.ResourceInformation +import org.apache.spark.resource.{ExecutorResourceRequest, ResourceInformation, TaskResourceRequest} case class ApplicationInfo private[spark]( id: String, @@ -62,6 +62,11 @@ case class ApplicationAttemptInfo private[spark]( } +class ResourceProfileInfo private[spark]( + val id: Int, + val executorResources: Map[String, ExecutorResourceRequest], + val taskResources: Map[String, TaskResourceRequest]) + class ExecutorStageSummary private[spark]( val taskTime : Long, val failedTasks : Int, @@ -109,7 +114,8 @@ class ExecutorSummary private[spark]( @JsonDeserialize(using = classOf[ExecutorMetricsJsonDeserializer]) val peakMemoryMetrics: Option[ExecutorMetrics], val attributes: Map[String, String], - val resources: Map[String, ResourceInformation]) + val resources: Map[String, ResourceInformation], + val resourceProfileId: Int) class MemoryMetrics private[spark]( val usedOnHeapStorageMemory: Long, @@ -252,7 +258,8 @@ class StageData private[spark]( val accumulatorUpdates: Seq[AccumulableInfo], val tasks: Option[Map[Long, TaskData]], val executorSummary: Option[Map[String, ExecutorStageSummary]], - val killedTasksSummary: Map[String, Int]) + val killedTasksSummary: Map[String, Int], + val resourceProfileId: Int) class TaskData private[spark]( val taskId: Long, @@ -365,12 +372,15 @@ class AccumulableInfo private[spark]( class VersionInfo private[spark]( val spark: String) +// Note the resourceProfiles information are only added here on return from the +// REST call, they are not stored with it. class ApplicationEnvironmentInfo private[spark] ( val runtime: RuntimeInfo, val sparkProperties: Seq[(String, String)], val hadoopProperties: Seq[(String, String)], val systemProperties: Seq[(String, String)], - val classpathEntries: Seq[(String, String)]) + val classpathEntries: Seq[(String, String)], + val resourceProfiles: Seq[ResourceProfileInfo]) class RuntimeInfo private[spark]( val javaVersion: String, diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index c957ff75a501f..b40f7304b7ce2 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -374,6 +374,13 @@ private[spark] class RDDStorageInfoWrapper(val info: RDDStorageInfo) { } +private[spark] class ResourceProfileWrapper(val rpInfo: ResourceProfileInfo) { + + @JsonIgnore @KVIndex + def id: Int = rpInfo.id + +} + private[spark] class ExecutorStageSummaryWrapper( val stageId: Int, val stageAttemptId: Int, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e7f8de5ab7e4a..e0478ad09601d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -54,6 +54,7 @@ import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.serializer.{SerializerInstance, SerializerManager} import org.apache.spark.shuffle.{ShuffleManager, ShuffleWriteMetricsReporter} +import org.apache.spark.storage.BlockManagerMessages.ReplicateBlock import org.apache.spark.storage.memory._ import org.apache.spark.unsafe.Platform import org.apache.spark.util._ @@ -241,6 +242,9 @@ private[spark] class BlockManager( private var blockReplicationPolicy: BlockReplicationPolicy = _ + private var blockManagerDecommissioning: Boolean = false + private var decommissionManager: Option[BlockManagerDecommissionManager] = None + // A DownloadFileManager used to track all the files of remote blocks which are above the // specified memory threshold. Files will be deleted automatically based on weak reference. // Exposed for test @@ -1551,18 +1555,22 @@ private[spark] class BlockManager( } /** - * Called for pro-active replenishment of blocks lost due to executor failures + * Replicates a block to peer block managers based on existingReplicas and maxReplicas * * @param blockId blockId being replicate * @param existingReplicas existing block managers that have a replica * @param maxReplicas maximum replicas needed + * @param maxReplicationFailures number of replication failures to tolerate before + * giving up. + * @return whether block was successfully replicated or not */ def replicateBlock( blockId: BlockId, existingReplicas: Set[BlockManagerId], - maxReplicas: Int): Unit = { + maxReplicas: Int, + maxReplicationFailures: Option[Int] = None): Boolean = { logInfo(s"Using $blockManagerId to pro-actively replicate $blockId") - blockInfoManager.lockForReading(blockId).foreach { info => + blockInfoManager.lockForReading(blockId).forall { info => val data = doGetLocalBytes(blockId, info) val storageLevel = StorageLevel( useDisk = info.level.useDisk, @@ -1570,11 +1578,13 @@ private[spark] class BlockManager( useOffHeap = info.level.useOffHeap, deserialized = info.level.deserialized, replication = maxReplicas) - // we know we are called as a result of an executor removal, so we refresh peer cache - // this way, we won't try to replicate to a missing executor with a stale reference + // we know we are called as a result of an executor removal or because the current executor + // is getting decommissioned. so we refresh peer cache before trying replication, we won't + // try to replicate to a missing executor/another decommissioning executor getPeers(forceFetch = true) try { - replicate(blockId, data, storageLevel, info.classTag, existingReplicas) + replicate( + blockId, data, storageLevel, info.classTag, existingReplicas, maxReplicationFailures) } finally { logDebug(s"Releasing lock for $blockId") releaseLockAndDispose(blockId, data) @@ -1591,9 +1601,11 @@ private[spark] class BlockManager( data: BlockData, level: StorageLevel, classTag: ClassTag[_], - existingReplicas: Set[BlockManagerId] = Set.empty): Unit = { + existingReplicas: Set[BlockManagerId] = Set.empty, + maxReplicationFailures: Option[Int] = None): Boolean = { - val maxReplicationFailures = conf.get(config.STORAGE_MAX_REPLICATION_FAILURE) + val maxReplicationFailureCount = maxReplicationFailures.getOrElse( + conf.get(config.STORAGE_MAX_REPLICATION_FAILURE)) val tLevel = StorageLevel( useDisk = level.useDisk, useMemory = level.useMemory, @@ -1617,7 +1629,7 @@ private[spark] class BlockManager( blockId, numPeersToReplicateTo) - while(numFailures <= maxReplicationFailures && + while(numFailures <= maxReplicationFailureCount && !peersForReplication.isEmpty && peersReplicatedTo.size < numPeersToReplicateTo) { val peer = peersForReplication.head @@ -1641,6 +1653,10 @@ private[spark] class BlockManager( peersForReplication = peersForReplication.tail peersReplicatedTo += peer } catch { + // Rethrow interrupt exception + case e: InterruptedException => + throw e + // Everything else we may retry case NonFatal(e) => logWarning(s"Failed to replicate $blockId to $peer, failure #$numFailures", e) peersFailedToReplicateTo += peer @@ -1665,9 +1681,11 @@ private[spark] class BlockManager( if (peersReplicatedTo.size < numPeersToReplicateTo) { logWarning(s"Block $blockId replicated to only " + s"${peersReplicatedTo.size} peer(s) instead of $numPeersToReplicateTo peers") + return false } logDebug(s"block $blockId replicated to ${peersReplicatedTo.mkString(", ")}") + return true } /** @@ -1761,6 +1779,60 @@ private[spark] class BlockManager( blocksToRemove.size } + def decommissionBlockManager(): Unit = { + if (!blockManagerDecommissioning) { + logInfo("Starting block manager decommissioning process") + blockManagerDecommissioning = true + decommissionManager = Some(new BlockManagerDecommissionManager(conf)) + decommissionManager.foreach(_.start()) + } else { + logDebug("Block manager already in decommissioning state") + } + } + + /** + * Tries to offload all cached RDD blocks from this BlockManager to peer BlockManagers + * Visible for testing + */ + def decommissionRddCacheBlocks(): Unit = { + val replicateBlocksInfo = master.getReplicateInfoForRDDBlocks(blockManagerId) + + if (replicateBlocksInfo.nonEmpty) { + logInfo(s"Need to replicate ${replicateBlocksInfo.size} blocks " + + "for block manager decommissioning") + } else { + logWarning(s"Asked to decommission RDD cache blocks, but no blocks to migrate") + return + } + + // Maximum number of storage replication failure which replicateBlock can handle + val maxReplicationFailures = conf.get( + config.STORAGE_DECOMMISSION_MAX_REPLICATION_FAILURE_PER_BLOCK) + + // TODO: We can sort these blocks based on some policy (LRU/blockSize etc) + // so that we end up prioritize them over each other + val blocksFailedReplication = replicateBlocksInfo.map { + case ReplicateBlock(blockId, existingReplicas, maxReplicas) => + val replicatedSuccessfully = replicateBlock( + blockId, + existingReplicas.toSet, + maxReplicas, + maxReplicationFailures = Some(maxReplicationFailures)) + if (replicatedSuccessfully) { + logInfo(s"Block $blockId offloaded successfully, Removing block now") + removeBlock(blockId) + logInfo(s"Block $blockId removed") + } else { + logWarning(s"Failed to offload block $blockId") + } + (blockId, replicatedSuccessfully) + }.filterNot(_._2).map(_._1) + if (blocksFailedReplication.nonEmpty) { + logWarning("Blocks failed replication in cache decommissioning " + + s"process: ${blocksFailedReplication.mkString(",")}") + } + } + /** * Remove all blocks belonging to the given broadcast. */ @@ -1829,7 +1901,58 @@ private[spark] class BlockManager( data.dispose() } + /** + * Class to handle block manager decommissioning retries + * It creates a Thread to retry offloading all RDD cache blocks + */ + private class BlockManagerDecommissionManager(conf: SparkConf) { + @volatile private var stopped = false + private val sleepInterval = conf.get( + config.STORAGE_DECOMMISSION_REPLICATION_REATTEMPT_INTERVAL) + + private val blockReplicationThread = new Thread { + override def run(): Unit = { + var failures = 0 + while (blockManagerDecommissioning + && !stopped + && !Thread.interrupted() + && failures < 20) { + try { + logDebug("Attempting to replicate all cached RDD blocks") + decommissionRddCacheBlocks() + logInfo("Attempt to replicate all cached blocks done") + Thread.sleep(sleepInterval) + } catch { + case _: InterruptedException => + logInfo("Interrupted during migration, will not refresh migrations.") + stopped = true + case NonFatal(e) => + failures += 1 + logError("Error occurred while trying to replicate cached RDD blocks" + + s" for block manager decommissioning (failure count: $failures)", e) + } + } + } + } + blockReplicationThread.setDaemon(true) + blockReplicationThread.setName("block-replication-thread") + + def start(): Unit = { + logInfo("Starting block replication thread") + blockReplicationThread.start() + } + + def stop(): Unit = { + if (!stopped) { + stopped = true + logInfo("Stopping block replication thread") + blockReplicationThread.interrupt() + } + } + } + def stop(): Unit = { + decommissionManager.foreach(_.stop()) blockTransferService.close() if (blockStoreClient ne blockTransferService) { // Closing should be idempotent, but maybe not for the NioBlockTransferService. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index e440c1ab7bcd9..3cfa5d2a25818 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -43,6 +43,16 @@ class BlockManagerMaster( logInfo("Removed " + execId + " successfully in removeExecutor") } + /** Decommission block managers corresponding to given set of executors */ + def decommissionBlockManagers(executorIds: Seq[String]): Unit = { + driverEndpoint.ask[Unit](DecommissionBlockManagers(executorIds)) + } + + /** Get Replication Info for all the RDD blocks stored in given blockManagerId */ + def getReplicateInfoForRDDBlocks(blockManagerId: BlockManagerId): Seq[ReplicateBlock] = { + driverEndpoint.askSync[Seq[ReplicateBlock]](GetReplicateInfoForRDDBlocks(blockManagerId)) + } + /** Request removal of a dead executor from the driver endpoint. * This is only called on the driver side. Non-blocking */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index d7f7eedc7f33b..d936420a99276 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -65,6 +65,9 @@ class BlockManagerMasterEndpoint( // Mapping from executor ID to block manager ID. private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId] + // Set of block managers which are decommissioning + private val decommissioningBlockManagerSet = new mutable.HashSet[BlockManagerId] + // Mapping from block id to the set of block managers that have the block. private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]] @@ -153,6 +156,13 @@ class BlockManagerMasterEndpoint( removeExecutor(execId) context.reply(true) + case DecommissionBlockManagers(executorIds) => + decommissionBlockManagers(executorIds.flatMap(blockManagerIdByExecutor.get)) + context.reply(true) + + case GetReplicateInfoForRDDBlocks(blockManagerId) => + context.reply(getReplicateInfoForRDDBlocks(blockManagerId)) + case StopBlockManagerMaster => context.reply(true) stop() @@ -257,6 +267,7 @@ class BlockManagerMasterEndpoint( // Remove the block manager from blockManagerIdByExecutor. blockManagerIdByExecutor -= blockManagerId.executorId + decommissioningBlockManagerSet.remove(blockManagerId) // Remove it from blockManagerInfo and remove all the blocks. blockManagerInfo.remove(blockManagerId) @@ -299,6 +310,39 @@ class BlockManagerMasterEndpoint( blockManagerIdByExecutor.get(execId).foreach(removeBlockManager) } + /** + * Decommission the given Seq of blockmanagers + * - Adds these block managers to decommissioningBlockManagerSet Set + * - Sends the DecommissionBlockManager message to each of the [[BlockManagerSlaveEndpoint]] + */ + def decommissionBlockManagers(blockManagerIds: Seq[BlockManagerId]): Future[Seq[Unit]] = { + val newBlockManagersToDecommission = blockManagerIds.toSet.diff(decommissioningBlockManagerSet) + val futures = newBlockManagersToDecommission.map { blockManagerId => + decommissioningBlockManagerSet.add(blockManagerId) + val info = blockManagerInfo(blockManagerId) + info.slaveEndpoint.ask[Unit](DecommissionBlockManager) + } + Future.sequence{ futures.toSeq } + } + + /** + * Returns a Seq of ReplicateBlock for each RDD block stored by given blockManagerId + * @param blockManagerId - block manager id for which ReplicateBlock info is needed + * @return Seq of ReplicateBlock + */ + private def getReplicateInfoForRDDBlocks(blockManagerId: BlockManagerId): Seq[ReplicateBlock] = { + val info = blockManagerInfo(blockManagerId) + + val rddBlocks = info.blocks.keySet().asScala.filter(_.isRDD) + rddBlocks.map { blockId => + val currentBlockLocations = blockLocations.get(blockId) + val maxReplicas = currentBlockLocations.size + 1 + val remainingLocations = currentBlockLocations.toSeq.filter(bm => bm != blockManagerId) + val replicateMsg = ReplicateBlock(blockId, remainingLocations, maxReplicas) + replicateMsg + }.toSeq + } + // Remove a block from the slaves that have it. This can only be used to remove // blocks that the master knows about. private def removeBlockFromWorkers(blockId: BlockId): Unit = { @@ -536,7 +580,11 @@ class BlockManagerMasterEndpoint( private def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = { val blockManagerIds = blockManagerInfo.keySet if (blockManagerIds.contains(blockManagerId)) { - blockManagerIds.filterNot { _.isDriver }.filterNot { _ == blockManagerId }.toSeq + blockManagerIds + .filterNot { _.isDriver } + .filterNot { _ == blockManagerId } + .diff(decommissioningBlockManagerSet) + .toSeq } else { Seq.empty } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 895f48d0709fb..7d4f2fff5c34c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -36,6 +36,8 @@ private[spark] object BlockManagerMessages { case class ReplicateBlock(blockId: BlockId, replicas: Seq[BlockManagerId], maxReplicas: Int) extends ToBlockManagerSlave + case object DecommissionBlockManager extends ToBlockManagerSlave + // Remove all blocks belonging to a specific RDD. case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave @@ -125,6 +127,11 @@ private[spark] object BlockManagerMessages { case object GetStorageStatus extends ToBlockManagerMaster + case class DecommissionBlockManagers(executorIds: Seq[String]) extends ToBlockManagerMaster + + case class GetReplicateInfoForRDDBlocks(blockManagerId: BlockManagerId) + extends ToBlockManagerMaster + case class GetBlockStatus(blockId: BlockId, askSlaves: Boolean = true) extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index 29e21142ce449..a3a7149103491 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -61,6 +61,9 @@ class BlockManagerSlaveEndpoint( SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId) } + case DecommissionBlockManager => + context.reply(blockManager.decommissionBlockManager()) + case RemoveBroadcast(broadcastId, _) => doAsync[Int]("removing broadcast " + broadcastId, context) { blockManager.removeBroadcast(broadcastId, tellMaster = true) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 94c99d48e773c..4b4788f453243 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -496,10 +496,16 @@ private[spark] case class ServerInfo( } def stop(): Unit = { + val threadPool = server.getThreadPool + threadPool match { + case pool: QueuedThreadPool => + // Workaround for SPARK-30385 to avoid Jetty's acceptor thread shrink. + pool.setIdleTimeout(0) + case _ => + } server.stop() // Stop the ThreadPool if it supports stop() method (through LifeCycle). // It is needed because stopping the Server won't stop the ThreadPool it uses. - val threadPool = server.getThreadPool if (threadPool != null && threadPool.isInstanceOf[LifeCycle]) { threadPool.asInstanceOf[LifeCycle].stop } diff --git a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala index 008dcc6200d37..a002af70a919d 100644 --- a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala @@ -17,8 +17,9 @@ package org.apache.spark.ui -import java.net.URLDecoder +import java.net.{URLDecoder, URLEncoder} import java.nio.charset.StandardCharsets.UTF_8 +import javax.servlet.http.HttpServletRequest import scala.collection.JavaConverters._ import scala.xml.{Node, Unparsed} @@ -297,4 +298,102 @@ private[spark] trait PagedTable[T] { * Returns the submission path for the "go to page #" form. */ def goButtonFormPath: String + + /** + * Returns parameters of other tables in the page. + */ + def getParameterOtherTable(request: HttpServletRequest, tableTag: String): String = { + request.getParameterMap.asScala + .filterNot(_._1.startsWith(tableTag)) + .map(parameter => parameter._1 + "=" + parameter._2(0)) + .mkString("&") + } + + /** + * Returns parameter of this table. + */ + def getTableParameters( + request: HttpServletRequest, + tableTag: String, + defaultSortColumn: String): (String, Boolean, Int) = { + val parameterSortColumn = request.getParameter(s"$tableTag.sort") + val parameterSortDesc = request.getParameter(s"$tableTag.desc") + val parameterPageSize = request.getParameter(s"$tableTag.pageSize") + val sortColumn = Option(parameterSortColumn).map { sortColumn => + UIUtils.decodeURLParameter(sortColumn) + }.getOrElse(defaultSortColumn) + val desc = Option(parameterSortDesc).map(_.toBoolean).getOrElse( + sortColumn == defaultSortColumn + ) + val pageSize = Option(parameterPageSize).map(_.toInt).getOrElse(100) + + (sortColumn, desc, pageSize) + } + + /** + * Check if given sort column is valid or not. If invalid then an exception is thrown. + */ + def isSortColumnValid( + headerInfo: Seq[(String, Boolean, Option[String])], + sortColumn: String): Unit = { + if (!headerInfo.filter(_._2).map(_._1).contains(sortColumn)) { + throw new IllegalArgumentException(s"Unknown column: $sortColumn") + } + } + + def headerRow( + headerInfo: Seq[(String, Boolean, Option[String])], + desc: Boolean, + pageSize: Int, + sortColumn: String, + parameterPath: String, + tableTag: String, + headerId: String): Seq[Node] = { + val row: Seq[Node] = { + headerInfo.map { case (header, sortable, tooltip) => + if (header == sortColumn) { + val headerLink = Unparsed( + parameterPath + + s"&$tableTag.sort=${URLEncoder.encode(header, UTF_8.name())}" + + s"&$tableTag.desc=${!desc}" + + s"&$tableTag.pageSize=$pageSize" + + s"#$headerId") + val arrow = if (desc) "▾" else "▴" // UP or DOWN + + + + + {header} {Unparsed(arrow)} + + + + } else { + if (sortable) { + val headerLink = Unparsed( + parameterPath + + s"&$tableTag.sort=${URLEncoder.encode(header, UTF_8.name())}" + + s"&$tableTag.pageSize=$pageSize" + + s"#$headerId") + + + + + {header} + + + + } else { + + + {header} + + + } + } + } + } + + {row} + + } } diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 90167858df663..087a22d6c6140 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -292,6 +292,7 @@ private[spark] object UIUtils extends Logging { {commonHeaderNodes(request)} + {if (showVisualization) vizHeaderNodes(request) else Seq.empty} {if (useDataTables) dataTablesHeaderNodes(request) else Seq.empty} appEnv.runtime.javaHome, "Scala Version" -> appEnv.runtime.scalaVersion) + def constructExecutorRequestString(execReqs: Map[String, ExecutorResourceRequest]): String = { + execReqs.map { + case (_, execReq) => + val execStr = new StringBuilder(s"\t${execReq.resourceName}: [amount: ${execReq.amount}") + if (execReq.discoveryScript.nonEmpty) { + execStr ++= s", discovery: ${execReq.discoveryScript}" + } + if (execReq.vendor.nonEmpty) { + execStr ++= s", vendor: ${execReq.vendor}" + } + execStr ++= "]" + execStr.toString() + }.mkString("\n") + } + + def constructTaskRequestString(taskReqs: Map[String, TaskResourceRequest]): String = { + taskReqs.map { + case (_, taskReq) => s"\t${taskReq.resourceName}: [amount: ${taskReq.amount}]" + }.mkString("\n") + } + + val resourceProfileInfo = store.resourceProfileInfo().map { rinfo => + val einfo = constructExecutorRequestString(rinfo.executorResources) + val tinfo = constructTaskRequestString(rinfo.taskResources) + val res = s"Executor Reqs:\n$einfo\nTask Reqs:\n$tinfo" + (rinfo.id.toString, res) + }.toMap + + val resourceProfileInformationTable = UIUtils.listingTable(resourceProfileHeader, + jvmRowDataPre, resourceProfileInfo.toSeq.sortWith(_._1.toInt < _._1.toInt), + fixedWidth = true, headerClasses = headerClassesNoSortValues) val runtimeInformationTable = UIUtils.listingTable( propertyHeader, jvmRow, jvmInformation.toSeq.sorted, fixedWidth = true, headerClasses = headerClasses) @@ -77,6 +110,17 @@ private[ui] class EnvironmentPage(
{sparkPropertiesTable}
+ +

+ + Resource Profiles +

+
+
+ {resourceProfileInformationTable} +
@@ -115,10 +159,14 @@ private[ui] class EnvironmentPage( UIUtils.headerSparkPage(request, "Environment", content, parent) } + private def resourceProfileHeader = Seq("Resource Profile Id", "Resource Profile Contents") private def propertyHeader = Seq("Name", "Value") private def classPathHeader = Seq("Resource", "Source") private def headerClasses = Seq("sorttable_alpha", "sorttable_alpha") + private def headerClassesNoSortValues = Seq("sorttable_numeric", "sorttable_nosort") + private def jvmRowDataPre(kv: (String, String)) = + {kv._1}
{kv._2}
private def jvmRow(kv: (String, String)) = {kv._1}{kv._2} private def propertyRow(kv: (String, String)) = {kv._1}{kv._2} private def classPathRow(data: (String, String)) = {data._1}{data._2} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 0b362201a7846..066512d159d00 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -22,7 +22,6 @@ import java.nio.charset.StandardCharsets.UTF_8 import java.util.Date import javax.servlet.http.HttpServletRequest -import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import scala.xml._ @@ -211,45 +210,22 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We jobTag: String, jobs: Seq[v1.JobData], killEnabled: Boolean): Seq[Node] = { - val parameterOtherTable = request.getParameterMap().asScala - .filterNot(_._1.startsWith(jobTag)) - .map(para => para._1 + "=" + para._2(0)) val someJobHasJobGroup = jobs.exists(_.jobGroup.isDefined) val jobIdTitle = if (someJobHasJobGroup) "Job Id (Job Group)" else "Job Id" - - val parameterJobPage = request.getParameter(jobTag + ".page") - val parameterJobSortColumn = request.getParameter(jobTag + ".sort") - val parameterJobSortDesc = request.getParameter(jobTag + ".desc") - val parameterJobPageSize = request.getParameter(jobTag + ".pageSize") - - val jobPage = Option(parameterJobPage).map(_.toInt).getOrElse(1) - val jobSortColumn = Option(parameterJobSortColumn).map { sortColumn => - UIUtils.decodeURLParameter(sortColumn) - }.getOrElse(jobIdTitle) - val jobSortDesc = Option(parameterJobSortDesc).map(_.toBoolean).getOrElse( - // New jobs should be shown above old jobs by default. - jobSortColumn == jobIdTitle - ) - val jobPageSize = Option(parameterJobPageSize).map(_.toInt).getOrElse(100) - - val currentTime = System.currentTimeMillis() + val jobPage = Option(request.getParameter(jobTag + ".page")).map(_.toInt).getOrElse(1) try { new JobPagedTable( + request, store, jobs, tableHeaderId, jobTag, UIUtils.prependBaseUri(request, parent.basePath), "jobs", // subPath - parameterOtherTable, killEnabled, - currentTime, - jobIdTitle, - pageSize = jobPageSize, - sortColumn = jobSortColumn, - desc = jobSortDesc + jobIdTitle ).table(jobPage) } catch { case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) => @@ -421,7 +397,6 @@ private[ui] class JobDataSource( store: AppStatusStore, jobs: Seq[v1.JobData], basePath: String, - currentTime: Long, pageSize: Int, sortColumn: String, desc: Boolean) extends PagedDataSource[JobTableRowData](pageSize) { @@ -432,15 +407,9 @@ private[ui] class JobDataSource( // so that we can avoid creating duplicate contents during sorting the data private val data = jobs.map(jobRow).sorted(ordering(sortColumn, desc)) - private var _slicedJobIds: Set[Int] = null - override def dataSize: Int = data.size - override def sliceData(from: Int, to: Int): Seq[JobTableRowData] = { - val r = data.slice(from, to) - _slicedJobIds = r.map(_.jobData.jobId).toSet - r - } + override def sliceData(from: Int, to: Int): Seq[JobTableRowData] = data.slice(from, to) private def jobRow(jobData: v1.JobData): JobTableRowData = { val duration: Option[Long] = JobDataUtil.getDuration(jobData) @@ -493,27 +462,25 @@ private[ui] class JobDataSource( } private[ui] class JobPagedTable( + request: HttpServletRequest, store: AppStatusStore, data: Seq[v1.JobData], tableHeaderId: String, jobTag: String, basePath: String, subPath: String, - parameterOtherTable: Iterable[String], killEnabled: Boolean, - currentTime: Long, - jobIdTitle: String, - pageSize: Int, - sortColumn: String, - desc: Boolean + jobIdTitle: String ) extends PagedTable[JobTableRowData] { - val parameterPath = basePath + s"/$subPath/?" + parameterOtherTable.mkString("&") + + private val (sortColumn, desc, pageSize) = getTableParameters(request, jobTag, jobIdTitle) + private val parameterPath = basePath + s"/$subPath/?" + getParameterOtherTable(request, jobTag) + private val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) override def tableId: String = jobTag + "-table" override def tableCssClass: String = - "table table-bordered table-sm table-striped " + - "table-head-clickable table-cell-width-limited" + "table table-bordered table-sm table-striped table-head-clickable table-cell-width-limited" override def pageSizeFormField: String = jobTag + ".pageSize" @@ -523,13 +490,11 @@ private[ui] class JobPagedTable( store, data, basePath, - currentTime, pageSize, sortColumn, desc) override def pageLink(page: Int): String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) parameterPath + s"&$pageNumberFormField=$page" + s"&$jobTag.sort=$encodedSortColumn" + @@ -538,96 +503,26 @@ private[ui] class JobPagedTable( s"#$tableHeaderId" } - override def goButtonFormPath: String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) + override def goButtonFormPath: String = s"$parameterPath&$jobTag.sort=$encodedSortColumn&$jobTag.desc=$desc#$tableHeaderId" - } override def headers: Seq[Node] = { - // Information for each header: title, cssClass, and sortable - val jobHeadersAndCssClasses: Seq[(String, String, Boolean, Option[String])] = + // Information for each header: title, sortable, tooltip + val jobHeadersAndCssClasses: Seq[(String, Boolean, Option[String])] = Seq( - (jobIdTitle, "", true, None), - ("Description", "", true, None), - ("Submitted", "", true, None), - ("Duration", "", true, Some("Elapsed time since the job was submitted " + + (jobIdTitle, true, None), + ("Description", true, None), + ("Submitted", true, None), + ("Duration", true, Some("Elapsed time since the job was submitted " + "until execution completion of all its stages.")), - ("Stages: Succeeded/Total", "", false, None), - ("Tasks (for all stages): Succeeded/Total", "", false, None) + ("Stages: Succeeded/Total", false, None), + ("Tasks (for all stages): Succeeded/Total", false, None) ) - if (!jobHeadersAndCssClasses.filter(_._3).map(_._1).contains(sortColumn)) { - throw new IllegalArgumentException(s"Unknown column: $sortColumn") - } + isSortColumnValid(jobHeadersAndCssClasses, sortColumn) - val headerRow: Seq[Node] = { - jobHeadersAndCssClasses.map { case (header, cssClass, sortable, tooltip) => - if (header == sortColumn) { - val headerLink = Unparsed( - parameterPath + - s"&$jobTag.sort=${URLEncoder.encode(header, UTF_8.name())}" + - s"&$jobTag.desc=${!desc}" + - s"&$jobTag.pageSize=$pageSize" + - s"#$tableHeaderId") - val arrow = if (desc) "▾" else "▴" // UP or DOWN - - - - { - if (tooltip.nonEmpty) { - - {header} {Unparsed(arrow)} - - } else { - - {header} {Unparsed(arrow)} - - } - } - - - } else { - if (sortable) { - val headerLink = Unparsed( - parameterPath + - s"&$jobTag.sort=${URLEncoder.encode(header, UTF_8.name())}" + - s"&$jobTag.pageSize=$pageSize" + - s"#$tableHeaderId") - - - - { - if (tooltip.nonEmpty) { - - {header} - - } else { - - {header} - - } - } - - - } else { - - { - if (tooltip.nonEmpty) { - - {header} - - } else { - - {header} - - } - } - - } - } - } - } - {headerRow} + headerRow(jobHeadersAndCssClasses, desc, pageSize, sortColumn, parameterPath, + jobTag, tableHeaderId) } override def row(jobTableRow: JobTableRowData): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 9be7124adcf7b..542dc39eee4f0 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -26,6 +26,7 @@ import scala.xml.{Node, NodeSeq, Unparsed, Utility} import org.apache.commons.text.StringEscapeUtils import org.apache.spark.JobExecutionStatus +import org.apache.spark.resource.ResourceProfile import org.apache.spark.status.AppStatusStore import org.apache.spark.status.api.v1 import org.apache.spark.ui._ @@ -253,7 +254,8 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP accumulatorUpdates = Nil, tasks = None, executorSummary = None, - killedTasksSummary = Map()) + killedTasksSummary = Map(), + ResourceProfile.UNKNOWN_RESOURCE_PROFILE_ID) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 7973d30493a5a..47ba951953cec 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -142,6 +142,10 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val summary =
    +
  • + Resource Profile Id: + {stageData.resourceProfileId} +
  • Total Time Across All Tasks: {UIUtils.formatDuration(stageData.executorRunTime)} @@ -208,7 +212,6 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We stageData, UIUtils.prependBaseUri(request, parent.basePath) + s"/stages/stage/?id=${stageId}&attempt=${stageAttemptId}", - currentTime, pageSize = taskPageSize, sortColumn = taskSortColumn, desc = taskSortDesc, @@ -448,7 +451,6 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We private[ui] class TaskDataSource( stage: StageData, - currentTime: Long, pageSize: Int, sortColumn: String, desc: Boolean, @@ -470,8 +472,6 @@ private[ui] class TaskDataSource( _tasksToShow } - def tasks: Seq[TaskData] = _tasksToShow - def executorLogs(id: String): Map[String, String] = { executorIdToLogs.getOrElseUpdate(id, store.asOption(store.executorSummary(id)).map(_.executorLogs).getOrElse(Map.empty)) @@ -482,7 +482,6 @@ private[ui] class TaskDataSource( private[ui] class TaskPagedTable( stage: StageData, basePath: String, - currentTime: Long, pageSize: Int, sortColumn: String, desc: Boolean, @@ -490,6 +489,8 @@ private[ui] class TaskPagedTable( import ApiHelper._ + private val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) + override def tableId: String = "task-table" override def tableCssClass: String = @@ -501,14 +502,12 @@ private[ui] class TaskPagedTable( override val dataSource: TaskDataSource = new TaskDataSource( stage, - currentTime, pageSize, sortColumn, desc, store) override def pageLink(page: Int): String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) basePath + s"&$pageNumberFormField=$page" + s"&task.sort=$encodedSortColumn" + @@ -516,10 +515,7 @@ private[ui] class TaskPagedTable( s"&$pageSizeFormField=$pageSize" } - override def goButtonFormPath: String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) - s"$basePath&task.sort=$encodedSortColumn&task.desc=$desc" - } + override def goButtonFormPath: String = s"$basePath&task.sort=$encodedSortColumn&task.desc=$desc" def headers: Seq[Node] = { import ApiHelper._ diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index a29483b5d5a5e..9e6eb418fe134 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -22,7 +22,6 @@ import java.nio.charset.StandardCharsets.UTF_8 import java.util.Date import javax.servlet.http.HttpServletRequest -import scala.collection.JavaConverters._ import scala.xml._ import org.apache.commons.text.StringEscapeUtils @@ -43,24 +42,8 @@ private[ui] class StageTableBase( isFairScheduler: Boolean, killEnabled: Boolean, isFailedStage: Boolean) { - val parameterOtherTable = request.getParameterMap().asScala - .filterNot(_._1.startsWith(stageTag)) - .map(para => para._1 + "=" + para._2(0)) - - val parameterStagePage = request.getParameter(stageTag + ".page") - val parameterStageSortColumn = request.getParameter(stageTag + ".sort") - val parameterStageSortDesc = request.getParameter(stageTag + ".desc") - val parameterStagePageSize = request.getParameter(stageTag + ".pageSize") - - val stagePage = Option(parameterStagePage).map(_.toInt).getOrElse(1) - val stageSortColumn = Option(parameterStageSortColumn).map { sortColumn => - UIUtils.decodeURLParameter(sortColumn) - }.getOrElse("Stage Id") - val stageSortDesc = Option(parameterStageSortDesc).map(_.toBoolean).getOrElse( - // New stages should be shown above old jobs by default. - stageSortColumn == "Stage Id" - ) - val stagePageSize = Option(parameterStagePageSize).map(_.toInt).getOrElse(100) + + val stagePage = Option(request.getParameter(stageTag + ".page")).map(_.toInt).getOrElse(1) val currentTime = System.currentTimeMillis() @@ -75,11 +58,7 @@ private[ui] class StageTableBase( isFairScheduler, killEnabled, currentTime, - stagePageSize, - stageSortColumn, - stageSortDesc, isFailedStage, - parameterOtherTable, request ).table(stagePage) } catch { @@ -131,25 +110,24 @@ private[ui] class StagePagedTable( isFairScheduler: Boolean, killEnabled: Boolean, currentTime: Long, - pageSize: Int, - sortColumn: String, - desc: Boolean, isFailedStage: Boolean, - parameterOtherTable: Iterable[String], request: HttpServletRequest) extends PagedTable[StageTableRowData] { override def tableId: String = stageTag + "-table" override def tableCssClass: String = - "table table-bordered table-sm table-striped " + - "table-head-clickable table-cell-width-limited" + "table table-bordered table-sm table-striped table-head-clickable table-cell-width-limited" override def pageSizeFormField: String = stageTag + ".pageSize" override def pageNumberFormField: String = stageTag + ".page" - val parameterPath = UIUtils.prependBaseUri(request, basePath) + s"/$subPath/?" + - parameterOtherTable.mkString("&") + private val (sortColumn, desc, pageSize) = getTableParameters(request, stageTag, "Stage Id") + + private val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) + + private val parameterPath = UIUtils.prependBaseUri(request, basePath) + s"/$subPath/?" + + getParameterOtherTable(request, stageTag) override val dataSource = new StageDataSource( store, @@ -161,7 +139,6 @@ private[ui] class StagePagedTable( ) override def pageLink(page: Int): String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) parameterPath + s"&$pageNumberFormField=$page" + s"&$stageTag.sort=$encodedSortColumn" + @@ -170,82 +147,31 @@ private[ui] class StagePagedTable( s"#$tableHeaderId" } - override def goButtonFormPath: String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) + override def goButtonFormPath: String = s"$parameterPath&$stageTag.sort=$encodedSortColumn&$stageTag.desc=$desc#$tableHeaderId" - } override def headers: Seq[Node] = { - // stageHeadersAndCssClasses has three parts: header title, tooltip information, and sortable. + // stageHeadersAndCssClasses has three parts: header title, sortable and tooltip information. // The tooltip information could be None, which indicates it does not have a tooltip. - // Otherwise, it has two parts: tooltip text, and position (true for left, false for default). - val stageHeadersAndCssClasses: Seq[(String, String, Boolean)] = - Seq(("Stage Id", null, true)) ++ - {if (isFairScheduler) {Seq(("Pool Name", null, true))} else Seq.empty} ++ + val stageHeadersAndCssClasses: Seq[(String, Boolean, Option[String])] = + Seq(("Stage Id", true, None)) ++ + {if (isFairScheduler) {Seq(("Pool Name", true, None))} else Seq.empty} ++ Seq( - ("Description", null, true), - ("Submitted", null, true), - ("Duration", ToolTips.DURATION, true), - ("Tasks: Succeeded/Total", null, false), - ("Input", ToolTips.INPUT, true), - ("Output", ToolTips.OUTPUT, true), - ("Shuffle Read", ToolTips.SHUFFLE_READ, true), - ("Shuffle Write", ToolTips.SHUFFLE_WRITE, true) + ("Description", true, None), + ("Submitted", true, None), + ("Duration", true, Some(ToolTips.DURATION)), + ("Tasks: Succeeded/Total", false, None), + ("Input", true, Some(ToolTips.INPUT)), + ("Output", true, Some(ToolTips.OUTPUT)), + ("Shuffle Read", true, Some(ToolTips.SHUFFLE_READ)), + ("Shuffle Write", true, Some(ToolTips.SHUFFLE_WRITE)) ) ++ - {if (isFailedStage) {Seq(("Failure Reason", null, false))} else Seq.empty} - - if (!stageHeadersAndCssClasses.filter(_._3).map(_._1).contains(sortColumn)) { - throw new IllegalArgumentException(s"Unknown column: $sortColumn") - } + {if (isFailedStage) {Seq(("Failure Reason", false, None))} else Seq.empty} - val headerRow: Seq[Node] = { - stageHeadersAndCssClasses.map { case (header, tooltip, sortable) => - val headerSpan = if (null != tooltip && !tooltip.isEmpty) { - - {header} - - } else { - {header} - } + isSortColumnValid(stageHeadersAndCssClasses, sortColumn) - if (header == sortColumn) { - val headerLink = Unparsed( - parameterPath + - s"&$stageTag.sort=${URLEncoder.encode(header, UTF_8.name())}" + - s"&$stageTag.desc=${!desc}" + - s"&$stageTag.pageSize=$pageSize") + - s"#$tableHeaderId" - val arrow = if (desc) "▾" else "▴" // UP or DOWN - - - - {headerSpan} -  {Unparsed(arrow)} - - - - } else { - if (sortable) { - val headerLink = Unparsed( - parameterPath + - s"&$stageTag.sort=${URLEncoder.encode(header, UTF_8.name())}" + - s"&$stageTag.pageSize=$pageSize") + - s"#$tableHeaderId" - - - - {headerSpan} - - - } else { - - {headerSpan} - - } - } - } - } - {headerRow} + headerRow(stageHeadersAndCssClasses, desc, pageSize, sortColumn, parameterPath, + stageTag, tableHeaderId) } override def row(data: StageTableRowData): Seq[Node] = { @@ -383,15 +309,9 @@ private[ui] class StageDataSource( // table so that we can avoid creating duplicate contents during sorting the data private val data = stages.map(stageRow).sorted(ordering(sortColumn, desc)) - private var _slicedStageIds: Set[Int] = _ - override def dataSize: Int = data.size - override def sliceData(from: Int, to: Int): Seq[StageTableRowData] = { - val r = data.slice(from, to) - _slicedStageIds = r.map(_.stageId).toSet - r - } + override def sliceData(from: Int, to: Int): Seq[StageTableRowData] = data.slice(from, to) private def stageRow(stageData: v1.StageData): StageTableRowData = { val formattedSubmissionTime = stageData.submissionTime match { @@ -422,7 +342,6 @@ private[ui] class StageDataSource( val shuffleWrite = stageData.shuffleWriteBytes val shuffleWriteWithUnit = if (shuffleWrite > 0) Utils.bytesToString(shuffleWrite) else "" - new StageTableRowData( stageData, Some(stageData), diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 04f8d8edd4d50..97f3cf534fb2c 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -35,15 +35,7 @@ private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends Web val parameterId = request.getParameter("id") require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") - val parameterBlockPage = request.getParameter("block.page") - val parameterBlockSortColumn = request.getParameter("block.sort") - val parameterBlockSortDesc = request.getParameter("block.desc") - val parameterBlockPageSize = request.getParameter("block.pageSize") - - val blockPage = Option(parameterBlockPage).map(_.toInt).getOrElse(1) - val blockSortColumn = Option(parameterBlockSortColumn).getOrElse("Block Name") - val blockSortDesc = Option(parameterBlockSortDesc).map(_.toBoolean).getOrElse(false) - val blockPageSize = Option(parameterBlockPageSize).map(_.toInt).getOrElse(100) + val blockPage = Option(request.getParameter("block.page")).map(_.toInt).getOrElse(1) val rddId = parameterId.toInt val rddStorageInfo = try { @@ -60,11 +52,10 @@ private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends Web val blockTableHTML = try { val _blockTable = new BlockPagedTable( + request, + "block", UIUtils.prependBaseUri(request, parent.basePath) + s"/storage/rdd/?id=${rddId}", rddStorageInfo.partitions.get, - blockPageSize, - blockSortColumn, - blockSortDesc, store.executorList(true)) _blockTable.table(blockPage) } catch { @@ -216,21 +207,22 @@ private[ui] class BlockDataSource( } private[ui] class BlockPagedTable( + request: HttpServletRequest, + rddTag: String, basePath: String, rddPartitions: Seq[RDDPartitionInfo], - pageSize: Int, - sortColumn: String, - desc: Boolean, executorSummaries: Seq[ExecutorSummary]) extends PagedTable[BlockTableRowData] { + private val (sortColumn, desc, pageSize) = getTableParameters(request, rddTag, "Block Name") + override def tableId: String = "rdd-storage-by-block-table" override def tableCssClass: String = "table table-bordered table-sm table-striped table-head-clickable" - override def pageSizeFormField: String = "block.pageSize" + override def pageSizeFormField: String = s"$rddTag.pageSize" - override def pageNumberFormField: String = "block.page" + override def pageNumberFormField: String = s"$rddTag.page" override val dataSource: BlockDataSource = new BlockDataSource( rddPartitions, @@ -254,46 +246,16 @@ private[ui] class BlockPagedTable( } override def headers: Seq[Node] = { - val blockHeaders = Seq( + val blockHeaders: Seq[(String, Boolean, Option[String])] = Seq( "Block Name", "Storage Level", "Size in Memory", "Size on Disk", - "Executors") + "Executors").map(x => (x, true, None)) - if (!blockHeaders.contains(sortColumn)) { - throw new IllegalArgumentException(s"Unknown column: $sortColumn") - } + isSortColumnValid(blockHeaders, sortColumn) - val headerRow: Seq[Node] = { - blockHeaders.map { header => - if (header == sortColumn) { - val headerLink = Unparsed( - basePath + - s"&block.sort=${URLEncoder.encode(header, UTF_8.name())}" + - s"&block.desc=${!desc}" + - s"&block.pageSize=$pageSize") - val arrow = if (desc) "▾" else "▴" // UP or DOWN - - - {header} -  {Unparsed(arrow)} - - - } else { - val headerLink = Unparsed( - basePath + - s"&block.sort=${URLEncoder.encode(header, UTF_8.name())}" + - s"&block.pageSize=$pageSize") - - - {header} - - - } - } - } - {headerRow} + headerRow(blockHeaders, desc, pageSize, sortColumn, basePath, rddTag, "block") } override def row(block: BlockTableRowData): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index d2ad14f2a1a96..6ffd6605f75b8 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -18,12 +18,15 @@ package org.apache.spark.util import java.io.{ByteArrayInputStream, ByteArrayOutputStream} -import java.lang.invoke.SerializedLambda +import java.lang.invoke.{MethodHandleInfo, SerializedLambda} +import scala.collection.JavaConverters._ import scala.collection.mutable.{Map, Set, Stack} -import org.apache.xbean.asm7.{ClassReader, ClassVisitor, MethodVisitor, Type} +import org.apache.commons.lang3.ClassUtils +import org.apache.xbean.asm7.{ClassReader, ClassVisitor, Handle, MethodVisitor, Type} import org.apache.xbean.asm7.Opcodes._ +import org.apache.xbean.asm7.tree.{ClassNode, MethodNode} import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.internal.Logging @@ -159,39 +162,6 @@ private[spark] object ClosureCleaner extends Logging { clean(closure, checkSerializable, cleanTransitively, Map.empty) } - /** - * Try to get a serialized Lambda from the closure. - * - * @param closure the closure to check. - */ - private def getSerializedLambda(closure: AnyRef): Option[SerializedLambda] = { - val isClosureCandidate = - closure.getClass.isSynthetic && - closure - .getClass - .getInterfaces.exists(_.getName == "scala.Serializable") - - if (isClosureCandidate) { - try { - Option(inspect(closure)) - } catch { - case e: Exception => - // no need to check if debug is enabled here the Spark - // logging api covers this. - logDebug("Closure is not a serialized lambda.", e) - None - } - } else { - None - } - } - - private def inspect(closure: AnyRef): SerializedLambda = { - val writeReplace = closure.getClass.getDeclaredMethod("writeReplace") - writeReplace.setAccessible(true) - writeReplace.invoke(closure).asInstanceOf[java.lang.invoke.SerializedLambda] - } - /** * Helper method to clean the given closure in place. * @@ -239,12 +209,12 @@ private[spark] object ClosureCleaner extends Logging { cleanTransitively: Boolean, accessedFields: Map[Class[_], Set[String]]): Unit = { - // most likely to be the case with 2.12, 2.13 + // indylambda check. Most likely to be the case with 2.12, 2.13 // so we check first // non LMF-closures should be less frequent from now on - val lambdaFunc = getSerializedLambda(func) + val maybeIndylambdaProxy = IndylambdaScalaClosures.getSerializationProxy(func) - if (!isClosure(func.getClass) && lambdaFunc.isEmpty) { + if (!isClosure(func.getClass) && maybeIndylambdaProxy.isEmpty) { logDebug(s"Expected a closure; got ${func.getClass.getName}") return } @@ -256,7 +226,7 @@ private[spark] object ClosureCleaner extends Logging { return } - if (lambdaFunc.isEmpty) { + if (maybeIndylambdaProxy.isEmpty) { logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}) +++") // A list of classes that represents closures enclosed in the given one @@ -300,7 +270,7 @@ private[spark] object ClosureCleaner extends Logging { } } - logDebug(s" + fields accessed by starting closure: " + accessedFields.size) + logDebug(s" + fields accessed by starting closure: ${accessedFields.size} classes") accessedFields.foreach { f => logDebug(" " + f) } // List of outer (class, object) pairs, ordered from outermost to innermost @@ -372,14 +342,64 @@ private[spark] object ClosureCleaner extends Logging { logDebug(s" +++ closure $func (${func.getClass.getName}) is now cleaned +++") } else { - logDebug(s"Cleaning lambda: ${lambdaFunc.get.getImplMethodName}") + val lambdaProxy = maybeIndylambdaProxy.get + val implMethodName = lambdaProxy.getImplMethodName + + logDebug(s"Cleaning indylambda closure: $implMethodName") + + // capturing class is the class that declared this lambda + val capturingClassName = lambdaProxy.getCapturingClass.replace('/', '.') + val classLoader = func.getClass.getClassLoader // this is the safest option + // scalastyle:off classforname + val capturingClass = Class.forName(capturingClassName, false, classLoader) + // scalastyle:on classforname - val captClass = Utils.classForName(lambdaFunc.get.getCapturingClass.replace('/', '.'), - initialize = false, noSparkClassLoader = true) // Fail fast if we detect return statements in closures - getClassReader(captClass) - .accept(new ReturnStatementFinder(Some(lambdaFunc.get.getImplMethodName)), 0) - logDebug(s" +++ Lambda closure (${lambdaFunc.get.getImplMethodName}) is now cleaned +++") + val capturingClassReader = getClassReader(capturingClass) + capturingClassReader.accept(new ReturnStatementFinder(Option(implMethodName)), 0) + + val isClosureDeclaredInScalaRepl = capturingClassName.startsWith("$line") && + capturingClassName.endsWith("$iw") + val outerThisOpt = if (lambdaProxy.getCapturedArgCount > 0) { + Option(lambdaProxy.getCapturedArg(0)) + } else { + None + } + + // only need to clean when there is an enclosing "this" captured by the closure, and it + // should be something cleanable, i.e. a Scala REPL line object + val needsCleaning = isClosureDeclaredInScalaRepl && + outerThisOpt.isDefined && outerThisOpt.get.getClass.getName == capturingClassName + + if (needsCleaning) { + // indylambda closures do not reference enclosing closures via an `$outer` chain, so no + // transitive cleaning on the `$outer` chain is needed. + // Thus clean() shouldn't be recursively called with a non-empty accessedFields. + assert(accessedFields.isEmpty) + + initAccessedFields(accessedFields, Seq(capturingClass)) + IndylambdaScalaClosures.findAccessedFields( + lambdaProxy, classLoader, accessedFields, cleanTransitively) + + logDebug(s" + fields accessed by starting closure: ${accessedFields.size} classes") + accessedFields.foreach { f => logDebug(" " + f) } + + if (accessedFields(capturingClass).size < capturingClass.getDeclaredFields.length) { + // clone and clean the enclosing `this` only when there are fields to null out + + val outerThis = outerThisOpt.get + + logDebug(s" + cloning instance of REPL class $capturingClassName") + val clonedOuterThis = cloneAndSetFields( + parent = null, outerThis, capturingClass, accessedFields) + + val outerField = func.getClass.getDeclaredField("arg$1") + outerField.setAccessible(true) + outerField.set(func, clonedOuterThis) + } + } + + logDebug(s" +++ indylambda closure ($implMethodName) is now cleaned +++") } if (checkSerializable) { @@ -414,6 +434,312 @@ private[spark] object ClosureCleaner extends Logging { } } +private[spark] object IndylambdaScalaClosures extends Logging { + // internal name of java.lang.invoke.LambdaMetafactory + val LambdaMetafactoryClassName = "java/lang/invoke/LambdaMetafactory" + // the method that Scala indylambda use for bootstrap method + val LambdaMetafactoryMethodName = "altMetafactory" + val LambdaMetafactoryMethodDesc = "(Ljava/lang/invoke/MethodHandles$Lookup;" + + "Ljava/lang/String;Ljava/lang/invoke/MethodType;[Ljava/lang/Object;)" + + "Ljava/lang/invoke/CallSite;" + + /** + * Check if the given reference is a indylambda style Scala closure. + * If so (e.g. for Scala 2.12+ closures), return a non-empty serialization proxy + * (SerializedLambda) of the closure; + * otherwise (e.g. for Scala 2.11 closures) return None. + * + * @param maybeClosure the closure to check. + */ + def getSerializationProxy(maybeClosure: AnyRef): Option[SerializedLambda] = { + def isClosureCandidate(cls: Class[_]): Boolean = { + // TODO: maybe lift this restriction to support other functional interfaces in the future + val implementedInterfaces = ClassUtils.getAllInterfaces(cls).asScala + implementedInterfaces.exists(_.getName.startsWith("scala.Function")) + } + + maybeClosure.getClass match { + // shortcut the fast check: + // 1. indylambda closure classes are generated by Java's LambdaMetafactory, and they're + // always synthetic. + // 2. We only care about Serializable closures, so let's check that as well + case c if !c.isSynthetic || !maybeClosure.isInstanceOf[Serializable] => None + + case c if isClosureCandidate(c) => + try { + Option(inspect(maybeClosure)).filter(isIndylambdaScalaClosure) + } catch { + case e: Exception => + logDebug("The given reference is not an indylambda Scala closure.", e) + None + } + + case _ => None + } + } + + def isIndylambdaScalaClosure(lambdaProxy: SerializedLambda): Boolean = { + lambdaProxy.getImplMethodKind == MethodHandleInfo.REF_invokeStatic && + lambdaProxy.getImplMethodName.contains("$anonfun$") + } + + def inspect(closure: AnyRef): SerializedLambda = { + val writeReplace = closure.getClass.getDeclaredMethod("writeReplace") + writeReplace.setAccessible(true) + writeReplace.invoke(closure).asInstanceOf[SerializedLambda] + } + + /** + * Check if the handle represents the LambdaMetafactory that indylambda Scala closures + * use for creating the lambda class and getting a closure instance. + */ + def isLambdaMetafactory(bsmHandle: Handle): Boolean = { + bsmHandle.getOwner == LambdaMetafactoryClassName && + bsmHandle.getName == LambdaMetafactoryMethodName && + bsmHandle.getDesc == LambdaMetafactoryMethodDesc + } + + /** + * Check if the handle represents a target method that is: + * - a STATIC method that implements a Scala lambda body in the indylambda style + * - captures the enclosing `this`, i.e. the first argument is a reference to the same type as + * the owning class. + * Returns true if both criteria above are met. + */ + def isLambdaBodyCapturingOuter(handle: Handle, ownerInternalName: String): Boolean = { + handle.getTag == H_INVOKESTATIC && + handle.getName.contains("$anonfun$") && + handle.getOwner == ownerInternalName && + handle.getDesc.startsWith(s"(L$ownerInternalName;") + } + + /** + * Check if the callee of a call site is a inner class constructor. + * - A constructor has to be invoked via INVOKESPECIAL + * - A constructor's internal name is "<init>" and the return type is "V" (void) + * - An inner class' first argument in the signature has to be a reference to the + * enclosing "this", aka `$outer` in Scala. + */ + def isInnerClassCtorCapturingOuter( + op: Int, owner: String, name: String, desc: String, callerInternalName: String): Boolean = { + op == INVOKESPECIAL && name == "" && desc.startsWith(s"(L$callerInternalName;") + } + + /** + * Scans an indylambda Scala closure, along with its lexically nested closures, and populate + * the accessed fields info on which fields on the outer object are accessed. + * + * This is equivalent to getInnerClosureClasses() + InnerClosureFinder + FieldAccessFinder fused + * into one for processing indylambda closures. The traversal order along the call graph is the + * same for all three combined, so they can be fused together easily while maintaining the same + * ordering as the existing implementation. + * + * Precondition: this function expects the `accessedFields` to be populated with all known + * outer classes and their super classes to be in the map as keys, e.g. + * initializing via ClosureCleaner.initAccessedFields. + */ + // scalastyle:off line.size.limit + // Example: run the following code snippet in a Spark Shell w/ Scala 2.12+: + // val topLevelValue = "someValue"; val closure = (j: Int) => { + // class InnerFoo { + // val innerClosure = (x: Int) => (1 to x).map { y => y + topLevelValue } + // } + // val innerFoo = new InnerFoo + // (1 to j).flatMap(innerFoo.innerClosure) + // } + // sc.parallelize(0 to 2).map(closure).collect + // + // produces the following trace-level logs: + // (slightly simplified: + // - omitting the "ignoring ..." lines; + // - "$iw" is actually "$line14.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw"; + // - "invokedynamic" lines are simplified to just show the name+desc, omitting the bsm info) + // Cleaning indylambda closure: $anonfun$closure$1$adapted + // scanning $iw.$anonfun$closure$1$adapted(L$iw;Ljava/lang/Object;)Lscala/collection/immutable/IndexedSeq; + // found intra class call to $iw.$anonfun$closure$1(L$iw;I)Lscala/collection/immutable/IndexedSeq; + // scanning $iw.$anonfun$closure$1(L$iw;I)Lscala/collection/immutable/IndexedSeq; + // found inner class $iw$InnerFoo$1 + // found method innerClosure()Lscala/Function1; + // found method $anonfun$innerClosure$2(L$iw$InnerFoo$1;I)Ljava/lang/String; + // found method $anonfun$innerClosure$1(L$iw$InnerFoo$1;I)Lscala/collection/immutable/IndexedSeq; + // found method (L$iw;)V + // found method $anonfun$innerClosure$2$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Ljava/lang/String; + // found method $anonfun$innerClosure$1$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Lscala/collection/immutable/IndexedSeq; + // found method $deserializeLambda$(Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object; + // found call to outer $iw$InnerFoo$1.innerClosure()Lscala/Function1; + // scanning $iw$InnerFoo$1.innerClosure()Lscala/Function1; + // scanning $iw$InnerFoo$1.$deserializeLambda$(Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object; + // invokedynamic: lambdaDeserialize(Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object;, bsm...) + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$1$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Lscala/collection/immutable/IndexedSeq; + // found intra class call to $iw$InnerFoo$1.$anonfun$innerClosure$1(L$iw$InnerFoo$1;I)Lscala/collection/immutable/IndexedSeq; + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$1(L$iw$InnerFoo$1;I)Lscala/collection/immutable/IndexedSeq; + // invokedynamic: apply(L$iw$InnerFoo$1;)Lscala/Function1;, bsm...) + // found inner closure $iw$InnerFoo$1.$anonfun$innerClosure$2$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Ljava/lang/String; (6) + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$2$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Ljava/lang/String; + // found intra class call to $iw$InnerFoo$1.$anonfun$innerClosure$2(L$iw$InnerFoo$1;I)Ljava/lang/String; + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$2(L$iw$InnerFoo$1;I)Ljava/lang/String; + // found call to outer $iw.topLevelValue()Ljava/lang/String; + // scanning $iw.topLevelValue()Ljava/lang/String; + // found field access topLevelValue on $iw + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$2$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Ljava/lang/String; + // found intra class call to $iw$InnerFoo$1.$anonfun$innerClosure$2(L$iw$InnerFoo$1;I)Ljava/lang/String; + // scanning $iw$InnerFoo$1.(L$iw;)V + // invokedynamic: apply(L$iw$InnerFoo$1;)Lscala/Function1;, bsm...) + // found inner closure $iw$InnerFoo$1.$anonfun$innerClosure$1$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Lscala/collection/immutable/IndexedSeq; (6) + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$1(L$iw$InnerFoo$1;I)Lscala/collection/immutable/IndexedSeq; + // invokedynamic: apply(L$iw$InnerFoo$1;)Lscala/Function1;, bsm...) + // found inner closure $iw$InnerFoo$1.$anonfun$innerClosure$2$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Ljava/lang/String; (6) + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$2(L$iw$InnerFoo$1;I)Ljava/lang/String; + // found call to outer $iw.topLevelValue()Ljava/lang/String; + // scanning $iw$InnerFoo$1.innerClosure()Lscala/Function1; + // + fields accessed by starting closure: 2 classes + // (class java.lang.Object,Set()) + // (class $iw,Set(topLevelValue)) + // + cloning instance of REPL class $iw + // +++ indylambda closure ($anonfun$closure$1$adapted) is now cleaned +++ + // + // scalastyle:on line.size.limit + def findAccessedFields( + lambdaProxy: SerializedLambda, + lambdaClassLoader: ClassLoader, + accessedFields: Map[Class[_], Set[String]], + findTransitively: Boolean): Unit = { + + // We may need to visit the same class multiple times for different methods on it, and we'll + // need to lookup by name. So we use ASM's Tree API and cache the ClassNode/MethodNode. + val classInfoByInternalName = Map.empty[String, (Class[_], ClassNode)] + val methodNodeById = Map.empty[MethodIdentifier[_], MethodNode] + def getOrUpdateClassInfo(classInternalName: String): (Class[_], ClassNode) = { + val classInfo = classInfoByInternalName.getOrElseUpdate(classInternalName, { + val classExternalName = classInternalName.replace('/', '.') + // scalastyle:off classforname + val clazz = Class.forName(classExternalName, false, lambdaClassLoader) + // scalastyle:on classforname + val classNode = new ClassNode() + val classReader = ClosureCleaner.getClassReader(clazz) + classReader.accept(classNode, 0) + + for (m <- classNode.methods.asScala) { + methodNodeById(MethodIdentifier(clazz, m.name, m.desc)) = m + } + + (clazz, classNode) + }) + classInfo + } + + val implClassInternalName = lambdaProxy.getImplClass + val (implClass, _) = getOrUpdateClassInfo(implClassInternalName) + + val implMethodId = MethodIdentifier( + implClass, lambdaProxy.getImplMethodName, lambdaProxy.getImplMethodSignature) + + // The set internal names of classes that we would consider following the calls into. + // Candidates are: known outer class which happens to be the starting closure's impl class, + // and all inner classes discovered below. + // Note that code in an inner class can make calls to methods in any of its enclosing classes, + // e.g. + // starting closure (in class T) + // inner class A + // inner class B + // inner closure + // we need to track calls from "inner closure" to outer classes relative to it (class T, A, B) + // to better find and track field accesses. + val trackedClassInternalNames = Set[String](implClassInternalName) + + // Depth-first search for inner closures and track the fields that were accessed in them. + // Start from the lambda body's implementation method, follow method invocations + val visited = Set.empty[MethodIdentifier[_]] + val stack = Stack[MethodIdentifier[_]](implMethodId) + def pushIfNotVisited(methodId: MethodIdentifier[_]): Unit = { + if (!visited.contains(methodId)) { + stack.push(methodId) + } + } + + while (!stack.isEmpty) { + val currentId = stack.pop + visited += currentId + + val currentClass = currentId.cls + val currentMethodNode = methodNodeById(currentId) + logTrace(s" scanning ${currentId.cls.getName}.${currentId.name}${currentId.desc}") + currentMethodNode.accept(new MethodVisitor(ASM7) { + val currentClassName = currentClass.getName + val currentClassInternalName = currentClassName.replace('.', '/') + + // Find and update the accessedFields info. Only fields on known outer classes are tracked. + // This is the FieldAccessFinder equivalent. + override def visitFieldInsn(op: Int, owner: String, name: String, desc: String): Unit = { + if (op == GETFIELD || op == PUTFIELD) { + val ownerExternalName = owner.replace('/', '.') + for (cl <- accessedFields.keys if cl.getName == ownerExternalName) { + logTrace(s" found field access $name on $ownerExternalName") + accessedFields(cl) += name + } + } + } + + override def visitMethodInsn( + op: Int, owner: String, name: String, desc: String, itf: Boolean): Unit = { + val ownerExternalName = owner.replace('/', '.') + if (owner == currentClassInternalName) { + logTrace(s" found intra class call to $ownerExternalName.$name$desc") + // could be invoking a helper method or a field accessor method, just follow it. + pushIfNotVisited(MethodIdentifier(currentClass, name, desc)) + } else if (isInnerClassCtorCapturingOuter( + op, owner, name, desc, currentClassInternalName)) { + // Discover inner classes. + // This this the InnerClassFinder equivalent for inner classes, which still use the + // `$outer` chain. So this is NOT controlled by the `findTransitively` flag. + logDebug(s" found inner class $ownerExternalName") + val innerClassInfo = getOrUpdateClassInfo(owner) + val innerClass = innerClassInfo._1 + val innerClassNode = innerClassInfo._2 + trackedClassInternalNames += owner + // We need to visit all methods on the inner class so that we don't missing anything. + for (m <- innerClassNode.methods.asScala) { + logTrace(s" found method ${m.name}${m.desc}") + pushIfNotVisited(MethodIdentifier(innerClass, m.name, m.desc)) + } + } else if (findTransitively && trackedClassInternalNames.contains(owner)) { + logTrace(s" found call to outer $ownerExternalName.$name$desc") + val (calleeClass, _) = getOrUpdateClassInfo(owner) // make sure MethodNodes are cached + pushIfNotVisited(MethodIdentifier(calleeClass, name, desc)) + } else { + // keep the same behavior as the original ClosureCleaner + logTrace(s" ignoring call to $ownerExternalName.$name$desc") + } + } + + // Find the lexically nested closures + // This is the InnerClosureFinder equivalent for indylambda nested closures + override def visitInvokeDynamicInsn( + name: String, desc: String, bsmHandle: Handle, bsmArgs: Object*): Unit = { + logTrace(s" invokedynamic: $name$desc, bsmHandle=$bsmHandle, bsmArgs=$bsmArgs") + + // fast check: we only care about Scala lambda creation + // TODO: maybe lift this restriction and support other functional interfaces + if (!name.startsWith("apply")) return + if (!Type.getReturnType(desc).getDescriptor.startsWith("Lscala/Function")) return + + if (isLambdaMetafactory(bsmHandle)) { + // OK we're in the right bootstrap method for serializable Java 8 style lambda creation + val targetHandle = bsmArgs(1).asInstanceOf[Handle] + if (isLambdaBodyCapturingOuter(targetHandle, currentClassInternalName)) { + // this is a lexically nested closure that also captures the enclosing `this` + logDebug(s" found inner closure $targetHandle") + val calleeMethodId = + MethodIdentifier(currentClass, targetHandle.getName, targetHandle.getDesc) + pushIfNotVisited(calleeMethodId) + } + } + } + }) + } + } +} + private[spark] class ReturnStatementInClosureException extends SparkException("Return statements aren't allowed in Spark closures") @@ -422,7 +748,7 @@ private class ReturnStatementFinder(targetMethodName: Option[String] = None) override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { - // $anonfun$ covers Java 8 lambdas + // $anonfun$ covers indylambda closures if (name.contains("apply") || name.contains("$anonfun$")) { // A method with suffix "$adapted" will be generated in cases like // { _:Int => return; Seq()} but not { _:Int => return; true} diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 9254ac94005f1..844d9b7cf2c27 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -33,7 +33,7 @@ import org.apache.spark._ import org.apache.spark.executor._ import org.apache.spark.metrics.ExecutorMetricType import org.apache.spark.rdd.RDDOperationScope -import org.apache.spark.resource.{ResourceInformation, ResourceProfile} +import org.apache.spark.resource.{ExecutorResourceRequest, ResourceInformation, ResourceProfile, TaskResourceRequest} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage._ @@ -105,6 +105,8 @@ private[spark] object JsonProtocol { stageExecutorMetricsToJson(stageExecutorMetrics) case blockUpdate: SparkListenerBlockUpdated => blockUpdateToJson(blockUpdate) + case resourceProfileAdded: SparkListenerResourceProfileAdded => + resourceProfileAddedToJson(resourceProfileAdded) case _ => parse(mapper.writeValueAsString(event)) } } @@ -224,6 +226,15 @@ private[spark] object JsonProtocol { ("Timestamp" -> applicationEnd.time) } + def resourceProfileAddedToJson(profileAdded: SparkListenerResourceProfileAdded): JValue = { + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.resourceProfileAdded) ~ + ("Resource Profile Id" -> profileAdded.resourceProfile.id) ~ + ("Executor Resource Requests" -> + executorResourceRequestMapToJson(profileAdded.resourceProfile.executorResources)) ~ + ("Task Resource Requests" -> + taskResourceRequestMapToJson(profileAdded.resourceProfile.taskResources)) + } + def executorAddedToJson(executorAdded: SparkListenerExecutorAdded): JValue = { ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.executorAdded) ~ ("Timestamp" -> executorAdded.time) ~ @@ -297,7 +308,8 @@ private[spark] object JsonProtocol { ("Submission Time" -> submissionTime) ~ ("Completion Time" -> completionTime) ~ ("Failure Reason" -> failureReason) ~ - ("Accumulables" -> accumulablesToJson(stageInfo.accumulables.values)) + ("Accumulables" -> accumulablesToJson(stageInfo.accumulables.values)) ~ + ("Resource Profile Id" -> stageInfo.resourceProfileId) } def taskInfoToJson(taskInfo: TaskInfo): JValue = { @@ -475,6 +487,7 @@ private[spark] object JsonProtocol { ("Callsite" -> rddInfo.callSite) ~ ("Parent IDs" -> parentIds) ~ ("Storage Level" -> storageLevel) ~ + ("Barrier" -> rddInfo.isBarrier) ~ ("Number of Partitions" -> rddInfo.numPartitions) ~ ("Number of Cached Partitions" -> rddInfo.numCachedPartitions) ~ ("Memory Size" -> rddInfo.memSize) ~ @@ -500,7 +513,8 @@ private[spark] object JsonProtocol { ("Total Cores" -> executorInfo.totalCores) ~ ("Log Urls" -> mapToJson(executorInfo.logUrlMap)) ~ ("Attributes" -> mapToJson(executorInfo.attributes)) ~ - ("Resources" -> resourcesMapToJson(executorInfo.resourcesInfo)) + ("Resources" -> resourcesMapToJson(executorInfo.resourcesInfo)) ~ + ("Resource Profile Id" -> executorInfo.resourceProfileId) } def resourcesMapToJson(m: Map[String, ResourceInformation]): JValue = { @@ -518,6 +532,34 @@ private[spark] object JsonProtocol { ("Disk Size" -> blockUpdatedInfo.diskSize) } + def executorResourceRequestToJson(execReq: ExecutorResourceRequest): JValue = { + ("Resource Name" -> execReq.resourceName) ~ + ("Amount" -> execReq.amount) ~ + ("Discovery Script" -> execReq.discoveryScript) ~ + ("Vendor" -> execReq.vendor) + } + + def executorResourceRequestMapToJson(m: Map[String, ExecutorResourceRequest]): JValue = { + val jsonFields = m.map { + case (k, execReq) => + JField(k, executorResourceRequestToJson(execReq)) + } + JObject(jsonFields.toList) + } + + def taskResourceRequestToJson(taskReq: TaskResourceRequest): JValue = { + ("Resource Name" -> taskReq.resourceName) ~ + ("Amount" -> taskReq.amount) + } + + def taskResourceRequestMapToJson(m: Map[String, TaskResourceRequest]): JValue = { + val jsonFields = m.map { + case (k, taskReq) => + JField(k, taskResourceRequestToJson(taskReq)) + } + JObject(jsonFields.toList) + } + /** ------------------------------ * * Util JSON serialization methods | * ------------------------------- */ @@ -577,6 +619,7 @@ private[spark] object JsonProtocol { val metricsUpdate = Utils.getFormattedClassName(SparkListenerExecutorMetricsUpdate) val stageExecutorMetrics = Utils.getFormattedClassName(SparkListenerStageExecutorMetrics) val blockUpdate = Utils.getFormattedClassName(SparkListenerBlockUpdated) + val resourceProfileAdded = Utils.getFormattedClassName(SparkListenerResourceProfileAdded) } def sparkEventFromJson(json: JValue): SparkListenerEvent = { @@ -602,6 +645,7 @@ private[spark] object JsonProtocol { case `metricsUpdate` => executorMetricsUpdateFromJson(json) case `stageExecutorMetrics` => stageExecutorMetricsFromJson(json) case `blockUpdate` => blockUpdateFromJson(json) + case `resourceProfileAdded` => resourceProfileAddedFromJson(json) case other => mapper.readValue(compact(render(json)), Utils.classForName(other)) .asInstanceOf[SparkListenerEvent] } @@ -678,6 +722,45 @@ private[spark] object JsonProtocol { SparkListenerJobEnd(jobId, completionTime, jobResult) } + def resourceProfileAddedFromJson(json: JValue): SparkListenerResourceProfileAdded = { + val profId = (json \ "Resource Profile Id").extract[Int] + val executorReqs = executorResourceRequestMapFromJson(json \ "Executor Resource Requests") + val taskReqs = taskResourceRequestMapFromJson(json \ "Task Resource Requests") + val rp = new ResourceProfile(executorReqs.toMap, taskReqs.toMap) + rp.setResourceProfileId(profId) + SparkListenerResourceProfileAdded(rp) + } + + def executorResourceRequestFromJson(json: JValue): ExecutorResourceRequest = { + val rName = (json \ "Resource Name").extract[String] + val amount = (json \ "Amount").extract[Int] + val discoveryScript = (json \ "Discovery Script").extract[String] + val vendor = (json \ "Vendor").extract[String] + new ExecutorResourceRequest(rName, amount, discoveryScript, vendor) + } + + def taskResourceRequestFromJson(json: JValue): TaskResourceRequest = { + val rName = (json \ "Resource Name").extract[String] + val amount = (json \ "Amount").extract[Int] + new TaskResourceRequest(rName, amount) + } + + def taskResourceRequestMapFromJson(json: JValue): Map[String, TaskResourceRequest] = { + val jsonFields = json.asInstanceOf[JObject].obj + jsonFields.map { case JField(k, v) => + val req = taskResourceRequestFromJson(v) + (k, req) + }.toMap + } + + def executorResourceRequestMapFromJson(json: JValue): Map[String, ExecutorResourceRequest] = { + val jsonFields = json.asInstanceOf[JObject].obj + jsonFields.map { case JField(k, v) => + val req = executorResourceRequestFromJson(v) + (k, req) + }.toMap + } + def environmentUpdateFromJson(json: JValue): SparkListenerEnvironmentUpdate = { // For compatible with previous event logs val hadoopProperties = jsonOption(json \ "Hadoop Properties").map(mapFromJson(_).toSeq) @@ -804,9 +887,10 @@ private[spark] object JsonProtocol { } } - val stageInfo = new StageInfo( - stageId, attemptId, stageName, numTasks, rddInfos, parentIds, details, - resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) + val rpId = jsonOption(json \ "Resource Profile Id").map(_.extract[Int]) + val stageProf = rpId.getOrElse(ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) + val stageInfo = new StageInfo(stageId, attemptId, stageName, numTasks, rddInfos, + parentIds, details, resourceProfileId = stageProf) stageInfo.submissionTime = submissionTime stageInfo.completionTime = completionTime stageInfo.failureReason = failureReason @@ -1109,7 +1193,11 @@ private[spark] object JsonProtocol { case Some(resources) => resourcesMapFromJson(resources).toMap case None => Map.empty[String, ResourceInformation] } - new ExecutorInfo(executorHost, totalCores, logUrls, attributes, resources) + val resourceProfileId = jsonOption(json \ "Resource Profile Id") match { + case Some(id) => id.extract[Int] + case None => ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID + } + new ExecutorInfo(executorHost, totalCores, logUrls, attributes, resources, resourceProfileId) } def blockUpdatedInfoFromJson(json: JValue): BlockUpdatedInfo = { diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala index 4f1311224bb95..4db268604a3e9 100644 --- a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala +++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala @@ -209,9 +209,7 @@ private [util] class SparkShutdownHookManager { private class SparkShutdownHook(private val priority: Int, hook: () => Unit) extends Comparable[SparkShutdownHook] { - override def compareTo(other: SparkShutdownHook): Int = { - other.priority - priority - } + override def compareTo(other: SparkShutdownHook): Int = other.priority.compareTo(priority) def run(): Unit = hook() diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c7db2127a6f04..9636fe88c77c2 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2904,6 +2904,24 @@ private[spark] object Utils extends Logging { props.forEach((k, v) => resultProps.put(k, v)) resultProps } + + /** + * Convert a sequence of `Path`s to a metadata string. When the length of metadata string + * exceeds `stopAppendingThreshold`, stop appending paths for saving memory. + */ + def buildLocationMetadata(paths: Seq[Path], stopAppendingThreshold: Int): String = { + val metadata = new StringBuilder("[") + var index: Int = 0 + while (index < paths.length && metadata.length < stopAppendingThreshold) { + if (index > 0) { + metadata.append(", ") + } + metadata.append(paths(index).toString) + index += 1 + } + metadata.append("]") + metadata.toString + } } private[util] object CallerContext extends Logging { diff --git a/core/src/test/resources/HistoryServerExpectations/app_environment_expectation.json b/core/src/test/resources/HistoryServerExpectations/app_environment_expectation.json index a64617256d63a..0b617a7d0aced 100644 --- a/core/src/test/resources/HistoryServerExpectations/app_environment_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/app_environment_expectation.json @@ -282,5 +282,6 @@ [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-jobclient-2.2.0.jar", "System Classpath" ], [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-sketch_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/pmml-model-1.2.15.jar", "System Classpath" ] - ] + ], + "resourceProfiles" : [ ] } diff --git a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json index 6e6d28b6a57ec..d2b3d1b069204 100644 --- a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json @@ -1,4 +1,19 @@ [ { + "id" : "application_1578436911597_0052", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2020-01-11T17:44:22.851GMT", + "endTime" : "2020-01-11T17:46:42.615GMT", + "lastUpdated" : "", + "duration" : 139764, + "sparkUser" : "tgraves", + "completed" : true, + "appSparkVersion" : "3.0.0-SNAPSHOT", + "endTimeEpoch" : 1578764802615, + "startTimeEpoch" : 1578764662851, + "lastUpdatedEpoch" : 0 + } ] +}, { "id" : "application_1555004656427_0144", "name" : "Spark shell", "attempts" : [ { diff --git a/core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json b/core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json index b18b19f7eeffb..0d197eab0e25d 100644 --- a/core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json @@ -717,5 +717,6 @@ "isBlacklistedForStage" : false } }, - "killedTasksSummary" : { } + "killedTasksSummary" : { }, + "resourceProfileId" : 0 } diff --git a/core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json b/core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json index 8d11081247913..24d73faa45021 100644 --- a/core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json @@ -876,5 +876,6 @@ "isBlacklistedForStage" : true } }, - "killedTasksSummary" : { } + "killedTasksSummary" : { }, + "resourceProfileId" : 0 } diff --git a/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json index a47cd26ed102b..a452488294547 100644 --- a/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json @@ -41,7 +41,8 @@ "schedulingPool" : "default", "rddIds" : [ 6, 5 ], "accumulatorUpdates" : [ ], - "killedTasksSummary" : { } + "killedTasksSummary" : { }, + "resourceProfileId" : 0 }, { "status" : "COMPLETE", "stageId" : 1, @@ -85,7 +86,8 @@ "schedulingPool" : "default", "rddIds" : [ 1, 0 ], "accumulatorUpdates" : [ ], - "killedTasksSummary" : { } + "killedTasksSummary" : { }, + "resourceProfileId" : 0 }, { "status" : "COMPLETE", "stageId" : 0, @@ -129,5 +131,6 @@ "schedulingPool" : "default", "rddIds" : [ 0 ], "accumulatorUpdates" : [ ], - "killedTasksSummary" : { } + "killedTasksSummary" : { }, + "resourceProfileId" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json index 6e6d28b6a57ec..d2b3d1b069204 100644 --- a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json @@ -1,4 +1,19 @@ [ { + "id" : "application_1578436911597_0052", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2020-01-11T17:44:22.851GMT", + "endTime" : "2020-01-11T17:46:42.615GMT", + "lastUpdated" : "", + "duration" : 139764, + "sparkUser" : "tgraves", + "completed" : true, + "appSparkVersion" : "3.0.0-SNAPSHOT", + "endTimeEpoch" : 1578764802615, + "startTimeEpoch" : 1578764662851, + "lastUpdatedEpoch" : 0 + } ] +}, { "id" : "application_1555004656427_0144", "name" : "Spark shell", "attempts" : [ { diff --git a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json index eadf27164c814..67425676a62d6 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json @@ -22,5 +22,6 @@ "executorLogs" : { }, "blacklistedInStages" : [ ], "attributes" : { }, - "resources" : { } + "resources" : { }, + "resourceProfileId" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_list_with_executor_metrics_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_list_with_executor_metrics_json_expectation.json index d322485baa8de..d052a27385f66 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_list_with_executor_metrics_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_list_with_executor_metrics_json_expectation.json @@ -50,7 +50,8 @@ "MajorGCTime" : 144 }, "attributes" : { }, - "resources" : { } + "resources" : { }, + "resourceProfileId" : 0 }, { "id" : "3", "hostPort" : "test-3.vpc.company.com:37641", @@ -116,7 +117,8 @@ "NM_HOST" : "test-3.vpc.company.com", "CONTAINER_ID" : "container_1553914137147_0018_01_000004" }, - "resources" : { } + "resources" : { }, + "resourceProfileId" : 0 }, { "id" : "2", "hostPort" : "test-4.vpc.company.com:33179", @@ -182,7 +184,8 @@ "NM_HOST" : "test-4.vpc.company.com", "CONTAINER_ID" : "container_1553914137147_0018_01_000003" }, - "resources" : { } + "resources" : { }, + "resourceProfileId" : 0 }, { "id" : "1", "hostPort" : "test-2.vpc.company.com:43764", @@ -248,5 +251,6 @@ "NM_HOST" : "test-2.vpc.company.com", "CONTAINER_ID" : "container_1553914137147_0018_01_000002" }, - "resources" : { } + "resources" : { }, + "resourceProfileId" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json index 7c3f77d8c10cf..91574ca8266b2 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json @@ -28,7 +28,8 @@ }, "blacklistedInStages" : [ ], "attributes" : { }, - "resources" : { } + "resources" : { }, + "resourceProfileId" : 0 }, { "id" : "3", "hostPort" : "172.22.0.167:51485", @@ -62,7 +63,8 @@ }, "blacklistedInStages" : [ ], "attributes" : { }, - "resources" : { } + "resources" : { }, + "resourceProfileId" : 0 } ,{ "id" : "2", "hostPort" : "172.22.0.167:51487", @@ -96,7 +98,8 @@ }, "blacklistedInStages" : [ ], "attributes" : { }, - "resources" : { } + "resources" : { }, + "resourceProfileId" : 0 }, { "id" : "1", "hostPort" : "172.22.0.167:51490", @@ -130,7 +133,8 @@ }, "blacklistedInStages" : [ ], "attributes" : { }, - "resources" : { } + "resources" : { }, + "resourceProfileId" : 0 }, { "id" : "0", "hostPort" : "172.22.0.167:51491", @@ -164,5 +168,6 @@ }, "blacklistedInStages" : [ ], "attributes" : { }, - "resources" : { } + "resources" : { }, + "resourceProfileId" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json index 0986e85f16b3e..f14b9a5085a42 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json @@ -28,7 +28,8 @@ }, "blacklistedInStages" : [ ], "attributes" : { }, - "resources" : { } + "resources" : { }, + "resourceProfileId" : 0 }, { "id" : "3", "hostPort" : "172.22.0.167:51485", @@ -62,7 +63,8 @@ }, "blacklistedInStages" : [ ], "attributes" : { }, - "resources" : { } + "resources" : { }, + "resourceProfileId" : 0 }, { "id" : "2", "hostPort" : "172.22.0.167:51487", @@ -96,7 +98,8 @@ }, "blacklistedInStages" : [ ], "attributes" : { }, - "resources" : { } + "resources" : { }, + "resourceProfileId" : 0 }, { "id" : "1", "hostPort" : "172.22.0.167:51490", @@ -130,7 +133,8 @@ }, "blacklistedInStages" : [ ], "attributes" : { }, - "resources" : { } + "resources" : { }, + "resourceProfileId" : 0 }, { "id" : "0", "hostPort" : "172.22.0.167:51491", @@ -164,5 +168,6 @@ }, "blacklistedInStages" : [ ], "attributes" : { }, - "resources" : { } + "resources" : { }, + "resourceProfileId" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json index 26d665151a52d..3645387317ca1 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json @@ -22,7 +22,8 @@ "executorLogs" : { }, "blacklistedInStages" : [ ], "attributes" : { }, - "resources" : { } + "resources" : { }, + "resourceProfileId" : 0 }, { "id" : "3", "hostPort" : "172.22.0.111:64543", @@ -50,7 +51,8 @@ }, "blacklistedInStages" : [ ], "attributes" : { }, - "resources" : { } + "resources" : { }, + "resourceProfileId" : 0 }, { "id" : "2", "hostPort" : "172.22.0.111:64539", @@ -78,7 +80,8 @@ }, "blacklistedInStages" : [ ], "attributes" : { }, - "resources" : { } + "resources" : { }, + "resourceProfileId" : 0 }, { "id" : "1", "hostPort" : "172.22.0.111:64541", @@ -106,7 +109,8 @@ }, "blacklistedInStages" : [ ], "attributes" : { }, - "resources" : { } + "resources" : { }, + "resourceProfileId" : 0 }, { "id" : "0", "hostPort" : "172.22.0.111:64540", @@ -134,5 +138,6 @@ }, "blacklistedInStages" : [ ], "attributes" : { }, - "resources" : { } + "resources" : { }, + "resourceProfileId" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_resource_information_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_resource_information_expectation.json index e69ab3b49d455..165389cf25027 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_resource_information_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_resource_information_expectation.json @@ -28,7 +28,8 @@ }, "blacklistedInStages" : [ ], "attributes" : { }, - "resources" : { } + "resources" : { }, + "resourceProfileId" : 0 }, { "id" : "2", "hostPort" : "tomg-test:46005", @@ -77,7 +78,8 @@ "name" : "gpu", "addresses" : [ "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12" ] } - } + }, + "resourceProfileId" : 0 }, { "id" : "1", "hostPort" : "tomg-test:44873", @@ -126,5 +128,6 @@ "name" : "gpu", "addresses" : [ "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12" ] } - } + }, + "resourceProfileId" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json index da26271e66bc4..c38741646c64b 100644 --- a/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json @@ -42,5 +42,6 @@ "schedulingPool" : "default", "rddIds" : [ 3, 2 ], "accumulatorUpdates" : [ ], - "killedTasksSummary" : { } + "killedTasksSummary" : { }, + "resourceProfileId" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json index 3102909f81116..82489e94a84c8 100644 --- a/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json @@ -1,4 +1,19 @@ [ { + "id" : "application_1578436911597_0052", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2020-01-11T17:44:22.851GMT", + "endTime" : "2020-01-11T17:46:42.615GMT", + "lastUpdated" : "", + "duration" : 139764, + "sparkUser" : "tgraves", + "completed" : true, + "appSparkVersion" : "3.0.0-SNAPSHOT", + "endTimeEpoch" : 1578764802615, + "startTimeEpoch" : 1578764662851, + "lastUpdatedEpoch" : 0 + } ] +}, { "id" : "application_1555004656427_0144", "name" : "Spark shell", "attempts" : [ { @@ -28,19 +43,4 @@ "endTimeEpoch" : 1554756046454, "lastUpdatedEpoch" : 0 } ] -}, { - "id" : "application_1516285256255_0012", - "name" : "Spark shell", - "attempts" : [ { - "startTime" : "2018-01-18T18:30:35.119GMT", - "endTime" : "2018-01-18T18:38:27.938GMT", - "lastUpdated" : "", - "duration" : 472819, - "sparkUser" : "attilapiros", - "completed" : true, - "appSparkVersion" : "2.3.0-SNAPSHOT", - "lastUpdatedEpoch" : 0, - "startTimeEpoch" : 1516300235119, - "endTimeEpoch" : 1516300707938 - } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json index 794f1514a6708..ac2bb0e29b2fb 100644 --- a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json @@ -1,5 +1,19 @@ -[ - { +[ { + "id" : "application_1578436911597_0052", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2020-01-11T17:44:22.851GMT", + "endTime" : "2020-01-11T17:46:42.615GMT", + "lastUpdated" : "", + "duration" : 139764, + "sparkUser" : "tgraves", + "completed" : true, + "appSparkVersion" : "3.0.0-SNAPSHOT", + "endTimeEpoch" : 1578764802615, + "startTimeEpoch" : 1578764662851, + "lastUpdatedEpoch" : 0 + } ] +}, { "id": "application_1555004656427_0144", "name": "Spark shell", "attempts": [ diff --git a/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json index adcdccef48450..156167606ff20 100644 --- a/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json @@ -1,4 +1,19 @@ [ { + "id" : "application_1578436911597_0052", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2020-01-11T17:44:22.851GMT", + "endTime" : "2020-01-11T17:46:42.615GMT", + "lastUpdated" : "", + "duration" : 139764, + "sparkUser" : "tgraves", + "completed" : true, + "appSparkVersion" : "3.0.0-SNAPSHOT", + "endTimeEpoch" : 1578764802615, + "startTimeEpoch" : 1578764662851, + "lastUpdatedEpoch" : 0 + } ] +}, { "id" : "application_1555004656427_0144", "name" : "Spark shell", "attempts" : [ { diff --git a/core/src/test/resources/HistoryServerExpectations/multiple_resource_profiles_expectation.json b/core/src/test/resources/HistoryServerExpectations/multiple_resource_profiles_expectation.json new file mode 100644 index 0000000000000..5c1e4cc2337be --- /dev/null +++ b/core/src/test/resources/HistoryServerExpectations/multiple_resource_profiles_expectation.json @@ -0,0 +1,112 @@ +{ + "runtime" : { + "javaVersion" : "1.8.0_232 (Private Build)", + "javaHome" : "/usr/lib/jvm/java-8-openjdk-amd64/jre", + "scalaVersion" : "version 2.12.10" + }, + "sparkProperties" : [ ], + "hadoopProperties" : [ ], + "systemProperties" : [ ], + "classpathEntries" : [ ], + "resourceProfiles" : [ { + "id" : 0, + "executorResources" : { + "cores" : { + "resourceName" : "cores", + "amount" : 1, + "discoveryScript" : "", + "vendor" : "" + }, + "memory" : { + "resourceName" : "memory", + "amount" : 1024, + "discoveryScript" : "", + "vendor" : "" + }, + "gpu" : { + "resourceName" : "gpu", + "amount" : 1, + "discoveryScript" : "/home/tgraves/getGpus", + "vendor" : "" + } + }, + "taskResources" : { + "cpus" : { + "resourceName" : "cpus", + "amount" : 1.0 + }, + "gpu" : { + "resourceName" : "gpu", + "amount" : 1.0 + } + } + }, { + "id" : 1, + "executorResources" : { + "cores" : { + "resourceName" : "cores", + "amount" : 4, + "discoveryScript" : "", + "vendor" : "" + }, + "gpu" : { + "resourceName" : "gpu", + "amount" : 1, + "discoveryScript" : "./getGpus", + "vendor" : "" + } + }, + "taskResources" : { + "cpus" : { + "resourceName" : "cpus", + "amount" : 1.0 + }, + "gpu" : { + "resourceName" : "gpu", + "amount" : 1.0 + } + } + }, { + "id" : 2, + "executorResources" : { + "cores" : { + "resourceName" : "cores", + "amount" : 2, + "discoveryScript" : "", + "vendor" : "" + } + }, + "taskResources" : { + "cpus" : { + "resourceName" : "cpus", + "amount" : 2.0 + } + } + }, { + "id" : 3, + "executorResources" : { + "cores" : { + "resourceName" : "cores", + "amount" : 4, + "discoveryScript" : "", + "vendor" : "" + }, + "gpu" : { + "resourceName" : "gpu", + "amount" : 1, + "discoveryScript" : "./getGpus", + "vendor" : "" + } + }, + "taskResources" : { + "cpus" : { + "resourceName" : "cpus", + "amount" : 2.0 + }, + "gpu" : { + "resourceName" : "gpu", + "amount" : 1.0 + } + } + } ] +} diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json index 791907045e500..3db7d551b6130 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json @@ -462,5 +462,6 @@ "isBlacklistedForStage" : false } }, - "killedTasksSummary" : { } + "killedTasksSummary" : { }, + "resourceProfileId" : 0 } diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json index 50d3f74ae775f..8ef3769c1ca6b 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json @@ -462,5 +462,6 @@ "isBlacklistedForStage" : false } }, - "killedTasksSummary" : { } + "killedTasksSummary" : { }, + "resourceProfileId" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json index edbac7127039d..a31c907221388 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json @@ -41,7 +41,8 @@ "schedulingPool" : "default", "rddIds" : [ 6, 5 ], "accumulatorUpdates" : [ ], - "killedTasksSummary" : { } + "killedTasksSummary" : { }, + "resourceProfileId" : 0 }, { "status" : "FAILED", "stageId" : 2, @@ -86,7 +87,8 @@ "schedulingPool" : "default", "rddIds" : [ 3, 2 ], "accumulatorUpdates" : [ ], - "killedTasksSummary" : { } + "killedTasksSummary" : { }, + "resourceProfileId" : 0 }, { "status" : "COMPLETE", "stageId" : 1, @@ -130,7 +132,8 @@ "schedulingPool" : "default", "rddIds" : [ 1, 0 ], "accumulatorUpdates" : [ ], - "killedTasksSummary" : { } + "killedTasksSummary" : { }, + "resourceProfileId" : 0 }, { "status" : "COMPLETE", "stageId" : 0, @@ -174,5 +177,6 @@ "schedulingPool" : "default", "rddIds" : [ 0 ], "accumulatorUpdates" : [ ], - "killedTasksSummary" : { } + "killedTasksSummary" : { }, + "resourceProfileId" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json index 836f2cb095097..08089d4f3f65b 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json @@ -45,5 +45,6 @@ "name" : "my counter", "value" : "5050" } ], - "killedTasksSummary" : { } + "killedTasksSummary" : { }, + "resourceProfileId" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json index 735a8257fc343..3b5476ae8b160 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json @@ -506,5 +506,6 @@ "isBlacklistedForStage" : false } }, - "killedTasksSummary" : { } + "killedTasksSummary" : { }, + "resourceProfileId" : 0 } diff --git a/core/src/test/resources/spark-events/application_1578436911597_0052 b/core/src/test/resources/spark-events/application_1578436911597_0052 new file mode 100644 index 0000000000000..c57481a348a89 --- /dev/null +++ b/core/src/test/resources/spark-events/application_1578436911597_0052 @@ -0,0 +1,27 @@ +{"Event":"SparkListenerLogStart","Spark Version":"3.0.0-SNAPSHOT"} +{"Event":"SparkListenerResourceProfileAdded","Resource Profile Id":0,"Executor Resource Requests":{"cores":{"Resource Name":"cores","Amount":1,"Discovery Script":"","Vendor":""},"memory":{"Resource Name":"memory","Amount":1024,"Discovery Script":"","Vendor":""},"gpu":{"Resource Name":"gpu","Amount":1,"Discovery Script":"/home/tgraves/getGpus","Vendor":""}},"Task Resource Requests":{"cpus":{"Resource Name":"cpus","Amount":1.0},"gpu":{"Resource Name":"gpu","Amount":1.0}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"driver","Host":"10.10.10.10","Port":32957},"Maximum Memory":428762726,"Timestamp":1578764671818,"Maximum Onheap Memory":428762726,"Maximum Offheap Memory":0} +{"Event":"SparkListenerEnvironmentUpdate","JVM Information":{"Java Home":"/usr/lib/jvm/java-8-openjdk-amd64/jre","Java Version":"1.8.0_232 (Private Build)","Scala Version":"version 2.12.10"},"Spark Properties":{},"Hadoop Properties":{},"System Properties":{}, "Classpath Entries": {}} +{"Event":"SparkListenerApplicationStart","App Name":"Spark shell","App ID":"application_1578436911597_0052","Timestamp":1578764662851,"User":"tgraves"} +{"Event":"SparkListenerResourceProfileAdded","Resource Profile Id":1,"Executor Resource Requests":{"cores":{"Resource Name":"cores","Amount":4,"Discovery Script":"","Vendor":""},"gpu":{"Resource Name":"gpu","Amount":1,"Discovery Script":"./getGpus","Vendor":""}},"Task Resource Requests":{"cpus":{"Resource Name":"cpus","Amount":1.0},"gpu":{"Resource Name":"gpu","Amount":1.0}}} +{"Event":"SparkListenerResourceProfileAdded","Resource Profile Id":2,"Executor Resource Requests":{"cores":{"Resource Name":"cores","Amount":2,"Discovery Script":"","Vendor":""}},"Task Resource Requests":{"cpus":{"Resource Name":"cpus","Amount":2.0}}} +{"Event":"SparkListenerResourceProfileAdded","Resource Profile Id":3,"Executor Resource Requests":{"cores":{"Resource Name":"cores","Amount":4,"Discovery Script":"","Vendor":""},"gpu":{"Resource Name":"gpu","Amount":1,"Discovery Script":"./getGpus","Vendor":""}},"Task Resource Requests":{"cpus":{"Resource Name":"cpus","Amount":2.0},"gpu":{"Resource Name":"gpu","Amount":1.0}}} +{"Event":"SparkListenerJobStart","Job ID":0,"Submission Time":1578764765274,"Stage Infos":[{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"collect at :29","Number of Tasks":6,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :31","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":6,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :31","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":6,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:1004)\n$line37.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:29)\n$line37.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:33)\n$line37.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line37.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line37.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:39)\n$line37.$read$$iw$$iw$$iw$$iw$$iw.(:41)\n$line37.$read$$iw$$iw$$iw$$iw.(:43)\n$line37.$read$$iw$$iw$$iw.(:45)\n$line37.$read$$iw$$iw.(:47)\n$line37.$read$$iw.(:49)\n$line37.$read.(:51)\n$line37.$read$.(:55)\n$line37.$read$.()\n$line37.$eval$.$print$lzycompute(:7)\n$line37.$eval$.$print(:6)\n$line37.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)","Accumulables":[],"Resource Profile Id":3}],"Stage IDs":[0],"Properties":{"spark.rdd.scope":"{\"id\":\"2\",\"name\":\"collect\"}","spark.rdd.scope.noOverride":"true"}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"collect at :29","Number of Tasks":6,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :31","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":6,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :31","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":6,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:1004)\n$line37.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:29)\n$line37.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:33)\n$line37.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line37.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line37.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:39)\n$line37.$read$$iw$$iw$$iw$$iw$$iw.(:41)\n$line37.$read$$iw$$iw$$iw$$iw.(:43)\n$line37.$read$$iw$$iw$$iw.(:45)\n$line37.$read$$iw$$iw.(:47)\n$line37.$read$$iw.(:49)\n$line37.$read.(:51)\n$line37.$read$.(:55)\n$line37.$read$.()\n$line37.$eval$.$print$lzycompute(:7)\n$line37.$eval$.$print(:6)\n$line37.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)","Submission Time":1578764765293,"Accumulables":[],"Resource Profile Id":3},"Properties":{"spark.rdd.scope":"{\"id\":\"2\",\"name\":\"collect\"}","spark.rdd.scope.noOverride":"true"}} +{"Event":"SparkListenerExecutorAdded","Timestamp":1578764769706,"Executor ID":"1","Executor Info":{"Host":"host1","Total Cores":4,"Log Urls":{"stdout":"http://host1:8042/node/containerlogs/container_1578436911597_0052_01_000002/tgraves/stdout?start=-4096","stderr":"http://host1:8042/node/containerlogs/container_1578436911597_0052_01_000002/tgraves/stderr?start=-4096"},"Attributes":{"NM_HTTP_ADDRESS":"host1:8042","USER":"tgraves","LOG_FILES":"stderr,stdout","NM_HTTP_PORT":"8042","CLUSTER_ID":"","NM_PORT":"37783","HTTP_SCHEME":"http://","NM_HOST":"host1","CONTAINER_ID":"container_1578436911597_0052_01_000002"},"Resources":{"gpu":{"name":"gpu","addresses":["0","1","2"]}},"Resource Profile Id":3}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"1","Host":"host1","Port":40787},"Maximum Memory":384093388,"Timestamp":1578764769796,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":0} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1578764769858,"Executor ID":"1","Host":"host1","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":1,"Index":1,"Attempt":0,"Launch Time":1578764769877,"Executor ID":"1","Host":"host1","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":2,"Index":2,"Attempt":0,"Launch Time":1578764770507,"Executor ID":"1","Host":"host1","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":3,"Index":3,"Attempt":0,"Launch Time":1578764770509,"Executor ID":"1","Host":"host1","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1578764769858,"Executor ID":"1","Host":"host1","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1578764770512,"Failed":false,"Killed":false,"Accumulables":[{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":2,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":49,"Value":49,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":3706,"Value":3706,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":20740892,"Value":20740892,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":32,"Value":32,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":250921658,"Value":250921658,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":555,"Value":555,"Internal":true,"Count Failed Values":true}]},"Task Executor Metrics":{"JVMHeapMemory":0,"JVMOffHeapMemory":0,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":0,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":0,"OffHeapUnifiedMemory":0,"DirectPoolMemory":0,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":0,"ProcessTreeJVMRSSMemory":0,"ProcessTreePythonVMemory":0,"ProcessTreePythonRSSMemory":0,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0,"MinorGCCount":0,"MinorGCTime":0,"MajorGCCount":0,"MajorGCTime":0},"Task Metrics":{"Executor Deserialize Time":555,"Executor Deserialize CPU Time":250921658,"Executor Run Time":32,"Executor CPU Time":20740892,"Peak Execution Memory":0,"Result Size":3706,"JVM GC Time":49,"Result Serialization Time":2,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":1,"Index":1,"Attempt":0,"Launch Time":1578764769877,"Executor ID":"1","Host":"host1","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1578764770515,"Failed":false,"Killed":false,"Accumulables":[{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":2,"Value":4,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":49,"Value":98,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":3722,"Value":7428,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":25185125,"Value":45926017,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":32,"Value":64,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":416274503,"Value":667196161,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":555,"Value":1110,"Internal":true,"Count Failed Values":true}]},"Task Executor Metrics":{"JVMHeapMemory":0,"JVMOffHeapMemory":0,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":0,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":0,"OffHeapUnifiedMemory":0,"DirectPoolMemory":0,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":0,"ProcessTreeJVMRSSMemory":0,"ProcessTreePythonVMemory":0,"ProcessTreePythonRSSMemory":0,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0,"MinorGCCount":0,"MinorGCTime":0,"MajorGCCount":0,"MajorGCTime":0},"Task Metrics":{"Executor Deserialize Time":555,"Executor Deserialize CPU Time":416274503,"Executor Run Time":32,"Executor CPU Time":25185125,"Peak Execution Memory":0,"Result Size":3722,"JVM GC Time":49,"Result Serialization Time":2,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":4,"Index":4,"Attempt":0,"Launch Time":1578764770525,"Executor ID":"1","Host":"host1","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":2,"Index":2,"Attempt":0,"Launch Time":1578764770507,"Executor ID":"1","Host":"host1","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1578764770526,"Failed":false,"Killed":false,"Accumulables":[{"ID":4,"Name":"internal.metrics.resultSize","Update":3636,"Value":11064,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":2203515,"Value":48129532,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":2,"Value":66,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2733237,"Value":669929398,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":1112,"Internal":true,"Count Failed Values":true}]},"Task Executor Metrics":{"JVMHeapMemory":0,"JVMOffHeapMemory":0,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":0,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":0,"OffHeapUnifiedMemory":0,"DirectPoolMemory":0,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":0,"ProcessTreeJVMRSSMemory":0,"ProcessTreePythonVMemory":0,"ProcessTreePythonRSSMemory":0,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0,"MinorGCCount":0,"MinorGCTime":0,"MajorGCCount":0,"MajorGCTime":0},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":2733237,"Executor Run Time":2,"Executor CPU Time":2203515,"Peak Execution Memory":0,"Result Size":3636,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":5,"Index":5,"Attempt":0,"Launch Time":1578764770527,"Executor ID":"1","Host":"host1","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":3,"Index":3,"Attempt":0,"Launch Time":1578764770509,"Executor ID":"1","Host":"host1","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1578764770529,"Failed":false,"Killed":false,"Accumulables":[{"ID":4,"Name":"internal.metrics.resultSize","Update":3620,"Value":14684,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":2365599,"Value":50495131,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":2,"Value":68,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3387884,"Value":673317282,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":3,"Value":1115,"Internal":true,"Count Failed Values":true}]},"Task Executor Metrics":{"JVMHeapMemory":0,"JVMOffHeapMemory":0,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":0,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":0,"OffHeapUnifiedMemory":0,"DirectPoolMemory":0,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":0,"ProcessTreeJVMRSSMemory":0,"ProcessTreePythonVMemory":0,"ProcessTreePythonRSSMemory":0,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0,"MinorGCCount":0,"MinorGCTime":0,"MajorGCCount":0,"MajorGCTime":0},"Task Metrics":{"Executor Deserialize Time":3,"Executor Deserialize CPU Time":3387884,"Executor Run Time":2,"Executor CPU Time":2365599,"Peak Execution Memory":0,"Result Size":3620,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":4,"Index":4,"Attempt":0,"Launch Time":1578764770525,"Executor ID":"1","Host":"host1","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1578764770542,"Failed":false,"Killed":false,"Accumulables":[{"ID":4,"Name":"internal.metrics.resultSize","Update":3636,"Value":18320,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":2456346,"Value":52951477,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":2,"Value":70,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3502860,"Value":676820142,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":3,"Value":1118,"Internal":true,"Count Failed Values":true}]},"Task Executor Metrics":{"JVMHeapMemory":0,"JVMOffHeapMemory":0,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":0,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":0,"OffHeapUnifiedMemory":0,"DirectPoolMemory":0,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":0,"ProcessTreeJVMRSSMemory":0,"ProcessTreePythonVMemory":0,"ProcessTreePythonRSSMemory":0,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0,"MinorGCCount":0,"MinorGCTime":0,"MajorGCCount":0,"MajorGCTime":0},"Task Metrics":{"Executor Deserialize Time":3,"Executor Deserialize CPU Time":3502860,"Executor Run Time":2,"Executor CPU Time":2456346,"Peak Execution Memory":0,"Result Size":3636,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":5,"Index":5,"Attempt":0,"Launch Time":1578764770527,"Executor ID":"1","Host":"host1","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1578764770542,"Failed":false,"Killed":false,"Accumulables":[{"ID":4,"Name":"internal.metrics.resultSize","Update":3636,"Value":21956,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":2162370,"Value":55113847,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":2,"Value":72,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3622437,"Value":680442579,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":3,"Value":1121,"Internal":true,"Count Failed Values":true}]},"Task Executor Metrics":{"JVMHeapMemory":0,"JVMOffHeapMemory":0,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":0,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":0,"OffHeapUnifiedMemory":0,"DirectPoolMemory":0,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":0,"ProcessTreeJVMRSSMemory":0,"ProcessTreePythonVMemory":0,"ProcessTreePythonRSSMemory":0,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0,"MinorGCCount":0,"MinorGCTime":0,"MajorGCCount":0,"MajorGCTime":0},"Task Metrics":{"Executor Deserialize Time":3,"Executor Deserialize CPU Time":3622437,"Executor Run Time":2,"Executor CPU Time":2162370,"Peak Execution Memory":0,"Result Size":3636,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"collect at :29","Number of Tasks":6,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :31","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":6,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :31","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":6,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:1004)\n$line37.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:29)\n$line37.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:33)\n$line37.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line37.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line37.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:39)\n$line37.$read$$iw$$iw$$iw$$iw$$iw.(:41)\n$line37.$read$$iw$$iw$$iw$$iw.(:43)\n$line37.$read$$iw$$iw$$iw.(:45)\n$line37.$read$$iw$$iw.(:47)\n$line37.$read$$iw.(:49)\n$line37.$read.(:51)\n$line37.$read$.(:55)\n$line37.$read$.()\n$line37.$eval$.$print$lzycompute(:7)\n$line37.$eval$.$print(:6)\n$line37.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)","Submission Time":1578764765293,"Completion Time":1578764770543,"Accumulables":[{"ID":2,"Name":"internal.metrics.executorRunTime","Value":72,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Value":98,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Value":21956,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Value":680442579,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Value":55113847,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Value":4,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Value":1121,"Internal":true,"Count Failed Values":true}],"Resource Profile Id":3}} +{"Event":"SparkListenerJobEnd","Job ID":0,"Completion Time":1578764770546,"Job Result":{"Result":"JobSucceeded"}} +{"Event":"SparkListenerApplicationEnd","Timestamp":1578764802615} diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 807f0eb808f9b..8037f4a9447dd 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -1442,7 +1442,7 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { conf: SparkConf, clock: Clock = new SystemClock()): ExecutorAllocationManager = { ResourceProfile.reInitDefaultProfile(conf) - rpManager = new ResourceProfileManager(conf) + rpManager = new ResourceProfileManager(conf, listenerBus) val manager = new ExecutorAllocationManager(client, listenerBus, conf, clock = clock, resourceProfileManager = rpManager) managers += manager diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index cf4400e080e37..ec641f8294b29 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark // scalastyle:off import java.io.File +import java.util.{Locale, TimeZone} import org.apache.log4j.spi.LoggingEvent @@ -63,6 +64,11 @@ abstract class SparkFunSuite with Logging { // scalastyle:on + // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) + // Add Locale setting + Locale.setDefault(Locale.US) + protected val enableAutoThreadAudit = true protected override def beforeAll(): Unit = { diff --git a/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala b/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala index 55e34b32fe0d4..e97b9d5d6bea6 100644 --- a/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala +++ b/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala @@ -46,7 +46,8 @@ abstract class BenchmarkBase { if (regenerateBenchmarkFiles) { val version = System.getProperty("java.version").split("\\D+")(0).toInt val jdkString = if (version > 8) s"-jdk$version" else "" - val resultFileName = s"${this.getClass.getSimpleName.replace("$", "")}$jdkString-results.txt" + val resultFileName = + s"${this.getClass.getSimpleName.replace("$", "")}$jdkString$suffix-results.txt" val file = new File(s"benchmarks/$resultFileName") if (!file.exists()) { file.createNewFile() @@ -65,6 +66,8 @@ abstract class BenchmarkBase { afterAll() } + def suffix: String = "" + /** * Any shutdown code to ensure a clean shutdown */ diff --git a/core/src/test/scala/org/apache/spark/deploy/ExternalShuffleServiceMetricsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/ExternalShuffleServiceMetricsSuite.scala new file mode 100644 index 0000000000000..d681c13337e0d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/ExternalShuffleServiceMetricsSuite.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import scala.collection.JavaConverters._ + +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.internal.config.{SHUFFLE_SERVICE_DB_ENABLED, SHUFFLE_SERVICE_ENABLED} +import org.apache.spark.util.Utils + +class ExternalShuffleServiceMetricsSuite extends SparkFunSuite { + + var sparkConf: SparkConf = _ + var externalShuffleService: ExternalShuffleService = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sparkConf = new SparkConf() + sparkConf.set(SHUFFLE_SERVICE_ENABLED, true) + sparkConf.set(SHUFFLE_SERVICE_DB_ENABLED, false) + sparkConf.set("spark.local.dir", System.getProperty("java.io.tmpdir")) + Utils.loadDefaultSparkProperties(sparkConf, null) + val securityManager = new SecurityManager(sparkConf) + externalShuffleService = new ExternalShuffleService(sparkConf, securityManager) + externalShuffleService.start() + } + + override def afterAll(): Unit = { + if (externalShuffleService != null) { + externalShuffleService.stop() + } + super.afterAll() + } + + test("SPARK-31646: metrics should be registered") { + val sourceRef = classOf[ExternalShuffleService].getDeclaredField("shuffleServiceSource") + sourceRef.setAccessible(true) + val source = sourceRef.get(externalShuffleService).asInstanceOf[ExternalShuffleServiceSource] + assert(source.metricRegistry.getMetrics.keySet().asScala == + Set( + "blockTransferRateBytes", + "numActiveConnections", + "numCaughtExceptions", + "numRegisteredConnections", + "openBlockRequestLatencyMillis", + "registeredExecutorsSize", + "registerExecutorRequestLatencyMillis", + "shuffle-server.usedDirectMemory", + "shuffle-server.usedHeapMemory") + ) + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/ChromeUIHistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/ChromeUIHistoryServerSuite.scala new file mode 100644 index 0000000000000..1fa2d0ab882c9 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/history/ChromeUIHistoryServerSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.history + +import org.openqa.selenium.WebDriver +import org.openqa.selenium.chrome.{ChromeDriver, ChromeOptions} + +import org.apache.spark.tags.ChromeUITest + +/** + * Tests for HistoryServer with Chrome. + */ +@ChromeUITest +class ChromeUIHistoryServerSuite + extends RealBrowserUIHistoryServerSuite("webdriver.chrome.driver") { + + override var webDriver: WebDriver = _ + + override def beforeAll(): Unit = { + super.beforeAll() + val chromeOptions = new ChromeOptions + chromeOptions.addArguments("--headless", "--disable-gpu") + webDriver = new ChromeDriver(chromeOptions) + } + + override def afterAll(): Unit = { + try { + if (webDriver != null) { + webDriver.quit() + } + } finally { + super.afterAll() + } + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 206db0feb5716..8737cd5bb3241 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -29,8 +29,6 @@ import scala.concurrent.duration._ import com.google.common.io.{ByteStreams, Files} import org.apache.commons.io.{FileUtils, IOUtils} import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} -import org.eclipse.jetty.proxy.ProxyServlet -import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods import org.json4s.jackson.JsonMethods._ @@ -171,6 +169,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers "executor node blacklisting unblacklisting" -> "applications/app-20161115172038-0000/executors", "executor memory usage" -> "applications/app-20161116163331-0000/executors", "executor resource information" -> "applications/application_1555004656427_0144/executors", + "multiple resource profiles" -> "applications/application_1578436911597_0052/environment", "app environment" -> "applications/app-20161116163331-0000/environment", @@ -314,7 +313,8 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers all (directSiteRelativeLinks) should not startWith (knoxBaseUrl) } - test("static relative links are prefixed with uiRoot (spark.ui.proxyBase)") { + // TODO (SPARK-31723): re-enable it + ignore("static relative links are prefixed with uiRoot (spark.ui.proxyBase)") { val uiRoot = Option(System.getenv("APPLICATION_WEB_PROXY_BASE")).getOrElse("/testwebproxybase") val page = new HistoryPage(server) val request = mock[HttpServletRequest] @@ -334,66 +334,6 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers assert(response.contains(SPARK_VERSION)) } - test("ajax rendered relative links are prefixed with uiRoot (spark.ui.proxyBase)") { - val uiRoot = "/testwebproxybase" - System.setProperty("spark.ui.proxyBase", uiRoot) - - stop() - init() - - val port = server.boundPort - - val servlet = new ProxyServlet { - override def rewriteTarget(request: HttpServletRequest): String = { - // servlet acts like a proxy that redirects calls made on - // spark.ui.proxyBase context path to the normal servlet handlers operating off "/" - val sb = request.getRequestURL() - - if (request.getQueryString() != null) { - sb.append(s"?${request.getQueryString()}") - } - - val proxyidx = sb.indexOf(uiRoot) - sb.delete(proxyidx, proxyidx + uiRoot.length).toString - } - } - - val contextHandler = new ServletContextHandler - val holder = new ServletHolder(servlet) - contextHandler.setContextPath(uiRoot) - contextHandler.addServlet(holder, "/") - server.attachHandler(contextHandler) - - implicit val webDriver: WebDriver = new HtmlUnitDriver(true) - - try { - val url = s"http://localhost:$port" - - go to s"$url$uiRoot" - - // expect the ajax call to finish in 5 seconds - implicitlyWait(org.scalatest.time.Span(5, org.scalatest.time.Seconds)) - - // once this findAll call returns, we know the ajax load of the table completed - findAll(ClassNameQuery("odd")) - - val links = findAll(TagNameQuery("a")) - .map(_.attribute("href")) - .filter(_.isDefined) - .map(_.get) - .filter(_.startsWith(url)).toList - - // there are at least some URL links that were generated via javascript, - // and they all contain the spark.ui.proxyBase (uiRoot) - links.length should be > 4 - all(links) should startWith(url + uiRoot) - } finally { - contextHandler.stop() - quit() - } - - } - /** * Verify that the security manager needed for the history server can be instantiated * when `spark.authenticate` is `true`, rather than raise an `IllegalArgumentException`. @@ -693,6 +633,17 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers out.close() } + test("SPARK-31697: HistoryServer should set Content-Type") { + val port = server.boundPort + val nonExistenceAppId = "local-non-existence" + val url = new URL(s"http://localhost:$port/history/$nonExistenceAppId") + val conn = url.openConnection().asInstanceOf[HttpURLConnection] + conn.setRequestMethod("GET") + conn.connect() + val expectedContentType = "text/html;charset=utf-8" + val actualContentType = conn.getContentType + assert(actualContentType === expectedContentType) + } } object HistoryServerSuite { diff --git a/core/src/test/scala/org/apache/spark/deploy/history/RealBrowserUIHistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/RealBrowserUIHistoryServerSuite.scala new file mode 100644 index 0000000000000..8a1e22c694497 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/history/RealBrowserUIHistoryServerSuite.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.history + +import javax.servlet.http.HttpServletRequest + +import org.eclipse.jetty.proxy.ProxyServlet +import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} +import org.openqa.selenium.WebDriver +import org.scalatest._ +import org.scalatestplus.selenium.WebBrowser + +import org.apache.spark._ +import org.apache.spark.internal.config.{EVENT_LOG_STAGE_EXECUTOR_METRICS, EXECUTOR_PROCESS_TREE_METRICS_ENABLED} +import org.apache.spark.internal.config.History.{HISTORY_LOG_DIR, LOCAL_STORE_DIR, UPDATE_INTERVAL_S} +import org.apache.spark.internal.config.Tests.IS_TESTING +import org.apache.spark.util.{ResetSystemProperties, Utils} + +/** + * Tests for HistoryServer with real web browsers. + */ +abstract class RealBrowserUIHistoryServerSuite(val driverProp: String) + extends SparkFunSuite with WebBrowser with Matchers with BeforeAndAfterAll + with BeforeAndAfterEach with ResetSystemProperties { + + implicit var webDriver: WebDriver + + private val driverPropPrefix = "spark.test." + private val logDir = getTestResourcePath("spark-events") + private val storeDir = Utils.createTempDir(namePrefix = "history") + + private var provider: FsHistoryProvider = null + private var server: HistoryServer = null + private var port: Int = -1 + + override def beforeAll(): Unit = { + super.beforeAll() + assume( + sys.props(driverPropPrefix + driverProp) !== null, + "System property " + driverPropPrefix + driverProp + + " should be set to the corresponding driver path.") + sys.props(driverProp) = sys.props(driverPropPrefix + driverProp) + } + + override def beforeEach(): Unit = { + super.beforeEach() + if (server == null) { + init() + } + } + + override def afterAll(): Unit = { + sys.props.remove(driverProp) + super.afterAll() + } + + def init(extraConf: (String, String)*): Unit = { + Utils.deleteRecursively(storeDir) + assert(storeDir.mkdir()) + val conf = new SparkConf() + .set(HISTORY_LOG_DIR, logDir) + .set(UPDATE_INTERVAL_S.key, "0") + .set(IS_TESTING, true) + .set(LOCAL_STORE_DIR, storeDir.getAbsolutePath()) + .set(EVENT_LOG_STAGE_EXECUTOR_METRICS, true) + .set(EXECUTOR_PROCESS_TREE_METRICS_ENABLED, true) + conf.setAll(extraConf) + provider = new FsHistoryProvider(conf) + provider.checkForLogs() + val securityManager = HistoryServer.createSecurityManager(conf) + + server = new HistoryServer(conf, provider, securityManager, 18080) + server.initialize() + server.bind() + provider.start() + port = server.boundPort + } + + def stop(): Unit = { + server.stop() + server = null + } + + test("ajax rendered relative links are prefixed with uiRoot (spark.ui.proxyBase)") { + val uiRoot = "/testwebproxybase" + System.setProperty("spark.ui.proxyBase", uiRoot) + + stop() + init() + + val port = server.boundPort + + val servlet = new ProxyServlet { + override def rewriteTarget(request: HttpServletRequest): String = { + // servlet acts like a proxy that redirects calls made on + // spark.ui.proxyBase context path to the normal servlet handlers operating off "/" + val sb = request.getRequestURL() + + if (request.getQueryString() != null) { + sb.append(s"?${request.getQueryString()}") + } + + val proxyidx = sb.indexOf(uiRoot) + sb.delete(proxyidx, proxyidx + uiRoot.length).toString + } + } + + val contextHandler = new ServletContextHandler + val holder = new ServletHolder(servlet) + contextHandler.setContextPath(uiRoot) + contextHandler.addServlet(holder, "/") + server.attachHandler(contextHandler) + + try { + val url = s"http://localhost:$port" + + go to s"$url$uiRoot" + + // expect the ajax call to finish in 5 seconds + implicitlyWait(org.scalatest.time.Span(5, org.scalatest.time.Seconds)) + + // once this findAll call returns, we know the ajax load of the table completed + findAll(ClassNameQuery("odd")) + + val links = findAll(TagNameQuery("a")) + .map(_.attribute("href")) + .filter(_.isDefined) + .map(_.get) + .filter(_.startsWith(url)).toList + + // there are at least some URL links that were generated via javascript, + // and they all contain the spark.ui.proxyBase (uiRoot) + links.length should be > 4 + all(links) should startWith(url + uiRoot) + } finally { + contextHandler.stop() + quit() + } + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala index 275bca3459855..d9d559509f4fb 100644 --- a/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala @@ -19,10 +19,14 @@ package org.apache.spark.deploy.security import java.security.PrivilegedExceptionAction +import scala.util.control.NonFatal + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHENTICATION import org.apache.hadoop.minikdc.MiniKdc import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil @@ -88,8 +92,30 @@ class HadoopDelegationTokenManagerSuite extends SparkFunSuite { // krb5.conf. MiniKdc sets "java.security.krb5.conf" in start and removes it when stop called. val kdcDir = Utils.createTempDir() val kdcConf = MiniKdc.createConf() - kdc = new MiniKdc(kdcConf, kdcDir) - kdc.start() + // The port for MiniKdc service gets selected in the constructor, but will be bound + // to it later in MiniKdc.start() -> MiniKdc.initKDCServer() -> KdcServer.start(). + // In meantime, when some other service might capture the port during this progress, and + // cause BindException. + // This makes our tests which have dedicated JVMs and rely on MiniKDC being flaky + // + // https://issues.apache.org/jira/browse/HADOOP-12656 get fixed in Hadoop 2.8.0. + // + // The workaround here is to periodically repeat this process with a timeout , since we are + // using Hadoop 2.7.4 as default. + // https://issues.apache.org/jira/browse/SPARK-31631 + eventually(timeout(60.seconds), interval(1.second)) { + try { + kdc = new MiniKdc(kdcConf, kdcDir) + kdc.start() + } catch { + case NonFatal(e) => + if (kdc != null) { + kdc.stop() + kdc = null + } + throw e + } + } val krbConf = new Configuration() krbConf.set(HADOOP_SECURITY_AUTHENTICATION, "kerberos") diff --git a/core/src/test/scala/org/apache/spark/resource/ResourceProfileManagerSuite.scala b/core/src/test/scala/org/apache/spark/resource/ResourceProfileManagerSuite.scala index 004618a161b44..f4521738c4870 100644 --- a/core/src/test/scala/org/apache/spark/resource/ResourceProfileManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/resource/ResourceProfileManagerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.resource import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.internal.config._ import org.apache.spark.internal.config.Tests._ +import org.apache.spark.scheduler.LiveListenerBus class ResourceProfileManagerSuite extends SparkFunSuite { @@ -39,9 +40,11 @@ class ResourceProfileManagerSuite extends SparkFunSuite { } } + val listenerBus = new LiveListenerBus(new SparkConf()) + test("ResourceProfileManager") { val conf = new SparkConf().set(EXECUTOR_CORES, 4) - val rpmanager = new ResourceProfileManager(conf) + val rpmanager = new ResourceProfileManager(conf, listenerBus) val defaultProf = rpmanager.defaultResourceProfile assert(defaultProf.id === ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) assert(defaultProf.executorResources.size === 2, @@ -53,7 +56,7 @@ class ResourceProfileManagerSuite extends SparkFunSuite { test("isSupported yarn no dynamic allocation") { val conf = new SparkConf().setMaster("yarn").set(EXECUTOR_CORES, 4) conf.set(RESOURCE_PROFILE_MANAGER_TESTING.key, "true") - val rpmanager = new ResourceProfileManager(conf) + val rpmanager = new ResourceProfileManager(conf, listenerBus) // default profile should always work val defaultProf = rpmanager.defaultResourceProfile val rprof = new ResourceProfileBuilder() @@ -71,7 +74,7 @@ class ResourceProfileManagerSuite extends SparkFunSuite { val conf = new SparkConf().setMaster("yarn").set(EXECUTOR_CORES, 4) conf.set(DYN_ALLOCATION_ENABLED, true) conf.set(RESOURCE_PROFILE_MANAGER_TESTING.key, "true") - val rpmanager = new ResourceProfileManager(conf) + val rpmanager = new ResourceProfileManager(conf, listenerBus) // default profile should always work val defaultProf = rpmanager.defaultResourceProfile val rprof = new ResourceProfileBuilder() @@ -84,7 +87,7 @@ class ResourceProfileManagerSuite extends SparkFunSuite { test("isSupported yarn with local mode") { val conf = new SparkConf().setMaster("local").set(EXECUTOR_CORES, 4) conf.set(RESOURCE_PROFILE_MANAGER_TESTING.key, "true") - val rpmanager = new ResourceProfileManager(conf) + val rpmanager = new ResourceProfileManager(conf, listenerBus) // default profile should always work val defaultProf = rpmanager.defaultResourceProfile val rprof = new ResourceProfileBuilder() @@ -100,7 +103,7 @@ class ResourceProfileManagerSuite extends SparkFunSuite { test("ResourceProfileManager has equivalent profile") { val conf = new SparkConf().set(EXECUTOR_CORES, 4) - val rpmanager = new ResourceProfileManager(conf) + val rpmanager = new ResourceProfileManager(conf, listenerBus) var rpAlreadyExist: Option[ResourceProfile] = None val checkId = 500 for (i <- 1 to 1000) { diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala index c4e5e7c700652..01c82f894cf98 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.concurrent.Eventually import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.internal.config import org.apache.spark.internal.config.Tests.TEST_NO_STAGE_RETRY class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with Eventually { @@ -37,6 +38,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with .setAppName("test-cluster") .set(TEST_NO_STAGE_RETRY, true) sc = new SparkContext(conf) + TestUtils.waitUntilExecutorsUp(sc, numWorker, 60000) } test("global sync by barrier() call") { @@ -56,10 +58,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with } test("share messages with allGather() call") { - val conf = new SparkConf() - .setMaster("local-cluster[4, 1, 1024]") - .setAppName("test-cluster") - sc = new SparkContext(conf) + initLocalClusterSparkContext() val rdd = sc.makeRDD(1 to 10, 4) val rdd2 = rdd.barrier().mapPartitions { it => val context = BarrierTaskContext.get() @@ -68,19 +67,16 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with // Pass partitionId message in val message: String = context.partitionId().toString val messages: Array[String] = context.allGather(message) - messages.toList.iterator + Iterator.single(messages.toList) } - // Take a sorted list of all the partitionId messages - val messages = rdd2.collect().head - // All the task partitionIds are shared - for((x, i) <- messages.view.zipWithIndex) assert(x.toString == i.toString) + val messages = rdd2.collect() + // All the task partitionIds are shared across all tasks + assert(messages.length === 4) + assert(messages.forall(_ == List("0", "1", "2", "3"))) } test("throw exception if we attempt to synchronize with different blocking calls") { - val conf = new SparkConf() - .setMaster("local-cluster[4, 1, 1024]") - .setAppName("test-cluster") - sc = new SparkContext(conf) + initLocalClusterSparkContext() val rdd = sc.makeRDD(1 to 10, 4) val rdd2 = rdd.barrier().mapPartitions { it => val context = BarrierTaskContext.get() @@ -95,17 +91,11 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with val error = intercept[SparkException] { rdd2.collect() }.getMessage - assert( - error.contains("does not match the current synchronized requestMethod") || - error.contains("not properly killed") - ) + assert(error.contains("Different barrier sync types found")) } test("successively sync with allGather and barrier") { - val conf = new SparkConf() - .setMaster("local-cluster[4, 1, 1024]") - .setAppName("test-cluster") - sc = new SparkContext(conf) + initLocalClusterSparkContext() val rdd = sc.makeRDD(1 to 10, 4) val rdd2 = rdd.barrier().mapPartitions { it => val context = BarrierTaskContext.get() diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 61ea21fa86c5a..7c23e4449f461 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -36,6 +36,7 @@ import org.apache.spark.deploy.history.{EventLogFileReader, SingleEventLogFileWr import org.apache.spark.deploy.history.EventLogTestHelper._ import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.{EVENT_LOG_DIR, EVENT_LOG_ENABLED} import org.apache.spark.io._ import org.apache.spark.metrics.{ExecutorMetricType, MetricsSystem} import org.apache.spark.resource.ResourceProfile @@ -100,6 +101,49 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit testStageExecutorMetricsEventLogging() } + test("SPARK-31764: isBarrier should be logged in event log") { + val conf = new SparkConf() + conf.set(EVENT_LOG_ENABLED, true) + conf.set(EVENT_LOG_DIR, testDirPath.toString) + val sc = new SparkContext("local", "test-SPARK-31764", conf) + val appId = sc.applicationId + + sc.parallelize(1 to 10) + .barrier() + .mapPartitions(_.map(elem => (elem, elem))) + .filter(elem => elem._1 % 2 == 0) + .reduceByKey(_ + _) + .collect + sc.stop() + + val eventLogStream = EventLogFileReader.openEventLog(new Path(testDirPath, appId), fileSystem) + val events = readLines(eventLogStream).map(line => JsonProtocol.sparkEventFromJson(parse(line))) + val jobStartEvents = events + .filter(event => event.isInstanceOf[SparkListenerJobStart]) + .map(_.asInstanceOf[SparkListenerJobStart]) + + assert(jobStartEvents.size === 1) + val stageInfos = jobStartEvents.head.stageInfos + assert(stageInfos.size === 2) + + val stage0 = stageInfos(0) + val rddInfosInStage0 = stage0.rddInfos + assert(rddInfosInStage0.size === 3) + val sortedRddInfosInStage0 = rddInfosInStage0.sortBy(_.scope.get.name) + assert(sortedRddInfosInStage0(0).scope.get.name === "filter") + assert(sortedRddInfosInStage0(0).isBarrier === true) + assert(sortedRddInfosInStage0(1).scope.get.name === "mapPartitions") + assert(sortedRddInfosInStage0(1).isBarrier === true) + assert(sortedRddInfosInStage0(2).scope.get.name === "parallelize") + assert(sortedRddInfosInStage0(2).isBarrier === false) + + val stage1 = stageInfos(1) + val rddInfosInStage1 = stage1.rddInfos + assert(rddInfosInStage1.size === 1) + assert(rddInfosInStage1(0).scope.get.name === "reduceByKey") + assert(rddInfosInStage1(0).isBarrier === false) // reduceByKey + } + /* ----------------- * * Actual test logic * * ----------------- */ diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index a8541cb863478..a75bae56229b4 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -1208,7 +1208,6 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B test("SPARK-16106 locality levels updated if executor added to existing host") { val taskScheduler = setupScheduler() - taskScheduler.resourceOffers(IndexedSeq(new WorkerOffer("executor0", "host0", 1))) taskScheduler.submitTasks(FakeTask.createTaskSet(2, stageId = 0, stageAttemptId = 0, (0 until 2).map { _ => Seq(TaskLocation("host0", "executor2")) }: _* )) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 4978be3e04c1e..e4aad58d25064 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -620,7 +620,7 @@ class TaskSetManagerSuite manager.executorAdded() sched.addExecutor("execC", "host2") manager.executorAdded() - assert(manager.resourceOffer("exec1", "host1", ANY)._1.isDefined) + assert(manager.resourceOffer("execB", "host1", ANY)._1.isDefined) sched.removeExecutor("execA") manager.executorLost( "execA", @@ -634,6 +634,25 @@ class TaskSetManagerSuite assert(sched.taskSetsFailed.contains(taskSet.id)) } + test("SPARK-31837: Shift to the new highest locality level if there is when recomputeLocality") { + sc = new SparkContext("local", "test") + sched = new FakeTaskScheduler(sc) + val taskSet = FakeTask.createTaskSet(2, + Seq(TaskLocation("host1", "execA")), + Seq(TaskLocation("host1", "execA"))) + val clock = new ManualClock() + val manager = new TaskSetManager(sched, taskSet, 1, clock = clock) + // before any executors are added to TaskScheduler, the manager's + // locality level only has ANY, so tasks can be scheduled anyway. + assert(manager.resourceOffer("execB", "host2", ANY)._1.isDefined) + sched.addExecutor("execA", "host1") + manager.executorAdded() + // after adding a new executor, the manager locality has PROCESS_LOCAL, NODE_LOCAL, ANY. + // And we'll shift to the new highest locality level, which is PROCESS_LOCAL in this case. + assert(manager.resourceOffer("execC", "host3", ANY)._1.isEmpty) + assert(manager.resourceOffer("execA", "host1", ANY)._1.isDefined) + } + test("test RACK_LOCAL tasks") { // Assign host1 to rack1 FakeRackUtil.assignHostToRack("host1", "rack1") diff --git a/core/src/test/scala/org/apache/spark/scheduler/WorkerDecommissionSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/WorkerDecommissionSuite.scala index 8c6f86a6c0e88..148d20ee659a2 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/WorkerDecommissionSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/WorkerDecommissionSuite.scala @@ -22,7 +22,8 @@ import java.util.concurrent.Semaphore import scala.concurrent.TimeoutException import scala.concurrent.duration._ -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite, + TestUtils} import org.apache.spark.internal.config import org.apache.spark.scheduler.cluster.StandaloneSchedulerBackend import org.apache.spark.util.{RpcUtils, SerializableBuffer, ThreadUtils} @@ -48,12 +49,6 @@ class WorkerDecommissionSuite extends SparkFunSuite with LocalSparkContext { test("verify a task with all workers decommissioned succeeds") { val input = sc.parallelize(1 to 10) - // Do a count to wait for the executors to be registered. - input.count() - val sleepyRdd = input.mapPartitions{ x => - Thread.sleep(50) - x - } // Listen for the job val sem = new Semaphore(0) sc.addSparkListener(new SparkListener { @@ -61,22 +56,31 @@ class WorkerDecommissionSuite extends SparkFunSuite with LocalSparkContext { sem.release() } }) + TestUtils.waitUntilExecutorsUp(sc = sc, + numExecutors = 2, + timeout = 10000) // 10s + val sleepyRdd = input.mapPartitions{ x => + Thread.sleep(5000) // 5s + x + } // Start the task. val asyncCount = sleepyRdd.countAsync() // Wait for the job to have started sem.acquire(1) + // Give it time to make it to the worker otherwise we'll block + Thread.sleep(2000) // 2s // Decommission all the executors, this should not halt the current task. // decom.sh message passing is tested manually. val sched = sc.schedulerBackend.asInstanceOf[StandaloneSchedulerBackend] val execs = sched.getExecutorIds() execs.foreach(execId => sched.decommissionExecutor(execId)) - val asyncCountResult = ThreadUtils.awaitResult(asyncCount, 10.seconds) + val asyncCountResult = ThreadUtils.awaitResult(asyncCount, 20.seconds) assert(asyncCountResult === 10) // Try and launch task after decommissioning, this should fail val postDecommissioned = input.map(x => x) val postDecomAsyncCount = postDecommissioned.countAsync() val thrown = intercept[java.util.concurrent.TimeoutException]{ - val result = ThreadUtils.awaitResult(postDecomAsyncCount, 10.seconds) + val result = ThreadUtils.awaitResult(postDecomAsyncCount, 20.seconds) } assert(postDecomAsyncCount.isCompleted === false, "After exec decommission new task could not launch") diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionSuite.scala new file mode 100644 index 0000000000000..7456ca7f02a2e --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionSuite.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.util.concurrent.Semaphore + +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration._ + +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite, Success} +import org.apache.spark.internal.config +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd, SparkListenerTaskStart} +import org.apache.spark.scheduler.cluster.StandaloneSchedulerBackend +import org.apache.spark.util.{ResetSystemProperties, ThreadUtils} + +class BlockManagerDecommissionSuite extends SparkFunSuite with LocalSparkContext + with ResetSystemProperties { + + override def beforeEach(): Unit = { + val conf = new SparkConf().setAppName("test") + .set(config.Worker.WORKER_DECOMMISSION_ENABLED, true) + .set(config.STORAGE_DECOMMISSION_ENABLED, true) + + sc = new SparkContext("local-cluster[2, 1, 1024]", "test", conf) + } + + test(s"verify that an already running task which is going to cache data succeeds " + + s"on a decommissioned executor") { + // Create input RDD with 10 partitions + val input = sc.parallelize(1 to 10, 10) + val accum = sc.longAccumulator("mapperRunAccumulator") + // Do a count to wait for the executors to be registered. + input.count() + + // Create a new RDD where we have sleep in each partition, we are also increasing + // the value of accumulator in each partition + val sleepyRdd = input.mapPartitions { x => + Thread.sleep(500) + accum.add(1) + x + } + + // Listen for the job + val sem = new Semaphore(0) + val taskEndEvents = ArrayBuffer.empty[SparkListenerTaskEnd] + sc.addSparkListener(new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + sem.release() + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + taskEndEvents.append(taskEnd) + } + }) + + // Cache the RDD lazily + sleepyRdd.persist() + + // Start the computation of RDD - this step will also cache the RDD + val asyncCount = sleepyRdd.countAsync() + + // Wait for the job to have started + sem.acquire(1) + + // Give Spark a tiny bit to start the tasks after the listener says hello + Thread.sleep(100) + // Decommission one of the executor + val sched = sc.schedulerBackend.asInstanceOf[StandaloneSchedulerBackend] + val execs = sched.getExecutorIds() + assert(execs.size == 2, s"Expected 2 executors but found ${execs.size}") + val execToDecommission = execs.head + sched.decommissionExecutor(execToDecommission) + + // Wait for job to finish + val asyncCountResult = ThreadUtils.awaitResult(asyncCount, 6.seconds) + assert(asyncCountResult === 10) + // All 10 tasks finished, so accum should have been increased 10 times + assert(accum.value === 10) + + // All tasks should be successful, nothing should have failed + sc.listenerBus.waitUntilEmpty() + assert(taskEndEvents.size === 10) // 10 mappers + assert(taskEndEvents.map(_.reason).toSet === Set(Success)) + + // Since the RDD is cached, so further usage of same RDD should use the + // cached data. Original RDD partitions should not be recomputed i.e. accum + // should have same value like before + assert(sleepyRdd.count() === 10) + assert(accum.value === 10) + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 8d06768a2b284..bfef8f1ab29d8 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -1706,6 +1706,64 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE verify(liveListenerBus, never()).post(SparkListenerBlockUpdated(BlockUpdatedInfo(updateInfo))) } + test("test decommission block manager should not be part of peers") { + val exec1 = "exec1" + val exec2 = "exec2" + val exec3 = "exec3" + val store1 = makeBlockManager(800, exec1) + val store2 = makeBlockManager(800, exec2) + val store3 = makeBlockManager(800, exec3) + + assert(master.getPeers(store3.blockManagerId).map(_.executorId).toSet === Set(exec1, exec2)) + + val data = new Array[Byte](4) + val blockId = rdd(0, 0) + store1.putSingle(blockId, data, StorageLevel.MEMORY_ONLY_2) + assert(master.getLocations(blockId).size === 2) + + master.decommissionBlockManagers(Seq(exec1)) + // store1 is decommissioned, so it should not be part of peer list for store3 + assert(master.getPeers(store3.blockManagerId).map(_.executorId).toSet === Set(exec2)) + } + + test("test decommissionRddCacheBlocks should offload all cached blocks") { + val store1 = makeBlockManager(800, "exec1") + val store2 = makeBlockManager(800, "exec2") + val store3 = makeBlockManager(800, "exec3") + + val data = new Array[Byte](4) + val blockId = rdd(0, 0) + store1.putSingle(blockId, data, StorageLevel.MEMORY_ONLY_2) + assert(master.getLocations(blockId).size === 2) + assert(master.getLocations(blockId).contains(store1.blockManagerId)) + + store1.decommissionRddCacheBlocks() + assert(master.getLocations(blockId).size === 2) + assert(master.getLocations(blockId).toSet === Set(store2.blockManagerId, + store3.blockManagerId)) + } + + test("test decommissionRddCacheBlocks should keep the block if it is not able to offload") { + val store1 = makeBlockManager(3500, "exec1") + val store2 = makeBlockManager(1000, "exec2") + + val dataLarge = new Array[Byte](1500) + val blockIdLarge = rdd(0, 0) + val dataSmall = new Array[Byte](1) + val blockIdSmall = rdd(0, 1) + + store1.putSingle(blockIdLarge, dataLarge, StorageLevel.MEMORY_ONLY) + store1.putSingle(blockIdSmall, dataSmall, StorageLevel.MEMORY_ONLY) + assert(master.getLocations(blockIdLarge) === Seq(store1.blockManagerId)) + assert(master.getLocations(blockIdSmall) === Seq(store1.blockManagerId)) + + store1.decommissionRddCacheBlocks() + // Smaller block offloaded to store2 + assert(master.getLocations(blockIdSmall) === Seq(store2.blockManagerId)) + // Larger block still present in store1 as it can't be offloaded + assert(master.getLocations(blockIdLarge) === Seq(store1.blockManagerId)) + } + class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService { var numCalls = 0 var tempFileManager: DownloadFileManager = null diff --git a/core/src/test/scala/org/apache/spark/ui/ChromeUISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/ChromeUISeleniumSuite.scala new file mode 100644 index 0000000000000..9ba705c4abd75 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ui/ChromeUISeleniumSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import org.openqa.selenium.WebDriver +import org.openqa.selenium.chrome.{ChromeDriver, ChromeOptions} + +import org.apache.spark.tags.ChromeUITest + +/** + * Selenium tests for the Spark Web UI with Chrome. + */ +@ChromeUITest +class ChromeUISeleniumSuite extends RealBrowserUISeleniumSuite("webdriver.chrome.driver") { + + override var webDriver: WebDriver = _ + + override def beforeAll(): Unit = { + super.beforeAll() + val chromeOptions = new ChromeOptions + chromeOptions.addArguments("--headless", "--disable-gpu") + webDriver = new ChromeDriver(chromeOptions) + } + + override def afterAll(): Unit = { + try { + if (webDriver != null) { + webDriver.quit() + } + } finally { + super.afterAll() + } + } +} diff --git a/core/src/test/scala/org/apache/spark/ui/RealBrowserUISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/RealBrowserUISeleniumSuite.scala new file mode 100644 index 0000000000000..4b018f69b1660 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ui/RealBrowserUISeleniumSuite.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import org.openqa.selenium.{By, WebDriver} +import org.scalatest._ +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ +import org.scalatestplus.selenium.WebBrowser + +import org.apache.spark._ +import org.apache.spark.LocalSparkContext.withSpark +import org.apache.spark.internal.config.MEMORY_OFFHEAP_SIZE +import org.apache.spark.internal.config.UI.{UI_ENABLED, UI_KILL_ENABLED, UI_PORT} +import org.apache.spark.util.CallSite + +/** + * Selenium tests for the Spark Web UI with real web browsers. + */ +abstract class RealBrowserUISeleniumSuite(val driverProp: String) + extends SparkFunSuite with WebBrowser with Matchers with BeforeAndAfterAll { + + implicit var webDriver: WebDriver + private val driverPropPrefix = "spark.test." + + override def beforeAll(): Unit = { + super.beforeAll() + assume( + sys.props(driverPropPrefix + driverProp) !== null, + "System property " + driverPropPrefix + driverProp + + " should be set to the corresponding driver path.") + sys.props(driverProp) = sys.props(driverPropPrefix + driverProp) + } + + override def afterAll(): Unit = { + sys.props.remove(driverProp) + super.afterAll() + } + + test("SPARK-31534: text for tooltip should be escaped") { + withSpark(newSparkContext()) { sc => + sc.setLocalProperty(CallSite.LONG_FORM, "collect at :25") + sc.setLocalProperty(CallSite.SHORT_FORM, "collect at :25") + sc.parallelize(1 to 10).collect + + eventually(timeout(10.seconds), interval(50.milliseconds)) { + goToUi(sc, "/jobs") + + val jobDesc = + webDriver.findElement(By.cssSelector("div[class='application-timeline-content']")) + jobDesc.getAttribute("data-title") should include ("collect at <console>:25") + + goToUi(sc, "/jobs/job/?id=0") + webDriver.get(sc.ui.get.webUrl.stripSuffix("/") + "/jobs/job/?id=0") + val stageDesc = webDriver.findElement(By.cssSelector("div[class='job-timeline-content']")) + stageDesc.getAttribute("data-title") should include ("collect at <console>:25") + + // Open DAG Viz. + webDriver.findElement(By.id("job-dag-viz")).click() + val nodeDesc = webDriver.findElement(By.cssSelector("g[class='node_0 node']")) + nodeDesc.getAttribute("name") should include ("collect at <console>:25") + } + } + } + + test("SPARK-31882: Link URL for Stage DAGs should not depend on paged table.") { + withSpark(newSparkContext()) { sc => + sc.parallelize(1 to 100).map(v => (v, v)).repartition(10).reduceByKey(_ + _).collect + + eventually(timeout(10.seconds), interval(50.microseconds)) { + val pathWithPagedTable = + "/jobs/job/?id=0&completedStage.page=2&completedStage.sort=Stage+Id&" + + "completedStage.desc=true&completedStage.pageSize=1#completed" + goToUi(sc, pathWithPagedTable) + + // Open DAG Viz. + webDriver.findElement(By.id("job-dag-viz")).click() + val stages = webDriver.findElements(By.cssSelector("svg[class='job'] > a")) + stages.size() should be (3) + + stages.get(0).getAttribute("href") should include ("/stages/stage/?id=0&attempt=0") + stages.get(1).getAttribute("href") should include ("/stages/stage/?id=1&attempt=0") + stages.get(2).getAttribute("href") should include ("/stages/stage/?id=2&attempt=0") + } + } + } + + test("SPARK-31886: Color barrier execution mode RDD correctly") { + withSpark(newSparkContext()) { sc => + sc.parallelize(1 to 10).barrier.mapPartitions(identity).repartition(1).collect() + + eventually(timeout(10.seconds), interval(50.milliseconds)) { + goToUi(sc, "/jobs/job/?id=0") + webDriver.findElement(By.id("job-dag-viz")).click() + + val stage0 = webDriver.findElement(By.cssSelector("g[id='graph_0']")) + val stage1 = webDriver.findElement(By.cssSelector("g[id='graph_1']")) + val barrieredOps = webDriver.findElements(By.className("barrier-rdd")).iterator() + + while (barrieredOps.hasNext) { + val barrieredOpId = barrieredOps.next().getAttribute("innerHTML") + val foundInStage0 = + stage0.findElements( + By.cssSelector("g.barrier.cluster.cluster_" + barrieredOpId)) + assert(foundInStage0.size === 1) + + val foundInStage1 = + stage1.findElements( + By.cssSelector("g.barrier.cluster.cluster_" + barrieredOpId)) + assert(foundInStage1.size === 0) + } + } + } + } + + /** + * Create a test SparkContext with the SparkUI enabled. + * It is safe to `get` the SparkUI directly from the SparkContext returned here. + */ + private def newSparkContext( + killEnabled: Boolean = true, + master: String = "local", + additionalConfs: Map[String, String] = Map.empty): SparkContext = { + val conf = new SparkConf() + .setMaster(master) + .setAppName("test") + .set(UI_ENABLED, true) + .set(UI_PORT, 0) + .set(UI_KILL_ENABLED, killEnabled) + .set(MEMORY_OFFHEAP_SIZE.key, "64m") + additionalConfs.foreach { case (k, v) => conf.set(k, v) } + val sc = new SparkContext(conf) + assert(sc.ui.isDefined) + sc + } + + def goToUi(sc: SparkContext, path: String): Unit = { + goToUi(sc.ui.get, path) + } + + def goToUi(ui: SparkUI, path: String): Unit = { + go to (ui.webUrl.stripSuffix("/") + path) + } +} diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 7711934cbe8a6..3d52199b01327 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -92,12 +92,12 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { accumulatorUpdates = Seq(new UIAccumulableInfo(0L, "acc", None, "value")), tasks = None, executorSummary = None, - killedTasksSummary = Map.empty + killedTasksSummary = Map.empty, + ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID ) val taskTable = new TaskPagedTable( stageData, basePath = "/a/b/c", - currentTime = 0, pageSize = 10, sortColumn = "Index", desc = false, diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 3ec9385116408..909056eab8c5a 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -773,33 +773,6 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } } - test("SPARK-31534: text for tooltip should be escaped") { - withSpark(newSparkContext()) { sc => - sc.setLocalProperty(CallSite.LONG_FORM, "collect at :25") - sc.setLocalProperty(CallSite.SHORT_FORM, "collect at :25") - sc.parallelize(1 to 10).collect - - val driver = webDriver.asInstanceOf[HtmlUnitDriver] - driver.setJavascriptEnabled(true) - - eventually(timeout(10.seconds), interval(50.milliseconds)) { - goToUi(sc, "/jobs") - val jobDesc = - driver.findElement(By.cssSelector("div[class='application-timeline-content']")) - jobDesc.getAttribute("data-title") should include ("collect at <console>:25") - - goToUi(sc, "/jobs/job/?id=0") - val stageDesc = driver.findElement(By.cssSelector("div[class='job-timeline-content']")) - stageDesc.getAttribute("data-title") should include ("collect at <console>:25") - - // Open DAG Viz. - driver.findElement(By.id("job-dag-viz")).click() - val nodeDesc = driver.findElement(By.cssSelector("g[class='node_0 node']")) - nodeDesc.getAttribute("name") should include ("collect at <console>:25") - } - } - } - def goToUi(sc: SparkContext, path: String): Unit = { goToUi(sc.ui.get, path) } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index eb7f3079bee36..248142a5ad633 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark._ import org.apache.spark.executor._ import org.apache.spark.metrics.ExecutorMetricType import org.apache.spark.rdd.RDDOperationScope -import org.apache.spark.resource.{ResourceInformation, ResourceProfile, ResourceUtils} +import org.apache.spark.resource._ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.shuffle.MetadataFetchFailedException @@ -92,7 +92,7 @@ class JsonProtocolSuite extends SparkFunSuite { 42L, "Garfield", Some("appAttempt"), Some(logUrlMap)) val applicationEnd = SparkListenerApplicationEnd(42L) val executorAdded = SparkListenerExecutorAdded(executorAddedTime, "exec1", - new ExecutorInfo("Hostee.awesome.com", 11, logUrlMap, attributes, resources.toMap)) + new ExecutorInfo("Hostee.awesome.com", 11, logUrlMap, attributes, resources.toMap, 4)) val executorRemoved = SparkListenerExecutorRemoved(executorRemovedTime, "exec2", "test reason") val executorBlacklisted = SparkListenerExecutorBlacklisted(executorBlacklistedTime, "exec1", 22) val executorUnblacklisted = @@ -119,6 +119,14 @@ class JsonProtocolSuite extends SparkFunSuite { SparkListenerStageExecutorMetrics("1", 2, 3, new ExecutorMetrics(Array(543L, 123456L, 12345L, 1234L, 123L, 12L, 432L, 321L, 654L, 765L, 256912L, 123456L, 123456L, 61728L, 30364L, 15182L, 10L, 90L, 2L, 20L))) + val rprofBuilder = new ResourceProfileBuilder() + val taskReq = new TaskResourceRequests().cpus(1).resource("gpu", 1) + val execReq = + new ExecutorResourceRequests().cores(2).resource("gpu", 2, "myscript") + rprofBuilder.require(taskReq).require(execReq) + val resourceProfile = rprofBuilder.build + resourceProfile.setResourceProfileId(21) + val resourceProfileAdded = SparkListenerResourceProfileAdded(resourceProfile) testEvent(stageSubmitted, stageSubmittedJsonString) testEvent(stageCompleted, stageCompletedJsonString) testEvent(taskStart, taskStartJsonString) @@ -144,6 +152,7 @@ class JsonProtocolSuite extends SparkFunSuite { testEvent(executorMetricsUpdate, executorMetricsUpdateJsonString) testEvent(blockUpdated, blockUpdatedJsonString) testEvent(stageExecutorMetrics, stageExecutorMetricsJsonString) + testEvent(resourceProfileAdded, resourceProfileJsonString) } test("Dependent Classes") { @@ -231,6 +240,20 @@ class JsonProtocolSuite extends SparkFunSuite { assert(0 === newInfo.accumulables.size) } + test("StageInfo resourceProfileId") { + val info = makeStageInfo(1, 2, 3, 4L, 5L, 5) + val json = JsonProtocol.stageInfoToJson(info) + + // Fields added after 1.0.0. + assert(info.details.nonEmpty) + assert(info.resourceProfileId === 5) + + val newInfo = JsonProtocol.stageInfoFromJson(json) + + assert(info.name === newInfo.name) + assert(5 === newInfo.resourceProfileId) + } + test("InputMetrics backward compatibility") { // InputMetrics were added after 1.0.1. val metrics = makeTaskMetrics(1L, 2L, 3L, 4L, 5, 6, hasHadoopInput = true, hasOutput = false) @@ -865,6 +888,10 @@ private[spark] object JsonProtocolSuite extends Assertions { assert(ste1.getFileName === ste2.getFileName) } + private def assertEquals(rp1: ResourceProfile, rp2: ResourceProfile): Unit = { + assert(rp1 === rp2) + } + /** ----------------------------------- * | Util methods for constructing events | * ------------------------------------ */ @@ -895,10 +922,16 @@ private[spark] object JsonProtocolSuite extends Assertions { r } - private def makeStageInfo(a: Int, b: Int, c: Int, d: Long, e: Long) = { + private def makeStageInfo( + a: Int, + b: Int, + c: Int, + d: Long, + e: Long, + rpId: Int = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) = { val rddInfos = (0 until a % 5).map { i => makeRddInfo(a + i, b + i, c + i, d + i, e + i) } val stageInfo = new StageInfo(a, 0, "greetings", b, rddInfos, Seq(100, 200, 300), "details", - resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) + resourceProfileId = rpId) val (acc1, acc2) = (makeAccumulableInfo(1), makeAccumulableInfo(2)) stageInfo.accumulables(acc1.id) = acc1 stageInfo.accumulables(acc2.id) = acc2 @@ -1034,7 +1067,8 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Internal": false, | "Count Failed Values": false | } - | ] + | ], + | "Resource Profile Id" : 0 | }, | "Properties": { | "France": "Paris", @@ -1066,6 +1100,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 201, | "Number of Cached Partitions": 301, | "Memory Size": 401, @@ -1091,7 +1126,8 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Internal": false, | "Count Failed Values": false | } - | ] + | ], + | "Resource Profile Id" : 0 | } |} """.stripMargin @@ -1588,6 +1624,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 200, | "Number of Cached Partitions": 300, | "Memory Size": 400, @@ -1613,7 +1650,8 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Internal": false, | "Count Failed Values": false | } - | ] + | ], + | "Resource Profile Id" : 0 | }, | { | "Stage ID": 2, @@ -1632,6 +1670,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 400, | "Number of Cached Partitions": 600, | "Memory Size": 800, @@ -1648,6 +1687,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 401, | "Number of Cached Partitions": 601, | "Memory Size": 801, @@ -1673,7 +1713,8 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Internal": false, | "Count Failed Values": false | } - | ] + | ], + | "Resource Profile Id" : 0 | }, | { | "Stage ID": 3, @@ -1692,6 +1733,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 600, | "Number of Cached Partitions": 900, | "Memory Size": 1200, @@ -1708,6 +1750,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 601, | "Number of Cached Partitions": 901, | "Memory Size": 1201, @@ -1724,6 +1767,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 602, | "Number of Cached Partitions": 902, | "Memory Size": 1202, @@ -1749,7 +1793,8 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Internal": false, | "Count Failed Values": false | } - | ] + | ], + | "Resource Profile Id" : 0 | }, | { | "Stage ID": 4, @@ -1768,6 +1813,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 800, | "Number of Cached Partitions": 1200, | "Memory Size": 1600, @@ -1784,6 +1830,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 801, | "Number of Cached Partitions": 1201, | "Memory Size": 1601, @@ -1800,6 +1847,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 802, | "Number of Cached Partitions": 1202, | "Memory Size": 1602, @@ -1816,6 +1864,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 803, | "Number of Cached Partitions": 1203, | "Memory Size": 1603, @@ -1841,7 +1890,8 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Internal": false, | "Count Failed Values": false | } - | ] + | ], + | "Resource Profile Id" : 0 | } | ], | "Stage IDs": [ @@ -1988,7 +2038,8 @@ private[spark] object JsonProtocolSuite extends Assertions { | "name" : "gpu", | "addresses" : [ "0", "1" ] | } - | } + | }, + | "Resource Profile Id": 4 | } |} """.stripMargin @@ -2334,6 +2385,38 @@ private[spark] object JsonProtocolSuite extends Assertions { | "hostId" : "node1" |} """.stripMargin + private val resourceProfileJsonString = + """ + |{ + | "Event":"SparkListenerResourceProfileAdded", + | "Resource Profile Id":21, + | "Executor Resource Requests":{ + | "cores" : { + | "Resource Name":"cores", + | "Amount":2, + | "Discovery Script":"", + | "Vendor":"" + | }, + | "gpu":{ + | "Resource Name":"gpu", + | "Amount":2, + | "Discovery Script":"myscript", + | "Vendor":"" + | } + | }, + | "Task Resource Requests":{ + | "cpus":{ + | "Resource Name":"cpus", + | "Amount":1.0 + | }, + | "gpu":{ + | "Resource Name":"gpu", + | "Amount":1.0 + | } + | } + |} + """.stripMargin + } case class TestListenerEvent(foo: String, bar: Int) extends SparkListenerEvent diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index f5e438b0f1a52..c9c8ae6023877 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -745,10 +745,14 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { manager.add(3, () => output += 3) manager.add(2, () => output += 2) manager.add(4, () => output += 4) + manager.add(Int.MinValue, () => output += Int.MinValue) + manager.add(Int.MinValue, () => output += Int.MinValue) + manager.add(Int.MaxValue, () => output += Int.MaxValue) + manager.add(Int.MaxValue, () => output += Int.MaxValue) manager.remove(hook1) manager.runAll() - assert(output.toList === List(4, 3, 2)) + assert(output.toList === List(Int.MaxValue, Int.MaxValue, 4, 3, 2, Int.MinValue, Int.MinValue)) } test("isInDirectory") { @@ -1297,6 +1301,14 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(Utils.trimExceptCRLF(s"b${s}b") === s"b${s}b") } } + + test("pathsToMetadata") { + val paths = (0 to 4).map(i => new Path(s"path$i")) + assert(Utils.buildLocationMetadata(paths, 5) == "[path0]") + assert(Utils.buildLocationMetadata(paths, 10) == "[path0, path1]") + assert(Utils.buildLocationMetadata(paths, 15) == "[path0, path1, path2]") + assert(Utils.buildLocationMetadata(paths, 25) == "[path0, path1, path2, path3]") + } } private class SimpleExtension diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 473551f208994..4faaaecfe8457 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -25,7 +25,6 @@ bootstrap.bundle.min.js bootstrap.bundle.min.js.map bootstrap.min.css bootstrap.min.css.map -bootstrap-tooltip.js jquery-3.4.1.min.js d3.min.js dagre-d3.min.js @@ -125,3 +124,4 @@ vote.tmpl SessionManager.java SessionHandler.java GangliaReporter.java +application_1578436911597_0052 diff --git a/dev/create-release/do-release-docker.sh b/dev/create-release/do-release-docker.sh index 2f794c0e0a174..8f53f4a4e13ad 100755 --- a/dev/create-release/do-release-docker.sh +++ b/dev/create-release/do-release-docker.sh @@ -128,6 +128,7 @@ ASF_PASSWORD=$ASF_PASSWORD GPG_PASSPHRASE=$GPG_PASSPHRASE RELEASE_STEP=$RELEASE_STEP USER=$USER +ZINC_OPTS=${RELEASE_ZINC_OPTS:-"-Xmx4g -XX:ReservedCodeCacheSize=2g"} EOF JAVA_VOL= diff --git a/dev/create-release/do-release.sh b/dev/create-release/do-release.sh index 4f18a55a3bceb..64fba8a56affe 100755 --- a/dev/create-release/do-release.sh +++ b/dev/create-release/do-release.sh @@ -17,6 +17,8 @@ # limitations under the License. # +set -e + SELF=$(cd $(dirname $0) && pwd) . "$SELF/release-util.sh" @@ -52,9 +54,6 @@ function should_build { if should_build "tag" && [ $SKIP_TAG = 0 ]; then run_silent "Creating release tag $RELEASE_TAG..." "tag.log" \ "$SELF/release-tag.sh" - echo "It may take some time for the tag to be synchronized to github." - echo "Press enter when you've verified that the new tag ($RELEASE_TAG) is available." - read else echo "Skipping tag creation for $RELEASE_TAG." fi @@ -79,3 +78,12 @@ if should_build "publish"; then else echo "Skipping publish step." fi + +if should_build "tag" && [ $SKIP_TAG = 0 ]; then + git push origin $RELEASE_TAG + if [[ $RELEASE_TAG != *"preview"* ]]; then + git push origin HEAD:$GIT_BRANCH + else + echo "It's preview release. We only push $RELEASE_TAG to remote." + fi +fi diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 022d3af95c05d..66c51845cc1d0 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -92,9 +92,12 @@ BASE_DIR=$(pwd) init_java init_maven_sbt -rm -rf spark -git clone "$ASF_REPO" +# Only clone repo fresh if not present, otherwise use checkout from the tag step +if [ ! -d spark ]; then + git clone "$ASF_REPO" +fi cd spark +git fetch git checkout $GIT_REF git_hash=`git rev-parse --short HEAD` echo "Checked out Spark git hash $git_hash" @@ -103,7 +106,7 @@ if [ -z "$SPARK_VERSION" ]; then # Run $MVN in a separate command so that 'set -e' does the right thing. TMP=$(mktemp) $MVN help:evaluate -Dexpression=project.version > $TMP - SPARK_VERSION=$(cat $TMP | grep -v INFO | grep -v WARNING | grep -v Download) + SPARK_VERSION=$(cat $TMP | grep -v INFO | grep -v WARNING | grep -vi Download) rm $TMP fi @@ -380,7 +383,7 @@ if [[ "$1" == "publish-snapshot" ]]; then echo "" >> $tmp_settings # Generate random point for Zinc - export ZINC_PORT=$(python -S -c "import random; print random.randrange(3030,4030)") + export ZINC_PORT=$(python -S -c "import random; print(random.randrange(3030,4030))") $MVN -DzincPort=$ZINC_PORT --settings $tmp_settings -DskipTests $SCALA_2_12_PROFILES $PUBLISH_PROFILES deploy @@ -412,7 +415,7 @@ if [[ "$1" == "publish-release" ]]; then tmp_repo=$(mktemp -d spark-repo-XXXXX) # Generate random point for Zinc - export ZINC_PORT=$(python -S -c "import random; print random.randrange(3030,4030)") + export ZINC_PORT=$(python -S -c "import random; print(random.randrange(3030,4030))") # TODO: revisit for Scala 2.13 support diff --git a/dev/create-release/release-tag.sh b/dev/create-release/release-tag.sh index 39856a9955955..e37aa27fc0aac 100755 --- a/dev/create-release/release-tag.sh +++ b/dev/create-release/release-tag.sh @@ -25,6 +25,7 @@ function exit_with_usage { cat << EOF usage: $NAME Tags a Spark release on a particular branch. +You must push the tags after. Inputs are specified with the following environment variables: ASF_USERNAME - Apache Username @@ -105,19 +106,8 @@ sed -i".tmp7" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$R_NEXT_VERSION" git commit -a -m "Preparing development version $NEXT_VERSION" -if ! is_dry_run; then - # Push changes - git push origin $RELEASE_TAG - if [[ $RELEASE_VERSION != *"preview"* ]]; then - git push origin HEAD:$GIT_BRANCH - else - echo "It's preview release. We only push $RELEASE_TAG to remote." - fi - - cd .. - rm -rf spark -else - cd .. +cd .. +if is_dry_run; then mv spark spark.tag echo "Clone with version changes and tag available as spark.tag in the output directory." fi diff --git a/dev/create-release/release-util.sh b/dev/create-release/release-util.sh index 8ee94a67f34fd..af9ed201b3b47 100755 --- a/dev/create-release/release-util.sh +++ b/dev/create-release/release-util.sh @@ -19,9 +19,8 @@ DRY_RUN=${DRY_RUN:-0} GPG="gpg --no-tty --batch" -ASF_REPO="https://gitbox.apache.org/repos/asf/spark.git" -ASF_REPO_WEBUI="https://gitbox.apache.org/repos/asf?p=spark.git" -ASF_GITHUB_REPO="https://github.com/apache/spark" +ASF_REPO="https://github.com/apache/spark" +ASF_REPO_WEBUI="https://raw.githubusercontent.com/apache/spark" function error { echo "$*" @@ -74,7 +73,7 @@ function fcreate_secure { } function check_for_tag { - curl -s --head --fail "$ASF_GITHUB_REPO/releases/tag/$1" > /dev/null + curl -s --head --fail "$ASF_REPO/releases/tag/$1" > /dev/null } function get_release_info { @@ -91,7 +90,7 @@ function get_release_info { export GIT_BRANCH=$(read_config "Branch" "$GIT_BRANCH") # Find the current version for the branch. - local VERSION=$(curl -s "$ASF_REPO_WEBUI;a=blob_plain;f=pom.xml;hb=refs/heads/$GIT_BRANCH" | + local VERSION=$(curl -s "$ASF_REPO_WEBUI/$GIT_BRANCH/pom.xml" | parse_version) echo "Current branch version is $VERSION." diff --git a/dev/deps/spark-deps-hadoop-2.7-hive-1.2 b/dev/deps/spark-deps-hadoop-2.7-hive-1.2 index 0f8da141249ca..0fd8005582738 100644 --- a/dev/deps/spark-deps-hadoop-2.7-hive-1.2 +++ b/dev/deps/spark-deps-hadoop-2.7-hive-1.2 @@ -35,7 +35,7 @@ commons-beanutils/1.9.4//commons-beanutils-1.9.4.jar commons-cli/1.2//commons-cli-1.2.jar commons-codec/1.10//commons-codec-1.10.jar commons-collections/3.2.2//commons-collections-3.2.2.jar -commons-compiler/3.0.16//commons-compiler-3.0.16.jar +commons-compiler/3.1.2//commons-compiler-3.1.2.jar commons-compress/1.8.1//commons-compress-1.8.1.jar commons-configuration/1.6//commons-configuration-1.6.jar commons-crypto/1.0.0//commons-crypto-1.0.0.jar @@ -93,6 +93,7 @@ jackson-core-asl/1.9.13//jackson-core-asl-1.9.13.jar jackson-core/2.10.0//jackson-core-2.10.0.jar jackson-databind/2.10.0//jackson-databind-2.10.0.jar jackson-dataformat-yaml/2.10.0//jackson-dataformat-yaml-2.10.0.jar +jackson-datatype-jsr310/2.10.3//jackson-datatype-jsr310-2.10.3.jar jackson-jaxrs/1.9.13//jackson-jaxrs-1.9.13.jar jackson-mapper-asl/1.9.13//jackson-mapper-asl-1.9.13.jar jackson-module-jaxb-annotations/2.10.0//jackson-module-jaxb-annotations-2.10.0.jar @@ -105,7 +106,7 @@ jakarta.inject/2.6.1//jakarta.inject-2.6.1.jar jakarta.validation-api/2.0.2//jakarta.validation-api-2.0.2.jar jakarta.ws.rs-api/2.1.6//jakarta.ws.rs-api-2.1.6.jar jakarta.xml.bind-api/2.3.2//jakarta.xml.bind-api-2.3.2.jar -janino/3.0.16//janino-3.0.16.jar +janino/3.1.2//janino-3.1.2.jar javassist/3.25.0-GA//javassist-3.25.0-GA.jar javax.inject/1//javax.inject-1.jar javax.servlet-api/3.1.0//javax.servlet-api-3.1.0.jar @@ -137,9 +138,9 @@ jsr305/3.0.0//jsr305-3.0.0.jar jta/1.1//jta-1.1.jar jul-to-slf4j/1.7.30//jul-to-slf4j-1.7.30.jar kryo-shaded/4.0.2//kryo-shaded-4.0.2.jar -kubernetes-client/4.7.1//kubernetes-client-4.7.1.jar -kubernetes-model-common/4.7.1//kubernetes-model-common-4.7.1.jar -kubernetes-model/4.7.1//kubernetes-model-4.7.1.jar +kubernetes-client/4.9.2//kubernetes-client-4.9.2.jar +kubernetes-model-common/4.9.2//kubernetes-model-common-4.9.2.jar +kubernetes-model/4.9.2//kubernetes-model-4.9.2.jar leveldbjni-all/1.8//leveldbjni-all-1.8.jar libfb303/0.9.3//libfb303-0.9.3.jar libthrift/0.12.0//libthrift-0.12.0.jar @@ -187,7 +188,7 @@ shims/0.7.45//shims-0.7.45.jar slf4j-api/1.7.30//slf4j-api-1.7.30.jar slf4j-log4j12/1.7.30//slf4j-log4j12-1.7.30.jar snakeyaml/1.24//snakeyaml-1.24.jar -snappy-java/1.1.7.3//snappy-java-1.1.7.3.jar +snappy-java/1.1.7.5//snappy-java-1.1.7.5.jar snappy/0.2//snappy-0.2.jar spire-macros_2.12/0.17.0-M1//spire-macros_2.12-0.17.0-M1.jar spire-platform_2.12/0.17.0-M1//spire-platform_2.12-0.17.0-M1.jar @@ -207,4 +208,4 @@ xmlenc/0.52//xmlenc-0.52.jar xz/1.5//xz-1.5.jar zjsonpatch/0.3.0//zjsonpatch-0.3.0.jar zookeeper/3.4.14//zookeeper-3.4.14.jar -zstd-jni/1.4.4-3//zstd-jni-1.4.4-3.jar +zstd-jni/1.4.5-2//zstd-jni-1.4.5-2.jar diff --git a/dev/deps/spark-deps-hadoop-2.7-hive-2.3 b/dev/deps/spark-deps-hadoop-2.7-hive-2.3 index 6abdd7409eb14..e4df088e08b65 100644 --- a/dev/deps/spark-deps-hadoop-2.7-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-2.7-hive-2.3 @@ -33,7 +33,7 @@ commons-beanutils/1.9.4//commons-beanutils-1.9.4.jar commons-cli/1.2//commons-cli-1.2.jar commons-codec/1.10//commons-codec-1.10.jar commons-collections/3.2.2//commons-collections-3.2.2.jar -commons-compiler/3.0.16//commons-compiler-3.0.16.jar +commons-compiler/3.1.2//commons-compiler-3.1.2.jar commons-compress/1.8.1//commons-compress-1.8.1.jar commons-configuration/1.6//commons-configuration-1.6.jar commons-crypto/1.0.0//commons-crypto-1.0.0.jar @@ -106,6 +106,7 @@ jackson-core-asl/1.9.13//jackson-core-asl-1.9.13.jar jackson-core/2.10.0//jackson-core-2.10.0.jar jackson-databind/2.10.0//jackson-databind-2.10.0.jar jackson-dataformat-yaml/2.10.0//jackson-dataformat-yaml-2.10.0.jar +jackson-datatype-jsr310/2.10.3//jackson-datatype-jsr310-2.10.3.jar jackson-jaxrs/1.9.13//jackson-jaxrs-1.9.13.jar jackson-mapper-asl/1.9.13//jackson-mapper-asl-1.9.13.jar jackson-module-jaxb-annotations/2.10.0//jackson-module-jaxb-annotations-2.10.0.jar @@ -118,7 +119,7 @@ jakarta.inject/2.6.1//jakarta.inject-2.6.1.jar jakarta.validation-api/2.0.2//jakarta.validation-api-2.0.2.jar jakarta.ws.rs-api/2.1.6//jakarta.ws.rs-api-2.1.6.jar jakarta.xml.bind-api/2.3.2//jakarta.xml.bind-api-2.3.2.jar -janino/3.0.16//janino-3.0.16.jar +janino/3.1.2//janino-3.1.2.jar javassist/3.25.0-GA//javassist-3.25.0-GA.jar javax.inject/1//javax.inject-1.jar javax.jdo/3.2.0-m3//javax.jdo-3.2.0-m3.jar @@ -152,9 +153,9 @@ jsr305/3.0.0//jsr305-3.0.0.jar jta/1.1//jta-1.1.jar jul-to-slf4j/1.7.30//jul-to-slf4j-1.7.30.jar kryo-shaded/4.0.2//kryo-shaded-4.0.2.jar -kubernetes-client/4.7.1//kubernetes-client-4.7.1.jar -kubernetes-model-common/4.7.1//kubernetes-model-common-4.7.1.jar -kubernetes-model/4.7.1//kubernetes-model-4.7.1.jar +kubernetes-client/4.9.2//kubernetes-client-4.9.2.jar +kubernetes-model-common/4.9.2//kubernetes-model-common-4.9.2.jar +kubernetes-model/4.9.2//kubernetes-model-4.9.2.jar leveldbjni-all/1.8//leveldbjni-all-1.8.jar libfb303/0.9.3//libfb303-0.9.3.jar libthrift/0.12.0//libthrift-0.12.0.jar @@ -201,7 +202,7 @@ shims/0.7.45//shims-0.7.45.jar slf4j-api/1.7.30//slf4j-api-1.7.30.jar slf4j-log4j12/1.7.30//slf4j-log4j12-1.7.30.jar snakeyaml/1.24//snakeyaml-1.24.jar -snappy-java/1.1.7.3//snappy-java-1.1.7.3.jar +snappy-java/1.1.7.5//snappy-java-1.1.7.5.jar spire-macros_2.12/0.17.0-M1//spire-macros_2.12-0.17.0-M1.jar spire-platform_2.12/0.17.0-M1//spire-platform_2.12-0.17.0-M1.jar spire-util_2.12/0.17.0-M1//spire-util_2.12-0.17.0-M1.jar @@ -221,4 +222,4 @@ xmlenc/0.52//xmlenc-0.52.jar xz/1.5//xz-1.5.jar zjsonpatch/0.3.0//zjsonpatch-0.3.0.jar zookeeper/3.4.14//zookeeper-3.4.14.jar -zstd-jni/1.4.4-3//zstd-jni-1.4.4-3.jar +zstd-jni/1.4.5-2//zstd-jni-1.4.5-2.jar diff --git a/dev/deps/spark-deps-hadoop-3.2-hive-2.3 b/dev/deps/spark-deps-hadoop-3.2-hive-2.3 index 3553734b35fe6..7f3f74e3e039d 100644 --- a/dev/deps/spark-deps-hadoop-3.2-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3.2-hive-2.3 @@ -30,14 +30,14 @@ commons-beanutils/1.9.4//commons-beanutils-1.9.4.jar commons-cli/1.2//commons-cli-1.2.jar commons-codec/1.10//commons-codec-1.10.jar commons-collections/3.2.2//commons-collections-3.2.2.jar -commons-compiler/3.0.16//commons-compiler-3.0.16.jar +commons-compiler/3.1.2//commons-compiler-3.1.2.jar commons-compress/1.8.1//commons-compress-1.8.1.jar commons-configuration2/2.1.1//commons-configuration2-2.1.1.jar commons-crypto/1.0.0//commons-crypto-1.0.0.jar commons-daemon/1.0.13//commons-daemon-1.0.13.jar commons-dbcp/1.4//commons-dbcp-1.4.jar commons-httpclient/3.1//commons-httpclient-3.1.jar -commons-io/2.4//commons-io-2.4.jar +commons-io/2.5//commons-io-2.5.jar commons-lang/2.6//commons-lang-2.6.jar commons-lang3/3.9//commons-lang3-3.9.jar commons-logging/1.1.3//commons-logging-1.1.3.jar @@ -105,6 +105,7 @@ jackson-core-asl/1.9.13//jackson-core-asl-1.9.13.jar jackson-core/2.10.0//jackson-core-2.10.0.jar jackson-databind/2.10.0//jackson-databind-2.10.0.jar jackson-dataformat-yaml/2.10.0//jackson-dataformat-yaml-2.10.0.jar +jackson-datatype-jsr310/2.10.3//jackson-datatype-jsr310-2.10.3.jar jackson-jaxrs-base/2.9.5//jackson-jaxrs-base-2.9.5.jar jackson-jaxrs-json-provider/2.9.5//jackson-jaxrs-json-provider-2.9.5.jar jackson-mapper-asl/1.9.13//jackson-mapper-asl-1.9.13.jar @@ -117,7 +118,7 @@ jakarta.inject/2.6.1//jakarta.inject-2.6.1.jar jakarta.validation-api/2.0.2//jakarta.validation-api-2.0.2.jar jakarta.ws.rs-api/2.1.6//jakarta.ws.rs-api-2.1.6.jar jakarta.xml.bind-api/2.3.2//jakarta.xml.bind-api-2.3.2.jar -janino/3.0.16//janino-3.0.16.jar +janino/3.1.2//janino-3.1.2.jar javassist/3.25.0-GA//javassist-3.25.0-GA.jar javax.inject/1//javax.inject-1.jar javax.jdo/3.2.0-m3//javax.jdo-3.2.0-m3.jar @@ -164,9 +165,9 @@ kerby-pkix/1.0.1//kerby-pkix-1.0.1.jar kerby-util/1.0.1//kerby-util-1.0.1.jar kerby-xdr/1.0.1//kerby-xdr-1.0.1.jar kryo-shaded/4.0.2//kryo-shaded-4.0.2.jar -kubernetes-client/4.7.1//kubernetes-client-4.7.1.jar -kubernetes-model-common/4.7.1//kubernetes-model-common-4.7.1.jar -kubernetes-model/4.7.1//kubernetes-model-4.7.1.jar +kubernetes-client/4.9.2//kubernetes-client-4.9.2.jar +kubernetes-model-common/4.9.2//kubernetes-model-common-4.9.2.jar +kubernetes-model/4.9.2//kubernetes-model-4.9.2.jar leveldbjni-all/1.8//leveldbjni-all-1.8.jar libfb303/0.9.3//libfb303-0.9.3.jar libthrift/0.12.0//libthrift-0.12.0.jar @@ -182,7 +183,6 @@ metrics-jmx/4.1.1//metrics-jmx-4.1.1.jar metrics-json/4.1.1//metrics-json-4.1.1.jar metrics-jvm/4.1.1//metrics-jvm-4.1.1.jar minlog/1.3.0//minlog-1.3.0.jar -mssql-jdbc/6.2.1.jre7//mssql-jdbc-6.2.1.jre7.jar netty-all/4.1.47.Final//netty-all-4.1.47.Final.jar nimbus-jose-jwt/4.41.1//nimbus-jose-jwt-4.41.1.jar objenesis/2.5.1//objenesis-2.5.1.jar @@ -217,7 +217,7 @@ shims/0.7.45//shims-0.7.45.jar slf4j-api/1.7.30//slf4j-api-1.7.30.jar slf4j-log4j12/1.7.30//slf4j-log4j12-1.7.30.jar snakeyaml/1.24//snakeyaml-1.24.jar -snappy-java/1.1.7.3//snappy-java-1.1.7.3.jar +snappy-java/1.1.7.5//snappy-java-1.1.7.5.jar spire-macros_2.12/0.17.0-M1//spire-macros_2.12-0.17.0-M1.jar spire-platform_2.12/0.17.0-M1//spire-platform_2.12-0.17.0-M1.jar spire-util_2.12/0.17.0-M1//spire-util_2.12-0.17.0-M1.jar @@ -236,4 +236,4 @@ xbean-asm7-shaded/4.15//xbean-asm7-shaded-4.15.jar xz/1.5//xz-1.5.jar zjsonpatch/0.3.0//zjsonpatch-0.3.0.jar zookeeper/3.4.14//zookeeper-3.4.14.jar -zstd-jni/1.4.4-3//zstd-jni-1.4.4-3.jar +zstd-jni/1.4.5-2//zstd-jni-1.4.5-2.jar diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index 72e32d4e16e14..13be9592d771f 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -198,7 +198,7 @@ def main(): # format: http://linux.die.net/man/1/timeout # must be less than the timeout configured on Jenkins. Usually Jenkins's timeout is higher # then this. Please consult with the build manager or a committer when it should be increased. - tests_timeout = "400m" + tests_timeout = "500m" # Array to capture all test names to run on the pull request. These tests are represented # by their file equivalents in the dev/tests/ directory. diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index 936ac00f6b9e7..b3e68bed1d1e7 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -17,7 +17,7 @@ # limitations under the License. # -set -e +set -ex FWDIR="$(cd "`dirname $0`"/..; pwd)" cd "$FWDIR" @@ -47,7 +47,7 @@ OLD_VERSION=$($MVN -q \ -Dexec.executable="echo" \ -Dexec.args='${project.version}' \ --non-recursive \ - org.codehaus.mojo:exec-maven-plugin:1.6.0:exec) + org.codehaus.mojo:exec-maven-plugin:1.6.0:exec | grep -E '[0-9]+\.[0-9]+\.[0-9]+') if [ $? != 0 ]; then echo -e "Error while getting version string from Maven:\n$OLD_VERSION" exit 1 diff --git a/docs/_data/menu-sql.yaml b/docs/_data/menu-sql.yaml index dfe4cfab2a2ab..219e6809a96f0 100644 --- a/docs/_data/menu-sql.yaml +++ b/docs/_data/menu-sql.yaml @@ -76,14 +76,6 @@ - text: SQL Reference url: sql-ref.html subitems: - - text: Data Types - url: sql-ref-datatypes.html - - text: Identifiers - url: sql-ref-identifier.html - - text: Literals - url: sql-ref-literals.html - - text: Null Semantics - url: sql-ref-null-semantics.html - text: ANSI Compliance url: sql-ref-ansi-compliance.html subitems: @@ -93,6 +85,27 @@ url: sql-ref-ansi-compliance.html#type-conversion - text: SQL Keywords url: sql-ref-ansi-compliance.html#sql-keywords + - text: Data Types + url: sql-ref-datatypes.html + - text: Datetime Pattern + url: sql-ref-datetime-pattern.html + - text: Functions + url: sql-ref-functions.html + subitems: + - text: Built-in Functions + url: sql-ref-functions-builtin.html + - text: Scalar UDFs (User-Defined Functions) + url: sql-ref-functions-udf-scalar.html + - text: UDAFs (User-Defined Aggregate Functions) + url: sql-ref-functions-udf-aggregate.html + - text: Integration with Hive UDFs/UDAFs/UDTFs + url: sql-ref-functions-udf-hive.html + - text: Identifiers + url: sql-ref-identifier.html + - text: Literals + url: sql-ref-literals.html + - text: Null Semantics + url: sql-ref-null-semantics.html - text: SQL Syntax url: sql-ref-syntax.html subitems: @@ -156,24 +169,24 @@ url: sql-ref-syntax-qry-select-distribute-by.html - text: LIMIT Clause url: sql-ref-syntax-qry-select-limit.html + - text: Common Table Expression + url: sql-ref-syntax-qry-select-cte.html + - text: Hints + url: sql-ref-syntax-qry-select-hints.html + - text: Inline Table + url: sql-ref-syntax-qry-select-inline-table.html - text: JOIN url: sql-ref-syntax-qry-select-join.html - - text: Join Hints - url: sql-ref-syntax-qry-select-hints.html + - text: LIKE Predicate + url: sql-ref-syntax-qry-select-like.html - text: Set Operators url: sql-ref-syntax-qry-select-setops.html - text: TABLESAMPLE - url: sql-ref-syntax-qry-sampling.html + url: sql-ref-syntax-qry-select-sampling.html - text: Table-valued Function url: sql-ref-syntax-qry-select-tvf.html - - text: Inline Table - url: sql-ref-syntax-qry-select-inline-table.html - - text: Common Table Expression - url: sql-ref-syntax-qry-select-cte.html - - text: LIKE Predicate - url: sql-ref-syntax-qry-select-like.html - text: Window Function - url: sql-ref-syntax-qry-window.html + url: sql-ref-syntax-qry-select-window.html - text: EXPLAIN url: sql-ref-syntax-qry-explain.html - text: Auxiliary Statements @@ -213,20 +226,20 @@ subitems: - text: SHOW COLUMNS url: sql-ref-syntax-aux-show-columns.html + - text: SHOW CREATE TABLE + url: sql-ref-syntax-aux-show-create-table.html - text: SHOW DATABASES url: sql-ref-syntax-aux-show-databases.html - text: SHOW FUNCTIONS url: sql-ref-syntax-aux-show-functions.html + - text: SHOW PARTITIONS + url: sql-ref-syntax-aux-show-partitions.html - text: SHOW TABLE url: sql-ref-syntax-aux-show-table.html - text: SHOW TABLES url: sql-ref-syntax-aux-show-tables.html - text: SHOW TBLPROPERTIES url: sql-ref-syntax-aux-show-tblproperties.html - - text: SHOW PARTITIONS - url: sql-ref-syntax-aux-show-partitions.html - - text: SHOW CREATE TABLE - url: sql-ref-syntax-aux-show-create-table.html - text: SHOW VIEWS url: sql-ref-syntax-aux-show-views.html - text: CONFIGURATION MANAGEMENT @@ -247,16 +260,3 @@ url: sql-ref-syntax-aux-resource-mgmt-list-file.html - text: LIST JAR url: sql-ref-syntax-aux-resource-mgmt-list-jar.html - - text: Functions - url: sql-ref-functions.html - subitems: - - text: Built-in Functions - url: sql-ref-functions-builtin.html - - text: Scalar UDFs (User-Defined Functions) - url: sql-ref-functions-udf-scalar.html - - text: UDAFs (User-Defined Aggregate Functions) - url: sql-ref-functions-udf-aggregate.html - - text: Integration with Hive UDFs/UDAFs/UDTFs - url: sql-ref-functions-udf-hive.html - - text: Datetime Pattern - url: sql-ref-datetime-pattern.html diff --git a/docs/configuration.md b/docs/configuration.md index fce04b940594b..420942f7b7bbb 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -2955,6 +2955,12 @@ Spark uses [log4j](http://logging.apache.org/log4j/) for logging. You can config `log4j.properties` file in the `conf` directory. One way to start is to copy the existing `log4j.properties.template` located there. +By default, Spark adds 1 record to the MDC (Mapped Diagnostic Context): `taskName`, which shows something +like `task 1.0 in stage 0.0`. You can add `%X{taskName}` to your patternLayout in +order to print it in the logs. +Moreover, you can use `spark.sparkContext.setLocalProperty("mdc." + name, "value")` to add user specific data into MDC. +The key in MDC will be the string after the `mdc.` prefix. + # Overriding configuration directory To specify a different configuration directory other than the default "SPARK_HOME/conf", diff --git a/docs/img/webui-structured-streaming-detail.png b/docs/img/webui-structured-streaming-detail.png new file mode 100644 index 0000000000000..f4850523c5c2f Binary files /dev/null and b/docs/img/webui-structured-streaming-detail.png differ diff --git a/docs/ml-features.md b/docs/ml-features.md index 05ef848aefef8..660c27250ebfb 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1793,6 +1793,210 @@ for more details on the API.
+## ANOVASelector + +`ANOVASelector` operates on categorical labels with continuous features. It uses the +[one-way ANOVA F-test](https://en.wikipedia.org/wiki/F-test#Multiple-comparison_ANOVA_problems) to decide which +features to choose. +It supports five selection methods: `numTopFeatures`, `percentile`, `fpr`, `fdr`, `fwe`: +* `numTopFeatures` chooses a fixed number of top features according to ANOVA F-test. +* `percentile` is similar to `numTopFeatures` but chooses a fraction of all features instead of a fixed number. +* `fpr` chooses all features whose p-values are below a threshold, thus controlling the false positive rate of selection. +* `fdr` uses the [Benjamini-Hochberg procedure](https://en.wikipedia.org/wiki/False_discovery_rate#Benjamini.E2.80.93Hochberg_procedure) to choose all features whose false discovery rate is below a threshold. +* `fwe` chooses all features whose p-values are below a threshold. The threshold is scaled by 1/numFeatures, thus controlling the family-wise error rate of selection. +By default, the selection method is `numTopFeatures`, with the default number of top features set to 50. +The user can choose a selection method using `setSelectorType`. + +**Examples** + +Assume that we have a DataFrame with the columns `id`, `features`, and `label`, which is used as +our target to be predicted: + +~~~ +id | features | label +---|--------------------------------|--------- + 1 | [1.7, 4.4, 7.6, 5.8, 9.6, 2.3] | 3.0 + 2 | [8.8, 7.3, 5.7, 7.3, 2.2, 4.1] | 2.0 + 3 | [1.2, 9.5, 2.5, 3.1, 8.7, 2.5] | 3.0 + 4 | [3.7, 9.2, 6.1, 4.1, 7.5, 3.8] | 2.0 + 5 | [8.9, 5.2, 7.8, 8.3, 5.2, 3.0] | 4.0 + 6 | [7.9, 8.5, 9.2, 4.0, 9.4, 2.1] | 4.0 +~~~ + +If we use `ANOVASelector` with `numTopFeatures = 1`, the +last column in our `features` is chosen as the most useful feature: + +~~~ +id | features | label | selectedFeatures +---|--------------------------------|---------|------------------ + 1 | [1.7, 4.4, 7.6, 5.8, 9.6, 2.3] | 3.0 | [2.3] + 2 | [8.8, 7.3, 5.7, 7.3, 2.2, 4.1] | 2.0 | [4.1] + 3 | [1.2, 9.5, 2.5, 3.1, 8.7, 2.5] | 3.0 | [2.5] + 4 | [3.7, 9.2, 6.1, 4.1, 7.5, 3.8] | 2.0 | [3.8] + 5 | [8.9, 5.2, 7.8, 8.3, 5.2, 3.0] | 4.0 | [3.0] + 6 | [7.9, 8.5, 9.2, 4.0, 9.4, 2.1] | 4.0 | [2.1] +~~~ + +
+
+ +Refer to the [ANOVASelector Scala docs](api/scala/org/apache/spark/ml/feature/ANOVASelector.html) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/ANOVASelectorExample.scala %} +
+ +
+ +Refer to the [ANOVASelector Java docs](api/java/org/apache/spark/ml/feature/ANOVASelector.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaANOVASelectorExample.java %} +
+ +
+ +Refer to the [ANOVASelector Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.ANOVASelector) +for more details on the API. + +{% include_example python/ml/anova_selector_example.py %} +
+
+ +## FValueSelector + +`FValueSelector` operates on categorical labels with continuous features. It uses the +[F-test for regression](https://en.wikipedia.org/wiki/F-test#Regression_problems) to decide which +features to choose. +It supports five selection methods: `numTopFeatures`, `percentile`, `fpr`, `fdr`, `fwe`: +* `numTopFeatures` chooses a fixed number of top features according to a F-test for regression. +* `percentile` is similar to `numTopFeatures` but chooses a fraction of all features instead of a fixed number. +* `fpr` chooses all features whose p-values are below a threshold, thus controlling the false positive rate of selection. +* `fdr` uses the [Benjamini-Hochberg procedure](https://en.wikipedia.org/wiki/False_discovery_rate#Benjamini.E2.80.93Hochberg_procedure) to choose all features whose false discovery rate is below a threshold. +* `fwe` chooses all features whose p-values are below a threshold. The threshold is scaled by 1/numFeatures, thus controlling the family-wise error rate of selection. +By default, the selection method is `numTopFeatures`, with the default number of top features set to 50. +The user can choose a selection method using `setSelectorType`. + +**Examples** + +Assume that we have a DataFrame with the columns `id`, `features`, and `label`, which is used as +our target to be predicted: + +~~~ +id | features | label +---|--------------------------------|--------- + 1 | [6.0, 7.0, 0.0, 7.0, 6.0, 0.0] | 4.6 + 2 | [0.0, 9.0, 6.0, 0.0, 5.0, 9.0] | 6.6 + 3 | [0.0, 9.0, 3.0, 0.0, 5.0, 5.0] | 5.1 + 4 | [0.0, 9.0, 8.0, 5.0, 6.0, 4.0] | 7.6 + 5 | [8.0, 9.0, 6.0, 5.0, 4.0, 4.0] | 9.0 + 6 | [8.0, 9.0, 6.0, 4.0, 0.0, 0.0] | 9.0 +~~~ + +If we use `FValueSelector` with `numTopFeatures = 1`, the +3rd column in our `features` is chosen as the most useful feature: + +~~~ +id | features | label | selectedFeatures +---|--------------------------------|---------|------------------ + 1 | [6.0, 7.0, 0.0, 7.0, 6.0, 0.0] | 4.6 | [0.0] + 2 | [0.0, 9.0, 6.0, 0.0, 5.0, 9.0] | 6.6 | [6.0] + 3 | [0.0, 9.0, 3.0, 0.0, 5.0, 5.0] | 5.1 | [3.0] + 4 | [0.0, 9.0, 8.0, 5.0, 6.0, 4.0] | 7.6 | [8.0] + 5 | [8.0, 9.0, 6.0, 5.0, 4.0, 4.0] | 9.0 | [6.0] + 6 | [8.0, 9.0, 6.0, 4.0, 0.0, 0.0] | 9.0 | [6.0] +~~~ + +
+
+ +Refer to the [FValueSelector Scala docs](api/scala/org/apache/spark/ml/feature/FValueSelector.html) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/FValueSelectorExample.scala %} +
+ +
+ +Refer to the [FValueSelector Java docs](api/java/org/apache/spark/ml/feature/FValueSelector.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaFValueSelectorExample.java %} +
+ +
+ +Refer to the [FValueSelector Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.FValueSelector) +for more details on the API. + +{% include_example python/ml/anova_selector_example.py %} +
+
+ +## VarianceThresholdSelector + +`VarianceThresholdSelector` is a selector that removes low-variance features. Features with a + variance not greater than the `varianceThreshold` will be removed. If not set, `varianceThreshold` + defaults to 0, which means only features with variance 0 (i.e. features that have the same value in all samples) + will be removed. + +**Examples** + +Assume that we have a DataFrame with the columns `id` and `features`, which is used as +our target to be predicted: + +~~~ +id | features +---|-------------------------------- + 1 | [6.0, 7.0, 0.0, 7.0, 6.0, 0.0] + 2 | [0.0, 9.0, 6.0, 0.0, 5.0, 9.0] + 3 | [0.0, 9.0, 3.0, 0.0, 5.0, 5.0] + 4 | [0.0, 9.0, 8.0, 5.0, 6.0, 4.0] + 5 | [8.0, 9.0, 6.0, 5.0, 4.0, 4.0] + 6 | [8.0, 9.0, 6.0, 0.0, 0.0, 0.0] +~~~ + +The variance for the 6 features are 16.67, 0.67, 8.17, 10.17, +5.07, and 11.47 respectively. If we use `VarianceThresholdSelector` with +`varianceThreshold = 8.0`, then the features with variance <= 8.0 are removed: + +~~~ +id | features | selectedFeatures +---|--------------------------------|------------------- + 1 | [6.0, 7.0, 0.0, 7.0, 6.0, 0.0] | [6.0,0.0,7.0,0.0] + 2 | [0.0, 9.0, 6.0, 0.0, 5.0, 9.0] | [0.0,6.0,0.0,9.0] + 3 | [0.0, 9.0, 3.0, 0.0, 5.0, 5.0] | [0.0,3.0,0.0,5.0] + 4 | [0.0, 9.0, 8.0, 5.0, 6.0, 4.0] | [0.0,8.0,5.0,4.0] + 5 | [8.0, 9.0, 6.0, 5.0, 4.0, 4.0] | [8.0,6.0,5.0,4.0] + 6 | [8.0, 9.0, 6.0, 0.0, 0.0, 0.0] | [8.0,6.0,0.0,0.0] +~~~ + +
+
+ +Refer to the [VarianceThresholdSelector Scala docs]((api/python/pyspark.ml.html#pyspark.ml.feature.ChiSqSelector)) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/VarianceThresholdSelectorExample.scala %} +
+ +
+ +Refer to the [VarianceThresholdSelector Java docs](api/java/org/apache/spark/ml/feature/VarianceThresholdSelector.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaVarianceThresholdSelectorExample.java %} +
+ +
+ +Refer to the [VarianceThresholdSelector Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.VarianceThresholdSelector) +for more details on the API. + +{% include_example python/ml/variance_threshold_selector_example.py %} +
+
+ # Locality Sensitive Hashing [Locality Sensitive Hashing (LSH)](https://en.wikipedia.org/wiki/Locality-sensitive_hashing) is an important class of hashing techniques, which is commonly used in clustering, approximate nearest neighbor search and outlier detection with large datasets. diff --git a/docs/ml-statistics.md b/docs/ml-statistics.md index a3d57ff7d266b..637cdd6c78f10 100644 --- a/docs/ml-statistics.md +++ b/docs/ml-statistics.md @@ -79,7 +79,35 @@ The output will be a DataFrame that contains the correlation matrix of the colum Hypothesis testing is a powerful tool in statistics to determine whether a result is statistically significant, whether this result occurred by chance or not. `spark.ml` currently supports Pearson's -Chi-squared ( $\chi^2$) tests for independence. +Chi-squared ( $\chi^2$) tests for independence, as well as ANOVA test for classification tasks and +F-value test for regression tasks. + +### ANOVATest + +`ANOVATest` computes ANOVA F-values between labels and features for classification tasks. The labels should be categorical +and features should be continuous. + +
+
+Refer to the [`ANOVATest` Scala docs](api/scala/org/apache/spark/ml/stat/ANOVATest$.html) for details on the API. + +{% include_example scala/org/apache/spark/examples/ml/ANOVATestExample.scala %} +
+ +
+Refer to the [`ANOVATest` Java docs](api/java/org/apache/spark/ml/stat/ANOVATest.html) for details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaANOVATestExample.java %} +
+ +
+Refer to the [`ANOVATest` Python docs](api/python/index.html#pyspark.ml.stat.ANOVATest$) for details on the API. + +{% include_example python/ml/anova_test_example.py %} +
+
+ +### ChiSquareTest `ChiSquareTest` conducts Pearson's independence test for every feature against the label. For each feature, the (feature, label) pairs are converted into a contingency matrix for which @@ -106,6 +134,32 @@ Refer to the [`ChiSquareTest` Python docs](api/python/index.html#pyspark.ml.stat +### FValueTest + +`FValueTest` computes F-values between labels and features for regression tasks. Both the labels + and features should be continuous. + +
+
+ Refer to the [`FValueTest` Scala docs](api/scala/org/apache/spark/ml/stat/FValueTest$.html) for details on the API. + + {% include_example scala/org/apache/spark/examples/ml/FValueTestExample.scala %} +
+ +
+ Refer to the [`FValueTest` Java docs](api/java/org/apache/spark/ml/stat/FValueTest.html) for details on the API. + + {% include_example java/org/apache/spark/examples/ml/JavaFValueTestExample.java %} +
+ +
+ Refer to the [`FValueTest` Python docs](api/python/index.html#pyspark.ml.stat.FValueTest$) for details on the API. + + {% include_example python/ml/fvalue_test_example.py %} +
+ +
+ ## Summarizer We provide vector column summary statistics for `Dataframe` through `Summarizer`. diff --git a/docs/monitoring.md b/docs/monitoring.md index 7e41c9df4efa0..32959b77c4773 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -544,6 +544,24 @@ can be identified by their `[attempt-id]`. In the API listed below, when running /applications/[app-id]/streaming/batches/[batch-id]/operations/[outputOp-id] Details of the given operation and given batch. + + /applications/[app-id]/sql + A list of all queries for a given application. +
+ ?details=[true (default) | false] lists/hides details of Spark plan nodes. +
+ ?planDescription=[true (default) | false] enables/disables Physical planDescription on demand when Physical Plan size is high. +
+ ?offset=[offset]&length=[len] lists queries in the given range. + + + /applications/[app-id]/sql/[execution-id] + Details for the given query. +
+ ?details=[true (default) | false] lists/hides metric details in addition to given query details. +
+ ?planDescription=[true (default) | false] enables/disables Physical planDescription on demand for the given query when Physical Plan size is high. + /applications/[app-id]/environment Environment details of the given application. @@ -715,7 +733,7 @@ A list of the available metrics, with a short description: Executor-level metrics are sent from each executor to the driver as part of the Heartbeat to describe the performance metrics of Executor itself like JVM heap memory, GC information. Executor metric values and their measured memory peak values per executor are exposed via the REST API in JSON format and in Prometheus format. The JSON end point is exposed at: `/applications/[app-id]/executors`, and the Prometheus endpoint at: `/metrics/executors/prometheus`. -The Prometheus endpoint is conditional to a configuration parameter: `spark.ui.prometheus.enabled=true` (the default is `false`). +The Prometheus endpoint is experimental and conditional to a configuration parameter: `spark.ui.prometheus.enabled=true` (the default is `false`). In addition, aggregated per-stage peak values of the executor memory metrics are written to the event log if `spark.eventLog.logStageExecutorMetrics` is true. Executor memory metrics are also exposed via the Spark metrics system based on the Dropwizard metrics library. @@ -963,7 +981,7 @@ Each instance can report to zero or more _sinks_. Sinks are contained in the * `CSVSink`: Exports metrics data to CSV files at regular intervals. * `JmxSink`: Registers metrics for viewing in a JMX console. * `MetricsServlet`: Adds a servlet within the existing Spark UI to serve metrics data as JSON data. -* `PrometheusServlet`: Adds a servlet within the existing Spark UI to serve metrics data in Prometheus format. +* `PrometheusServlet`: (Experimental) Adds a servlet within the existing Spark UI to serve metrics data in Prometheus format. * `GraphiteSink`: Sends metrics to a Graphite node. * `Slf4jSink`: Sends metrics to slf4j as log entries. * `StatsdSink`: Sends metrics to a StatsD node. diff --git a/docs/pyspark-migration-guide.md b/docs/pyspark-migration-guide.md index 6f0fbbfb78de8..2c9ea410f217e 100644 --- a/docs/pyspark-migration-guide.md +++ b/docs/pyspark-migration-guide.md @@ -45,6 +45,8 @@ Please refer [Migration Guide: SQL, Datasets and DataFrame](sql-migration-guide. - As of Spark 3.0, `Row` field names are no longer sorted alphabetically when constructing with named arguments for Python versions 3.6 and above, and the order of fields will match that as entered. To enable sorted fields by default, as in Spark 2.4, set the environment variable `PYSPARK_ROW_FIELD_SORTING_ENABLED` to `true` for both executors and driver - this environment variable must be consistent on all executors and driver; otherwise, it may cause failures or incorrect answers. For Python versions less than 3.6, the field names will be sorted alphabetically as the only option. +- In Spark 3.0, `pyspark.ml.param.shared.Has*` mixins do not provide any `set*(self, value)` setter methods anymore, use the respective `self.set(self.*, value)` instead. See [SPARK-29093](https://issues.apache.org/jira/browse/SPARK-29093) for details. + ## Upgrading from PySpark 2.3 to 2.4 - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unable to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. diff --git a/docs/rdd-programming-guide.md b/docs/rdd-programming-guide.md index ba99007aaf639..70bfefce475a1 100644 --- a/docs/rdd-programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -360,7 +360,7 @@ Some notes on reading files with Spark: * If using a path on the local filesystem, the file must also be accessible at the same path on worker nodes. Either copy the file to all workers or use a network-mounted shared file system. -* All of Spark's file-based input methods, including `textFile`, support running on directories, compressed files, and wildcards as well. For example, you can use `textFile("/my/directory")`, `textFile("/my/directory/*.txt")`, and `textFile("/my/directory/*.gz")`. +* All of Spark's file-based input methods, including `textFile`, support running on directories, compressed files, and wildcards as well. For example, you can use `textFile("/my/directory")`, `textFile("/my/directory/*.txt")`, and `textFile("/my/directory/*.gz")`. When multiple files are read, the order of the partitions depends on the order the files are returned from the filesystem. It may or may not, for example, follow the lexicographic ordering of the files by path. Within a partition, elements are ordered according to their order in the underlying file. * The `textFile` method also takes an optional second argument for controlling the number of partitions of the file. By default, Spark creates one partition for each block of the file (blocks being 128MB by default in HDFS), but you can also ask for a higher number of partitions by passing a larger value. Note that you cannot have fewer partitions than blocks. diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 41800266fdd77..ba735cacd4c38 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -811,11 +811,20 @@ See the [configuration page](configuration.html) for information on Spark config spark.kubernetes.driver.annotation.[AnnotationName] (none) - Add the annotation specified by AnnotationName to the driver pod. + Add the Kubernetes annotation specified by AnnotationName to the driver pod. For example, spark.kubernetes.driver.annotation.something=true. 2.3.0 + + spark.kubernetes.driver.service.annotation.[AnnotationName] + (none) + + Add the Kubernetes annotation specified by AnnotationName to the driver service. + For example, spark.kubernetes.driver.service.annotation.something=true. + + 3.0.0 + spark.kubernetes.executor.label.[LabelName] (none) @@ -831,7 +840,7 @@ See the [configuration page](configuration.html) for information on Spark config spark.kubernetes.executor.annotation.[AnnotationName] (none) - Add the annotation specified by AnnotationName to the executor pods. + Add the Kubernetes annotation specified by AnnotationName to the executor pods. For example, spark.kubernetes.executor.annotation.something=true. 2.3.0 diff --git a/docs/sparkr.md b/docs/sparkr.md index d937bc90b6506..d86fa86c89853 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -666,20 +666,15 @@ Apache Arrow is an in-memory columnar data format that is used in Spark to effic ## Ensure Arrow Installed -Arrow R library is available on CRAN as of [ARROW-3204](https://issues.apache.org/jira/browse/ARROW-3204). It can be installed as below. +Arrow R library is available on CRAN and it can be installed as below. ```bash Rscript -e 'install.packages("arrow", repos="https://cloud.r-project.org/")' ``` +Please refer [the official documentation of Apache Arrow](https://arrow.apache.org/docs/r/) for more detials. -If you need to install old versions, it should be installed directly from Github. You can use `remotes::install_github` as below. - -```bash -Rscript -e 'remotes::install_github("apache/arrow@apache-arrow-0.12.1", subdir = "r")' -``` - -`apache-arrow-0.12.1` is a version tag that can be checked in [Arrow at Github](https://github.com/apache/arrow/releases). You must ensure that Arrow R package is installed and available on all cluster nodes. -The current supported minimum version is 0.12.1; however, this might change between the minor releases since Arrow optimization in SparkR is experimental. +Note that you must ensure that Arrow R package is installed and available on all cluster nodes. +The current supported minimum version is 0.15.1; however, this might change between the minor releases since Arrow optimization in SparkR is experimental. ## Enabling for Conversion to/from R DataFrame, `dapply` and `gapply` diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index 699f9acc8c50e..2272c90384847 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -40,8 +40,6 @@ license: | ### DDL Statements - - In Spark 3.0, `CREATE TABLE` without a specific provider uses the value of `spark.sql.sources.default` as its provider. In Spark version 2.4 and below, it was Hive. To restore the behavior before Spark 3.0, you can set `spark.sql.legacy.createHiveTableByDefault.enabled` to `true`. - - In Spark 3.0, when inserting a value into a table column with a different data type, the type coercion is performed as per ANSI SQL standard. Certain unreasonable type conversions such as converting `string` to `int` and `double` to `boolean` are disallowed. A runtime exception is thrown if the value is out-of-range for the data type of the column. In Spark version 2.4 and below, type conversions during table insertion are allowed as long as they are valid `Cast`. When inserting an out-of-range value to an integral field, the low-order bits of the value is inserted(the same as Java/Scala numeric type casting). For example, if 257 is inserted to a field of byte type, the result is 1. The behavior is controlled by the option `spark.sql.storeAssignmentPolicy`, with a default value as "ANSI". Setting the option as "Legacy" restores the previous behavior. - The `ADD JAR` command previously returned a result set with the single value 0. It now returns an empty result set. @@ -963,3 +961,4 @@ Below are the scenarios in which Hive and Spark generate different results: * `SQRT(n)` If n < 0, Hive returns null, Spark SQL returns NaN. * `ACOS(n)` If n < -1 or n > 1, Hive returns null, Spark SQL returns NaN. * `ASIN(n)` If n < -1 or n > 1, Hive returns null, Spark SQL returns NaN. +* `CAST(n AS TIMESTAMP)` If n is integral numbers, Hive treats n as milliseconds, Spark SQL treats n as seconds. diff --git a/docs/sql-performance-tuning.md b/docs/sql-performance-tuning.md index 7cd85b6a9ab4c..5e6f049a51e95 100644 --- a/docs/sql-performance-tuning.md +++ b/docs/sql-performance-tuning.md @@ -179,7 +179,7 @@ SELECT /*+ BROADCAST(r) */ * FROM records r JOIN src s ON r.key = s.key -For more details please refer to the documentation of [Join Hints](sql-ref-syntax-qry-select-hints.html). +For more details please refer to the documentation of [Join Hints](sql-ref-syntax-qry-select-hints.html#join-hints). ## Coalesce Hints for SQL Queries @@ -196,6 +196,8 @@ The "REPARTITION_BY_RANGE" hint must have column names and a partition number is SELECT /*+ REPARTITION_BY_RANGE(c) */ * FROM t SELECT /*+ REPARTITION_BY_RANGE(3, c) */ * FROM t +For more details please refer to the documentation of [Partitioning Hints](sql-ref-syntax-qry-select-hints.html#partitioning-hints). + ## Adaptive Query Execution Adaptive Query Execution (AQE) is an optimization technique in Spark SQL that makes use of the runtime statistics to choose the most efficient query execution plan. AQE is disabled by default. Spark SQL can use the umbrella configuration of `spark.sql.adaptive.enabled` to control whether turn it on/off. As of Spark 3.0, there are three major features in AQE, including coalescing post-shuffle partitions, converting sort-merge join to broadcast join, and skew join optimization. diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index 93fb10b24e99f..eab194c71ec79 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -41,10 +41,10 @@ This means that in case an operation causes overflows, the result is the same wi On the other hand, Spark SQL returns null for decimal overflows. When `spark.sql.ansi.enabled` is set to `true` and an overflow occurs in numeric and interval arithmetic operations, it throws an arithmetic exception at runtime. -{% highlight sql %} +```sql -- `spark.sql.ansi.enabled=true` SELECT 2147483647 + 1; - java.lang.ArithmeticException: integer overflow +java.lang.ArithmeticException: integer overflow -- `spark.sql.ansi.enabled=false` SELECT 2147483647 + 1; @@ -53,7 +53,7 @@ SELECT 2147483647 + 1; +----------------+ | -2147483648| +----------------+ -{% endhighlight %} +``` ### Type Conversion @@ -64,15 +64,15 @@ On the other hand, `INSERT INTO` syntax throws an analysis exception when the AN Currently, the ANSI mode affects explicit casting and assignment casting only. In future releases, the behaviour of type coercion might change along with the other two type conversion rules. -{% highlight sql %} +```sql -- Examples of explicit casting -- `spark.sql.ansi.enabled=true` SELECT CAST('a' AS INT); - java.lang.NumberFormatException: invalid input syntax for type numeric: a +java.lang.NumberFormatException: invalid input syntax for type numeric: a SELECT CAST(2147483648L AS INT); - java.lang.ArithmeticException: Casting 2147483648 to int causes overflow +java.lang.ArithmeticException: Casting 2147483648 to int causes overflow -- `spark.sql.ansi.enabled=false` (This is a default behaviour) SELECT CAST('a' AS INT); @@ -94,8 +94,8 @@ CREATE TABLE t (v INT); -- `spark.sql.storeAssignmentPolicy=ANSI` INSERT INTO t VALUES ('1'); - org.apache.spark.sql.AnalysisException: Cannot write incompatible data to table '`default`.`t`': - - Cannot safely cast 'v': StringType to IntegerType; +org.apache.spark.sql.AnalysisException: Cannot write incompatible data to table '`default`.`t`': +- Cannot safely cast 'v': string to int; -- `spark.sql.storeAssignmentPolicy=LEGACY` (This is a legacy behaviour until Spark 2.x) INSERT INTO t VALUES ('1'); @@ -105,7 +105,7 @@ SELECT * FROM t; +---+ | 1| +---+ -{% endhighlight %} +``` ### SQL Functions diff --git a/docs/sql-ref-datatypes.md b/docs/sql-ref-datatypes.md index 3f6b6b590f843..f27f1a0ca967f 100644 --- a/docs/sql-ref-datatypes.md +++ b/docs/sql-ref-datatypes.md @@ -240,7 +240,7 @@ Specifically: #### Examples -{% highlight sql %} +```sql SELECT double('infinity') AS col; +--------+ | col| @@ -313,4 +313,4 @@ SELECT COUNT(*), c2 FROM test GROUP BY c2; | 2|-Infinity| | 3| Infinity| +---------+---------+ -{% endhighlight %} \ No newline at end of file +``` diff --git a/docs/sql-ref-datetime-pattern.md b/docs/sql-ref-datetime-pattern.md index b65bb1319fb9d..5859ad82525f2 100644 --- a/docs/sql-ref-datetime-pattern.md +++ b/docs/sql-ref-datetime-pattern.md @@ -30,25 +30,25 @@ Spark uses pattern letters in the following table for date and timestamp parsing |Symbol|Meaning|Presentation|Examples| |------|-------|------------|--------| -|**G**|era|text|AD; Anno Domini; A| +|**G**|era|text|AD; Anno Domini| |**y**|year|year|2020; 20| -|**D**|day-of-year|number|189| -|**M/L**|month-of-year|number/text|7; 07; Jul; July; J| -|**d**|day-of-month|number|28| +|**D**|day-of-year|number(3)|189| +|**M/L**|month-of-year|month|7; 07; Jul; July| +|**d**|day-of-month|number(3)|28| |**Q/q**|quarter-of-year|number/text|3; 03; Q3; 3rd quarter| |**Y**|week-based-year|year|1996; 96| -|**w**|week-of-week-based-year|number|27| -|**W**|week-of-month|number|4| -|**E**|day-of-week|text|Tue; Tuesday; T| -|**u**|localized day-of-week|number/text|2; 02; Tue; Tuesday; T| -|**F**|week-of-month|number|3| -|**a**|am-pm-of-day|text|PM| -|**h**|clock-hour-of-am-pm (1-12)|number|12| -|**K**|hour-of-am-pm (0-11)|number|0| -|**k**|clock-hour-of-day (1-24)|number|0| -|**H**|hour-of-day (0-23)|number|0| -|**m**|minute-of-hour|number|30| -|**s**|second-of-minute|number|55| +|**w**|week-of-week-based-year|number(2)|27| +|**W**|week-of-month|number(1)|4| +|**E**|day-of-week|text|Tue; Tuesday| +|**u**|localized day-of-week|number/text|2; 02; Tue; Tuesday| +|**F**|week-of-month|number(1)|3| +|**a**|am-pm-of-day|am-pm|PM| +|**h**|clock-hour-of-am-pm (1-12)|number(2)|12| +|**K**|hour-of-am-pm (0-11)|number(2)|0| +|**k**|clock-hour-of-day (1-24)|number(2)|0| +|**H**|hour-of-day (0-23)|number(2)|0| +|**m**|minute-of-hour|number(2)|30| +|**s**|second-of-minute|number(2)|55| |**S**|fraction-of-second|fraction|978| |**V**|time-zone ID|zone-id|America/Los_Angeles; Z; -08:30| |**z**|time-zone name|zone-name|Pacific Standard Time; PST| @@ -63,9 +63,9 @@ Spark uses pattern letters in the following table for date and timestamp parsing The count of pattern letters determines the format. -- Text: The text style is determined based on the number of pattern letters used. Less than 4 pattern letters will use the short form. Exactly 4 pattern letters will use the full form. Exactly 5 pattern letters will use the narrow form. Six or more letters will fail. +- Text: The text style is determined based on the number of pattern letters used. Less than 4 pattern letters will use the short form. Exactly 4 pattern letters will use the full form. Exactly 5 pattern letters will use the narrow form. 5 or more letters will fail. -- Number: If the count of letters is one, then the value is output using the minimum number of digits and without padding. Otherwise, the count of digits is used as the width of the output field, with the value zero-padded as necessary. The following pattern letters have constraints on the count of letters. Only one letter 'F' can be specified. Up to two letters of 'd', 'H', 'h', 'K', 'k', 'm', and 's' can be specified. Up to three letters of 'D' can be specified. +- Number(n): The n here represents the maximum count of letters this type of datetime pattern can be used. If the count of letters is one, then the value is output using the minimum number of digits and without padding. Otherwise, the count of digits is used as the width of the output field, with the value zero-padded as necessary. - Number/Text: If the count of pattern letters is 3 or greater, use the Text rules above. Otherwise use the Number rules above. @@ -74,7 +74,53 @@ The count of pattern letters determines the format. For formatting, the fraction length would be padded to the number of contiguous 'S' with zeros. Spark supports datetime of micro-of-second precision, which has up to 6 significant digits, but can parse nano-of-second with exceeded part truncated. -- Year: The count of letters determines the minimum field width below which padding is used. If the count of letters is two, then a reduced two digit form is used. For printing, this outputs the rightmost two digits. For parsing, this will parse using the base value of 2000, resulting in a year within the range 2000 to 2099 inclusive. If the count of letters is less than four (but not two), then the sign is only output for negative years. Otherwise, the sign is output if the pad width is exceeded when 'G' is not present. +- Year: The count of letters determines the minimum field width below which padding is used. If the count of letters is two, then a reduced two digit form is used. For printing, this outputs the rightmost two digits. For parsing, this will parse using the base value of 2000, resulting in a year within the range 2000 to 2099 inclusive. If the count of letters is less than four (but not two), then the sign is only output for negative years. Otherwise, the sign is output if the pad width is exceeded when 'G' is not present. 11 or more letters will fail. + +- Month: It follows the rule of Number/Text. The text form is depend on letters - 'M' denotes the 'standard' form, and 'L' is for 'stand-alone' form. These two forms are different only in some certain languages. For example, in Russian, 'Июль' is the stand-alone form of July, and 'Июля' is the standard form. Here are examples for all supported pattern letters: + - `'M'` or `'L'`: Month number in a year starting from 1. There is no difference between 'M' and 'L'. Month from 1 to 9 are printed without padding. + ```sql + spark-sql> select date_format(date '1970-01-01', "M"); + 1 + spark-sql> select date_format(date '1970-12-01', "L"); + 12 + ``` + - `'MM'` or `'LL'`: Month number in a year starting from 1. Zero padding is added for month 1-9. + ```sql + spark-sql> select date_format(date '1970-1-01', "LL"); + 01 + spark-sql> select date_format(date '1970-09-01', "MM"); + 09 + ``` + - `'MMM'`: Short textual representation in the standard form. The month pattern should be a part of a date pattern not just a stand-alone month except locales where there is no difference between stand and stand-alone forms like in English. + ```sql + spark-sql> select date_format(date '1970-01-01', "d MMM"); + 1 Jan + spark-sql> select to_csv(named_struct('date', date '1970-01-01'), map('dateFormat', 'dd MMM', 'locale', 'RU')); + 01 янв. + ``` + - `'LLL'`: Short textual representation in the stand-alone form. It should be used to format/parse only months without any other date fields. + ```sql + spark-sql> select date_format(date '1970-01-01', "LLL"); + Jan + spark-sql> select to_csv(named_struct('date', date '1970-01-01'), map('dateFormat', 'LLL', 'locale', 'RU')); + янв. + ``` + - `'MMMM'`: full textual month representation in the standard form. It is used for parsing/formatting months as a part of dates/timestamps. + ```sql + spark-sql> select date_format(date '1970-01-01', "d MMMM"); + 1 January + spark-sql> select to_csv(named_struct('date', date '1970-01-01'), map('dateFormat', 'd MMMM', 'locale', 'RU')); + 1 января + ``` + - `'LLLL'`: full textual month representation in the stand-alone form. The pattern can be used to format/parse only months. + ```sql + spark-sql> select date_format(date '1970-01-01', "LLLL"); + January + spark-sql> select to_csv(named_struct('date', date '1970-01-01'), map('dateFormat', 'LLLL', 'locale', 'RU')); + январь + ``` + +- am-pm: This outputs the am-pm-of-day. Pattern letter count must be 1. - Zone ID(V): This outputs the display the time-zone ID. Pattern letter count must be 2. @@ -90,11 +136,11 @@ The count of pattern letters determines the format. During formatting, all valid data will be output even it is in the optional section. During parsing, the whole section may be missing from the parsed string. An optional section is started by `[` and ended using `]` (or at the end of the pattern). + +- Symbols of 'Y', 'W', 'w', 'E', 'u', 'F', 'q' and 'Q' can only be used for datetime formatting, e.g. `date_format`. They are not allowed used for datetime parsing, e.g. `to_timestamp`. More details for the text style: - Short Form: Short text, typically an abbreviation. For example, day-of-week Monday might output "Mon". - Full Form: Full text, typically the full description. For example, day-of-week Monday might output "Monday". - -- Narrow Form: Narrow text, typically a single letter. For example, day-of-week Monday might output "M". diff --git a/docs/sql-ref-functions-builtin.md b/docs/sql-ref-functions-builtin.md index 1bca68e5f19df..cabb83e09fde9 100644 --- a/docs/sql-ref-functions-builtin.md +++ b/docs/sql-ref-functions-builtin.md @@ -70,7 +70,7 @@ license: | ### JSON Functions {% include_relative generated-json-funcs-table.html %} #### Examples -{% include_relative generated-agg-funcs-examples.html %} +{% include_relative generated-json-funcs-examples.html %} {% break %} {% endif %} {% endfor %} diff --git a/docs/sql-ref-functions-udf-aggregate.md b/docs/sql-ref-functions-udf-aggregate.md index 3fde94d6bc4bf..da3182149410b 100644 --- a/docs/sql-ref-functions-udf-aggregate.md +++ b/docs/sql-ref-functions-udf-aggregate.md @@ -27,46 +27,35 @@ User-Defined Aggregate Functions (UDAFs) are user-programmable routines that act A base class for user-defined aggregations, which can be used in Dataset operations to take all of the elements of a group and reduce them to a single value. - * IN - The input type for the aggregation. - * BUF - The type of the intermediate value of the reduction. - * OUT - The type of the final output result. + ***IN*** - The input type for the aggregation. + + ***BUF*** - The type of the intermediate value of the reduction. + + ***OUT*** - The type of the final output result. + +* **bufferEncoder: Encoder[BUF]** -
-
bufferEncoder: Encoder[BUF]
-
Specifies the Encoder for the intermediate value type. -
-
-
-
finish(reduction: BUF): OUT
-
+ +* **finish(reduction: BUF): OUT** + Transform the output of the reduction. -
-
-
-
merge(b1: BUF, b2: BUF): BUF
-
+ +* **merge(b1: BUF, b2: BUF): BUF** + Merge two intermediate values. -
-
-
-
outputEncoder: Encoder[OUT]
-
+ +* **outputEncoder: Encoder[OUT]** + Specifies the Encoder for the final output value type. -
-
-
-
reduce(b: BUF, a: IN): BUF
-
- Aggregate input value a into current intermediate value. For performance, the function may modify b and return it instead of constructing new object for b. -
-
-
-
zero: BUF
-
+ +* **reduce(b: BUF, a: IN): BUF** + + Aggregate input value `a` into current intermediate value. For performance, the function may modify `b` and return it instead of constructing new object for `b`. + +* **zero: BUF** + The initial value of the intermediate result for this aggregation. -
-
### Examples @@ -95,16 +84,16 @@ For example, a user-defined average for untyped DataFrames can look like: {% include_example untyped_custom_aggregation java/org/apache/spark/examples/sql/JavaUserDefinedUntypedAggregation.java%}
-{% highlight sql %} +```sql -- Compile and place UDAF MyAverage in a JAR file called `MyAverage.jar` in /tmp. CREATE FUNCTION myAverage AS 'MyAverage' USING JAR '/tmp/MyAverage.jar'; SHOW USER FUNCTIONS; --- +------------------+ --- | function| --- +------------------+ --- | default.myAverage| --- +------------------+ ++------------------+ +| function| ++------------------+ +| default.myAverage| ++------------------+ CREATE TEMPORARY VIEW employees USING org.apache.spark.sql.json @@ -113,26 +102,26 @@ OPTIONS ( ); SELECT * FROM employees; --- +-------+------+ --- | name|salary| --- +-------+------+ --- |Michael| 3000| --- | Andy| 4500| --- | Justin| 3500| --- | Berta| 4000| --- +-------+------+ ++-------+------+ +| name|salary| ++-------+------+ +|Michael| 3000| +| Andy| 4500| +| Justin| 3500| +| Berta| 4000| ++-------+------+ SELECT myAverage(salary) as average_salary FROM employees; --- +--------------+ --- |average_salary| --- +--------------+ --- | 3750.0| --- +--------------+ -{% endhighlight %} ++--------------+ +|average_salary| ++--------------+ +| 3750.0| ++--------------+ +```
### Related Statements - * [Scalar User Defined Functions (UDFs)](sql-ref-functions-udf-scalar.html) - * [Integration with Hive UDFs/UDAFs/UDTFs](sql-ref-functions-udf-hive.html) +* [Scalar User Defined Functions (UDFs)](sql-ref-functions-udf-scalar.html) +* [Integration with Hive UDFs/UDAFs/UDTFs](sql-ref-functions-udf-hive.html) diff --git a/docs/sql-ref-functions-udf-hive.md b/docs/sql-ref-functions-udf-hive.md index 7a7129de23836..819c446c411d2 100644 --- a/docs/sql-ref-functions-udf-hive.md +++ b/docs/sql-ref-functions-udf-hive.md @@ -28,7 +28,7 @@ Spark SQL supports integration of Hive UDFs, UDAFs and UDTFs. Similar to Spark U Hive has two UDF interfaces: [UDF](https://github.com/apache/hive/blob/master/udf/src/java/org/apache/hadoop/hive/ql/exec/UDF.java) and [GenericUDF](https://github.com/apache/hive/blob/master/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDF.java). An example below uses [GenericUDFAbs](https://github.com/apache/hive/blob/master/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFAbs.java) derived from `GenericUDF`. -{% highlight sql %} +```sql -- Register `GenericUDFAbs` and use it in Spark SQL. -- Note that, if you use your own programmed one, you need to add a JAR containing it -- into a classpath, @@ -52,12 +52,12 @@ SELECT testUDF(value) FROM t; | 2.0| | 3.0| +--------------+ -{% endhighlight %} +``` An example below uses [GenericUDTFExplode](https://github.com/apache/hive/blob/master/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDTFExplode.java) derived from [GenericUDTF](https://github.com/apache/hive/blob/master/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDF.java). -{% highlight sql %} +```sql -- Register `GenericUDTFExplode` and use it in Spark SQL CREATE TEMPORARY FUNCTION hiveUDTF AS 'org.apache.hadoop.hive.ql.udf.generic.GenericUDTFExplode'; @@ -79,12 +79,12 @@ SELECT hiveUDTF(value) FROM t; | 3| | 4| +---+ -{% endhighlight %} +``` Hive has two UDAF interfaces: [UDAF](https://github.com/apache/hive/blob/master/udf/src/java/org/apache/hadoop/hive/ql/exec/UDAF.java) and [GenericUDAFResolver](https://github.com/apache/hive/blob/master/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFResolver.java). An example below uses [GenericUDAFSum](https://github.com/apache/hive/blob/master/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java) derived from `GenericUDAFResolver`. -{% highlight sql %} +```sql -- Register `GenericUDAFSum` and use it in Spark SQL CREATE TEMPORARY FUNCTION hiveUDAF AS 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFSum'; @@ -105,4 +105,4 @@ SELECT key, hiveUDAF(value) FROM t GROUP BY key; | b| 3| | a| 3| +---+---------------+ -{% endhighlight %} +``` \ No newline at end of file diff --git a/docs/sql-ref-functions-udf-scalar.md b/docs/sql-ref-functions-udf-scalar.md index 2cb25f275cb59..97f5a89d3fb19 100644 --- a/docs/sql-ref-functions-udf-scalar.md +++ b/docs/sql-ref-functions-udf-scalar.md @@ -26,24 +26,18 @@ User-Defined Functions (UDFs) are user-programmable routines that act on one row ### UserDefinedFunction To define the properties of a user-defined function, the user can use some of the methods defined in this class. -
-
asNonNullable(): UserDefinedFunction
-
+ +* **asNonNullable(): UserDefinedFunction** + Updates UserDefinedFunction to non-nullable. -
-
-
-
asNondeterministic(): UserDefinedFunction
-
+ +* **asNondeterministic(): UserDefinedFunction** + Updates UserDefinedFunction to nondeterministic. -
-
-
-
withName(name: String): UserDefinedFunction
-
+ +* **withName(name: String): UserDefinedFunction** + Updates UserDefinedFunction with a given name. -
-
### Examples @@ -57,5 +51,5 @@ To define the properties of a user-defined function, the user can use some of th ### Related Statements - * [User Defined Aggregate Functions (UDAFs)](sql-ref-functions-udf-aggregate.html) - * [Integration with Hive UDFs/UDAFs/UDTFs](sql-ref-functions-udf-hive.html) +* [User Defined Aggregate Functions (UDAFs)](sql-ref-functions-udf-aggregate.html) +* [Integration with Hive UDFs/UDAFs/UDTFs](sql-ref-functions-udf-hive.html) diff --git a/docs/sql-ref-identifier.md b/docs/sql-ref-identifier.md index 89cde21e6fdb6..f65d491cc2fc4 100644 --- a/docs/sql-ref-identifier.md +++ b/docs/sql-ref-identifier.md @@ -27,41 +27,34 @@ An identifier is a string used to identify a database object such as a table, vi #### Regular Identifier -{% highlight sql %} +```sql { letter | digit | '_' } [ , ... ] -{% endhighlight %} -Note: If `spark.sql.ansi.enabled` is set to true, ANSI SQL reserved keywords cannot be used as identifiers. For more details, please refer to [ANSI Compliance](sql-ref-ansi-compliance.html). +``` +**Note:** If `spark.sql.ansi.enabled` is set to true, ANSI SQL reserved keywords cannot be used as identifiers. For more details, please refer to [ANSI Compliance](sql-ref-ansi-compliance.html). #### Delimited Identifier -{% highlight sql %} +```sql `c [ ... ]` -{% endhighlight %} +``` ### Parameters -
-
letter
-
+* **letter** + Any letter from A-Z or a-z. -
-
-
-
digit
-
+ +* **digit** + Any numeral from 0 to 9. -
-
-
-
c
-
+ +* **c** + Any character from the character set. Use ` to escape special characters (e.g., `). -
-
### Examples -{% highlight sql %} +```sql -- This CREATE TABLE fails with ParseException because of the illegal identifier name a.b CREATE TABLE test (a.b int); org.apache.spark.sql.catalyst.parser.ParseException: @@ -77,4 +70,4 @@ no viable alternative at input 'CREATE TABLE test (`a`b`'(line 1, pos 23) -- This CREATE TABLE works CREATE TABLE test (`a``b` int); -{% endhighlight %} +``` diff --git a/docs/sql-ref-literals.md b/docs/sql-ref-literals.md index 0088f79cb7007..b83f7f0a97c24 100644 --- a/docs/sql-ref-literals.md +++ b/docs/sql-ref-literals.md @@ -35,22 +35,19 @@ A string literal is used to specify a character string value. #### Syntax -{% highlight sql %} -'c [ ... ]' | "c [ ... ]" -{% endhighlight %} +```sql +'char [ ... ]' | "char [ ... ]" +``` -#### Parameters +#### Parameters + +* **char** -
-
c
-
- One character from the character set. Use \ to escape special characters (e.g., ' or \). -
-
+ One character from the character set. Use `\` to escape special characters (e.g., `'` or `\`). -#### Examples +#### Examples -{% highlight sql %} +```sql SELECT 'Hello, World!' AS col; +-------------+ | col| @@ -71,7 +68,7 @@ SELECT 'it\'s $10.' AS col; +---------+ |It's $10.| +---------+ -{% endhighlight %} +``` ### Binary Literal @@ -79,29 +76,26 @@ A binary literal is used to specify a byte sequence value. #### Syntax -{% highlight sql %} -X { 'c [ ... ]' | "c [ ... ]" } -{% endhighlight %} +```sql +X { 'num [ ... ]' | "num [ ... ]" } +``` + +#### Parameters -#### Parameters +* **num** -
-
c
-
- One character from the character set. -
-
+ Any hexadecimal number from 0 to F. -#### Examples +#### Examples -{% highlight sql %} +```sql SELECT X'123456' AS col; +----------+ | col| +----------+ |[12 34 56]| +----------+ -{% endhighlight %} +``` ### Null Literal @@ -109,20 +103,20 @@ A null literal is used to specify a null value. #### Syntax -{% highlight sql %} +```sql NULL -{% endhighlight %} +``` #### Examples -{% highlight sql %} +```sql SELECT NULL AS col; +----+ | col| +----+ |NULL| +----+ -{% endhighlight %} +``` ### Boolean Literal @@ -130,20 +124,20 @@ A boolean literal is used to specify a boolean value. #### Syntax -{% highlight sql %} +```sql TRUE | FALSE -{% endhighlight %} +``` #### Examples -{% highlight sql %} +```sql SELECT TRUE AS col; +----+ | col| +----+ |true| +----+ -{% endhighlight %} +``` ### Numeric Literal @@ -151,48 +145,37 @@ A numeric literal is used to specify a fixed or floating-point number. #### Integral Literal -#### Syntax +##### Syntax -{% highlight sql %} +```sql [ + | - ] digit [ ... ] [ L | S | Y ] -{% endhighlight %} +``` -#### Parameters +##### Parameters + +* **digit** -
-
digit
-
Any numeral from 0 to 9. -
-
-
-
L
-
- Case insensitive, indicates BIGINT, which is a 8-byte signed integer number. -
-
-
-
S
-
- Case insensitive, indicates SMALLINT, which is a 2-byte signed integer number. -
-
-
-
Y
-
- Case insensitive, indicates TINYINT, which is a 1-byte signed integer number. -
-
-
-
default (no postfix)
-
+ +* **L** + + Case insensitive, indicates `BIGINT`, which is an 8-byte signed integer number. + +* **S** + + Case insensitive, indicates `SMALLINT`, which is a 2-byte signed integer number. + +* **Y** + + Case insensitive, indicates `TINYINT`, which is a 1-byte signed integer number. + +* **default (no postfix)** + Indicates a 4-byte signed integer number. -
-
-#### Examples +##### Examples -{% highlight sql %} +```sql SELECT -2147483648 AS col; +-----------+ | col| @@ -220,56 +203,49 @@ SELECT 482S AS col; +---+ |482| +---+ -{% endhighlight %} +``` #### Fractional Literals -#### Syntax +##### Syntax decimal literals: -{% highlight sql %} +```sql decimal_digits { [ BD ] | [ exponent BD ] } | digit [ ... ] [ exponent ] BD -{% endhighlight %} +``` double literals: -{% highlight sql %} +```sql decimal_digits { D | exponent [ D ] } | digit [ ... ] { exponent [ D ] | [ exponent ] D } -{% endhighlight %} +``` While decimal_digits is defined as -{% highlight sql %} +```sql [ + | - ] { digit [ ... ] . [ digit [ ... ] ] | . digit [ ... ] } -{% endhighlight %} +``` and exponent is defined as -{% highlight sql %} +```sql E [ + | - ] digit [ ... ] -{% endhighlight %} +``` -#### Parameters +##### Parameters + +* **digit** -
-
digit
-
Any numeral from 0 to 9. -
-
-
-
D
-
- Case insensitive, indicates DOUBLE, which is a 8-byte double-precision floating point number. -
-
-
-
BD
-
- Case insensitive, indicates DECIMAL, with the total number of digits as precision and the number of digits to right of decimal point as scale. -
-
-#### Examples +* **D** + + Case insensitive, indicates `DOUBLE`, which is an 8-byte double-precision floating point number. + +* **BD** + + Case insensitive, indicates `DECIMAL`, with the total number of digits as precision and the number of digits to right of decimal point as scale. -{% highlight sql %} +##### Examples + +```sql SELECT 12.578 AS col; +------+ | col| @@ -353,7 +329,7 @@ SELECT -3.E-3D AS col; +------+ |-0.003| +------+ -{% endhighlight %} +``` ### Datetime Literal @@ -361,19 +337,19 @@ A Datetime literal is used to specify a datetime value. #### Date Literal -#### Syntax +##### Syntax -{% highlight sql %} +```sql DATE { 'yyyy' | 'yyyy-[m]m' | 'yyyy-[m]m-[d]d' | 'yyyy-[m]m-[d]d[T]' } -{% endhighlight %} -Note: defaults to 01 if month or day is not specified. +``` +**Note:** defaults to `01` if month or day is not specified. -#### Examples +##### Examples -{% highlight sql %} +```sql SELECT DATE '1997' AS col; +----------+ | col| @@ -394,13 +370,13 @@ SELECT DATE '2011-11-11' AS col; +----------+ |2011-11-11| +----------+ -{% endhighlight %} +``` #### Timestamp Literal -#### Syntax +##### Syntax -{% highlight sql %} +```sql TIMESTAMP { 'yyyy' | 'yyyy-[m]m' | 'yyyy-[m]m-[d]d' | @@ -409,27 +385,23 @@ TIMESTAMP { 'yyyy' | 'yyyy-[m]m-[d]d[T][h]h:[m]m[:]' | 'yyyy-[m]m-[d]d[T][h]h:[m]m:[s]s[.]' | 'yyyy-[m]m-[d]d[T][h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]'} -{% endhighlight %} -Note: defaults to 00 if hour, minute or second is not specified.

+``` +**Note:** defaults to `00` if hour, minute or second is not specified. `zone_id` should have one of the forms: - -Note: defaults to the session local timezone (set via spark.sql.session.timeZone) if zone_id is not specified. +* Z - Zulu time zone UTC+0 +* `+|-[h]h:[m]m` +* An id with one of the prefixes UTC+, UTC-, GMT+, GMT-, UT+ or UT-, and a suffix in the formats: + * `+|-h[h]` + * `+|-hh[:]mm` + * `+|-hh:mm:ss` + * `+|-hhmmss` +* Region-based zone IDs in the form `area/city`, such as `Europe/Paris` -#### Examples +**Note:** defaults to the session local timezone (set via `spark.sql.session.timeZone`) if `zone_id` is not specified. -{% highlight sql %} +##### Examples + +```sql SELECT TIMESTAMP '1997-01-31 09:26:56.123' AS col; +-----------------------+ | col| @@ -450,50 +422,40 @@ SELECT TIMESTAMP '1997-01' AS col; +-------------------+ |1997-01-01 00:00:00| +-------------------+ -{% endhighlight %} +``` ### Interval Literal An interval literal is used to specify a fixed period of time. -#### Syntax -{% highlight sql %} -{ INTERVAL interval_value interval_unit [ interval_value interval_unit ... ] | - INTERVAL 'interval_value interval_unit [ interval_value interval_unit ... ]' | - INTERVAL interval_string_value interval_unit TO interval_unit } -{% endhighlight %} +```sql +INTERVAL interval_value interval_unit [ interval_value interval_unit ... ] | +INTERVAL 'interval_value interval_unit [ interval_value interval_unit ... ]' | +INTERVAL interval_string_value interval_unit TO interval_unit +``` #### Parameters -
-
interval_value
-
- Syntax: - - [ + | - ] number_value | '[ + | - ] number_value' -
-
-
-
-
interval_string_value
-
- year-month/day-time interval string. -
-
-
-
interval_unit
-
- Syntax:
- - YEAR[S] | MONTH[S] | WEEK[S] | DAY[S] | HOUR[S] | MINUTE[S] | SECOND[S] |
- MILLISECOND[S] | MICROSECOND[S] -
-
-
+* **interval_value** + + **Syntax:** + + [ + | - ] number_value | '[ + | - ] number_value' + +* **interval_string_value** + + year-month/day-time interval string. + +* **interval_unit** + + **Syntax:** + + YEAR[S] | MONTH[S] | WEEK[S] | DAY[S] | HOUR[S] | MINUTE[S] | SECOND[S] | + MILLISECOND[S] | MICROSECOND[S] #### Examples -{% highlight sql %} +```sql SELECT INTERVAL 3 YEAR AS col; +-------+ | col| @@ -536,4 +498,4 @@ SELECT INTERVAL '20 15:40:32.99899999' DAY TO SECOND AS col; +---------------------------------------------+ |20 days 15 hours 40 minutes 32.998999 seconds| +---------------------------------------------+ -{% endhighlight %} +``` diff --git a/docs/sql-ref-null-semantics.md b/docs/sql-ref-null-semantics.md index 56b5cde630eaf..fb5d2a312d0e1 100644 --- a/docs/sql-ref-null-semantics.md +++ b/docs/sql-ref-null-semantics.md @@ -79,7 +79,7 @@ one or both operands are `NULL`: ### Examples -{% highlight sql %} +```sql -- Normal comparison operators return `NULL` when one of the operand is `NULL`. SELECT 5 > null AS expression_output; +-----------------+ @@ -111,7 +111,7 @@ SELECT NULL <=> NULL; +-----------------+ | true| +-----------------+ -{% endhighlight %} +``` ### Logical Operators @@ -134,7 +134,7 @@ The following tables illustrate the behavior of logical operators when one or bo ### Examples -{% highlight sql %} +```sql -- Normal comparison operators return `NULL` when one of the operands is `NULL`. SELECT (true OR null) AS expression_output; +-----------------+ @@ -158,7 +158,7 @@ SELECT NOT(null) AS expression_output; +-----------------+ | null| +-----------------+ -{% endhighlight %} +``` ### Expressions @@ -177,7 +177,7 @@ expression are `NULL` and most of the expressions fall in this category. ##### Examples -{% highlight sql %} +```sql SELECT concat('John', null) AS expression_output; +-----------------+ |expression_output| @@ -198,7 +198,7 @@ SELECT to_date(null) AS expression_output; +-----------------+ | null| +-----------------+ -{% endhighlight %} +``` #### Expressions That Can Process Null Value Operands @@ -221,7 +221,7 @@ returns the first non `NULL` value in its list of operands. However, `coalesce` ##### Examples -{% highlight sql %} +```sql SELECT isnull(null) AS expression_output; +-----------------+ |expression_output| @@ -251,7 +251,7 @@ SELECT isnan(null) AS expression_output; +-----------------+ | false| +-----------------+ -{% endhighlight %} +``` #### Builtin Aggregate Expressions @@ -271,7 +271,7 @@ the rules of how `NULL` values are handled by aggregate functions. #### Examples -{% highlight sql %} +```sql -- `count(*)` does not skip `NULL` values. SELECT count(*) FROM person; +--------+ @@ -312,7 +312,7 @@ SELECT max(age) FROM person where 1 = 0; +--------+ | null| +--------+ -{% endhighlight %} +``` ### Condition Expressions in WHERE, HAVING and JOIN Clauses @@ -323,7 +323,7 @@ For all the three operators, a condition expression is a boolean expression and #### Examples -{% highlight sql %} +```sql -- Persons whose age is unknown (`NULL`) are filtered out from the result set. SELECT * FROM person WHERE age > 0; +--------+---+ @@ -391,7 +391,7 @@ SELECT * FROM person p1, person p2 | Marry|null| Marry|null| | Joe| 30| Joe| 30| +--------+----+--------+----+ -{% endhighlight %} +``` ### Aggregate Operator (GROUP BY, DISTINCT) @@ -402,7 +402,7 @@ standard and with other enterprise database management systems. #### Examples -{% highlight sql %} +```sql -- `NULL` values are put in one bucket in `GROUP BY` processing. SELECT age, count(*) FROM person GROUP BY age; +----+--------+ @@ -424,7 +424,7 @@ SELECT DISTINCT age FROM person; | 30| | 18| +----+ -{% endhighlight %} +``` ### Sort Operator (ORDER BY Clause) @@ -434,7 +434,7 @@ the `NULL` values are placed at first. #### Examples -{% highlight sql %} +```sql -- `NULL` values are shown at first and other values -- are sorted in ascending way. SELECT age, name FROM person ORDER BY age; @@ -479,7 +479,7 @@ SELECT age, name FROM person ORDER BY age DESC NULLS LAST; |null| Marry| |null| Albert| +----+--------+ -{% endhighlight %} +``` ### Set Operators (UNION, INTERSECT, EXCEPT) @@ -489,7 +489,7 @@ equal unlike the regular `EqualTo`(`=`) operator. #### Examples -{% highlight sql %} +```sql CREATE VIEW unknown_age SELECT * FROM person WHERE age IS NULL; -- Only common rows between two legs of `INTERSECT` are in the @@ -537,7 +537,7 @@ SELECT name, age FROM person | Mike| 18| | Dan| 50| +--------+----+ -{% endhighlight %} +``` ### EXISTS/NOT EXISTS Subquery @@ -554,7 +554,7 @@ semijoins / anti-semijoins without special provisions for null awareness. #### Examples -{% highlight sql %} +```sql -- Even if subquery produces rows with `NULL` values, the `EXISTS` expression -- evaluates to `TRUE` as the subquery produces 1 row. SELECT * FROM person WHERE EXISTS (SELECT null); @@ -591,7 +591,7 @@ SELECT * FROM person WHERE NOT EXISTS (SELECT 1 WHERE 1 = 0); | Marry|null| | Joe| 30| +--------+----+ -{% endhighlight %} +``` ### IN/NOT IN Subquery @@ -617,7 +617,7 @@ and because NOT UNKNOWN is again UNKNOWN. #### Examples -{% highlight sql %} +```sql -- The subquery has only `NULL` value in its result set. Therefore, -- the result of `IN` predicate is UNKNOWN. SELECT * FROM person WHERE age IN (SELECT null); @@ -646,4 +646,4 @@ SELECT * FROM person |name|age| +----+---+ +----+---+ -{% endhighlight %} +``` diff --git a/docs/sql-ref-syntax-aux-analyze-table.md b/docs/sql-ref-syntax-aux-analyze-table.md index f6a6c5f4bc555..8f43d7388d7db 100644 --- a/docs/sql-ref-syntax-aux-analyze-table.md +++ b/docs/sql-ref-syntax-aux-analyze-table.md @@ -25,53 +25,39 @@ The `ANALYZE TABLE` statement collects statistics about the table to be used by ### Syntax -{% highlight sql %} +```sql ANALYZE TABLE table_identifier [ partition_spec ] COMPUTE STATISTICS [ NOSCAN | FOR COLUMNS col [ , ... ] | FOR ALL COLUMNS ] -{% endhighlight %} +``` ### Parameters -
-
table_identifier
-
- Specifies a table name, which may be optionally qualified with a database name.

- Syntax: - - [ database_name. ] table_name - -
-
- -
-
partition_spec
-
+* **table_identifier** + + Specifies a table name, which may be optionally qualified with a database name. + + **Syntax:** `[ database_name. ] table_name` + +* **partition_spec** + An optional parameter that specifies a comma separated list of key and value pairs - for partitions. When specified, partition statistics is returned.

- Syntax: - - PARTITION ( partition_col_name [ = partition_col_val ] [ , ... ] ) - -
-
- -
-
[ NOSCAN | FOR COLUMNS col [ , ... ] | FOR ALL COLUMNS ]
-
-
    -
  • If no analyze option is specified, ANALYZE TABLE collects the table's number of rows and size in bytes.
  • -
  • NOSCAN -
    Collect only the table's size in bytes ( which does not require scanning the entire table ).
  • -
  • FOR COLUMNS col [ , ... ] | FOR ALL COLUMNS -
    Collect column statistics for each column specified, or alternatively for every column, as well as table statistics. -
  • -
-
-
+ for partitions. When specified, partition statistics is returned. + + **Syntax:** `PARTITION ( partition_col_name [ = partition_col_val ] [ , ... ] )` + +* **[ NOSCAN `|` FOR COLUMNS col [ , ... ] `|` FOR ALL COLUMNS ]** + + * If no analyze option is specified, `ANALYZE TABLE` collects the table's number of rows and size in bytes. + * **NOSCAN** + + Collects only the table's size in bytes ( which does not require scanning the entire table ). + * **FOR COLUMNS col [ , ... ] `|` FOR ALL COLUMNS** + + Collects column statistics for each column specified, or alternatively for every column, as well as table statistics. ### Examples -{% highlight sql %} +```sql CREATE TABLE students (name STRING, student_id INT) PARTITIONED BY (student_id); INSERT INTO students PARTITION (student_id = 111111) VALUES ('Mark'); INSERT INTO students PARTITION (student_id = 222222) VALUES ('John'); @@ -135,4 +121,4 @@ DESC EXTENDED students name; | max_col_len| 4| | histogram| NULL| +--------------+----------+ -{% endhighlight %} +``` diff --git a/docs/sql-ref-syntax-aux-cache-cache-table.md b/docs/sql-ref-syntax-aux-cache-cache-table.md index 11f682cc10891..193e209d792b3 100644 --- a/docs/sql-ref-syntax-aux-cache-cache-table.md +++ b/docs/sql-ref-syntax-aux-cache-cache-table.md @@ -26,71 +26,57 @@ This reduces scanning of the original files in future queries. ### Syntax -{% highlight sql %} +```sql CACHE [ LAZY ] TABLE table_identifier [ OPTIONS ( 'storageLevel' [ = ] value ) ] [ [ AS ] query ] -{% endhighlight %} +``` ### Parameters -
-
LAZY
-
Only cache the table when it is first used, instead of immediately.
-
- -
-
table_identifier
-
- Specifies the table or view name to be cached. The table or view name may be optionally qualified with a database name.

- Syntax: - - [ database_name. ] table_name - -
-
- -
-
OPTIONS ( 'storageLevel' [ = ] value )
-
- OPTIONS clause with storageLevel key and value pair. A Warning is issued when a key other than storageLevel is used. The valid options for storageLevel are: -
    -
  • NONE
  • -
  • DISK_ONLY
  • -
  • DISK_ONLY_2
  • -
  • MEMORY_ONLY
  • -
  • MEMORY_ONLY_2
  • -
  • MEMORY_ONLY_SER
  • -
  • MEMORY_ONLY_SER_2
  • -
  • MEMORY_AND_DISK
  • -
  • MEMORY_AND_DISK_2
  • -
  • MEMORY_AND_DISK_SER
  • -
  • MEMORY_AND_DISK_SER_2
  • -
  • OFF_HEAP
  • -
- An Exception is thrown when an invalid value is set for storageLevel. If storageLevel is not explicitly set using OPTIONS clause, the default storageLevel is set to MEMORY_AND_DISK. -
-
- -
-
query
-
A query that produces the rows to be cached. It can be in one of following formats: -
    -
  • a SELECT statement
  • -
  • a TABLE statement
  • -
  • a FROM statement
  • -
-
-
+* **LAZY** + + Only cache the table when it is first used, instead of immediately. + +* **table_identifier** + + Specifies the table or view name to be cached. The table or view name may be optionally qualified with a database name. + + **Syntax:** `[ database_name. ] table_name` + +* **OPTIONS ( 'storageLevel' [ = ] value )** + + `OPTIONS` clause with `storageLevel` key and value pair. A Warning is issued when a key other than `storageLevel` is used. The valid options for `storageLevel` are: + * `NONE` + * `DISK_ONLY` + * `DISK_ONLY_2` + * `MEMORY_ONLY` + * `MEMORY_ONLY_2` + * `MEMORY_ONLY_SER` + * `MEMORY_ONLY_SER_2` + * `MEMORY_AND_DISK` + * `MEMORY_AND_DISK_2` + * `MEMORY_AND_DISK_SER` + * `MEMORY_AND_DISK_SER_2` + * `OFF_HEAP` + + An Exception is thrown when an invalid value is set for `storageLevel`. If `storageLevel` is not explicitly set using `OPTIONS` clause, the default `storageLevel` is set to `MEMORY_AND_DISK`. + +* **query** + + A query that produces the rows to be cached. It can be in one of following formats: + * a `SELECT` statement + * a `TABLE` statement + * a `FROM` statement ### Examples -{% highlight sql %} +```sql CACHE TABLE testCache OPTIONS ('storageLevel' 'DISK_ONLY') SELECT * FROM testData; -{% endhighlight %} +``` ### Related Statements - * [CLEAR CACHE](sql-ref-syntax-aux-cache-clear-cache.html) - * [UNCACHE TABLE](sql-ref-syntax-aux-cache-uncache-table.html) - * [REFRESH TABLE](sql-ref-syntax-aux-refresh-table.html) - * [REFRESH](sql-ref-syntax-aux-cache-refresh.html) +* [CLEAR CACHE](sql-ref-syntax-aux-cache-clear-cache.html) +* [UNCACHE TABLE](sql-ref-syntax-aux-cache-uncache-table.html) +* [REFRESH TABLE](sql-ref-syntax-aux-refresh-table.html) +* [REFRESH](sql-ref-syntax-aux-cache-refresh.html) diff --git a/docs/sql-ref-syntax-aux-cache-clear-cache.md b/docs/sql-ref-syntax-aux-cache-clear-cache.md index 47889691148b7..ee33e6a98296d 100644 --- a/docs/sql-ref-syntax-aux-cache-clear-cache.md +++ b/docs/sql-ref-syntax-aux-cache-clear-cache.md @@ -25,19 +25,19 @@ license: | ### Syntax -{% highlight sql %} +```sql CLEAR CACHE -{% endhighlight %} +``` ### Examples -{% highlight sql %} +```sql CLEAR CACHE; -{% endhighlight %} +``` ### Related Statements - * [CACHE TABLE](sql-ref-syntax-aux-cache-cache-table.html) - * [UNCACHE TABLE](sql-ref-syntax-aux-cache-uncache-table.html) - * [REFRESH TABLE](sql-ref-syntax-aux-refresh-table.html) - * [REFRESH](sql-ref-syntax-aux-cache-refresh.html) +* [CACHE TABLE](sql-ref-syntax-aux-cache-cache-table.html) +* [UNCACHE TABLE](sql-ref-syntax-aux-cache-uncache-table.html) +* [REFRESH TABLE](sql-ref-syntax-aux-refresh-table.html) +* [REFRESH](sql-ref-syntax-aux-cache-refresh.html) diff --git a/docs/sql-ref-syntax-aux-cache-refresh.md b/docs/sql-ref-syntax-aux-cache-refresh.md index 25f7ede1d324e..82bc12da5d1ac 100644 --- a/docs/sql-ref-syntax-aux-cache-refresh.md +++ b/docs/sql-ref-syntax-aux-cache-refresh.md @@ -27,32 +27,30 @@ invalidate everything that is cached. ### Syntax -{% highlight sql %} +```sql REFRESH resource_path -{% endhighlight %} +``` ### Parameters -
-
resource_path
-
The path of the resource that is to be refreshed.
-
+* **resource_path** + + The path of the resource that is to be refreshed. ### Examples -{% highlight sql %} +```sql -- The Path is resolved using the datasource's File Index. - CREATE TABLE test(ID INT) using parquet; INSERT INTO test SELECT 1000; CACHE TABLE test; INSERT INTO test SELECT 100; REFRESH "hdfs://path/to/table"; -{% endhighlight %} +``` ### Related Statements - * [CACHE TABLE](sql-ref-syntax-aux-cache-cache-table.html) - * [CLEAR CACHE](sql-ref-syntax-aux-cache-clear-cache.html) - * [UNCACHE TABLE](sql-ref-syntax-aux-cache-uncache-table.html) - * [REFRESH TABLE](sql-ref-syntax-aux-refresh-table.html) +* [CACHE TABLE](sql-ref-syntax-aux-cache-cache-table.html) +* [CLEAR CACHE](sql-ref-syntax-aux-cache-clear-cache.html) +* [UNCACHE TABLE](sql-ref-syntax-aux-cache-uncache-table.html) +* [REFRESH TABLE](sql-ref-syntax-aux-refresh-table.html) diff --git a/docs/sql-ref-syntax-aux-cache-uncache-table.md b/docs/sql-ref-syntax-aux-cache-uncache-table.md index 95fd91c3c4807..c5a8fbbe08281 100644 --- a/docs/sql-ref-syntax-aux-cache-uncache-table.md +++ b/docs/sql-ref-syntax-aux-cache-uncache-table.md @@ -26,32 +26,27 @@ underlying entries should already have been brought to cache by previous `CACHE ### Syntax -{% highlight sql %} +```sql UNCACHE TABLE [ IF EXISTS ] table_identifier -{% endhighlight %} +``` ### Parameters -
-
table_identifier
-
- Specifies the table or view name to be uncached. The table or view name may be optionally qualified with a database name.

- Syntax: - - [ database_name. ] table_name - -
-
+* **table_identifier** + + Specifies the table or view name to be uncached. The table or view name may be optionally qualified with a database name. + + **Syntax:** `[ database_name. ] table_name` ### Examples -{% highlight sql %} +```sql UNCACHE TABLE t1; -{% endhighlight %} +``` ### Related Statements - * [CACHE TABLE](sql-ref-syntax-aux-cache-cache-table.html) - * [CLEAR CACHE](sql-ref-syntax-aux-cache-clear-cache.html) - * [REFRESH TABLE](sql-ref-syntax-aux-refresh-table.html) - * [REFRESH](sql-ref-syntax-aux-cache-refresh.html) +* [CACHE TABLE](sql-ref-syntax-aux-cache-cache-table.html) +* [CLEAR CACHE](sql-ref-syntax-aux-cache-clear-cache.html) +* [REFRESH TABLE](sql-ref-syntax-aux-refresh-table.html) +* [REFRESH](sql-ref-syntax-aux-cache-refresh.html) diff --git a/docs/sql-ref-syntax-aux-conf-mgmt-reset.md b/docs/sql-ref-syntax-aux-conf-mgmt-reset.md index e7e6dda4e25ee..4caf57a232f89 100644 --- a/docs/sql-ref-syntax-aux-conf-mgmt-reset.md +++ b/docs/sql-ref-syntax-aux-conf-mgmt-reset.md @@ -25,17 +25,17 @@ Reset any runtime configurations specific to the current session which were set ### Syntax -{% highlight sql %} +```sql RESET -{% endhighlight %} +``` ### Examples -{% highlight sql %} +```sql -- Reset any runtime configurations specific to the current session which were set via the SET command to their default values. RESET; -{% endhighlight %} +``` ### Related Statements - * [SET](sql-ref-syntax-aux-conf-mgmt-set.html) +* [SET](sql-ref-syntax-aux-conf-mgmt-set.html) diff --git a/docs/sql-ref-syntax-aux-conf-mgmt-set.md b/docs/sql-ref-syntax-aux-conf-mgmt-set.md index 330a1a6a399ff..f97b7f2a8efed 100644 --- a/docs/sql-ref-syntax-aux-conf-mgmt-set.md +++ b/docs/sql-ref-syntax-aux-conf-mgmt-set.md @@ -25,32 +25,29 @@ The SET command sets a property, returns the value of an existing property or re ### Syntax -{% highlight sql %} +```sql SET SET [ -v ] SET property_key[ = property_value ] -{% endhighlight %} +``` ### Parameters -
-
-v
-
Outputs the key, value and meaning of existing SQLConf properties.
-
+* **-v** -
-
property_key
-
Returns the value of specified property key.
-
+ Outputs the key, value and meaning of existing SQLConf properties. -
-
property_key=property_value
-
Sets the value for a given property key. If an old value exists for a given property key, then it gets overridden by the new value.
-
+* **property_key** + + Returns the value of specified property key. + +* **property_key=property_value** + + Sets the value for a given property key. If an old value exists for a given property key, then it gets overridden by the new value. ### Examples -{% highlight sql %} +```sql -- Set a property. SET spark.sql.variable.substitute=false; @@ -67,8 +64,8 @@ SET spark.sql.variable.substitute; +-----------------------------+-----+ |spark.sql.variable.substitute|false| +-----------------------------+-----+ -{% endhighlight %} +``` ### Related Statements - * [RESET](sql-ref-syntax-aux-conf-mgmt-reset.html) +* [RESET](sql-ref-syntax-aux-conf-mgmt-reset.html) diff --git a/docs/sql-ref-syntax-aux-conf-mgmt.md b/docs/sql-ref-syntax-aux-conf-mgmt.md index f5e48ef2fee30..1900fb7f1cb9a 100644 --- a/docs/sql-ref-syntax-aux-conf-mgmt.md +++ b/docs/sql-ref-syntax-aux-conf-mgmt.md @@ -20,4 +20,4 @@ license: | --- * [SET](sql-ref-syntax-aux-conf-mgmt-set.html) - * [UNSET](sql-ref-syntax-aux-conf-mgmt-reset.html) + * [RESET](sql-ref-syntax-aux-conf-mgmt-reset.html) diff --git a/docs/sql-ref-syntax-aux-describe-database.md b/docs/sql-ref-syntax-aux-describe-database.md index 39a40ddac800f..143fa78b205ca 100644 --- a/docs/sql-ref-syntax-aux-describe-database.md +++ b/docs/sql-ref-syntax-aux-describe-database.md @@ -28,23 +28,20 @@ interchangeable. ### Syntax -{% highlight sql %} +```sql { DESC | DESCRIBE } DATABASE [ EXTENDED ] db_name -{% endhighlight %} +``` ### Parameters -
-
db_name
-
+* **db_name** + Specifies a name of an existing database or an existing schema in the system. If the name does not exist, an exception is thrown. -
-
### Examples -{% highlight sql %} +```sql -- Create employees DATABASE CREATE DATABASE employees COMMENT 'For software companies'; @@ -89,10 +86,10 @@ DESC DATABASE deployment; | Description| Deployment environment| | Location|file:/Users/Temp/deployment.db| +-------------------------+------------------------------+ -{% endhighlight %} +``` ### Related Statements - * [DESCRIBE FUNCTION](sql-ref-syntax-aux-describe-function.html) - * [DESCRIBE TABLE](sql-ref-syntax-aux-describe-table.html) - * [DESCRIBE QUERY](sql-ref-syntax-aux-describe-query.html) +* [DESCRIBE FUNCTION](sql-ref-syntax-aux-describe-function.html) +* [DESCRIBE TABLE](sql-ref-syntax-aux-describe-table.html) +* [DESCRIBE QUERY](sql-ref-syntax-aux-describe-query.html) diff --git a/docs/sql-ref-syntax-aux-describe-function.md b/docs/sql-ref-syntax-aux-describe-function.md index 76c9efad2fa7d..a871fb5bfd406 100644 --- a/docs/sql-ref-syntax-aux-describe-function.md +++ b/docs/sql-ref-syntax-aux-describe-function.md @@ -28,29 +28,24 @@ metadata information is returned along with the extended usage information. ### Syntax -{% highlight sql %} +```sql { DESC | DESCRIBE } FUNCTION [ EXTENDED ] function_name -{% endhighlight %} +``` ### Parameters -
-
function_name
-
+* **function_name** + Specifies a name of an existing function in the system. The function name may be optionally qualified with a database name. If `function_name` is qualified with a database then the function is resolved from the user specified database, otherwise - it is resolved from the current database.

- Syntax: - - [ database_name. ] function_name - -
-
+ it is resolved from the current database. + + **Syntax:** `[ database_name. ] function_name` ### Examples -{% highlight sql %} +```sql -- Describe a builtin scalar function. -- Returns function name, implementing class and usage DESC FUNCTION abs; @@ -107,11 +102,10 @@ DESC FUNCTION EXTENDED explode | 10 | | 20 | +---------------------------------------------------------------+ - -{% endhighlight %} +``` ### Related Statements - * [DESCRIBE DATABASE](sql-ref-syntax-aux-describe-database.html) - * [DESCRIBE TABLE](sql-ref-syntax-aux-describe-table.html) - * [DESCRIBE QUERY](sql-ref-syntax-aux-describe-query.html) +* [DESCRIBE DATABASE](sql-ref-syntax-aux-describe-database.html) +* [DESCRIBE TABLE](sql-ref-syntax-aux-describe-table.html) +* [DESCRIBE QUERY](sql-ref-syntax-aux-describe-query.html) diff --git a/docs/sql-ref-syntax-aux-describe-query.md b/docs/sql-ref-syntax-aux-describe-query.md index 41e66dcb2e316..b2a74cbd06078 100644 --- a/docs/sql-ref-syntax-aux-describe-query.md +++ b/docs/sql-ref-syntax-aux-describe-query.md @@ -27,38 +27,36 @@ describe the query output. ### Syntax -{% highlight sql %} +```sql { DESC | DESCRIBE } [ QUERY ] input_statement -{% endhighlight %} +``` ### Parameters -
-
QUERY
-
This clause is optional and may be omitted.
-
input_statement
-
+* **QUERY** + This clause is optional and may be omitted. + +* **input_statement** + Specifies a result set producing statement and may be one of the following: -
    -
  • a SELECT statement
  • -
  • a CTE(Common table expression) statement
  • -
  • an INLINE TABLE statement
  • -
  • a TABLE statement
  • -
  • a FROM statement
  • -
- Please refer to select-statement + + * a `SELECT` statement + * a `CTE(Common table expression)` statement + * an `INLINE TABLE` statement + * a `TABLE` statement + * a `FROM` statement` + + Please refer to [select-statement](sql-ref-syntax-qry-select.html) for a detailed syntax of the query parameter. -
-
### Examples -{% highlight sql %} +```sql -- Create table `person` CREATE TABLE person (name STRING , age INT COMMENT 'Age column', address STRING); -- Returns column metadata information for a simple select query -DESCRIBE QUERY select age, sum(age) FROM person GROUP BY age; +DESCRIBE QUERY SELECT age, sum(age) FROM person GROUP BY age; +--------+---------+----------+ |col_name|data_type| comment| +--------+---------+----------+ @@ -103,10 +101,10 @@ DESCRIBE FROM person SELECT age; +--------+---------+----------+ | age| int| Agecolumn| +--------+---------+----------+ -{% endhighlight %} +``` ### Related Statements - * [DESCRIBE DATABASE](sql-ref-syntax-aux-describe-database.html) - * [DESCRIBE TABLE](sql-ref-syntax-aux-describe-table.html) - * [DESCRIBE FUNCTION](sql-ref-syntax-aux-describe-function.html) +* [DESCRIBE DATABASE](sql-ref-syntax-aux-describe-database.html) +* [DESCRIBE TABLE](sql-ref-syntax-aux-describe-table.html) +* [DESCRIBE FUNCTION](sql-ref-syntax-aux-describe-function.html) diff --git a/docs/sql-ref-syntax-aux-describe-table.md b/docs/sql-ref-syntax-aux-describe-table.md index 63bf056d785cc..4b6e1e8c3461e 100644 --- a/docs/sql-ref-syntax-aux-describe-table.md +++ b/docs/sql-ref-syntax-aux-describe-table.md @@ -28,53 +28,43 @@ to return the metadata pertaining to a partition or column respectively. ### Syntax -{% highlight sql %} +```sql { DESC | DESCRIBE } [ TABLE ] [ format ] table_identifier [ partition_spec ] [ col_name ] -{% endhighlight %} +``` ### Parameters -
-
format
-
+* **format** + Specifies the optional format of describe output. If `EXTENDED` is specified then additional metadata information (such as parent database, owner, and access time) is returned. -
-
table_identifier
-
- Specifies a table name, which may be optionally qualified with a database name.

- Syntax: - - [ database_name. ] table_name - -
-
partition_spec
-
+ +* **table_identifier** + + Specifies a table name, which may be optionally qualified with a database name. + + **Syntax:** `[ database_name. ] table_name` + +* **partition_spec** + An optional parameter that specifies a comma separated list of key and value pairs - for partitions. When specified, additional partition metadata is returned.

- Syntax: - - PARTITION ( partition_col_name = partition_col_val [ , ... ] ) - -
-
col_name
-
+ for partitions. When specified, additional partition metadata is returned. + + **Syntax:** `PARTITION ( partition_col_name = partition_col_val [ , ... ] )` + +* **col_name** + An optional parameter that specifies the column name that needs to be described. The supplied column name may be optionally qualified. Parameters `partition_spec` and `col_name` are mutually exclusive and can not be specified together. Currently - nested columns are not allowed to be specified.

+ nested columns are not allowed to be specified. - Syntax: - - [ database_name. ] [ table_name. ] column_name - -
-
+ **Syntax:** `[ database_name. ] [ table_name. ] column_name` ### Examples -{% highlight sql %} +```sql -- Creates a table `customer`. Assumes current database is `salesdb`. CREATE TABLE customer( cust_id INT, @@ -183,10 +173,10 @@ DESCRIBE customer salesdb.customer.name; |data_type| string| | comment|Short name| +---------+----------+ -{% endhighlight %} +``` ### Related Statements - * [DESCRIBE DATABASE](sql-ref-syntax-aux-describe-database.html) - * [DESCRIBE QUERY](sql-ref-syntax-aux-describe-query.html) - * [DESCRIBE FUNCTION](sql-ref-syntax-aux-describe-function.html) +* [DESCRIBE DATABASE](sql-ref-syntax-aux-describe-database.html) +* [DESCRIBE QUERY](sql-ref-syntax-aux-describe-query.html) +* [DESCRIBE FUNCTION](sql-ref-syntax-aux-describe-function.html) diff --git a/docs/sql-ref-syntax-aux-refresh-table.md b/docs/sql-ref-syntax-aux-refresh-table.md index 165ca68309f4a..8d4a804f88671 100644 --- a/docs/sql-ref-syntax-aux-refresh-table.md +++ b/docs/sql-ref-syntax-aux-refresh-table.md @@ -27,26 +27,21 @@ lazy manner when the cached table or the query associated with it is executed ag ### Syntax -{% highlight sql %} +```sql REFRESH [TABLE] table_identifier -{% endhighlight %} +``` ### Parameters -
-
table_identifier
-
- Specifies a table name, which is either a qualified or unqualified name that designates a table/view. If no database identifier is provided, it refers to a temporary view or a table/view in the current database.

- Syntax: - - [ database_name. ] table_name - -
-
+* **table_identifier** + + Specifies a table name, which is either a qualified or unqualified name that designates a table/view. If no database identifier is provided, it refers to a temporary view or a table/view in the current database. + + **Syntax:** `[ database_name. ] table_name` ### Examples -{% highlight sql %} +```sql -- The cached entries of the table will be refreshed -- The table is resolved from the current database as the table name is unqualified. REFRESH TABLE tbl1; @@ -54,11 +49,11 @@ REFRESH TABLE tbl1; -- The cached entries of the view will be refreshed or invalidated -- The view is resolved from tempDB database, as the view name is qualified. REFRESH TABLE tempDB.view1; -{% endhighlight %} +``` ### Related Statements - * [CACHE TABLE](sql-ref-syntax-aux-cache-cache-table.html) - * [CLEAR CACHE](sql-ref-syntax-aux-cache-clear-cache.html) - * [UNCACHE TABLE](sql-ref-syntax-aux-cache-uncache-table.html) - * [REFRESH](sql-ref-syntax-aux-cache-refresh.html) +* [CACHE TABLE](sql-ref-syntax-aux-cache-cache-table.html) +* [CLEAR CACHE](sql-ref-syntax-aux-cache-clear-cache.html) +* [UNCACHE TABLE](sql-ref-syntax-aux-cache-uncache-table.html) +* [REFRESH](sql-ref-syntax-aux-cache-refresh.html) diff --git a/docs/sql-ref-syntax-aux-resource-mgmt-add-file.md b/docs/sql-ref-syntax-aux-resource-mgmt-add-file.md index 0028884308890..9203293d0c981 100644 --- a/docs/sql-ref-syntax-aux-resource-mgmt-add-file.md +++ b/docs/sql-ref-syntax-aux-resource-mgmt-add-file.md @@ -25,30 +25,29 @@ license: | ### Syntax -{% highlight sql %} +```sql ADD FILE resource_name -{% endhighlight %} +``` ### Parameters -
-
resource_name
-
The name of the file or directory to be added.
-
+* **resource_name** + + The name of the file or directory to be added. ### Examples -{% highlight sql %} +```sql ADD FILE /tmp/test; ADD FILE "/path/to/file/abc.txt"; ADD FILE '/another/test.txt'; ADD FILE "/path with space/abc.txt"; ADD FILE "/path/to/some/directory"; -{% endhighlight %} +``` ### Related Statements - * [LIST FILE](sql-ref-syntax-aux-resource-mgmt-list-file.html) - * [LIST JAR](sql-ref-syntax-aux-resource-mgmt-list-jar.html) - * [ADD JAR](sql-ref-syntax-aux-resource-mgmt-add-jar.html) +* [LIST FILE](sql-ref-syntax-aux-resource-mgmt-list-file.html) +* [LIST JAR](sql-ref-syntax-aux-resource-mgmt-list-jar.html) +* [ADD JAR](sql-ref-syntax-aux-resource-mgmt-add-jar.html) diff --git a/docs/sql-ref-syntax-aux-resource-mgmt-add-jar.md b/docs/sql-ref-syntax-aux-resource-mgmt-add-jar.md index c4020347c1be0..4694bff99daf5 100644 --- a/docs/sql-ref-syntax-aux-resource-mgmt-add-jar.md +++ b/docs/sql-ref-syntax-aux-resource-mgmt-add-jar.md @@ -25,28 +25,27 @@ license: | ### Syntax -{% highlight sql %} +```sql ADD JAR file_name -{% endhighlight %} +``` ### Parameters -
-
file_name
-
The name of the JAR file to be added. It could be either on a local file system or a distributed file system.
-
+* **file_name** + + The name of the JAR file to be added. It could be either on a local file system or a distributed file system. ### Examples -{% highlight sql %} +```sql ADD JAR /tmp/test.jar; ADD JAR "/path/to/some.jar"; ADD JAR '/some/other.jar'; ADD JAR "/path with space/abc.jar"; -{% endhighlight %} +``` ### Related Statements - * [LIST JAR](sql-ref-syntax-aux-resource-mgmt-list-jar.html) - * [ADD FILE](sql-ref-syntax-aux-resource-mgmt-add-file.html) - * [LIST FILE](sql-ref-syntax-aux-resource-mgmt-list-file.html) +* [LIST JAR](sql-ref-syntax-aux-resource-mgmt-list-jar.html) +* [ADD FILE](sql-ref-syntax-aux-resource-mgmt-add-file.html) +* [LIST FILE](sql-ref-syntax-aux-resource-mgmt-list-file.html) diff --git a/docs/sql-ref-syntax-aux-resource-mgmt-list-file.md b/docs/sql-ref-syntax-aux-resource-mgmt-list-file.md index eec98e1fbffb5..9b9a7df7f612f 100644 --- a/docs/sql-ref-syntax-aux-resource-mgmt-list-file.md +++ b/docs/sql-ref-syntax-aux-resource-mgmt-list-file.md @@ -25,13 +25,13 @@ license: | ### Syntax -{% highlight sql %} +```sql LIST FILE -{% endhighlight %} +``` ### Examples -{% highlight sql %} +```sql ADD FILE /tmp/test; ADD FILE /tmp/test_2; LIST FILE; @@ -42,11 +42,11 @@ file:/private/tmp/test_2 LIST FILE /tmp/test /some/random/file /another/random/file --output file:/private/tmp/test -{% endhighlight %} +``` ### Related Statements - * [ADD FILE](sql-ref-syntax-aux-resource-mgmt-add-file.html) - * [ADD JAR](sql-ref-syntax-aux-resource-mgmt-add-jar.html) - * [LIST JAR](sql-ref-syntax-aux-resource-mgmt-list-jar.html) +* [ADD FILE](sql-ref-syntax-aux-resource-mgmt-add-file.html) +* [ADD JAR](sql-ref-syntax-aux-resource-mgmt-add-jar.html) +* [LIST JAR](sql-ref-syntax-aux-resource-mgmt-list-jar.html) diff --git a/docs/sql-ref-syntax-aux-resource-mgmt-list-jar.md b/docs/sql-ref-syntax-aux-resource-mgmt-list-jar.md index dca4252c90ef2..04aa52c2ad8af 100644 --- a/docs/sql-ref-syntax-aux-resource-mgmt-list-jar.md +++ b/docs/sql-ref-syntax-aux-resource-mgmt-list-jar.md @@ -25,13 +25,13 @@ license: | ### Syntax -{% highlight sql %} +```sql LIST JAR -{% endhighlight %} +``` ### Examples -{% highlight sql %} +```sql ADD JAR /tmp/test.jar; ADD JAR /tmp/test_2.jar; LIST JAR; @@ -42,11 +42,11 @@ spark://192.168.1.112:62859/jars/test_2.jar LIST JAR /tmp/test.jar /some/random.jar /another/random.jar; -- output spark://192.168.1.112:62859/jars/test.jar -{% endhighlight %} +``` ### Related Statements - * [ADD JAR](sql-ref-syntax-aux-resource-mgmt-add-jar.html) - * [ADD FILE](sql-ref-syntax-aux-resource-mgmt-add-file.html) - * [LIST FILE](sql-ref-syntax-aux-resource-mgmt-list-file.html) +* [ADD JAR](sql-ref-syntax-aux-resource-mgmt-add-jar.html) +* [ADD FILE](sql-ref-syntax-aux-resource-mgmt-add-file.html) +* [LIST FILE](sql-ref-syntax-aux-resource-mgmt-list-file.html) diff --git a/docs/sql-ref-syntax-aux-show-columns.md b/docs/sql-ref-syntax-aux-show-columns.md index 7229bba23d2bf..b76db252f1a0f 100644 --- a/docs/sql-ref-syntax-aux-show-columns.md +++ b/docs/sql-ref-syntax-aux-show-columns.md @@ -21,7 +21,7 @@ license: | ### Description -Return the list of columns in a table. If the table does not exist, an exception is thrown. +Returns the list of columns in a table. If the table does not exist, an exception is thrown. ### Syntax diff --git a/docs/sql-ref-syntax-aux-show-create-table.md b/docs/sql-ref-syntax-aux-show-create-table.md index 47a5290f1d022..ae8c10e2d0178 100644 --- a/docs/sql-ref-syntax-aux-show-create-table.md +++ b/docs/sql-ref-syntax-aux-show-create-table.md @@ -25,26 +25,21 @@ license: | ### Syntax -{% highlight sql %} +```sql SHOW CREATE TABLE table_identifier -{% endhighlight %} +``` ### Parameters -
-
table_identifier
-
- Specifies a table or view name, which may be optionally qualified with a database name.

- Syntax: - - [ database_name. ] table_name - -
-
+* **table_identifier** + + Specifies a table or view name, which may be optionally qualified with a database name. + + **Syntax:** `[ database_name. ] table_name` ### Examples -{% highlight sql %} +```sql CREATE TABLE test (c INT) ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS TEXTFILE TBLPROPERTIES ('prop1' = 'value1', 'prop2' = 'value2'); @@ -60,9 +55,9 @@ SHOW CREATE TABLE test; 'prop1' = 'value1', 'prop2' = 'value2') +----------------------------------------------------+ -{% endhighlight %} +``` ### Related Statements - * [CREATE TABLE](sql-ref-syntax-ddl-create-table.html) - * [CREATE VIEW](sql-ref-syntax-ddl-create-view.html) +* [CREATE TABLE](sql-ref-syntax-ddl-create-table.html) +* [CREATE VIEW](sql-ref-syntax-ddl-create-view.html) diff --git a/docs/sql-ref-syntax-aux-show-databases.md b/docs/sql-ref-syntax-aux-show-databases.md index c84898aa81459..44c0fbbef3929 100644 --- a/docs/sql-ref-syntax-aux-show-databases.md +++ b/docs/sql-ref-syntax-aux-show-databases.md @@ -21,35 +21,31 @@ license: | ### Description -Lists the databases that match an optionally supplied string pattern. If no +Lists the databases that match an optionally supplied regular expression pattern. If no pattern is supplied then the command lists all the databases in the system. Please note that the usage of `SCHEMAS` and `DATABASES` are interchangeable and mean the same thing. ### Syntax -{% highlight sql %} +```sql SHOW { DATABASES | SCHEMAS } [ LIKE regex_pattern ] -{% endhighlight %} +``` ### Parameters -
-
regex_pattern
-
+* **regex_pattern** + Specifies a regular expression pattern that is used to filter the results of the statement. -
    -
  • Only * and | are allowed as wildcard pattern.
  • -
  • Excluding * and |, the remaining pattern follows the regular expression semantics.
  • -
  • The leading and trailing blanks are trimmed in the input pattern before processing. The pattern match is case-insensitive.
  • -
-
-
+ * Except for `*` and `|` character, the pattern works like a regular expression. + * `*` alone matches 0 or more characters and `|` is used to separate multiple different regular expressions, + any of which can match. + * The leading and trailing blanks are trimmed in the input pattern before processing. The pattern match is case-insensitive. ### Examples -{% highlight sql %} +```sql -- Create database. Assumes a database named `default` already exists in -- the system. CREATE DATABASE payroll_db; @@ -83,10 +79,10 @@ SHOW SCHEMAS; | payments_db| | payroll_db| +------------+ -{% endhighlight %} +``` ### Related Statements - * [DESCRIBE DATABASE](sql-ref-syntax-aux-describe-database.html) - * [CREATE DATABASE](sql-ref-syntax-ddl-create-database.html) - * [ALTER DATABASE](sql-ref-syntax-ddl-alter-database.html) +* [DESCRIBE DATABASE](sql-ref-syntax-aux-describe-database.html) +* [CREATE DATABASE](sql-ref-syntax-ddl-create-database.html) +* [ALTER DATABASE](sql-ref-syntax-ddl-alter-database.html) diff --git a/docs/sql-ref-syntax-aux-show-functions.md b/docs/sql-ref-syntax-aux-show-functions.md index 8a6de402c7f20..b4dd72801202e 100644 --- a/docs/sql-ref-syntax-aux-show-functions.md +++ b/docs/sql-ref-syntax-aux-show-functions.md @@ -9,7 +9,7 @@ license: | The ASF licenses this file to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software @@ -29,48 +29,42 @@ clause is optional and supported only for compatibility with other systems. ### Syntax -{% highlight sql %} -SHOW [ function_kind ] FUNCTIONS ( [ LIKE ] function_name | regex_pattern ) -{% endhighlight %} +```sql +SHOW [ function_kind ] FUNCTIONS [ [ LIKE ] { function_name | regex_pattern } ] +``` ### Parameters -
-
function_kind
-
+* **function_kind** + Specifies the name space of the function to be searched upon. The valid name spaces are : -
    -
  • USER - Looks up the function(s) among the user defined functions.
  • -
  • SYSTEM - Looks up the function(s) among the system defined functions.
  • -
  • ALL - Looks up the function(s) among both user and system defined functions.
  • -
-
-
function_name
-
+ + * **USER** - Looks up the function(s) among the user defined functions. + * **SYSTEM** - Looks up the function(s) among the system defined functions. + * **ALL** - Looks up the function(s) among both user and system defined functions. + +* **function_name** + Specifies a name of an existing function in the system. The function name may be optionally qualified with a database name. If `function_name` is qualified with a database then the function is resolved from the user specified database, otherwise - it is resolved from the current database.

- Syntax: - - [database_name.]function_name - -
-
regex_pattern
-
+ it is resolved from the current database. + + **Syntax:** `[ database_name. ] function_name` + +* **regex_pattern** + Specifies a regular expression pattern that is used to filter the results of the statement. -
    -
  • Only * and | are allowed as wildcard pattern.
  • -
  • Excluding * and |, the remaining pattern follows the regular expression semantics.
  • -
  • The leading and trailing blanks are trimmed in the input pattern before processing. The pattern match is case-insensitive.
  • -
-
-
+ + * Except for `*` and `|` character, the pattern works like a regular expression. + * `*` alone matches 0 or more characters and `|` is used to separate multiple different regular expressions, + any of which can match. + * The leading and trailing blanks are trimmed in the input pattern before processing. The pattern match is case-insensitive. ### Examples -{% highlight sql %} +```sql -- List a system function `trim` by searching both user defined and system -- defined functions. SHOW FUNCTIONS trim; @@ -138,8 +132,8 @@ SHOW FUNCTIONS LIKE 't[a-z][a-z][a-z]'; | tanh| | trim| +--------+ -{% endhighlight %} +``` ### Related Statements - * [DESCRIBE FUNCTION](sql-ref-syntax-aux-describe-function.html) +* [DESCRIBE FUNCTION](sql-ref-syntax-aux-describe-function.html) diff --git a/docs/sql-ref-syntax-aux-show-partitions.md b/docs/sql-ref-syntax-aux-show-partitions.md index 592833b23eb09..d93825550413f 100644 --- a/docs/sql-ref-syntax-aux-show-partitions.md +++ b/docs/sql-ref-syntax-aux-show-partitions.md @@ -27,37 +27,28 @@ partition spec. ### Syntax -{% highlight sql %} +```sql SHOW PARTITIONS table_identifier [ partition_spec ] -{% endhighlight %} +``` ### Parameters -
-
table_identifier
-
- Specifies a table name, which may be optionally qualified with a database name.

- Syntax: - - [ database_name. ] table_name - -
-
-
-
partition_spec
-
+* **table_identifier** + + Specifies a table name, which may be optionally qualified with a database name. + + **Syntax:** `[ database_name. ] table_name` + +* **partition_spec** + An optional parameter that specifies a comma separated list of key and value pairs - for partitions. When specified, the partitions that match the partition spec are returned.

- Syntax: - - PARTITION ( partition_col_name [ = partition_col_val ] [ , ... ] ) - -
-
+ for partitions. When specified, the partitions that match the partition specification are returned. + + **Syntax:** `PARTITION ( partition_col_name = partition_col_val [ , ... ] )` ### Examples -{% highlight sql %} +```sql -- create a partitioned table and insert a few rows. USE salesdb; CREATE TABLE customer(id INT, name STRING) PARTITIONED BY (state STRING, city STRING); @@ -109,11 +100,11 @@ SHOW PARTITIONS customer PARTITION (city = 'San Jose'); +----------------------+ |state=CA/city=San Jose| +----------------------+ -{% endhighlight %} +``` ### Related Statements - * [CREATE TABLE](sql-ref-syntax-ddl-create-table.html) - * [INSERT STATEMENT](sql-ref-syntax-dml-insert.html) - * [DESCRIBE TABLE](sql-ref-syntax-aux-describe-table.html) - * [SHOW TABLE](sql-ref-syntax-aux-show-table.html) +* [CREATE TABLE](sql-ref-syntax-ddl-create-table.html) +* [INSERT STATEMENT](sql-ref-syntax-dml-insert.html) +* [DESCRIBE TABLE](sql-ref-syntax-aux-describe-table.html) +* [SHOW TABLE](sql-ref-syntax-aux-show-table.html) diff --git a/docs/sql-ref-syntax-aux-show-table.md b/docs/sql-ref-syntax-aux-show-table.md index 3f588045790b2..0ce0a3eefa538 100644 --- a/docs/sql-ref-syntax-aux-show-table.md +++ b/docs/sql-ref-syntax-aux-show-table.md @@ -32,42 +32,36 @@ cannot be used with a partition specification. ### Syntax -{% highlight sql %} -SHOW TABLE EXTENDED [ IN | FROM database_name ] LIKE regex_pattern +```sql +SHOW TABLE EXTENDED [ { IN | FROM } database_name ] LIKE regex_pattern [ partition_spec ] -{% endhighlight %} +``` ### Parameters -
-
IN|FROM database_name
-
+* **{ IN`|`FROM } database_name** + Specifies database name. If not provided, will use the current database. -
-
regex_pattern
-
+ +* **regex_pattern** + Specifies the regular expression pattern that is used to filter out unwanted tables. -
    -
  • Except for * and | character, the pattern works like a regular expression.
  • -
  • * alone matches 0 or more characters and | is used to separate multiple different regular expressions, - any of which can match.
  • -
  • The leading and trailing blanks are trimmed in the input pattern before processing. The pattern match is case-insensitive.
  • -
-
-
partition_spec
-
+ + * Except for `*` and `|` character, the pattern works like a regular expression. + * `*` alone matches 0 or more characters and `|` is used to separate multiple different regular expressions, + any of which can match. + * The leading and trailing blanks are trimmed in the input pattern before processing. The pattern match is case-insensitive. + +* **partition_spec** + An optional parameter that specifies a comma separated list of key and value pairs - for partitions. Note that a table regex cannot be used with a partition specification.

- Syntax: - - PARTITION ( partition_col_name [ = partition_col_val ] [ , ... ] ) - -
-
+ for partitions. Note that a table regex cannot be used with a partition specification. + + **Syntax:** `PARTITION ( partition_col_name = partition_col_val [ , ... ] )` ### Examples -{% highlight sql %} +```sql -- Assumes `employee` table created with partitioned by column `grade` CREATE TABLE employee(name STRING, grade INT) PARTITIONED BY (grade); INSERT INTO employee PARTITION (grade = 1) VALUES ('sam'); @@ -152,7 +146,7 @@ SHOW TABLE EXTENDED LIKE `employe*`; +--------+---------+----------+---------------------------------------------------------------+ -- show partition file system details -SHOW TABLE EXTENDED IN `default` LIKE `employee` PARTITION (`grade=1`); +SHOW TABLE EXTENDED IN default LIKE `employee` PARTITION (`grade=1`); +--------+---------+-----------+--------------------------------------------------------------+ |database|tableName|isTemporary| information | +--------+---------+-----------+--------------------------------------------------------------+ @@ -175,12 +169,12 @@ SHOW TABLE EXTENDED IN `default` LIKE `employee` PARTITION (`grade=1`); +--------+---------+-----------+--------------------------------------------------------------+ -- show partition file system details with regex fails as shown below -SHOW TABLE EXTENDED IN `default` LIKE `empl*` PARTITION (`grade=1`); - Error: Error running query: org.apache.spark.sql.catalyst.analysis.NoSuchTableException: - Table or view 'emplo*' not found in database 'default'; (state=,code=0) -{% endhighlight %} +SHOW TABLE EXTENDED IN default LIKE `empl*` PARTITION (`grade=1`); +Error: Error running query: org.apache.spark.sql.catalyst.analysis.NoSuchTableException: + Table or view 'emplo*' not found in database 'default'; (state=,code=0) +``` ### Related Statements - * [CREATE TABLE](sql-ref-syntax-ddl-create-table.html) - * [DESCRIBE TABLE](sql-ref-syntax-aux-describe-table.html) +* [CREATE TABLE](sql-ref-syntax-ddl-create-table.html) +* [DESCRIBE TABLE](sql-ref-syntax-aux-describe-table.html) diff --git a/docs/sql-ref-syntax-aux-show-tables.md b/docs/sql-ref-syntax-aux-show-tables.md index 62eb3ddb18b5c..fef9722a444f8 100644 --- a/docs/sql-ref-syntax-aux-show-tables.md +++ b/docs/sql-ref-syntax-aux-show-tables.md @@ -28,33 +28,28 @@ current database. ### Syntax -{% highlight sql %} +```sql SHOW TABLES [ { FROM | IN } database_name ] [ LIKE regex_pattern ] -{% endhighlight %} +``` ### Parameters -
-
{ FROM | IN } database_name
-
+* **{ FROM `|` IN } database_name** + Specifies the database name from which tables are listed. -
-
regex_pattern
-
+ +* **regex_pattern** + Specifies the regular expression pattern that is used to filter out unwanted tables. -
    -
  • Except for * and | character, the pattern works like a regular expression.
  • -
  • * alone matches 0 or more characters and | is used to separate multiple different regular expressions, - any of which can match.
  • -
  • The leading and trailing blanks are trimmed in the input pattern before processing. The pattern match is case-insensitive.
  • -
- -
-
+ + * Except for `*` and `|` character, the pattern works like a regular expression. + * `*` alone matches 0 or more characters and `|` is used to separate multiple different regular expressions, + any of which can match. + * The leading and trailing blanks are trimmed in the input pattern before processing. The pattern match is case-insensitive. ### Examples -{% highlight sql %} +```sql -- List all tables in default database SHOW TABLES; +--------+---------+-----------+ @@ -101,11 +96,11 @@ SHOW TABLES LIKE 'sam*|suj'; | default| sam1| false| | default| suj| false| +--------+---------+-----------+ -{% endhighlight %} +``` ### Related Statements - * [CREATE TABLE](sql-ref-syntax-ddl-create-table.html) - * [DROP TABLE](sql-ref-syntax-ddl-drop-table.html) - * [CREATE DATABASE](sql-ref-syntax-ddl-create-database.html) - * [DROP DATABASE](sql-ref-syntax-ddl-drop-database.html) +* [CREATE TABLE](sql-ref-syntax-ddl-create-table.html) +* [DROP TABLE](sql-ref-syntax-ddl-drop-table.html) +* [CREATE DATABASE](sql-ref-syntax-ddl-create-database.html) +* [DROP DATABASE](sql-ref-syntax-ddl-drop-database.html) diff --git a/docs/sql-ref-syntax-aux-show-tblproperties.md b/docs/sql-ref-syntax-aux-show-tblproperties.md index 662aaad069dd9..5b7ddcbcd9534 100644 --- a/docs/sql-ref-syntax-aux-show-tblproperties.md +++ b/docs/sql-ref-syntax-aux-show-tblproperties.md @@ -26,37 +26,30 @@ a property key. If no key is specified then all the properties are returned. ### Syntax -{% highlight sql %} +```sql SHOW TBLPROPERTIES table_identifier [ ( unquoted_property_key | property_key_as_string_literal ) ] -{% endhighlight %} +``` ### Parameters -
-
table_identifier
-
+* **table_identifier** + Specifies the table name of an existing table. The table may be optionally qualified - with a database name.

- Syntax: - - [ database_name. ] table_name - -
-
unquoted_property_key
-
+ with a database name. + + **Syntax:** `[ database_name. ] table_name` + +* **unquoted_property_key** + Specifies the property key in unquoted form. The key may consists of multiple - parts separated by dot.

- Syntax: - - [ key_part1 ] [ .key_part2 ] [ ... ] - -
-
property_key_as_string_literal
-
+ parts separated by dot. + + **Syntax:** `[ key_part1 ] [ .key_part2 ] [ ... ]` + +* **property_key_as_string_literal** + Specifies a property key value as a string literal. -
-
**Note** - Property value returned by this statement excludes some properties @@ -68,7 +61,7 @@ SHOW TBLPROPERTIES table_identifier ### Examples -{% highlight sql %} +```sql -- create a table `customer` in database `salesdb` USE salesdb; CREATE TABLE customer(cust_code INT, name VARCHAR(100), cust_addr STRING) @@ -110,11 +103,11 @@ SHOW TBLPROPERTIES customer ('created.date'); +----------+ |01-01-2001| +----------+ -{% endhighlight %} +``` ### Related Statements - * [CREATE TABLE](sql-ref-syntax-ddl-create-table.html) - * [ALTER TABLE SET TBLPROPERTIES](sql-ref-syntax-ddl-alter-table.html) - * [SHOW TABLES](sql-ref-syntax-aux-show-tables.html) - * [SHOW TABLE EXTENDED](sql-ref-syntax-aux-show-table.html) +* [CREATE TABLE](sql-ref-syntax-ddl-create-table.html) +* [ALTER TABLE SET TBLPROPERTIES](sql-ref-syntax-ddl-alter-table.html) +* [SHOW TABLES](sql-ref-syntax-aux-show-tables.html) +* [SHOW TABLE EXTENDED](sql-ref-syntax-aux-show-table.html) diff --git a/docs/sql-ref-syntax-aux-show-views.md b/docs/sql-ref-syntax-aux-show-views.md index 29ad6caf140f8..5003c092cabce 100644 --- a/docs/sql-ref-syntax-aux-show-views.md +++ b/docs/sql-ref-syntax-aux-show-views.md @@ -29,30 +29,26 @@ list global temporary views. Note that the command also lists local temporary vi regardless of a given database. ### Syntax -{% highlight sql %} +```sql SHOW VIEWS [ { FROM | IN } database_name ] [ LIKE regex_pattern ] -{% endhighlight %} +``` ### Parameters -
-
{ FROM | IN } database_name
-
+* **{ FROM `|` IN } database_name** + Specifies the database name from which views are listed. -
-
regex_pattern
-
+ +* **regex_pattern** + Specifies the regular expression pattern that is used to filter out unwanted views. -
    -
  • Except for * and | character, the pattern works like a regular expression.
  • -
  • * alone matches 0 or more characters and | is used to separate multiple different regular expressions, - any of which can match.
  • -
  • The leading and trailing blanks are trimmed in the input pattern before processing. The pattern match is case-insensitive.
  • -
-
-
+ + * Except for `*` and `|` character, the pattern works like a regular expression. + * `*` alone matches 0 or more characters and `|` is used to separate multiple different regular expressions, + any of which can match. + * The leading and trailing blanks are trimmed in the input pattern before processing. The pattern match is case-insensitive. ### Examples -{% highlight sql %} +```sql -- Create views in different databases, also create global/local temp views. CREATE VIEW sam AS SELECT id, salary FROM employee WHERE name = 'sam'; CREATE VIEW sam1 AS SELECT id, salary FROM employee WHERE name = 'sam1'; @@ -61,8 +57,8 @@ USE userdb; CREATE VIEW user1 AS SELECT id, salary FROM default.employee WHERE name = 'user1'; CREATE VIEW user2 AS SELECT id, salary FROM default.employee WHERE name = 'user2'; USE default; -CREATE GLOBAL TEMP VIEW temp1 AS SELECT 1 as col1; -CREATE TEMP VIEW temp2 AS SELECT 1 as col1; +CREATE GLOBAL TEMP VIEW temp1 AS SELECT 1 AS col1; +CREATE TEMP VIEW temp2 AS SELECT 1 AS col1; -- List all views in default database SHOW VIEWS; @@ -112,11 +108,10 @@ SHOW VIEWS LIKE 'sam|suj|temp*'; | default | suj | false | | | temp2 | true | +-------------+------------+--------------+ - -{% endhighlight %} +``` ### Related statements -- [CREATE VIEW](sql-ref-syntax-ddl-create-view.html) -- [DROP VIEW](sql-ref-syntax-ddl-drop-view.html) -- [CREATE DATABASE](sql-ref-syntax-ddl-create-database.html) -- [DROP DATABASE](sql-ref-syntax-ddl-drop-database.html) +* [CREATE VIEW](sql-ref-syntax-ddl-create-view.html) +* [DROP VIEW](sql-ref-syntax-ddl-drop-view.html) +* [CREATE DATABASE](sql-ref-syntax-ddl-create-database.html) +* [DROP DATABASE](sql-ref-syntax-ddl-drop-database.html) diff --git a/docs/sql-ref-syntax-aux-show.md b/docs/sql-ref-syntax-aux-show.md index 424fe71370897..9f64ea2d50ae1 100644 --- a/docs/sql-ref-syntax-aux-show.md +++ b/docs/sql-ref-syntax-aux-show.md @@ -20,11 +20,11 @@ license: | --- * [SHOW COLUMNS](sql-ref-syntax-aux-show-columns.html) + * [SHOW CREATE TABLE](sql-ref-syntax-aux-show-create-table.html) * [SHOW DATABASES](sql-ref-syntax-aux-show-databases.html) * [SHOW FUNCTIONS](sql-ref-syntax-aux-show-functions.html) + * [SHOW PARTITIONS](sql-ref-syntax-aux-show-partitions.html) * [SHOW TABLE EXTENDED](sql-ref-syntax-aux-show-table.html) * [SHOW TABLES](sql-ref-syntax-aux-show-tables.html) * [SHOW TBLPROPERTIES](sql-ref-syntax-aux-show-tblproperties.html) - * [SHOW PARTITIONS](sql-ref-syntax-aux-show-partitions.html) - * [SHOW CREATE TABLE](sql-ref-syntax-aux-show-create-table.html) * [SHOW VIEWS](sql-ref-syntax-aux-show-views.html) diff --git a/docs/sql-ref-syntax-ddl-alter-database.md b/docs/sql-ref-syntax-ddl-alter-database.md index 2d5860c2ea920..fbc454e25fb0c 100644 --- a/docs/sql-ref-syntax-ddl-alter-database.md +++ b/docs/sql-ref-syntax-ddl-alter-database.md @@ -29,21 +29,20 @@ for a database and may be used for auditing purposes. ### Syntax -{% highlight sql %} +```sql ALTER { DATABASE | SCHEMA } database_name SET DBPROPERTIES ( property_name = property_value [ , ... ] ) -{% endhighlight %} +``` ### Parameters -
-
database_name
-
Specifies the name of the database to be altered.
-
+* **database_name** + + Specifies the name of the database to be altered. ### Examples -{% highlight sql %} +```sql -- Creates a database named `inventory`. CREATE DATABASE inventory; @@ -60,8 +59,8 @@ DESCRIBE DATABASE EXTENDED inventory; | Location| file:/temp/spark-warehouse/inventory.db| | Properties|((Edit-date,01/01/2001), (Edited-by,John))| +-------------------------+------------------------------------------+ -{% endhighlight %} +``` ### Related Statements - * [DESCRIBE DATABASE](sql-ref-syntax-aux-describe-database.html) +* [DESCRIBE DATABASE](sql-ref-syntax-aux-describe-database.html) diff --git a/docs/sql-ref-syntax-ddl-alter-table.md b/docs/sql-ref-syntax-ddl-alter-table.md index f81585fef3aae..eb0e9a9f9cf73 100644 --- a/docs/sql-ref-syntax-ddl-alter-table.md +++ b/docs/sql-ref-syntax-ddl-alter-table.md @@ -29,35 +29,25 @@ license: | #### Syntax -{% highlight sql %} +```sql ALTER TABLE table_identifier RENAME TO table_identifier ALTER TABLE table_identifier partition_spec RENAME TO partition_spec -{% endhighlight %} +``` #### Parameters -
-
table_identifier
-
- Specifies a table name, which may be optionally qualified with a database name.

- Syntax: - - [ database_name. ] table_name - -
-
- -
-
partition_spec
-
- Partition to be renamed.

- Syntax: - - PARTITION ( partition_col_name = partition_col_val [ , ... ] ) - -
-
+* **table_identifier** + + Specifies a table name, which may be optionally qualified with a database name. + + **Syntax:** `[ database_name. ] table_name` + +* **partition_spec** + + Partition to be renamed. + + **Syntax:** `PARTITION ( partition_col_name = partition_col_val [ , ... ] )` ### ADD COLUMNS @@ -65,66 +55,47 @@ ALTER TABLE table_identifier partition_spec RENAME TO partition_spec #### Syntax -{% highlight sql %} +```sql ALTER TABLE table_identifier ADD COLUMNS ( col_spec [ , ... ] ) -{% endhighlight %} +``` #### Parameters -
-
table_identifier
-
- Specifies a table name, which may be optionally qualified with a database name.

- Syntax: - - [ database_name. ] table_name - -
-
- -
-
COLUMNS ( col_spec )
-
Specifies the columns to be added.
-
+* **table_identifier** + + Specifies a table name, which may be optionally qualified with a database name. + + **Syntax:** `[ database_name. ] table_name` + +* **COLUMNS ( col_spec )** + + Specifies the columns to be added. ### ALTER OR CHANGE COLUMN -`ALTER TABLE ALTER COLUMN` or `ALTER TABLE CHANGE COLUMN` statement changes column's comment. +`ALTER TABLE ALTER COLUMN` or `ALTER TABLE CHANGE COLUMN` statement changes column's definition. #### Syntax -{% highlight sql %} +```sql ALTER TABLE table_identifier { ALTER | CHANGE } [ COLUMN ] col_spec alterColumnAction -{% endhighlight %} +``` #### Parameters -
-
table_identifier
-
- Specifies a table name, which may be optionally qualified with a database name.

- Syntax: - - [ database_name. ] table_name - -
-
- -
-
COLUMN col_spec
-
Specifies the column to be altered or be changed.
-
- -
-
alterColumnAction
-
- Change the comment string.

- Syntax: - - COMMENT STRING - -
-
+* **table_identifier** + + Specifies a table name, which may be optionally qualified with a database name. + + **Syntax:** `[ database_name. ] table_name` + +* **COLUMNS ( col_spec )** + + Specifies the column to be altered or be changed. + +* **alterColumnAction** + + Change column's definition. ### ADD AND DROP PARTITION @@ -134,34 +105,24 @@ ALTER TABLE table_identifier { ALTER | CHANGE } [ COLUMN ] col_spec alterColumnA ##### Syntax -{% highlight sql %} +```sql ALTER TABLE table_identifier ADD [IF NOT EXISTS] ( partition_spec [ partition_spec ... ] ) -{% endhighlight %} +``` ##### Parameters -
-
table_identifier
-
- Specifies a table name, which may be optionally qualified with a database name.

- Syntax: - - [ database_name. ] table_name - -
-
- -
-
partition_spec
-
- Partition to be added.

- Syntax: - - PARTITION ( partition_col_name = partition_col_val [ , ... ] ) - -
-
+* **table_identifier** + + Specifies a table name, which may be optionally qualified with a database name. + + **Syntax:** `[ database_name. ] table_name` + +* **partition_spec** + + Partition to be added. + + **Syntax:** `PARTITION ( partition_col_name = partition_col_val [ , ... ] )` #### DROP PARTITION @@ -169,33 +130,23 @@ ALTER TABLE table_identifier ADD [IF NOT EXISTS] ##### Syntax -{% highlight sql %} +```sql ALTER TABLE table_identifier DROP [ IF EXISTS ] partition_spec [PURGE] -{% endhighlight %} +``` ##### Parameters -
-
table_identifier
-
- Specifies a table name, which may be optionally qualified with a database name.

- Syntax: - - [ database_name. ] table_name - -
-
- -
-
partition_spec
-
- Partition to be dropped.

- Syntax: - - PARTITION ( partition_col_name = partition_col_val [ , ... ] ) - -
-
+* **table_identifier** + + Specifies a table name, which may be optionally qualified with a database name. + + **Syntax:** `[ database_name. ] table_name` + +* **partition_spec** + + Partition to be dropped. + + **Syntax:** `PARTITION ( partition_col_name = partition_col_val [ , ... ] )` ### SET AND UNSET @@ -208,30 +159,28 @@ this overrides the old value with the new one. ##### Syntax -{% highlight sql %} +```sql -- Set Table Properties ALTER TABLE table_identifier SET TBLPROPERTIES ( key1 = val1, key2 = val2, ... ) -- Unset Table Properties ALTER TABLE table_identifier UNSET TBLPROPERTIES [ IF EXISTS ] ( key1, key2, ... ) -{% endhighlight %} +``` #### SET SERDE -`ALTER TABLE SET` command is used for setting the SERDE or SERDE properties in Hive tables. If a particular property was already set, -this overrides the old value with the new one. +`ALTER TABLE SET` command is used for setting the SERDE or SERDE properties in Hive tables. If a particular property was already set, this overrides the old value with the new one. ##### Syntax -{% highlight sql %} +```sql -- Set SERDE Properties ALTER TABLE table_identifier [ partition_spec ] SET SERDEPROPERTIES ( key1 = val1, key2 = val2, ... ) ALTER TABLE table_identifier [ partition_spec ] SET SERDE serde_class_name [ WITH SERDEPROPERTIES ( key1 = val1, key2 = val2, ... ) ] - -{% endhighlight %} +``` #### SET LOCATION And SET FILE FORMAT @@ -240,46 +189,35 @@ existing tables. ##### Syntax -{% highlight sql %} +```sql -- Changing File Format ALTER TABLE table_identifier [ partition_spec ] SET FILEFORMAT file_format -- Changing File Location ALTER TABLE table_identifier [ partition_spec ] SET LOCATION 'new_location' -{% endhighlight %} +``` #### Parameters -
-
table_identifier
-
- Specifies a table name, which may be optionally qualified with a database name.

- Syntax: - - [ database_name. ] table_name - -
-
- -
-
partition_spec
-
- Specifies the partition on which the property has to be set.

- Syntax: - - PARTITION ( partition_col_name = partition_col_val [ , ... ] ) - -
-
- -
-
SERDEPROPERTIES ( key1 = val1, key2 = val2, ... )
-
Specifies the SERDE properties to be set.
-
+* **table_identifier** + + Specifies a table name, which may be optionally qualified with a database name. + + **Syntax:** `[ database_name. ] table_name` + +* **partition_spec** + + Specifies the partition on which the property has to be set. + + **Syntax:** `PARTITION ( partition_col_name = partition_col_val [ , ... ] )` + +* **SERDEPROPERTIES ( key1 = val1, key2 = val2, ... )** + + Specifies the SERDE properties to be set. ### Examples -{% highlight sql %} +```sql -- RENAME table DESC student; +-----------------------+---------+-------+ @@ -477,13 +415,19 @@ ALTER TABLE test_tab SET SERDE 'org.apache.hadoop.hive.serde2.columnar.LazyBinar ALTER TABLE dbx.tab1 SET SERDE 'org.apache.hadoop' WITH SERDEPROPERTIES ('k' = 'v', 'kay' = 'vee') -- SET TABLE PROPERTIES -ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('winner' = 'loser') +ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('winner' = 'loser'); + +-- SET TABLE COMMENT Using SET PROPERTIES +ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('comment' = 'A table comment.'); + +-- Alter TABLE COMMENT Using SET PROPERTIES +ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('comment' = 'This is a new comment.'); -- DROP TABLE PROPERTIES -ALTER TABLE dbx.tab1 UNSET TBLPROPERTIES ('winner') -{% endhighlight %} +ALTER TABLE dbx.tab1 UNSET TBLPROPERTIES ('winner'); +``` ### Related Statements - * [CREATE TABLE](sql-ref-syntax-ddl-create-table.html) - * [DROP TABLE](sql-ref-syntax-ddl-drop-table.html) +* [CREATE TABLE](sql-ref-syntax-ddl-create-table.html) +* [DROP TABLE](sql-ref-syntax-ddl-drop-table.html) diff --git a/docs/sql-ref-syntax-ddl-alter-view.md b/docs/sql-ref-syntax-ddl-alter-view.md index c2887692949ea..a34e77decf593 100644 --- a/docs/sql-ref-syntax-ddl-alter-view.md +++ b/docs/sql-ref-syntax-ddl-alter-view.md @@ -29,21 +29,16 @@ Renames the existing view. If the new view name already exists in the source dat does not support moving the views across databases. #### Syntax -{% highlight sql %} +```sql ALTER VIEW view_identifier RENAME TO view_identifier -{% endhighlight %} +``` #### Parameters -
-
view_identifier
-
- Specifies a view name, which may be optionally qualified with a database name.

- Syntax: - - [ database_name. ] view_name - -
-
+* **view_identifier** + + Specifies a view name, which may be optionally qualified with a database name. + + **Syntax:** `[ database_name. ] view_name` #### SET View Properties Set one or more properties of an existing view. The properties are the key value pairs. If the properties' keys exist, @@ -51,89 +46,70 @@ the values are replaced with the new values. If the properties' keys do not exis the properties. #### Syntax -{% highlight sql %} +```sql ALTER VIEW view_identifier SET TBLPROPERTIES ( property_key = property_val [ , ... ] ) -{% endhighlight %} +``` #### Parameters -
-
view_identifier
-
- Specifies a view name, which may be optionally qualified with a database name.

- Syntax: - - [ database_name. ] view_name - -
-
property_key
-
- Specifies the property key. The key may consists of multiple parts separated by dot.

- Syntax: - - [ key_part1 ] [ .key_part2 ] [ ... ] - -
-
+* **view_identifier** + + Specifies a view name, which may be optionally qualified with a database name. + + **Syntax:** `[ database_name. ] view_name` + +* **property_key** + + Specifies the property key. The key may consists of multiple parts separated by dot. + + **Syntax:** `[ key_part1 ] [ .key_part2 ] [ ... ]` #### UNSET View Properties Drop one or more properties of an existing view. If the specified keys do not exist, an exception is thrown. Use `IF EXISTS` to avoid the exception. #### Syntax -{% highlight sql %} +```sql ALTER VIEW view_identifier UNSET TBLPROPERTIES [ IF EXISTS ] ( property_key [ , ... ] ) -{% endhighlight %} +``` #### Parameters -
-
view_identifier
-
- Specifies a view name, which may be optionally qualified with a database name.

- Syntax: - - [ database_name. ] view_name - -
-
property_key
-
- Specifies the property key. The key may consists of multiple parts separated by dot.

- Syntax: - - [ key_part1 ] [ .key_part2 ] [ ... ] - -
-
+* **view_identifier** + + Specifies a view name, which may be optionally qualified with a database name. + + **Syntax:** `[ database_name. ] view_name` + +* **property_key** + + Specifies the property key. The key may consists of multiple parts separated by dot. + + **Syntax:** `[ key_part1 ] [ .key_part2 ] [ ... ]` #### ALTER View AS SELECT `ALTER VIEW view_identifier AS SELECT` statement changes the definition of a view. The `SELECT` statement must be valid, and the `view_identifier` must exist. #### Syntax -{% highlight sql %} +```sql ALTER VIEW view_identifier AS select_statement -{% endhighlight %} +``` Note that `ALTER VIEW` statement does not support `SET SERDE` or `SET SERDEPROPERTIES` properties. #### Parameters -
-
view_identifier
-
- Specifies a view name, which may be optionally qualified with a database name.

- Syntax: - - [ database_name. ] view_name - -
-
select_statement
-
- Specifies the definition of the view. Check select_statement for details. -
-
+* **view_identifier** + + Specifies a view name, which may be optionally qualified with a database name. + + **Syntax:** `[ database_name. ] view_name` + +* **select_statement** + + Specifies the definition of the view. Check [select_statement](sql-ref-syntax-qry-select.html) for details. ### Examples -{% highlight sql %} +```sql -- Rename only changes the view name. -- The source and target databases of the view have to be the same. -- Use qualified or unqualified name for the source and target view. @@ -218,11 +194,11 @@ DESC TABLE EXTENDED tempdb1.v2; | View Text| select * from tempdb1.v1| | | View Original Text| select * from tempdb1.v1| | +----------------------------+---------------------------+-------+ -{% endhighlight %} +``` ### Related Statements - * [describe-table](sql-ref-syntax-aux-describe-table.html) - * [create-view](sql-ref-syntax-ddl-create-view.html) - * [drop-view](sql-ref-syntax-ddl-drop-view.html) - * [show-views](sql-ref-syntax-aux-show-views.html) +* [describe-table](sql-ref-syntax-aux-describe-table.html) +* [create-view](sql-ref-syntax-ddl-create-view.html) +* [drop-view](sql-ref-syntax-ddl-drop-view.html) +* [show-views](sql-ref-syntax-aux-show-views.html) diff --git a/docs/sql-ref-syntax-ddl-create-database.md b/docs/sql-ref-syntax-ddl-create-database.md index 0ef0dfbdaed2b..9d8bf47844724 100644 --- a/docs/sql-ref-syntax-ddl-create-database.md +++ b/docs/sql-ref-syntax-ddl-create-database.md @@ -25,35 +25,38 @@ Creates a database with the specified name. If database with the same name alrea ### Syntax -{% highlight sql %} +```sql CREATE { DATABASE | SCHEMA } [ IF NOT EXISTS ] database_name [ COMMENT database_comment ] [ LOCATION database_directory ] [ WITH DBPROPERTIES ( property_name = property_value [ , ... ] ) ] -{% endhighlight %} +``` ### Parameters -
-
database_name
-
Specifies the name of the database to be created.
+* **database_name** -
IF NOT EXISTS
-
Creates a database with the given name if it doesn't exists. If a database with the same name already exists, nothing will happen.
+ Specifies the name of the database to be created. -
database_directory
-
Path of the file system in which the specified database is to be created. If the specified path does not exist in the underlying file system, this command creates a directory with the path. If the location is not specified, the database will be created in the default warehouse directory, whose path is configured by the static configuration spark.sql.warehouse.dir.
+* **IF NOT EXISTS** -
database_comment
-
Specifies the description for the database.
+ Creates a database with the given name if it does not exist. If a database with the same name already exists, nothing will happen. -
WITH DBPROPERTIES ( property_name=property_value [ , ... ] )
-
Specifies the properties for the database in key-value pairs.
-
+* **database_directory** + + Path of the file system in which the specified database is to be created. If the specified path does not exist in the underlying file system, this command creates a directory with the path. If the location is not specified, the database will be created in the default warehouse directory, whose path is configured by the static configuration spark.sql.warehouse.dir. + +* **database_comment** + + Specifies the description for the database. + +* **WITH DBPROPERTIES ( property_name=property_value [ , ... ] )** + + Specifies the properties for the database in key-value pairs. ### Examples -{% highlight sql %} +```sql -- Create database `customer_db`. This throws exception if database with name customer_db -- already exists. CREATE DATABASE customer_db; @@ -76,9 +79,9 @@ DESCRIBE DATABASE EXTENDED customer_db; | Location| hdfs://hacluster/user| | Properties| ((ID,001), (Name,John))| +-------------------------+--------------------------+ -{% endhighlight %} +``` ### Related Statements - * [DESCRIBE DATABASE](sql-ref-syntax-aux-describe-database.html) - * [DROP DATABASE](sql-ref-syntax-ddl-drop-database.html) +* [DESCRIBE DATABASE](sql-ref-syntax-aux-describe-database.html) +* [DROP DATABASE](sql-ref-syntax-ddl-drop-database.html) diff --git a/docs/sql-ref-syntax-ddl-create-function.md b/docs/sql-ref-syntax-ddl-create-function.md index e3f21f70f7c18..aa6c1fad7b56b 100644 --- a/docs/sql-ref-syntax-ddl-create-function.md +++ b/docs/sql-ref-syntax-ddl-create-function.md @@ -33,68 +33,59 @@ aggregate functions using Scala, Python and Java APIs. Please refer to ### Syntax -{% highlight sql %} +```sql CREATE [ OR REPLACE ] [ TEMPORARY ] FUNCTION [ IF NOT EXISTS ] function_name AS class_name [ resource_locations ] -{% endhighlight %} +``` ### Parameters -
-
OR REPLACE
-
+* **OR REPLACE** + If specified, the resources for the function are reloaded. This is mainly useful to pick up any changes made to the implementation of the function. This - parameter is mutually exclusive to IF NOT EXISTS and can not + parameter is mutually exclusive to `IF NOT EXISTS` and can not be specified together. -
-
TEMPORARY
-
- Indicates the scope of function being created. When TEMPORARY is specified, the + +* **TEMPORARY** + + Indicates the scope of function being created. When `TEMPORARY` is specified, the created function is valid and visible in the current session. No persistent entry is made in the catalog for these kind of functions. -
-
IF NOT EXISTS
-
+ +* **IF NOT EXISTS** + If specified, creates the function only when it does not exist. The creation of function succeeds (no error is thrown) if the specified function already - exists in the system. This parameter is mutually exclusive to OR REPLACE + exists in the system. This parameter is mutually exclusive to `OR REPLACE` and can not be specified together. -
-
function_name
-
- Specifies a name of function to be created. The function name may be - optionally qualified with a database name.

- Syntax: - - [ database_name. ] function_name - -
-
class_name
-
+ +* **function_name** + + Specifies a name of function to be created. The function name may be optionally qualified with a database name. + + **Syntax:** `[ database_name. ] function_name` + +* **class_name** + Specifies the name of the class that provides the implementation for function to be created. The implementing class should extend one of the base classes as follows: -
    -
  • Should extend UDF or UDAF in org.apache.hadoop.hive.ql.exec package.
  • -
  • Should extend AbstractGenericUDAFResolver, GenericUDF, or - GenericUDTF in org.apache.hadoop.hive.ql.udf.generic package.
  • -
  • Should extend UserDefinedAggregateFunction in org.apache.spark.sql.expressions package.
  • -
-
-
resource_locations
-
+ + * Should extend `UDF` or `UDAF` in `org.apache.hadoop.hive.ql.exec` package. + * Should extend `AbstractGenericUDAFResolver`, `GenericUDF`, or + `GenericUDTF` in `org.apache.hadoop.hive.ql.udf.generic` package. + * Should extend `UserDefinedAggregateFunction` in `org.apache.spark.sql.expressions` package. + +* **resource_locations** + Specifies the list of resources that contain the implementation of the function - along with its dependencies.

- Syntax: - - USING { { (JAR | FILE ) resource_uri } , ... } - -
-
+ along with its dependencies. + + **Syntax:** `USING { { (JAR | FILE ) resource_uri } , ... }` ### Examples -{% highlight sql %} +```sql -- 1. Create a simple UDF `SimpleUdf` that increments the supplied integral value by 10. -- import org.apache.hadoop.hive.ql.exec.UDF; -- public class SimpleUdf extends UDF { @@ -166,10 +157,10 @@ SELECT simple_udf(c1) AS function_return_value FROM t1; | 21| | 22| +---------------------+ -{% endhighlight %} +``` ### Related Statements - * [SHOW FUNCTIONS](sql-ref-syntax-aux-show-functions.html) - * [DESCRIBE FUNCTION](sql-ref-syntax-aux-describe-function.html) - * [DROP FUNCTION](sql-ref-syntax-ddl-drop-function.html) +* [SHOW FUNCTIONS](sql-ref-syntax-aux-show-functions.html) +* [DESCRIBE FUNCTION](sql-ref-syntax-aux-describe-function.html) +* [DROP FUNCTION](sql-ref-syntax-ddl-drop-function.html) diff --git a/docs/sql-ref-syntax-ddl-create-table-datasource.md b/docs/sql-ref-syntax-ddl-create-table-datasource.md index 54827fd63568d..d334447a91011 100644 --- a/docs/sql-ref-syntax-ddl-create-table-datasource.md +++ b/docs/sql-ref-syntax-ddl-create-table-datasource.md @@ -25,10 +25,10 @@ The `CREATE TABLE` statement defines a new table using a Data Source. ### Syntax -{% highlight sql %} +```sql CREATE TABLE [ IF NOT EXISTS ] table_identifier [ ( col_name1 col_type1 [ COMMENT col_comment1 ], ... ) ] - [ USING data_source ] + USING data_source [ OPTIONS ( key1=val1, key2=val2, ... ) ] [ PARTITIONED BY ( col_name1, col_name2, ... ) ] [ CLUSTERED BY ( col_name3, col_name4, ... ) @@ -38,62 +38,52 @@ CREATE TABLE [ IF NOT EXISTS ] table_identifier [ COMMENT table_comment ] [ TBLPROPERTIES ( key1=val1, key2=val2, ... ) ] [ AS select_statement ] -{% endhighlight %} +``` Note that, the clauses between the USING clause and the AS SELECT clause can come in as any order. For example, you can write COMMENT table_comment after TBLPROPERTIES. ### Parameters -
-
table_identifier
-
- Specifies a table name, which may be optionally qualified with a database name.

- Syntax: - - [ database_name. ] table_name - -
-
-
-
USING data_source
-
Data Source is the input format used to create the table. Data source can be CSV, TXT, ORC, JDBC, PARQUET, etc.
-
- -
-
PARTITIONED BY
-
Partitions are created on the table, based on the columns specified.
-
- -
-
CLUSTERED BY
-
- Partitions created on the table will be bucketed into fixed buckets based on the column specified for bucketing.

- NOTE:Bucketing is an optimization technique that uses buckets (and bucketing columns) to determine data partitioning and avoid data shuffle.
-
SORTED BY
-
Determines the order in which the data is stored in buckets. Default is Ascending order.
- -
- -
-
LOCATION
-
Path to the directory where table data is stored, which could be a path on distributed storage like HDFS, etc.
-
- -
-
COMMENT
-
A string literal to describe the table.
-
- -
-
TBLPROPERTIES
-
A list of key-value pairs that is used to tag the table definition.
-
- -
-
AS select_statement
-
The table is populated using the data from the select statement.
-
+* **table_identifier** + + Specifies a table name, which may be optionally qualified with a database name. + + **Syntax:** `[ database_name. ] table_name` + +* **USING data_source** + + Data Source is the input format used to create the table. Data source can be CSV, TXT, ORC, JDBC, PARQUET, etc. + +* **PARTITIONED BY** + + Partitions are created on the table, based on the columns specified. + +* **CLUSTERED BY** + + Partitions created on the table will be bucketed into fixed buckets based on the column specified for bucketing. + + **NOTE:** Bucketing is an optimization technique that uses buckets (and bucketing columns) to determine data partitioning and avoid data shuffle. + +* **SORTED BY** + + Determines the order in which the data is stored in buckets. Default is Ascending order. + +* **LOCATION** + + Path to the directory where table data is stored, which could be a path on distributed storage like HDFS, etc. + +* **COMMENT** + + A string literal to describe the table. + +* **TBLPROPERTIES** + + A list of key-value pairs that is used to tag the table definition. + +* **AS select_statement** + + The table is populated using the data from the select statement. ### Data Source Interaction @@ -110,7 +100,7 @@ input query, to make sure the table gets created contains exactly the same data ### Examples -{% highlight sql %} +```sql --Use data source CREATE TABLE student (id INT, name STRING, age INT) USING CSV; @@ -137,9 +127,9 @@ CREATE TABLE student (id INT, name STRING, age INT) USING CSV PARTITIONED BY (age) CLUSTERED BY (Id) INTO 4 buckets; -{% endhighlight %} +``` ### Related Statements - * [CREATE TABLE USING HIVE FORMAT](sql-ref-syntax-ddl-create-table-hiveformat.html) - * [CREATE TABLE LIKE](sql-ref-syntax-ddl-create-table-like.html) +* [CREATE TABLE USING HIVE FORMAT](sql-ref-syntax-ddl-create-table-hiveformat.html) +* [CREATE TABLE LIKE](sql-ref-syntax-ddl-create-table-like.html) diff --git a/docs/sql-ref-syntax-ddl-create-table-hiveformat.md b/docs/sql-ref-syntax-ddl-create-table-hiveformat.md index 06f353ad2f103..38f8856a24e3d 100644 --- a/docs/sql-ref-syntax-ddl-create-table-hiveformat.md +++ b/docs/sql-ref-syntax-ddl-create-table-hiveformat.md @@ -25,7 +25,7 @@ The `CREATE TABLE` statement defines a new table using Hive format. ### Syntax -{% highlight sql %} +```sql CREATE [ EXTERNAL ] TABLE [ IF NOT EXISTS ] table_identifier [ ( col_name1[:] col_type1 [ COMMENT col_comment1 ], ... ) ] [ COMMENT table_comment ] @@ -36,67 +36,54 @@ CREATE [ EXTERNAL ] TABLE [ IF NOT EXISTS ] table_identifier [ LOCATION path ] [ TBLPROPERTIES ( key1=val1, key2=val2, ... ) ] [ AS select_statement ] -{% endhighlight %} +``` Note that, the clauses between the columns definition clause and the AS SELECT clause can come in as any order. For example, you can write COMMENT table_comment after TBLPROPERTIES. ### Parameters -
-
table_identifier
-
- Specifies a table name, which may be optionally qualified with a database name.

- Syntax: - - [ database_name. ] table_name - -
-
- -
-
EXTERNAL
-
Table is defined using the path provided as LOCATION, does not use default location for this table.
-
- -
-
PARTITIONED BY
-
Partitions are created on the table, based on the columns specified.
-
- -
-
ROW FORMAT
-
SERDE is used to specify a custom SerDe or the DELIMITED clause in order to use the native SerDe.
-
- -
-
STORED AS
-
File format for table storage, could be TEXTFILE, ORC, PARQUET,etc.
-
- -
-
LOCATION
-
Path to the directory where table data is stored, Path to the directory where table data is stored, which could be a path on distributed storage like HDFS, etc.
-
- -
-
COMMENT
-
A string literal to describe the table.
-
- -
-
TBLPROPERTIES
-
A list of key-value pairs that is used to tag the table definition.
-
- -
-
AS select_statement
-
The table is populated using the data from the select statement.
-
+* **table_identifier** + + Specifies a table name, which may be optionally qualified with a database name. + + **Syntax:** `[ database_name. ] table_name` + +* **EXTERNAL** + + Table is defined using the path provided as LOCATION, does not use default location for this table. + +* **PARTITIONED BY** + + Partitions are created on the table, based on the columns specified. + +* **ROW FORMAT** + + SERDE is used to specify a custom SerDe or the DELIMITED clause in order to use the native SerDe. + +* **STORED AS** + + File format for table storage, could be TEXTFILE, ORC, PARQUET, etc. + +* **LOCATION** + + Path to the directory where table data is stored, which could be a path on distributed storage like HDFS, etc. + +* **COMMENT** + + A string literal to describe the table. + +* **TBLPROPERTIES** + + A list of key-value pairs that is used to tag the table definition. + +* **AS select_statement** + + The table is populated using the data from the select statement. ### Examples -{% highlight sql %} +```sql --Use hive format CREATE TABLE student (id INT, name STRING, age INT) STORED AS ORC; @@ -130,9 +117,9 @@ CREATE TABLE student (id INT, name STRING) CREATE TABLE student (id INT,name STRING) ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS TEXTFILE; -{% endhighlight %} +``` ### Related Statements - * [CREATE TABLE USING DATASOURCE](sql-ref-syntax-ddl-create-table-datasource.html) - * [CREATE TABLE LIKE](sql-ref-syntax-ddl-create-table-like.html) +* [CREATE TABLE USING DATASOURCE](sql-ref-syntax-ddl-create-table-datasource.html) +* [CREATE TABLE LIKE](sql-ref-syntax-ddl-create-table-like.html) diff --git a/docs/sql-ref-syntax-ddl-create-table-like.md b/docs/sql-ref-syntax-ddl-create-table-like.md index fe1dc4b1ef258..cfb959ca6b23d 100644 --- a/docs/sql-ref-syntax-ddl-create-table-like.md +++ b/docs/sql-ref-syntax-ddl-create-table-like.md @@ -25,57 +25,46 @@ The `CREATE TABLE` statement defines a new table using the definition/metadata o ### Syntax -{% highlight sql %} +```sql CREATE TABLE [IF NOT EXISTS] table_identifier LIKE source_table_identifier USING data_source [ ROW FORMAT row_format ] [ STORED AS file_format ] [ TBLPROPERTIES ( key1=val1, key2=val2, ... ) ] [ LOCATION path ] -{% endhighlight %} +``` ### Parameters -
-
table_identifier
-
- Specifies a table name, which may be optionally qualified with a database name.

- Syntax: [ TBLPROPERTIES ( key1=val1, key2=val2, ... ) ] - - [ database_name. ] table_name - -
-
- -
-
USING data_source
-
Data Source is the input format used to create the table. Data source can be CSV, TXT, ORC, JDBC, PARQUET, etc.
-
- -
-
ROW FORMAT
-
SERDE is used to specify a custom SerDe or the DELIMITED clause in order to use the native SerDe.
-
- -
-
STORED AS
-
File format for table storage, could be TEXTFILE, ORC, PARQUET,etc.
-
- -
-
TBLPROPERTIES
-
Table properties that have to be set are specified, such as `created.by.user`, `owner`, etc. -
-
- -
-
LOCATION
-
Path to the directory where table data is stored,Path to the directory where table data is stored, which could be a path on distributed storage like HDFS, etc. Location to create an external table.
-
+* **table_identifier** + + Specifies a table name, which may be optionally qualified with a database name. + + **Syntax:** `[ database_name. ] table_name` + +* **USING data_source** + + Data Source is the input format used to create the table. Data source can be CSV, TXT, ORC, JDBC, PARQUET, etc. + +* **ROW FORMAT** + + SERDE is used to specify a custom SerDe or the DELIMITED clause in order to use the native SerDe. + +* **STORED AS** + + File format for table storage, could be TEXTFILE, ORC, PARQUET, etc. + +* **TBLPROPERTIES** + + Table properties that have to be set are specified, such as `created.by.user`, `owner`, etc. + +* **LOCATION** + + Path to the directory where table data is stored, which could be a path on distributed storage like HDFS, etc. Location to create an external table. ### Examples -{% highlight sql %} +```sql -- Create table using an existing table CREATE TABLE Student_Dupli like Student; @@ -90,10 +79,10 @@ CREATE TABLE Student_Dupli like Student ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS TEXTFILE TBLPROPERTIES ('owner'='xxxx'); -{% endhighlight %} +``` ### Related Statements - * [CREATE TABLE USING DATASOURCE](sql-ref-syntax-ddl-create-table-datasource.html) - * [CREATE TABLE USING HIVE FORMAT](sql-ref-syntax-ddl-create-table-hiveformat.html) +* [CREATE TABLE USING DATASOURCE](sql-ref-syntax-ddl-create-table-datasource.html) +* [CREATE TABLE USING HIVE FORMAT](sql-ref-syntax-ddl-create-table-hiveformat.html) diff --git a/docs/sql-ref-syntax-ddl-create-table.md b/docs/sql-ref-syntax-ddl-create-table.md index b0388adbc9a38..85dc2020e6585 100644 --- a/docs/sql-ref-syntax-ddl-create-table.md +++ b/docs/sql-ref-syntax-ddl-create-table.md @@ -25,11 +25,11 @@ license: | The CREATE statements: - * [CREATE TABLE USING DATA_SOURCE](sql-ref-syntax-ddl-create-table-datasource.html) - * [CREATE TABLE USING HIVE FORMAT](sql-ref-syntax-ddl-create-table-hiveformat.html) - * [CREATE TABLE LIKE](sql-ref-syntax-ddl-create-table-like.html) +* [CREATE TABLE USING DATA_SOURCE](sql-ref-syntax-ddl-create-table-datasource.html) +* [CREATE TABLE USING HIVE FORMAT](sql-ref-syntax-ddl-create-table-hiveformat.html) +* [CREATE TABLE LIKE](sql-ref-syntax-ddl-create-table-like.html) ### Related Statements - * [ALTER TABLE](sql-ref-syntax-ddl-alter-table.html) - * [DROP TABLE](sql-ref-syntax-ddl-drop-table.html) +* [ALTER TABLE](sql-ref-syntax-ddl-alter-table.html) +* [DROP TABLE](sql-ref-syntax-ddl-drop-table.html) diff --git a/docs/sql-ref-syntax-ddl-create-view.md b/docs/sql-ref-syntax-ddl-create-view.md index ba8c1df1223a3..1a9c1f62728e7 100644 --- a/docs/sql-ref-syntax-ddl-create-view.md +++ b/docs/sql-ref-syntax-ddl-create-view.md @@ -27,55 +27,47 @@ a virtual table that has no physical data therefore other operations like ### Syntax -{% highlight sql %} +```sql CREATE [ OR REPLACE ] [ [ GLOBAL ] TEMPORARY ] VIEW [ IF NOT EXISTS ] view_identifier create_view_clauses AS query -{% endhighlight %} +``` ### Parameters -
-
OR REPLACE
-
If a view of same name already exists, it will be replaced.
-
-
-
[ GLOBAL ] TEMPORARY
-
TEMPORARY views are session-scoped and will be dropped when session ends - because it skips persisting the definition in the underlying metastore, if any. - GLOBAL TEMPORARY views are tied to a system preserved temporary database global_temp.
-
-
-
IF NOT EXISTS
-
Creates a view if it does not exists.
-
-
-
view_identifier
-
- Specifies a view name, which may be optionally qualified with a database name.

- Syntax: - - [ database_name. ] view_name - -
-
-
-
create_view_clauses
-
These clauses are optional and order insensitive. It can be of following formats. -
    -
  • [ ( column_name [ COMMENT column_comment ], ... ) ] to specify column-level comments.
  • -
  • [ COMMENT view_comment ] to specify view-level comments.
  • -
  • [ TBLPROPERTIES ( property_name = property_value [ , ... ] ) ] to add metadata key-value pairs.
  • -
-
-
-
-
query
-
A SELECT statement that constructs the view from base tables or other views.
-
+* **OR REPLACE** + + If a view of same name already exists, it will be replaced. + +* **[ GLOBAL ] TEMPORARY** + + TEMPORARY views are session-scoped and will be dropped when session ends + because it skips persisting the definition in the underlying metastore, if any. + GLOBAL TEMPORARY views are tied to a system preserved temporary database `global_temp`. + +* **IF NOT EXISTS** + + Creates a view if it does not exist. + +* **view_identifier** + + Specifies a view name, which may be optionally qualified with a database name. + + **Syntax:** `[ database_name. ] view_name` + +* **create_view_clauses** + + These clauses are optional and order insensitive. It can be of following formats. + + * `[ ( column_name [ COMMENT column_comment ], ... ) ]` to specify column-level comments. + * `[ COMMENT view_comment ]` to specify view-level comments. + * `[ TBLPROPERTIES ( property_name = property_value [ , ... ] ) ]` to add metadata key-value pairs. + +* **query** + A [SELECT](sql-ref-syntax-qry-select.html) statement that constructs the view from base tables or other views. ### Examples -{% highlight sql %} +```sql -- Create or replace view for `experienced_employee` with comments. CREATE OR REPLACE VIEW experienced_employee (ID COMMENT 'Unique identification number', Name) @@ -88,10 +80,10 @@ CREATE GLOBAL TEMPORARY VIEW IF NOT EXISTS subscribed_movies AS SELECT mo.member_id, mb.full_name, mo.movie_title FROM movies AS mo INNER JOIN members AS mb ON mo.member_id = mb.id; -{% endhighlight %} +``` ### Related Statements - * [ALTER VIEW](sql-ref-syntax-ddl-alter-view.html) - * [DROP VIEW](sql-ref-syntax-ddl-drop-view.html) - * [SHOW VIEWS](sql-ref-syntax-aux-show-views.html) +* [ALTER VIEW](sql-ref-syntax-ddl-alter-view.html) +* [DROP VIEW](sql-ref-syntax-ddl-drop-view.html) +* [SHOW VIEWS](sql-ref-syntax-aux-show-views.html) diff --git a/docs/sql-ref-syntax-ddl-drop-database.md b/docs/sql-ref-syntax-ddl-drop-database.md index 7467e7a4ad6e7..4a3bc0c68b6d4 100644 --- a/docs/sql-ref-syntax-ddl-drop-database.md +++ b/docs/sql-ref-syntax-ddl-drop-database.md @@ -26,35 +26,31 @@ exception will be thrown if the database does not exist in the system. ### Syntax -{% highlight sql %} +```sql DROP { DATABASE | SCHEMA } [ IF EXISTS ] dbname [ RESTRICT | CASCADE ] -{% endhighlight %} +``` ### Parameters -
-
DATABASE | SCHEMA
-
DATABASE and SCHEMA mean the same thing, either of them can be used.
-
+* **DATABASE `|` SCHEMA** -
-
IF EXISTS
-
If specified, no exception is thrown when the database does not exist.
-
+ `DATABASE` and `SCHEMA` mean the same thing, either of them can be used. -
-
RESTRICT
-
If specified, will restrict dropping a non-empty database and is enabled by default.
-
+* **IF EXISTS** -
-
CASCADE
-
If specified, will drop all the associated tables and functions.
-
+ If specified, no exception is thrown when the database does not exist. + +* **RESTRICT** + + If specified, will restrict dropping a non-empty database and is enabled by default. + +* **CASCADE** + + If specified, will drop all the associated tables and functions. ### Examples -{% highlight sql %} +```sql -- Create `inventory_db` Database CREATE DATABASE inventory_db COMMENT 'This database is used to maintain Inventory'; @@ -63,10 +59,10 @@ DROP DATABASE inventory_db CASCADE; -- Drop the database using IF EXISTS DROP DATABASE IF EXISTS inventory_db CASCADE; -{% endhighlight %} +``` ### Related Statements - * [CREATE DATABASE](sql-ref-syntax-ddl-create-database.html) - * [DESCRIBE DATABASE](sql-ref-syntax-aux-describe-database.html) - * [SHOW DATABASES](sql-ref-syntax-aux-show-databases.html) +* [CREATE DATABASE](sql-ref-syntax-ddl-create-database.html) +* [DESCRIBE DATABASE](sql-ref-syntax-aux-describe-database.html) +* [SHOW DATABASES](sql-ref-syntax-aux-show-databases.html) diff --git a/docs/sql-ref-syntax-ddl-drop-function.md b/docs/sql-ref-syntax-ddl-drop-function.md index 66a405c24e413..bef31d74afcff 100644 --- a/docs/sql-ref-syntax-ddl-drop-function.md +++ b/docs/sql-ref-syntax-ddl-drop-function.md @@ -26,39 +26,32 @@ be thrown if the function does not exist. ### Syntax -{% highlight sql %} +```sql DROP [ TEMPORARY ] FUNCTION [ IF EXISTS ] function_name -{% endhighlight %} +``` ### Parameters -
-
function_name
-
+* **function_name** + Specifies the name of an existing function. The function name may be - optionally qualified with a database name.

- Syntax: - - [ database_name. ] function_name - -
-
- -
-
TEMPORARY
-
Should be used to delete the TEMPORARY function.
-
- -
-
IF EXISTS
-
If specified, no exception is thrown when the function does not exist.
-
+ optionally qualified with a database name. + + **Syntax:** `[ database_name. ] function_name` + +* **TEMPORARY** + + Should be used to delete the `TEMPORARY` function. + +* **IF EXISTS** + + If specified, no exception is thrown when the function does not exist. ### Examples -{% highlight sql %} +```sql -- Create a permanent function `test_avg` -CREATE FUNCTION test_avg as 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage'; +CREATE FUNCTION test_avg AS 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage'; -- List user functions SHOW USER FUNCTIONS; @@ -86,9 +79,9 @@ DROP FUNCTION test_avg; -- Try to drop Permanent function which is not present DROP FUNCTION test_avg; - Error: Error running query: - org.apache.spark.sql.catalyst.analysis.NoSuchPermanentFunctionException: - Function 'default.test_avg' not found in database 'default'; (state=,code=0) +Error: Error running query: +org.apache.spark.sql.catalyst.analysis.NoSuchPermanentFunctionException: +Function 'default.test_avg' not found in database 'default'; (state=,code=0) -- List the functions after dropping, it should list only temporary function SHOW USER FUNCTIONS; @@ -100,10 +93,10 @@ SHOW USER FUNCTIONS; -- Drop Temporary function DROP TEMPORARY FUNCTION IF EXISTS test_avg; -{% endhighlight %} +``` ### Related Statements - * [CREATE FUNCTION](sql-ref-syntax-ddl-create-function.html) - * [DESCRIBE FUNCTION](sql-ref-syntax-aux-describe-function.html) - * [SHOW FUNCTION](sql-ref-syntax-aux-show-functions.html) +* [CREATE FUNCTION](sql-ref-syntax-ddl-create-function.html) +* [DESCRIBE FUNCTION](sql-ref-syntax-aux-describe-function.html) +* [SHOW FUNCTION](sql-ref-syntax-aux-show-functions.html) diff --git a/docs/sql-ref-syntax-ddl-drop-table.md b/docs/sql-ref-syntax-ddl-drop-table.md index c943b922ae812..a15a9928f437d 100644 --- a/docs/sql-ref-syntax-ddl-drop-table.md +++ b/docs/sql-ref-syntax-ddl-drop-table.md @@ -28,49 +28,44 @@ In case of an external table, only the associated metadata information is remove ### Syntax -{% highlight sql %} +```sql DROP TABLE [ IF EXISTS ] table_identifier -{% endhighlight %} +``` ### Parameter -
-
IF EXISTS
-
- If specified, no exception is thrown when the table does not exists. -
-
table_identifier
-
- Specifies the table name to be dropped. The table name may be optionally qualified with a database name.

- Syntax: - - [ database_name. ] table_name - -
-
+* **IF EXISTS** + + If specified, no exception is thrown when the table does not exist. + +* **table_identifier** + + Specifies the table name to be dropped. The table name may be optionally qualified with a database name. + + **Syntax:** `[ database_name. ] table_name` ### Examples -{% highlight sql %} +```sql -- Assumes a table named `employeetable` exists. DROP TABLE employeetable; -- Assumes a table named `employeetable` exists in the `userdb` database DROP TABLE userdb.employeetable; --- Assumes a table named `employeetable` does not exists. +-- Assumes a table named `employeetable` does not exist. -- Throws exception DROP TABLE employeetable; - Error: org.apache.spark.sql.AnalysisException: Table or view not found: employeetable; - (state=,code=0) +Error: org.apache.spark.sql.AnalysisException: Table or view not found: employeetable; +(state=,code=0) --- Assumes a table named `employeetable` does not exists,Try with IF EXISTS +-- Assumes a table named `employeetable` does not exist,Try with IF EXISTS -- this time it will not throw exception DROP TABLE IF EXISTS employeetable; -{% endhighlight %} +``` ### Related Statements - * [CREATE TABLE](sql-ref-syntax-ddl-create-table.html) - * [CREATE DATABASE](sql-ref-syntax-ddl-create-database.html) - * [DROP DATABASE](sql-ref-syntax-ddl-drop-database.html) +* [CREATE TABLE](sql-ref-syntax-ddl-create-table.html) +* [CREATE DATABASE](sql-ref-syntax-ddl-create-database.html) +* [DROP DATABASE](sql-ref-syntax-ddl-drop-database.html) diff --git a/docs/sql-ref-syntax-ddl-drop-view.md b/docs/sql-ref-syntax-ddl-drop-view.md index ad018b5e6fd5c..5b680d7f907e0 100644 --- a/docs/sql-ref-syntax-ddl-drop-view.md +++ b/docs/sql-ref-syntax-ddl-drop-view.md @@ -25,51 +25,46 @@ license: | ### Syntax -{% highlight sql %} +```sql DROP VIEW [ IF EXISTS ] view_identifier -{% endhighlight %} +``` ### Parameter -
-
IF EXISTS
-
- If specified, no exception is thrown when the view does not exists. -
-
view_identifier
-
- Specifies the view name to be dropped. The view name may be optionally qualified with a database name.

- Syntax: - - [ database_name. ] view_name - -
-
+* **IF EXISTS** + + If specified, no exception is thrown when the view does not exist. + +* **view_identifier** + + Specifies the view name to be dropped. The view name may be optionally qualified with a database name. + + **Syntax:** `[ database_name. ] view_name` ### Examples -{% highlight sql %} +```sql -- Assumes a view named `employeeView` exists. DROP VIEW employeeView; -- Assumes a view named `employeeView` exists in the `userdb` database DROP VIEW userdb.employeeView; --- Assumes a view named `employeeView` does not exists. +-- Assumes a view named `employeeView` does not exist. -- Throws exception DROP VIEW employeeView; - Error: org.apache.spark.sql.AnalysisException: Table or view not found: employeeView; - (state=,code=0) +Error: org.apache.spark.sql.AnalysisException: Table or view not found: employeeView; +(state=,code=0) --- Assumes a view named `employeeView` does not exists,Try with IF EXISTS +-- Assumes a view named `employeeView` does not exist,Try with IF EXISTS -- this time it will not throw exception DROP VIEW IF EXISTS employeeView; -{% endhighlight %} +``` ### Related Statements - * [CREATE VIEW](sql-ref-syntax-ddl-create-view.html) - * [ALTER VIEW](sql-ref-syntax-ddl-alter-view.html) - * [SHOW VIEWS](sql-ref-syntax-aux-show-views.html) - * [CREATE DATABASE](sql-ref-syntax-ddl-create-database.html) - * [DROP DATABASE](sql-ref-syntax-ddl-drop-database.html) +* [CREATE VIEW](sql-ref-syntax-ddl-create-view.html) +* [ALTER VIEW](sql-ref-syntax-ddl-alter-view.html) +* [SHOW VIEWS](sql-ref-syntax-aux-show-views.html) +* [CREATE DATABASE](sql-ref-syntax-ddl-create-database.html) +* [DROP DATABASE](sql-ref-syntax-ddl-drop-database.html) diff --git a/docs/sql-ref-syntax-ddl-repair-table.md b/docs/sql-ref-syntax-ddl-repair-table.md index c48b731512ad3..c2ef0a7b7fbe9 100644 --- a/docs/sql-ref-syntax-ddl-repair-table.md +++ b/docs/sql-ref-syntax-ddl-repair-table.md @@ -25,26 +25,21 @@ license: | ### Syntax -{% highlight sql %} +```sql MSCK REPAIR TABLE table_identifier -{% endhighlight %} +``` ### Parameters -
-
table_identifier
-
- Specifies the name of the table to be repaired. The table name may be optionally qualified with a database name.

- Syntax: - - [ database_name. ] table_name - -
-
+* **table_identifier** + + Specifies the name of the table to be repaired. The table name may be optionally qualified with a database name. + + **Syntax:** `[ database_name. ] table_name` ### Examples -{% highlight sql %} +```sql -- create a partitioned table from existing data /tmp/namesAndAges.parquet CREATE TABLE t1 (name STRING, age INT) USING parquet PARTITIONED BY (age) LOCATION "/tmp/namesAndAges.parquet"; @@ -66,8 +61,8 @@ SELECT * FROM t1; +-------+---+ | Andy| 30| +-------+---+ -{% endhighlight %} +``` ### Related Statements - * [ALTER TABLE](sql-ref-syntax-ddl-alter-table.html) +* [ALTER TABLE](sql-ref-syntax-ddl-alter-table.html) diff --git a/docs/sql-ref-syntax-ddl-truncate-table.md b/docs/sql-ref-syntax-ddl-truncate-table.md index 820f439f97a4b..6139814a3259a 100644 --- a/docs/sql-ref-syntax-ddl-truncate-table.md +++ b/docs/sql-ref-syntax-ddl-truncate-table.md @@ -27,37 +27,28 @@ in `partition_spec`. If no `partition_spec` is specified it will remove all part ### Syntax -{% highlight sql %} +```sql TRUNCATE TABLE table_identifier [ partition_spec ] -{% endhighlight %} +``` ### Parameters -
-
table_identifier
-
- Specifies a table name, which may be optionally qualified with a database name.

- Syntax: - - [ database_name. ] table_name - -
-
-
-
partition_spec
-
+* **table_identifier** + + Specifies a table name, which may be optionally qualified with a database name. + + **Syntax:** `[ database_name. ] table_name` + +* **partition_spec** + An optional parameter that specifies a comma separated list of key and value pairs - for partitions.

- Syntax: - - PARTITION ( partition_col_name = partition_col_val [ , ... ] ) - -
-
+ for partitions. + + **Syntax:** `PARTITION ( partition_col_name = partition_col_val [ , ... ] )` ### Examples -{% highlight sql %} +```sql -- Create table Student with partition CREATE TABLE Student (name STRING, rollno INT) PARTITIONED BY (age INT); @@ -89,9 +80,9 @@ SELECT * FROM Student; |name|rollno|age| +----+------+---+ +----+------+---+ -{% endhighlight %} +``` ### Related Statements - * [DROP TABLE](sql-ref-syntax-ddl-drop-table.html) - * [ALTER TABLE](sql-ref-syntax-ddl-alter-table.html) +* [DROP TABLE](sql-ref-syntax-ddl-drop-table.html) +* [ALTER TABLE](sql-ref-syntax-ddl-alter-table.html) diff --git a/docs/sql-ref-syntax-dml-insert-into.md b/docs/sql-ref-syntax-dml-insert-into.md index 924831f7feedd..ed5da2b2d28df 100644 --- a/docs/sql-ref-syntax-dml-insert-into.md +++ b/docs/sql-ref-syntax-dml-insert-into.md @@ -25,57 +25,43 @@ The `INSERT INTO` statement inserts new rows into a table. The inserted rows can ### Syntax -{% highlight sql %} +```sql INSERT INTO [ TABLE ] table_identifier [ partition_spec ] { VALUES ( { value | NULL } [ , ... ] ) [ , ( ... ) ] | query } -{% endhighlight %} +``` ### Parameters -
-
table_identifier
-
- Specifies a table name, which may be optionally qualified with a database name.

- Syntax: - - [ database_name. ] table_name - -
-
- -
-
partition_spec
-
+* **table_identifier** + + Specifies a table name, which may be optionally qualified with a database name. + + **Syntax:** `[ database_name. ] table_name` + +* **partition_spec** + An optional parameter that specifies a comma separated list of key and value pairs - for partitions.

- Syntax: - - PARTITION ( partition_col_name = partition_col_val [ , ... ] ) - -
-
- -
-
VALUES ( { value | NULL } [ , ... ] ) [ , ( ... ) ]
-
Specifies the values to be inserted. Either an explicitly specified value or a NULL can be inserted. A comma must be used to separate each value in the clause. More than one set of values can be specified to insert multiple rows.
-
- -
-
query
-
A query that produces the rows to be inserted. It can be in one of following formats: -
    -
  • a SELECT statement
  • -
  • a TABLE statement
  • -
  • a FROM statement
  • -
-
-
+ for partitions. + + **Syntax:** `PARTITION ( partition_col_name = partition_col_val [ , ... ] )` + +* **VALUES ( { value `|` NULL } [ , ... ] ) [ , ( ... ) ]** + + Specifies the values to be inserted. Either an explicitly specified value or a NULL can be inserted. + A comma must be used to separate each value in the clause. More than one set of values can be specified to insert multiple rows. + +* **query** + + A query that produces the rows to be inserted. It can be in one of following formats: + * a `SELECT` statement + * a `TABLE` statement + * a `FROM` statement ### Examples #### Single Row Insert Using a VALUES Clause -{% highlight sql %} +```sql CREATE TABLE students (name VARCHAR(64), address VARCHAR(64), student_id INT) USING PARQUET PARTITIONED BY (student_id); @@ -88,11 +74,11 @@ SELECT * FROM students; +---------+---------------------+----------+ |Amy Smith|123 Park Ave,San Jose| 111111| +---------+---------------------+----------+ -{% endhighlight %} +``` #### Multi-Row Insert Using a VALUES Clause -{% highlight sql %} +```sql INSERT INTO students VALUES ('Bob Brown', '456 Taylor St, Cupertino', 222222), ('Cathy Johnson', '789 Race Ave, Palo Alto', 333333); @@ -107,11 +93,11 @@ SELECT * FROM students; +-------------+------------------------+----------+ |Cathy Johnson| 789 Race Ave, Palo Alto| 333333| +--------------+-----------------------+----------+ -{% endhighlight %} +``` #### Insert Using a SELECT Statement -{% highlight sql %} +```sql -- Assuming the persons table has already been created and populated. SELECT * FROM persons; +-------------+-------------------------+---------+ @@ -137,11 +123,11 @@ SELECT * FROM students; +-------------+-------------------------+----------+ |Dora Williams|134 Forest Ave, Melo Park| 444444| +-------------+-------------------------+----------+ -{% endhighlight %} +``` #### Insert Using a TABLE Statement -{% highlight sql %} +```sql -- Assuming the visiting_students table has already been created and populated. SELECT * FROM visiting_students; +-------------+---------------------+----------+ @@ -170,11 +156,11 @@ SELECT * FROM students; +-------------+-------------------------+----------+ |Gordon Martin| 779 Lake Ave, Oxford| 888888| +-------------+-------------------------+----------+ -{% endhighlight %} +``` #### Insert Using a FROM Statement -{% highlight sql %} +```sql -- Assuming the applicants table has already been created and populated. SELECT * FROM applicants; +-----------+--------------------------+----------+---------+ @@ -210,10 +196,10 @@ SELECT * FROM students; +-------------+-------------------------+----------+ | Jason Wang| 908 Bird St, Saratoga| 121212| +-------------+-------------------------+----------+ -{% endhighlight %} +``` ### Related Statements - * [INSERT OVERWRITE statement](sql-ref-syntax-dml-insert-overwrite-table.html) - * [INSERT OVERWRITE DIRECTORY statement](sql-ref-syntax-dml-insert-overwrite-directory.html) - * [INSERT OVERWRITE DIRECTORY with Hive format statement](sql-ref-syntax-dml-insert-overwrite-directory-hive.html) +* [INSERT OVERWRITE statement](sql-ref-syntax-dml-insert-overwrite-table.html) +* [INSERT OVERWRITE DIRECTORY statement](sql-ref-syntax-dml-insert-overwrite-directory.html) +* [INSERT OVERWRITE DIRECTORY with Hive format statement](sql-ref-syntax-dml-insert-overwrite-directory-hive.html) diff --git a/docs/sql-ref-syntax-dml-insert-overwrite-directory-hive.md b/docs/sql-ref-syntax-dml-insert-overwrite-directory-hive.md index 3cd2107668fbe..8ed6a3cd1be09 100644 --- a/docs/sql-ref-syntax-dml-insert-overwrite-directory-hive.md +++ b/docs/sql-ref-syntax-dml-insert-overwrite-directory-hive.md @@ -26,56 +26,41 @@ Hive support must be enabled to use this command. The inserted rows can be speci ### Syntax -{% highlight sql %} +```sql INSERT OVERWRITE [ LOCAL ] DIRECTORY directory_path [ ROW FORMAT row_format ] [ STORED AS file_format ] { VALUES ( { value | NULL } [ , ... ] ) [ , ( ... ) ] | query } -{% endhighlight %} +``` ### Parameters -
-
directory_path
-
- Specifies the destination directory. The LOCAL keyword is used to specify that the directory is on the local file system. -
-
- -
-
row_format
-
- Specifies the row format for this insert. Valid options are SERDE clause and DELIMITED clause. SERDE clause can be used to specify a custom SerDe for this insert. Alternatively, DELIMITED clause can be used to specify the native SerDe and state the delimiter, escape character, null character, and so on. -
-
- -
-
file_format
-
- Specifies the file format for this insert. Valid options are TEXTFILE, SEQUENCEFILE, RCFILE, ORC, PARQUET, and AVRO. You can also specify your own input and output format using INPUTFORMAT and OUTPUTFORMAT. ROW FORMAT SERDE can only be used with TEXTFILE, SEQUENCEFILE, or RCFILE, while ROW FORMAT DELIMITED can only be used with TEXTFILE. -
-
- -
-
VALUES ( { value | NULL } [ , ... ] ) [ , ( ... ) ]
-
- Specifies the values to be inserted. Either an explicitly specified value or a NULL can be inserted. A comma must be used to separate each value in the clause. More than one set of values can be specified to insert multiple rows. -
-
- -
-
query
-
A query that produces the rows to be inserted. It can be in one of following formats: -
    -
  • a SELECT statement
  • -
  • a TABLE statement
  • -
  • a FROM statement
  • -
-
-
+* **directory_path** + + Specifies the destination directory. The `LOCAL` keyword is used to specify that the directory is on the local file system. + +* **row_format** + + Specifies the row format for this insert. Valid options are `SERDE` clause and `DELIMITED` clause. `SERDE` clause can be used to specify a custom `SerDe` for this insert. Alternatively, `DELIMITED` clause can be used to specify the native `SerDe` and state the delimiter, escape character, null character, and so on. + +* **file_format** + + Specifies the file format for this insert. Valid options are `TEXTFILE`, `SEQUENCEFILE`, `RCFILE`, `ORC`, `PARQUET`, and `AVRO`. You can also specify your own input and output format using `INPUTFORMAT` and `OUTPUTFORMAT`. `ROW FORMAT SERDE` can only be used with `TEXTFILE`, `SEQUENCEFILE`, or `RCFILE`, while `ROW FORMAT DELIMITED` can only be used with `TEXTFILE`. + +* **VALUES ( { value `|` NULL } [ , ... ] ) [ , ( ... ) ]** + + Specifies the values to be inserted. Either an explicitly specified value or a NULL can be inserted. + A comma must be used to separate each value in the clause. More than one set of values can be specified to insert multiple rows. + +* **query** + + A query that produces the rows to be inserted. It can be in one of following formats: + * a `SELECT` statement + * a `TABLE` statement + * a `FROM` statement ### Examples -{% highlight sql %} +```sql INSERT OVERWRITE LOCAL DIRECTORY '/tmp/destination' STORED AS orc SELECT * FROM test_table; @@ -83,10 +68,10 @@ INSERT OVERWRITE LOCAL DIRECTORY '/tmp/destination' INSERT OVERWRITE LOCAL DIRECTORY '/tmp/destination' ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' SELECT * FROM test_table; -{% endhighlight %} +``` ### Related Statements - * [INSERT INTO statement](sql-ref-syntax-dml-insert-into.html) - * [INSERT OVERWRITE statement](sql-ref-syntax-dml-insert-overwrite-table.html) - * [INSERT OVERWRITE DIRECTORY statement](sql-ref-syntax-dml-insert-overwrite-directory.html) +* [INSERT INTO statement](sql-ref-syntax-dml-insert-into.html) +* [INSERT OVERWRITE statement](sql-ref-syntax-dml-insert-overwrite-table.html) +* [INSERT OVERWRITE DIRECTORY statement](sql-ref-syntax-dml-insert-overwrite-directory.html) diff --git a/docs/sql-ref-syntax-dml-insert-overwrite-directory.md b/docs/sql-ref-syntax-dml-insert-overwrite-directory.md index 6ce7f50588e32..fd7437d37c909 100644 --- a/docs/sql-ref-syntax-dml-insert-overwrite-directory.md +++ b/docs/sql-ref-syntax-dml-insert-overwrite-directory.md @@ -25,54 +25,42 @@ The `INSERT OVERWRITE DIRECTORY` statement overwrites the existing data in the d ### Syntax -{% highlight sql %} +```sql INSERT OVERWRITE [ LOCAL ] DIRECTORY [ directory_path ] USING file_format [ OPTIONS ( key = val [ , ... ] ) ] { VALUES ( { value | NULL } [ , ... ] ) [ , ( ... ) ] | query } -{% endhighlight %} +``` ### Parameters -
-
directory_path
-
- Specifies the destination directory. It can also be specified in OPTIONS using path. The LOCAL keyword is used to specify that the directory is on the local file system. -
-
- -
-
file_format
-
- Specifies the file format to use for the insert. Valid options are TEXT, CSV, JSON, JDBC, PARQUET, ORC, HIVE, LIBSVM, or a fully qualified class name of a custom implementation of org.apache.spark.sql.execution.datasources.FileFormat. -
-
- -
-
OPTIONS ( key = val [ , ... ] )
-
Specifies one or more options for the writing of the file format.
-
- -
-
VALUES ( { value | NULL } [ , ... ] ) [ , ( ... ) ]
-
- Specifies the values to be inserted. Either an explicitly specified value or a NULL can be inserted. A comma must be used to separate each value in the clause. More than one set of values can be specified to insert multiple rows. -
-
- -
-
query
-
A query that produces the rows to be inserted. It can be in one of following formats: -
    -
  • a SELECT statement
  • -
  • a TABLE statement
  • -
  • a FROM statement
  • -
-
-
+* **directory_path** + + Specifies the destination directory. It can also be specified in `OPTIONS` using `path`. + The `LOCAL` keyword is used to specify that the directory is on the local file system. + +* **file_format** + + Specifies the file format to use for the insert. Valid options are `TEXT`, `CSV`, `JSON`, `JDBC`, `PARQUET`, `ORC`, `HIVE`, `LIBSVM`, or a fully qualified class name of a custom implementation of `org.apache.spark.sql.execution.datasources.FileFormat`. + +* **OPTIONS ( key = val [ , ... ] )** + + Specifies one or more options for the writing of the file format. + +* **VALUES ( { value `|` NULL } [ , ... ] ) [ , ( ... ) ]** + + Specifies the values to be inserted. Either an explicitly specified value or a NULL can be inserted. + A comma must be used to separate each value in the clause. More than one set of values can be specified to insert multiple rows. + +* **query** + + A query that produces the rows to be inserted. It can be in one of following formats: + * a `SELECT` statement + * a `TABLE` statement + * a `FROM` statement ### Examples -{% highlight sql %} +```sql INSERT OVERWRITE DIRECTORY '/tmp/destination' USING parquet OPTIONS (col1 1, col2 2, col3 'test') @@ -82,10 +70,10 @@ INSERT OVERWRITE DIRECTORY USING parquet OPTIONS ('path' '/tmp/destination', col1 1, col2 2, col3 'test') SELECT * FROM test_table; -{% endhighlight %} +``` ### Related Statements - * [INSERT INTO statement](sql-ref-syntax-dml-insert-into.html) - * [INSERT OVERWRITE statement](sql-ref-syntax-dml-insert-overwrite-table.html) - * [INSERT OVERWRITE DIRECTORY with Hive format statement](sql-ref-syntax-dml-insert-overwrite-directory-hive.html) +* [INSERT INTO statement](sql-ref-syntax-dml-insert-into.html) +* [INSERT OVERWRITE statement](sql-ref-syntax-dml-insert-overwrite-table.html) +* [INSERT OVERWRITE DIRECTORY with Hive format statement](sql-ref-syntax-dml-insert-overwrite-directory-hive.html) diff --git a/docs/sql-ref-syntax-dml-insert-overwrite-table.md b/docs/sql-ref-syntax-dml-insert-overwrite-table.md index 5c760f00ed0c4..ecfd060dfd5ee 100644 --- a/docs/sql-ref-syntax-dml-insert-overwrite-table.md +++ b/docs/sql-ref-syntax-dml-insert-overwrite-table.md @@ -25,57 +25,43 @@ The `INSERT OVERWRITE` statement overwrites the existing data in the table using ### Syntax -{% highlight sql %} +```sql INSERT OVERWRITE [ TABLE ] table_identifier [ partition_spec [ IF NOT EXISTS ] ] { VALUES ( { value | NULL } [ , ... ] ) [ , ( ... ) ] | query } -{% endhighlight %} +``` ### Parameters -
-
table_identifier
-
- Specifies a table name, which may be optionally qualified with a database name.

- Syntax: - - [ database_name. ] table_name - -
-
- -
-
partition_spec
-
+* **table_identifier** + + Specifies a table name, which may be optionally qualified with a database name. + + **Syntax:** `[ database_name. ] table_name` + +* **partition_spec** + An optional parameter that specifies a comma separated list of key and value pairs - for partitions.

- Syntax: - - PARTITION ( partition_col_name [ = partition_col_val ] [ , ... ] ) - -
-
- -
-
VALUES ( { value | NULL } [ , ... ] ) [ , ( ... ) ]
-
Specifies the values to be inserted. Either an explicitly specified value or a NULL can be inserted. A comma must be used to separate each value in the clause. More than one set of values can be specified to insert multiple rows.
-
- -
-
query
-
A query that produces the rows to be inserted. It can be in one of following formats: -
    -
  • a SELECT statement
  • -
  • a TABLE statement
  • -
  • a FROM statement
  • -
-
-
+ for partitions. + + **Syntax:** `PARTITION ( partition_col_name [ = partition_col_val ] [ , ... ] )` + +* **VALUES ( { value `|` NULL } [ , ... ] ) [ , ( ... ) ]** + + Specifies the values to be inserted. Either an explicitly specified value or a NULL can be inserted. + A comma must be used to separate each value in the clause. More than one set of values can be specified to insert multiple rows. + +* **query** + + A query that produces the rows to be inserted. It can be in one of following formats: + * a `SELECT` statement + * a `TABLE` statement + * a `FROM` statement ### Examples #### Insert Using a VALUES Clause -{% highlight sql %} +```sql -- Assuming the students table has already been created and populated. SELECT * FROM students; +-------------+-------------------------+----------+ @@ -102,12 +88,11 @@ SELECT * FROM students; |Ashua Hill|456 Erica Ct, Cupertino| 111111| |Brian Reed|723 Kern Ave, Palo Alto| 222222| +----------+-----------------------+----------+ - -{% endhighlight %} +``` #### Insert Using a SELECT Statement -{% highlight sql %} +```sql -- Assuming the persons table has already been created and populated. SELECT * FROM persons; +-------------+-------------------------+---------+ @@ -129,11 +114,11 @@ SELECT * FROM students; +-------------+-------------------------+----------+ |Dora Williams|134 Forest Ave, Melo Park| 222222| +-------------+-------------------------+----------+ -{% endhighlight %} +``` #### Insert Using a TABLE Statement -{% highlight sql %} +```sql -- Assuming the visiting_students table has already been created and populated. SELECT * FROM visiting_students; +-------------+---------------------+----------+ @@ -154,11 +139,11 @@ SELECT * FROM students; +-------------+---------------------+----------+ |Gordon Martin| 779 Lake Ave, Oxford| 888888| +-------------+---------------------+----------+ -{% endhighlight %} +``` #### Insert Using a FROM Statement -{% highlight sql %} +```sql -- Assuming the applicants table has already been created and populated. SELECT * FROM applicants; +-----------+--------------------------+----------+---------+ @@ -182,10 +167,10 @@ SELECT * FROM students; +-----------+-------------------------+----------+ | Jason Wang| 908 Bird St, Saratoga| 121212| +-----------+-------------------------+----------+ -{% endhighlight %} +``` ### Related Statements - * [INSERT INTO statement](sql-ref-syntax-dml-insert-into.html) - * [INSERT OVERWRITE DIRECTORY statement](sql-ref-syntax-dml-insert-overwrite-directory.html) - * [INSERT OVERWRITE DIRECTORY with Hive format statement](sql-ref-syntax-dml-insert-overwrite-directory-hive.html) +* [INSERT INTO statement](sql-ref-syntax-dml-insert-into.html) +* [INSERT OVERWRITE DIRECTORY statement](sql-ref-syntax-dml-insert-overwrite-directory.html) +* [INSERT OVERWRITE DIRECTORY with Hive format statement](sql-ref-syntax-dml-insert-overwrite-directory-hive.html) diff --git a/docs/sql-ref-syntax-dml-insert.md b/docs/sql-ref-syntax-dml-insert.md index 2345add2460c8..62f6dee876450 100644 --- a/docs/sql-ref-syntax-dml-insert.md +++ b/docs/sql-ref-syntax-dml-insert.md @@ -21,7 +21,7 @@ license: | The INSERT statements: - * [INSERT INTO statement](sql-ref-syntax-dml-insert-into.html) - * [INSERT OVERWRITE statement](sql-ref-syntax-dml-insert-overwrite-table.html) - * [INSERT OVERWRITE DIRECTORY statement](sql-ref-syntax-dml-insert-overwrite-directory.html) - * [INSERT OVERWRITE DIRECTORY with Hive format statement](sql-ref-syntax-dml-insert-overwrite-directory-hive.html) +* [INSERT INTO statement](sql-ref-syntax-dml-insert-into.html) +* [INSERT OVERWRITE statement](sql-ref-syntax-dml-insert-overwrite-table.html) +* [INSERT OVERWRITE DIRECTORY statement](sql-ref-syntax-dml-insert-overwrite-directory.html) +* [INSERT OVERWRITE DIRECTORY with Hive format statement](sql-ref-syntax-dml-insert-overwrite-directory-hive.html) diff --git a/docs/sql-ref-syntax-dml-load.md b/docs/sql-ref-syntax-dml-load.md index 01ece31bd17fa..9381b4267fb24 100644 --- a/docs/sql-ref-syntax-dml-load.md +++ b/docs/sql-ref-syntax-dml-load.md @@ -25,53 +25,40 @@ license: | ### Syntax -{% highlight sql %} +```sql LOAD DATA [ LOCAL ] INPATH path [ OVERWRITE ] INTO TABLE table_identifier [ partition_spec ] -{% endhighlight %} +``` ### Parameters -
-
path
-
Path of the file system. It can be either an absolute or a relative path.
-
- -
-
table_identifier
-
- Specifies a table name, which may be optionally qualified with a database name.

- Syntax: - - [ database_name. ] table_name - -
-
- -
-
partition_spec
-
+* **path** + + Path of the file system. It can be either an absolute or a relative path. + +* **table_identifier** + + Specifies a table name, which may be optionally qualified with a database name. + + **Syntax:** `[ database_name. ] table_name` + +* **partition_spec** + An optional parameter that specifies a comma separated list of key and value pairs - for partitions.

- Syntax: - - PARTITION ( partition_col_name = partition_col_val [ , ... ] ) - -
-
- -
-
LOCAL
-
If specified, it causes the INPATH to be resolved against the local file system, instead of the default file system, which is typically a distributed storage.
-
- -
-
OVERWRITE
-
By default, new data is appended to the table. If OVERWRITE is used, the table is instead overwritten with new data.
-
+ for partitions. + + **Syntax:** `PARTITION ( partition_col_name = partition_col_val [ , ... ] )` + +* **LOCAL** + + If specified, it causes the `INPATH` to be resolved against the local file system, instead of the default file system, which is typically a distributed storage. + +* **OVERWRITE** + + By default, new data is appended to the table. If `OVERWRITE` is used, the table is instead overwritten with new data. ### Examples -{% highlight sql %} +```sql -- Example without partition specification. -- Assuming the students table has already been created and populated. SELECT * FROM students; @@ -123,4 +110,4 @@ SELECT * FROM test_load_partition; +---+---+---+ | 1| 2| 3| +---+---+---+ -{% endhighlight %} +``` diff --git a/docs/sql-ref-syntax-dml.md b/docs/sql-ref-syntax-dml.md index 9f75990555f64..fc408e1d38d26 100644 --- a/docs/sql-ref-syntax-dml.md +++ b/docs/sql-ref-syntax-dml.md @@ -21,5 +21,5 @@ license: | Data Manipulation Statements are used to add, change, or delete data. Spark SQL supports the following Data Manipulation Statements: - * [INSERT](sql-ref-syntax-dml-insert.html) - * [LOAD](sql-ref-syntax-dml-load.html) +* [INSERT](sql-ref-syntax-dml-insert.html) +* [LOAD](sql-ref-syntax-dml-load.html) diff --git a/docs/sql-ref-syntax-qry-explain.md b/docs/sql-ref-syntax-qry-explain.md index 298a2edaea1f2..7b84264a28cca 100644 --- a/docs/sql-ref-syntax-qry-explain.md +++ b/docs/sql-ref-syntax-qry-explain.md @@ -26,46 +26,38 @@ By default, this clause provides information about a physical plan only. ### Syntax -{% highlight sql %} +```sql EXPLAIN [ EXTENDED | CODEGEN | COST | FORMATTED ] statement -{% endhighlight %} +``` ### Parameters -
-
EXTENDED
-
Generates parsed logical plan, analyzed logical plan, optimized logical plan and physical plan. - Parsed Logical plan is a unresolved plan that extracted from the query. - Analyzed logical plans transforms which translates unresolvedAttribute and unresolvedRelation into fully typed objects. - The optimized logical plan transforms through a set of optimization rules, resulting in the physical plan. -
-
- -
-
CODEGEN
-
Generates code for the statement, if any and a physical plan.
-
- -
-
COST
-
If plan node statistics are available, generates a logical plan and the statistics.
-
- -
-
FORMATTED
-
Generates two sections: a physical plan outline and node details.
-
- -
-
statement
-
+* **EXTENDED** + + Generates parsed logical plan, analyzed logical plan, optimized logical plan and physical plan. + Parsed Logical plan is a unresolved plan that extracted from the query. + Analyzed logical plans transforms which translates unresolvedAttribute and unresolvedRelation into fully typed objects. + The optimized logical plan transforms through a set of optimization rules, resulting in the physical plan. + +* **CODEGEN** + + Generates code for the statement, if any and a physical plan. + +* **COST** + + If plan node statistics are available, generates a logical plan and the statistics. + +* **FORMATTED** + + Generates two sections: a physical plan outline and node details. + +* **statement** + Specifies a SQL statement to be explained. -
-
### Examples -{% highlight sql %} +```sql -- Default Output EXPLAIN select k, sum(v) from values (1, 2), (1, 3) t(k, v) group by k; +----------------------------------------------------+ @@ -132,4 +124,4 @@ EXPLAIN FORMATTED select k, sum(v) from values (1, 2), (1, 3) t(k, v) group by k Input: [k#19, sum#24L] | +----------------------------------------------------+ -{% endhighlight %} +``` diff --git a/docs/sql-ref-syntax-qry-select-clusterby.md b/docs/sql-ref-syntax-qry-select-clusterby.md index ac1e1ccb00ac9..e3bd2ed926ecc 100644 --- a/docs/sql-ref-syntax-qry-select-clusterby.md +++ b/docs/sql-ref-syntax-qry-select-clusterby.md @@ -21,7 +21,7 @@ license: | ### Description -The CLUSTER BY clause is used to first repartition the data based +The `CLUSTER BY` clause is used to first repartition the data based on the input expressions and then sort the data within each partition. This is semantically equivalent to performing a [DISTRIBUTE BY](sql-ref-syntax-qry-select-distribute-by.html) followed by a @@ -30,22 +30,19 @@ resultant rows are sorted within each partition and does not guarantee a total o ### Syntax -{% highlight sql %} +```sql CLUSTER BY { expression [ , ... ] } -{% endhighlight %} +``` ### Parameters -
-
expression
-
+* **expression** + Specifies combination of one or more values, operators and SQL functions that results in a value. -
-
### Examples -{% highlight sql %} +```sql CREATE TABLE person (name STRING, age INT); INSERT INTO person VALUES ('Zen Hui', 25), @@ -90,15 +87,15 @@ SELECT age, name FROM person CLUSTER BY age; | 16|Shone S| | 16| Jack N| +---+-------+ -{% endhighlight %} +``` ### Related Statements - * [SELECT Main](sql-ref-syntax-qry-select.html) - * [WHERE Clause](sql-ref-syntax-qry-select-where.html) - * [GROUP BY Clause](sql-ref-syntax-qry-select-groupby.html) - * [HAVING Clause](sql-ref-syntax-qry-select-having.html) - * [ORDER BY Clause](sql-ref-syntax-qry-select-orderby.html) - * [SORT BY Clause](sql-ref-syntax-qry-select-sortby.html) - * [DISTRIBUTE BY Clause](sql-ref-syntax-qry-select-distribute-by.html) - * [LIMIT Clause](sql-ref-syntax-qry-select-limit.html) +* [SELECT Main](sql-ref-syntax-qry-select.html) +* [WHERE Clause](sql-ref-syntax-qry-select-where.html) +* [GROUP BY Clause](sql-ref-syntax-qry-select-groupby.html) +* [HAVING Clause](sql-ref-syntax-qry-select-having.html) +* [ORDER BY Clause](sql-ref-syntax-qry-select-orderby.html) +* [SORT BY Clause](sql-ref-syntax-qry-select-sortby.html) +* [DISTRIBUTE BY Clause](sql-ref-syntax-qry-select-distribute-by.html) +* [LIMIT Clause](sql-ref-syntax-qry-select-limit.html) diff --git a/docs/sql-ref-syntax-qry-select-cte.md b/docs/sql-ref-syntax-qry-select-cte.md index 2408c884c64b5..351de64a2d026 100644 --- a/docs/sql-ref-syntax-qry-select-cte.md +++ b/docs/sql-ref-syntax-qry-select-cte.md @@ -25,33 +25,28 @@ A common table expression (CTE) defines a temporary result set that a user can r ### Syntax -{% highlight sql %} +```sql WITH common_table_expression [ , ... ] -{% endhighlight %} +``` While `common_table_expression` is defined as -{% highlight sql %} -expression_name [ ( column_name [ , ... ] ) ] [ AS ] ( [ common_table_expression ] query ) -{% endhighlight %} +```sql +expression_name [ ( column_name [ , ... ] ) ] [ AS ] ( query ) +``` ### Parameters -
-
expression_name
-
+* **expression_name** + Specifies a name for the common table expression. -
-
-
-
query
-
- A SELECT statement. -
-
+ +* **query** + + A [SELECT statement](sql-ref-syntax-qry-select.html). ### Examples -{% highlight sql %} +```sql -- CTE with multiple column aliases WITH t(x, y) AS (SELECT 1, 2) SELECT * FROM t WHERE x = 1 AND y = 2; @@ -62,7 +57,7 @@ SELECT * FROM t WHERE x = 1 AND y = 2; +---+---+ -- CTE in CTE definition -WITH t as ( +WITH t AS ( WITH t2 AS (SELECT 1) SELECT * FROM t2 ) @@ -122,8 +117,8 @@ SELECT * FROM t2; +---+ | 2| +---+ -{% endhighlight %} +``` ### Related Statements - * [SELECT](sql-ref-syntax-qry-select.html) +* [SELECT](sql-ref-syntax-qry-select.html) diff --git a/docs/sql-ref-syntax-qry-select-distribute-by.md b/docs/sql-ref-syntax-qry-select-distribute-by.md index 9e2db27ae7161..1fdfb91dad286 100644 --- a/docs/sql-ref-syntax-qry-select-distribute-by.md +++ b/docs/sql-ref-syntax-qry-select-distribute-by.md @@ -21,28 +21,25 @@ license: | ### Description -The DISTRIBUTE BY clause is used to repartition the data based +The `DISTRIBUTE BY` clause is used to repartition the data based on the input expressions. Unlike the [CLUSTER BY](sql-ref-syntax-qry-select-clusterby.html) clause, this does not sort the data within each partition. ### Syntax -{% highlight sql %} +```sql DISTRIBUTE BY { expression [ , ... ] } -{% endhighlight %} +``` ### Parameters -
-
expression
-
+* **expression** + Specifies combination of one or more values, operators and SQL functions that results in a value. -
-
### Examples -{% highlight sql %} +```sql CREATE TABLE person (name STRING, age INT); INSERT INTO person VALUES ('Zen Hui', 25), @@ -85,15 +82,15 @@ SELECT age, name FROM person DISTRIBUTE BY age; | 16|Shone S| | 16| Jack N| +---+-------+ -{% endhighlight %} +``` ### Related Statements - * [SELECT Main](sql-ref-syntax-qry-select.html) - * [WHERE Clause](sql-ref-syntax-qry-select-where.html) - * [GROUP BY Clause](sql-ref-syntax-qry-select-groupby.html) - * [HAVING Clause](sql-ref-syntax-qry-select-having.html) - * [ORDER BY Clause](sql-ref-syntax-qry-select-orderby.html) - * [SORT BY Clause](sql-ref-syntax-qry-select-sortby.html) - * [CLUSTER BY Clause](sql-ref-syntax-qry-select-clusterby.html) - * [LIMIT Clause](sql-ref-syntax-qry-select-limit.html) +* [SELECT Main](sql-ref-syntax-qry-select.html) +* [WHERE Clause](sql-ref-syntax-qry-select-where.html) +* [GROUP BY Clause](sql-ref-syntax-qry-select-groupby.html) +* [HAVING Clause](sql-ref-syntax-qry-select-having.html) +* [ORDER BY Clause](sql-ref-syntax-qry-select-orderby.html) +* [SORT BY Clause](sql-ref-syntax-qry-select-sortby.html) +* [CLUSTER BY Clause](sql-ref-syntax-qry-select-clusterby.html) +* [LIMIT Clause](sql-ref-syntax-qry-select-limit.html) diff --git a/docs/sql-ref-syntax-qry-select-groupby.md b/docs/sql-ref-syntax-qry-select-groupby.md index 22fe782f9eaa7..bd9377ef78df6 100644 --- a/docs/sql-ref-syntax-qry-select-groupby.md +++ b/docs/sql-ref-syntax-qry-select-groupby.md @@ -21,84 +21,79 @@ license: | ### Description -The GROUP BY clause is used to group the rows based on a set of specified grouping expressions and compute aggregations on +The `GROUP BY` clause is used to group the rows based on a set of specified grouping expressions and compute aggregations on the group of rows based on one or more specified aggregate functions. Spark also supports advanced aggregations to do multiple aggregations for the same input record set via `GROUPING SETS`, `CUBE`, `ROLLUP` clauses. When a FILTER clause is attached to an aggregate function, only the matching rows are passed to that function. ### Syntax -{% highlight sql %} +```sql GROUP BY group_expression [ , group_expression [ , ... ] ] [ { WITH ROLLUP | WITH CUBE | GROUPING SETS (grouping_set [ , ...]) } ] GROUP BY GROUPING SETS (grouping_set [ , ...]) -{% endhighlight %} +``` While aggregate functions are defined as -{% highlight sql %} +```sql aggregate_name ( [ DISTINCT ] expression [ , ... ] ) [ FILTER ( WHERE boolean_expression ) ] -{% endhighlight %} +``` ### Parameters -
-
GROUPING SETS
-
+* **GROUPING SETS** + Groups the rows for each subset of the expressions specified in the grouping sets. For example, - GROUP BY GROUPING SETS (warehouse, product) is semantically equivalent - to union of results of GROUP BY warehouse and GROUP BY product. This clause - is a shorthand for a UNION ALL where each leg of the UNION ALL - operator performs aggregation of subset of the columns specified in the GROUPING SETS clause. -
-
grouping_set
-
- A grouping set is specified by zero or more comma-separated expressions in parentheses.

- Syntax: - - ([expression [, ...]]) - -
-
grouping_expression
-
+ `GROUP BY GROUPING SETS (warehouse, product)` is semantically equivalent + to union of results of `GROUP BY warehouse` and `GROUP BY product`. This clause + is a shorthand for a `UNION ALL` where each leg of the `UNION ALL` + operator performs aggregation of subset of the columns specified in the `GROUPING SETS` clause. + +* **grouping_set** + + A grouping set is specified by zero or more comma-separated expressions in parentheses. + + **Syntax:** `( [ expression [ , ... ] ] )` + +* **grouping_expression** + Specifies the critieria based on which the rows are grouped together. The grouping of rows is performed based on result values of the grouping expressions. A grouping expression may be a column alias, a column position or an expression. -
-
ROLLUP
-
+ +* **ROLLUP** + Specifies multiple levels of aggregations in a single statement. This clause is used to compute aggregations - based on multiple grouping sets. ROLLUP is a shorthand for GROUPING SETS. For example, - GROUP BY warehouse, product WITH ROLLUP is equivalent to GROUP BY GROUPING SETS - ((warehouse, product), (warehouse), ()). - The N elements of a ROLLUP specification results in N+1 GROUPING SETS. -
-
CUBE
-
- CUBE clause is used to perform aggregations based on combination of grouping columns specified in the - GROUP BY clause. CUBE is a shorthand for GROUPING SETS. For example, - GROUP BY warehouse, product WITH CUBE is equivalent to GROUP BY GROUPING SETS - ((warehouse, product), (warehouse), (product), ()). - The N elements of a CUBE specification results in 2^N GROUPING SETS. -
-
aggregate_name
-
+ based on multiple grouping sets. `ROLLUP` is a shorthand for `GROUPING SETS`. For example, + `GROUP BY warehouse, product WITH ROLLUP` is equivalent to `GROUP BY GROUPING SETS + ((warehouse, product), (warehouse), ())`. + The N elements of a `ROLLUP` specification results in N+1 `GROUPING SETS`. + +* **CUBE** + + `CUBE` clause is used to perform aggregations based on combination of grouping columns specified in the + `GROUP BY` clause. `CUBE` is a shorthand for `GROUPING SETS`. For example, + `GROUP BY warehouse, product WITH CUBE` is equivalent to `GROUP BY GROUPING SETS + ((warehouse, product), (warehouse), (product), ())`. + The N elements of a `CUBE` specification results in 2^N `GROUPING SETS`. + +* **aggregate_name** + Specifies an aggregate function name (MIN, MAX, COUNT, SUM, AVG, etc.). -
-
DISTINCT
-
+ +* **DISTINCT** + Removes duplicates in input rows before they are passed to aggregate functions. -
-
FILTER
-
- Filters the input rows for which the boolean_expression in the WHERE clause evaluates + +* **FILTER** + + Filters the input rows for which the `boolean_expression` in the `WHERE` clause evaluates to true are passed to the aggregate function; other rows are discarded. -
-
### Examples -{% highlight sql %} +```sql CREATE TABLE dealer (id INT, city STRING, car_model STRING, quantity INT); INSERT INTO dealer VALUES (100, 'Fremont', 'Honda Civic', 10), @@ -174,106 +169,106 @@ SELECT id, sum(quantity) FILTER ( SELECT city, car_model, sum(quantity) AS sum FROM dealer GROUP BY GROUPING SETS ((city, car_model), (city), (car_model), ()) ORDER BY city; -+--------+------------+---+ -| city| car_model|sum| -+--------+------------+---+ -| null| null| 78| -| null| HondaAccord| 33| -| null| HondaCRV| 10| -| null| HondaCivic| 35| -| Dublin| null| 33| -| Dublin| HondaAccord| 10| -| Dublin| HondaCRV| 3| -| Dublin| HondaCivic| 20| -| Fremont| null| 32| -| Fremont| HondaAccord| 15| -| Fremont| HondaCRV| 7| -| Fremont| HondaCivic| 10| -| SanJose| null| 13| -| SanJose| HondaAccord| 8| -| SanJose| HondaCivic| 5| -+--------+------------+---+ ++---------+------------+---+ +| city| car_model|sum| ++---------+------------+---+ +| null| null| 78| +| null| HondaAccord| 33| +| null| HondaCRV| 10| +| null| HondaCivic| 35| +| Dublin| null| 33| +| Dublin| HondaAccord| 10| +| Dublin| HondaCRV| 3| +| Dublin| HondaCivic| 20| +| Fremont| null| 32| +| Fremont| HondaAccord| 15| +| Fremont| HondaCRV| 7| +| Fremont| HondaCivic| 10| +| San Jose| null| 13| +| San Jose| HondaAccord| 8| +| San Jose| HondaCivic| 5| ++---------+------------+---+ -- Alternate syntax for `GROUPING SETS` in which both `GROUP BY` and `GROUPING SETS` -- specifications are present. SELECT city, car_model, sum(quantity) AS sum FROM dealer GROUP BY city, car_model GROUPING SETS ((city, car_model), (city), (car_model), ()) ORDER BY city, car_model; -+--------+------------+---+ -| city| car_model|sum| -+--------+------------+---+ -| null| null| 78| -| null| HondaAccord| 33| -| null| HondaCRV| 10| -| null| HondaCivic| 35| -| Dublin| null| 33| -| Dublin| HondaAccord| 10| -| Dublin| HondaCRV| 3| -| Dublin| HondaCivic| 20| -| Fremont| null| 32| -| Fremont| HondaAccord| 15| -| Fremont| HondaCRV| 7| -| Fremont| HondaCivic| 10| -| SanJose| null| 13| -| SanJose| HondaAccord| 8| -| SanJose| HondaCivic| 5| -+--------+------------+---+ ++---------+------------+---+ +| city| car_model|sum| ++---------+------------+---+ +| null| null| 78| +| null| HondaAccord| 33| +| null| HondaCRV| 10| +| null| HondaCivic| 35| +| Dublin| null| 33| +| Dublin| HondaAccord| 10| +| Dublin| HondaCRV| 3| +| Dublin| HondaCivic| 20| +| Fremont| null| 32| +| Fremont| HondaAccord| 15| +| Fremont| HondaCRV| 7| +| Fremont| HondaCivic| 10| +| San Jose| null| 13| +| San Jose| HondaAccord| 8| +| San Jose| HondaCivic| 5| ++---------+------------+---+ -- Group by processing with `ROLLUP` clause. -- Equivalent GROUP BY GROUPING SETS ((city, car_model), (city), ()) SELECT city, car_model, sum(quantity) AS sum FROM dealer GROUP BY city, car_model WITH ROLLUP ORDER BY city, car_model; -+--------+------------+---+ -| city| car_model|sum| -+--------+------------+---+ -| null| null| 78| -| Dublin| null| 33| -| Dublin| HondaAccord| 10| -| Dublin| HondaCRV| 3| -| Dublin| HondaCivic| 20| -| Fremont| null| 32| -| Fremont| HondaAccord| 15| -| Fremont| HondaCRV| 7| -| Fremont| HondaCivic| 10| -| SanJose| null| 13| -| SanJose| HondaAccord| 8| -| SanJose| HondaCivic| 5| -+--------+------------+---+ ++---------+------------+---+ +| city| car_model|sum| ++---------+------------+---+ +| null| null| 78| +| Dublin| null| 33| +| Dublin| HondaAccord| 10| +| Dublin| HondaCRV| 3| +| Dublin| HondaCivic| 20| +| Fremont| null| 32| +| Fremont| HondaAccord| 15| +| Fremont| HondaCRV| 7| +| Fremont| HondaCivic| 10| +| San Jose| null| 13| +| San Jose| HondaAccord| 8| +| San Jose| HondaCivic| 5| ++---------+------------+---+ -- Group by processing with `CUBE` clause. -- Equivalent GROUP BY GROUPING SETS ((city, car_model), (city), (car_model), ()) SELECT city, car_model, sum(quantity) AS sum FROM dealer GROUP BY city, car_model WITH CUBE ORDER BY city, car_model; -+--------+------------+---+ -| city| car_model|sum| -+--------+------------+---+ -| null| null| 78| -| null| HondaAccord| 33| -| null| HondaCRV| 10| -| null| HondaCivic| 35| -| Dublin| null| 33| -| Dublin| HondaAccord| 10| -| Dublin| HondaCRV| 3| -| Dublin| HondaCivic| 20| -| Fremont| null| 32| -| Fremont| HondaAccord| 15| -| Fremont| HondaCRV| 7| -| Fremont| HondaCivic| 10| -| SanJose| null| 13| -| SanJose| HondaAccord| 8| -| SanJose| HondaCivic| 5| -+--------+------------+---+ -{% endhighlight %} ++---------+------------+---+ +| city| car_model|sum| ++---------+------------+---+ +| null| null| 78| +| null| HondaAccord| 33| +| null| HondaCRV| 10| +| null| HondaCivic| 35| +| Dublin| null| 33| +| Dublin| HondaAccord| 10| +| Dublin| HondaCRV| 3| +| Dublin| HondaCivic| 20| +| Fremont| null| 32| +| Fremont| HondaAccord| 15| +| Fremont| HondaCRV| 7| +| Fremont| HondaCivic| 10| +| San Jose| null| 13| +| San Jose| HondaAccord| 8| +| San Jose| HondaCivic| 5| ++---------+------------+---+ +``` ### Related Statements - * [SELECT Main](sql-ref-syntax-qry-select.html) - * [WHERE Clause](sql-ref-syntax-qry-select-where.html) - * [HAVING Clause](sql-ref-syntax-qry-select-having.html) - * [ORDER BY Clause](sql-ref-syntax-qry-select-orderby.html) - * [SORT BY Clause](sql-ref-syntax-qry-select-sortby.html) - * [CLUSTER BY Clause](sql-ref-syntax-qry-select-clusterby.html) - * [DISTRIBUTE BY Clause](sql-ref-syntax-qry-select-distribute-by.html) - * [LIMIT Clause](sql-ref-syntax-qry-select-limit.html) +* [SELECT Main](sql-ref-syntax-qry-select.html) +* [WHERE Clause](sql-ref-syntax-qry-select-where.html) +* [HAVING Clause](sql-ref-syntax-qry-select-having.html) +* [ORDER BY Clause](sql-ref-syntax-qry-select-orderby.html) +* [SORT BY Clause](sql-ref-syntax-qry-select-sortby.html) +* [CLUSTER BY Clause](sql-ref-syntax-qry-select-clusterby.html) +* [DISTRIBUTE BY Clause](sql-ref-syntax-qry-select-distribute-by.html) +* [LIMIT Clause](sql-ref-syntax-qry-select-limit.html) diff --git a/docs/sql-ref-syntax-qry-select-having.md b/docs/sql-ref-syntax-qry-select-having.md index c8c4f2c38104c..935782c551e1f 100644 --- a/docs/sql-ref-syntax-qry-select-having.md +++ b/docs/sql-ref-syntax-qry-select-having.md @@ -21,39 +21,35 @@ license: | ### Description -The HAVING clause is used to filter the results produced by -GROUP BY based on the specified condition. It is often used +The `HAVING` clause is used to filter the results produced by +`GROUP BY` based on the specified condition. It is often used in conjunction with a [GROUP BY](sql-ref-syntax-qry-select-groupby.html) clause. ### Syntax -{% highlight sql %} +```sql HAVING boolean_expression -{% endhighlight %} +``` ### Parameters -
-
boolean_expression
-
- Specifies any expression that evaluates to a result type boolean. Two or +* **boolean_expression** + + Specifies any expression that evaluates to a result type `boolean`. Two or more expressions may be combined together using the logical - operators ( AND, OR ).

- - Note
- The expressions specified in the HAVING clause can only refer to: -
    -
  1. Constants
  2. -
  3. Expressions that appear in GROUP BY
  4. -
  5. Aggregate functions
  6. -
-
-
+ operators ( `AND`, `OR` ). + + **Note** + + The expressions specified in the `HAVING` clause can only refer to: + 1. Constants + 2. Expressions that appear in GROUP BY + 3. Aggregate functions ### Examples -{% highlight sql %} +```sql CREATE TABLE dealer (id INT, city STRING, car_model STRING, quantity INT); INSERT INTO dealer VALUES (100, 'Fremont', 'Honda Civic', 10), @@ -117,15 +113,15 @@ SELECT sum(quantity) AS sum FROM dealer HAVING sum(quantity) > 10; +---+ | 78| +---+ -{% endhighlight %} +``` ### Related Statements - * [SELECT Main](sql-ref-syntax-qry-select.html) - * [WHERE Clause](sql-ref-syntax-qry-select-where.html) - * [GROUP BY Clause](sql-ref-syntax-qry-select-groupby.html) - * [ORDER BY Clause](sql-ref-syntax-qry-select-orderby.html) - * [SORT BY Clause](sql-ref-syntax-qry-select-sortby.html) - * [CLUSTER BY Clause](sql-ref-syntax-qry-select-clusterby.html) - * [DISTRIBUTE BY Clause](sql-ref-syntax-qry-select-distribute-by.html) - * [LIMIT Clause](sql-ref-syntax-qry-select-limit.html) +* [SELECT Main](sql-ref-syntax-qry-select.html) +* [WHERE Clause](sql-ref-syntax-qry-select-where.html) +* [GROUP BY Clause](sql-ref-syntax-qry-select-groupby.html) +* [ORDER BY Clause](sql-ref-syntax-qry-select-orderby.html) +* [SORT BY Clause](sql-ref-syntax-qry-select-sortby.html) +* [CLUSTER BY Clause](sql-ref-syntax-qry-select-clusterby.html) +* [DISTRIBUTE BY Clause](sql-ref-syntax-qry-select-distribute-by.html) +* [LIMIT Clause](sql-ref-syntax-qry-select-limit.html) diff --git a/docs/sql-ref-syntax-qry-select-hints.md b/docs/sql-ref-syntax-qry-select-hints.md index 16f4f95f90ea1..247ce48e79445 100644 --- a/docs/sql-ref-syntax-qry-select-hints.md +++ b/docs/sql-ref-syntax-qry-select-hints.md @@ -1,7 +1,7 @@ --- layout: global -title: Join Hints -displayTitle: Join Hints +title: Hints +displayTitle: Hints license: | Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with @@ -21,41 +21,106 @@ license: | ### Description -Join Hints allow users to suggest the join strategy that Spark should use. Prior to Spark 3.0, only the `BROADCAST` Join Hint was supported. `MERGE`, `SHUFFLE_HASH` and `SHUFFLE_REPLICATE_NL` Joint Hints support was added in 3.0. When different join strategy hints are specified on both sides of a join, Spark prioritizes hints in the following order: `BROADCAST` over `MERGE` over `SHUFFLE_HASH` over `SHUFFLE_REPLICATE_NL`. When both sides are specified with the `BROADCAST` hint or the `SHUFFLE_HASH` hint, Spark will pick the build side based on the join type and the sizes of the relations. Since a given strategy may not support all join types, Spark is not guaranteed to use the join strategy suggested by the hint. - -### Join Hints Types - -
-
BROADCAST
-
- Suggests that Spark use broadcast join. The join side with the hint will be broadcast regardless of autoBroadcastJoinThreshold. If both sides of the join have the broadcast hints, the one with the smaller size (based on stats) will be broadcast. The aliases for BROADCAST are BROADCASTJOIN and MAPJOIN. -
-
- -
-
MERGE
-
- Suggests that Spark use shuffle sort merge join. The aliases for MERGE are SHUFFLE_MERGE and MERGEJOIN. -
-
- -
-
SHUFFLE_HASH
-
- Suggests that Spark use shuffle hash join. If both sides have the shuffle hash hints, Spark chooses the smaller side (based on stats) as the build side. -
-
- -
-
SHUFFLE_REPLICATE_NL
-
+Hints give users a way to suggest how Spark SQL to use specific approaches to generate its execution plan. + +### Syntax + +```sql +/*+ hint [ , ... ] */ +``` + +### Partitioning Hints + +Partitioning hints allow users to suggest a partitioning stragety that Spark should follow. `COALESCE`, `REPARTITION`, +and `REPARTITION_BY_RANGE` hints are supported and are equivalent to `coalesce`, `repartition`, and +`repartitionByRange` [Dataset APIs](api/scala/org/apache/spark/sql/Dataset.html), respectively. These hints give users +a way to tune performance and control the number of output files in Spark SQL. When multiple partitioning hints are +specified, multiple nodes are inserted into the logical plan, but the leftmost hint is picked by the optimizer. + +#### Partitioning Hints Types + +* **COALESCE** + + The `COALESCE` hint can be used to reduce the number of partitions to the specified number of partitions. It takes a partition number as a parameter. + +* **REPARTITION** + + The `REPARTITION` hint can be used to repartition to the specified number of partitions using the specified partitioning expressions. It takes a partition number, column names, or both as parameters. + +* **REPARTITION_BY_RANGE** + + The `REPARTITION_BY_RANGE` hint can be used to repartition to the specified number of partitions using the specified partitioning expressions. It takes column names and an optional partition number as parameters. + +#### Examples + +```sql +SELECT /*+ COALESCE(3) */ * FROM t; + +SELECT /*+ REPARTITION(3) */ * FROM t; + +SELECT /*+ REPARTITION(c) */ * FROM t; + +SELECT /*+ REPARTITION(3, c) */ * FROM t; + +SELECT /*+ REPARTITION_BY_RANGE(c) */ * FROM t; + +SELECT /*+ REPARTITION_BY_RANGE(3, c) */ * FROM t; + +-- multiple partitioning hints +EXPLAIN EXTENDED SELECT /*+ REPARTITION(100), COALESCE(500), REPARTITION_BY_RANGE(3, c) */ * FROM t; +== Parsed Logical Plan == +'UnresolvedHint REPARTITION, [100] ++- 'UnresolvedHint COALESCE, [500] + +- 'UnresolvedHint REPARTITION_BY_RANGE, [3, 'c] + +- 'Project [*] + +- 'UnresolvedRelation [t] + +== Analyzed Logical Plan == +name: string, c: int +Repartition 100, true ++- Repartition 500, false + +- RepartitionByExpression [c#30 ASC NULLS FIRST], 3 + +- Project [name#29, c#30] + +- SubqueryAlias spark_catalog.default.t + +- Relation[name#29,c#30] parquet + +== Optimized Logical Plan == +Repartition 100, true ++- Relation[name#29,c#30] parquet + +== Physical Plan == +Exchange RoundRobinPartitioning(100), false, [id=#121] ++- *(1) ColumnarToRow + +- FileScan parquet default.t[name#29,c#30] Batched: true, DataFilters: [], Format: Parquet, + Location: CatalogFileIndex[file:/spark/spark-warehouse/t], PartitionFilters: [], + PushedFilters: [], ReadSchema: struct +``` + +### Join Hints + +Join hints allow users to suggest the join strategy that Spark should use. Prior to Spark 3.0, only the `BROADCAST` Join Hint was supported. `MERGE`, `SHUFFLE_HASH` and `SHUFFLE_REPLICATE_NL` Joint Hints support was added in 3.0. When different join strategy hints are specified on both sides of a join, Spark prioritizes hints in the following order: `BROADCAST` over `MERGE` over `SHUFFLE_HASH` over `SHUFFLE_REPLICATE_NL`. When both sides are specified with the `BROADCAST` hint or the `SHUFFLE_HASH` hint, Spark will pick the build side based on the join type and the sizes of the relations. Since a given strategy may not support all join types, Spark is not guaranteed to use the join strategy suggested by the hint. + +#### Join Hints Types + +* **BROADCAST** + + Suggests that Spark use broadcast join. The join side with the hint will be broadcast regardless of `autoBroadcastJoinThreshold`. If both sides of the join have the broadcast hints, the one with the smaller size (based on stats) will be broadcast. The aliases for `BROADCAST` are `BROADCASTJOIN` and `MAPJOIN`. + +* **MERGE** + + Suggests that Spark use shuffle sort merge join. The aliases for `MERGE` are `SHUFFLE_MERGE` and `MERGEJOIN`. + +* **SHUFFLE_HASH** + + Suggests that Spark use shuffle hash join. If both sides have the shuffle hash hints, Spark chooses the smaller side (based on stats) as the build side. + +* **SHUFFLE_REPLICATE_NL** + Suggests that Spark use shuffle-and-replicate nested loop join. -
-
-### Examples +#### Examples -{% highlight sql %} +```sql -- Join Hints for broadcast join SELECT /*+ BROADCAST(t1) */ * FROM t1 INNER JOIN t2 ON t1.key = t2.key; SELECT /*+ BROADCASTJOIN (t1) */ * FROM t1 left JOIN t2 ON t1.key = t2.key; @@ -78,10 +143,10 @@ SELECT /*+ SHUFFLE_REPLICATE_NL(t1) */ * FROM t1 INNER JOIN t2 ON t1.key = t2.ke -- Spark will issue Warning in the following example -- org.apache.spark.sql.catalyst.analysis.HintErrorLogger: Hint (strategy=merge) -- is overridden by another hint and will not take effect. -SELECT /*+ BROADCAST(t1) */ /*+ MERGE(t1, t2) */ * FROM t1 INNER JOIN t2 ON t1.key = t2.key; -{% endhighlight %} +SELECT /*+ BROADCAST(t1), MERGE(t1, t2) */ * FROM t1 INNER JOIN t2 ON t1.key = t2.key; +``` ### Related Statements - * [JOIN](sql-ref-syntax-qry-select-join.html) - * [SELECT](sql-ref-syntax-qry-select.html) +* [JOIN](sql-ref-syntax-qry-select-join.html) +* [SELECT](sql-ref-syntax-qry-select.html) diff --git a/docs/sql-ref-syntax-qry-select-inline-table.md b/docs/sql-ref-syntax-qry-select-inline-table.md index 9c33cbc679f06..7f0372ea787e5 100644 --- a/docs/sql-ref-syntax-qry-select-inline-table.md +++ b/docs/sql-ref-syntax-qry-select-inline-table.md @@ -25,32 +25,25 @@ An inline table is a temporary table created using a VALUES clause. ### Syntax -{% highlight sql %} +```sql VALUES ( expression [ , ... ] ) [ table_alias ] -{% endhighlight %} +``` ### Parameters -
-
expression
-
+* **expression** + Specifies a combination of one or more values, operators and SQL functions that results in a value. -
-
-
-
table_alias
-
- Specifies a temporary name with an optional column name list.

- Syntax: - - [ AS ] table_name [ ( column_name [ , ... ] ) ] - -
-
+ +* **table_alias** + + Specifies a temporary name with an optional column name list. + + **Syntax:** `[ AS ] table_name [ ( column_name [ , ... ] ) ]` ### Examples -{% highlight sql %} +```sql -- single row, without a table alias SELECT * FROM VALUES ("one", 1); +----+----+ @@ -77,8 +70,8 @@ SELECT * FROM VALUES ("one", array(0, 1)), ("two", array(2, 3)) AS data(a, b); |one|[0, 1]| |two|[2, 3]| +---+------+ -{% endhighlight %} +``` -### Related Statement +### Related Statements - * [SELECT](sql-ref-syntax-qry-select.html) +* [SELECT](sql-ref-syntax-qry-select.html) diff --git a/docs/sql-ref-syntax-qry-select-join.md b/docs/sql-ref-syntax-qry-select-join.md index 0b1bb1eb8fd61..09b0efd7b5751 100644 --- a/docs/sql-ref-syntax-qry-select-join.md +++ b/docs/sql-ref-syntax-qry-select-join.md @@ -25,118 +25,95 @@ A SQL join is used to combine rows from two relations based on join criteria. Th ### Syntax -{% highlight sql %} +```sql relation { [ join_type ] JOIN relation [ join_criteria ] | NATURAL join_type JOIN relation } -{% endhighlight %} +``` ### Parameters -
-
relation
-
+* **relation** + Specifies the relation to be joined. -
-
join_type
-
- Specifies the join type.

- Syntax:
- - [ INNER ] - | CROSS - | LEFT [ OUTER ] - | [ LEFT ] SEMI - | RIGHT [ OUTER ] - | FULL [ OUTER ] - | [ LEFT ] ANTI - -
-
join_criteria
-
- Specifies how the rows from one relation will be combined with the rows of another relation.

- Syntax: - - ON boolean_expression | USING ( column_name [ , column_name ... ] ) -

- boolean_expression
- Specifies an expression with a return type of boolean. -
-
+ +* **join_type** + + Specifies the join type. + + **Syntax:** + + `[ INNER ] | CROSS | LEFT [ OUTER ] | [ LEFT ] SEMI | RIGHT [ OUTER ] | FULL [ OUTER ] | [ LEFT ] ANTI` + +* **join_criteria** + + Specifies how the rows from one relation will be combined with the rows of another relation. + + **Syntax:** `ON boolean_expression | USING ( column_name [ , ... ] )` + + `boolean_expression` + + Specifies an expression with a return type of boolean. ### Join Types -#### Inner Join - -
-The inner join is the default join in Spark SQL. It selects rows that have matching values in both relations.

- Syntax:
- - relation [ INNER ] JOIN relation [ join_criteria ] - -
- -#### Left Join - -
-A left join returns all values from the left relation and the matched values from the right relation, or appends NULL if there is no match. It is also referred to as a left outer join.

- Syntax:
- - relation LEFT [ OUTER ] JOIN relation [ join_criteria ] - -
- -#### Right Join - -
-A right join returns all values from the right relation and the matched values from the left relation, or appends NULL if there is no match. It is also referred to as a right outer join.

- Syntax:
- - relation RIGHT [ OUTER ] JOIN relation [ join_criteria ] - -
- -#### Full Join - -
-A full join returns all values from both relations, appending NULL values on the side that does not have a match. It is also referred to as a full outer join.

- Syntax:
- - relation FULL [ OUTER ] JOIN relation [ join_criteria ] - -
- -#### Cross Join - -
-A cross join returns the Cartesian product of two relations.

- Syntax:
- - relation CROSS JOIN relation [ join_criteria ] - -
- -#### Semi Join - -
-A semi join returns values from the left side of the relation that has a match with the right. It is also referred to as a left semi join.

- Syntax:
- - relation [ LEFT ] SEMI JOIN relation [ join_criteria ] - -
- -#### Anti Join - -
-An anti join returns values from the left relation that has no match with the right. It is also referred to as a left anti join.

- Syntax:
- - relation [ LEFT ] ANTI JOIN relation [ join_criteria ] - -
+#### **Inner Join** + +The inner join is the default join in Spark SQL. It selects rows that have matching values in both relations. + +**Syntax:** + +`relation [ INNER ] JOIN relation [ join_criteria ]` + +#### **Left Join** + +A left join returns all values from the left relation and the matched values from the right relation, or appends NULL if there is no match. It is also referred to as a left outer join. + +**Syntax:** + +`relation LEFT [ OUTER ] JOIN relation [ join_criteria ]` + +#### **Right Join** + +A right join returns all values from the right relation and the matched values from the left relation, or appends NULL if there is no match. It is also referred to as a right outer join. + +**Syntax:** + +`relation RIGHT [ OUTER ] JOIN relation [ join_criteria ]` + +#### **Full Join** + +A full join returns all values from both relations, appending NULL values on the side that does not have a match. It is also referred to as a full outer join. + +**Syntax:** + +`relation FULL [ OUTER ] JOIN relation [ join_criteria ]` + +#### **Cross Join** + +A cross join returns the Cartesian product of two relations. + +**Syntax:** + +`relation CROSS JOIN relation [ join_criteria ]` + +#### **Semi Join** + +A semi join returns values from the left side of the relation that has a match with the right. It is also referred to as a left semi join. + +**Syntax:** + +`relation [ LEFT ] SEMI JOIN relation [ join_criteria ]` + +#### **Anti Join** + +An anti join returns values from the left relation that has no match with the right. It is also referred to as a left anti join. + +**Syntax:** + +`relation [ LEFT ] ANTI JOIN relation [ join_criteria ]` ### Examples -{% highlight sql %} +```sql -- Use employee and department tables to demonstrate different type of joins. SELECT * FROM employee; +---+-----+------+ @@ -253,9 +230,9 @@ SELECT * FROM employee ANTI JOIN department ON employee.deptno = department.dept |104| Evan| 4| |106| Amy| 6| +---+-----+------+ -{% endhighlight %} +``` ### Related Statements - * [SELECT](sql-ref-syntax-qry-select.html) - * [Join Hints](sql-ref-syntax-qry-select-hints.html) +* [SELECT](sql-ref-syntax-qry-select.html) +* [Hints](sql-ref-syntax-qry-select-hints.html) diff --git a/docs/sql-ref-syntax-qry-select-like.md b/docs/sql-ref-syntax-qry-select-like.md index 408673c532ddd..feb5eb7b3c80d 100644 --- a/docs/sql-ref-syntax-qry-select-like.md +++ b/docs/sql-ref-syntax-qry-select-like.md @@ -25,38 +25,30 @@ A LIKE predicate is used to search for a specific pattern. ### Syntax -{% highlight sql %} +```sql [ NOT ] { LIKE search_pattern [ ESCAPE esc_char ] | RLIKE regex_pattern } -{% endhighlight %} +``` ### Parameters -
-
search_pattern
-
- Specifies a string pattern to be searched by the LIKE clause. It can contain special pattern-matching characters: -
    -
  • %
  • matches zero or more characters. -
  • _
  • matches exactly one character. -
-
-
-
-
esc_char
-
- Specifies the escape character. The default escape character is \. -
-
-
-
regex_pattern
-
- Specifies a regular expression search pattern to be searched by the RLIKE clause. -
-
+* **search_pattern** + + Specifies a string pattern to be searched by the `LIKE` clause. It can contain special pattern-matching characters: + + * `%` matches zero or more characters. + * `_` matches exactly one character. + +* **esc_char** + + Specifies the escape character. The default escape character is `\`. + +* **regex_pattern** + + Specifies a regular expression search pattern to be searched by the `RLIKE` clause. ### Examples -{% highlight sql %} +```sql CREATE TABLE person (id INT, name STRING, age INT); INSERT INTO person VALUES (100, 'John', 30), @@ -90,12 +82,11 @@ SELECT * FROM person WHERE name NOT LIKE 'M_ry'; |400| Dan| 50| +---+------+---+ -SELECT * FROM person WHERE name RLIKE '[MD]'; +SELECT * FROM person WHERE name RLIKE 'M+'; +---+----+----+ | id|name| age| +---+----+----+ |300|Mike| 80| -|400| Dan| 50| |200|Mary|null| +---+----+----+ @@ -112,9 +103,9 @@ SELECT * FROM person WHERE name LIKE '%$_%' ESCAPE '$'; +---+------+---+ |500|Evan_W| 16| +---+------+---+ -{% endhighlight %} +``` ### Related Statements - * [SELECT](sql-ref-syntax-qry-select.html) - * [WHERE Clause](sql-ref-syntax-qry-select-where.html) +* [SELECT](sql-ref-syntax-qry-select.html) +* [WHERE Clause](sql-ref-syntax-qry-select-where.html) diff --git a/docs/sql-ref-syntax-qry-select-limit.md b/docs/sql-ref-syntax-qry-select-limit.md index eaeaed068102f..03c4df3cbc442 100644 --- a/docs/sql-ref-syntax-qry-select-limit.md +++ b/docs/sql-ref-syntax-qry-select-limit.md @@ -21,34 +21,31 @@ license: | ### Description -The LIMIT clause is used to constrain the number of rows returned by +The `LIMIT` clause is used to constrain the number of rows returned by the [SELECT](sql-ref-syntax-qry-select.html) statement. In general, this clause is used in conjunction with [ORDER BY](sql-ref-syntax-qry-select-orderby.html) to ensure that the results are deterministic. ### Syntax -{% highlight sql %} +```sql LIMIT { ALL | integer_expression } -{% endhighlight %} +``` ### Parameters -
-
ALL
-
+* **ALL** + If specified, the query returns all the rows. In other words, no limit is applied if this option is specified. -
-
integer_expression
-
+ +* **integer_expression** + Specifies a foldable expression that returns an integer. -
-
### Examples -{% highlight sql %} +```sql CREATE TABLE person (name STRING, age INT); INSERT INTO person VALUES ('Zen Hui', 25), @@ -94,16 +91,16 @@ SELECT name, age FROM person ORDER BY name LIMIT length('SPARK'); -- A non-foldable expression as an input to LIMIT is not allowed. SELECT name, age FROM person ORDER BY name LIMIT length(name); - org.apache.spark.sql.AnalysisException: The limit expression must evaluate to a constant value ... -{% endhighlight %} +org.apache.spark.sql.AnalysisException: The limit expression must evaluate to a constant value ... +``` ### Related Statements - * [SELECT Main](sql-ref-syntax-qry-select.html) - * [WHERE Clause](sql-ref-syntax-qry-select-where.html) - * [GROUP BY Clause](sql-ref-syntax-qry-select-groupby.html) - * [HAVING Clause](sql-ref-syntax-qry-select-having.html) - * [ORDER BY Clause](sql-ref-syntax-qry-select-orderby.html) - * [SORT BY Clause](sql-ref-syntax-qry-select-sortby.html) - * [CLUSTER BY Clause](sql-ref-syntax-qry-select-clusterby.html) - * [DISTRIBUTE BY Clause](sql-ref-syntax-qry-select-distribute-by.html) +* [SELECT Main](sql-ref-syntax-qry-select.html) +* [WHERE Clause](sql-ref-syntax-qry-select-where.html) +* [GROUP BY Clause](sql-ref-syntax-qry-select-groupby.html) +* [HAVING Clause](sql-ref-syntax-qry-select-having.html) +* [ORDER BY Clause](sql-ref-syntax-qry-select-orderby.html) +* [SORT BY Clause](sql-ref-syntax-qry-select-sortby.html) +* [CLUSTER BY Clause](sql-ref-syntax-qry-select-clusterby.html) +* [DISTRIBUTE BY Clause](sql-ref-syntax-qry-select-distribute-by.html) diff --git a/docs/sql-ref-syntax-qry-select-orderby.md b/docs/sql-ref-syntax-qry-select-orderby.md index d927177398f7f..85bbe514cdc95 100644 --- a/docs/sql-ref-syntax-qry-select-orderby.md +++ b/docs/sql-ref-syntax-qry-select-orderby.md @@ -21,56 +21,48 @@ license: | ### Description -The ORDER BY clause is used to return the result rows in a sorted manner +The `ORDER BY` clause is used to return the result rows in a sorted manner in the user specified order. Unlike the [SORT BY](sql-ref-syntax-qry-select-sortby.html) clause, this clause guarantees a total order in the output. ### Syntax -{% highlight sql %} +```sql ORDER BY { expression [ sort_direction | nulls_sort_oder ] [ , ... ] } -{% endhighlight %} +``` ### Parameters -
-
ORDER BY
-
- Specifies a comma-separated list of expressions along with optional parameters sort_direction - and nulls_sort_order which are used to sort the rows. -
-
sort_direction
-
+* **ORDER BY** + + Specifies a comma-separated list of expressions along with optional parameters `sort_direction` + and `nulls_sort_order` which are used to sort the rows. + +* **sort_direction** + Optionally specifies whether to sort the rows in ascending or descending - order. The valid values for the sort direction are ASC for ascending - and DESC for descending. If sort direction is not explicitly specified, then by default - rows are sorted ascending.

- Syntax: - - [ ASC | DESC ] - -
-
nulls_sort_order
-
+ order. The valid values for the sort direction are `ASC` for ascending + and `DESC` for descending. If sort direction is not explicitly specified, then by default + rows are sorted ascending. + + **Syntax:** [ ASC `|` DESC ] + +* **nulls_sort_order** + Optionally specifies whether NULL values are returned before/after non-NULL values. If - null_sort_order is not specified, then NULLs sort first if sort order is - ASC and NULLS sort last if sort order is DESC.

-
    -
  1. If NULLS FIRST is specified, then NULL values are returned first - regardless of the sort order.
  2. -
  3. If NULLS LAST is specified, then NULL values are returned last regardless of - the sort order.
  4. -

- Syntax: - - [ NULLS { FIRST | LAST } ] - -
-
+ `null_sort_order` is not specified, then NULLs sort first if sort order is + `ASC` and NULLS sort last if sort order is `DESC`. + + 1. If `NULLS FIRST` is specified, then NULL values are returned first + regardless of the sort order. + 2. If `NULLS LAST` is specified, then NULL values are returned last regardless of + the sort order. + + **Syntax:** `[ NULLS { FIRST | LAST } ]` ### Examples -{% highlight sql %} +```sql CREATE TABLE person (id INT, name STRING, age INT); INSERT INTO person VALUES (100, 'John', 30), @@ -139,15 +131,15 @@ SELECT * FROM person ORDER BY name ASC, age DESC; |200| Mary|null| |300| Mike| 80| +---+-----+----+ -{% endhighlight %} +``` ### Related Statements - * [SELECT Main](sql-ref-syntax-qry-select.html) - * [WHERE Clause](sql-ref-syntax-qry-select-where.html) - * [GROUP BY Clause](sql-ref-syntax-qry-select-groupby.html) - * [HAVING Clause](sql-ref-syntax-qry-select-having.html) - * [SORT BY Clause](sql-ref-syntax-qry-select-sortby.html) - * [CLUSTER BY Clause](sql-ref-syntax-qry-select-clusterby.html) - * [DISTRIBUTE BY Clause](sql-ref-syntax-qry-select-distribute-by.html) - * [LIMIT Clause](sql-ref-syntax-qry-select-limit.html) +* [SELECT Main](sql-ref-syntax-qry-select.html) +* [WHERE Clause](sql-ref-syntax-qry-select-where.html) +* [GROUP BY Clause](sql-ref-syntax-qry-select-groupby.html) +* [HAVING Clause](sql-ref-syntax-qry-select-having.html) +* [SORT BY Clause](sql-ref-syntax-qry-select-sortby.html) +* [CLUSTER BY Clause](sql-ref-syntax-qry-select-clusterby.html) +* [DISTRIBUTE BY Clause](sql-ref-syntax-qry-select-distribute-by.html) +* [LIMIT Clause](sql-ref-syntax-qry-select-limit.html) diff --git a/docs/sql-ref-syntax-qry-sampling.md b/docs/sql-ref-syntax-qry-select-sampling.md similarity index 82% rename from docs/sql-ref-syntax-qry-sampling.md rename to docs/sql-ref-syntax-qry-select-sampling.md index 82f6588e6c504..f9c7b760a239f 100644 --- a/docs/sql-ref-syntax-qry-sampling.md +++ b/docs/sql-ref-syntax-qry-select-sampling.md @@ -26,19 +26,19 @@ The `TABLESAMPLE` statement is used to sample the table. It supports the followi * `TABLESAMPLE`(x `PERCENT`): Sample the table down to the given percentage. Note that percentages are defined as a number between 0 and 100. * `TABLESAMPLE`(`BUCKET` x `OUT OF` y): Sample the table down to a `x` out of `y` fraction. -Note: `TABLESAMPLE` returns the approximate number of rows or fraction requested. +**Note:** `TABLESAMPLE` returns the approximate number of rows or fraction requested. ### Syntax -{% highlight sql %} -TABLESAMPLE ((integer_expression | decimal_expression) PERCENT) - | TABLESAMPLE (integer_expression ROWS) - | TABLESAMPLE (BUCKET integer_expression OUT OF integer_expression) -{% endhighlight %} +```sql +TABLESAMPLE ({ integer_expression | decimal_expression } PERCENT) + | TABLESAMPLE ( integer_expression ROWS ) + | TABLESAMPLE ( BUCKET integer_expression OUT OF integer_expression ) +``` ### Examples -{% highlight sql %} +```sql SELECT * FROM test; +--+----+ |id|name| @@ -87,8 +87,8 @@ SELECT * FROM test TABLESAMPLE (BUCKET 4 OUT OF 10); | 9|Eric| | 6|Mark| +--+----+ -{% endhighlight %} +``` -### Related Statement +### Related Statements - * [SELECT](sql-ref-syntax-qry-select.html) \ No newline at end of file +* [SELECT](sql-ref-syntax-qry-select.html) \ No newline at end of file diff --git a/docs/sql-ref-syntax-qry-select-setops.md b/docs/sql-ref-syntax-qry-select-setops.md index 98c20941d16bf..8cd12c37fa603 100644 --- a/docs/sql-ref-syntax-qry-select-setops.md +++ b/docs/sql-ref-syntax-qry-select-setops.md @@ -35,13 +35,13 @@ Note that input relations must have the same number of columns and compatible da #### Syntax -{% highlight sql %} +```sql [ ( ] relation [ ) ] EXCEPT | MINUS [ ALL | DISTINCT ] [ ( ] relation [ ) ] -{% endhighlight %} +``` #### Examples -{% highlight sql %} +```sql -- Use number1 and number2 tables to demonstrate set operators in this page. SELECT * FROM number1; +---+ @@ -98,7 +98,7 @@ SELECT c FROM number1 MINUS ALL (SELECT c FROM number2); | 3| | 4| +---+ -{% endhighlight %} +``` ### INTERSECT @@ -106,13 +106,13 @@ SELECT c FROM number1 MINUS ALL (SELECT c FROM number2); #### Syntax -{% highlight sql %} +```sql [ ( ] relation [ ) ] INTERSECT [ ALL | DISTINCT ] [ ( ] relation [ ) ] -{% endhighlight %} +``` #### Examples -{% highlight sql %} +```sql (SELECT c FROM number1) INTERSECT (SELECT c FROM number2); +---+ | c| @@ -137,7 +137,7 @@ SELECT c FROM number1 MINUS ALL (SELECT c FROM number2); | 2| | 2| +---+ -{% endhighlight %} +``` ### UNION @@ -145,13 +145,13 @@ SELECT c FROM number1 MINUS ALL (SELECT c FROM number2); #### Syntax -{% highlight sql %} +```sql [ ( ] relation [ ) ] UNION [ ALL | DISTINCT ] [ ( ] relation [ ) ] -{% endhighlight %} +``` ### Examples -{% highlight sql %} +```sql (SELECT c FROM number1) UNION (SELECT c FROM number2); +---+ | c| @@ -189,8 +189,8 @@ SELECT c FROM number1 UNION ALL (SELECT c FROM number2); | 2| | 2| +---+ -{% endhighlight %} +``` ### Related Statements - * [SELECT Statement](sql-ref-syntax-qry-select.html) +* [SELECT Statement](sql-ref-syntax-qry-select.html) diff --git a/docs/sql-ref-syntax-qry-select-sortby.md b/docs/sql-ref-syntax-qry-select-sortby.md index 1dfa10429709e..554bdb569d005 100644 --- a/docs/sql-ref-syntax-qry-select-sortby.md +++ b/docs/sql-ref-syntax-qry-select-sortby.md @@ -21,58 +21,50 @@ license: | ### Description -The SORT BY clause is used to return the result rows sorted +The `SORT BY` clause is used to return the result rows sorted within each partition in the user specified order. When there is more than one partition -SORT BY may return result that is partially ordered. This is different +`SORT BY` may return result that is partially ordered. This is different than [ORDER BY](sql-ref-syntax-qry-select-orderby.html) clause which guarantees a total order of the output. ### Syntax -{% highlight sql %} +```sql SORT BY { expression [ sort_direction | nulls_sort_order ] [ , ... ] } -{% endhighlight %} +``` ### Parameters -
-
SORT BY
-
- Specifies a comma-separated list of expressions along with optional parameters sort_direction - and nulls_sort_order which are used to sort the rows within each partition. -
-
sort_direction
-
+* **SORT BY** + + Specifies a comma-separated list of expressions along with optional parameters `sort_direction` + and `nulls_sort_order` which are used to sort the rows within each partition. + +* **sort_direction** + Optionally specifies whether to sort the rows in ascending or descending - order. The valid values for the sort direction are ASC for ascending - and DESC for descending. If sort direction is not explicitly specified, then by default - rows are sorted ascending.

- Syntax: - - [ ASC | DESC ] - -
-
nulls_sort_order
-
+ order. The valid values for the sort direction are `ASC` for ascending + and `DESC` for descending. If sort direction is not explicitly specified, then by default + rows are sorted ascending. + + **Syntax:** `[ ASC | DESC ]` + +* **nulls_sort_order** + Optionally specifies whether NULL values are returned before/after non-NULL values. If - null_sort_order is not specified, then NULLs sort first if sort order is - ASC and NULLS sort last if sort order is DESC.

-
    -
  1. If NULLS FIRST is specified, then NULL values are returned first - regardless of the sort order.
  2. -
  3. If NULLS LAST is specified, then NULL values are returned last regardless of - the sort order.
  4. -

- Syntax: - - [ NULLS { FIRST | LAST } ] - -
-
+ `null_sort_order` is not specified, then NULLs sort first if sort order is + `ASC` and NULLS sort last if sort order is `DESC`. + + 1. If `NULLS FIRST` is specified, then NULL values are returned first + regardless of the sort order. + 2. If `NULLS LAST` is specified, then NULL values are returned last regardless of + the sort order. + + **Syntax:** `[ NULLS { FIRST | LAST } ]` ### Examples -{% highlight sql %} +```sql CREATE TABLE person (zip_code INT, name STRING, age INT); INSERT INTO person VALUES (94588, 'Zen Hui', 50), @@ -172,15 +164,15 @@ SELECT /*+ REPARTITION(zip_code) */ name, age, zip_code FROM person | David K| 42| 94511| |Lalit B.|null| 94511| +--------+----+--------+ -{% endhighlight %} +``` ### Related Statements - * [SELECT Main](sql-ref-syntax-qry-select.html) - * [WHERE Clause](sql-ref-syntax-qry-select-where.html) - * [GROUP BY Clause](sql-ref-syntax-qry-select-groupby.html) - * [HAVING Clause](sql-ref-syntax-qry-select-having.html) - * [ORDER BY Clause](sql-ref-syntax-qry-select-orderby.html) - * [CLUSTER BY Clause](sql-ref-syntax-qry-select-clusterby.html) - * [DISTRIBUTE BY Clause](sql-ref-syntax-qry-select-distribute-by.html) - * [LIMIT Clause](sql-ref-syntax-qry-select-limit.html) +* [SELECT Main](sql-ref-syntax-qry-select.html) +* [WHERE Clause](sql-ref-syntax-qry-select-where.html) +* [GROUP BY Clause](sql-ref-syntax-qry-select-groupby.html) +* [HAVING Clause](sql-ref-syntax-qry-select-having.html) +* [ORDER BY Clause](sql-ref-syntax-qry-select-orderby.html) +* [CLUSTER BY Clause](sql-ref-syntax-qry-select-clusterby.html) +* [DISTRIBUTE BY Clause](sql-ref-syntax-qry-select-distribute-by.html) +* [LIMIT Clause](sql-ref-syntax-qry-select-limit.html) diff --git a/docs/sql-ref-syntax-qry-select-tvf.md b/docs/sql-ref-syntax-qry-select-tvf.md index 1d9505a2bc11b..cc8d7c34645fb 100644 --- a/docs/sql-ref-syntax-qry-select-tvf.md +++ b/docs/sql-ref-syntax-qry-select-tvf.md @@ -25,28 +25,21 @@ A table-valued function (TVF) is a function that returns a relation or a set of ### Syntax -{% highlight sql %} +```sql function_name ( expression [ , ... ] ) [ table_alias ] -{% endhighlight %} +``` ### Parameters -
-
expression
-
+* **expression** + Specifies a combination of one or more values, operators and SQL functions that results in a value. -
-
-
-
table_alias
-
- Specifies a temporary name with an optional column name list.

- Syntax: - - [ AS ] table_name [ ( column_name [ , ... ] ) ] - -
-
+ +* **table_alias** + + Specifies a temporary name with an optional column name list. + + **Syntax:** `[ AS ] table_name [ ( column_name [ , ... ] ) ]` ### Supported Table-valued Functions @@ -59,7 +52,7 @@ function_name ( expression [ , ... ] ) [ table_alias ] ### Examples -{% highlight sql %} +```sql -- range call with end SELECT * FROM range(6 + cos(3)); +---+ @@ -105,8 +98,8 @@ SELECT * FROM range(5, 8) AS test; | 6| | 7| +---+ -{% endhighlight %} +``` -### Related Statement +### Related Statements - * [SELECT](sql-ref-syntax-qry-select.html) +* [SELECT](sql-ref-syntax-qry-select.html) diff --git a/docs/sql-ref-syntax-qry-select-usedb.md b/docs/sql-ref-syntax-qry-select-usedb.md index bb95a8e4ddf30..90076e0306125 100644 --- a/docs/sql-ref-syntax-qry-select-usedb.md +++ b/docs/sql-ref-syntax-qry-select-usedb.md @@ -28,32 +28,30 @@ The default database name is 'default'. ### Syntax -{% highlight sql %} +```sql USE database_name -{% endhighlight %} +``` ### Parameter -
-
database_name
-
- Name of the database will be used. If the database does not exist, an exception will be thrown. -
-
+* **database_name** + + Name of the database will be used. If the database does not exist, an exception will be thrown. ### Examples -{% highlight sql %} +```sql -- Use the 'userdb' which exists. USE userdb; -- Use the 'userdb1' which doesn't exist USE userdb1; - Error: org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException: Database 'userdb1' not found;(state=,code=0) -{% endhighlight %} +Error: org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException: Database 'userdb1' not found; +(state=,code=0) +``` ### Related Statements - * [CREATE DATABASE](sql-ref-syntax-ddl-create-database.html) - * [DROP DATABASE](sql-ref-syntax-ddl-drop-database.html) - * [CREATE TABLE ](sql-ref-syntax-ddl-create-table.html) +* [CREATE DATABASE](sql-ref-syntax-ddl-create-database.html) +* [DROP DATABASE](sql-ref-syntax-ddl-drop-database.html) +* [CREATE TABLE ](sql-ref-syntax-ddl-create-table.html) diff --git a/docs/sql-ref-syntax-qry-select-where.md b/docs/sql-ref-syntax-qry-select-where.md index 360313fcfff1c..ca3f5ec7866c6 100644 --- a/docs/sql-ref-syntax-qry-select-where.md +++ b/docs/sql-ref-syntax-qry-select-where.md @@ -21,29 +21,26 @@ license: | ### Description -The WHERE clause is used to limit the results of the FROM +The `WHERE` clause is used to limit the results of the `FROM` clause of a query or a subquery based on the specified condition. ### Syntax -{% highlight sql %} +```sql WHERE boolean_expression -{% endhighlight %} +``` ### Parameters -
-
boolean_expression
-
- Specifies any expression that evaluates to a result type boolean. Two or +* **boolean_expression** + + Specifies any expression that evaluates to a result type `boolean`. Two or more expressions may be combined together using the logical - operators ( AND, OR ). -
-
+ operators ( `AND`, `OR` ). ### Examples -{% highlight sql %} +```sql CREATE TABLE person (id INT, name STRING, age INT); INSERT INTO person VALUES (100, 'John', 30), @@ -116,15 +113,15 @@ SELECT * FROM person AS parent +---+----+----+ |200|Mary|null| +---+----+----+ -{% endhighlight %} +``` ### Related Statements - * [SELECT Main](sql-ref-syntax-qry-select.html) - * [GROUP BY Clause](sql-ref-syntax-qry-select-groupby.html) - * [HAVING Clause](sql-ref-syntax-qry-select-having.html) - * [ORDER BY Clause](sql-ref-syntax-qry-select-orderby.html) - * [SORT BY Clause](sql-ref-syntax-qry-select-sortby.html) - * [CLUSTER BY Clause](sql-ref-syntax-qry-select-clusterby.html) - * [DISTRIBUTE BY Clause](sql-ref-syntax-qry-select-distribute-by.html) - * [LIMIT Clause](sql-ref-syntax-qry-select-limit.html) +* [SELECT Main](sql-ref-syntax-qry-select.html) +* [GROUP BY Clause](sql-ref-syntax-qry-select-groupby.html) +* [HAVING Clause](sql-ref-syntax-qry-select-having.html) +* [ORDER BY Clause](sql-ref-syntax-qry-select-orderby.html) +* [SORT BY Clause](sql-ref-syntax-qry-select-sortby.html) +* [CLUSTER BY Clause](sql-ref-syntax-qry-select-clusterby.html) +* [DISTRIBUTE BY Clause](sql-ref-syntax-qry-select-distribute-by.html) +* [LIMIT Clause](sql-ref-syntax-qry-select-limit.html) diff --git a/docs/sql-ref-syntax-qry-window.md b/docs/sql-ref-syntax-qry-select-window.md similarity index 80% rename from docs/sql-ref-syntax-qry-window.md rename to docs/sql-ref-syntax-qry-select-window.md index e3762925760e2..a1c2b18b04fce 100644 --- a/docs/sql-ref-syntax-qry-window.md +++ b/docs/sql-ref-syntax-qry-select-window.md @@ -25,67 +25,52 @@ Window functions operate on a group of rows, referred to as a window, and calcul ### Syntax -{% highlight sql %} +```sql window_function OVER ( [ { PARTITION | DISTRIBUTE } BY partition_col_name = partition_col_val ( [ , ... ] ) ] { ORDER | SORT } BY expression [ ASC | DESC ] [ NULLS { FIRST | LAST } ] [ , ... ] [ window_frame ] ) -{% endhighlight %} +``` ### Parameters -
-
window_function
-
-
    -
  • Ranking Functions
  • -
    - Syntax: - - RANK | DENSE_RANK | PERCENT_RANK | NTILE | ROW_NUMBER - -
-
    -
  • Analytic Functions
  • -
    - Syntax: - - CUME_DIST | LAG | LEAD - -
-
    -
  • Aggregate Functions
  • -
    - Syntax: - - MAX | MIN | COUNT | SUM | AVG | ... - -
    - Please refer to the Built-in Functions document for a complete list of Spark aggregate functions. -
-
-
-
-
window_frame
-
- Specifies which row to start the window on and where to end it.
- Syntax:
- { RANGE | ROWS } { frame_start | BETWEEN frame_start AND frame_end }
- If frame_end is omitted it defaults to CURRENT ROW.

-
    - frame_start and frame_end have the following syntax
    - Syntax:
    - - UNBOUNDED PRECEDING | offset PRECEDING | CURRENT ROW | offset FOLLOWING | UNBOUNDED FOLLOWING -
    - offset:specifies the offset from the position of the current row. -
-
-
+* **window_function** + + * Ranking Functions + + **Syntax:** `RANK | DENSE_RANK | PERCENT_RANK | NTILE | ROW_NUMBER` + + * Analytic Functions + + **Syntax:** `CUME_DIST | LAG | LEAD` + + * Aggregate Functions + + **Syntax:** `MAX | MIN | COUNT | SUM | AVG | ...` + + Please refer to the [Built-in Aggregation Functions](sql-ref-functions-builtin.html#aggregate-functions) document for a complete list of Spark aggregate functions. + +* **window_frame** + + Specifies which row to start the window on and where to end it. + + **Syntax:** + + `{ RANGE | ROWS } { frame_start | BETWEEN frame_start AND frame_end }` + + * `frame_start` and `frame_end` have the following syntax: + + **Syntax:** + + `UNBOUNDED PRECEDING | offset PRECEDING | CURRENT ROW | offset FOLLOWING | UNBOUNDED FOLLOWING` + + `offset:` specifies the `offset` from the position of the current row. + + **Note:** If `frame_end` is omitted it defaults to `CURRENT ROW`. ### Examples -{% highlight sql %} +```sql CREATE TABLE employees (name STRING, dept STRING, salary INT, age INT); INSERT INTO employees VALUES ("Lisa", "Sales", 10000, 35); @@ -199,8 +184,8 @@ SELECT name, salary, | Jane| Marketing| 29000|29000|35000| | Jeff| Marketing| 35000|29000| 0| +-----+-----------+------+-----+-----+ -{% endhighlight %} +``` ### Related Statements - * [SELECT](sql-ref-syntax-qry-select.html) +* [SELECT](sql-ref-syntax-qry-select.html) diff --git a/docs/sql-ref-syntax-qry-select.md b/docs/sql-ref-syntax-qry-select.md index bc2cc0269124e..987e6479ab20a 100644 --- a/docs/sql-ref-syntax-qry-select.md +++ b/docs/sql-ref-syntax-qry-select.md @@ -28,7 +28,7 @@ of a query along with examples. ### Syntax -{% highlight sql %} +```sql [ WITH with_query [ , ... ] ] select_statement [ { UNION | INTERSECT | EXCEPT } [ ALL | DISTINCT ] select_statement, ... ] [ ORDER BY { expression [ ASC | DESC ] [ NULLS { FIRST | LAST } ] [ , ...] } ] @@ -37,126 +37,125 @@ select_statement [ { UNION | INTERSECT | EXCEPT } [ ALL | DISTINCT ] select_stat [ DISTRIBUTE BY { expression [, ...] } ] [ WINDOW { named_window [ , WINDOW named_window, ... ] } ] [ LIMIT { ALL | expression } ] -{% endhighlight %} +``` While `select_statement` is defined as -{% highlight sql %} +```sql SELECT [ hints , ... ] [ ALL | DISTINCT ] { named_expression [ , ... ] } FROM { from_item [ , ...] } [ WHERE boolean_expression ] [ GROUP BY expression [ , ...] ] [ HAVING boolean_expression ] -{% endhighlight %} +``` ### Parameters -
-
with_query
-
- Specifies the common table expressions (CTEs) before the main query block. +* **with_query** + + Specifies the [common table expressions (CTEs)](sql-ref-syntax-qry-select-cte.html) before the main query block. These table expressions are allowed to be referenced later in the FROM clause. This is useful to abstract out repeated subquery blocks in the FROM clause and improves readability of the query. -
-
hints
-
+ +* **hints** + Hints can be specified to help spark optimizer make better planning decisions. Currently spark supports hints that influence selection of join strategies and repartitioning of the data. -
-
ALL
-
+ +* **ALL** + Select all matching rows from the relation and is enabled by default. -
-
DISTINCT
-
+ +* **DISTINCT** + Select all matching rows from the relation after removing duplicates in results. -
-
named_expression
-
- An expression with an assigned name. In general, it denotes a column expression.

- Syntax: - - expression [AS] [alias] - -
-
from_item
-
- Specifies a source of input for the query. It can be one of the following: -
    -
  1. Table relation
  2. -
  3. Join relation
  4. -
  5. Table-value function
  6. -
  7. Inline table
  8. -
  9. Subquery
  10. -
-
-
WHERE
-
- Filters the result of the FROM clause based on the supplied predicates. -
-
GROUP BY
-
- Specifies the expressions that are used to group the rows. This is used in conjunction with aggregate functions - (MIN, MAX, COUNT, SUM, AVG, etc.) to group rows based on the grouping expressions and aggregate values in each group. - When a FILTER clause is attached to an aggregate function, only the matching rows are passed to that function. -
-
HAVING
-
- Specifies the predicates by which the rows produced by GROUP BY are filtered. The HAVING clause is used to - filter rows after the grouping is performed. If HAVING is specified without GROUP BY, it indicates a GROUP BY - without grouping expressions (global aggregate). -
-
ORDER BY
-
- Specifies an ordering of the rows of the complete result set of the query. The output rows are ordered - across the partitions. This parameter is mutually exclusive with SORT BY, - CLUSTER BY and DISTRIBUTE BY and can not be specified together. -
-
SORT BY
-
- Specifies an ordering by which the rows are ordered within each partition. This parameter is mutually - exclusive with ORDER BY and CLUSTER BY and can not be specified together. -
-
CLUSTER BY
-
- Specifies a set of expressions that is used to repartition and sort the rows. Using this clause has - the same effect of using DISTRIBUTE BY and SORT BY together. -
-
DISTRIBUTE BY
-
- Specifies a set of expressions by which the result rows are repartitioned. This parameter is mutually - exclusive with ORDER BY and CLUSTER BY and can not be specified together. -
-
LIMIT
-
- Specifies the maximum number of rows that can be returned by a statement or subquery. This clause - is mostly used in the conjunction with ORDER BY to produce a deterministic result. -
-
boolean_expression
-
- Specifies an expression with a return type of boolean. -
-
expression
-
- Specifies a combination of one or more values, operators, and SQL functions that evaluates to a value. -
-
named_window
-
- Specifies aliases for one or more source window specifications. The source window specifications can - be referenced in the widow definitions in the query. -
-
+ +* **named_expression** + + An expression with an assigned name. In general, it denotes a column expression. + + **Syntax:** `expression [AS] [alias]` + + * **from_item** + + Specifies a source of input for the query. It can be one of the following: + * Table relation + * [Join relation](sql-ref-syntax-qry-select-join.html) + * [Table-value function](sql-ref-syntax-qry-select-tvf.html) + * [Inline table](sql-ref-syntax-qry-select-inline-table.html) + * Subquery + + + * **WHERE** + + Filters the result of the FROM clause based on the supplied predicates. + + * **GROUP BY** + + Specifies the expressions that are used to group the rows. This is used in conjunction with aggregate functions + (MIN, MAX, COUNT, SUM, AVG, etc.) to group rows based on the grouping expressions and aggregate values in each group. + When a FILTER clause is attached to an aggregate function, only the matching rows are passed to that function. + + * **HAVING** + + Specifies the predicates by which the rows produced by GROUP BY are filtered. The HAVING clause is used to + filter rows after the grouping is performed. If HAVING is specified without GROUP BY, it indicates a GROUP BY + without grouping expressions (global aggregate). + + * **ORDER BY** + + Specifies an ordering of the rows of the complete result set of the query. The output rows are ordered + across the partitions. This parameter is mutually exclusive with `SORT BY`, + `CLUSTER BY` and `DISTRIBUTE BY` and can not be specified together. + + * **SORT BY** + + Specifies an ordering by which the rows are ordered within each partition. This parameter is mutually + exclusive with `ORDER BY` and `CLUSTER BY` and can not be specified together. + + * **CLUSTER BY** + + Specifies a set of expressions that is used to repartition and sort the rows. Using this clause has + the same effect of using `DISTRIBUTE BY` and `SORT BY` together. + + * **DISTRIBUTE BY** + + Specifies a set of expressions by which the result rows are repartitioned. This parameter is mutually + exclusive with `ORDER BY` and `CLUSTER BY` and can not be specified together. + + * **LIMIT** + + Specifies the maximum number of rows that can be returned by a statement or subquery. This clause + is mostly used in the conjunction with `ORDER BY` to produce a deterministic result. + + * **boolean_expression** + + Specifies an expression with a return type of boolean. + + * **expression** + + Specifies a combination of one or more values, operators, and SQL functions that evaluates to a value. + + * **named_window** + + Specifies aliases for one or more source window specifications. The source window specifications can + be referenced in the widow definitions in the query. ### Related Statements - * [WHERE Clause](sql-ref-syntax-qry-select-where.html) - * [GROUP BY Clause](sql-ref-syntax-qry-select-groupby.html) - * [HAVING Clause](sql-ref-syntax-qry-select-having.html) - * [ORDER BY Clause](sql-ref-syntax-qry-select-orderby.html) - * [SORT BY Clause](sql-ref-syntax-qry-select-sortby.html) - * [CLUSTER BY Clause](sql-ref-syntax-qry-select-clusterby.html) - * [DISTRIBUTE BY Clause](sql-ref-syntax-qry-select-distribute-by.html) - * [LIMIT Clause](sql-ref-syntax-qry-select-limit.html) - * [TABLESAMPLE](sql-ref-syntax-qry-sampling.html) - * [JOIN](sql-ref-syntax-qry-select-join.html) - * [SET Operators](sql-ref-syntax-qry-select-setops.html) - * [Common Table Expression](sql-ref-syntax-qry-select-cte.html) +* [WHERE Clause](sql-ref-syntax-qry-select-where.html) +* [GROUP BY Clause](sql-ref-syntax-qry-select-groupby.html) +* [HAVING Clause](sql-ref-syntax-qry-select-having.html) +* [ORDER BY Clause](sql-ref-syntax-qry-select-orderby.html) +* [SORT BY Clause](sql-ref-syntax-qry-select-sortby.html) +* [CLUSTER BY Clause](sql-ref-syntax-qry-select-clusterby.html) +* [DISTRIBUTE BY Clause](sql-ref-syntax-qry-select-distribute-by.html) +* [LIMIT Clause](sql-ref-syntax-qry-select-limit.html) +* [Common Table Expression](sql-ref-syntax-qry-select-cte.html) +* [Hints](sql-ref-syntax-qry-select-hints.html) +* [Inline Table](sql-ref-syntax-qry-select-inline-table.html) +* [JOIN](sql-ref-syntax-qry-select-join.html) +* [LIKE Predicate](sql-ref-syntax-qry-select-like.html) +* [Set Operators](sql-ref-syntax-qry-select-setops.html) +* [TABLESAMPLE](sql-ref-syntax-qry-select-sampling.html) +* [Table-valued Function](sql-ref-syntax-qry-select-tvf.html) +* [Window Function](sql-ref-syntax-qry-select-window.html) diff --git a/docs/sql-ref-syntax-qry.md b/docs/sql-ref-syntax-qry.md index 325c9b69f12f9..167c394d0fe49 100644 --- a/docs/sql-ref-syntax-qry.md +++ b/docs/sql-ref-syntax-qry.md @@ -27,20 +27,22 @@ to SELECT are also included in this section. Spark also provides the ability to generate logical and physical plan for a given query using [EXPLAIN](sql-ref-syntax-qry-explain.html) statement. - * [WHERE Clause](sql-ref-syntax-qry-select-where.html) - * [GROUP BY Clause](sql-ref-syntax-qry-select-groupby.html) - * [HAVING Clause](sql-ref-syntax-qry-select-having.html) - * [ORDER BY Clause](sql-ref-syntax-qry-select-orderby.html) - * [SORT BY Clause](sql-ref-syntax-qry-select-sortby.html) - * [CLUSTER BY Clause](sql-ref-syntax-qry-select-clusterby.html) - * [DISTRIBUTE BY Clause](sql-ref-syntax-qry-select-distribute-by.html) - * [LIMIT Clause](sql-ref-syntax-qry-select-limit.html) - * [JOIN](sql-ref-syntax-qry-select-join.html) - * [Join Hints](sql-ref-syntax-qry-select-hints.html) - * [Set Operators](sql-ref-syntax-qry-select-setops.html) - * [TABLESAMPLE](sql-ref-syntax-qry-sampling.html) - * [Table-valued Function](sql-ref-syntax-qry-select-tvf.html) - * [Inline Table](sql-ref-syntax-qry-select-inline-table.html) - * [Common Table Expression](sql-ref-syntax-qry-select-cte.html) - * [Window Function](sql-ref-syntax-qry-window.html) - * [EXPLAIN Statement](sql-ref-syntax-qry-explain.html) +* [SELECT Statement](sql-ref-syntax-qry-select.html) + * [WHERE Clause](sql-ref-syntax-qry-select-where.html) + * [GROUP BY Clause](sql-ref-syntax-qry-select-groupby.html) + * [HAVING Clause](sql-ref-syntax-qry-select-having.html) + * [ORDER BY Clause](sql-ref-syntax-qry-select-orderby.html) + * [SORT BY Clause](sql-ref-syntax-qry-select-sortby.html) + * [CLUSTER BY Clause](sql-ref-syntax-qry-select-clusterby.html) + * [DISTRIBUTE BY Clause](sql-ref-syntax-qry-select-distribute-by.html) + * [LIMIT Clause](sql-ref-syntax-qry-select-limit.html) + * [Common Table Expression](sql-ref-syntax-qry-select-cte.html) + * [Hints](sql-ref-syntax-qry-select-hints.html) + * [Inline Table](sql-ref-syntax-qry-select-inline-table.html) + * [JOIN](sql-ref-syntax-qry-select-join.html) + * [LIKE Predicate](sql-ref-syntax-qry-select-like.html) + * [Set Operators](sql-ref-syntax-qry-select-setops.html) + * [TABLESAMPLE](sql-ref-syntax-qry-select-sampling.html) + * [Table-valued Function](sql-ref-syntax-qry-select-tvf.html) + * [Window Function](sql-ref-syntax-qry-select-window.html) +* [EXPLAIN Statement](sql-ref-syntax-qry-explain.html) diff --git a/docs/sql-ref-syntax.md b/docs/sql-ref-syntax.md index 94bd476ffb7b1..d78a01fd655a2 100644 --- a/docs/sql-ref-syntax.md +++ b/docs/sql-ref-syntax.md @@ -48,15 +48,25 @@ Spark SQL is Apache Spark's module for working with structured data. The SQL Syn ### Data Retrieval Statements - * [CLUSTER BY Clause](sql-ref-syntax-qry-select-clusterby.html) - * [DISTRIBUTE BY Clause](sql-ref-syntax-qry-select-distribute-by.html) + * [SELECT Statement](sql-ref-syntax-qry-select.html) + * [Common Table Expression](sql-ref-syntax-qry-select-cte.html) + * [CLUSTER BY Clause](sql-ref-syntax-qry-select-clusterby.html) + * [DISTRIBUTE BY Clause](sql-ref-syntax-qry-select-distribute-by.html) + * [GROUP BY Clause](sql-ref-syntax-qry-select-groupby.html) + * [HAVING Clause](sql-ref-syntax-qry-select-having.html) + * [Hints](sql-ref-syntax-qry-select-hints.html) + * [Inline Table](sql-ref-syntax-qry-select-inline-table.html) + * [JOIN](sql-ref-syntax-qry-select-join.html) + * [LIKE Predicate](sql-ref-syntax-qry-select-like.html) + * [LIMIT Clause](sql-ref-syntax-qry-select-limit.html) + * [ORDER BY Clause](sql-ref-syntax-qry-select-orderby.html) + * [Set Operators](sql-ref-syntax-qry-select-setops.html) + * [SORT BY Clause](sql-ref-syntax-qry-select-sortby.html) + * [TABLESAMPLE](sql-ref-syntax-qry-select-sampling.html) + * [Table-valued Function](sql-ref-syntax-qry-select-tvf.html) + * [WHERE Clause](sql-ref-syntax-qry-select-where.html) + * [Window Function](sql-ref-syntax-qry-select-window.html) * [EXPLAIN](sql-ref-syntax-qry-explain.html) - * [GROUP BY Clause](sql-ref-syntax-qry-select-groupby.html) - * [HAVING Clause](sql-ref-syntax-qry-select-having.html) - * [LIMIT Clause](sql-ref-syntax-qry-select-limit.html) - * [ORDER BY Clause](sql-ref-syntax-qry-select-orderby.html) - * [SORT BY Clause](sql-ref-syntax-qry-select-sortby.html) - * [WHERE Clause](sql-ref-syntax-qry-select-where.html) ### Auxiliary Statements @@ -73,6 +83,7 @@ Spark SQL is Apache Spark's module for working with structured data. The SQL Syn * [LIST JAR](sql-ref-syntax-aux-resource-mgmt-list-jar.html) * [REFRESH](sql-ref-syntax-aux-cache-refresh.html) * [REFRESH TABLE](sql-ref-syntax-aux-refresh-table.html) + * [RESET](sql-ref-syntax-aux-conf-mgmt-reset.html) * [SET](sql-ref-syntax-aux-conf-mgmt-set.html) * [SHOW COLUMNS](sql-ref-syntax-aux-show-columns.html) * [SHOW CREATE TABLE](sql-ref-syntax-aux-show-create-table.html) @@ -84,4 +95,3 @@ Spark SQL is Apache Spark's module for working with structured data. The SQL Syn * [SHOW TBLPROPERTIES](sql-ref-syntax-aux-show-tblproperties.html) * [SHOW VIEWS](sql-ref-syntax-aux-show-views.html) * [UNCACHE TABLE](sql-ref-syntax-aux-cache-uncache-table.html) - * [UNSET](sql-ref-syntax-aux-conf-mgmt-reset.html) diff --git a/docs/sql-ref.md b/docs/sql-ref.md index db51fe1978eec..f88026b7abf02 100644 --- a/docs/sql-ref.md +++ b/docs/sql-ref.md @@ -21,19 +21,19 @@ license: | Spark SQL is Apache Spark's module for working with structured data. This guide is a reference for Structured Query Language (SQL) and includes syntax, semantics, keywords, and examples for common SQL usage. It contains information for the following topics: + * [ANSI Compliance](sql-ref-ansi-compliance.html) * [Data Types](sql-ref-datatypes.html) + * [Datetime Pattern](sql-ref-datetime-pattern.html) + * [Functions](sql-ref-functions.html) + * [Built-in Functions](sql-ref-functions-builtin.html) + * [Scalar User-Defined Functions (UDFs)](sql-ref-functions-udf-scalar.html) + * [User-Defined Aggregate Functions (UDAFs)](sql-ref-functions-udf-aggregate.html) + * [Integration with Hive UDFs/UDAFs/UDTFs](sql-ref-functions-udf-hive.html) * [Identifiers](sql-ref-identifier.html) * [Literals](sql-ref-literals.html) * [Null Semanitics](sql-ref-null-semantics.html) - * [ANSI Compliance](sql-ref-ansi-compliance.html) * [SQL Syntax](sql-ref-syntax.html) * [DDL Statements](sql-ref-syntax-ddl.html) - * [DML Statements](sql-ref-syntax-ddl.html) + * [DML Statements](sql-ref-syntax-dml.html) * [Data Retrieval Statements](sql-ref-syntax-qry.html) * [Auxiliary Statements](sql-ref-syntax-aux.html) - * [Functions](sql-ref-functions.html) - * [Built-in Functions](sql-ref-functions-builtin.html) - * [Scalar User-Defined Functions (UDFs)](sql-ref-functions-udf-scalar.html) - * [User-Defined Aggregate Functions (UDAFs)](sql-ref-functions-udf-aggregate.html) - * [Integration with Hive UDFs/UDAFs/UDTFs](sql-ref-functions-udf-hive.html) - * [Datetime Pattern](sql-ref-datetime-pattern.html) diff --git a/docs/web-ui.md b/docs/web-ui.md index 3c35dbeec86a2..e2e612cef3e54 100644 --- a/docs/web-ui.md +++ b/docs/web-ui.md @@ -407,6 +407,34 @@ Here is the list of SQL metrics: +## Structured Streaming Tab +When running Structured Streaming jobs in micro-batch mode, a Structured Streaming tab will be +available on the Web UI. The overview page displays some brief statistics for running and completed +queries. Also, you can check the latest exception of a failed query. For detailed statistics, please +click a "run id" in the tables. + +

+ Structured Streaming Query Statistics +

+ +The statistics page displays some useful metrics for insight into the status of your streaming +queries. Currently, it contains the following metrics. + +* **Input Rate.** The aggregate (across all sources) rate of data arriving. +* **Process Rate.** The aggregate (across all sources) rate at which Spark is processing data. +* **Input Rows.** The aggregate (across all sources) number of records processed in a trigger. +* **Batch Duration.** The process duration of each batch. +* **Operation Duration.** The amount of time taken to perform various operations in milliseconds. +The tracked operations are listed as follows. + * addBatch: Adds result data of the current batch to the sink. + * getBatch: Gets a new batch of data to process. + * latestOffset: Gets the latest offsets for sources. + * queryPlanning: Generates the execution plan. + * walCommit: Writes the offsets to the metadata log. + +As an early-release version, the statistics page is still under development and will be improved in +future releases. + ## Streaming Tab The web UI includes a Streaming tab if the application uses Spark streaming. This tab displays scheduling delay and processing time for each micro-batch in the data stream, which can be useful diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaANOVASelectorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaANOVASelectorExample.java new file mode 100644 index 0000000000000..6f24b4571b5e7 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaANOVASelectorExample.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.SparkSession; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.feature.ANOVASelector; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.linalg.Vectors; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; +// $example off$ + +/** + * An example for ANOVASelector. + * Run with + *
+ * bin/run-example ml.JavaANOVASelectorExample
+ * 
+ */ +public class JavaANOVASelectorExample { + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaANOVASelectorExample") + .getOrCreate(); + + // $example on$ + List data = Arrays.asList( + RowFactory.create(1, Vectors.dense(1.7, 4.4, 7.6, 5.8, 9.6, 2.3), 3.0), + RowFactory.create(2, Vectors.dense(8.8, 7.3, 5.7, 7.3, 2.2, 4.1), 2.0), + RowFactory.create(3, Vectors.dense(1.2, 9.5, 2.5, 3.1, 8.7, 2.5), 3.0), + RowFactory.create(4, Vectors.dense(3.7, 9.2, 6.1, 4.1, 7.5, 3.8), 2.0), + RowFactory.create(5, Vectors.dense(8.9, 5.2, 7.8, 8.3, 5.2, 3.0), 4.0), + RowFactory.create(6, Vectors.dense(7.9, 8.5, 9.2, 4.0, 9.4, 2.1), 4.0) + ); + StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("features", new VectorUDT(), false, Metadata.empty()), + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()) + }); + + Dataset df = spark.createDataFrame(data, schema); + + ANOVASelector selector = new ANOVASelector() + .setNumTopFeatures(1) + .setFeaturesCol("features") + .setLabelCol("label") + .setOutputCol("selectedFeatures"); + + Dataset result = selector.fit(df).transform(df); + + System.out.println("ANOVASelector output with top " + selector.getNumTopFeatures() + + " features selected"); + result.show(); + + // $example off$ + spark.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaANOVATestExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaANOVATestExample.java index 3b2de1f39cc88..4785dbd34f5d4 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaANOVATestExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaANOVATestExample.java @@ -51,7 +51,7 @@ public static void main(String[] args) { List data = Arrays.asList( RowFactory.create(3.0, Vectors.dense(1.7, 4.4, 7.6, 5.8, 9.6, 2.3)), RowFactory.create(2.0, Vectors.dense(8.8, 7.3, 5.7, 7.3, 2.2, 4.1)), - RowFactory.create(1.0, Vectors.dense(1.2, 9.5, 2.5, 3.1, 8.7, 2.5)), + RowFactory.create(3.0, Vectors.dense(1.2, 9.5, 2.5, 3.1, 8.7, 2.5)), RowFactory.create(2.0, Vectors.dense(3.7, 9.2, 6.1, 4.1, 7.5, 3.8)), RowFactory.create(4.0, Vectors.dense(8.9, 5.2, 7.8, 8.3, 5.2, 3.0)), RowFactory.create(4.0, Vectors.dense(7.9, 8.5, 9.2, 4.0, 9.4, 2.1)) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaFValueSelectorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaFValueSelectorExample.java new file mode 100644 index 0000000000000..e8253ff0836cf --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaFValueSelectorExample.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.SparkSession; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.feature.FValueSelector; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.linalg.Vectors; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; +// $example off$ + +/** + * An example demonstrating FValueSelector. + * Run with + *
+ * bin/run-example ml.JavaFValueSelectorExample
+ * 
+ */ +public class JavaFValueSelectorExample { + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaFValueSelectorExample") + .getOrCreate(); + + // $example on$ + List data = Arrays.asList( + RowFactory.create(1, Vectors.dense(6.0, 7.0, 0.0, 7.0, 6.0, 0.0), 4.6), + RowFactory.create(2, Vectors.dense(0.0, 9.0, 6.0, 0.0, 5.0, 9.0), 6.6), + RowFactory.create(3, Vectors.dense(0.0, 9.0, 3.0, 0.0, 5.0, 5.0), 5.1), + RowFactory.create(4, Vectors.dense(0.0, 9.0, 8.0, 5.0, 6.0, 4.0), 7.6), + RowFactory.create(5, Vectors.dense(8.0, 9.0, 6.0, 5.0, 4.0, 4.0), 9.0), + RowFactory.create(6, Vectors.dense(8.0, 9.0, 6.0, 4.0, 0.0, 0.0), 9.0) + ); + StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("features", new VectorUDT(), false, Metadata.empty()), + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()) + }); + + Dataset df = spark.createDataFrame(data, schema); + + FValueSelector selector = new FValueSelector() + .setNumTopFeatures(1) + .setFeaturesCol("features") + .setLabelCol("label") + .setOutputCol("selectedFeatures"); + + Dataset result = selector.fit(df).transform(df); + + System.out.println("FValueSelector output with top " + selector.getNumTopFeatures() + + " features selected"); + result.show(); + + // $example off$ + spark.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaFValueTestExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaFValueTestExample.java index 11861ac8a5110..cda28dbc7e966 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaFValueTestExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaFValueTestExample.java @@ -66,7 +66,7 @@ public static void main(String[] args) { Row r = FValueTest.test(df, "features", "label").head(); System.out.println("pValues: " + r.get(0).toString()); System.out.println("degreesOfFreedom: " + r.getList(1).toString()); - System.out.println("fvalue: " + r.get(2).toString()); + System.out.println("fvalues: " + r.get(2).toString()); // $example off$ diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVarianceThresholdSelectorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVarianceThresholdSelectorExample.java new file mode 100644 index 0000000000000..5820a95114eb5 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVarianceThresholdSelectorExample.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.SparkSession; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.feature.VarianceThresholdSelector; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.linalg.Vectors; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; +// $example off$ + +/** + * An example for VarianceThresholdSelector. + * Run with + *
+ * bin/run-example ml.JavaVarianceThresholdSelectorExample
+ * 
+ */ +public class JavaVarianceThresholdSelectorExample { + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaVarianceThresholdSelectorExample") + .getOrCreate(); + + // $example on$ + List data = Arrays.asList( + RowFactory.create(1, Vectors.dense(6.0, 7.0, 0.0, 7.0, 6.0, 0.0)), + RowFactory.create(2, Vectors.dense(0.0, 9.0, 6.0, 0.0, 5.0, 9.0)), + RowFactory.create(3, Vectors.dense(0.0, 9.0, 3.0, 0.0, 5.0, 5.0)), + RowFactory.create(4, Vectors.dense(0.0, 9.0, 8.0, 5.0, 6.0, 4.0)), + RowFactory.create(5, Vectors.dense(8.0, 9.0, 6.0, 5.0, 4.0, 4.0)), + RowFactory.create(6, Vectors.dense(8.0, 9.0, 6.0, 0.0, 0.0, 0.0)) + ); + StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("features", new VectorUDT(), false, Metadata.empty()) + }); + + Dataset df = spark.createDataFrame(data, schema); + + VarianceThresholdSelector selector = new VarianceThresholdSelector() + .setVarianceThreshold(8.0) + .setFeaturesCol("features") + .setOutputCol("selectedFeatures"); + + Dataset result = selector.fit(df).transform(df); + + System.out.println("Output: Features with variance lower than " + + selector.getVarianceThreshold() + " are removed."); + result.show(); + + // $example off$ + spark.stop(); + } +} diff --git a/examples/src/main/python/ml/anova_selector_example.py b/examples/src/main/python/ml/anova_selector_example.py new file mode 100644 index 0000000000000..f8458f5d6e487 --- /dev/null +++ b/examples/src/main/python/ml/anova_selector_example.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +An example for ANOVASelector. +Run with: + bin/spark-submit examples/src/main/python/ml/anova_selector_example.py +""" +from __future__ import print_function + +from pyspark.sql import SparkSession +# $example on$ +from pyspark.ml.feature import ANOVASelector +from pyspark.ml.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + spark = SparkSession\ + .builder\ + .appName("ANOVASelectorExample")\ + .getOrCreate() + + # $example on$ + df = spark.createDataFrame([ + (1, Vectors.dense([1.7, 4.4, 7.6, 5.8, 9.6, 2.3]), 3.0,), + (2, Vectors.dense([8.8, 7.3, 5.7, 7.3, 2.2, 4.1]), 2.0,), + (3, Vectors.dense([1.2, 9.5, 2.5, 3.1, 8.7, 2.5]), 3.0,), + (4, Vectors.dense([3.7, 9.2, 6.1, 4.1, 7.5, 3.8]), 2.0,), + (5, Vectors.dense([8.9, 5.2, 7.8, 8.3, 5.2, 3.0]), 4.0,), + (6, Vectors.dense([7.9, 8.5, 9.2, 4.0, 9.4, 2.1]), 4.0,)], ["id", "features", "label"]) + + selector = ANOVASelector(numTopFeatures=1, featuresCol="features", + outputCol="selectedFeatures", labelCol="label") + + result = selector.fit(df).transform(df) + + print("ANOVASelector output with top %d features selected" % selector.getNumTopFeatures()) + result.show() + # $example off$ + + spark.stop() diff --git a/examples/src/main/python/ml/anova_test_example.py b/examples/src/main/python/ml/anova_test_example.py index 3fffdbddf3aca..4119441cdeab6 100644 --- a/examples/src/main/python/ml/anova_test_example.py +++ b/examples/src/main/python/ml/anova_test_example.py @@ -37,7 +37,7 @@ # $example on$ data = [(3.0, Vectors.dense([1.7, 4.4, 7.6, 5.8, 9.6, 2.3])), (2.0, Vectors.dense([8.8, 7.3, 5.7, 7.3, 2.2, 4.1])), - (1.0, Vectors.dense([1.2, 9.5, 2.5, 3.1, 8.7, 2.5])), + (3.0, Vectors.dense([1.2, 9.5, 2.5, 3.1, 8.7, 2.5])), (2.0, Vectors.dense([3.7, 9.2, 6.1, 4.1, 7.5, 3.8])), (4.0, Vectors.dense([8.9, 5.2, 7.8, 8.3, 5.2, 3.0])), (4.0, Vectors.dense([7.9, 8.5, 9.2, 4.0, 9.4, 2.1]))] diff --git a/examples/src/main/python/ml/fvalue_selector_example.py b/examples/src/main/python/ml/fvalue_selector_example.py new file mode 100644 index 0000000000000..3158953a5dfc4 --- /dev/null +++ b/examples/src/main/python/ml/fvalue_selector_example.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +An example for FValueSelector. +Run with: + bin/spark-submit examples/src/main/python/ml/fvalue_selector_example.py +""" +from __future__ import print_function + +from pyspark.sql import SparkSession +# $example on$ +from pyspark.ml.feature import FValueSelector +from pyspark.ml.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + spark = SparkSession\ + .builder\ + .appName("FValueSelectorExample")\ + .getOrCreate() + + # $example on$ + df = spark.createDataFrame([ + (1, Vectors.dense([6.0, 7.0, 0.0, 7.0, 6.0, 0.0]), 4.6,), + (2, Vectors.dense([0.0, 9.0, 6.0, 0.0, 5.0, 9.0]), 6.6,), + (3, Vectors.dense([0.0, 9.0, 3.0, 0.0, 5.0, 5.0]), 5.1,), + (4, Vectors.dense([0.0, 9.0, 8.0, 5.0, 6.0, 4.0]), 7.6,), + (5, Vectors.dense([8.0, 9.0, 6.0, 5.0, 4.0, 4.0]), 9.0,), + (6, Vectors.dense([8.0, 9.0, 6.0, 4.0, 0.0, 0.0]), 9.0,)], ["id", "features", "label"]) + + selector = FValueSelector(numTopFeatures=1, featuresCol="features", + outputCol="selectedFeatures", labelCol="label") + + result = selector.fit(df).transform(df) + + print("FValueSelector output with top %d features selected" % selector.getNumTopFeatures()) + result.show() + # $example off$ + + spark.stop() diff --git a/examples/src/main/python/ml/fvalue_test_example.py b/examples/src/main/python/ml/fvalue_test_example.py index 4a97bcdc87f11..410b39e4493f8 100644 --- a/examples/src/main/python/ml/fvalue_test_example.py +++ b/examples/src/main/python/ml/fvalue_test_example.py @@ -46,7 +46,7 @@ ftest = FValueTest.test(df, "features", "label").head() print("pValues: " + str(ftest.pValues)) print("degreesOfFreedom: " + str(ftest.degreesOfFreedom)) - print("fvalue: " + str(ftest.fValues)) + print("fvalues: " + str(ftest.fValues)) # $example off$ spark.stop() diff --git a/examples/src/main/python/ml/variance_threshold_selector_example.py b/examples/src/main/python/ml/variance_threshold_selector_example.py new file mode 100644 index 0000000000000..b7edb86653530 --- /dev/null +++ b/examples/src/main/python/ml/variance_threshold_selector_example.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +An example for VarianceThresholdSelector. +Run with: + bin/spark-submit examples/src/main/python/ml/variance_threshold_selector_example.py +""" +from __future__ import print_function + +from pyspark.sql import SparkSession +# $example on$ +from pyspark.ml.feature import VarianceThresholdSelector +from pyspark.ml.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + spark = SparkSession\ + .builder\ + .appName("VarianceThresholdSelectorExample")\ + .getOrCreate() + + # $example on$ + df = spark.createDataFrame([ + (1, Vectors.dense([6.0, 7.0, 0.0, 7.0, 6.0, 0.0])), + (2, Vectors.dense([0.0, 9.0, 6.0, 0.0, 5.0, 9.0])), + (3, Vectors.dense([0.0, 9.0, 3.0, 0.0, 5.0, 5.0])), + (4, Vectors.dense([0.0, 9.0, 8.0, 5.0, 6.0, 4.0])), + (5, Vectors.dense([8.0, 9.0, 6.0, 5.0, 4.0, 4.0])), + (6, Vectors.dense([8.0, 9.0, 6.0, 0.0, 0.0, 0.0]))], ["id", "features"]) + + selector = VarianceThresholdSelector(varianceThreshold=8.0, outputCol="selectedFeatures") + + result = selector.fit(df).transform(df) + + print("Output: Features with variance lower than %f are removed." % + selector.getVarianceThreshold()) + result.show() + # $example off$ + + spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ANOVASelectorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ANOVASelectorExample.scala new file mode 100644 index 0000000000000..46803cc78e767 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ANOVASelectorExample.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.ANOVASelector +import org.apache.spark.ml.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SparkSession + +/** + * An example for ANOVASelector. + * Run with + * {{{ + * bin/run-example ml.ANOVASelectorExample + * }}} + */ +object ANOVASelectorExample { + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder + .appName("ANOVASelectorExample") + .getOrCreate() + import spark.implicits._ + + // $example on$ + val data = Seq( + (1, Vectors.dense(1.7, 4.4, 7.6, 5.8, 9.6, 2.3), 3.0), + (2, Vectors.dense(8.8, 7.3, 5.7, 7.3, 2.2, 4.1), 2.0), + (3, Vectors.dense(1.2, 9.5, 2.5, 3.1, 8.7, 2.5), 3.0), + (4, Vectors.dense(3.7, 9.2, 6.1, 4.1, 7.5, 3.8), 2.0), + (5, Vectors.dense(8.9, 5.2, 7.8, 8.3, 5.2, 3.0), 4.0), + (6, Vectors.dense(7.9, 8.5, 9.2, 4.0, 9.4, 2.1), 4.0) + ) + + val df = spark.createDataset(data).toDF("id", "features", "label") + + val selector = new ANOVASelector() + .setNumTopFeatures(1) + .setFeaturesCol("features") + .setLabelCol("label") + .setOutputCol("selectedFeatures") + + val result = selector.fit(df).transform(df) + + println(s"ANOVASelector output with top ${selector.getNumTopFeatures} features selected") + result.show() + // $example off$ + + spark.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ANOVATestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ANOVATestExample.scala index 0cd793f5b7b88..f0b9f23514d93 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ANOVATestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ANOVATestExample.scala @@ -44,7 +44,7 @@ object ANOVATestExample { val data = Seq( (3.0, Vectors.dense(1.7, 4.4, 7.6, 5.8, 9.6, 2.3)), (2.0, Vectors.dense(8.8, 7.3, 5.7, 7.3, 2.2, 4.1)), - (1.0, Vectors.dense(1.2, 9.5, 2.5, 3.1, 8.7, 2.5)), + (3.0, Vectors.dense(1.2, 9.5, 2.5, 3.1, 8.7, 2.5)), (2.0, Vectors.dense(3.7, 9.2, 6.1, 4.1, 7.5, 3.8)), (4.0, Vectors.dense(8.9, 5.2, 7.8, 8.3, 5.2, 3.0)), (4.0, Vectors.dense(7.9, 8.5, 9.2, 4.0, 9.4, 2.1)) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/FValueSelectorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/FValueSelectorExample.scala new file mode 100644 index 0000000000000..914d81b79c997 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/FValueSelectorExample.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.FValueSelector +import org.apache.spark.ml.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SparkSession + +/** + * An example for FValueSelector. + * Run with + * {{{ + * bin/run-example ml.FValueSelectorExample + * }}} + */ +object FValueSelectorExample { + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder + .appName("FValueSelectorExample") + .getOrCreate() + import spark.implicits._ + + // $example on$ + val data = Seq( + (1, Vectors.dense(6.0, 7.0, 0.0, 7.0, 6.0, 0.0), 4.6), + (2, Vectors.dense(0.0, 9.0, 6.0, 0.0, 5.0, 9.0), 6.6), + (3, Vectors.dense(0.0, 9.0, 3.0, 0.0, 5.0, 5.0), 5.1), + (4, Vectors.dense(0.0, 9.0, 8.0, 5.0, 6.0, 4.0), 7.6), + (5, Vectors.dense(8.0, 9.0, 6.0, 5.0, 4.0, 4.0), 9.0), + (6, Vectors.dense(8.0, 9.0, 6.0, 4.0, 0.0, 0.0), 9.0) + ) + + val df = spark.createDataset(data).toDF("id", "features", "label") + + val selector = new FValueSelector() + .setNumTopFeatures(1) + .setFeaturesCol("features") + .setLabelCol("label") + .setOutputCol("selectedFeatures") + + val result = selector.fit(df).transform(df) + + println(s"FValueSelector output with top ${selector.getNumTopFeatures} features selected") + result.show() + // $example off$ + + spark.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/FVlaueTestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/FValueTestExample.scala similarity index 100% rename from examples/src/main/scala/org/apache/spark/examples/ml/FVlaueTestExample.scala rename to examples/src/main/scala/org/apache/spark/examples/ml/FValueTestExample.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VarianceThresholdSelectorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VarianceThresholdSelectorExample.scala new file mode 100644 index 0000000000000..e4185268aa86f --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VarianceThresholdSelectorExample.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.VarianceThresholdSelector +import org.apache.spark.ml.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SparkSession + +/** + * An example for VarianceThresholdSelector. + * Run with + * {{{ + * bin/run-example ml.VarianceThresholdSelectorExample + * }}} + */ +object VarianceThresholdSelectorExample { + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder + .appName("VarianceThresholdSelectorExample") + .getOrCreate() + import spark.implicits._ + + // $example on$ + val data = Seq( + (1, Vectors.dense(6.0, 7.0, 0.0, 7.0, 6.0, 0.0)), + (2, Vectors.dense(0.0, 9.0, 6.0, 0.0, 5.0, 9.0)), + (3, Vectors.dense(0.0, 9.0, 3.0, 0.0, 5.0, 5.0)), + (4, Vectors.dense(0.0, 9.0, 8.0, 5.0, 6.0, 4.0)), + (5, Vectors.dense(8.0, 9.0, 6.0, 5.0, 4.0, 4.0)), + (6, Vectors.dense(8.0, 9.0, 6.0, 0.0, 0.0, 0.0)) + ) + + val df = spark.createDataset(data).toDF("id", "features") + + val selector = new VarianceThresholdSelector() + .setVarianceThreshold(8.0) + .setFeaturesCol("features") + .setOutputCol("selectedFeatures") + + val result = selector.fit(df).transform(df) + + println(s"Output: Features with variance lower than" + + s" ${selector.getVarianceThreshold} are removed.") + result.show() + // $example off$ + + spark.stop() + } +} +// scalastyle:on println diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala index f32fe46bb6e1f..1d18594fd349c 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -34,22 +34,33 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_DAY -import org.apache.spark.sql.catalyst.util.RebaseDateTime._ +import org.apache.spark.sql.execution.datasources.DataSourceUtils import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** * A deserializer to deserialize data in avro format to data in catalyst format. */ -class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType, rebaseDateTime: Boolean) { +class AvroDeserializer( + rootAvroType: Schema, + rootCatalystType: DataType, + datetimeRebaseMode: LegacyBehaviorPolicy.Value) { def this(rootAvroType: Schema, rootCatalystType: DataType) { this(rootAvroType, rootCatalystType, - SQLConf.get.getConf(SQLConf.LEGACY_AVRO_REBASE_DATETIME_IN_READ)) + LegacyBehaviorPolicy.withName( + SQLConf.get.getConf(SQLConf.LEGACY_AVRO_REBASE_MODE_IN_READ))) } private lazy val decimalConversions = new DecimalConversion() + private val dateRebaseFunc = DataSourceUtils.creteDateRebaseFuncInRead( + datetimeRebaseMode, "Avro") + + private val timestampRebaseFunc = DataSourceUtils.creteTimestampRebaseFuncInRead( + datetimeRebaseMode, "Avro") + private val converter: Any => Any = rootCatalystType match { // A shortcut for empty schema. case st: StructType if st.isEmpty => @@ -96,13 +107,8 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType, rebaseD case (INT, IntegerType) => (updater, ordinal, value) => updater.setInt(ordinal, value.asInstanceOf[Int]) - case (INT, DateType) if rebaseDateTime => (updater, ordinal, value) => - val days = value.asInstanceOf[Int] - val rebasedDays = rebaseJulianToGregorianDays(days) - updater.setInt(ordinal, rebasedDays) - case (INT, DateType) => (updater, ordinal, value) => - updater.setInt(ordinal, value.asInstanceOf[Int]) + updater.setInt(ordinal, dateRebaseFunc(value.asInstanceOf[Int])) case (LONG, LongType) => (updater, ordinal, value) => updater.setLong(ordinal, value.asInstanceOf[Long]) @@ -110,22 +116,13 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType, rebaseD case (LONG, TimestampType) => avroType.getLogicalType match { // For backward compatibility, if the Avro type is Long and it is not logical type // (the `null` case), the value is processed as timestamp type with millisecond precision. - case null | _: TimestampMillis if rebaseDateTime => (updater, ordinal, value) => - val millis = value.asInstanceOf[Long] - val micros = DateTimeUtils.millisToMicros(millis) - val rebasedMicros = rebaseJulianToGregorianMicros(micros) - updater.setLong(ordinal, rebasedMicros) case null | _: TimestampMillis => (updater, ordinal, value) => val millis = value.asInstanceOf[Long] val micros = DateTimeUtils.millisToMicros(millis) - updater.setLong(ordinal, micros) - case _: TimestampMicros if rebaseDateTime => (updater, ordinal, value) => - val micros = value.asInstanceOf[Long] - val rebasedMicros = rebaseJulianToGregorianMicros(micros) - updater.setLong(ordinal, rebasedMicros) + updater.setLong(ordinal, timestampRebaseFunc(micros)) case _: TimestampMicros => (updater, ordinal, value) => val micros = value.asInstanceOf[Long] - updater.setLong(ordinal, micros) + updater.setLong(ordinal, timestampRebaseFunc(micros)) case other => throw new IncompatibleSchemaException( s"Cannot convert Avro logical type ${other} to Catalyst Timestamp type.") } diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index e69c95b797c73..59d54bc433f8b 100755 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -124,12 +124,12 @@ private[sql] class AvroFileFormat extends FileFormat reader.sync(file.start) val stop = file.start + file.length - val rebaseDateTime = DataSourceUtils.needRebaseDateTime( - reader.asInstanceOf[DataFileReader[_]].getMetaString).getOrElse { - SQLConf.get.getConf(SQLConf.LEGACY_AVRO_REBASE_DATETIME_IN_READ) - } + val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode( + reader.asInstanceOf[DataFileReader[_]].getMetaString, + SQLConf.get.getConf(SQLConf.LEGACY_AVRO_REBASE_MODE_IN_READ)) + val deserializer = new AvroDeserializer( - userProvidedSchema.getOrElse(reader.getSchema), requiredSchema, rebaseDateTime) + userProvidedSchema.getOrElse(reader.getSchema), requiredSchema, datetimeRebaseMode) new Iterator[InternalRow] { private[this] var completed = false diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala index 82a568049990e..ac9608c867937 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.{SPARK_LEGACY_DATETIME, SPARK_VERSION_METADATA_KEY} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.OutputWriter import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy import org.apache.spark.sql.types._ // NOTE: This class is instantiated and used on executor side only, no need to be serializable. @@ -43,12 +44,12 @@ private[avro] class AvroOutputWriter( avroSchema: Schema) extends OutputWriter { // Whether to rebase datetimes from Gregorian to Julian calendar in write - private val rebaseDateTime: Boolean = - SQLConf.get.getConf(SQLConf.LEGACY_AVRO_REBASE_DATETIME_IN_WRITE) + private val datetimeRebaseMode = LegacyBehaviorPolicy.withName( + SQLConf.get.getConf(SQLConf.LEGACY_AVRO_REBASE_MODE_IN_WRITE)) // The input rows will never be null. private lazy val serializer = - new AvroSerializer(schema, avroSchema, nullable = false, rebaseDateTime) + new AvroSerializer(schema, avroSchema, nullable = false, datetimeRebaseMode) /** * Overrides the couple of methods responsible for generating the output streams / files so @@ -56,7 +57,11 @@ private[avro] class AvroOutputWriter( */ private val recordWriter: RecordWriter[AvroKey[GenericRecord], NullWritable] = { val fileMeta = Map(SPARK_VERSION_METADATA_KEY -> SPARK_VERSION_SHORT) ++ { - if (rebaseDateTime) Some(SPARK_LEGACY_DATETIME -> "") else None + if (datetimeRebaseMode == LegacyBehaviorPolicy.LEGACY) { + Some(SPARK_LEGACY_DATETIME -> "") + } else { + None + } } new SparkAvroKeyOutputFormat(fileMeta.asJava) { diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala index c87249e29fbd6..21c5dec6239bd 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala @@ -35,8 +35,9 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow} import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.catalyst.util.RebaseDateTime._ +import org.apache.spark.sql.execution.datasources.DataSourceUtils import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy import org.apache.spark.sql.types._ /** @@ -46,17 +47,24 @@ class AvroSerializer( rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean, - rebaseDateTime: Boolean) extends Logging { + datetimeRebaseMode: LegacyBehaviorPolicy.Value) extends Logging { def this(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean) { this(rootCatalystType, rootAvroType, nullable, - SQLConf.get.getConf(SQLConf.LEGACY_AVRO_REBASE_DATETIME_IN_WRITE)) + LegacyBehaviorPolicy.withName(SQLConf.get.getConf( + SQLConf.LEGACY_AVRO_REBASE_MODE_IN_WRITE))) } def serialize(catalystData: Any): Any = { converter.apply(catalystData) } + private val dateRebaseFunc = DataSourceUtils.creteDateRebaseFuncInWrite( + datetimeRebaseMode, "Avro") + + private val timestampRebaseFunc = DataSourceUtils.creteTimestampRebaseFuncInWrite( + datetimeRebaseMode, "Avro") + private val converter: Any => Any = { val actualAvroType = resolveNullableType(rootAvroType, nullable) val baseConverter = rootCatalystType match { @@ -146,24 +154,16 @@ class AvroSerializer( case (BinaryType, BYTES) => (getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal)) - case (DateType, INT) if rebaseDateTime => - (getter, ordinal) => rebaseGregorianToJulianDays(getter.getInt(ordinal)) - case (DateType, INT) => - (getter, ordinal) => getter.getInt(ordinal) + (getter, ordinal) => dateRebaseFunc(getter.getInt(ordinal)) case (TimestampType, LONG) => avroType.getLogicalType match { // For backward compatibility, if the Avro type is Long and it is not logical type // (the `null` case), output the timestamp value as with millisecond precision. - case null | _: TimestampMillis if rebaseDateTime => (getter, ordinal) => - val micros = getter.getLong(ordinal) - val rebasedMicros = rebaseGregorianToJulianMicros(micros) - DateTimeUtils.microsToMillis(rebasedMicros) case null | _: TimestampMillis => (getter, ordinal) => - DateTimeUtils.microsToMillis(getter.getLong(ordinal)) - case _: TimestampMicros if rebaseDateTime => (getter, ordinal) => - rebaseGregorianToJulianMicros(getter.getLong(ordinal)) - case _: TimestampMicros => (getter, ordinal) => getter.getLong(ordinal) + DateTimeUtils.microsToMillis(timestampRebaseFunc(getter.getLong(ordinal))) + case _: TimestampMicros => (getter, ordinal) => + timestampRebaseFunc(getter.getLong(ordinal)) case other => throw new IncompatibleSchemaException( s"Cannot convert Catalyst Timestamp type to Avro logical type ${other}") } diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala index 712aec6acbd56..15918f46a83bb 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala @@ -88,12 +88,11 @@ case class AvroPartitionReaderFactory( reader.sync(partitionedFile.start) val stop = partitionedFile.start + partitionedFile.length - val rebaseDateTime = DataSourceUtils.needRebaseDateTime( - reader.asInstanceOf[DataFileReader[_]].getMetaString).getOrElse { - SQLConf.get.getConf(SQLConf.LEGACY_AVRO_REBASE_DATETIME_IN_READ) - } + val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode( + reader.asInstanceOf[DataFileReader[_]].getMetaString, + SQLConf.get.getConf(SQLConf.LEGACY_AVRO_REBASE_MODE_IN_READ)) val deserializer = new AvroDeserializer( - userProvidedSchema.getOrElse(reader.getSchema), readDataSchema, rebaseDateTime) + userProvidedSchema.getOrElse(reader.getSchema), readDataSchema, datetimeRebaseMode) val fileReader = new PartitionReader[InternalRow] { private[this] var completed = false diff --git a/external/avro/src/test/resources/before_1582_date_v2_4.avro b/external/avro/src/test/resources/before_1582_date_v2_4.avro deleted file mode 100644 index 96aa7cbf176a5..0000000000000 Binary files a/external/avro/src/test/resources/before_1582_date_v2_4.avro and /dev/null differ diff --git a/external/avro/src/test/resources/before_1582_date_v2_4_5.avro b/external/avro/src/test/resources/before_1582_date_v2_4_5.avro new file mode 100644 index 0000000000000..5c15601f7ee4b Binary files /dev/null and b/external/avro/src/test/resources/before_1582_date_v2_4_5.avro differ diff --git a/external/avro/src/test/resources/before_1582_date_v2_4_6.avro b/external/avro/src/test/resources/before_1582_date_v2_4_6.avro new file mode 100644 index 0000000000000..212ea1d5efa5c Binary files /dev/null and b/external/avro/src/test/resources/before_1582_date_v2_4_6.avro differ diff --git a/external/avro/src/test/resources/before_1582_timestamp_micros_v2_4_5.avro b/external/avro/src/test/resources/before_1582_timestamp_micros_v2_4_5.avro new file mode 100644 index 0000000000000..c3445e3999bc1 Binary files /dev/null and b/external/avro/src/test/resources/before_1582_timestamp_micros_v2_4_5.avro differ diff --git a/external/avro/src/test/resources/before_1582_timestamp_micros_v2_4_6.avro b/external/avro/src/test/resources/before_1582_timestamp_micros_v2_4_6.avro new file mode 100644 index 0000000000000..96008d2378b1f Binary files /dev/null and b/external/avro/src/test/resources/before_1582_timestamp_micros_v2_4_6.avro differ diff --git a/external/avro/src/test/resources/before_1582_ts_millis_v2_4.avro b/external/avro/src/test/resources/before_1582_timestamp_millis_v2_4_5.avro similarity index 52% rename from external/avro/src/test/resources/before_1582_ts_millis_v2_4.avro rename to external/avro/src/test/resources/before_1582_timestamp_millis_v2_4_5.avro index dbaec814eb954..be12a0782073c 100644 Binary files a/external/avro/src/test/resources/before_1582_ts_millis_v2_4.avro and b/external/avro/src/test/resources/before_1582_timestamp_millis_v2_4_5.avro differ diff --git a/external/avro/src/test/resources/before_1582_timestamp_millis_v2_4_6.avro b/external/avro/src/test/resources/before_1582_timestamp_millis_v2_4_6.avro new file mode 100644 index 0000000000000..262f5dd6e77a4 Binary files /dev/null and b/external/avro/src/test/resources/before_1582_timestamp_millis_v2_4_6.avro differ diff --git a/external/avro/src/test/resources/before_1582_ts_micros_v2_4.avro b/external/avro/src/test/resources/before_1582_ts_micros_v2_4.avro deleted file mode 100644 index efe5e71a58813..0000000000000 Binary files a/external/avro/src/test/resources/before_1582_ts_micros_v2_4.avro and /dev/null differ diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala index 64d790bc4acd4..c8a1f670bda9e 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala @@ -288,7 +288,7 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite """.stripMargin val avroSchema = new Schema.Parser().parse(jsonFormatSchema) val dataType = SchemaConverters.toSqlType(avroSchema).dataType - val deserializer = new AvroDeserializer(avroSchema, dataType, rebaseDateTime = false) + val deserializer = new AvroDeserializer(avroSchema, dataType) def checkDeserialization(data: GenericData.Record, expected: Any): Unit = { assert(checkResult( diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 3e754f02911dc..e2ae489446d85 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.avro import java.io._ import java.net.URL -import java.nio.file.{Files, Paths} +import java.nio.file.{Files, Paths, StandardCopyOption} import java.sql.{Date, Timestamp} import java.util.{Locale, UUID} @@ -33,16 +33,17 @@ import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWri import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} import org.apache.commons.io.FileUtils -import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf, SparkException} +import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf, SparkException, SparkUpgradeException} import org.apache.spark.sql._ import org.apache.spark.sql.TestingUDT.IntervalData import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.Filter -import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC} +import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, LA, UTC} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources.{DataSource, FilePartition} import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy._ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.sql.v2.avro.AvroScan @@ -1528,51 +1529,160 @@ abstract class AvroSuite extends QueryTest with SharedSparkSession { } } + // It generates input files for the test below: + // "SPARK-31183: compatibility with Spark 2.4 in reading dates/timestamps" + ignore("SPARK-31855: generate test files for checking compatibility with Spark 2.4") { + val resourceDir = "external/avro/src/test/resources" + val version = "2_4_6" + def save( + in: Seq[String], + t: String, + dstFile: String, + options: Map[String, String] = Map.empty): Unit = { + withTempDir { dir => + in.toDF("dt") + .select($"dt".cast(t)) + .repartition(1) + .write + .mode("overwrite") + .options(options) + .format("avro") + .save(dir.getCanonicalPath) + Files.copy( + dir.listFiles().filter(_.getName.endsWith(".avro")).head.toPath, + Paths.get(resourceDir, dstFile), + StandardCopyOption.REPLACE_EXISTING) + } + } + withDefaultTimeZone(LA) { + withSQLConf( + SQLConf.SESSION_LOCAL_TIMEZONE.key -> LA.getId) { + save( + Seq("1001-01-01"), + "date", + s"before_1582_date_v$version.avro") + save( + Seq("1001-01-01 01:02:03.123"), + "timestamp", + s"before_1582_timestamp_millis_v$version.avro", + // scalastyle:off line.size.limit + Map("avroSchema" -> + s""" + | { + | "namespace": "logical", + | "type": "record", + | "name": "test", + | "fields": [ + | {"name": "dt", "type": ["null", {"type": "long","logicalType": "timestamp-millis"}], "default": null} + | ] + | } + |""".stripMargin)) + // scalastyle:on line.size.limit + save( + Seq("1001-01-01 01:02:03.123456"), + "timestamp", + s"before_1582_timestamp_micros_v$version.avro") + } + } + } + test("SPARK-31183: compatibility with Spark 2.4 in reading dates/timestamps") { // test reading the existing 2.4 files and new 3.0 files (with rebase on/off) together. - def checkReadMixedFiles(fileName: String, dt: String, dataStr: String): Unit = { + def checkReadMixedFiles( + fileName: String, + dt: String, + dataStr: String, + checkDefaultLegacyRead: String => Unit): Unit = { withTempPaths(2) { paths => paths.foreach(_.delete()) val path2_4 = getResourceAvroFilePath(fileName) val path3_0 = paths(0).getCanonicalPath val path3_0_rebase = paths(1).getCanonicalPath if (dt == "date") { - val df = Seq(dataStr).toDF("str").select($"str".cast("date").as("date")) - df.write.format("avro").save(path3_0) - withSQLConf(SQLConf.LEGACY_AVRO_REBASE_DATETIME_IN_WRITE.key -> "true") { + val df = Seq(dataStr).toDF("str").select($"str".cast("date").as("dt")) + + // By default we should fail to write ancient datetime values. + val e = intercept[SparkException](df.write.format("avro").save(path3_0)) + assert(e.getCause.getCause.getCause.isInstanceOf[SparkUpgradeException]) + checkDefaultLegacyRead(path2_4) + + withSQLConf(SQLConf.LEGACY_AVRO_REBASE_MODE_IN_WRITE.key -> CORRECTED.toString) { + df.write.format("avro").mode("overwrite").save(path3_0) + } + withSQLConf(SQLConf.LEGACY_AVRO_REBASE_MODE_IN_WRITE.key -> LEGACY.toString) { df.write.format("avro").save(path3_0_rebase) } - checkAnswer( - spark.read.format("avro").load(path2_4, path3_0, path3_0_rebase), - 1.to(3).map(_ => Row(java.sql.Date.valueOf(dataStr)))) + + // For Avro files written by Spark 3.0, we know the writer info and don't need the config + // to guide the rebase behavior. + withSQLConf(SQLConf.LEGACY_AVRO_REBASE_MODE_IN_READ.key -> LEGACY.toString) { + checkAnswer( + spark.read.format("avro").load(path2_4, path3_0, path3_0_rebase), + 1.to(3).map(_ => Row(java.sql.Date.valueOf(dataStr)))) + } } else { - val df = Seq(dataStr).toDF("str").select($"str".cast("timestamp").as("ts")) + val df = Seq(dataStr).toDF("str").select($"str".cast("timestamp").as("dt")) val avroSchema = s""" |{ | "type" : "record", | "name" : "test_schema", | "fields" : [ - | {"name": "ts", "type": {"type": "long", "logicalType": "$dt"}} + | {"name": "dt", "type": {"type": "long", "logicalType": "$dt"}} | ] |}""".stripMargin - df.write.format("avro").option("avroSchema", avroSchema).save(path3_0) - withSQLConf(SQLConf.LEGACY_AVRO_REBASE_DATETIME_IN_WRITE.key -> "true") { + + // By default we should fail to write ancient datetime values. + val e = intercept[SparkException] { + df.write.format("avro").option("avroSchema", avroSchema).save(path3_0) + } + assert(e.getCause.getCause.getCause.isInstanceOf[SparkUpgradeException]) + checkDefaultLegacyRead(path2_4) + + withSQLConf(SQLConf.LEGACY_AVRO_REBASE_MODE_IN_WRITE.key -> CORRECTED.toString) { + df.write.format("avro").option("avroSchema", avroSchema).mode("overwrite").save(path3_0) + } + withSQLConf(SQLConf.LEGACY_AVRO_REBASE_MODE_IN_WRITE.key -> LEGACY.toString) { df.write.format("avro").option("avroSchema", avroSchema).save(path3_0_rebase) } - checkAnswer( - spark.read.format("avro").load(path2_4, path3_0, path3_0_rebase), - 1.to(3).map(_ => Row(java.sql.Timestamp.valueOf(dataStr)))) + + // For Avro files written by Spark 3.0, we know the writer info and don't need the config + // to guide the rebase behavior. + withSQLConf(SQLConf.LEGACY_AVRO_REBASE_MODE_IN_READ.key -> LEGACY.toString) { + checkAnswer( + spark.read.format("avro").load(path2_4, path3_0, path3_0_rebase), + 1.to(3).map(_ => Row(java.sql.Timestamp.valueOf(dataStr)))) + } } } } - withSQLConf(SQLConf.LEGACY_AVRO_REBASE_DATETIME_IN_READ.key -> "true") { - checkReadMixedFiles("before_1582_date_v2_4.avro", "date", "1001-01-01") + def failInRead(path: String): Unit = { + val e = intercept[SparkException](spark.read.format("avro").load(path).collect()) + assert(e.getCause.isInstanceOf[SparkUpgradeException]) + } + def successInRead(path: String): Unit = spark.read.format("avro").load(path).collect() + Seq( + // By default we should fail to read ancient datetime values when parquet files don't + // contain Spark version. + "2_4_5" -> failInRead _, + "2_4_6" -> successInRead _ + ).foreach { case (version, checkDefaultRead) => + checkReadMixedFiles( + s"before_1582_date_v$version.avro", + "date", + "1001-01-01", + checkDefaultRead) checkReadMixedFiles( - "before_1582_ts_micros_v2_4.avro", "timestamp-micros", "1001-01-01 01:02:03.123456") + s"before_1582_timestamp_micros_v$version.avro", + "timestamp-micros", + "1001-01-01 01:02:03.123456", + checkDefaultRead) checkReadMixedFiles( - "before_1582_ts_millis_v2_4.avro", "timestamp-millis", "1001-01-01 01:02:03.124") + s"before_1582_timestamp_millis_v$version.avro", + "timestamp-millis", + "1001-01-01 01:02:03.123", + checkDefaultRead) } } @@ -1581,7 +1691,7 @@ abstract class AvroSuite extends QueryTest with SharedSparkSession { val nonRebased = "1001-01-07 01:09:05.123456" withTempPath { dir => val path = dir.getAbsolutePath - withSQLConf(SQLConf.LEGACY_AVRO_REBASE_DATETIME_IN_WRITE.key -> "true") { + withSQLConf(SQLConf.LEGACY_AVRO_REBASE_MODE_IN_WRITE.key -> LEGACY.toString) { Seq(tsStr).toDF("tsS") .select($"tsS".cast("timestamp").as("ts")) .write.format("avro") @@ -1589,9 +1699,9 @@ abstract class AvroSuite extends QueryTest with SharedSparkSession { } // The file metadata indicates if it needs rebase or not, so we can always get the correct - // result regardless of the "rebaseInRead" config. - Seq(true, false).foreach { rebase => - withSQLConf(SQLConf.LEGACY_AVRO_REBASE_DATETIME_IN_READ.key -> rebase.toString) { + // result regardless of the "rebase mode" config. + Seq(LEGACY, CORRECTED, EXCEPTION).foreach { mode => + withSQLConf(SQLConf.LEGACY_AVRO_REBASE_MODE_IN_READ.key -> mode.toString) { checkAnswer(spark.read.format("avro").load(path), Row(Timestamp.valueOf(tsStr))) } } @@ -1622,7 +1732,7 @@ abstract class AvroSuite extends QueryTest with SharedSparkSession { |}""".stripMargin withTempPath { dir => val path = dir.getAbsolutePath - withSQLConf(SQLConf.LEGACY_AVRO_REBASE_DATETIME_IN_WRITE.key -> "true") { + withSQLConf(SQLConf.LEGACY_AVRO_REBASE_MODE_IN_WRITE.key -> LEGACY.toString) { Seq(tsStr).toDF("tsS") .select($"tsS".cast("timestamp").as("ts")) .write @@ -1632,9 +1742,9 @@ abstract class AvroSuite extends QueryTest with SharedSparkSession { } // The file metadata indicates if it needs rebase or not, so we can always get the correct - // result regardless of the "rebaseInRead" config. - Seq(true, false).foreach { rebase => - withSQLConf(SQLConf.LEGACY_AVRO_REBASE_DATETIME_IN_READ.key -> rebase.toString) { + // result regardless of the "rebase mode" config. + Seq(LEGACY, CORRECTED, EXCEPTION).foreach { mode => + withSQLConf(SQLConf.LEGACY_AVRO_REBASE_MODE_IN_READ.key -> mode.toString) { checkAnswer( spark.read.schema("ts timestamp").format("avro").load(path), Row(Timestamp.valueOf(rebased))) @@ -1655,7 +1765,7 @@ abstract class AvroSuite extends QueryTest with SharedSparkSession { test("SPARK-31183: rebasing dates in write") { withTempPath { dir => val path = dir.getAbsolutePath - withSQLConf(SQLConf.LEGACY_AVRO_REBASE_DATETIME_IN_WRITE.key -> "true") { + withSQLConf(SQLConf.LEGACY_AVRO_REBASE_MODE_IN_WRITE.key -> LEGACY.toString) { Seq("1001-01-01").toDF("dateS") .select($"dateS".cast("date").as("date")) .write.format("avro") @@ -1663,9 +1773,9 @@ abstract class AvroSuite extends QueryTest with SharedSparkSession { } // The file metadata indicates if it needs rebase or not, so we can always get the correct - // result regardless of the "rebaseInRead" config. - Seq(true, false).foreach { rebase => - withSQLConf(SQLConf.LEGACY_AVRO_REBASE_DATETIME_IN_READ.key -> rebase.toString) { + // result regardless of the "rebase mode" config. + Seq(LEGACY, CORRECTED, EXCEPTION).foreach { mode => + withSQLConf(SQLConf.LEGACY_AVRO_REBASE_MODE_IN_READ.key -> mode.toString) { checkAnswer(spark.read.format("avro").load(path), Row(Date.valueOf("1001-01-01"))) } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index a5e5d01152db8..ede58bd26ce34 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -33,10 +33,10 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability, TableProvider} import org.apache.spark.sql.connector.read.{Batch, Scan, ScanBuilder} import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} -import org.apache.spark.sql.connector.write.{BatchWrite, LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.connector.write.{BatchWrite, LogicalWriteInfo, SupportsTruncate, WriteBuilder} import org.apache.spark.sql.connector.write.streaming.StreamingWrite import org.apache.spark.sql.execution.streaming.{Sink, Source} -import org.apache.spark.sql.internal.connector.SimpleTableProvider +import org.apache.spark.sql.internal.connector.{SimpleTableProvider, SupportsStreamingUpdate} import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -394,7 +394,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister () => new KafkaScan(options) override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { - new WriteBuilder { + new WriteBuilder with SupportsTruncate with SupportsStreamingUpdate { private val options = info.options private val inputSchema: StructType = info.schema() private val topic = Option(options.get(TOPIC_OPTION_KEY)).map(_.trim) @@ -410,6 +410,9 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister assert(inputSchema != null) new KafkaStreamingWrite(topic, producerParams, inputSchema) } + + override def truncate(): WriteBuilder = this + override def update(): WriteBuilder = this } } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index a4601b91af0d6..bdad214a91343 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -349,7 +349,8 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { ) } - test("subscribing topic by pattern with topic deletions") { + // TODO (SPARK-31731): re-enable it + ignore("subscribing topic by pattern with topic deletions") { val topicPrefix = newTopic() val topic = topicPrefix + "-seems" val topic2 = topicPrefix + "-bad" diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala index 32d056140a0d7..e5f3a229622e1 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala @@ -179,7 +179,8 @@ abstract class KafkaRelationSuiteBase extends QueryTest with SharedSparkSession ("3", Seq(("e", "f".getBytes(UTF_8)), ("e", "g".getBytes(UTF_8))))).toDF) } - test("timestamp provided for starting and ending") { + // TODO (SPARK-31729): re-enable it + ignore("timestamp provided for starting and ending") { val (topic, timestamps) = prepareTimestampRelatedUnitTest // timestamp both presented: starting "first" ending "finalized" diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala index 4f846199cfbc7..275a8170182fe 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -27,6 +27,7 @@ import javax.security.auth.login.Configuration import scala.collection.JavaConverters._ import scala.io.Source import scala.util.Random +import scala.util.control.NonFatal import com.google.common.io.Files import kafka.api.Request @@ -36,7 +37,7 @@ import kafka.zk.KafkaZkClient import org.apache.hadoop.minikdc.MiniKdc import org.apache.hadoop.security.UserGroupInformation import org.apache.kafka.clients.CommonClientConfigs -import org.apache.kafka.clients.admin.{AdminClient, CreatePartitionsOptions, ListConsumerGroupsResult, NewPartitions, NewTopic} +import org.apache.kafka.clients.admin._ import org.apache.kafka.clients.consumer.KafkaConsumer import org.apache.kafka.clients.producer._ import org.apache.kafka.common.TopicPartition @@ -134,8 +135,30 @@ class KafkaTestUtils( val kdcDir = Utils.createTempDir() val kdcConf = MiniKdc.createConf() kdcConf.setProperty(MiniKdc.DEBUG, "true") - kdc = new MiniKdc(kdcConf, kdcDir) - kdc.start() + // The port for MiniKdc service gets selected in the constructor, but will be bound + // to it later in MiniKdc.start() -> MiniKdc.initKDCServer() -> KdcServer.start(). + // In meantime, when some other service might capture the port during this progress, and + // cause BindException. + // This makes our tests which have dedicated JVMs and rely on MiniKDC being flaky + // + // https://issues.apache.org/jira/browse/HADOOP-12656 get fixed in Hadoop 2.8.0. + // + // The workaround here is to periodically repeat this process with a timeout , since we are + // using Hadoop 2.7.4 as default. + // https://issues.apache.org/jira/browse/SPARK-31631 + eventually(timeout(60.seconds), interval(1.second)) { + try { + kdc = new MiniKdc(kdcConf, kdcDir) + kdc.start() + } catch { + case NonFatal(e) => + if (kdc != null) { + kdc.stop() + kdc = null + } + throw e + } + } // TODO https://issues.apache.org/jira/browse/SPARK-30037 // Need to build spark's own MiniKDC and customize krb5.conf like Kafka rewriteKrb5Conf() diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala index 925327d9d58e6..72cf3e8118228 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala @@ -332,7 +332,8 @@ class DirectKafkaStreamSuite } // Test to verify the offset ranges can be recovered from the checkpoints - test("offset recovery") { + // TODO (SPARK-31722): re-enable it + ignore("offset recovery") { val topic = "recovery" kafkaTestUtils.createTopic(topic) testDir = Utils.createTempDir() @@ -418,8 +419,9 @@ class DirectKafkaStreamSuite ssc.stop() } - // Test to verify the offsets can be recovered from Kafka - test("offset recovery from kafka") { + // Test to verify the offsets can be recovered from Kafka + // TODO (SPARK-31722): re-enable it + ignore("offset recovery from kafka") { val topic = "recoveryfromkafka" kafkaTestUtils.createTopic(topic) diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala index 3d3e7a22e594b..368f177cda828 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala @@ -271,7 +271,7 @@ private[spark] object BLAS extends Serializable { } /** - * Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's ?SPR. + * Adds alpha * v * v.t to a matrix in-place. This is the same as BLAS's ?SPR. * * @param U the upper triangular part of the matrix packed in an array (column major) */ diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala b/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala index 42746b5727029..a08b8af0fcbfd 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala @@ -55,7 +55,7 @@ class MultivariateGaussian @Since("2.0.0") ( */ @transient private lazy val tuple = { val (rootSigmaInv, u) = calculateCovarianceConstants - val rootSigmaInvMat = Matrices.fromBreeze(rootSigmaInv) + val rootSigmaInvMat = Matrices.fromBreeze(rootSigmaInv).toDense val rootSigmaInvMulMu = rootSigmaInvMat.multiply(mean) (rootSigmaInvMat, u, rootSigmaInvMulMu) } @@ -81,6 +81,36 @@ class MultivariateGaussian @Since("2.0.0") ( u - 0.5 * BLAS.dot(v, v) } + private[ml] def pdf(X: Matrix): DenseVector = { + val mat = DenseMatrix.zeros(X.numRows, X.numCols) + pdf(X, mat) + } + + private[ml] def pdf(X: Matrix, mat: DenseMatrix): DenseVector = { + require(!mat.isTransposed) + + BLAS.gemm(1.0, X, rootSigmaInvMat.transpose, 0.0, mat) + val m = mat.numRows + val n = mat.numCols + + val pdfVec = mat.multiply(rootSigmaInvMulMu) + + val blas = BLAS.getBLAS(n) + val squared1 = blas.ddot(n, rootSigmaInvMulMu.values, 1, rootSigmaInvMulMu.values, 1) + + val localU = u + var i = 0 + while (i < m) { + val squared2 = blas.ddot(n, mat.values, i, m, mat.values, i, m) + val dot = pdfVec(i) + val squaredSum = squared1 + squared2 - dot - dot + pdfVec.values(i) = math.exp(localU - 0.5 * squaredSum) + i += 1 + } + + pdfVec + } + /** * Calculate distribution dependent components used for the density function: * pdf(x) = (2*pi)^(-k/2)^ * det(sigma)^(-1/2)^ * exp((-1/2) * (x-mu).t * inv(sigma) * (x-mu)) diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussianSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussianSuite.scala index f2ecff1cc58bd..8652d317a85c4 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussianSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussianSuite.scala @@ -27,6 +27,7 @@ class MultivariateGaussianSuite extends SparkMLFunSuite { test("univariate") { val x1 = Vectors.dense(0.0) val x2 = Vectors.dense(1.5) + val mat = Matrices.fromVectors(Seq(x1, x2)) val mu = Vectors.dense(0.0) val sigma1 = Matrices.dense(1, 1, Array(1.0)) @@ -35,6 +36,7 @@ class MultivariateGaussianSuite extends SparkMLFunSuite { assert(dist1.logpdf(x2) ~== -2.0439385332046727 absTol 1E-5) assert(dist1.pdf(x1) ~== 0.39894 absTol 1E-5) assert(dist1.pdf(x2) ~== 0.12952 absTol 1E-5) + assert(dist1.pdf(mat) ~== Vectors.dense(0.39894, 0.12952) absTol 1E-5) val sigma2 = Matrices.dense(1, 1, Array(4.0)) val dist2 = new MultivariateGaussian(mu, sigma2) @@ -42,11 +44,13 @@ class MultivariateGaussianSuite extends SparkMLFunSuite { assert(dist2.logpdf(x2) ~== -1.893335713764618 absTol 1E-5) assert(dist2.pdf(x1) ~== 0.19947 absTol 1E-5) assert(dist2.pdf(x2) ~== 0.15057 absTol 1E-5) + assert(dist2.pdf(mat) ~== Vectors.dense(0.19947, 0.15057) absTol 1E-5) } test("multivariate") { val x1 = Vectors.dense(0.0, 0.0) val x2 = Vectors.dense(1.0, 1.0) + val mat = Matrices.fromVectors(Seq(x1, x2)) val mu = Vectors.dense(0.0, 0.0) val sigma1 = Matrices.dense(2, 2, Array(1.0, 0.0, 0.0, 1.0)) @@ -55,6 +59,7 @@ class MultivariateGaussianSuite extends SparkMLFunSuite { assert(dist1.logpdf(x2) ~== -2.8378770664093453 absTol 1E-5) assert(dist1.pdf(x1) ~== 0.15915 absTol 1E-5) assert(dist1.pdf(x2) ~== 0.05855 absTol 1E-5) + assert(dist1.pdf(mat) ~== Vectors.dense(0.15915, 0.05855) absTol 1E-5) val sigma2 = Matrices.dense(2, 2, Array(4.0, -1.0, -1.0, 2.0)) val dist2 = new MultivariateGaussian(mu, sigma2) @@ -62,21 +67,25 @@ class MultivariateGaussianSuite extends SparkMLFunSuite { assert(dist2.logpdf(x2) ~== -3.3822607123655732 absTol 1E-5) assert(dist2.pdf(x1) ~== 0.060155 absTol 1E-5) assert(dist2.pdf(x2) ~== 0.033971 absTol 1E-5) + assert(dist2.pdf(mat) ~== Vectors.dense(0.060155, 0.033971) absTol 1E-5) } test("multivariate degenerate") { val x1 = Vectors.dense(0.0, 0.0) val x2 = Vectors.dense(1.0, 1.0) + val mat = Matrices.fromVectors(Seq(x1, x2)) val mu = Vectors.dense(0.0, 0.0) val sigma = Matrices.dense(2, 2, Array(1.0, 1.0, 1.0, 1.0)) val dist = new MultivariateGaussian(mu, sigma) assert(dist.pdf(x1) ~== 0.11254 absTol 1E-5) assert(dist.pdf(x2) ~== 0.068259 absTol 1E-5) + assert(dist.pdf(mat) ~== Vectors.dense(0.11254, 0.068259) absTol 1E-5) } test("SPARK-11302") { val x = Vectors.dense(629, 640, 1.7188, 618.19) + val mat = Matrices.fromVectors(Seq(x)) val mu = Vectors.dense( 1055.3910505836575, 1070.489299610895, 1.39020554474708, 1040.5907503867697) val sigma = Matrices.dense(4, 4, Array( @@ -87,5 +96,6 @@ class MultivariateGaussianSuite extends SparkMLFunSuite { val dist = new MultivariateGaussian(mu, sigma) // Agrees with R's dmvnorm: 7.154782e-05 assert(dist.pdf(x) ~== 7.154782224045512E-5 absTol 1E-9) + assert(dist.pdf(mat) ~== Vectors.dense(7.154782224045512E-5) absTol 1E-5) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index febeba7e13fcb..e0b128e369816 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml import org.apache.spark.annotation.Since import org.apache.spark.ml.feature.{Instance, LabeledPoint} +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -71,7 +72,7 @@ private[ml] trait PredictorParams extends Params val w = this match { case p: HasWeightCol => if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) { - col($(p.weightCol)).cast(DoubleType) + checkNonNegativeWeight((col($(p.weightCol)).cast(DoubleType))) } else { lit(1.0) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index 69c35a8a80f52..217398c51b393 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -187,17 +187,15 @@ class LinearSVC @Since("2.2.0") ( val instances = extractInstances(dataset) .setName("training instances") - val (summarizer, labelSummarizer) = if ($(blockSize) == 1) { - if (dataset.storageLevel == StorageLevel.NONE) { - instances.persist(StorageLevel.MEMORY_AND_DISK) - } - Summarizer.getClassificationSummarizers(instances, $(aggregationDepth)) - } else { - // instances will be standardized and converted to blocks, so no need to cache instances. - Summarizer.getClassificationSummarizers(instances, $(aggregationDepth), - Seq("mean", "std", "count", "numNonZeros")) + if (dataset.storageLevel == StorageLevel.NONE && $(blockSize) == 1) { + instances.persist(StorageLevel.MEMORY_AND_DISK) } + var requestedMetrics = Seq("mean", "std", "count") + if ($(blockSize) != 1) requestedMetrics +:= "numNonZeros" + val (summarizer, labelSummarizer) = Summarizer + .getClassificationSummarizers(instances, $(aggregationDepth), requestedMetrics) + val histogram = labelSummarizer.histogram val numInvalid = labelSummarizer.countInvalid val numFeatures = summarizer.mean.size @@ -316,7 +314,7 @@ class LinearSVC @Since("2.2.0") ( } val blocks = InstanceBlock.blokify(standardized, $(blockSize)) .persist(StorageLevel.MEMORY_AND_DISK) - .setName(s"training dataset (blockSize=${$(blockSize)})") + .setName(s"training blocks (blockSize=${$(blockSize)})") val getAggregatorFunc = new BlockHingeAggregator($(fitIntercept))(_) val costFun = new RDDLossFunction(blocks, getAggregatorFunc, diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 10cf96180090a..0d1350640c74a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -29,6 +29,7 @@ import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml.feature._ +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg._ import org.apache.spark.ml.optim.aggregator._ import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction} @@ -41,7 +42,7 @@ import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, Multiclas import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.types.{DataType, DoubleType, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.VersionUtils @@ -517,17 +518,18 @@ class LogisticRegression @Since("1.2.0") ( probabilityCol, regParam, elasticNetParam, standardization, threshold, maxIter, tol, fitIntercept, blockSize) - val instances = extractInstances(dataset).setName("training instances") - if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) + val instances = extractInstances(dataset) + .setName("training instances") - val (summarizer, labelSummarizer) = if ($(blockSize) == 1) { - Summarizer.getClassificationSummarizers(instances, $(aggregationDepth)) - } else { - // instances will be standardized and converted to blocks, so no need to cache instances. - Summarizer.getClassificationSummarizers(instances, $(aggregationDepth), - Seq("mean", "std", "count", "numNonZeros")) + if (handlePersistence && $(blockSize) == 1) { + instances.persist(StorageLevel.MEMORY_AND_DISK) } + var requestedMetrics = Seq("mean", "std", "count") + if ($(blockSize) != 1) requestedMetrics +:= "numNonZeros" + val (summarizer, labelSummarizer) = Summarizer + .getClassificationSummarizers(instances, $(aggregationDepth), requestedMetrics) + val numFeatures = summarizer.mean.size val histogram = labelSummarizer.histogram val numInvalid = labelSummarizer.countInvalid @@ -591,7 +593,7 @@ class LogisticRegression @Since("1.2.0") ( } else { Vectors.dense(if (numClasses == 2) Double.PositiveInfinity else Double.NegativeInfinity) } - if (handlePersistence) instances.unpersist() + if (instances.getStorageLevel != StorageLevel.NONE) instances.unpersist() return createModel(dataset, numClasses, coefMatrix, interceptVec, Array.empty) } @@ -650,7 +652,7 @@ class LogisticRegression @Since("1.2.0") ( trainOnBlocks(instances, featuresStd, numClasses, initialCoefWithInterceptMatrix, regularization, optimizer) } - if (handlePersistence) instances.unpersist() + if (instances.getStorageLevel != StorageLevel.NONE) instances.unpersist() if (allCoefficients == null) { val msg = s"${optimizer.getClass.getName} failed." @@ -728,6 +730,7 @@ class LogisticRegression @Since("1.2.0") ( objectiveHistory: Array[Double]): LogisticRegressionModel = { val model = copyValues(new LogisticRegressionModel(uid, coefficientMatrix, interceptVector, numClasses, checkMultinomial(numClasses))) + val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel() val logRegSummary = if (numClasses <= 2) { @@ -737,6 +740,7 @@ class LogisticRegression @Since("1.2.0") ( predictionColName, $(labelCol), $(featuresCol), + weightColName, objectiveHistory) } else { new LogisticRegressionTrainingSummaryImpl( @@ -745,6 +749,7 @@ class LogisticRegression @Since("1.2.0") ( predictionColName, $(labelCol), $(featuresCol), + weightColName, objectiveHistory) } model.setSummary(Some(logRegSummary)) @@ -1002,7 +1007,7 @@ class LogisticRegression @Since("1.2.0") ( } val blocks = InstanceBlock.blokify(standardized, $(blockSize)) .persist(StorageLevel.MEMORY_AND_DISK) - .setName(s"training dataset (blockSize=${$(blockSize)})") + .setName(s"training blocks (blockSize=${$(blockSize)})") val getAggregatorFunc = new BlockLogisticAggregator(numFeatures, numClasses, $(fitIntercept), checkMultinomial(numClasses))(_) @@ -1184,14 +1189,15 @@ class LogisticRegressionModel private[spark] ( */ @Since("2.0.0") def evaluate(dataset: Dataset[_]): LogisticRegressionSummary = { + val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) // Handle possible missing or invalid prediction columns val (summaryModel, probabilityColName, predictionColName) = findSummaryModel() if (numClasses > 2) { new LogisticRegressionSummaryImpl(summaryModel.transform(dataset), - probabilityColName, predictionColName, $(labelCol), $(featuresCol)) + probabilityColName, predictionColName, $(labelCol), $(featuresCol), weightColName) } else { new BinaryLogisticRegressionSummaryImpl(summaryModel.transform(dataset), - probabilityColName, predictionColName, $(labelCol), $(featuresCol)) + probabilityColName, predictionColName, $(labelCol), $(featuresCol), weightColName) } } @@ -1389,8 +1395,6 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { /** * Abstraction for logistic regression results for a given model. - * - * Currently, the summary ignores the instance weights. */ sealed trait LogisticRegressionSummary extends Serializable { @@ -1416,12 +1420,28 @@ sealed trait LogisticRegressionSummary extends Serializable { @Since("1.6.0") def featuresCol: String + /** Field in "predictions" which gives the weight of each instance as a vector. */ + @Since("3.1.0") + def weightCol: String + @transient private val multiclassMetrics = { - new MulticlassMetrics( - predictions.select( - col(predictionCol), - col(labelCol).cast(DoubleType)) - .rdd.map { case Row(prediction: Double, label: Double) => (prediction, label) }) + if (predictions.schema.fieldNames.contains(weightCol)) { + new MulticlassMetrics( + predictions.select( + col(predictionCol), + col(labelCol).cast(DoubleType), + checkNonNegativeWeight(col(weightCol).cast(DoubleType))).rdd.map { + case Row(prediction: Double, label: Double, weight: Double) => (prediction, label, weight) + }) + } else { + new MulticlassMetrics( + predictions.select( + col(predictionCol), + col(labelCol).cast(DoubleType), + lit(1.0)).rdd.map { + case Row(prediction: Double, label: Double, weight: Double) => (prediction, label, weight) + }) + } } /** @@ -1537,8 +1557,6 @@ sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary /** * Abstraction for binary logistic regression results for a given model. - * - * Currently, the summary ignores the instance weights. */ sealed trait BinaryLogisticRegressionSummary extends LogisticRegressionSummary { @@ -1547,29 +1565,33 @@ sealed trait BinaryLogisticRegressionSummary extends LogisticRegressionSummary { // TODO: Allow the user to vary the number of bins using a setBins method in // BinaryClassificationMetrics. For now the default is set to 100. - @transient private val binaryMetrics = new BinaryClassificationMetrics( - predictions.select(col(probabilityCol), col(labelCol).cast(DoubleType)).rdd.map { - case Row(score: Vector, label: Double) => (score(1), label) - }, 100 - ) + @transient private val binaryMetrics = if (predictions.schema.fieldNames.contains(weightCol)) { + new BinaryClassificationMetrics( + predictions.select(col(probabilityCol), col(labelCol).cast(DoubleType), + checkNonNegativeWeight(col(weightCol).cast(DoubleType))).rdd.map { + case Row(score: Vector, label: Double, weight: Double) => (score(1), label, weight) + }, 100 + ) + } else { + new BinaryClassificationMetrics( + predictions.select(col(probabilityCol), col(labelCol).cast(DoubleType), + lit(1.0)).rdd.map { + case Row(score: Vector, label: Double, weight: Double) => (score(1), label, weight) + }, 100 + ) + } /** * Returns the receiver operating characteristic (ROC) curve, * which is a Dataframe having two fields (FPR, TPR) * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. * See http://en.wikipedia.org/wiki/Receiver_operating_characteristic - * - * @note This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. - * This will change in later Spark versions. */ @Since("1.5.0") @transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR") /** * Computes the area under the receiver operating characteristic (ROC) curve. - * - * @note This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. - * This will change in later Spark versions. */ @Since("1.5.0") lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC() @@ -1577,18 +1599,12 @@ sealed trait BinaryLogisticRegressionSummary extends LogisticRegressionSummary { /** * Returns the precision-recall curve, which is a Dataframe containing * two fields recall, precision with (0.0, 1.0) prepended to it. - * - * @note This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. - * This will change in later Spark versions. */ @Since("1.5.0") @transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall", "precision") /** * Returns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0. - * - * @note This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. - * This will change in later Spark versions. */ @Since("1.5.0") @transient lazy val fMeasureByThreshold: DataFrame = { @@ -1599,9 +1615,6 @@ sealed trait BinaryLogisticRegressionSummary extends LogisticRegressionSummary { * Returns a dataframe with two fields (threshold, precision) curve. * Every possible probability obtained in transforming the dataset are used * as thresholds used in calculating the precision. - * - * @note This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. - * This will change in later Spark versions. */ @Since("1.5.0") @transient lazy val precisionByThreshold: DataFrame = { @@ -1612,9 +1625,6 @@ sealed trait BinaryLogisticRegressionSummary extends LogisticRegressionSummary { * Returns a dataframe with two fields (threshold, recall) curve. * Every possible probability obtained in transforming the dataset are used * as thresholds used in calculating the recall. - * - * @note This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. - * This will change in later Spark versions. */ @Since("1.5.0") @transient lazy val recallByThreshold: DataFrame = { @@ -1624,8 +1634,6 @@ sealed trait BinaryLogisticRegressionSummary extends LogisticRegressionSummary { /** * Abstraction for binary logistic regression training results. - * Currently, the training summary ignores the training weights except - * for the objective trace. */ sealed trait BinaryLogisticRegressionTrainingSummary extends BinaryLogisticRegressionSummary with LogisticRegressionTrainingSummary @@ -1640,6 +1648,7 @@ sealed trait BinaryLogisticRegressionTrainingSummary extends BinaryLogisticRegre * double. * @param labelCol field in "predictions" which gives the true label of each instance. * @param featuresCol field in "predictions" which gives the features of each instance as a vector. + * @param weightCol field in "predictions" which gives the weight of each instance as a vector. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ private class LogisticRegressionTrainingSummaryImpl( @@ -1648,9 +1657,10 @@ private class LogisticRegressionTrainingSummaryImpl( predictionCol: String, labelCol: String, featuresCol: String, + weightCol: String, override val objectiveHistory: Array[Double]) extends LogisticRegressionSummaryImpl( - predictions, probabilityCol, predictionCol, labelCol, featuresCol) + predictions, probabilityCol, predictionCol, labelCol, featuresCol, weightCol) with LogisticRegressionTrainingSummary /** @@ -1663,13 +1673,15 @@ private class LogisticRegressionTrainingSummaryImpl( * double. * @param labelCol field in "predictions" which gives the true label of each instance. * @param featuresCol field in "predictions" which gives the features of each instance as a vector. + * @param weightCol field in "predictions" which gives the weight of each instance as a vector. */ private class LogisticRegressionSummaryImpl( @transient override val predictions: DataFrame, override val probabilityCol: String, override val predictionCol: String, override val labelCol: String, - override val featuresCol: String) + override val featuresCol: String, + override val weightCol: String) extends LogisticRegressionSummary /** @@ -1682,6 +1694,7 @@ private class LogisticRegressionSummaryImpl( * double. * @param labelCol field in "predictions" which gives the true label of each instance. * @param featuresCol field in "predictions" which gives the features of each instance as a vector. + * @param weightCol field in "predictions" which gives the weight of each instance as a vector. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ private class BinaryLogisticRegressionTrainingSummaryImpl( @@ -1690,9 +1703,10 @@ private class BinaryLogisticRegressionTrainingSummaryImpl( predictionCol: String, labelCol: String, featuresCol: String, + weightCol: String, override val objectiveHistory: Array[Double]) extends BinaryLogisticRegressionSummaryImpl( - predictions, probabilityCol, predictionCol, labelCol, featuresCol) + predictions, probabilityCol, predictionCol, labelCol, featuresCol, weightCol) with BinaryLogisticRegressionTrainingSummary /** @@ -1705,13 +1719,15 @@ private class BinaryLogisticRegressionTrainingSummaryImpl( * each class as a double. * @param labelCol field in "predictions" which gives the true label of each instance. * @param featuresCol field in "predictions" which gives the features of each instance as a vector. + * @param weightCol field in "predictions" which gives the weight of each instance as a vector. */ private class BinaryLogisticRegressionSummaryImpl( predictions: DataFrame, probabilityCol: String, predictionCol: String, labelCol: String, - featuresCol: String) + featuresCol: String, + weightCol: String) extends LogisticRegressionSummaryImpl( - predictions, probabilityCol, predictionCol, labelCol, featuresCol) + predictions, probabilityCol, predictionCol, labelCol, featuresCol, weightCol) with BinaryLogisticRegressionSummary diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 5459a0fab9135..e65295dbdaf55 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -22,6 +22,7 @@ import org.json4s.DefaultFormats import org.apache.spark.annotation.Since import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.HasWeightCol @@ -179,7 +180,7 @@ class NaiveBayes @Since("1.5.0") ( } val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) { - col($(weightCol)).cast(DoubleType) + checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) } else { lit(1.0) } @@ -259,7 +260,7 @@ class NaiveBayes @Since("1.5.0") ( import spark.implicits._ val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) { - col($(weightCol)).cast(DoubleType) + checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) } else { lit(1.0) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 6c7112b80569f..b09f11dcfe156 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -21,6 +21,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -280,7 +281,7 @@ class BisectingKMeans @Since("2.0.0") ( val handlePersistence = dataset.storageLevel == StorageLevel.NONE val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) { - col($(weightCol)).cast(DoubleType) + checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) } else { lit(1.0) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 1c4560aa5fdd7..18fd220b4ca9c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -22,6 +22,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.impl.Utils.{unpackUpperTriangular, EPSILON} import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param._ @@ -43,7 +44,7 @@ import org.apache.spark.storage.StorageLevel */ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter with HasFeaturesCol with HasSeed with HasPredictionCol with HasWeightCol with HasProbabilityCol with HasTol - with HasAggregationDepth { + with HasAggregationDepth with HasBlockSize { /** * Number of independent Gaussians in the mixture model. Must be greater than 1. Default: 2. @@ -279,8 +280,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] { * @param weights Weights for each Gaussian * @return Probability (partial assignment) for each of the k clusters */ - private[clustering] - def computeProbabilities( + private[clustering] def computeProbabilities( features: Vector, dists: Array[MultivariateGaussian], weights: Array[Double]): Array[Double] = { @@ -375,6 +375,25 @@ class GaussianMixture @Since("2.0.0") ( @Since("3.0.0") def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value) + /** + * Set block size for stacking input data in matrices. + * If blockSize == 1, then stacking will be skipped, and each vector is treated individually; + * If blockSize > 1, then vectors will be stacked to blocks, and high-level BLAS routines + * will be used if possible (for example, GEMV instead of DOT, GEMM instead of GEMV). + * Recommended size is between 10 and 1000. An appropriate choice of the block size depends + * on the sparsity and dim of input datasets, the underlying BLAS implementation (for example, + * f2jBLAS, OpenBLAS, intel MKL) and its configuration (for example, number of threads). + * Note that existing BLAS implementations are mainly optimized for dense matrices, if the + * input dataset is sparse, stacking may bring no performance gain, the worse is possible + * performance regression. + * Default is 1. + * + * @group expertSetParam + */ + @Since("3.1.0") + def setBlockSize(value: Int): this.type = set(blockSize, value) + setDefault(blockSize -> 1) + /** * Number of samples per cluster to use when initializing Gaussians. */ @@ -392,34 +411,63 @@ class GaussianMixture @Since("2.0.0") ( s"than ${GaussianMixture.MAX_NUM_FEATURES} features because the size of the covariance" + s" matrix is quadratic in the number of features.") - val handlePersistence = dataset.storageLevel == StorageLevel.NONE + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, featuresCol, predictionCol, probabilityCol, weightCol, k, maxIter, + seed, tol, aggregationDepth, blockSize) + instr.logNumFeatures(numFeatures) val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) { - col($(weightCol)).cast(DoubleType) + checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) } else { lit(1.0) } val instances = dataset.select(DatasetUtils.columnToVector(dataset, $(featuresCol)), w) - .as[(Vector, Double)] - .rdd + .as[(Vector, Double)].rdd + .setName("training instances") - if (handlePersistence) { + if ($(blockSize) == 1 && dataset.storageLevel == StorageLevel.NONE) { instances.persist(StorageLevel.MEMORY_AND_DISK) } - val sc = spark.sparkContext - val numClusters = $(k) + // TODO: SPARK-15785 Support users supplied initial GMM. + val (weights, gaussians) = initRandom(instances, $(k), numFeatures) - instr.logPipelineStage(this) - instr.logDataset(dataset) - instr.logParams(this, featuresCol, predictionCol, probabilityCol, weightCol, k, maxIter, - seed, tol, aggregationDepth) - instr.logNumFeatures(numFeatures) + val (logLikelihood, iteration) = if ($(blockSize) == 1) { + trainOnRows(instances, weights, gaussians, numFeatures, instr) + } else { + val sparsity = 1 - instances.map { case (v, _) => v.numNonzeros.toDouble / v.size }.mean() + instr.logNamedValue("sparsity", sparsity.toString) + if (sparsity > 0.5) { + logWarning(s"sparsity of input dataset is $sparsity, " + + s"which may hurt performance in high-level BLAS.") + } + trainOnBlocks(instances, weights, gaussians, numFeatures, instr) + } + if (instances.getStorageLevel != StorageLevel.NONE) instances.unpersist() - // TODO: SPARK-15785 Support users supplied initial GMM. - val (weights, gaussians) = initRandom(instances, numClusters, numFeatures) + val gaussianDists = gaussians.map { case (mean, covVec) => + val cov = GaussianMixture.unpackUpperTriangularMatrix(numFeatures, covVec.values) + new MultivariateGaussian(mean, cov) + } + + val model = copyValues(new GaussianMixtureModel(uid, weights, gaussianDists)) + .setParent(this) + val summary = new GaussianMixtureSummary(model.transform(dataset), + $(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood, iteration) + instr.logNamedValue("logLikelihood", logLikelihood) + instr.logNamedValue("clusterSizes", summary.clusterSizes) + model.setSummary(Some(summary)) + } + private def trainOnRows( + instances: RDD[(Vector, Double)], + weights: Array[Double], + gaussians: Array[(DenseVector, DenseVector)], + numFeatures: Int, + instr: Instrumentation): (Double, Int) = { + val sc = instances.sparkContext var logLikelihood = Double.MinValue var logLikelihoodPrev = 0.0 @@ -440,7 +488,7 @@ class GaussianMixture @Since("2.0.0") ( val ws = agg.weights.sum if (iteration == 0) weightSumAccum.add(ws) logLikelihoodAccum.add(agg.logLikelihood) - Iterator.tabulate(numClusters) { i => + Iterator.tabulate(bcWeights.value.length) { i => (i, (agg.means(i), agg.covs(i), agg.weights(i), ws)) } } else Iterator.empty @@ -471,21 +519,77 @@ class GaussianMixture @Since("2.0.0") ( instr.logNamedValue(s"logLikelihood@iter$iteration", logLikelihood) iteration += 1 } - if (handlePersistence) { - instances.unpersist() - } - val gaussianDists = gaussians.map { case (mean, covVec) => - val cov = GaussianMixture.unpackUpperTriangularMatrix(numFeatures, covVec.values) - new MultivariateGaussian(mean, cov) + (logLikelihood, iteration) + } + + private def trainOnBlocks( + instances: RDD[(Vector, Double)], + weights: Array[Double], + gaussians: Array[(DenseVector, DenseVector)], + numFeatures: Int, + instr: Instrumentation): (Double, Int) = { + val blocks = instances.mapPartitions { iter => + iter.grouped($(blockSize)) + .map { seq => (Matrices.fromVectors(seq.map(_._1)), seq.map(_._2).toArray) } + }.persist(StorageLevel.MEMORY_AND_DISK) + .setName(s"training dataset (blockSize=${$(blockSize)})") + + val sc = instances.sparkContext + var logLikelihood = Double.MinValue + var logLikelihoodPrev = 0.0 + + var iteration = 0 + while (iteration < $(maxIter) && math.abs(logLikelihood - logLikelihoodPrev) > $(tol)) { + val weightSumAccum = if (iteration == 0) sc.doubleAccumulator else null + val logLikelihoodAccum = sc.doubleAccumulator + val bcWeights = sc.broadcast(weights) + val bcGaussians = sc.broadcast(gaussians) + + // aggregate the cluster contribution for all sample points, + // and then compute the new distributions + blocks.mapPartitions { iter => + if (iter.nonEmpty) { + val agg = new BlockExpectationAggregator(numFeatures, + $(blockSize), bcWeights, bcGaussians) + while (iter.hasNext) { agg.add(iter.next) } + // sum of weights in this partition + val ws = agg.weights.sum + if (iteration == 0) weightSumAccum.add(ws) + logLikelihoodAccum.add(agg.logLikelihood) + agg.meanIter.zip(agg.covIter).zipWithIndex + .map { case ((mean, cov), i) => (i, (mean, cov, agg.weights(i), ws)) } + } else Iterator.empty + }.reduceByKey { case ((mean1, cov1, w1, ws1), (mean2, cov2, w2, ws2)) => + // update the weights, means and covariances for i-th distributions + BLAS.axpy(1.0, mean2, mean1) + BLAS.axpy(1.0, cov2, cov1) + (mean1, cov1, w1 + w2, ws1 + ws2) + }.mapValues { case (mean, cov, w, ws) => + // Create new distributions based on the partial assignments + // (often referred to as the "M" step in literature) + GaussianMixture.updateWeightsAndGaussians(mean, cov, w, ws) + }.collect().foreach { case (i, (weight, gaussian)) => + weights(i) = weight + gaussians(i) = gaussian + } + + bcWeights.destroy() + bcGaussians.destroy() + + if (iteration == 0) { + instr.logNumExamples(weightSumAccum.count) + instr.logSumOfWeights(weightSumAccum.value) + } + + logLikelihoodPrev = logLikelihood // current becomes previous + logLikelihood = logLikelihoodAccum.value // this is the freshly computed log-likelihood + instr.logNamedValue(s"logLikelihood@iter$iteration", logLikelihood) + iteration += 1 } + blocks.unpersist() - val model = copyValues(new GaussianMixtureModel(uid, weights, gaussianDists)).setParent(this) - val summary = new GaussianMixtureSummary(model.transform(dataset), - $(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood, iteration) - instr.logNamedValue("logLikelihood", logLikelihood) - instr.logNamedValue("clusterSizes", summary.clusterSizes) - model.setSummary(Some(summary)) + (logLikelihood, iteration) } @Since("2.0.0") @@ -626,16 +730,15 @@ private class ExpectationAggregator( bcWeights: Broadcast[Array[Double]], bcGaussians: Broadcast[Array[(DenseVector, DenseVector)]]) extends Serializable { - private val k: Int = bcWeights.value.length - private var totalCnt: Long = 0L - private var newLogLikelihood: Double = 0.0 - private lazy val newWeights: Array[Double] = Array.ofDim[Double](k) - private lazy val newMeans: Array[DenseVector] = Array.fill(k)( - new DenseVector(Array.ofDim[Double](numFeatures))) - private lazy val newCovs: Array[DenseVector] = Array.fill(k)( - new DenseVector(Array.ofDim[Double](numFeatures * (numFeatures + 1) / 2))) + private val k = bcWeights.value.length + private var totalCnt = 0L + private var newLogLikelihood = 0.0 + private val covSize = numFeatures * (numFeatures + 1) / 2 + private lazy val newWeights = Array.ofDim[Double](k) + @transient private lazy val newMeans = Array.fill(k)(Vectors.zeros(numFeatures).toDense) + @transient private lazy val newCovs = Array.fill(k)(Vectors.zeros(covSize).toDense) - @transient private lazy val oldGaussians = { + @transient private lazy val gaussians = { bcGaussians.value.map { case (mean, covVec) => val cov = GaussianMixture.unpackUpperTriangularMatrix(numFeatures, covVec.values) new MultivariateGaussian(mean, cov) @@ -656,19 +759,19 @@ private class ExpectationAggregator( * Add a new training instance to this ExpectationAggregator, update the weights, * means and covariances for each distributions, and update the log likelihood. * - * @param weightedVector The instance of data point to be added. + * @param instance The instance of data point to be added. * @return This ExpectationAggregator object. */ - def add(weightedVector: (Vector, Double)): this.type = { - val (instance: Vector, weight: Double) = weightedVector + def add(instance: (Vector, Double)): this.type = { + val (vector: Vector, weight: Double) = instance val localWeights = bcWeights.value - val localOldGaussians = oldGaussians + val localGaussians = gaussians val prob = new Array[Double](k) var probSum = 0.0 var i = 0 while (i < k) { - val p = EPSILON + localWeights(i) * localOldGaussians(i).pdf(instance) + val p = EPSILON + localWeights(i) * localGaussians(i).pdf(vector) prob(i) = p probSum += p i += 1 @@ -682,42 +785,128 @@ private class ExpectationAggregator( while (i < k) { val w = prob(i) / probSum * weight localNewWeights(i) += w - BLAS.axpy(w, instance, localNewMeans(i)) - BLAS.spr(w, instance, localNewCovs(i)) + BLAS.axpy(w, vector, localNewMeans(i)) + BLAS.spr(w, vector, localNewCovs(i)) i += 1 } totalCnt += 1 this } +} + + +/** + * BlockExpectationAggregator computes the partial expectation results. + * + * @param numFeatures The number of features. + * @param bcWeights The broadcast weights for each Gaussian distribution in the mixture. + * @param bcGaussians The broadcast array of Multivariate Gaussian (Normal) Distribution + * in the mixture. Note only upper triangular part of the covariance + * matrix of each distribution is stored as dense vector (column major) + * in order to reduce shuffled data size. + */ +private class BlockExpectationAggregator( + numFeatures: Int, + blockSize: Int, + bcWeights: Broadcast[Array[Double]], + bcGaussians: Broadcast[Array[(DenseVector, DenseVector)]]) extends Serializable { + + private val k = bcWeights.value.length + private var totalCnt = 0L + private var newLogLikelihood = 0.0 + private val covSize = numFeatures * (numFeatures + 1) / 2 + private lazy val newWeights = Array.ofDim[Double](k) + @transient private lazy val newMeansMat = DenseMatrix.zeros(numFeatures, k) + @transient private lazy val newCovsMat = DenseMatrix.zeros(covSize, k) + @transient private lazy val auxiliaryProbMat = DenseMatrix.zeros(blockSize, k) + @transient private lazy val auxiliaryPDFMat = DenseMatrix.zeros(blockSize, numFeatures) + @transient private lazy val auxiliaryCovVec = Vectors.zeros(covSize).toDense + + @transient private lazy val gaussians = { + bcGaussians.value.map { case (mean, covVec) => + val cov = GaussianMixture.unpackUpperTriangularMatrix(numFeatures, covVec.values) + new MultivariateGaussian(mean, cov) + } + } + + def count: Long = totalCnt + + def logLikelihood: Double = newLogLikelihood + + def weights: Array[Double] = newWeights + + def meanIter: Iterator[DenseVector] = newMeansMat.colIter.map(_.toDense) + + def covIter: Iterator[DenseVector] = newCovsMat.colIter.map(_.toDense) /** - * Merge another ExpectationAggregator, update the weights, means and covariances - * for each distributions, and update the log likelihood. - * (Note that it's in place merging; as a result, `this` object will be modified.) + * Add a new training instance block to this BlockExpectationAggregator, update the weights, + * means and covariances for each distributions, and update the log likelihood. * - * @param other The other ExpectationAggregator to be merged. - * @return This ExpectationAggregator object. + * @param block The instance block of data point to be added. + * @return This BlockExpectationAggregator object. */ - def merge(other: ExpectationAggregator): this.type = { - if (other.count != 0) { - totalCnt += other.totalCnt - - val localThisNewWeights = this.newWeights - val localOtherNewWeights = other.newWeights - val localThisNewMeans = this.newMeans - val localOtherNewMeans = other.newMeans - val localThisNewCovs = this.newCovs - val localOtherNewCovs = other.newCovs - var i = 0 - while (i < k) { - localThisNewWeights(i) += localOtherNewWeights(i) - BLAS.axpy(1.0, localOtherNewMeans(i), localThisNewMeans(i)) - BLAS.axpy(1.0, localOtherNewCovs(i), localThisNewCovs(i)) - i += 1 - } - newLogLikelihood += other.newLogLikelihood + def add(block: (Matrix, Array[Double])): this.type = { + val (matrix: Matrix, weights: Array[Double]) = block + require(matrix.isTransposed) + val size = matrix.numRows + require(weights.length == size) + + val blas1 = BLAS.getBLAS(size) + val blas2 = BLAS.getBLAS(k) + + val probMat = if (blockSize == size) auxiliaryProbMat else DenseMatrix.zeros(size, k) + require(!probMat.isTransposed) + java.util.Arrays.fill(probMat.values, EPSILON) + + val pdfMat = if (blockSize == size) auxiliaryPDFMat else DenseMatrix.zeros(size, numFeatures) + var j = 0 + while (j < k) { + val pdfVec = gaussians(j).pdf(matrix, pdfMat) + blas1.daxpy(size, bcWeights.value(j), pdfVec.values, 0, 1, probMat.values, j * size, 1) + j += 1 + } + + var i = 0 + while (i < size) { + val weight = weights(i) + val probSum = blas2.dasum(k, probMat.values, i, size) + blas2.dscal(k, weight / probSum, probMat.values, i, size) + blas2.daxpy(k, 1.0, probMat.values, i, size, newWeights, 0, 1) + newLogLikelihood += math.log(probSum) * weight + i += 1 + } + + BLAS.gemm(1.0, matrix.transpose, probMat, 1.0, newMeansMat) + + // compute the cov vector for each row vector + val covVec = auxiliaryCovVec + val covVecIter = matrix match { + case dm: DenseMatrix => + Iterator.tabulate(size) { i => + java.util.Arrays.fill(covVec.values, 0.0) + // when input block is dense, directly use nativeBLAS to avoid array copy + BLAS.nativeBLAS.dspr("U", numFeatures, 1.0, dm.values, i * numFeatures, 1, + covVec.values, 0) + covVec + } + + case sm: SparseMatrix => + sm.rowIter.map { vec => + java.util.Arrays.fill(covVec.values, 0.0) + BLAS.spr(1.0, vec, covVec) + covVec + } } + + covVecIter.zipWithIndex.foreach { case (covVec, i) => + BLAS.nativeBLAS.dger(covSize, k, 1.0, covVec.values, 0, 1, + probMat.values, i, size, newCovsMat.values, 0, covSize) + } + + totalCnt += size + this } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index a42c920e24987..806015b633c23 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model, PipelineStage} +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -336,7 +337,7 @@ class KMeans @Since("1.5.0") ( val handlePersistence = dataset.storageLevel == StorageLevel.NONE val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) { - col($(weightCol)).cast(DoubleType) + checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) } else { lit(1.0) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 82b8e14f010af..52be22f714981 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.Since +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -98,6 +99,24 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va @Since("2.0.0") override def evaluate(dataset: Dataset[_]): Double = { + val metrics = getMetrics(dataset) + val metric = $(metricName) match { + case "areaUnderROC" => metrics.areaUnderROC() + case "areaUnderPR" => metrics.areaUnderPR() + } + metrics.unpersist() + metric + } + + /** + * Get a BinaryClassificationMetrics, which can be used to get binary classification + * metrics such as areaUnderROC and areaUnderPR. + * + * @param dataset a dataset that contains labels/observations and predictions. + * @return BinaryClassificationMetrics + */ + @Since("3.1.0") + def getMetrics(dataset: Dataset[_]): BinaryClassificationMetrics = { val schema = dataset.schema SchemaUtils.checkColumnTypes(schema, $(rawPredictionCol), Seq(DoubleType, new VectorUDT)) SchemaUtils.checkNumericType(schema, $(labelCol)) @@ -113,19 +132,13 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va col($(rawPredictionCol)), col($(labelCol)).cast(DoubleType), if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) - else col($(weightCol)).cast(DoubleType)).rdd.map { + else checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))).rdd.map { case Row(rawPrediction: Vector, label: Double, weight: Double) => (rawPrediction(1), label, weight) case Row(rawPrediction: Double, label: Double, weight: Double) => (rawPrediction, label, weight) } - val metrics = new BinaryClassificationMetrics(scoreAndLabelsWithWeights, $(numBins)) - val metric = $(metricName) match { - case "areaUnderROC" => metrics.areaUnderROC() - case "areaUnderPR" => metrics.areaUnderPR() - } - metrics.unpersist() - metric + new BinaryClassificationMetrics(scoreAndLabelsWithWeights, $(numBins)) } @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala index 641a1eb5f61db..fa2c25a5912a7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala @@ -17,15 +17,13 @@ package org.apache.spark.ml.evaluation -import org.apache.spark.SparkContext import org.apache.spark.annotation.Since -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors} +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} -import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol} +import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol, HasWeightCol} import org.apache.spark.ml.util._ -import org.apache.spark.sql.{Column, DataFrame, Dataset} -import org.apache.spark.sql.functions.{avg, col, udf} +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DoubleType /** @@ -38,7 +36,8 @@ import org.apache.spark.sql.types.DoubleType */ @Since("2.3.0") class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: String) - extends Evaluator with HasPredictionCol with HasFeaturesCol with DefaultParamsWritable { + extends Evaluator with HasPredictionCol with HasFeaturesCol with HasWeightCol + with DefaultParamsWritable { @Since("2.3.0") def this() = this(Identifiable.randomUID("cluEval")) @@ -57,6 +56,10 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str @Since("2.3.0") def setFeaturesCol(value: String): this.type = set(featuresCol, value) + /** @group setParam */ + @Since("3.1.0") + def setWeightCol(value: String): this.type = set(weightCol, value) + /** * param for metric name in evaluation * (supports `"silhouette"` (default)) @@ -102,557 +105,62 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str @Since("2.3.0") override def evaluate(dataset: Dataset[_]): Double = { - SchemaUtils.validateVectorCompatibleColumn(dataset.schema, $(featuresCol)) - SchemaUtils.checkNumericType(dataset.schema, $(predictionCol)) + val metrics = getMetrics(dataset) - val vectorCol = DatasetUtils.columnToVector(dataset, $(featuresCol)) - val df = dataset.select(col($(predictionCol)), - vectorCol.as($(featuresCol), dataset.schema($(featuresCol)).metadata)) - - ($(metricName), $(distanceMeasure)) match { - case ("silhouette", "squaredEuclidean") => - SquaredEuclideanSilhouette.computeSilhouetteScore( - df, $(predictionCol), $(featuresCol)) - case ("silhouette", "cosine") => - CosineSilhouette.computeSilhouetteScore(df, $(predictionCol), $(featuresCol)) - case (mn, dm) => - throw new IllegalArgumentException(s"No support for metric $mn, distance $dm") + $(metricName) match { + case ("silhouette") => metrics.silhouette + case (other) => + throw new IllegalArgumentException(s"No support for metric $other") } } - @Since("3.0.0") - override def toString: String = { - s"ClusteringEvaluator: uid=$uid, metricName=${$(metricName)}, " + - s"distanceMeasure=${$(distanceMeasure)}" - } -} - - -@Since("2.3.0") -object ClusteringEvaluator - extends DefaultParamsReadable[ClusteringEvaluator] { - - @Since("2.3.0") - override def load(path: String): ClusteringEvaluator = super.load(path) - -} - - -private[evaluation] abstract class Silhouette { - - /** - * It computes the Silhouette coefficient for a point. - */ - def pointSilhouetteCoefficient( - clusterIds: Set[Double], - pointClusterId: Double, - pointClusterNumOfPoints: Long, - averageDistanceToCluster: (Double) => Double): Double = { - if (pointClusterNumOfPoints == 1) { - // Single-element clusters have silhouette 0 - 0.0 - } else { - // Here we compute the average dissimilarity of the current point to any cluster of which the - // point is not a member. - // The cluster with the lowest average dissimilarity - i.e. the nearest cluster to the current - // point - is said to be the "neighboring cluster". - val otherClusterIds = clusterIds.filter(_ != pointClusterId) - val neighboringClusterDissimilarity = otherClusterIds.map(averageDistanceToCluster).min - // adjustment for excluding the node itself from the computation of the average dissimilarity - val currentClusterDissimilarity = - averageDistanceToCluster(pointClusterId) * pointClusterNumOfPoints / - (pointClusterNumOfPoints - 1) - if (currentClusterDissimilarity < neighboringClusterDissimilarity) { - 1 - (currentClusterDissimilarity / neighboringClusterDissimilarity) - } else if (currentClusterDissimilarity > neighboringClusterDissimilarity) { - (neighboringClusterDissimilarity / currentClusterDissimilarity) - 1 - } else { - 0.0 - } - } - } - - /** - * Compute the mean Silhouette values of all samples. - */ - def overallScore(df: DataFrame, scoreColumn: Column): Double = { - df.select(avg(scoreColumn)).collect()(0).getDouble(0) - } -} - -/** - * SquaredEuclideanSilhouette computes the average of the - * Silhouette over all the data of the dataset, which is - * a measure of how appropriately the data have been clustered. - * - * The Silhouette for each point `i` is defined as: - * - *
- * $$ - * s_{i} = \frac{b_{i}-a_{i}}{max\{a_{i},b_{i}\}} - * $$ - *
- * - * which can be rewritten as - * - *
- * $$ - * s_{i}= \begin{cases} - * 1-\frac{a_{i}}{b_{i}} & \text{if } a_{i} \leq b_{i} \\ - * \frac{b_{i}}{a_{i}}-1 & \text{if } a_{i} \gt b_{i} \end{cases} - * $$ - *
- * - * where `$a_{i}$` is the average dissimilarity of `i` with all other data - * within the same cluster, `$b_{i}$` is the lowest average dissimilarity - * of `i` to any other cluster, of which `i` is not a member. - * `$a_{i}$` can be interpreted as how well `i` is assigned to its cluster - * (the smaller the value, the better the assignment), while `$b_{i}$` is - * a measure of how well `i` has not been assigned to its "neighboring cluster", - * ie. the nearest cluster to `i`. - * - * Unfortunately, the naive implementation of the algorithm requires to compute - * the distance of each couple of points in the dataset. Since the computation of - * the distance measure takes `D` operations - if `D` is the number of dimensions - * of each point, the computational complexity of the algorithm is `O(N^2^*D)`, where - * `N` is the cardinality of the dataset. Of course this is not scalable in `N`, - * which is the critical number in a Big Data context. - * - * The algorithm which is implemented in this object, instead, is an efficient - * and parallel implementation of the Silhouette using the squared Euclidean - * distance measure. - * - * With this assumption, the total distance of the point `X` - * to the points `$C_{i}$` belonging to the cluster `$\Gamma$` is: - * - *
- * $$ - * \sum\limits_{i=1}^N d(X, C_{i} ) = - * \sum\limits_{i=1}^N \Big( \sum\limits_{j=1}^D (x_{j}-c_{ij})^2 \Big) - * = \sum\limits_{i=1}^N \Big( \sum\limits_{j=1}^D x_{j}^2 + - * \sum\limits_{j=1}^D c_{ij}^2 -2\sum\limits_{j=1}^D x_{j}c_{ij} \Big) - * = \sum\limits_{i=1}^N \sum\limits_{j=1}^D x_{j}^2 + - * \sum\limits_{i=1}^N \sum\limits_{j=1}^D c_{ij}^2 - * -2 \sum\limits_{i=1}^N \sum\limits_{j=1}^D x_{j}c_{ij} - * $$ - *
- * - * where `$x_{j}$` is the `j`-th dimension of the point `X` and - * `$c_{ij}$` is the `j`-th dimension of the `i`-th point in cluster `$\Gamma$`. - * - * Then, the first term of the equation can be rewritten as: - * - *
- * $$ - * \sum\limits_{i=1}^N \sum\limits_{j=1}^D x_{j}^2 = N \xi_{X} \text{ , - * with } \xi_{X} = \sum\limits_{j=1}^D x_{j}^2 - * $$ - *
- * - * where `$\xi_{X}$` is fixed for each point and it can be precomputed. - * - * Moreover, the second term is fixed for each cluster too, - * thus we can name it `$\Psi_{\Gamma}$` - * - *
- * $$ - * \sum\limits_{i=1}^N \sum\limits_{j=1}^D c_{ij}^2 = - * \sum\limits_{i=1}^N \xi_{C_{i}} = \Psi_{\Gamma} - * $$ - *
- * - * Last, the third element becomes - * - *
- * $$ - * \sum\limits_{i=1}^N \sum\limits_{j=1}^D x_{j}c_{ij} = - * \sum\limits_{j=1}^D \Big(\sum\limits_{i=1}^N c_{ij} \Big) x_{j} - * $$ - *
- * - * thus defining the vector - * - *
- * $$ - * Y_{\Gamma}:Y_{\Gamma j} = \sum\limits_{i=1}^N c_{ij} , j=0, ..., D - * $$ - *
- * - * which is fixed for each cluster `$\Gamma$`, we have - * - *
- * $$ - * \sum\limits_{j=1}^D \Big(\sum\limits_{i=1}^N c_{ij} \Big) x_{j} = - * \sum\limits_{j=1}^D Y_{\Gamma j} x_{j} - * $$ - *
- * - * In this way, the previous equation becomes - * - *
- * $$ - * N\xi_{X} + \Psi_{\Gamma} - 2 \sum\limits_{j=1}^D Y_{\Gamma j} x_{j} - * $$ - *
- * - * and the average distance of a point to a cluster can be computed as - * - *
- * $$ - * \frac{\sum\limits_{i=1}^N d(X, C_{i} )}{N} = - * \frac{N\xi_{X} + \Psi_{\Gamma} - 2 \sum\limits_{j=1}^D Y_{\Gamma j} x_{j}}{N} = - * \xi_{X} + \frac{\Psi_{\Gamma} }{N} - 2 \frac{\sum\limits_{j=1}^D Y_{\Gamma j} x_{j}}{N} - * $$ - *
- * - * Thus, it is enough to precompute: the constant `$\xi_{X}$` for each point `X`; the - * constants `$\Psi_{\Gamma}$`, `N` and the vector `$Y_{\Gamma}$` for - * each cluster `$\Gamma$`. - * - * In the implementation, the precomputed values for the clusters - * are distributed among the worker nodes via broadcasted variables, - * because we can assume that the clusters are limited in number and - * anyway they are much fewer than the points. - * - * The main strengths of this algorithm are the low computational complexity - * and the intrinsic parallelism. The precomputed information for each point - * and for each cluster can be computed with a computational complexity - * which is `O(N/W)`, where `N` is the number of points in the dataset and - * `W` is the number of worker nodes. After that, every point can be - * analyzed independently of the others. - * - * For every point we need to compute the average distance to all the clusters. - * Since the formula above requires `O(D)` operations, this phase has a - * computational complexity which is `O(C*D*N/W)` where `C` is the number of - * clusters (which we assume quite low), `D` is the number of dimensions, - * `N` is the number of points in the dataset and `W` is the number - * of worker nodes. - */ -private[evaluation] object SquaredEuclideanSilhouette extends Silhouette { - - private[this] var kryoRegistrationPerformed: Boolean = false - /** - * This method registers the class - * [[org.apache.spark.ml.evaluation.SquaredEuclideanSilhouette.ClusterStats]] - * for kryo serialization. + * Get a ClusteringMetrics, which can be used to get clustering metrics such as + * silhouette score. * - * @param sc `SparkContext` to be used + * @param dataset a dataset that contains labels/observations and predictions. + * @return ClusteringMetrics */ - def registerKryoClasses(sc: SparkContext): Unit = { - if (!kryoRegistrationPerformed) { - sc.getConf.registerKryoClasses( - Array( - classOf[SquaredEuclideanSilhouette.ClusterStats] - ) - ) - kryoRegistrationPerformed = true + @Since("3.1.0") + def getMetrics(dataset: Dataset[_]): ClusteringMetrics = { + val schema = dataset.schema + SchemaUtils.validateVectorCompatibleColumn(schema, $(featuresCol)) + SchemaUtils.checkNumericType(schema, $(predictionCol)) + if (isDefined(weightCol)) { + SchemaUtils.checkNumericType(schema, $(weightCol)) } - } - case class ClusterStats(featureSum: Vector, squaredNormSum: Double, numOfPoints: Long) - - /** - * The method takes the input dataset and computes the aggregated values - * about a cluster which are needed by the algorithm. - * - * @param df The DataFrame which contains the input data - * @param predictionCol The name of the column which contains the predicted cluster id - * for the point. - * @param featuresCol The name of the column which contains the feature vector of the point. - * @return A [[scala.collection.immutable.Map]] which associates each cluster id - * to a [[ClusterStats]] object (which contains the precomputed values `N`, - * `$\Psi_{\Gamma}$` and `$Y_{\Gamma}$` for a cluster). - */ - def computeClusterStats( - df: DataFrame, - predictionCol: String, - featuresCol: String): Map[Double, ClusterStats] = { - val numFeatures = MetadataUtils.getNumFeatures(df, featuresCol) - val clustersStatsRDD = df.select( - col(predictionCol).cast(DoubleType), col(featuresCol), col("squaredNorm")) - .rdd - .map { row => (row.getDouble(0), (row.getAs[Vector](1), row.getDouble(2))) } - .aggregateByKey[(DenseVector, Double, Long)]((Vectors.zeros(numFeatures).toDense, 0.0, 0L))( - seqOp = { - case ( - (featureSum: DenseVector, squaredNormSum: Double, numOfPoints: Long), - (features, squaredNorm) - ) => - BLAS.axpy(1.0, features, featureSum) - (featureSum, squaredNormSum + squaredNorm, numOfPoints + 1) - }, - combOp = { - case ( - (featureSum1, squaredNormSum1, numOfPoints1), - (featureSum2, squaredNormSum2, numOfPoints2) - ) => - BLAS.axpy(1.0, featureSum2, featureSum1) - (featureSum1, squaredNormSum1 + squaredNormSum2, numOfPoints1 + numOfPoints2) - } - ) - - clustersStatsRDD - .collectAsMap() - .mapValues { - case (featureSum: DenseVector, squaredNormSum: Double, numOfPoints: Long) => - SquaredEuclideanSilhouette.ClusterStats(featureSum, squaredNormSum, numOfPoints) - } - .toMap - } + val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) - /** - * It computes the Silhouette coefficient for a point. - * - * @param broadcastedClustersMap A map of the precomputed values for each cluster. - * @param point The [[org.apache.spark.ml.linalg.Vector]] representing the current point. - * @param clusterId The id of the cluster the current point belongs to. - * @param squaredNorm The `$\Xi_{X}$` (which is the squared norm) precomputed for the point. - * @return The Silhouette for the point. - */ - def computeSilhouetteCoefficient( - broadcastedClustersMap: Broadcast[Map[Double, ClusterStats]], - point: Vector, - clusterId: Double, - squaredNorm: Double): Double = { - - def compute(targetClusterId: Double): Double = { - val clusterStats = broadcastedClustersMap.value(targetClusterId) - val pointDotClusterFeaturesSum = BLAS.dot(point, clusterStats.featureSum) - - squaredNorm + - clusterStats.squaredNormSum / clusterStats.numOfPoints - - 2 * pointDotClusterFeaturesSum / clusterStats.numOfPoints + val vectorCol = DatasetUtils.columnToVector(dataset, $(featuresCol)) + val df = if (!isDefined(weightCol) || $(weightCol).isEmpty) { + dataset.select(col($(predictionCol)), + vectorCol.as($(featuresCol), dataset.schema($(featuresCol)).metadata), + lit(1.0).as(weightColName)) + } else { + dataset.select(col($(predictionCol)), + vectorCol.as($(featuresCol), dataset.schema($(featuresCol)).metadata), + checkNonNegativeWeight(col(weightColName).cast(DoubleType))) } - pointSilhouetteCoefficient(broadcastedClustersMap.value.keySet, - clusterId, - broadcastedClustersMap.value(clusterId).numOfPoints, - compute) + val metrics = new ClusteringMetrics(df) + metrics.setDistanceMeasure($(distanceMeasure)) + metrics } - /** - * Compute the Silhouette score of the dataset using squared Euclidean distance measure. - * - * @param dataset The input dataset (previously clustered) on which compute the Silhouette. - * @param predictionCol The name of the column which contains the predicted cluster id - * for the point. - * @param featuresCol The name of the column which contains the feature vector of the point. - * @return The average of the Silhouette values of the clustered data. - */ - def computeSilhouetteScore( - dataset: Dataset[_], - predictionCol: String, - featuresCol: String): Double = { - SquaredEuclideanSilhouette.registerKryoClasses(dataset.sparkSession.sparkContext) - - val squaredNormUDF = udf { - features: Vector => math.pow(Vectors.norm(features, 2.0), 2.0) - } - val dfWithSquaredNorm = dataset.withColumn("squaredNorm", squaredNormUDF(col(featuresCol))) - - // compute aggregate values for clusters needed by the algorithm - val clustersStatsMap = SquaredEuclideanSilhouette - .computeClusterStats(dfWithSquaredNorm, predictionCol, featuresCol) - - // Silhouette is reasonable only when the number of clusters is greater then 1 - assert(clustersStatsMap.size > 1, "Number of clusters must be greater than one.") - - val bClustersStatsMap = dataset.sparkSession.sparkContext.broadcast(clustersStatsMap) - - val computeSilhouetteCoefficientUDF = udf { - computeSilhouetteCoefficient(bClustersStatsMap, _: Vector, _: Double, _: Double) - } - - val silhouetteScore = overallScore(dfWithSquaredNorm, - computeSilhouetteCoefficientUDF(col(featuresCol), col(predictionCol).cast(DoubleType), - col("squaredNorm"))) - - bClustersStatsMap.destroy() - - silhouetteScore + @Since("3.0.0") + override def toString: String = { + s"ClusteringEvaluator: uid=$uid, metricName=${$(metricName)}, " + + s"distanceMeasure=${$(distanceMeasure)}" } } -/** - * The algorithm which is implemented in this object, instead, is an efficient and parallel - * implementation of the Silhouette using the cosine distance measure. The cosine distance - * measure is defined as `1 - s` where `s` is the cosine similarity between two points. - * - * The total distance of the point `X` to the points `$C_{i}$` belonging to the cluster `$\Gamma$` - * is: - * - *
- * $$ - * \sum\limits_{i=1}^N d(X, C_{i} ) = - * \sum\limits_{i=1}^N \Big( 1 - \frac{\sum\limits_{j=1}^D x_{j}c_{ij} }{ \|X\|\|C_{i}\|} \Big) - * = \sum\limits_{i=1}^N 1 - \sum\limits_{i=1}^N \sum\limits_{j=1}^D \frac{x_{j}}{\|X\|} - * \frac{c_{ij}}{\|C_{i}\|} - * = N - \sum\limits_{j=1}^D \frac{x_{j}}{\|X\|} \Big( \sum\limits_{i=1}^N - * \frac{c_{ij}}{\|C_{i}\|} \Big) - * $$ - *
- * - * where `$x_{j}$` is the `j`-th dimension of the point `X` and `$c_{ij}$` is the `j`-th dimension - * of the `i`-th point in cluster `$\Gamma$`. - * - * Then, we can define the vector: - * - *
- * $$ - * \xi_{X} : \xi_{X i} = \frac{x_{i}}{\|X\|}, i = 1, ..., D - * $$ - *
- * - * which can be precomputed for each point and the vector - * - *
- * $$ - * \Omega_{\Gamma} : \Omega_{\Gamma i} = \sum\limits_{j=1}^N \xi_{C_{j}i}, i = 1, ..., D - * $$ - *
- * - * which can be precomputed too for each cluster `$\Gamma$` by its points `$C_{i}$`. - * - * With these definitions, the numerator becomes: - * - *
- * $$ - * N - \sum\limits_{j=1}^D \xi_{X j} \Omega_{\Gamma j} - * $$ - *
- * - * Thus the average distance of a point `X` to the points of the cluster `$\Gamma$` is: - * - *
- * $$ - * 1 - \frac{\sum\limits_{j=1}^D \xi_{X j} \Omega_{\Gamma j}}{N} - * $$ - *
- * - * In the implementation, the precomputed values for the clusters are distributed among the worker - * nodes via broadcasted variables, because we can assume that the clusters are limited in number. - * - * The main strengths of this algorithm are the low computational complexity and the intrinsic - * parallelism. The precomputed information for each point and for each cluster can be computed - * with a computational complexity which is `O(N/W)`, where `N` is the number of points in the - * dataset and `W` is the number of worker nodes. After that, every point can be analyzed - * independently from the others. - * - * For every point we need to compute the average distance to all the clusters. Since the formula - * above requires `O(D)` operations, this phase has a computational complexity which is - * `O(C*D*N/W)` where `C` is the number of clusters (which we assume quite low), `D` is the number - * of dimensions, `N` is the number of points in the dataset and `W` is the number of worker - * nodes. - */ -private[evaluation] object CosineSilhouette extends Silhouette { - - private[this] val normalizedFeaturesColName = "normalizedFeatures" - - /** - * The method takes the input dataset and computes the aggregated values - * about a cluster which are needed by the algorithm. - * - * @param df The DataFrame which contains the input data - * @param predictionCol The name of the column which contains the predicted cluster id - * for the point. - * @return A [[scala.collection.immutable.Map]] which associates each cluster id to a - * its statistics (ie. the precomputed values `N` and `$\Omega_{\Gamma}$`). - */ - def computeClusterStats( - df: DataFrame, - featuresCol: String, - predictionCol: String): Map[Double, (Vector, Long)] = { - val numFeatures = MetadataUtils.getNumFeatures(df, featuresCol) - val clustersStatsRDD = df.select( - col(predictionCol).cast(DoubleType), col(normalizedFeaturesColName)) - .rdd - .map { row => (row.getDouble(0), row.getAs[Vector](1)) } - .aggregateByKey[(DenseVector, Long)]((Vectors.zeros(numFeatures).toDense, 0L))( - seqOp = { - case ((normalizedFeaturesSum: DenseVector, numOfPoints: Long), (normalizedFeatures)) => - BLAS.axpy(1.0, normalizedFeatures, normalizedFeaturesSum) - (normalizedFeaturesSum, numOfPoints + 1) - }, - combOp = { - case ((normalizedFeaturesSum1, numOfPoints1), (normalizedFeaturesSum2, numOfPoints2)) => - BLAS.axpy(1.0, normalizedFeaturesSum2, normalizedFeaturesSum1) - (normalizedFeaturesSum1, numOfPoints1 + numOfPoints2) - } - ) - - clustersStatsRDD - .collectAsMap() - .toMap - } - - /** - * It computes the Silhouette coefficient for a point. - * - * @param broadcastedClustersMap A map of the precomputed values for each cluster. - * @param normalizedFeatures The [[org.apache.spark.ml.linalg.Vector]] representing the - * normalized features of the current point. - * @param clusterId The id of the cluster the current point belongs to. - */ - def computeSilhouetteCoefficient( - broadcastedClustersMap: Broadcast[Map[Double, (Vector, Long)]], - normalizedFeatures: Vector, - clusterId: Double): Double = { - - def compute(targetClusterId: Double): Double = { - val (normalizedFeatureSum, numOfPoints) = broadcastedClustersMap.value(targetClusterId) - 1 - BLAS.dot(normalizedFeatures, normalizedFeatureSum) / numOfPoints - } - - pointSilhouetteCoefficient(broadcastedClustersMap.value.keySet, - clusterId, - broadcastedClustersMap.value(clusterId)._2, - compute) - } - - /** - * Compute the Silhouette score of the dataset using the cosine distance measure. - * - * @param dataset The input dataset (previously clustered) on which compute the Silhouette. - * @param predictionCol The name of the column which contains the predicted cluster id - * for the point. - * @param featuresCol The name of the column which contains the feature vector of the point. - * @return The average of the Silhouette values of the clustered data. - */ - def computeSilhouetteScore( - dataset: Dataset[_], - predictionCol: String, - featuresCol: String): Double = { - val normalizeFeatureUDF = udf { - features: Vector => { - val norm = Vectors.norm(features, 2.0) - BLAS.scal(1.0 / norm, features) - features - } - } - val dfWithNormalizedFeatures = dataset.withColumn(normalizedFeaturesColName, - normalizeFeatureUDF(col(featuresCol))) - - // compute aggregate values for clusters needed by the algorithm - val clustersStatsMap = computeClusterStats(dfWithNormalizedFeatures, featuresCol, - predictionCol) - - // Silhouette is reasonable only when the number of clusters is greater then 1 - assert(clustersStatsMap.size > 1, "Number of clusters must be greater than one.") - - val bClustersStatsMap = dataset.sparkSession.sparkContext.broadcast(clustersStatsMap) - - val computeSilhouetteCoefficientUDF = udf { - computeSilhouetteCoefficient(bClustersStatsMap, _: Vector, _: Double) - } - - val silhouetteScore = overallScore(dfWithNormalizedFeatures, - computeSilhouetteCoefficientUDF(col(normalizedFeaturesColName), - col(predictionCol).cast(DoubleType))) +@Since("2.3.0") +object ClusteringEvaluator + extends DefaultParamsReadable[ClusteringEvaluator] { - bClustersStatsMap.destroy() + @Since("2.3.0") + override def load(path: String): ClusteringEvaluator = super.load(path) - silhouetteScore - } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala new file mode 100644 index 0000000000000..a785d063f1476 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala @@ -0,0 +1,593 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.Since +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.sql.{Column, DataFrame, Dataset} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.DoubleType + + +/** + * Metrics for clustering, which expects two input columns: prediction and label. + */ +@Since("3.1.0") +class ClusteringMetrics private[spark](dataset: Dataset[_]) { + + private var distanceMeasure: String = "squaredEuclidean" + + def getDistanceMeasure: String = distanceMeasure + + def setDistanceMeasure(value: String) : Unit = distanceMeasure = value + + /** + * Returns the silhouette score + */ + @Since("3.1.0") + lazy val silhouette: Double = { + val columns = dataset.columns.toSeq + if (distanceMeasure.equalsIgnoreCase("squaredEuclidean")) { + SquaredEuclideanSilhouette.computeSilhouetteScore( + dataset, columns(0), columns(1), columns(2)) + } else { + CosineSilhouette.computeSilhouetteScore(dataset, columns(0), columns(1), columns(2)) + } + } +} + + +private[evaluation] abstract class Silhouette { + + /** + * It computes the Silhouette coefficient for a point. + */ + def pointSilhouetteCoefficient( + clusterIds: Set[Double], + pointClusterId: Double, + weightSum: Double, + weight: Double, + averageDistanceToCluster: (Double) => Double): Double = { + if (weightSum == weight) { + // Single-element clusters have silhouette 0 + 0.0 + } else { + // Here we compute the average dissimilarity of the current point to any cluster of which the + // point is not a member. + // The cluster with the lowest average dissimilarity - i.e. the nearest cluster to the current + // point - is said to be the "neighboring cluster". + val otherClusterIds = clusterIds.filter(_ != pointClusterId) + val neighboringClusterDissimilarity = otherClusterIds.map(averageDistanceToCluster).min + // adjustment for excluding the node itself from the computation of the average dissimilarity + val currentClusterDissimilarity = + averageDistanceToCluster(pointClusterId) * weightSum / + (weightSum - weight) + if (currentClusterDissimilarity < neighboringClusterDissimilarity) { + 1 - (currentClusterDissimilarity / neighboringClusterDissimilarity) + } else if (currentClusterDissimilarity > neighboringClusterDissimilarity) { + (neighboringClusterDissimilarity / currentClusterDissimilarity) - 1 + } else { + 0.0 + } + } + } + + /** + * Compute the mean Silhouette values of all samples. + */ + def overallScore(df: DataFrame, scoreColumn: Column, weightColumn: Column): Double = { + df.select(sum(scoreColumn * weightColumn) / sum(weightColumn)).collect()(0).getDouble(0) + } +} + +/** + * SquaredEuclideanSilhouette computes the average of the + * Silhouette over all the data of the dataset, which is + * a measure of how appropriately the data have been clustered. + * + * The Silhouette for each point `i` is defined as: + * + *
+ * $$ + * s_{i} = \frac{b_{i}-a_{i}}{max\{a_{i},b_{i}\}} + * $$ + *
+ * + * which can be rewritten as + * + *
+ * $$ + * s_{i}= \begin{cases} + * 1-\frac{a_{i}}{b_{i}} & \text{if } a_{i} \leq b_{i} \\ + * \frac{b_{i}}{a_{i}}-1 & \text{if } a_{i} \gt b_{i} \end{cases} + * $$ + *
+ * + * where `$a_{i}$` is the average dissimilarity of `i` with all other data + * within the same cluster, `$b_{i}$` is the lowest average dissimilarity + * of `i` to any other cluster, of which `i` is not a member. + * `$a_{i}$` can be interpreted as how well `i` is assigned to its cluster + * (the smaller the value, the better the assignment), while `$b_{i}$` is + * a measure of how well `i` has not been assigned to its "neighboring cluster", + * ie. the nearest cluster to `i`. + * + * Unfortunately, the naive implementation of the algorithm requires to compute + * the distance of each couple of points in the dataset. Since the computation of + * the distance measure takes `D` operations - if `D` is the number of dimensions + * of each point, the computational complexity of the algorithm is `O(N^2^*D)`, where + * `N` is the cardinality of the dataset. Of course this is not scalable in `N`, + * which is the critical number in a Big Data context. + * + * The algorithm which is implemented in this object, instead, is an efficient + * and parallel implementation of the Silhouette using the squared Euclidean + * distance measure. + * + * With this assumption, the total distance of the point `X` + * to the points `$C_{i}$` belonging to the cluster `$\Gamma$` is: + * + *
+ * $$ + * \sum\limits_{i=1}^N d(X, C_{i} ) = + * \sum\limits_{i=1}^N \Big( \sum\limits_{j=1}^D (x_{j}-c_{ij})^2 \Big) + * = \sum\limits_{i=1}^N \Big( \sum\limits_{j=1}^D x_{j}^2 + + * \sum\limits_{j=1}^D c_{ij}^2 -2\sum\limits_{j=1}^D x_{j}c_{ij} \Big) + * = \sum\limits_{i=1}^N \sum\limits_{j=1}^D x_{j}^2 + + * \sum\limits_{i=1}^N \sum\limits_{j=1}^D c_{ij}^2 + * -2 \sum\limits_{i=1}^N \sum\limits_{j=1}^D x_{j}c_{ij} + * $$ + *
+ * + * where `$x_{j}$` is the `j`-th dimension of the point `X` and + * `$c_{ij}$` is the `j`-th dimension of the `i`-th point in cluster `$\Gamma$`. + * + * Then, the first term of the equation can be rewritten as: + * + *
+ * $$ + * \sum\limits_{i=1}^N \sum\limits_{j=1}^D x_{j}^2 = N \xi_{X} \text{ , + * with } \xi_{X} = \sum\limits_{j=1}^D x_{j}^2 + * $$ + *
+ * + * where `$\xi_{X}$` is fixed for each point and it can be precomputed. + * + * Moreover, the second term is fixed for each cluster too, + * thus we can name it `$\Psi_{\Gamma}$` + * + *
+ * $$ + * \sum\limits_{i=1}^N \sum\limits_{j=1}^D c_{ij}^2 = + * \sum\limits_{i=1}^N \xi_{C_{i}} = \Psi_{\Gamma} + * $$ + *
+ * + * Last, the third element becomes + * + *
+ * $$ + * \sum\limits_{i=1}^N \sum\limits_{j=1}^D x_{j}c_{ij} = + * \sum\limits_{j=1}^D \Big(\sum\limits_{i=1}^N c_{ij} \Big) x_{j} + * $$ + *
+ * + * thus defining the vector + * + *
+ * $$ + * Y_{\Gamma}:Y_{\Gamma j} = \sum\limits_{i=1}^N c_{ij} , j=0, ..., D + * $$ + *
+ * + * which is fixed for each cluster `$\Gamma$`, we have + * + *
+ * $$ + * \sum\limits_{j=1}^D \Big(\sum\limits_{i=1}^N c_{ij} \Big) x_{j} = + * \sum\limits_{j=1}^D Y_{\Gamma j} x_{j} + * $$ + *
+ * + * In this way, the previous equation becomes + * + *
+ * $$ + * N\xi_{X} + \Psi_{\Gamma} - 2 \sum\limits_{j=1}^D Y_{\Gamma j} x_{j} + * $$ + *
+ * + * and the average distance of a point to a cluster can be computed as + * + *
+ * $$ + * \frac{\sum\limits_{i=1}^N d(X, C_{i} )}{N} = + * \frac{N\xi_{X} + \Psi_{\Gamma} - 2 \sum\limits_{j=1}^D Y_{\Gamma j} x_{j}}{N} = + * \xi_{X} + \frac{\Psi_{\Gamma} }{N} - 2 \frac{\sum\limits_{j=1}^D Y_{\Gamma j} x_{j}}{N} + * $$ + *
+ * + * Thus, it is enough to precompute: the constant `$\xi_{X}$` for each point `X`; the + * constants `$\Psi_{\Gamma}$`, `N` and the vector `$Y_{\Gamma}$` for + * each cluster `$\Gamma$`. + * + * In the implementation, the precomputed values for the clusters + * are distributed among the worker nodes via broadcasted variables, + * because we can assume that the clusters are limited in number and + * anyway they are much fewer than the points. + * + * The main strengths of this algorithm are the low computational complexity + * and the intrinsic parallelism. The precomputed information for each point + * and for each cluster can be computed with a computational complexity + * which is `O(N/W)`, where `N` is the number of points in the dataset and + * `W` is the number of worker nodes. After that, every point can be + * analyzed independently of the others. + * + * For every point we need to compute the average distance to all the clusters. + * Since the formula above requires `O(D)` operations, this phase has a + * computational complexity which is `O(C*D*N/W)` where `C` is the number of + * clusters (which we assume quite low), `D` is the number of dimensions, + * `N` is the number of points in the dataset and `W` is the number + * of worker nodes. + */ +private[evaluation] object SquaredEuclideanSilhouette extends Silhouette { + + private[this] var kryoRegistrationPerformed: Boolean = false + + /** + * This method registers the class + * [[org.apache.spark.ml.evaluation.SquaredEuclideanSilhouette.ClusterStats]] + * for kryo serialization. + * + * @param sc `SparkContext` to be used + */ + def registerKryoClasses(sc: SparkContext): Unit = { + if (!kryoRegistrationPerformed) { + sc.getConf.registerKryoClasses( + Array( + classOf[SquaredEuclideanSilhouette.ClusterStats] + ) + ) + kryoRegistrationPerformed = true + } + } + + case class ClusterStats(featureSum: Vector, squaredNormSum: Double, weightSum: Double) + + /** + * The method takes the input dataset and computes the aggregated values + * about a cluster which are needed by the algorithm. + * + * @param df The DataFrame which contains the input data + * @param predictionCol The name of the column which contains the predicted cluster id + * for the point. + * @param featuresCol The name of the column which contains the feature vector of the point. + * @param weightCol The name of the column which contains the instance weight. + * @return A [[scala.collection.immutable.Map]] which associates each cluster id + * to a [[ClusterStats]] object (which contains the precomputed values `N`, + * `$\Psi_{\Gamma}$` and `$Y_{\Gamma}$` for a cluster). + */ + def computeClusterStats( + df: DataFrame, + predictionCol: String, + featuresCol: String, + weightCol: String): Map[Double, ClusterStats] = { + val numFeatures = MetadataUtils.getNumFeatures(df, featuresCol) + val clustersStatsRDD = df.select( + col(predictionCol).cast(DoubleType), col(featuresCol), col("squaredNorm"), col(weightCol)) + .rdd + .map { row => (row.getDouble(0), (row.getAs[Vector](1), row.getDouble(2), row.getDouble(3))) } + .aggregateByKey + [(DenseVector, Double, Double)]((Vectors.zeros(numFeatures).toDense, 0.0, 0.0))( + seqOp = { + case ( + (featureSum: DenseVector, squaredNormSum: Double, weightSum: Double), + (features, squaredNorm, weight) + ) => + BLAS.axpy(weight, features, featureSum) + (featureSum, squaredNormSum + squaredNorm * weight, weightSum + weight) + }, + combOp = { + case ( + (featureSum1, squaredNormSum1, weightSum1), + (featureSum2, squaredNormSum2, weightSum2) + ) => + BLAS.axpy(1.0, featureSum2, featureSum1) + (featureSum1, squaredNormSum1 + squaredNormSum2, weightSum1 + weightSum2) + } + ) + + clustersStatsRDD + .collectAsMap() + .mapValues { + case (featureSum: DenseVector, squaredNormSum: Double, weightSum: Double) => + SquaredEuclideanSilhouette.ClusterStats(featureSum, squaredNormSum, weightSum) + } + .toMap + } + + /** + * It computes the Silhouette coefficient for a point. + * + * @param broadcastedClustersMap A map of the precomputed values for each cluster. + * @param point The [[org.apache.spark.ml.linalg.Vector]] representing the current point. + * @param clusterId The id of the cluster the current point belongs to. + * @param weight The instance weight of the current point. + * @param squaredNorm The `$\Xi_{X}$` (which is the squared norm) precomputed for the point. + * @return The Silhouette for the point. + */ + def computeSilhouetteCoefficient( + broadcastedClustersMap: Broadcast[Map[Double, ClusterStats]], + point: Vector, + clusterId: Double, + weight: Double, + squaredNorm: Double): Double = { + + def compute(targetClusterId: Double): Double = { + val clusterStats = broadcastedClustersMap.value(targetClusterId) + val pointDotClusterFeaturesSum = BLAS.dot(point, clusterStats.featureSum) + + squaredNorm + + clusterStats.squaredNormSum / clusterStats.weightSum - + 2 * pointDotClusterFeaturesSum / clusterStats.weightSum + } + + pointSilhouetteCoefficient(broadcastedClustersMap.value.keySet, + clusterId, + broadcastedClustersMap.value(clusterId).weightSum, + weight, + compute) + } + + /** + * Compute the Silhouette score of the dataset using squared Euclidean distance measure. + * + * @param dataset The input dataset (previously clustered) on which compute the Silhouette. + * @param predictionCol The name of the column which contains the predicted cluster id + * for the point. + * @param featuresCol The name of the column which contains the feature vector of the point. + * @param weightCol The name of the column which contains instance weight. + * @return The average of the Silhouette values of the clustered data. + */ + def computeSilhouetteScore( + dataset: Dataset[_], + predictionCol: String, + featuresCol: String, + weightCol: String): Double = { + SquaredEuclideanSilhouette.registerKryoClasses(dataset.sparkSession.sparkContext) + + val squaredNormUDF = udf { + features: Vector => math.pow(Vectors.norm(features, 2.0), 2.0) + } + val dfWithSquaredNorm = dataset.withColumn("squaredNorm", squaredNormUDF(col(featuresCol))) + + // compute aggregate values for clusters needed by the algorithm + val clustersStatsMap = SquaredEuclideanSilhouette + .computeClusterStats(dfWithSquaredNorm, predictionCol, featuresCol, weightCol) + + // Silhouette is reasonable only when the number of clusters is greater then 1 + assert(clustersStatsMap.size > 1, "Number of clusters must be greater than one.") + + val bClustersStatsMap = dataset.sparkSession.sparkContext.broadcast(clustersStatsMap) + + val computeSilhouetteCoefficientUDF = udf { + computeSilhouetteCoefficient(bClustersStatsMap, _: Vector, _: Double, _: Double, _: Double) + } + + val silhouetteScore = overallScore(dfWithSquaredNorm, + computeSilhouetteCoefficientUDF(col(featuresCol), col(predictionCol).cast(DoubleType), + col(weightCol), col("squaredNorm")), col(weightCol)) + + bClustersStatsMap.destroy() + + silhouetteScore + } +} + + +/** + * The algorithm which is implemented in this object, instead, is an efficient and parallel + * implementation of the Silhouette using the cosine distance measure. The cosine distance + * measure is defined as `1 - s` where `s` is the cosine similarity between two points. + * + * The total distance of the point `X` to the points `$C_{i}$` belonging to the cluster `$\Gamma$` + * is: + * + *
+ * $$ + * \sum\limits_{i=1}^N d(X, C_{i} ) = + * \sum\limits_{i=1}^N \Big( 1 - \frac{\sum\limits_{j=1}^D x_{j}c_{ij} }{ \|X\|\|C_{i}\|} \Big) + * = \sum\limits_{i=1}^N 1 - \sum\limits_{i=1}^N \sum\limits_{j=1}^D \frac{x_{j}}{\|X\|} + * \frac{c_{ij}}{\|C_{i}\|} + * = N - \sum\limits_{j=1}^D \frac{x_{j}}{\|X\|} \Big( \sum\limits_{i=1}^N + * \frac{c_{ij}}{\|C_{i}\|} \Big) + * $$ + *
+ * + * where `$x_{j}$` is the `j`-th dimension of the point `X` and `$c_{ij}$` is the `j`-th dimension + * of the `i`-th point in cluster `$\Gamma$`. + * + * Then, we can define the vector: + * + *
+ * $$ + * \xi_{X} : \xi_{X i} = \frac{x_{i}}{\|X\|}, i = 1, ..., D + * $$ + *
+ * + * which can be precomputed for each point and the vector + * + *
+ * $$ + * \Omega_{\Gamma} : \Omega_{\Gamma i} = \sum\limits_{j=1}^N \xi_{C_{j}i}, i = 1, ..., D + * $$ + *
+ * + * which can be precomputed too for each cluster `$\Gamma$` by its points `$C_{i}$`. + * + * With these definitions, the numerator becomes: + * + *
+ * $$ + * N - \sum\limits_{j=1}^D \xi_{X j} \Omega_{\Gamma j} + * $$ + *
+ * + * Thus the average distance of a point `X` to the points of the cluster `$\Gamma$` is: + * + *
+ * $$ + * 1 - \frac{\sum\limits_{j=1}^D \xi_{X j} \Omega_{\Gamma j}}{N} + * $$ + *
+ * + * In the implementation, the precomputed values for the clusters are distributed among the worker + * nodes via broadcasted variables, because we can assume that the clusters are limited in number. + * + * The main strengths of this algorithm are the low computational complexity and the intrinsic + * parallelism. The precomputed information for each point and for each cluster can be computed + * with a computational complexity which is `O(N/W)`, where `N` is the number of points in the + * dataset and `W` is the number of worker nodes. After that, every point can be analyzed + * independently from the others. + * + * For every point we need to compute the average distance to all the clusters. Since the formula + * above requires `O(D)` operations, this phase has a computational complexity which is + * `O(C*D*N/W)` where `C` is the number of clusters (which we assume quite low), `D` is the number + * of dimensions, `N` is the number of points in the dataset and `W` is the number of worker + * nodes. + */ +private[evaluation] object CosineSilhouette extends Silhouette { + + private[this] val normalizedFeaturesColName = "normalizedFeatures" + + /** + * The method takes the input dataset and computes the aggregated values + * about a cluster which are needed by the algorithm. + * + * @param df The DataFrame which contains the input data + * @param featuresCol The name of the column which contains the feature vector of the point. + * @param predictionCol The name of the column which contains the predicted cluster id + * for the point. + * @param weightCol The name of the column which contains the instance weight. + * @return A [[scala.collection.immutable.Map]] which associates each cluster id to a + * its statistics (ie. the precomputed values `N` and `$\Omega_{\Gamma}$`). + */ + def computeClusterStats( + df: DataFrame, + featuresCol: String, + predictionCol: String, + weightCol: String): Map[Double, (Vector, Double)] = { + val numFeatures = MetadataUtils.getNumFeatures(df, featuresCol) + val clustersStatsRDD = df.select( + col(predictionCol).cast(DoubleType), col(normalizedFeaturesColName), col(weightCol)) + .rdd + .map { row => (row.getDouble(0), (row.getAs[Vector](1), row.getDouble(2))) } + .aggregateByKey[(DenseVector, Double)]((Vectors.zeros(numFeatures).toDense, 0.0))( + seqOp = { + case ((normalizedFeaturesSum: DenseVector, weightSum: Double), + (normalizedFeatures, weight)) => + BLAS.axpy(weight, normalizedFeatures, normalizedFeaturesSum) + (normalizedFeaturesSum, weightSum + weight) + }, + combOp = { + case ((normalizedFeaturesSum1, weightSum1), (normalizedFeaturesSum2, weightSum2)) => + BLAS.axpy(1.0, normalizedFeaturesSum2, normalizedFeaturesSum1) + (normalizedFeaturesSum1, weightSum1 + weightSum2) + } + ) + + clustersStatsRDD + .collectAsMap() + .toMap + } + + /** + * It computes the Silhouette coefficient for a point. + * + * @param broadcastedClustersMap A map of the precomputed values for each cluster. + * @param normalizedFeatures The [[org.apache.spark.ml.linalg.Vector]] representing the + * normalized features of the current point. + * @param clusterId The id of the cluster the current point belongs to. + * @param weight The instance weight of the current point. + */ + def computeSilhouetteCoefficient( + broadcastedClustersMap: Broadcast[Map[Double, (Vector, Double)]], + normalizedFeatures: Vector, + clusterId: Double, + weight: Double): Double = { + + def compute(targetClusterId: Double): Double = { + val (normalizedFeatureSum, numOfPoints) = broadcastedClustersMap.value(targetClusterId) + 1 - BLAS.dot(normalizedFeatures, normalizedFeatureSum) / numOfPoints + } + + pointSilhouetteCoefficient(broadcastedClustersMap.value.keySet, + clusterId, + broadcastedClustersMap.value(clusterId)._2, + weight, + compute) + } + + /** + * Compute the Silhouette score of the dataset using the cosine distance measure. + * + * @param dataset The input dataset (previously clustered) on which compute the Silhouette. + * @param predictionCol The name of the column which contains the predicted cluster id + * for the point. + * @param featuresCol The name of the column which contains the feature vector of the point. + * @param weightCol The name of the column which contains the instance weight. + * @return The average of the Silhouette values of the clustered data. + */ + def computeSilhouetteScore( + dataset: Dataset[_], + predictionCol: String, + featuresCol: String, + weightCol: String): Double = { + val normalizeFeatureUDF = udf { + features: Vector => { + val norm = Vectors.norm(features, 2.0) + BLAS.scal(1.0 / norm, features) + features + } + } + val dfWithNormalizedFeatures = dataset.withColumn(normalizedFeaturesColName, + normalizeFeatureUDF(col(featuresCol))) + + // compute aggregate values for clusters needed by the algorithm + val clustersStatsMap = computeClusterStats(dfWithNormalizedFeatures, featuresCol, + predictionCol, weightCol) + + // Silhouette is reasonable only when the number of clusters is greater then 1 + assert(clustersStatsMap.size > 1, "Number of clusters must be greater than one.") + + val bClustersStatsMap = dataset.sparkSession.sparkContext.broadcast(clustersStatsMap) + + val computeSilhouetteCoefficientUDF = udf { + computeSilhouetteCoefficient(bClustersStatsMap, _: Vector, _: Double, _: Double) + } + + val silhouetteScore = overallScore(dfWithNormalizedFeatures, + computeSilhouetteCoefficientUDF(col(normalizedFeaturesColName), + col(predictionCol).cast(DoubleType), col(weightCol)), col(weightCol)) + + bClustersStatsMap.destroy() + + silhouetteScore + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index 1d6540e970383..3d77792c4fc88 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.Since +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -153,19 +154,51 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid @Since("2.0.0") override def evaluate(dataset: Dataset[_]): Double = { + val metrics = getMetrics(dataset) + $(metricName) match { + case "f1" => metrics.weightedFMeasure + case "accuracy" => metrics.accuracy + case "weightedPrecision" => metrics.weightedPrecision + case "weightedRecall" => metrics.weightedRecall + case "weightedTruePositiveRate" => metrics.weightedTruePositiveRate + case "weightedFalsePositiveRate" => metrics.weightedFalsePositiveRate + case "weightedFMeasure" => metrics.weightedFMeasure($(beta)) + case "truePositiveRateByLabel" => metrics.truePositiveRate($(metricLabel)) + case "falsePositiveRateByLabel" => metrics.falsePositiveRate($(metricLabel)) + case "precisionByLabel" => metrics.precision($(metricLabel)) + case "recallByLabel" => metrics.recall($(metricLabel)) + case "fMeasureByLabel" => metrics.fMeasure($(metricLabel), $(beta)) + case "hammingLoss" => metrics.hammingLoss + case "logLoss" => metrics.logLoss($(eps)) + } + } + + /** + * Get a MulticlassMetrics, which can be used to get multiclass classification + * metrics such as accuracy, weightedPrecision, etc. + * + * @param dataset a dataset that contains labels/observations and predictions. + * @return MulticlassMetrics + */ + @Since("3.1.0") + def getMetrics(dataset: Dataset[_]): MulticlassMetrics = { val schema = dataset.schema SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) SchemaUtils.checkNumericType(schema, $(labelCol)) val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) { - col($(weightCol)).cast(DoubleType) + checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) } else { lit(1.0) } - val rdd = if ($(metricName) == "logLoss") { + if ($(metricName) == "logLoss") { // probabilityCol is only needed to compute logloss - require(isDefined(probabilityCol) && $(probabilityCol).nonEmpty) + require(schema.fieldNames.contains($(probabilityCol)), + "probabilityCol is needed to compute logloss") + } + + val rdd = if (schema.fieldNames.contains($(probabilityCol))) { val p = DatasetUtils.columnToVector(dataset, $(probabilityCol)) dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType), w, p) .rdd.map { @@ -179,23 +212,7 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid } } - val metrics = new MulticlassMetrics(rdd) - $(metricName) match { - case "f1" => metrics.weightedFMeasure - case "accuracy" => metrics.accuracy - case "weightedPrecision" => metrics.weightedPrecision - case "weightedRecall" => metrics.weightedRecall - case "weightedTruePositiveRate" => metrics.weightedTruePositiveRate - case "weightedFalsePositiveRate" => metrics.weightedFalsePositiveRate - case "weightedFMeasure" => metrics.weightedFMeasure($(beta)) - case "truePositiveRateByLabel" => metrics.truePositiveRate($(metricLabel)) - case "falsePositiveRateByLabel" => metrics.falsePositiveRate($(metricLabel)) - case "precisionByLabel" => metrics.precision($(metricLabel)) - case "recallByLabel" => metrics.recall($(metricLabel)) - case "fMeasureByLabel" => metrics.fMeasure($(metricLabel), $(beta)) - case "hammingLoss" => metrics.hammingLoss - case "logLoss" => metrics.logLoss($(eps)) - } + new MulticlassMetrics(rdd) } @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MultilabelClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MultilabelClassificationEvaluator.scala index a8db5452bd56c..1a82ac7a9472f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MultilabelClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MultilabelClassificationEvaluator.scala @@ -98,18 +98,7 @@ class MultilabelClassificationEvaluator @Since("3.0.0") (@Since("3.0.0") overrid @Since("3.0.0") override def evaluate(dataset: Dataset[_]): Double = { - val schema = dataset.schema - SchemaUtils.checkColumnTypes(schema, $(predictionCol), - Seq(ArrayType(DoubleType, false), ArrayType(DoubleType, true))) - SchemaUtils.checkColumnTypes(schema, $(labelCol), - Seq(ArrayType(DoubleType, false), ArrayType(DoubleType, true))) - - val predictionAndLabels = - dataset.select(col($(predictionCol)), col($(labelCol))) - .rdd.map { row => - (row.getSeq[Double](0).toArray, row.getSeq[Double](1).toArray) - } - val metrics = new MultilabelMetrics(predictionAndLabels) + val metrics = getMetrics(dataset) $(metricName) match { case "subsetAccuracy" => metrics.subsetAccuracy case "accuracy" => metrics.accuracy @@ -126,6 +115,29 @@ class MultilabelClassificationEvaluator @Since("3.0.0") (@Since("3.0.0") overrid } } + /** + * Get a MultilabelMetrics, which can be used to get multilabel classification + * metrics such as accuracy, precision, precisionByLabel, etc. + * + * @param dataset a dataset that contains labels/observations and predictions. + * @return MultilabelMetrics + */ + @Since("3.1.0") + def getMetrics(dataset: Dataset[_]): MultilabelMetrics = { + val schema = dataset.schema + SchemaUtils.checkColumnTypes(schema, $(predictionCol), + Seq(ArrayType(DoubleType, false), ArrayType(DoubleType, true))) + SchemaUtils.checkColumnTypes(schema, $(labelCol), + Seq(ArrayType(DoubleType, false), ArrayType(DoubleType, true))) + + val predictionAndLabels = + dataset.select(col($(predictionCol)), col($(labelCol))) + .rdd.map { row => + (row.getSeq[Double](0).toArray, row.getSeq[Double](1).toArray) + } + new MultilabelMetrics(predictionAndLabels) + } + @Since("3.0.0") override def isLargerBetter: Boolean = { $(metricName) match { diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala index c5dea6c177e21..82dda4109771d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala @@ -95,6 +95,25 @@ class RankingEvaluator @Since("3.0.0") (@Since("3.0.0") override val uid: String @Since("3.0.0") override def evaluate(dataset: Dataset[_]): Double = { + val metrics = getMetrics(dataset) + $(metricName) match { + case "meanAveragePrecision" => metrics.meanAveragePrecision + case "meanAveragePrecisionAtK" => metrics.meanAveragePrecisionAt($(k)) + case "precisionAtK" => metrics.precisionAt($(k)) + case "ndcgAtK" => metrics.ndcgAt($(k)) + case "recallAtK" => metrics.recallAt($(k)) + } + } + + /** + * Get a RankingMetrics, which can be used to get ranking metrics + * such as meanAveragePrecision, meanAveragePrecisionAtK, etc. + * + * @param dataset a dataset that contains labels/observations and predictions. + * @return RankingMetrics + */ + @Since("3.1.0") + def getMetrics(dataset: Dataset[_]): RankingMetrics[Double] = { val schema = dataset.schema SchemaUtils.checkColumnTypes(schema, $(predictionCol), Seq(ArrayType(DoubleType, false), ArrayType(DoubleType, true))) @@ -106,14 +125,7 @@ class RankingEvaluator @Since("3.0.0") (@Since("3.0.0") override val uid: String .rdd.map { row => (row.getSeq[Double](0).toArray, row.getSeq[Double](1).toArray) } - val metrics = new RankingMetrics[Double](predictionAndLabels) - $(metricName) match { - case "meanAveragePrecision" => metrics.meanAveragePrecision - case "meanAveragePrecisionAtK" => metrics.meanAveragePrecisionAt($(k)) - case "precisionAtK" => metrics.precisionAt($(k)) - case "ndcgAtK" => metrics.ndcgAt($(k)) - case "recallAtK" => metrics.recallAt($(k)) - } + new RankingMetrics[Double](predictionAndLabels) } @Since("3.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index 18a8dda0c76ef..f0b7c345c3285 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.Since +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} @@ -97,24 +98,37 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui @Since("2.0.0") override def evaluate(dataset: Dataset[_]): Double = { + val metrics = getMetrics(dataset) + $(metricName) match { + case "rmse" => metrics.rootMeanSquaredError + case "mse" => metrics.meanSquaredError + case "r2" => metrics.r2 + case "mae" => metrics.meanAbsoluteError + case "var" => metrics.explainedVariance + } + } + + /** + * Get a RegressionMetrics, which can be used to get regression + * metrics such as rootMeanSquaredError, meanSquaredError, etc. + * + * @param dataset a dataset that contains labels/observations and predictions. + * @return RegressionMetrics + */ + @Since("3.1.0") + def getMetrics(dataset: Dataset[_]): RegressionMetrics = { val schema = dataset.schema SchemaUtils.checkColumnTypes(schema, $(predictionCol), Seq(DoubleType, FloatType)) SchemaUtils.checkNumericType(schema, $(labelCol)) val predictionAndLabelsWithWeights = dataset .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType), - if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))) + if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) + else checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))) .rdd .map { case Row(prediction: Double, label: Double, weight: Double) => (prediction, label, weight) } - val metrics = new RegressionMetrics(predictionAndLabelsWithWeights, $(throughOrigin)) - $(metricName) match { - case "rmse" => metrics.rootMeanSquaredError - case "mse" => metrics.meanSquaredError - case "r2" => metrics.r2 - case "mae" => metrics.meanAbsoluteError - case "var" => metrics.explainedVariance - } + new RegressionMetrics(predictionAndLabelsWithWeights, $(throughOrigin)) } @Since("1.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/FValueSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/FValueSelector.scala index af975edf7c049..d177555c5ddee 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/FValueSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/FValueSelector.scala @@ -150,7 +150,7 @@ class FValueSelectorModel private[ml]( @Since("3.1.0") override def toString: String = { - s"FValueModel: uid=$uid, numSelectedFeatures=${selectedFeatures.length}" + s"FValueSelectorModel: uid=$uid, numSelectedFeatures=${selectedFeatures.length}" } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 80bf85936aace..d2bb013448aae 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -42,14 +42,17 @@ import org.apache.spark.util.VersionUtils.majorMinorVersion * otherwise the features will not be mapped evenly to the columns. */ @Since("1.2.0") -class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String) +class HashingTF @Since("3.0.0") private[ml] ( + @Since("1.4.0") override val uid: String, + @Since("3.1.0") val hashFuncVersion: Int) extends Transformer with HasInputCol with HasOutputCol with HasNumFeatures with DefaultParamsWritable { - private var hashFunc: Any => Int = FeatureHasher.murmur3Hash - @Since("1.2.0") - def this() = this(Identifiable.randomUID("hashingTF")) + def this() = this(Identifiable.randomUID("hashingTF"), HashingTF.SPARK_3_MURMUR3_HASH) + + @Since("1.4.0") + def this(uid: String) = this(uid, hashFuncVersion = HashingTF.SPARK_3_MURMUR3_HASH) /** @group setParam */ @Since("1.4.0") @@ -122,7 +125,12 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String) */ @Since("3.0.0") def indexOf(term: Any): Int = { - Utils.nonNegativeMod(hashFunc(term), $(numFeatures)) + val hashValue = hashFuncVersion match { + case HashingTF.SPARK_2_MURMUR3_HASH => OldHashingTF.murmur3Hash(term) + case HashingTF.SPARK_3_MURMUR3_HASH => FeatureHasher.murmur3Hash(term) + case _ => throw new IllegalArgumentException("Illegal hash function version setting.") + } + Utils.nonNegativeMod(hashValue, $(numFeatures)) } @Since("1.4.1") @@ -132,27 +140,41 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String) override def toString: String = { s"HashingTF: uid=$uid, binary=${$(binary)}, numFeatures=${$(numFeatures)}" } + + @Since("3.0.0") + override def save(path: String): Unit = { + require(hashFuncVersion == HashingTF.SPARK_3_MURMUR3_HASH, + "Cannot save model which is loaded from lower version spark saved model. We can address " + + "it by (1) use old spark version to save the model, or (2) use new version spark to " + + "re-train the pipeline.") + super.save(path) + } } @Since("1.6.0") object HashingTF extends DefaultParamsReadable[HashingTF] { + private[ml] val SPARK_2_MURMUR3_HASH = 1 + private[ml] val SPARK_3_MURMUR3_HASH = 2 + private class HashingTFReader extends MLReader[HashingTF] { private val className = classOf[HashingTF].getName override def load(path: String): HashingTF = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) - val hashingTF = new HashingTF(metadata.uid) - metadata.getAndSetParams(hashingTF) // We support loading old `HashingTF` saved by previous Spark versions. // Previous `HashingTF` uses `mllib.feature.HashingTF.murmur3Hash`, but new `HashingTF` uses // `ml.Feature.FeatureHasher.murmur3Hash`. val (majorVersion, _) = majorMinorVersion(metadata.sparkVersion) - if (majorVersion < 3) { - hashingTF.hashFunc = OldHashingTF.murmur3Hash + val hashFuncVersion = if (majorVersion < 3) { + SPARK_2_MURMUR3_HASH + } else { + SPARK_3_MURMUR3_HASH } + val hashingTF = new HashingTF(metadata.uid, hashFuncVersion = hashFuncVersion) + metadata.getAndSetParams(hashingTF) hashingTF } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 216d99d01f2f7..4eedfc4dc0efa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -236,6 +236,18 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui private def getDistinctSplits(splits: Array[Double]): Array[Double] = { splits(0) = Double.NegativeInfinity splits(splits.length - 1) = Double.PositiveInfinity + + // 0.0 and -0.0 are distinct values, array.distinct will preserve both of them. + // but 0.0 > -0.0 is False which will break the parameter validation checking. + // and in scala <= 2.12, there's bug which will cause array.distinct generate + // non-deterministic results when array contains both 0.0 and -0.0 + // So that here we should first normalize all 0.0 and -0.0 to be 0.0 + // See https://github.com/scala/bug/issues/11995 + for (i <- 0 until splits.length) { + if (splits(i) == -0.0) { + splits(i) = 0.0 + } + } val distinctSplits = splits.distinct if (splits.length != distinctSplits.length) { log.warn(s"Some quantiles were identical. Bucketing to ${distinctSplits.length - 1}" + diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index c6b1b29a6d9bc..7434b1adb2ff2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -281,7 +281,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { values } - private[ml] def getTransformFunc( + private[spark] def getTransformFunc( shift: Array[Double], scale: Array[Double], withShift: Boolean, diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 30700122665de..7bc5e56aaebf2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -233,7 +233,7 @@ object VectorAssembler extends DefaultParamsReadable[VectorAssembler] { getVectorLengthsFromFirstRow(dataset.na.drop(missingColumns), missingColumns) case (true, VectorAssembler.KEEP_INVALID) => throw new RuntimeException( s"""Can not infer column lengths with handleInvalid = "keep". Consider using VectorSizeHint - |to add metadata for columns: ${columns.mkString("[", ", ", "]")}.""" + |to add metadata for columns: ${missingColumns.mkString("[", ", ", "]")}.""" .stripMargin.replaceAll("\n", " ")) case (_, _) => Map.empty } diff --git a/mllib/src/main/scala/org/apache/spark/ml/functions.scala b/mllib/src/main/scala/org/apache/spark/ml/functions.scala index 0f03231079866..a0b6d11a46be9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/functions.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/functions.scala @@ -71,4 +71,10 @@ object functions { ) } } + + private[ml] def checkNonNegativeWeight = udf { + value: Double => + require(value >= 0, s"illegal weight value: $value. weight must be >= 0.0.") + value + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/AFTAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/AFTAggregator.scala index 6482c619e6c05..8a5d7fe34e7a0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/AFTAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/AFTAggregator.scala @@ -155,8 +155,102 @@ private[ml] class AFTAggregator( } gradientSumArray(dim - 2) += { if (fitIntercept) multiplier else 0.0 } gradientSumArray(dim - 1) += delta + multiplier * sigma * epsilon - weightSum += 1.0 + + this + } +} + + +/** + * BlockAFTAggregator computes the gradient and loss as used in AFT survival regression + * for blocks in sparse or dense matrix in an online fashion. + * + * Two BlockAFTAggregators can be merged together to have a summary of loss and gradient of + * the corresponding joint dataset. + * + * NOTE: The feature values are expected to be standardized before computation. + * + * @param bcCoefficients The coefficients corresponding to the features. + * @param fitIntercept Whether to fit an intercept term. + */ +private[ml] class BlockAFTAggregator( + fitIntercept: Boolean)(bcCoefficients: Broadcast[Vector]) + extends DifferentiableLossAggregator[(Matrix, Array[Double], Array[Double]), + BlockAFTAggregator] { + + protected override val dim: Int = bcCoefficients.value.size + private val numFeatures = dim - 2 + + @transient private lazy val coefficientsArray = bcCoefficients.value match { + case DenseVector(values) => values + case _ => throw new IllegalArgumentException(s"coefficients only supports dense vector" + + s" but got type ${bcCoefficients.value.getClass}.") + } + + @transient private lazy val linear = Vectors.dense(coefficientsArray.take(numFeatures)) + + /** + * Add a new training instance block to this BlockAFTAggregator, and update the loss and + * gradient of the objective function. + * + * @return This BlockAFTAggregator object. + */ + def add(block: (Matrix, Array[Double], Array[Double])): this.type = { + val (matrix, labels, censors) = block + require(matrix.isTransposed) + require(numFeatures == matrix.numCols, s"Dimensions mismatch when adding new " + + s"instance. Expecting $numFeatures but got ${matrix.numCols}.") + require(labels.forall(_ > 0.0), "The lifetime or label should be greater than 0.") + + val size = matrix.numRows + require(labels.length == size && censors.length == size) + + val intercept = coefficientsArray(dim - 2) + // sigma is the scale parameter of the AFT model + val sigma = math.exp(coefficientsArray(dim - 1)) + + // vec here represents margins + val vec = if (fitIntercept) { + Vectors.dense(Array.fill(size)(intercept)).toDense + } else { + Vectors.zeros(size).toDense + } + BLAS.gemv(1.0, matrix, linear, 1.0, vec) + + // in-place convert margins to gradient scales + // then, vec represents gradient scales + var i = 0 + var sigmaGradSum = 0.0 + while (i < size) { + val ti = labels(i) + val delta = censors(i) + val margin = vec(i) + val epsilon = (math.log(ti) - margin) / sigma + val expEpsilon = math.exp(epsilon) + lossSum += delta * math.log(sigma) - delta * epsilon + expEpsilon + val multiplier = (delta - expEpsilon) / sigma + vec.values(i) = multiplier + sigmaGradSum += delta + multiplier * sigma * epsilon + i += 1 + } + + matrix match { + case dm: DenseMatrix => + BLAS.nativeBLAS.dgemv("N", dm.numCols, dm.numRows, 1.0, dm.values, dm.numCols, + vec.values, 1, 1.0, gradientSumArray, 1) + + case sm: SparseMatrix => + val linearGradSumVec = Vectors.zeros(numFeatures).toDense + BLAS.gemv(1.0, sm.transpose, vec, 0.0, linearGradSumVec) + BLAS.getBLAS(numFeatures).daxpy(numFeatures, 1.0, linearGradSumVec.values, 1, + gradientSumArray, 1) + } + + if (fitIntercept) gradientSumArray(dim - 2) += vec.values.sum + gradientSumArray(dim - 1) += sigmaGradSum + weightSum += size + this } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HuberAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HuberAggregator.scala index 8a1a41b2950c1..59ecc038e5569 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HuberAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HuberAggregator.scala @@ -17,8 +17,8 @@ package org.apache.spark.ml.optim.aggregator import org.apache.spark.broadcast.Broadcast -import org.apache.spark.ml.feature.Instance -import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.feature.{Instance, InstanceBlock} +import org.apache.spark.ml.linalg._ /** * HuberAggregator computes the gradient and loss for a huber loss function, @@ -74,15 +74,12 @@ private[ml] class HuberAggregator( extends DifferentiableLossAggregator[Instance, HuberAggregator] { protected override val dim: Int = bcParameters.value.size - private val numFeatures: Int = if (fitIntercept) dim - 2 else dim - 1 - private val sigma: Double = bcParameters.value(dim - 1) - private val intercept: Double = if (fitIntercept) { - bcParameters.value(dim - 2) - } else { - 0.0 - } + private val numFeatures = if (fitIntercept) dim - 2 else dim - 1 + private val sigma = bcParameters.value(dim - 1) + private val intercept = if (fitIntercept) bcParameters.value(dim - 2) else 0.0 + // make transient so we do not serialize between aggregation stages - @transient private lazy val coefficients = bcParameters.value.toArray.slice(0, numFeatures) + @transient private lazy val coefficients = bcParameters.value.toArray.take(numFeatures) /** * Add a new training instance to this HuberAggregator, and update the loss and gradient @@ -150,3 +147,101 @@ private[ml] class HuberAggregator( } } } + + +/** + * BlockHuberAggregator computes the gradient and loss for Huber loss function + * as used in linear regression for blocks in sparse or dense matrix in an online fashion. + * + * Two BlockHuberAggregators can be merged together to have a summary of loss and gradient + * of the corresponding joint dataset. + * + * NOTE: The feature values are expected to be standardized before computation. + * + * @param fitIntercept Whether to fit an intercept term. + */ +private[ml] class BlockHuberAggregator( + fitIntercept: Boolean, + epsilon: Double)(bcParameters: Broadcast[Vector]) + extends DifferentiableLossAggregator[InstanceBlock, BlockHuberAggregator] { + + protected override val dim: Int = bcParameters.value.size + private val numFeatures = if (fitIntercept) dim - 2 else dim - 1 + private val sigma = bcParameters.value(dim - 1) + private val intercept = if (fitIntercept) bcParameters.value(dim - 2) else 0.0 + // make transient so we do not serialize between aggregation stages + @transient private lazy val linear = Vectors.dense(bcParameters.value.toArray.take(numFeatures)) + + /** + * Add a new training instance block to this BlockHuberAggregator, and update the loss and + * gradient of the objective function. + * + * @param block The instance block of data point to be added. + * @return This BlockHuberAggregator object. + */ + def add(block: InstanceBlock): BlockHuberAggregator = { + require(block.matrix.isTransposed) + require(numFeatures == block.numFeatures, s"Dimensions mismatch when adding new " + + s"instance. Expecting $numFeatures but got ${block.numFeatures}.") + require(block.weightIter.forall(_ >= 0), + s"instance weights ${block.weightIter.mkString("[", ",", "]")} has to be >= 0.0") + + if (block.weightIter.forall(_ == 0)) return this + val size = block.size + + // vec here represents margins or dotProducts + val vec = if (fitIntercept) { + Vectors.dense(Array.fill(size)(intercept)).toDense + } else { + Vectors.zeros(size).toDense + } + BLAS.gemv(1.0, block.matrix, linear, 1.0, vec) + + // in-place convert margins to multipliers + // then, vec represents multipliers + var sigmaGradSum = 0.0 + var i = 0 + while (i < size) { + val weight = block.getWeight(i) + if (weight > 0) { + weightSum += weight + val label = block.getLabel(i) + val margin = vec(i) + val linearLoss = label - margin + + if (math.abs(linearLoss) <= sigma * epsilon) { + lossSum += 0.5 * weight * (sigma + math.pow(linearLoss, 2.0) / sigma) + val linearLossDivSigma = linearLoss / sigma + val multiplier = -1.0 * weight * linearLossDivSigma + vec.values(i) = multiplier + sigmaGradSum += 0.5 * weight * (1.0 - math.pow(linearLossDivSigma, 2.0)) + } else { + lossSum += 0.5 * weight * + (sigma + 2.0 * epsilon * math.abs(linearLoss) - sigma * epsilon * epsilon) + val sign = if (linearLoss >= 0) -1.0 else 1.0 + val multiplier = weight * sign * epsilon + vec.values(i) = multiplier + sigmaGradSum += 0.5 * weight * (1.0 - epsilon * epsilon) + } + } else { vec.values(i) = 0.0 } + i += 1 + } + + block.matrix match { + case dm: DenseMatrix => + BLAS.nativeBLAS.dgemv("N", dm.numCols, dm.numRows, 1.0, dm.values, dm.numCols, + vec.values, 1, 1.0, gradientSumArray, 1) + + case sm: SparseMatrix => + val linearGradSumVec = Vectors.zeros(numFeatures).toDense + BLAS.gemv(1.0, sm.transpose, vec, 0.0, linearGradSumVec) + BLAS.getBLAS(numFeatures).daxpy(numFeatures, 1.0, linearGradSumVec.values, 1, + gradientSumArray, 1) + } + + gradientSumArray(dim - 1) += sigmaGradSum + if (fitIntercept) gradientSumArray(dim - 2) += vec.values.sum + + this + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregator.scala index 7a5806dc24aee..fa3bda00d802d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregator.scala @@ -17,8 +17,8 @@ package org.apache.spark.ml.optim.aggregator import org.apache.spark.broadcast.Broadcast -import org.apache.spark.ml.feature.Instance -import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.ml.feature.{Instance, InstanceBlock} +import org.apache.spark.ml.linalg._ /** * LeastSquaresAggregator computes the gradient and loss for a Least-squared loss function, @@ -222,3 +222,92 @@ private[ml] class LeastSquaresAggregator( } } } + + +/** + * BlockLeastSquaresAggregator computes the gradient and loss for LeastSquares loss function + * as used in linear regression for blocks in sparse or dense matrix in an online fashion. + * + * Two BlockLeastSquaresAggregators can be merged together to have a summary of loss and gradient + * of the corresponding joint dataset. + * + * NOTE: The feature values are expected to be standardized before computation. + * + * @param bcCoefficients The coefficients corresponding to the features. + * @param fitIntercept Whether to fit an intercept term. + */ +private[ml] class BlockLeastSquaresAggregator( + labelStd: Double, + labelMean: Double, + fitIntercept: Boolean, + bcFeaturesStd: Broadcast[Array[Double]], + bcFeaturesMean: Broadcast[Array[Double]])(bcCoefficients: Broadcast[Vector]) + extends DifferentiableLossAggregator[InstanceBlock, BlockLeastSquaresAggregator] { + require(labelStd > 0.0, s"${this.getClass.getName} requires the label standard " + + s"deviation to be positive.") + + private val numFeatures = bcFeaturesStd.value.length + protected override val dim: Int = numFeatures + // make transient so we do not serialize between aggregation stages + @transient private lazy val effectiveCoefAndOffset = { + val coefficientsArray = bcCoefficients.value.toArray.clone() + val featuresMean = bcFeaturesMean.value + val featuresStd = bcFeaturesStd.value + var sum = 0.0 + var i = 0 + val len = coefficientsArray.length + while (i < len) { + if (featuresStd(i) != 0.0) { + sum += coefficientsArray(i) / featuresStd(i) * featuresMean(i) + } else { + coefficientsArray(i) = 0.0 + } + i += 1 + } + val offset = if (fitIntercept) labelMean / labelStd - sum else 0.0 + (Vectors.dense(coefficientsArray), offset) + } + // do not use tuple assignment above because it will circumvent the @transient tag + @transient private lazy val effectiveCoefficientsVec = effectiveCoefAndOffset._1 + @transient private lazy val offset = effectiveCoefAndOffset._2 + + /** + * Add a new training instance block to this BlockLeastSquaresAggregator, and update the loss + * and gradient of the objective function. + * + * @param block The instance block of data point to be added. + * @return This BlockLeastSquaresAggregator object. + */ + def add(block: InstanceBlock): BlockLeastSquaresAggregator = { + require(block.matrix.isTransposed) + require(numFeatures == block.numFeatures, s"Dimensions mismatch when adding new " + + s"instance. Expecting $numFeatures but got ${block.numFeatures}.") + require(block.weightIter.forall(_ >= 0), + s"instance weights ${block.weightIter.mkString("[", ",", "]")} has to be >= 0.0") + + if (block.weightIter.forall(_ == 0)) return this + val size = block.size + + // vec here represents diffs + val vec = new DenseVector(Array.tabulate(size)(i => offset - block.getLabel(i) / labelStd)) + BLAS.gemv(1.0, block.matrix, effectiveCoefficientsVec, 1.0, vec) + + // in-place convert diffs to multipliers + // then, vec represents multipliers + var i = 0 + while (i < size) { + val weight = block.getWeight(i) + val diff = vec(i) + lossSum += weight * diff * diff / 2 + weightSum += weight + val multiplier = weight * diff + vec.values(i) = multiplier + i += 1 + } + + val gradSumVec = new DenseVector(gradientSumArray) + BLAS.gemv(1.0, block.matrix.transpose, vec, 1.0, gradSumVec) + + this + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 8cc5f864de1e0..2c30e44b93467 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -27,8 +27,9 @@ import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml.PredictorParams -import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT} -import org.apache.spark.ml.optim.aggregator.AFTAggregator +import org.apache.spark.ml.feature.StandardScalerModel +import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.optim.aggregator._ import org.apache.spark.ml.optim.loss.RDDLossFunction import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -46,7 +47,8 @@ import org.apache.spark.storage.StorageLevel * Params for accelerated failure time (AFT) regression. */ private[regression] trait AFTSurvivalRegressionParams extends PredictorParams - with HasMaxIter with HasTol with HasFitIntercept with HasAggregationDepth with Logging { + with HasMaxIter with HasTol with HasFitIntercept with HasAggregationDepth with HasBlockSize + with Logging { /** * Param for censor column name. @@ -183,6 +185,25 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value) setDefault(aggregationDepth -> 2) + /** + * Set block size for stacking input data in matrices. + * If blockSize == 1, then stacking will be skipped, and each vector is treated individually; + * If blockSize > 1, then vectors will be stacked to blocks, and high-level BLAS routines + * will be used if possible (for example, GEMV instead of DOT, GEMM instead of GEMV). + * Recommended size is between 10 and 1000. An appropriate choice of the block size depends + * on the sparsity and dim of input datasets, the underlying BLAS implementation (for example, + * f2jBLAS, OpenBLAS, intel MKL) and its configuration (for example, number of threads). + * Note that existing BLAS implementations are mainly optimized for dense matrices, if the + * input dataset is sparse, stacking may bring no performance gain, the worse is possible + * performance regression. + * Default is 1. + * + * @group expertSetParam + */ + @Since("3.1.0") + def setBlockSize(value: Int): this.type = set(blockSize, value) + setDefault(blockSize -> 1) + /** * Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset, * and put it in an RDD with strong types. @@ -197,39 +218,50 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S override protected def train( dataset: Dataset[_]): AFTSurvivalRegressionModel = instrumented { instr => + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, labelCol, featuresCol, censorCol, predictionCol, quantilesCol, + fitIntercept, maxIter, tol, aggregationDepth, blockSize) + instr.logNamedValue("quantileProbabilities.size", $(quantileProbabilities).length) + val instances = extractAFTPoints(dataset) - val handlePersistence = dataset.storageLevel == StorageLevel.NONE - if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) + .setName("training instances") - val featuresSummarizer = instances.treeAggregate( - Summarizer.createSummarizerBuffer("mean", "std", "count"))( + if ($(blockSize) == 1 && dataset.storageLevel == StorageLevel.NONE) { + instances.persist(StorageLevel.MEMORY_AND_DISK) + } + + var requestedMetrics = Seq("mean", "std", "count") + if ($(blockSize) != 1) requestedMetrics +:= "numNonZeros" + val summarizer = instances.treeAggregate( + Summarizer.createSummarizerBuffer(requestedMetrics: _*))( seqOp = (c: SummarizerBuffer, v: AFTPoint) => c.add(v.features), combOp = (c1: SummarizerBuffer, c2: SummarizerBuffer) => c1.merge(c2), depth = $(aggregationDepth) ) - val featuresStd = featuresSummarizer.std.toArray + val featuresStd = summarizer.std.toArray val numFeatures = featuresStd.length - - instr.logPipelineStage(this) - instr.logDataset(dataset) - instr.logParams(this, labelCol, featuresCol, censorCol, predictionCol, quantilesCol, - fitIntercept, maxIter, tol, aggregationDepth) - instr.logNamedValue("quantileProbabilities.size", $(quantileProbabilities).length) instr.logNumFeatures(numFeatures) - instr.logNumExamples(featuresSummarizer.count) + instr.logNumExamples(summarizer.count) + if ($(blockSize) > 1) { + val scale = 1.0 / summarizer.count / numFeatures + val sparsity = 1 - summarizer.numNonzeros.toArray.map(_ * scale).sum + instr.logNamedValue("sparsity", sparsity.toString) + if (sparsity > 0.5) { + instr.logWarning(s"sparsity of input dataset is $sparsity, " + + s"which may hurt performance in high-level BLAS.") + } + } if (!$(fitIntercept) && (0 until numFeatures).exists { i => - featuresStd(i) == 0.0 && featuresSummarizer.mean(i) != 0.0 }) { + featuresStd(i) == 0.0 && summarizer.mean(i) != 0.0 }) { instr.logWarning("Fitting AFTSurvivalRegressionModel without intercept on dataset with " + "constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero " + "columns. This behavior is different from R survival::survreg.") } - val bcFeaturesStd = instances.context.broadcast(featuresStd) - val getAggregatorFunc = new AFTAggregator(bcFeaturesStd, $(fitIntercept))(_) val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) - val costFun = new RDDLossFunction(instances, getAggregatorFunc, None, $(aggregationDepth)) /* The parameters vector has three parts: @@ -239,36 +271,86 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S */ val initialParameters = Vectors.zeros(numFeatures + 2) + val (rawCoefficients, objectiveHistory) = if ($(blockSize) == 1) { + trainOnRows(instances, featuresStd, optimizer, initialParameters) + } else { + trainOnBlocks(instances, featuresStd, optimizer, initialParameters) + } + if (instances.getStorageLevel != StorageLevel.NONE) instances.unpersist() + + if (rawCoefficients == null) { + val msg = s"${optimizer.getClass.getName} failed." + instr.logError(msg) + throw new SparkException(msg) + } + + val coefficientArray = Array.tabulate(numFeatures) { i => + if (featuresStd(i) != 0) rawCoefficients(i) / featuresStd(i) else 0.0 + } + val coefficients = Vectors.dense(coefficientArray) + val intercept = rawCoefficients(numFeatures) + val scale = math.exp(rawCoefficients(numFeatures + 1)) + new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale) + } + + private def trainOnRows( + instances: RDD[AFTPoint], + featuresStd: Array[Double], + optimizer: BreezeLBFGS[BDV[Double]], + initialParameters: Vector): (Array[Double], Array[Double]) = { + val bcFeaturesStd = instances.context.broadcast(featuresStd) + val getAggregatorFunc = new AFTAggregator(bcFeaturesStd, $(fitIntercept))(_) + val costFun = new RDDLossFunction(instances, getAggregatorFunc, None, $(aggregationDepth)) + val states = optimizer.iterations(new CachedDiffFunction(costFun), initialParameters.asBreeze.toDenseVector) - val parameters = { - val arrayBuilder = mutable.ArrayBuilder.make[Double] - var state: optimizer.State = null - while (states.hasNext) { - state = states.next() - arrayBuilder += state.adjustedValue - } - if (state == null) { - val msg = s"${optimizer.getClass.getName} failed." - throw new SparkException(msg) + val arrayBuilder = mutable.ArrayBuilder.make[Double] + var state: optimizer.State = null + while (states.hasNext) { + state = states.next() + arrayBuilder += state.adjustedValue + } + bcFeaturesStd.destroy() + + (if (state != null) state.x.toArray else null, arrayBuilder.result) + } + + private def trainOnBlocks( + instances: RDD[AFTPoint], + featuresStd: Array[Double], + optimizer: BreezeLBFGS[BDV[Double]], + initialParameters: Vector): (Array[Double], Array[Double]) = { + val bcFeaturesStd = instances.context.broadcast(featuresStd) + val blocks = instances.mapPartitions { iter => + val inverseStd = bcFeaturesStd.value.map { std => if (std != 0) 1.0 / std else 0.0 } + val func = StandardScalerModel.getTransformFunc(Array.empty, inverseStd, false, true) + iter.grouped($(blockSize)).map { seq => + val matrix = Matrices.fromVectors(seq.map(point => func(point.features))) + val labels = seq.map(_.label).toArray + val censors = seq.map(_.censor).toArray + (matrix, labels, censors) } - state.x.toArray.clone() } + blocks.persist(StorageLevel.MEMORY_AND_DISK) + .setName(s"training blocks (blockSize=${$(blockSize)})") - bcFeaturesStd.destroy() - if (handlePersistence) instances.unpersist() + val getAggregatorFunc = new BlockAFTAggregator($(fitIntercept))(_) + val costFun = new RDDLossFunction(blocks, getAggregatorFunc, None, $(aggregationDepth)) - val rawCoefficients = parameters.take(numFeatures) - var i = 0 - while (i < numFeatures) { - rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 } - i += 1 + val states = optimizer.iterations(new CachedDiffFunction(costFun), + initialParameters.asBreeze.toDenseVector) + + val arrayBuilder = mutable.ArrayBuilder.make[Double] + var state: optimizer.State = null + while (states.hasNext) { + state = states.next() + arrayBuilder += state.adjustedValue } - val coefficients = Vectors.dense(rawCoefficients) - val intercept = parameters(numFeatures) - val scale = math.exp(parameters(numFeatures + 1)) - new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale) + blocks.unpersist() + bcFeaturesStd.destroy() + + (if (state != null) state.x.toArray else null, arrayBuilder.result) } @Since("1.6.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index fa41a98749f32..0ee895a95a288 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -29,6 +29,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.attribute._ import org.apache.spark.ml.feature.{Instance, OffsetInstance} +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} import org.apache.spark.ml.optim._ import org.apache.spark.ml.param._ @@ -399,7 +400,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val "GeneralizedLinearRegression was given data with 0 features, and with Param fitIntercept " + "set to false. To fit a model with 0 features, fitIntercept must be set to true." ) - val w = if (!hasWeightCol) lit(1.0) else col($(weightCol)) + val w = if (!hasWeightCol) lit(1.0) else checkNonNegativeWeight(col($(weightCol))) val offset = if (!hasOffsetCol) lit(0.0) else col($(offsetCol)).cast(DoubleType) val model = if (familyAndLink.family == Gaussian && familyAndLink.link == Identity) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index fe4de57de60f2..ec2640e9ef225 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -22,6 +22,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -87,11 +88,11 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures } else { col($(featuresCol)) } - val w = if (hasWeightCol) col($(weightCol)).cast(DoubleType) else lit(1.0) + val w = + if (hasWeightCol) checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) else lit(1.0) dataset.select(col($(labelCol)).cast(DoubleType), f, w).rdd.map { - case Row(label: Double, feature: Double, weight: Double) => - (label, feature, weight) + case Row(label: Double, feature: Double, weight: Double) => (label, feature, weight) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 355055b7a9f73..bcf9b7c0426cd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.regression import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} -import breeze.optimize.{CachedDiffFunction, LBFGS => BreezeLBFGS, LBFGSB => BreezeLBFGSB, OWLQN => BreezeOWLQN} +import breeze.optimize.{CachedDiffFunction, DiffFunction, FirstOrderMinimizer, LBFGS => BreezeLBFGS, LBFGSB => BreezeLBFGSB, OWLQN => BreezeOWLQN} import breeze.stats.distributions.StudentsT import org.apache.hadoop.fs.Path @@ -28,10 +28,11 @@ import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml.{PipelineStage, PredictorParams} +import org.apache.spark.ml.feature._ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.linalg.BLAS._ import org.apache.spark.ml.optim.WeightedLeastSquares -import org.apache.spark.ml.optim.aggregator.{HuberAggregator, LeastSquaresAggregator} +import org.apache.spark.ml.optim.aggregator._ import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction} import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared._ @@ -42,6 +43,7 @@ import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.regression.{LinearRegressionModel => OldLinearRegressionModel} import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} @@ -54,7 +56,7 @@ import org.apache.spark.util.VersionUtils.majorMinorVersion private[regression] trait LinearRegressionParams extends PredictorParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver - with HasAggregationDepth with HasLoss { + with HasAggregationDepth with HasLoss with HasBlockSize { import LinearRegression._ @@ -315,49 +317,52 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String def setEpsilon(value: Double): this.type = set(epsilon, value) setDefault(epsilon -> 1.35) - override protected def train(dataset: Dataset[_]): LinearRegressionModel = instrumented { instr => - // Extract the number of features before deciding optimization solver. - val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol)) - - val instances = extractInstances(dataset) + /** + * Set block size for stacking input data in matrices. + * If blockSize == 1, then stacking will be skipped, and each vector is treated individually; + * If blockSize > 1, then vectors will be stacked to blocks, and high-level BLAS routines + * will be used if possible (for example, GEMV instead of DOT, GEMM instead of GEMV). + * Recommended size is between 10 and 1000. An appropriate choice of the block size depends + * on the sparsity and dim of input datasets, the underlying BLAS implementation (for example, + * f2jBLAS, OpenBLAS, intel MKL) and its configuration (for example, number of threads). + * Note that existing BLAS implementations are mainly optimized for dense matrices, if the + * input dataset is sparse, stacking may bring no performance gain, the worse is possible + * performance regression. + * Default is 1. + * + * @group expertSetParam + */ + @Since("3.1.0") + def setBlockSize(value: Int): this.type = set(blockSize, value) + setDefault(blockSize -> 1) + override protected def train(dataset: Dataset[_]): LinearRegressionModel = instrumented { instr => instr.logPipelineStage(this) instr.logDataset(dataset) instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, solver, tol, elasticNetParam, fitIntercept, maxIter, regParam, standardization, aggregationDepth, loss, - epsilon) + epsilon, blockSize) + + // Extract the number of features before deciding optimization solver. + val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol)) instr.logNumFeatures(numFeatures) if ($(loss) == SquaredError && (($(solver) == Auto && numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == Normal)) { - // For low dimensional data, WeightedLeastSquares is more efficient since the - // training algorithm only requires one pass through the data. (SPARK-10668) - - val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), - elasticNetParam = $(elasticNetParam), $(standardization), true, - solverType = WeightedLeastSquares.Auto, maxIter = $(maxIter), tol = $(tol)) - val model = optimizer.fit(instances, instr = OptionalInstrumentation.create(instr)) - // When it is trained by WeightedLeastSquares, training summary does not - // attach returned model. - val lrModel = copyValues(new LinearRegressionModel(uid, model.coefficients, model.intercept)) - val (summaryModel, predictionColName) = lrModel.findSummaryModelAndPredictionCol() - val trainingSummary = new LinearRegressionTrainingSummary( - summaryModel.transform(dataset), - predictionColName, - $(labelCol), - $(featuresCol), - summaryModel, - model.diagInvAtWA.toArray, - model.objectiveHistory) - - return lrModel.setSummary(Some(trainingSummary)) + return trainWithNormal(dataset, instr) } - val handlePersistence = dataset.storageLevel == StorageLevel.NONE - if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) + val instances = extractInstances(dataset) + .setName("training instances") - val (featuresSummarizer, ySummarizer) = - Summarizer.getRegressionSummarizers(instances, $(aggregationDepth)) + if (dataset.storageLevel == StorageLevel.NONE && $(blockSize) == 1) { + instances.persist(StorageLevel.MEMORY_AND_DISK) + } + + var requestedMetrics = Seq("mean", "std", "count") + if ($(blockSize) != 1) requestedMetrics +:= "numNonZeros" + val (featuresSummarizer, ySummarizer) = Summarizer + .getRegressionSummarizers(instances, $(aggregationDepth), requestedMetrics) val yMean = ySummarizer.mean(0) val rawYStd = ySummarizer.std(0) @@ -366,40 +371,20 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String instr.logNamedValue(Instrumentation.loggerTags.meanOfLabels, yMean) instr.logNamedValue(Instrumentation.loggerTags.varianceOfLabels, rawYStd) instr.logSumOfWeights(featuresSummarizer.weightSum) + if ($(blockSize) > 1) { + val scale = 1.0 / featuresSummarizer.count / numFeatures + val sparsity = 1 - featuresSummarizer.numNonzeros.toArray.map(_ * scale).sum + instr.logNamedValue("sparsity", sparsity.toString) + if (sparsity > 0.5) { + instr.logWarning(s"sparsity of input dataset is $sparsity, " + + s"which may hurt performance in high-level BLAS.") + } + } if (rawYStd == 0.0) { if ($(fitIntercept) || yMean == 0.0) { - // If the rawYStd==0 and fitIntercept==true, then the intercept is yMean with - // zero coefficient; as a result, training is not needed. - // Also, if yMean==0 and rawYStd==0, all the coefficients are zero regardless of - // the fitIntercept. - if (yMean == 0.0) { - instr.logWarning(s"Mean and standard deviation of the label are zero, so the " + - s"coefficients and the intercept will all be zero; as a result, training is not " + - s"needed.") - } else { - instr.logWarning(s"The standard deviation of the label is zero, so the coefficients " + - s"will be zeros and the intercept will be the mean of the label; as a result, " + - s"training is not needed.") - } - if (handlePersistence) instances.unpersist() - val coefficients = Vectors.sparse(numFeatures, Seq.empty) - val intercept = yMean - - val model = copyValues(new LinearRegressionModel(uid, coefficients, intercept)) - // Handle possible missing or invalid prediction columns - val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() - - val trainingSummary = new LinearRegressionTrainingSummary( - summaryModel.transform(dataset), - predictionColName, - $(labelCol), - $(featuresCol), - model, - Array(0D), - Array(0D)) - - return model.setSummary(Some(trainingSummary)) + if (instances.getStorageLevel != StorageLevel.NONE) instances.unpersist() + return trainWithConstantLabel(dataset, instr, numFeatures, yMean) } else { require($(regParam) == 0.0, "The standard deviation of the label is zero. " + "Model cannot be regularized.") @@ -413,8 +398,6 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val yStd = if (rawYStd > 0) rawYStd else math.abs(yMean) val featuresMean = featuresSummarizer.mean.toArray val featuresStd = featuresSummarizer.std.toArray - val bcFeaturesMean = instances.context.broadcast(featuresMean) - val bcFeaturesStd = instances.context.broadcast(featuresStd) if (!$(fitIntercept) && (0 until numFeatures).exists { i => featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) { @@ -437,21 +420,105 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val shouldApply = (idx: Int) => idx >= 0 && idx < numFeatures Some(new L2Regularization(effectiveL2RegParam, shouldApply, if ($(standardization)) None else Some(getFeaturesStd))) - } else { - None - } + } else None - val costFun = $(loss) match { + val optimizer = createOptimizer(effectiveRegParam, effectiveL1RegParam, + numFeatures, featuresStd) + + val initialValues = $(loss) match { case SquaredError => - val getAggregatorFunc = new LeastSquaresAggregator(yStd, yMean, $(fitIntercept), - bcFeaturesStd, bcFeaturesMean)(_) - new RDDLossFunction(instances, getAggregatorFunc, regularization, $(aggregationDepth)) + Vectors.zeros(numFeatures) case Huber => - val getAggregatorFunc = new HuberAggregator($(fitIntercept), $(epsilon), bcFeaturesStd)(_) - new RDDLossFunction(instances, getAggregatorFunc, regularization, $(aggregationDepth)) + val dim = if ($(fitIntercept)) numFeatures + 2 else numFeatures + 1 + Vectors.dense(Array.fill(dim)(1.0)) + } + + val (parameters, objectiveHistory) = if ($(blockSize) == 1) { + trainOnRows(instances, yMean, yStd, featuresMean, featuresStd, + initialValues, regularization, optimizer) + } else { + trainOnBlocks(instances, yMean, yStd, featuresMean, featuresStd, + initialValues, regularization, optimizer) + } + if (instances.getStorageLevel != StorageLevel.NONE) instances.unpersist() + + if (parameters == null) { + val msg = s"${optimizer.getClass.getName} failed." + instr.logError(msg) + throw new SparkException(msg) + } + + val model = createModel(parameters, yMean, yStd, featuresMean, featuresStd) + // Handle possible missing or invalid prediction columns + val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() + + val trainingSummary = new LinearRegressionTrainingSummary( + summaryModel.transform(dataset), predictionColName, $(labelCol), $(featuresCol), + model, Array(0.0), objectiveHistory) + model.setSummary(Some(trainingSummary)) + } + + private def trainWithNormal( + dataset: Dataset[_], + instr: Instrumentation): LinearRegressionModel = { + // For low dimensional data, WeightedLeastSquares is more efficient since the + // training algorithm only requires one pass through the data. (SPARK-10668) + + val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), + elasticNetParam = $(elasticNetParam), $(standardization), true, + solverType = WeightedLeastSquares.Auto, maxIter = $(maxIter), tol = $(tol)) + val instances = extractInstances(dataset) + .setName("training instances") + val model = optimizer.fit(instances, instr = OptionalInstrumentation.create(instr)) + // When it is trained by WeightedLeastSquares, training summary does not + // attach returned model. + val lrModel = copyValues(new LinearRegressionModel(uid, model.coefficients, model.intercept)) + val (summaryModel, predictionColName) = lrModel.findSummaryModelAndPredictionCol() + val trainingSummary = new LinearRegressionTrainingSummary( + summaryModel.transform(dataset), predictionColName, $(labelCol), $(featuresCol), + summaryModel, model.diagInvAtWA.toArray, model.objectiveHistory) + + lrModel.setSummary(Some(trainingSummary)) + } + + private def trainWithConstantLabel( + dataset: Dataset[_], + instr: Instrumentation, + numFeatures: Int, + yMean: Double): LinearRegressionModel = { + // If the rawYStd==0 and fitIntercept==true, then the intercept is yMean with + // zero coefficient; as a result, training is not needed. + // Also, if rawYStd==0 and yMean==0, all the coefficients are zero regardless of + // the fitIntercept. + if (yMean == 0.0) { + instr.logWarning(s"Mean and standard deviation of the label are zero, so the " + + s"coefficients and the intercept will all be zero; as a result, training is not " + + s"needed.") + } else { + instr.logWarning(s"The standard deviation of the label is zero, so the coefficients " + + s"will be zeros and the intercept will be the mean of the label; as a result, " + + s"training is not needed.") } + val coefficients = Vectors.sparse(numFeatures, Seq.empty) + val intercept = yMean - val optimizer = $(loss) match { + val model = copyValues(new LinearRegressionModel(uid, coefficients, intercept)) + // Handle possible missing or invalid prediction columns + val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() + + val trainingSummary = new LinearRegressionTrainingSummary( + summaryModel.transform(dataset), predictionColName, $(labelCol), $(featuresCol), + model, Array(0.0), Array(0.0)) + + model.setSummary(Some(trainingSummary)) + } + + private def createOptimizer( + effectiveRegParam: Double, + effectiveL1RegParam: Double, + numFeatures: Int, + featuresStd: Array[Double]) = { + $(loss) match { case SquaredError => if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) { new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) @@ -479,105 +546,162 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val upperBounds = BDV[Double](Array.fill(dim)(Double.MaxValue)) new BreezeLBFGSB(lowerBounds, upperBounds, $(maxIter), 10, $(tol)) } + } - val initialValues = $(loss) match { + private def trainOnRows( + instances: RDD[Instance], + yMean: Double, + yStd: Double, + featuresMean: Array[Double], + featuresStd: Array[Double], + initialValues: Vector, + regularization: Option[L2Regularization], + optimizer: FirstOrderMinimizer[BDV[Double], DiffFunction[BDV[Double]]]) = { + val bcFeaturesMean = instances.context.broadcast(featuresMean) + val bcFeaturesStd = instances.context.broadcast(featuresStd) + + val costFun = $(loss) match { case SquaredError => - Vectors.zeros(numFeatures) + val getAggregatorFunc = new LeastSquaresAggregator(yStd, yMean, $(fitIntercept), + bcFeaturesStd, bcFeaturesMean)(_) + new RDDLossFunction(instances, getAggregatorFunc, regularization, $(aggregationDepth)) case Huber => - val dim = if ($(fitIntercept)) numFeatures + 2 else numFeatures + 1 - Vectors.dense(Array.fill(dim)(1.0)) + val getAggregatorFunc = new HuberAggregator($(fitIntercept), $(epsilon), bcFeaturesStd)(_) + new RDDLossFunction(instances, getAggregatorFunc, regularization, $(aggregationDepth)) } val states = optimizer.iterations(new CachedDiffFunction(costFun), initialValues.asBreeze.toDenseVector) - val (coefficients, intercept, scale, objectiveHistory) = { - /* - Note that in Linear Regression, the objective history (loss + regularization) returned - from optimizer is computed in the scaled space given by the following formula. -
- $$ - L &= 1/2n||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2 - + regTerms \\ - $$ -
- */ - val arrayBuilder = mutable.ArrayBuilder.make[Double] - var state: optimizer.State = null - while (states.hasNext) { - state = states.next() - arrayBuilder += state.adjustedValue - } - if (state == null) { - val msg = s"${optimizer.getClass.getName} failed." - instr.logError(msg) - throw new SparkException(msg) - } + /* + Note that in Linear Regression, the objective history (loss + regularization) returned + from optimizer is computed in the scaled space given by the following formula. +
+ $$ + L &= 1/2n||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2 + + regTerms \\ + $$ +
+ */ + val arrayBuilder = mutable.ArrayBuilder.make[Double] + var state: optimizer.State = null + while (states.hasNext) { + state = states.next() + arrayBuilder += state.adjustedValue + } - bcFeaturesMean.destroy() - bcFeaturesStd.destroy() + bcFeaturesMean.destroy() + bcFeaturesStd.destroy() - val parameters = state.x.toArray.clone() + (if (state == null) null else state.x.toArray, arrayBuilder.result) + } - /* - The coefficients are trained in the scaled space; we're converting them back to - the original space. - */ - val rawCoefficients: Array[Double] = $(loss) match { - case SquaredError => parameters - case Huber => parameters.slice(0, numFeatures) - } + private def trainOnBlocks( + instances: RDD[Instance], + yMean: Double, + yStd: Double, + featuresMean: Array[Double], + featuresStd: Array[Double], + initialValues: Vector, + regularization: Option[L2Regularization], + optimizer: FirstOrderMinimizer[BDV[Double], DiffFunction[BDV[Double]]]) = { + val bcFeaturesMean = instances.context.broadcast(featuresMean) + val bcFeaturesStd = instances.context.broadcast(featuresStd) - var i = 0 - val len = rawCoefficients.length - val multiplier = $(loss) match { - case SquaredError => yStd - case Huber => 1.0 - } - while (i < len) { - rawCoefficients(i) *= { if (featuresStd(i) != 0.0) multiplier / featuresStd(i) else 0.0 } - i += 1 - } + val standardized = instances.mapPartitions { iter => + val inverseStd = bcFeaturesStd.value.map { std => if (std != 0) 1.0 / std else 0.0 } + val func = StandardScalerModel.getTransformFunc(Array.empty, inverseStd, false, true) + iter.map { case Instance(label, weight, vec) => Instance(label, weight, func(vec)) } + } + val blocks = InstanceBlock.blokify(standardized, $(blockSize)) + .persist(StorageLevel.MEMORY_AND_DISK) + .setName(s"training blocks (blockSize=${$(blockSize)})") - val interceptValue: Double = if ($(fitIntercept)) { - $(loss) match { - case SquaredError => - /* - The intercept of squared error in R's GLMNET is computed using closed form - after the coefficients are converged. See the following discussion for detail. - http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet - */ - yMean - dot(Vectors.dense(rawCoefficients), Vectors.dense(featuresMean)) - case Huber => parameters(numFeatures) - } - } else { - 0.0 - } + val costFun = $(loss) match { + case SquaredError => + val getAggregatorFunc = new BlockLeastSquaresAggregator(yStd, yMean, $(fitIntercept), + bcFeaturesStd, bcFeaturesMean)(_) + new RDDLossFunction(blocks, getAggregatorFunc, regularization, $(aggregationDepth)) + case Huber => + val getAggregatorFunc = new BlockHuberAggregator($(fitIntercept), $(epsilon))(_) + new RDDLossFunction(blocks, getAggregatorFunc, regularization, $(aggregationDepth)) + } - val scaleValue: Double = $(loss) match { - case SquaredError => 1.0 - case Huber => parameters.last - } + val states = optimizer.iterations(new CachedDiffFunction(costFun), + initialValues.asBreeze.toDenseVector) - (Vectors.dense(rawCoefficients).compressed, interceptValue, scaleValue, arrayBuilder.result()) + /* + Note that in Linear Regression, the objective history (loss + regularization) returned + from optimizer is computed in the scaled space given by the following formula. +
+ $$ + L &= 1/2n||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2 + + regTerms \\ + $$ +
+ */ + val arrayBuilder = mutable.ArrayBuilder.make[Double] + var state: optimizer.State = null + while (states.hasNext) { + state = states.next() + arrayBuilder += state.adjustedValue } - if (handlePersistence) instances.unpersist() + blocks.unpersist() + bcFeaturesMean.destroy() + bcFeaturesStd.destroy() - val model = copyValues(new LinearRegressionModel(uid, coefficients, intercept, scale)) - // Handle possible missing or invalid prediction columns - val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() + (if (state == null) null else state.x.toArray, arrayBuilder.result) + } - val trainingSummary = new LinearRegressionTrainingSummary( - summaryModel.transform(dataset), - predictionColName, - $(labelCol), - $(featuresCol), - model, - Array(0D), - objectiveHistory) + private def createModel( + parameters: Array[Double], + yMean: Double, + yStd: Double, + featuresMean: Array[Double], + featuresStd: Array[Double]): LinearRegressionModel = { + val numFeatures = featuresStd.length + /* + The coefficients are trained in the scaled space; we're converting them back to + the original space. + */ + val rawCoefficients = $(loss) match { + case SquaredError => parameters.clone() + case Huber => parameters.take(numFeatures) + } - model.setSummary(Some(trainingSummary)) + var i = 0 + val len = rawCoefficients.length + val multiplier = $(loss) match { + case SquaredError => yStd + case Huber => 1.0 + } + while (i < len) { + rawCoefficients(i) *= { if (featuresStd(i) != 0.0) multiplier / featuresStd(i) else 0.0 } + i += 1 + } + + val intercept = if ($(fitIntercept)) { + $(loss) match { + case SquaredError => + /* + The intercept of squared error in R's GLMNET is computed using closed form + after the coefficients are converged. See the following discussion for detail. + http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet + */ + yMean - dot(Vectors.dense(rawCoefficients), Vectors.dense(featuresMean)) + case Huber => parameters(numFeatures) + } + } else 0.0 + + val scale = $(loss) match { + case SquaredError => 1.0 + case Huber => parameters.last + } + + val coefficients = Vectors.dense(rawCoefficients).compressed + + copyValues(new LinearRegressionModel(uid, coefficients, intercept, scale)) } @Since("1.4.0") @@ -655,7 +779,7 @@ class LinearRegressionModel private[ml] ( // Handle possible missing or invalid prediction columns val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol() new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName, - $(labelCol), $(featuresCol), summaryModel, Array(0D)) + $(labelCol), $(featuresCol), summaryModel, Array(0.0)) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala index 4230b495fa5d5..4db518bd4f9ba 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala @@ -208,9 +208,10 @@ object Summarizer extends Logging { /** Get regression feature and label summarizers for provided data. */ private[ml] def getRegressionSummarizers( instances: RDD[Instance], - aggregationDepth: Int = 2): (SummarizerBuffer, SummarizerBuffer) = { + aggregationDepth: Int = 2, + requested: Seq[String] = Seq("mean", "std", "count")) = { instances.treeAggregate( - (Summarizer.createSummarizerBuffer("mean", "std"), + (Summarizer.createSummarizerBuffer(requested: _*), Summarizer.createSummarizerBuffer("mean", "std", "count")))( seqOp = (c: (SummarizerBuffer, SummarizerBuffer), instance: Instance) => (c._1.add(instance.features, instance.weight), @@ -223,7 +224,7 @@ object Summarizer extends Logging { } /** Get classification feature and label summarizers for provided data. */ - private[ml] def getClassificationSummarizers( + private[spark] def getClassificationSummarizers( instances: RDD[Instance], aggregationDepth: Int = 2, requested: Seq[String] = Seq("mean", "std", "count")) = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index 050ebb0fa4fbd..1a91801a9da28 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -283,7 +283,8 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[_ <: Product]) (loss * weight, weight) case other => - throw new IllegalArgumentException(s"Expected quadruples, got $other") + throw new IllegalArgumentException(s"Invalid RDD value for MulticlassMetrics.logLoss. " + + s"Expected quadruples, got $other") }.treeReduce { case ((l1, w1), (l2, w2)) => (l1 + l2, w1 + w2) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 933a63b40fcf8..25e9697d64855 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -288,6 +288,67 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { val mlorSummary = mlorModel.evaluate(smallMultinomialDataset) assert(blorSummary.isInstanceOf[BinaryLogisticRegressionSummary]) assert(mlorSummary.isInstanceOf[LogisticRegressionSummary]) + + // verify instance weight works + val lr2 = new LogisticRegression() + .setFamily("binomial") + .setMaxIter(1) + .setWeightCol("weight") + + val smallBinaryDatasetWithWeight = + smallBinaryDataset.select(col("label"), col("features"), lit(2.5).as("weight")) + + val smallMultinomialDatasetWithWeight = + smallMultinomialDataset.select(col("label"), col("features"), lit(10.0).as("weight")) + + val blorModel2 = lr2.fit(smallBinaryDatasetWithWeight) + assert(blorModel2.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummary]) + assert(blorModel2.summary.asBinary.isInstanceOf[BinaryLogisticRegressionSummary]) + assert(blorModel2.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummary]) + + val mlorModel2 = lr2.setFamily("multinomial").fit(smallMultinomialDatasetWithWeight) + assert(mlorModel2.summary.isInstanceOf[LogisticRegressionTrainingSummary]) + withClue("cannot get binary summary for multiclass model") { + intercept[RuntimeException] { + mlorModel.binarySummary + } + } + withClue("cannot cast summary to binary summary multiclass model") { + intercept[RuntimeException] { + mlorModel.summary.asBinary + } + } + + val mlorBinaryModel2 = lr2.setFamily("multinomial").fit(smallBinaryDatasetWithWeight) + assert(mlorBinaryModel2.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummary]) + assert(mlorBinaryModel2.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummary]) + + val blorSummary2 = blorModel2.evaluate(smallBinaryDatasetWithWeight) + val mlorSummary2 = mlorModel2.evaluate(smallMultinomialDatasetWithWeight) + assert(blorSummary2.isInstanceOf[BinaryLogisticRegressionSummary]) + assert(mlorSummary2.isInstanceOf[LogisticRegressionSummary]) + + assert(blorSummary.accuracy ~== blorSummary2.accuracy relTol 1e-6) + assert(blorSummary.weightedPrecision ~== blorSummary2.weightedPrecision relTol 1e-6) + assert(blorSummary.weightedRecall ~== blorSummary2.weightedRecall relTol 1e-6) + assert(blorSummary.asBinary.areaUnderROC ~== blorSummary2.asBinary.areaUnderROC relTol 1e-6) + + assert(blorModel.summary.asBinary.accuracy ~== + blorModel2.summary.asBinary.accuracy relTol 1e-6) + assert(blorModel.summary.asBinary.weightedPrecision ~== + blorModel2.summary.asBinary.weightedPrecision relTol 1e-6) + assert(blorModel.summary.asBinary.weightedRecall ~== + blorModel2.summary.asBinary.weightedRecall relTol 1e-6) + assert(blorModel.summary.asBinary.asBinary.areaUnderROC ~== + blorModel2.summary.asBinary.areaUnderROC relTol 1e-6) + + assert(mlorSummary.accuracy ~== mlorSummary2.accuracy relTol 1e-6) + assert(mlorSummary.weightedPrecision ~== mlorSummary2.weightedPrecision relTol 1e-6) + assert(mlorSummary.weightedRecall ~== mlorSummary2.weightedRecall relTol 1e-6) + + assert(mlorModel.summary.accuracy ~== mlorModel2.summary.accuracy relTol 1e-6) + assert(mlorModel.summary.weightedPrecision ~== mlorModel2.summary.weightedPrecision relTol 1e-6) + assert(mlorModel.summary.weightedRecall ~==mlorModel2.summary.weightedRecall relTol 1e-6) } test("setThreshold, getThreshold") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index b35f964c959bf..d848d5a5ee452 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -285,6 +285,17 @@ class GaussianMixtureSuite extends MLTest with DefaultReadWriteTest { testClusteringModelSingleProbabilisticPrediction(model, model.predictProbability, dataset, model.getFeaturesCol, model.getProbabilityCol) } + + test("GMM on blocks") { + Seq(dataset, sparseDataset, denseDataset, rDataset).foreach { dataset => + val gmm = new GaussianMixture().setK(k).setMaxIter(20).setBlockSize(1).setSeed(seed) + val model = gmm.fit(dataset) + Seq(2, 4, 8, 16, 32).foreach { blockSize => + val model2 = gmm.setBlockSize(blockSize).fit(dataset) + modelEquals(model, model2) + } + } + } } object GaussianMixtureSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala index 83b213ab51d43..008bf0e108e13 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -102,4 +102,27 @@ class BinaryClassificationEvaluatorSuite val evaluator = new BinaryClassificationEvaluator().setRawPredictionCol("prediction") MLTestingUtils.checkNumericTypes(evaluator, spark) } + + test("getMetrics") { + val weightCol = "weight" + // get metric with weight column + val evaluator = new BinaryClassificationEvaluator() + .setWeightCol(weightCol) + val vectorDF = Seq( + (0.0, Vectors.dense(2.5, 12), 1.0), + (1.0, Vectors.dense(1, 3), 1.0), + (0.0, Vectors.dense(10, 2), 1.0) + ).toDF("label", "rawPrediction", weightCol) + + val metrics = evaluator.getMetrics(vectorDF) + val roc = metrics.areaUnderROC() + val pr = metrics.areaUnderPR() + + // default = areaUnderROC + assert(evaluator.evaluate(vectorDF) == roc) + + // areaUnderPR + evaluator.setMetricName("areaUnderPR") + assert(evaluator.evaluate(vectorDF) == pr) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala index 6cf3b1deeac93..d4c620adc2e3c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala @@ -19,12 +19,13 @@ package org.apache.spark.ml.evaluation import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.AttributeGroup -import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.lit class ClusteringEvaluatorSuite @@ -145,4 +146,60 @@ class ClusteringEvaluatorSuite assert(evaluator.evaluate(twoSingleItemClusters) === 0.0) } + test("getMetrics") { + val evaluator = new ClusteringEvaluator() + .setFeaturesCol("features") + .setPredictionCol("label") + + val metrics1 = evaluator.getMetrics(irisDataset) + val silhouetteScoreEuclidean = metrics1.silhouette + + assert(evaluator.evaluate(irisDataset) == silhouetteScoreEuclidean) + + evaluator.setDistanceMeasure("cosine") + val metrics2 = evaluator.getMetrics(irisDataset) + val silhouetteScoreCosin = metrics2.silhouette + + assert(evaluator.evaluate(irisDataset) == silhouetteScoreCosin) + } + + test("test weight support") { + Seq("squaredEuclidean", "cosine").foreach { distanceMeasure => + val evaluator1 = new ClusteringEvaluator() + .setFeaturesCol("features") + .setPredictionCol("label") + .setDistanceMeasure(distanceMeasure) + + val evaluator2 = new ClusteringEvaluator() + .setFeaturesCol("features") + .setPredictionCol("label") + .setDistanceMeasure(distanceMeasure) + .setWeightCol("weight") + + Seq(0.25, 1.0, 10.0, 99.99).foreach { w => + var score1 = evaluator1.evaluate(irisDataset) + var score2 = evaluator2.evaluate(irisDataset.withColumn("weight", lit(w))) + assert(score1 ~== score2 relTol 1e-6) + + score1 = evaluator1.evaluate(newIrisDataset) + score2 = evaluator2.evaluate(newIrisDataset.withColumn("weight", lit(w))) + assert(score1 ~== score2 relTol 1e-6) + } + } + } + + test("single-element clusters with weight") { + val singleItemClusters = spark.createDataFrame(spark.sparkContext.parallelize(Array( + (0.0, Vectors.dense(5.1, 3.5, 1.4, 0.2), 6.0), + (1.0, Vectors.dense(7.0, 3.2, 4.7, 1.4), 0.25), + (2.0, Vectors.dense(6.3, 3.3, 6.0, 2.5), 9.99)))).toDF("label", "features", "weight") + Seq("squaredEuclidean", "cosine").foreach { distanceMeasure => + val evaluator = new ClusteringEvaluator() + .setFeaturesCol("features") + .setPredictionCol("label") + .setDistanceMeasure(distanceMeasure) + .setWeightCol("weight") + assert(evaluator.evaluate(singleItemClusters) === 0.0) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala index 5b5212abdf7cc..3dfd860a5b9d8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala @@ -80,4 +80,33 @@ class MulticlassClassificationEvaluatorSuite .setMetricName("logLoss") assert(evaluator.evaluate(df) ~== 0.9682005730687164 absTol 1e-5) } + + test("getMetrics") { + val predictionAndLabels = Seq((0.0, 0.0), (0.0, 1.0), + (0.0, 0.0), (1.0, 0.0), (1.0, 1.0), + (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)).toDF("prediction", "label") + + val evaluator = new MulticlassClassificationEvaluator() + + val metrics = evaluator.getMetrics(predictionAndLabels) + val f1 = metrics.weightedFMeasure + val accuracy = metrics.accuracy + val precisionByLabel = metrics.precision(evaluator.getMetricLabel) + + // default = f1 + assert(evaluator.evaluate(predictionAndLabels) == f1) + + // accuracy + evaluator.setMetricName("accuracy") + assert(evaluator.evaluate(predictionAndLabels) == accuracy) + + // precisionByLabel + evaluator.setMetricName("precisionByLabel") + assert(evaluator.evaluate(predictionAndLabels) == precisionByLabel) + + // truePositiveRateByLabel + evaluator.setMetricName("truePositiveRateByLabel").setMetricLabel(1.0) + assert(evaluator.evaluate(predictionAndLabels) == + metrics.truePositiveRate(evaluator.getMetricLabel)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MultilabelClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MultilabelClassificationEvaluatorSuite.scala index f41fc04a5faed..520103d6aed92 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MultilabelClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MultilabelClassificationEvaluatorSuite.scala @@ -59,4 +59,52 @@ class MultilabelClassificationEvaluatorSuite .setMetricName("precisionByLabel") testDefaultReadWrite(evaluator) } + + test("getMetrics") { + val scoreAndLabels = Seq((Array(0.0, 1.0), Array(0.0, 2.0)), + (Array(0.0, 2.0), Array(0.0, 1.0)), + (Array.empty[Double], Array(0.0)), + (Array(2.0), Array(2.0)), + (Array(2.0, 0.0), Array(2.0, 0.0)), + (Array(0.0, 1.0, 2.0), Array(0.0, 1.0)), + (Array(1.0), Array(1.0, 2.0))).toDF("prediction", "label") + + val evaluator = new MultilabelClassificationEvaluator() + + val metrics = evaluator.getMetrics(scoreAndLabels) + val f1 = metrics.f1Measure + val accuracy = metrics.accuracy + val precision = metrics.precision + val recall = metrics.recall + val hammingLoss = metrics.hammingLoss + val precisionByLabel = metrics.precision(evaluator.getMetricLabel) + + // default = f1 + assert(evaluator.evaluate(scoreAndLabels) == f1) + + // accuracy + evaluator.setMetricName("accuracy") + assert(evaluator.evaluate(scoreAndLabels) == accuracy) + + // precision + evaluator.setMetricName("precision") + assert(evaluator.evaluate(scoreAndLabels) == precision) + + // recall + evaluator.setMetricName("recall") + assert(evaluator.evaluate(scoreAndLabels) == recall) + + // hammingLoss + evaluator.setMetricName("hammingLoss") + assert(evaluator.evaluate(scoreAndLabels) == hammingLoss) + + // precisionByLabel + evaluator.setMetricName("precisionByLabel") + assert(evaluator.evaluate(scoreAndLabels) == precisionByLabel) + + // truePositiveRateByLabel + evaluator.setMetricName("recallByLabel").setMetricLabel(1.0) + assert(evaluator.evaluate(scoreAndLabels) == + metrics.recall(evaluator.getMetricLabel)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingEvaluatorSuite.scala index 02d26d7eb351f..b3457981a08e9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingEvaluatorSuite.scala @@ -59,4 +59,42 @@ class RankingEvaluatorSuite .setK(2) assert(evaluator.evaluate(scoreAndLabels) ~== 1.0 / 3 absTol 1e-5) } + + test("getMetrics") { + val scoreAndLabels = Seq( + (Array(1.0, 6.0, 2.0, 7.0, 8.0, 3.0, 9.0, 10.0, 4.0, 5.0), + Array(1.0, 2.0, 3.0, 4.0, 5.0)), + (Array(4.0, 1.0, 5.0, 6.0, 2.0, 7.0, 3.0, 8.0, 9.0, 10.0), + Array(1.0, 2.0, 3.0)), + (Array(1.0, 2.0, 3.0, 4.0, 5.0), Array.empty[Double]) + ).toDF("prediction", "label") + + val evaluator = new RankingEvaluator().setK(5) + + val metrics = evaluator.getMetrics(scoreAndLabels) + val meanAveragePrecision = metrics.meanAveragePrecision + val meanAveragePrecisionAtK = metrics.meanAveragePrecisionAt(evaluator.getK) + val precisionAtK = metrics.precisionAt(evaluator.getK) + val ndcgAtK = metrics.ndcgAt(evaluator.getK) + val recallAtK = metrics.recallAt(evaluator.getK) + + // default = meanAveragePrecision + assert(evaluator.evaluate(scoreAndLabels) == meanAveragePrecision) + + // meanAveragePrecisionAtK + evaluator.setMetricName("meanAveragePrecisionAtK") + assert(evaluator.evaluate(scoreAndLabels) == meanAveragePrecisionAtK) + + // precisionAtK + evaluator.setMetricName("precisionAtK") + assert(evaluator.evaluate(scoreAndLabels) == precisionAtK) + + // ndcgAtK + evaluator.setMetricName("ndcgAtK") + assert(evaluator.evaluate(scoreAndLabels) == ndcgAtK) + + // recallAtK + evaluator.setMetricName("recallAtK") + assert(evaluator.evaluate(scoreAndLabels) == recallAtK) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala index f4f858c3e92dc..5ee161ce8dd33 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -93,4 +93,37 @@ class RegressionEvaluatorSuite test("should support all NumericType labels and not support other types") { MLTestingUtils.checkNumericTypes(new RegressionEvaluator, spark) } + + test("getMetrics") { + val dataset = LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1) + .map(_.asML).toDF() + + val trainer = new LinearRegression + val model = trainer.fit(dataset) + val predictions = model.transform(dataset) + + val evaluator = new RegressionEvaluator() + + val metrics = evaluator.getMetrics(predictions) + val rmse = metrics.rootMeanSquaredError + val r2 = metrics.r2 + val mae = metrics.meanAbsoluteError + val variance = metrics.explainedVariance + + // default = rmse + assert(evaluator.evaluate(predictions) == rmse) + + // r2 score + evaluator.setMetricName("r2") + assert(evaluator.evaluate(predictions) == r2) + + // mae + evaluator.setMetricName("mae") + assert(evaluator.evaluate(predictions) == mae) + + // var + evaluator.setMetricName("var") + assert(evaluator.evaluate(predictions) == variance) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala index 722302e5a165f..8fd192fa56500 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala @@ -100,6 +100,10 @@ class HashingTFSuite extends MLTest with DefaultReadWriteTest { val metadata = spark.read.json(s"$hashingTFPath/metadata") val sparkVersionStr = metadata.select("sparkVersion").first().getString(0) assert(sparkVersionStr == "2.4.4") + + intercept[IllegalArgumentException] { + loadedHashingTF.save(hashingTFPath) + } } test("read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 6f6ab26cbac43..682b87a0f68d3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -512,4 +512,22 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { assert(observedNumBuckets === numBuckets, "Observed number of buckets does not equal expected number of buckets.") } + + test("[SPARK-31676] QuantileDiscretizer raise error parameter splits given invalid value") { + import scala.util.Random + val rng = new Random(3) + + val a1 = Array.tabulate(200)(_ => rng.nextDouble * 2.0 - 1.0) ++ + Array.fill(20)(0.0) ++ Array.fill(20)(-0.0) + + val df1 = sc.parallelize(a1, 2).toDF("id") + + val qd = new QuantileDiscretizer() + .setInputCol("id") + .setOutputCol("out") + .setNumBuckets(200) + .setRelativeError(0.0) + + qd.fit(df1) // assert no exception raised here. + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index a4d388fd321db..4957f6f1f46a7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -261,4 +261,15 @@ class VectorAssemblerSuite val output = vectorAssembler.transform(dfWithNullsAndNaNs) assert(output.select("a").limit(1).collect().head == Row(Vectors.sparse(0, Seq.empty))) } + + test("SPARK-31671: should give explicit error message when can not infer column lengths") { + val df = Seq( + (Vectors.dense(1.0), Vectors.dense(2.0)) + ).toDF("n1", "n2") + val hintedDf = new VectorSizeHint().setInputCol("n1").setSize(1).transform(df) + val assembler = new VectorAssembler() + .setInputCols(Array("n1", "n2")).setOutputCol("features") + assert(!intercept[RuntimeException](assembler.setHandleInvalid("keep").transform(hintedDf)) + .getMessage.contains("n1"), "should only show no vector size columns' name") + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HingeAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HingeAggregatorSuite.scala index e2b417882403e..425a5eb26ab67 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HingeAggregatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HingeAggregatorSuite.scala @@ -63,22 +63,7 @@ class HingeAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { new HingeAggregator(bcFeaturesStd, fitIntercept)(bcCoefficients) } - private def standardize(instances: Array[Instance]): Array[Instance] = { - val (featuresSummarizer, _) = - Summarizer.getClassificationSummarizers(sc.parallelize(instances)) - val stdArray = featuresSummarizer.std.toArray - val numFeatures = stdArray.length - instances.map { case Instance(label, weight, features) => - val standardized = Array.ofDim[Double](numFeatures) - features.foreachNonZero { (i, v) => - val std = stdArray(i) - if (std != 0) standardized(i) = v / std - } - Instance(label, weight, Vectors.dense(standardized).compressed) - } - } - - /** Get summary statistics for some data and create a new BlockHingeAggregator. */ + /** Get summary statistics for some data and create a new BlockHingeAggregator. */ private def getNewBlockAggregator( coefficients: Vector, fitIntercept: Boolean): BlockHingeAggregator = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HuberAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HuberAggregatorSuite.scala index f5de41695a47e..d64c4227faf85 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HuberAggregatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HuberAggregatorSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.optim.aggregator import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.feature.{Instance, InstanceBlock} import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} import org.apache.spark.ml.stat.Summarizer import org.apache.spark.ml.util.TestingUtils._ @@ -28,6 +28,7 @@ class HuberAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var instances: Array[Instance] = _ @transient var instancesConstantFeature: Array[Instance] = _ @transient var instancesConstantFeatureFiltered: Array[Instance] = _ + @transient var standardizedInstances: Array[Instance] = _ override def beforeAll(): Unit = { super.beforeAll() @@ -46,6 +47,7 @@ class HuberAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { Instance(1.0, 0.5, Vectors.dense(1.0)), Instance(2.0, 0.3, Vectors.dense(0.5)) ) + standardizedInstances = standardize(instances) } /** Get summary statistics for some data and create a new HuberAggregator. */ @@ -61,6 +63,15 @@ class HuberAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { new HuberAggregator(fitIntercept, epsilon, bcFeaturesStd)(bcParameters) } + /** Get summary statistics for some data and create a new BlockHingeAggregator. */ + private def getNewBlockAggregator( + parameters: Vector, + fitIntercept: Boolean, + epsilon: Double): BlockHuberAggregator = { + val bcParameters = spark.sparkContext.broadcast(parameters) + new BlockHuberAggregator(fitIntercept, epsilon)(bcParameters) + } + test("aggregator add method should check input size") { val parameters = Vectors.dense(1.0, 2.0, 3.0, 4.0) val agg = getNewAggregator(instances, parameters, fitIntercept = true, epsilon = 1.35) @@ -147,6 +158,23 @@ class HuberAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { assert(loss ~== agg.loss relTol 0.01) assert(gradient ~== agg.gradient relTol 0.01) + + Seq(1, 2, 4).foreach { blockSize => + val blocks1 = standardizedInstances + .grouped(blockSize) + .map(seq => InstanceBlock.fromInstances(seq)) + .toArray + val blocks2 = blocks1.map { block => + new InstanceBlock(block.labels, block.weights, block.matrix.toSparseRowMajor) + } + + Seq(blocks1, blocks2).foreach { blocks => + val blockAgg = getNewBlockAggregator(parameters, fitIntercept = true, epsilon) + blocks.foreach(blockAgg.add) + assert(agg.loss ~== blockAgg.loss relTol 1e-9) + assert(agg.gradient ~== blockAgg.gradient relTol 1e-9) + } + } } test("check with zero standard deviation") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregatorSuite.scala index 03ed323c9a387..ebce90b59c89c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregatorSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.optim.aggregator import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.feature.{Instance, InstanceBlock} import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} import org.apache.spark.ml.stat.Summarizer import org.apache.spark.ml.util.TestingUtils._ @@ -28,6 +28,7 @@ class LeastSquaresAggregatorSuite extends SparkFunSuite with MLlibTestSparkConte @transient var instances: Array[Instance] = _ @transient var instancesConstantFeature: Array[Instance] = _ @transient var instancesConstantLabel: Array[Instance] = _ + @transient var standardizedInstances: Array[Instance] = _ override def beforeAll(): Unit = { super.beforeAll() @@ -46,6 +47,7 @@ class LeastSquaresAggregatorSuite extends SparkFunSuite with MLlibTestSparkConte Instance(1.0, 0.5, Vectors.dense(1.5, 1.0)), Instance(1.0, 0.3, Vectors.dense(4.0, 0.5)) ) + standardizedInstances = standardize(instances) } /** Get summary statistics for some data and create a new LeastSquaresAggregator. */ @@ -66,6 +68,24 @@ class LeastSquaresAggregatorSuite extends SparkFunSuite with MLlibTestSparkConte bcFeaturesMean)(bcCoefficients) } + /** Get summary statistics for some data and create a new BlockHingeAggregator. */ + private def getNewBlockAggregator( + instances: Array[Instance], + coefficients: Vector, + fitIntercept: Boolean): BlockLeastSquaresAggregator = { + val (featuresSummarizer, ySummarizer) = + Summarizer.getRegressionSummarizers(sc.parallelize(instances)) + val yStd = ySummarizer.std(0) + val yMean = ySummarizer.mean(0) + val featuresStd = featuresSummarizer.std.toArray + val bcFeaturesStd = spark.sparkContext.broadcast(featuresStd) + val featuresMean = featuresSummarizer.mean + val bcFeaturesMean = spark.sparkContext.broadcast(featuresMean.toArray) + val bcCoefficients = spark.sparkContext.broadcast(coefficients) + new BlockLeastSquaresAggregator(yStd, yMean, fitIntercept, bcFeaturesStd, + bcFeaturesMean)(bcCoefficients) + } + test("aggregator add method input size") { val coefficients = Vectors.dense(1.0, 2.0) val agg = getNewAggregator(instances, coefficients, fitIntercept = true) @@ -142,6 +162,23 @@ class LeastSquaresAggregatorSuite extends SparkFunSuite with MLlibTestSparkConte BLAS.scal(1.0 / weightSum, expectedGradient) assert(agg.loss ~== (expectedLoss.sum / weightSum) relTol 1e-5) assert(agg.gradient ~== expectedGradient relTol 1e-5) + + Seq(1, 2, 4).foreach { blockSize => + val blocks1 = standardizedInstances + .grouped(blockSize) + .map(seq => InstanceBlock.fromInstances(seq)) + .toArray + val blocks2 = blocks1.map { block => + new InstanceBlock(block.labels, block.weights, block.matrix.toSparseRowMajor) + } + + Seq(blocks1, blocks2).foreach { blocks => + val blockAgg = getNewBlockAggregator(instances, coefficients, fitIntercept = true) + blocks.foreach(blockAgg.add) + assert(agg.loss ~== blockAgg.loss relTol 1e-9) + assert(agg.gradient ~== blockAgg.gradient relTol 1e-9) + } + } } test("check with zero standard deviation") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregatorSuite.scala index a8fdcc38d13bb..e3e39c691b8a3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregatorSuite.scala @@ -79,21 +79,6 @@ class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { new BlockLogisticAggregator(numFeatures, numClasses, fitIntercept, multinomial)(bcCoefficients) } - private def standardize(instances: Array[Instance]): Array[Instance] = { - val (featuresSummarizer, _) = - Summarizer.getClassificationSummarizers(sc.parallelize(instances)) - val stdArray = featuresSummarizer.std.toArray - val numFeatures = stdArray.length - instances.map { case Instance(label, weight, features) => - val standardized = Array.ofDim[Double](numFeatures) - features.foreachNonZero { (i, v) => - val std = stdArray(i) - if (std != 0) standardized(i) = v / std - } - Instance(label, weight, Vectors.dense(standardized).compressed) - } - } - test("aggregator add method input size") { val coefArray = Array(1.0, 2.0, -2.0, 3.0, 0.0, -1.0) val interceptArray = Array(4.0, 2.0, -3.0) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 6cc73e040e82c..a66143ab12e49 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -428,6 +428,22 @@ class AFTSurvivalRegressionSuite extends MLTest with DefaultReadWriteTest { val trainer = new AFTSurvivalRegression() trainer.fit(dataset) } + + test("AFTSurvivalRegression on blocks") { + val quantileProbabilities = Array(0.1, 0.5, 0.9) + for (dataset <- Seq(datasetUnivariate, datasetUnivariateScaled, datasetMultivariate)) { + val aft = new AFTSurvivalRegression() + .setQuantileProbabilities(quantileProbabilities) + .setQuantilesCol("quantiles") + val model = aft.fit(dataset) + Seq(4, 16, 64).foreach { blockSize => + val model2 = aft.setBlockSize(blockSize).fit(dataset) + assert(model.coefficients ~== model2.coefficients relTol 1e-9) + assert(model.intercept ~== model2.intercept relTol 1e-9) + assert(model.scale ~== model2.scale relTol 1e-9) + } + } + } } object AFTSurvivalRegressionSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 82d984933d815..df9a66b49fe48 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -660,6 +660,26 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLRe testPredictionModelSinglePrediction(model, datasetWithDenseFeature) } + test("LinearRegression on blocks") { + for (dataset <- Seq(datasetWithDenseFeature, datasetWithStrongNoise, + datasetWithDenseFeatureWithoutIntercept, datasetWithSparseFeature, datasetWithWeight, + datasetWithWeightConstantLabel, datasetWithWeightZeroLabel, datasetWithOutlier); + fitIntercept <- Seq(true, false); + loss <- Seq("squaredError", "huber")) { + val lir = new LinearRegression() + .setFitIntercept(fitIntercept) + .setLoss(loss) + .setMaxIter(3) + val model = lir.fit(dataset) + Seq(4, 16, 64).foreach { blockSize => + val model2 = lir.setBlockSize(blockSize).fit(dataset) + assert(model.intercept ~== model2.intercept relTol 1e-9) + assert(model.coefficients ~== model2.coefficients relTol 1e-9) + assert(model.scale ~== model2.scale relTol 1e-9) + } + } + } + test("linear regression model with constant label") { /* R code: diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala index f9a3cd088314e..5eb128abacdb9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala @@ -22,6 +22,8 @@ import java.io.File import org.scalatest.Suite import org.apache.spark.SparkContext +import org.apache.spark.ml.feature._ +import org.apache.spark.ml.stat.Summarizer import org.apache.spark.ml.util.TempDirectory import org.apache.spark.sql.{SparkSession, SQLContext, SQLImplicits} import org.apache.spark.util.Utils @@ -66,4 +68,13 @@ trait MLlibTestSparkContext extends TempDirectory { self: Suite => protected object testImplicits extends SQLImplicits { protected override def _sqlContext: SQLContext = self.spark.sqlContext } + + private[spark] def standardize(instances: Array[Instance]): Array[Instance] = { + val (featuresSummarizer, _) = + Summarizer.getClassificationSummarizers(sc.parallelize(instances)) + val inverseStd = featuresSummarizer.std.toArray + .map { std => if (std != 0) 1.0 / std else 0.0 } + val func = StandardScalerModel.getTransformFunc(Array.empty, inverseStd, false, true) + instances.map { case Instance(label, weight, vec) => Instance(label, weight, func(vec)) } + } } diff --git a/pom.xml b/pom.xml index 7d4c3c49bb2ca..b3f7b7db1a79a 100644 --- a/pom.xml +++ b/pom.xml @@ -167,7 +167,7 @@ true 1.9.13 2.10.0 - 1.1.7.3 + 1.1.7.5 1.1.2 1.10 2.4 @@ -179,7 +179,7 @@ 2.6.2 4.1.17 14.0.1 - 3.0.16 + 3.1.2 2.30 2.10.5 3.5.2 @@ -204,6 +204,9 @@ org.fusesource.leveldbjni ${java.home} + + + org.apache.spark.tags.ChromeUITest @@ -243,6 +246,7 @@ things breaking. --> ${session.executionRootDirectory} + 1g @@ -661,7 +665,7 @@ com.github.luben zstd-jni - 1.4.4-3 + 1.4.5-2 com.clearspring.analytics @@ -1357,6 +1361,10 @@ com.zaxxer HikariCP-java7 + + com.microsoft.sqlserver + mssql-jdbc + @@ -2512,10 +2520,11 @@ false false true + ${spark.test.webdriver.chrome.driver} __not_used__ - ${test.exclude.tags} + ${test.exclude.tags},${test.default.exclude.tags} ${test.include.tags} @@ -3037,6 +3046,7 @@ 3.2.0 2.13.0 + 2.5 diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index e2228b0b45075..57fbb125dc470 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -45,7 +45,11 @@ object MimaExcludes { // after: class ChiSqSelector extends PSelector // false positive, no binary incompatibility ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.ChiSqSelectorModel"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.ChiSqSelector") + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.ChiSqSelector"), + + //[SPARK-31840] Add instance weight support in LogisticRegressionSummary + // weightCol in org.apache.spark.ml.classification.LogisticRegressionSummary is present only in current version + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightCol") ) // Exclude rules for 3.0.x diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 65937ad3cefe3..eb12f2f1f6ab7 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -623,7 +623,6 @@ object KubernetesIntegrationTests { object DependencyOverrides { lazy val settings = Seq( dependencyOverrides += "com.google.guava" % "guava" % "14.0.1", - dependencyOverrides += "commons-io" % "commons-io" % "2.4", dependencyOverrides += "xerces" % "xercesImpl" % "2.12.0", dependencyOverrides += "jline" % "jline" % "2.14.6", dependencyOverrides += "org.apache.avro" % "avro" % "1.8.2") @@ -967,6 +966,9 @@ object TestSettings { "2.12" } */ + + private val defaultExcludedTags = Seq("org.apache.spark.tags.ChromeUITest") + lazy val settings = Seq ( // Fork new JVMs for tests and set Java options for those fork := true, @@ -1004,6 +1006,10 @@ object TestSettings { sys.props.get("test.exclude.tags").map { tags => tags.split(",").flatMap { tag => Seq("-l", tag) }.toSeq }.getOrElse(Nil): _*), + testOptions in Test += Tests.Argument(TestFrameworks.ScalaTest, + sys.props.get("test.default.exclude.tags").map(tags => tags.split(",").toSeq) + .map(tags => tags.filter(!_.trim.isEmpty)).getOrElse(defaultExcludedTags) + .flatMap(tag => Seq("-l", tag)): _*), testOptions in Test += Tests.Argument(TestFrameworks.JUnit, sys.props.get("test.exclude.tags").map { tags => Seq("--exclude-categories=" + tags) diff --git a/python/docs/index.rst b/python/docs/index.rst index 0e7b62361802a..6e059264e6bbb 100644 --- a/python/docs/index.rst +++ b/python/docs/index.rst @@ -16,6 +16,7 @@ Contents: pyspark.streaming pyspark.ml pyspark.mllib + pyspark.resource Core classes: diff --git a/python/docs/pyspark.resource.rst b/python/docs/pyspark.resource.rst new file mode 100644 index 0000000000000..7f3a79b9e5b52 --- /dev/null +++ b/python/docs/pyspark.resource.rst @@ -0,0 +1,11 @@ +pyspark.resource module +======================= + +Module Contents +--------------- + +.. automodule:: pyspark.resource + :members: + :undoc-members: + :inherited-members: + diff --git a/python/docs/pyspark.rst b/python/docs/pyspark.rst index 0df12c49ad033..402d6ce9eb016 100644 --- a/python/docs/pyspark.rst +++ b/python/docs/pyspark.rst @@ -11,6 +11,7 @@ Subpackages pyspark.streaming pyspark.ml pyspark.mllib + pyspark.resource Contents -------- diff --git a/python/docs/pyspark.sql.rst b/python/docs/pyspark.sql.rst index b69562e845920..406ada701941a 100644 --- a/python/docs/pyspark.sql.rst +++ b/python/docs/pyspark.sql.rst @@ -1,8 +1,8 @@ pyspark.sql module ================== -Module Context --------------- +Module Contents +--------------- .. automodule:: pyspark.sql :members: diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 70c0b27a6aa33..ee153af18c88c 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -54,7 +54,6 @@ from pyspark.storagelevel import StorageLevel from pyspark.accumulators import Accumulator, AccumulatorParam from pyspark.broadcast import Broadcast -from pyspark.resourceinformation import ResourceInformation from pyspark.serializers import MarshalSerializer, PickleSerializer from pyspark.status import * from pyspark.taskcontext import TaskContext, BarrierTaskContext, BarrierTaskInfo @@ -119,5 +118,5 @@ def wrapper(self, *args, **kwargs): "SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast", "Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer", "StatusTracker", "SparkJobInfo", "SparkStageInfo", "Profiler", "BasicProfiler", "TaskContext", - "RDDBarrier", "BarrierTaskContext", "BarrierTaskInfo", "ResourceInformation", + "RDDBarrier", "BarrierTaskContext", "BarrierTaskInfo", ] diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 6cc343e3e495c..96353bb9228d5 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -25,6 +25,7 @@ from tempfile import NamedTemporaryFile from py4j.protocol import Py4JError +from py4j.java_gateway import is_instance_of from pyspark import accumulators from pyspark.accumulators import Accumulator @@ -35,7 +36,7 @@ from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ PairDeserializer, AutoBatchedSerializer, NoOpSerializer, ChunkedStream from pyspark.storagelevel import StorageLevel -from pyspark.resourceinformation import ResourceInformation +from pyspark.resource.information import ResourceInformation from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.traceback_utils import CallSite, first_spark_call from pyspark.status import StatusTracker @@ -864,8 +865,21 @@ def union(self, rdds): first_jrdd_deserializer = rdds[0]._jrdd_deserializer if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds): rdds = [x._reserialize() for x in rdds] - cls = SparkContext._jvm.org.apache.spark.api.java.JavaRDD - jrdds = SparkContext._gateway.new_array(cls, len(rdds)) + gw = SparkContext._gateway + jvm = SparkContext._jvm + jrdd_cls = jvm.org.apache.spark.api.java.JavaRDD + jpair_rdd_cls = jvm.org.apache.spark.api.java.JavaPairRDD + jdouble_rdd_cls = jvm.org.apache.spark.api.java.JavaDoubleRDD + if is_instance_of(gw, rdds[0]._jrdd, jrdd_cls): + cls = jrdd_cls + elif is_instance_of(gw, rdds[0]._jrdd, jpair_rdd_cls): + cls = jpair_rdd_cls + elif is_instance_of(gw, rdds[0]._jrdd, jdouble_rdd_cls): + cls = jdouble_rdd_cls + else: + cls_name = rdds[0]._jrdd.getClass().getCanonicalName() + raise TypeError("Unsupported Java RDD class %s" % cls_name) + jrdds = gw.new_array(cls, len(rdds)) for i in range(0, len(rdds)): jrdds[i] = rdds[i]._jrdd return RDD(self._jsc.union(jrdds), self, rdds[0]._jrdd_deserializer) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index d635be1d8db80..734c393db2a26 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -197,7 +197,7 @@ class _JavaClassificationModel(ClassificationModel, JavaPredictionModel): """ Java Model produced by a ``Classifier``. Classes are indexed {0, 1, ..., numClasses - 1}. - To be mixed in with class:`pyspark.ml.JavaModel` + To be mixed in with :class:`pyspark.ml.JavaModel` """ @property @@ -662,6 +662,8 @@ class LogisticRegression(_JavaProbabilisticClassifier, _LogisticRegressionParams DenseVector([-1.080..., -0.646...]) >>> blorModel.intercept 3.112... + >>> blorModel.evaluate(bdf).accuracy == blorModel.summary.accuracy + True >>> data_path = "data/mllib/sample_multiclass_classification_data.txt" >>> mdf = spark.read.format("libsvm").load(data_path) >>> mlor = LogisticRegression(regParam=0.1, elasticNetParam=1.0, family="multinomial") @@ -932,7 +934,10 @@ def evaluate(self, dataset): if not isinstance(dataset, DataFrame): raise ValueError("dataset must be a DataFrame but got %s." % type(dataset)) java_blr_summary = self._call_java("evaluate", dataset) - return BinaryLogisticRegressionSummary(java_blr_summary) + if self.numClasses <= 2: + return BinaryLogisticRegressionSummary(java_blr_summary) + else: + return LogisticRegressionSummary(java_blr_summary) class LogisticRegressionSummary(JavaWrapper): @@ -985,6 +990,15 @@ def featuresCol(self): """ return self._call_java("featuresCol") + @property + @since("3.1.0") + def weightCol(self): + """ + Field in "predictions" which gives the weight of each instance + as a vector. + """ + return self._call_java("weightCol") + @property @since("2.3.0") def labels(self): diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 984ca411167d3..54a184bc081ee 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -98,7 +98,8 @@ def numIter(self): @inherit_doc class _GaussianMixtureParams(HasMaxIter, HasFeaturesCol, HasSeed, HasPredictionCol, - HasProbabilityCol, HasTol, HasAggregationDepth, HasWeightCol): + HasProbabilityCol, HasTol, HasAggregationDepth, HasWeightCol, + HasBlockSize): """ Params for :py:class:`GaussianMixture` and :py:class:`GaussianMixtureModel`. @@ -243,6 +244,8 @@ class GaussianMixture(JavaEstimator, _GaussianMixtureParams, JavaMLWritable, Jav >>> gm.getMaxIter() 30 >>> model = gm.fit(df) + >>> model.getBlockSize() + 1 >>> model.getAggregationDepth() 2 >>> model.getFeaturesCol() @@ -327,16 +330,16 @@ class GaussianMixture(JavaEstimator, _GaussianMixtureParams, JavaMLWritable, Jav @keyword_only def __init__(self, featuresCol="features", predictionCol="prediction", k=2, probabilityCol="probability", tol=0.01, maxIter=100, seed=None, - aggregationDepth=2, weightCol=None): + aggregationDepth=2, weightCol=None, blockSize=1): """ __init__(self, featuresCol="features", predictionCol="prediction", k=2, \ probabilityCol="probability", tol=0.01, maxIter=100, seed=None, \ - aggregationDepth=2, weightCol=None) + aggregationDepth=2, weightCol=None, blockSize=1) """ super(GaussianMixture, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.GaussianMixture", self.uid) - self._setDefault(k=2, tol=0.01, maxIter=100, aggregationDepth=2) + self._setDefault(k=2, tol=0.01, maxIter=100, aggregationDepth=2, blockSize=1) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -347,11 +350,11 @@ def _create_model(self, java_model): @since("2.0.0") def setParams(self, featuresCol="features", predictionCol="prediction", k=2, probabilityCol="probability", tol=0.01, maxIter=100, seed=None, - aggregationDepth=2, weightCol=None): + aggregationDepth=2, weightCol=None, blockSize=1): """ setParams(self, featuresCol="features", predictionCol="prediction", k=2, \ probabilityCol="probability", tol=0.01, maxIter=100, seed=None, \ - aggregationDepth=2, weightCol=None) + aggregationDepth=2, weightCol=None, blockSize=1) Sets params for GaussianMixture. """ @@ -421,6 +424,13 @@ def setAggregationDepth(self, value): """ return self._set(aggregationDepth=value) + @since("3.1.0") + def setBlockSize(self, value): + """ + Sets the value of :py:attr:`blockSize`. + """ + return self._set(blockSize=value) + class GaussianMixtureSummary(ClusteringSummary): """ @@ -792,7 +802,7 @@ def computeCost(self, dataset): Computes the sum of squared distances between the input points and their corresponding cluster centers. - ..note:: Deprecated in 3.0.0. It will be removed in future versions. Use + .. note:: Deprecated in 3.0.0. It will be removed in future versions. Use ClusteringEvaluator instead. You can also get the cost on the training dataset in the summary. """ diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 265f02c1a03ac..a69a57f588571 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -654,7 +654,7 @@ def setParams(self, predictionCol="prediction", labelCol="label", @inherit_doc -class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol, +class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol, HasWeightCol, JavaMLReadable, JavaMLWritable): """ Evaluator for Clustering results, which expects two input @@ -677,6 +677,18 @@ class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol, ClusteringEvaluator... >>> evaluator.evaluate(dataset) 0.9079... + >>> featureAndPredictionsWithWeight = map(lambda x: (Vectors.dense(x[0]), x[1], x[2]), + ... [([0.0, 0.5], 0.0, 2.5), ([0.5, 0.0], 0.0, 2.5), ([10.0, 11.0], 1.0, 2.5), + ... ([10.5, 11.5], 1.0, 2.5), ([1.0, 1.0], 0.0, 2.5), ([8.0, 6.0], 1.0, 2.5)]) + >>> dataset = spark.createDataFrame( + ... featureAndPredictionsWithWeight, ["features", "prediction", "weight"]) + >>> evaluator = ClusteringEvaluator() + >>> evaluator.setPredictionCol("prediction") + ClusteringEvaluator... + >>> evaluator.setWeightCol("weight") + ClusteringEvaluator... + >>> evaluator.evaluate(dataset) + 0.9079... >>> ce_path = temp_path + "/ce" >>> evaluator.save(ce_path) >>> evaluator2 = ClusteringEvaluator.load(ce_path) @@ -694,10 +706,10 @@ class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol, @keyword_only def __init__(self, predictionCol="prediction", featuresCol="features", - metricName="silhouette", distanceMeasure="squaredEuclidean"): + metricName="silhouette", distanceMeasure="squaredEuclidean", weightCol=None): """ __init__(self, predictionCol="prediction", featuresCol="features", \ - metricName="silhouette", distanceMeasure="squaredEuclidean") + metricName="silhouette", distanceMeasure="squaredEuclidean", weightCol=None) """ super(ClusteringEvaluator, self).__init__() self._java_obj = self._new_java_obj( @@ -709,10 +721,10 @@ def __init__(self, predictionCol="prediction", featuresCol="features", @keyword_only @since("2.3.0") def setParams(self, predictionCol="prediction", featuresCol="features", - metricName="silhouette", distanceMeasure="squaredEuclidean"): + metricName="silhouette", distanceMeasure="squaredEuclidean", weightCol=None): """ setParams(self, predictionCol="prediction", featuresCol="features", \ - metricName="silhouette", distanceMeasure="squaredEuclidean") + metricName="silhouette", distanceMeasure="squaredEuclidean", weightCol=None) Sets params for clustering evaluator. """ kwargs = self._input_kwargs @@ -758,6 +770,13 @@ def setPredictionCol(self, value): """ return self._set(predictionCol=value) + @since("3.1.0") + def setWeightCol(self, value): + """ + Sets the value of :py:attr:`weightCol`. + """ + return self._set(weightCol=value) + @inherit_doc class RankingEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 7acf8ce595840..498629cea846c 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -27,7 +27,8 @@ from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaTransformer, _jvm from pyspark.ml.common import inherit_doc -__all__ = ['Binarizer', +__all__ = ['ANOVASelector', 'ANOVASelectorModel', + 'Binarizer', 'BucketedRandomProjectionLSH', 'BucketedRandomProjectionLSHModel', 'Bucketizer', 'ChiSqSelector', 'ChiSqSelectorModel', @@ -35,6 +36,7 @@ 'DCT', 'ElementwiseProduct', 'FeatureHasher', + 'FValueSelector', 'FValueSelectorModel', 'HashingTF', 'IDF', 'IDFModel', 'Imputer', 'ImputerModel', @@ -5034,15 +5036,15 @@ def __str__(self): return "RFormulaModel(%s) (uid=%s)" % (resolvedFormula, self.uid) -class _ChiSqSelectorParams(HasFeaturesCol, HasOutputCol, HasLabelCol): +class _SelectorParams(HasFeaturesCol, HasOutputCol, HasLabelCol): """ - Params for :py:class:`ChiSqSelector` and :py:class:`ChiSqSelectorModel`. + Params for :py:class:`Selector` and :py:class:`SelectorModel`. - .. versionadded:: 3.0.0 + .. versionadded:: 3.1.0 """ selectorType = Param(Params._dummy(), "selectorType", - "The selector type of the ChisqSelector. " + + "The selector type. " + "Supported options: numTopFeatures (default), percentile, fpr, fdr, fwe.", typeConverter=TypeConverters.toString) @@ -5108,8 +5110,210 @@ def getFwe(self): return self.getOrDefault(self.fwe) +class _Selector(JavaEstimator, _SelectorParams, JavaMLReadable, JavaMLWritable): + """ + Mixin for Selectors. + """ + + @since("2.1.0") + def setSelectorType(self, value): + """ + Sets the value of :py:attr:`selectorType`. + """ + return self._set(selectorType=value) + + @since("2.0.0") + def setNumTopFeatures(self, value): + """ + Sets the value of :py:attr:`numTopFeatures`. + Only applicable when selectorType = "numTopFeatures". + """ + return self._set(numTopFeatures=value) + + @since("2.1.0") + def setPercentile(self, value): + """ + Sets the value of :py:attr:`percentile`. + Only applicable when selectorType = "percentile". + """ + return self._set(percentile=value) + + @since("2.1.0") + def setFpr(self, value): + """ + Sets the value of :py:attr:`fpr`. + Only applicable when selectorType = "fpr". + """ + return self._set(fpr=value) + + @since("2.2.0") + def setFdr(self, value): + """ + Sets the value of :py:attr:`fdr`. + Only applicable when selectorType = "fdr". + """ + return self._set(fdr=value) + + @since("2.2.0") + def setFwe(self, value): + """ + Sets the value of :py:attr:`fwe`. + Only applicable when selectorType = "fwe". + """ + return self._set(fwe=value) + + def setFeaturesCol(self, value): + """ + Sets the value of :py:attr:`featuresCol`. + """ + return self._set(featuresCol=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + + def setLabelCol(self, value): + """ + Sets the value of :py:attr:`labelCol`. + """ + return self._set(labelCol=value) + + +class _SelectorModel(JavaModel, _SelectorParams): + """ + Mixin for Selector models. + """ + + @since("3.0.0") + def setFeaturesCol(self, value): + """ + Sets the value of :py:attr:`featuresCol`. + """ + return self._set(featuresCol=value) + + @since("3.0.0") + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + + @property + @since("2.0.0") + def selectedFeatures(self): + """ + List of indices to select (filter). + """ + return self._call_java("selectedFeatures") + + @inherit_doc -class ChiSqSelector(JavaEstimator, _ChiSqSelectorParams, JavaMLReadable, JavaMLWritable): +class ANOVASelector(_Selector, JavaMLReadable, JavaMLWritable): + """ + ANOVA F-value Classification selector, which selects continuous features to use for predicting + a categorical label. + The selector supports different selection methods: `numTopFeatures`, `percentile`, `fpr`, + `fdr`, `fwe`. + + * `numTopFeatures` chooses a fixed number of top features according to a F value + classification test. + + * `percentile` is similar but chooses a fraction of all features + instead of a fixed number. + + * `fpr` chooses all features whose p-values are below a threshold, + thus controlling the false positive rate of selection. + + * `fdr` uses the `Benjamini-Hochberg procedure `_ + to choose all features whose false discovery rate is below a threshold. + + * `fwe` chooses all features whose p-values are below a threshold. The threshold is scaled by + 1/numFeatures, thus controlling the family-wise error rate of selection. + + By default, the selection method is `numTopFeatures`, with the default number of top features + set to 50. + + + >>> from pyspark.ml.linalg import Vectors + >>> df = spark.createDataFrame( + ... [(Vectors.dense([1.7, 4.4, 7.6, 5.8, 9.6, 2.3]), 3.0), + ... (Vectors.dense([8.8, 7.3, 5.7, 7.3, 2.2, 4.1]), 2.0), + ... (Vectors.dense([1.2, 9.5, 2.5, 3.1, 8.7, 2.5]), 1.0), + ... (Vectors.dense([3.7, 9.2, 6.1, 4.1, 7.5, 3.8]), 2.0), + ... (Vectors.dense([8.9, 5.2, 7.8, 8.3, 5.2, 3.0]), 4.0), + ... (Vectors.dense([7.9, 8.5, 9.2, 4.0, 9.4, 2.1]), 4.0)], + ... ["features", "label"]) + >>> selector = ANOVASelector(numTopFeatures=1, outputCol="selectedFeatures") + >>> model = selector.fit(df) + >>> model.getFeaturesCol() + 'features' + >>> model.setFeaturesCol("features") + ANOVASelectorModel... + >>> model.transform(df).head().selectedFeatures + DenseVector([7.6]) + >>> model.selectedFeatures + [2] + >>> anovaSelectorPath = temp_path + "/anova-selector" + >>> selector.save(anovaSelectorPath) + >>> loadedSelector = ANOVASelector.load(anovaSelectorPath) + >>> loadedSelector.getNumTopFeatures() == selector.getNumTopFeatures() + True + >>> modelPath = temp_path + "/anova-selector-model" + >>> model.save(modelPath) + >>> loadedModel = ANOVASelectorModel.load(modelPath) + >>> loadedModel.selectedFeatures == model.selectedFeatures + True + + .. versionadded:: 3.1.0 + """ + + @keyword_only + def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, + labelCol="label", selectorType="numTopFeatures", percentile=0.1, fpr=0.05, + fdr=0.05, fwe=0.05): + """ + __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, \ + labelCol="label", selectorType="numTopFeatures", percentile=0.1, fpr=0.05, \ + fdr=0.05, fwe=0.05) + """ + super(ANOVASelector, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ANOVASelector", self.uid) + self._setDefault(numTopFeatures=50, selectorType="numTopFeatures", percentile=0.1, + fpr=0.05, fdr=0.05, fwe=0.05) + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("3.1.0") + def setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, + labelCol="labels", selectorType="numTopFeatures", percentile=0.1, fpr=0.05, + fdr=0.05, fwe=0.05): + """ + setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, \ + labelCol="labels", selectorType="numTopFeatures", percentile=0.1, fpr=0.05, \ + fdr=0.05, fwe=0.05) + Sets params for this ANOVASelector. + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return ANOVASelectorModel(java_model) + + +class ANOVASelectorModel(_SelectorModel, JavaMLReadable, JavaMLWritable): + """ + Model fitted by :py:class:`ANOVASelector`. + + .. versionadded:: 3.1.0 + """ + + +@inherit_doc +class ChiSqSelector(_Selector, JavaMLReadable, JavaMLWritable): """ Chi-Squared feature selection, which selects categorical features to use for predicting a categorical label. @@ -5195,103 +5399,119 @@ def setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, kwargs = self._input_kwargs return self._set(**kwargs) - @since("2.1.0") - def setSelectorType(self, value): - """ - Sets the value of :py:attr:`selectorType`. - """ - return self._set(selectorType=value) + def _create_model(self, java_model): + return ChiSqSelectorModel(java_model) - @since("2.0.0") - def setNumTopFeatures(self, value): - """ - Sets the value of :py:attr:`numTopFeatures`. - Only applicable when selectorType = "numTopFeatures". - """ - return self._set(numTopFeatures=value) - @since("2.1.0") - def setPercentile(self, value): - """ - Sets the value of :py:attr:`percentile`. - Only applicable when selectorType = "percentile". - """ - return self._set(percentile=value) +class ChiSqSelectorModel(_SelectorModel, JavaMLReadable, JavaMLWritable): + """ + Model fitted by :py:class:`ChiSqSelector`. - @since("2.1.0") - def setFpr(self, value): - """ - Sets the value of :py:attr:`fpr`. - Only applicable when selectorType = "fpr". - """ - return self._set(fpr=value) + .. versionadded:: 2.0.0 + """ - @since("2.2.0") - def setFdr(self, value): - """ - Sets the value of :py:attr:`fdr`. - Only applicable when selectorType = "fdr". - """ - return self._set(fdr=value) - @since("2.2.0") - def setFwe(self, value): - """ - Sets the value of :py:attr:`fwe`. - Only applicable when selectorType = "fwe". - """ - return self._set(fwe=value) +@inherit_doc +class FValueSelector(_Selector, JavaMLReadable, JavaMLWritable): + """ + F Value Regression feature selector, which selects continuous features to use for predicting a + continuous label. + The selector supports different selection methods: `numTopFeatures`, `percentile`, `fpr`, + `fdr`, `fwe`. - def setFeaturesCol(self, value): - """ - Sets the value of :py:attr:`featuresCol`. - """ - return self._set(featuresCol=value) + * `numTopFeatures` chooses a fixed number of top features according to a F value + regression test. - def setOutputCol(self, value): - """ - Sets the value of :py:attr:`outputCol`. - """ - return self._set(outputCol=value) + * `percentile` is similar but chooses a fraction of all features + instead of a fixed number. - def setLabelCol(self, value): - """ - Sets the value of :py:attr:`labelCol`. - """ - return self._set(labelCol=value) + * `fpr` chooses all features whose p-values are below a threshold, + thus controlling the false positive rate of selection. - def _create_model(self, java_model): - return ChiSqSelectorModel(java_model) + * `fdr` uses the `Benjamini-Hochberg procedure `_ + to choose all features whose false discovery rate is below a threshold. + * `fwe` chooses all features whose p-values are below a threshold. The threshold is scaled by + 1/numFeatures, thus controlling the family-wise error rate of selection. -class ChiSqSelectorModel(JavaModel, _ChiSqSelectorParams, JavaMLReadable, JavaMLWritable): - """ - Model fitted by :py:class:`ChiSqSelector`. + By default, the selection method is `numTopFeatures`, with the default number of top features + set to 50. - .. versionadded:: 2.0.0 + + >>> from pyspark.ml.linalg import Vectors + >>> df = spark.createDataFrame( + ... [(Vectors.dense([6.0, 7.0, 0.0, 7.0, 6.0, 0.0]), 4.6), + ... (Vectors.dense([0.0, 9.0, 6.0, 0.0, 5.0, 9.0]), 6.6), + ... (Vectors.dense([0.0, 9.0, 3.0, 0.0, 5.0, 5.0]), 5.1), + ... (Vectors.dense([0.0, 9.0, 8.0, 5.0, 6.0, 4.0]), 7.6), + ... (Vectors.dense([8.0, 9.0, 6.0, 5.0, 4.0, 4.0]), 9.0), + ... (Vectors.dense([8.0, 9.0, 6.0, 4.0, 0.0, 0.0]), 9.0)], + ... ["features", "label"]) + >>> selector = FValueSelector(numTopFeatures=1, outputCol="selectedFeatures") + >>> model = selector.fit(df) + >>> model.getFeaturesCol() + 'features' + >>> model.setFeaturesCol("features") + FValueSelectorModel... + >>> model.transform(df).head().selectedFeatures + DenseVector([0.0]) + >>> model.selectedFeatures + [2] + >>> fvalueSelectorPath = temp_path + "/fvalue-selector" + >>> selector.save(fvalueSelectorPath) + >>> loadedSelector = FValueSelector.load(fvalueSelectorPath) + >>> loadedSelector.getNumTopFeatures() == selector.getNumTopFeatures() + True + >>> modelPath = temp_path + "/fvalue-selector-model" + >>> model.save(modelPath) + >>> loadedModel = FValueSelectorModel.load(modelPath) + >>> loadedModel.selectedFeatures == model.selectedFeatures + True + + .. versionadded:: 3.1.0 """ - @since("3.0.0") - def setFeaturesCol(self, value): + @keyword_only + def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, + labelCol="label", selectorType="numTopFeatures", percentile=0.1, fpr=0.05, + fdr=0.05, fwe=0.05): """ - Sets the value of :py:attr:`featuresCol`. + __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, \ + labelCol="label", selectorType="numTopFeatures", percentile=0.1, fpr=0.05, \ + fdr=0.05, fwe=0.05) """ - return self._set(featuresCol=value) + super(FValueSelector, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.FValueSelector", self.uid) + self._setDefault(numTopFeatures=50, selectorType="numTopFeatures", percentile=0.1, + fpr=0.05, fdr=0.05, fwe=0.05) + kwargs = self._input_kwargs + self.setParams(**kwargs) - @since("3.0.0") - def setOutputCol(self, value): + @keyword_only + @since("3.1.0") + def setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, + labelCol="labels", selectorType="numTopFeatures", percentile=0.1, fpr=0.05, + fdr=0.05, fwe=0.05): """ - Sets the value of :py:attr:`outputCol`. + setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, \ + labelCol="labels", selectorType="numTopFeatures", percentile=0.1, fpr=0.05, \ + fdr=0.05, fwe=0.05) + Sets params for this FValueSelector. """ - return self._set(outputCol=value) + kwargs = self._input_kwargs + return self._set(**kwargs) - @property - @since("2.0.0") - def selectedFeatures(self): - """ - List of indices to select (filter). - """ - return self._call_java("selectedFeatures") + def _create_model(self, java_model): + return FValueSelectorModel(java_model) + + +class FValueSelectorModel(_SelectorModel, JavaMLReadable, JavaMLWritable): + """ + Model fitted by :py:class:`FValueSelector`. + + .. versionadded:: 3.1.0 + """ @inherit_doc diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index f227fe058dcbd..b58255ea12afc 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -77,7 +77,7 @@ class _JavaRegressor(Regressor, JavaPredictor): class _JavaRegressionModel(RegressionModel, JavaPredictionModel): """ Java Model produced by a ``_JavaRegressor``. - To be mixed in with class:`pyspark.ml.JavaModel` + To be mixed in with :class:`pyspark.ml.JavaModel` .. versionadded:: 3.0.0 """ @@ -87,7 +87,7 @@ class _JavaRegressionModel(RegressionModel, JavaPredictionModel): class _LinearRegressionParams(_PredictorParams, HasRegParam, HasElasticNetParam, HasMaxIter, HasTol, HasFitIntercept, HasStandardization, HasWeightCol, HasSolver, - HasAggregationDepth, HasLoss): + HasAggregationDepth, HasLoss, HasBlockSize): """ Params for :py:class:`LinearRegression` and :py:class:`LinearRegressionModel`. @@ -155,6 +155,8 @@ class LinearRegression(_JavaRegressor, _LinearRegressionParams, JavaMLWritable, LinearRegressionModel... >>> model.getMaxIter() 5 + >>> model.getBlockSize() + 1 >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> abs(model.predict(test0.head().features) - (-1.0)) < 0.001 True @@ -194,17 +196,18 @@ class LinearRegression(_JavaRegressor, _LinearRegressionParams, JavaMLWritable, def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, standardization=True, solver="auto", weightCol=None, aggregationDepth=2, - loss="squaredError", epsilon=1.35): + loss="squaredError", epsilon=1.35, blockSize=1): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ standardization=True, solver="auto", weightCol=None, aggregationDepth=2, \ - loss="squaredError", epsilon=1.35) + loss="squaredError", epsilon=1.35, blockSize=1) """ super(LinearRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.LinearRegression", self.uid) - self._setDefault(maxIter=100, regParam=0.0, tol=1e-6, loss="squaredError", epsilon=1.35) + self._setDefault(maxIter=100, regParam=0.0, tol=1e-6, loss="squaredError", epsilon=1.35, + blockSize=1) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -213,12 +216,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, standardization=True, solver="auto", weightCol=None, aggregationDepth=2, - loss="squaredError", epsilon=1.35): + loss="squaredError", epsilon=1.35, blockSize=1): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ standardization=True, solver="auto", weightCol=None, aggregationDepth=2, \ - loss="squaredError", epsilon=1.35) + loss="squaredError", epsilon=1.35, blockSize=1) Sets params for linear regression. """ kwargs = self._input_kwargs @@ -294,6 +297,13 @@ def setLoss(self, value): """ return self._set(lossType=value) + @since("3.1.0") + def setBlockSize(self, value): + """ + Sets the value of :py:attr:`blockSize`. + """ + return self._set(blockSize=value) + class LinearRegressionModel(_JavaRegressionModel, _LinearRegressionParams, GeneralJavaMLWritable, JavaMLReadable, HasTrainingSummary): @@ -1597,7 +1607,7 @@ def evaluateEachIteration(self, dataset, loss): class _AFTSurvivalRegressionParams(_PredictorParams, HasMaxIter, HasTol, HasFitIntercept, - HasAggregationDepth): + HasAggregationDepth, HasBlockSize): """ Params for :py:class:`AFTSurvivalRegression` and :py:class:`AFTSurvivalRegressionModel`. @@ -1664,6 +1674,8 @@ class AFTSurvivalRegression(_JavaRegressor, _AFTSurvivalRegressionParams, 10 >>> aftsr.clear(aftsr.maxIter) >>> model = aftsr.fit(df) + >>> model.getBlockSize() + 1 >>> model.setFeaturesCol("features") AFTSurvivalRegressionModel... >>> model.predict(Vectors.dense(6.3)) @@ -1700,19 +1712,19 @@ class AFTSurvivalRegression(_JavaRegressor, _AFTSurvivalRegressionParams, def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]), - quantilesCol=None, aggregationDepth=2): + quantilesCol=None, aggregationDepth=2, blockSize=1): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \ quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \ - quantilesCol=None, aggregationDepth=2) + quantilesCol=None, aggregationDepth=2, blockSize=1) """ super(AFTSurvivalRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.AFTSurvivalRegression", self.uid) self._setDefault(censorCol="censor", quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], - maxIter=100, tol=1E-6) + maxIter=100, tol=1E-6, blockSize=1) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -1721,12 +1733,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]), - quantilesCol=None, aggregationDepth=2): + quantilesCol=None, aggregationDepth=2, blockSize=1): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \ quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \ - quantilesCol=None, aggregationDepth=2): + quantilesCol=None, aggregationDepth=2, blockSize=1): """ kwargs = self._input_kwargs return self._set(**kwargs) @@ -1783,6 +1795,13 @@ def setAggregationDepth(self, value): """ return self._set(aggregationDepth=value) + @since("3.1.0") + def setBlockSize(self, value): + """ + Sets the value of :py:attr:`blockSize`. + """ + return self._set(blockSize=value) + class AFTSurvivalRegressionModel(_JavaRegressionModel, _AFTSurvivalRegressionParams, JavaMLWritable, JavaMLReadable): diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py index 058146935ed91..70de8425613ec 100644 --- a/python/pyspark/ml/stat.py +++ b/python/pyspark/ml/stat.py @@ -38,7 +38,7 @@ class ChiSquareTest(object): """ @staticmethod @since("2.2.0") - def test(dataset, featuresCol, labelCol): + def test(dataset, featuresCol, labelCol, flatten=False): """ Perform a Pearson's independence test using dataset. @@ -49,14 +49,24 @@ def test(dataset, featuresCol, labelCol): Name of features column in dataset, of type `Vector` (`VectorUDT`). :param labelCol: Name of label column in dataset, of any numerical type. + :param flatten: if True, flattens the returned dataframe. :return: DataFrame containing the test result for every feature against the label. - This DataFrame will contain a single Row with the following fields: + If flatten is True, this DataFrame will contain one row per feature with the following + fields: + - `featureIndex: int` + - `pValue: float` + - `degreesOfFreedom: int` + - `statistic: float` + If flatten is False, this DataFrame will contain a single Row with the following fields: - `pValues: Vector` - - `degreesOfFreedom: Array[Int]` + - `degreesOfFreedom: Array[int]` - `statistics: Vector` Each of these fields has one value per feature. + .. versionchanged:: 3.1.0 + Added optional ``flatten`` argument. + >>> from pyspark.ml.linalg import Vectors >>> from pyspark.ml.stat import ChiSquareTest >>> dataset = [[0, Vectors.dense([0, 0, 1])], @@ -67,10 +77,14 @@ def test(dataset, featuresCol, labelCol): >>> chiSqResult = ChiSquareTest.test(dataset, 'features', 'label') >>> chiSqResult.select("degreesOfFreedom").collect()[0] Row(degreesOfFreedom=[3, 1, 0]) + >>> chiSqResult = ChiSquareTest.test(dataset, 'features', 'label', True) + >>> row = chiSqResult.orderBy("featureIndex").collect() + >>> row[0].statistic + 4.0 """ sc = SparkContext._active_spark_context javaTestObj = _jvm().org.apache.spark.ml.stat.ChiSquareTest - args = [_py2java(sc, arg) for arg in (dataset, featuresCol, labelCol)] + args = [_py2java(sc, arg) for arg in (dataset, featuresCol, labelCol, flatten)] return _java2py(sc, javaTestObj.test(*args)) @@ -419,7 +433,7 @@ class ANOVATest(object): """ @staticmethod @since("3.1.0") - def test(dataset, featuresCol, labelCol): + def test(dataset, featuresCol, labelCol, flatten=False): """ Perform an ANOVA test using dataset. @@ -429,11 +443,18 @@ def test(dataset, featuresCol, labelCol): Name of features column in dataset, of type `Vector` (`VectorUDT`). :param labelCol: Name of label column in dataset, of any numerical type. + :param flatten: if True, flattens the returned dataframe. :return: DataFrame containing the test result for every feature against the label. - This DataFrame will contain a single Row with the following fields: + If flatten is True, this DataFrame will contain one row per feature with the following + fields: + - `featureIndex: int` + - `pValue: float` + - `degreesOfFreedom: int` + - `fValue: float` + If flatten is False, this DataFrame will contain a single Row with the following fields: - `pValues: Vector` - - `degreesOfFreedom: Array[Long]` + - `degreesOfFreedom: Array[int]` - `fValues: Vector` Each of these fields has one value per feature. @@ -454,10 +475,14 @@ def test(dataset, featuresCol, labelCol): DenseVector([4.0264, 18.4713, 3.4659, 1.9042, 0.5532, 0.512]) >>> row[0].pValues DenseVector([0.3324, 0.1623, 0.3551, 0.456, 0.689, 0.7029]) + >>> anovaResult = ANOVATest.test(dataset, 'features', 'label', True) + >>> row = anovaResult.orderBy("featureIndex").collect() + >>> row[0].fValue + 4.026438671875297 """ sc = SparkContext._active_spark_context javaTestObj = _jvm().org.apache.spark.ml.stat.ANOVATest - args = [_py2java(sc, arg) for arg in (dataset, featuresCol, labelCol)] + args = [_py2java(sc, arg) for arg in (dataset, featuresCol, labelCol, flatten)] return _java2py(sc, javaTestObj.test(*args)) @@ -469,7 +494,7 @@ class FValueTest(object): """ @staticmethod @since("3.1.0") - def test(dataset, featuresCol, labelCol): + def test(dataset, featuresCol, labelCol, flatten=False): """ Perform a F Regression test using dataset. @@ -479,11 +504,18 @@ def test(dataset, featuresCol, labelCol): Name of features column in dataset, of type `Vector` (`VectorUDT`). :param labelCol: Name of label column in dataset, of any numerical type. + :param flatten: if True, flattens the returned dataframe. :return: DataFrame containing the test result for every feature against the label. - This DataFrame will contain a single Row with the following fields: + If flatten is True, this DataFrame will contain one row per feature with the following + fields: + - `featureIndex: int` + - `pValue: float` + - `degreesOfFreedom: int` + - `fValue: float` + If flatten is False, this DataFrame will contain a single Row with the following fields: - `pValues: Vector` - - `degreesOfFreedom: Array[Long]` + - `degreesOfFreedom: Array[int]` - `fValues: Vector` Each of these fields has one value per feature. @@ -504,10 +536,14 @@ def test(dataset, featuresCol, labelCol): DenseVector([3.741, 7.5807, 142.0684, 34.9849, 0.4112, 0.0539]) >>> row[0].pValues DenseVector([0.1928, 0.1105, 0.007, 0.0274, 0.5871, 0.838]) + >>> fValueResult = FValueTest.test(dataset, 'features', 'label', True) + >>> row = fValueResult.orderBy("featureIndex").collect() + >>> row[0].fValue + 3.7409548308350593 """ sc = SparkContext._active_spark_context javaTestObj = _jvm().org.apache.spark.ml.stat.FValueTest - args = [_py2java(sc, arg) for arg in (dataset, featuresCol, labelCol)] + args = [_py2java(sc, arg) for arg in (dataset, featuresCol, labelCol, flatten)] return _java2py(sc, javaTestObj.test(*args)) diff --git a/python/pyspark/ml/tests/test_param.py b/python/pyspark/ml/tests/test_param.py index 61f9f18f65b21..1b2b1914cc036 100644 --- a/python/pyspark/ml/tests/test_param.py +++ b/python/pyspark/ml/tests/test_param.py @@ -366,7 +366,8 @@ def test_java_params(self): for name, cls in inspect.getmembers(module, inspect.isclass): if not name.endswith('Model') and not name.endswith('Params') \ and issubclass(cls, JavaParams) and not inspect.isabstract(cls) \ - and not re.match("_?Java", name) and name != '_LSH': + and not re.match("_?Java", name) and name != '_LSH' \ + and name != '_Selector': # NOTE: disable check_params_exist until there is parity with Scala API check_params(self, cls(), check_params_exist=False) diff --git a/python/pyspark/ml/tests/test_training_summary.py b/python/pyspark/ml/tests/test_training_summary.py index 1d19ebf9a34a0..b5054095d190b 100644 --- a/python/pyspark/ml/tests/test_training_summary.py +++ b/python/pyspark/ml/tests/test_training_summary.py @@ -21,7 +21,8 @@ if sys.version > '3': basestring = str -from pyspark.ml.classification import LogisticRegression +from pyspark.ml.classification import BinaryLogisticRegressionSummary, LogisticRegression, \ + LogisticRegressionSummary from pyspark.ml.clustering import BisectingKMeans, GaussianMixture, KMeans from pyspark.ml.linalg import Vectors from pyspark.ml.regression import GeneralizedLinearRegression, LinearRegression @@ -149,6 +150,7 @@ def test_binary_logistic_regression_summary(self): # test evaluation (with training dataset) produces a summary with same values # one check is enough to verify a summary is returned, Scala version runs full test sameSummary = model.evaluate(df) + self.assertTrue(isinstance(sameSummary, BinaryLogisticRegressionSummary)) self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) def test_multiclass_logistic_regression_summary(self): @@ -187,6 +189,8 @@ def test_multiclass_logistic_regression_summary(self): # test evaluation (with training dataset) produces a summary with same values # one check is enough to verify a summary is returned, Scala version runs full test sameSummary = model.evaluate(df) + self.assertTrue(isinstance(sameSummary, LogisticRegressionSummary)) + self.assertFalse(isinstance(sameSummary, BinaryLogisticRegressionSummary)) self.assertAlmostEqual(sameSummary.accuracy, s.accuracy) def test_gaussian_mixture_summary(self): diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 35ad5518e1c1f..aac2b38d3f57d 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -563,6 +563,7 @@ def loadParamsInstance(path, sc): class HasTrainingSummary(object): """ Base class for models that provides Training summary. + .. versionadded:: 3.0.0 """ diff --git a/python/pyspark/mllib/tests/test_streaming_algorithms.py b/python/pyspark/mllib/tests/test_streaming_algorithms.py index 2077809a043f1..f57de83bae64d 100644 --- a/python/pyspark/mllib/tests/test_streaming_algorithms.py +++ b/python/pyspark/mllib/tests/test_streaming_algorithms.py @@ -463,7 +463,7 @@ def condition(): return True return "Latest errors: " + ", ".join(map(lambda x: str(x), errors)) - eventually(condition) + eventually(condition, timeout=180.0) if __name__ == "__main__": diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 1a0ce42dc4e4f..f0f9cda4672b1 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -372,7 +372,7 @@ def save(self, sc, path): * human-readable (JSON) model metadata to path/metadata/ * Parquet formatted data to path/data/ - The model may be loaded using py:meth:`Loader.load`. + The model may be loaded using :py:meth:`Loader.load`. :param sc: Spark context used to save model data. :param path: Path specifying the directory in which to save @@ -412,7 +412,7 @@ class Loader(object): def load(cls, sc, path): """ Load a model from the given path. The model should have been - saved using py:meth:`Saveable.save`. + saved using :py:meth:`Saveable.save`. :param sc: Spark context used for loading model files. :param path: Path specifying the directory to which the model diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d0ac000ba3208..db0c1971cd2fe 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -47,9 +47,8 @@ from pyspark.statcounter import StatCounter from pyspark.rddsampler import RDDSampler, RDDRangeSampler, RDDStratifiedSampler from pyspark.storagelevel import StorageLevel -from pyspark.resource.executorrequests import ExecutorResourceRequests -from pyspark.resource.resourceprofile import ResourceProfile -from pyspark.resource.taskrequests import TaskResourceRequests +from pyspark.resource.requests import ExecutorResourceRequests, TaskResourceRequests +from pyspark.resource.profile import ResourceProfile from pyspark.resultiterable import ResultIterable from pyspark.shuffle import Aggregator, ExternalMerger, \ get_used_memory, ExternalSorter, ExternalGroupBy diff --git a/python/pyspark/resource/__init__.py b/python/pyspark/resource/__init__.py index 89070ec4adc7e..b5f4c4a6b1825 100644 --- a/python/pyspark/resource/__init__.py +++ b/python/pyspark/resource/__init__.py @@ -18,12 +18,13 @@ """ APIs to let users manipulate resource requirements. """ -from pyspark.resource.executorrequests import ExecutorResourceRequest, ExecutorResourceRequests -from pyspark.resource.taskrequests import TaskResourceRequest, TaskResourceRequests -from pyspark.resource.resourceprofilebuilder import ResourceProfileBuilder -from pyspark.resource.resourceprofile import ResourceProfile +from pyspark.resource.information import ResourceInformation +from pyspark.resource.requests import TaskResourceRequest, TaskResourceRequests, \ + ExecutorResourceRequest, ExecutorResourceRequests +from pyspark.resource.profile import ResourceProfile, ResourceProfileBuilder __all__ = [ "TaskResourceRequest", "TaskResourceRequests", "ExecutorResourceRequest", - "ExecutorResourceRequests", "ResourceProfile", "ResourceProfileBuilder", + "ExecutorResourceRequests", "ResourceProfile", "ResourceInformation", + "ResourceProfileBuilder", ] diff --git a/python/pyspark/resourceinformation.py b/python/pyspark/resource/information.py similarity index 89% rename from python/pyspark/resourceinformation.py rename to python/pyspark/resource/information.py index aaed21374b6ee..b0e41cced85b5 100644 --- a/python/pyspark/resourceinformation.py +++ b/python/pyspark/resource/information.py @@ -26,8 +26,10 @@ class ResourceInformation(object): One example is GPUs, where the addresses would be the indices of the GPUs - @param name the name of the resource - @param addresses an array of strings describing the addresses of the resource + :param name: the name of the resource + :param addresses: an array of strings describing the addresses of the resource + + .. versionadded:: 3.0.0 """ def __init__(self, name, addresses): diff --git a/python/pyspark/resource/resourceprofilebuilder.py b/python/pyspark/resource/profile.py similarity index 69% rename from python/pyspark/resource/resourceprofilebuilder.py rename to python/pyspark/resource/profile.py index 67654289d500f..3f6ae1ddd5e30 100644 --- a/python/pyspark/resource/resourceprofilebuilder.py +++ b/python/pyspark/resource/profile.py @@ -15,10 +15,61 @@ # limitations under the License. # -from pyspark.resource.executorrequests import ExecutorResourceRequest,\ - ExecutorResourceRequests -from pyspark.resource.resourceprofile import ResourceProfile -from pyspark.resource.taskrequests import TaskResourceRequest, TaskResourceRequests +from pyspark.resource.requests import TaskResourceRequest, TaskResourceRequests, \ + ExecutorResourceRequests, ExecutorResourceRequest + + +class ResourceProfile(object): + + """ + .. note:: Evolving + + Resource profile to associate with an RDD. A :class:`pyspark.resource.ResourceProfile` + allows the user to specify executor and task requirements for an RDD that will get + applied during a stage. This allows the user to change the resource requirements between + stages. This is meant to be immutable so user cannot change it after building. + + .. versionadded:: 3.1.0 + """ + + def __init__(self, _java_resource_profile=None, _exec_req={}, _task_req={}): + if _java_resource_profile is not None: + self._java_resource_profile = _java_resource_profile + else: + self._java_resource_profile = None + self._executor_resource_requests = _exec_req + self._task_resource_requests = _task_req + + @property + def id(self): + if self._java_resource_profile is not None: + return self._java_resource_profile.id() + else: + raise RuntimeError("SparkContext must be created to get the id, get the id " + "after adding the ResourceProfile to an RDD") + + @property + def taskResources(self): + if self._java_resource_profile is not None: + taskRes = self._java_resource_profile.taskResourcesJMap() + result = {} + for k, v in taskRes.items(): + result[k] = TaskResourceRequest(v.resourceName(), v.amount()) + return result + else: + return self._task_resource_requests + + @property + def executorResources(self): + if self._java_resource_profile is not None: + execRes = self._java_resource_profile.executorResourcesJMap() + result = {} + for k, v in execRes.items(): + result[k] = ExecutorResourceRequest(v.resourceName(), v.amount(), + v.discoveryScript(), v.vendor()) + return result + else: + return self._executor_resource_requests class ResourceProfileBuilder(object): diff --git a/python/pyspark/resource/executorrequests.py b/python/pyspark/resource/requests.py similarity index 70% rename from python/pyspark/resource/executorrequests.py rename to python/pyspark/resource/requests.py index 91a195c94b6e5..56ad6e8be9bcb 100644 --- a/python/pyspark/resource/executorrequests.py +++ b/python/pyspark/resource/requests.py @@ -15,7 +15,6 @@ # limitations under the License. # -from pyspark.resource.taskrequests import TaskResourceRequest from pyspark.util import _parse_memory @@ -167,3 +166,89 @@ def requests(self): return result else: return self._executor_resources + + +class TaskResourceRequest(object): + """ + .. note:: Evolving + + A task resource request. This is used in conjuntion with the + :class:`pyspark.resource.ResourceProfile` to programmatically specify the resources + needed for an RDD that will be applied at the stage level. The amount is specified + as a Double to allow for saying you want more than 1 task per resource. Valid values + are less than or equal to 0.5 or whole numbers. + Use :class:`pyspark.resource.TaskResourceRequests` class as a convenience API. + + :param resourceName: Name of the resource + :param amount: Amount requesting as a Double to support fractional resource requests. + Valid values are less than or equal to 0.5 or whole numbers. + + .. versionadded:: 3.1.0 + """ + def __init__(self, resourceName, amount): + self._name = resourceName + self._amount = float(amount) + + @property + def resourceName(self): + return self._name + + @property + def amount(self): + return self._amount + + +class TaskResourceRequests(object): + + """ + .. note:: Evolving + + A set of task resource requests. This is used in conjuntion with the + :class:`pyspark.resource.ResourceProfileBuilder` to programmatically specify the resources + needed for an RDD that will be applied at the stage level. + + .. versionadded:: 3.1.0 + """ + + _CPUS = "cpus" + + def __init__(self, _jvm=None, _requests=None): + from pyspark import SparkContext + _jvm = _jvm or SparkContext._jvm + if _jvm is not None: + self._java_task_resource_requests = \ + SparkContext._jvm.org.apache.spark.resource.TaskResourceRequests() + if _requests is not None: + for k, v in _requests.items(): + if k == self._CPUS: + self._java_task_resource_requests.cpus(int(v.amount)) + else: + self._java_task_resource_requests.resource(v.resourceName, v.amount) + else: + self._java_task_resource_requests = None + self._task_resources = {} + + def cpus(self, amount): + if self._java_task_resource_requests is not None: + self._java_task_resource_requests.cpus(amount) + else: + self._task_resources[self._CPUS] = TaskResourceRequest(self._CPUS, amount) + return self + + def resource(self, resourceName, amount): + if self._java_task_resource_requests is not None: + self._java_task_resource_requests.resource(resourceName, float(amount)) + else: + self._task_resources[resourceName] = TaskResourceRequest(resourceName, amount) + return self + + @property + def requests(self): + if self._java_task_resource_requests is not None: + result = {} + taskRes = self._java_task_resource_requests.requestsJMap() + for k, v in taskRes.items(): + result[k] = TaskResourceRequest(v.resourceName(), v.amount()) + return result + else: + return self._task_resources diff --git a/python/pyspark/resource/resourceprofile.py b/python/pyspark/resource/resourceprofile.py deleted file mode 100644 index 59e9ccb4b6ea0..0000000000000 --- a/python/pyspark/resource/resourceprofile.py +++ /dev/null @@ -1,72 +0,0 @@ - -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from pyspark.resource.taskrequests import TaskResourceRequest -from pyspark.resource.executorrequests import ExecutorResourceRequest - - -class ResourceProfile(object): - - """ - .. note:: Evolving - - Resource profile to associate with an RDD. A :class:`pyspark.resource.ResourceProfile` - allows the user to specify executor and task requirements for an RDD that will get - applied during a stage. This allows the user to change the resource requirements between - stages. This is meant to be immutable so user doesn't change it after building. - - .. versionadded:: 3.1.0 - """ - - def __init__(self, _java_resource_profile=None, _exec_req={}, _task_req={}): - if _java_resource_profile is not None: - self._java_resource_profile = _java_resource_profile - else: - self._java_resource_profile = None - self._executor_resource_requests = _exec_req - self._task_resource_requests = _task_req - - @property - def id(self): - if self._java_resource_profile is not None: - return self._java_resource_profile.id() - else: - raise RuntimeError("SparkContext must be created to get the id, get the id " - "after adding the ResourceProfile to an RDD") - - @property - def taskResources(self): - if self._java_resource_profile is not None: - taskRes = self._java_resource_profile.taskResourcesJMap() - result = {} - for k, v in taskRes.items(): - result[k] = TaskResourceRequest(v.resourceName(), v.amount()) - return result - else: - return self._task_resource_requests - - @property - def executorResources(self): - if self._java_resource_profile is not None: - execRes = self._java_resource_profile.executorResourcesJMap() - result = {} - for k, v in execRes.items(): - result[k] = ExecutorResourceRequest(v.resourceName(), v.amount(), - v.discoveryScript(), v.vendor()) - return result - else: - return self._executor_resource_requests diff --git a/python/pyspark/resource/taskrequests.py b/python/pyspark/resource/taskrequests.py deleted file mode 100644 index e8dca98d14b61..0000000000000 --- a/python/pyspark/resource/taskrequests.py +++ /dev/null @@ -1,102 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -class TaskResourceRequest(object): - """ - .. note:: Evolving - - A task resource request. This is used in conjuntion with the - :class:`pyspark.resource.ResourceProfile` to programmatically specify the resources - needed for an RDD that will be applied at the stage level. The amount is specified - as a Double to allow for saying you want more then 1 task per resource. Valid values - are less than or equal to 0.5 or whole numbers. - Use :class:`pyspark.resource.TaskResourceRequests` class as a convenience API. - - :param resourceName: Name of the resource - :param amount: Amount requesting as a Double to support fractional resource requests. - Valid values are less than or equal to 0.5 or whole numbers. - - .. versionadded:: 3.1.0 - """ - def __init__(self, resourceName, amount): - self._name = resourceName - self._amount = float(amount) - - @property - def resourceName(self): - return self._name - - @property - def amount(self): - return self._amount - - -class TaskResourceRequests(object): - - """ - .. note:: Evolving - - A set of task resource requests. This is used in conjuntion with the - :class:`pyspark.resource.ResourceProfileBuilder` to programmatically specify the resources - needed for an RDD that will be applied at the stage level. - - .. versionadded:: 3.1.0 - """ - - _CPUS = "cpus" - - def __init__(self, _jvm=None, _requests=None): - from pyspark import SparkContext - _jvm = _jvm or SparkContext._jvm - if _jvm is not None: - self._java_task_resource_requests = \ - SparkContext._jvm.org.apache.spark.resource.TaskResourceRequests() - if _requests is not None: - for k, v in _requests.items(): - if k == self._CPUS: - self._java_task_resource_requests.cpus(int(v.amount)) - else: - self._java_task_resource_requests.resource(v.resourceName, v.amount) - else: - self._java_task_resource_requests = None - self._task_resources = {} - - def cpus(self, amount): - if self._java_task_resource_requests is not None: - self._java_task_resource_requests.cpus(amount) - else: - self._task_resources[self._CPUS] = TaskResourceRequest(self._CPUS, amount) - return self - - def resource(self, resourceName, amount): - if self._java_task_resource_requests is not None: - self._java_task_resource_requests.resource(resourceName, float(amount)) - else: - self._task_resources[resourceName] = TaskResourceRequest(resourceName, amount) - return self - - @property - def requests(self): - if self._java_task_resource_requests is not None: - result = {} - taskRes = self._java_task_resource_requests.requestsJMap() - for k, v in taskRes.items(): - result[k] = TaskResourceRequest(v.resourceName(), v.amount()) - return result - else: - return self._task_resources diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 78b574685327c..03e3b9ca4bd05 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -276,6 +276,8 @@ def explain(self, extended=None, mode=None): """Prints the (logical and physical) plans to the console for debugging purpose. :param extended: boolean, default ``False``. If ``False``, prints only the physical plan. + When this is a string without specifying the ``mode``, it works as the mode is + specified. :param mode: specifies the expected output format of plans. * ``simple``: Print only a physical plan. @@ -306,12 +308,17 @@ def explain(self, extended=None, mode=None): Output [2]: [age#0, name#1] ... + >>> df.explain("cost") + == Optimized Logical Plan == + ...Statistics... + ... + .. versionchanged:: 3.0.0 Added optional argument `mode` to specify the expected output format of plans. """ if extended is not None and mode is not None: - raise Exception("extended and mode can not be specified simultaneously") + raise Exception("extended and mode should not be set together.") # For the no argument case: df.explain() is_no_argument = extended is None and mode is None @@ -319,18 +326,22 @@ def explain(self, extended=None, mode=None): # For the cases below: # explain(True) # explain(extended=False) - is_extended_case = extended is not None and isinstance(extended, bool) + is_extended_case = isinstance(extended, bool) and mode is None - # For the mode specified: df.explain(mode="formatted") - is_mode_case = mode is not None and isinstance(mode, basestring) + # For the case when extended is mode: + # df.explain("formatted") + is_extended_as_mode = isinstance(extended, basestring) and mode is None - if not is_no_argument and not (is_extended_case or is_mode_case): - if extended is not None: - err_msg = "extended (optional) should be provided as bool" \ - ", got {0}".format(type(extended)) - else: # For mode case - err_msg = "mode (optional) should be provided as str, got {0}".format(type(mode)) - raise TypeError(err_msg) + # For the mode specified: + # df.explain(mode="formatted") + is_mode_case = extended is None and isinstance(mode, basestring) + + if not (is_no_argument or is_extended_case or is_extended_as_mode or is_mode_case): + argtypes = [ + str(type(arg)) for arg in [extended, mode] if arg is not None] + raise TypeError( + "extended (optional) and mode (optional) should be a string " + "and bool; however, got [%s]." % ", ".join(argtypes)) # Sets an explain mode depending on a given argument if is_no_argument: @@ -339,6 +350,8 @@ def explain(self, extended=None, mode=None): explain_mode = "extended" if extended else "simple" elif is_mode_case: explain_mode = mode + elif is_extended_as_mode: + explain_mode = extended print(self._sc._jvm.PythonSQLUtils.explainString(self._jdf.queryExecution(), explain_mode)) @@ -2138,7 +2151,7 @@ def drop(self, *cols): @ignore_unicode_prefix def toDF(self, *cols): - """Returns a new class:`DataFrame` that with new specified column names + """Returns a new :class:`DataFrame` that with new specified column names :param cols: list of new column names (string) @@ -2150,9 +2163,9 @@ def toDF(self, *cols): @since(3.0) def transform(self, func): - """Returns a new class:`DataFrame`. Concise syntax for chaining custom transformations. + """Returns a new :class:`DataFrame`. Concise syntax for chaining custom transformations. - :param func: a function that takes and returns a class:`DataFrame`. + :param func: a function that takes and returns a :class:`DataFrame`. >>> from pyspark.sql.functions import col >>> df = spark.createDataFrame([(1, 1.0), (2, 2.0)], ["int", "float"]) @@ -2219,6 +2232,20 @@ def semanticHash(self): """ return self._jdf.semanticHash() + @since(3.1) + def inputFiles(self): + """ + Returns a best-effort snapshot of the files that compose this :class:`DataFrame`. + This method simply asks each constituent BaseRelation for its respective files and + takes the union of all results. Depending on the source relations, this may not find + all input files. Duplicates are removed. + + >>> df = spark.read.load("examples/src/main/resources/people.json", format="json") + >>> len(df.inputFiles()) + 1 + """ + return list(self._jdf.inputFiles()) + where = copy_func( filter, sinceversion=1.3, diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 4dd15d14b9c53..ff0b10a9306cf 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -154,6 +154,9 @@ def create_array(s, t): # Ensure timestamp series are in expected form for Spark internal representation if t is not None and pa.types.is_timestamp(t): s = _check_series_convert_timestamps_internal(s, self._timezone) + elif type(s.dtype) == pd.CategoricalDtype: + # Note: This can be removed once minimum pyarrow version is >= 0.16.1 + s = s.astype(s.dtypes.categories.dtype) try: array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck) except pa.ArrowException as e: diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index d1edf3f9c47c1..4b70c8a2e95e1 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -114,6 +114,8 @@ def from_arrow_type(at): return StructType( [StructField(field.name, from_arrow_type(field.type), nullable=field.nullable) for field in at]) + elif types.is_dictionary(at): + spark_type = from_arrow_type(at.value_type) else: raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) return spark_type diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 6ad6377288ec5..336345e383729 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -223,15 +223,15 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \ + * ``PERMISSIVE``: when it meets a corrupted record, puts the malformed string \ into a field configured by ``columnNameOfCorruptRecord``, and sets malformed \ fields to ``null``. To keep corrupt records, an user can set a string type \ field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \ schema does not have the field, it drops corrupt records during parsing. \ When inferring a schema, it implicitly adds a ``columnNameOfCorruptRecord`` \ field in an output schema. - * ``DROPMALFORMED`` : ignores the whole corrupted records. - * ``FAILFAST`` : throws an exception when it meets corrupted records. + * ``DROPMALFORMED``: ignores the whole corrupted records. + * ``FAILFAST``: throws an exception when it meets corrupted records. :param columnNameOfCorruptRecord: allows renaming the new field having malformed string created by ``PERMISSIVE`` mode. This overrides @@ -470,7 +470,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non be controlled by ``spark.sql.csv.parser.columnPruning.enabled`` (enabled by default). - * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \ + * ``PERMISSIVE``: when it meets a corrupted record, puts the malformed string \ into a field configured by ``columnNameOfCorruptRecord``, and sets malformed \ fields to ``null``. To keep corrupt records, an user can set a string type \ field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \ @@ -479,8 +479,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non When it meets a record having fewer tokens than the length of the schema, \ sets ``null`` to extra fields. When the record has more tokens than the \ length of the schema, it drops extra tokens. - * ``DROPMALFORMED`` : ignores the whole corrupted records. - * ``FAILFAST`` : throws an exception when it meets corrupted records. + * ``DROPMALFORMED``: ignores the whole corrupted records. + * ``FAILFAST``: throws an exception when it meets corrupted records. :param columnNameOfCorruptRecord: allows renaming the new field having malformed string created by ``PERMISSIVE`` mode. This overrides @@ -830,7 +830,7 @@ def save(self, path=None, format=None, mode=None, partitionBy=None, **options): def insertInto(self, tableName, overwrite=None): """Inserts the content of the :class:`DataFrame` to the specified table. - It requires that the schema of the class:`DataFrame` is the same as the + It requires that the schema of the :class:`DataFrame` is the same as the schema of the table. Optionally overwriting any existing data. diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 05cf331d897a2..2450a4c93c460 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -461,15 +461,15 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \ + * ``PERMISSIVE``: when it meets a corrupted record, puts the malformed string \ into a field configured by ``columnNameOfCorruptRecord``, and sets malformed \ fields to ``null``. To keep corrupt records, an user can set a string type \ field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \ schema does not have the field, it drops corrupt records during parsing. \ When inferring a schema, it implicitly adds a ``columnNameOfCorruptRecord`` \ field in an output schema. - * ``DROPMALFORMED`` : ignores the whole corrupted records. - * ``FAILFAST`` : throws an exception when it meets corrupted records. + * ``DROPMALFORMED``: ignores the whole corrupted records. + * ``FAILFAST``: throws an exception when it meets corrupted records. :param columnNameOfCorruptRecord: allows renaming the new field having malformed string created by ``PERMISSIVE`` mode. This overrides @@ -707,7 +707,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \ + * ``PERMISSIVE``: when it meets a corrupted record, puts the malformed string \ into a field configured by ``columnNameOfCorruptRecord``, and sets malformed \ fields to ``null``. To keep corrupt records, an user can set a string type \ field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \ @@ -716,8 +716,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non When it meets a record having fewer tokens than the length of the schema, \ sets ``null`` to extra fields. When the record has more tokens than the \ length of the schema, it drops extra tokens. - * ``DROPMALFORMED`` : ignores the whole corrupted records. - * ``FAILFAST`` : throws an exception when it meets corrupted records. + * ``DROPMALFORMED``: ignores the whole corrupted records. + * ``FAILFAST``: throws an exception when it meets corrupted records. :param columnNameOfCorruptRecord: allows renaming the new field having malformed string created by ``PERMISSIVE`` mode. This overrides @@ -795,11 +795,11 @@ def outputMode(self, outputMode): Options include: - * `append`:Only the new rows in the streaming DataFrame/Dataset will be written to + * `append`: Only the new rows in the streaming DataFrame/Dataset will be written to the sink - * `complete`:All the rows in the streaming DataFrame/Dataset will be written to the sink + * `complete`: All the rows in the streaming DataFrame/Dataset will be written to the sink every time these is some updates - * `update`:only the rows that were updated in the streaming DataFrame/Dataset will be + * `update`: only the rows that were updated in the streaming DataFrame/Dataset will be written to the sink every time there are some updates. If the query doesn't contain aggregations, it will be equivalent to `append` mode. @@ -1170,11 +1170,11 @@ def start(self, path=None, format=None, outputMode=None, partitionBy=None, query :param outputMode: specifies how data of a streaming DataFrame/Dataset is written to a streaming sink. - * `append`:Only the new rows in the streaming DataFrame/Dataset will be written to the + * `append`: Only the new rows in the streaming DataFrame/Dataset will be written to the sink - * `complete`:All the rows in the streaming DataFrame/Dataset will be written to the sink - every time these is some updates - * `update`:only the rows that were updated in the streaming DataFrame/Dataset will be + * `complete`: All the rows in the streaming DataFrame/Dataset will be written to the + sink every time these is some updates + * `update`: only the rows that were updated in the streaming DataFrame/Dataset will be written to the sink every time there are some updates. If the query doesn't contain aggregations, it will be equivalent to `append` mode. :param partitionBy: names of partitioning columns diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 004c79f290213..c59765dd79eb9 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -415,6 +415,33 @@ def run_test(num_records, num_parts, max_records, use_delay=False): for case in cases: run_test(*case) + def test_createDateFrame_with_category_type(self): + pdf = pd.DataFrame({"A": [u"a", u"b", u"c", u"a"]}) + pdf["B"] = pdf["A"].astype('category') + category_first_element = dict(enumerate(pdf['B'].cat.categories))[0] + + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": True}): + arrow_df = self.spark.createDataFrame(pdf) + arrow_type = arrow_df.dtypes[1][1] + result_arrow = arrow_df.toPandas() + arrow_first_category_element = result_arrow["B"][0] + + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): + df = self.spark.createDataFrame(pdf) + spark_type = df.dtypes[1][1] + result_spark = df.toPandas() + spark_first_category_element = result_spark["B"][0] + + assert_frame_equal(result_spark, result_arrow) + + # ensure original category elements are string + self.assertIsInstance(category_first_element, str) + # spark data frame and arrow execution mode enabled data frame type must match pandas + self.assertEqual(spark_type, 'string') + self.assertEqual(arrow_type, 'string') + self.assertIsInstance(arrow_first_category_element, str) + self.assertIsInstance(spark_first_category_element, str) + @unittest.skipIf( not have_pandas or not have_pyarrow, diff --git a/python/pyspark/sql/tests/test_context.py b/python/pyspark/sql/tests/test_context.py index d4a476dd36371..3b1b638ed4aa6 100644 --- a/python/pyspark/sql/tests/test_context.py +++ b/python/pyspark/sql/tests/test_context.py @@ -228,7 +228,7 @@ def test_datetime_functions(self): from datetime import date df = self.spark.range(1).selectExpr("'2017-01-22' as dateCol") parse_result = df.select(functions.to_date(functions.col("dateCol"))).first() - self.assertEquals(date(2017, 1, 22), parse_result['to_date(`dateCol`)']) + self.assertEquals(date(2017, 1, 22), parse_result['to_date(dateCol)']) def test_unbounded_frames(self): from pyspark.sql import functions as F diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 9861178158f85..062e61663a332 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -17,6 +17,8 @@ import os import pydoc +import shutil +import tempfile import time import unittest @@ -820,6 +822,22 @@ def test_same_semantics_error(self): with self.assertRaisesRegexp(ValueError, "should be of DataFrame.*int"): self.spark.range(10).sameSemantics(1) + def test_input_files(self): + tpath = tempfile.mkdtemp() + shutil.rmtree(tpath) + try: + self.spark.range(1, 100, 1, 10).write.parquet(tpath) + # read parquet file and get the input files list + input_files_list = self.spark.read.parquet(tpath).inputFiles() + + # input files list should contain 10 entries + self.assertEquals(len(input_files_list), 10) + # all file paths in list must contain tpath + for file_path in input_files_list: + self.assertTrue(tpath in file_path) + finally: + shutil.rmtree(tpath) + class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils): # These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is diff --git a/python/pyspark/sql/tests/test_pandas_udf.py b/python/pyspark/sql/tests/test_pandas_udf.py index 4218f5cfc401f..7fa65f0e792b9 100644 --- a/python/pyspark/sql/tests/test_pandas_udf.py +++ b/python/pyspark/sql/tests/test_pandas_udf.py @@ -19,14 +19,12 @@ from pyspark.sql.functions import udf, pandas_udf, PandasUDFType from pyspark.sql.types import * -from pyspark.sql.utils import ParseException +from pyspark.sql.utils import ParseException, PythonException from pyspark.rdd import PythonEvalType from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message from pyspark.testing.utils import QuietTest -from py4j.protocol import Py4JJavaError - @unittest.skipIf( not have_pandas or not have_pyarrow, @@ -157,14 +155,14 @@ def foofoo(x, y): # plain udf (test for SPARK-23754) self.assertRaisesRegexp( - Py4JJavaError, + PythonException, exc_message, df.withColumn('v', udf(foo)('id')).collect ) # pandas scalar udf self.assertRaisesRegexp( - Py4JJavaError, + PythonException, exc_message, df.withColumn( 'v', pandas_udf(foo, 'double', PandasUDFType.SCALAR)('id') @@ -173,7 +171,7 @@ def foofoo(x, y): # pandas grouped map self.assertRaisesRegexp( - Py4JJavaError, + PythonException, exc_message, df.groupBy('id').apply( pandas_udf(foo, df.schema, PandasUDFType.GROUPED_MAP) @@ -181,7 +179,7 @@ def foofoo(x, y): ) self.assertRaisesRegexp( - Py4JJavaError, + PythonException, exc_message, df.groupBy('id').apply( pandas_udf(foofoo, df.schema, PandasUDFType.GROUPED_MAP) @@ -190,7 +188,7 @@ def foofoo(x, y): # pandas grouped agg self.assertRaisesRegexp( - Py4JJavaError, + PythonException, exc_message, df.groupBy('id').agg( pandas_udf(foo, 'double', PandasUDFType.GROUPED_AGG)('id') diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index 7260e80e2cfca..2d38efd39f902 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -897,6 +897,24 @@ def test_timestamp_dst(self): result = df.withColumn('time', foo_udf(df.time)) self.assertEquals(df.collect(), result.collect()) + def test_udf_category_type(self): + + @pandas_udf('string') + def to_category_func(x): + return x.astype('category') + + pdf = pd.DataFrame({"A": [u"a", u"b", u"c", u"a"]}) + df = self.spark.createDataFrame(pdf) + df = df.withColumn("B", to_category_func(df['A'])) + result_spark = df.toPandas() + + spark_type = df.dtypes[1][1] + # spark data frame and arrow execution mode enabled data frame type must match pandas + self.assertEqual(spark_type, 'string') + + # Check result of column 'B' must be equal to column 'A' in type and values + pd.testing.assert_series_equal(result_spark["A"], result_spark["B"], check_names=False) + @unittest.skipIf(sys.version_info[:2] < (3, 5), "Type hints are supported from Python 3.5.") def test_type_annotation(self): # Regression test to check if type hints can be used. See SPARK-23569. diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 147ac3325efd9..27adc2372ec0f 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -18,8 +18,19 @@ import py4j import sys +from pyspark import SparkContext + if sys.version_info.major >= 3: unicode = str + # Disable exception chaining (PEP 3134) in captured exceptions + # in order to hide JVM stacktace. + exec(""" +def raise_from(e): + raise e from None +""") +else: + def raise_from(e): + raise e class CapturedException(Exception): @@ -29,7 +40,11 @@ def __init__(self, desc, stackTrace, cause=None): self.cause = convert_exception(cause) if cause is not None else None def __str__(self): + sql_conf = SparkContext._jvm.org.apache.spark.sql.internal.SQLConf.get() + debug_enabled = sql_conf.pysparkJVMStacktraceEnabled() desc = self.desc + if debug_enabled: + desc = desc + "\nJVM stacktrace:\n%s" % self.stackTrace # encode unicode instance for python2 for human readable description if sys.version_info.major < 3 and isinstance(desc, unicode): return str(desc.encode('utf-8')) @@ -67,6 +82,12 @@ class QueryExecutionException(CapturedException): """ +class PythonException(CapturedException): + """ + Exceptions thrown from Python workers. + """ + + class UnknownException(CapturedException): """ None of the above exceptions. @@ -75,21 +96,33 @@ class UnknownException(CapturedException): def convert_exception(e): s = e.toString() - stackTrace = '\n\t at '.join(map(lambda x: x.toString(), e.getStackTrace())) c = e.getCause() + + jvm = SparkContext._jvm + jwriter = jvm.java.io.StringWriter() + e.printStackTrace(jvm.java.io.PrintWriter(jwriter)) + stacktrace = jwriter.toString() if s.startswith('org.apache.spark.sql.AnalysisException: '): - return AnalysisException(s.split(': ', 1)[1], stackTrace, c) + return AnalysisException(s.split(': ', 1)[1], stacktrace, c) if s.startswith('org.apache.spark.sql.catalyst.analysis'): - return AnalysisException(s.split(': ', 1)[1], stackTrace, c) + return AnalysisException(s.split(': ', 1)[1], stacktrace, c) if s.startswith('org.apache.spark.sql.catalyst.parser.ParseException: '): - return ParseException(s.split(': ', 1)[1], stackTrace, c) + return ParseException(s.split(': ', 1)[1], stacktrace, c) if s.startswith('org.apache.spark.sql.streaming.StreamingQueryException: '): - return StreamingQueryException(s.split(': ', 1)[1], stackTrace, c) + return StreamingQueryException(s.split(': ', 1)[1], stacktrace, c) if s.startswith('org.apache.spark.sql.execution.QueryExecutionException: '): - return QueryExecutionException(s.split(': ', 1)[1], stackTrace, c) + return QueryExecutionException(s.split(': ', 1)[1], stacktrace, c) if s.startswith('java.lang.IllegalArgumentException: '): - return IllegalArgumentException(s.split(': ', 1)[1], stackTrace, c) - return UnknownException(s, stackTrace, c) + return IllegalArgumentException(s.split(': ', 1)[1], stacktrace, c) + if c is not None and ( + c.toString().startswith('org.apache.spark.api.python.PythonException: ') + # To make sure this only catches Python UDFs. + and any(map(lambda v: "org.apache.spark.sql.execution.python" in v.toString(), + c.getStackTrace()))): + msg = ("\n An exception was thrown from Python worker in the executor. " + "The below is the Python worker stacktrace.\n%s" % c.getMessage()) + return PythonException(msg, stacktrace) + return UnknownException(s, stacktrace, c) def capture_sql_exception(f): @@ -99,7 +132,9 @@ def deco(*a, **kw): except py4j.protocol.Py4JJavaError as e: converted = convert_exception(e.java_exception) if not isinstance(converted, UnknownException): - raise converted + # Hide where the exception came from that shows a non-Pythonic + # JVM exception message. + raise_from(converted) else: raise return deco diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 769121c19ff4d..6199611940dc9 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -17,7 +17,7 @@ from __future__ import print_function -from py4j.java_gateway import java_import +from py4j.java_gateway import java_import, is_instance_of from pyspark import RDD, SparkConf from pyspark.serializers import NoOpSerializer, UTF8Deserializer, CloudPickleSerializer @@ -341,8 +341,17 @@ def union(self, *dstreams): raise ValueError("All DStreams should have same serializer") if len(set(s._slideDuration for s in dstreams)) > 1: raise ValueError("All DStreams should have same slide duration") - cls = SparkContext._jvm.org.apache.spark.streaming.api.java.JavaDStream - jdstreams = SparkContext._gateway.new_array(cls, len(dstreams)) + jdstream_cls = SparkContext._jvm.org.apache.spark.streaming.api.java.JavaDStream + jpair_dstream_cls = SparkContext._jvm.org.apache.spark.streaming.api.java.JavaPairDStream + gw = SparkContext._gateway + if is_instance_of(gw, dstreams[0]._jdstream, jdstream_cls): + cls = jdstream_cls + elif is_instance_of(gw, dstreams[0]._jdstream, jpair_dstream_cls): + cls = jpair_dstream_cls + else: + cls_name = dstreams[0]._jdstream.getClass().getCanonicalName() + raise TypeError("Unsupported Java DStream class %s" % cls_name) + jdstreams = gw.new_array(cls, len(dstreams)) for i in range(0, len(dstreams)): jdstreams[i] = dstreams[i]._jdstream return DStream(self._jssc.union(jdstreams), self, dstreams[0]._jrdd_deserializer) diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py index 62ad4221d7078..6c5b818056f2d 100644 --- a/python/pyspark/tests/test_rdd.py +++ b/python/pyspark/tests/test_rdd.py @@ -168,6 +168,17 @@ def test_zip_chaining(self): set([(x, (x, x)) for x in 'abc']) ) + def test_union_pair_rdd(self): + # SPARK-31788: test if pair RDDs can be combined by union. + rdd = self.sc.parallelize([1, 2]) + pair_rdd = rdd.zip(rdd) + unionRDD = self.sc.union([pair_rdd, pair_rdd]) + self.assertEqual( + set(unionRDD.collect()), + set([(1, 1), (2, 2), (1, 1), (2, 2)]) + ) + self.assertEqual(unionRDD.count(), 4) + def test_deleting_input_files(self): # Regression test for SPARK-1025 tempFile = tempfile.NamedTemporaryFile(delete=False) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 988941e7550b9..5f4a8a2d2db1f 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -36,7 +36,7 @@ from pyspark.java_gateway import local_connect_and_auth from pyspark.taskcontext import BarrierTaskContext, TaskContext from pyspark.files import SparkFiles -from pyspark.resourceinformation import ResourceInformation +from pyspark.resource import ResourceInformation from pyspark.rdd import PythonEvalType from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ diff --git a/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala index 4795306692f7a..e11a54bc88070 100644 --- a/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala @@ -380,6 +380,67 @@ class SingletonReplSuite extends SparkFunSuite { assertDoesNotContain("Exception", output) } + test("SPARK-31399: should clone+clean line object w/ non-serializable state in ClosureCleaner") { + // Test ClosureCleaner when a closure captures the enclosing `this` REPL line object, and that + // object contains an unused non-serializable field. + // Specifically, the closure in this test case contains a directly nested closure, and the + // capture is triggered by the inner closure. + // `ns` should be nulled out, but `topLevelValue` should stay intact. + + // Can't use :paste mode because PipedOutputStream/PipedInputStream doesn't work well with the + // EOT control character (i.e. Ctrl+D). + // Just write things on a single line to emulate :paste mode. + + // NOTE: in order for this test case to trigger the intended scenario, the following three + // variables need to be in the same "input", which will make the REPL pack them into the + // same REPL line object: + // - ns: a non-serializable state, not accessed by the closure; + // - topLevelValue: a serializable state, accessed by the closure; + // - closure: the starting closure, captures the enclosing REPL line object. + val output = runInterpreter( + """ + |class NotSerializableClass(val x: Int) + |val ns = new NotSerializableClass(42); val topLevelValue = "someValue"; val closure = + |(j: Int) => { + | (1 to j).flatMap { x => + | (1 to x).map { y => y + topLevelValue } + | } + |} + |val r = sc.parallelize(0 to 2).map(closure).collect + """.stripMargin) + assertContains("r: Array[scala.collection.immutable.IndexedSeq[String]] = " + + "Array(Vector(), Vector(1someValue), Vector(1someValue, 1someValue, 2someValue))", output) + assertDoesNotContain("Exception", output) + } + + test("SPARK-31399: ClosureCleaner should discover indirectly nested closure in inner class") { + // Similar to the previous test case, but with indirect closure nesting instead. + // There's still nested closures involved, but the inner closure is indirectly nested in the + // outer closure, with a level of inner class in between them. + // This changes how the inner closure references/captures the outer closure/enclosing `this` + // REPL line object, and covers a different code path in inner closure discovery. + + // `ns` should be nulled out, but `topLevelValue` should stay intact. + + val output = runInterpreter( + """ + |class NotSerializableClass(val x: Int) + |val ns = new NotSerializableClass(42); val topLevelValue = "someValue"; val closure = + |(j: Int) => { + | class InnerFoo { + | val innerClosure = (x: Int) => (1 to x).map { y => y + topLevelValue } + | } + | val innerFoo = new InnerFoo + | (1 to j).flatMap(innerFoo.innerClosure) + |} + |val r = sc.parallelize(0 to 2).map(closure).collect + """.stripMargin) + assertContains("r: Array[scala.collection.immutable.IndexedSeq[String]] = " + + "Array(Vector(), Vector(1someValue), Vector(1someValue, 1someValue, 2someValue))", output) + assertDoesNotContain("Array(Vector(), Vector(1null), Vector(1null, 1null, 2null)", output) + assertDoesNotContain("Exception", output) + } + test("newProductSeqEncoder with REPL defined class") { val output = runInterpreter( """ diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index b527816015c63..c1a7dafb69c46 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -30,7 +30,7 @@ kubernetes - 4.7.1 + 4.9.2 diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index fa58b98fba04c..274b859fef96d 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -419,6 +419,7 @@ private[spark] object Config extends Logging { val KUBERNETES_DRIVER_LABEL_PREFIX = "spark.kubernetes.driver.label." val KUBERNETES_DRIVER_ANNOTATION_PREFIX = "spark.kubernetes.driver.annotation." + val KUBERNETES_DRIVER_SERVICE_ANNOTATION_PREFIX = "spark.kubernetes.driver.service.annotation." val KUBERNETES_DRIVER_SECRETS_PREFIX = "spark.kubernetes.driver.secrets." val KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX = "spark.kubernetes.driver.secretKeyRef." val KUBERNETES_DRIVER_VOLUMES_PREFIX = "spark.kubernetes.driver.volumes." diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala index a6fc8519108c6..6bd7fa81c0e37 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -110,6 +110,11 @@ private[spark] class KubernetesDriverConf( KubernetesUtils.parsePrefixedKeyValuePairs(sparkConf, KUBERNETES_DRIVER_ANNOTATION_PREFIX) } + def serviceAnnotations: Map[String, String] = { + KubernetesUtils.parsePrefixedKeyValuePairs(sparkConf, + KUBERNETES_DRIVER_SERVICE_ANNOTATION_PREFIX) + } + override def secretNamesToMountPaths: Map[String, String] = { KubernetesUtils.parsePrefixedKeyValuePairs(sparkConf, KUBERNETES_DRIVER_SECRETS_PREFIX) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala index 925bcdf3e637f..1e9c60c871479 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala @@ -69,6 +69,7 @@ private[spark] class DriverServiceFeatureStep( val driverService = new ServiceBuilder() .withNewMetadata() .withName(resolvedServiceName) + .addToAnnotations(kubernetesConf.serviceAnnotations.asJava) .endMetadata() .withNewSpec() .withClusterIP("None") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesTestConf.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesTestConf.scala index 518158a783492..d6871a6c2866a 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesTestConf.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesTestConf.scala @@ -48,6 +48,7 @@ object KubernetesTestConf { labels: Map[String, String] = Map.empty, environment: Map[String, String] = Map.empty, annotations: Map[String, String] = Map.empty, + serviceAnnotations: Map[String, String] = Map.empty, secretEnvNamesToKeyRefs: Map[String, String] = Map.empty, secretNamesToMountPaths: Map[String, String] = Map.empty, volumes: Seq[KubernetesVolumeSpec] = Seq.empty, @@ -60,6 +61,7 @@ object KubernetesTestConf { setPrefixedConfigs(conf, KUBERNETES_DRIVER_LABEL_PREFIX, labels) setPrefixedConfigs(conf, KUBERNETES_DRIVER_ENV_PREFIX, environment) setPrefixedConfigs(conf, KUBERNETES_DRIVER_ANNOTATION_PREFIX, annotations) + setPrefixedConfigs(conf, KUBERNETES_DRIVER_SERVICE_ANNOTATION_PREFIX, serviceAnnotations) setPrefixedConfigs(conf, KUBERNETES_DRIVER_SECRETS_PREFIX, secretNamesToMountPaths) setPrefixedConfigs(conf, KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX, secretEnvNamesToKeyRefs) setVolumeSpecs(conf, KUBERNETES_DRIVER_VOLUMES_PREFIX, volumes) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala index 9068289bab581..18afd10395566 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala @@ -38,6 +38,9 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite { private val DRIVER_LABELS = Map( "label1key" -> "label1value", "label2key" -> "label2value") + private val DRIVER_SERVICE_ANNOTATIONS = Map( + "annotation1key" -> "annotation1value", + "annotation2key" -> "annotation2value") test("Headless service has a port for the driver RPC, the block manager and driver ui.") { val sparkConf = new SparkConf(false) @@ -46,7 +49,8 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite { .set(UI_PORT, 4080) val kconf = KubernetesTestConf.createDriverConf( sparkConf = sparkConf, - labels = DRIVER_LABELS) + labels = DRIVER_LABELS, + serviceAnnotations = DRIVER_SERVICE_ANNOTATIONS) val configurationStep = new DriverServiceFeatureStep(kconf) assert(configurationStep.configurePod(SparkPod.initialPod()) === SparkPod.initialPod()) assert(configurationStep.getAdditionalKubernetesResources().size === 1) @@ -79,7 +83,9 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite { } test("Ports should resolve to defaults in SparkConf and in the service.") { - val kconf = KubernetesTestConf.createDriverConf(labels = DRIVER_LABELS) + val kconf = KubernetesTestConf.createDriverConf( + labels = DRIVER_LABELS, + serviceAnnotations = DRIVER_SERVICE_ANNOTATIONS) val configurationStep = new DriverServiceFeatureStep(kconf) val resolvedService = configurationStep .getAdditionalKubernetesResources() @@ -164,6 +170,9 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite { DRIVER_LABELS.foreach { case (k, v) => assert(service.getSpec.getSelector.get(k) === v) } + DRIVER_SERVICE_ANNOTATIONS.foreach { case (k, v) => + assert(service.getMetadata.getAnnotations.get(k) === v) + } assert(service.getSpec.getPorts.size() === 3) val driverServicePorts = service.getSpec.getPorts.asScala assert(driverServicePorts.head.getName === DRIVER_PORT_NAME) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala index 8c683e85dd5e2..894e1e4978178 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.Fabric8Aliases._ import org.apache.spark.resource.ResourceProfileManager import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} -import org.apache.spark.scheduler.{ExecutorKilled, TaskSchedulerImpl} +import org.apache.spark.scheduler.{ExecutorKilled, LiveListenerBus, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils.TEST_SPARK_APP_ID @@ -87,7 +87,8 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn private var driverEndpoint: ArgumentCaptor[RpcEndpoint] = _ private var schedulerBackendUnderTest: KubernetesClusterSchedulerBackend = _ - private val resourceProfileManager = new ResourceProfileManager(sparkConf) + private val listenerBus = new LiveListenerBus(new SparkConf()) + private val resourceProfileManager = new ResourceProfileManager(sparkConf, listenerBus) before { MockitoAnnotations.initMocks(this) diff --git a/resource-managers/kubernetes/integration-tests/README.md b/resource-managers/kubernetes/integration-tests/README.md index 18b91916208d6..6409c227ec287 100644 --- a/resource-managers/kubernetes/integration-tests/README.md +++ b/resource-managers/kubernetes/integration-tests/README.md @@ -17,6 +17,10 @@ To run tests with Java 11 instead of Java 8, use `--java-image-tag` to specify t ./dev/dev-run-integration-tests.sh --java-image-tag 11-jre-slim +To run tests with Hadoop 3.2 instead of Hadoop 2.7, use `--hadoop-profile`. + + ./dev/dev-run-integration-tests.sh --hadoop-profile hadoop-3.2 + The minimum tested version of Minikube is 0.23.0. The kube-dns addon must be enabled. Minikube should run with a minimum of 4 CPUs and 6G of memory: diff --git a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh index 292abe91d35b6..9c03a97ef15d5 100755 --- a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh +++ b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh @@ -35,6 +35,7 @@ CONTEXT= INCLUDE_TAGS="k8s" EXCLUDE_TAGS= JAVA_VERSION="8" +HADOOP_PROFILE="hadoop-2.7" MVN="$TEST_ROOT_DIR/build/mvn" SCALA_VERSION=$("$MVN" help:evaluate -Dexpression=scala.binary.version 2>/dev/null\ @@ -112,6 +113,10 @@ while (( "$#" )); do JAVA_VERSION="$2" shift ;; + --hadoop-profile) + HADOOP_PROFILE="$2" + shift + ;; *) echo "Unexpected command line flag $2 $1." exit 1 @@ -171,4 +176,4 @@ properties+=( -Dlog4j.logger.org.apache.spark=DEBUG ) -$TEST_ROOT_DIR/build/mvn integration-test -f $TEST_ROOT_DIR/pom.xml -pl resource-managers/kubernetes/integration-tests -am -Pscala-$SCALA_VERSION -Pkubernetes -Pkubernetes-integration-tests ${properties[@]} +$TEST_ROOT_DIR/build/mvn integration-test -f $TEST_ROOT_DIR/pom.xml -pl resource-managers/kubernetes/integration-tests -am -Pscala-$SCALA_VERSION -P$HADOOP_PROFILE -Pkubernetes -Pkubernetes-integration-tests ${properties[@]} diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index 7a889c427b41e..503540403f5ec 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -29,7 +29,7 @@ 1.3.0 1.4.0 - 4.7.1 + 4.9.2 3.2.2 1.0 kubernetes-integration-tests @@ -77,12 +77,6 @@ spark-tags_${scala.binary.version} test-jar - - com.amazonaws - aws-java-sdk - 1.7.4 - test - @@ -121,6 +115,9 @@ --spark-tgz ${spark.kubernetes.test.sparkTgz} + + --test-exclude-tags + "${test.exclude.tags}" @@ -186,4 +183,31 @@ + + + hadoop-2.7 + + true + + + + com.amazonaws + aws-java-sdk + 1.7.4 + test + + + + + hadoop-3.2 + + + com.amazonaws + aws-java-sdk-bundle + 1.11.375 + test + + + + diff --git a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh index ab906604fce06..beda56cf37c94 100755 --- a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh +++ b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh @@ -25,6 +25,8 @@ IMAGE_REPO="docker.io/kubespark" IMAGE_TAG="N/A" JAVA_IMAGE_TAG="8-jre-slim" SPARK_TGZ="N/A" +MVN="$TEST_ROOT_DIR/build/mvn" +EXCLUDE_TAGS="" # Parse arguments while (( "$#" )); do @@ -57,6 +59,10 @@ while (( "$#" )); do SPARK_TGZ="$2" shift ;; + --test-exclude-tags) + EXCLUDE_TAGS="$2" + shift + ;; *) break ;; @@ -84,7 +90,11 @@ fi # If there is a specific Spark image skip building and extraction/copy if [[ $IMAGE_TAG == "N/A" ]]; then - IMAGE_TAG=$(uuidgen); + VERSION=$("$MVN" help:evaluate -Dexpression=project.version \ + | grep -v "INFO"\ + | grep -v "WARNING"\ + | tail -n 1) + IMAGE_TAG=${VERSION}_$(uuidgen) cd $SPARK_INPUT_DIR # OpenJDK base-image tag (e.g. 8-jre-slim, 11-jre-slim) @@ -94,7 +104,10 @@ then LANGUAGE_BINDING_BUILD_ARGS="-p $DOCKER_FILE_BASE_PATH/bindings/python/Dockerfile" # Build SparkR image - LANGUAGE_BINDING_BUILD_ARGS="$LANGUAGE_BINDING_BUILD_ARGS -R $DOCKER_FILE_BASE_PATH/bindings/R/Dockerfile" + tags=(${EXCLUDE_TAGS//,/ }) + if [[ ! ${tags[@]} =~ "r" ]]; then + LANGUAGE_BINDING_BUILD_ARGS="$LANGUAGE_BINDING_BUILD_ARGS -R $DOCKER_FILE_BASE_PATH/bindings/R/Dockerfile" + fi # Unset SPARK_HOME to let the docker-image-tool script detect SPARK_HOME. Otherwise, it cannot # indicate the unpacked directory as its home. See SPARK-28550. diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/BasicTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/BasicTestsSuite.scala index 6b340f2558cca..76221e46d1cf4 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/BasicTestsSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/BasicTestsSuite.scala @@ -111,6 +111,6 @@ private[spark] object BasicTestsSuite { val CONTAINER_LOCAL_DOWNLOADED_PAGE_RANK_DATA_FILE = s"$CONTAINER_LOCAL_FILE_DOWNLOAD_PATH/pagerank_data.txt" val REMOTE_PAGE_RANK_DATA_FILE = - "https://storage.googleapis.com/spark-k8s-integration-tests/files/pagerank_data.txt" + "https://raw.githubusercontent.com/apache/spark/master/data/mllib/pagerank_data.txt" val REMOTE_PAGE_RANK_FILE_NAME = "pagerank_data.txt" } diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DepsTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DepsTestsSuite.scala index 2d90c06e36390..e712b95cdbcea 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DepsTestsSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DepsTestsSuite.scala @@ -148,6 +148,11 @@ private[spark] trait DepsTestsSuite { k8sSuite: KubernetesSuite => } test("Launcher client dependencies", k8sTestTag, MinikubeTag) { + val packages = if (Utils.isHadoop3) { + "org.apache.hadoop:hadoop-aws:3.2.0" + } else { + "com.amazonaws:aws-java-sdk:1.7.4,org.apache.hadoop:hadoop-aws:2.7.6" + } val fileName = Utils.createTempFile(FILE_CONTENTS, HOST_PATH) try { setupMinioStorage() @@ -164,8 +169,7 @@ private[spark] trait DepsTestsSuite { k8sSuite: KubernetesSuite => .set("spark.kubernetes.file.upload.path", s"s3a://$BUCKET") .set("spark.files", s"$HOST_PATH/$fileName") .set("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem") - .set("spark.jars.packages", "com.amazonaws:aws-java-sdk:" + - "1.7.4,org.apache.hadoop:hadoop-aws:2.7.6") + .set("spark.jars.packages", packages) .set("spark.driver.extraJavaOptions", "-Divy.cache.dir=/tmp -Divy.home=/tmp") createS3Bucket(ACCESS_KEY, SECRET_KEY, minioUrlStr) runSparkRemoteCheckAndVerifyCompletion(appResource = examplesJar, diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 4de7e70c1f409..65a2f1ff79697 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -476,6 +476,7 @@ class KubernetesSuite extends SparkFunSuite private[spark] object KubernetesSuite { val k8sTestTag = Tag("k8s") + val rTestTag = Tag("r") val MinikubeTag = Tag("minikube") val SPARK_PI_MAIN_CLASS: String = "org.apache.spark.examples.SparkPi" val SPARK_DFS_READ_WRITE_TEST = "org.apache.spark.examples.DFSReadWriteTest" diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/RTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/RTestsSuite.scala index e81562a923228..b7c8886a15ae7 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/RTestsSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/RTestsSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.deploy.k8s.integrationtest private[spark] trait RTestsSuite { k8sSuite: KubernetesSuite => import RTestsSuite._ - import KubernetesSuite.k8sTestTag + import KubernetesSuite.{k8sTestTag, rTestTag} - test("Run SparkR on simple dataframe.R example", k8sTestTag) { + test("Run SparkR on simple dataframe.R example", k8sTestTag, rTestTag) { sparkAppConf.set("spark.kubernetes.container.image", rImage) runSparkApplicationAndVerifyCompletion( appResource = SPARK_R_DATAFRAME_TEST, diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/Utils.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/Utils.scala index 9f85805b9d315..0000a94725763 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/Utils.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/Utils.scala @@ -21,13 +21,16 @@ import java.nio.file.{Files, Path} import java.util.concurrent.CountDownLatch import scala.collection.JavaConverters._ +import scala.util.Try import io.fabric8.kubernetes.client.dsl.ExecListener import okhttp3.Response import org.apache.commons.io.output.ByteArrayOutputStream +import org.apache.hadoop.util.VersionInfo import org.apache.spark.{SPARK_VERSION, SparkException} import org.apache.spark.internal.Logging +import org.apache.spark.util.{Utils => SparkUtils} object Utils extends Logging { @@ -131,4 +134,8 @@ object Utils extends Logging { s"under spark home test dir ${sparkHomeDir.toAbsolutePath}!") } } + + def isHadoop3(): Boolean = { + VersionInfo.getVersion.startsWith("3") + } } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 43cd7458ef55b..9a6a43914bca3 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -42,6 +42,7 @@ import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.history.HistoryServer +import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ @@ -863,10 +864,22 @@ object ApplicationMaster extends Logging { val ugi = sparkConf.get(PRINCIPAL) match { // We only need to log in with the keytab in cluster mode. In client mode, the driver // handles the user keytab. - case Some(principal) if amArgs.userClass != null => + case Some(principal) if master.isClusterMode => val originalCreds = UserGroupInformation.getCurrentUser().getCredentials() SparkHadoopUtil.get.loginUserFromKeytab(principal, sparkConf.get(KEYTAB).orNull) val newUGI = UserGroupInformation.getCurrentUser() + + if (master.appAttemptId == null || master.appAttemptId.getAttemptId > 1) { + // Re-obtain delegation tokens if this is not a first attempt, as they might be outdated + // as of now. Add the fresh tokens on top of the original user's credentials (overwrite). + // Set the context class loader so that the token manager has access to jars + // distributed by the user. + Utils.withContextClassLoader(master.userClassLoader) { + val credentialManager = new HadoopDelegationTokenManager(sparkConf, yarnConf, null) + credentialManager.obtainDelegationTokens(originalCreds) + } + } + // Transfer the original user's tokens to the new user, since it may contain needed tokens // (such as those user to connect to YARN). newUGI.addCredentials(originalCreds) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala index fa8c9610220c8..339d3715a7316 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala @@ -103,7 +103,14 @@ private[spark] class YarnAllocatorBlacklistTracker( refreshBlacklistedNodes() } - def isAllNodeBlacklisted: Boolean = currentBlacklistedYarnNodes.size >= numClusterNodes + def isAllNodeBlacklisted: Boolean = { + if (numClusterNodes <= 0) { + logWarning("No available nodes reported, please check Resource Manager.") + false + } else { + currentBlacklistedYarnNodes.size >= numClusterNodes + } + } private def refreshBlacklistedNodes(): Unit = { removeExpiredYarnBlacklistedNodes() diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index b335e7fc04f53..9c5c376ce5357 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -215,6 +215,8 @@ class ClientSuite extends SparkFunSuite with Matchers { } test("specify a more specific type for the application") { + // TODO (SPARK-31733) Make this test case pass with Hadoop-3.2 + assume(!isYarnResourceTypesAvailable) // When the type exceeds 20 characters will be truncated by yarn val appTypes = Map( 1 -> ("", ""), diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/LocalityPlacementStrategySuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/LocalityPlacementStrategySuite.scala index 727851747e088..3c9209c292418 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/LocalityPlacementStrategySuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/LocalityPlacementStrategySuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.deploy.yarn -import scala.collection.JavaConverters._ +import java.io.{PrintWriter, StringWriter} + import scala.collection.mutable.{HashMap, HashSet, Set} import org.apache.hadoop.yarn.api.records._ @@ -46,7 +47,11 @@ class LocalityPlacementStrategySuite extends SparkFunSuite { thread.start() thread.join() - assert(error === null) + if (error != null) { + val errors = new StringWriter() + error.printStackTrace(new PrintWriter(errors)) + fail(s"StackOverflowError should not be thrown; however, got:\n\n$errors") + } } private def runTest(): Unit = { @@ -57,7 +62,6 @@ class LocalityPlacementStrategySuite extends SparkFunSuite { // goal is to create enough requests for localized containers (so there should be many // tasks on several hosts that have no allocated containers). - val resource = Resource.newInstance(8 * 1024, 4) val strategy = new LocalityPreferredContainerPlacementStrategy(new SparkConf(), yarnConf, new MockResolver()) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala index e260c0b60b87c..632c66d77b707 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala @@ -39,7 +39,7 @@ class YarnShuffleServiceMetricsSuite extends SparkFunSuite with Matchers { val allMetrics = Set( "openBlockRequestLatencyMillis", "registerExecutorRequestLatencyMillis", "blockTransferRateBytes", "registeredExecutorsSize", "numActiveConnections", - "numRegisteredConnections", "numCaughtExceptions") + "numCaughtExceptions") metrics.getMetrics.keySet().asScala should be (allMetrics) } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala index 381a93580f961..1a5a099217f55 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala @@ -23,9 +23,13 @@ import java.nio.file.attribute.PosixFilePermission._ import java.util.EnumSet import scala.annotation.tailrec +import scala.collection.JavaConverters._ import scala.concurrent.duration._ +import com.codahale.metrics.MetricSet import org.apache.hadoop.fs.Path +import org.apache.hadoop.metrics2.impl.MetricsSystemImpl +import org.apache.hadoop.metrics2.lib.DefaultMetricsSystem import org.apache.hadoop.service.ServiceStateException import org.apache.hadoop.yarn.api.records.ApplicationId import org.apache.hadoop.yarn.conf.YarnConfiguration @@ -381,4 +385,27 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd s1.secretsFile should be (null) } + test("SPARK-31646: metrics should be registered into Node Manager's metrics system") { + s1 = new YarnShuffleService + s1.init(yarnConfig) + + val metricsSource = DefaultMetricsSystem.instance.asInstanceOf[MetricsSystemImpl] + .getSource("sparkShuffleService").asInstanceOf[YarnShuffleServiceMetrics] + val metricSetRef = classOf[YarnShuffleServiceMetrics].getDeclaredField("metricSet") + metricSetRef.setAccessible(true) + val metrics = metricSetRef.get(metricsSource).asInstanceOf[MetricSet].getMetrics + + assert(metrics.keySet().asScala == Set( + "blockTransferRateBytes", + "numActiveConnections", + "numCaughtExceptions", + "numRegisteredConnections", + "openBlockRequestLatencyMillis", + "registeredExecutorsSize", + "registerExecutorRequestLatencyMillis", + "shuffle-server.usedDirectMemory", + "shuffle-server.usedHeapMemory" + )) + } + } diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh index 81f2fd40a706f..e563f7bff1667 100755 --- a/sbin/spark-daemon.sh +++ b/sbin/spark-daemon.sh @@ -23,6 +23,7 @@ # # SPARK_CONF_DIR Alternate conf dir. Default is ${SPARK_HOME}/conf. # SPARK_LOG_DIR Where log files are stored. ${SPARK_HOME}/logs by default. +# SPARK_LOG_MAX_FILES Max log files of Spark daemons can rotate to. Default is 5. # SPARK_MASTER host:path where spark code should be rsync'd from # SPARK_PID_DIR The pid files are stored. /tmp by default. # SPARK_IDENT_STRING A string representing this instance of spark. $USER by default @@ -74,10 +75,16 @@ shift spark_rotate_log () { log=$1; - num=5; - if [ -n "$2" ]; then - num=$2 + + if [[ -z ${SPARK_LOG_MAX_FILES} ]]; then + num=5 + elif [[ ${SPARK_LOG_MAX_FILES} -gt 0 ]]; then + num=${SPARK_LOG_MAX_FILES} + else + echo "Error: SPARK_LOG_MAX_FILES must be a positive number, but got ${SPARK_LOG_MAX_FILES}" + exit -1 fi + if [ -f "$log" ]; then # rotate logs while [ $num -gt 1 ]; do prev=`expr $num - 1` diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index e49bc07f18c59..b03e6372a8eae 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -29,12 +29,6 @@ grammar SqlBase; */ public boolean legacy_exponent_literal_as_decimal_enabled = false; - /** - * When false, CREATE TABLE syntax without a provider will use - * the value of spark.sql.sources.default as its provider. - */ - public boolean legacy_create_hive_table_by_default_enabled = false; - /** * Verify whether current token is a valid decimal token (which contains dot). * Returns true if the character that follows the token is not a digit or letter or underscore. @@ -123,12 +117,7 @@ statement (RESTRICT | CASCADE)? #dropNamespace | SHOW (DATABASES | NAMESPACES) ((FROM | IN) multipartIdentifier)? (LIKE? pattern=STRING)? #showNamespaces - | {!legacy_create_hive_table_by_default_enabled}? - createTableHeader ('(' colTypeList ')')? tableProvider? - createTableClauses - (AS? query)? #createTable - | {legacy_create_hive_table_by_default_enabled}? - createTableHeader ('(' colTypeList ')')? tableProvider + | createTableHeader ('(' colTypeList ')')? tableProvider createTableClauses (AS? query)? #createTable | createTableHeader ('(' columns=colTypeList ')')? @@ -1815,7 +1804,7 @@ fragment LETTER ; SIMPLE_COMMENT - : '--' ~[\r\n]* '\r'? '\n'? -> channel(HIDDEN) + : '--' ('\\\n' | ~[\r\n])* '\r'? '\n'? -> channel(HIDDEN) ; BRACKETED_COMMENT diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 4487a2d7f4358..5b17f1d65f1bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -571,14 +571,10 @@ trait Row extends Serializable { case (s: String, _) => JString(s) case (b: Array[Byte], BinaryType) => JString(Base64.getEncoder.encodeToString(b)) - case (d: LocalDate, _) => - JString(dateFormatter.format(DateTimeUtils.localDateToDays(d))) - case (d: Date, _) => - JString(dateFormatter.format(DateTimeUtils.fromJavaDate(d))) - case (i: Instant, _) => - JString(timestampFormatter.format(DateTimeUtils.instantToMicros(i))) - case (t: Timestamp, _) => - JString(timestampFormatter.format(DateTimeUtils.fromJavaTimestamp(t))) + case (d: LocalDate, _) => JString(dateFormatter.format(d)) + case (d: Date, _) => JString(dateFormatter.format(d)) + case (i: Instant, _) => JString(timestampFormatter.format(i)) + case (t: Timestamp, _) => JString(timestampFormatter.format(t)) case (i: CalendarInterval, _) => JString(i.toString) case (a: Array[_], ArrayType(elementType, _)) => iteratorToJsonArray(a.iterator, elementType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index e55c25c4b0c54..701e4e3483c0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -161,6 +161,10 @@ object DeserializerBuildHelper { case _: StructType => expr case _: ArrayType => expr case _: MapType => expr + case _: DecimalType => + // For Scala/Java `BigDecimal`, we accept decimal types of any valid precision/scale. + // Here we use the `DecimalType` object to indicate it. + UpCast(expr, DecimalType, walkedTypePath.getPaths) case _ => UpCast(expr, expected, walkedTypePath.getPaths) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 29796abb2d4fe..91ac0a6f404cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -557,18 +557,13 @@ class Analyzer( } } - /* - * Construct [[Aggregate]] operator from Cube/Rollup/GroupingSets. - */ - private def constructAggregate( + private def getFinalGroupByExpressions( selectedGroupByExprs: Seq[Seq[Expression]], - groupByExprs: Seq[Expression], - aggregationExprs: Seq[NamedExpression], - child: LogicalPlan): LogicalPlan = { + groupByExprs: Seq[Expression]): Seq[Expression] = { // In case of ANSI-SQL compliant syntax for GROUPING SETS, groupByExprs is optional and // can be null. In such case, we derive the groupByExprs from the user supplied values for // grouping sets. - val finalGroupByExpressions = if (groupByExprs == Nil) { + if (groupByExprs == Nil) { selectedGroupByExprs.flatten.foldLeft(Seq.empty[Expression]) { (result, currentExpr) => // Only unique expressions are included in the group by expressions and is determined // based on their semantic equality. Example. grouping sets ((a * b), (b * a)) results @@ -582,6 +577,17 @@ class Analyzer( } else { groupByExprs } + } + + /* + * Construct [[Aggregate]] operator from Cube/Rollup/GroupingSets. + */ + private def constructAggregate( + selectedGroupByExprs: Seq[Seq[Expression]], + groupByExprs: Seq[Expression], + aggregationExprs: Seq[NamedExpression], + child: LogicalPlan): LogicalPlan = { + val finalGroupByExpressions = getFinalGroupByExpressions(selectedGroupByExprs, groupByExprs) if (finalGroupByExpressions.size > GroupingID.dataType.defaultSize * 8) { throw new AnalysisException( @@ -619,8 +625,70 @@ class Analyzer( } } - // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { + private def tryResolveHavingCondition(h: UnresolvedHaving): LogicalPlan = { + val aggForResolving = h.child match { + // For CUBE/ROLLUP expressions, to avoid resolving repeatedly, here we delete them from + // groupingExpressions for condition resolving. + case a @ Aggregate(Seq(c @ Cube(groupByExprs)), _, _) => + a.copy(groupingExpressions = groupByExprs) + case a @ Aggregate(Seq(r @ Rollup(groupByExprs)), _, _) => + a.copy(groupingExpressions = groupByExprs) + case g: GroupingSets => + Aggregate( + getFinalGroupByExpressions(g.selectedGroupByExprs, g.groupByExprs), + g.aggregations, g.child) + } + // Try resolving the condition of the filter as though it is in the aggregate clause + val resolvedInfo = + ResolveAggregateFunctions.resolveFilterCondInAggregate(h.havingCondition, aggForResolving) + + // Push the aggregate expressions into the aggregate (if any). + if (resolvedInfo.nonEmpty) { + val (extraAggExprs, resolvedHavingCond) = resolvedInfo.get + val newChild = h.child match { + case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) => + constructAggregate( + cubeExprs(groupByExprs), groupByExprs, aggregateExpressions ++ extraAggExprs, child) + case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) => + constructAggregate( + rollupExprs(groupByExprs), groupByExprs, aggregateExpressions ++ extraAggExprs, child) + case x: GroupingSets => + constructAggregate( + x.selectedGroupByExprs, x.groupByExprs, x.aggregations ++ extraAggExprs, x.child) + } + + // Since the exprId of extraAggExprs will be changed in the constructed aggregate, and the + // aggregateExpressions keeps the input order. So here we build an exprMap to resolve the + // condition again. + val exprMap = extraAggExprs.zip( + newChild.asInstanceOf[Aggregate].aggregateExpressions.takeRight( + extraAggExprs.length)).toMap + val newCond = resolvedHavingCond.transform { + case ne: NamedExpression if exprMap.contains(ne) => exprMap(ne) + } + Project(newChild.output.dropRight(extraAggExprs.length), + Filter(newCond, newChild)) + } else { + h + } + } + + // This require transformDown to resolve having condition when generating aggregate node for + // CUBE/ROLLUP/GROUPING SETS. This also replace grouping()/grouping_id() in resolved + // Filter/Sort. + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { + case h @ UnresolvedHaving( + _, agg @ Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, _)) + if agg.childrenResolved && (groupByExprs ++ aggregateExpressions).forall(_.resolved) => + tryResolveHavingCondition(h) + case h @ UnresolvedHaving( + _, agg @ Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, _)) + if agg.childrenResolved && (groupByExprs ++ aggregateExpressions).forall(_.resolved) => + tryResolveHavingCondition(h) + case h @ UnresolvedHaving(_, g: GroupingSets) + if g.childrenResolved && g.expressions.forall(_.resolved) => + tryResolveHavingCondition(h) + case a if !a.childrenResolved => a // be sure all of the children are resolved. // Ensure group by expressions and aggregate expressions have been resolved. @@ -971,7 +1039,7 @@ class Analyzer( private def lookupRelation(identifier: Seq[String]): Option[LogicalPlan] = { expandRelationName(identifier) match { case SessionCatalogAndIdentifier(catalog, ident) => - def loaded = CatalogV2Util.loadTable(catalog, ident).map { + lazy val loaded = CatalogV2Util.loadTable(catalog, ident).map { case v1Table: V1Table => v1SessionCatalog.getRelation(v1Table.v1Table) case table => @@ -980,7 +1048,12 @@ class Analyzer( DataSourceV2Relation.create(table, Some(catalog), Some(ident))) } val key = catalog.name +: ident.namespace :+ ident.name - Option(AnalysisContext.get.relationCache.getOrElseUpdate(key, loaded.orNull)) + AnalysisContext.get.relationCache.get(key).map(_.transform { + case multi: MultiInstanceRelation => multi.newInstance() + }).orElse { + loaded.foreach(AnalysisContext.get.relationCache.update(key, _)) + loaded + } case _ => None } } @@ -1428,7 +1501,7 @@ class Analyzer( } // Skip the having clause here, this will be handled in ResolveAggregateFunctions. - case h: AggregateWithHaving => h + case h: UnresolvedHaving => h case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString(SQLConf.get.maxToStringFields)}") @@ -2073,7 +2146,7 @@ class Analyzer( // Resolve aggregate with having clause to Filter(..., Aggregate()). Note, to avoid wrongly // resolve the having condition expression, here we skip resolving it in ResolveReferences // and transform it to Filter after aggregate is resolved. See more details in SPARK-31519. - case AggregateWithHaving(cond, agg: Aggregate) if agg.resolved => + case UnresolvedHaving(cond, agg: Aggregate) if agg.resolved => resolveHaving(Filter(cond, agg), agg) case f @ Filter(_, agg: Aggregate) if agg.resolved => @@ -2149,13 +2222,13 @@ class Analyzer( condition.find(_.isInstanceOf[AggregateExpression]).isDefined } - def resolveHaving(filter: Filter, agg: Aggregate): LogicalPlan = { - // Try resolving the condition of the filter as though it is in the aggregate clause + def resolveFilterCondInAggregate( + filterCond: Expression, agg: Aggregate): Option[(Seq[NamedExpression], Expression)] = { try { val aggregatedCondition = Aggregate( agg.groupingExpressions, - Alias(filter.condition, "havingCondition")() :: Nil, + Alias(filterCond, "havingCondition")() :: Nil, agg.child) val resolvedOperator = executeSameContext(aggregatedCondition) def resolvedAggregateFilter = @@ -2187,22 +2260,33 @@ class Analyzer( alias.toAttribute } } - - // Push the aggregate expressions into the aggregate (if any). if (aggregateExpressions.nonEmpty) { - Project(agg.output, - Filter(transformedAggregateFilter, - agg.copy(aggregateExpressions = agg.aggregateExpressions ++ aggregateExpressions))) + Some(aggregateExpressions, transformedAggregateFilter) } else { - filter + None } } else { - filter + None } } catch { - // Attempting to resolve in the aggregate can result in ambiguity. When this happens, - // just return the original plan. - case ae: AnalysisException => filter + // Attempting to resolve in the aggregate can result in ambiguity. When this happens, + // just return None and the caller side will return the original plan. + case ae: AnalysisException => None + } + } + + def resolveHaving(filter: Filter, agg: Aggregate): LogicalPlan = { + // Try resolving the condition of the filter as though it is in the aggregate clause + val resolvedInfo = resolveFilterCondInAggregate(filter.condition, agg) + + // Push the aggregate expressions into the aggregate (if any). + if (resolvedInfo.nonEmpty) { + val (aggregateExpressions, resolvedHavingCond) = resolvedInfo.get + Project(agg.output, + Filter(resolvedHavingCond, + agg.copy(aggregateExpressions = agg.aggregateExpressions ++ aggregateExpressions))) + } else { + filter } } } @@ -2631,12 +2715,12 @@ class Analyzer( case Filter(condition, _) if hasWindowFunction(condition) => failAnalysis("It is not allowed to use window functions inside WHERE clause") - case AggregateWithHaving(condition, _) if hasWindowFunction(condition) => + case UnresolvedHaving(condition, _) if hasWindowFunction(condition) => failAnalysis("It is not allowed to use window functions inside HAVING clause") // Aggregate with Having clause. This rule works with an unresolved Aggregate because // a resolved Aggregate will not have Window Functions. - case f @ AggregateWithHaving(condition, a @ Aggregate(groupingExprs, aggregateExprs, child)) + case f @ UnresolvedHaving(condition, a @ Aggregate(groupingExprs, aggregateExprs, child)) if child.resolved && hasWindowFunction(aggregateExprs) && a.expressions.forall(_.resolved) => @@ -3095,15 +3179,29 @@ class Analyzer( case p => p transformExpressions { case u @ UpCast(child, _, _) if !child.resolved => u - case UpCast(child, dt: AtomicType, _) + case UpCast(_, target, _) if target != DecimalType && !target.isInstanceOf[DataType] => + throw new AnalysisException( + s"UpCast only support DecimalType as AbstractDataType yet, but got: $target") + + case UpCast(child, target, walkedTypePath) if target == DecimalType + && child.dataType.isInstanceOf[DecimalType] => + assert(walkedTypePath.nonEmpty, + "object DecimalType should only be used inside ExpressionEncoder") + + // SPARK-31750: if we want to upcast to the general decimal type, and the `child` is + // already decimal type, we can remove the `Upcast` and accept any precision/scale. + // This can happen for cases like `spark.read.parquet("/tmp/file").as[BigDecimal]`. + child + + case UpCast(child, target: AtomicType, _) if SQLConf.get.getConf(SQLConf.LEGACY_LOOSE_UPCAST) && child.dataType == StringType => - Cast(child, dt.asNullable) + Cast(child, target.asNullable) - case UpCast(child, dataType, walkedTypePath) if !Cast.canUpCast(child.dataType, dataType) => - fail(child, dataType, walkedTypePath) + case u @ UpCast(child, _, walkedTypePath) if !Cast.canUpCast(child.dataType, u.dataType) => + fail(child, u.dataType, walkedTypePath) - case UpCast(child, dataType, _) => Cast(child, dataType.asNullable) + case u @ UpCast(child, _, _) => Cast(child, u.dataType.asNullable) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index c7d0eba0964cc..e2559d4c07297 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -253,7 +253,7 @@ object FunctionRegistry { expression[Log2]("log2"), expression[Log]("ln"), expression[Remainder]("mod", true), - expression[UnaryMinus]("negative"), + expression[UnaryMinus]("negative", true), expression[Pi]("pi"), expression[Pmod]("pmod"), expression[UnaryPositive]("positive"), @@ -339,7 +339,7 @@ object FunctionRegistry { expression[GetJsonObject]("get_json_object"), expression[InitCap]("initcap"), expression[StringInstr]("instr"), - expression[Lower]("lcase"), + expression[Lower]("lcase", true), expression[Length]("length"), expression[Levenshtein]("levenshtein"), expression[Like]("like"), @@ -350,7 +350,7 @@ object FunctionRegistry { expression[StringTrimLeft]("ltrim"), expression[JsonTuple]("json_tuple"), expression[ParseUrl]("parse_url"), - expression[StringLocate]("position"), + expression[StringLocate]("position", true), expression[FormatString]("printf", true), expression[RegExpExtract]("regexp_extract"), expression[RegExpReplace]("regexp_replace"), @@ -424,6 +424,9 @@ object FunctionRegistry { expression[MakeInterval]("make_interval"), expression[DatePart]("date_part"), expression[Extract]("extract"), + expression[SecondsToTimestamp]("timestamp_seconds"), + expression[MillisToTimestamp]("timestamp_millis"), + expression[MicrosToTimestamp]("timestamp_micros"), // collection functions expression[CreateArray]("array"), @@ -488,6 +491,7 @@ object FunctionRegistry { expression[InputFileBlockLength]("input_file_block_length"), expression[MonotonicallyIncreasingID]("monotonically_increasing_id"), expression[CurrentDatabase]("current_database"), + expression[CurrentCatalog]("current_catalog"), expression[CallMethodViaReflection]("reflect"), expression[CallMethodViaReflection]("java_method", true), expression[SparkVersion]("version"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index d970bf466fb81..3484108a5503f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -61,6 +61,7 @@ object TypeCoercion { IfCoercion :: StackCoercion :: Division :: + IntegralDivision :: ImplicitTypeCasts :: DateTimeOperations :: WindowFrameCoercion :: @@ -684,6 +685,23 @@ object TypeCoercion { } } + /** + * The DIV operator always returns long-type value. + * This rule cast the integral inputs to long type, to avoid overflow during calculation. + */ + object IntegralDivision extends TypeCoercionRule { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + case e if !e.childrenResolved => e + case d @ IntegralDivide(left, right) => + IntegralDivide(mayCastToLong(left), mayCastToLong(right)) + } + + private def mayCastToLong(expr: Expression): Expression = expr.dataType match { + case _: ByteType | _: ShortType | _: IntegerType => Cast(expr, LongType) + case _ => expr + } + } + /** * Coerces the type of different branches of a CASE WHEN statement to a common type. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 806cdeb95cca4..b28be042c43f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -540,11 +540,12 @@ case class UnresolvedOrdinal(ordinal: Int) } /** - * Represents unresolved aggregate with having clause, it is turned by the analyzer into a Filter. + * Represents unresolved having clause, the child for it can be Aggregate, GroupingSets, Rollup + * and Cube. It is turned by the analyzer into a Filter. */ -case class AggregateWithHaving( +case class UnresolvedHaving( havingCondition: Expression, - child: Aggregate) + child: LogicalPlan) extends UnaryNode { override lazy val resolved: Boolean = false override def output: Seq[Attribute] = child.output diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index d02776b5d86f8..4e63ee7428d72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -525,7 +525,7 @@ object CatalogColumnStat extends Logging { TimestampFormatter( format = "yyyy-MM-dd HH:mm:ss.SSSSSS", zoneId = ZoneOffset.UTC, - needVarLengthSecondFraction = isParsing) + isParsing = isParsing) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala index a7c243537acb7..f0df18da8eed6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala @@ -35,7 +35,7 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable { options.zoneId, options.locale, legacyFormat = FAST_DATE_FORMAT, - needVarLengthSecondFraction = true) + isParsing = true) private val decimalParser = if (options.locale == Locale.US) { // Special handling the default locale for backward compatibility diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala index 4990da2bf3797..a3ee129cd6d64 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala @@ -47,12 +47,13 @@ class UnivocityGenerator( options.zoneId, options.locale, legacyFormat = FAST_DATE_FORMAT, - needVarLengthSecondFraction = false) + isParsing = false) private val dateFormatter = DateFormatter( options.dateFormat, options.zoneId, options.locale, - legacyFormat = FAST_DATE_FORMAT) + legacyFormat = FAST_DATE_FORMAT, + isParsing = false) private def makeConverter(dataType: DataType): ValueConverter = dataType match { case DateType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index 8e87a82769471..3898eca79478e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -85,17 +85,18 @@ class UnivocityParser( // We preallocate it avoid unnecessary allocations. private val noRows = None - private val timestampFormatter = TimestampFormatter( + private lazy val timestampFormatter = TimestampFormatter( options.timestampFormat, options.zoneId, options.locale, legacyFormat = FAST_DATE_FORMAT, - needVarLengthSecondFraction = true) - private val dateFormatter = DateFormatter( + isParsing = true) + private lazy val dateFormatter = DateFormatter( options.dateFormat, options.zoneId, options.locale, - legacyFormat = FAST_DATE_FORMAT) + legacyFormat = FAST_DATE_FORMAT, + isParsing = true) private val csvFilters = new CSVFilters(filters, requiredSchema) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index f135f50493ed8..26f5bee72092c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} -import java.time.LocalDate +import java.time.{Instant, LocalDate} import scala.language.implicitConversions @@ -152,6 +152,7 @@ package object dsl { implicit def bigDecimalToLiteral(d: java.math.BigDecimal): Literal = Literal(d) implicit def decimalToLiteral(d: Decimal): Literal = Literal(d) implicit def timestampToLiteral(t: Timestamp): Literal = Literal(t) + implicit def instantToLiteral(i: Instant): Literal = Literal(i) implicit def binaryToLiteral(a: Array[Byte]): Literal = Literal(a) implicit def symbolToUnresolvedAttribute(s: Symbol): analysis.UnresolvedAttribute = @@ -368,7 +369,7 @@ package object dsl { groupingExprs: Expression*)( aggregateExprs: Expression*)( havingCondition: Expression): LogicalPlan = { - AggregateWithHaving(havingCondition, + UnresolvedHaving(havingCondition, groupBy(groupingExprs: _*)(aggregateExprs: _*).asInstanceOf[Aggregate]) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index fa615d71a61a0..ef70915a5c969 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -511,7 +511,14 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit case DateType => buildCast[Int](_, d => null) case TimestampType if ansiEnabled => - buildCast[Long](_, t => LongExactNumeric.toInt(timestampToLong(t))) + buildCast[Long](_, t => { + val longValue = timestampToLong(t) + if (longValue == longValue.toInt) { + longValue.toInt + } else { + throw new ArithmeticException(s"Casting $t to int causes overflow") + } + }) case TimestampType => buildCast[Long](_, t => timestampToLong(t).toInt) case x: NumericType if ansiEnabled => @@ -1735,8 +1742,16 @@ case class AnsiCast(child: Expression, dataType: DataType, timeZoneId: Option[St /** * Cast the child expression to the target data type, but will throw error if the cast might * truncate, e.g. long -> int, timestamp -> data. + * + * Note: `target` is `AbstractDataType`, so that we can put `object DecimalType`, which means + * we accept `DecimalType` with any valid precision/scale. */ -case class UpCast(child: Expression, dataType: DataType, walkedTypePath: Seq[String] = Nil) +case class UpCast(child: Expression, target: AbstractDataType, walkedTypePath: Seq[String] = Nil) extends UnaryExpression with Unevaluable { override lazy val resolved = false + + def dataType: DataType = target match { + case DecimalType => DecimalType.SYSTEM_DEFAULT + case _ => target.asInstanceOf[DataType] + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index f29ece2e03b08..18cc648e57d71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -323,6 +324,19 @@ trait RuntimeReplaceable extends UnaryExpression with Unevaluable { // two `RuntimeReplaceable` are considered to be semantically equal if their "child" expressions // are semantically equal. override lazy val canonicalized: Expression = child.canonicalized + + /** + * Only used to generate SQL representation of this expression. + * + * Implementations should override this with original parameters + */ + def exprsReplaced: Seq[Expression] + + override def sql: String = mkString(exprsReplaced.map(_.sql)) + + def mkString(childrenString: Seq[String]): String = { + prettyName + childrenString.mkString("(", ", ", ")") + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index caacb71814f17..f7fe467cea830 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -63,6 +63,7 @@ case class TimeWindow( override def dataType: DataType = new StructType() .add(StructField("start", TimestampType)) .add(StructField("end", TimestampType)) + override def prettyName: String = "window" // This expression is replaced in the analyzer. override lazy val resolved = false @@ -143,7 +144,7 @@ object TimeWindow { case class PreciseTimestampConversion( child: Expression, fromType: DataType, - toType: DataType) extends UnaryExpression with ExpectsInputTypes { + toType: DataType) extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(fromType) override def dataType: DataType = toType override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala index 2e202240923c3..6d3d3dafe16e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala @@ -31,7 +31,6 @@ abstract class MaxMinBy extends DeclarativeAggregate { def valueExpr: Expression def orderingExpr: Expression - protected def funcName: String // The predicate compares two ordering values. protected def predicate(oldExpr: Expression, newExpr: Expression): Expression // The arithmetic expression returns greatest/least value of all parameters. @@ -46,7 +45,7 @@ abstract class MaxMinBy extends DeclarativeAggregate { override def dataType: DataType = valueExpr.dataType override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForOrderingExpr(orderingExpr.dataType, s"function $funcName") + TypeUtils.checkForOrderingExpr(orderingExpr.dataType, s"function $prettyName") // The attributes used to keep extremum (max or min) and associated aggregated values. private lazy val extremumOrdering = @@ -101,7 +100,8 @@ abstract class MaxMinBy extends DeclarativeAggregate { group = "agg_funcs", since = "3.0.0") case class MaxBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMinBy { - override protected def funcName: String = "max_by" + + override def prettyName: String = "max_by" override protected def predicate(oldExpr: Expression, newExpr: Expression): Expression = oldExpr > newExpr @@ -120,7 +120,8 @@ case class MaxBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMin group = "agg_funcs", since = "3.0.0") case class MinBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMinBy { - override protected def funcName: String = "min_by" + + override def prettyName: String = "min_by" override protected def predicate(oldExpr: Expression, newExpr: Expression): Expression = oldExpr < newExpr diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index d2daaac72fc85..6e850267100fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -62,38 +62,74 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast private lazy val sum = AttributeReference("sum", sumDataType)() + private lazy val isEmpty = AttributeReference("isEmpty", BooleanType, nullable = false)() + private lazy val zero = Literal.default(sumDataType) - override lazy val aggBufferAttributes = sum :: Nil + override lazy val aggBufferAttributes = resultType match { + case _: DecimalType => sum :: isEmpty :: Nil + case _ => sum :: Nil + } - override lazy val initialValues: Seq[Expression] = Seq( - /* sum = */ Literal.create(null, sumDataType) - ) + override lazy val initialValues: Seq[Expression] = resultType match { + case _: DecimalType => Seq(Literal(null, resultType), Literal(true, BooleanType)) + case _ => Seq(Literal(null, resultType)) + } override lazy val updateExpressions: Seq[Expression] = { if (child.nullable) { - Seq( - /* sum = */ - coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) - ) + val updateSumExpr = coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) + resultType match { + case _: DecimalType => + Seq(updateSumExpr, isEmpty && child.isNull) + case _ => Seq(updateSumExpr) + } } else { - Seq( - /* sum = */ - coalesce(sum, zero) + child.cast(sumDataType) - ) + val updateSumExpr = coalesce(sum, zero) + child.cast(sumDataType) + resultType match { + case _: DecimalType => + Seq(updateSumExpr, Literal(false, BooleanType)) + case _ => Seq(updateSumExpr) + } } } + /** + * For decimal type: + * If isEmpty is false and if sum is null, then it means we have had an overflow. + * + * update of the sum is as follows: + * Check if either portion of the left.sum or right.sum has overflowed + * If it has, then the sum value will remain null. + * If it did not have overflow, then add the sum.left and sum.right + * + * isEmpty: Set to false if either one of the left or right is set to false. This + * means we have seen atleast a value that was not null. + */ override lazy val mergeExpressions: Seq[Expression] = { - Seq( - /* sum = */ - coalesce(coalesce(sum.left, zero) + sum.right, sum.left) - ) + val mergeSumExpr = coalesce(coalesce(sum.left, zero) + sum.right, sum.left) + resultType match { + case _: DecimalType => + val inputOverflow = !isEmpty.right && sum.right.isNull + val bufferOverflow = !isEmpty.left && sum.left.isNull + Seq( + If(inputOverflow || bufferOverflow, Literal.create(null, sumDataType), mergeSumExpr), + isEmpty.left && isEmpty.right) + case _ => Seq(mergeSumExpr) + } } + /** + * If the isEmpty is true, then it means there were no values to begin with or all the values + * were null, so the result will be null. + * If the isEmpty is false, then if sum is null that means an overflow has happened. + * So now, if ansi is enabled, then throw exception, if not then return null. + * If sum is not null, then return the sum. + */ override lazy val evaluateExpression: Expression = resultType match { - case d: DecimalType => CheckOverflow(sum, d, !SQLConf.get.ansiEnabled) + case d: DecimalType => + If(isEmpty, Literal.create(null, sumDataType), + CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled)) case _ => sum } - } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 6a64819aabb48..7c521838447d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -86,7 +86,12 @@ case class UnaryMinus(child: Expression) extends UnaryExpression case _ => numeric.negate(input) } - override def sql: String = s"(- ${child.sql})" + override def sql: String = { + getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("-") match { + case "-" => s"(- ${child.sql})" + case funcName => s"$funcName(${child.sql})" + } + } } @ExpressionDescription( @@ -407,7 +412,7 @@ case class IntegralDivide( left: Expression, right: Expression) extends DivModLike { - override def inputType: AbstractDataType = TypeCollection(IntegralType, DecimalType) + override def inputType: AbstractDataType = TypeCollection(LongType, DecimalType) override def dataType: DataType = LongType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index 72a8f7e99729b..342b14eaa3390 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -127,7 +127,8 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme > SELECT _FUNC_ 0; -1 """) -case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class BitwiseNot(child: Expression) + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType) @@ -164,7 +165,8 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp 0 """, since = "3.0.0") -case class BitwiseCount(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class BitwiseCount(child: Expression) + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegralType, BooleanType)) @@ -172,6 +174,8 @@ case class BitwiseCount(child: Expression) extends UnaryExpression with ExpectsI override def toString: String = s"bit_count($child)" + override def prettyName: String = "bit_count" + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = child.dataType match { case BooleanType => defineCodeGen(ctx, ev, c => s"if ($c) 1 else 0") case _ => defineCodeGen(ctx, ev, c => s"java.lang.Long.bitCount($c)") @@ -184,6 +188,4 @@ case class BitwiseCount(child: Expression) extends UnaryExpression with ExpectsI case IntegerType => java.lang.Long.bitCount(input.asInstanceOf[Int]) case LongType => java.lang.Long.bitCount(input.asInstanceOf[Long]) } - - override def sql: String = s"bit_count(${child.sql})" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 1cc7836e93d35..817dd948f1a6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -27,8 +27,9 @@ import scala.util.control.NonFatal import com.google.common.cache.{CacheBuilder, CacheLoader} import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException} -import org.codehaus.commons.compiler.CompileException -import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, InternalCompilerException, SimpleCompiler} +import org.codehaus.commons.compiler.{CompileException, InternalCompilerException} +import org.codehaus.commons.compiler.util.reflect.ByteArrayClassLoader +import org.codehaus.janino.{ClassBodyEvaluator, SimpleCompiler} import org.codehaus.janino.util.ClassFile import org.apache.spark.{TaskContext, TaskKilledException} @@ -1419,9 +1420,10 @@ object CodeGenerator extends Logging { private def updateAndGetCompilationStats(evaluator: ClassBodyEvaluator): ByteCodeStats = { // First retrieve the generated classes. val classes = { - val resultField = classOf[SimpleCompiler].getDeclaredField("result") - resultField.setAccessible(true) - val loader = resultField.get(evaluator).asInstanceOf[ByteArrayClassLoader] + val scField = classOf[ClassBodyEvaluator].getDeclaredField("sc") + scField.setAccessible(true) + val compiler = scField.get(evaluator).asInstanceOf[SimpleCompiler] + val loader = compiler.getClassLoader.asInstanceOf[ByteArrayClassLoader] val classesField = loader.getClass.getDeclaredField("classes") classesField.setAccessible(true) classesField.get(loader).asInstanceOf[JavaMap[String, Array[Byte]]].asScala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 4fd68dcfe5156..b32e9ee05f1ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -141,7 +141,7 @@ object Size { """, group = "map_funcs") case class MapKeys(child: Expression) - extends UnaryExpression with ExpectsInputTypes { + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(MapType) @@ -332,7 +332,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI """, group = "map_funcs") case class MapValues(child: Expression) - extends UnaryExpression with ExpectsInputTypes { + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(MapType) @@ -361,7 +361,8 @@ case class MapValues(child: Expression) """, group = "map_funcs", since = "3.0.0") -case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class MapEntries(child: Expression) + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(MapType) @@ -649,7 +650,7 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres """, group = "map_funcs", since = "2.4.0") -case class MapFromEntries(child: Expression) extends UnaryExpression { +case class MapFromEntries(child: Expression) extends UnaryExpression with NullIntolerant { @transient private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = child.dataType match { @@ -873,7 +874,7 @@ object ArraySortLike { group = "array_funcs") // scalastyle:on line.size.limit case class SortArray(base: Expression, ascendingOrder: Expression) - extends BinaryExpression with ArraySortLike { + extends BinaryExpression with ArraySortLike with NullIntolerant { def this(e: Expression) = this(e, Literal(true)) @@ -1017,7 +1018,8 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None) Reverse logic for arrays is available since 2.4.0. """ ) -case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Reverse(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { // Input types are utilized by type coercion in ImplicitTypeCasts. override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType)) @@ -1086,7 +1088,7 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI """, group = "array_funcs") case class ArrayContains(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = BooleanType @@ -1185,7 +1187,7 @@ case class ArrayContains(left: Expression, right: Expression) since = "2.4.0") // scalastyle:off line.size.limit case class ArraysOverlap(left: Expression, right: Expression) - extends BinaryArrayExpressionWithImplicitCast { + extends BinaryArrayExpressionWithImplicitCast with NullIntolerant { override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match { case TypeCheckResult.TypeCheckSuccess => @@ -1410,7 +1412,7 @@ case class ArraysOverlap(left: Expression, right: Expression) since = "2.4.0") // scalastyle:on line.size.limit case class Slice(x: Expression, start: Expression, length: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = x.dataType @@ -1688,7 +1690,8 @@ case class ArrayJoin( """, group = "array_funcs", since = "2.4.0") -case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class ArrayMin(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def nullable: Boolean = true @@ -1755,7 +1758,8 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast """, group = "array_funcs", since = "2.4.0") -case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class ArrayMax(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def nullable: Boolean = true @@ -1831,7 +1835,7 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast group = "array_funcs", since = "2.4.0") case class ArrayPosition(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(right.dataType) @@ -1909,7 +1913,7 @@ case class ArrayPosition(left: Expression, right: Expression) """, since = "2.4.0") case class ElementAt(left: Expression, right: Expression) - extends GetMapValueUtil with GetArrayItemUtil { + extends GetMapValueUtil with GetArrayItemUtil with NullIntolerant { @transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType @@ -2245,7 +2249,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio """, group = "array_funcs", since = "2.4.0") -case class Flatten(child: Expression) extends UnaryExpression { +case class Flatten(child: Expression) extends UnaryExpression with NullIntolerant { private def childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType] @@ -2884,7 +2888,7 @@ case class ArrayRepeat(left: Expression, right: Expression) group = "array_funcs", since = "2.4.0") case class ArrayRemove(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = left.dataType @@ -3081,7 +3085,7 @@ trait ArraySetLike { group = "array_funcs", since = "2.4.0") case class ArrayDistinct(child: Expression) - extends UnaryExpression with ArraySetLike with ExpectsInputTypes { + extends UnaryExpression with ArraySetLike with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) @@ -3219,7 +3223,8 @@ case class ArrayDistinct(child: Expression) /** * Will become common base class for [[ArrayUnion]], [[ArrayIntersect]], and [[ArrayExcept]]. */ -trait ArrayBinaryLike extends BinaryArrayExpressionWithImplicitCast with ArraySetLike { +trait ArrayBinaryLike + extends BinaryArrayExpressionWithImplicitCast with ArraySetLike with NullIntolerant { override protected def dt: DataType = dataType override protected def et: DataType = elementType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 858c91a4d8e86..1b4a705e804f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util._ @@ -255,7 +255,7 @@ object CreateMap { {1.0:"2",3.0:"4"} """, since = "2.4.0") case class MapFromArrays(left: Expression, right: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) @@ -311,7 +311,12 @@ case object NamePlaceholder extends LeafExpression with Unevaluable { /** * Returns a Row containing the evaluation of all children expressions. */ -object CreateStruct extends FunctionBuilder { +object CreateStruct { + /** + * Returns a named struct with generated names or using the names when available. + * It should not be used for `struct` expressions or functions explicitly called + * by users. + */ def apply(children: Seq[Expression]): CreateNamedStruct = { CreateNamedStruct(children.zipWithIndex.flatMap { case (e: NamedExpression, _) if e.resolved => Seq(Literal(e.name), e) @@ -320,12 +325,23 @@ object CreateStruct extends FunctionBuilder { }) } + /** + * Returns a named struct with a pretty SQL name. It will show the pretty SQL string + * in its output column name as if `struct(...)` was called. Should be + * used for `struct` expressions or functions explicitly called by users. + */ + def create(children: Seq[Expression]): CreateNamedStruct = { + val expr = CreateStruct(children) + expr.setTagValue(FUNC_ALIAS, "struct") + expr + } + /** * Entry to use in the function registry. */ val registryEntry: (String, (ExpressionInfo, FunctionBuilder)) = { val info: ExpressionInfo = new ExpressionInfo( - "org.apache.spark.sql.catalyst.expressions.NamedStruct", + classOf[CreateNamedStruct].getCanonicalName, null, "struct", "_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.", @@ -335,7 +351,7 @@ object CreateStruct extends FunctionBuilder { "", "", "") - ("struct", (info, this)) + ("struct", (info, this.create)) } } @@ -433,7 +449,15 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { """.stripMargin, isNull = FalseLiteral) } - override def prettyName: String = "named_struct" + // There is an alias set at `CreateStruct.create`. If there is an alias, + // this is the struct function explicitly called by a user and we should + // respect it in the SQL string as `struct(...)`. + override def prettyName: String = getTagValue(FUNC_ALIAS).getOrElse("named_struct") + + override def sql: String = getTagValue(FUNC_ALIAS).map { alias => + val childrenSQL = children.indices.filter(_ % 2 == 1).map(children(_).sql).mkString(", ") + s"$alias($childrenSQL)" + }.getOrElse(super.sql) } /** @@ -452,7 +476,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { since = "2.0.1") // scalastyle:on line.size.limit case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: Expression) - extends TernaryExpression with ExpectsInputTypes { + extends TernaryExpression with ExpectsInputTypes with NullIntolerant { def this(child: Expression, pairDelim: Expression) = { this(child, pairDelim, Literal(":")) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index 5140db90c5954..f9ccf3c8c811f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -211,7 +211,8 @@ case class StructsToCsv( options: Map[String, String], child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { + extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes + with NullIntolerant { override def nullable: Boolean = true def this(options: Map[String, String], child: Expression) = this(options, child, None) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 7dfa5fa0bf841..c5ead9412a438 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.util.{DateTimeUtils, LegacyDateFormats, Tim import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.catalyst.util.LegacyDateFormats.SIMPLE_DATE_FORMAT +import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -197,7 +198,7 @@ case class CurrentBatchTimestamp( group = "datetime_funcs", since = "1.5.0") case class DateAdd(startDate: Expression, days: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ExpectsInputTypes with NullIntolerant { override def left: Expression = startDate override def right: Expression = days @@ -233,7 +234,7 @@ case class DateAdd(startDate: Expression, days: Expression) group = "datetime_funcs", since = "1.5.0") case class DateSub(startDate: Expression, days: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ExpectsInputTypes with NullIntolerant { override def left: Expression = startDate override def right: Expression = days @@ -265,7 +266,8 @@ case class DateSub(startDate: Expression, days: Expression) group = "datetime_funcs", since = "1.5.0") case class Hour(child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes + with NullIntolerant { def this(child: Expression) = this(child, None) @@ -297,7 +299,8 @@ case class Hour(child: Expression, timeZoneId: Option[String] = None) group = "datetime_funcs", since = "1.5.0") case class Minute(child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes + with NullIntolerant { def this(child: Expression) = this(child, None) @@ -329,7 +332,8 @@ case class Minute(child: Expression, timeZoneId: Option[String] = None) group = "datetime_funcs", since = "1.5.0") case class Second(child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes + with NullIntolerant { def this(child: Expression) = this(child, None) @@ -352,7 +356,8 @@ case class Second(child: Expression, timeZoneId: Option[String] = None) } case class SecondWithFraction(child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes + with NullIntolerant { def this(child: Expression) = this(child, None) @@ -384,7 +389,8 @@ case class SecondWithFraction(child: Expression, timeZoneId: Option[String] = No """, group = "datetime_funcs", since = "1.5.0") -case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class DayOfYear(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -400,6 +406,83 @@ case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCas } } +abstract class NumberToTimestampBase extends UnaryExpression + with ExpectsInputTypes with NullIntolerant { + + protected def upScaleFactor: Long + + override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType) + + override def dataType: DataType = TimestampType + + override def nullSafeEval(input: Any): Any = { + Math.multiplyExact(input.asInstanceOf[Number].longValue(), upScaleFactor) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + if (upScaleFactor == 1) { + defineCodeGen(ctx, ev, c => c) + } else { + defineCodeGen(ctx, ev, c => s"java.lang.Math.multiplyExact($c, ${upScaleFactor}L)") + } + } +} + +@ExpressionDescription( + usage = "_FUNC_(seconds) - Creates timestamp from the number of seconds since UTC epoch.", + examples = """ + Examples: + > SELECT _FUNC_(1230219000); + 2008-12-25 07:30:00 + """, + group = "datetime_funcs", + since = "3.1.0") +case class SecondsToTimestamp(child: Expression) + extends NumberToTimestampBase { + + override def upScaleFactor: Long = MICROS_PER_SECOND + + override def prettyName: String = "timestamp_seconds" +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(milliseconds) - Creates timestamp from the number of milliseconds since UTC epoch.", + examples = """ + Examples: + > SELECT _FUNC_(1230219000123); + 2008-12-25 07:30:00.123 + """, + group = "datetime_funcs", + since = "3.1.0") +// scalastyle:on line.size.limit +case class MillisToTimestamp(child: Expression) + extends NumberToTimestampBase { + + override def upScaleFactor: Long = MICROS_PER_MILLIS + + override def prettyName: String = "timestamp_millis" +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(microseconds) - Creates timestamp from the number of microseconds since UTC epoch.", + examples = """ + Examples: + > SELECT _FUNC_(1230219000123123); + 2008-12-25 07:30:00.123123 + """, + group = "datetime_funcs", + since = "3.1.0") +// scalastyle:on line.size.limit +case class MicrosToTimestamp(child: Expression) + extends NumberToTimestampBase { + + override def upScaleFactor: Long = 1L + + override def prettyName: String = "timestamp_micros" +} + @ExpressionDescription( usage = "_FUNC_(date) - Returns the year component of the date/timestamp.", examples = """ @@ -409,7 +492,8 @@ case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCas """, group = "datetime_funcs", since = "1.5.0") -case class Year(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Year(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -425,7 +509,8 @@ case class Year(child: Expression) extends UnaryExpression with ImplicitCastInpu } } -case class YearOfWeek(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class YearOfWeek(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -450,7 +535,8 @@ case class YearOfWeek(child: Expression) extends UnaryExpression with ImplicitCa """, group = "datetime_funcs", since = "1.5.0") -case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Quarter(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -475,7 +561,8 @@ case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastI """, group = "datetime_funcs", since = "1.5.0") -case class Month(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Month(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -499,7 +586,8 @@ case class Month(child: Expression) extends UnaryExpression with ImplicitCastInp 30 """, since = "1.5.0") -case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class DayOfMonth(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -569,7 +657,7 @@ case class WeekDay(child: Expression) extends DayWeek { } } -abstract class DayWeek extends UnaryExpression with ImplicitCastInputTypes { +abstract class DayWeek extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -587,7 +675,8 @@ abstract class DayWeek extends UnaryExpression with ImplicitCastInputTypes { group = "datetime_funcs", since = "1.5.0") // scalastyle:on line.size.limit -case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class WeekOfYear(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -626,7 +715,8 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa since = "1.5.0") // scalastyle:on line.size.limit case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Option[String] = None) - extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes + with NullIntolerant { def this(left: Expression, right: Expression) = this(left, right, None) @@ -644,7 +734,7 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti format.toString, zoneId, legacyFormat = SIMPLE_DATE_FORMAT, - needVarLengthSecondFraction = false) + isParsing = false) } } else None } @@ -655,7 +745,7 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti format.toString, zoneId, legacyFormat = SIMPLE_DATE_FORMAT, - needVarLengthSecondFraction = false) + isParsing = false) } else { formatter.get } @@ -800,8 +890,9 @@ abstract class ToTimestamp constFormat.toString, zoneId, legacyFormat = SIMPLE_DATE_FORMAT, - needVarLengthSecondFraction = true) + isParsing = true) } catch { + case e: SparkUpgradeException => throw e case NonFatal(_) => null } @@ -838,7 +929,7 @@ abstract class ToTimestamp formatString, zoneId, legacyFormat = SIMPLE_DATE_FORMAT, - needVarLengthSecondFraction = true) + isParsing = true) .parse(t.asInstanceOf[UTF8String].toString) / downScaleFactor } catch { case e: SparkUpgradeException => throw e @@ -981,8 +1072,9 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ constFormat.toString, zoneId, legacyFormat = SIMPLE_DATE_FORMAT, - needVarLengthSecondFraction = false) + isParsing = false) } catch { + case e: SparkUpgradeException => throw e case NonFatal(_) => null } @@ -998,6 +1090,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ try { UTF8String.fromString(formatter.format(time.asInstanceOf[Long] * MICROS_PER_SECOND)) } catch { + case e: SparkUpgradeException => throw e case NonFatal(_) => null } } @@ -1012,9 +1105,10 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ f.toString, zoneId, legacyFormat = SIMPLE_DATE_FORMAT, - needVarLengthSecondFraction = false) + isParsing = false) .format(time.asInstanceOf[Long] * MICROS_PER_SECOND)) } catch { + case e: SparkUpgradeException => throw e case NonFatal(_) => null } } @@ -1072,7 +1166,8 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ """, group = "datetime_funcs", since = "1.5.0") -case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class LastDay(startDate: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def child: Expression = startDate override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -1110,7 +1205,7 @@ case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitC since = "1.5.0") // scalastyle:on line.size.limit case class NextDay(startDate: Expression, dayOfWeek: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def left: Expression = startDate override def right: Expression = dayOfWeek @@ -1166,7 +1261,7 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression) * Adds an interval to timestamp. */ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[String] = None) - extends BinaryExpression with TimeZoneAwareExpression with ExpectsInputTypes { + extends BinaryExpression with TimeZoneAwareExpression with ExpectsInputTypes with NullIntolerant { def this(start: Expression, interval: Expression) = this(start, interval, None) @@ -1205,8 +1300,9 @@ case class DatetimeSub( start: Expression, interval: Expression, child: Expression) extends RuntimeReplaceable { + override def exprsReplaced: Seq[Expression] = Seq(start, interval) override def toString: String = s"$start - $interval" - override def sql: String = s"${start.sql} - ${interval.sql}" + override def mkString(childrenString: Seq[String]): String = childrenString.mkString(" - ") } /** @@ -1223,7 +1319,7 @@ case class DateAddInterval( interval: Expression, timeZoneId: Option[String] = None, ansiEnabled: Boolean = SQLConf.get.ansiEnabled) - extends BinaryExpression with ExpectsInputTypes with TimeZoneAwareExpression { + extends BinaryExpression with ExpectsInputTypes with TimeZoneAwareExpression with NullIntolerant { override def left: Expression = start override def right: Expression = interval @@ -1297,7 +1393,7 @@ case class DateAddInterval( since = "1.5.0") // scalastyle:on line.size.limit case class FromUTCTimestamp(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType) override def dataType: DataType = TimestampType @@ -1357,7 +1453,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) since = "1.5.0") // scalastyle:on line.size.limit case class AddMonths(startDate: Expression, numMonths: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def left: Expression = startDate override def right: Expression = numMonths @@ -1411,7 +1507,8 @@ case class MonthsBetween( date2: Expression, roundOff: Expression, timeZoneId: Option[String] = None) - extends TernaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + extends TernaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes + with NullIntolerant { def this(date1: Expression, date2: Expression) = this(date1, date2, Literal.TrueLiteral, None) @@ -1469,7 +1566,7 @@ case class MonthsBetween( since = "1.5.0") // scalastyle:on line.size.limit case class ToUTCTimestamp(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType) override def dataType: DataType = TimestampType @@ -1553,14 +1650,8 @@ case class ParseToDate(left: Expression, format: Option[Expression], child: Expr this(left, None, Cast(left, DateType)) } + override def exprsReplaced: Seq[Expression] = left +: format.toSeq override def flatArguments: Iterator[Any] = Iterator(left, format) - override def sql: String = { - if (format.isDefined) { - s"$prettyName(${left.sql}, ${format.get.sql})" - } else { - s"$prettyName(${left.sql})" - } - } override def prettyName: String = "to_date" } @@ -1601,13 +1692,7 @@ case class ParseToTimestamp(left: Expression, format: Option[Expression], child: def this(left: Expression) = this(left, None, Cast(left, TimestampType)) override def flatArguments: Iterator[Any] = Iterator(left, format) - override def sql: String = { - if (format.isDefined) { - s"$prettyName(${left.sql}, ${format.get.sql})" - } else { - s"$prettyName(${left.sql})" - } - } + override def exprsReplaced: Seq[Expression] = left +: format.toSeq override def prettyName: String = "to_timestamp" override def dataType: DataType = TimestampType @@ -1835,7 +1920,7 @@ case class TruncTimestamp( group = "datetime_funcs", since = "1.5.0") case class DateDiff(endDate: Expression, startDate: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def left: Expression = endDate override def right: Expression = startDate @@ -1889,7 +1974,7 @@ private case class GetTimestamp( group = "datetime_funcs", since = "3.0.0") case class MakeDate(year: Expression, month: Expression, day: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { override def children: Seq[Expression] = Seq(year, month, day) override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType, IntegerType, IntegerType) @@ -1960,7 +2045,8 @@ case class MakeTimestamp( sec: Expression, timezone: Option[Expression] = None, timeZoneId: Option[String] = None) - extends SeptenaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + extends SeptenaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes + with NullIntolerant { def this( year: Expression, @@ -2161,7 +2247,8 @@ case class DatePart(field: Expression, source: Expression, child: Expression) } override def flatArguments: Iterator[Any] = Iterator(field, source) - override def sql: String = s"$prettyName(${field.sql}, ${source.sql})" + override def exprsReplaced: Seq[Expression] = Seq(field, source) + override def prettyName: String = "date_part" } @@ -2221,8 +2308,12 @@ case class Extract(field: Expression, source: Expression, child: Expression) } override def flatArguments: Iterator[Any] = Iterator(field, source) - override def sql: String = s"$prettyName(${field.sql} FROM ${source.sql})" - override def prettyName: String = "extract" + + override def exprsReplaced: Seq[Expression] = Seq(field, source) + + override def mkString(childrenString: Seq[String]): String = { + prettyName + childrenString.mkString("(", " FROM ", ")") + } } /** @@ -2231,7 +2322,7 @@ case class Extract(field: Expression, source: Expression, child: Expression) * between the given timestamps. */ case class SubtractTimestamps(endTimestamp: Expression, startTimestamp: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ExpectsInputTypes with NullIntolerant { override def left: Expression = endTimestamp override def right: Expression = startTimestamp @@ -2252,7 +2343,7 @@ case class SubtractTimestamps(endTimestamp: Expression, startTimestamp: Expressi * Returns the interval from the `left` date (inclusive) to the `right` date (exclusive). */ case class SubtractDates(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType, DateType) override def dataType: DataType = CalendarIntervalType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 9014ebfe2f96a..7e4560ab8161b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -27,7 +28,7 @@ import org.apache.spark.sql.types._ * Note: this expression is internal and created only by the optimizer, * we don't need to do type check for it. */ -case class UnscaledValue(child: Expression) extends UnaryExpression { +case class UnscaledValue(child: Expression) extends UnaryExpression with NullIntolerant { override def dataType: DataType = LongType override def toString: String = s"UnscaledValue($child)" @@ -49,7 +50,7 @@ case class MakeDecimal( child: Expression, precision: Int, scale: Int, - nullOnOverflow: Boolean) extends UnaryExpression { + nullOnOverflow: Boolean) extends UnaryExpression with NullIntolerant { def this(child: Expression, precision: Int, scale: Int) = { this(child, precision, scale, !SQLConf.get.ansiEnabled) @@ -144,3 +145,54 @@ case class CheckOverflow( override def sql: String = child.sql } + +// A variant `CheckOverflow`, which treats null as overflow. This is necessary in `Sum`. +case class CheckOverflowInSum( + child: Expression, + dataType: DecimalType, + nullOnOverflow: Boolean) extends UnaryExpression { + + override def nullable: Boolean = true + + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + if (nullOnOverflow) null else throw new ArithmeticException("Overflow in sum of decimals.") + } else { + value.asInstanceOf[Decimal].toPrecision( + dataType.precision, + dataType.scale, + Decimal.ROUND_HALF_UP, + nullOnOverflow) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childGen = child.genCode(ctx) + val nullHandling = if (nullOnOverflow) { + "" + } else { + s""" + |throw new ArithmeticException("Overflow in sum of decimals."); + |""".stripMargin + } + val code = code""" + |${childGen.code} + |boolean ${ev.isNull} = ${childGen.isNull}; + |Decimal ${ev.value} = null; + |if (${childGen.isNull}) { + | $nullHandling + |} else { + | ${ev.value} = ${childGen.value}.toPrecision( + | ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow); + | ${ev.isNull} = ${ev.value} == null; + |} + |""".stripMargin + + ev.copy(code = code) + } + + override def toString: String = s"CheckOverflowInSum($child, $dataType, $nullOnOverflow)" + + override def sql: String = child.sql +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 4c8c58ae232f4..5e21b58f070ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -53,7 +53,8 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} > SELECT _FUNC_('Spark'); 8cde774d6f7333752ed72cacddb05126 """) -case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Md5(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = StringType @@ -89,7 +90,7 @@ case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInput """) // scalastyle:on line.size.limit case class Sha2(left: Expression, right: Expression) - extends BinaryExpression with Serializable with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { override def dataType: DataType = StringType override def nullable: Boolean = true @@ -160,7 +161,8 @@ case class Sha2(left: Expression, right: Expression) > SELECT _FUNC_('Spark'); 85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c """) -case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Sha1(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = StringType @@ -187,7 +189,8 @@ case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInpu > SELECT _FUNC_('Spark'); 1557323817 """) -case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Crc32(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = LongType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index 1a569a7b89fe1..baab224691bc1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -31,7 +31,7 @@ abstract class ExtractIntervalPart( val dataType: DataType, func: CalendarInterval => Any, funcName: String) - extends UnaryExpression with ExpectsInputTypes with Serializable { + extends UnaryExpression with ExpectsInputTypes with NullIntolerant with Serializable { override def inputTypes: Seq[AbstractDataType] = Seq(CalendarIntervalType) @@ -82,7 +82,7 @@ object ExtractIntervalPart { abstract class IntervalNumOperation( interval: Expression, num: Expression) - extends BinaryExpression with ImplicitCastInputTypes with Serializable { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { override def left: Expression = interval override def right: Expression = num @@ -160,7 +160,7 @@ case class MakeInterval( hours: Expression, mins: Expression, secs: Expression) - extends SeptenaryExpression with ImplicitCastInputTypes { + extends SeptenaryExpression with ImplicitCastInputTypes with NullIntolerant { def this( years: Expression, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 205e5271517c3..f4568f860ac0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -519,7 +519,8 @@ case class JsonToStructs( options: Map[String, String], child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { + extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes + with NullIntolerant { // The JSON input data might be missing certain fields. We force the nullability // of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder @@ -638,7 +639,8 @@ case class StructsToJson( options: Map[String, String], child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { + extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback + with ExpectsInputTypes with NullIntolerant { override def nullable: Boolean = true def this(options: Map[String, String], child: Expression) = this(options, child, None) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 66e6334e3a450..fe8ea2a3c6733 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -57,7 +57,7 @@ abstract class LeafMathExpression(c: Double, name: String) * @param name The short name of the function */ abstract class UnaryMathExpression(val f: Double => Double, name: String) - extends UnaryExpression with Serializable with ImplicitCastInputTypes { + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType) override def dataType: DataType = DoubleType @@ -111,7 +111,7 @@ abstract class UnaryLogExpression(f: Double => Double, name: String) * @param name The short name of the function */ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) - extends BinaryExpression with Serializable with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType) @@ -324,7 +324,7 @@ case class Acosh(child: Expression) -16 """) case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { override def children: Seq[Expression] = Seq(numExpr, fromBaseExpr, toBaseExpr) override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType) @@ -452,7 +452,8 @@ object Factorial { > SELECT _FUNC_(5); 120 """) -case class Factorial(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Factorial(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[DataType] = Seq(IntegerType) @@ -491,7 +492,9 @@ case class Factorial(child: Expression) extends UnaryExpression with ImplicitCas > SELECT _FUNC_(1); 0.0 """) -case class Log(child: Expression) extends UnaryLogExpression(StrictMath.log, "LOG") +case class Log(child: Expression) extends UnaryLogExpression(StrictMath.log, "LOG") { + override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("ln") +} @ExpressionDescription( usage = "_FUNC_(expr) - Returns the logarithm of `expr` with base 2.", @@ -546,6 +549,7 @@ case class Log1p(child: Expression) extends UnaryLogExpression(StrictMath.log1p, // scalastyle:on line.size.limit case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") { override def funcName: String = "rint" + override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("rint") } @ExpressionDescription( @@ -732,7 +736,7 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia """) // scalastyle:on line.size.limit case class Bin(child: Expression) - extends UnaryExpression with Serializable with ImplicitCastInputTypes { + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { override def inputTypes: Seq[DataType] = Seq(LongType) override def dataType: DataType = StringType @@ -831,7 +835,8 @@ object Hex { > SELECT _FUNC_('Spark SQL'); 537061726B2053514C """) -case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Hex(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(LongType, BinaryType, StringType)) @@ -866,7 +871,8 @@ case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInput > SELECT decode(_FUNC_('537061726B2053514C'), 'UTF-8'); Spark SQL """) -case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Unhex(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(StringType) @@ -952,7 +958,7 @@ case class Pow(left: Expression, right: Expression) 4 """) case class ShiftLeft(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType), IntegerType) @@ -986,7 +992,7 @@ case class ShiftLeft(left: Expression, right: Expression) 2 """) case class ShiftRight(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType), IntegerType) @@ -1020,7 +1026,7 @@ case class ShiftRight(left: Expression, right: Expression) 2 """) case class ShiftRightUnsigned(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType), IntegerType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 8ce3ddd30a69e..617ddcb69eab0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -116,6 +116,24 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { override def prettyName: String = "current_database" } +/** + * Returns the current catalog. + */ +@ExpressionDescription( + usage = "_FUNC_() - Returns the current catalog.", + examples = """ + Examples: + > SELECT _FUNC_(); + spark_catalog + """, + since = "3.1.0") +case class CurrentCatalog() extends LeafExpression with Unevaluable { + override def dataType: DataType = StringType + override def foldable: Boolean = true + override def nullable: Boolean = false + override def prettyName: String = "current_catalog" +} + // scalastyle:off line.size.limit @ExpressionDescription( usage = """_FUNC_() - Returns an universally unique identifier (UUID) string. The value is returned as a canonical UUID 36-character string.""", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index f54d5f167856c..09ae2186b2429 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.TypeUtils @@ -138,7 +138,7 @@ case class IfNull(left: Expression, right: Expression, child: Expression) } override def flatArguments: Iterator[Any] = Iterator(left, right) - override def sql: String = s"$prettyName(${left.sql}, ${right.sql})" + override def exprsReplaced: Seq[Expression] = Seq(left, right) } @@ -158,7 +158,7 @@ case class NullIf(left: Expression, right: Expression, child: Expression) } override def flatArguments: Iterator[Any] = Iterator(left, right) - override def sql: String = s"$prettyName(${left.sql}, ${right.sql})" + override def exprsReplaced: Seq[Expression] = Seq(left, right) } @@ -177,7 +177,7 @@ case class Nvl(left: Expression, right: Expression, child: Expression) extends R } override def flatArguments: Iterator[Any] = Iterator(left, right) - override def sql: String = s"$prettyName(${left.sql}, ${right.sql})" + override def exprsReplaced: Seq[Expression] = Seq(left, right) } @@ -199,7 +199,7 @@ case class Nvl2(expr1: Expression, expr2: Expression, expr3: Expression, child: } override def flatArguments: Iterator[Any] = Iterator(expr1, expr2, expr3) - override def sql: String = s"$prettyName(${expr1.sql}, ${expr2.sql}, ${expr3.sql})" + override def exprsReplaced: Seq[Expression] = Seq(expr1, expr2, expr3) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 3f60ca388a807..28924fac48eef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -283,7 +283,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress """, since = "1.5.0") case class StringSplit(str: Expression, regex: Expression, limit: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = ArrayType(StringType) override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) @@ -325,7 +325,7 @@ case class StringSplit(str: Expression, regex: Expression, limit: Expression) since = "1.5.0") // scalastyle:on line.size.limit case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { // last regex in string, we will update the pattern iff regexp value changed. @transient private var lastRegex: UTF8String = _ @@ -433,7 +433,7 @@ object RegExpExtract { """, since = "1.5.0") case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { def this(s: Expression, r: Expression) = this(s, r, Literal(1)) // last regex in string, we will update the pattern iff regexp value changed. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 82b1e5f0998b0..334a079fc1892 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -334,7 +334,7 @@ trait String2StringExpression extends ImplicitCastInputTypes { """, since = "1.0.1") case class Upper(child: Expression) - extends UnaryExpression with String2StringExpression { + extends UnaryExpression with String2StringExpression with NullIntolerant { // scalastyle:off caselocale override def convert(v: UTF8String): UTF8String = v.toUpperCase @@ -356,7 +356,8 @@ case class Upper(child: Expression) sparksql """, since = "1.0.1") -case class Lower(child: Expression) extends UnaryExpression with String2StringExpression { +case class Lower(child: Expression) + extends UnaryExpression with String2StringExpression with NullIntolerant { // scalastyle:off caselocale override def convert(v: UTF8String): UTF8String = v.toLowerCase @@ -365,6 +366,9 @@ case class Lower(child: Expression) extends UnaryExpression with String2StringEx override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"($c).toLowerCase()") } + + override def prettyName: String = + getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("lower") } /** A base trait for functions that compare two strings, returning a boolean. */ @@ -432,7 +436,7 @@ case class EndsWith(left: Expression, right: Expression) extends StringPredicate since = "2.3.0") // scalastyle:on line.size.limit case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExpr: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { def this(srcExpr: Expression, searchExpr: Expression) = { this(srcExpr, searchExpr, Literal("")) @@ -598,7 +602,7 @@ object StringTranslate { since = "1.5.0") // scalastyle:on line.size.limit case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replaceExpr: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { @transient private var lastMatching: UTF8String = _ @transient private var lastReplace: UTF8String = _ @@ -663,7 +667,7 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac since = "1.5.0") // scalastyle:on line.size.limit case class FindInSet(left: Expression, right: Expression) extends BinaryExpression - with ImplicitCastInputTypes { + with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) @@ -1035,7 +1039,7 @@ case class StringTrimRight( since = "1.5.0") // scalastyle:on line.size.limit case class StringInstr(str: Expression, substr: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def left: Expression = str override def right: Expression = substr @@ -1077,7 +1081,7 @@ case class StringInstr(str: Expression, substr: Expression) since = "1.5.0") // scalastyle:on line.size.limit case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) @@ -1182,7 +1186,8 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) """) } - override def prettyName: String = "locate" + override def prettyName: String = + getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("locate") } /** @@ -1205,7 +1210,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) """, since = "1.5.0") case class StringLPad(str: Expression, len: Expression, pad: Expression = Literal(" ")) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { def this(str: Expression, len: Expression) = { this(str, len, Literal(" ")) @@ -1246,7 +1251,7 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression = Litera """, since = "1.5.0") case class StringRPad(str: Expression, len: Expression, pad: Expression = Literal(" ")) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { def this(str: Expression, len: Expression) = { this(str, len, Literal(" ")) @@ -1536,7 +1541,8 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC Spark Sql """, since = "1.5.0") -case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class InitCap(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[DataType] = Seq(StringType) override def dataType: DataType = StringType @@ -1563,7 +1569,7 @@ case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastI """, since = "1.5.0") case class StringRepeat(str: Expression, times: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def left: Expression = str override def right: Expression = times @@ -1593,7 +1599,7 @@ case class StringRepeat(str: Expression, times: Expression) """, since = "1.5.0") case class StringSpace(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes { + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(IntegerType) @@ -1695,7 +1701,7 @@ case class Right(str: Expression, len: Expression, child: Expression) extends Ru } override def flatArguments: Iterator[Any] = Iterator(str, len) - override def sql: String = s"$prettyName(${str.sql}, ${len.sql})" + override def exprsReplaced: Seq[Expression] = Seq(str, len) } /** @@ -1717,7 +1723,7 @@ case class Left(str: Expression, len: Expression, child: Expression) extends Run } override def flatArguments: Iterator[Any] = Iterator(str, len) - override def sql: String = s"$prettyName(${str.sql}, ${len.sql})" + override def exprsReplaced: Seq[Expression] = Seq(str, len) } /** @@ -1738,7 +1744,8 @@ case class Left(str: Expression, len: Expression, child: Expression) extends Run """, since = "1.5.0") // scalastyle:on line.size.limit -case class Length(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Length(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) @@ -1766,7 +1773,8 @@ case class Length(child: Expression) extends UnaryExpression with ImplicitCastIn 72 """, since = "2.3.0") -case class BitLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class BitLength(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) @@ -1797,7 +1805,8 @@ case class BitLength(child: Expression) extends UnaryExpression with ImplicitCas 9 """, since = "2.3.0") -case class OctetLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class OctetLength(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) @@ -1828,7 +1837,7 @@ case class OctetLength(child: Expression) extends UnaryExpression with ImplicitC """, since = "1.5.0") case class Levenshtein(left: Expression, right: Expression) extends BinaryExpression - with ImplicitCastInputTypes { + with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) @@ -1853,7 +1862,8 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres M460 """, since = "1.5.0") -case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class SoundEx(child: Expression) + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def dataType: DataType = StringType @@ -1879,7 +1889,8 @@ case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputT 50 """, since = "1.5.0") -case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Ascii(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[DataType] = Seq(StringType) @@ -1921,7 +1932,8 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInp """, since = "2.3.0") // scalastyle:on line.size.limit -case class Chr(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Chr(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(LongType) @@ -1964,7 +1976,8 @@ case class Chr(child: Expression) extends UnaryExpression with ImplicitCastInput U3BhcmsgU1FM """, since = "1.5.0") -case class Base64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Base64(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(BinaryType) @@ -1992,7 +2005,8 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn Spark SQL """, since = "1.5.0") -case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class UnBase64(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = BinaryType override def inputTypes: Seq[DataType] = Seq(StringType) @@ -2024,7 +2038,7 @@ case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCast since = "1.5.0") // scalastyle:on line.size.limit case class Decode(bin: Expression, charset: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def left: Expression = bin override def right: Expression = charset @@ -2064,7 +2078,7 @@ case class Decode(bin: Expression, charset: Expression) since = "1.5.0") // scalastyle:on line.size.limit case class Encode(value: Expression, charset: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def left: Expression = value override def right: Expression = charset @@ -2108,7 +2122,7 @@ case class Encode(value: Expression, charset: Expression) """, since = "1.5.0") case class FormatNumber(x: Expression, d: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ExpectsInputTypes with NullIntolerant { override def left: Expression = x override def right: Expression = d diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala index 55e06cb9e8471..e08a10ecac71c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala @@ -30,7 +30,8 @@ import org.apache.spark.unsafe.types.UTF8String * * This is not the world's most efficient implementation due to type conversion, but works. */ -abstract class XPathExtract extends BinaryExpression with ExpectsInputTypes with CodegenFallback { +abstract class XPathExtract + extends BinaryExpression with ExpectsInputTypes with CodegenFallback with NullIntolerant { override def left: Expression = xml override def right: Expression = path diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index f3938feef0a35..fb0ca323af1ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -83,12 +83,13 @@ private[sql] class JacksonGenerator( options.zoneId, options.locale, legacyFormat = FAST_DATE_FORMAT, - needVarLengthSecondFraction = false) + isParsing = false) private val dateFormatter = DateFormatter( options.dateFormat, options.zoneId, options.locale, - legacyFormat = FAST_DATE_FORMAT) + legacyFormat = FAST_DATE_FORMAT, + isParsing = false) private def makeWriter(dataType: DataType): ValueWriter = dataType match { case NullType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index a52c3450e83df..e038f777c7a41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -56,17 +56,18 @@ class JacksonParser( private val factory = options.buildJsonFactory() - private val timestampFormatter = TimestampFormatter( + private lazy val timestampFormatter = TimestampFormatter( options.timestampFormat, options.zoneId, options.locale, legacyFormat = FAST_DATE_FORMAT, - needVarLengthSecondFraction = true) - private val dateFormatter = DateFormatter( + isParsing = true) + private lazy val dateFormatter = DateFormatter( options.dateFormat, options.zoneId, options.locale, - legacyFormat = FAST_DATE_FORMAT) + legacyFormat = FAST_DATE_FORMAT, + isParsing = true) /** * Create a converter which converts the JSON documents held by the `JsonParser` @@ -456,6 +457,7 @@ class JacksonParser( } } } catch { + case e: SparkUpgradeException => throw e case e @ (_: RuntimeException | _: JsonProcessingException | _: MalformedInputException) => // JSON parser currently doesn't support partial results for corrupted records. // For such records, all fields other than the field configured by diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index 56b12784fd214..de396a4c63458 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -43,7 +43,7 @@ private[sql] class JsonInferSchema(options: JSONOptions) extends Serializable { options.zoneId, options.locale, legacyFormat = FAST_DATE_FORMAT, - needVarLengthSecondFraction = true) + isParsing = true) /** * Infer the type of a collection of json records in three stages: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e59e3b999aa7f..f1a307b1c2cc1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -133,7 +133,7 @@ abstract class Optimizer(catalogManager: CatalogManager) ReplaceExpressions, RewriteNonCorrelatedExists, ComputeCurrentTime, - GetCurrentDatabase(catalogManager), + GetCurrentDatabaseAndCatalog(catalogManager), RewriteDistinctAggregates, ReplaceDeduplicateWithAggregate) :: ////////////////////////////////////////////////////////////////////////////////////////// @@ -223,7 +223,7 @@ abstract class Optimizer(catalogManager: CatalogManager) EliminateView.ruleName :: ReplaceExpressions.ruleName :: ComputeCurrentTime.ruleName :: - GetCurrentDatabase(catalogManager).ruleName :: + GetCurrentDatabaseAndCatalog(catalogManager).ruleName :: RewriteDistinctAggregates.ruleName :: ReplaceDeduplicateWithAggregate.ruleName :: ReplaceIntersectWithSemiJoin.ruleName :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index 80d85827657fd..6c9bb6db06d86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -91,15 +91,21 @@ object ComputeCurrentTime extends Rule[LogicalPlan] { } -/** Replaces the expression of CurrentDatabase with the current database name. */ -case class GetCurrentDatabase(catalogManager: CatalogManager) extends Rule[LogicalPlan] { +/** + * Replaces the expression of CurrentDatabase with the current database name. + * Replaces the expression of CurrentCatalog with the current catalog name. + */ +case class GetCurrentDatabaseAndCatalog(catalogManager: CatalogManager) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ val currentNamespace = catalogManager.currentNamespace.quoted + val currentCatalog = catalogManager.currentCatalog.name() plan transformAllExpressions { case CurrentDatabase() => Literal.create(currentNamespace, StringType) + case CurrentCatalog() => + Literal.create(currentCatalog, StringType) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index b65221c236bfe..85c6600685bd1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -208,3 +208,161 @@ object ExtractPythonUDFFromJoinCondition extends Rule[LogicalPlan] with Predicat } } } + +sealed abstract class BuildSide + +case object BuildRight extends BuildSide + +case object BuildLeft extends BuildSide + +trait JoinSelectionHelper { + + def getBroadcastBuildSide( + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + hint: JoinHint, + hintOnly: Boolean, + conf: SQLConf): Option[BuildSide] = { + val buildLeft = if (hintOnly) { + hintToBroadcastLeft(hint) + } else { + canBroadcastBySize(left, conf) && !hintToNotBroadcastLeft(hint) + } + val buildRight = if (hintOnly) { + hintToBroadcastRight(hint) + } else { + canBroadcastBySize(right, conf) && !hintToNotBroadcastRight(hint) + } + getBuildSide( + canBuildLeft(joinType) && buildLeft, + canBuildRight(joinType) && buildRight, + left, + right + ) + } + + def getShuffleHashJoinBuildSide( + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + hint: JoinHint, + hintOnly: Boolean, + conf: SQLConf): Option[BuildSide] = { + val buildLeft = if (hintOnly) { + hintToShuffleHashJoinLeft(hint) + } else { + canBuildLocalHashMapBySize(left, conf) && muchSmaller(left, right) + } + val buildRight = if (hintOnly) { + hintToShuffleHashJoinRight(hint) + } else { + canBuildLocalHashMapBySize(right, conf) && muchSmaller(right, left) + } + getBuildSide( + canBuildLeft(joinType) && buildLeft, + canBuildRight(joinType) && buildRight, + left, + right + ) + } + + def getSmallerSide(left: LogicalPlan, right: LogicalPlan): BuildSide = { + if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft + } + + /** + * Matches a plan whose output should be small enough to be used in broadcast join. + */ + def canBroadcastBySize(plan: LogicalPlan, conf: SQLConf): Boolean = { + plan.stats.sizeInBytes >= 0 && plan.stats.sizeInBytes <= conf.autoBroadcastJoinThreshold + } + + def canBuildLeft(joinType: JoinType): Boolean = { + joinType match { + case _: InnerLike | RightOuter => true + case _ => false + } + } + + def canBuildRight(joinType: JoinType): Boolean = { + joinType match { + case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => true + case _ => false + } + } + + def hintToBroadcastLeft(hint: JoinHint): Boolean = { + hint.leftHint.exists(_.strategy.contains(BROADCAST)) + } + + def hintToBroadcastRight(hint: JoinHint): Boolean = { + hint.rightHint.exists(_.strategy.contains(BROADCAST)) + } + + def hintToNotBroadcastLeft(hint: JoinHint): Boolean = { + hint.leftHint.exists(_.strategy.contains(NO_BROADCAST_HASH)) + } + + def hintToNotBroadcastRight(hint: JoinHint): Boolean = { + hint.rightHint.exists(_.strategy.contains(NO_BROADCAST_HASH)) + } + + def hintToShuffleHashJoinLeft(hint: JoinHint): Boolean = { + hint.leftHint.exists(_.strategy.contains(SHUFFLE_HASH)) + } + + def hintToShuffleHashJoinRight(hint: JoinHint): Boolean = { + hint.rightHint.exists(_.strategy.contains(SHUFFLE_HASH)) + } + + def hintToSortMergeJoin(hint: JoinHint): Boolean = { + hint.leftHint.exists(_.strategy.contains(SHUFFLE_MERGE)) || + hint.rightHint.exists(_.strategy.contains(SHUFFLE_MERGE)) + } + + def hintToShuffleReplicateNL(hint: JoinHint): Boolean = { + hint.leftHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL)) || + hint.rightHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL)) + } + + private def getBuildSide( + canBuildLeft: Boolean, + canBuildRight: Boolean, + left: LogicalPlan, + right: LogicalPlan): Option[BuildSide] = { + if (canBuildLeft && canBuildRight) { + // returns the smaller side base on its estimated physical size, if we want to build the + // both sides. + Some(getSmallerSide(left, right)) + } else if (canBuildLeft) { + Some(BuildLeft) + } else if (canBuildRight) { + Some(BuildRight) + } else { + None + } + } + + /** + * Matches a plan whose single partition should be small enough to build a hash table. + * + * Note: this assume that the number of partition is fixed, requires additional work if it's + * dynamic. + */ + private def canBuildLocalHashMapBySize(plan: LogicalPlan, conf: SQLConf): Boolean = { + plan.stats.sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions + } + + /** + * Returns whether plan a is much smaller (3X) than plan b. + * + * The cost to build hash map is higher than sorting, we should only build hash map on a table + * that is much smaller than other one. Since we does not have the statistic for number of rows, + * use the size of bytes here as estimation. + */ + private def muchSmaller(a: LogicalPlan, b: LogicalPlan): Boolean = { + a.stats.sizeInBytes * 3 <= b.stats.sizeInBytes + } +} + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 97750f467adbc..03571a740df3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -629,12 +629,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case p: Predicate => p case e => Cast(e, BooleanType) } - plan match { - case aggregate: Aggregate => - AggregateWithHaving(predicate, aggregate) - case _ => - Filter(predicate, plan) - } + UnresolvedHaving(predicate, plan) } /** @@ -1539,7 +1534,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging * Create a [[CreateStruct]] expression. */ override def visitStruct(ctx: StructContext): Expression = withOrigin(ctx) { - CreateStruct(ctx.argument.asScala.map(expression)) + CreateStruct.create(ctx.argument.asScala.map(expression)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index 590193bddafb5..fab282f15f215 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -98,7 +98,6 @@ abstract class AbstractSqlParser(conf: SQLConf) extends ParserInterface with Log lexer.addErrorListener(ParseErrorListener) lexer.legacy_setops_precedence_enbled = conf.setOpsPrecedenceEnforced lexer.legacy_exponent_literal_as_decimal_enabled = conf.exponentLiteralAsDecimalEnabled - lexer.legacy_create_hive_table_by_default_enabled = conf.createHiveTableByDefaultEnabled lexer.SQL_standard_keyword_behavior = conf.ansiEnabled val tokenStream = new CommonTokenStream(lexer) @@ -108,7 +107,6 @@ abstract class AbstractSqlParser(conf: SQLConf) extends ParserInterface with Log parser.addErrorListener(ParseErrorListener) parser.legacy_setops_precedence_enbled = conf.setOpsPrecedenceEnforced parser.legacy_exponent_literal_as_decimal_enabled = conf.exponentLiteralAsDecimalEnabled - parser.legacy_create_hive_table_by_default_enabled = conf.createHiveTableByDefaultEnabled parser.SQL_standard_keyword_behavior = conf.ansiEnabled try { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala index 0f79c1a6a751d..6d225ad9b7645 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala @@ -29,14 +29,20 @@ import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy._ sealed trait DateFormatter extends Serializable { def parse(s: String): Int // returns days since epoch + def format(days: Int): String + def format(date: Date): String + def format(localDate: LocalDate): String + + def validatePatternString(): Unit } class Iso8601DateFormatter( pattern: String, zoneId: ZoneId, locale: Locale, - legacyFormat: LegacyDateFormats.LegacyDateFormat) + legacyFormat: LegacyDateFormats.LegacyDateFormat, + isParsing: Boolean) extends DateFormatter with DateTimeFormatterHelper { @transient @@ -50,28 +56,44 @@ class Iso8601DateFormatter( val specialDate = convertSpecialDate(s.trim, zoneId) specialDate.getOrElse { try { - val localDate = LocalDate.parse(s, formatter) + val localDate = toLocalDate(formatter.parse(s)) localDateToDays(localDate) } catch checkDiffResult(s, legacyFormatter.parse) } } + override def format(localDate: LocalDate): String = { + localDate.format(formatter) + } + override def format(days: Int): String = { - LocalDate.ofEpochDay(days).format(formatter) + format(LocalDate.ofEpochDay(days)) + } + + override def format(date: Date): String = { + legacyFormatter.format(date) + } + + override def validatePatternString(): Unit = { + try { + formatter + } catch checkLegacyFormatter(pattern, legacyFormatter.validatePatternString) } } trait LegacyDateFormatter extends DateFormatter { def parseToDate(s: String): Date - def formatDate(d: Date): String override def parse(s: String): Int = { fromJavaDate(new java.sql.Date(parseToDate(s).getTime)) } override def format(days: Int): String = { - val date = DateTimeUtils.toJavaDate(days) - formatDate(date) + format(DateTimeUtils.toJavaDate(days)) + } + + override def format(localDate: LocalDate): String = { + format(localDateToDays(localDate)) } } @@ -79,14 +101,17 @@ class LegacyFastDateFormatter(pattern: String, locale: Locale) extends LegacyDat @transient private lazy val fdf = FastDateFormat.getInstance(pattern, locale) override def parseToDate(s: String): Date = fdf.parse(s) - override def formatDate(d: Date): String = fdf.format(d) + override def format(d: Date): String = fdf.format(d) + override def validatePatternString(): Unit = fdf } class LegacySimpleDateFormatter(pattern: String, locale: Locale) extends LegacyDateFormatter { @transient private lazy val sdf = new SimpleDateFormat(pattern, locale) override def parseToDate(s: String): Date = sdf.parse(s) - override def formatDate(d: Date): String = sdf.format(d) + override def format(d: Date): String = sdf.format(d) + override def validatePatternString(): Unit = sdf + } object DateFormatter { @@ -100,12 +125,15 @@ object DateFormatter { format: Option[String], zoneId: ZoneId, locale: Locale = defaultLocale, - legacyFormat: LegacyDateFormat = LENIENT_SIMPLE_DATE_FORMAT): DateFormatter = { + legacyFormat: LegacyDateFormat = LENIENT_SIMPLE_DATE_FORMAT, + isParsing: Boolean = true): DateFormatter = { val pattern = format.getOrElse(defaultPattern) if (SQLConf.get.legacyTimeParserPolicy == LEGACY) { getLegacyFormatter(pattern, zoneId, locale, legacyFormat) } else { - new Iso8601DateFormatter(pattern, zoneId, locale, legacyFormat) + val df = new Iso8601DateFormatter(pattern, zoneId, locale, legacyFormat, isParsing) + df.validatePatternString() + df } } @@ -126,8 +154,9 @@ object DateFormatter { format: String, zoneId: ZoneId, locale: Locale, - legacyFormat: LegacyDateFormat): DateFormatter = { - getFormatter(Some(format), zoneId, locale, legacyFormat) + legacyFormat: LegacyDateFormat, + isParsing: Boolean): DateFormatter = { + getFormatter(Some(format), zoneId, locale, legacyFormat, isParsing) } def apply(format: String, zoneId: ZoneId): DateFormatter = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala index 05ec23f7ad479..eeb56aa9821ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.util import java.time._ import java.time.chrono.IsoChronology -import java.time.format.{DateTimeFormatter, DateTimeFormatterBuilder, DateTimeParseException, ResolverStyle} +import java.time.format.{DateTimeFormatter, DateTimeFormatterBuilder, ResolverStyle} import java.time.temporal.{ChronoField, TemporalAccessor, TemporalQueries} import java.util.Locale @@ -31,17 +31,60 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy._ trait DateTimeFormatterHelper { + private def getOrDefault(accessor: TemporalAccessor, field: ChronoField, default: Int): Int = { + if (accessor.isSupported(field)) { + accessor.get(field) + } else { + default + } + } + + protected def toLocalDate(accessor: TemporalAccessor): LocalDate = { + val localDate = accessor.query(TemporalQueries.localDate()) + // If all the date fields are specified, return the local date directly. + if (localDate != null) return localDate + + // Users may want to parse only a few datetime fields from a string and extract these fields + // later, and we should provide default values for missing fields. + // To be compatible with Spark 2.4, we pick 1970 as the default value of year. + val year = getOrDefault(accessor, ChronoField.YEAR, 1970) + val month = getOrDefault(accessor, ChronoField.MONTH_OF_YEAR, 1) + val day = getOrDefault(accessor, ChronoField.DAY_OF_MONTH, 1) + LocalDate.of(year, month, day) + } + + private def toLocalTime(accessor: TemporalAccessor): LocalTime = { + val localTime = accessor.query(TemporalQueries.localTime()) + // If all the time fields are specified, return the local time directly. + if (localTime != null) return localTime + + val hour = if (accessor.isSupported(ChronoField.HOUR_OF_DAY)) { + accessor.get(ChronoField.HOUR_OF_DAY) + } else if (accessor.isSupported(ChronoField.HOUR_OF_AMPM)) { + // When we reach here, it means am/pm is not specified. Here we assume it's am. + // All of CLOCK_HOUR_OF_AMPM(h)/HOUR_OF_DAY(H)/CLOCK_HOUR_OF_DAY(k)/HOUR_OF_AMPM(K) will + // be resolved to HOUR_OF_AMPM here, we do not need to handle them separately + accessor.get(ChronoField.HOUR_OF_AMPM) + } else if (accessor.isSupported(ChronoField.AMPM_OF_DAY) && + accessor.get(ChronoField.AMPM_OF_DAY) == 1) { + // When reach here, the `hour` part is missing, and PM is specified. + // None of CLOCK_HOUR_OF_AMPM(h)/HOUR_OF_DAY(H)/CLOCK_HOUR_OF_DAY(k)/HOUR_OF_AMPM(K) is + // specified + 12 + } else { + 0 + } + val minute = getOrDefault(accessor, ChronoField.MINUTE_OF_HOUR, 0) + val second = getOrDefault(accessor, ChronoField.SECOND_OF_MINUTE, 0) + val nanoSecond = getOrDefault(accessor, ChronoField.NANO_OF_SECOND, 0) + LocalTime.of(hour, minute, second, nanoSecond) + } + // Converts the parsed temporal object to ZonedDateTime. It sets time components to zeros // if they does not exist in the parsed object. - protected def toZonedDateTime( - temporalAccessor: TemporalAccessor, - zoneId: ZoneId): ZonedDateTime = { - // Parsed input might not have time related part. In that case, time component is set to zeros. - val parsedLocalTime = temporalAccessor.query(TemporalQueries.localTime) - val localTime = if (parsedLocalTime == null) LocalTime.MIDNIGHT else parsedLocalTime - // Parsed input must have date component. At least, year must present in temporalAccessor. - val localDate = temporalAccessor.query(TemporalQueries.localDate) - + protected def toZonedDateTime(accessor: TemporalAccessor, zoneId: ZoneId): ZonedDateTime = { + val localDate = toLocalDate(accessor) + val localTime = toLocalTime(accessor) ZonedDateTime.of(localDate, localTime, zoneId) } @@ -54,9 +97,9 @@ trait DateTimeFormatterHelper { protected def getOrCreateFormatter( pattern: String, locale: Locale, - needVarLengthSecondFraction: Boolean = false): DateTimeFormatter = { - val newPattern = convertIncompatiblePattern(pattern) - val useVarLen = needVarLengthSecondFraction && newPattern.contains('S') + isParsing: Boolean = false): DateTimeFormatter = { + val newPattern = convertIncompatiblePattern(pattern, isParsing) + val useVarLen = isParsing && newPattern.contains('S') val key = (newPattern, locale, useVarLen) var formatter = cache.getIfPresent(key) if (formatter == null) { @@ -72,19 +115,43 @@ trait DateTimeFormatterHelper { // DateTimeParseException will address by the caller side. protected def checkDiffResult[T]( s: String, legacyParseFunc: String => T): PartialFunction[Throwable, T] = { - case e: DateTimeParseException if SQLConf.get.legacyTimeParserPolicy == EXCEPTION => - val res = try { - Some(legacyParseFunc(s)) + case e: DateTimeException if SQLConf.get.legacyTimeParserPolicy == EXCEPTION => + try { + legacyParseFunc(s) } catch { - case _: Throwable => None + case _: Throwable => throw e } - if (res.nonEmpty) { - throw new SparkUpgradeException("3.0", s"Fail to parse '$s' in the new parser. You can " + - s"set ${SQLConf.LEGACY_TIME_PARSER_POLICY.key} to LEGACY to restore the behavior " + - s"before Spark 3.0, or set to CORRECTED and treat it as an invalid datetime string.", e) - } else { - throw e + throw new SparkUpgradeException("3.0", s"Fail to parse '$s' in the new parser. You can " + + s"set ${SQLConf.LEGACY_TIME_PARSER_POLICY.key} to LEGACY to restore the behavior " + + s"before Spark 3.0, or set to CORRECTED and treat it as an invalid datetime string.", e) + } + + /** + * When the new DateTimeFormatter failed to initialize because of invalid datetime pattern, it + * will throw IllegalArgumentException. If the pattern can be recognized by the legacy formatter + * it will raise SparkUpgradeException to tell users to restore the previous behavior via LEGACY + * policy or follow our guide to correct their pattern. Otherwise, the original + * IllegalArgumentException will be thrown. + * + * @param pattern the date time pattern + * @param tryLegacyFormatter a func to capture exception, identically which forces a legacy + * datetime formatter to be initialized + */ + + protected def checkLegacyFormatter( + pattern: String, + tryLegacyFormatter: => Unit): PartialFunction[Throwable, DateTimeFormatter] = { + case e: IllegalArgumentException => + try { + tryLegacyFormatter + } catch { + case _: Throwable => throw e } + throw new SparkUpgradeException("3.0", s"Fail to recognize '$pattern' pattern in the" + + s" DateTimeFormatter. 1) You can set ${SQLConf.LEGACY_TIME_PARSER_POLICY.key} to LEGACY" + + s" to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern" + + s" with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html", + e) } } @@ -101,10 +168,6 @@ private object DateTimeFormatterHelper { def toFormatter(builder: DateTimeFormatterBuilder, locale: Locale): DateTimeFormatter = { builder - .parseDefaulting(ChronoField.MONTH_OF_YEAR, 1) - .parseDefaulting(ChronoField.DAY_OF_MONTH, 1) - .parseDefaulting(ChronoField.MINUTE_OF_HOUR, 0) - .parseDefaulting(ChronoField.SECOND_OF_MINUTE, 0) .toFormatter(locale) .withChronology(IsoChronology.INSTANCE) .withResolverStyle(ResolverStyle.STRICT) @@ -162,7 +225,32 @@ private object DateTimeFormatterHelper { toFormatter(builder, TimestampFormatter.defaultLocale) } + private final val bugInStandAloneForm = { + // Java 8 has a bug for stand-alone form. See https://bugs.openjdk.java.net/browse/JDK-8114833 + // Note: we only check the US locale so that it's a static check. It can produce false-negative + // as some locales are not affected by the bug. Since `L`/`q` is rarely used, we choose to not + // complicate the check here. + // TODO: remove it when we drop Java 8 support. + val formatter = DateTimeFormatter.ofPattern("LLL qqq", Locale.US) + formatter.format(LocalDate.of(2000, 1, 1)) == "1 1" + } final val unsupportedLetters = Set('A', 'c', 'e', 'n', 'N', 'p') + // SPARK-31892: The week-based date fields are rarely used and really confusing for parsing values + // to datetime, especially when they are mixed with other non-week-based ones + // The quarter fields will also be parsed strangely, e.g. when the pattern contains `yMd` and can + // be directly resolved then the `q` do check for whether the month is valid, but if the date + // fields is incomplete, e.g. `yM`, the checking will be bypassed. + final val unsupportedLettersForParsing = Set('Y', 'W', 'w', 'E', 'u', 'F', 'q', 'Q') + final val unsupportedPatternLengths = { + // SPARK-31771: Disable Narrow-form TextStyle to avoid silent data change, as it is Full-form in + // 2.4 + Seq("G", "M", "L", "E", "u", "Q", "q").map(_ * 5) ++ + // SPARK-31867: Disable year pattern longer than 10 which will cause Java time library throw + // unchecked `ArrayIndexOutOfBoundsException` by the `NumberPrinterParser` for formatting. It + // makes the call side difficult to handle exceptions and easily leads to silent data change + // because of the exceptions being suppressed. + Seq("y", "Y").map(_ * 11) + }.toSet /** * In Spark 3.0, we switch to the Proleptic Gregorian calendar and use DateTimeFormatter for @@ -172,7 +260,7 @@ private object DateTimeFormatterHelper { * @param pattern The input pattern. * @return The pattern for new parser */ - def convertIncompatiblePattern(pattern: String): String = { + def convertIncompatiblePattern(pattern: String, isParsing: Boolean = false): String = { val eraDesignatorContained = pattern.split("'").zipWithIndex.exists { case (patternPart, index) => // Text can be quoted using single quotes, we only check the non-quote parts. @@ -181,9 +269,19 @@ private object DateTimeFormatterHelper { (pattern + " ").split("'").zipWithIndex.map { case (patternPart, index) => if (index % 2 == 0) { - for (c <- patternPart if unsupportedLetters.contains(c)) { + for (c <- patternPart if unsupportedLetters.contains(c) || + (isParsing && unsupportedLettersForParsing.contains(c))) { throw new IllegalArgumentException(s"Illegal pattern character: $c") } + for (style <- unsupportedPatternLengths if patternPart.contains(style)) { + throw new IllegalArgumentException(s"Too many pattern letters: ${style.head}") + } + if (bugInStandAloneForm && (patternPart.contains("LLL") || patternPart.contains("qqq"))) { + throw new IllegalArgumentException("Java 8 has a bug to support stand-alone " + + "form (3 or more 'L' or 'q' in the pattern string). Please use 'M' or 'Q' instead, " + + "or upgrade your Java version. For more details, please read " + + "https://bugs.openjdk.java.net/browse/JDK-8114833") + } // The meaning of 'u' was day number of week in SimpleDateFormat, it was changed to year // in DateTimeFormatter. Substitute 'u' to 'e' and use DateTimeFormatter to parse the // string. If parsable, return the result; otherwise, fall back to 'u', and then use the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RebaseDateTime.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RebaseDateTime.scala index a1b87e8e02351..cc75340cd8fcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RebaseDateTime.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RebaseDateTime.scala @@ -146,6 +146,8 @@ object RebaseDateTime { -354226, -317702, -244653, -208129, -171605, -141436, -141435, -141434, -141433, -141432, -141431, -141430, -141429, -141428, -141427) + final val lastSwitchGregorianDay: Int = gregJulianDiffSwitchDay.last + // The first days of Common Era (CE) which is mapped to the '0001-01-01' date // in Proleptic Gregorian calendar. private final val gregorianCommonEraStartDay = gregJulianDiffSwitchDay(0) @@ -295,7 +297,7 @@ object RebaseDateTime { } // The switch time point after which all diffs between Gregorian and Julian calendars // across all time zones are zero - private final val lastSwitchGregorianTs: Long = getLastSwitchTs(gregJulianRebaseMap) + final val lastSwitchGregorianTs: Long = getLastSwitchTs(gregJulianRebaseMap) private final val gregorianStartTs = LocalDateTime.of(gregorianStartDate, LocalTime.MIDNIGHT) private final val julianEndTs = LocalDateTime.of( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala index dc06fa9d6f1c4..97ecc430af4a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala @@ -50,7 +50,11 @@ sealed trait TimestampFormatter extends Serializable { @throws(classOf[DateTimeParseException]) @throws(classOf[DateTimeException]) def parse(s: String): Long + def format(us: Long): String + def format(ts: Timestamp): String + def format(instant: Instant): String + def validatePatternString(): Unit } class Iso8601TimestampFormatter( @@ -84,9 +88,23 @@ class Iso8601TimestampFormatter( } } + override def format(instant: Instant): String = { + formatter.withZone(zoneId).format(instant) + } + override def format(us: Long): String = { val instant = DateTimeUtils.microsToInstant(us) - formatter.withZone(zoneId).format(instant) + format(instant) + } + + override def format(ts: Timestamp): String = { + legacyFormatter.format(ts) + } + + override def validatePatternString(): Unit = { + try { + formatter + } catch checkLegacyFormatter(pattern, legacyFormatter.validatePatternString) } } @@ -100,10 +118,49 @@ class Iso8601TimestampFormatter( */ class FractionTimestampFormatter(zoneId: ZoneId) extends Iso8601TimestampFormatter( - "", zoneId, TimestampFormatter.defaultLocale, needVarLengthSecondFraction = false) { + TimestampFormatter.defaultPattern, + zoneId, + TimestampFormatter.defaultLocale, + LegacyDateFormats.FAST_DATE_FORMAT, + needVarLengthSecondFraction = false) { @transient override protected lazy val formatter = DateTimeFormatterHelper.fractionFormatter + + // The new formatter will omit the trailing 0 in the timestamp string, but the legacy formatter + // can't. Here we use the legacy formatter to format the given timestamp up to seconds fractions, + // and custom implementation to format the fractional part without trailing zeros. + override def format(ts: Timestamp): String = { + val formatted = legacyFormatter.format(ts) + var nanos = ts.getNanos + if (nanos == 0) { + formatted + } else { + // Formats non-zero seconds fraction w/o trailing zeros. For example: + // formatted = '2020-05:27 15:55:30' + // nanos = 001234000 + // Counts the length of the fractional part: 001234000 -> 6 + var fracLen = 9 + while (nanos % 10 == 0) { + nanos /= 10 + fracLen -= 1 + } + // Places `nanos` = 1234 after '2020-05:27 15:55:30.' + val fracOffset = formatted.length + 1 + val totalLen = fracOffset + fracLen + // The buffer for the final result: '2020-05:27 15:55:30.001234' + val buf = new Array[Char](totalLen) + formatted.getChars(0, formatted.length, buf, 0) + buf(formatted.length) = '.' + var i = totalLen + do { + i -= 1 + buf(i) = ('0' + (nanos % 10)).toChar + nanos /= 10 + } while (i > fracOffset) + new String(buf) + } + } } /** @@ -149,7 +206,7 @@ class LegacyFastTimestampFormatter( fastDateFormat.getTimeZone, fastDateFormat.getPattern.count(_ == 'S')) - def parse(s: String): SQLTimestamp = { + override def parse(s: String): SQLTimestamp = { cal.clear() // Clear the calendar because it can be re-used many times if (!fastDateFormat.parse(s, new ParsePosition(0), cal)) { throw new IllegalArgumentException(s"'$s' is an invalid timestamp") @@ -160,12 +217,26 @@ class LegacyFastTimestampFormatter( rebaseJulianToGregorianMicros(julianMicros) } - def format(timestamp: SQLTimestamp): String = { + override def format(timestamp: SQLTimestamp): String = { val julianMicros = rebaseGregorianToJulianMicros(timestamp) cal.setTimeInMillis(Math.floorDiv(julianMicros, MICROS_PER_SECOND) * MILLIS_PER_SECOND) cal.setMicros(Math.floorMod(julianMicros, MICROS_PER_SECOND)) fastDateFormat.format(cal) } + + override def format(ts: Timestamp): String = { + if (ts.getNanos == 0) { + fastDateFormat.format(ts) + } else { + format(fromJavaTimestamp(ts)) + } + } + + override def format(instant: Instant): String = { + format(instantToMicros(instant)) + } + + override def validatePatternString(): Unit = fastDateFormat } class LegacySimpleTimestampFormatter( @@ -187,6 +258,16 @@ class LegacySimpleTimestampFormatter( override def format(us: Long): String = { sdf.format(toJavaTimestamp(us)) } + + override def format(ts: Timestamp): String = { + sdf.format(ts) + } + + override def format(instant: Instant): String = { + format(instantToMicros(instant)) + } + + override def validatePatternString(): Unit = sdf } object LegacyDateFormats extends Enumeration { @@ -206,13 +287,15 @@ object TimestampFormatter { zoneId: ZoneId, locale: Locale = defaultLocale, legacyFormat: LegacyDateFormat = LENIENT_SIMPLE_DATE_FORMAT, - needVarLengthSecondFraction: Boolean = false): TimestampFormatter = { + isParsing: Boolean = false): TimestampFormatter = { val pattern = format.getOrElse(defaultPattern) if (SQLConf.get.legacyTimeParserPolicy == LEGACY) { getLegacyFormatter(pattern, zoneId, locale, legacyFormat) } else { - new Iso8601TimestampFormatter( - pattern, zoneId, locale, legacyFormat, needVarLengthSecondFraction) + val tf = new Iso8601TimestampFormatter( + pattern, zoneId, locale, legacyFormat, isParsing) + tf.validatePatternString() + tf } } @@ -236,23 +319,23 @@ object TimestampFormatter { zoneId: ZoneId, locale: Locale, legacyFormat: LegacyDateFormat, - needVarLengthSecondFraction: Boolean): TimestampFormatter = { - getFormatter(Some(format), zoneId, locale, legacyFormat, needVarLengthSecondFraction) + isParsing: Boolean): TimestampFormatter = { + getFormatter(Some(format), zoneId, locale, legacyFormat, isParsing) } def apply( format: String, zoneId: ZoneId, legacyFormat: LegacyDateFormat, - needVarLengthSecondFraction: Boolean): TimestampFormatter = { - getFormatter(Some(format), zoneId, defaultLocale, legacyFormat, needVarLengthSecondFraction) + isParsing: Boolean): TimestampFormatter = { + getFormatter(Some(format), zoneId, defaultLocale, legacyFormat, isParsing) } def apply( format: String, zoneId: ZoneId, - needVarLengthSecondFraction: Boolean = false): TimestampFormatter = { - getFormatter(Some(format), zoneId, needVarLengthSecondFraction = needVarLengthSecondFraction) + isParsing: Boolean = false): TimestampFormatter = { + getFormatter(Some(format), zoneId, isParsing = isParsing) } def apply(zoneId: ZoneId): TimestampFormatter = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index a0fec29075e40..a5f0b239d6086 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -134,6 +134,8 @@ package object util extends Logging { PrettyAttribute(usePrettyExpression(e.child).sql + "." + name, e.dataType) case e: GetArrayStructFields => PrettyAttribute(usePrettyExpression(e.child) + "." + e.field.name, e.dataType) + case r: RuntimeReplaceable => + PrettyAttribute(r.mkString(r.exprsReplaced.map(toPrettySQL)), r.dataType) } def quoteIdentifier(name: String): String = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index c4922b56f0756..3a41b0553db54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -495,7 +495,7 @@ object SQLConf { .version("3.0.0") .intConf .checkValue(_ > 0, "The skew factor must be positive.") - .createWithDefault(10) + .createWithDefault(5) val SKEW_JOIN_SKEWED_PARTITION_THRESHOLD = buildConf("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes") @@ -1784,6 +1784,15 @@ object SQLConf { .version("3.0.0") .fallbackConf(ARROW_EXECUTION_ENABLED) + val PYSPARK_JVM_STACKTRACE_ENABLED = + buildConf("spark.sql.pyspark.jvmStacktrace.enabled") + .doc("When true, it shows the JVM stacktrace in the user-facing PySpark exception " + + "together with Python stacktrace. By default, it is disabled and hides JVM stacktrace " + + "and shows a Python-friendly exception only.") + .version("3.0.0") + .booleanConf + .createWithDefault(false) + val ARROW_SPARKR_EXECUTION_ENABLED = buildConf("spark.sql.execution.arrow.sparkr.enabled") .doc("When true, make use of Apache Arrow for columnar data transfers in SparkR. " + @@ -2228,15 +2237,6 @@ object SQLConf { .booleanConf .createWithDefault(false) - val LEGACY_CREATE_HIVE_TABLE_BY_DEFAULT_ENABLED = - buildConf("spark.sql.legacy.createHiveTableByDefault.enabled") - .internal() - .doc("When set to true, CREATE TABLE syntax without a provider will use hive " + - s"instead of the value of ${DEFAULT_DATA_SOURCE_NAME.key}.") - .version("3.0.0") - .booleanConf - .createWithDefault(false) - val LEGACY_BUCKETED_TABLE_SCAN_OUTPUT_ORDERING = buildConf("spark.sql.legacy.bucketedTableScan.outputOrdering") .internal() @@ -2528,57 +2528,72 @@ object SQLConf { .booleanConf .createWithDefault(false) - val LEGACY_PARQUET_REBASE_DATETIME_IN_WRITE = - buildConf("spark.sql.legacy.parquet.rebaseDateTimeInWrite.enabled") + val LEGACY_PARQUET_REBASE_MODE_IN_WRITE = + buildConf("spark.sql.legacy.parquet.datetimeRebaseModeInWrite") .internal() - .doc("When true, rebase dates/timestamps from Proleptic Gregorian calendar " + - "to the hybrid calendar (Julian + Gregorian) in write. " + - "The rebasing is performed by converting micros/millis/days to " + - "a local date/timestamp in the source calendar, interpreting the resulted date/" + - "timestamp in the target calendar, and getting the number of micros/millis/days " + - "since the epoch 1970-01-01 00:00:00Z.") + .doc("When LEGACY, Spark will rebase dates/timestamps from Proleptic Gregorian calendar " + + "to the legacy hybrid (Julian + Gregorian) calendar when writing Parquet files. " + + "When CORRECTED, Spark will not do rebase and write the dates/timestamps as it is. " + + "When EXCEPTION, which is the default, Spark will fail the writing if it sees " + + "ancient dates/timestamps that are ambiguous between the two calendars.") .version("3.0.0") - .booleanConf - .createWithDefault(false) - - val LEGACY_PARQUET_REBASE_DATETIME_IN_READ = - buildConf("spark.sql.legacy.parquet.rebaseDateTimeInRead.enabled") - .internal() - .doc("When true, rebase dates/timestamps " + - "from the hybrid calendar to Proleptic Gregorian calendar in read. " + - "The rebasing is performed by converting micros/millis/days to " + - "a local date/timestamp in the source calendar, interpreting the resulted date/" + - "timestamp in the target calendar, and getting the number of micros/millis/days " + - "since the epoch 1970-01-01 00:00:00Z.") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(LegacyBehaviorPolicy.values.map(_.toString)) + .createWithDefault(LegacyBehaviorPolicy.EXCEPTION.toString) + + val LEGACY_PARQUET_REBASE_MODE_IN_READ = + buildConf("spark.sql.legacy.parquet.datetimeRebaseModeInRead") + .internal() + .doc("When LEGACY, Spark will rebase dates/timestamps from the legacy hybrid (Julian + " + + "Gregorian) calendar to Proleptic Gregorian calendar when reading Parquet files. " + + "When CORRECTED, Spark will not do rebase and read the dates/timestamps as it is. " + + "When EXCEPTION, which is the default, Spark will fail the reading if it sees " + + "ancient dates/timestamps that are ambiguous between the two calendars. This config is " + + "only effective if the writer info (like Spark, Hive) of the Parquet files is unknown.") .version("3.0.0") - .booleanConf - .createWithDefault(false) + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(LegacyBehaviorPolicy.values.map(_.toString)) + .createWithDefault(LegacyBehaviorPolicy.EXCEPTION.toString) - val LEGACY_AVRO_REBASE_DATETIME_IN_WRITE = - buildConf("spark.sql.legacy.avro.rebaseDateTimeInWrite.enabled") + val LEGACY_AVRO_REBASE_MODE_IN_WRITE = + buildConf("spark.sql.legacy.avro.datetimeRebaseModeInWrite") .internal() - .doc("When true, rebase dates/timestamps from Proleptic Gregorian calendar " + - "to the hybrid calendar (Julian + Gregorian) in write. " + - "The rebasing is performed by converting micros/millis/days to " + - "a local date/timestamp in the source calendar, interpreting the resulted date/" + - "timestamp in the target calendar, and getting the number of micros/millis/days " + - "since the epoch 1970-01-01 00:00:00Z.") + .doc("When LEGACY, Spark will rebase dates/timestamps from Proleptic Gregorian calendar " + + "to the legacy hybrid (Julian + Gregorian) calendar when writing Avro files. " + + "When CORRECTED, Spark will not do rebase and write the dates/timestamps as it is. " + + "When EXCEPTION, which is the default, Spark will fail the writing if it sees " + + "ancient dates/timestamps that are ambiguous between the two calendars.") .version("3.0.0") - .booleanConf - .createWithDefault(false) + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(LegacyBehaviorPolicy.values.map(_.toString)) + .createWithDefault(LegacyBehaviorPolicy.EXCEPTION.toString) + + val LEGACY_AVRO_REBASE_MODE_IN_READ = + buildConf("spark.sql.legacy.avro.datetimeRebaseModeInRead") + .internal() + .doc("When LEGACY, Spark will rebase dates/timestamps from the legacy hybrid (Julian + " + + "Gregorian) calendar to Proleptic Gregorian calendar when reading Avro files. " + + "When CORRECTED, Spark will not do rebase and read the dates/timestamps as it is. " + + "When EXCEPTION, which is the default, Spark will fail the reading if it sees " + + "ancient dates/timestamps that are ambiguous between the two calendars. This config is " + + "only effective if the writer info (like Spark, Hive) of the Avro files is unknown.") + .version("3.0.0") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(LegacyBehaviorPolicy.values.map(_.toString)) + .createWithDefault(LegacyBehaviorPolicy.EXCEPTION.toString) - val LEGACY_AVRO_REBASE_DATETIME_IN_READ = - buildConf("spark.sql.legacy.avro.rebaseDateTimeInRead.enabled") + val SCRIPT_TRANSFORMATION_EXIT_TIMEOUT = + buildConf("spark.sql.scriptTransformation.exitTimeoutInSeconds") .internal() - .doc("When true, rebase dates/timestamps " + - "from the hybrid calendar to Proleptic Gregorian calendar in read. " + - "The rebasing is performed by converting micros/millis/days to " + - "a local date/timestamp in the source calendar, interpreting the resulted date/" + - "timestamp in the target calendar, and getting the number of micros/millis/days " + - "since the epoch 1970-01-01 00:00:00Z.") + .doc("Timeout for executor to wait for the termination of transformation script when EOF.") .version("3.0.0") - .booleanConf - .createWithDefault(false) + .timeConf(TimeUnit.SECONDS) + .checkValue(_ > 0, "The timeout value must be positive") + .createWithDefault(10L) /** * Holds information about keys that have been deprecated. @@ -3057,6 +3072,8 @@ class SQLConf extends Serializable with Logging { def arrowPySparkEnabled: Boolean = getConf(ARROW_PYSPARK_EXECUTION_ENABLED) + def pysparkJVMStacktraceEnabled: Boolean = getConf(PYSPARK_JVM_STACKTRACE_ENABLED) + def arrowSparkREnabled: Boolean = getConf(ARROW_SPARKR_EXECUTION_ENABLED) def arrowPySparkFallbackEnabled: Boolean = getConf(ARROW_PYSPARK_FALLBACK_ENABLED) @@ -3138,9 +3155,6 @@ class SQLConf extends Serializable with Logging { def allowNegativeScaleOfDecimalEnabled: Boolean = getConf(SQLConf.LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED) - def createHiveTableByDefaultEnabled: Boolean = - getConf(SQLConf.LEGACY_CREATE_HIVE_TABLE_BY_DEFAULT_ENABLED) - def truncateTableIgnorePermissionAcl: Boolean = getConf(SQLConf.TRUNCATE_TABLE_IGNORE_PERMISSION_ACL) @@ -3162,10 +3176,6 @@ class SQLConf extends Serializable with Logging { def integerGroupingIdEnabled: Boolean = getConf(SQLConf.LEGACY_INTEGER_GROUPING_ID) - def parquetRebaseDateTimeInReadEnabled: Boolean = { - getConf(SQLConf.LEGACY_PARQUET_REBASE_DATETIME_IN_READ) - } - /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SupportsStreamingUpdate.scala similarity index 60% rename from sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationUtils.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SupportsStreamingUpdate.scala index f4c4b04bada2a..32be74a345c5a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SupportsStreamingUpdate.scala @@ -15,20 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.thriftserver +package org.apache.spark.sql.internal.connector -import org.apache.spark.sql.catalyst.catalog.CatalogTableType -import org.apache.spark.sql.catalyst.catalog.CatalogTableType.{EXTERNAL, MANAGED, VIEW} +import org.apache.spark.sql.connector.write.WriteBuilder -/** - * Utils for metadata operations. - */ -private[hive] trait SparkMetadataOperationUtils { - - def tableTypeString(tableType: CatalogTableType): String = tableType match { - case EXTERNAL | MANAGED => "TABLE" - case VIEW => "VIEW" - case t => - throw new IllegalArgumentException(s"Unknown table type is found: $t") - } +// An internal `WriteBuilder` mixin to support UPDATE streaming output mode. +// TODO: design an official API for streaming output mode UPDATE. +trait SupportsStreamingUpdate extends WriteBuilder { + def update(): WriteBuilder } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 7449a28e069d2..fe8d7efc9dc12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -457,7 +457,7 @@ object DataType { case (w: AtomicType, r: AtomicType) if storeAssignmentPolicy == STRICT => if (!Cast.canUpCast(w, r)) { - addError(s"Cannot safely cast '$context': $w to $r") + addError(s"Cannot safely cast '$context': ${w.catalogString} to ${r.catalogString}") false } else { true @@ -467,7 +467,7 @@ object DataType { case (w: AtomicType, r: AtomicType) if storeAssignmentPolicy == ANSI => if (!Cast.canANSIStoreAssign(w, r)) { - addError(s"Cannot safely cast '$context': $w to $r") + addError(s"Cannot safely cast '$context': ${w.catalogString} to ${r.catalogString}") false } else { true @@ -477,7 +477,8 @@ object DataType { true case (w, r) => - addError(s"Cannot write '$context': $w is incompatible with $r") + addError(s"Cannot write '$context': " + + s"${w.catalogString} is incompatible with ${r.catalogString}") false } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index a7c20c34d78bc..6a5bdc4f6fc3d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -18,13 +18,16 @@ package org.apache.spark.sql import java.math.MathContext +import java.sql.{Date, Timestamp} +import java.time.{Instant, LocalDate, LocalDateTime, ZoneId} import scala.collection.mutable -import scala.util.Random +import scala.util.{Random, Try} import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_DAY +import org.apache.spark.sql.catalyst.util.DateTimeConstants.{MICROS_PER_MILLIS, MILLIS_PER_DAY} import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval /** @@ -161,36 +164,89 @@ object RandomDataGenerator { }) case BooleanType => Some(() => rand.nextBoolean()) case DateType => - val generator = - () => { - var milliseconds = rand.nextLong() % 253402329599999L - // -62135740800000L is the number of milliseconds before January 1, 1970, 00:00:00 GMT - // for "0001-01-01 00:00:00.000000". We need to find a - // number that is greater or equals to this number as a valid timestamp value. - while (milliseconds < -62135740800000L) { - // 253402329599999L is the number of milliseconds since - // January 1, 1970, 00:00:00 GMT for "9999-12-31 23:59:59.999999". - milliseconds = rand.nextLong() % 253402329599999L - } - DateTimeUtils.toJavaDate((milliseconds / MILLIS_PER_DAY).toInt) + def uniformDaysRand(rand: Random): Int = { + var milliseconds = rand.nextLong() % 253402329599999L + // -62135740800000L is the number of milliseconds before January 1, 1970, 00:00:00 GMT + // for "0001-01-01 00:00:00.000000". We need to find a + // number that is greater or equals to this number as a valid timestamp value. + while (milliseconds < -62135740800000L) { + // 253402329599999L is the number of milliseconds since + // January 1, 1970, 00:00:00 GMT for "9999-12-31 23:59:59.999999". + milliseconds = rand.nextLong() % 253402329599999L } - Some(generator) + (milliseconds / MILLIS_PER_DAY).toInt + } + val specialDates = Seq( + "0001-01-01", // the fist day of Common Era + "1582-10-15", // the cutover date from Julian to Gregorian calendar + "1970-01-01", // the epoch date + "9999-12-31" // the last supported date according to SQL standard + ) + if (SQLConf.get.getConf(SQLConf.DATETIME_JAVA8API_ENABLED)) { + randomNumeric[LocalDate]( + rand, + (rand: Random) => LocalDate.ofEpochDay(uniformDaysRand(rand)), + specialDates.map(LocalDate.parse)) + } else { + randomNumeric[java.sql.Date]( + rand, + (rand: Random) => { + val date = DateTimeUtils.toJavaDate(uniformDaysRand(rand)) + // The generated `date` is based on the hybrid calendar Julian + Gregorian since + // 1582-10-15 but it should be valid in Proleptic Gregorian calendar too which is used + // by Spark SQL since version 3.0 (see SPARK-26651). We try to convert `date` to + // a local date in Proleptic Gregorian calendar to satisfy this requirement. Some + // years are leap years in Julian calendar but not in Proleptic Gregorian calendar. + // As the consequence of that, 29 February of such years might not exist in Proleptic + // Gregorian calendar. When this happens, we shift the date by one day. + Try { date.toLocalDate; date }.getOrElse(new Date(date.getTime + MILLIS_PER_DAY)) + }, + specialDates.map(java.sql.Date.valueOf)) + } case TimestampType => - val generator = - () => { - var milliseconds = rand.nextLong() % 253402329599999L - // -62135740800000L is the number of milliseconds before January 1, 1970, 00:00:00 GMT - // for "0001-01-01 00:00:00.000000". We need to find a - // number that is greater or equals to this number as a valid timestamp value. - while (milliseconds < -62135740800000L) { - // 253402329599999L is the number of milliseconds since - // January 1, 1970, 00:00:00 GMT for "9999-12-31 23:59:59.999999". - milliseconds = rand.nextLong() % 253402329599999L - } - // DateTimeUtils.toJavaTimestamp takes microsecond. - DateTimeUtils.toJavaTimestamp(milliseconds * 1000) + def uniformMicorsRand(rand: Random): Long = { + var milliseconds = rand.nextLong() % 253402329599999L + // -62135740800000L is the number of milliseconds before January 1, 1970, 00:00:00 GMT + // for "0001-01-01 00:00:00.000000". We need to find a + // number that is greater or equals to this number as a valid timestamp value. + while (milliseconds < -62135740800000L) { + // 253402329599999L is the number of milliseconds since + // January 1, 1970, 00:00:00 GMT for "9999-12-31 23:59:59.999999". + milliseconds = rand.nextLong() % 253402329599999L } - Some(generator) + milliseconds * MICROS_PER_MILLIS + } + val specialTs = Seq( + "0001-01-01 00:00:00", // the fist timestamp of Common Era + "1582-10-15 23:59:59", // the cutover date from Julian to Gregorian calendar + "1970-01-01 00:00:00", // the epoch timestamp + "9999-12-31 23:59:59" // the last supported timestamp according to SQL standard + ) + if (SQLConf.get.getConf(SQLConf.DATETIME_JAVA8API_ENABLED)) { + randomNumeric[Instant]( + rand, + (rand: Random) => DateTimeUtils.microsToInstant(uniformMicorsRand(rand)), + specialTs.map { s => + val ldt = LocalDateTime.parse(s.replace(" ", "T")) + ldt.atZone(ZoneId.systemDefault()).toInstant + }) + } else { + randomNumeric[java.sql.Timestamp]( + rand, + (rand: Random) => { + // DateTimeUtils.toJavaTimestamp takes microsecond. + val ts = DateTimeUtils.toJavaTimestamp(uniformMicorsRand(rand)) + // The generated `ts` is based on the hybrid calendar Julian + Gregorian since + // 1582-10-15 but it should be valid in Proleptic Gregorian calendar too which is used + // by Spark SQL since version 3.0 (see SPARK-26651). We try to convert `ts` to + // a local timestamp in Proleptic Gregorian calendar to satisfy this requirement. Some + // years are leap years in Julian calendar but not in Proleptic Gregorian calendar. + // As the consequence of that, 29 February of such years might not exist in Proleptic + // Gregorian calendar. When this happens, we shift the timestamp `ts` by one day. + Try { ts.toLocalDateTime; ts }.getOrElse(new Timestamp(ts.getTime + MILLIS_PER_DAY)) + }, + specialTs.map(java.sql.Timestamp.valueOf)) + } case CalendarIntervalType => Some(() => { val months = rand.nextInt(1000) val days = rand.nextInt(10000) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala index 3e62ca069e9ea..cb335e5f435a3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala @@ -24,30 +24,36 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** * Tests of [[RandomDataGenerator]]. */ -class RandomDataGeneratorSuite extends SparkFunSuite { +class RandomDataGeneratorSuite extends SparkFunSuite with SQLHelper { /** * Tests random data generation for the given type by using it to generate random values then * converting those values into their Catalyst equivalents using CatalystTypeConverters. */ def testRandomDataGeneration(dataType: DataType, nullable: Boolean = true): Unit = { - val toCatalyst = CatalystTypeConverters.createToCatalystConverter(dataType) - val generator = RandomDataGenerator.forType(dataType, nullable, new Random(33)).getOrElse { - fail(s"Random data generator was not defined for $dataType") - } - if (nullable) { - assert(Iterator.fill(100)(generator()).contains(null)) - } else { - assert(!Iterator.fill(100)(generator()).contains(null)) - } - for (_ <- 1 to 10) { - val generatedValue = generator() - toCatalyst(generatedValue) + Seq(false, true).foreach { java8Api => + withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> java8Api.toString) { + val toCatalyst = CatalystTypeConverters.createToCatalystConverter(dataType) + val generator = RandomDataGenerator.forType(dataType, nullable, new Random(33)).getOrElse { + fail(s"Random data generator was not defined for $dataType") + } + if (nullable) { + assert(Iterator.fill(100)(generator()).contains(null)) + } else { + assert(!Iterator.fill(100)(generator()).contains(null)) + } + for (_ <- 1 to 10) { + val generatedValue = generator() + toCatalyst(generatedValue) + } + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala index c01dea96fe2de..e466d558db1ef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala @@ -21,7 +21,7 @@ import java.net.URI import java.util.Locale import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog} -import org.apache.spark.sql.catalyst.expressions.{Alias, AnsiCast, AttributeReference, Cast, Expression, LessThanOrEqual, Literal} +import org.apache.spark.sql.catalyst.expressions.{Alias, AnsiCast, AttributeReference, Cast, LessThanOrEqual, Literal} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy @@ -143,7 +143,7 @@ abstract class DataSourceV2StrictAnalysisSuite extends DataSourceV2AnalysisBaseS assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( "Cannot write", "'table-name'", - "Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType")) + "Cannot safely cast", "'x'", "'y'", "double to float")) } test("byName: multiple field errors are reported") { @@ -160,7 +160,7 @@ abstract class DataSourceV2StrictAnalysisSuite extends DataSourceV2AnalysisBaseS assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( "Cannot write incompatible data to table", "'table-name'", - "Cannot safely cast", "'x'", "DoubleType to FloatType", + "Cannot safely cast", "'x'", "double to float", "Cannot write nullable values to non-null column", "'x'", "Cannot find data for output column", "'y'")) } @@ -176,7 +176,7 @@ abstract class DataSourceV2StrictAnalysisSuite extends DataSourceV2AnalysisBaseS assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( "Cannot write", "'table-name'", - "Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType")) + "Cannot safely cast", "'x'", "'y'", "double to float")) } test("byPosition: multiple field errors are reported") { @@ -194,7 +194,7 @@ abstract class DataSourceV2StrictAnalysisSuite extends DataSourceV2AnalysisBaseS assertAnalysisError(parsedPlan, Seq( "Cannot write incompatible data to table", "'table-name'", "Cannot write nullable values to non-null column", "'x'", - "Cannot safely cast", "'x'", "DoubleType to FloatType")) + "Cannot safely cast", "'x'", "double to float")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index e37555f1c0ec3..1ea1ddb8bbd08 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1559,6 +1559,30 @@ class TypeCoercionSuite extends AnalysisTest { Literal.create(null, DecimalType.SYSTEM_DEFAULT))) } } + + test("SPARK-31761: byte, short and int should be cast to long for IntegralDivide's datatype") { + val rules = Seq(FunctionArgumentConversion, Division, ImplicitTypeCasts) + // Casts Byte to Long + ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2.toByte, 1.toByte), + IntegralDivide(Cast(2.toByte, LongType), Cast(1.toByte, LongType))) + // Casts Short to Long + ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2.toShort, 1.toShort), + IntegralDivide(Cast(2.toShort, LongType), Cast(1.toShort, LongType))) + // Casts Integer to Long + ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2, 1), + IntegralDivide(Cast(2, LongType), Cast(1, LongType))) + // should not be any change for Long data types + ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2L, 1L), IntegralDivide(2L, 1L)) + // one of the operand is byte + ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2L, 1.toByte), + IntegralDivide(2L, Cast(1.toByte, LongType))) + // one of the operand is short + ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2.toShort, 1L), + IntegralDivide(Cast(2.toShort, LongType), 1L)) + // one of the operand is int + ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2, 1L), + IntegralDivide(Cast(2, LongType), 1L)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala index 0e9fcc980aabb..822008007ebbc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala @@ -325,7 +325,7 @@ class UnivocityParserSuite extends SparkFunSuite with SQLHelper { assert(parser.makeConverter("t", TimestampType).apply("2020-1-12 12:3:45") == date(2020, 1, 12, 12, 3, 45, 0)) assert(parser.makeConverter("t", DateType).apply("2020-1-12") == - days(2020, 1, 12, 0, 0, 0)) + days(2020, 1, 12)) // The legacy format allows arbitrary length of second fraction. assert(parser.makeConverter("t", TimestampType).apply("2020-1-12 12:3:45.1") == date(2020, 1, 12, 12, 3, 45, 100000)) @@ -333,22 +333,22 @@ class UnivocityParserSuite extends SparkFunSuite with SQLHelper { date(2020, 1, 12, 12, 3, 45, 123400)) // The legacy format allow date string to end with T or space, with arbitrary string assert(parser.makeConverter("t", DateType).apply("2020-1-12T") == - days(2020, 1, 12, 0, 0, 0)) + days(2020, 1, 12)) assert(parser.makeConverter("t", DateType).apply("2020-1-12Txyz") == - days(2020, 1, 12, 0, 0, 0)) + days(2020, 1, 12)) assert(parser.makeConverter("t", DateType).apply("2020-1-12 ") == - days(2020, 1, 12, 0, 0, 0)) + days(2020, 1, 12)) assert(parser.makeConverter("t", DateType).apply("2020-1-12 xyz") == - days(2020, 1, 12, 0, 0, 0)) + days(2020, 1, 12)) // The legacy format ignores the "GMT" from the string assert(parser.makeConverter("t", TimestampType).apply("2020-1-12 12:3:45GMT") == date(2020, 1, 12, 12, 3, 45, 0)) assert(parser.makeConverter("t", TimestampType).apply("GMT2020-1-12 12:3:45") == date(2020, 1, 12, 12, 3, 45, 0)) assert(parser.makeConverter("t", DateType).apply("2020-1-12GMT") == - days(2020, 1, 12, 0, 0, 0)) + days(2020, 1, 12)) assert(parser.makeConverter("t", DateType).apply("GMT2020-1-12") == - days(2020, 1, 12, 0, 0, 0)) + days(2020, 1, 12)) } val options = new CSVOptions(Map.empty[String, String], false, "UTC") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 48f4ef5051fb3..577814b9c6696 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -22,9 +22,9 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -247,6 +247,13 @@ class EncoderResolutionSuite extends PlanTest { """.stripMargin.trim + " of the field in the target object") } + test("SPARK-31750: eliminate UpCast if child's dataType is DecimalType") { + val encoder = ExpressionEncoder[Seq[BigDecimal]] + val attr = Seq(AttributeReference("a", ArrayType(DecimalType(38, 0)))()) + // Before SPARK-31750, it will fail because Decimal(38, 0) can not be casted to Decimal(38, 18) + testFromRow(encoder, attr, InternalRow(ArrayData.toArrayData(Array(Decimal(1.0))))) + } + // test for leaf types castSuccess[Int, Long] castSuccess[java.sql.Date, java.sql.Timestamp] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index c1158e001a780..fd24f058f357c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -377,23 +377,27 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { private def encodeDecodeTest(schema: StructType): Unit = { test(s"encode/decode: ${schema.simpleString}") { - val encoder = RowEncoder(schema).resolveAndBind() - val inputGenerator = RandomDataGenerator.forType(schema, nullable = false).get - - var input: Row = null - try { - for (_ <- 1 to 5) { - input = inputGenerator.apply().asInstanceOf[Row] - val convertedBack = roundTrip(encoder, input) - assert(input == convertedBack) + Seq(false, true).foreach { java8Api => + withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> java8Api.toString) { + val encoder = RowEncoder(schema).resolveAndBind() + val inputGenerator = RandomDataGenerator.forType(schema, nullable = false).get + + var input: Row = null + try { + for (_ <- 1 to 5) { + input = inputGenerator.apply().asInstanceOf[Row] + val convertedBack = roundTrip(encoder, input) + assert(input == convertedBack) + } + } catch { + case e: Exception => + fail( + s""" + |schema: ${schema.simpleString} + |input: ${input} + """.stripMargin, e) + } } - } catch { - case e: Exception => - fail( - s""" - |schema: ${schema.simpleString} - |input: ${input} - """.stripMargin, e) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 675f85f9e82ea..f05598aeb5353 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -173,13 +173,8 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } } - test("/ (Divide) for integral type") { - checkEvaluation(IntegralDivide(Literal(1.toByte), Literal(2.toByte)), 0L) - checkEvaluation(IntegralDivide(Literal(1.toShort), Literal(2.toShort)), 0L) - checkEvaluation(IntegralDivide(Literal(1), Literal(2)), 0L) + test("/ (Divide) for Long type") { checkEvaluation(IntegralDivide(Literal(1.toLong), Literal(2.toLong)), 0L) - checkEvaluation(IntegralDivide(positiveShortLit, negativeShortLit), 0L) - checkEvaluation(IntegralDivide(positiveIntLit, negativeIntLit), 0L) checkEvaluation(IntegralDivide(positiveLongLit, negativeLongLit), 0L) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index ee94f3587b55c..6af995cab64fe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -240,7 +240,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkCast(1.5, "1.5") checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble) - checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble) } test("cast from string") { @@ -1299,6 +1298,18 @@ class CastSuite extends CastSuiteBase { } } } + + test("cast a timestamp before the epoch 1970-01-01 00:00:00Z") { + withDefaultTimeZone(UTC) { + val negativeTs = Timestamp.valueOf("1900-05-05 18:34:56.1") + assert(negativeTs.getTime < 0) + val expectedSecs = Math.floorDiv(negativeTs.getTime, MILLIS_PER_SECOND) + checkEvaluation(cast(negativeTs, ByteType), expectedSecs.toByte) + checkEvaluation(cast(negativeTs, ShortType), expectedSecs.toShort) + checkEvaluation(cast(negativeTs, IntegerType), expectedSecs.toInt) + checkEvaluation(cast(negativeTs, LongType), expectedSecs) + } + } } /** @@ -1341,4 +1352,17 @@ class AnsiCastSuite extends CastSuiteBase { cast("abc.com", dataType), "invalid input") } } + + test("cast a timestamp before the epoch 1970-01-01 00:00:00Z") { + def errMsg(t: String): String = s"Casting -2198208303900000 to $t causes overflow" + withDefaultTimeZone(UTC) { + val negativeTs = Timestamp.valueOf("1900-05-05 18:34:56.1") + assert(negativeTs.getTime < 0) + val expectedSecs = Math.floorDiv(negativeTs.getTime, MILLIS_PER_SECOND) + checkExceptionInExpression[ArithmeticException](cast(negativeTs, ByteType), errMsg("byte")) + checkExceptionInExpression[ArithmeticException](cast(negativeTs, ShortType), errMsg("short")) + checkExceptionInExpression[ArithmeticException](cast(negativeTs, IntegerType), errMsg("int")) + checkEvaluation(cast(negativeTs, LongType), expectedSecs) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 6e8397d12da78..c038b7a9d476a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -267,7 +267,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // Test escaping of format GenerateUnsafeProjection.generate( - DateFormatClass(Literal(ts), Literal("\"quote"), JST_OPT) :: Nil) + DateFormatClass(Literal(ts), Literal("\""), JST_OPT) :: Nil) // SPARK-28072 The codegen path should work checkEvaluation( @@ -792,7 +792,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } // Test escaping of format - GenerateUnsafeProjection.generate(FromUnixTime(Literal(0L), Literal("\"quote")) :: Nil) + GenerateUnsafeProjection.generate(FromUnixTime(Literal(0L), Literal("\"quote"), UTC_OPT) :: Nil) } test("unix_timestamp") { @@ -862,7 +862,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } // Test escaping of format GenerateUnsafeProjection.generate( - UnixTimestamp(Literal("2015-07-24"), Literal("\"quote")) :: Nil) + UnixTimestamp(Literal("2015-07-24"), Literal("\"quote"), UTC_OPT) :: Nil) } test("to_unix_timestamp") { @@ -940,7 +940,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } // Test escaping of format GenerateUnsafeProjection.generate( - ToUnixTimestamp(Literal("2015-07-24"), Literal("\"quote")) :: Nil) + ToUnixTimestamp(Literal("2015-07-24"), Literal("\"quote"), UTC_OPT) :: Nil) } test("datediff") { @@ -1146,4 +1146,65 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Literal("yyyy-MM-dd'T'HH:mm:ss.SSSz")), "Fail to parse") } } + + test("SPARK-31710:Adds TIMESTAMP_SECONDS, " + + "TIMESTAMP_MILLIS and TIMESTAMP_MICROS functions") { + checkEvaluation(SecondsToTimestamp(Literal(1230219000)), 1230219000L * MICROS_PER_SECOND) + checkEvaluation(SecondsToTimestamp(Literal(-1230219000)), -1230219000L * MICROS_PER_SECOND) + checkEvaluation(SecondsToTimestamp(Literal(null, IntegerType)), null) + checkEvaluation(MillisToTimestamp(Literal(1230219000123L)), 1230219000123L * MICROS_PER_MILLIS) + checkEvaluation(MillisToTimestamp( + Literal(-1230219000123L)), -1230219000123L * MICROS_PER_MILLIS) + checkEvaluation(MillisToTimestamp(Literal(null, IntegerType)), null) + checkEvaluation(MicrosToTimestamp(Literal(1230219000123123L)), 1230219000123123L) + checkEvaluation(MicrosToTimestamp(Literal(-1230219000123123L)), -1230219000123123L) + checkEvaluation(MicrosToTimestamp(Literal(null, IntegerType)), null) + checkExceptionInExpression[ArithmeticException]( + SecondsToTimestamp(Literal(1230219000123123L)), "long overflow") + checkExceptionInExpression[ArithmeticException]( + SecondsToTimestamp(Literal(-1230219000123123L)), "long overflow") + checkExceptionInExpression[ArithmeticException]( + MillisToTimestamp(Literal(92233720368547758L)), "long overflow") + checkExceptionInExpression[ArithmeticException]( + MillisToTimestamp(Literal(-92233720368547758L)), "long overflow") + } + + test("Disable week-based date fields and quarter fields for parsing") { + + def checkSparkUpgrade(c: Char): Unit = { + checkExceptionInExpression[SparkUpgradeException]( + new ParseToTimestamp(Literal("1"), Literal(c.toString)).child, "3.0") + checkExceptionInExpression[SparkUpgradeException]( + new ParseToDate(Literal("1"), Literal(c.toString)).child, "3.0") + checkExceptionInExpression[SparkUpgradeException]( + ToUnixTimestamp(Literal("1"), Literal(c.toString)), "3.0") + checkExceptionInExpression[SparkUpgradeException]( + UnixTimestamp(Literal("1"), Literal(c.toString)), "3.0") + } + + def checkNullify(c: Char): Unit = { + checkEvaluation(new ParseToTimestamp(Literal("1"), Literal(c.toString)).child, null) + checkEvaluation(new ParseToDate(Literal("1"), Literal(c.toString)).child, null) + checkEvaluation(ToUnixTimestamp(Literal("1"), Literal(c.toString)), null) + checkEvaluation(UnixTimestamp(Literal("1"), Literal(c.toString)), null) + } + + Seq('Y', 'W', 'w', 'E', 'u', 'F').foreach { l => + checkSparkUpgrade(l) + } + + Seq('q', 'Q').foreach { l => + checkNullify(l) + } + } + + + test("SPARK-31896: Handle am-pm timestamp parsing when hour is missing") { + checkEvaluation( + new ParseToTimestamp(Literal("PM"), Literal("a")).child, + Timestamp.valueOf("1970-01-01 12:00:00.0")) + checkEvaluation( + new ParseToTimestamp(Literal("11:11 PM"), Literal("mm:ss a")).child, + Timestamp.valueOf("1970-01-01 12:11:11.0")) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala index f2696849d7753..4e6976f76ea5f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.Timestamp -import java.util.TimeZone import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ @@ -35,15 +34,7 @@ class SortOrderExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val l1 = Literal.create(20132983L, LongType) val l2 = Literal.create(-20132983L, LongType) val millis = 1524954911000L - // Explicitly choose a time zone, since Date objects can create different values depending on - // local time zone of the machine on which the test is running - val oldDefaultTZ = TimeZone.getDefault - val d1 = try { - TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) - Literal.create(new java.sql.Date(millis), DateType) - } finally { - TimeZone.setDefault(oldDefaultTZ) - } + val d1 = Literal.create(new java.sql.Date(millis), DateType) val t1 = Literal.create(new Timestamp(millis), TimestampType) val f1 = Literal.create(0.7788229f, FloatType) val f2 = Literal.create(-0.7788229f, FloatType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinSelectionHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinSelectionHelperSuite.scala new file mode 100644 index 0000000000000..3513cfa14808f --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinSelectionHelperSuite.scala @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.AttributeMap +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, JoinHint, NO_BROADCAST_HASH, SHUFFLE_HASH} +import org.apache.spark.sql.catalyst.statsEstimation.StatsTestPlan +import org.apache.spark.sql.internal.SQLConf + +class JoinSelectionHelperSuite extends PlanTest with JoinSelectionHelper { + + private val left = StatsTestPlan( + outputList = Seq('a.int, 'b.int, 'c.int), + rowCount = 20000000, + size = Some(20000000), + attributeStats = AttributeMap(Seq())) + + private val right = StatsTestPlan( + outputList = Seq('d.int), + rowCount = 1000, + size = Some(1000), + attributeStats = AttributeMap(Seq())) + + private val hintBroadcast = Some(HintInfo(Some(BROADCAST))) + private val hintNotToBroadcast = Some(HintInfo(Some(NO_BROADCAST_HASH))) + private val hintShuffleHash = Some(HintInfo(Some(SHUFFLE_HASH))) + + test("getBroadcastBuildSide (hintOnly = true) return BuildLeft with only a left hint") { + val broadcastSide = getBroadcastBuildSide( + left, + right, + Inner, + JoinHint(hintBroadcast, None), + hintOnly = true, + SQLConf.get + ) + assert(broadcastSide === Some(BuildLeft)) + } + + test("getBroadcastBuildSide (hintOnly = true) return BuildRight with only a right hint") { + val broadcastSide = getBroadcastBuildSide( + left, + right, + Inner, + JoinHint(None, hintBroadcast), + hintOnly = true, + SQLConf.get + ) + assert(broadcastSide === Some(BuildRight)) + } + + test("getBroadcastBuildSide (hintOnly = true) return smaller side with both having hints") { + val broadcastSide = getBroadcastBuildSide( + left, + right, + Inner, + JoinHint(hintBroadcast, hintBroadcast), + hintOnly = true, + SQLConf.get + ) + assert(broadcastSide === Some(BuildRight)) + } + + test("getBroadcastBuildSide (hintOnly = true) return None when no side has a hint") { + val broadcastSide = getBroadcastBuildSide( + left, + right, + Inner, + JoinHint(None, None), + hintOnly = true, + SQLConf.get + ) + assert(broadcastSide === None) + } + + test("getBroadcastBuildSide (hintOnly = false) return BuildRight when right is broadcastable") { + val broadcastSide = getBroadcastBuildSide( + left, + right, + Inner, + JoinHint(None, None), + hintOnly = false, + SQLConf.get + ) + assert(broadcastSide === Some(BuildRight)) + } + + test("getBroadcastBuildSide (hintOnly = false) return None when right has no broadcast hint") { + val broadcastSide = getBroadcastBuildSide( + left, + right, + Inner, + JoinHint(None, hintNotToBroadcast ), + hintOnly = false, + SQLConf.get + ) + assert(broadcastSide === None) + } + + test("getShuffleHashJoinBuildSide (hintOnly = true) return BuildLeft with only a left hint") { + val broadcastSide = getShuffleHashJoinBuildSide( + left, + right, + Inner, + JoinHint(hintShuffleHash, None), + hintOnly = true, + SQLConf.get + ) + assert(broadcastSide === Some(BuildLeft)) + } + + test("getShuffleHashJoinBuildSide (hintOnly = true) return BuildRight with only a right hint") { + val broadcastSide = getShuffleHashJoinBuildSide( + left, + right, + Inner, + JoinHint(None, hintShuffleHash), + hintOnly = true, + SQLConf.get + ) + assert(broadcastSide === Some(BuildRight)) + } + + test("getShuffleHashJoinBuildSide (hintOnly = true) return smaller side when both have hints") { + val broadcastSide = getShuffleHashJoinBuildSide( + left, + right, + Inner, + JoinHint(hintShuffleHash, hintShuffleHash), + hintOnly = true, + SQLConf.get + ) + assert(broadcastSide === Some(BuildRight)) + } + + test("getShuffleHashJoinBuildSide (hintOnly = true) return None when no side has a hint") { + val broadcastSide = getShuffleHashJoinBuildSide( + left, + right, + Inner, + JoinHint(None, None), + hintOnly = true, + SQLConf.get + ) + assert(broadcastSide === None) + } + + test("getShuffleHashJoinBuildSide (hintOnly = false) return BuildRight when right is smaller") { + val broadcastSide = getBroadcastBuildSide( + left, + right, + Inner, + JoinHint(None, None), + hintOnly = false, + SQLConf.get + ) + assert(broadcastSide === Some(BuildRight)) + } + + test("getSmallerSide should return BuildRight") { + assert(getSmallerSide(left, right) === BuildRight) + } + + test("canBroadcastBySize should return true if the plan size is less than 10MB") { + assert(canBroadcastBySize(left, SQLConf.get) === false) + assert(canBroadcastBySize(right, SQLConf.get) === true) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 9c31f07f293be..6499b5d8e7974 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -2196,21 +2196,20 @@ class DDLParserSuite extends AnalysisTest { CommentOnTable(UnresolvedTable(Seq("a", "b", "c")), "xYz")) } - test("create table - without using") { - withSQLConf(SQLConf.LEGACY_CREATE_HIVE_TABLE_BY_DEFAULT_ENABLED.key -> "false") { - val sql = "CREATE TABLE 1m.2g(a INT)" - val expectedTableSpec = TableSpec( - Seq("1m", "2g"), - Some(new StructType().add("a", IntegerType)), - Seq.empty[Transform], - None, - Map.empty[String, String], - None, - Map.empty[String, String], - None, - None) + // TODO: ignored by SPARK-31707, restore the test after create table syntax unification + ignore("create table - without using") { + val sql = "CREATE TABLE 1m.2g(a INT)" + val expectedTableSpec = TableSpec( + Seq("1m", "2g"), + Some(new StructType().add("a", IntegerType)), + Seq.empty[Transform], + None, + Map.empty[String, String], + None, + Map.empty[String, String], + None, + None) - testCreateOrReplaceDdl(sql, expectedTableSpec, expectedIfNotExists = false) - } + testCreateOrReplaceDdl(sql, expectedTableSpec, expectedIfNotExists = false) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index bec35ae458763..88afcb10d9c20 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -55,11 +55,16 @@ class PlanParserSuite extends AnalysisTest { With(plan, ctes) } - test("single comment") { + test("single comment case one") { val plan = table("a").select(star()) assertEqual("-- single comment\nSELECT * FROM a", plan) } + test("single comment case two") { + val plan = table("a").select(star()) + assertEqual("-- single comment\\\nwith line continuity\nSELECT * FROM a", plan) + } + test("bracketed comment case one") { val plan = table("a").select(star()) assertEqual( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelperSuite.scala index 817e503584324..c68bdacb13af7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelperSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelperSuite.scala @@ -40,6 +40,23 @@ class DateTimeFormatterHelperSuite extends SparkFunSuite { val e = intercept[IllegalArgumentException](convertIncompatiblePattern(s"yyyy-MM-dd $l G")) assert(e.getMessage === s"Illegal pattern character: $l") } + unsupportedLettersForParsing.foreach { l => + val e = intercept[IllegalArgumentException] { + convertIncompatiblePattern(s"$l", isParsing = true) + } + assert(e.getMessage === s"Illegal pattern character: $l") + assert(convertIncompatiblePattern(s"$l").nonEmpty) + } + unsupportedPatternLengths.foreach { style => + val e1 = intercept[IllegalArgumentException] { + convertIncompatiblePattern(s"yyyy-MM-dd $style") + } + assert(e1.getMessage === s"Too many pattern letters: ${style.head}") + val e2 = intercept[IllegalArgumentException] { + convertIncompatiblePattern(s"yyyy-MM-dd $style${style.head}") + } + assert(e2.getMessage === s"Too many pattern letters: ${style.head}") + } assert(convertIncompatiblePattern("yyyy-MM-dd uuuu") === "uuuu-MM-dd eeee") assert(convertIncompatiblePattern("yyyy-MM-dd EEEE") === "uuuu-MM-dd EEEE") assert(convertIncompatiblePattern("yyyy-MM-dd'e'HH:mm:ss") === "uuuu-MM-dd'e'HH:mm:ss") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeTestUtils.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeTestUtils.scala index bf9e8f71ba1c9..66aef1b4b6cb0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeTestUtils.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeTestUtils.scala @@ -88,12 +88,8 @@ object DateTimeTestUtils { def days( year: Int, month: Byte = 1, - day: Byte = 1, - hour: Byte = 0, - minute: Byte = 0, - sec: Byte = 0): Int = { - val micros = date(year, month, day, hour, minute, sec) - TimeUnit.MICROSECONDS.toDays(micros).toInt + day: Byte = 1): Int = { + LocalDate.of(year, month, day).toEpochDay.toInt } // Returns microseconds since epoch for current date and give time diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 807ec7dafb568..4883bef8c0886 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -386,13 +386,13 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper { } test("date add months") { - val input = days(1997, 2, 28, 10, 30) + val input = days(1997, 2, 28) assert(dateAddMonths(input, 36) === days(2000, 2, 28)) assert(dateAddMonths(input, -13) === days(1996, 1, 28)) } test("date add interval with day precision") { - val input = days(1997, 2, 28, 10, 30) + val input = days(1997, 2, 28) assert(dateAddInterval(input, new CalendarInterval(36, 0, 0)) === days(2000, 2, 28)) assert(dateAddInterval(input, new CalendarInterval(36, 47, 0)) === days(2000, 4, 15)) assert(dateAddInterval(input, new CalendarInterval(-13, 0, 0)) === days(1996, 1, 28)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala index c47332f5d9fcb..1a262d646ca10 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala @@ -80,7 +80,7 @@ class StrictDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBa test("Check NullType is incompatible with all other types") { allNonNullTypes.foreach { t => assertSingleError(NullType, t, "nulls", s"Should not allow writing None to type $t") { err => - assert(err.contains(s"incompatible with $t")) + assert(err.contains(s"incompatible with ${t.catalogString}")) } } } @@ -145,12 +145,12 @@ class ANSIDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBase test("Conversions between timestamp and long are not allowed") { assertSingleError(LongType, TimestampType, "longToTimestamp", "Should not allow long to timestamp") { err => - assert(err.contains("Cannot safely cast 'longToTimestamp': LongType to TimestampType")) + assert(err.contains("Cannot safely cast 'longToTimestamp': bigint to timestamp")) } assertSingleError(TimestampType, LongType, "timestampToLong", "Should not allow timestamp to long") { err => - assert(err.contains("Cannot safely cast 'timestampToLong': TimestampType to LongType")) + assert(err.contains("Cannot safely cast 'timestampToLong': timestamp to bigint")) } } @@ -209,8 +209,8 @@ abstract class DataTypeWriteCompatibilityBaseSuite extends SparkFunSuite { s"Should not allow writing $w to $r because cast is not safe") { err => assert(err.contains("'t'"), "Should include the field name context") assert(err.contains("Cannot safely cast"), "Should identify unsafe cast") - assert(err.contains(s"$w"), "Should include write type") - assert(err.contains(s"$r"), "Should include read type") + assert(err.contains(s"${w.catalogString}"), "Should include write type") + assert(err.contains(s"${r.catalogString}"), "Should include read type") } } } @@ -413,7 +413,7 @@ abstract class DataTypeWriteCompatibilityBaseSuite extends SparkFunSuite { assertNumErrors(writeType, readType, "top", "Should catch 14 errors", 14) { errs => assert(errs(0).contains("'top.a.element'"), "Should identify bad type") assert(errs(0).contains("Cannot safely cast")) - assert(errs(0).contains("StringType to DoubleType")) + assert(errs(0).contains("string to double")) assert(errs(1).contains("'top.a'"), "Should identify bad type") assert(errs(1).contains("Cannot write nullable elements to array of non-nulls")) @@ -430,11 +430,11 @@ abstract class DataTypeWriteCompatibilityBaseSuite extends SparkFunSuite { assert(errs(5).contains("'top.m.key'"), "Should identify bad type") assert(errs(5).contains("Cannot safely cast")) - assert(errs(5).contains("StringType to LongType")) + assert(errs(5).contains("string to bigint")) assert(errs(6).contains("'top.m.value'"), "Should identify bad type") assert(errs(6).contains("Cannot safely cast")) - assert(errs(6).contains("BooleanType to FloatType")) + assert(errs(6).contains("boolean to float")) assert(errs(7).contains("'top.m'"), "Should identify bad type") assert(errs(7).contains("Cannot write nullable values to map of non-nulls")) @@ -452,7 +452,7 @@ abstract class DataTypeWriteCompatibilityBaseSuite extends SparkFunSuite { assert(errs(11).contains("'top.x'"), "Should identify bad type") assert(errs(11).contains("Cannot safely cast")) - assert(errs(11).contains("StringType to IntegerType")) + assert(errs(11).contains("string to int")) assert(errs(12).contains("'top'"), "Should identify bad type") assert(errs(12).contains("expected 'x', found 'y'"), "Should detect name mismatch") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateFormatterSuite.scala index 5e2b6a7c7fafe..22a1396d5efdc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateFormatterSuite.scala @@ -17,18 +17,19 @@ package org.apache.spark.sql.util -import java.time.{DateTimeException, LocalDate, ZoneOffset} +import java.time.{DateTimeException, LocalDate} import org.apache.spark.{SparkFunSuite, SparkUpgradeException} import org.apache.spark.sql.catalyst.plans.SQLHelper -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, localDateToDays} +import org.apache.spark.sql.catalyst.util.{DateFormatter, LegacyDateFormats} +import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy class DateFormatterSuite extends SparkFunSuite with SQLHelper { test("parsing dates") { - DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => + outstandingTimezonesIds.foreach { timeZone => withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) { val formatter = DateFormatter(getZoneId(timeZone)) val daysSinceEpoch = formatter.parse("2018-12-02") @@ -38,11 +39,14 @@ class DateFormatterSuite extends SparkFunSuite with SQLHelper { } test("format dates") { - DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => + outstandingTimezonesIds.foreach { timeZone => withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) { val formatter = DateFormatter(getZoneId(timeZone)) - val date = formatter.format(17867) - assert(date === "2018-12-02") + val (days, expected) = (17867, "2018-12-02") + val date = formatter.format(days) + assert(date === expected) + assert(formatter.format(daysToLocalDate(days)) === expected) + assert(formatter.format(toJavaDate(days)) === expected) } } } @@ -62,16 +66,18 @@ class DateFormatterSuite extends SparkFunSuite with SQLHelper { "2018-12-12", "2038-01-01", "5010-11-17").foreach { date => - DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => + outstandingTimezonesIds.foreach { timeZone => withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) { val formatter = DateFormatter( DateFormatter.defaultPattern, getZoneId(timeZone), DateFormatter.defaultLocale, - legacyFormat) + legacyFormat, + isParsing = false) val days = formatter.parse(date) - val formatted = formatter.format(days) - assert(date === formatted) + assert(date === formatter.format(days)) + assert(date === formatter.format(daysToLocalDate(days))) + assert(date === formatter.format(toJavaDate(days))) } } } @@ -95,13 +101,14 @@ class DateFormatterSuite extends SparkFunSuite with SQLHelper { 17877, 24837, 1110657).foreach { days => - DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => + outstandingTimezonesIds.foreach { timeZone => withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) { val formatter = DateFormatter( DateFormatter.defaultPattern, getZoneId(timeZone), DateFormatter.defaultLocale, - legacyFormat) + legacyFormat, + isParsing = false) val date = formatter.format(days) val parsed = formatter.parse(date) assert(days === parsed) @@ -114,14 +121,14 @@ class DateFormatterSuite extends SparkFunSuite with SQLHelper { } test("parsing date without explicit day") { - val formatter = DateFormatter("yyyy MMM", ZoneOffset.UTC) + val formatter = DateFormatter("yyyy MMM", UTC) val daysSinceEpoch = formatter.parse("2018 Dec") - assert(daysSinceEpoch === LocalDate.of(2018, 12, 1).toEpochDay) + assert(daysSinceEpoch === days(2018, 12, 1)) } test("formatting negative years with default pattern") { - val epochDays = LocalDate.of(-99, 1, 1).toEpochDay.toInt - assert(DateFormatter(ZoneOffset.UTC).format(epochDays) === "-0099-01-01") + val epochDays = days(-99, 1, 1) + assert(DateFormatter(UTC).format(epochDays) === "-0099-01-01") } test("special date values") { @@ -138,8 +145,8 @@ class DateFormatterSuite extends SparkFunSuite with SQLHelper { } test("SPARK-30958: parse date with negative year") { - val formatter1 = DateFormatter("yyyy-MM-dd", ZoneOffset.UTC) - assert(formatter1.parse("-1234-02-22") === localDateToDays(LocalDate.of(-1234, 2, 22))) + val formatter1 = DateFormatter("yyyy-MM-dd", UTC) + assert(formatter1.parse("-1234-02-22") === days(-1234, 2, 22)) def assertParsingError(f: => Unit): Unit = { intercept[Exception](f) match { @@ -151,29 +158,45 @@ class DateFormatterSuite extends SparkFunSuite with SQLHelper { } // "yyyy" with "G" can't parse negative year or year 0000. - val formatter2 = DateFormatter("G yyyy-MM-dd", ZoneOffset.UTC) + val formatter2 = DateFormatter("G yyyy-MM-dd", UTC) assertParsingError(formatter2.parse("BC -1234-02-22")) assertParsingError(formatter2.parse("AD 0000-02-22")) - assert(formatter2.parse("BC 1234-02-22") === localDateToDays(LocalDate.of(-1233, 2, 22))) - assert(formatter2.parse("AD 1234-02-22") === localDateToDays(LocalDate.of(1234, 2, 22))) + assert(formatter2.parse("BC 1234-02-22") === days(-1233, 2, 22)) + assert(formatter2.parse("AD 1234-02-22") === days(1234, 2, 22)) } test("SPARK-31557: rebasing in legacy formatters/parsers") { withSQLConf(SQLConf.LEGACY_TIME_PARSER_POLICY.key -> LegacyBehaviorPolicy.LEGACY.toString) { LegacyDateFormats.values.foreach { legacyFormat => - DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => + outstandingTimezonesIds.foreach { timeZone => withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) { val formatter = DateFormatter( DateFormatter.defaultPattern, getZoneId(timeZone), DateFormatter.defaultLocale, - legacyFormat) + legacyFormat, + isParsing = false) assert(LocalDate.ofEpochDay(formatter.parse("1000-01-01")) === LocalDate.of(1000, 1, 1)) + assert(formatter.format(LocalDate.of(1000, 1, 1)) === "1000-01-01") assert(formatter.format(localDateToDays(LocalDate.of(1000, 1, 1))) === "1000-01-01") + assert(formatter.format(java.sql.Date.valueOf("1000-01-01")) === "1000-01-01") } } } } } + + test("missing date fields") { + val formatter = DateFormatter("HH", UTC) + val daysSinceEpoch = formatter.parse("20") + assert(daysSinceEpoch === days(1970, 1, 1)) + } + + test("missing year field with invalid date") { + val formatter = DateFormatter("MM-dd", UTC) + // The date parser in 2.4 accepts 1970-02-29 and turn it into 1970-03-01, so we should get a + // SparkUpgradeException here. + intercept[SparkUpgradeException](formatter.parse("02-29")) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala index 5d27a6b8cce1e..1530ac4e24da2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala @@ -17,15 +17,15 @@ package org.apache.spark.sql.util -import java.time.{DateTimeException, Instant, LocalDateTime, LocalTime, ZoneOffset} +import java.time.{DateTimeException, Instant, LocalDateTime, LocalTime} import java.util.concurrent.TimeUnit import org.scalatest.Matchers import org.apache.spark.{SparkFunSuite, SparkUpgradeException} import org.apache.spark.sql.catalyst.plans.SQLHelper -import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils, LegacyDateFormats, TimestampFormatter} -import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{CET, PST, UTC} +import org.apache.spark.sql.catalyst.util.{LegacyDateFormats, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy @@ -44,11 +44,11 @@ class TimestampFormatterSuite extends SparkFunSuite with SQLHelper with Matchers "Antarctica/Vostok" -> 1543723872001234L, "Asia/Hong_Kong" -> 1543716672001234L, "Europe/Amsterdam" -> 1543741872001234L) - DateTimeTestUtils.outstandingTimezonesIds.foreach { zoneId => + outstandingTimezonesIds.foreach { zoneId => val formatter = TimestampFormatter( "yyyy-MM-dd'T'HH:mm:ss.SSSSSS", - DateTimeUtils.getZoneId(zoneId), - needVarLengthSecondFraction = true) + getZoneId(zoneId), + isParsing = true) val microsSinceEpoch = formatter.parse(localDate) assert(microsSinceEpoch === expectedMicros(zoneId)) } @@ -57,20 +57,29 @@ class TimestampFormatterSuite extends SparkFunSuite with SQLHelper with Matchers test("format timestamps using time zones") { val microsSinceEpoch = 1543745472001234L val expectedTimestamp = Map( - "UTC" -> "2018-12-02T10:11:12.001234", - PST.getId -> "2018-12-02T02:11:12.001234", - CET.getId -> "2018-12-02T11:11:12.001234", - "Africa/Dakar" -> "2018-12-02T10:11:12.001234", - "America/Los_Angeles" -> "2018-12-02T02:11:12.001234", - "Antarctica/Vostok" -> "2018-12-02T16:11:12.001234", - "Asia/Hong_Kong" -> "2018-12-02T18:11:12.001234", - "Europe/Amsterdam" -> "2018-12-02T11:11:12.001234") - DateTimeTestUtils.outstandingTimezonesIds.foreach { zoneId => - val formatter = TimestampFormatter( - "yyyy-MM-dd'T'HH:mm:ss.SSSSSS", - DateTimeUtils.getZoneId(zoneId)) - val timestamp = formatter.format(microsSinceEpoch) - assert(timestamp === expectedTimestamp(zoneId)) + "UTC" -> "2018-12-02 10:11:12.001234", + PST.getId -> "2018-12-02 02:11:12.001234", + CET.getId -> "2018-12-02 11:11:12.001234", + "Africa/Dakar" -> "2018-12-02 10:11:12.001234", + "America/Los_Angeles" -> "2018-12-02 02:11:12.001234", + "Antarctica/Vostok" -> "2018-12-02 16:11:12.001234", + "Asia/Hong_Kong" -> "2018-12-02 18:11:12.001234", + "Europe/Amsterdam" -> "2018-12-02 11:11:12.001234") + outstandingTimezonesIds.foreach { zoneId => + Seq( + TimestampFormatter( + "yyyy-MM-dd HH:mm:ss.SSSSSS", + getZoneId(zoneId), + // Test only FAST_DATE_FORMAT because other legacy formats don't support formatting + // in microsecond precision. + LegacyDateFormats.FAST_DATE_FORMAT, + isParsing = false), + TimestampFormatter.getFractionFormatter(getZoneId(zoneId))).foreach { formatter => + val timestamp = formatter.format(microsSinceEpoch) + assert(timestamp === expectedTimestamp(zoneId)) + assert(formatter.format(microsToInstant(microsSinceEpoch)) === expectedTimestamp(zoneId)) + assert(formatter.format(toJavaTimestamp(microsSinceEpoch)) === expectedTimestamp(zoneId)) + } } } @@ -86,10 +95,10 @@ class TimestampFormatterSuite extends SparkFunSuite with SQLHelper with Matchers 1543749753123456L, 2177456523456789L, 11858049903010203L).foreach { micros => - DateTimeTestUtils.outstandingZoneIds.foreach { zoneId => + outstandingZoneIds.foreach { zoneId => val timestamp = TimestampFormatter(pattern, zoneId).format(micros) val parsed = TimestampFormatter( - pattern, zoneId, needVarLengthSecondFraction = true).parse(timestamp) + pattern, zoneId, isParsing = true).parse(timestamp) assert(micros === parsed) } } @@ -107,10 +116,10 @@ class TimestampFormatterSuite extends SparkFunSuite with SQLHelper with Matchers "2018-12-02T11:22:33.123456", "2039-01-01T01:02:03.456789", "2345-10-07T22:45:03.010203").foreach { timestamp => - DateTimeTestUtils.outstandingZoneIds.foreach { zoneId => + outstandingZoneIds.foreach { zoneId => val pattern = "yyyy-MM-dd'T'HH:mm:ss.SSSSSS" val micros = TimestampFormatter( - pattern, zoneId, needVarLengthSecondFraction = true).parse(timestamp) + pattern, zoneId, isParsing = true).parse(timestamp) val formatted = TimestampFormatter(pattern, zoneId).format(micros) assert(timestamp === formatted) } @@ -118,27 +127,39 @@ class TimestampFormatterSuite extends SparkFunSuite with SQLHelper with Matchers } test("case insensitive parsing of am and pm") { - val formatter = TimestampFormatter("yyyy MMM dd hh:mm:ss a", ZoneOffset.UTC) + val formatter = TimestampFormatter("yyyy MMM dd hh:mm:ss a", UTC) val micros = formatter.parse("2009 Mar 20 11:30:01 am") - assert(micros === TimeUnit.SECONDS.toMicros( - LocalDateTime.of(2009, 3, 20, 11, 30, 1).toEpochSecond(ZoneOffset.UTC))) + assert(micros === date(2009, 3, 20, 11, 30, 1)) } test("format fraction of second") { - val formatter = TimestampFormatter.getFractionFormatter(ZoneOffset.UTC) - assert(formatter.format(0) === "1970-01-01 00:00:00") - assert(formatter.format(1) === "1970-01-01 00:00:00.000001") - assert(formatter.format(1000) === "1970-01-01 00:00:00.001") - assert(formatter.format(900000) === "1970-01-01 00:00:00.9") - assert(formatter.format(1000000) === "1970-01-01 00:00:01") + val formatter = TimestampFormatter.getFractionFormatter(UTC) + Seq( + -999999 -> "1969-12-31 23:59:59.000001", + -999900 -> "1969-12-31 23:59:59.0001", + -1 -> "1969-12-31 23:59:59.999999", + 0 -> "1970-01-01 00:00:00", + 1 -> "1970-01-01 00:00:00.000001", + 1000 -> "1970-01-01 00:00:00.001", + 900000 -> "1970-01-01 00:00:00.9", + 1000000 -> "1970-01-01 00:00:01").foreach { case (micros, tsStr) => + assert(formatter.format(micros) === tsStr) + assert(formatter.format(microsToInstant(micros)) === tsStr) + withDefaultTimeZone(UTC) { + assert(formatter.format(toJavaTimestamp(micros)) === tsStr) + } + } } test("formatting negative years with default pattern") { - val instant = LocalDateTime.of(-99, 1, 1, 0, 0, 0) - .atZone(ZoneOffset.UTC) - .toInstant - val micros = DateTimeUtils.instantToMicros(instant) - assert(TimestampFormatter(ZoneOffset.UTC).format(micros) === "-0099-01-01 00:00:00") + val instant = LocalDateTime.of(-99, 1, 1, 0, 0, 0).atZone(UTC).toInstant + val micros = instantToMicros(instant) + assert(TimestampFormatter(UTC).format(micros) === "-0099-01-01 00:00:00") + assert(TimestampFormatter(UTC).format(instant) === "-0099-01-01 00:00:00") + withDefaultTimeZone(UTC) { // toJavaTimestamp depends on the default time zone + assert(TimestampFormatter("yyyy-MM-dd HH:mm:SS G", UTC).format(toJavaTimestamp(micros)) + === "0100-01-01 00:00:00 BC") + } } test("special timestamp values") { @@ -162,11 +183,10 @@ class TimestampFormatterSuite extends SparkFunSuite with SQLHelper with Matchers } test("parsing timestamp strings with various seconds fractions") { - DateTimeTestUtils.outstandingZoneIds.foreach { zoneId => + outstandingZoneIds.foreach { zoneId => def check(pattern: String, input: String, reference: String): Unit = { - val formatter = TimestampFormatter(pattern, zoneId, needVarLengthSecondFraction = true) - val expected = DateTimeUtils.stringToTimestamp( - UTF8String.fromString(reference), zoneId).get + val formatter = TimestampFormatter(pattern, zoneId, isParsing = true) + val expected = stringToTimestamp(UTF8String.fromString(reference), zoneId).get val actual = formatter.parse(input) assert(actual === expected) } @@ -200,11 +220,10 @@ class TimestampFormatterSuite extends SparkFunSuite with SQLHelper with Matchers } test("formatting timestamp strings up to microsecond precision") { - DateTimeTestUtils.outstandingZoneIds.foreach { zoneId => + outstandingZoneIds.foreach { zoneId => def check(pattern: String, input: String, expected: String): Unit = { val formatter = TimestampFormatter(pattern, zoneId) - val timestamp = DateTimeUtils.stringToTimestamp( - UTF8String.fromString(input), zoneId).get + val timestamp = stringToTimestamp(UTF8String.fromString(input), zoneId).get val actual = formatter.format(timestamp) assert(actual === expected) } @@ -240,9 +259,8 @@ class TimestampFormatterSuite extends SparkFunSuite with SQLHelper with Matchers } test("SPARK-30958: parse timestamp with negative year") { - val formatter1 = TimestampFormatter("yyyy-MM-dd HH:mm:ss", ZoneOffset.UTC, true) - assert(formatter1.parse("-1234-02-22 02:22:22") === instantToMicros( - LocalDateTime.of(-1234, 2, 22, 2, 22, 22).toInstant(ZoneOffset.UTC))) + val formatter1 = TimestampFormatter("yyyy-MM-dd HH:mm:ss", UTC, true) + assert(formatter1.parse("-1234-02-22 02:22:22") === date(-1234, 2, 22, 2, 22, 22)) def assertParsingError(f: => Unit): Unit = { intercept[Exception](f) match { @@ -258,32 +276,37 @@ class TimestampFormatterSuite extends SparkFunSuite with SQLHelper with Matchers assertParsingError(formatter2.parse("BC -1234-02-22 02:22:22")) assertParsingError(formatter2.parse("AC 0000-02-22 02:22:22")) - assert(formatter2.parse("BC 1234-02-22 02:22:22") === instantToMicros( - LocalDateTime.of(-1233, 2, 22, 2, 22, 22).toInstant(ZoneOffset.UTC))) - assert(formatter2.parse("AD 1234-02-22 02:22:22") === instantToMicros( - LocalDateTime.of(1234, 2, 22, 2, 22, 22).toInstant(ZoneOffset.UTC))) + assert(formatter2.parse("BC 1234-02-22 02:22:22") === date(-1233, 2, 22, 2, 22, 22)) + assert(formatter2.parse("AD 1234-02-22 02:22:22") === date(1234, 2, 22, 2, 22, 22)) } test("SPARK-31557: rebasing in legacy formatters/parsers") { withSQLConf(SQLConf.LEGACY_TIME_PARSER_POLICY.key -> LegacyBehaviorPolicy.LEGACY.toString) { - LegacyDateFormats.values.foreach { legacyFormat => - DateTimeTestUtils.outstandingZoneIds.foreach { zoneId => - withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> zoneId.getId) { - DateTimeTestUtils.withDefaultTimeZone(zoneId) { - withClue(s"${zoneId.getId} legacyFormat = $legacyFormat") { - val formatter = TimestampFormatter( + outstandingZoneIds.foreach { zoneId => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> zoneId.getId) { + withDefaultTimeZone(zoneId) { + withClue(s"zoneId = ${zoneId.getId}") { + val formatters = LegacyDateFormats.values.map { legacyFormat => + TimestampFormatter( TimestampFormatter.defaultPattern, zoneId, TimestampFormatter.defaultLocale, legacyFormat, - needVarLengthSecondFraction = false) + isParsing = false) + }.toSeq :+ TimestampFormatter.getFractionFormatter(zoneId) + formatters.foreach { formatter => assert(microsToInstant(formatter.parse("1000-01-01 01:02:03")) .atZone(zoneId) .toLocalDateTime === LocalDateTime.of(1000, 1, 1, 1, 2, 3)) + assert(formatter.format( + LocalDateTime.of(1000, 1, 1, 1, 2, 3).atZone(zoneId).toInstant) === + "1000-01-01 01:02:03") assert(formatter.format(instantToMicros( LocalDateTime.of(1000, 1, 1, 1, 2, 3) .atZone(zoneId).toInstant)) === "1000-01-01 01:02:03") + assert(formatter.format(java.sql.Timestamp.valueOf("1000-01-01 01:02:03")) === + "1000-01-01 01:02:03") } } } @@ -291,4 +314,119 @@ class TimestampFormatterSuite extends SparkFunSuite with SQLHelper with Matchers } } } + + test("parsing hour with various patterns") { + def createFormatter(pattern: String): TimestampFormatter = { + // Use `SIMPLE_DATE_FORMAT`, so that the legacy parser also fails with invalid value range. + TimestampFormatter(pattern, UTC, LegacyDateFormats.SIMPLE_DATE_FORMAT, false) + } + + withClue("HH") { + val formatter = createFormatter("yyyy-MM-dd HH") + + val micros1 = formatter.parse("2009-12-12 00") + assert(micros1 === date(2009, 12, 12)) + + val micros2 = formatter.parse("2009-12-12 15") + assert(micros2 === date(2009, 12, 12, 15)) + + intercept[DateTimeException](formatter.parse("2009-12-12 24")) + } + + withClue("kk") { + val formatter = createFormatter("yyyy-MM-dd kk") + + intercept[DateTimeException](formatter.parse("2009-12-12 00")) + + val micros1 = formatter.parse("2009-12-12 15") + assert(micros1 === date(2009, 12, 12, 15)) + + val micros2 = formatter.parse("2009-12-12 24") + assert(micros2 === date(2009, 12, 12)) + } + + withClue("KK") { + val formatter = createFormatter("yyyy-MM-dd KK a") + + val micros1 = formatter.parse("2009-12-12 00 am") + assert(micros1 === date(2009, 12, 12)) + + // For `KK`, "12:00:00 am" is the same as "00:00:00 pm". + val micros2 = formatter.parse("2009-12-12 12 am") + assert(micros2 === date(2009, 12, 12, 12)) + + val micros3 = formatter.parse("2009-12-12 00 pm") + assert(micros3 === date(2009, 12, 12, 12)) + + intercept[DateTimeException](formatter.parse("2009-12-12 12 pm")) + } + + withClue("hh") { + val formatter = createFormatter("yyyy-MM-dd hh a") + + intercept[DateTimeException](formatter.parse("2009-12-12 00 am")) + + val micros1 = formatter.parse("2009-12-12 12 am") + assert(micros1 === date(2009, 12, 12)) + + intercept[DateTimeException](formatter.parse("2009-12-12 00 pm")) + + val micros2 = formatter.parse("2009-12-12 12 pm") + assert(micros2 === date(2009, 12, 12, 12)) + } + } + + test("missing date fields") { + val formatter = TimestampFormatter("HH:mm:ss", UTC) + val micros = formatter.parse("11:30:01") + assert(micros === date(1970, 1, 1, 11, 30, 1)) + } + + test("missing year field with invalid date") { + // Use `SIMPLE_DATE_FORMAT`, so that the legacy parser also fails with invalid date. + val formatter = TimestampFormatter("MM-dd", UTC, LegacyDateFormats.SIMPLE_DATE_FORMAT, false) + withDefaultTimeZone(UTC)(intercept[DateTimeException](formatter.parse("02-29"))) + } + + test("missing am/pm field") { + Seq("HH", "hh", "KK", "kk").foreach { hour => + val formatter = TimestampFormatter(s"yyyy $hour:mm:ss", UTC) + val micros = formatter.parse("2009 11:30:01") + assert(micros === date(2009, 1, 1, 11, 30, 1)) + } + } + + test("missing time fields") { + val formatter = TimestampFormatter("yyyy HH", UTC) + val micros = formatter.parse("2009 11") + assert(micros === date(2009, 1, 1, 11)) + } + + test("missing hour field") { + val f1 = TimestampFormatter("mm:ss a", UTC) + val t1 = f1.parse("30:01 PM") + assert(t1 === date(1970, 1, 1, 12, 30, 1)) + val t2 = f1.parse("30:01 AM") + assert(t2 === date(1970, 1, 1, 0, 30, 1)) + val f2 = TimestampFormatter("mm:ss", UTC) + val t3 = f2.parse("30:01") + assert(t3 === date(1970, 1, 1, 0, 30, 1)) + val f3 = TimestampFormatter("a", UTC) + val t4 = f3.parse("PM") + assert(t4 === date(1970, 1, 1, 12)) + val t5 = f3.parse("AM") + assert(t5 === date(1970)) + } + + test("explicitly forbidden datetime patterns") { + // not support by the legacy one too + Seq("QQQQQ", "qqqqq", "A", "c", "e", "n", "N", "p").foreach { pattern => + intercept[IllegalArgumentException](TimestampFormatter(pattern, UTC).format(0)) + } + // supported by the legacy one, then we will suggest users with SparkUpgradeException + Seq("GGGGG", "MMMMM", "LLLLL", "EEEEE", "uuuuu", "aa", "aaa", "y" * 11, "y" * 11) + .foreach { pattern => + intercept[SparkUpgradeException](TimestampFormatter(pattern, UTC).format(0)) + } + } } diff --git a/sql/core/benchmarks/CSVBenchmark-jdk11-results.txt b/sql/core/benchmarks/CSVBenchmark-jdk11-results.txt index 147a77ff098d0..0e82b632793d2 100644 --- a/sql/core/benchmarks/CSVBenchmark-jdk11-results.txt +++ b/sql/core/benchmarks/CSVBenchmark-jdk11-results.txt @@ -2,66 +2,66 @@ Benchmark to measure CSV read/write performance ================================================================================================ -Java HotSpot(TM) 64-Bit Server VM 11.0.5+10-LTS on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Parsing quoted values: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -One quoted string 24907 29374 NaN 0.0 498130.5 1.0X +One quoted string 46568 46683 198 0.0 931358.6 1.0X -Java HotSpot(TM) 64-Bit Server VM 11.0.5+10-LTS on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Wide rows with 1000 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Select 1000 columns 62811 63690 1416 0.0 62811.4 1.0X -Select 100 columns 23839 24064 230 0.0 23839.5 2.6X -Select one column 19936 20641 827 0.1 19936.4 3.2X -count() 4174 4380 206 0.2 4174.4 15.0X -Select 100 columns, one bad input field 41015 42380 1688 0.0 41015.4 1.5X -Select 100 columns, corrupt record field 46281 46338 93 0.0 46280.5 1.4X +Select 1000 columns 129836 130796 1404 0.0 129836.0 1.0X +Select 100 columns 40444 40679 261 0.0 40443.5 3.2X +Select one column 33429 33475 73 0.0 33428.6 3.9X +count() 7967 8047 73 0.1 7966.7 16.3X +Select 100 columns, one bad input field 90639 90832 266 0.0 90638.6 1.4X +Select 100 columns, corrupt record field 109023 109084 74 0.0 109023.3 1.2X -Java HotSpot(TM) 64-Bit Server VM 11.0.5+10-LTS on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Count a dataset with 10 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Select 10 columns + count() 10810 10997 163 0.9 1081.0 1.0X -Select 1 column + count() 7608 7641 47 1.3 760.8 1.4X -count() 2415 2462 77 4.1 241.5 4.5X +Select 10 columns + count() 20685 20707 35 0.5 2068.5 1.0X +Select 1 column + count() 13096 13149 49 0.8 1309.6 1.6X +count() 3994 4001 7 2.5 399.4 5.2X -Java HotSpot(TM) 64-Bit Server VM 11.0.5+10-LTS on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Write dates and timestamps: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Create a dataset of timestamps 874 914 37 11.4 87.4 1.0X -to_csv(timestamp) 7051 7223 250 1.4 705.1 0.1X -write timestamps to files 6712 6741 31 1.5 671.2 0.1X -Create a dataset of dates 909 945 35 11.0 90.9 1.0X -to_csv(date) 4222 4231 8 2.4 422.2 0.2X -write dates to files 3799 3813 14 2.6 379.9 0.2X +Create a dataset of timestamps 2169 2203 32 4.6 216.9 1.0X +to_csv(timestamp) 14401 14591 168 0.7 1440.1 0.2X +write timestamps to files 13209 13276 59 0.8 1320.9 0.2X +Create a dataset of dates 2231 2248 17 4.5 223.1 1.0X +to_csv(date) 10406 10473 68 1.0 1040.6 0.2X +write dates to files 7970 7976 9 1.3 797.0 0.3X -Java HotSpot(TM) 64-Bit Server VM 11.0.5+10-LTS on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Read dates and timestamps: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -read timestamp text from files 1342 1364 35 7.5 134.2 1.0X -read timestamps from files 20300 20473 247 0.5 2030.0 0.1X -infer timestamps from files 40705 40744 54 0.2 4070.5 0.0X -read date text from files 1146 1151 6 8.7 114.6 1.2X -read date from files 12278 12408 117 0.8 1227.8 0.1X -infer date from files 12734 12872 220 0.8 1273.4 0.1X -timestamp strings 1467 1482 15 6.8 146.7 0.9X -parse timestamps from Dataset[String] 21708 22234 477 0.5 2170.8 0.1X -infer timestamps from Dataset[String] 42357 43253 922 0.2 4235.7 0.0X -date strings 1512 1532 18 6.6 151.2 0.9X -parse dates from Dataset[String] 13436 13470 33 0.7 1343.6 0.1X -from_csv(timestamp) 20390 20486 95 0.5 2039.0 0.1X -from_csv(date) 12592 12693 139 0.8 1259.2 0.1X +read timestamp text from files 2387 2391 6 4.2 238.7 1.0X +read timestamps from files 53503 53593 124 0.2 5350.3 0.0X +infer timestamps from files 107988 108668 647 0.1 10798.8 0.0X +read date text from files 2121 2133 12 4.7 212.1 1.1X +read date from files 29983 30039 48 0.3 2998.3 0.1X +infer date from files 30196 30436 218 0.3 3019.6 0.1X +timestamp strings 3098 3109 10 3.2 309.8 0.8X +parse timestamps from Dataset[String] 63331 63426 84 0.2 6333.1 0.0X +infer timestamps from Dataset[String] 124003 124463 490 0.1 12400.3 0.0X +date strings 3423 3429 11 2.9 342.3 0.7X +parse dates from Dataset[String] 34235 34314 76 0.3 3423.5 0.1X +from_csv(timestamp) 60829 61600 668 0.2 6082.9 0.0X +from_csv(date) 33047 33173 139 0.3 3304.7 0.1X -Java HotSpot(TM) 64-Bit Server VM 11.0.5+10-LTS on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Filters pushdown: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -w/o filters 12535 12606 67 0.0 125348.8 1.0X -pushdown disabled 12611 12672 91 0.0 126112.9 1.0X -w/ filters 1093 1099 11 0.1 10928.3 11.5X +w/o filters 28752 28765 16 0.0 287516.5 1.0X +pushdown disabled 28856 28880 22 0.0 288556.3 1.0X +w/ filters 1714 1731 15 0.1 17137.3 16.8X diff --git a/sql/core/benchmarks/CSVBenchmark-results.txt b/sql/core/benchmarks/CSVBenchmark-results.txt index 498ca4caa0e45..a3af46c037bf9 100644 --- a/sql/core/benchmarks/CSVBenchmark-results.txt +++ b/sql/core/benchmarks/CSVBenchmark-results.txt @@ -2,66 +2,66 @@ Benchmark to measure CSV read/write performance ================================================================================================ -Java HotSpot(TM) 64-Bit Server VM 1.8.0_231-b11 on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Parsing quoted values: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -One quoted string 24073 24109 33 0.0 481463.5 1.0X +One quoted string 45457 45731 344 0.0 909136.8 1.0X -Java HotSpot(TM) 64-Bit Server VM 1.8.0_231-b11 on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Wide rows with 1000 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Select 1000 columns 58415 59611 2071 0.0 58414.8 1.0X -Select 100 columns 22568 23020 594 0.0 22568.0 2.6X -Select one column 18995 19058 99 0.1 18995.0 3.1X -count() 5301 5332 30 0.2 5300.9 11.0X -Select 100 columns, one bad input field 39736 40153 361 0.0 39736.1 1.5X -Select 100 columns, corrupt record field 47195 47826 590 0.0 47195.2 1.2X +Select 1000 columns 129646 130527 1412 0.0 129646.3 1.0X +Select 100 columns 42444 42551 119 0.0 42444.0 3.1X +Select one column 35415 35428 20 0.0 35414.6 3.7X +count() 11114 11128 16 0.1 11113.6 11.7X +Select 100 columns, one bad input field 93353 93670 275 0.0 93352.6 1.4X +Select 100 columns, corrupt record field 113569 113952 373 0.0 113568.8 1.1X -Java HotSpot(TM) 64-Bit Server VM 1.8.0_231-b11 on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Count a dataset with 10 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Select 10 columns + count() 9884 9904 25 1.0 988.4 1.0X -Select 1 column + count() 6794 6835 46 1.5 679.4 1.5X -count() 2060 2065 5 4.9 206.0 4.8X +Select 10 columns + count() 18498 18589 87 0.5 1849.8 1.0X +Select 1 column + count() 11078 11095 27 0.9 1107.8 1.7X +count() 3928 3950 22 2.5 392.8 4.7X -Java HotSpot(TM) 64-Bit Server VM 1.8.0_231-b11 on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Write dates and timestamps: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Create a dataset of timestamps 717 732 18 14.0 71.7 1.0X -to_csv(timestamp) 6994 7100 121 1.4 699.4 0.1X -write timestamps to files 6417 6435 27 1.6 641.7 0.1X -Create a dataset of dates 827 855 24 12.1 82.7 0.9X -to_csv(date) 4408 4438 32 2.3 440.8 0.2X -write dates to files 3738 3758 28 2.7 373.8 0.2X +Create a dataset of timestamps 1933 1940 11 5.2 193.3 1.0X +to_csv(timestamp) 18078 18243 255 0.6 1807.8 0.1X +write timestamps to files 12668 12786 134 0.8 1266.8 0.2X +Create a dataset of dates 2196 2201 5 4.6 219.6 0.9X +to_csv(date) 9583 9597 21 1.0 958.3 0.2X +write dates to files 7091 7110 20 1.4 709.1 0.3X -Java HotSpot(TM) 64-Bit Server VM 1.8.0_231-b11 on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Read dates and timestamps: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -read timestamp text from files 1121 1176 52 8.9 112.1 1.0X -read timestamps from files 21298 21366 105 0.5 2129.8 0.1X -infer timestamps from files 41008 41051 39 0.2 4100.8 0.0X -read date text from files 962 967 5 10.4 96.2 1.2X -read date from files 11749 11772 22 0.9 1174.9 0.1X -infer date from files 12426 12459 29 0.8 1242.6 0.1X -timestamp strings 1508 1519 9 6.6 150.8 0.7X -parse timestamps from Dataset[String] 21674 21997 455 0.5 2167.4 0.1X -infer timestamps from Dataset[String] 42141 42230 105 0.2 4214.1 0.0X -date strings 1694 1701 8 5.9 169.4 0.7X -parse dates from Dataset[String] 12929 12951 25 0.8 1292.9 0.1X -from_csv(timestamp) 20603 20786 166 0.5 2060.3 0.1X -from_csv(date) 12325 12338 12 0.8 1232.5 0.1X +read timestamp text from files 2166 2177 10 4.6 216.6 1.0X +read timestamps from files 53212 53402 281 0.2 5321.2 0.0X +infer timestamps from files 109788 110372 570 0.1 10978.8 0.0X +read date text from files 1921 1929 8 5.2 192.1 1.1X +read date from files 25470 25499 25 0.4 2547.0 0.1X +infer date from files 27201 27342 134 0.4 2720.1 0.1X +timestamp strings 3638 3653 19 2.7 363.8 0.6X +parse timestamps from Dataset[String] 61894 62532 555 0.2 6189.4 0.0X +infer timestamps from Dataset[String] 125171 125430 236 0.1 12517.1 0.0X +date strings 3736 3749 14 2.7 373.6 0.6X +parse dates from Dataset[String] 30787 30829 43 0.3 3078.7 0.1X +from_csv(timestamp) 60842 61035 209 0.2 6084.2 0.0X +from_csv(date) 30123 30196 95 0.3 3012.3 0.1X -Java HotSpot(TM) 64-Bit Server VM 1.8.0_231-b11 on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Filters pushdown: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -w/o filters 12455 12474 22 0.0 124553.8 1.0X -pushdown disabled 12462 12486 29 0.0 124624.9 1.0X -w/ filters 1073 1092 18 0.1 10727.6 11.6X +w/o filters 28985 29042 80 0.0 289852.9 1.0X +pushdown disabled 29080 29146 58 0.0 290799.4 1.0X +w/ filters 2072 2084 17 0.0 20722.3 14.0X diff --git a/sql/core/benchmarks/DateTimeBenchmark-jdk11-results.txt b/sql/core/benchmarks/DateTimeBenchmark-jdk11-results.txt index 61ca342a0d559..f4ed8ce4afaea 100644 --- a/sql/core/benchmarks/DateTimeBenchmark-jdk11-results.txt +++ b/sql/core/benchmarks/DateTimeBenchmark-jdk11-results.txt @@ -6,18 +6,18 @@ OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-106 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz datetime +/- interval: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date + interval(m) 1496 1569 104 6.7 149.6 1.0X -date + interval(m, d) 1514 1526 17 6.6 151.4 1.0X -date + interval(m, d, ms) 6231 6253 30 1.6 623.1 0.2X -date - interval(m) 1481 1487 9 6.8 148.1 1.0X -date - interval(m, d) 1550 1552 2 6.5 155.0 1.0X -date - interval(m, d, ms) 6269 6272 4 1.6 626.9 0.2X -timestamp + interval(m) 3017 3056 54 3.3 301.7 0.5X -timestamp + interval(m, d) 3146 3148 3 3.2 314.6 0.5X -timestamp + interval(m, d, ms) 3446 3460 20 2.9 344.6 0.4X -timestamp - interval(m) 3045 3059 19 3.3 304.5 0.5X -timestamp - interval(m, d) 3147 3164 25 3.2 314.7 0.5X -timestamp - interval(m, d, ms) 3425 3442 25 2.9 342.5 0.4X +date + interval(m) 1660 1745 120 6.0 166.0 1.0X +date + interval(m, d) 1672 1685 19 6.0 167.2 1.0X +date + interval(m, d, ms) 6462 6481 27 1.5 646.2 0.3X +date - interval(m) 1456 1480 35 6.9 145.6 1.1X +date - interval(m, d) 1501 1509 11 6.7 150.1 1.1X +date - interval(m, d, ms) 6457 6466 12 1.5 645.7 0.3X +timestamp + interval(m) 2941 2944 4 3.4 294.1 0.6X +timestamp + interval(m, d) 3008 3012 6 3.3 300.8 0.6X +timestamp + interval(m, d, ms) 3329 3333 6 3.0 332.9 0.5X +timestamp - interval(m) 2964 2982 26 3.4 296.4 0.6X +timestamp - interval(m, d) 3030 3039 13 3.3 303.0 0.5X +timestamp - interval(m, d, ms) 3312 3313 1 3.0 331.2 0.5X ================================================================================================ @@ -28,92 +28,92 @@ OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-106 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz cast to timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -cast to timestamp wholestage off 332 336 5 30.1 33.2 1.0X -cast to timestamp wholestage on 333 344 10 30.0 33.3 1.0X +cast to timestamp wholestage off 333 334 0 30.0 33.3 1.0X +cast to timestamp wholestage on 349 368 12 28.6 34.9 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz year of timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -year of timestamp wholestage off 1246 1257 16 8.0 124.6 1.0X -year of timestamp wholestage on 1209 1218 12 8.3 120.9 1.0X +year of timestamp wholestage off 1229 1229 1 8.1 122.9 1.0X +year of timestamp wholestage on 1218 1223 5 8.2 121.8 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz quarter of timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -quarter of timestamp wholestage off 1608 1616 11 6.2 160.8 1.0X -quarter of timestamp wholestage on 1540 1552 10 6.5 154.0 1.0X +quarter of timestamp wholestage off 1593 1594 2 6.3 159.3 1.0X +quarter of timestamp wholestage on 1515 1529 14 6.6 151.5 1.1X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz month of timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -month of timestamp wholestage off 1242 1246 6 8.1 124.2 1.0X -month of timestamp wholestage on 1202 1212 11 8.3 120.2 1.0X +month of timestamp wholestage off 1222 1246 34 8.2 122.2 1.0X +month of timestamp wholestage on 1207 1232 31 8.3 120.7 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz weekofyear of timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -weekofyear of timestamp wholestage off 1879 1885 8 5.3 187.9 1.0X -weekofyear of timestamp wholestage on 1832 1845 10 5.5 183.2 1.0X +weekofyear of timestamp wholestage off 2453 2455 2 4.1 245.3 1.0X +weekofyear of timestamp wholestage on 2357 2380 22 4.2 235.7 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz day of timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -day of timestamp wholestage off 1236 1239 4 8.1 123.6 1.0X -day of timestamp wholestage on 1206 1219 17 8.3 120.6 1.0X +day of timestamp wholestage off 1216 1219 5 8.2 121.6 1.0X +day of timestamp wholestage on 1205 1221 25 8.3 120.5 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz dayofyear of timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -dayofyear of timestamp wholestage off 1308 1309 1 7.6 130.8 1.0X -dayofyear of timestamp wholestage on 1239 1255 15 8.1 123.9 1.1X +dayofyear of timestamp wholestage off 1268 1274 9 7.9 126.8 1.0X +dayofyear of timestamp wholestage on 1253 1268 10 8.0 125.3 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz dayofmonth of timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -dayofmonth of timestamp wholestage off 1259 1263 5 7.9 125.9 1.0X -dayofmonth of timestamp wholestage on 1201 1205 5 8.3 120.1 1.0X +dayofmonth of timestamp wholestage off 1223 1224 1 8.2 122.3 1.0X +dayofmonth of timestamp wholestage on 1231 1246 14 8.1 123.1 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz dayofweek of timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -dayofweek of timestamp wholestage off 1406 1410 6 7.1 140.6 1.0X -dayofweek of timestamp wholestage on 1387 1402 15 7.2 138.7 1.0X +dayofweek of timestamp wholestage off 1398 1406 12 7.2 139.8 1.0X +dayofweek of timestamp wholestage on 1387 1399 15 7.2 138.7 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz weekday of timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -weekday of timestamp wholestage off 1355 1367 18 7.4 135.5 1.0X -weekday of timestamp wholestage on 1311 1321 10 7.6 131.1 1.0X +weekday of timestamp wholestage off 1327 1333 9 7.5 132.7 1.0X +weekday of timestamp wholestage on 1329 1333 4 7.5 132.9 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz hour of timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -hour of timestamp wholestage off 996 997 2 10.0 99.6 1.0X -hour of timestamp wholestage on 930 936 6 10.7 93.0 1.1X +hour of timestamp wholestage off 1005 1016 15 9.9 100.5 1.0X +hour of timestamp wholestage on 934 940 4 10.7 93.4 1.1X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz minute of timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -minute of timestamp wholestage off 1005 1012 10 9.9 100.5 1.0X -minute of timestamp wholestage on 949 952 3 10.5 94.9 1.1X +minute of timestamp wholestage off 1003 1009 8 10.0 100.3 1.0X +minute of timestamp wholestage on 934 938 7 10.7 93.4 1.1X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz second of timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -second of timestamp wholestage off 1013 1014 1 9.9 101.3 1.0X -second of timestamp wholestage on 933 934 2 10.7 93.3 1.1X +second of timestamp wholestage off 997 998 2 10.0 99.7 1.0X +second of timestamp wholestage on 925 935 8 10.8 92.5 1.1X ================================================================================================ @@ -124,15 +124,15 @@ OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-106 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz current_date: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -current_date wholestage off 291 293 2 34.3 29.1 1.0X -current_date wholestage on 280 284 3 35.7 28.0 1.0X +current_date wholestage off 297 297 0 33.7 29.7 1.0X +current_date wholestage on 280 282 2 35.7 28.0 1.1X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz current_timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -current_timestamp wholestage off 311 324 18 32.1 31.1 1.0X -current_timestamp wholestage on 275 364 85 36.3 27.5 1.1X +current_timestamp wholestage off 307 337 43 32.6 30.7 1.0X +current_timestamp wholestage on 260 284 29 38.4 26.0 1.2X ================================================================================================ @@ -143,43 +143,43 @@ OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-106 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz cast to date: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -cast to date wholestage off 1077 1079 3 9.3 107.7 1.0X -cast to date wholestage on 1018 1030 14 9.8 101.8 1.1X +cast to date wholestage off 1066 1073 10 9.4 106.6 1.0X +cast to date wholestage on 997 1003 6 10.0 99.7 1.1X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz last_day: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -last_day wholestage off 1257 1260 4 8.0 125.7 1.0X -last_day wholestage on 1218 1227 14 8.2 121.8 1.0X +last_day wholestage off 1238 1242 6 8.1 123.8 1.0X +last_day wholestage on 1259 1272 12 7.9 125.9 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz next_day: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -next_day wholestage off 1140 1141 1 8.8 114.0 1.0X -next_day wholestage on 1067 1076 11 9.4 106.7 1.1X +next_day wholestage off 1116 1138 32 9.0 111.6 1.0X +next_day wholestage on 1052 1063 11 9.5 105.2 1.1X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz date_add: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date_add wholestage off 1062 1064 3 9.4 106.2 1.0X -date_add wholestage on 1046 1055 11 9.6 104.6 1.0X +date_add wholestage off 1048 1049 1 9.5 104.8 1.0X +date_add wholestage on 1035 1039 3 9.7 103.5 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz date_sub: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date_sub wholestage off 1082 1083 1 9.2 108.2 1.0X -date_sub wholestage on 1047 1056 12 9.6 104.7 1.0X +date_sub wholestage off 1119 1127 11 8.9 111.9 1.0X +date_sub wholestage on 1028 1039 7 9.7 102.8 1.1X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz add_months: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -add_months wholestage off 1430 1431 1 7.0 143.0 1.0X -add_months wholestage on 1441 1446 8 6.9 144.1 1.0X +add_months wholestage off 1421 1421 0 7.0 142.1 1.0X +add_months wholestage on 1423 1434 11 7.0 142.3 1.0X ================================================================================================ @@ -190,8 +190,8 @@ OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-106 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz format date: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -format date wholestage off 5442 5549 150 1.8 544.2 1.0X -format date wholestage on 5529 5655 236 1.8 552.9 1.0X +format date wholestage off 5293 5296 5 1.9 529.3 1.0X +format date wholestage on 5143 5157 19 1.9 514.3 1.0X ================================================================================================ @@ -202,8 +202,8 @@ OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-106 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz from_unixtime: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -from_unixtime wholestage off 7416 7440 34 1.3 741.6 1.0X -from_unixtime wholestage on 7372 7391 17 1.4 737.2 1.0X +from_unixtime wholestage off 7136 7136 1 1.4 713.6 1.0X +from_unixtime wholestage on 7049 7068 29 1.4 704.9 1.0X ================================================================================================ @@ -214,15 +214,15 @@ OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-106 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz from_utc_timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -from_utc_timestamp wholestage off 1316 1320 6 7.6 131.6 1.0X -from_utc_timestamp wholestage on 1268 1272 4 7.9 126.8 1.0X +from_utc_timestamp wholestage off 1325 1329 6 7.5 132.5 1.0X +from_utc_timestamp wholestage on 1269 1273 4 7.9 126.9 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz to_utc_timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -to_utc_timestamp wholestage off 1653 1657 6 6.0 165.3 1.0X -to_utc_timestamp wholestage on 1594 1599 4 6.3 159.4 1.0X +to_utc_timestamp wholestage off 1684 1691 10 5.9 168.4 1.0X +to_utc_timestamp wholestage on 1641 1648 9 6.1 164.1 1.0X ================================================================================================ @@ -233,29 +233,29 @@ OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-106 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz cast interval: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -cast interval wholestage off 341 343 3 29.4 34.1 1.0X -cast interval wholestage on 279 282 1 35.8 27.9 1.2X +cast interval wholestage off 343 346 4 29.1 34.3 1.0X +cast interval wholestage on 281 282 1 35.6 28.1 1.2X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz datediff: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -datediff wholestage off 1862 1865 4 5.4 186.2 1.0X -datediff wholestage on 1769 1783 15 5.7 176.9 1.1X +datediff wholestage off 1831 1840 13 5.5 183.1 1.0X +datediff wholestage on 1759 1769 15 5.7 175.9 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz months_between: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -months_between wholestage off 5594 5599 7 1.8 559.4 1.0X -months_between wholestage on 5498 5508 11 1.8 549.8 1.0X +months_between wholestage off 5729 5747 25 1.7 572.9 1.0X +months_between wholestage on 5710 5720 9 1.8 571.0 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz window: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -window wholestage off 2044 2127 117 0.5 2044.3 1.0X -window wholestage on 48057 48109 54 0.0 48056.9 0.0X +window wholestage off 2183 2189 9 0.5 2182.6 1.0X +window wholestage on 46835 46944 88 0.0 46834.8 0.0X ================================================================================================ @@ -266,134 +266,134 @@ OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-106 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz date_trunc YEAR: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date_trunc YEAR wholestage off 2540 2542 3 3.9 254.0 1.0X -date_trunc YEAR wholestage on 2486 2507 29 4.0 248.6 1.0X +date_trunc YEAR wholestage off 2668 2672 5 3.7 266.8 1.0X +date_trunc YEAR wholestage on 2719 2731 9 3.7 271.9 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz date_trunc YYYY: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date_trunc YYYY wholestage off 2542 2543 3 3.9 254.2 1.0X -date_trunc YYYY wholestage on 2491 2498 9 4.0 249.1 1.0X +date_trunc YYYY wholestage off 2672 2677 8 3.7 267.2 1.0X +date_trunc YYYY wholestage on 2710 2726 12 3.7 271.0 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz date_trunc YY: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date_trunc YY wholestage off 2545 2569 35 3.9 254.5 1.0X -date_trunc YY wholestage on 2487 2493 4 4.0 248.7 1.0X +date_trunc YY wholestage off 2670 2673 4 3.7 267.0 1.0X +date_trunc YY wholestage on 2711 2720 7 3.7 271.1 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz date_trunc MON: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date_trunc MON wholestage off 2590 2590 1 3.9 259.0 1.0X -date_trunc MON wholestage on 2506 2520 12 4.0 250.6 1.0X +date_trunc MON wholestage off 2674 2674 0 3.7 267.4 1.0X +date_trunc MON wholestage on 2667 2677 10 3.7 266.7 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz date_trunc MONTH: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date_trunc MONTH wholestage off 2595 2603 11 3.9 259.5 1.0X -date_trunc MONTH wholestage on 2505 2516 12 4.0 250.5 1.0X +date_trunc MONTH wholestage off 2675 2686 16 3.7 267.5 1.0X +date_trunc MONTH wholestage on 2667 2674 6 3.7 266.7 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz date_trunc MM: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date_trunc MM wholestage off 2605 2612 10 3.8 260.5 1.0X -date_trunc MM wholestage on 2501 2515 11 4.0 250.1 1.0X +date_trunc MM wholestage off 2673 2674 1 3.7 267.3 1.0X +date_trunc MM wholestage on 2664 2669 4 3.8 266.4 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz date_trunc DAY: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date_trunc DAY wholestage off 2225 2229 5 4.5 222.5 1.0X -date_trunc DAY wholestage on 2184 2196 9 4.6 218.4 1.0X +date_trunc DAY wholestage off 2281 2288 10 4.4 228.1 1.0X +date_trunc DAY wholestage on 2302 2312 8 4.3 230.2 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz date_trunc DD: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date_trunc DD wholestage off 2232 2236 6 4.5 223.2 1.0X -date_trunc DD wholestage on 2183 2190 6 4.6 218.3 1.0X +date_trunc DD wholestage off 2281 2283 3 4.4 228.1 1.0X +date_trunc DD wholestage on 2291 2302 11 4.4 229.1 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz date_trunc HOUR: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date_trunc HOUR wholestage off 2194 2199 7 4.6 219.4 1.0X -date_trunc HOUR wholestage on 2160 2166 5 4.6 216.0 1.0X +date_trunc HOUR wholestage off 2331 2332 1 4.3 233.1 1.0X +date_trunc HOUR wholestage on 2290 2304 11 4.4 229.0 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz date_trunc MINUTE: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date_trunc MINUTE wholestage off 390 396 9 25.7 39.0 1.0X -date_trunc MINUTE wholestage on 331 337 7 30.2 33.1 1.2X +date_trunc MINUTE wholestage off 379 385 9 26.4 37.9 1.0X +date_trunc MINUTE wholestage on 371 376 5 27.0 37.1 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz date_trunc SECOND: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date_trunc SECOND wholestage off 375 381 8 26.7 37.5 1.0X -date_trunc SECOND wholestage on 332 346 14 30.1 33.2 1.1X +date_trunc SECOND wholestage off 375 376 1 26.7 37.5 1.0X +date_trunc SECOND wholestage on 370 376 8 27.0 37.0 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz date_trunc WEEK: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date_trunc WEEK wholestage off 2439 2443 6 4.1 243.9 1.0X -date_trunc WEEK wholestage on 2390 2409 32 4.2 239.0 1.0X +date_trunc WEEK wholestage off 2597 2604 10 3.9 259.7 1.0X +date_trunc WEEK wholestage on 2591 2605 13 3.9 259.1 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz date_trunc QUARTER: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date_trunc QUARTER wholestage off 3290 3292 4 3.0 329.0 1.0X -date_trunc QUARTER wholestage on 3214 3218 3 3.1 321.4 1.0X +date_trunc QUARTER wholestage off 3501 3511 14 2.9 350.1 1.0X +date_trunc QUARTER wholestage on 3477 3489 9 2.9 347.7 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz trunc year: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -trunc year wholestage off 308 310 3 32.5 30.8 1.0X -trunc year wholestage on 289 293 6 34.7 28.9 1.1X +trunc year wholestage off 332 334 3 30.1 33.2 1.0X +trunc year wholestage on 332 346 17 30.1 33.2 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz trunc yyyy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -trunc yyyy wholestage off 309 311 3 32.4 30.9 1.0X -trunc yyyy wholestage on 289 294 7 34.6 28.9 1.1X +trunc yyyy wholestage off 331 331 0 30.2 33.1 1.0X +trunc yyyy wholestage on 336 339 4 29.8 33.6 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz trunc yy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -trunc yy wholestage off 311 311 0 32.2 31.1 1.0X -trunc yy wholestage on 288 294 7 34.7 28.8 1.1X +trunc yy wholestage off 330 342 17 30.3 33.0 1.0X +trunc yy wholestage on 333 337 3 30.0 33.3 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz trunc mon: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -trunc mon wholestage off 313 313 0 32.0 31.3 1.0X -trunc mon wholestage on 287 290 2 34.8 28.7 1.1X +trunc mon wholestage off 334 335 1 30.0 33.4 1.0X +trunc mon wholestage on 333 347 9 30.0 33.3 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz trunc month: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -trunc month wholestage off 310 310 0 32.3 31.0 1.0X -trunc month wholestage on 287 290 2 34.8 28.7 1.1X +trunc month wholestage off 332 333 1 30.1 33.2 1.0X +trunc month wholestage on 333 340 7 30.0 33.3 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz trunc mm: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -trunc mm wholestage off 311 312 1 32.1 31.1 1.0X -trunc mm wholestage on 287 296 9 34.8 28.7 1.1X +trunc mm wholestage off 328 336 11 30.5 32.8 1.0X +trunc mm wholestage on 333 343 11 30.0 33.3 1.0X ================================================================================================ @@ -404,36 +404,36 @@ OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-106 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz to timestamp str: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -to timestamp str wholestage off 169 170 1 5.9 168.9 1.0X -to timestamp str wholestage on 161 168 11 6.2 161.0 1.0X +to timestamp str wholestage off 170 171 1 5.9 170.1 1.0X +to timestamp str wholestage on 172 174 2 5.8 171.6 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz to_timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -to_timestamp wholestage off 1360 1361 1 0.7 1359.6 1.0X -to_timestamp wholestage on 1362 1366 6 0.7 1362.0 1.0X +to_timestamp wholestage off 1437 1439 3 0.7 1437.0 1.0X +to_timestamp wholestage on 1288 1292 5 0.8 1288.1 1.1X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz to_unix_timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -to_unix_timestamp wholestage off 1343 1346 4 0.7 1342.6 1.0X -to_unix_timestamp wholestage on 1356 1359 2 0.7 1356.2 1.0X +to_unix_timestamp wholestage off 1352 1353 2 0.7 1352.0 1.0X +to_unix_timestamp wholestage on 1314 1319 5 0.8 1314.4 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz to date str: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -to date str wholestage off 227 230 4 4.4 227.0 1.0X -to date str wholestage on 299 302 3 3.3 299.0 0.8X +to date str wholestage off 211 215 6 4.7 210.7 1.0X +to date str wholestage on 217 217 1 4.6 216.5 1.0X OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz to_date: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -to_date wholestage off 3413 3440 38 0.3 3413.0 1.0X -to_date wholestage on 3392 3402 12 0.3 3392.3 1.0X +to_date wholestage off 3281 3295 20 0.3 3280.9 1.0X +to_date wholestage on 3223 3239 17 0.3 3222.8 1.0X ================================================================================================ @@ -444,14 +444,14 @@ OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-106 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz To/from Java's date-time: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -From java.sql.Date 410 415 7 12.2 82.0 1.0X -From java.time.LocalDate 332 333 1 15.1 66.4 1.2X -Collect java.sql.Date 1891 2542 829 2.6 378.1 0.2X -Collect java.time.LocalDate 1630 2138 441 3.1 326.0 0.3X -From java.sql.Timestamp 254 259 6 19.7 50.9 1.6X -From java.time.Instant 302 306 4 16.6 60.3 1.4X -Collect longs 1134 1265 117 4.4 226.8 0.4X -Collect java.sql.Timestamp 1441 1458 16 3.5 288.1 0.3X -Collect java.time.Instant 1680 1928 253 3.0 336.0 0.2X +From java.sql.Date 446 447 1 11.2 89.1 1.0X +From java.time.LocalDate 354 356 1 14.1 70.8 1.3X +Collect java.sql.Date 2722 3091 495 1.8 544.4 0.2X +Collect java.time.LocalDate 1786 1836 60 2.8 357.2 0.2X +From java.sql.Timestamp 275 287 19 18.2 55.0 1.6X +From java.time.Instant 325 328 3 15.4 65.0 1.4X +Collect longs 1300 1321 25 3.8 260.0 0.3X +Collect java.sql.Timestamp 1450 1557 102 3.4 290.0 0.3X +Collect java.time.Instant 1499 1599 87 3.3 299.9 0.3X diff --git a/sql/core/benchmarks/DateTimeBenchmark-results.txt b/sql/core/benchmarks/DateTimeBenchmark-results.txt index 7586295778bd8..7a9aa4badfeb7 100644 --- a/sql/core/benchmarks/DateTimeBenchmark-results.txt +++ b/sql/core/benchmarks/DateTimeBenchmark-results.txt @@ -6,18 +6,18 @@ OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aw Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz datetime +/- interval: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date + interval(m) 1638 1701 89 6.1 163.8 1.0X -date + interval(m, d) 1785 1790 7 5.6 178.5 0.9X -date + interval(m, d, ms) 6229 6270 58 1.6 622.9 0.3X -date - interval(m) 1500 1503 4 6.7 150.0 1.1X -date - interval(m, d) 1764 1766 3 5.7 176.4 0.9X -date - interval(m, d, ms) 6428 6446 25 1.6 642.8 0.3X -timestamp + interval(m) 2719 2722 4 3.7 271.9 0.6X -timestamp + interval(m, d) 3011 3021 14 3.3 301.1 0.5X -timestamp + interval(m, d, ms) 3405 3412 9 2.9 340.5 0.5X -timestamp - interval(m) 2759 2764 7 3.6 275.9 0.6X -timestamp - interval(m, d) 3094 3112 25 3.2 309.4 0.5X -timestamp - interval(m, d, ms) 3388 3392 5 3.0 338.8 0.5X +date + interval(m) 1555 1634 113 6.4 155.5 1.0X +date + interval(m, d) 1774 1797 33 5.6 177.4 0.9X +date + interval(m, d, ms) 6293 6335 59 1.6 629.3 0.2X +date - interval(m) 1461 1468 10 6.8 146.1 1.1X +date - interval(m, d) 1741 1741 0 5.7 174.1 0.9X +date - interval(m, d, ms) 6503 6518 21 1.5 650.3 0.2X +timestamp + interval(m) 2384 2385 1 4.2 238.4 0.7X +timestamp + interval(m, d) 2683 2684 2 3.7 268.3 0.6X +timestamp + interval(m, d, ms) 2987 3001 19 3.3 298.7 0.5X +timestamp - interval(m) 2391 2395 5 4.2 239.1 0.7X +timestamp - interval(m, d) 2674 2684 14 3.7 267.4 0.6X +timestamp - interval(m, d, ms) 3005 3007 3 3.3 300.5 0.5X ================================================================================================ @@ -28,92 +28,92 @@ OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aw Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz cast to timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -cast to timestamp wholestage off 319 323 6 31.4 31.9 1.0X -cast to timestamp wholestage on 304 311 8 32.9 30.4 1.0X +cast to timestamp wholestage off 313 320 10 31.9 31.3 1.0X +cast to timestamp wholestage on 325 341 18 30.8 32.5 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz year of timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -year of timestamp wholestage off 1234 1239 6 8.1 123.4 1.0X -year of timestamp wholestage on 1229 1244 22 8.1 122.9 1.0X +year of timestamp wholestage off 1216 1216 1 8.2 121.6 1.0X +year of timestamp wholestage on 1226 1243 13 8.2 122.6 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz quarter of timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -quarter of timestamp wholestage off 1440 1445 7 6.9 144.0 1.0X -quarter of timestamp wholestage on 1358 1361 3 7.4 135.8 1.1X +quarter of timestamp wholestage off 1417 1421 5 7.1 141.7 1.0X +quarter of timestamp wholestage on 1358 1365 8 7.4 135.8 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz month of timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -month of timestamp wholestage off 1239 1240 1 8.1 123.9 1.0X -month of timestamp wholestage on 1221 1239 26 8.2 122.1 1.0X +month of timestamp wholestage off 1219 1220 1 8.2 121.9 1.0X +month of timestamp wholestage on 1222 1227 7 8.2 122.2 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz weekofyear of timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -weekofyear of timestamp wholestage off 1926 1934 11 5.2 192.6 1.0X -weekofyear of timestamp wholestage on 1901 1911 10 5.3 190.1 1.0X +weekofyear of timestamp wholestage off 1950 1950 0 5.1 195.0 1.0X +weekofyear of timestamp wholestage on 1890 1899 8 5.3 189.0 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz day of timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -day of timestamp wholestage off 1225 1229 6 8.2 122.5 1.0X -day of timestamp wholestage on 1217 1225 7 8.2 121.7 1.0X +day of timestamp wholestage off 1212 1213 2 8.3 121.2 1.0X +day of timestamp wholestage on 1216 1227 13 8.2 121.6 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz dayofyear of timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -dayofyear of timestamp wholestage off 1290 1295 7 7.8 129.0 1.0X -dayofyear of timestamp wholestage on 1262 1270 7 7.9 126.2 1.0X +dayofyear of timestamp wholestage off 1282 1284 3 7.8 128.2 1.0X +dayofyear of timestamp wholestage on 1269 1274 5 7.9 126.9 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz dayofmonth of timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -dayofmonth of timestamp wholestage off 1239 1239 1 8.1 123.9 1.0X -dayofmonth of timestamp wholestage on 1215 1222 8 8.2 121.5 1.0X +dayofmonth of timestamp wholestage off 1214 1219 7 8.2 121.4 1.0X +dayofmonth of timestamp wholestage on 1216 1224 6 8.2 121.6 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz dayofweek of timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -dayofweek of timestamp wholestage off 1421 1422 2 7.0 142.1 1.0X -dayofweek of timestamp wholestage on 1379 1388 8 7.3 137.9 1.0X +dayofweek of timestamp wholestage off 1403 1430 39 7.1 140.3 1.0X +dayofweek of timestamp wholestage on 1378 1386 8 7.3 137.8 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz weekday of timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -weekday of timestamp wholestage off 1349 1351 2 7.4 134.9 1.0X -weekday of timestamp wholestage on 1320 1327 8 7.6 132.0 1.0X +weekday of timestamp wholestage off 1344 1353 13 7.4 134.4 1.0X +weekday of timestamp wholestage on 1316 1322 5 7.6 131.6 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz hour of timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -hour of timestamp wholestage off 1024 1024 0 9.8 102.4 1.0X -hour of timestamp wholestage on 921 929 11 10.9 92.1 1.1X +hour of timestamp wholestage off 992 1000 10 10.1 99.2 1.0X +hour of timestamp wholestage on 960 962 3 10.4 96.0 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz minute of timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -minute of timestamp wholestage off 977 982 6 10.2 97.7 1.0X -minute of timestamp wholestage on 927 929 2 10.8 92.7 1.1X +minute of timestamp wholestage off 989 1000 16 10.1 98.9 1.0X +minute of timestamp wholestage on 965 974 13 10.4 96.5 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz second of timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -second of timestamp wholestage off 987 989 3 10.1 98.7 1.0X -second of timestamp wholestage on 923 926 5 10.8 92.3 1.1X +second of timestamp wholestage off 974 977 5 10.3 97.4 1.0X +second of timestamp wholestage on 959 966 8 10.4 95.9 1.0X ================================================================================================ @@ -124,15 +124,15 @@ OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aw Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz current_date: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -current_date wholestage off 303 311 12 33.0 30.3 1.0X -current_date wholestage on 266 271 5 37.5 26.6 1.1X +current_date wholestage off 281 282 2 35.6 28.1 1.0X +current_date wholestage on 294 300 5 34.0 29.4 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz current_timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -current_timestamp wholestage off 297 297 1 33.7 29.7 1.0X -current_timestamp wholestage on 264 272 7 37.8 26.4 1.1X +current_timestamp wholestage off 282 296 19 35.4 28.2 1.0X +current_timestamp wholestage on 304 331 31 32.9 30.4 0.9X ================================================================================================ @@ -143,43 +143,43 @@ OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aw Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz cast to date: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -cast to date wholestage off 1062 1063 2 9.4 106.2 1.0X -cast to date wholestage on 1007 1021 20 9.9 100.7 1.1X +cast to date wholestage off 1060 1061 1 9.4 106.0 1.0X +cast to date wholestage on 1021 1026 10 9.8 102.1 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz last_day: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -last_day wholestage off 1262 1265 5 7.9 126.2 1.0X -last_day wholestage on 1244 1256 14 8.0 124.4 1.0X +last_day wholestage off 1278 1280 3 7.8 127.8 1.0X +last_day wholestage on 1560 1566 6 6.4 156.0 0.8X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz next_day: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -next_day wholestage off 1119 1121 2 8.9 111.9 1.0X -next_day wholestage on 1057 1063 6 9.5 105.7 1.1X +next_day wholestage off 1091 1093 3 9.2 109.1 1.0X +next_day wholestage on 1070 1076 9 9.3 107.0 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz date_add: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date_add wholestage off 1054 1059 7 9.5 105.4 1.0X -date_add wholestage on 1037 1069 52 9.6 103.7 1.0X +date_add wholestage off 1041 1047 8 9.6 104.1 1.0X +date_add wholestage on 1044 1050 4 9.6 104.4 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz date_sub: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date_sub wholestage off 1054 1056 4 9.5 105.4 1.0X -date_sub wholestage on 1036 1040 4 9.7 103.6 1.0X +date_sub wholestage off 1038 1040 3 9.6 103.8 1.0X +date_sub wholestage on 1057 1061 4 9.5 105.7 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz add_months: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -add_months wholestage off 1408 1421 19 7.1 140.8 1.0X -add_months wholestage on 1434 1440 7 7.0 143.4 1.0X +add_months wholestage off 1401 1401 1 7.1 140.1 1.0X +add_months wholestage on 1438 1442 4 7.0 143.8 1.0X ================================================================================================ @@ -190,8 +190,8 @@ OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aw Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz format date: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -format date wholestage off 5937 6169 328 1.7 593.7 1.0X -format date wholestage on 5836 5878 74 1.7 583.6 1.0X +format date wholestage off 5482 5803 454 1.8 548.2 1.0X +format date wholestage on 5502 5518 9 1.8 550.2 1.0X ================================================================================================ @@ -202,8 +202,8 @@ OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aw Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz from_unixtime: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -from_unixtime wholestage off 8904 8914 14 1.1 890.4 1.0X -from_unixtime wholestage on 8918 8936 13 1.1 891.8 1.0X +from_unixtime wholestage off 8538 8553 22 1.2 853.8 1.0X +from_unixtime wholestage on 8545 8552 6 1.2 854.5 1.0X ================================================================================================ @@ -214,15 +214,15 @@ OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aw Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz from_utc_timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -from_utc_timestamp wholestage off 1110 1112 3 9.0 111.0 1.0X -from_utc_timestamp wholestage on 1115 1119 3 9.0 111.5 1.0X +from_utc_timestamp wholestage off 1094 1099 8 9.1 109.4 1.0X +from_utc_timestamp wholestage on 1109 1114 5 9.0 110.9 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz to_utc_timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -to_utc_timestamp wholestage off 1524 1525 1 6.6 152.4 1.0X -to_utc_timestamp wholestage on 1450 1458 14 6.9 145.0 1.1X +to_utc_timestamp wholestage off 1466 1469 4 6.8 146.6 1.0X +to_utc_timestamp wholestage on 1401 1408 7 7.1 140.1 1.0X ================================================================================================ @@ -233,29 +233,29 @@ OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aw Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz cast interval: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -cast interval wholestage off 341 342 1 29.3 34.1 1.0X -cast interval wholestage on 285 294 7 35.1 28.5 1.2X +cast interval wholestage off 332 332 0 30.1 33.2 1.0X +cast interval wholestage on 315 324 10 31.7 31.5 1.1X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz datediff: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -datediff wholestage off 1874 1881 10 5.3 187.4 1.0X -datediff wholestage on 1785 1791 3 5.6 178.5 1.0X +datediff wholestage off 1796 1802 8 5.6 179.6 1.0X +datediff wholestage on 1758 1764 10 5.7 175.8 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz months_between: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -months_between wholestage off 5038 5042 5 2.0 503.8 1.0X -months_between wholestage on 4979 4987 8 2.0 497.9 1.0X +months_between wholestage off 4833 4836 4 2.1 483.3 1.0X +months_between wholestage on 4777 4780 2 2.1 477.7 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz window: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -window wholestage off 1716 1841 177 0.6 1716.2 1.0X -window wholestage on 46024 46063 27 0.0 46024.1 0.0X +window wholestage off 1812 1908 136 0.6 1811.7 1.0X +window wholestage on 46279 46376 74 0.0 46278.8 0.0X ================================================================================================ @@ -266,134 +266,134 @@ OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aw Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz date_trunc YEAR: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date_trunc YEAR wholestage off 2428 2429 2 4.1 242.8 1.0X -date_trunc YEAR wholestage on 2451 2469 12 4.1 245.1 1.0X +date_trunc YEAR wholestage off 2367 2368 1 4.2 236.7 1.0X +date_trunc YEAR wholestage on 2321 2334 22 4.3 232.1 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz date_trunc YYYY: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date_trunc YYYY wholestage off 2423 2426 3 4.1 242.3 1.0X -date_trunc YYYY wholestage on 2454 2462 8 4.1 245.4 1.0X +date_trunc YYYY wholestage off 2330 2334 5 4.3 233.0 1.0X +date_trunc YYYY wholestage on 2326 2332 5 4.3 232.6 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz date_trunc YY: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date_trunc YY wholestage off 2421 2441 28 4.1 242.1 1.0X -date_trunc YY wholestage on 2453 2461 9 4.1 245.3 1.0X +date_trunc YY wholestage off 2334 2335 1 4.3 233.4 1.0X +date_trunc YY wholestage on 2315 2324 6 4.3 231.5 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz date_trunc MON: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date_trunc MON wholestage off 2425 2427 3 4.1 242.5 1.0X -date_trunc MON wholestage on 2431 2438 9 4.1 243.1 1.0X +date_trunc MON wholestage off 2327 2330 4 4.3 232.7 1.0X +date_trunc MON wholestage on 2279 2289 12 4.4 227.9 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz date_trunc MONTH: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date_trunc MONTH wholestage off 2427 2433 8 4.1 242.7 1.0X -date_trunc MONTH wholestage on 2429 2435 4 4.1 242.9 1.0X +date_trunc MONTH wholestage off 2330 2332 2 4.3 233.0 1.0X +date_trunc MONTH wholestage on 2277 2284 6 4.4 227.7 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz date_trunc MM: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date_trunc MM wholestage off 2425 2431 9 4.1 242.5 1.0X -date_trunc MM wholestage on 2430 2435 4 4.1 243.0 1.0X +date_trunc MM wholestage off 2328 2329 2 4.3 232.8 1.0X +date_trunc MM wholestage on 2279 2284 4 4.4 227.9 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz date_trunc DAY: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date_trunc DAY wholestage off 2117 2119 4 4.7 211.7 1.0X -date_trunc DAY wholestage on 2036 2118 174 4.9 203.6 1.0X +date_trunc DAY wholestage off 1974 1984 14 5.1 197.4 1.0X +date_trunc DAY wholestage on 1914 1922 7 5.2 191.4 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz date_trunc DD: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date_trunc DD wholestage off 2116 2119 5 4.7 211.6 1.0X -date_trunc DD wholestage on 2035 2043 10 4.9 203.5 1.0X +date_trunc DD wholestage off 1967 1976 12 5.1 196.7 1.0X +date_trunc DD wholestage on 1913 1917 4 5.2 191.3 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz date_trunc HOUR: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date_trunc HOUR wholestage off 2013 2014 2 5.0 201.3 1.0X -date_trunc HOUR wholestage on 2077 2088 13 4.8 207.7 1.0X +date_trunc HOUR wholestage off 1970 1970 0 5.1 197.0 1.0X +date_trunc HOUR wholestage on 1945 1946 2 5.1 194.5 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz date_trunc MINUTE: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date_trunc MINUTE wholestage off 363 368 8 27.6 36.3 1.0X -date_trunc MINUTE wholestage on 321 326 7 31.2 32.1 1.1X +date_trunc MINUTE wholestage off 361 361 1 27.7 36.1 1.0X +date_trunc MINUTE wholestage on 331 336 4 30.2 33.1 1.1X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz date_trunc SECOND: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date_trunc SECOND wholestage off 365 366 0 27.4 36.5 1.0X -date_trunc SECOND wholestage on 319 332 16 31.4 31.9 1.1X +date_trunc SECOND wholestage off 360 361 1 27.8 36.0 1.0X +date_trunc SECOND wholestage on 335 348 15 29.8 33.5 1.1X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz date_trunc WEEK: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date_trunc WEEK wholestage off 2371 2376 7 4.2 237.1 1.0X -date_trunc WEEK wholestage on 2314 2322 8 4.3 231.4 1.0X +date_trunc WEEK wholestage off 2232 2236 6 4.5 223.2 1.0X +date_trunc WEEK wholestage on 2225 2232 6 4.5 222.5 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz date_trunc QUARTER: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -date_trunc QUARTER wholestage off 3334 3335 1 3.0 333.4 1.0X -date_trunc QUARTER wholestage on 3286 3291 7 3.0 328.6 1.0X +date_trunc QUARTER wholestage off 3083 3086 4 3.2 308.3 1.0X +date_trunc QUARTER wholestage on 3073 3086 16 3.3 307.3 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz trunc year: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -trunc year wholestage off 303 304 2 33.0 30.3 1.0X -trunc year wholestage on 283 291 5 35.3 28.3 1.1X +trunc year wholestage off 321 321 0 31.1 32.1 1.0X +trunc year wholestage on 299 303 5 33.5 29.9 1.1X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz trunc yyyy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -trunc yyyy wholestage off 324 330 8 30.9 32.4 1.0X -trunc yyyy wholestage on 283 291 9 35.3 28.3 1.1X +trunc yyyy wholestage off 323 327 5 30.9 32.3 1.0X +trunc yyyy wholestage on 299 302 3 33.4 29.9 1.1X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz trunc yy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -trunc yy wholestage off 304 305 3 32.9 30.4 1.0X -trunc yy wholestage on 283 302 28 35.3 28.3 1.1X +trunc yy wholestage off 315 315 1 31.8 31.5 1.0X +trunc yy wholestage on 299 304 4 33.4 29.9 1.1X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz trunc mon: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -trunc mon wholestage off 315 319 6 31.7 31.5 1.0X -trunc mon wholestage on 284 287 5 35.3 28.4 1.1X +trunc mon wholestage off 320 321 1 31.2 32.0 1.0X +trunc mon wholestage on 299 307 10 33.4 29.9 1.1X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz trunc month: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -trunc month wholestage off 305 314 13 32.8 30.5 1.0X -trunc month wholestage on 283 292 14 35.3 28.3 1.1X +trunc month wholestage off 316 317 1 31.6 31.6 1.0X +trunc month wholestage on 299 302 5 33.5 29.9 1.1X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz trunc mm: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -trunc mm wholestage off 301 301 0 33.2 30.1 1.0X -trunc mm wholestage on 285 290 7 35.1 28.5 1.1X +trunc mm wholestage off 313 313 1 32.0 31.3 1.0X +trunc mm wholestage on 298 302 4 33.5 29.8 1.0X ================================================================================================ @@ -404,36 +404,36 @@ OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aw Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz to timestamp str: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -to timestamp str wholestage off 218 220 3 4.6 218.4 1.0X -to timestamp str wholestage on 213 216 6 4.7 212.5 1.0X +to timestamp str wholestage off 217 217 0 4.6 217.3 1.0X +to timestamp str wholestage on 209 212 2 4.8 209.5 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz to_timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -to_timestamp wholestage off 1838 1842 5 0.5 1838.1 1.0X -to_timestamp wholestage on 1952 1971 11 0.5 1952.2 0.9X +to_timestamp wholestage off 1676 1677 2 0.6 1675.6 1.0X +to_timestamp wholestage on 1599 1606 8 0.6 1599.5 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz to_unix_timestamp: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -to_unix_timestamp wholestage off 1987 1988 1 0.5 1986.9 1.0X -to_unix_timestamp wholestage on 1944 1948 3 0.5 1944.2 1.0X +to_unix_timestamp wholestage off 1582 1589 9 0.6 1582.1 1.0X +to_unix_timestamp wholestage on 1634 1637 3 0.6 1633.8 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz to date str: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -to date str wholestage off 263 264 0 3.8 263.5 1.0X -to date str wholestage on 263 265 2 3.8 262.6 1.0X +to date str wholestage off 275 282 9 3.6 275.0 1.0X +to date str wholestage on 264 265 2 3.8 263.5 1.0X OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz to_date: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -to_date wholestage off 3560 3567 11 0.3 3559.7 1.0X -to_date wholestage on 3525 3534 10 0.3 3524.8 1.0X +to_date wholestage off 3170 3188 25 0.3 3170.1 1.0X +to_date wholestage on 3134 3143 10 0.3 3134.3 1.0X ================================================================================================ @@ -444,14 +444,14 @@ OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aw Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz To/from Java's date-time: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -From java.sql.Date 405 416 16 12.3 81.0 1.0X -From java.time.LocalDate 344 352 14 14.5 68.8 1.2X -Collect java.sql.Date 1622 2553 1372 3.1 324.4 0.2X -Collect java.time.LocalDate 1464 1482 20 3.4 292.8 0.3X -From java.sql.Timestamp 248 258 15 20.2 49.6 1.6X -From java.time.Instant 237 243 7 21.1 47.4 1.7X -Collect longs 1252 1341 109 4.0 250.5 0.3X -Collect java.sql.Timestamp 1515 1516 2 3.3 302.9 0.3X -Collect java.time.Instant 1379 1490 96 3.6 275.8 0.3X +From java.sql.Date 407 413 7 12.3 81.5 1.0X +From java.time.LocalDate 340 344 5 14.7 68.1 1.2X +Collect java.sql.Date 1700 2658 1422 2.9 340.0 0.2X +Collect java.time.LocalDate 1473 1494 30 3.4 294.6 0.3X +From java.sql.Timestamp 252 266 13 19.8 50.5 1.6X +From java.time.Instant 236 243 7 21.1 47.3 1.7X +Collect longs 1280 1337 79 3.9 256.1 0.3X +Collect java.sql.Timestamp 1485 1501 15 3.4 297.0 0.3X +Collect java.time.Instant 1441 1465 37 3.5 288.1 0.3X diff --git a/sql/core/benchmarks/JsonBenchmark-jdk11-results.txt b/sql/core/benchmarks/JsonBenchmark-jdk11-results.txt index 03bc334471e56..d0cd591da4c94 100644 --- a/sql/core/benchmarks/JsonBenchmark-jdk11-results.txt +++ b/sql/core/benchmarks/JsonBenchmark-jdk11-results.txt @@ -3,110 +3,110 @@ Benchmark for performance of JSON parsing ================================================================================================ Preparing data for benchmarking ... -Java HotSpot(TM) 64-Bit Server VM 11.0.5+10-LTS on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz JSON schema inferring: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -No encoding 46010 46118 113 2.2 460.1 1.0X -UTF-8 is set 54407 55427 1718 1.8 544.1 0.8X +No encoding 68879 68993 116 1.5 688.8 1.0X +UTF-8 is set 115270 115602 455 0.9 1152.7 0.6X Preparing data for benchmarking ... -Java HotSpot(TM) 64-Bit Server VM 11.0.5+10-LTS on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz count a short column: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -No encoding 26614 28220 1461 3.8 266.1 1.0X -UTF-8 is set 42765 43400 550 2.3 427.6 0.6X +No encoding 47452 47538 113 2.1 474.5 1.0X +UTF-8 is set 77330 77354 30 1.3 773.3 0.6X Preparing data for benchmarking ... -Java HotSpot(TM) 64-Bit Server VM 11.0.5+10-LTS on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz count a wide column: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -No encoding 35696 35821 113 0.3 3569.6 1.0X -UTF-8 is set 55441 56176 1037 0.2 5544.1 0.6X +No encoding 60470 60900 534 0.2 6047.0 1.0X +UTF-8 is set 104733 104931 189 0.1 10473.3 0.6X Preparing data for benchmarking ... -Java HotSpot(TM) 64-Bit Server VM 11.0.5+10-LTS on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz select wide row: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -No encoding 61514 62968 NaN 0.0 123027.2 1.0X -UTF-8 is set 72096 72933 1162 0.0 144192.7 0.9X +No encoding 130302 131072 976 0.0 260604.6 1.0X +UTF-8 is set 150860 151284 377 0.0 301720.1 0.9X Preparing data for benchmarking ... -Java HotSpot(TM) 64-Bit Server VM 11.0.5+10-LTS on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select a subset of 10 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Select 10 columns 9859 9913 79 1.0 985.9 1.0X -Select 1 column 10981 11003 36 0.9 1098.1 0.9X +Select 10 columns 18619 18684 99 0.5 1861.9 1.0X +Select 1 column 24227 24270 38 0.4 2422.7 0.8X Preparing data for benchmarking ... -Java HotSpot(TM) 64-Bit Server VM 11.0.5+10-LTS on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz creation of JSON parser per line: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Short column without encoding 3555 3579 27 2.8 355.5 1.0X -Short column with UTF-8 5204 5227 35 1.9 520.4 0.7X -Wide column without encoding 60458 60637 164 0.2 6045.8 0.1X -Wide column with UTF-8 77544 78111 551 0.1 7754.4 0.0X +Short column without encoding 7947 7971 21 1.3 794.7 1.0X +Short column with UTF-8 12700 12753 58 0.8 1270.0 0.6X +Wide column without encoding 92632 92955 463 0.1 9263.2 0.1X +Wide column with UTF-8 147013 147170 188 0.1 14701.3 0.1X Preparing data for benchmarking ... -Java HotSpot(TM) 64-Bit Server VM 11.0.5+10-LTS on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz JSON functions: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Text read 342 346 3 29.2 34.2 1.0X -from_json 7123 7318 179 1.4 712.3 0.0X -json_tuple 9843 9957 132 1.0 984.3 0.0X -get_json_object 7827 8046 194 1.3 782.7 0.0X +Text read 713 734 19 14.0 71.3 1.0X +from_json 22019 22429 456 0.5 2201.9 0.0X +json_tuple 27987 28047 74 0.4 2798.7 0.0X +get_json_object 21468 21870 350 0.5 2146.8 0.0X Preparing data for benchmarking ... -Java HotSpot(TM) 64-Bit Server VM 11.0.5+10-LTS on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Dataset of json strings: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Text read 1856 1884 32 26.9 37.1 1.0X -schema inferring 16734 16900 153 3.0 334.7 0.1X -parsing 14884 15203 470 3.4 297.7 0.1X +Text read 2887 2910 24 17.3 57.7 1.0X +schema inferring 31793 31843 43 1.6 635.9 0.1X +parsing 36791 37104 294 1.4 735.8 0.1X Preparing data for benchmarking ... -Java HotSpot(TM) 64-Bit Server VM 11.0.5+10-LTS on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Json files in the per-line mode: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Text read 5932 6148 228 8.4 118.6 1.0X -Schema inferring 20836 21938 1086 2.4 416.7 0.3X -Parsing without charset 18134 18661 457 2.8 362.7 0.3X -Parsing with UTF-8 27734 28069 378 1.8 554.7 0.2X +Text read 10570 10611 45 4.7 211.4 1.0X +Schema inferring 48729 48763 41 1.0 974.6 0.2X +Parsing without charset 35490 35648 141 1.4 709.8 0.3X +Parsing with UTF-8 63853 63994 163 0.8 1277.1 0.2X -Java HotSpot(TM) 64-Bit Server VM 11.0.5+10-LTS on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Write dates and timestamps: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Create a dataset of timestamps 889 914 28 11.2 88.9 1.0X -to_json(timestamp) 7920 8172 353 1.3 792.0 0.1X -write timestamps to files 6726 6822 129 1.5 672.6 0.1X -Create a dataset of dates 953 963 12 10.5 95.3 0.9X -to_json(date) 5370 5705 320 1.9 537.0 0.2X -write dates to files 4109 4166 52 2.4 410.9 0.2X +Create a dataset of timestamps 2187 2190 5 4.6 218.7 1.0X +to_json(timestamp) 16262 16503 323 0.6 1626.2 0.1X +write timestamps to files 11679 11692 12 0.9 1167.9 0.2X +Create a dataset of dates 2297 2310 12 4.4 229.7 1.0X +to_json(date) 10904 10956 46 0.9 1090.4 0.2X +write dates to files 6610 6645 35 1.5 661.0 0.3X -Java HotSpot(TM) 64-Bit Server VM 11.0.5+10-LTS on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Read dates and timestamps: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -read timestamp text from files 1614 1675 55 6.2 161.4 1.0X -read timestamps from files 16640 16858 209 0.6 1664.0 0.1X -infer timestamps from files 33239 33388 227 0.3 3323.9 0.0X -read date text from files 1310 1340 44 7.6 131.0 1.2X -read date from files 9470 9513 41 1.1 947.0 0.2X -timestamp strings 1303 1342 47 7.7 130.3 1.2X -parse timestamps from Dataset[String] 17650 18073 380 0.6 1765.0 0.1X -infer timestamps from Dataset[String] 32623 34065 1330 0.3 3262.3 0.0X -date strings 1864 1871 7 5.4 186.4 0.9X -parse dates from Dataset[String] 10914 11316 482 0.9 1091.4 0.1X -from_json(timestamp) 21102 21990 929 0.5 2110.2 0.1X -from_json(date) 15275 15961 598 0.7 1527.5 0.1X +read timestamp text from files 2524 2530 9 4.0 252.4 1.0X +read timestamps from files 41002 41052 59 0.2 4100.2 0.1X +infer timestamps from files 84621 84939 526 0.1 8462.1 0.0X +read date text from files 2292 2302 9 4.4 229.2 1.1X +read date from files 16954 16976 21 0.6 1695.4 0.1X +timestamp strings 3067 3077 13 3.3 306.7 0.8X +parse timestamps from Dataset[String] 48690 48971 243 0.2 4869.0 0.1X +infer timestamps from Dataset[String] 97463 97786 338 0.1 9746.3 0.0X +date strings 3952 3956 3 2.5 395.2 0.6X +parse dates from Dataset[String] 24210 24241 30 0.4 2421.0 0.1X +from_json(timestamp) 71710 72242 629 0.1 7171.0 0.0X +from_json(date) 42465 42481 13 0.2 4246.5 0.1X diff --git a/sql/core/benchmarks/JsonBenchmark-results.txt b/sql/core/benchmarks/JsonBenchmark-results.txt index 0f188c4cdea56..46d2410fb47c3 100644 --- a/sql/core/benchmarks/JsonBenchmark-results.txt +++ b/sql/core/benchmarks/JsonBenchmark-results.txt @@ -3,110 +3,110 @@ Benchmark for performance of JSON parsing ================================================================================================ Preparing data for benchmarking ... -Java HotSpot(TM) 64-Bit Server VM 1.8.0_231-b11 on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz JSON schema inferring: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -No encoding 38998 41002 NaN 2.6 390.0 1.0X -UTF-8 is set 61231 63282 1854 1.6 612.3 0.6X +No encoding 63981 64044 56 1.6 639.8 1.0X +UTF-8 is set 112672 113350 962 0.9 1126.7 0.6X Preparing data for benchmarking ... -Java HotSpot(TM) 64-Bit Server VM 1.8.0_231-b11 on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz count a short column: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -No encoding 28272 28338 70 3.5 282.7 1.0X -UTF-8 is set 58681 62243 1517 1.7 586.8 0.5X +No encoding 51256 51449 180 2.0 512.6 1.0X +UTF-8 is set 83694 83859 148 1.2 836.9 0.6X Preparing data for benchmarking ... -Java HotSpot(TM) 64-Bit Server VM 1.8.0_231-b11 on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz count a wide column: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -No encoding 44026 51829 1329 0.2 4402.6 1.0X -UTF-8 is set 65839 68596 500 0.2 6583.9 0.7X +No encoding 58440 59097 569 0.2 5844.0 1.0X +UTF-8 is set 102746 102883 198 0.1 10274.6 0.6X Preparing data for benchmarking ... -Java HotSpot(TM) 64-Bit Server VM 1.8.0_231-b11 on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz select wide row: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -No encoding 72144 74820 NaN 0.0 144287.6 1.0X -UTF-8 is set 69571 77888 NaN 0.0 139142.3 1.0X +No encoding 128982 129304 356 0.0 257965.0 1.0X +UTF-8 is set 147247 147415 231 0.0 294494.1 0.9X Preparing data for benchmarking ... -Java HotSpot(TM) 64-Bit Server VM 1.8.0_231-b11 on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select a subset of 10 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Select 10 columns 9502 9604 106 1.1 950.2 1.0X -Select 1 column 11861 11948 109 0.8 1186.1 0.8X +Select 10 columns 18837 19048 331 0.5 1883.7 1.0X +Select 1 column 24707 24723 14 0.4 2470.7 0.8X Preparing data for benchmarking ... -Java HotSpot(TM) 64-Bit Server VM 1.8.0_231-b11 on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz creation of JSON parser per line: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Short column without encoding 3830 3846 15 2.6 383.0 1.0X -Short column with UTF-8 5538 5543 7 1.8 553.8 0.7X -Wide column without encoding 66899 69158 NaN 0.1 6689.9 0.1X -Wide column with UTF-8 90052 93235 NaN 0.1 9005.2 0.0X +Short column without encoding 8218 8234 17 1.2 821.8 1.0X +Short column with UTF-8 12374 12438 107 0.8 1237.4 0.7X +Wide column without encoding 136918 137298 345 0.1 13691.8 0.1X +Wide column with UTF-8 176961 177142 257 0.1 17696.1 0.0X Preparing data for benchmarking ... -Java HotSpot(TM) 64-Bit Server VM 1.8.0_231-b11 on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz JSON functions: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Text read 659 674 13 15.2 65.9 1.0X -from_json 7676 7943 405 1.3 767.6 0.1X -json_tuple 9881 10172 273 1.0 988.1 0.1X -get_json_object 7949 8055 119 1.3 794.9 0.1X +Text read 1268 1278 12 7.9 126.8 1.0X +from_json 23348 23479 176 0.4 2334.8 0.1X +json_tuple 29606 30221 1024 0.3 2960.6 0.0X +get_json_object 21898 22148 226 0.5 2189.8 0.1X Preparing data for benchmarking ... -Java HotSpot(TM) 64-Bit Server VM 1.8.0_231-b11 on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Dataset of json strings: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Text read 3314 3326 17 15.1 66.3 1.0X -schema inferring 16549 17037 484 3.0 331.0 0.2X -parsing 15138 15283 172 3.3 302.8 0.2X +Text read 5887 5944 49 8.5 117.7 1.0X +schema inferring 46696 47054 312 1.1 933.9 0.1X +parsing 32336 32450 129 1.5 646.7 0.2X Preparing data for benchmarking ... -Java HotSpot(TM) 64-Bit Server VM 1.8.0_231-b11 on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Json files in the per-line mode: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Text read 5136 5446 268 9.7 102.7 1.0X -Schema inferring 19864 20568 1191 2.5 397.3 0.3X -Parsing without charset 17535 17888 329 2.9 350.7 0.3X -Parsing with UTF-8 25609 25758 218 2.0 512.2 0.2X +Text read 9756 9769 11 5.1 195.1 1.0X +Schema inferring 51318 51433 108 1.0 1026.4 0.2X +Parsing without charset 43609 43743 118 1.1 872.2 0.2X +Parsing with UTF-8 60775 60844 106 0.8 1215.5 0.2X -Java HotSpot(TM) 64-Bit Server VM 1.8.0_231-b11 on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Write dates and timestamps: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Create a dataset of timestamps 784 790 7 12.8 78.4 1.0X -to_json(timestamp) 8005 8055 50 1.2 800.5 0.1X -write timestamps to files 6515 6559 45 1.5 651.5 0.1X -Create a dataset of dates 854 881 24 11.7 85.4 0.9X -to_json(date) 5187 5194 7 1.9 518.7 0.2X -write dates to files 3663 3684 22 2.7 366.3 0.2X +Create a dataset of timestamps 1998 2015 17 5.0 199.8 1.0X +to_json(timestamp) 18156 18317 263 0.6 1815.6 0.1X +write timestamps to files 12912 12917 5 0.8 1291.2 0.2X +Create a dataset of dates 2209 2270 53 4.5 220.9 0.9X +to_json(date) 9433 9489 90 1.1 943.3 0.2X +write dates to files 6915 6923 8 1.4 691.5 0.3X -Java HotSpot(TM) 64-Bit Server VM 1.8.0_231-b11 on Mac OS X 10.15.4 -Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Read dates and timestamps: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -read timestamp text from files 1297 1316 26 7.7 129.7 1.0X -read timestamps from files 16915 17723 963 0.6 1691.5 0.1X -infer timestamps from files 33967 34304 360 0.3 3396.7 0.0X -read date text from files 1095 1100 7 9.1 109.5 1.2X -read date from files 8376 8513 209 1.2 837.6 0.2X -timestamp strings 1807 1816 8 5.5 180.7 0.7X -parse timestamps from Dataset[String] 18189 18242 74 0.5 1818.9 0.1X -infer timestamps from Dataset[String] 37906 38547 571 0.3 3790.6 0.0X -date strings 2191 2194 4 4.6 219.1 0.6X -parse dates from Dataset[String] 11593 11625 33 0.9 1159.3 0.1X -from_json(timestamp) 22589 22650 101 0.4 2258.9 0.1X -from_json(date) 16479 16619 159 0.6 1647.9 0.1X +read timestamp text from files 2395 2412 17 4.2 239.5 1.0X +read timestamps from files 47269 47334 89 0.2 4726.9 0.1X +infer timestamps from files 91806 91851 67 0.1 9180.6 0.0X +read date text from files 2118 2133 13 4.7 211.8 1.1X +read date from files 17267 17340 115 0.6 1726.7 0.1X +timestamp strings 3906 3935 26 2.6 390.6 0.6X +parse timestamps from Dataset[String] 52244 52534 279 0.2 5224.4 0.0X +infer timestamps from Dataset[String] 100488 100714 198 0.1 10048.8 0.0X +date strings 4572 4584 12 2.2 457.2 0.5X +parse dates from Dataset[String] 26749 26768 17 0.4 2674.9 0.1X +from_json(timestamp) 71414 71867 556 0.1 7141.4 0.0X +from_json(date) 45322 45549 250 0.2 4532.2 0.1X diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 7ae60f22aa790..3e409ab9a50a1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils; import org.apache.spark.sql.catalyst.util.RebaseDateTime; +import org.apache.spark.sql.execution.datasources.DataSourceUtils; import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; import org.apache.spark.sql.types.DataTypes; @@ -102,14 +103,14 @@ public class VectorizedColumnReader { // The timezone conversion to apply to int96 timestamps. Null if no conversion. private final ZoneId convertTz; private static final ZoneId UTC = ZoneOffset.UTC; - private final boolean rebaseDateTime; + private final String datetimeRebaseMode; public VectorizedColumnReader( ColumnDescriptor descriptor, OriginalType originalType, PageReader pageReader, ZoneId convertTz, - boolean rebaseDateTime) throws IOException { + String datetimeRebaseMode) throws IOException { this.descriptor = descriptor; this.pageReader = pageReader; this.convertTz = convertTz; @@ -132,7 +133,9 @@ public VectorizedColumnReader( if (totalValueCount == 0) { throw new IOException("totalValueCount == 0"); } - this.rebaseDateTime = rebaseDateTime; + assert "LEGACY".equals(datetimeRebaseMode) || "EXCEPTION".equals(datetimeRebaseMode) || + "CORRECTED".equals(datetimeRebaseMode); + this.datetimeRebaseMode = datetimeRebaseMode; } /** @@ -152,6 +155,52 @@ private boolean next() throws IOException { return definitionLevelColumn.nextInt() == maxDefLevel; } + private boolean isLazyDecodingSupported(PrimitiveType.PrimitiveTypeName typeName) { + boolean isSupported = false; + switch (typeName) { + case INT32: + isSupported = originalType != OriginalType.DATE || "CORRECTED".equals(datetimeRebaseMode); + break; + case INT64: + if (originalType == OriginalType.TIMESTAMP_MICROS) { + isSupported = "CORRECTED".equals(datetimeRebaseMode); + } else { + isSupported = originalType != OriginalType.TIMESTAMP_MILLIS; + } + break; + case FLOAT: + case DOUBLE: + case BINARY: + isSupported = true; + break; + } + return isSupported; + } + + static int rebaseDays(int julianDays, final boolean failIfRebase) { + if (failIfRebase) { + if (julianDays < RebaseDateTime.lastSwitchJulianDay()) { + throw DataSourceUtils.newRebaseExceptionInRead("Parquet"); + } else { + return julianDays; + } + } else { + return RebaseDateTime.rebaseJulianToGregorianDays(julianDays); + } + } + + static long rebaseMicros(long julianMicros, final boolean failIfRebase) { + if (failIfRebase) { + if (julianMicros < RebaseDateTime.lastSwitchJulianTs()) { + throw DataSourceUtils.newRebaseExceptionInRead("Parquet"); + } else { + return julianMicros; + } + } else { + return RebaseDateTime.rebaseJulianToGregorianMicros(julianMicros); + } + } + /** * Reads `total` values from this columnReader into column. */ @@ -181,13 +230,7 @@ void readBatch(int total, WritableColumnVector column) throws IOException { // TIMESTAMP_MILLIS encoded as INT64 can't be lazily decoded as we need to post process // the values to add microseconds precision. - if (column.hasDictionary() || (rowId == 0 && - (typeName == PrimitiveType.PrimitiveTypeName.INT32 || - (typeName == PrimitiveType.PrimitiveTypeName.INT64 && - originalType != OriginalType.TIMESTAMP_MILLIS) || - typeName == PrimitiveType.PrimitiveTypeName.FLOAT || - typeName == PrimitiveType.PrimitiveTypeName.DOUBLE || - typeName == PrimitiveType.PrimitiveTypeName.BINARY))) { + if (column.hasDictionary() || (rowId == 0 && isLazyDecodingSupported(typeName))) { // Column vector supports lazy decoding of dictionary values so just set the dictionary. // We can't do this if rowId != 0 AND the column doesn't have a dictionary (i.e. some // non-dictionary encoded values have already been added). @@ -266,7 +309,8 @@ private void decodeDictionaryIds( switch (descriptor.getPrimitiveType().getPrimitiveTypeName()) { case INT32: if (column.dataType() == DataTypes.IntegerType || - DecimalType.is32BitDecimalType(column.dataType())) { + DecimalType.is32BitDecimalType(column.dataType()) || + (column.dataType() == DataTypes.DateType && "CORRECTED".equals(datetimeRebaseMode))) { for (int i = rowId; i < rowId + num; ++i) { if (!column.isNullAt(i)) { column.putInt(i, dictionary.decodeToInt(dictionaryIds.getDictId(i))); @@ -284,6 +328,14 @@ private void decodeDictionaryIds( column.putShort(i, (short) dictionary.decodeToInt(dictionaryIds.getDictId(i))); } } + } else if (column.dataType() == DataTypes.DateType) { + final boolean failIfRebase = "EXCEPTION".equals(datetimeRebaseMode); + for (int i = rowId; i < rowId + num; ++i) { + if (!column.isNullAt(i)) { + int julianDays = dictionary.decodeToInt(dictionaryIds.getDictId(i)); + column.putInt(i, rebaseDays(julianDays, failIfRebase)); + } + } } else { throw constructConvertNotSupportedException(descriptor, column); } @@ -292,17 +344,37 @@ private void decodeDictionaryIds( case INT64: if (column.dataType() == DataTypes.LongType || DecimalType.is64BitDecimalType(column.dataType()) || - originalType == OriginalType.TIMESTAMP_MICROS) { + (originalType == OriginalType.TIMESTAMP_MICROS && + "CORRECTED".equals(datetimeRebaseMode))) { for (int i = rowId; i < rowId + num; ++i) { if (!column.isNullAt(i)) { column.putLong(i, dictionary.decodeToLong(dictionaryIds.getDictId(i))); } } } else if (originalType == OriginalType.TIMESTAMP_MILLIS) { + if ("CORRECTED".equals(datetimeRebaseMode)) { + for (int i = rowId; i < rowId + num; ++i) { + if (!column.isNullAt(i)) { + long gregorianMillis = dictionary.decodeToLong(dictionaryIds.getDictId(i)); + column.putLong(i, DateTimeUtils.millisToMicros(gregorianMillis)); + } + } + } else { + final boolean failIfRebase = "EXCEPTION".equals(datetimeRebaseMode); + for (int i = rowId; i < rowId + num; ++i) { + if (!column.isNullAt(i)) { + long julianMillis = dictionary.decodeToLong(dictionaryIds.getDictId(i)); + long julianMicros = DateTimeUtils.millisToMicros(julianMillis); + column.putLong(i, rebaseMicros(julianMicros, failIfRebase)); + } + } + } + } else if (originalType == OriginalType.TIMESTAMP_MICROS) { + final boolean failIfRebase = "EXCEPTION".equals(datetimeRebaseMode); for (int i = rowId; i < rowId + num; ++i) { if (!column.isNullAt(i)) { - column.putLong(i, - DateTimeUtils.millisToMicros(dictionary.decodeToLong(dictionaryIds.getDictId(i)))); + long julianMicros = dictionary.decodeToLong(dictionaryIds.getDictId(i)); + column.putLong(i, rebaseMicros(julianMicros, failIfRebase)); } } } else { @@ -422,12 +494,13 @@ private void readIntBatch(int rowId, int num, WritableColumnVector column) throw defColumn.readShorts( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else if (column.dataType() == DataTypes.DateType ) { - if (rebaseDateTime) { - defColumn.readIntegersWithRebase( + if ("CORRECTED".equals(datetimeRebaseMode)) { + defColumn.readIntegers( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else { - defColumn.readIntegers( - num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + boolean failIfRebase = "EXCEPTION".equals(datetimeRebaseMode); + defColumn.readIntegersWithRebase( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn, failIfRebase); } } else { throw constructConvertNotSupportedException(descriptor, column); @@ -441,27 +514,29 @@ private void readLongBatch(int rowId, int num, WritableColumnVector column) thro defColumn.readLongs( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else if (originalType == OriginalType.TIMESTAMP_MICROS) { - if (rebaseDateTime) { - defColumn.readLongsWithRebase( - num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); - } else { + if ("CORRECTED".equals(datetimeRebaseMode)) { defColumn.readLongs( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + } else { + boolean failIfRebase = "EXCEPTION".equals(datetimeRebaseMode); + defColumn.readLongsWithRebase( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn, failIfRebase); } } else if (originalType == OriginalType.TIMESTAMP_MILLIS) { - if (rebaseDateTime) { + if ("CORRECTED".equals(datetimeRebaseMode)) { for (int i = 0; i < num; i++) { if (defColumn.readInteger() == maxDefLevel) { - long micros = DateTimeUtils.millisToMicros(dataColumn.readLong()); - column.putLong(rowId + i, RebaseDateTime.rebaseJulianToGregorianMicros(micros)); + column.putLong(rowId + i, DateTimeUtils.millisToMicros(dataColumn.readLong())); } else { column.putNull(rowId + i); } } } else { + final boolean failIfRebase = "EXCEPTION".equals(datetimeRebaseMode); for (int i = 0; i < num; i++) { if (defColumn.readInteger() == maxDefLevel) { - column.putLong(rowId + i, DateTimeUtils.millisToMicros(dataColumn.readLong())); + long julianMicros = DateTimeUtils.millisToMicros(dataColumn.readLong()); + column.putLong(rowId + i, rebaseMicros(julianMicros, failIfRebase)); } else { column.putNull(rowId + i); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index c9590b97ce9cd..b40cc154d76fe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -89,9 +89,9 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa private final ZoneId convertTz; /** - * true if need to rebase date/timestamp from Julian to Proleptic Gregorian calendar. + * The mode of rebasing date/timestamp from Julian to Proleptic Gregorian calendar. */ - private final boolean rebaseDateTime; + private final String datetimeRebaseMode; /** * columnBatch object that is used for batch decoding. This is created on first use and triggers @@ -122,16 +122,16 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa private final MemoryMode MEMORY_MODE; public VectorizedParquetRecordReader( - ZoneId convertTz, boolean rebaseDateTime, boolean useOffHeap, int capacity) { + ZoneId convertTz, String datetimeRebaseMode, boolean useOffHeap, int capacity) { this.convertTz = convertTz; - this.rebaseDateTime = rebaseDateTime; + this.datetimeRebaseMode = datetimeRebaseMode; MEMORY_MODE = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP; this.capacity = capacity; } // For test only. public VectorizedParquetRecordReader(boolean useOffHeap, int capacity) { - this(null, false, useOffHeap, capacity); + this(null, "CORRECTED", useOffHeap, capacity); } /** @@ -321,7 +321,7 @@ private void checkEndOfRowGroup() throws IOException { for (int i = 0; i < columns.size(); ++i) { if (missingColumns[i]) continue; columnReaders[i] = new VectorizedColumnReader(columns.get(i), types.get(i).getOriginalType(), - pages.getPageReader(columns.get(i)), convertTz, rebaseDateTime); + pages.getPageReader(columns.get(i)), convertTz, datetimeRebaseMode); } totalCountLoadedSoFar += pages.getRowCount(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java index 2ed2e11b60c03..eddbf39178e9a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -21,13 +21,14 @@ import java.nio.ByteOrder; import org.apache.parquet.bytes.ByteBufferInputStream; +import org.apache.parquet.column.values.ValuesReader; +import org.apache.parquet.io.api.Binary; import org.apache.parquet.io.ParquetDecodingException; + import org.apache.spark.sql.catalyst.util.RebaseDateTime; +import org.apache.spark.sql.execution.datasources.DataSourceUtils; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; -import org.apache.parquet.column.values.ValuesReader; -import org.apache.parquet.io.api.Binary; - /** * An implementation of the Parquet PLAIN decoder that supports the vectorized interface. */ @@ -86,7 +87,8 @@ public final void readIntegers(int total, WritableColumnVector c, int rowId) { // iterates the values twice: check if we need to rebase first, then go to the optimized branch // if rebase is not needed. @Override - public final void readIntegersWithRebase(int total, WritableColumnVector c, int rowId) { + public final void readIntegersWithRebase( + int total, WritableColumnVector c, int rowId, boolean failIfRebase) { int requiredBytes = total * 4; ByteBuffer buffer = getBuffer(requiredBytes); boolean rebase = false; @@ -94,8 +96,12 @@ public final void readIntegersWithRebase(int total, WritableColumnVector c, int rebase |= buffer.getInt(buffer.position() + i * 4) < RebaseDateTime.lastSwitchJulianDay(); } if (rebase) { - for (int i = 0; i < total; i += 1) { - c.putInt(rowId + i, RebaseDateTime.rebaseJulianToGregorianDays(buffer.getInt())); + if (failIfRebase) { + throw DataSourceUtils.newRebaseExceptionInRead("Parquet"); + } else { + for (int i = 0; i < total; i += 1) { + c.putInt(rowId + i, RebaseDateTime.rebaseJulianToGregorianDays(buffer.getInt())); + } } } else { if (buffer.hasArray()) { @@ -128,7 +134,8 @@ public final void readLongs(int total, WritableColumnVector c, int rowId) { // iterates the values twice: check if we need to rebase first, then go to the optimized branch // if rebase is not needed. @Override - public final void readLongsWithRebase(int total, WritableColumnVector c, int rowId) { + public final void readLongsWithRebase( + int total, WritableColumnVector c, int rowId, boolean failIfRebase) { int requiredBytes = total * 8; ByteBuffer buffer = getBuffer(requiredBytes); boolean rebase = false; @@ -136,8 +143,12 @@ public final void readLongsWithRebase(int total, WritableColumnVector c, int row rebase |= buffer.getLong(buffer.position() + i * 8) < RebaseDateTime.lastSwitchJulianTs(); } if (rebase) { - for (int i = 0; i < total; i += 1) { - c.putLong(rowId + i, RebaseDateTime.rebaseJulianToGregorianMicros(buffer.getLong())); + if (failIfRebase) { + throw DataSourceUtils.newRebaseExceptionInRead("Parquet"); + } else { + for (int i = 0; i < total; i += 1) { + c.putLong(rowId + i, RebaseDateTime.rebaseJulianToGregorianMicros(buffer.getLong())); + } } } else { if (buffer.hasArray()) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index 4d72a33fcf774..24347a4e3a0c5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution.datasources.parquet; +import java.io.IOException; +import java.nio.ByteBuffer; + import org.apache.parquet.Preconditions; import org.apache.parquet.bytes.ByteBufferInputStream; import org.apache.parquet.bytes.BytesUtils; @@ -26,12 +29,8 @@ import org.apache.parquet.io.ParquetDecodingException; import org.apache.parquet.io.api.Binary; -import org.apache.spark.sql.catalyst.util.RebaseDateTime; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; -import java.io.IOException; -import java.nio.ByteBuffer; - /** * A values reader for Parquet's run-length encoded data. This is based off of the version in * parquet-mr with these changes: @@ -211,7 +210,8 @@ public void readIntegersWithRebase( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) throws IOException { + VectorizedValuesReader data, + final boolean failIfRebase) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -219,7 +219,7 @@ public void readIntegersWithRebase( switch (mode) { case RLE: if (currentValue == level) { - data.readIntegersWithRebase(n, c, rowId); + data.readIntegersWithRebase(n, c, rowId, failIfRebase); } else { c.putNulls(rowId, n); } @@ -227,8 +227,8 @@ public void readIntegersWithRebase( case PACKED: for (int i = 0; i < n; ++i) { if (currentBuffer[currentBufferIdx++] == level) { - c.putInt(rowId + i, - RebaseDateTime.rebaseJulianToGregorianDays(data.readInteger())); + int julianDays = data.readInteger(); + c.putInt(rowId + i, VectorizedColumnReader.rebaseDays(julianDays, failIfRebase)); } else { c.putNull(rowId + i); } @@ -387,7 +387,8 @@ public void readLongsWithRebase( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) throws IOException { + VectorizedValuesReader data, + final boolean failIfRebase) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -395,7 +396,7 @@ public void readLongsWithRebase( switch (mode) { case RLE: if (currentValue == level) { - data.readLongsWithRebase(n, c, rowId); + data.readLongsWithRebase(n, c, rowId, failIfRebase); } else { c.putNulls(rowId, n); } @@ -403,8 +404,8 @@ public void readLongsWithRebase( case PACKED: for (int i = 0; i < n; ++i) { if (currentBuffer[currentBufferIdx++] == level) { - c.putLong(rowId + i, - RebaseDateTime.rebaseJulianToGregorianMicros(data.readLong())); + long julianMicros = data.readLong(); + c.putLong(rowId + i, VectorizedColumnReader.rebaseMicros(julianMicros, failIfRebase)); } else { c.putNull(rowId + i); } @@ -584,7 +585,8 @@ public void readIntegers(int total, WritableColumnVector c, int rowId) { } @Override - public void readIntegersWithRebase(int total, WritableColumnVector c, int rowId) { + public void readIntegersWithRebase( + int total, WritableColumnVector c, int rowId, boolean failIfRebase) { throw new UnsupportedOperationException("only readInts is valid."); } @@ -604,7 +606,8 @@ public void readLongs(int total, WritableColumnVector c, int rowId) { } @Override - public void readLongsWithRebase(int total, WritableColumnVector c, int rowId) { + public void readLongsWithRebase( + int total, WritableColumnVector c, int rowId, boolean failIfRebase) { throw new UnsupportedOperationException("only readInts is valid."); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java index 809ac44cc8272..35db8f235ed60 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java @@ -40,9 +40,9 @@ public interface VectorizedValuesReader { void readBooleans(int total, WritableColumnVector c, int rowId); void readBytes(int total, WritableColumnVector c, int rowId); void readIntegers(int total, WritableColumnVector c, int rowId); - void readIntegersWithRebase(int total, WritableColumnVector c, int rowId); + void readIntegersWithRebase(int total, WritableColumnVector c, int rowId, boolean failIfRebase); void readLongs(int total, WritableColumnVector c, int rowId); - void readLongsWithRebase(int total, WritableColumnVector c, int rowId); + void readLongsWithRebase(int total, WritableColumnVector c, int rowId, boolean failIfRebase); void readFloats(int total, WritableColumnVector c, int rowId); void readDoubles(int total, WritableColumnVector c, int rowId); void readBinary(int total, WritableColumnVector c, int rowId); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index be597edecba98..60a60377d8a3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import java.io.Closeable import java.util.concurrent.TimeUnit._ -import java.util.concurrent.atomic.AtomicReference +import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag @@ -49,7 +49,6 @@ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.ExecutionListenerManager import org.apache.spark.util.{CallSite, Utils} - /** * The entry point to programming Spark with the Dataset and DataFrame API. * @@ -940,15 +939,7 @@ object SparkSession extends Logging { options.foreach { case (k, v) => session.initialSessionOptions.put(k, v) } setDefaultSession(session) setActiveSession(session) - - // Register a successfully instantiated context to the singleton. This should be at the - // end of the class definition so that the singleton is updated only if there is no - // exception in the construction of the instance. - sparkContext.addSparkListener(new SparkListener { - override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { - defaultSession.set(null) - } - }) + registerContextListener(sparkContext) } return session @@ -1064,6 +1055,20 @@ object SparkSession extends Logging { // Private methods from now on //////////////////////////////////////////////////////////////////////////////////////// + private val listenerRegistered: AtomicBoolean = new AtomicBoolean(false) + + /** Register the AppEnd listener onto the Context */ + private def registerContextListener(sparkContext: SparkContext): Unit = { + if (!listenerRegistered.get()) { + sparkContext.addSparkListener(new SparkListener { + override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { + defaultSession.set(null) + } + }) + listenerRegistered.set(true) + } + } + /** The active SparkSession for the current thread. */ private val activeThreadSession = new InheritableThreadLocal[SparkSession] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 66996498ffd3b..0ae39cf8560e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -55,10 +55,12 @@ trait DataSourceScanExec extends LeafExecNode { // Metadata that describes more details of this scan. protected def metadata: Map[String, String] + protected val maxMetadataValueLength = 100 + override def simpleString(maxFields: Int): String = { val metadataEntries = metadata.toSeq.sorted.map { case (key, value) => - key + ": " + StringUtils.abbreviate(redact(value), 100) + key + ": " + StringUtils.abbreviate(redact(value), maxMetadataValueLength) } val metadataStr = truncatedString(metadataEntries, " ", ", ", "", maxFields) redact( @@ -335,7 +337,8 @@ case class FileSourceScanExec( def seqToString(seq: Seq[Any]) = seq.mkString("[", ", ", "]") val location = relation.location val locationDesc = - location.getClass.getSimpleName + seqToString(location.rootPaths) + location.getClass.getSimpleName + + Utils.buildLocationMetadata(location.rootPaths, maxMetadataValueLength) val metadata = Map( "Format" -> relation.fileFormat.toString, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala index 1a84db1970449..9f99bf5011569 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.execution import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} -import java.time.{Instant, LocalDate} +import java.time.{Instant, LocalDate, ZoneOffset} import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, LegacyDateFormats, TimestampFormatter} import org.apache.spark.sql.execution.command.{DescribeCommandBase, ExecutedCommandExec, ShowTablesCommand, ShowViewsCommand} import org.apache.spark.sql.execution.datasources.v2.{DescribeTableExec, ShowTablesExec} import org.apache.spark.sql.internal.SQLConf @@ -72,21 +72,33 @@ object HiveResult { } } - private def zoneId = DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone) - private def dateFormatter = DateFormatter(zoneId) - private def timestampFormatter = TimestampFormatter.getFractionFormatter(zoneId) + // We can create the date formatter only once because it does not depend on Spark's + // session time zone controlled by the SQL config `spark.sql.session.timeZone`. + // The `zoneId` parameter is used only in parsing of special date values like `now`, + // `yesterday` and etc. but not in date formatting. While formatting of: + // - `java.time.LocalDate`, zone id is not used by `DateTimeFormatter` at all. + // - `java.sql.Date`, the date formatter delegates formatting to the legacy formatter + // which uses the default system time zone `TimeZone.getDefault`. This works correctly + // due to `DateTimeUtils.toJavaDate` which is based on the system time zone too. + private val dateFormatter = DateFormatter( + format = DateFormatter.defaultPattern, + // We can set any time zone id. UTC was taken for simplicity. + zoneId = ZoneOffset.UTC, + locale = DateFormatter.defaultLocale, + // Use `FastDateFormat` as the legacy formatter because it is thread-safe. + legacyFormat = LegacyDateFormats.FAST_DATE_FORMAT, + isParsing = false) + private def timestampFormatter = TimestampFormatter.getFractionFormatter( + DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone)) /** Formats a datum (based on the given data type) and returns the string representation. */ def toHiveString(a: (Any, DataType), nested: Boolean = false): String = a match { case (null, _) => if (nested) "null" else "NULL" case (b, BooleanType) => b.toString - case (d: Date, DateType) => dateFormatter.format(DateTimeUtils.fromJavaDate(d)) - case (ld: LocalDate, DateType) => - dateFormatter.format(DateTimeUtils.localDateToDays(ld)) - case (t: Timestamp, TimestampType) => - timestampFormatter.format(DateTimeUtils.fromJavaTimestamp(t)) - case (i: Instant, TimestampType) => - timestampFormatter.format(DateTimeUtils.instantToMicros(i)) + case (d: Date, DateType) => dateFormatter.format(d) + case (ld: LocalDate, DateType) => dateFormatter.format(ld) + case (t: Timestamp, TimestampType) => timestampFormatter.format(t) + case (i: Instant, TimestampType) => timestampFormatter.format(i) case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8) case (decimal: java.math.BigDecimal, DecimalType()) => decimal.toPlainString case (n, _: NumericType) => n.toString diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 99bc45fa9e9e8..89915d254883d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -22,6 +22,7 @@ import java.util.UUID import org.apache.hadoop.fs.Path +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker} @@ -50,7 +51,7 @@ import org.apache.spark.util.Utils class QueryExecution( val sparkSession: SparkSession, val logical: LogicalPlan, - val tracker: QueryPlanningTracker = new QueryPlanningTracker) { + val tracker: QueryPlanningTracker = new QueryPlanningTracker) extends Logging { // TODO: Move the planner an optimizer into here from SessionState. protected def planner = sparkSession.sessionState.planner @@ -82,17 +83,30 @@ class QueryExecution( sparkSession.sessionState.optimizer.executeAndTrack(withCachedData.clone(), tracker) } - lazy val sparkPlan: SparkPlan = executePhase(QueryPlanningTracker.PLANNING) { - // Clone the logical plan here, in case the planner rules change the states of the logical plan. - QueryExecution.createSparkPlan(sparkSession, planner, optimizedPlan.clone()) + private def assertOptimized(): Unit = optimizedPlan + + lazy val sparkPlan: SparkPlan = { + // We need to materialize the optimizedPlan here because sparkPlan is also tracked under + // the planning phase + assertOptimized() + executePhase(QueryPlanningTracker.PLANNING) { + // Clone the logical plan here, in case the planner rules change the states of the logical + // plan. + QueryExecution.createSparkPlan(sparkSession, planner, optimizedPlan.clone()) + } } // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. - lazy val executedPlan: SparkPlan = executePhase(QueryPlanningTracker.PLANNING) { - // clone the plan to avoid sharing the plan instance between different stages like analyzing, - // optimizing and planning. - QueryExecution.prepareForExecution(preparations, sparkPlan.clone()) + lazy val executedPlan: SparkPlan = { + // We need to materialize the optimizedPlan here, before tracking the planning phase, to ensure + // that the optimization time is not counted as part of the planning phase. + assertOptimized() + executePhase(QueryPlanningTracker.PLANNING) { + // clone the plan to avoid sharing the plan instance between different stages like analyzing, + // optimizing and planning. + QueryExecution.prepareForExecution(preparations, sparkPlan.clone()) + } } /** @@ -120,26 +134,42 @@ class QueryExecution( tracker.measurePhase(phase)(block) } - def simpleString: String = simpleString(false) - - def simpleString(formatted: Boolean): String = withRedaction { + def simpleString: String = { val concat = new PlanStringConcat() - concat.append("== Physical Plan ==\n") + simpleString(false, SQLConf.get.maxToStringFields, concat.append) + withRedaction { + concat.toString + } + } + + private def simpleString( + formatted: Boolean, + maxFields: Int, + append: String => Unit): Unit = { + append("== Physical Plan ==\n") if (formatted) { try { - ExplainUtils.processPlan(executedPlan, concat.append) + ExplainUtils.processPlan(executedPlan, append) } catch { - case e: AnalysisException => concat.append(e.toString) - case e: IllegalArgumentException => concat.append(e.toString) + case e: AnalysisException => append(e.toString) + case e: IllegalArgumentException => append(e.toString) } } else { - QueryPlan.append(executedPlan, concat.append, verbose = false, addSuffix = false) + QueryPlan.append(executedPlan, + append, verbose = false, addSuffix = false, maxFields = maxFields) } - concat.append("\n") - concat.toString + append("\n") } def explainString(mode: ExplainMode): String = { + val concat = new PlanStringConcat() + explainString(mode, SQLConf.get.maxToStringFields, concat.append) + withRedaction { + concat.toString + } + } + + private def explainString(mode: ExplainMode, maxFields: Int, append: String => Unit): Unit = { val queryExecution = if (logical.isStreaming) { // This is used only by explaining `Dataset/DataFrame` created by `spark.readStream`, so the // output mode does not matter since there is no `Sink`. @@ -152,19 +182,19 @@ class QueryExecution( mode match { case SimpleMode => - queryExecution.simpleString + queryExecution.simpleString(false, maxFields, append) case ExtendedMode => - queryExecution.toString + queryExecution.toString(maxFields, append) case CodegenMode => try { - org.apache.spark.sql.execution.debug.codegenString(queryExecution.executedPlan) + org.apache.spark.sql.execution.debug.writeCodegen(append, queryExecution.executedPlan) } catch { - case e: AnalysisException => e.toString + case e: AnalysisException => append(e.toString) } case CostMode => - queryExecution.stringWithStats + queryExecution.stringWithStats(maxFields, append) case FormattedMode => - queryExecution.simpleString(formatted = true) + queryExecution.simpleString(formatted = true, maxFields = maxFields, append) } } @@ -191,27 +221,39 @@ class QueryExecution( override def toString: String = withRedaction { val concat = new PlanStringConcat() - writePlans(concat.append, SQLConf.get.maxToStringFields) - concat.toString + toString(SQLConf.get.maxToStringFields, concat.append) + withRedaction { + concat.toString + } + } + + private def toString(maxFields: Int, append: String => Unit): Unit = { + writePlans(append, maxFields) } - def stringWithStats: String = withRedaction { + def stringWithStats: String = { val concat = new PlanStringConcat() + stringWithStats(SQLConf.get.maxToStringFields, concat.append) + withRedaction { + concat.toString + } + } + + private def stringWithStats(maxFields: Int, append: String => Unit): Unit = { val maxFields = SQLConf.get.maxToStringFields // trigger to compute stats for logical plans try { optimizedPlan.stats } catch { - case e: AnalysisException => concat.append(e.toString + "\n") + case e: AnalysisException => append(e.toString + "\n") } // only show optimized logical plan and physical plan - concat.append("== Optimized Logical Plan ==\n") - QueryPlan.append(optimizedPlan, concat.append, verbose = true, addSuffix = true, maxFields) - concat.append("\n== Physical Plan ==\n") - QueryPlan.append(executedPlan, concat.append, verbose = true, addSuffix = false, maxFields) - concat.append("\n") - concat.toString + append("== Optimized Logical Plan ==\n") + QueryPlan.append(optimizedPlan, append, verbose = true, addSuffix = true, maxFields) + append("\n== Physical Plan ==\n") + QueryPlan.append(executedPlan, append, verbose = true, addSuffix = false, maxFields) + append("\n") } /** @@ -248,19 +290,26 @@ class QueryExecution( /** * Dumps debug information about query execution into the specified file. * + * @param path path of the file the debug info is written to. * @param maxFields maximum number of fields converted to string representation. + * @param explainMode the explain mode to be used to generate the string + * representation of the plan. */ - def toFile(path: String, maxFields: Int = Int.MaxValue): Unit = { + def toFile( + path: String, + maxFields: Int = Int.MaxValue, + explainMode: Option[String] = None): Unit = { val filePath = new Path(path) val fs = filePath.getFileSystem(sparkSession.sessionState.newHadoopConf()) val writer = new BufferedWriter(new OutputStreamWriter(fs.create(filePath))) - val append = (s: String) => { - writer.write(s) - } try { - writePlans(append, maxFields) - writer.write("\n== Whole Stage Codegen ==\n") - org.apache.spark.sql.execution.debug.writeCodegen(writer.write, executedPlan) + val mode = explainMode.map(ExplainMode.fromString(_)).getOrElse(ExtendedMode) + explainString(mode, maxFields, writer.write) + if (mode != CodegenMode) { + writer.write("\n== Whole Stage Codegen ==\n") + org.apache.spark.sql.execution.debug.writeCodegen(writer.write, executedPlan) + } + log.info(s"Debug information was written at: $filePath") } finally { writer.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index 357820a9d63d0..db587dd98685e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec} +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.metric.SQLMetricInfo import org.apache.spark.sql.internal.SQLConf @@ -56,6 +57,7 @@ private[execution] object SparkPlanInfo { case ReusedSubqueryExec(child) => child :: Nil case a: AdaptiveSparkPlanExec => a.executedPlan :: Nil case stage: QueryStageExec => stage.plan :: Nil + case inMemTab: InMemoryTableScanExec => inMemTab.relation.cachedPlan :: Nil case _ => plan.children ++ plan.subqueries } val metrics = plan.metrics.toSeq.map { case (key, metric) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 12a1a1e7fc16e..302aae08d588b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, JoinSelectionHelper, NormalizeFloatingNumbers} import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -33,7 +33,6 @@ import org.apache.spark.sql.execution.aggregate.AggUtils import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec -import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.execution.python._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.MemoryPlan @@ -135,93 +134,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Supports both equi-joins and non-equi-joins. * Supports only inner like joins. */ - object JoinSelection extends Strategy with PredicateHelper { - - /** - * Matches a plan whose output should be small enough to be used in broadcast join. - */ - private def canBroadcast(plan: LogicalPlan): Boolean = { - plan.stats.sizeInBytes >= 0 && plan.stats.sizeInBytes <= conf.autoBroadcastJoinThreshold - } - - /** - * Matches a plan whose single partition should be small enough to build a hash table. - * - * Note: this assume that the number of partition is fixed, requires additional work if it's - * dynamic. - */ - private def canBuildLocalHashMap(plan: LogicalPlan): Boolean = { - plan.stats.sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions - } - - /** - * Returns whether plan a is much smaller (3X) than plan b. - * - * The cost to build hash map is higher than sorting, we should only build hash map on a table - * that is much smaller than other one. Since we does not have the statistic for number of rows, - * use the size of bytes here as estimation. - */ - private def muchSmaller(a: LogicalPlan, b: LogicalPlan): Boolean = { - a.stats.sizeInBytes * 3 <= b.stats.sizeInBytes - } - - private def canBuildRight(joinType: JoinType): Boolean = joinType match { - case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => true - case _ => false - } - - private def canBuildLeft(joinType: JoinType): Boolean = joinType match { - case _: InnerLike | RightOuter => true - case _ => false - } - - private def getBuildSide( - wantToBuildLeft: Boolean, - wantToBuildRight: Boolean, - left: LogicalPlan, - right: LogicalPlan): Option[BuildSide] = { - if (wantToBuildLeft && wantToBuildRight) { - // returns the smaller side base on its estimated physical size, if we want to build the - // both sides. - Some(getSmallerSide(left, right)) - } else if (wantToBuildLeft) { - Some(BuildLeft) - } else if (wantToBuildRight) { - Some(BuildRight) - } else { - None - } - } - - private def getSmallerSide(left: LogicalPlan, right: LogicalPlan) = { - if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft - } - - private def hintToBroadcastLeft(hint: JoinHint): Boolean = { - hint.leftHint.exists(_.strategy.contains(BROADCAST)) - } - - private def hintToBroadcastRight(hint: JoinHint): Boolean = { - hint.rightHint.exists(_.strategy.contains(BROADCAST)) - } - - private def hintToShuffleHashLeft(hint: JoinHint): Boolean = { - hint.leftHint.exists(_.strategy.contains(SHUFFLE_HASH)) - } - - private def hintToShuffleHashRight(hint: JoinHint): Boolean = { - hint.rightHint.exists(_.strategy.contains(SHUFFLE_HASH)) - } - - private def hintToSortMergeJoin(hint: JoinHint): Boolean = { - hint.leftHint.exists(_.strategy.contains(SHUFFLE_MERGE)) || - hint.rightHint.exists(_.strategy.contains(SHUFFLE_MERGE)) - } - - private def hintToShuffleReplicateNL(hint: JoinHint): Boolean = { - hint.leftHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL)) || - hint.rightHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL)) - } + object JoinSelection extends Strategy + with PredicateHelper + with JoinSelectionHelper { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { @@ -245,33 +160,31 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // 5. Pick broadcast nested loop join as the final solution. It may OOM but we don't have // other choice. case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, hint) => - def createBroadcastHashJoin(buildLeft: Boolean, buildRight: Boolean) = { - val wantToBuildLeft = canBuildLeft(joinType) && buildLeft - val wantToBuildRight = canBuildRight(joinType) && buildRight - getBuildSide(wantToBuildLeft, wantToBuildRight, left, right).map { buildSide => - Seq(joins.BroadcastHashJoinExec( - leftKeys, - rightKeys, - joinType, - buildSide, - condition, - planLater(left), - planLater(right))) + def createBroadcastHashJoin(onlyLookingAtHint: Boolean) = { + getBroadcastBuildSide(left, right, joinType, hint, onlyLookingAtHint, conf).map { + buildSide => + Seq(joins.BroadcastHashJoinExec( + leftKeys, + rightKeys, + joinType, + buildSide, + condition, + planLater(left), + planLater(right))) } } - def createShuffleHashJoin(buildLeft: Boolean, buildRight: Boolean) = { - val wantToBuildLeft = canBuildLeft(joinType) && buildLeft - val wantToBuildRight = canBuildRight(joinType) && buildRight - getBuildSide(wantToBuildLeft, wantToBuildRight, left, right).map { buildSide => - Seq(joins.ShuffledHashJoinExec( - leftKeys, - rightKeys, - joinType, - buildSide, - condition, - planLater(left), - planLater(right))) + def createShuffleHashJoin(onlyLookingAtHint: Boolean) = { + getShuffleHashJoinBuildSide(left, right, joinType, hint, onlyLookingAtHint, conf).map { + buildSide => + Seq(joins.ShuffledHashJoinExec( + leftKeys, + rightKeys, + joinType, + buildSide, + condition, + planLater(left), + planLater(right))) } } @@ -293,14 +206,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } def createJoinWithoutHint() = { - createBroadcastHashJoin( - canBroadcast(left) && !hint.leftHint.exists(_.strategy.contains(NO_BROADCAST_HASH)), - canBroadcast(right) && !hint.rightHint.exists(_.strategy.contains(NO_BROADCAST_HASH))) + createBroadcastHashJoin(false) .orElse { if (!conf.preferSortMergeJoin) { - createShuffleHashJoin( - canBuildLocalHashMap(left) && muchSmaller(left, right), - canBuildLocalHashMap(right) && muchSmaller(right, left)) + createShuffleHashJoin(false) } else { None } @@ -315,9 +224,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - createBroadcastHashJoin(hintToBroadcastLeft(hint), hintToBroadcastRight(hint)) + createBroadcastHashJoin(true) .orElse { if (hintToSortMergeJoin(hint)) createSortMergeJoin() else None } - .orElse(createShuffleHashJoin(hintToShuffleHashLeft(hint), hintToShuffleHashRight(hint))) + .orElse(createShuffleHashJoin(true)) .orElse { if (hintToShuffleReplicateNL(hint)) createCartesianProduct() else None } .getOrElse(createJoinWithoutHint()) @@ -374,7 +283,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } def createJoinWithoutHint() = { - createBroadcastNLJoin(canBroadcast(left), canBroadcast(right)) + createBroadcastNLJoin(canBroadcastBySize(left, conf), canBroadcastBySize(right, conf)) .orElse(createCartesianProduct()) .getOrElse { // This join could be very slow or OOM diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 32308063a11d3..bc924e6978ddc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -156,7 +156,7 @@ case class AdaptiveSparkPlanExec( var currentLogicalPlan = currentPhysicalPlan.logicalLink.get var result = createQueryStages(currentPhysicalPlan) val events = new LinkedBlockingQueue[StageMaterializationEvent]() - val errors = new mutable.ArrayBuffer[SparkException]() + val errors = new mutable.ArrayBuffer[Throwable]() var stagesToReplace = Seq.empty[QueryStageExec] while (!result.allChildStagesMaterialized) { currentPhysicalPlan = result.newPlan @@ -176,9 +176,7 @@ case class AdaptiveSparkPlanExec( }(AdaptiveSparkPlanExec.executionContext) } catch { case e: Throwable => - val ex = new SparkException( - s"Early failed query stage found: ${stage.treeString}", e) - cleanUpAndThrowException(Seq(ex), Some(stage.id)) + cleanUpAndThrowException(Seq(e), Some(stage.id)) } } } @@ -191,10 +189,9 @@ case class AdaptiveSparkPlanExec( events.drainTo(rem) (Seq(nextMsg) ++ rem.asScala).foreach { case StageSuccess(stage, res) => - stage.resultOption = Some(res) + stage.resultOption.set(Some(res)) case StageFailure(stage, ex) => - errors.append( - new SparkException(s"Failed to materialize query stage: ${stage.treeString}.", ex)) + errors.append(ex) } // In case of errors, we cancel all running stages and throw exception. @@ -328,11 +325,11 @@ case class AdaptiveSparkPlanExec( context.stageCache.get(e.canonicalized) match { case Some(existingStage) if conf.exchangeReuseEnabled => val stage = reuseQueryStage(existingStage, e) - // This is a leaf stage and is not materialized yet even if the reused exchange may has - // been completed. It will trigger re-optimization later and stage materialization will - // finish in instant if the underlying exchange is already completed. + val isMaterialized = stage.resultOption.get().isDefined CreateStageResult( - newPlan = stage, allChildStagesMaterialized = false, newStages = Seq(stage)) + newPlan = stage, + allChildStagesMaterialized = isMaterialized, + newStages = if (isMaterialized) Seq.empty else Seq(stage)) case _ => val result = createQueryStages(e.child) @@ -349,10 +346,11 @@ case class AdaptiveSparkPlanExec( newStage = reuseQueryStage(queryStage, e) } } - - // We've created a new stage, which is obviously not ready yet. - CreateStageResult(newPlan = newStage, - allChildStagesMaterialized = false, newStages = Seq(newStage)) + val isMaterialized = newStage.resultOption.get().isDefined + CreateStageResult( + newPlan = newStage, + allChildStagesMaterialized = isMaterialized, + newStages = if (isMaterialized) Seq.empty else Seq(newStage)) } else { CreateStageResult(newPlan = newPlan, allChildStagesMaterialized = false, newStages = result.newStages) @@ -361,7 +359,7 @@ case class AdaptiveSparkPlanExec( case q: QueryStageExec => CreateStageResult(newPlan = q, - allChildStagesMaterialized = q.resultOption.isDefined, newStages = Seq.empty) + allChildStagesMaterialized = q.resultOption.get().isDefined, newStages = Seq.empty) case _ => if (plan.children.isEmpty) { @@ -527,8 +525,8 @@ case class AdaptiveSparkPlanExec( val planDescriptionMode = ExplainMode.fromString(conf.uiExplainMode) context.session.sparkContext.listenerBus.post(SparkListenerSQLAdaptiveExecutionUpdate( executionId, - SQLExecution.getQueryExecution(executionId).explainString(planDescriptionMode), - SparkPlanInfo.fromSparkPlan(this))) + context.qe.explainString(planDescriptionMode), + SparkPlanInfo.fromSparkPlan(context.qe.executedPlan))) } } @@ -537,31 +535,28 @@ case class AdaptiveSparkPlanExec( * materialization errors and stage cancellation errors. */ private def cleanUpAndThrowException( - errors: Seq[SparkException], + errors: Seq[Throwable], earlyFailedStage: Option[Int]): Unit = { - val runningStages = currentPhysicalPlan.collect { + currentPhysicalPlan.foreach { // earlyFailedStage is the stage which failed before calling doMaterialize, // so we should avoid calling cancel on it to re-trigger the failure again. - case s: QueryStageExec if !earlyFailedStage.contains(s.id) => s - } - val cancelErrors = new mutable.ArrayBuffer[SparkException]() - try { - runningStages.foreach { s => + case s: QueryStageExec if !earlyFailedStage.contains(s.id) => try { s.cancel() } catch { case NonFatal(t) => - cancelErrors.append( - new SparkException(s"Failed to cancel query stage: ${s.treeString}", t)) + logError(s"Exception in cancelling query stage: ${s.treeString}", t) } - } - } finally { - val ex = new SparkException( - "Adaptive execution failed due to stage materialization failures.", errors.head) - errors.tail.foreach(ex.addSuppressed) - cancelErrors.foreach(ex.addSuppressed) - throw ex + case _ => + } + val e = if (errors.size == 1) { + errors.head + } else { + val se = new SparkException("Multiple failures in stage materialization.", errors.head) + errors.tail.foreach(se.addSuppressed) + se } + throw e } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala index 6aa34497c9ea6..84c65df31a7c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.internal.SQLConf * avoid many small reduce tasks that hurt performance. */ case class CoalesceShufflePartitions(session: SparkSession) extends Rule[SparkPlan] { - import CoalesceShufflePartitions._ private def conf = session.sessionState.conf override def apply(plan: SparkPlan): SparkPlan = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/DemoteBroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/DemoteBroadcastHashJoin.scala index 0f2868e41cc39..aba83b1337109 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/DemoteBroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/DemoteBroadcastHashJoin.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.internal.SQLConf case class DemoteBroadcastHashJoin(conf: SQLConf) extends Rule[LogicalPlan] { private def shouldDemote(plan: LogicalPlan): Boolean = plan match { - case LogicalQueryStage(_, stage: ShuffleQueryStageExec) if stage.resultOption.isDefined + case LogicalQueryStage(_, stage: ShuffleQueryStageExec) if stage.resultOption.get().isDefined && stage.mapStats.isDefined => val mapStats = stage.mapStats.get val partitionCnt = mapStats.bytesByPartitionId.length diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala index d60c3ca72f6f6..ac98342277bc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.expressions.PredicateHelper +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan} import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, BuildLeft, BuildRight} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec} /** * Strategy for plans containing [[LogicalQueryStage]] nodes: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala index 5416fde222cb6..3620f27058af2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.execution.adaptive +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.internal.SQLConf /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index 396c9c9d6b4e5..b5d287ca7ac79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.apache.commons.io.FileUtils -import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv} +import org.apache.spark.{MapOutputTrackerMaster, SparkEnv} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ @@ -70,9 +70,9 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { size > conf.getConf(SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD) } - private def medianSize(stats: MapOutputStatistics): Long = { - val numPartitions = stats.bytesByPartitionId.length - val bytes = stats.bytesByPartitionId.sorted + private def medianSize(sizes: Seq[Long]): Long = { + val numPartitions = sizes.length + val bytes = sizes.sorted numPartitions match { case _ if (numPartitions % 2 == 0) => math.max((bytes(numPartitions / 2) + bytes(numPartitions / 2 - 1)) / 2, 1) @@ -163,16 +163,16 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { if supportedJoinTypes.contains(joinType) => assert(left.partitionsWithSizes.length == right.partitionsWithSizes.length) val numPartitions = left.partitionsWithSizes.length - // We use the median size of the original shuffle partitions to detect skewed partitions. - val leftMedSize = medianSize(left.mapStats) - val rightMedSize = medianSize(right.mapStats) + // Use the median size of the actual (coalesced) partition sizes to detect skewed partitions. + val leftMedSize = medianSize(left.partitionsWithSizes.map(_._2)) + val rightMedSize = medianSize(right.partitionsWithSizes.map(_._2)) logDebug( s""" |Optimizing skewed join. |Left side partitions size info: - |${getSizeInfo(leftMedSize, left.mapStats.bytesByPartitionId)} + |${getSizeInfo(leftMedSize, left.partitionsWithSizes.map(_._2))} |Right side partitions size info: - |${getSizeInfo(rightMedSize, right.mapStats.bytesByPartitionId)} + |${getSizeInfo(rightMedSize, right.partitionsWithSizes.map(_._2))} """.stripMargin) val canSplitLeft = canSplitLeftSide(joinType) val canSplitRight = canSplitRightSide(joinType) @@ -291,17 +291,15 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { private object ShuffleStage { def unapply(plan: SparkPlan): Option[ShuffleStageInfo] = plan match { case s: ShuffleQueryStageExec if s.mapStats.isDefined => - val mapStats = s.mapStats.get - val sizes = mapStats.bytesByPartitionId + val sizes = s.mapStats.get.bytesByPartitionId val partitions = sizes.zipWithIndex.map { case (size, i) => CoalescedPartitionSpec(i, i + 1) -> size } - Some(ShuffleStageInfo(s, mapStats, partitions)) + Some(ShuffleStageInfo(s, partitions)) case CustomShuffleReaderExec(s: ShuffleQueryStageExec, partitionSpecs) if s.mapStats.isDefined && partitionSpecs.nonEmpty => - val mapStats = s.mapStats.get - val sizes = mapStats.bytesByPartitionId + val sizes = s.mapStats.get.bytesByPartitionId val partitions = partitionSpecs.map { case spec @ CoalescedPartitionSpec(start, end) => var sum = 0L @@ -314,7 +312,7 @@ private object ShuffleStage { case other => throw new IllegalArgumentException( s"Expect CoalescedPartitionSpec but got $other") } - Some(ShuffleStageInfo(s, mapStats, partitions)) + Some(ShuffleStageInfo(s, partitions)) case _ => None } @@ -322,5 +320,4 @@ private object ShuffleStage { private case class ShuffleStageInfo( shuffleStage: ShuffleQueryStageExec, - mapStats: MapOutputStatistics, partitionsWithSizes: Seq[(CoalescedPartitionSpec, Long)]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala index f414f854b92ae..4e83b4344fbf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.adaptive import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicReference import scala.concurrent.{Future, Promise} @@ -25,6 +26,7 @@ import org.apache.spark.{FutureAction, MapOutputStatistics, SparkException} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.Partitioning @@ -82,7 +84,7 @@ abstract class QueryStageExec extends LeafExecNode { /** * Compute the statistics of the query stage if executed, otherwise None. */ - def computeStats(): Option[Statistics] = resultOption.map { _ => + def computeStats(): Option[Statistics] = resultOption.get().map { _ => // Metrics `dataSize` are available in both `ShuffleExchangeExec` and `BroadcastExchangeExec`. val exchange = plan match { case r: ReusedExchangeExec => r.child @@ -94,7 +96,9 @@ abstract class QueryStageExec extends LeafExecNode { @transient @volatile - private[adaptive] var resultOption: Option[Any] = None + protected var _resultOption = new AtomicReference[Option[Any]](None) + + private[adaptive] def resultOption: AtomicReference[Option[Any]] = _resultOption override def output: Seq[Attribute] = plan.output override def outputPartitioning: Partitioning = plan.outputPartitioning @@ -147,14 +151,16 @@ case class ShuffleQueryStageExec( throw new IllegalStateException("wrong plan for shuffle stage:\n " + plan.treeString) } - override def doMaterialize(): Future[Any] = { + override def doMaterialize(): Future[Any] = attachTree(this, "execute") { shuffle.mapOutputStatisticsFuture } override def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec = { - ShuffleQueryStageExec( + val reuse = ShuffleQueryStageExec( newStageId, ReusedExchangeExec(newOutput, shuffle)) + reuse._resultOption = this._resultOption + reuse } override def cancel(): Unit = { @@ -171,8 +177,8 @@ case class ShuffleQueryStageExec( * this method returns None, as there is no map statistics. */ def mapStats: Option[MapOutputStatistics] = { - assert(resultOption.isDefined, "ShuffleQueryStageExec should already be ready") - val stats = resultOption.get.asInstanceOf[MapOutputStatistics] + assert(resultOption.get().isDefined, "ShuffleQueryStageExec should already be ready") + val stats = resultOption.get().get.asInstanceOf[MapOutputStatistics] Option(stats) } } @@ -212,9 +218,11 @@ case class BroadcastQueryStageExec( } override def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec = { - BroadcastQueryStageExec( + val reuse = BroadcastQueryStageExec( newStageId, ReusedExchangeExec(newOutput, broadcast)) + reuse._resultOption = this._resultOption + reuse } override def cancel(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala index f506bdddc16b5..f1e053f7fb2a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression} -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Final, PartialMerge} import org.apache.spark.sql.execution.{ExplainUtils, UnaryExecNode} /** @@ -40,4 +40,28 @@ trait BaseAggregateExec extends UnaryExecNode { |${ExplainUtils.generateFieldString("Results", resultExpressions)} |""".stripMargin } + + protected def inputAttributes: Seq[Attribute] = { + val modes = aggregateExpressions.map(_.mode).distinct + if (modes.contains(Final) || modes.contains(PartialMerge)) { + // SPARK-31620: when planning aggregates, the partial aggregate uses aggregate function's + // `inputAggBufferAttributes` as its output. And Final and PartialMerge aggregate rely on the + // output to bind references for `DeclarativeAggregate.mergeExpressions`. But if we copy the + // aggregate function somehow after aggregate planning, like `PlanSubqueries`, the + // `DeclarativeAggregate` will be replaced by a new instance with new + // `inputAggBufferAttributes` and `mergeExpressions`. Then Final and PartialMerge aggregate + // can't bind the `mergeExpressions` with the output of the partial aggregate, as they use + // the `inputAggBufferAttributes` of the original `DeclarativeAggregate` before copy. Instead, + // we shall use `inputAggBufferAttributes` after copy to match the new `mergeExpressions`. + val aggAttrs = aggregateExpressions + // there're exactly four cases needs `inputAggBufferAttributes` from child according to the + // agg planning in `AggUtils`: Partial -> Final, PartialMerge -> Final, + // Partial -> PartialMerge, PartialMerge -> PartialMerge. + .filter(a => a.mode == Final || a.mode == PartialMerge).map(_.aggregateFunction) + .flatMap(_.inputAggBufferAttributes) + child.output.dropRight(aggAttrs.length) ++ aggAttrs + } else { + child.output + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 8af17ed0e1639..9c07ea10a87e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -129,7 +129,7 @@ case class HashAggregateExec( resultExpressions, (expressions, inputSchema) => MutableProjection.create(expressions, inputSchema), - child.output, + inputAttributes, iter, testFallbackStartsAt, numOutputRows, @@ -334,7 +334,7 @@ case class HashAggregateExec( private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) - val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output + val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ inputAttributes // To individually generate code for each aggregate function, an element in `updateExprs` holds // all the expressions for the buffer of an aggregation function. val updateExprs = aggregateExpressions.map { e => @@ -931,7 +931,7 @@ case class HashAggregateExec( } } - val inputAttr = aggregateBufferAttributes ++ child.output + val inputAttr = aggregateBufferAttributes ++ inputAttributes // Here we set `currentVars(0)` to `currentVars(numBufferSlots)` to null, so that when // generating code for buffer columns, we use `INPUT_ROW`(will be the buffer row), while // generating input columns, we use `currentVars`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 3fb58eb2cc8ba..f1c0719ff8948 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -123,7 +123,7 @@ case class ObjectHashAggregateExec( resultExpressions, (expressions, inputSchema) => MutableProjection.create(expressions, inputSchema), - child.output, + inputAttributes, iter, fallbackCountThreshold, numOutputRows) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index 9610eab82c7cb..ba0c3517a1a14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -88,7 +88,7 @@ case class SortAggregateExec( val outputIter = new SortBasedAggregationIterator( partIndex, groupingExpressions, - child.output, + inputAttributes, iter, aggregateExpressions, aggregateAttributes, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala index 614d6c2846bfa..136f7c47f5341 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala @@ -76,6 +76,8 @@ class DetectAmbiguousSelfJoin(conf: SQLConf) extends Rule[LogicalPlan] { // We always remove the special metadata from `AttributeReference` at the end of this rule, so // Dataset column reference only exists in the root node via Dataset transformations like // `Dataset#select`. + if (plan.find(_.isInstanceOf[Join]).isEmpty) return stripColumnReferenceMetadataInPlan(plan) + val colRefAttrs = plan.expressions.flatMap(_.collect { case a: AttributeReference if isColumnReference(a) => a }) @@ -153,6 +155,10 @@ class DetectAmbiguousSelfJoin(conf: SQLConf) extends Rule[LogicalPlan] { } } + stripColumnReferenceMetadataInPlan(plan) + } + + private def stripColumnReferenceMetadataInPlan(plan: LogicalPlan): LogicalPlan = { plan.transformExpressions { case a: AttributeReference if isColumnReference(a) => // Remove the special metadata from this `AttributeReference`, as the detection is done. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala index 45a9b1a808cf3..abb74d8d09ec6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -23,9 +23,12 @@ import org.apache.hadoop.fs.Path import org.json4s.NoTypeHints import org.json4s.jackson.Serialization +import org.apache.spark.SparkUpgradeException import org.apache.spark.sql.{SPARK_LEGACY_DATETIME, SPARK_VERSION_METADATA_KEY} import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.util.RebaseDateTime import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -84,17 +87,107 @@ object DataSourceUtils { case _ => false } - def needRebaseDateTime(lookupFileMeta: String => String): Option[Boolean] = { + def datetimeRebaseMode( + lookupFileMeta: String => String, + modeByConfig: String): LegacyBehaviorPolicy.Value = { if (Utils.isTesting && SQLConf.get.getConfString("spark.test.forceNoRebase", "") == "true") { - return Some(false) + return LegacyBehaviorPolicy.CORRECTED } - // If there is no version, we return None and let the caller side to decide. + // If there is no version, we return the mode specified by the config. Option(lookupFileMeta(SPARK_VERSION_METADATA_KEY)).map { version => // Files written by Spark 2.4 and earlier follow the legacy hybrid calendar and we need to // rebase the datetime values. // Files written by Spark 3.0 and latter may also need the rebase if they were written with - // the "rebaseInWrite" config enabled. - version < "3.0.0" || lookupFileMeta(SPARK_LEGACY_DATETIME) != null + // the "LEGACY" rebase mode. + if (version < "3.0.0" || lookupFileMeta(SPARK_LEGACY_DATETIME) != null) { + LegacyBehaviorPolicy.LEGACY + } else { + LegacyBehaviorPolicy.CORRECTED + } + }.getOrElse(LegacyBehaviorPolicy.withName(modeByConfig)) + } + + def newRebaseExceptionInRead(format: String): SparkUpgradeException = { + val config = if (format == "Parquet") { + SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_READ.key + } else if (format == "Avro") { + SQLConf.LEGACY_AVRO_REBASE_MODE_IN_READ.key + } else { + throw new IllegalStateException("unrecognized format " + format) } + new SparkUpgradeException("3.0", "reading dates before 1582-10-15 or timestamps before " + + s"1900-01-01T00:00:00Z from $format files can be ambiguous, as the files may be written by " + + "Spark 2.x or legacy versions of Hive, which uses a legacy hybrid calendar that is " + + "different from Spark 3.0+'s Proleptic Gregorian calendar. See more details in " + + s"SPARK-31404. You can set $config to 'LEGACY' to rebase the datetime values w.r.t. " + + s"the calendar difference during reading. Or set $config to 'CORRECTED' to read the " + + "datetime values as it is.", null) + } + + def newRebaseExceptionInWrite(format: String): SparkUpgradeException = { + val config = if (format == "Parquet") { + SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_WRITE.key + } else if (format == "Avro") { + SQLConf.LEGACY_AVRO_REBASE_MODE_IN_WRITE.key + } else { + throw new IllegalStateException("unrecognized format " + format) + } + new SparkUpgradeException("3.0", "writing dates before 1582-10-15 or timestamps before " + + s"1900-01-01T00:00:00Z into $format files can be dangerous, as the files may be read by " + + "Spark 2.x or legacy versions of Hive later, which uses a legacy hybrid calendar that is " + + "different from Spark 3.0+'s Proleptic Gregorian calendar. See more details in " + + s"SPARK-31404. You can set $config to 'LEGACY' to rebase the datetime values w.r.t. " + + "the calendar difference during writing, to get maximum interoperability. Or set " + + s"$config to 'CORRECTED' to write the datetime values as it is, if you are 100% sure that " + + "the written files will only be read by Spark 3.0+ or other systems that use Proleptic " + + "Gregorian calendar.", null) + } + + def creteDateRebaseFuncInRead( + rebaseMode: LegacyBehaviorPolicy.Value, + format: String): Int => Int = rebaseMode match { + case LegacyBehaviorPolicy.EXCEPTION => days: Int => + if (days < RebaseDateTime.lastSwitchJulianDay) { + throw DataSourceUtils.newRebaseExceptionInRead(format) + } + days + case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseJulianToGregorianDays + case LegacyBehaviorPolicy.CORRECTED => identity[Int] + } + + def creteDateRebaseFuncInWrite( + rebaseMode: LegacyBehaviorPolicy.Value, + format: String): Int => Int = rebaseMode match { + case LegacyBehaviorPolicy.EXCEPTION => days: Int => + if (days < RebaseDateTime.lastSwitchGregorianDay) { + throw DataSourceUtils.newRebaseExceptionInWrite(format) + } + days + case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseGregorianToJulianDays + case LegacyBehaviorPolicy.CORRECTED => identity[Int] + } + + def creteTimestampRebaseFuncInRead( + rebaseMode: LegacyBehaviorPolicy.Value, + format: String): Long => Long = rebaseMode match { + case LegacyBehaviorPolicy.EXCEPTION => micros: Long => + if (micros < RebaseDateTime.lastSwitchJulianTs) { + throw DataSourceUtils.newRebaseExceptionInRead(format) + } + micros + case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseJulianToGregorianMicros + case LegacyBehaviorPolicy.CORRECTED => identity[Long] + } + + def creteTimestampRebaseFuncInWrite( + rebaseMode: LegacyBehaviorPolicy.Value, + format: String): Long => Long = rebaseMode match { + case LegacyBehaviorPolicy.EXCEPTION => micros: Long => + if (micros < RebaseDateTime.lastSwitchGregorianTs) { + throw DataSourceUtils.newRebaseExceptionInWrite(format) + } + micros + case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseGregorianToJulianMicros + case LegacyBehaviorPolicy.CORRECTED => identity[Long] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 542c996a5342d..fc59336d6107c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -21,7 +21,7 @@ import java.io.{FileNotFoundException, IOException} import org.apache.parquet.io.ParquetDecodingException -import org.apache.spark.{Partition => RDDPartition, TaskContext} +import org.apache.spark.{Partition => RDDPartition, SparkUpgradeException, TaskContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{InputFileBlockHolder, RDD} import org.apache.spark.sql.SparkSession @@ -178,7 +178,9 @@ class FileScanRDD( s"Expected: ${e.getLogicalType}, Found: ${e.getPhysicalType}" throw new QueryExecutionException(message, e) case e: ParquetDecodingException => - if (e.getMessage.contains("Can not read value at")) { + if (e.getCause.isInstanceOf[SparkUpgradeException]) { + throw e.getCause + } else if (e.getMessage.contains("Can not read value at")) { val message = "Encounter error while reading parquet files. " + "One possible cause: Parquet column cannot be converted in the " + "corresponding files. Details: " diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 292ac6db04baf..f7e225b0cdc96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -134,7 +134,7 @@ object PartitioningUtils { val timestampFormatter = TimestampFormatter( timestampPartitionPattern, zoneId, - needVarLengthSecondFraction = true) + isParsing = true) // First, we need to parse every partition's path and see if we can find partition values. val (partitionValues, optDiscoveredBasePaths) = paths.map { path => parsePartition(path, typeInference, basePaths, userSpecifiedDataTypes, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/DB2ConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/DB2ConnectionProvider.scala index 1e9e713e2c4d4..cf9729639c03c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/DB2ConnectionProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/DB2ConnectionProvider.scala @@ -48,7 +48,7 @@ private[sql] class DB2ConnectionProvider(driver: Driver, options: JDBCOptions) result } - override def setAuthenticationConfigIfNeeded(): Unit = { + override def setAuthenticationConfigIfNeeded(): Unit = SecurityConfigurationLock.synchronized { val (parent, configEntry) = getConfigWithAppEntry() if (configEntry == null || configEntry.isEmpty) { setAuthenticationConfig(parent) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MariaDBConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MariaDBConnectionProvider.scala index 2b2496f27aa8c..8e3381077cbbf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MariaDBConnectionProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MariaDBConnectionProvider.scala @@ -27,7 +27,7 @@ private[jdbc] class MariaDBConnectionProvider(driver: Driver, options: JDBCOptio "Krb5ConnectorContext" } - override def setAuthenticationConfigIfNeeded(): Unit = { + override def setAuthenticationConfigIfNeeded(): Unit = SecurityConfigurationLock.synchronized { val (parent, configEntry) = getConfigWithAppEntry() /** * Couple of things to mention here: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProvider.scala index f36f7d76be087..73034dcb9c2e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProvider.scala @@ -30,7 +30,7 @@ private[jdbc] class PostgresConnectionProvider(driver: Driver, options: JDBCOpti properties.getProperty("jaasApplicationName", "pgjdbc") } - override def setAuthenticationConfigIfNeeded(): Unit = { + override def setAuthenticationConfigIfNeeded(): Unit = SecurityConfigurationLock.synchronized { val (parent, configEntry) = getConfigWithAppEntry() if (configEntry == null || configEntry.isEmpty) { setAuthenticationConfig(parent) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/SecureConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/SecureConnectionProvider.scala index 1b54e9509b9eb..fa75fc8c28fbf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/SecureConnectionProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/SecureConnectionProvider.scala @@ -26,6 +26,12 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.util.SecurityUtils +/** + * Some of the secure connection providers modify global JVM security configuration. + * In order to avoid race the modification must be synchronized with this. + */ +private[connection] object SecurityConfigurationLock + private[jdbc] abstract class SecureConnectionProvider(driver: Driver, options: JDBCOptions) extends BasicConnectionProvider(driver, options) with Logging { override def getConnection(): Connection = { @@ -40,7 +46,8 @@ private[jdbc] abstract class SecureConnectionProvider(driver: Driver, options: J /** * Sets database specific authentication configuration when needed. If configuration already set - * then later calls must be no op. + * then later calls must be no op. When the global JVM security configuration changed then the + * related code parts must be synchronized properly. */ def setAuthenticationConfigIfNeeded(): Unit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala index 851cc51466a91..8a6c4dce75f30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability} import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, LogicalWriteInfo, PhysicalWriteInfo, SupportsTruncate, WriteBuilder, WriterCommitMessage} import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} -import org.apache.spark.sql.internal.connector.SimpleTableProvider +import org.apache.spark.sql.internal.connector.{SimpleTableProvider, SupportsStreamingUpdate} import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -52,8 +52,10 @@ private[noop] object NoopTable extends Table with SupportsWrite { } } -private[noop] object NoopWriteBuilder extends WriteBuilder with SupportsTruncate { +private[noop] object NoopWriteBuilder extends WriteBuilder + with SupportsTruncate with SupportsStreamingUpdate { override def truncate(): WriteBuilder = this + override def update(): WriteBuilder = this override def buildForBatch(): BatchWrite = NoopBatchWrite override def buildForStreaming(): StreamingWrite = NoopStreamingWrite } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index c6d9ddf370e22..71874104fcf4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -300,10 +300,9 @@ class ParquetFileFormat None } - val rebaseDateTime = DataSourceUtils.needRebaseDateTime( - footerFileMetaData.getKeyValueMetaData.get).getOrElse { - SQLConf.get.getConf(SQLConf.LEGACY_PARQUET_REBASE_DATETIME_IN_READ) - } + val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode( + footerFileMetaData.getKeyValueMetaData.get, + SQLConf.get.getConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_READ)) val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val hadoopAttemptContext = @@ -318,7 +317,7 @@ class ParquetFileFormat if (enableVectorizedReader) { val vectorizedReader = new VectorizedParquetRecordReader( convertTz.orNull, - rebaseDateTime, + datetimeRebaseMode.toString, enableOffHeapColumnVector && taskContext.isDefined, capacity) val iter = new RecordReaderIterator(vectorizedReader) @@ -337,7 +336,7 @@ class ParquetFileFormat logDebug(s"Falling back to parquet-mr") // ParquetRecordReader returns InternalRow val readSupport = new ParquetReadSupport( - convertTz, enableVectorizedReader = false, rebaseDateTime) + convertTz, enableVectorizedReader = false, datetimeRebaseMode) val reader = if (pushed.isDefined && enableRecordFilter) { val parquetFilter = FilterCompat.get(pushed.get, null) new ParquetRecordReader[InternalRow](readSupport, parquetFilter) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index d89186af8c8e5..491977c61d3cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.lang.{Boolean => JBoolean, Double => JDouble, Float => JFloat, Long => JLong} import java.math.{BigDecimal => JBigDecimal} import java.sql.{Date, Timestamp} -import java.time.LocalDate +import java.time.{Instant, LocalDate} import java.util.Locale import scala.collection.JavaConverters.asScalaBufferConverter @@ -129,6 +129,11 @@ class ParquetFilters( case ld: LocalDate => DateTimeUtils.localDateToDays(ld) } + private def timestampToMicros(v: Any): JLong = v match { + case i: Instant => DateTimeUtils.instantToMicros(i) + case t: Timestamp => DateTimeUtils.fromJavaTimestamp(t) + } + private def decimalToInt32(decimal: JBigDecimal): Integer = decimal.unscaledValue().intValue() private def decimalToInt64(decimal: JBigDecimal): JLong = decimal.unscaledValue().longValue() @@ -148,6 +153,12 @@ class ParquetFilters( Binary.fromConstantByteArray(fixedLengthBytes, 0, numBytes) } + private def timestampToMillis(v: Any): JLong = { + val micros = timestampToMicros(v) + val millis = DateTimeUtils.microsToMillis(micros) + millis.asInstanceOf[JLong] + } + private val makeEq: PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { case ParquetBooleanType => @@ -179,12 +190,11 @@ class ParquetFilters( case ParquetTimestampMicrosType if pushDownTimestamp => (n: Array[String], v: Any) => FilterApi.eq( longColumn(n), - Option(v).map(t => DateTimeUtils.fromJavaTimestamp(t.asInstanceOf[Timestamp]) - .asInstanceOf[JLong]).orNull) + Option(v).map(timestampToMicros).orNull) case ParquetTimestampMillisType if pushDownTimestamp => (n: Array[String], v: Any) => FilterApi.eq( longColumn(n), - Option(v).map(_.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]).orNull) + Option(v).map(timestampToMillis).orNull) case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => (n: Array[String], v: Any) => FilterApi.eq( @@ -230,12 +240,11 @@ class ParquetFilters( case ParquetTimestampMicrosType if pushDownTimestamp => (n: Array[String], v: Any) => FilterApi.notEq( longColumn(n), - Option(v).map(t => DateTimeUtils.fromJavaTimestamp(t.asInstanceOf[Timestamp]) - .asInstanceOf[JLong]).orNull) + Option(v).map(timestampToMicros).orNull) case ParquetTimestampMillisType if pushDownTimestamp => (n: Array[String], v: Any) => FilterApi.notEq( longColumn(n), - Option(v).map(_.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]).orNull) + Option(v).map(timestampToMillis).orNull) case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => (n: Array[String], v: Any) => FilterApi.notEq( @@ -273,13 +282,9 @@ class ParquetFilters( (n: Array[String], v: Any) => FilterApi.lt(intColumn(n), dateToDays(v).asInstanceOf[Integer]) case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.lt( - longColumn(n), - DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong]) + (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMicros(v)) case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.lt( - longColumn(n), - v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) + (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMillis(v)) case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => (n: Array[String], v: Any) => @@ -314,13 +319,9 @@ class ParquetFilters( (n: Array[String], v: Any) => FilterApi.ltEq(intColumn(n), dateToDays(v).asInstanceOf[Integer]) case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.ltEq( - longColumn(n), - DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong]) + (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMicros(v)) case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.ltEq( - longColumn(n), - v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) + (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMillis(v)) case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => (n: Array[String], v: Any) => @@ -355,13 +356,9 @@ class ParquetFilters( (n: Array[String], v: Any) => FilterApi.gt(intColumn(n), dateToDays(v).asInstanceOf[Integer]) case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.gt( - longColumn(n), - DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong]) + (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMicros(v)) case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.gt( - longColumn(n), - v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) + (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMillis(v)) case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => (n: Array[String], v: Any) => @@ -396,13 +393,9 @@ class ParquetFilters( (n: Array[String], v: Any) => FilterApi.gtEq(intColumn(n), dateToDays(v).asInstanceOf[Integer]) case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.gtEq( - longColumn(n), - DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong]) + (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMicros(v)) case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.gtEq( - longColumn(n), - v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) + (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMillis(v)) case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => (n: Array[String], v: Any) => @@ -476,7 +469,7 @@ class ParquetFilters( case ParquetDateType => value.isInstanceOf[Date] || value.isInstanceOf[LocalDate] case ParquetTimestampMicrosType | ParquetTimestampMillisType => - value.isInstanceOf[Timestamp] + value.isInstanceOf[Timestamp] || value.isInstanceOf[Instant] case ParquetSchemaType(DECIMAL, INT32, _, decimalMeta) => isDecimalMatched(value, decimalMeta) case ParquetSchemaType(DECIMAL, INT64, _, decimalMeta) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala index 28165e0bbecde..a30d1c26b3b2d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala @@ -32,6 +32,7 @@ import org.apache.parquet.schema.Type.Repetition import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy import org.apache.spark.sql.types._ /** @@ -53,7 +54,7 @@ import org.apache.spark.sql.types._ class ParquetReadSupport( val convertTz: Option[ZoneId], enableVectorizedReader: Boolean, - rebaseDateTime: Boolean) + datetimeRebaseMode: LegacyBehaviorPolicy.Value) extends ReadSupport[InternalRow] with Logging { private var catalystRequestedSchema: StructType = _ @@ -61,7 +62,7 @@ class ParquetReadSupport( // We need a zero-arg constructor for SpecificParquetRecordReaderBase. But that is only // used in the vectorized reader, where we get the convertTz/rebaseDateTime value directly, // and the values here are ignored. - this(None, enableVectorizedReader = true, rebaseDateTime = false) + this(None, enableVectorizedReader = true, datetimeRebaseMode = LegacyBehaviorPolicy.CORRECTED) } /** @@ -130,7 +131,7 @@ class ParquetReadSupport( ParquetReadSupport.expandUDT(catalystRequestedSchema), new ParquetToSparkSchemaConverter(conf), convertTz, - rebaseDateTime) + datetimeRebaseMode) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala index ec037130aa7e9..bb528d548b6ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala @@ -23,6 +23,7 @@ import org.apache.parquet.io.api.{GroupConverter, RecordMaterializer} import org.apache.parquet.schema.MessageType import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy import org.apache.spark.sql.types.StructType /** @@ -32,19 +33,19 @@ import org.apache.spark.sql.types.StructType * @param catalystSchema Catalyst schema of the rows to be constructed * @param schemaConverter A Parquet-Catalyst schema converter that helps initializing row converters * @param convertTz the optional time zone to convert to int96 data - * @param rebaseDateTime true if need to rebase date/timestamp from Julian to Proleptic Gregorian - * calendar + * @param datetimeRebaseMode the mode of rebasing date/timestamp from Julian to Proleptic Gregorian + * calendar */ private[parquet] class ParquetRecordMaterializer( parquetSchema: MessageType, catalystSchema: StructType, schemaConverter: ParquetToSparkSchemaConverter, convertTz: Option[ZoneId], - rebaseDateTime: Boolean) + datetimeRebaseMode: LegacyBehaviorPolicy.Value) extends RecordMaterializer[InternalRow] { private val rootConverter = new ParquetRowConverter( - schemaConverter, parquetSchema, catalystSchema, convertTz, rebaseDateTime, NoopUpdater) + schemaConverter, parquetSchema, catalystSchema, convertTz, datetimeRebaseMode, NoopUpdater) override def getCurrentRecord: InternalRow = rootConverter.currentRecord diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index 8376b7b137ae4..201ee16faeb08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -35,8 +35,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CaseInsensitiveMap, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLTimestamp -import org.apache.spark.sql.catalyst.util.RebaseDateTime._ +import org.apache.spark.sql.execution.datasources.DataSourceUtils import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -121,8 +122,8 @@ private[parquet] class ParquetPrimitiveConverter(val updater: ParentContainerUpd * @param catalystType Spark SQL schema that corresponds to the Parquet record type. User-defined * types should have been expanded. * @param convertTz the optional time zone to convert to int96 data - * @param rebaseDateTime true if need to rebase date/timestamp from Julian to Proleptic Gregorian - * calendar + * @param datetimeRebaseMode the mode of rebasing date/timestamp from Julian to Proleptic Gregorian + * calendar * @param updater An updater which propagates converted field values to the parent container */ private[parquet] class ParquetRowConverter( @@ -130,7 +131,7 @@ private[parquet] class ParquetRowConverter( parquetType: GroupType, catalystType: StructType, convertTz: Option[ZoneId], - rebaseDateTime: Boolean, + datetimeRebaseMode: LegacyBehaviorPolicy.Value, updater: ParentContainerUpdater) extends ParquetGroupConverter(updater) with Logging { @@ -181,6 +182,12 @@ private[parquet] class ParquetRowConverter( */ def currentRecord: InternalRow = currentRow + private val dateRebaseFunc = DataSourceUtils.creteDateRebaseFuncInRead( + datetimeRebaseMode, "Parquet") + + private val timestampRebaseFunc = DataSourceUtils.creteTimestampRebaseFuncInRead( + datetimeRebaseMode, "Parquet") + // Converters for each field. private[this] val fieldConverters: Array[Converter with HasParentContainerUpdater] = { // (SPARK-31116) Use case insensitive map if spark.sql.caseSensitive is false @@ -275,35 +282,17 @@ private[parquet] class ParquetRowConverter( new ParquetStringConverter(updater) case TimestampType if parquetType.getOriginalType == OriginalType.TIMESTAMP_MICROS => - if (rebaseDateTime) { - new ParquetPrimitiveConverter(updater) { - override def addLong(value: Long): Unit = { - val rebased = rebaseJulianToGregorianMicros(value) - updater.setLong(rebased) - } - } - } else { - new ParquetPrimitiveConverter(updater) { - override def addLong(value: Long): Unit = { - updater.setLong(value) - } + new ParquetPrimitiveConverter(updater) { + override def addLong(value: Long): Unit = { + updater.setLong(timestampRebaseFunc(value)) } } case TimestampType if parquetType.getOriginalType == OriginalType.TIMESTAMP_MILLIS => - if (rebaseDateTime) { - new ParquetPrimitiveConverter(updater) { - override def addLong(value: Long): Unit = { - val micros = DateTimeUtils.millisToMicros(value) - val rebased = rebaseJulianToGregorianMicros(micros) - updater.setLong(rebased) - } - } - } else { - new ParquetPrimitiveConverter(updater) { - override def addLong(value: Long): Unit = { - updater.setLong(DateTimeUtils.millisToMicros(value)) - } + new ParquetPrimitiveConverter(updater) { + override def addLong(value: Long): Unit = { + val micros = DateTimeUtils.millisToMicros(value) + updater.setLong(timestampRebaseFunc(micros)) } } @@ -328,17 +317,9 @@ private[parquet] class ParquetRowConverter( } case DateType => - if (rebaseDateTime) { - new ParquetPrimitiveConverter(updater) { - override def addInt(value: Int): Unit = { - updater.set(rebaseJulianToGregorianDays(value)) - } - } - } else { - new ParquetPrimitiveConverter(updater) { - override def addInt(value: Int): Unit = { - updater.set(value) - } + new ParquetPrimitiveConverter(updater) { + override def addInt(value: Int): Unit = { + updater.set(dateRebaseFunc(value)) } } @@ -386,7 +367,12 @@ private[parquet] class ParquetRowConverter( } } new ParquetRowConverter( - schemaConverter, parquetType.asGroupType(), t, convertTz, rebaseDateTime, wrappedUpdater) + schemaConverter, + parquetType.asGroupType(), + t, + convertTz, + datetimeRebaseMode, + wrappedUpdater) case t => throw new RuntimeException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala index b135611dd6416..6c333671d59cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala @@ -35,8 +35,9 @@ import org.apache.spark.sql.{SPARK_LEGACY_DATETIME, SPARK_VERSION_METADATA_KEY} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.catalyst.util.RebaseDateTime._ +import org.apache.spark.sql.execution.datasources.DataSourceUtils import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy import org.apache.spark.sql.types._ /** @@ -78,9 +79,14 @@ class ParquetWriteSupport extends WriteSupport[InternalRow] with Logging { private val decimalBuffer = new Array[Byte](Decimal.minBytesForPrecision(DecimalType.MAX_PRECISION)) - // Whether to rebase datetimes from Gregorian to Julian calendar in write - private val rebaseDateTime: Boolean = - SQLConf.get.getConf(SQLConf.LEGACY_PARQUET_REBASE_DATETIME_IN_WRITE) + private val datetimeRebaseMode = LegacyBehaviorPolicy.withName( + SQLConf.get.getConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_WRITE)) + + private val dateRebaseFunc = DataSourceUtils.creteDateRebaseFuncInWrite( + datetimeRebaseMode, "Parquet") + + private val timestampRebaseFunc = DataSourceUtils.creteTimestampRebaseFuncInWrite( + datetimeRebaseMode, "Parquet") override def init(configuration: Configuration): WriteContext = { val schemaString = configuration.get(ParquetWriteSupport.SPARK_ROW_SCHEMA) @@ -103,7 +109,13 @@ class ParquetWriteSupport extends WriteSupport[InternalRow] with Logging { val metadata = Map( SPARK_VERSION_METADATA_KEY -> SPARK_VERSION_SHORT, ParquetReadSupport.SPARK_METADATA_KEY -> schemaString - ) ++ (if (rebaseDateTime) Some(SPARK_LEGACY_DATETIME -> "") else None) + ) ++ { + if (datetimeRebaseMode == LegacyBehaviorPolicy.LEGACY) { + Some(SPARK_LEGACY_DATETIME -> "") + } else { + None + } + } logInfo( s"""Initialized Parquet WriteSupport with Catalyst schema: @@ -152,12 +164,11 @@ class ParquetWriteSupport extends WriteSupport[InternalRow] with Logging { (row: SpecializedGetters, ordinal: Int) => recordConsumer.addInteger(row.getShort(ordinal)) - case DateType if rebaseDateTime => + case DateType => (row: SpecializedGetters, ordinal: Int) => - val rebasedDays = rebaseGregorianToJulianDays(row.getInt(ordinal)) - recordConsumer.addInteger(rebasedDays) + recordConsumer.addInteger(dateRebaseFunc(row.getInt(ordinal))) - case IntegerType | DateType => + case IntegerType => (row: SpecializedGetters, ordinal: Int) => recordConsumer.addInteger(row.getInt(ordinal)) @@ -187,24 +198,15 @@ class ParquetWriteSupport extends WriteSupport[InternalRow] with Logging { buf.order(ByteOrder.LITTLE_ENDIAN).putLong(timeOfDayNanos).putInt(julianDay) recordConsumer.addBinary(Binary.fromReusedByteArray(timestampBuffer)) - case SQLConf.ParquetOutputTimestampType.TIMESTAMP_MICROS if rebaseDateTime => - (row: SpecializedGetters, ordinal: Int) => - val rebasedMicros = rebaseGregorianToJulianMicros(row.getLong(ordinal)) - recordConsumer.addLong(rebasedMicros) - case SQLConf.ParquetOutputTimestampType.TIMESTAMP_MICROS => (row: SpecializedGetters, ordinal: Int) => - recordConsumer.addLong(row.getLong(ordinal)) - - case SQLConf.ParquetOutputTimestampType.TIMESTAMP_MILLIS if rebaseDateTime => - (row: SpecializedGetters, ordinal: Int) => - val rebasedMicros = rebaseGregorianToJulianMicros(row.getLong(ordinal)) - val millis = DateTimeUtils.microsToMillis(rebasedMicros) - recordConsumer.addLong(millis) + val micros = row.getLong(ordinal) + recordConsumer.addLong(timestampRebaseFunc(micros)) case SQLConf.ParquetOutputTimestampType.TIMESTAMP_MILLIS => (row: SpecializedGetters, ordinal: Int) => - val millis = DateTimeUtils.microsToMillis(row.getLong(ordinal)) + val micros = row.getLong(ordinal) + val millis = DateTimeUtils.microsToMillis(timestampRebaseFunc(micros)) recordConsumer.addLong(millis) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index 6e05aa56f4f72..7e8e0ed2dc675 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -94,8 +94,10 @@ trait FileScan extends Scan with Batch with SupportsReportStatistics with Loggin override def hashCode(): Int = getClass.hashCode() override def description(): String = { + val maxMetadataValueLength = 100 val locationDesc = - fileIndex.getClass.getSimpleName + fileIndex.rootPaths.mkString("[", ", ", "]") + fileIndex.getClass.getSimpleName + + Utils.buildLocationMetadata(fileIndex.rootPaths, maxMetadataValueLength) val metadata: Map[String, String] = Map( "ReadSchema" -> readDataSchema.catalogString, "PartitionFilters" -> seqToString(partitionFilters), @@ -105,7 +107,7 @@ trait FileScan extends Scan with Batch with SupportsReportStatistics with Loggin case (key, value) => val redactedValue = Utils.redact(sparkSession.sessionState.conf.stringRedactionPattern, value) - key + ": " + StringUtils.abbreviate(redactedValue, 100) + key + ": " + StringUtils.abbreviate(redactedValue, maxMetadataValueLength) }.mkString(", ") s"${this.getClass.getSimpleName} $metadataStr" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index 1925fa1796d48..3b482b0c8ab62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitionedF import org.apache.spark.sql.execution.datasources.parquet._ import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.{AtomicType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -116,8 +117,9 @@ case class ParquetPartitionReaderFactory( private def buildReaderBase[T]( file: PartitionedFile, buildReaderFunc: ( - ParquetInputSplit, InternalRow, TaskAttemptContextImpl, Option[FilterPredicate], - Option[ZoneId], Boolean) => RecordReader[Void, T]): RecordReader[Void, T] = { + ParquetInputSplit, InternalRow, TaskAttemptContextImpl, + Option[FilterPredicate], Option[ZoneId], + LegacyBehaviorPolicy.Value) => RecordReader[Void, T]): RecordReader[Void, T] = { val conf = broadcastedConf.value.value val filePath = new Path(new URI(file.filePath)) @@ -169,12 +171,11 @@ case class ParquetPartitionReaderFactory( if (pushed.isDefined) { ParquetInputFormat.setFilterPredicate(hadoopAttemptContext.getConfiguration, pushed.get) } - val rebaseDatetime = DataSourceUtils.needRebaseDateTime( - footerFileMetaData.getKeyValueMetaData.get).getOrElse { - SQLConf.get.getConf(SQLConf.LEGACY_PARQUET_REBASE_DATETIME_IN_READ) - } + val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode( + footerFileMetaData.getKeyValueMetaData.get, + SQLConf.get.getConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_READ)) val reader = buildReaderFunc( - split, file.partitionValues, hadoopAttemptContext, pushed, convertTz, rebaseDatetime) + split, file.partitionValues, hadoopAttemptContext, pushed, convertTz, datetimeRebaseMode) reader.initialize(split, hadoopAttemptContext) reader } @@ -189,12 +190,12 @@ case class ParquetPartitionReaderFactory( hadoopAttemptContext: TaskAttemptContextImpl, pushed: Option[FilterPredicate], convertTz: Option[ZoneId], - needDateTimeRebase: Boolean): RecordReader[Void, InternalRow] = { + datetimeRebaseMode: LegacyBehaviorPolicy.Value): RecordReader[Void, InternalRow] = { logDebug(s"Falling back to parquet-mr") val taskContext = Option(TaskContext.get()) // ParquetRecordReader returns InternalRow val readSupport = new ParquetReadSupport( - convertTz, enableVectorizedReader = false, needDateTimeRebase) + convertTz, enableVectorizedReader = false, datetimeRebaseMode) val reader = if (pushed.isDefined && enableRecordFilter) { val parquetFilter = FilterCompat.get(pushed.get, null) new ParquetRecordReader[InternalRow](readSupport, parquetFilter) @@ -220,11 +221,11 @@ case class ParquetPartitionReaderFactory( hadoopAttemptContext: TaskAttemptContextImpl, pushed: Option[FilterPredicate], convertTz: Option[ZoneId], - rebaseDatetime: Boolean): VectorizedParquetRecordReader = { + datetimeRebaseMode: LegacyBehaviorPolicy.Value): VectorizedParquetRecordReader = { val taskContext = Option(TaskContext.get()) val vectorizedReader = new VectorizedParquetRecordReader( convertTz.orNull, - rebaseDatetime, + datetimeRebaseMode.toString, enableOffHeapColumnVector && taskContext.isDefined, capacity) val iter = new RecordReaderIterator(vectorizedReader) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala index eb091758910cd..cfc653a23840d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.dynamicpruning import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.{Alias, BindReferences, DynamicPruningExpression, DynamicPruningSubquery, Expression, ListQuery, Literal, PredicateHelper} +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode import org.apache.spark.sql.catalyst.rules.Rule diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 08128d8f69dab..707ed1402d1ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index 888e7af7c07ed..52b476f9cf134 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -21,6 +21,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{ExplainUtils, SparkPlan} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 7f90a51c1f234..c7c3e1672f034 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{ExplainUtils, RowIterator} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 755a63e545ef1..2b7cd65e7d96f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.optimizer.BuildSide import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.SparkPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index d05113431df41..4b2d4195ee906 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -276,12 +276,12 @@ case class MapElementsExec( } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val (funcClass, methodName) = func match { + val (funcClass, funcName) = func match { case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call" case _ => FunctionUtils.getFunctionOneName(outputObjectType, child.output(0).dataType) } val funcObj = Literal.create(func, ObjectType(funcClass)) - val callFunc = Invoke(funcObj, methodName, outputObjectType, child.output) + val callFunc = Invoke(funcObj, funcName, outputObjectType, child.output, propagateNull = false) val result = BindReferences.bindReference(callFunc, child.output).genCode(ctx) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala index 10bcfe6649802..e8ae0eaf0ea48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala @@ -213,7 +213,7 @@ abstract class CompactibleFileStreamLog[T <: AnyRef : ClassTag]( * Returns all files except the deleted ones. */ def allFiles(): Array[T] = { - var latestId = getLatest().map(_._1).getOrElse(-1L) + var latestId = getLatestBatchId().getOrElse(-1L) // There is a race condition when `FileStreamSink` is deleting old files and `StreamFileIndex` // is calling this method. This loop will retry the reading to deal with the // race condition. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index b679f163fc561..32245470d8f5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -142,7 +142,7 @@ class FileStreamSink( } override def addBatch(batchId: Long, data: DataFrame): Unit = { - if (batchId <= fileLog.getLatest().map(_._1).getOrElse(-1L)) { + if (batchId <= fileLog.getLatestBatchId().getOrElse(-1L)) { logInfo(s"Skipping already committed batch $batchId") } else { val committer = FileCommitProtocol.instantiate( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala index 7b2ea9627a98e..c43887774c13d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala @@ -96,7 +96,7 @@ class FileStreamSourceLog( val searchKeys = removedBatches.map(_._1) val retrievedBatches = if (searchKeys.nonEmpty) { logWarning(s"Get batches from removed files, this is unexpected in the current code path!!!") - val latestBatchId = getLatest().map(_._1).getOrElse(-1L) + val latestBatchId = getLatestBatchId().getOrElse(-1L) if (latestBatchId < 0) { Map.empty[Long, Option[Array[FileEntry]]] } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index ed0c44da08c5d..5c86f8a50ddae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -182,17 +182,26 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: } } - override def getLatest(): Option[(Long, T)] = { - val batchIds = fileManager.list(metadataPath, batchFilesFilter) + /** + * Return the latest batch Id without reading the file. This method only checks for existence of + * file to avoid cost on reading and deserializing log file. + */ + def getLatestBatchId(): Option[Long] = { + fileManager.list(metadataPath, batchFilesFilter) .map(f => pathToBatchId(f.getPath)) .sorted(Ordering.Long.reverse) - for (batchId <- batchIds) { - val batch = get(batchId) - if (batch.isDefined) { - return Some((batchId, batch.get)) + .headOption + } + + override def getLatest(): Option[(Long, T)] = { + getLatestBatchId().map { batchId => + val content = get(batchId).getOrElse { + // If we find the last batch file, we must read that file, other than failing back to + // old batches. + throw new IllegalStateException(s"failed to read log file for batch $batchId") } + (batchId, content) } - None } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 9b1951a834d9a..18fe38caa5e65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -44,6 +44,7 @@ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.StreamingExplainCommand import org.apache.spark.sql.execution.datasources.v2.StreamWriterCommitProgress import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.connector.SupportsStreamingUpdate import org.apache.spark.sql.streaming._ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.{Clock, UninterruptibleThread, Utils} @@ -629,14 +630,9 @@ abstract class StreamExecution( writeBuilder.asInstanceOf[SupportsTruncate].truncate().buildForStreaming() case Update => - // Although no v2 sinks really support Update mode now, but during tests we do want them - // to pretend to support Update mode, and treat Update mode same as Append mode. - if (Utils.isTesting) { - writeBuilder.buildForStreaming() - } else { - throw new IllegalArgumentException( - "Data source v2 streaming sinks does not support Update mode.") - } + require(writeBuilder.isInstanceOf[SupportsStreamingUpdate], + table.name + " does not support Update mode.") + writeBuilder.asInstanceOf[SupportsStreamingUpdate].update().buildForStreaming() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index e471e6c601d16..1e64021c8105e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapabi import org.apache.spark.sql.connector.write.{LogicalWriteInfo, SupportsTruncate, WriteBuilder} import org.apache.spark.sql.connector.write.streaming.StreamingWrite import org.apache.spark.sql.execution.streaming.sources.ConsoleWrite -import org.apache.spark.sql.internal.connector.SimpleTableProvider +import org.apache.spark.sql.internal.connector.{SimpleTableProvider, SupportsStreamingUpdate} import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -73,11 +73,12 @@ object ConsoleTable extends Table with SupportsWrite { } override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { - new WriteBuilder with SupportsTruncate { + new WriteBuilder with SupportsTruncate with SupportsStreamingUpdate { private val inputSchema: StructType = info.schema() - // Do nothing for truncate. Console sink is special that it just prints all the records. + // Do nothing for truncate/update. Console sink is special and it just prints all the records. override def truncate(): WriteBuilder = this + override def update(): WriteBuilder = this override def buildForStreaming(): StreamingWrite = { assert(inputSchema != null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala index ba54c85d07303..57a73c740310e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapabi import org.apache.spark.sql.connector.write.{DataWriter, LogicalWriteInfo, PhysicalWriteInfo, SupportsTruncate, WriteBuilder, WriterCommitMessage} import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.execution.python.PythonForeachWriter +import org.apache.spark.sql.internal.connector.SupportsStreamingUpdate import org.apache.spark.sql.types.StructType /** @@ -54,12 +55,13 @@ case class ForeachWriterTable[T]( } override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { - new WriteBuilder with SupportsTruncate { + new WriteBuilder with SupportsTruncate with SupportsStreamingUpdate { private var inputSchema: StructType = info.schema() - // Do nothing for truncate. Foreach sink is special that it just forwards all the records to - // ForeachWriter. + // Do nothing for truncate/update. Foreach sink is special and it just forwards all the + // records to ForeachWriter. override def truncate(): WriteBuilder = this + override def update(): WriteBuilder = this override def buildForStreaming(): StreamingWrite = { new StreamingWrite { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala index deab42bea36ad..03ebbb9f1b376 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapabi import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory, LogicalWriteInfo, PhysicalWriteInfo, SupportsTruncate, WriteBuilder, WriterCommitMessage} import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.execution.streaming.Sink +import org.apache.spark.sql.internal.connector.SupportsStreamingUpdate import org.apache.spark.sql.types.StructType /** @@ -53,7 +54,7 @@ class MemorySink extends Table with SupportsWrite with Logging { } override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { - new WriteBuilder with SupportsTruncate { + new WriteBuilder with SupportsTruncate with SupportsStreamingUpdate { private var needTruncate: Boolean = false private val inputSchema: StructType = info.schema() @@ -62,6 +63,9 @@ class MemorySink extends Table with SupportsWrite with Logging { this } + // The in-memory sink treats update as append. + override def update(): WriteBuilder = this + override def buildForStreaming(): StreamingWrite = { new MemoryStreamingWrite(MemorySink.this, inputSchema, needTruncate) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index 5ce052a0ae997..33539c01ee5dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -21,9 +21,8 @@ import java.net.URLEncoder import java.nio.charset.StandardCharsets.UTF_8 import javax.servlet.http.HttpServletRequest -import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.xml.{Node, NodeSeq, Unparsed} +import scala.xml.{Node, NodeSeq} import org.apache.spark.JobExecutionStatus import org.apache.spark.internal.Logging @@ -159,26 +158,8 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L showSucceededJobs: Boolean, showFailedJobs: Boolean): Seq[Node] = { - val parameterOtherTable = request.getParameterMap().asScala - .filterNot(_._1.startsWith(executionTag)) - .map { case (name, vals) => - name + "=" + vals(0) - } - - val parameterExecutionPage = request.getParameter(s"$executionTag.page") - val parameterExecutionSortColumn = request.getParameter(s"$executionTag.sort") - val parameterExecutionSortDesc = request.getParameter(s"$executionTag.desc") - val parameterExecutionPageSize = request.getParameter(s"$executionTag.pageSize") - - val executionPage = Option(parameterExecutionPage).map(_.toInt).getOrElse(1) - val executionSortColumn = Option(parameterExecutionSortColumn).map { sortColumn => - UIUtils.decodeURLParameter(sortColumn) - }.getOrElse("ID") - val executionSortDesc = Option(parameterExecutionSortDesc).map(_.toBoolean).getOrElse( - // New executions should be shown above old executions by default. - executionSortColumn == "ID" - ) - val executionPageSize = Option(parameterExecutionPageSize).map(_.toInt).getOrElse(100) + val executionPage = + Option(request.getParameter(s"$executionTag.page")).map(_.toInt).getOrElse(1) val tableHeaderId = executionTag // "running", "completed" or "failed" @@ -191,11 +172,7 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L executionTag, UIUtils.prependBaseUri(request, parent.basePath), "SQL", // subPath - parameterOtherTable, currentTime, - pageSize = executionPageSize, - sortColumn = executionSortColumn, - desc = executionSortDesc, showRunningJobs, showSucceededJobs, showFailedJobs).table(executionPage) @@ -219,20 +196,17 @@ private[ui] class ExecutionPagedTable( executionTag: String, basePath: String, subPath: String, - parameterOtherTable: Iterable[String], currentTime: Long, - pageSize: Int, - sortColumn: String, - desc: Boolean, showRunningJobs: Boolean, showSucceededJobs: Boolean, showFailedJobs: Boolean) extends PagedTable[ExecutionTableRowData] { + private val (sortColumn, desc, pageSize) = getTableParameters(request, executionTag, "ID") + + private val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) + override val dataSource = new ExecutionDataSource( - request, - parent, data, - basePath, currentTime, pageSize, sortColumn, @@ -241,16 +215,15 @@ private[ui] class ExecutionPagedTable( showSucceededJobs, showFailedJobs) - private val parameterPath = s"$basePath/$subPath/?${parameterOtherTable.mkString("&")}" + private val parameterPath = + s"$basePath/$subPath/?${getParameterOtherTable(request, executionTag)}" override def tableId: String = s"$executionTag-table" override def tableCssClass: String = - "table table-bordered table-sm table-striped " + - "table-head-clickable table-cell-width-limited" + "table table-bordered table-sm table-striped table-head-clickable table-cell-width-limited" override def pageLink(page: Int): String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) parameterPath + s"&$pageNumberFormField=$page" + s"&$executionTag.sort=$encodedSortColumn" + @@ -263,89 +236,36 @@ private[ui] class ExecutionPagedTable( override def pageNumberFormField: String = s"$executionTag.page" - override def goButtonFormPath: String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) + override def goButtonFormPath: String = s"$parameterPath&$executionTag.sort=$encodedSortColumn&$executionTag.desc=$desc#$tableHeaderId" - } override def headers: Seq[Node] = { - // Information for each header: title, sortable - val executionHeadersAndCssClasses: Seq[(String, Boolean)] = + // Information for each header: title, sortable, tooltip + val executionHeadersAndCssClasses: Seq[(String, Boolean, Option[String])] = Seq( - ("ID", true), - ("Description", true), - ("Submitted", true), - ("Duration", true)) ++ { + ("ID", true, None), + ("Description", true, None), + ("Submitted", true, None), + ("Duration", true, Some("Time from query submission to completion (or if still executing," + + "time since submission)"))) ++ { if (showRunningJobs && showSucceededJobs && showFailedJobs) { Seq( - ("Running Job IDs", true), - ("Succeeded Job IDs", true), - ("Failed Job IDs", true)) + ("Running Job IDs", true, None), + ("Succeeded Job IDs", true, None), + ("Failed Job IDs", true, None)) } else if (showSucceededJobs && showFailedJobs) { Seq( - ("Succeeded Job IDs", true), - ("Failed Job IDs", true)) + ("Succeeded Job IDs", true, None), + ("Failed Job IDs", true, None)) } else { - Seq(("Job IDs", true)) + Seq(("Job IDs", true, None)) } } - val sortableColumnHeaders = executionHeadersAndCssClasses.filter { - case (_, sortable) => sortable - }.map { case (title, _) => title } - - require(sortableColumnHeaders.contains(sortColumn), s"Unknown column: $sortColumn") - - val headerRow: Seq[Node] = { - executionHeadersAndCssClasses.map { case (header, sortable) => - if (header == sortColumn) { - val headerLink = Unparsed( - parameterPath + - s"&$executionTag.sort=${URLEncoder.encode(header, UTF_8.name())}" + - s"&$executionTag.desc=${!desc}" + - s"&$executionTag.pageSize=$pageSize" + - s"#$tableHeaderId") - val arrow = if (desc) "▾" else "▴" // UP or DOWN - - - - {header} -  {Unparsed(arrow)} - - - - } else { - if (sortable) { - val headerLink = Unparsed( - parameterPath + - s"&$executionTag.sort=${URLEncoder.encode(header, UTF_8.name())}" + - s"&$executionTag.pageSize=$pageSize" + - s"#$tableHeaderId") - - - - {if (header == "Duration") { - - {header} - - } else { - {header} - }} - - - } else { - - {header} - - } - } - } - } - - {headerRow} - + isSortColumnValid(executionHeadersAndCssClasses, sortColumn) + + headerRow(executionHeadersAndCssClasses, desc, pageSize, sortColumn, parameterPath, + executionTag, tableHeaderId) } override def row(executionTableRow: ExecutionTableRowData): Seq[Node] = { @@ -423,7 +343,6 @@ private[ui] class ExecutionPagedTable( private[ui] class ExecutionTableRowData( - val submissionTime: Long, val duration: Long, val executionUIData: SQLExecutionUIData, val runningJobData: Seq[Int], @@ -432,10 +351,7 @@ private[ui] class ExecutionTableRowData( private[ui] class ExecutionDataSource( - request: HttpServletRequest, - parent: SQLTab, executionData: Seq[SQLExecutionUIData], - basePath: String, currentTime: Long, pageSize: Int, sortColumn: String, @@ -448,20 +364,13 @@ private[ui] class ExecutionDataSource( // in the table so that we can avoid creating duplicate contents during sorting the data private val data = executionData.map(executionRow).sorted(ordering(sortColumn, desc)) - private var _sliceExecutionIds: Set[Int] = _ - override def dataSize: Int = data.size - override def sliceData(from: Int, to: Int): Seq[ExecutionTableRowData] = { - val r = data.slice(from, to) - _sliceExecutionIds = r.map(_.executionUIData.executionId.toInt).toSet - r - } + override def sliceData(from: Int, to: Int): Seq[ExecutionTableRowData] = data.slice(from, to) private def executionRow(executionUIData: SQLExecutionUIData): ExecutionTableRowData = { - val submissionTime = executionUIData.submissionTime val duration = executionUIData.completionTime.map(_.getTime()) - .getOrElse(currentTime) - submissionTime + .getOrElse(currentTime) - executionUIData.submissionTime val runningJobData = if (showRunningJobs) { executionUIData.jobs.filter { @@ -482,7 +391,6 @@ private[ui] class ExecutionDataSource( } else Seq.empty new ExecutionTableRowData( - submissionTime, duration, executionUIData, runningJobData, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 274a5a414ffa2..a798fe02700e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -153,7 +153,7 @@ object SparkPlanGraph { * @param name the name of this SparkPlan node * @param metrics metrics that this SparkPlan node will track */ -private[ui] class SparkPlanGraphNode( +class SparkPlanGraphNode( val id: Long, val name: String, val desc: String, @@ -193,7 +193,7 @@ private[ui] class SparkPlanGraphNode( /** * Represent a tree of SparkPlan for WholeStageCodegen. */ -private[ui] class SparkPlanGraphCluster( +class SparkPlanGraphCluster( id: Long, name: String, desc: String, @@ -229,7 +229,7 @@ private[ui] class SparkPlanGraphCluster( * Represent an edge in the SparkPlan tree. `fromId` is the child node id, and `toId` is the parent * node id. */ -private[ui] case class SparkPlanGraphEdge(fromId: Long, toId: Long) { +case class SparkPlanGraphEdge(fromId: Long, toId: Long) { def makeDotEdge: String = s""" $fromId->$toId;\n""" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5481337bf6cee..0cca3e7b47c56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1306,7 +1306,7 @@ object functions { * @since 1.4.0 */ @scala.annotation.varargs - def struct(cols: Column*): Column = withExpr { CreateStruct(cols.map(_.expr)) } + def struct(cols: Column*): Column = withExpr { CreateStruct.create(cols.map(_.expr)) } /** * Creates a new struct column that composes multiple input columns. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index 47119ab903da7..ce4385d88f1e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -53,7 +53,7 @@ private[sql] class SharedState( initialConfigs: scala.collection.Map[String, String]) extends Logging { - SharedState.setFsUrlStreamHandlerFactory(sparkContext.conf) + SharedState.setFsUrlStreamHandlerFactory(sparkContext.conf, sparkContext.hadoopConfiguration) private val (conf, hadoopConf) = { // Load hive-site.xml into hadoopConf and determine the warehouse path which will be set into @@ -174,13 +174,13 @@ private[sql] class SharedState( object SharedState extends Logging { @volatile private var fsUrlStreamHandlerFactoryInitialized = false - private def setFsUrlStreamHandlerFactory(conf: SparkConf): Unit = { + private def setFsUrlStreamHandlerFactory(conf: SparkConf, hadoopConf: Configuration): Unit = { if (!fsUrlStreamHandlerFactoryInitialized && conf.get(DEFAULT_URL_STREAM_HANDLER_FACTORY_ENABLED)) { synchronized { if (!fsUrlStreamHandlerFactoryInitialized) { try { - URL.setURLStreamHandlerFactory(new FsUrlStreamHandlerFactory()) + URL.setURLStreamHandlerFactory(new FsUrlStreamHandlerFactory(hadoopConf)) fsUrlStreamHandlerFactoryInitialized = true } catch { case NonFatal(_) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryPage.scala index 733676546eab3..b969e41e4e55c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryPage.scala @@ -17,15 +17,19 @@ package org.apache.spark.sql.streaming.ui +import java.net.URLEncoder +import java.nio.charset.StandardCharsets.UTF_8 import javax.servlet.http.HttpServletRequest +import scala.collection.mutable import scala.xml.Node import org.apache.commons.text.StringEscapeUtils import org.apache.spark.internal.Logging import org.apache.spark.sql.streaming.ui.UIUtils._ -import org.apache.spark.ui.{UIUtils => SparkUIUtils, WebUIPage} +import org.apache.spark.ui.{PagedDataSource, PagedTable, UIUtils => SparkUIUtils, WebUIPage} +import org.apache.spark.util.Utils private[ui] class StreamingQueryPage(parent: StreamingQueryTab) extends WebUIPage("") with Logging { @@ -35,11 +39,147 @@ private[ui] class StreamingQueryPage(parent: StreamingQueryTab) SparkUIUtils.headerSparkPage(request, "Streaming Query", content, parent) } - def generateDataRow(request: HttpServletRequest, queryActive: Boolean) - (query: StreamingQueryUIData): Seq[Node] = { + private def generateStreamingQueryTable(request: HttpServletRequest): Seq[Node] = { + val (activeQueries, inactiveQueries) = parent.statusListener.allQueryStatus + .partition(_.isActive) + + val content = mutable.ListBuffer[Node]() + // show active queries table only if there is at least one active query + if (activeQueries.nonEmpty) { + // scalastyle:off + content ++= + +
+ + Active Streaming Queries ({activeQueries.length}) +
+
++ +
+
    + {queryTable(activeQueries, request, "active")} +
+
+ // scalastyle:on + } + // show active queries table only if there is at least one completed query + if (inactiveQueries.nonEmpty) { + // scalastyle:off + content ++= + +
+ + Completed Streaming Queries ({inactiveQueries.length}) +
+
++ +
+
    + {queryTable(inactiveQueries, request, "completed")} +
+
+ // scalastyle:on + } + content + } + + private def queryTable(data: Seq[StreamingQueryUIData], request: HttpServletRequest, + tableTag: String): Seq[Node] = { + + val isActive = if (tableTag.contains("active")) true else false + val page = Option(request.getParameter(s"$tableTag.page")).map(_.toInt).getOrElse(1) + + try { + new StreamingQueryPagedTable( + request, + parent, + data, + tableTag, + isActive, + SparkUIUtils.prependBaseUri(request, parent.basePath), + "StreamingQuery" + ).table(page) + } catch { + case e@(_: IllegalArgumentException | _: IndexOutOfBoundsException) => +
+

Error while rendering execution table:

+
+            {Utils.exceptionString(e)}
+          
+
+ } + } +} + +class StreamingQueryPagedTable( + request: HttpServletRequest, + parent: StreamingQueryTab, + data: Seq[StreamingQueryUIData], + tableTag: String, + isActive: Boolean, + basePath: String, + subPath: String) extends PagedTable[StructuredStreamingRow] { + + private val (sortColumn, sortDesc, pageSize) = getTableParameters(request, tableTag, "Start Time") + private val parameterPath = s"$basePath/$subPath/?${getParameterOtherTable(request, tableTag)}" + private val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) + + override def tableId: String = s"$tableTag-table" + + override def tableCssClass: String = + "table table-bordered table-sm table-striped table-head-clickable table-cell-width-limited" + + override def pageSizeFormField: String = s"$tableTag.pageSize" + + override def pageNumberFormField: String = s"$tableTag.page" + + override def pageLink(page: Int): String = { + parameterPath + + s"&$pageNumberFormField=$page" + + s"&$tableTag.sort=$encodedSortColumn" + + s"&$tableTag.desc=$sortDesc" + + s"&$pageSizeFormField=$pageSize" + + s"#$tableTag" + } + + override def goButtonFormPath: String = + s"$parameterPath&$tableTag.sort=$encodedSortColumn&$tableTag.desc=$sortDesc#$tableTag" + + override def dataSource: PagedDataSource[StructuredStreamingRow] = + new StreamingQueryDataSource(data, sortColumn, sortDesc, pageSize, isActive) + + override def headers: Seq[Node] = { + val headerAndCss: Seq[(String, Boolean, Option[String])] = { + Seq( + ("Name", true, None), + ("Status", false, None), + ("ID", true, None), + ("Run ID", true, None), + ("Start Time", true, None), + ("Duration", true, None), + ("Avg Input /sec", true, None), + ("Avg Process /sec", true, None), + ("Latest Batch", true, None)) ++ { + if (!isActive) { + Seq(("Error", false, None)) + } else { + Nil + } + } + } + isSortColumnValid(headerAndCss, sortColumn) + + headerRow(headerAndCss, sortDesc, pageSize, sortColumn, parameterPath, tableTag, tableTag) + } + + override def row(query: StructuredStreamingRow): Seq[Node] = { + val streamingQuery = query.streamingUIData + val statisticsLink = "%s/%s/statistics?id=%s" + .format(SparkUIUtils.prependBaseUri(request, parent.basePath), parent.prefix, + streamingQuery.runId) def details(detail: Any): Seq[Node] = { - if (queryActive) { + if (isActive) { return Seq.empty[Node] } val detailString = detail.asInstanceOf[String] @@ -51,12 +191,39 @@ private[ui] class StreamingQueryPage(parent: StreamingQueryTab) {summary}{details} } - val statisticsLink = "%s/%s/statistics?id=%s" - .format(SparkUIUtils.prependBaseUri(request, parent.basePath), parent.prefix, query.runId) + + {UIUtils.getQueryName(streamingQuery)} + {UIUtils.getQueryStatus(streamingQuery)} + {streamingQuery.id} + {streamingQuery.runId} + {SparkUIUtils.formatDate(streamingQuery.startTimestamp)} + {query.duration} + {withNoProgress(streamingQuery, {query.avgInput.formatted("%.2f")}, "NaN")} + {withNoProgress(streamingQuery, {query.avgProcess.formatted("%.2f")}, "NaN")} + {withNoProgress(streamingQuery, {streamingQuery.lastProgress.batchId}, "NaN")} + {details(streamingQuery.exception.getOrElse("-"))} + + } +} - val name = UIUtils.getQueryName(query) - val status = UIUtils.getQueryStatus(query) - val duration = if (queryActive) { +case class StructuredStreamingRow( + duration: String, + avgInput: Double, + avgProcess: Double, + streamingUIData: StreamingQueryUIData) + +class StreamingQueryDataSource(uiData: Seq[StreamingQueryUIData], sortColumn: String, desc: Boolean, + pageSize: Int, isActive: Boolean) extends PagedDataSource[StructuredStreamingRow](pageSize) { + + // convert StreamingQueryUIData to StreamingRow to provide required data for sorting and sort it + private val data = uiData.map(streamingRow).sorted(ordering(sortColumn, desc)) + + override def dataSize: Int = data.size + + override def sliceData(from: Int, to: Int): Seq[StructuredStreamingRow] = data.slice(from, to) + + private def streamingRow(query: StreamingQueryUIData): StructuredStreamingRow = { + val duration = if (isActive) { SparkUIUtils.formatDurationVerbose(System.currentTimeMillis() - query.startTimestamp) } else { withNoProgress(query, { @@ -65,79 +232,31 @@ private[ui] class StreamingQueryPage(parent: StreamingQueryTab) }, "-") } - - {name} - {status} - {query.id} - {query.runId} - {SparkUIUtils.formatDate(query.startTimestamp)} - {duration} - {withNoProgress(query, { - (query.recentProgress.map(p => withNumberInvalid(p.inputRowsPerSecond)).sum / - query.recentProgress.length).formatted("%.2f") }, "NaN")} - - {withNoProgress(query, { - (query.recentProgress.map(p => withNumberInvalid(p.processedRowsPerSecond)).sum / - query.recentProgress.length).formatted("%.2f") }, "NaN")} - - {withNoProgress(query, { query.lastProgress.batchId }, "NaN")} - {details(query.exception.getOrElse("-"))} - - } + val avgInput = (query.recentProgress.map(p => withNumberInvalid(p.inputRowsPerSecond)).sum / + query.recentProgress.length) - private def generateStreamingQueryTable(request: HttpServletRequest): Seq[Node] = { - val (activeQueries, inactiveQueries) = parent.statusListener.allQueryStatus - .partition(_.isActive) - val activeQueryTables = if (activeQueries.nonEmpty) { - val headerRow = Seq( - "Name", "Status", "Id", "Run ID", "Start Time", "Duration", "Avg Input /sec", - "Avg Process /sec", "Lastest Batch") - - Some(SparkUIUtils.listingTable(headerRow, generateDataRow(request, queryActive = true), - activeQueries, true, Some("activeQueries-table"), Seq(null), false)) - } else { - None - } + val avgProcess = (query.recentProgress.map(p => + withNumberInvalid(p.processedRowsPerSecond)).sum / query.recentProgress.length) - val inactiveQueryTables = if (inactiveQueries.nonEmpty) { - val headerRow = Seq( - "Name", "Status", "Id", "Run ID", "Start Time", "Duration", "Avg Input /sec", - "Avg Process /sec", "Lastest Batch", "Error") + StructuredStreamingRow(duration, avgInput, avgProcess, query) + } - Some(SparkUIUtils.listingTable(headerRow, generateDataRow(request, queryActive = false), - inactiveQueries, true, Some("completedQueries-table"), Seq(null), false)) + private def ordering(sortColumn: String, desc: Boolean): Ordering[StructuredStreamingRow] = { + val ordering: Ordering[StructuredStreamingRow] = sortColumn match { + case "Name" => Ordering.by(q => UIUtils.getQueryName(q.streamingUIData)) + case "ID" => Ordering.by(_.streamingUIData.id) + case "Run ID" => Ordering.by(_.streamingUIData.runId) + case "Start Time" => Ordering.by(_.streamingUIData.startTimestamp) + case "Duration" => Ordering.by(_.duration) + case "Avg Input /sec" => Ordering.by(_.avgInput) + case "Avg Process /sec" => Ordering.by(_.avgProcess) + case "Latest Batch" => Ordering.by(_.streamingUIData.lastProgress.batchId) + case unknownColumn => throw new IllegalArgumentException(s"Unknown Column: $unknownColumn") + } + if (desc) { + ordering.reverse } else { - None + ordering } - - // scalastyle:off - val content = - -
- - Active Streaming Queries ({activeQueries.length}) -
-
++ -
-
    - {activeQueryTables.getOrElse(Seq.empty[Node])} -
-
++ - -
- - Completed Streaming Queries ({inactiveQueries.length}) -
-
++ -
-
    - {inactiveQueryTables.getOrElse(Seq.empty[Node])} -
-
- // scalastyle:on - - content } } diff --git a/sql/core/src/main/scala/org/apache/spark/status/api/v1/sql/SqlResource.scala b/sql/core/src/main/scala/org/apache/spark/status/api/v1/sql/SqlResource.scala index 346e07f2bef15..c7599f864dd97 100644 --- a/sql/core/src/main/scala/org/apache/spark/status/api/v1/sql/SqlResource.scala +++ b/sql/core/src/main/scala/org/apache/spark/status/api/v1/sql/SqlResource.scala @@ -21,21 +21,29 @@ import java.util.Date import javax.ws.rs._ import javax.ws.rs.core.MediaType +import scala.util.{Failure, Success, Try} + import org.apache.spark.JobExecutionStatus -import org.apache.spark.sql.execution.ui.{SQLAppStatusStore, SQLExecutionUIData, SQLPlanMetric} +import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SparkPlanGraphCluster, SparkPlanGraphNode, SQLAppStatusStore, SQLExecutionUIData} import org.apache.spark.status.api.v1.{BaseAppResource, NotFoundException} @Produces(Array(MediaType.APPLICATION_JSON)) private[v1] class SqlResource extends BaseAppResource { + val WHOLE_STAGE_CODEGEN = "WholeStageCodegen" + @GET def sqlList( - @DefaultValue("false") @QueryParam("details") details: Boolean, + @DefaultValue("true") @QueryParam("details") details: Boolean, + @DefaultValue("true") @QueryParam("planDescription") planDescription: Boolean, @DefaultValue("0") @QueryParam("offset") offset: Int, @DefaultValue("20") @QueryParam("length") length: Int): Seq[ExecutionData] = { withUI { ui => val sqlStore = new SQLAppStatusStore(ui.store.store) - sqlStore.executionsList(offset, length).map(prepareExecutionData(_, details)) + sqlStore.executionsList(offset, length).map { exec => + val graph = sqlStore.planGraph(exec.executionId) + prepareExecutionData(exec, graph, details, planDescription) + } } } @@ -43,24 +51,25 @@ private[v1] class SqlResource extends BaseAppResource { @Path("{executionId:\\d+}") def sql( @PathParam("executionId") execId: Long, - @DefaultValue("false") @QueryParam("details") details: Boolean): ExecutionData = { + @DefaultValue("true") @QueryParam("details") details: Boolean, + @DefaultValue("true") @QueryParam("planDescription") + planDescription: Boolean): ExecutionData = { withUI { ui => val sqlStore = new SQLAppStatusStore(ui.store.store) + val graph = sqlStore.planGraph(execId) sqlStore .execution(execId) - .map(prepareExecutionData(_, details)) - .getOrElse(throw new NotFoundException("unknown id: " + execId)) + .map(prepareExecutionData(_, graph, details, planDescription)) + .getOrElse(throw new NotFoundException("unknown query execution id: " + execId)) } } - private def printableMetrics( - metrics: Seq[SQLPlanMetric], - metricValues: Map[Long, String]): Seq[Metrics] = { - metrics.map(metric => - Metrics(metric.name, metricValues.get(metric.accumulatorId).getOrElse(""))) - } + private def prepareExecutionData( + exec: SQLExecutionUIData, + graph: SparkPlanGraph, + details: Boolean, + planDescription: Boolean): ExecutionData = { - private def prepareExecutionData(exec: SQLExecutionUIData, details: Boolean): ExecutionData = { var running = Seq[Int]() var completed = Seq[Int]() var failed = Seq[Int]() @@ -84,18 +93,65 @@ private[v1] class SqlResource extends BaseAppResource { } val duration = exec.completionTime.getOrElse(new Date()).getTime - exec.submissionTime - val planDetails = if (details) exec.physicalPlanDescription else "" - val metrics = if (details) printableMetrics(exec.metrics, exec.metricValues) else Seq.empty + val planDetails = if (planDescription) exec.physicalPlanDescription else "" + val nodes = if (details) printableMetrics(graph.allNodes, exec.metricValues) else Seq.empty + val edges = if (details) graph.edges else Seq.empty + new ExecutionData( exec.executionId, status, exec.description, planDetails, - metrics, new Date(exec.submissionTime), duration, running, completed, - failed) + failed, + nodes, + edges) } + + private def printableMetrics(allNodes: Seq[SparkPlanGraphNode], + metricValues: Map[Long, String]): Seq[Node] = { + + def getMetric(metricValues: Map[Long, String], accumulatorId: Long, + metricName: String): Option[Metric] = { + + metricValues.get(accumulatorId).map( mv => { + val metricValue = if (mv.startsWith("\n")) mv.substring(1, mv.length) else mv + Metric(metricName, metricValue) + }) + } + + val nodeIdAndWSCGIdMap = getNodeIdAndWSCGIdMap(allNodes) + val nodes = allNodes.map { node => + val wholeStageCodegenId = nodeIdAndWSCGIdMap.get(node.id).flatten + val metrics = + node.metrics.flatMap(m => getMetric(metricValues, m.accumulatorId, m.name.trim)) + Node(nodeId = node.id, nodeName = node.name.trim, wholeStageCodegenId, metrics) + } + + nodes.sortBy(_.nodeId).reverse + } + + private def getNodeIdAndWSCGIdMap(allNodes: Seq[SparkPlanGraphNode]): Map[Long, Option[Long]] = { + val wscgNodes = allNodes.filter(_.name.trim.startsWith(WHOLE_STAGE_CODEGEN)) + val nodeIdAndWSCGIdMap: Map[Long, Option[Long]] = wscgNodes.flatMap { + _ match { + case x: SparkPlanGraphCluster => x.nodes.map(_.id -> getWholeStageCodegenId(x.name.trim)) + case _ => Seq.empty + } + }.toMap + + nodeIdAndWSCGIdMap + } + + private def getWholeStageCodegenId(wscgNodeName: String): Option[Long] = { + Try(wscgNodeName.substring( + s"$WHOLE_STAGE_CODEGEN (".length, wscgNodeName.length - 1).toLong) match { + case Success(wscgId) => Some(wscgId) + case Failure(t) => None + } + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/status/api/v1/sql/api.scala b/sql/core/src/main/scala/org/apache/spark/status/api/v1/sql/api.scala index 7ace66ffb06e1..0ddf66718bce7 100644 --- a/sql/core/src/main/scala/org/apache/spark/status/api/v1/sql/api.scala +++ b/sql/core/src/main/scala/org/apache/spark/status/api/v1/sql/api.scala @@ -18,16 +18,25 @@ package org.apache.spark.status.api.v1.sql import java.util.Date +import org.apache.spark.sql.execution.ui.SparkPlanGraphEdge + class ExecutionData private[spark] ( val id: Long, val status: String, val description: String, val planDescription: String, - val metrics: Seq[Metrics], val submissionTime: Date, val duration: Long, val runningJobIds: Seq[Int], val successJobIds: Seq[Int], - val failedJobIds: Seq[Int]) + val failedJobIds: Seq[Int], + val nodes: Seq[Node], + val edges: Seq[SparkPlanGraphEdge]) + +case class Node private[spark]( + nodeId: Long, + nodeName: String, + wholeStageCodegenId: Option[Long] = None, + metrics: Seq[Metric]) -case class Metrics private[spark] (metricName: String, metricValue: String) +case class Metric private[spark] (name: String, value: String) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java index 5603cb988b8e7..af0a22b036030 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java @@ -18,6 +18,8 @@ package test.org.apache.spark.sql; import java.io.Serializable; +import java.sql.Timestamp; +import java.text.SimpleDateFormat; import java.time.Instant; import java.time.LocalDate; import java.util.*; @@ -210,6 +212,17 @@ private static Row createRecordSpark22000Row(Long index) { return new GenericRow(values); } + private static String timestampToString(Timestamp ts) { + String timestampString = String.valueOf(ts); + String formatted = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(ts); + + if (timestampString.length() > 19 && !timestampString.substring(19).equals(".0")) { + return formatted + timestampString.substring(19); + } else { + return formatted; + } + } + private static RecordSpark22000 createRecordSpark22000(Row recordRow) { RecordSpark22000 record = new RecordSpark22000(); record.setShortField(String.valueOf(recordRow.getShort(0))); @@ -219,7 +232,7 @@ private static RecordSpark22000 createRecordSpark22000(Row recordRow) { record.setDoubleField(String.valueOf(recordRow.getDouble(4))); record.setStringField(recordRow.getString(5)); record.setBooleanField(String.valueOf(recordRow.getBoolean(6))); - record.setTimestampField(String.valueOf(recordRow.getTimestamp(7))); + record.setTimestampField(timestampToString(recordRow.getTimestamp(7))); // This would figure out that null value will not become "null". record.setNullIntField(null); return record; diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 1e22ae2eefeb2..d245aa5a17345 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -1,8 +1,8 @@ - + ## Summary - - Number of queries: 333 + - Number of queries: 337 - Number of expressions that missing example: 34 - - Expressions missing examples: and,string,tinyint,double,smallint,date,decimal,boolean,float,binary,bigint,int,timestamp,cume_dist,dense_rank,input_file_block_length,input_file_block_start,input_file_name,lag,lead,monotonically_increasing_id,ntile,struct,!,not,or,percent_rank,rank,row_number,spark_partition_id,version,window,positive,count_min_sketch + - Expressions missing examples: and,string,tinyint,double,smallint,date,decimal,boolean,float,binary,bigint,int,timestamp,struct,cume_dist,dense_rank,input_file_block_length,input_file_block_start,input_file_name,lag,lead,monotonically_increasing_id,ntile,!,not,or,percent_rank,rank,row_number,spark_partition_id,version,window,positive,count_min_sketch ## Schema of Built-in Functions | Class name | Function name or alias | Query example | Output schema | | ---------- | ---------------------- | ------------- | ------------- | @@ -79,16 +79,18 @@ | org.apache.spark.sql.catalyst.expressions.CreateArray | array | SELECT array(1, 2, 3) | struct> | | org.apache.spark.sql.catalyst.expressions.CreateMap | map | SELECT map(1.0, '2', 3.0, '4') | struct> | | org.apache.spark.sql.catalyst.expressions.CreateNamedStruct | named_struct | SELECT named_struct("a", 1, "b", 2, "c", 3) | struct> | +| org.apache.spark.sql.catalyst.expressions.CreateNamedStruct | struct | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.CsvToStructs | from_csv | SELECT from_csv('1, 0.8', 'a INT, b DOUBLE') | struct> | | org.apache.spark.sql.catalyst.expressions.Cube | cube | SELECT name, age, count(*) FROM VALUES (2, 'Alice'), (5, 'Bob') people(age, name) GROUP BY cube(name, age) | struct | | org.apache.spark.sql.catalyst.expressions.CumeDist | cume_dist | N/A | N/A | +| org.apache.spark.sql.catalyst.expressions.CurrentCatalog | current_catalog | SELECT current_catalog() | struct | | org.apache.spark.sql.catalyst.expressions.CurrentDatabase | current_database | SELECT current_database() | struct | | org.apache.spark.sql.catalyst.expressions.CurrentDate | current_date | SELECT current_date() | struct | | org.apache.spark.sql.catalyst.expressions.CurrentTimestamp | current_timestamp | SELECT current_timestamp() | struct | | org.apache.spark.sql.catalyst.expressions.DateAdd | date_add | SELECT date_add('2016-07-30', 1) | struct | | org.apache.spark.sql.catalyst.expressions.DateDiff | datediff | SELECT datediff('2009-07-31', '2009-07-30') | struct | | org.apache.spark.sql.catalyst.expressions.DateFormatClass | date_format | SELECT date_format('2016-04-08', 'y') | struct | -| org.apache.spark.sql.catalyst.expressions.DatePart | date_part | SELECT date_part('YEAR', TIMESTAMP '2019-08-12 01:00:00.123456') | struct | +| org.apache.spark.sql.catalyst.expressions.DatePart | date_part | SELECT date_part('YEAR', TIMESTAMP '2019-08-12 01:00:00.123456') | struct | | org.apache.spark.sql.catalyst.expressions.DateSub | date_sub | SELECT date_sub('2016-07-30', 1) | struct | | org.apache.spark.sql.catalyst.expressions.DayOfMonth | day | SELECT day('2009-07-30') | struct | | org.apache.spark.sql.catalyst.expressions.DayOfMonth | dayofmonth | SELECT dayofmonth('2009-07-30') | struct | @@ -108,7 +110,7 @@ | org.apache.spark.sql.catalyst.expressions.Explode | explode | SELECT explode(array(10, 20)) | struct | | org.apache.spark.sql.catalyst.expressions.Explode | explode_outer | SELECT explode_outer(array(10, 20)) | struct | | org.apache.spark.sql.catalyst.expressions.Expm1 | expm1 | SELECT expm1(0) | struct | -| org.apache.spark.sql.catalyst.expressions.Extract | extract | SELECT extract(YEAR FROM TIMESTAMP '2019-08-12 01:00:00.123456') | struct | +| org.apache.spark.sql.catalyst.expressions.Extract | extract | SELECT extract(YEAR FROM TIMESTAMP '2019-08-12 01:00:00.123456') | struct | | org.apache.spark.sql.catalyst.expressions.Factorial | factorial | SELECT factorial(5) | struct | | org.apache.spark.sql.catalyst.expressions.FindInSet | find_in_set | SELECT find_in_set('ab','abc,b,ab,c,def') | struct | | org.apache.spark.sql.catalyst.expressions.Flatten | flatten | SELECT flatten(array(array(1, 2), array(3, 4))) | struct> | @@ -128,7 +130,7 @@ | org.apache.spark.sql.catalyst.expressions.Hour | hour | SELECT hour('2009-07-30 12:58:59') | struct | | org.apache.spark.sql.catalyst.expressions.Hypot | hypot | SELECT hypot(3, 4) | struct | | org.apache.spark.sql.catalyst.expressions.If | if | SELECT if(1 < 2, 'a', 'b') | struct<(IF((1 < 2), a, b)):string> | -| org.apache.spark.sql.catalyst.expressions.IfNull | ifnull | SELECT ifnull(NULL, array('2')) | struct> | +| org.apache.spark.sql.catalyst.expressions.IfNull | ifnull | SELECT ifnull(NULL, array('2')) | struct> | | org.apache.spark.sql.catalyst.expressions.In | in | SELECT 1 in(1, 2, 3) | struct<(1 IN (1, 2, 3)):boolean> | | org.apache.spark.sql.catalyst.expressions.InitCap | initcap | SELECT initcap('sPark sql') | struct | | org.apache.spark.sql.catalyst.expressions.Inline | inline | SELECT inline(array(struct(1, 'a'), struct(2, 'b'))) | struct | @@ -136,7 +138,7 @@ | org.apache.spark.sql.catalyst.expressions.InputFileBlockLength | input_file_block_length | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.InputFileBlockStart | input_file_block_start | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.InputFileName | input_file_name | N/A | N/A | -| org.apache.spark.sql.catalyst.expressions.IntegralDivide | div | SELECT 3 div 2 | struct<(3 div 2):bigint> | +| org.apache.spark.sql.catalyst.expressions.IntegralDivide | div | SELECT 3 div 2 | struct<(CAST(3 AS BIGINT) div CAST(2 AS BIGINT)):bigint> | | org.apache.spark.sql.catalyst.expressions.IsNaN | isnan | SELECT isnan(cast('NaN' as double)) | struct | | org.apache.spark.sql.catalyst.expressions.IsNotNull | isnotnull | SELECT isnotnull(1) | struct<(1 IS NOT NULL):boolean> | | org.apache.spark.sql.catalyst.expressions.IsNull | isnull | SELECT isnull(1) | struct<(1 IS NULL):boolean> | @@ -147,7 +149,7 @@ | org.apache.spark.sql.catalyst.expressions.LastDay | last_day | SELECT last_day('2009-01-12') | struct | | org.apache.spark.sql.catalyst.expressions.Lead | lead | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.Least | least | SELECT least(10, 9, 2, 4, 3) | struct | -| org.apache.spark.sql.catalyst.expressions.Left | left | SELECT left('Spark SQL', 3) | struct | +| org.apache.spark.sql.catalyst.expressions.Left | left | SELECT left('Spark SQL', 3) | struct | | org.apache.spark.sql.catalyst.expressions.Length | character_length | SELECT character_length('Spark SQL ') | struct | | org.apache.spark.sql.catalyst.expressions.Length | char_length | SELECT char_length('Spark SQL ') | struct | | org.apache.spark.sql.catalyst.expressions.Length | length | SELECT length('Spark SQL ') | struct | @@ -156,12 +158,12 @@ | org.apache.spark.sql.catalyst.expressions.LessThanOrEqual | <= | SELECT 2 <= 2 | struct<(2 <= 2):boolean> | | org.apache.spark.sql.catalyst.expressions.Levenshtein | levenshtein | SELECT levenshtein('kitten', 'sitting') | struct | | org.apache.spark.sql.catalyst.expressions.Like | like | SELECT like('Spark', '_park') | struct | -| org.apache.spark.sql.catalyst.expressions.Log | ln | SELECT ln(1) | struct | +| org.apache.spark.sql.catalyst.expressions.Log | ln | SELECT ln(1) | struct | | org.apache.spark.sql.catalyst.expressions.Log10 | log10 | SELECT log10(10) | struct | | org.apache.spark.sql.catalyst.expressions.Log1p | log1p | SELECT log1p(0) | struct | | org.apache.spark.sql.catalyst.expressions.Log2 | log2 | SELECT log2(2) | struct | | org.apache.spark.sql.catalyst.expressions.Logarithm | log | SELECT log(10, 100) | struct | -| org.apache.spark.sql.catalyst.expressions.Lower | lcase | SELECT lcase('SparkSql') | struct | +| org.apache.spark.sql.catalyst.expressions.Lower | lcase | SELECT lcase('SparkSql') | struct | | org.apache.spark.sql.catalyst.expressions.Lower | lower | SELECT lower('SparkSql') | struct | | org.apache.spark.sql.catalyst.expressions.MakeDate | make_date | SELECT make_date(2013, 7, 15) | struct | | org.apache.spark.sql.catalyst.expressions.MakeInterval | make_interval | SELECT make_interval(100, 11, 1, 1, 12, 30, 01.001001) | struct | @@ -170,11 +172,13 @@ | org.apache.spark.sql.catalyst.expressions.MapEntries | map_entries | SELECT map_entries(map(1, 'a', 2, 'b')) | struct>> | | org.apache.spark.sql.catalyst.expressions.MapFilter | map_filter | SELECT map_filter(map(1, 0, 2, 2, 3, -1), (k, v) -> k > v) | struct namedlambdavariable()), namedlambdavariable(), namedlambdavariable())):map> | | org.apache.spark.sql.catalyst.expressions.MapFromArrays | map_from_arrays | SELECT map_from_arrays(array(1.0, 3.0), array('2', '4')) | struct> | -| org.apache.spark.sql.catalyst.expressions.MapFromEntries | map_from_entries | SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'b'))) | struct> | +| org.apache.spark.sql.catalyst.expressions.MapFromEntries | map_from_entries | SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'b'))) | struct> | | org.apache.spark.sql.catalyst.expressions.MapKeys | map_keys | SELECT map_keys(map(1, 'a', 2, 'b')) | struct> | | org.apache.spark.sql.catalyst.expressions.MapValues | map_values | SELECT map_values(map(1, 'a', 2, 'b')) | struct> | | org.apache.spark.sql.catalyst.expressions.MapZipWith | map_zip_with | SELECT map_zip_with(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2)) | struct> | | org.apache.spark.sql.catalyst.expressions.Md5 | md5 | SELECT md5('Spark') | struct | +| org.apache.spark.sql.catalyst.expressions.MicrosToTimestamp | timestamp_micros | SELECT timestamp_micros(1230219000123123) | struct | +| org.apache.spark.sql.catalyst.expressions.MillisToTimestamp | timestamp_millis | SELECT timestamp_millis(1230219000123) | struct | | org.apache.spark.sql.catalyst.expressions.Minute | minute | SELECT minute('2009-07-30 12:58:59') | struct | | org.apache.spark.sql.catalyst.expressions.MonotonicallyIncreasingID | monotonically_increasing_id | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.Month | month | SELECT month('2016-07-30') | struct | @@ -183,19 +187,18 @@ | org.apache.spark.sql.catalyst.expressions.Murmur3Hash | hash | SELECT hash('Spark', array(123), 2) | struct | | org.apache.spark.sql.catalyst.expressions.NTile | ntile | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.NaNvl | nanvl | SELECT nanvl(cast('NaN' as double), 123) | struct | -| org.apache.spark.sql.catalyst.expressions.NamedStruct | struct | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.NextDay | next_day | SELECT next_day('2015-01-14', 'TU') | struct | | org.apache.spark.sql.catalyst.expressions.Not | ! | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.Not | not | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.Now | now | SELECT now() | struct | | org.apache.spark.sql.catalyst.expressions.NullIf | nullif | SELECT nullif(2, 2) | struct | -| org.apache.spark.sql.catalyst.expressions.Nvl | nvl | SELECT nvl(NULL, array('2')) | struct> | +| org.apache.spark.sql.catalyst.expressions.Nvl | nvl | SELECT nvl(NULL, array('2')) | struct> | | org.apache.spark.sql.catalyst.expressions.Nvl2 | nvl2 | SELECT nvl2(NULL, 2, 1) | struct | | org.apache.spark.sql.catalyst.expressions.OctetLength | octet_length | SELECT octet_length('Spark SQL') | struct | | org.apache.spark.sql.catalyst.expressions.Or | or | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.Overlay | overlay | SELECT overlay('Spark SQL' PLACING '_' FROM 6) | struct | -| org.apache.spark.sql.catalyst.expressions.ParseToDate | to_date | SELECT to_date('2009-07-30 04:17:52') | struct | -| org.apache.spark.sql.catalyst.expressions.ParseToTimestamp | to_timestamp | SELECT to_timestamp('2016-12-31 00:12:00') | struct | +| org.apache.spark.sql.catalyst.expressions.ParseToDate | to_date | SELECT to_date('2009-07-30 04:17:52') | struct | +| org.apache.spark.sql.catalyst.expressions.ParseToTimestamp | to_timestamp | SELECT to_timestamp('2016-12-31 00:12:00') | struct | | org.apache.spark.sql.catalyst.expressions.ParseUrl | parse_url | SELECT parse_url('http://spark.apache.org/path?query=1', 'HOST') | struct | | org.apache.spark.sql.catalyst.expressions.PercentRank | percent_rank | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.Pi | pi | SELECT pi() | struct | @@ -215,14 +218,15 @@ | org.apache.spark.sql.catalyst.expressions.Remainder | % | SELECT 2 % 1.8 | struct<(CAST(CAST(2 AS DECIMAL(1,0)) AS DECIMAL(2,1)) % CAST(1.8 AS DECIMAL(2,1))):decimal(2,1)> | | org.apache.spark.sql.catalyst.expressions.Remainder | mod | SELECT 2 % 1.8 | struct<(CAST(CAST(2 AS DECIMAL(1,0)) AS DECIMAL(2,1)) % CAST(1.8 AS DECIMAL(2,1))):decimal(2,1)> | | org.apache.spark.sql.catalyst.expressions.Reverse | reverse | SELECT reverse('Spark SQL') | struct | -| org.apache.spark.sql.catalyst.expressions.Right | right | SELECT right('Spark SQL', 3) | struct | -| org.apache.spark.sql.catalyst.expressions.Rint | rint | SELECT rint(12.3456) | struct | +| org.apache.spark.sql.catalyst.expressions.Right | right | SELECT right('Spark SQL', 3) | struct | +| org.apache.spark.sql.catalyst.expressions.Rint | rint | SELECT rint(12.3456) | struct | | org.apache.spark.sql.catalyst.expressions.Rollup | rollup | SELECT name, age, count(*) FROM VALUES (2, 'Alice'), (5, 'Bob') people(age, name) GROUP BY rollup(name, age) | struct | | org.apache.spark.sql.catalyst.expressions.Round | round | SELECT round(2.5, 0) | struct | | org.apache.spark.sql.catalyst.expressions.RowNumber | row_number | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.SchemaOfCsv | schema_of_csv | SELECT schema_of_csv('1,abc') | struct | | org.apache.spark.sql.catalyst.expressions.SchemaOfJson | schema_of_json | SELECT schema_of_json('[{"col":0}]') | struct | | org.apache.spark.sql.catalyst.expressions.Second | second | SELECT second('2009-07-30 12:58:59') | struct | +| org.apache.spark.sql.catalyst.expressions.SecondsToTimestamp | timestamp_seconds | SELECT timestamp_seconds(1230219000) | struct | | org.apache.spark.sql.catalyst.expressions.Sentences | sentences | SELECT sentences('Hi there! Good morning.') | struct>> | | org.apache.spark.sql.catalyst.expressions.Sequence | sequence | SELECT sequence(1, 5) | struct> | | org.apache.spark.sql.catalyst.expressions.Sha1 | sha1 | SELECT sha1('Spark') | struct | @@ -247,7 +251,7 @@ | org.apache.spark.sql.catalyst.expressions.Stack | stack | SELECT stack(2, 1, 2, 3) | struct | | org.apache.spark.sql.catalyst.expressions.StringInstr | instr | SELECT instr('SparkSQL', 'SQL') | struct | | org.apache.spark.sql.catalyst.expressions.StringLPad | lpad | SELECT lpad('hi', 5, '??') | struct | -| org.apache.spark.sql.catalyst.expressions.StringLocate | position | SELECT position('bar', 'foobarbar') | struct | +| org.apache.spark.sql.catalyst.expressions.StringLocate | position | SELECT position('bar', 'foobarbar') | struct | | org.apache.spark.sql.catalyst.expressions.StringLocate | locate | SELECT locate('bar', 'foobarbar') | struct | | org.apache.spark.sql.catalyst.expressions.StringRPad | rpad | SELECT rpad('hi', 5, '??') | struct | | org.apache.spark.sql.catalyst.expressions.StringRepeat | repeat | SELECT repeat('123', 2) | struct | @@ -278,7 +282,7 @@ | org.apache.spark.sql.catalyst.expressions.TruncTimestamp | date_trunc | SELECT date_trunc('YEAR', '2015-03-05T09:32:05.359') | struct | | org.apache.spark.sql.catalyst.expressions.TypeOf | typeof | SELECT typeof(1) | struct | | org.apache.spark.sql.catalyst.expressions.UnBase64 | unbase64 | SELECT unbase64('U3BhcmsgU1FM') | struct | -| org.apache.spark.sql.catalyst.expressions.UnaryMinus | negative | SELECT negative(1) | struct<(- 1):int> | +| org.apache.spark.sql.catalyst.expressions.UnaryMinus | negative | SELECT negative(1) | struct | | org.apache.spark.sql.catalyst.expressions.UnaryPositive | positive | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.Unhex | unhex | SELECT decode(unhex('537061726B2053514C'), 'UTF-8') | struct | | org.apache.spark.sql.catalyst.expressions.UnixTimestamp | unix_timestamp | SELECT unix_timestamp() | struct | @@ -317,9 +321,9 @@ | org.apache.spark.sql.catalyst.expressions.aggregate.Last | last_value | SELECT last_value(col) FROM VALUES (10), (5), (20) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.Last | last | SELECT last(col) FROM VALUES (10), (5), (20) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.Max | max | SELECT max(col) FROM VALUES (10), (50), (20) AS tab(col) | struct | -| org.apache.spark.sql.catalyst.expressions.aggregate.MaxBy | max_by | SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y) | struct | +| org.apache.spark.sql.catalyst.expressions.aggregate.MaxBy | max_by | SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.Min | min | SELECT min(col) FROM VALUES (10), (-1), (20) AS tab(col) | struct | -| org.apache.spark.sql.catalyst.expressions.aggregate.MinBy | min_by | SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y) | struct | +| org.apache.spark.sql.catalyst.expressions.aggregate.MinBy | min_by | SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.Percentile | percentile | SELECT percentile(col, 0.3) FROM VALUES (0), (10) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.Skewness | skewness | SELECT skewness(col) FROM VALUES (-10), (-20), (100), (1000) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.StddevPop | stddev_pop | SELECT stddev_pop(col) FROM VALUES (1), (2), (3) AS tab(col) | struct | diff --git a/sql/core/src/test/resources/sql-tests/inputs/ansi/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/ansi/string-functions.sql new file mode 100644 index 0000000000000..dd28e9b97fb20 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/ansi/string-functions.sql @@ -0,0 +1 @@ +--IMPORT string-functions.sql diff --git a/sql/core/src/test/resources/sql-tests/inputs/current_database_catalog.sql b/sql/core/src/test/resources/sql-tests/inputs/current_database_catalog.sql new file mode 100644 index 0000000000000..4406f1bc2e6e3 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/current_database_catalog.sql @@ -0,0 +1,2 @@ +-- get current_datebase and current_catalog +select current_database(), current_catalog(); diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime-legacy.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime-legacy.sql new file mode 100644 index 0000000000000..daec2b40a620b --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime-legacy.sql @@ -0,0 +1,2 @@ +--SET spark.sql.legacy.timeParserPolicy=LEGACY +--IMPORT datetime.sql diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql index fd3325085df96..9bd936f6f441f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql @@ -1,5 +1,15 @@ -- date time functions +-- [SPARK-31710] TIMESTAMP_SECONDS, TIMESTAMP_MILLISECONDS and TIMESTAMP_MICROSECONDS to timestamp transfer +select TIMESTAMP_SECONDS(1230219000),TIMESTAMP_SECONDS(-1230219000),TIMESTAMP_SECONDS(null); +select TIMESTAMP_MILLIS(1230219000123),TIMESTAMP_MILLIS(-1230219000123),TIMESTAMP_MILLIS(null); +select TIMESTAMP_MICROS(1230219000123123),TIMESTAMP_MICROS(-1230219000123123),TIMESTAMP_MICROS(null); +-- overflow exception: +select TIMESTAMP_SECONDS(1230219000123123); +select TIMESTAMP_SECONDS(-1230219000123123); +select TIMESTAMP_MILLIS(92233720368547758); +select TIMESTAMP_MILLIS(-92233720368547758); + -- [SPARK-16836] current_date and current_timestamp literals select current_date = current_date(), current_timestamp = current_timestamp(); @@ -86,7 +96,7 @@ select date_sub('2011-11-11', str) from v; select null - date '2019-10-06'; select date '2001-10-01' - date '2001-09-28'; --- variable-length tests +-- variable-length second fraction tests select to_timestamp('2019-10-06 10:11:12.', 'yyyy-MM-dd HH:mm:ss.SSSSSS[zzz]'); select to_timestamp('2019-10-06 10:11:12.0', 'yyyy-MM-dd HH:mm:ss.SSSSSS[zzz]'); select to_timestamp('2019-10-06 10:11:12.1', 'yyyy-MM-dd HH:mm:ss.SSSSSS[zzz]'); @@ -95,7 +105,7 @@ select to_timestamp('2019-10-06 10:11:12.123UTC', 'yyyy-MM-dd HH:mm:ss.SSSSSS[zz select to_timestamp('2019-10-06 10:11:12.1234', 'yyyy-MM-dd HH:mm:ss.SSSSSS[zzz]'); select to_timestamp('2019-10-06 10:11:12.12345CST', 'yyyy-MM-dd HH:mm:ss.SSSSSS[zzz]'); select to_timestamp('2019-10-06 10:11:12.123456PST', 'yyyy-MM-dd HH:mm:ss.SSSSSS[zzz]'); --- exceeded max variable length +-- second fraction exceeded max variable length select to_timestamp('2019-10-06 10:11:12.1234567PST', 'yyyy-MM-dd HH:mm:ss.SSSSSS[zzz]'); -- special cases select to_timestamp('123456 2019-10-06 10:11:12.123456PST', 'SSSSSS yyyy-MM-dd HH:mm:ss.SSSSSS[zzz]'); @@ -122,3 +132,35 @@ select to_timestamp("2019-10-06T10:11:12'12", "yyyy-MM-dd'T'HH:mm:ss''SSSS"); -- select to_timestamp("2019-10-06T10:11:12'", "yyyy-MM-dd'T'HH:mm:ss''"); -- tail select to_timestamp("'2019-10-06T10:11:12", "''yyyy-MM-dd'T'HH:mm:ss"); -- head select to_timestamp("P2019-10-06T10:11:12", "'P'yyyy-MM-dd'T'HH:mm:ss"); -- head but as single quote + +-- missing fields +select to_timestamp("16", "dd"); +select to_timestamp("02-29", "MM-dd"); +select to_date("16", "dd"); +select to_date("02-29", "MM-dd"); +select to_timestamp("2019 40", "yyyy mm"); +select to_timestamp("2019 10:10:10", "yyyy hh:mm:ss"); + +-- Unsupported narrow text style +select date_format(date '2020-05-23', 'GGGGG'); +select date_format(date '2020-05-23', 'MMMMM'); +select date_format(date '2020-05-23', 'LLLLL'); +select date_format(timestamp '2020-05-23', 'EEEEE'); +select date_format(timestamp '2020-05-23', 'uuuuu'); +select date_format('2020-05-23', 'QQQQQ'); +select date_format('2020-05-23', 'qqqqq'); +select to_timestamp('2019-10-06 A', 'yyyy-MM-dd GGGGG'); +select to_timestamp('22 05 2020 Friday', 'dd MM yyyy EEEEEE'); +select to_timestamp('22 05 2020 Friday', 'dd MM yyyy EEEEE'); +select unix_timestamp('22 05 2020 Friday', 'dd MM yyyy EEEEE'); +select from_unixtime(12345, 'MMMMM'); +select from_unixtime(54321, 'QQQQQ'); +select from_unixtime(23456, 'aaaaa'); +select from_json('{"time":"26/October/2015"}', 'time Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')); +select from_json('{"date":"26/October/2015"}', 'date Date', map('dateFormat', 'dd/MMMMM/yyyy')); +select from_csv('26/October/2015', 'time Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')); +select from_csv('26/October/2015', 'date Date', map('dateFormat', 'dd/MMMMM/yyyy')); + +select from_unixtime(1, 'yyyyyyyyyyy-MM-dd'); +select date_format(timestamp '2018-11-17 13:33:33', 'yyyyyyyyyy-MM-dd HH:mm:ss'); +select date_format(date '2018-11-17', 'yyyyyyyyyyy-MM-dd'); diff --git a/sql/core/src/test/resources/sql-tests/inputs/extract.sql b/sql/core/src/test/resources/sql-tests/inputs/extract.sql index abb9e82c9ef2e..0f1fd5bbcca0b 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/extract.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/extract.sql @@ -123,3 +123,8 @@ select extract('doy', c) from t; select extract('hour', c) from t; select extract('minute', c) from t; select extract('second', c) from t; + +select c - i from t; +select year(c - i) from t; +select extract(year from c - i) from t; +select extract(month from to_timestamp(c) - i) from t; diff --git a/sql/core/src/test/resources/sql-tests/inputs/having.sql b/sql/core/src/test/resources/sql-tests/inputs/having.sql index 6868b5902939d..3b75be19b5677 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/having.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/having.sql @@ -18,4 +18,9 @@ SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0); SELECT a + b FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a + b > 1; -- SPARK-31519: Cast in having aggregate expressions returns the wrong result -SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10 +SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10; + +-- SPARK-31663: Grouping sets with having clause returns the wrong result +SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY GROUPING SETS ((b), (a, b)) HAVING b > 10; +SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY CUBE(a, b) HAVING b > 10; +SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY ROLLUP(a, b) HAVING b > 10; diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql index 131890fddb0db..f6fa44161a771 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql @@ -48,6 +48,21 @@ select from_json('[null, {"a":2}]', 'array>'); select from_json('[{"a": 1}, {"b":2}]', 'array>'); select from_json('[{"a": 1}, 2]', 'array>'); +-- from_json - datetime type +select from_json('{"d": "2012-12-15", "t": "2012-12-15 15:15:15"}', 'd date, t timestamp'); +select from_json( + '{"d": "12/15 2012", "t": "12/15 2012 15:15:15"}', + 'd date, t timestamp', + map('dateFormat', 'MM/dd yyyy', 'timestampFormat', 'MM/dd yyyy HH:mm:ss')); +select from_json( + '{"d": "02-29"}', + 'd date', + map('dateFormat', 'MM-dd')); +select from_json( + '{"t": "02-29"}', + 't timestamp', + map('timestampFormat', 'MM-dd')); + -- to_json - array type select to_json(array('1', '2', '3')); select to_json(array(array(1, 2, 3), array(4))); diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part1.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part1.sql index 087d7a5befd19..6e95aca7aff62 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part1.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part1.sql @@ -146,7 +146,7 @@ SELECT count(*) OVER (PARTITION BY four) FROM (SELECT * FROM tenk1 WHERE FALSE)s -- mixture of agg/wfunc in the same window -- SELECT sum(salary) OVER w, rank() OVER w FROM empsalary WINDOW w AS (PARTITION BY depname ORDER BY salary DESC); --- Cannot safely cast 'enroll_date': StringType to DateType; +-- Cannot safely cast 'enroll_date': string to date; -- SELECT empno, depname, salary, bonus, depadj, MIN(bonus) OVER (ORDER BY empno), MAX(depadj) OVER () FROM( -- SELECT *, -- CASE WHEN enroll_date < '2008-01-01' THEN 2008 - extract(year FROM enroll_date) END * 500 AS bonus, diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part3.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part3.sql index cd3b74b3aa03f..f4b8454da0d82 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part3.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part3.sql @@ -42,7 +42,7 @@ create table datetimes ( f_timestamp timestamp ) using parquet; --- Spark cannot safely cast StringType to TimestampType +-- Spark cannot safely cast string to timestamp -- [SPARK-29636] Spark can't parse '11:00 BST' or '2000-10-19 10:23:54+01' signatures to timestamp insert into datetimes values (1, timestamp '11:00', cast ('11:00 BST' as timestamp), cast ('1 year' as timestamp), cast ('2000-10-19 10:23:54+01' as timestamp), timestamp '2000-10-19 10:23:54'), diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index 8e33471e8b129..f5ed2036dc8ac 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -48,4 +48,8 @@ SELECT trim(LEADING 'xyz' FROM 'zzzytestxyz'); SELECT trim(LEADING 'xy' FROM 'xyxXxyLAST WORD'); SELECT trim(TRAILING 'xyz' FROM 'testxxzx'); SELECT trim(TRAILING 'xyz' FROM 'xyztestxxzx'); -SELECT trim(TRAILING 'xy' FROM 'TURNERyxXxy'); \ No newline at end of file +SELECT trim(TRAILING 'xy' FROM 'TURNERyxXxy'); + +-- Check lpad/rpad with invalid length parameter +SELECT lpad('hi', 'invalid_length'); +SELECT rpad('hi', 'invalid_length'); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/datetime.sql.out index 05c335b413bf2..ca04b008d6537 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/datetime.sql.out @@ -1,5 +1,65 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 85 +-- Number of queries: 119 + + +-- !query +select TIMESTAMP_SECONDS(1230219000),TIMESTAMP_SECONDS(-1230219000),TIMESTAMP_SECONDS(null) +-- !query schema +struct +-- !query output +2008-12-25 07:30:00 1931-01-07 00:30:00 NULL + + +-- !query +select TIMESTAMP_MILLIS(1230219000123),TIMESTAMP_MILLIS(-1230219000123),TIMESTAMP_MILLIS(null) +-- !query schema +struct +-- !query output +2008-12-25 07:30:00.123 1931-01-07 00:29:59.877 NULL + + +-- !query +select TIMESTAMP_MICROS(1230219000123123),TIMESTAMP_MICROS(-1230219000123123),TIMESTAMP_MICROS(null) +-- !query schema +struct +-- !query output +2008-12-25 07:30:00.123123 1931-01-07 00:29:59.876877 NULL + + +-- !query +select TIMESTAMP_SECONDS(1230219000123123) +-- !query schema +struct<> +-- !query output +java.lang.ArithmeticException +long overflow + + +-- !query +select TIMESTAMP_SECONDS(-1230219000123123) +-- !query schema +struct<> +-- !query output +java.lang.ArithmeticException +long overflow + + +-- !query +select TIMESTAMP_MILLIS(92233720368547758) +-- !query schema +struct<> +-- !query output +java.lang.ArithmeticException +long overflow + + +-- !query +select TIMESTAMP_MILLIS(-92233720368547758) +-- !query schema +struct<> +-- !query output +java.lang.ArithmeticException +long overflow -- !query @@ -19,7 +79,7 @@ select current_date = current_date(), current_timestamp = current_timestamp() -- !query select to_date(null), to_date('2016-12-31'), to_date('2016-12-31', 'yyyy-MM-dd') -- !query schema -struct +struct -- !query output NULL 2016-12-31 2016-12-31 @@ -27,7 +87,7 @@ NULL 2016-12-31 2016-12-31 -- !query select to_timestamp(null), to_timestamp('2016-12-31 00:12:00'), to_timestamp('2016-12-31', 'yyyy-MM-dd') -- !query schema -struct +struct -- !query output NULL 2016-12-31 00:12:00 2016-12-31 00:00:00 @@ -178,7 +238,7 @@ requirement failed: Cannot add hours, minutes or seconds, milliseconds, microsec -- !query select '2011-11-11' - interval '2' day -- !query schema -struct +struct -- !query output 2011-11-09 00:00:00 @@ -186,7 +246,7 @@ struct -- !query select '2011-11-11 11:11:11' - interval '2' second -- !query schema -struct +struct -- !query output 2011-11-11 11:11:09 @@ -194,7 +254,7 @@ struct -- !query select '1' - interval '2' second -- !query schema -struct +struct -- !query output NULL @@ -493,7 +553,7 @@ struct -- !query select to_timestamp('2019-10-06 10:11:12.', 'yyyy-MM-dd HH:mm:ss.SSSSSS[zzz]') -- !query schema -struct +struct -- !query output NULL @@ -501,7 +561,7 @@ NULL -- !query select to_timestamp('2019-10-06 10:11:12.0', 'yyyy-MM-dd HH:mm:ss.SSSSSS[zzz]') -- !query schema -struct +struct -- !query output 2019-10-06 10:11:12 @@ -509,7 +569,7 @@ struct +struct -- !query output 2019-10-06 10:11:12.1 @@ -517,7 +577,7 @@ struct +struct -- !query output 2019-10-06 10:11:12.12 @@ -525,7 +585,7 @@ struct +struct -- !query output 2019-10-06 03:11:12.123 @@ -533,7 +593,7 @@ struct +struct -- !query output 2019-10-06 10:11:12.1234 @@ -541,7 +601,7 @@ struct +struct -- !query output 2019-10-06 08:11:12.12345 @@ -549,7 +609,7 @@ struct +struct -- !query output 2019-10-06 10:11:12.123456 @@ -557,7 +617,7 @@ struct +struct -- !query output NULL @@ -565,7 +625,7 @@ NULL -- !query select to_timestamp('123456 2019-10-06 10:11:12.123456PST', 'SSSSSS yyyy-MM-dd HH:mm:ss.SSSSSS[zzz]') -- !query schema -struct +struct -- !query output 2019-10-06 10:11:12.123456 @@ -573,7 +633,7 @@ struct +struct -- !query output NULL @@ -581,7 +641,7 @@ NULL -- !query select to_timestamp('2019-10-06 10:11:12.1234', 'yyyy-MM-dd HH:mm:ss.[SSSSSS]') -- !query schema -struct +struct -- !query output 2019-10-06 10:11:12.1234 @@ -589,7 +649,7 @@ struct +struct -- !query output 2019-10-06 10:11:12.123 @@ -597,7 +657,7 @@ struct +struct -- !query output 2019-10-06 10:11:12 @@ -605,7 +665,7 @@ struct +struct -- !query output 2019-10-06 10:11:12.12 @@ -613,7 +673,7 @@ struct +struct -- !query output 2019-10-06 10:11:00 @@ -621,7 +681,7 @@ struct +struct -- !query output 2019-10-06 10:11:12.12345 @@ -629,7 +689,7 @@ struct +struct -- !query output 2019-10-06 10:11:12.1234 @@ -637,7 +697,7 @@ struct +struct -- !query output NULL @@ -645,7 +705,7 @@ NULL -- !query select to_timestamp("12.1232019-10-06S10:11", "ss.SSSSyy-MM-dd'S'HH:mm") -- !query schema -struct +struct -- !query output NULL @@ -653,7 +713,7 @@ NULL -- !query select to_timestamp("12.1234019-10-06S10:11", "ss.SSSSy-MM-dd'S'HH:mm") -- !query schema -struct +struct -- !query output 0019-10-06 10:11:12.1234 @@ -661,7 +721,7 @@ struct +struct -- !query output 2019-10-06 00:00:00 @@ -669,7 +729,7 @@ struct -- !query select to_timestamp("S2019-10-06", "'S'yyyy-MM-dd") -- !query schema -struct +struct -- !query output 2019-10-06 00:00:00 @@ -703,7 +763,7 @@ struct -- !query select to_timestamp("2019-10-06T10:11:12'12", "yyyy-MM-dd'T'HH:mm:ss''SSSS") -- !query schema -struct +struct -- !query output 2019-10-06 10:11:12.12 @@ -711,7 +771,7 @@ struct +struct -- !query output 2019-10-06 10:11:12 @@ -719,7 +779,7 @@ struct +struct -- !query output 2019-10-06 10:11:12 @@ -727,6 +787,241 @@ struct +struct -- !query output 2019-10-06 10:11:12 + + +-- !query +select to_timestamp("16", "dd") +-- !query schema +struct +-- !query output +1970-01-16 00:00:00 + + +-- !query +select to_timestamp("02-29", "MM-dd") +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_date("16", "dd") +-- !query schema +struct +-- !query output +1970-01-16 + + +-- !query +select to_date("02-29", "MM-dd") +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_timestamp("2019 40", "yyyy mm") +-- !query schema +struct +-- !query output +2019-01-01 00:40:00 + + +-- !query +select to_timestamp("2019 10:10:10", "yyyy hh:mm:ss") +-- !query schema +struct +-- !query output +2019-01-01 10:10:10 + + +-- !query +select date_format(date '2020-05-23', 'GGGGG') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'GGGGG' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select date_format(date '2020-05-23', 'MMMMM') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'MMMMM' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select date_format(date '2020-05-23', 'LLLLL') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'LLLLL' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select date_format(timestamp '2020-05-23', 'EEEEE') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'EEEEE' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select date_format(timestamp '2020-05-23', 'uuuuu') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'uuuuu' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select date_format('2020-05-23', 'QQQQQ') +-- !query schema +struct<> +-- !query output +java.lang.IllegalArgumentException +Too many pattern letters: Q + + +-- !query +select date_format('2020-05-23', 'qqqqq') +-- !query schema +struct<> +-- !query output +java.lang.IllegalArgumentException +Too many pattern letters: q + + +-- !query +select to_timestamp('2019-10-06 A', 'yyyy-MM-dd GGGGG') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'yyyy-MM-dd GGGGG' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select to_timestamp('22 05 2020 Friday', 'dd MM yyyy EEEEEE') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd MM yyyy EEEEEE' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select to_timestamp('22 05 2020 Friday', 'dd MM yyyy EEEEE') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd MM yyyy EEEEE' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select unix_timestamp('22 05 2020 Friday', 'dd MM yyyy EEEEE') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd MM yyyy EEEEE' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select from_unixtime(12345, 'MMMMM') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'MMMMM' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select from_unixtime(54321, 'QQQQQ') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select from_unixtime(23456, 'aaaaa') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'aaaaa' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select from_json('{"time":"26/October/2015"}', 'time Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd/MMMMM/yyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select from_json('{"date":"26/October/2015"}', 'date Date', map('dateFormat', 'dd/MMMMM/yyyy')) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd/MMMMM/yyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select from_csv('26/October/2015', 'time Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd/MMMMM/yyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select from_csv('26/October/2015', 'date Date', map('dateFormat', 'dd/MMMMM/yyyy')) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd/MMMMM/yyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select from_unixtime(1, 'yyyyyyyyyyy-MM-dd') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'yyyyyyyyyyy-MM-dd' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select date_format(timestamp '2018-11-17 13:33:33', 'yyyyyyyyyy-MM-dd HH:mm:ss') +-- !query schema +struct +-- !query output +0000002018-11-17 13:33:33 + + +-- !query +select date_format(date '2018-11-17', 'yyyyyyyyyyy-MM-dd') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'yyyyyyyyyyy-MM-dd' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out index afe55319d8d17..39b230fd19f3d 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out @@ -697,7 +697,7 @@ select interval '2-2' year to month + dateval from interval_arithmetic -- !query schema -struct +struct -- !query output 2012-01-01 2009-11-01 2014-03-01 2014-03-01 2009-11-01 2009-11-01 2014-03-01 @@ -713,7 +713,7 @@ select interval '2-2' year to month + tsval from interval_arithmetic -- !query schema -struct +struct -- !query output 2012-01-01 00:00:00 2009-11-01 00:00:00 2014-03-01 00:00:00 2014-03-01 00:00:00 2009-11-01 00:00:00 2009-11-01 00:00:00 2014-03-01 00:00:00 @@ -757,7 +757,7 @@ select interval '99 11:22:33.123456789' day to second + tsval from interval_arithmetic -- !query schema -struct +struct -- !query output 2012-01-01 00:00:00 2011-09-23 12:37:26.876544 2012-04-09 11:22:33.123456 2012-04-09 11:22:33.123456 2011-09-23 12:37:26.876544 2011-09-23 12:37:26.876544 2012-04-09 11:22:33.123456 @@ -773,7 +773,7 @@ select interval '99 11:22:33.123456789' day to second + strval from interval_arithmetic -- !query schema -struct +struct -- !query output 2012-01-01 2011-09-23 12:37:26.876544 2012-04-09 11:22:33.123456 2012-04-09 11:22:33.123456 2011-09-23 12:37:26.876544 2011-09-23 12:37:26.876544 2012-04-09 11:22:33.123456 diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out new file mode 100644 index 0000000000000..d5c0acb40bb1e --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out @@ -0,0 +1,296 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 36 + + +-- !query +select concat_ws() +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +requirement failed: concat_ws requires at least one argument.; line 1 pos 7 + + +-- !query +select format_string() +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +requirement failed: format_string() should take at least 1 argument; line 1 pos 7 + + +-- !query +select 'a' || 'b' || 'c' +-- !query schema +struct +-- !query output +abc + + +-- !query +select replace('abc', 'b', '123') +-- !query schema +struct +-- !query output +a123c + + +-- !query +select replace('abc', 'b') +-- !query schema +struct +-- !query output +ac + + +-- !query +select length(uuid()), (uuid() <> uuid()) +-- !query schema +struct +-- !query output +36 true + + +-- !query +select position('bar' in 'foobarbar'), position(null, 'foobarbar'), position('aaads', null) +-- !query schema +struct +-- !query output +4 NULL NULL + + +-- !query +select left("abcd", 2), left("abcd", 5), left("abcd", '2'), left("abcd", null) +-- !query schema +struct +-- !query output +ab abcd ab NULL + + +-- !query +select left(null, -2), left("abcd", -2), left("abcd", 0), left("abcd", 'a') +-- !query schema +struct<> +-- !query output +java.lang.NumberFormatException +invalid input syntax for type numeric: a + + +-- !query +select right("abcd", 2), right("abcd", 5), right("abcd", '2'), right("abcd", null) +-- !query schema +struct +-- !query output +cd abcd cd NULL + + +-- !query +select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a') +-- !query schema +struct<> +-- !query output +java.lang.NumberFormatException +invalid input syntax for type numeric: a + + +-- !query +SELECT split('aa1cc2ee3', '[1-9]+') +-- !query schema +struct> +-- !query output +["aa","cc","ee",""] + + +-- !query +SELECT split('aa1cc2ee3', '[1-9]+', 2) +-- !query schema +struct> +-- !query output +["aa","cc2ee3"] + + +-- !query +SELECT substr('Spark SQL', 5) +-- !query schema +struct +-- !query output +k SQL + + +-- !query +SELECT substr('Spark SQL', -3) +-- !query schema +struct +-- !query output +SQL + + +-- !query +SELECT substr('Spark SQL', 5, 1) +-- !query schema +struct +-- !query output +k + + +-- !query +SELECT substr('Spark SQL' from 5) +-- !query schema +struct +-- !query output +k SQL + + +-- !query +SELECT substr('Spark SQL' from -3) +-- !query schema +struct +-- !query output +SQL + + +-- !query +SELECT substr('Spark SQL' from 5 for 1) +-- !query schema +struct +-- !query output +k + + +-- !query +SELECT substring('Spark SQL', 5) +-- !query schema +struct +-- !query output +k SQL + + +-- !query +SELECT substring('Spark SQL', -3) +-- !query schema +struct +-- !query output +SQL + + +-- !query +SELECT substring('Spark SQL', 5, 1) +-- !query schema +struct +-- !query output +k + + +-- !query +SELECT substring('Spark SQL' from 5) +-- !query schema +struct +-- !query output +k SQL + + +-- !query +SELECT substring('Spark SQL' from -3) +-- !query schema +struct +-- !query output +SQL + + +-- !query +SELECT substring('Spark SQL' from 5 for 1) +-- !query schema +struct +-- !query output +k + + +-- !query +SELECT trim(" xyz "), ltrim(" xyz "), rtrim(" xyz ") +-- !query schema +struct +-- !query output +xyz xyz xyz + + +-- !query +SELECT trim(BOTH 'xyz' FROM 'yxTomxx'), trim('xyz' FROM 'yxTomxx') +-- !query schema +struct +-- !query output +Tom Tom + + +-- !query +SELECT trim(BOTH 'x' FROM 'xxxbarxxx'), trim('x' FROM 'xxxbarxxx') +-- !query schema +struct +-- !query output +bar bar + + +-- !query +SELECT trim(LEADING 'xyz' FROM 'zzzytest') +-- !query schema +struct +-- !query output +test + + +-- !query +SELECT trim(LEADING 'xyz' FROM 'zzzytestxyz') +-- !query schema +struct +-- !query output +testxyz + + +-- !query +SELECT trim(LEADING 'xy' FROM 'xyxXxyLAST WORD') +-- !query schema +struct +-- !query output +XxyLAST WORD + + +-- !query +SELECT trim(TRAILING 'xyz' FROM 'testxxzx') +-- !query schema +struct +-- !query output +test + + +-- !query +SELECT trim(TRAILING 'xyz' FROM 'xyztestxxzx') +-- !query schema +struct +-- !query output +xyztest + + +-- !query +SELECT trim(TRAILING 'xy' FROM 'TURNERyxXxy') +-- !query schema +struct +-- !query output +TURNERyxX + + +-- !query +SELECT lpad('hi', 'invalid_length') +-- !query schema +struct<> +-- !query output +java.lang.NumberFormatException +invalid input syntax for type numeric: invalid_length + + +-- !query +SELECT rpad('hi', 'invalid_length') +-- !query schema +struct<> +-- !query output +java.lang.NumberFormatException +invalid input syntax for type numeric: invalid_length diff --git a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out index be7fa5e9d5ff4..1e3173172a528 100644 --- a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out @@ -130,7 +130,7 @@ struct -- !query select to_csv(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')) -- !query schema -struct +struct -- !query output 26/08/2015 diff --git a/sql/core/src/test/resources/sql-tests/results/current_database_catalog.sql.out b/sql/core/src/test/resources/sql-tests/results/current_database_catalog.sql.out new file mode 100644 index 0000000000000..b714463a0aa0c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/current_database_catalog.sql.out @@ -0,0 +1,10 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 1 + + +-- !query +select current_database(), current_catalog() +-- !query schema +struct +-- !query output +default spark_catalog diff --git a/sql/core/src/test/resources/sql-tests/results/datetime-legacy.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime-legacy.sql.out new file mode 100644 index 0000000000000..fe932d3a706a8 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/datetime-legacy.sql.out @@ -0,0 +1,982 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 119 + + +-- !query +select TIMESTAMP_SECONDS(1230219000),TIMESTAMP_SECONDS(-1230219000),TIMESTAMP_SECONDS(null) +-- !query schema +struct +-- !query output +2008-12-25 07:30:00 1931-01-07 00:30:00 NULL + + +-- !query +select TIMESTAMP_MILLIS(1230219000123),TIMESTAMP_MILLIS(-1230219000123),TIMESTAMP_MILLIS(null) +-- !query schema +struct +-- !query output +2008-12-25 07:30:00.123 1931-01-07 00:29:59.877 NULL + + +-- !query +select TIMESTAMP_MICROS(1230219000123123),TIMESTAMP_MICROS(-1230219000123123),TIMESTAMP_MICROS(null) +-- !query schema +struct +-- !query output +2008-12-25 07:30:00.123123 1931-01-07 00:29:59.876877 NULL + + +-- !query +select TIMESTAMP_SECONDS(1230219000123123) +-- !query schema +struct<> +-- !query output +java.lang.ArithmeticException +long overflow + + +-- !query +select TIMESTAMP_SECONDS(-1230219000123123) +-- !query schema +struct<> +-- !query output +java.lang.ArithmeticException +long overflow + + +-- !query +select TIMESTAMP_MILLIS(92233720368547758) +-- !query schema +struct<> +-- !query output +java.lang.ArithmeticException +long overflow + + +-- !query +select TIMESTAMP_MILLIS(-92233720368547758) +-- !query schema +struct<> +-- !query output +java.lang.ArithmeticException +long overflow + + +-- !query +select current_date = current_date(), current_timestamp = current_timestamp() +-- !query schema +struct<(current_date() = current_date()):boolean,(current_timestamp() = current_timestamp()):boolean> +-- !query output +true true + + +-- !query +select to_date(null), to_date('2016-12-31'), to_date('2016-12-31', 'yyyy-MM-dd') +-- !query schema +struct +-- !query output +NULL 2016-12-31 2016-12-31 + + +-- !query +select to_timestamp(null), to_timestamp('2016-12-31 00:12:00'), to_timestamp('2016-12-31', 'yyyy-MM-dd') +-- !query schema +struct +-- !query output +NULL 2016-12-31 00:12:00 2016-12-31 00:00:00 + + +-- !query +select dayofweek('2007-02-03'), dayofweek('2009-07-30'), dayofweek('2017-05-27'), dayofweek(null), dayofweek('1582-10-15 13:10:15') +-- !query schema +struct +-- !query output +7 5 7 NULL 6 + + +-- !query +create temporary view ttf1 as select * from values + (1, 2), + (2, 3) + as ttf1(current_date, current_timestamp) +-- !query schema +struct<> +-- !query output + + + +-- !query +select current_date, current_timestamp from ttf1 +-- !query schema +struct +-- !query output +1 2 +2 3 + + +-- !query +create temporary view ttf2 as select * from values + (1, 2), + (2, 3) + as ttf2(a, b) +-- !query schema +struct<> +-- !query output + + + +-- !query +select current_date = current_date(), current_timestamp = current_timestamp(), a, b from ttf2 +-- !query schema +struct<(current_date() = current_date()):boolean,(current_timestamp() = current_timestamp()):boolean,a:int,b:int> +-- !query output +true true 1 2 +true true 2 3 + + +-- !query +select a, b from ttf2 order by a, current_date +-- !query schema +struct +-- !query output +1 2 +2 3 + + +-- !query +select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15') +-- !query schema +struct +-- !query output +5 3 5 NULL 4 + + +-- !query +select year('1500-01-01'), month('1500-01-01'), dayOfYear('1500-01-01') +-- !query schema +struct +-- !query output +1500 1 1 + + +-- !query +select date '2019-01-01\t' +-- !query schema +struct +-- !query output +2019-01-01 + + +-- !query +select timestamp '2019-01-01\t' +-- !query schema +struct +-- !query output +2019-01-01 00:00:00 + + +-- !query +select timestamp'2011-11-11 11:11:11' + interval '2' day +-- !query schema +struct +-- !query output +2011-11-13 11:11:11 + + +-- !query +select timestamp'2011-11-11 11:11:11' - interval '2' day +-- !query schema +struct +-- !query output +2011-11-09 11:11:11 + + +-- !query +select date'2011-11-11 11:11:11' + interval '2' second +-- !query schema +struct +-- !query output +2011-11-11 + + +-- !query +select date'2011-11-11 11:11:11' - interval '2' second +-- !query schema +struct +-- !query output +2011-11-10 + + +-- !query +select '2011-11-11' - interval '2' day +-- !query schema +struct +-- !query output +2011-11-09 00:00:00 + + +-- !query +select '2011-11-11 11:11:11' - interval '2' second +-- !query schema +struct +-- !query output +2011-11-11 11:11:09 + + +-- !query +select '1' - interval '2' second +-- !query schema +struct +-- !query output +NULL + + +-- !query +select 1 - interval '2' second +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +cannot resolve '1 + (- INTERVAL '2 seconds')' due to data type mismatch: argument 1 requires timestamp type, however, '1' is of int type.; line 1 pos 7 + + +-- !query +select date'2020-01-01' - timestamp'2019-10-06 10:11:12.345678' +-- !query schema +struct +-- !query output +2078 hours 48 minutes 47.654322 seconds + + +-- !query +select timestamp'2019-10-06 10:11:12.345678' - date'2020-01-01' +-- !query schema +struct +-- !query output +-2078 hours -48 minutes -47.654322 seconds + + +-- !query +select timestamp'2019-10-06 10:11:12.345678' - null +-- !query schema +struct +-- !query output +NULL + + +-- !query +select null - timestamp'2019-10-06 10:11:12.345678' +-- !query schema +struct +-- !query output +NULL + + +-- !query +select date_add('2011-11-11', 1Y) +-- !query schema +struct +-- !query output +2011-11-12 + + +-- !query +select date_add('2011-11-11', 1S) +-- !query schema +struct +-- !query output +2011-11-12 + + +-- !query +select date_add('2011-11-11', 1) +-- !query schema +struct +-- !query output +2011-11-12 + + +-- !query +select date_add('2011-11-11', 1L) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +cannot resolve 'date_add(CAST('2011-11-11' AS DATE), 1L)' due to data type mismatch: argument 2 requires (int or smallint or tinyint) type, however, '1L' is of bigint type.; line 1 pos 7 + + +-- !query +select date_add('2011-11-11', 1.0) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +cannot resolve 'date_add(CAST('2011-11-11' AS DATE), 1.0BD)' due to data type mismatch: argument 2 requires (int or smallint or tinyint) type, however, '1.0BD' is of decimal(2,1) type.; line 1 pos 7 + + +-- !query +select date_add('2011-11-11', 1E1) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +cannot resolve 'date_add(CAST('2011-11-11' AS DATE), 10.0D)' due to data type mismatch: argument 2 requires (int or smallint or tinyint) type, however, '10.0D' is of double type.; line 1 pos 7 + + +-- !query +select date_add('2011-11-11', '1') +-- !query schema +struct +-- !query output +2011-11-12 + + +-- !query +select date_add('2011-11-11', '1.2') +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +The second argument of 'date_add' function needs to be an integer.; + + +-- !query +select date_add(date'2011-11-11', 1) +-- !query schema +struct +-- !query output +2011-11-12 + + +-- !query +select date_add(timestamp'2011-11-11', 1) +-- !query schema +struct +-- !query output +2011-11-12 + + +-- !query +select date_sub(date'2011-11-11', 1) +-- !query schema +struct +-- !query output +2011-11-10 + + +-- !query +select date_sub(date'2011-11-11', '1') +-- !query schema +struct +-- !query output +2011-11-10 + + +-- !query +select date_sub(date'2011-11-11', '1.2') +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +The second argument of 'date_sub' function needs to be an integer.; + + +-- !query +select date_sub(timestamp'2011-11-11', 1) +-- !query schema +struct +-- !query output +2011-11-10 + + +-- !query +select date_sub(null, 1) +-- !query schema +struct +-- !query output +NULL + + +-- !query +select date_sub(date'2011-11-11', null) +-- !query schema +struct +-- !query output +NULL + + +-- !query +select date'2011-11-11' + 1E1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +cannot resolve 'date_add(DATE '2011-11-11', 10.0D)' due to data type mismatch: argument 2 requires (int or smallint or tinyint) type, however, '10.0D' is of double type.; line 1 pos 7 + + +-- !query +select date'2011-11-11' + '1' +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +cannot resolve 'date_add(DATE '2011-11-11', CAST('1' AS DOUBLE))' due to data type mismatch: argument 2 requires (int or smallint or tinyint) type, however, 'CAST('1' AS DOUBLE)' is of double type.; line 1 pos 7 + + +-- !query +select null + date '2001-09-28' +-- !query schema +struct +-- !query output +NULL + + +-- !query +select date '2001-09-28' + 7Y +-- !query schema +struct +-- !query output +2001-10-05 + + +-- !query +select 7S + date '2001-09-28' +-- !query schema +struct +-- !query output +2001-10-05 + + +-- !query +select date '2001-10-01' - 7 +-- !query schema +struct +-- !query output +2001-09-24 + + +-- !query +select date '2001-10-01' - '7' +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +cannot resolve 'date_sub(DATE '2001-10-01', CAST('7' AS DOUBLE))' due to data type mismatch: argument 2 requires (int or smallint or tinyint) type, however, 'CAST('7' AS DOUBLE)' is of double type.; line 1 pos 7 + + +-- !query +select date '2001-09-28' + null +-- !query schema +struct +-- !query output +NULL + + +-- !query +select date '2001-09-28' - null +-- !query schema +struct +-- !query output +NULL + + +-- !query +create temp view v as select '1' str +-- !query schema +struct<> +-- !query output + + + +-- !query +select date_add('2011-11-11', str) from v +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +cannot resolve 'date_add(CAST('2011-11-11' AS DATE), v.`str`)' due to data type mismatch: argument 2 requires (int or smallint or tinyint) type, however, 'v.`str`' is of string type.; line 1 pos 7 + + +-- !query +select date_sub('2011-11-11', str) from v +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +cannot resolve 'date_sub(CAST('2011-11-11' AS DATE), v.`str`)' due to data type mismatch: argument 2 requires (int or smallint or tinyint) type, however, 'v.`str`' is of string type.; line 1 pos 7 + + +-- !query +select null - date '2019-10-06' +-- !query schema +struct +-- !query output +NULL + + +-- !query +select date '2001-10-01' - date '2001-09-28' +-- !query schema +struct +-- !query output +3 days + + +-- !query +select to_timestamp('2019-10-06 10:11:12.', 'yyyy-MM-dd HH:mm:ss.SSSSSS[zzz]') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_timestamp('2019-10-06 10:11:12.0', 'yyyy-MM-dd HH:mm:ss.SSSSSS[zzz]') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_timestamp('2019-10-06 10:11:12.1', 'yyyy-MM-dd HH:mm:ss.SSSSSS[zzz]') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_timestamp('2019-10-06 10:11:12.12', 'yyyy-MM-dd HH:mm:ss.SSSSSS[zzz]') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_timestamp('2019-10-06 10:11:12.123UTC', 'yyyy-MM-dd HH:mm:ss.SSSSSS[zzz]') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_timestamp('2019-10-06 10:11:12.1234', 'yyyy-MM-dd HH:mm:ss.SSSSSS[zzz]') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_timestamp('2019-10-06 10:11:12.12345CST', 'yyyy-MM-dd HH:mm:ss.SSSSSS[zzz]') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_timestamp('2019-10-06 10:11:12.123456PST', 'yyyy-MM-dd HH:mm:ss.SSSSSS[zzz]') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_timestamp('2019-10-06 10:11:12.1234567PST', 'yyyy-MM-dd HH:mm:ss.SSSSSS[zzz]') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_timestamp('123456 2019-10-06 10:11:12.123456PST', 'SSSSSS yyyy-MM-dd HH:mm:ss.SSSSSS[zzz]') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_timestamp('223456 2019-10-06 10:11:12.123456PST', 'SSSSSS yyyy-MM-dd HH:mm:ss.SSSSSS[zzz]') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_timestamp('2019-10-06 10:11:12.1234', 'yyyy-MM-dd HH:mm:ss.[SSSSSS]') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_timestamp('2019-10-06 10:11:12.123', 'yyyy-MM-dd HH:mm:ss[.SSSSSS]') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_timestamp('2019-10-06 10:11:12', 'yyyy-MM-dd HH:mm:ss[.SSSSSS]') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_timestamp('2019-10-06 10:11:12.12', 'yyyy-MM-dd HH:mm[:ss.SSSSSS]') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_timestamp('2019-10-06 10:11', 'yyyy-MM-dd HH:mm[:ss.SSSSSS]') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_timestamp("2019-10-06S10:11:12.12345", "yyyy-MM-dd'S'HH:mm:ss.SSSSSS") +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_timestamp("12.12342019-10-06S10:11", "ss.SSSSyyyy-MM-dd'S'HH:mm") +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_timestamp("12.1232019-10-06S10:11", "ss.SSSSyyyy-MM-dd'S'HH:mm") +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_timestamp("12.1232019-10-06S10:11", "ss.SSSSyy-MM-dd'S'HH:mm") +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_timestamp("12.1234019-10-06S10:11", "ss.SSSSy-MM-dd'S'HH:mm") +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_timestamp("2019-10-06S", "yyyy-MM-dd'S'") +-- !query schema +struct +-- !query output +2019-10-06 00:00:00 + + +-- !query +select to_timestamp("S2019-10-06", "'S'yyyy-MM-dd") +-- !query schema +struct +-- !query output +2019-10-06 00:00:00 + + +-- !query +select date_format(timestamp '2019-10-06', 'yyyy-MM-dd uuee') +-- !query schema +struct<> +-- !query output +java.lang.IllegalArgumentException +Illegal pattern character 'e' + + +-- !query +select date_format(timestamp '2019-10-06', 'yyyy-MM-dd uucc') +-- !query schema +struct<> +-- !query output +java.lang.IllegalArgumentException +Illegal pattern character 'c' + + +-- !query +select date_format(timestamp '2019-10-06', 'yyyy-MM-dd uuuu') +-- !query schema +struct +-- !query output +2019-10-06 0007 + + +-- !query +select to_timestamp("2019-10-06T10:11:12'12", "yyyy-MM-dd'T'HH:mm:ss''SSSS") +-- !query schema +struct +-- !query output +2019-10-06 10:11:12.012 + + +-- !query +select to_timestamp("2019-10-06T10:11:12'", "yyyy-MM-dd'T'HH:mm:ss''") +-- !query schema +struct +-- !query output +2019-10-06 10:11:12 + + +-- !query +select to_timestamp("'2019-10-06T10:11:12", "''yyyy-MM-dd'T'HH:mm:ss") +-- !query schema +struct +-- !query output +2019-10-06 10:11:12 + + +-- !query +select to_timestamp("P2019-10-06T10:11:12", "'P'yyyy-MM-dd'T'HH:mm:ss") +-- !query schema +struct +-- !query output +2019-10-06 10:11:12 + + +-- !query +select to_timestamp("16", "dd") +-- !query schema +struct +-- !query output +1970-01-16 00:00:00 + + +-- !query +select to_timestamp("02-29", "MM-dd") +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_date("16", "dd") +-- !query schema +struct +-- !query output +1970-01-16 + + +-- !query +select to_date("02-29", "MM-dd") +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_timestamp("2019 40", "yyyy mm") +-- !query schema +struct +-- !query output +2019-01-01 00:40:00 + + +-- !query +select to_timestamp("2019 10:10:10", "yyyy hh:mm:ss") +-- !query schema +struct +-- !query output +2019-01-01 10:10:10 + + +-- !query +select date_format(date '2020-05-23', 'GGGGG') +-- !query schema +struct +-- !query output +AD + + +-- !query +select date_format(date '2020-05-23', 'MMMMM') +-- !query schema +struct +-- !query output +May + + +-- !query +select date_format(date '2020-05-23', 'LLLLL') +-- !query schema +struct +-- !query output +May + + +-- !query +select date_format(timestamp '2020-05-23', 'EEEEE') +-- !query schema +struct +-- !query output +Saturday + + +-- !query +select date_format(timestamp '2020-05-23', 'uuuuu') +-- !query schema +struct +-- !query output +00006 + + +-- !query +select date_format('2020-05-23', 'QQQQQ') +-- !query schema +struct<> +-- !query output +java.lang.IllegalArgumentException +Illegal pattern character 'Q' + + +-- !query +select date_format('2020-05-23', 'qqqqq') +-- !query schema +struct<> +-- !query output +java.lang.IllegalArgumentException +Illegal pattern character 'q' + + +-- !query +select to_timestamp('2019-10-06 A', 'yyyy-MM-dd GGGGG') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_timestamp('22 05 2020 Friday', 'dd MM yyyy EEEEEE') +-- !query schema +struct +-- !query output +2020-05-22 00:00:00 + + +-- !query +select to_timestamp('22 05 2020 Friday', 'dd MM yyyy EEEEE') +-- !query schema +struct +-- !query output +2020-05-22 00:00:00 + + +-- !query +select unix_timestamp('22 05 2020 Friday', 'dd MM yyyy EEEEE') +-- !query schema +struct +-- !query output +1590130800 + + +-- !query +select from_unixtime(12345, 'MMMMM') +-- !query schema +struct +-- !query output +December + + +-- !query +select from_unixtime(54321, 'QQQQQ') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select from_unixtime(23456, 'aaaaa') +-- !query schema +struct +-- !query output +PM + + +-- !query +select from_json('{"time":"26/October/2015"}', 'time Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')) +-- !query schema +struct> +-- !query output +{"time":2015-10-26 00:00:00} + + +-- !query +select from_json('{"date":"26/October/2015"}', 'date Date', map('dateFormat', 'dd/MMMMM/yyyy')) +-- !query schema +struct> +-- !query output +{"date":2015-10-26} + + +-- !query +select from_csv('26/October/2015', 'time Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')) +-- !query schema +struct> +-- !query output +{"time":2015-10-26 00:00:00} + + +-- !query +select from_csv('26/October/2015', 'date Date', map('dateFormat', 'dd/MMMMM/yyyy')) +-- !query schema +struct> +-- !query output +{"date":2015-10-26} + + +-- !query +select from_unixtime(1, 'yyyyyyyyyyy-MM-dd') +-- !query schema +struct +-- !query output +00000001969-12-31 + + +-- !query +select date_format(timestamp '2018-11-17 13:33:33', 'yyyyyyyyyy-MM-dd HH:mm:ss') +-- !query schema +struct +-- !query output +0000002018-11-17 13:33:33 + + +-- !query +select date_format(date '2018-11-17', 'yyyyyyyyyyy-MM-dd') +-- !query schema +struct +-- !query output +00000002018-11-17 diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index e599153f348df..06a41da2671e6 100755 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -1,5 +1,65 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 85 +-- Number of queries: 119 + + +-- !query +select TIMESTAMP_SECONDS(1230219000),TIMESTAMP_SECONDS(-1230219000),TIMESTAMP_SECONDS(null) +-- !query schema +struct +-- !query output +2008-12-25 07:30:00 1931-01-07 00:30:00 NULL + + +-- !query +select TIMESTAMP_MILLIS(1230219000123),TIMESTAMP_MILLIS(-1230219000123),TIMESTAMP_MILLIS(null) +-- !query schema +struct +-- !query output +2008-12-25 07:30:00.123 1931-01-07 00:29:59.877 NULL + + +-- !query +select TIMESTAMP_MICROS(1230219000123123),TIMESTAMP_MICROS(-1230219000123123),TIMESTAMP_MICROS(null) +-- !query schema +struct +-- !query output +2008-12-25 07:30:00.123123 1931-01-07 00:29:59.876877 NULL + + +-- !query +select TIMESTAMP_SECONDS(1230219000123123) +-- !query schema +struct<> +-- !query output +java.lang.ArithmeticException +long overflow + + +-- !query +select TIMESTAMP_SECONDS(-1230219000123123) +-- !query schema +struct<> +-- !query output +java.lang.ArithmeticException +long overflow + + +-- !query +select TIMESTAMP_MILLIS(92233720368547758) +-- !query schema +struct<> +-- !query output +java.lang.ArithmeticException +long overflow + + +-- !query +select TIMESTAMP_MILLIS(-92233720368547758) +-- !query schema +struct<> +-- !query output +java.lang.ArithmeticException +long overflow -- !query @@ -13,7 +73,7 @@ true true -- !query select to_date(null), to_date('2016-12-31'), to_date('2016-12-31', 'yyyy-MM-dd') -- !query schema -struct +struct -- !query output NULL 2016-12-31 2016-12-31 @@ -21,7 +81,7 @@ NULL 2016-12-31 2016-12-31 -- !query select to_timestamp(null), to_timestamp('2016-12-31 00:12:00'), to_timestamp('2016-12-31', 'yyyy-MM-dd') -- !query schema -struct +struct -- !query output NULL 2016-12-31 00:12:00 2016-12-31 00:00:00 @@ -150,7 +210,7 @@ struct -- !query select '2011-11-11' - interval '2' day -- !query schema -struct +struct -- !query output 2011-11-09 00:00:00 @@ -158,7 +218,7 @@ struct -- !query select '2011-11-11 11:11:11' - interval '2' second -- !query schema -struct +struct -- !query output 2011-11-11 11:11:09 @@ -166,7 +226,7 @@ struct -- !query select '1' - interval '2' second -- !query schema -struct +struct -- !query output NULL @@ -465,7 +525,7 @@ struct -- !query select to_timestamp('2019-10-06 10:11:12.', 'yyyy-MM-dd HH:mm:ss.SSSSSS[zzz]') -- !query schema -struct +struct -- !query output NULL @@ -473,7 +533,7 @@ NULL -- !query select to_timestamp('2019-10-06 10:11:12.0', 'yyyy-MM-dd HH:mm:ss.SSSSSS[zzz]') -- !query schema -struct +struct -- !query output 2019-10-06 10:11:12 @@ -481,7 +541,7 @@ struct +struct -- !query output 2019-10-06 10:11:12.1 @@ -489,7 +549,7 @@ struct +struct -- !query output 2019-10-06 10:11:12.12 @@ -497,7 +557,7 @@ struct +struct -- !query output 2019-10-06 03:11:12.123 @@ -505,7 +565,7 @@ struct +struct -- !query output 2019-10-06 10:11:12.1234 @@ -513,7 +573,7 @@ struct +struct -- !query output 2019-10-06 08:11:12.12345 @@ -521,7 +581,7 @@ struct +struct -- !query output 2019-10-06 10:11:12.123456 @@ -529,7 +589,7 @@ struct +struct -- !query output NULL @@ -537,7 +597,7 @@ NULL -- !query select to_timestamp('123456 2019-10-06 10:11:12.123456PST', 'SSSSSS yyyy-MM-dd HH:mm:ss.SSSSSS[zzz]') -- !query schema -struct +struct -- !query output 2019-10-06 10:11:12.123456 @@ -545,7 +605,7 @@ struct +struct -- !query output NULL @@ -553,7 +613,7 @@ NULL -- !query select to_timestamp('2019-10-06 10:11:12.1234', 'yyyy-MM-dd HH:mm:ss.[SSSSSS]') -- !query schema -struct +struct -- !query output 2019-10-06 10:11:12.1234 @@ -561,7 +621,7 @@ struct +struct -- !query output 2019-10-06 10:11:12.123 @@ -569,7 +629,7 @@ struct +struct -- !query output 2019-10-06 10:11:12 @@ -577,7 +637,7 @@ struct +struct -- !query output 2019-10-06 10:11:12.12 @@ -585,7 +645,7 @@ struct +struct -- !query output 2019-10-06 10:11:00 @@ -593,7 +653,7 @@ struct +struct -- !query output 2019-10-06 10:11:12.12345 @@ -601,7 +661,7 @@ struct +struct -- !query output 2019-10-06 10:11:12.1234 @@ -609,7 +669,7 @@ struct +struct -- !query output NULL @@ -617,7 +677,7 @@ NULL -- !query select to_timestamp("12.1232019-10-06S10:11", "ss.SSSSyy-MM-dd'S'HH:mm") -- !query schema -struct +struct -- !query output NULL @@ -625,7 +685,7 @@ NULL -- !query select to_timestamp("12.1234019-10-06S10:11", "ss.SSSSy-MM-dd'S'HH:mm") -- !query schema -struct +struct -- !query output 0019-10-06 10:11:12.1234 @@ -633,7 +693,7 @@ struct +struct -- !query output 2019-10-06 00:00:00 @@ -641,7 +701,7 @@ struct -- !query select to_timestamp("S2019-10-06", "'S'yyyy-MM-dd") -- !query schema -struct +struct -- !query output 2019-10-06 00:00:00 @@ -675,7 +735,7 @@ struct -- !query select to_timestamp("2019-10-06T10:11:12'12", "yyyy-MM-dd'T'HH:mm:ss''SSSS") -- !query schema -struct +struct -- !query output 2019-10-06 10:11:12.12 @@ -683,7 +743,7 @@ struct +struct -- !query output 2019-10-06 10:11:12 @@ -691,7 +751,7 @@ struct +struct -- !query output 2019-10-06 10:11:12 @@ -699,6 +759,241 @@ struct +struct -- !query output 2019-10-06 10:11:12 + + +-- !query +select to_timestamp("16", "dd") +-- !query schema +struct +-- !query output +1970-01-16 00:00:00 + + +-- !query +select to_timestamp("02-29", "MM-dd") +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_date("16", "dd") +-- !query schema +struct +-- !query output +1970-01-16 + + +-- !query +select to_date("02-29", "MM-dd") +-- !query schema +struct +-- !query output +NULL + + +-- !query +select to_timestamp("2019 40", "yyyy mm") +-- !query schema +struct +-- !query output +2019-01-01 00:40:00 + + +-- !query +select to_timestamp("2019 10:10:10", "yyyy hh:mm:ss") +-- !query schema +struct +-- !query output +2019-01-01 10:10:10 + + +-- !query +select date_format(date '2020-05-23', 'GGGGG') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'GGGGG' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select date_format(date '2020-05-23', 'MMMMM') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'MMMMM' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select date_format(date '2020-05-23', 'LLLLL') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'LLLLL' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select date_format(timestamp '2020-05-23', 'EEEEE') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'EEEEE' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select date_format(timestamp '2020-05-23', 'uuuuu') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'uuuuu' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select date_format('2020-05-23', 'QQQQQ') +-- !query schema +struct<> +-- !query output +java.lang.IllegalArgumentException +Too many pattern letters: Q + + +-- !query +select date_format('2020-05-23', 'qqqqq') +-- !query schema +struct<> +-- !query output +java.lang.IllegalArgumentException +Too many pattern letters: q + + +-- !query +select to_timestamp('2019-10-06 A', 'yyyy-MM-dd GGGGG') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'yyyy-MM-dd GGGGG' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select to_timestamp('22 05 2020 Friday', 'dd MM yyyy EEEEEE') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd MM yyyy EEEEEE' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select to_timestamp('22 05 2020 Friday', 'dd MM yyyy EEEEE') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd MM yyyy EEEEE' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select unix_timestamp('22 05 2020 Friday', 'dd MM yyyy EEEEE') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd MM yyyy EEEEE' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select from_unixtime(12345, 'MMMMM') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'MMMMM' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select from_unixtime(54321, 'QQQQQ') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select from_unixtime(23456, 'aaaaa') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'aaaaa' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select from_json('{"time":"26/October/2015"}', 'time Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd/MMMMM/yyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select from_json('{"date":"26/October/2015"}', 'date Date', map('dateFormat', 'dd/MMMMM/yyyy')) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd/MMMMM/yyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select from_csv('26/October/2015', 'time Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd/MMMMM/yyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select from_csv('26/October/2015', 'date Date', map('dateFormat', 'dd/MMMMM/yyyy')) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd/MMMMM/yyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select from_unixtime(1, 'yyyyyyyyyyy-MM-dd') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'yyyyyyyyyyy-MM-dd' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + +-- !query +select date_format(timestamp '2018-11-17 13:33:33', 'yyyyyyyyyy-MM-dd HH:mm:ss') +-- !query schema +struct +-- !query output +0000002018-11-17 13:33:33 + + +-- !query +select date_format(date '2018-11-17', 'yyyyyyyyyyy-MM-dd') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'yyyyyyyyyyy-MM-dd' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html diff --git a/sql/core/src/test/resources/sql-tests/results/extract.sql.out b/sql/core/src/test/resources/sql-tests/results/extract.sql.out index 29cbefdb38541..9d3fe5d17fafa 100644 --- a/sql/core/src/test/resources/sql-tests/results/extract.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/extract.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 96 +-- Number of queries: 100 -- !query @@ -13,7 +13,7 @@ struct<> -- !query select extract(year from c), extract(year from i) from t -- !query schema -struct +struct -- !query output 2011 11 @@ -21,7 +21,7 @@ struct -- !query select extract(y from c), extract(y from i) from t -- !query schema -struct +struct -- !query output 2011 11 @@ -29,7 +29,7 @@ struct -- !query select extract(years from c), extract(years from i) from t -- !query schema -struct +struct -- !query output 2011 11 @@ -37,7 +37,7 @@ struct -- !query select extract(yr from c), extract(yr from i) from t -- !query schema -struct +struct -- !query output 2011 11 @@ -45,7 +45,7 @@ struct -- !query select extract(yrs from c), extract(yrs from i) from t -- !query schema -struct +struct -- !query output 2011 11 @@ -53,7 +53,7 @@ struct -- !query select extract(yearofweek from c) from t -- !query schema -struct +struct -- !query output 2011 @@ -61,7 +61,7 @@ struct -- !query select extract(quarter from c) from t -- !query schema -struct +struct -- !query output 2 @@ -69,7 +69,7 @@ struct -- !query select extract(qtr from c) from t -- !query schema -struct +struct -- !query output 2 @@ -77,7 +77,7 @@ struct -- !query select extract(month from c), extract(month from i) from t -- !query schema -struct +struct -- !query output 5 8 @@ -85,7 +85,7 @@ struct -- !query select extract(mon from c), extract(mon from i) from t -- !query schema -struct +struct -- !query output 5 8 @@ -93,7 +93,7 @@ struct -- !query select extract(mons from c), extract(mons from i) from t -- !query schema -struct +struct -- !query output 5 8 @@ -101,7 +101,7 @@ struct -- !query select extract(months from c), extract(months from i) from t -- !query schema -struct +struct -- !query output 5 8 @@ -109,7 +109,7 @@ struct -- !query select extract(week from c) from t -- !query schema -struct +struct -- !query output 18 @@ -117,7 +117,7 @@ struct -- !query select extract(w from c) from t -- !query schema -struct +struct -- !query output 18 @@ -125,7 +125,7 @@ struct -- !query select extract(weeks from c) from t -- !query schema -struct +struct -- !query output 18 @@ -133,7 +133,7 @@ struct -- !query select extract(day from c), extract(day from i) from t -- !query schema -struct +struct -- !query output 6 31 @@ -141,7 +141,7 @@ struct -- !query select extract(d from c), extract(d from i) from t -- !query schema -struct +struct -- !query output 6 31 @@ -149,7 +149,7 @@ struct -- !query select extract(days from c), extract(days from i) from t -- !query schema -struct +struct -- !query output 6 31 @@ -157,7 +157,7 @@ struct -- !query select extract(dayofweek from c) from t -- !query schema -struct +struct -- !query output 6 @@ -165,7 +165,7 @@ struct -- !query select extract(dow from c) from t -- !query schema -struct +struct -- !query output 6 @@ -173,7 +173,7 @@ struct -- !query select extract(dayofweek_iso from c) from t -- !query schema -struct +struct -- !query output 5 @@ -181,7 +181,7 @@ struct -- !query select extract(dow_iso from c) from t -- !query schema -struct +struct -- !query output 5 @@ -189,7 +189,7 @@ struct -- !query select extract(doy from c) from t -- !query schema -struct +struct -- !query output 126 @@ -197,7 +197,7 @@ struct -- !query select extract(hour from c), extract(hour from i) from t -- !query schema -struct +struct -- !query output 7 16 @@ -205,7 +205,7 @@ struct -- !query select extract(h from c), extract(h from i) from t -- !query schema -struct +struct -- !query output 7 16 @@ -213,7 +213,7 @@ struct -- !query select extract(hours from c), extract(hours from i) from t -- !query schema -struct +struct -- !query output 7 16 @@ -221,7 +221,7 @@ struct -- !query select extract(hr from c), extract(hr from i) from t -- !query schema -struct +struct -- !query output 7 16 @@ -229,7 +229,7 @@ struct -- !query select extract(hrs from c), extract(hrs from i) from t -- !query schema -struct +struct -- !query output 7 16 @@ -237,7 +237,7 @@ struct -- !query select extract(minute from c), extract(minute from i) from t -- !query schema -struct +struct -- !query output 8 50 @@ -245,7 +245,7 @@ struct -- !query select extract(m from c), extract(m from i) from t -- !query schema -struct +struct -- !query output 8 50 @@ -253,7 +253,7 @@ struct -- !query select extract(min from c), extract(min from i) from t -- !query schema -struct +struct -- !query output 8 50 @@ -261,7 +261,7 @@ struct -- !query select extract(mins from c), extract(mins from i) from t -- !query schema -struct +struct -- !query output 8 50 @@ -269,7 +269,7 @@ struct -- !query select extract(minutes from c), extract(minutes from i) from t -- !query schema -struct +struct -- !query output 8 50 @@ -277,7 +277,7 @@ struct -- !query select extract(second from c), extract(second from i) from t -- !query schema -struct +struct -- !query output 9.123456 6.789000 @@ -285,7 +285,7 @@ struct +struct -- !query output 9.123456 6.789000 @@ -293,7 +293,7 @@ struct +struct -- !query output 9.123456 6.789000 @@ -301,7 +301,7 @@ struct +struct -- !query output 9.123456 6.789000 @@ -309,7 +309,7 @@ struct +struct -- !query output 9.123456 6.789000 @@ -335,7 +335,7 @@ Literals of type 'not_supported' are currently not supported for the interval ty -- !query select date_part('year', c), date_part('year', i) from t -- !query schema -struct +struct -- !query output 2011 11 @@ -343,7 +343,7 @@ struct -- !query select date_part('y', c), date_part('y', i) from t -- !query schema -struct +struct -- !query output 2011 11 @@ -351,7 +351,7 @@ struct -- !query select date_part('years', c), date_part('years', i) from t -- !query schema -struct +struct -- !query output 2011 11 @@ -359,7 +359,7 @@ struct -- !query select date_part('yr', c), date_part('yr', i) from t -- !query schema -struct +struct -- !query output 2011 11 @@ -367,7 +367,7 @@ struct -- !query select date_part('yrs', c), date_part('yrs', i) from t -- !query schema -struct +struct -- !query output 2011 11 @@ -375,7 +375,7 @@ struct -- !query select date_part('yearofweek', c) from t -- !query schema -struct +struct -- !query output 2011 @@ -383,7 +383,7 @@ struct -- !query select date_part('quarter', c) from t -- !query schema -struct +struct -- !query output 2 @@ -391,7 +391,7 @@ struct -- !query select date_part('qtr', c) from t -- !query schema -struct +struct -- !query output 2 @@ -399,7 +399,7 @@ struct -- !query select date_part('month', c), date_part('month', i) from t -- !query schema -struct +struct -- !query output 5 8 @@ -407,7 +407,7 @@ struct -- !query select date_part('mon', c), date_part('mon', i) from t -- !query schema -struct +struct -- !query output 5 8 @@ -415,7 +415,7 @@ struct -- !query select date_part('mons', c), date_part('mons', i) from t -- !query schema -struct +struct -- !query output 5 8 @@ -423,7 +423,7 @@ struct -- !query select date_part('months', c), date_part('months', i) from t -- !query schema -struct +struct -- !query output 5 8 @@ -431,7 +431,7 @@ struct -- !query select date_part('week', c) from t -- !query schema -struct +struct -- !query output 18 @@ -439,7 +439,7 @@ struct -- !query select date_part('w', c) from t -- !query schema -struct +struct -- !query output 18 @@ -447,7 +447,7 @@ struct -- !query select date_part('weeks', c) from t -- !query schema -struct +struct -- !query output 18 @@ -455,7 +455,7 @@ struct -- !query select date_part('day', c), date_part('day', i) from t -- !query schema -struct +struct -- !query output 6 31 @@ -463,7 +463,7 @@ struct -- !query select date_part('d', c), date_part('d', i) from t -- !query schema -struct +struct -- !query output 6 31 @@ -471,7 +471,7 @@ struct -- !query select date_part('days', c), date_part('days', i) from t -- !query schema -struct +struct -- !query output 6 31 @@ -479,7 +479,7 @@ struct -- !query select date_part('dayofweek', c) from t -- !query schema -struct +struct -- !query output 6 @@ -487,7 +487,7 @@ struct -- !query select date_part('dow', c) from t -- !query schema -struct +struct -- !query output 6 @@ -495,7 +495,7 @@ struct -- !query select date_part('dayofweek_iso', c) from t -- !query schema -struct +struct -- !query output 5 @@ -503,7 +503,7 @@ struct -- !query select date_part('dow_iso', c) from t -- !query schema -struct +struct -- !query output 5 @@ -511,7 +511,7 @@ struct -- !query select date_part('doy', c) from t -- !query schema -struct +struct -- !query output 126 @@ -519,7 +519,7 @@ struct -- !query select date_part('hour', c), date_part('hour', i) from t -- !query schema -struct +struct -- !query output 7 16 @@ -527,7 +527,7 @@ struct -- !query select date_part('h', c), date_part('h', i) from t -- !query schema -struct +struct -- !query output 7 16 @@ -535,7 +535,7 @@ struct -- !query select date_part('hours', c), date_part('hours', i) from t -- !query schema -struct +struct -- !query output 7 16 @@ -543,7 +543,7 @@ struct -- !query select date_part('hr', c), date_part('hr', i) from t -- !query schema -struct +struct -- !query output 7 16 @@ -551,7 +551,7 @@ struct -- !query select date_part('hrs', c), date_part('hrs', i) from t -- !query schema -struct +struct -- !query output 7 16 @@ -559,7 +559,7 @@ struct -- !query select date_part('minute', c), date_part('minute', i) from t -- !query schema -struct +struct -- !query output 8 50 @@ -567,7 +567,7 @@ struct -- !query select date_part('m', c), date_part('m', i) from t -- !query schema -struct +struct -- !query output 8 50 @@ -575,7 +575,7 @@ struct -- !query select date_part('min', c), date_part('min', i) from t -- !query schema -struct +struct -- !query output 8 50 @@ -583,7 +583,7 @@ struct -- !query select date_part('mins', c), date_part('mins', i) from t -- !query schema -struct +struct -- !query output 8 50 @@ -591,7 +591,7 @@ struct -- !query select date_part('minutes', c), date_part('minutes', i) from t -- !query schema -struct +struct -- !query output 8 50 @@ -599,7 +599,7 @@ struct -- !query select date_part('second', c), date_part('second', i) from t -- !query schema -struct +struct -- !query output 9.123456 6.789000 @@ -607,7 +607,7 @@ struct +struct -- !query output 9.123456 6.789000 @@ -615,7 +615,7 @@ struct -- !query select date_part('sec', c), date_part('sec', i) from t -- !query schema -struct +struct -- !query output 9.123456 6.789000 @@ -623,7 +623,7 @@ struct +struct -- !query output 9.123456 6.789000 @@ -631,7 +631,7 @@ struct +struct -- !query output 9.123456 6.789000 @@ -657,7 +657,7 @@ The field parameter needs to be a foldable string value.;; line 1 pos 7 -- !query select date_part(null, c) from t -- !query schema -struct +struct -- !query output NULL @@ -674,7 +674,7 @@ The field parameter needs to be a foldable string value.;; line 1 pos 7 -- !query select date_part(null, i) from t -- !query schema -struct +struct -- !query output NULL @@ -682,7 +682,7 @@ NULL -- !query select extract('year', c) from t -- !query schema -struct +struct -- !query output 2011 @@ -690,7 +690,7 @@ struct -- !query select extract('quarter', c) from t -- !query schema -struct +struct -- !query output 2 @@ -698,7 +698,7 @@ struct -- !query select extract('month', c) from t -- !query schema -struct +struct -- !query output 5 @@ -706,7 +706,7 @@ struct -- !query select extract('week', c) from t -- !query schema -struct +struct -- !query output 18 @@ -714,7 +714,7 @@ struct -- !query select extract('day', c) from t -- !query schema -struct +struct -- !query output 6 @@ -722,7 +722,7 @@ struct -- !query select extract('days', c) from t -- !query schema -struct +struct -- !query output 6 @@ -730,7 +730,7 @@ struct -- !query select extract('dayofweek', c) from t -- !query schema -struct +struct -- !query output 6 @@ -738,7 +738,7 @@ struct -- !query select extract('dow', c) from t -- !query schema -struct +struct -- !query output 6 @@ -746,7 +746,7 @@ struct -- !query select extract('doy', c) from t -- !query schema -struct +struct -- !query output 126 @@ -754,7 +754,7 @@ struct -- !query select extract('hour', c) from t -- !query schema -struct +struct -- !query output 7 @@ -762,7 +762,7 @@ struct -- !query select extract('minute', c) from t -- !query schema -struct +struct -- !query output 8 @@ -770,6 +770,38 @@ struct -- !query select extract('second', c) from t -- !query schema -struct +struct -- !query output 9.123456 + + +-- !query +select c - i from t +-- !query schema +struct +-- !query output +1999-08-05 14:18:02.334456 + + +-- !query +select year(c - i) from t +-- !query schema +struct +-- !query output +1999 + + +-- !query +select extract(year from c - i) from t +-- !query schema +struct +-- !query output +1999 + + +-- !query +select extract(month from to_timestamp(c) - i) from t +-- !query schema +struct +-- !query output +8 diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out index a4c7c2cf90cd7..d41d25280146b 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out @@ -73,7 +73,7 @@ struct -- !query SELECT COUNT(id) FILTER (WHERE hiredate = to_date('2001-01-01 00:00:00')) FROM emp -- !query schema -struct +struct -- !query output 2 @@ -81,7 +81,7 @@ struct +struct -- !query output 2 @@ -141,7 +141,7 @@ NULL NULL -- !query SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_date("2003-01-01")) FROM emp GROUP BY dept_id -- !query schema -struct to_date('2003-01-01'))):double> +struct to_date(2003-01-01))):double> -- !query output 10 200.0 100 400.0 @@ -154,7 +154,7 @@ NULL NULL -- !query SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_timestamp("2003-01-01 00:00:00")) FROM emp GROUP BY dept_id -- !query schema -struct to_timestamp('2003-01-01 00:00:00'))):double> +struct to_timestamp(2003-01-01 00:00:00))):double> -- !query output 10 200.0 100 400.0 @@ -196,7 +196,7 @@ foo 1350.0 -- !query SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= to_date("2003-01-01")) FROM emp GROUP BY 1 -- !query schema -struct= to_date('2003-01-01'))):double> +struct= to_date(2003-01-01))):double> -- !query output foo 1350.0 @@ -204,7 +204,7 @@ foo 1350.0 -- !query SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= to_timestamp("2003-01-01")) FROM emp GROUP BY 1 -- !query schema -struct= to_timestamp('2003-01-01'))):double> +struct= to_timestamp(2003-01-01))):double> -- !query output foo 1350.0 @@ -272,7 +272,7 @@ struct= 0)):bigint> -- !query SELECT 'foo', MAX(STRUCT(a)) FILTER (WHERE b >= 1) FROM testData WHERE a = 0 GROUP BY 1 -- !query schema -struct= 1)):struct> +struct= 1)):struct> -- !query output diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 7bfdd0ad53a95..50eb2a9f22f69 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -87,7 +87,7 @@ struct -- !query SELECT 'foo', MAX(STRUCT(a)) FROM testData WHERE a = 0 GROUP BY 1 -- !query schema -struct> +struct> -- !query output diff --git a/sql/core/src/test/resources/sql-tests/results/having.sql.out b/sql/core/src/test/resources/sql-tests/results/having.sql.out index aa8ff73723586..1b3ac7865159f 100644 --- a/sql/core/src/test/resources/sql-tests/results/having.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/having.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 6 +-- Number of queries: 9 -- !query @@ -55,3 +55,29 @@ SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, struct -- !query output 2 2020-01-01 + + +-- !query +SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY GROUPING SETS ((b), (a, b)) HAVING b > 10 +-- !query schema +struct +-- !query output +2 +2 + + +-- !query +SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY CUBE(a, b) HAVING b > 10 +-- !query schema +struct +-- !query output +2 +2 + + +-- !query +SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY ROLLUP(a, b) HAVING b > 10 +-- !query schema +struct +-- !query output +2 diff --git a/sql/core/src/test/resources/sql-tests/results/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/interval.sql.out index 4b3dd17001f41..01db43ce9e8bc 100644 --- a/sql/core/src/test/resources/sql-tests/results/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/interval.sql.out @@ -676,7 +676,7 @@ select interval '2-2' year to month + dateval from interval_arithmetic -- !query schema -struct +struct -- !query output 2012-01-01 2009-11-01 2014-03-01 2014-03-01 2009-11-01 2009-11-01 2014-03-01 @@ -692,7 +692,7 @@ select interval '2-2' year to month + tsval from interval_arithmetic -- !query schema -struct +struct -- !query output 2012-01-01 00:00:00 2009-11-01 00:00:00 2014-03-01 00:00:00 2014-03-01 00:00:00 2009-11-01 00:00:00 2009-11-01 00:00:00 2014-03-01 00:00:00 @@ -719,7 +719,7 @@ select interval '99 11:22:33.123456789' day to second + dateval from interval_arithmetic -- !query schema -struct +struct -- !query output 2012-01-01 2011-09-23 2012-04-09 2012-04-09 2011-09-23 2011-09-23 2012-04-09 @@ -735,7 +735,7 @@ select interval '99 11:22:33.123456789' day to second + tsval from interval_arithmetic -- !query schema -struct +struct -- !query output 2012-01-01 00:00:00 2011-09-23 12:37:26.876544 2012-04-09 11:22:33.123456 2012-04-09 11:22:33.123456 2011-09-23 12:37:26.876544 2011-09-23 12:37:26.876544 2012-04-09 11:22:33.123456 @@ -751,7 +751,7 @@ select interval '99 11:22:33.123456789' day to second + strval from interval_arithmetic -- !query schema -struct +struct -- !query output 2012-01-01 2011-09-23 12:37:26.876544 2012-04-09 11:22:33.123456 2012-04-09 11:22:33.123456 2011-09-23 12:37:26.876544 2011-09-23 12:37:26.876544 2012-04-09 11:22:33.123456 diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index 866fd1245d0ed..34a329627f5dd 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 67 +-- Number of queries: 71 -- !query @@ -13,7 +13,7 @@ struct -- !query select to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')) -- !query schema -struct +struct -- !query output {"time":"26/08/2015"} @@ -288,6 +288,49 @@ struct>> NULL +-- !query +select from_json('{"d": "2012-12-15", "t": "2012-12-15 15:15:15"}', 'd date, t timestamp') +-- !query schema +struct> +-- !query output +{"d":2012-12-15,"t":2012-12-15 15:15:15} + + +-- !query +select from_json( + '{"d": "12/15 2012", "t": "12/15 2012 15:15:15"}', + 'd date, t timestamp', + map('dateFormat', 'MM/dd yyyy', 'timestampFormat', 'MM/dd yyyy HH:mm:ss')) +-- !query schema +struct> +-- !query output +{"d":2012-12-15,"t":2012-12-15 15:15:15} + + +-- !query +select from_json( + '{"d": "02-29"}', + 'd date', + map('dateFormat', 'MM-dd')) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to parse '02-29' in the new parser. You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0, or set to CORRECTED and treat it as an invalid datetime string. + + +-- !query +select from_json( + '{"t": "02-29"}', + 't timestamp', + map('timestampFormat', 'MM-dd')) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to parse '02-29' in the new parser. You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0, or set to CORRECTED and treat it as an invalid datetime string. + + -- !query select to_json(array('1', '2', '3')) -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index cf857cf9f98ad..9accc57d0bf60 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -157,7 +157,7 @@ NULL -- !query select 5 div 2 -- !query schema -struct<(5 div 2):bigint> +struct<(CAST(5 AS BIGINT) div CAST(2 AS BIGINT)):bigint> -- !query output 2 @@ -165,7 +165,7 @@ struct<(5 div 2):bigint> -- !query select 5 div 0 -- !query schema -struct<(5 div 0):bigint> +struct<(CAST(5 AS BIGINT) div CAST(0 AS BIGINT)):bigint> -- !query output NULL @@ -173,7 +173,7 @@ NULL -- !query select 5 div null -- !query schema -struct<(5 div CAST(NULL AS INT)):bigint> +struct<(CAST(5 AS BIGINT) div CAST(NULL AS BIGINT)):bigint> -- !query output NULL @@ -181,7 +181,7 @@ NULL -- !query select null div 5 -- !query schema -struct<(CAST(NULL AS INT) div 5):bigint> +struct<(CAST(NULL AS BIGINT) div CAST(5 AS BIGINT)):bigint> -- !query output NULL @@ -437,7 +437,7 @@ struct -- !query select positive('-1.11'), positive(-1.11), negative('-1.11'), negative(-1.11) -- !query schema -struct<(+ CAST(-1.11 AS DOUBLE)):double,(+ -1.11):decimal(3,2),(- CAST(-1.11 AS DOUBLE)):double,(- -1.11):decimal(3,2)> +struct<(+ CAST(-1.11 AS DOUBLE)):double,(+ -1.11):decimal(3,2),negative(CAST(-1.11 AS DOUBLE)):double,negative(-1.11):decimal(3,2)> -- !query output -1.11 -1.11 1.11 1.11 diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/date.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/date.sql.out index 1d862ba8a41a8..151fa1e28d725 100755 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/date.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/date.sql.out @@ -584,7 +584,7 @@ select make_date(-44, 3, 15) -- !query schema struct -- !query output --0044-03-15 +0045-03-15 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out index e59b9d5b63a40..7b7aeb4ec7934 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out @@ -4654,7 +4654,7 @@ struct -- !query select ln(1.2345678e-28) -- !query schema -struct +struct -- !query output -64.26166165451762 @@ -4662,7 +4662,7 @@ struct -- !query select ln(0.0456789) -- !query schema -struct +struct -- !query output -3.0861187944847437 @@ -4670,7 +4670,7 @@ struct -- !query select ln(0.99949452) -- !query schema -struct +struct -- !query output -5.056077980832118E-4 @@ -4678,7 +4678,7 @@ struct -- !query select ln(1.00049687395) -- !query schema -struct +struct -- !query output 4.967505490136803E-4 @@ -4686,7 +4686,7 @@ struct -- !query select ln(1234.567890123456789) -- !query schema -struct +struct -- !query output 7.11847630129779 @@ -4694,7 +4694,7 @@ struct -- !query select ln(5.80397490724e5) -- !query schema -struct +struct -- !query output 13.271468476626518 @@ -4702,7 +4702,7 @@ struct -- !query select ln(9.342536355e34) -- !query schema -struct +struct -- !query output 80.52247093552418 diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out index ccca1ba8cd8b4..811e7d6e4ca65 100755 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out @@ -99,7 +99,7 @@ one -- !query select concat(1,2,3,'hello',true, false, to_date('20100309','yyyyMMdd')) -- !query schema -struct +struct -- !query output 123hellotruefalse2010-03-09 @@ -115,7 +115,7 @@ one -- !query select concat_ws('#',1,2,3,'hello',true, false, to_date('20100309','yyyyMMdd')) -- !query schema -struct +struct -- !query output 1#x#x#hello#true#false#x-03-09 @@ -155,7 +155,7 @@ edcba -- !query select i, left('ahoj', i), right('ahoj', i) from range(-5, 6) t(i) order by i -- !query schema -struct +struct -- !query output -5 -4 diff --git a/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out index 08cc6fa993e0b..819be95603b0c 100644 --- a/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out @@ -77,7 +77,7 @@ true -- !query select to_date('2009-07-30 04:17:52') > to_date('2009-07-30 04:17:52') -- !query schema -struct<(to_date('2009-07-30 04:17:52') > to_date('2009-07-30 04:17:52')):boolean> +struct<(to_date(2009-07-30 04:17:52) > to_date(2009-07-30 04:17:52)):boolean> -- !query output false @@ -85,7 +85,7 @@ false -- !query select to_date('2009-07-30 04:17:52') > '2009-07-30 04:17:52' -- !query schema -struct<(to_date('2009-07-30 04:17:52') > CAST(2009-07-30 04:17:52 AS DATE)):boolean> +struct<(to_date(2009-07-30 04:17:52) > CAST(2009-07-30 04:17:52 AS DATE)):boolean> -- !query output false @@ -133,7 +133,7 @@ true -- !query select to_date('2009-07-30 04:17:52') >= to_date('2009-07-30 04:17:52') -- !query schema -struct<(to_date('2009-07-30 04:17:52') >= to_date('2009-07-30 04:17:52')):boolean> +struct<(to_date(2009-07-30 04:17:52) >= to_date(2009-07-30 04:17:52)):boolean> -- !query output true @@ -141,7 +141,7 @@ true -- !query select to_date('2009-07-30 04:17:52') >= '2009-07-30 04:17:52' -- !query schema -struct<(to_date('2009-07-30 04:17:52') >= CAST(2009-07-30 04:17:52 AS DATE)):boolean> +struct<(to_date(2009-07-30 04:17:52) >= CAST(2009-07-30 04:17:52 AS DATE)):boolean> -- !query output true @@ -189,7 +189,7 @@ true -- !query select to_date('2009-07-30 04:17:52') < to_date('2009-07-30 04:17:52') -- !query schema -struct<(to_date('2009-07-30 04:17:52') < to_date('2009-07-30 04:17:52')):boolean> +struct<(to_date(2009-07-30 04:17:52) < to_date(2009-07-30 04:17:52)):boolean> -- !query output false @@ -197,7 +197,7 @@ false -- !query select to_date('2009-07-30 04:17:52') < '2009-07-30 04:17:52' -- !query schema -struct<(to_date('2009-07-30 04:17:52') < CAST(2009-07-30 04:17:52 AS DATE)):boolean> +struct<(to_date(2009-07-30 04:17:52) < CAST(2009-07-30 04:17:52 AS DATE)):boolean> -- !query output false @@ -245,7 +245,7 @@ true -- !query select to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52') -- !query schema -struct<(to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52')):boolean> +struct<(to_date(2009-07-30 04:17:52) <= to_date(2009-07-30 04:17:52)):boolean> -- !query output true @@ -253,7 +253,7 @@ true -- !query select to_date('2009-07-30 04:17:52') <= '2009-07-30 04:17:52' -- !query schema -struct<(to_date('2009-07-30 04:17:52') <= CAST(2009-07-30 04:17:52 AS DATE)):boolean> +struct<(to_date(2009-07-30 04:17:52) <= CAST(2009-07-30 04:17:52 AS DATE)):boolean> -- !query output true @@ -261,7 +261,7 @@ true -- !query select to_date('2017-03-01') = to_timestamp('2017-03-01 00:00:00') -- !query schema -struct<(CAST(to_date('2017-03-01') AS TIMESTAMP) = to_timestamp('2017-03-01 00:00:00')):boolean> +struct<(CAST(to_date(2017-03-01) AS TIMESTAMP) = to_timestamp(2017-03-01 00:00:00)):boolean> -- !query output true @@ -269,7 +269,7 @@ true -- !query select to_timestamp('2017-03-01 00:00:01') > to_date('2017-03-01') -- !query schema -struct<(to_timestamp('2017-03-01 00:00:01') > CAST(to_date('2017-03-01') AS TIMESTAMP)):boolean> +struct<(to_timestamp(2017-03-01 00:00:01) > CAST(to_date(2017-03-01) AS TIMESTAMP)):boolean> -- !query output true @@ -277,7 +277,7 @@ true -- !query select to_timestamp('2017-03-01 00:00:01') >= to_date('2017-03-01') -- !query schema -struct<(to_timestamp('2017-03-01 00:00:01') >= CAST(to_date('2017-03-01') AS TIMESTAMP)):boolean> +struct<(to_timestamp(2017-03-01 00:00:01) >= CAST(to_date(2017-03-01) AS TIMESTAMP)):boolean> -- !query output true @@ -285,7 +285,7 @@ true -- !query select to_date('2017-03-01') < to_timestamp('2017-03-01 00:00:01') -- !query schema -struct<(CAST(to_date('2017-03-01') AS TIMESTAMP) < to_timestamp('2017-03-01 00:00:01')):boolean> +struct<(CAST(to_date(2017-03-01) AS TIMESTAMP) < to_timestamp(2017-03-01 00:00:01)):boolean> -- !query output true @@ -293,6 +293,6 @@ true -- !query select to_date('2017-03-01') <= to_timestamp('2017-03-01 00:00:01') -- !query schema -struct<(CAST(to_date('2017-03-01') AS TIMESTAMP) <= to_timestamp('2017-03-01 00:00:01')):boolean> +struct<(CAST(to_date(2017-03-01) AS TIMESTAMP) <= to_timestamp(2017-03-01 00:00:01)):boolean> -- !query output true diff --git a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out index 6f1bbd03bc223..26a44a85841e0 100644 --- a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out @@ -5,7 +5,7 @@ -- !query SELECT ifnull(null, 'x'), ifnull('y', 'x'), ifnull(null, null) -- !query schema -struct +struct -- !query output x y NULL @@ -13,7 +13,7 @@ x y NULL -- !query SELECT nullif('x', 'x'), nullif('x', 'y') -- !query schema -struct +struct -- !query output NULL x @@ -21,7 +21,7 @@ NULL x -- !query SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null) -- !query schema -struct +struct -- !query output x y NULL @@ -29,7 +29,7 @@ x y NULL -- !query SELECT nvl2(null, 'x', 'y'), nvl2('n', 'x', 'y'), nvl2(null, null, null) -- !query schema -struct +struct -- !query output y x NULL @@ -37,7 +37,7 @@ y x NULL -- !query SELECT ifnull(1, 2.1d), ifnull(null, 2.1d) -- !query schema -struct +struct -- !query output 1.0 2.1 @@ -45,7 +45,7 @@ struct -- !query SELECT nullif(1, 2.1d), nullif(1, 1.0d) -- !query schema -struct +struct -- !query output 1 NULL @@ -53,7 +53,7 @@ struct -- !query SELECT nvl(1, 2.1d), nvl(null, 2.1d) -- !query schema -struct +struct -- !query output 1.0 2.1 @@ -61,7 +61,7 @@ struct -- !query SELECT nvl2(null, 1, 2.1d), nvl2('n', 1, 2.1d) -- !query schema -struct +struct -- !query output 2.1 1.0 @@ -110,6 +110,6 @@ struct<> -- !query SELECT nvl(st.col1, "value"), count(*) FROM from tempView1 GROUP BY nvl(st.col1, "value") -- !query schema -struct +struct -- !query output gamma 1 diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 43c18f5417110..20c31b140b009 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 34 +-- Number of queries: 36 -- !query @@ -55,7 +55,7 @@ struct -- !query select position('bar' in 'foobarbar'), position(null, 'foobarbar'), position('aaads', null) -- !query schema -struct +struct -- !query output 4 NULL NULL @@ -63,7 +63,7 @@ struct +struct -- !query output ab abcd ab NULL @@ -71,7 +71,7 @@ ab abcd ab NULL -- !query select left(null, -2), left("abcd", -2), left("abcd", 0), left("abcd", 'a') -- !query schema -struct +struct -- !query output NULL NULL @@ -79,7 +79,7 @@ NULL NULL -- !query select right("abcd", 2), right("abcd", 5), right("abcd", '2'), right("abcd", null) -- !query schema -struct +struct -- !query output cd abcd cd NULL @@ -87,7 +87,7 @@ cd abcd cd NULL -- !query select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a') -- !query schema -struct +struct -- !query output NULL NULL @@ -274,3 +274,19 @@ SELECT trim(TRAILING 'xy' FROM 'TURNERyxXxy') struct -- !query output TURNERyxX + + +-- !query +SELECT lpad('hi', 'invalid_length') +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT rpad('hi', 'invalid_length') +-- !query schema +struct +-- !query output +NULL diff --git a/sql/core/src/test/resources/sql-tests/results/struct.sql.out b/sql/core/src/test/resources/sql-tests/results/struct.sql.out index f294c5213d319..3b610edc47169 100644 --- a/sql/core/src/test/resources/sql-tests/results/struct.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/struct.sql.out @@ -83,7 +83,7 @@ struct -- !query SELECT ID, STRUCT(ST.C as STC, ST.D as STD).STD FROM tbl_x -- !query schema -struct +struct -- !query output 1 delta 2 eta diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/dateTimeOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/dateTimeOperations.sql.out index f81e9f6b13dc9..41f888ee28923 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/dateTimeOperations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/dateTimeOperations.sql.out @@ -302,7 +302,7 @@ cannot resolve 'CAST(1 AS DECIMAL(10,0)) + (- INTERVAL '2 days')' due to data ty -- !query select cast('2017-12-11' as string) - interval 2 day -- !query schema -struct +struct -- !query output 2017-12-09 00:00:00 @@ -310,7 +310,7 @@ struct -- !query select cast('2017-12-11 09:30:00' as string) - interval 2 day -- !query schema -struct +struct -- !query output 2017-12-09 09:30:00 @@ -336,7 +336,7 @@ cannot resolve 'CAST(1 AS BOOLEAN) + (- INTERVAL '2 days')' due to data type mis -- !query select cast('2017-12-11 09:30:00.0' as timestamp) - interval 2 day -- !query schema -struct +struct -- !query output 2017-12-09 09:30:00 @@ -344,6 +344,6 @@ struct +struct -- !query output 2017-12-09 diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out index ed7ab5a342c12..d046ff249379f 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out @@ -85,7 +85,7 @@ FROM various_maps struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'map_zip_with(various_maps.`decimal_map1`, various_maps.`decimal_map2`, lambdafunction(named_struct(NamePlaceholder(), k, NamePlaceholder(), v1, NamePlaceholder(), v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,0), decimal(36,35)].; line 1 pos 7 +cannot resolve 'map_zip_with(various_maps.`decimal_map1`, various_maps.`decimal_map2`, lambdafunction(struct(k, v1, v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,0), decimal(36,35)].; line 1 pos 7 -- !query @@ -113,7 +113,7 @@ FROM various_maps struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'map_zip_with(various_maps.`decimal_map2`, various_maps.`int_map`, lambdafunction(named_struct(NamePlaceholder(), k, NamePlaceholder(), v1, NamePlaceholder(), v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,35), int].; line 1 pos 7 +cannot resolve 'map_zip_with(various_maps.`decimal_map2`, various_maps.`int_map`, lambdafunction(struct(k, v1, v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,35), int].; line 1 pos 7 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out index 5c56eff85b264..02944c268ed21 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out @@ -128,7 +128,7 @@ cannot resolve 't.`a`' due to data type mismatch: cannot cast string to map +struct -- !query output NULL @@ -136,9 +136,10 @@ NULL -- !query select to_timestamp('2018-01-01', a) from t -- !query schema -struct +struct<> -- !query output -NULL +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'aa' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -152,9 +153,10 @@ NULL -- !query select to_unix_timestamp('2018-01-01', a) from t -- !query schema -struct +struct<> -- !query output -NULL +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'aa' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query @@ -168,9 +170,10 @@ NULL -- !query select unix_timestamp('2018-01-01', a) from t -- !query schema -struct +struct<> -- !query output -NULL +org.apache.spark.SparkUpgradeException +You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'aa' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out index 6403406413db9..da5256f5c0453 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out @@ -87,7 +87,7 @@ struct> +struct> -- !query output diff --git a/sql/core/src/test/resources/test-data/before_1582_date_v2_4.snappy.parquet b/sql/core/src/test/resources/test-data/before_1582_date_v2_4.snappy.parquet deleted file mode 100644 index 7d5cc12eefe04..0000000000000 Binary files a/sql/core/src/test/resources/test-data/before_1582_date_v2_4.snappy.parquet and /dev/null differ diff --git a/sql/core/src/test/resources/test-data/before_1582_date_v2_4_5.snappy.parquet b/sql/core/src/test/resources/test-data/before_1582_date_v2_4_5.snappy.parquet new file mode 100644 index 0000000000000..edd61c9b9fec8 Binary files /dev/null and b/sql/core/src/test/resources/test-data/before_1582_date_v2_4_5.snappy.parquet differ diff --git a/sql/core/src/test/resources/test-data/before_1582_date_v2_4_6.snappy.parquet b/sql/core/src/test/resources/test-data/before_1582_date_v2_4_6.snappy.parquet new file mode 100644 index 0000000000000..01f4887f5e994 Binary files /dev/null and b/sql/core/src/test/resources/test-data/before_1582_date_v2_4_6.snappy.parquet differ diff --git a/sql/core/src/test/resources/test-data/before_1582_timestamp_int96_dict_v2_4_5.snappy.parquet b/sql/core/src/test/resources/test-data/before_1582_timestamp_int96_dict_v2_4_5.snappy.parquet new file mode 100644 index 0000000000000..c7e8d3926f63a Binary files /dev/null and b/sql/core/src/test/resources/test-data/before_1582_timestamp_int96_dict_v2_4_5.snappy.parquet differ diff --git a/sql/core/src/test/resources/test-data/before_1582_timestamp_int96_dict_v2_4_6.snappy.parquet b/sql/core/src/test/resources/test-data/before_1582_timestamp_int96_dict_v2_4_6.snappy.parquet new file mode 100644 index 0000000000000..939e2b8088eb0 Binary files /dev/null and b/sql/core/src/test/resources/test-data/before_1582_timestamp_int96_dict_v2_4_6.snappy.parquet differ diff --git a/sql/core/src/test/resources/test-data/before_1582_timestamp_int96_plain_v2_4_5.snappy.parquet b/sql/core/src/test/resources/test-data/before_1582_timestamp_int96_plain_v2_4_5.snappy.parquet new file mode 100644 index 0000000000000..88a94ac482052 Binary files /dev/null and b/sql/core/src/test/resources/test-data/before_1582_timestamp_int96_plain_v2_4_5.snappy.parquet differ diff --git a/sql/core/src/test/resources/test-data/before_1582_timestamp_int96_plain_v2_4_6.snappy.parquet b/sql/core/src/test/resources/test-data/before_1582_timestamp_int96_plain_v2_4_6.snappy.parquet new file mode 100644 index 0000000000000..68bfa33aac13f Binary files /dev/null and b/sql/core/src/test/resources/test-data/before_1582_timestamp_int96_plain_v2_4_6.snappy.parquet differ diff --git a/sql/core/src/test/resources/test-data/before_1582_timestamp_int96_v2_4.snappy.parquet b/sql/core/src/test/resources/test-data/before_1582_timestamp_int96_v2_4.snappy.parquet deleted file mode 100644 index 13254bd93a5e6..0000000000000 Binary files a/sql/core/src/test/resources/test-data/before_1582_timestamp_int96_v2_4.snappy.parquet and /dev/null differ diff --git a/sql/core/src/test/resources/test-data/before_1582_timestamp_micros_v2_4.snappy.parquet b/sql/core/src/test/resources/test-data/before_1582_timestamp_micros_v2_4.snappy.parquet deleted file mode 100644 index 7d2b46e9bea41..0000000000000 Binary files a/sql/core/src/test/resources/test-data/before_1582_timestamp_micros_v2_4.snappy.parquet and /dev/null differ diff --git a/sql/core/src/test/resources/test-data/before_1582_timestamp_micros_v2_4_5.snappy.parquet b/sql/core/src/test/resources/test-data/before_1582_timestamp_micros_v2_4_5.snappy.parquet new file mode 100644 index 0000000000000..62e6048354dc1 Binary files /dev/null and b/sql/core/src/test/resources/test-data/before_1582_timestamp_micros_v2_4_5.snappy.parquet differ diff --git a/sql/core/src/test/resources/test-data/before_1582_timestamp_micros_v2_4_6.snappy.parquet b/sql/core/src/test/resources/test-data/before_1582_timestamp_micros_v2_4_6.snappy.parquet new file mode 100644 index 0000000000000..d7fdaa3e67212 Binary files /dev/null and b/sql/core/src/test/resources/test-data/before_1582_timestamp_micros_v2_4_6.snappy.parquet differ diff --git a/sql/core/src/test/resources/test-data/before_1582_timestamp_millis_v2_4.snappy.parquet b/sql/core/src/test/resources/test-data/before_1582_timestamp_millis_v2_4.snappy.parquet deleted file mode 100644 index e9825455c2015..0000000000000 Binary files a/sql/core/src/test/resources/test-data/before_1582_timestamp_millis_v2_4.snappy.parquet and /dev/null differ diff --git a/sql/core/src/test/resources/test-data/before_1582_timestamp_millis_v2_4_5.snappy.parquet b/sql/core/src/test/resources/test-data/before_1582_timestamp_millis_v2_4_5.snappy.parquet new file mode 100644 index 0000000000000..a7cef9e60f134 Binary files /dev/null and b/sql/core/src/test/resources/test-data/before_1582_timestamp_millis_v2_4_5.snappy.parquet differ diff --git a/sql/core/src/test/resources/test-data/before_1582_timestamp_millis_v2_4_6.snappy.parquet b/sql/core/src/test/resources/test-data/before_1582_timestamp_millis_v2_4_6.snappy.parquet new file mode 100644 index 0000000000000..4c213f4540a73 Binary files /dev/null and b/sql/core/src/test/resources/test-data/before_1582_timestamp_millis_v2_4_6.snappy.parquet differ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 4edf3a5d39fdd..2293d4ae61aff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -973,4 +973,43 @@ class DataFrameAggregateSuite extends QueryTest assert(error.message.contains("function count_if requires boolean type")) } } + + Seq(true, false).foreach { value => + test(s"SPARK-31620: agg with subquery (whole-stage-codegen = $value)") { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> value.toString) { + withTempView("t1", "t2") { + sql("create temporary view t1 as select * from values (1, 2) as t1(a, b)") + sql("create temporary view t2 as select * from values (3, 4) as t2(c, d)") + + // test without grouping keys + checkAnswer(sql("select sum(if(c > (select a from t1), d, 0)) as csum from t2"), + Row(4) :: Nil) + + // test with grouping keys + checkAnswer(sql("select c, sum(if(c > (select a from t1), d, 0)) as csum from " + + "t2 group by c"), Row(3, 4) :: Nil) + + // test with distinct + checkAnswer(sql("select avg(distinct(d)), sum(distinct(if(c > (select a from t1)," + + " d, 0))) as csum from t2 group by c"), Row(4, 4) :: Nil) + + // test subquery with agg + checkAnswer(sql("select sum(distinct(if(c > (select sum(distinct(a)) from t1)," + + " d, 0))) as csum from t2 group by c"), Row(4) :: Nil) + + // test SortAggregateExec + var df = sql("select max(if(c > (select a from t1), 'str1', 'str2')) as csum from t2") + assert(df.queryExecution.executedPlan + .find { case _: SortAggregateExec => true }.isDefined) + checkAnswer(df, Row("str1") :: Nil) + + // test ObjectHashAggregateExec + df = sql("select collect_list(d), sum(if(c > (select a from t1), d, 0)) as csum from t2") + assert(df.queryExecution.executedPlan + .find { case _: ObjectHashAggregateExec => true }.isDefined) + checkAnswer(df, Row(Array(4), 4) :: Nil) + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala index 250ec7dc0ba5a..fb58c9851224b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.{count, sum} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -202,4 +203,15 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { assertAmbiguousSelfJoin(df1.join(df4).join(df2).select(df2("id"))) } } + + test("SPARK-28344: don't fail as ambiguous self join when there is no join") { + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "true") { + val df = Seq(1, 1, 2, 2).toDF("a") + val w = Window.partitionBy(df("a")) + checkAnswer( + df.select(df("a").alias("x"), sum(df("a")).over(w)), + Seq((1, 2), (1, 2), (2, 4), (2, 4)).map(Row.fromTuple)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 4e91a7c7bb0f4..8359dff674a87 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -192,6 +192,28 @@ class DataFrameSuite extends QueryTest structDf.select(xxhash64($"a", $"record.*"))) } + private def assertDecimalSumOverflow( + df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row): Unit = { + if (!ansiEnabled) { + try { + checkAnswer(df, expectedAnswer) + } catch { + case e: SparkException if e.getCause.isInstanceOf[ArithmeticException] => + // This is an existing bug that we can write overflowed decimal to UnsafeRow but fail + // to read it. + assert(e.getCause.getMessage.contains("Decimal precision 39 exceeds max precision 38")) + } + } else { + val e = intercept[SparkException] { + df.collect + } + assert(e.getCause.isInstanceOf[ArithmeticException]) + assert(e.getCause.getMessage.contains("cannot be represented as Decimal") || + e.getCause.getMessage.contains("Overflow in sum of decimals") || + e.getCause.getMessage.contains("Decimal precision 39 exceeds max precision 38")) + } + } + test("SPARK-28224: Aggregate sum big decimal overflow") { val largeDecimals = spark.sparkContext.parallelize( DecimalData(BigDecimal("1"* 20 + ".123"), BigDecimal("1"* 20 + ".123")) :: @@ -200,14 +222,90 @@ class DataFrameSuite extends QueryTest Seq(true, false).foreach { ansiEnabled => withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) { val structDf = largeDecimals.select("a").agg(sum("a")) - if (!ansiEnabled) { - checkAnswer(structDf, Row(null)) - } else { - val e = intercept[SparkException] { - structDf.collect + assertDecimalSumOverflow(structDf, ansiEnabled, Row(null)) + } + } + } + + test("SPARK-28067: sum of null decimal values") { + Seq("true", "false").foreach { wholeStageEnabled => + withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) { + Seq("true", "false").foreach { ansiEnabled => + withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled)) { + val df = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d")) + checkAnswer(df.agg(sum($"d")), Row(null)) + } + } + } + } + } + + test("SPARK-28067: Aggregate sum should not return wrong results for decimal overflow") { + Seq("true", "false").foreach { wholeStageEnabled => + withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) { + Seq(true, false).foreach { ansiEnabled => + withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) { + val df0 = Seq( + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum") + val df1 = Seq( + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum") + val df = df0.union(df1) + val df2 = df.withColumnRenamed("decNum", "decNum2"). + join(df, "intNum").agg(sum("decNum")) + + val expectedAnswer = Row(null) + assertDecimalSumOverflow(df2, ansiEnabled, expectedAnswer) + + val decStr = "1" + "0" * 19 + val d1 = spark.range(0, 12, 1, 1) + val d2 = d1.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d")) + assertDecimalSumOverflow(d2, ansiEnabled, expectedAnswer) + + val d3 = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1)) + val d4 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d")) + assertDecimalSumOverflow(d4, ansiEnabled, expectedAnswer) + + val d5 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d"), + lit(1).as("key")).groupBy("key").agg(sum($"d").alias("sumd")).select($"sumd") + assertDecimalSumOverflow(d5, ansiEnabled, expectedAnswer) + + val nullsDf = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d")) + + val largeDecimals = Seq(BigDecimal("1"* 20 + ".123"), BigDecimal("9"* 20 + ".123")). + toDF("d") + assertDecimalSumOverflow( + nullsDf.union(largeDecimals).agg(sum($"d")), ansiEnabled, expectedAnswer) + + val df3 = Seq( + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("50000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum") + + val df4 = Seq( + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum") + + val df5 = Seq( + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("20000000000000000000"), 2)).toDF("decNum", "intNum") + + val df6 = df3.union(df4).union(df5) + val df7 = df6.groupBy("intNum").agg(sum("decNum"), countDistinct("decNum")). + filter("intNum == 1") + assertDecimalSumOverflow(df7, ansiEnabled, Row(1, null, 2)) } - assert(e.getCause.getClass.equals(classOf[ArithmeticException])) - assert(e.getCause.getMessage.contains("cannot be represented as Decimal")) } } } @@ -2439,6 +2537,17 @@ class DataFrameSuite extends QueryTest val nestedDecArray = Array(decSpark) checkAnswer(Seq(nestedDecArray).toDF(), Row(Array(wrapRefArray(decJava)))) } + + test("SPARK-31750: eliminate UpCast if child's dataType is DecimalType") { + withTempPath { f => + sql("select cast(1 as decimal(38, 0)) as d") + .write.mode("overwrite") + .parquet(f.getAbsolutePath) + + val df = spark.read.parquet(f.getAbsolutePath).as[BigDecimal] + assert(df.schema === new StructType().add(StructField("d", DecimalType(38, 0)))) + } + } } case class GroupByKey(a: Int, b: Int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index af65957691b37..06600c1e4b1d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1916,6 +1916,16 @@ class DatasetSuite extends QueryTest assert(df1.semanticHash !== df3.semanticHash) assert(df3.semanticHash === df4.semanticHash) } + + test("SPARK-31854: Invoke in MapElementsExec should not propagate null") { + Seq("true", "false").foreach { wholeStage => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> wholeStage) { + val ds = Seq(1.asInstanceOf[Integer], null.asInstanceOf[Integer]).toDS() + val expectedAnswer = Seq[(Integer, Integer)]((1, 1), (null, null)) + checkDataset(ds.map(v => (v, v)), expectedAnswer: _*) + } + } + } } object AssertExecutionId { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 14e6ee2b04c14..c12468a4e70f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -23,7 +23,7 @@ import java.time.{Instant, LocalDateTime, ZoneId} import java.util.{Locale, TimeZone} import java.util.concurrent.TimeUnit -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkUpgradeException} import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{CEST, LA} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions._ @@ -450,9 +450,9 @@ class DateFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer( df.select(to_date(col("s"), "yyyy-hh-MM")), Seq(Row(null), Row(null), Row(null))) - checkAnswer( - df.select(to_date(col("s"), "yyyy-dd-aa")), - Seq(Row(null), Row(null), Row(null))) + val e = intercept[SparkUpgradeException](df.select(to_date(col("s"), "yyyy-dd-aa")).collect()) + assert(e.getCause.isInstanceOf[IllegalArgumentException]) + assert(e.getMessage.contains("You may get a different result due to the upgrading of Spark")) // february val x1 = "2016-02-29" @@ -618,8 +618,16 @@ class DateFunctionsSuite extends QueryTest with SharedSparkSession { Row(secs(ts4.getTime)), Row(null), Row(secs(ts3.getTime)), Row(null))) // invalid format - checkAnswer(df1.selectExpr(s"unix_timestamp(x, 'yyyy-MM-dd aa:HH:ss')"), Seq( - Row(null), Row(null), Row(null), Row(null))) + val invalid = df1.selectExpr(s"unix_timestamp(x, 'yyyy-MM-dd aa:HH:ss')") + if (legacyParserPolicy == "legacy") { + checkAnswer(invalid, + Seq(Row(null), Row(null), Row(null), Row(null))) + } else { + val e = intercept[SparkUpgradeException](invalid.collect()) + assert(e.getCause.isInstanceOf[IllegalArgumentException]) + assert( + e.getMessage.contains("You may get a different result due to the upgrading of Spark")) + } // february val y1 = "2016-02-29" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala index d41d624f1762d..5aeecd2df91e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala @@ -214,9 +214,9 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite val df = sql("select ifnull(id, 'x'), nullif(id, 'x'), nvl(id, 'x'), nvl2(id, 'x', 'y') " + "from range(2)") checkKeywordsExistsInExplain(df, - "Project [coalesce(cast(id#xL as string), x) AS ifnull(`id`, 'x')#x, " + - "id#xL AS nullif(`id`, 'x')#xL, coalesce(cast(id#xL as string), x) AS nvl(`id`, 'x')#x, " + - "x AS nvl2(`id`, 'x', 'y')#x]") + "Project [coalesce(cast(id#xL as string), x) AS ifnull(id, x)#x, " + + "id#xL AS nullif(id, x)#xL, coalesce(cast(id#xL as string), x) AS nvl(id, x)#x, " + + "x AS nvl2(id, x, y)#x]") } test("SPARK-26659: explain of DataWritingCommandExec should not contain duplicate cmd.nodeName") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionsSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExpressionsSchemaSuite.scala index 6d6cbf7508d1b..4c9ba9455c33f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionsSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExpressionsSchemaSuite.scala @@ -137,7 +137,7 @@ class ExpressionsSchemaSuite extends QueryTest with SharedSparkSession { } val header = Seq( - s"", + s"", "## Summary", s" - Number of queries: ${outputs.size}", s" - Number of expressions that missing example: ${missingExamples.size}", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala index f68c416941266..234978b9ce176 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.log4j.Level -import org.apache.spark.sql.catalyst.optimizer.EliminateResolvedHint +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide, EliminateResolvedHint} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MetadataCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MetadataCacheSuite.scala index a9f443be69cb2..956bd7861d99d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MetadataCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MetadataCacheSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql import java.io.File import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.sql.execution.adaptive.AdaptiveTestUtils.assertExceptionMessage import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -56,8 +55,8 @@ abstract class MetadataCacheSuite extends QueryTest with SharedSparkSession { val e = intercept[SparkException] { df.count() } - assertExceptionMessage(e, "FileNotFoundException") - assertExceptionMessage(e, "recreating the Dataset/DataFrame involved") + assert(e.getMessage.contains("FileNotFoundException")) + assert(e.getMessage.contains("recreating the Dataset/DataFrame involved")) } } } @@ -85,8 +84,8 @@ class MetadataCacheV1Suite extends MetadataCacheSuite { val e = intercept[SparkException] { sql("select count(*) from view_refresh").first() } - assertExceptionMessage(e, "FileNotFoundException") - assertExceptionMessage(e, "REFRESH") + assert(e.getMessage.contains("FileNotFoundException")) + assert(e.getMessage.contains("REFRESH")) // Refresh and we should be able to read it again. spark.catalog.refreshTable("view_refresh") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 4a21ae9242039..e52d2262a6bf8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.util.{Locale, TimeZone} +import java.util.TimeZone import scala.collection.JavaConverters._ @@ -35,11 +35,6 @@ abstract class QueryTest extends PlanTest { protected def spark: SparkSession - // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) - TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) - // Add Locale setting - Locale.setDefault(Locale.US) - /** * Runs the plan and makes sure the answer contains all of the keywords. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 87012f304fe02..45c38ac76cc05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3547,6 +3547,13 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark |with cube |""".stripMargin), Nil) } + + test("SPARK-31761: test byte, short, integer overflow for (Divide) integral type") { + checkAnswer(sql("Select -2147483648 DIV -1"), Seq(Row(Integer.MIN_VALUE.toLong * -1))) + checkAnswer(sql("select CAST(-128 as Byte) DIV CAST (-1 as Byte)"), + Seq(Row(Byte.MinValue.toLong * -1))) + checkAnswer(sql("select CAST(-32768 as short) DIV CAST (-1 as short)"), + Seq(Row(Short.MinValue.toLong * -1))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 2b977e74ebd15..92da58c27a141 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql import java.io.File -import java.util.{Locale, TimeZone} +import java.util.Locale import java.util.regex.Pattern import scala.collection.mutable.{ArrayBuffer, HashMap} @@ -672,16 +672,9 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession { session.sql("DROP TABLE IF EXISTS tenk1") } - private val originalTimeZone = TimeZone.getDefault - private val originalLocale = Locale.getDefault - override def beforeAll(): Unit = { super.beforeAll() createTestTables(spark) - // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) - TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) - // Add Locale setting - Locale.setDefault(Locale.US) RuleExecutor.resetMetrics() CodeGenerator.resetCompileTime() WholeStageCodegenExec.resetCodeGenTime() @@ -689,8 +682,6 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession { override def afterAll(): Unit = { try { - TimeZone.setDefault(originalTimeZone) - Locale.setDefault(originalLocale) removeTestTables(spark) // For debugging dump some statistics about how much time was spent in various optimizer rules diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala index 7b76d0702d835..0a522fdbdeed8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala @@ -169,6 +169,31 @@ class SparkSessionBuilderSuite extends SparkFunSuite with BeforeAndAfterEach { assert(session.sessionState.conf.getConf(GLOBAL_TEMP_DATABASE) === "globaltempdb-spark-31234") } + test("SPARK-31354: SparkContext only register one SparkSession ApplicationEnd listener") { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test-app-SPARK-31354-1") + val context = new SparkContext(conf) + SparkSession + .builder() + .sparkContext(context) + .master("local") + .getOrCreate() + val postFirstCreation = context.listenerBus.listeners.size() + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + + SparkSession + .builder() + .sparkContext(context) + .master("local") + .getOrCreate() + val postSecondCreation = context.listenerBus.listeners.size() + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + assert(postFirstCreation == postSecondCreation) + } + test("SPARK-31532: should not propagate static sql configs to the existing" + " active/default SparkSession") { val session = SparkSession.builder() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index e947e15a179e8..8462ce5a6c44f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -256,8 +256,8 @@ class DataSourceV2SQLSuite checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), Seq.empty) } - test("CreateTable: without USING clause") { - spark.conf.set(SQLConf.LEGACY_CREATE_HIVE_TABLE_BY_DEFAULT_ENABLED.key, "false") + // TODO: ignored by SPARK-31707, restore the test after create table syntax unification + ignore("CreateTable: without USING clause") { // unset this config to use the default v2 session catalog. spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key) val testCatalog = catalog("testcat").asTableCatalog @@ -681,8 +681,8 @@ class DataSourceV2SQLSuite } } - test("CreateTableAsSelect: without USING clause") { - spark.conf.set(SQLConf.LEGACY_CREATE_HIVE_TABLE_BY_DEFAULT_ENABLED.key, "false") + // TODO: ignored by SPARK-31707, restore the test after create table syntax unification + ignore("CreateTableAsSelect: without USING clause") { // unset this config to use the default v2 session catalog. spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key) val testCatalog = catalog("testcat").asTableCatalog diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala index f1411b263c77b..c99be986ddca5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.execution +import java.io.File + import scala.collection.mutable import org.apache.hadoop.fs.Path @@ -116,6 +118,30 @@ class DataSourceScanExecRedactionSuite extends DataSourceScanRedactionTest { assert(isIncluded(df.queryExecution, "Location")) } } + + test("SPARK-31793: FileSourceScanExec metadata should contain limited file paths") { + withTempPath { path => + val dir = path.getCanonicalPath + val partitionCol = "partitionCol" + spark.range(10) + .select("id", "id") + .toDF("value", partitionCol) + .write + .partitionBy(partitionCol) + .orc(dir) + val paths = (0 to 9).map(i => new File(dir, s"$partitionCol=$i").getCanonicalPath) + val plan = spark.read.orc(paths: _*).queryExecution.executedPlan + val location = plan collectFirst { + case f: FileSourceScanExec => f.metadata("Location") + } + assert(location.isDefined) + // The location metadata should at least contain one path + assert(location.get.contains(paths.head)) + // If the temp path length is larger than 100, the metadata length should not exceed + // twice of the length; otherwise, the metadata length should be controlled within 200. + assert(location.get.length < Math.max(paths.head.length, 100) * 2) + } + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala index 5e81c74420fd0..a0b212d2cf6fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala @@ -17,21 +17,27 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.catalyst.util.DateTimeTestUtils import org.apache.spark.sql.connector.InMemoryTableCatalog +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession} class HiveResultSuite extends SharedSparkSession { import testImplicits._ test("date formatting in hive result") { - val dates = Seq("2018-12-28", "1582-10-03", "1582-10-04", "1582-10-15") - val df = dates.toDF("a").selectExpr("cast(a as date) as b") - val executedPlan1 = df.queryExecution.executedPlan - val result = HiveResult.hiveResultString(executedPlan1) - assert(result == dates) - val executedPlan2 = df.selectExpr("array(b)").queryExecution.executedPlan - val result2 = HiveResult.hiveResultString(executedPlan2) - assert(result2 == dates.map(x => s"[$x]")) + DateTimeTestUtils.outstandingTimezonesIds.foreach { zoneId => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> zoneId) { + val dates = Seq("2018-12-28", "1582-10-03", "1582-10-04", "1582-10-15") + val df = dates.toDF("a").selectExpr("cast(a as date) as b") + val executedPlan1 = df.queryExecution.executedPlan + val result = HiveResult.hiveResultString(executedPlan1) + assert(result == dates) + val executedPlan2 = df.selectExpr("array(b)").queryExecution.executedPlan + val result2 = HiveResult.hiveResultString(executedPlan2) + assert(result2 == dates.map(x => s"[$x]")) + } + } } test("timestamp formatting in hive result") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala index eca39f3f81726..5c35cedba9bab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala @@ -53,6 +53,7 @@ class QueryExecutionSuite extends SharedSparkSession { s"*(1) Range (0, $expected, step=1, splits=2)", "")) } + test("dumping query execution info to a file") { withTempDir { dir => val path = dir.getCanonicalPath + "/plans.txt" @@ -93,6 +94,25 @@ class QueryExecutionSuite extends SharedSparkSession { assert(exception.getMessage.contains("Illegal character in scheme name")) } + test("dumping query execution info to a file - explainMode=formatted") { + withTempDir { dir => + val path = dir.getCanonicalPath + "/plans.txt" + val df = spark.range(0, 10) + df.queryExecution.debug.toFile(path, explainMode = Option("formatted")) + assert(Source.fromFile(path).getLines.toList + .takeWhile(_ != "== Whole Stage Codegen ==").map(_.replaceAll("#\\d+", "#x")) == List( + "== Physical Plan ==", + s"* Range (1)", + "", + "", + s"(1) Range [codegen id : 1]", + "Output [1]: [id#xL]", + s"Arguments: Range (0, 10, step=1, splits=Some(2))", + "", + "")) + } + } + test("limit number of fields by sql config") { def relationPlans: String = { val ds = spark.createDataset(Seq(QueryExecutionTestRecord( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryPlanningTrackerEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryPlanningTrackerEndToEndSuite.scala index 987338cf6cbbf..5ff459513e848 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryPlanningTrackerEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryPlanningTrackerEndToEndSuite.scala @@ -58,4 +58,13 @@ class QueryPlanningTrackerEndToEndSuite extends StreamTest { StopStream) } + test("The start times should be in order: parsing <= analysis <= optimization <= planning") { + val df = spark.sql("select count(*) from range(1)") + df.queryExecution.executedPlan + val phases = df.queryExecution.tracker.phases + assert(phases("parsing").startTimeMs <= phases("analysis").startTimeMs) + assert(phases("analysis").startTimeMs <= phases("optimization").startTimeMs) + assert(phases("optimization").startTimeMs <= phases("planning").startTimeMs) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index b29e822add8bc..7ddf9d87a6aca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -23,6 +23,7 @@ import scala.util.control.NonFatal import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{DataFrame, Row, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.test.SQLTestUtils /** @@ -237,7 +238,7 @@ object SparkPlanTest { * @param spark SqlContext used for execution of the plan */ def executePlan(outputPlan: SparkPlan, spark: SQLContext): Seq[Row] = { - val execution = new QueryExecution(spark.sparkSession, null) { + val execution = new QueryExecution(spark.sparkSession, LocalRelation(Nil)) { override lazy val sparkPlan: SparkPlan = outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 7a23e048dadbd..3d0ba05f76b71 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -24,11 +24,12 @@ import org.apache.log4j.Level import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} import org.apache.spark.sql.{QueryTest, Row, SparkSession, Strategy} +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} -import org.apache.spark.sql.execution.{ReusedSubqueryExec, ShuffledRowRDD, SparkPlan} +import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, ReusedSubqueryExec, ShuffledRowRDD, SparkPlan} import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec} -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft, BuildRight, SortMergeJoinExec} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -628,98 +629,92 @@ class AdaptiveQueryExecSuite withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", - SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "700") { + SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100") { withTempView("skewData1", "skewData2") { spark .range(0, 1000, 1, 10) - .selectExpr("id % 2 as key1", "id as value1") + .selectExpr("id % 3 as key1", "id as value1") .createOrReplaceTempView("skewData1") spark .range(0, 1000, 1, 10) .selectExpr("id % 1 as key2", "id as value2") .createOrReplaceTempView("skewData2") - val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT key1 FROM skewData1 join skewData2 ON key1 = key2 group by key1") + + def checkSkewJoin(query: String, optimizeSkewJoin: Boolean): Unit = { + val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(query) + val innerSmj = findTopLevelSortMergeJoin(innerAdaptivePlan) + assert(innerSmj.size == 1 && innerSmj.head.isSkewJoin == optimizeSkewJoin) + } + + checkSkewJoin( + "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2", true) // Additional shuffle introduced, so disable the "OptimizeSkewedJoin" optimization - val innerSmj = findTopLevelSortMergeJoin(innerAdaptivePlan) - assert(innerSmj.size == 1 && !innerSmj.head.isSkewJoin) + checkSkewJoin( + "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 GROUP BY key1", false) } } } - // TODO: we need a way to customize data distribution after shuffle, to improve test coverage - // of this case. test("SPARK-29544: adaptive skew join with different join types") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", - SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "2000", - SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "2000") { + SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", + SQLConf.SHUFFLE_PARTITIONS.key -> "100", + SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "800", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "800") { withTempView("skewData1", "skewData2") { spark .range(0, 1000, 1, 10) - .selectExpr("id % 2 as key1", "id as value1") + .select( + when('id < 250, 249) + .when('id >= 750, 1000) + .otherwise('id).as("key1"), + 'id as "value1") .createOrReplaceTempView("skewData1") spark .range(0, 1000, 1, 10) - .selectExpr("id % 1 as key2", "id as value2") + .select( + when('id < 250, 249) + .otherwise('id).as("key2"), + 'id as "value2") .createOrReplaceTempView("skewData2") - def checkSkewJoin(joins: Seq[SortMergeJoinExec], expectedNumPartitions: Int): Unit = { + def checkSkewJoin( + joins: Seq[SortMergeJoinExec], + leftSkewNum: Int, + rightSkewNum: Int): Unit = { assert(joins.size == 1 && joins.head.isSkewJoin) assert(joins.head.left.collect { case r: CustomShuffleReaderExec => r - }.head.partitionSpecs.length == expectedNumPartitions) + }.head.partitionSpecs.collect { + case p: PartialReducerPartitionSpec => p.reducerIndex + }.distinct.length == leftSkewNum) assert(joins.head.right.collect { case r: CustomShuffleReaderExec => r - }.head.partitionSpecs.length == expectedNumPartitions) + }.head.partitionSpecs.collect { + case p: PartialReducerPartitionSpec => p.reducerIndex + }.distinct.length == rightSkewNum) } // skewed inner join optimization val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult( "SELECT * FROM skewData1 join skewData2 ON key1 = key2") - // left stats: [3496, 0, 0, 0, 4014] - // right stats:[6292, 0, 0, 0, 0] - // Partition 0: both left and right sides are skewed, left side is divided - // into 2 splits and right side is divided into 4 splits, so - // 2 x 4 sub-partitions. - // Partition 1, 2, 3: not skewed, and coalesced into 1 partition, but it's ignored as the - // size is 0. - // Partition 4: only left side is skewed, and divide into 2 splits, so - // 2 sub-partitions. - // So total (8 + 0 + 2) partitions. val innerSmj = findTopLevelSortMergeJoin(innerAdaptivePlan) - checkSkewJoin(innerSmj, 8 + 0 + 2) + checkSkewJoin(innerSmj, 2, 1) // skewed left outer join optimization val (_, leftAdaptivePlan) = runAdaptiveAndVerifyResult( "SELECT * FROM skewData1 left outer join skewData2 ON key1 = key2") - // left stats: [3496, 0, 0, 0, 4014] - // right stats:[6292, 0, 0, 0, 0] - // Partition 0: both left and right sides are skewed, but left join can't split right side, - // so only left side is divided into 2 splits, and thus 2 sub-partitions. - // Partition 1, 2, 3: not skewed, and coalesced into 1 partition, but it's ignored as the - // size is 0. - // Partition 4: only left side is skewed, and divide into 2 splits, so - // 2 sub-partitions. - // So total (2 + 0 + 2) partitions. val leftSmj = findTopLevelSortMergeJoin(leftAdaptivePlan) - checkSkewJoin(leftSmj, 2 + 0 + 2) + checkSkewJoin(leftSmj, 2, 0) // skewed right outer join optimization val (_, rightAdaptivePlan) = runAdaptiveAndVerifyResult( "SELECT * FROM skewData1 right outer join skewData2 ON key1 = key2") - // left stats: [3496, 0, 0, 0, 4014] - // right stats:[6292, 0, 0, 0, 0] - // Partition 0: both left and right sides are skewed, but right join can't split left side, - // so only right side is divided into 4 splits, and thus 4 sub-partitions. - // Partition 1, 2, 3: not skewed, and coalesced into 1 partition, but it's ignored as the - // size is 0. - // Partition 4: only left side is skewed, but right join can't split left side, so just - // 1 partition. - // So total (4 + 0 + 1) partitions. val rightSmj = findTopLevelSortMergeJoin(rightAdaptivePlan) - checkSkewJoin(rightSmj, 4 + 0 + 1) + checkSkewJoin(rightSmj, 0, 1) } } } @@ -740,7 +735,7 @@ class AdaptiveQueryExecSuite val error = intercept[Exception] { agged.count() } - assert(error.getCause().toString contains "Early failed query stage found") + assert(error.getCause().toString contains "Invalid bucket file") assert(error.getSuppressed.size === 0) } } @@ -873,28 +868,40 @@ class AdaptiveQueryExecSuite withSQLConf( SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", - SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "2000", - SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "2000") { + SQLConf.SHUFFLE_PARTITIONS.key -> "100", + SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "800", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "800") { withTempView("skewData1", "skewData2") { spark .range(0, 1000, 1, 10) - .selectExpr("id % 2 as key1", "id as value1") + .select( + when('id < 250, 249) + .when('id >= 750, 1000) + .otherwise('id).as("key1"), + 'id as "value1") .createOrReplaceTempView("skewData1") spark .range(0, 1000, 1, 10) - .selectExpr("id % 1 as key2", "id as value2") + .select( + when('id < 250, 249) + .otherwise('id).as("key2"), + 'id as "value2") .createOrReplaceTempView("skewData2") val (_, adaptivePlan) = runAdaptiveAndVerifyResult( "SELECT * FROM skewData1 join skewData2 ON key1 = key2") - val reader = collect(adaptivePlan) { + val readers = collect(adaptivePlan) { case r: CustomShuffleReaderExec => r - }.head - assert(!reader.isLocalReader) - assert(reader.hasSkewedPartition) - assert(!reader.hasCoalescedPartition) // 0-size partitions are ignored. - assert(reader.metrics.contains("numSkewedPartitions")) - assert(reader.metrics("numSkewedPartitions").value > 0) - assert(reader.metrics("numSkewedSplits").value > 0) + } + readers.foreach { reader => + assert(!reader.isLocalReader) + assert(reader.hasCoalescedPartition) + assert(reader.hasSkewedPartition) + assert(reader.metrics.contains("numSkewedPartitions")) + } + assert(readers(0).metrics("numSkewedPartitions").value == 2) + assert(readers(0).metrics("numSkewedSplits").value == 15) + assert(readers(1).metrics("numSkewedPartitions").value == 1) + assert(readers(1).metrics("numSkewedSplits").value == 12) } } } @@ -943,9 +950,11 @@ class AdaptiveQueryExecSuite test("SPARK-30953: InsertAdaptiveSparkPlan should apply AQE on child plan of write commands") { withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { - val plan = sql("CREATE TABLE t1 AS SELECT 1 col").queryExecution.executedPlan - assert(plan.isInstanceOf[DataWritingCommandExec]) - assert(plan.asInstanceOf[DataWritingCommandExec].child.isInstanceOf[AdaptiveSparkPlanExec]) + withTable("t1") { + val plan = sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").queryExecution.executedPlan + assert(plan.isInstanceOf[DataWritingCommandExec]) + assert(plan.asInstanceOf[DataWritingCommandExec].child.isInstanceOf[AdaptiveSparkPlanExec]) + } } } @@ -985,4 +994,31 @@ class AdaptiveQueryExecSuite } } } + + test("SPARK-31658: SQL UI should show write commands") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { + withTable("t1") { + var checkDone = false + val listener = new SparkListener { + override def onOtherEvent(event: SparkListenerEvent): Unit = { + event match { + case SparkListenerSQLAdaptiveExecutionUpdate(_, _, planInfo) => + assert(planInfo.nodeName == "Execute CreateDataSourceTableAsSelectCommand") + checkDone = true + case _ => // ignore other events + } + } + } + spark.sparkContext.addSparkListener(listener) + try { + sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").collect() + spark.sparkContext.listenerBus.waitUntilEmpty() + assert(checkDone) + } finally { + spark.sparkContext.removeSparkListener(listener) + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveTestUtils.scala index ddaeb57d31547..48f85ae76cd8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveTestUtils.scala @@ -69,25 +69,3 @@ trait DisableAdaptiveExecutionSuite extends SQLTestUtils { } } } - -object AdaptiveTestUtils { - def assertExceptionMessage(e: Exception, expected: String): Unit = { - val stringWriter = new StringWriter() - e.printStackTrace(new PrintWriter(stringWriter)) - val errorMsg = stringWriter.toString - assert(errorMsg.contains(expected)) - } - - def assertExceptionCause(t: Throwable, causeClass: Class[_]): Unit = { - var c = t.getCause - var foundCause = false - while (c != null && !foundCause) { - if (causeClass.isAssignableFrom(c.getClass)) { - foundCause = true - } else { - c = c.getCause - } - } - assert(foundCause, s"Can not find cause: $causeClass") - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DateTimeRebaseBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DateTimeRebaseBenchmark.scala index aa47d36fe8c87..d6167f98b5a51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DateTimeRebaseBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DateTimeRebaseBenchmark.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.util.DateTimeConstants.SECONDS_PER_DAY import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, LA} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy._ import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType object DateTime extends Enumeration { @@ -161,9 +162,10 @@ object DateTimeRebaseBenchmark extends SqlBasedBenchmark { Seq(true, false).foreach { modernDates => Seq(false, true).foreach { rebase => benchmark.addCase(caseName(modernDates, dateTime, Some(rebase)), 1) { _ => + val mode = if (rebase) LEGACY else CORRECTED withSQLConf( SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> getOutputType(dateTime), - SQLConf.LEGACY_PARQUET_REBASE_DATETIME_IN_WRITE.key -> rebase.toString) { + SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_WRITE.key -> mode.toString) { genDF(rowsNum, dateTime, modernDates) .write .mode("overwrite") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index 28e5082886b67..c6a533dfae4d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -40,8 +40,7 @@ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class DDLParserSuite extends AnalysisTest with SharedSparkSession { - private lazy val parser = new SparkSqlParser(new SQLConf().copy( - SQLConf.LEGACY_CREATE_HIVE_TABLE_BY_DEFAULT_ENABLED -> false)) + private lazy val parser = new SparkSqlParser(new SQLConf) private def assertUnsupported(sql: String, containsThesePhrases: Seq[String] = Seq()): Unit = { val e = intercept[ParseException] { @@ -76,12 +75,6 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { }.head } - private def withCreateTableStatement(sql: String)(prediction: CreateTableStatement => Unit) - : Unit = { - val statement = parser.parsePlan(sql).asInstanceOf[CreateTableStatement] - prediction(statement) - } - test("alter database - property values must be set") { assertUnsupported( sql = "ALTER DATABASE my_db SET DBPROPERTIES('key_without_value', 'key_with_value'='x')", @@ -487,17 +480,21 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { test("Test CTAS #3") { val s3 = """CREATE TABLE page_view AS SELECT * FROM src""" - val statement = parser.parsePlan(s3).asInstanceOf[CreateTableAsSelectStatement] - assert(statement.tableName(0) == "page_view") - assert(statement.asSelect == parser.parsePlan("SELECT * FROM src")) - assert(statement.partitioning.isEmpty) - assert(statement.bucketSpec.isEmpty) - assert(statement.properties.isEmpty) - assert(statement.provider.isEmpty) - assert(statement.options.isEmpty) - assert(statement.location.isEmpty) - assert(statement.comment.isEmpty) - assert(!statement.ifNotExists) + val (desc, exists) = extractTableDesc(s3) + assert(exists == false) + assert(desc.identifier.database == None) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.MANAGED) + assert(desc.storage.locationUri == None) + assert(desc.schema.isEmpty) + assert(desc.viewText == None) // TODO will be SQLText + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.storage.properties == Map()) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat")) + assert(desc.storage.outputFormat == + Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) + assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + assert(desc.properties == Map()) } test("Test CTAS #4") { @@ -657,60 +654,67 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { test("create table - basic") { val query = "CREATE TABLE my_table (id int, name string)" - withCreateTableStatement(query) { state => - assert(state.tableName(0) == "my_table") - assert(state.tableSchema == new StructType().add("id", "int").add("name", "string")) - assert(state.partitioning.isEmpty) - assert(state.bucketSpec.isEmpty) - assert(state.properties.isEmpty) - assert(state.provider.isEmpty) - assert(state.options.isEmpty) - assert(state.location.isEmpty) - assert(state.comment.isEmpty) - assert(!state.ifNotExists) - } + val (desc, allowExisting) = extractTableDesc(query) + assert(!allowExisting) + assert(desc.identifier.database.isEmpty) + assert(desc.identifier.table == "my_table") + assert(desc.tableType == CatalogTableType.MANAGED) + assert(desc.schema == new StructType().add("id", "int").add("name", "string")) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.bucketSpec.isEmpty) + assert(desc.viewText.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.storage.locationUri.isEmpty) + assert(desc.storage.inputFormat == + Some("org.apache.hadoop.mapred.TextInputFormat")) + assert(desc.storage.outputFormat == + Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) + assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + assert(desc.storage.properties.isEmpty) + assert(desc.properties.isEmpty) + assert(desc.comment.isEmpty) } test("create table - with database name") { val query = "CREATE TABLE dbx.my_table (id int, name string)" - withCreateTableStatement(query) { state => - assert(state.tableName(0) == "dbx") - assert(state.tableName(1) == "my_table") - } + val (desc, _) = extractTableDesc(query) + assert(desc.identifier.database == Some("dbx")) + assert(desc.identifier.table == "my_table") } test("create table - temporary") { val query = "CREATE TEMPORARY TABLE tab1 (id int, name string)" val e = intercept[ParseException] { parser.parsePlan(query) } - assert(e.message.contains("CREATE TEMPORARY TABLE without a provider is not allowed.")) + assert(e.message.contains("CREATE TEMPORARY TABLE is not supported yet")) } test("create table - external") { val query = "CREATE EXTERNAL TABLE tab1 (id int, name string) LOCATION '/path/to/nowhere'" - val e = intercept[ParseException] { parser.parsePlan(query) } - assert(e.message.contains("Operation not allowed: CREATE EXTERNAL TABLE ...")) + val (desc, _) = extractTableDesc(query) + assert(desc.tableType == CatalogTableType.EXTERNAL) + assert(desc.storage.locationUri == Some(new URI("/path/to/nowhere"))) } test("create table - if not exists") { val query = "CREATE TABLE IF NOT EXISTS tab1 (id int, name string)" - withCreateTableStatement(query) { state => - assert(state.ifNotExists) - } + val (_, allowExisting) = extractTableDesc(query) + assert(allowExisting) } test("create table - comment") { val query = "CREATE TABLE my_table (id int, name string) COMMENT 'its hot as hell below'" - withCreateTableStatement(query) { state => - assert(state.comment == Some("its hot as hell below")) - } + val (desc, _) = extractTableDesc(query) + assert(desc.comment == Some("its hot as hell below")) } test("create table - partitioned columns") { - val query = "CREATE TABLE my_table (id int, name string) PARTITIONED BY (id)" - withCreateTableStatement(query) { state => - val transform = IdentityTransform(FieldReference(Seq("id"))) - assert(state.partitioning == Seq(transform)) - } + val query = "CREATE TABLE my_table (id int, name string) PARTITIONED BY (month int)" + val (desc, _) = extractTableDesc(query) + assert(desc.schema == new StructType() + .add("id", "int") + .add("name", "string") + .add("month", "int")) + assert(desc.partitionColumnNames == Seq("month")) } test("create table - clustered by") { @@ -726,22 +730,20 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { """ val query1 = s"$baseQuery INTO $numBuckets BUCKETS" - withCreateTableStatement(query1) { state => - assert(state.bucketSpec.isDefined) - val bucketSpec = state.bucketSpec.get - assert(bucketSpec.numBuckets == numBuckets) - assert(bucketSpec.bucketColumnNames.head.equals(bucketedColumn)) - assert(bucketSpec.sortColumnNames.isEmpty) - } + val (desc1, _) = extractTableDesc(query1) + assert(desc1.bucketSpec.isDefined) + val bucketSpec1 = desc1.bucketSpec.get + assert(bucketSpec1.numBuckets == numBuckets) + assert(bucketSpec1.bucketColumnNames.head.equals(bucketedColumn)) + assert(bucketSpec1.sortColumnNames.isEmpty) val query2 = s"$baseQuery SORTED BY($sortColumn) INTO $numBuckets BUCKETS" - withCreateTableStatement(query2) { state => - assert(state.bucketSpec.isDefined) - val bucketSpec = state.bucketSpec.get - assert(bucketSpec.numBuckets == numBuckets) - assert(bucketSpec.bucketColumnNames.head.equals(bucketedColumn)) - assert(bucketSpec.sortColumnNames.head.equals(sortColumn)) - } + val (desc2, _) = extractTableDesc(query2) + assert(desc2.bucketSpec.isDefined) + val bucketSpec2 = desc2.bucketSpec.get + assert(bucketSpec2.numBuckets == numBuckets) + assert(bucketSpec2.bucketColumnNames.head.equals(bucketedColumn)) + assert(bucketSpec2.sortColumnNames.head.equals(sortColumn)) } test("create table(hive) - skewed by") { @@ -811,9 +813,8 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { test("create table - properties") { val query = "CREATE TABLE my_table (id int, name string) TBLPROPERTIES ('k1'='v1', 'k2'='v2')" - withCreateTableStatement(query) { state => - assert(state.properties == Map("k1" -> "v1", "k2" -> "v2")) - } + val (desc, _) = extractTableDesc(query) + assert(desc.properties == Map("k1" -> "v1", "k2" -> "v2")) } test("create table(hive) - everything!") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 10ad8acc68937..e4709e469dca3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1203,14 +1203,24 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } test("alter table: recover partitions (sequential)") { - withSQLConf(RDD_PARALLEL_LISTING_THRESHOLD.key -> "10") { + val oldRddParallelListingThreshold = spark.sparkContext.conf.get( + RDD_PARALLEL_LISTING_THRESHOLD) + try { + spark.sparkContext.conf.set(RDD_PARALLEL_LISTING_THRESHOLD.key, "10") testRecoverPartitions() + } finally { + spark.sparkContext.conf.set(RDD_PARALLEL_LISTING_THRESHOLD, oldRddParallelListingThreshold) } } test("alter table: recover partition (parallel)") { - withSQLConf(RDD_PARALLEL_LISTING_THRESHOLD.key -> "0") { + val oldRddParallelListingThreshold = spark.sparkContext.conf.get( + RDD_PARALLEL_LISTING_THRESHOLD) + try { + spark.sparkContext.conf.set(RDD_PARALLEL_LISTING_THRESHOLD.key, "0") testRecoverPartitions() + } finally { + spark.sparkContext.conf.set(RDD_PARALLEL_LISTING_THRESHOLD, oldRddParallelListingThreshold) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProviderSuite.scala new file mode 100644 index 0000000000000..ff5fe4f620a1d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProviderSuite.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc.connection + +import javax.security.auth.login.Configuration + +class ConnectionProviderSuite extends ConnectionProviderSuiteBase { + test("Multiple security configs must be reachable") { + Configuration.setConfiguration(null) + val postgresDriver = registerDriver(PostgresConnectionProvider.driverClass) + val postgresProvider = new PostgresConnectionProvider( + postgresDriver, options("jdbc:postgresql://localhost/postgres")) + val db2Driver = registerDriver(DB2ConnectionProvider.driverClass) + val db2Provider = new DB2ConnectionProvider(db2Driver, options("jdbc:db2://localhost/db2")) + + // Make sure no authentication for the databases are set + val oldConfig = Configuration.getConfiguration + assert(oldConfig.getAppConfigurationEntry(postgresProvider.appEntry) == null) + assert(oldConfig.getAppConfigurationEntry(db2Provider.appEntry) == null) + + postgresProvider.setAuthenticationConfigIfNeeded() + db2Provider.setAuthenticationConfigIfNeeded() + + // Make sure authentication for the databases are set + val newConfig = Configuration.getConfiguration + assert(oldConfig != newConfig) + assert(newConfig.getAppConfigurationEntry(postgresProvider.appEntry) != null) + assert(newConfig.getAppConfigurationEntry(db2Provider.appEntry) != null) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index dcea4835b5f2a..3f8ee12f97776 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -35,7 +35,6 @@ import org.apache.spark.sql.{functions => F, _} import org.apache.spark.sql.catalyst.json._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.ExternalRDD -import org.apache.spark.sql.execution.adaptive.AdaptiveTestUtils.assertExceptionMessage import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -2239,7 +2238,7 @@ abstract class JsonSuite extends QueryTest with SharedSparkSession with TestJson .count() } - assertExceptionMessage(exception, "Malformed records are detected in record parsing") + assert(exception.getMessage.contains("Malformed records are detected in record parsing")) } def checkEncoding(expectedEncoding: String, pathToJsonFiles: String, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index 60f278b8e5bb0..9caf0c836f711 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -34,7 +34,6 @@ import org.apache.orc.mapreduce.OrcInputFormat import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.execution.adaptive.AdaptiveTestUtils.assertExceptionMessage import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, RecordReaderIterator} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -599,19 +598,19 @@ abstract class OrcQueryTest extends OrcTest { val e1 = intercept[SparkException] { testIgnoreCorruptFiles() } - assertExceptionMessage(e1, "Malformed ORC file") + assert(e1.getMessage.contains("Malformed ORC file")) val e2 = intercept[SparkException] { testIgnoreCorruptFilesWithoutSchemaInfer() } - assertExceptionMessage(e2, "Malformed ORC file") + assert(e2.getMessage.contains("Malformed ORC file")) val e3 = intercept[SparkException] { testAllCorruptFiles() } - assertExceptionMessage(e3, "Could not read footer for file") + assert(e3.getMessage.contains("Could not read footer for file")) val e4 = intercept[SparkException] { testAllCorruptFilesWithoutSchemaInfer() } - assertExceptionMessage(e4, "Malformed ORC file") + assert(e4.getMessage.contains("Malformed ORC file")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 5cf21293fd07f..d20a07f420e87 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.math.{BigDecimal => JBigDecimal} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} -import java.time.LocalDate +import java.time.{LocalDate, LocalDateTime, ZoneId} import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Operators} import org.apache.parquet.filter2.predicate.FilterApi._ @@ -143,7 +143,10 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } } - private def testTimestampPushdown(data: Seq[Timestamp]): Unit = { + private def testTimestampPushdown(data: Seq[String], java8Api: Boolean): Unit = { + implicit class StringToTs(s: String) { + def ts: Timestamp = Timestamp.valueOf(s) + } assert(data.size === 4) val ts1 = data.head val ts2 = data(1) @@ -151,7 +154,18 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared val ts4 = data(3) import testImplicits._ - withNestedDataFrame(data.map(i => Tuple1(i)).toDF()) { case (inputDF, colName, resultFun) => + val df = data.map(i => Tuple1(Timestamp.valueOf(i))).toDF() + withNestedDataFrame(df) { case (inputDF, colName, fun) => + def resultFun(tsStr: String): Any = { + val parsed = if (java8Api) { + LocalDateTime.parse(tsStr.replace(" ", "T")) + .atZone(ZoneId.systemDefault()) + .toInstant + } else { + Timestamp.valueOf(tsStr) + } + fun(parsed) + } withParquetDataFrame(inputDF) { implicit df => val tsAttr = df(colName).expr assert(df(colName).expr.dataType === TimestampType) @@ -160,26 +174,26 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared checkFilterPredicate(tsAttr.isNotNull, classOf[NotEq[_]], data.map(i => Row.apply(resultFun(i)))) - checkFilterPredicate(tsAttr === ts1, classOf[Eq[_]], resultFun(ts1)) - checkFilterPredicate(tsAttr <=> ts1, classOf[Eq[_]], resultFun(ts1)) - checkFilterPredicate(tsAttr =!= ts1, classOf[NotEq[_]], + checkFilterPredicate(tsAttr === ts1.ts, classOf[Eq[_]], resultFun(ts1)) + checkFilterPredicate(tsAttr <=> ts1.ts, classOf[Eq[_]], resultFun(ts1)) + checkFilterPredicate(tsAttr =!= ts1.ts, classOf[NotEq[_]], Seq(ts2, ts3, ts4).map(i => Row.apply(resultFun(i)))) - checkFilterPredicate(tsAttr < ts2, classOf[Lt[_]], resultFun(ts1)) - checkFilterPredicate(tsAttr > ts1, classOf[Gt[_]], + checkFilterPredicate(tsAttr < ts2.ts, classOf[Lt[_]], resultFun(ts1)) + checkFilterPredicate(tsAttr > ts1.ts, classOf[Gt[_]], Seq(ts2, ts3, ts4).map(i => Row.apply(resultFun(i)))) - checkFilterPredicate(tsAttr <= ts1, classOf[LtEq[_]], resultFun(ts1)) - checkFilterPredicate(tsAttr >= ts4, classOf[GtEq[_]], resultFun(ts4)) - - checkFilterPredicate(Literal(ts1) === tsAttr, classOf[Eq[_]], resultFun(ts1)) - checkFilterPredicate(Literal(ts1) <=> tsAttr, classOf[Eq[_]], resultFun(ts1)) - checkFilterPredicate(Literal(ts2) > tsAttr, classOf[Lt[_]], resultFun(ts1)) - checkFilterPredicate(Literal(ts3) < tsAttr, classOf[Gt[_]], resultFun(ts4)) - checkFilterPredicate(Literal(ts1) >= tsAttr, classOf[LtEq[_]], resultFun(ts1)) - checkFilterPredicate(Literal(ts4) <= tsAttr, classOf[GtEq[_]], resultFun(ts4)) - - checkFilterPredicate(!(tsAttr < ts4), classOf[GtEq[_]], resultFun(ts4)) - checkFilterPredicate(tsAttr < ts2 || tsAttr > ts3, classOf[Operators.Or], + checkFilterPredicate(tsAttr <= ts1.ts, classOf[LtEq[_]], resultFun(ts1)) + checkFilterPredicate(tsAttr >= ts4.ts, classOf[GtEq[_]], resultFun(ts4)) + + checkFilterPredicate(Literal(ts1.ts) === tsAttr, classOf[Eq[_]], resultFun(ts1)) + checkFilterPredicate(Literal(ts1.ts) <=> tsAttr, classOf[Eq[_]], resultFun(ts1)) + checkFilterPredicate(Literal(ts2.ts) > tsAttr, classOf[Lt[_]], resultFun(ts1)) + checkFilterPredicate(Literal(ts3.ts) < tsAttr, classOf[Gt[_]], resultFun(ts4)) + checkFilterPredicate(Literal(ts1.ts) >= tsAttr, classOf[LtEq[_]], resultFun(ts1)) + checkFilterPredicate(Literal(ts4.ts) <= tsAttr, classOf[GtEq[_]], resultFun(ts4)) + + checkFilterPredicate(!(tsAttr < ts4.ts), classOf[GtEq[_]], resultFun(ts4)) + checkFilterPredicate(tsAttr < ts2.ts || tsAttr > ts3.ts, classOf[Operators.Or], Seq(Row(resultFun(ts1)), Row(resultFun(ts4)))) } } @@ -588,34 +602,41 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } test("filter pushdown - timestamp") { - // spark.sql.parquet.outputTimestampType = TIMESTAMP_MILLIS - val millisData = Seq(Timestamp.valueOf("2018-06-14 08:28:53.123"), - Timestamp.valueOf("2018-06-15 08:28:53.123"), - Timestamp.valueOf("2018-06-16 08:28:53.123"), - Timestamp.valueOf("2018-06-17 08:28:53.123")) - withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> - ParquetOutputTimestampType.TIMESTAMP_MILLIS.toString) { - testTimestampPushdown(millisData) - } - - // spark.sql.parquet.outputTimestampType = TIMESTAMP_MICROS - val microsData = Seq(Timestamp.valueOf("2018-06-14 08:28:53.123456"), - Timestamp.valueOf("2018-06-15 08:28:53.123456"), - Timestamp.valueOf("2018-06-16 08:28:53.123456"), - Timestamp.valueOf("2018-06-17 08:28:53.123456")) - withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> - ParquetOutputTimestampType.TIMESTAMP_MICROS.toString) { - testTimestampPushdown(microsData) - } - - // spark.sql.parquet.outputTimestampType = INT96 doesn't support pushdown - withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> - ParquetOutputTimestampType.INT96.toString) { - import testImplicits._ - withParquetDataFrame(millisData.map(i => Tuple1(i)).toDF()) { implicit df => - val schema = new SparkToParquetSchemaConverter(conf).convert(df.schema) - assertResult(None) { - createParquetFilters(schema).createFilter(sources.IsNull("_1")) + Seq(true, false).foreach { java8Api => + withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> java8Api.toString) { + // spark.sql.parquet.outputTimestampType = TIMESTAMP_MILLIS + val millisData = Seq( + "1000-06-14 08:28:53.123", + "1582-06-15 08:28:53.001", + "1900-06-16 08:28:53.0", + "2018-06-17 08:28:53.999") + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> + ParquetOutputTimestampType.TIMESTAMP_MILLIS.toString) { + testTimestampPushdown(millisData, java8Api) + } + + // spark.sql.parquet.outputTimestampType = TIMESTAMP_MICROS + val microsData = Seq( + "1000-06-14 08:28:53.123456", + "1582-06-15 08:28:53.123456", + "1900-06-16 08:28:53.123456", + "2018-06-17 08:28:53.123456") + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> + ParquetOutputTimestampType.TIMESTAMP_MICROS.toString) { + testTimestampPushdown(microsData, java8Api) + } + + // spark.sql.parquet.outputTimestampType = INT96 doesn't support pushdown + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> + ParquetOutputTimestampType.INT96.toString) { + import testImplicits._ + withParquetDataFrame( + millisData.map(i => Tuple1(Timestamp.valueOf(i))).toDF()) { implicit df => + val schema = new SparkToParquetSchemaConverter(conf).convert(df.schema) + assertResult(None) { + createParquetFilters(schema).createFilter(sources.IsNull("_1")) + } + } } } } @@ -781,10 +802,9 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared test("Filter applied on merged Parquet schema with new column should work") { import testImplicits._ - Seq("true", "false").foreach { vectorized => + withAllParquetReaders { withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", - SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", - SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { + SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true") { withTempPath { dir => val path1 = s"${dir.getCanonicalPath}/table1" (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path1) @@ -1219,24 +1239,22 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } test("SPARK-17213: Broken Parquet filter push-down for string columns") { - Seq(true, false).foreach { vectorizedEnabled => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorizedEnabled.toString) { - withTempPath { dir => - import testImplicits._ + withAllParquetReaders { + withTempPath { dir => + import testImplicits._ - val path = dir.getCanonicalPath - // scalastyle:off nonascii - Seq("a", "é").toDF("name").write.parquet(path) - // scalastyle:on nonascii + val path = dir.getCanonicalPath + // scalastyle:off nonascii + Seq("a", "é").toDF("name").write.parquet(path) + // scalastyle:on nonascii - assert(spark.read.parquet(path).where("name > 'a'").count() == 1) - assert(spark.read.parquet(path).where("name >= 'a'").count() == 2) + assert(spark.read.parquet(path).where("name > 'a'").count() == 1) + assert(spark.read.parquet(path).where("name >= 'a'").count() == 2) - // scalastyle:off nonascii - assert(spark.read.parquet(path).where("name < 'é'").count() == 1) - assert(spark.read.parquet(path).where("name <= 'é'").count() == 2) - // scalastyle:on nonascii - } + // scalastyle:off nonascii + assert(spark.read.parquet(path).where("name < 'é'").count() == 1) + assert(spark.read.parquet(path).where("name <= 'é'").count() == 2) + // scalastyle:on nonascii } } } @@ -1244,8 +1262,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared test("SPARK-31026: Parquet predicate pushdown for fields having dots in the names") { import testImplicits._ - Seq(true, false).foreach { vectorized => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized.toString, + withAllParquetReaders { + withSQLConf( SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> true.toString, SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { withTempPath { path => @@ -1255,7 +1273,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } } - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized.toString, + withSQLConf( // Makes sure disabling 'spark.sql.parquet.recordFilter' still enables // row group level filtering. SQLConf.PARQUET_RECORD_FILTER_ENABLED.key -> "false", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 7f0a2286690bf..79c32976f02ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.parquet +import java.nio.file.{Files, Paths, StandardCopyOption} import java.sql.{Date, Timestamp} import java.time._ import java.util.Locale @@ -41,14 +42,15 @@ import org.apache.parquet.hadoop.util.HadoopInputFile import org.apache.parquet.io.api.RecordConsumer import org.apache.parquet.schema.{MessageType, MessageTypeParser} -import org.apache.spark.{SPARK_VERSION_SHORT, SparkException} +import org.apache.spark.{SPARK_VERSION_SHORT, SparkException, SparkUpgradeException} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils} import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy._ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -646,47 +648,39 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession } test("read dictionary encoded decimals written as INT32") { - ("true" :: "false" :: Nil).foreach { vectorized => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { - checkAnswer( - // Decimal column in this file is encoded using plain dictionary - readResourceParquetFile("test-data/dec-in-i32.parquet"), - spark.range(1 << 4).select('id % 10 cast DecimalType(5, 2) as 'i32_dec)) - } + withAllParquetReaders { + checkAnswer( + // Decimal column in this file is encoded using plain dictionary + readResourceParquetFile("test-data/dec-in-i32.parquet"), + spark.range(1 << 4).select('id % 10 cast DecimalType(5, 2) as 'i32_dec)) } } test("read dictionary encoded decimals written as INT64") { - ("true" :: "false" :: Nil).foreach { vectorized => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { - checkAnswer( - // Decimal column in this file is encoded using plain dictionary - readResourceParquetFile("test-data/dec-in-i64.parquet"), - spark.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'i64_dec)) - } + withAllParquetReaders { + checkAnswer( + // Decimal column in this file is encoded using plain dictionary + readResourceParquetFile("test-data/dec-in-i64.parquet"), + spark.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'i64_dec)) } } test("read dictionary encoded decimals written as FIXED_LEN_BYTE_ARRAY") { - ("true" :: "false" :: Nil).foreach { vectorized => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { - checkAnswer( - // Decimal column in this file is encoded using plain dictionary - readResourceParquetFile("test-data/dec-in-fixed-len.parquet"), - spark.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'fixed_len_dec)) - } + withAllParquetReaders { + checkAnswer( + // Decimal column in this file is encoded using plain dictionary + readResourceParquetFile("test-data/dec-in-fixed-len.parquet"), + spark.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'fixed_len_dec)) } } test("read dictionary and plain encoded timestamp_millis written as INT64") { - ("true" :: "false" :: Nil).foreach { vectorized => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { - checkAnswer( - // timestamp column in this file is encoded using combination of plain - // and dictionary encodings. - readResourceParquetFile("test-data/timemillis-in-i64.parquet"), - (1 to 3).map(i => Row(new java.sql.Timestamp(10)))) - } + withAllParquetReaders { + checkAnswer( + // timestamp column in this file is encoded using combination of plain + // and dictionary encodings. + readResourceParquetFile("test-data/timemillis-in-i64.parquet"), + (1 to 3).map(i => Row(new java.sql.Timestamp(10)))) } } @@ -882,88 +876,196 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession } } + // It generates input files for the test below: + // "SPARK-31159: compatibility with Spark 2.4 in reading dates/timestamps" + ignore("SPARK-31806: generate test files for checking compatibility with Spark 2.4") { + val resourceDir = "sql/core/src/test/resources/test-data" + val version = "2_4_5" + val N = 8 + def save( + in: Seq[(String, String)], + t: String, + dstFile: String, + options: Map[String, String] = Map.empty): Unit = { + withTempDir { dir => + in.toDF("dict", "plain") + .select($"dict".cast(t), $"plain".cast(t)) + .repartition(1) + .write + .mode("overwrite") + .options(options) + .parquet(dir.getCanonicalPath) + Files.copy( + dir.listFiles().filter(_.getName.endsWith(".snappy.parquet")).head.toPath, + Paths.get(resourceDir, dstFile), + StandardCopyOption.REPLACE_EXISTING) + } + } + DateTimeTestUtils.withDefaultTimeZone(DateTimeTestUtils.LA) { + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> DateTimeTestUtils.LA.getId) { + save( + (1 to N).map(i => ("1001-01-01", s"1001-01-0$i")), + "date", + s"before_1582_date_v$version.snappy.parquet") + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> "TIMESTAMP_MILLIS") { + save( + (1 to N).map(i => ("1001-01-01 01:02:03.123", s"1001-01-0$i 01:02:03.123")), + "timestamp", + s"before_1582_timestamp_millis_v$version.snappy.parquet") + } + val usTs = (1 to N).map(i => ("1001-01-01 01:02:03.123456", s"1001-01-0$i 01:02:03.123456")) + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> "TIMESTAMP_MICROS") { + save(usTs, "timestamp", s"before_1582_timestamp_micros_v$version.snappy.parquet") + } + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> "INT96") { + // Comparing to other logical types, Parquet-MR chooses dictionary encoding for the + // INT96 logical type because it consumes less memory for small column cardinality. + // Huge parquet files doesn't make sense to place to the resource folder. That's why + // we explicitly set `parquet.enable.dictionary` and generate two files w/ and w/o + // dictionary encoding. + save( + usTs, + "timestamp", + s"before_1582_timestamp_int96_plain_v$version.snappy.parquet", + Map("parquet.enable.dictionary" -> "false")) + save( + usTs, + "timestamp", + s"before_1582_timestamp_int96_dict_v$version.snappy.parquet", + Map("parquet.enable.dictionary" -> "true")) + } + } + } + } + test("SPARK-31159: compatibility with Spark 2.4 in reading dates/timestamps") { + val N = 8 // test reading the existing 2.4 files and new 3.0 files (with rebase on/off) together. - def checkReadMixedFiles(fileName: String, dt: String, dataStr: String): Unit = { + def checkReadMixedFiles[T]( + fileName: String, + catalystType: String, + rowFunc: Int => (String, String), + toJavaType: String => T, + checkDefaultLegacyRead: String => Unit, + tsOutputType: String = "TIMESTAMP_MICROS"): Unit = { withTempPaths(2) { paths => paths.foreach(_.delete()) val path2_4 = getResourceParquetFilePath("test-data/" + fileName) val path3_0 = paths(0).getCanonicalPath val path3_0_rebase = paths(1).getCanonicalPath - if (dt == "date") { - val df = Seq(dataStr).toDF("str").select($"str".cast("date").as("date")) - df.write.parquet(path3_0) - withSQLConf(SQLConf.LEGACY_PARQUET_REBASE_DATETIME_IN_WRITE.key -> "true") { - df.write.parquet(path3_0_rebase) + val df = Seq.tabulate(N)(rowFunc).toDF("dict", "plain") + .select($"dict".cast(catalystType), $"plain".cast(catalystType)) + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> tsOutputType) { + checkDefaultLegacyRead(path2_4) + // By default we should fail to write ancient datetime values. + val e = intercept[SparkException](df.write.parquet(path3_0)) + assert(e.getCause.getCause.getCause.isInstanceOf[SparkUpgradeException]) + withSQLConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_WRITE.key -> CORRECTED.toString) { + df.write.mode("overwrite").parquet(path3_0) } - checkAnswer( - spark.read.format("parquet").load(path2_4, path3_0, path3_0_rebase), - 1.to(3).map(_ => Row(java.sql.Date.valueOf(dataStr)))) - } else { - val df = Seq(dataStr).toDF("str").select($"str".cast("timestamp").as("ts")) - withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> dt) { - df.write.parquet(path3_0) - withSQLConf(SQLConf.LEGACY_PARQUET_REBASE_DATETIME_IN_WRITE.key -> "true") { - df.write.parquet(path3_0_rebase) - } + withSQLConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_WRITE.key -> LEGACY.toString) { + df.write.parquet(path3_0_rebase) } + } + // For Parquet files written by Spark 3.0, we know the writer info and don't need the + // config to guide the rebase behavior. + withSQLConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_READ.key -> LEGACY.toString) { checkAnswer( spark.read.format("parquet").load(path2_4, path3_0, path3_0_rebase), - 1.to(3).map(_ => Row(java.sql.Timestamp.valueOf(dataStr)))) + (0 until N).flatMap { i => + val (dictS, plainS) = rowFunc(i) + Seq.tabulate(3) { _ => + Row(toJavaType(dictS), toJavaType(plainS)) + } + }) } } } - - Seq(false, true).foreach { vectorized => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized.toString) { - withSQLConf(SQLConf.LEGACY_PARQUET_REBASE_DATETIME_IN_READ.key -> "true") { - checkReadMixedFiles("before_1582_date_v2_4.snappy.parquet", "date", "1001-01-01") - checkReadMixedFiles( - "before_1582_timestamp_micros_v2_4.snappy.parquet", - "TIMESTAMP_MICROS", - "1001-01-01 01:02:03.123456") - checkReadMixedFiles( - "before_1582_timestamp_millis_v2_4.snappy.parquet", - "TIMESTAMP_MILLIS", - "1001-01-01 01:02:03.123") - } - + def failInRead(path: String): Unit = { + val e = intercept[SparkException](spark.read.parquet(path).collect()) + assert(e.getCause.isInstanceOf[SparkUpgradeException]) + } + def successInRead(path: String): Unit = spark.read.parquet(path).collect() + Seq( + // By default we should fail to read ancient datetime values when parquet files don't + // contain Spark version. + "2_4_5" -> failInRead _, + "2_4_6" -> successInRead _).foreach { case (version, checkDefaultRead) => + withAllParquetReaders { + checkReadMixedFiles( + s"before_1582_date_v$version.snappy.parquet", + "date", + (i: Int) => ("1001-01-01", s"1001-01-0${i + 1}"), + java.sql.Date.valueOf, + checkDefaultRead) + checkReadMixedFiles( + s"before_1582_timestamp_micros_v$version.snappy.parquet", + "timestamp", + (i: Int) => ("1001-01-01 01:02:03.123456", s"1001-01-0${i + 1} 01:02:03.123456"), + java.sql.Timestamp.valueOf, + checkDefaultRead) + checkReadMixedFiles( + s"before_1582_timestamp_millis_v$version.snappy.parquet", + "timestamp", + (i: Int) => ("1001-01-01 01:02:03.123", s"1001-01-0${i + 1} 01:02:03.123"), + java.sql.Timestamp.valueOf, + checkDefaultRead, + tsOutputType = "TIMESTAMP_MILLIS") // INT96 is a legacy timestamp format and we always rebase the seconds for it. - checkAnswer(readResourceParquetFile( - "test-data/before_1582_timestamp_int96_v2_4.snappy.parquet"), - Row(java.sql.Timestamp.valueOf("1001-01-01 01:02:03.123456"))) + Seq("plain", "dict").foreach { enc => + checkAnswer(readResourceParquetFile( + s"test-data/before_1582_timestamp_int96_${enc}_v$version.snappy.parquet"), + Seq.tabulate(N) { i => + Row( + java.sql.Timestamp.valueOf("1001-01-01 01:02:03.123456"), + java.sql.Timestamp.valueOf(s"1001-01-0${i + 1} 01:02:03.123456")) + }) + } } } } test("SPARK-31159: rebasing timestamps in write") { - Seq( - ("TIMESTAMP_MILLIS", "1001-01-01 01:02:03.123", "1001-01-07 01:09:05.123"), - ("TIMESTAMP_MICROS", "1001-01-01 01:02:03.123456", "1001-01-07 01:09:05.123456"), - ("INT96", "1001-01-01 01:02:03.123456", "1001-01-01 01:02:03.123456") - ).foreach { case (outType, tsStr, nonRebased) => - withClue(s"output type $outType") { - withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> outType) { - withTempPath { dir => - val path = dir.getAbsolutePath - withSQLConf(SQLConf.LEGACY_PARQUET_REBASE_DATETIME_IN_WRITE.key -> "true") { - Seq(tsStr).toDF("tsS") - .select($"tsS".cast("timestamp").as("ts")) - .write - .parquet(path) - } - // The file metadata indicates if it needs rebase or not, so we can always get the - // correct result regardless of the "rebaseInRead" config. - Seq(true, false).foreach { rebase => - withSQLConf(SQLConf.LEGACY_PARQUET_REBASE_DATETIME_IN_READ.key -> rebase.toString) { - checkAnswer(spark.read.parquet(path), Row(Timestamp.valueOf(tsStr))) + val N = 8 + Seq(false, true).foreach { dictionaryEncoding => + Seq( + ("TIMESTAMP_MILLIS", "1001-01-01 01:02:03.123", "1001-01-07 01:09:05.123"), + ("TIMESTAMP_MICROS", "1001-01-01 01:02:03.123456", "1001-01-07 01:09:05.123456"), + ("INT96", "1001-01-01 01:02:03.123456", "1001-01-01 01:02:03.123456") + ).foreach { case (outType, tsStr, nonRebased) => + withClue(s"output type $outType") { + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> outType) { + withTempPath { dir => + val path = dir.getAbsolutePath + withSQLConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_WRITE.key -> LEGACY.toString) { + Seq.tabulate(N)(_ => tsStr).toDF("tsS") + .select($"tsS".cast("timestamp").as("ts")) + .repartition(1) + .write + .option("parquet.enable.dictionary", dictionaryEncoding) + .parquet(path) } - } - // Force to not rebase to prove the written datetime values are rebased and we will get - // wrong result if we don't rebase while reading. - withSQLConf("spark.test.forceNoRebase" -> "true") { - checkAnswer(spark.read.parquet(path), Row(Timestamp.valueOf(nonRebased))) + withAllParquetReaders { + // The file metadata indicates if it needs rebase or not, so we can always get the + // correct result regardless of the "rebase mode" config. + Seq(LEGACY, CORRECTED, EXCEPTION).foreach { mode => + withSQLConf( + SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_READ.key -> mode.toString) { + checkAnswer( + spark.read.parquet(path), + Seq.tabulate(N)(_ => Row(Timestamp.valueOf(tsStr)))) + } + } + + // Force to not rebase to prove the written datetime values are rebased + // and we will get wrong result if we don't rebase while reading. + withSQLConf("spark.test.forceNoRebase" -> "true") { + checkAnswer( + spark.read.parquet(path), + Seq.tabulate(N)(_ => Row(Timestamp.valueOf(nonRebased)))) + } + } } } } @@ -972,27 +1074,38 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession } test("SPARK-31159: rebasing dates in write") { - withTempPath { dir => - val path = dir.getAbsolutePath - withSQLConf(SQLConf.LEGACY_PARQUET_REBASE_DATETIME_IN_WRITE.key -> "true") { - Seq("1001-01-01").toDF("dateS") - .select($"dateS".cast("date").as("date")) - .write - .parquet(path) - } - - // The file metadata indicates if it needs rebase or not, so we can always get the correct - // result regardless of the "rebaseInRead" config. - Seq(true, false).foreach { rebase => - withSQLConf(SQLConf.LEGACY_PARQUET_REBASE_DATETIME_IN_READ.key -> rebase.toString) { - checkAnswer(spark.read.parquet(path), Row(Date.valueOf("1001-01-01"))) + val N = 8 + Seq(false, true).foreach { dictionaryEncoding => + withTempPath { dir => + val path = dir.getAbsolutePath + withSQLConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_WRITE.key -> LEGACY.toString) { + Seq.tabulate(N)(_ => "1001-01-01").toDF("dateS") + .select($"dateS".cast("date").as("date")) + .repartition(1) + .write + .option("parquet.enable.dictionary", dictionaryEncoding) + .parquet(path) } - } - // Force to not rebase to prove the written datetime values are rebased and we will get - // wrong result if we don't rebase while reading. - withSQLConf("spark.test.forceNoRebase" -> "true") { - checkAnswer(spark.read.parquet(path), Row(Date.valueOf("1001-01-07"))) + withAllParquetReaders { + // The file metadata indicates if it needs rebase or not, so we can always get the + // correct result regardless of the "rebase mode" config. + Seq(LEGACY, CORRECTED, EXCEPTION).foreach { mode => + withSQLConf(SQLConf.LEGACY_AVRO_REBASE_MODE_IN_READ.key -> mode.toString) { + checkAnswer( + spark.read.parquet(path), + Seq.tabulate(N)(_ => Row(Date.valueOf("1001-01-01")))) + } + } + + // Force to not rebase to prove the written datetime values are rebased and we will get + // wrong result if we don't rebase while reading. + withSQLConf("spark.test.forceNoRebase" -> "true") { + checkAnswer( + spark.read.parquet(path), + Seq.tabulate(N)(_ => Row(Date.valueOf("1001-01-07")))) + } + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala index 7d75077a9732a..a14f6416199a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala @@ -124,12 +124,11 @@ class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedS FileUtils.copyFile(new File(impalaPath), new File(tableDir, "part-00001.parq")) Seq(false, true).foreach { int96TimestampConversion => - Seq(false, true).foreach { vectorized => + withAllParquetReaders { withSQLConf( (SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key, SQLConf.ParquetOutputTimestampType.INT96.toString), - (SQLConf.PARQUET_INT96_TIMESTAMP_CONVERSION.key, int96TimestampConversion.toString()), - (SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, vectorized.toString()) + (SQLConf.PARQUET_INT96_TIMESTAMP_CONVERSION.key, int96TimestampConversion.toString()) ) { val readBack = spark.read.parquet(tableDir.getAbsolutePath).collect() assert(readBack.size === 6) @@ -149,7 +148,8 @@ class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedS val fullExpectations = (ts ++ impalaExpectations).map(_.toString).sorted.toArray val actual = readBack.map(_.getTimestamp(0).toString).sorted withClue( - s"int96TimestampConversion = $int96TimestampConversion; vectorized = $vectorized") { + s"int96TimestampConversion = $int96TimestampConversion; " + + s"vectorized = ${SQLConf.get.parquetVectorizedReaderEnabled}") { assert(fullExpectations === actual) // Now test that the behavior is still correct even with a filter which could get diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index d3301ced2ba19..32a9558e91f10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -60,7 +60,7 @@ abstract class ParquetPartitionDiscoverySuite val timeZoneId = ZoneId.systemDefault() val df = DateFormatter(timeZoneId) val tf = TimestampFormatter( - timestampPartitionPattern, timeZoneId, needVarLengthSecondFraction = true) + timestampPartitionPattern, timeZoneId, isParsing = true) protected override def beforeAll(): Unit = { super.beforeAll() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 917aaba2669ce..05d305a9b52ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -168,11 +168,9 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS withTempPath { file => val df = spark.createDataFrame(sparkContext.parallelize(data), schema) df.write.parquet(file.getCanonicalPath) - ("true" :: "false" :: Nil).foreach { vectorized => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { - val df2 = spark.read.parquet(file.getCanonicalPath) - checkAnswer(df2, df.collect().toSeq) - } + withAllParquetReaders { + val df2 = spark.read.parquet(file.getCanonicalPath) + checkAnswer(df2, df.collect().toSeq) } } } @@ -791,15 +789,13 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS } test("SPARK-26677: negated null-safe equality comparison should not filter matched row groups") { - (true :: false :: Nil).foreach { vectorized => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized.toString) { - withTempPath { path => - // Repeated values for dictionary encoding. - Seq(Some("A"), Some("A"), None).toDF.repartition(1) - .write.parquet(path.getAbsolutePath) - val df = spark.read.parquet(path.getAbsolutePath) - checkAnswer(stripSparkFilter(df.where("NOT (value <=> 'A')")), df) - } + withAllParquetReaders { + withTempPath { path => + // Repeated values for dictionary encoding. + Seq(Some("A"), Some("A"), None).toDF.repartition(1) + .write.parquet(path.getAbsolutePath) + val df = spark.read.parquet(path.getAbsolutePath) + checkAnswer(stripSparkFilter(df.where("NOT (value <=> 'A')")), df) } } } @@ -821,10 +817,8 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> toTsType) { write(df2.write.mode(SaveMode.Append)) } - Seq("true", "false").foreach { vectorized => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { - checkAnswer(readback, df1.unionAll(df2)) - } + withAllParquetReaders { + checkAnswer(readback, df1.unionAll(df2)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index c833d5f1ab1f6..105f025adc0ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -69,7 +69,9 @@ private[sql] trait ParquetTest extends FileBasedDataSourceTest { protected def withParquetDataFrame(df: DataFrame, testVectorized: Boolean = true) (f: DataFrame => Unit): Unit = { withTempPath { file => - df.write.format(dataSourceName).save(file.getCanonicalPath) + withSQLConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_WRITE.key -> "CORRECTED") { + df.write.format(dataSourceName).save(file.getCanonicalPath) + } readFile(file.getCanonicalPath, testVectorized)(f) } } @@ -162,4 +164,11 @@ private[sql] trait ParquetTest extends FileBasedDataSourceTest { protected def getResourceParquetFilePath(name: String): String = { Thread.currentThread().getContextClassLoader.getResource(name).toString } + + def withAllParquetReaders(code: => Unit): Unit = { + // test the row-based reader + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false")(code) + // test the vectorized reader + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true")(code) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 1be9308c06d8c..f7d5a899df1c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -22,9 +22,10 @@ import scala.reflect.ClassTag import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft} +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans.logical.BROADCAST import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec} -import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AdaptiveTestUtils, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.functions._ @@ -411,7 +412,7 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils val e = intercept[Exception] { testDf.collect() } - AdaptiveTestUtils.assertExceptionMessage(e, s"Could not execute broadcast in $timeout secs.") + assert(e.getMessage.contains(s"Could not execute broadcast in $timeout secs.")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala index 5490246baceea..554990413c28c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 08898f80034e6..44ab3f7d023d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint} @@ -133,7 +134,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSparkSession { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeBroadcastHashJoin( - leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, BuildLeft), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } @@ -145,7 +146,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSparkSession { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeBroadcastHashJoin( - leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, BuildRight), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } @@ -157,7 +158,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSparkSession { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeShuffledHashJoin( - leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, BuildLeft), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } @@ -169,7 +170,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSparkSession { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeShuffledHashJoin( - leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, BuildRight), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index a5ade0d8d7508..879f282e4d05d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 08fb655bde467..50652690339a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -310,8 +310,8 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils test("ShuffledHashJoin metrics") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "40", - SQLConf.SHUFFLE_PARTITIONS.key -> "2", - SQLConf.PREFER_SORTMERGEJOIN.key -> "false") { + SQLConf.SHUFFLE_PARTITIONS.key -> "2", + SQLConf.PREFER_SORTMERGEJOIN.key -> "false") { val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") val df2 = (1 to 10).map(i => (i, i.toString)).toSeq.toDF("key", "value") // Assume the execution plan is @@ -325,30 +325,49 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils // +- LocalTableScan(nodeId = 7) Seq((1L, 2L, 5L, false), (2L, 3L, 7L, true)).foreach { case (nodeId1, nodeId2, nodeId3, enableWholeStage) => - val df = df1.join(df2, "key") + val df = df1.join(df2, "key") + testSparkPlanMetrics(df, 1, Map( + nodeId1 -> (("ShuffledHashJoin", Map( + "number of output rows" -> 2L))), + nodeId2 -> (("Exchange", Map( + "shuffle records written" -> 2L, + "records read" -> 2L))), + nodeId3 -> (("Exchange", Map( + "shuffle records written" -> 10L, + "records read" -> 10L)))), + enableWholeStage + ) + } + } + } + + test("ShuffledHashJoin(left, outer) metrics") { + val leftDf = Seq((1, "1"), (2, "2")).toDF("key", "value") + val rightDf = (1 to 10).map(i => (i, i.toString)).toSeq.toDF("key2", "value") + Seq((0L, "right_outer", leftDf, rightDf, 10L, false), + (0L, "left_outer", rightDf, leftDf, 10L, false), + (0L, "right_outer", leftDf, rightDf, 10L, true), + (0L, "left_outer", rightDf, leftDf, 10L, true), + (2L, "left_anti", rightDf, leftDf, 8L, true), + (2L, "left_semi", rightDf, leftDf, 2L, true), + (1L, "left_anti", rightDf, leftDf, 8L, false), + (1L, "left_semi", rightDf, leftDf, 2L, false)) + .foreach { case (nodeId, joinType, leftDf, rightDf, rows, enableWholeStage) => + val df = leftDf.hint("shuffle_hash").join( + rightDf.hint("shuffle_hash"), $"key" === $"key2", joinType) testSparkPlanMetrics(df, 1, Map( - nodeId1 -> (("ShuffledHashJoin", Map( - "number of output rows" -> 2L))), - nodeId2 -> (("Exchange", Map( - "shuffle records written" -> 2L, - "records read" -> 2L))), - nodeId3 -> (("Exchange", Map( - "shuffle records written" -> 10L, - "records read" -> 10L)))), + nodeId -> (("ShuffledHashJoin", Map( + "number of output rows" -> rows)))), enableWholeStage ) } - } } test("BroadcastHashJoin(outer) metrics") { val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value") val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value") - // Assume the execution plan is - // ... -> BroadcastHashJoin(nodeId = 0) - Seq(("left_outer", 0L, 5L, false), ("right_outer", 0L, 6L, false), - ("left_outer", 1L, 5L, true), ("right_outer", 1L, 6L, true)).foreach { - case (joinType, nodeId, numRows, enableWholeStage) => + Seq(("left_outer", 0L, 5L, false), ("right_outer", 0L, 6L, false), ("left_outer", 1L, 5L, true), + ("right_outer", 1L, 6L, true)).foreach { case (joinType, nodeId, numRows, enableWholeStage) => val df = df1.join(broadcast(df2), $"key" === $"key2", joinType) testSparkPlanMetrics(df, 2, Map( nodeId -> (("BroadcastHashJoin", Map( @@ -365,9 +384,12 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils withTempView("testDataForJoin") { // Assume the execution plan is // ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0) - val query = "SELECT * FROM testData2 left JOIN testDataForJoin ON " + + val leftQuery = "SELECT * FROM testData2 LEFT JOIN testDataForJoin ON " + + "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a" + val rightQuery = "SELECT * FROM testData2 RIGHT JOIN testDataForJoin ON " + "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a" - Seq(false, true).foreach { enableWholeStage => + Seq((leftQuery, false), (rightQuery, false), (leftQuery, true), (rightQuery, true)) + .foreach { case (query, enableWholeStage) => val df = spark.sql(query) testSparkPlanMetrics(df, 2, Map( 0L -> (("BroadcastNestedLoopJoin", Map( @@ -394,6 +416,19 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils } } + test("BroadcastLeftAntiJoinHash metrics") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") + Seq((1L, false), (2L, true)).foreach { case (nodeId, enableWholeStage) => + val df = df2.join(broadcast(df1), $"key" === $"key2", "left_anti") + testSparkPlanMetrics(df, 2, Map( + nodeId -> (("BroadcastHashJoin", Map( + "number of output rows" -> 2L)))), + enableWholeStage + ) + } + } + test("CartesianProduct metrics") { withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala index f95daafdfe19b..6d615b5ef0449 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala @@ -18,7 +18,15 @@ package org.apache.spark.sql.execution.streaming import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.lang.{Long => JLong} +import java.net.URI import java.nio.charset.StandardCharsets.UTF_8 +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicLong + +import scala.util.Random + +import org.apache.hadoop.fs.{FSDataInputStream, Path, RawLocalFileSystem} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.internal.SQLConf @@ -240,6 +248,44 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSparkSession { )) } + test("getLatestBatchId") { + withCountOpenLocalFileSystemAsLocalFileSystem { + val scheme = CountOpenLocalFileSystem.scheme + withSQLConf(SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key -> "3") { + withTempDir { dir => + val sinkLog = new FileStreamSinkLog(FileStreamSinkLog.VERSION, spark, + s"$scheme:///${dir.getCanonicalPath}") + for (batchId <- 0L to 2L) { + sinkLog.add( + batchId, + Array(newFakeSinkFileStatus("/a/b/" + batchId, FileStreamSinkLog.ADD_ACTION))) + } + + def getCountForOpenOnMetadataFile(batchId: Long): Long = { + val path = sinkLog.batchIdToPath(batchId).toUri.getPath + CountOpenLocalFileSystem.pathToNumOpenCalled.getOrDefault(path, 0L) + } + + CountOpenLocalFileSystem.resetCount() + + assert(sinkLog.getLatestBatchId() === Some(2L)) + // getLatestBatchId doesn't open the latest metadata log file + (0L to 2L).foreach { batchId => + assert(getCountForOpenOnMetadataFile(batchId) === 0L) + } + + assert(sinkLog.getLatest().map(_._1).getOrElse(-1L) === 2L) + (0L to 1L).foreach { batchId => + assert(getCountForOpenOnMetadataFile(batchId) === 0L) + } + // getLatest opens the latest metadata log file, which explains the needs on + // having "getLatestBatchId". + assert(getCountForOpenOnMetadataFile(2L) === 1L) + } + } + } + } + /** * Create a fake SinkFileStatus using path and action. Most of tests don't care about other fields * in SinkFileStatus. @@ -267,4 +313,41 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSparkSession { val log = new FileStreamSinkLog(FileStreamSinkLog.VERSION, spark, input.toString) log.allFiles() } + + private def withCountOpenLocalFileSystemAsLocalFileSystem(body: => Unit): Unit = { + val optionKey = s"fs.${CountOpenLocalFileSystem.scheme}.impl" + val originClassForLocalFileSystem = spark.conf.getOption(optionKey) + try { + spark.conf.set(optionKey, classOf[CountOpenLocalFileSystem].getName) + body + } finally { + originClassForLocalFileSystem match { + case Some(fsClazz) => spark.conf.set(optionKey, fsClazz) + case _ => spark.conf.unset(optionKey) + } + } + } +} + +class CountOpenLocalFileSystem extends RawLocalFileSystem { + import CountOpenLocalFileSystem._ + + override def getUri: URI = { + URI.create(s"$scheme:///") + } + + override def open(f: Path, bufferSize: Int): FSDataInputStream = { + val path = f.toUri.getPath + pathToNumOpenCalled.compute(path, (_, v) => { + if (v == null) 1L else v + 1 + }) + super.open(f, bufferSize) + } +} + +object CountOpenLocalFileSystem { + val scheme = s"FileStreamSinkLogSuite${math.abs(Random.nextInt)}fs" + val pathToNumOpenCalled = new ConcurrentHashMap[String, JLong] + + def resetCount(): Unit = pathToNumOpenCalled.clear() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SparkPlanInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SparkPlanInfoSuite.scala new file mode 100644 index 0000000000000..a702e00ff9f92 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SparkPlanInfoSuite.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.ui + +import org.apache.spark.sql.execution.SparkPlanInfo +import org.apache.spark.sql.test.SharedSparkSession + +class SparkPlanInfoSuite extends SharedSparkSession{ + + import testImplicits._ + + def vaidateSparkPlanInfo(sparkPlanInfo: SparkPlanInfo): Unit = { + sparkPlanInfo.nodeName match { + case "InMemoryTableScan" => assert(sparkPlanInfo.children.length == 1) + case _ => sparkPlanInfo.children.foreach(vaidateSparkPlanInfo) + } + } + + test("SparkPlanInfo creation from SparkPlan with InMemoryTableScan node") { + val dfWithCache = Seq( + (1, 1), + (2, 2) + ).toDF().filter("_1 > 1").cache().repartition(10) + + val planInfoResult = SparkPlanInfo.fromSparkPlan(dfWithCache.queryExecution.executedPlan) + + vaidateSparkPlanInfo(planInfoResult) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala index e18514c6f93f9..53f9757750735 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala @@ -20,11 +20,12 @@ package org.apache.spark.sql.expressions import scala.collection.parallel.immutable.ParVector import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.catalyst.expressions.ExpressionInfo +import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.HiveResult.hiveResultString import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.util.Utils class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { @@ -156,4 +157,38 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { } } } + + test("Check whether SQL expressions should extend NullIntolerant") { + // Only check expressions extended from these expressions because these expressions are + // NullIntolerant by default. + val exprTypesToCheck = Seq(classOf[UnaryExpression], classOf[BinaryExpression], + classOf[TernaryExpression], classOf[QuaternaryExpression], classOf[SeptenaryExpression]) + + // Do not check these expressions, because these expressions extend NullIntolerant + // and override the eval method to avoid evaluating input1 if input2 is 0. + val ignoreSet = Set(classOf[IntegralDivide], classOf[Divide], classOf[Remainder], classOf[Pmod]) + + val candidateExprsToCheck = spark.sessionState.functionRegistry.listFunction() + .map(spark.sessionState.catalog.lookupFunctionInfo).map(_.getClassName) + .filterNot(c => ignoreSet.exists(_.getName.equals(c))) + .map(name => Utils.classForName(name)) + .filterNot(classOf[NonSQLExpression].isAssignableFrom) + + exprTypesToCheck.foreach { superClass => + candidateExprsToCheck.filter(superClass.isAssignableFrom).foreach { clazz => + val isEvalOverrode = clazz.getMethod("eval", classOf[InternalRow]) != + superClass.getMethod("eval", classOf[InternalRow]) + val isNullIntolerantMixedIn = classOf[NullIntolerant].isAssignableFrom(clazz) + if (isEvalOverrode && isNullIntolerantMixedIn) { + fail(s"${clazz.getName} should not extend ${classOf[NullIntolerant].getSimpleName}, " + + s"or add ${clazz.getName} in the ignoreSet of this test.") + } else if (!isEvalOverrode && !isNullIntolerantMixedIn) { + fail(s"${clazz.getName} should extend ${classOf[NullIntolerant].getSimpleName}.") + } else { + assert((!isEvalOverrode && isNullIntolerantMixedIn) || + (isEvalOverrode && !isNullIntolerantMixedIn)) + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SharedStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SharedStateSuite.scala new file mode 100644 index 0000000000000..81bf15342423c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SharedStateSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import java.net.URL + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FsUrlStreamHandlerFactory + +import org.apache.spark.SparkConf +import org.apache.spark.sql.test.SharedSparkSession + + +/** + * Tests for [[org.apache.spark.sql.internal.SharedState]]. + */ +class SharedStateSuite extends SharedSparkSession { + + override protected def sparkConf: SparkConf = { + super.sparkConf + .set("spark.hadoop.fs.defaultFS", "file:///") + } + + test("SPARK-31692: Url handler factory should have the hadoop configs from Spark conf") { + // Accessing shared state to init the object since it is `lazy val` + spark.sharedState + val field = classOf[URL].getDeclaredField("factory") + field.setAccessible(true) + val value = field.get(null) + assert(value.isInstanceOf[FsUrlStreamHandlerFactory]) + val streamFactory = value.asInstanceOf[FsUrlStreamHandlerFactory] + + val confField = classOf[FsUrlStreamHandlerFactory].getDeclaredField("conf") + confField.setAccessible(true) + val conf = confField.get(streamFactory) + + assert(conf.isInstanceOf[Configuration]) + assert(conf.asInstanceOf[Configuration].get("fs.defaultFS") == "file:///") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index e153c7168dbf2..1d8303b9e7750 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.{DataSourceScanExec, FileSourceScanExec, SortExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec -import org.apache.spark.sql.execution.adaptive.AdaptiveTestUtils.assertExceptionMessage import org.apache.spark.sql.execution.datasources.BucketingUtils import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec @@ -771,7 +770,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { agged.count() } - assertExceptionMessage(error, "Invalid bucket file") + assert(error.getCause().toString contains "Invalid bucket file") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 87a4d061b8170..abd33ab8a8f22 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -623,12 +623,12 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { var msg = intercept[AnalysisException] { sql("insert into t select 1L, 2") }.getMessage - assert(msg.contains("Cannot safely cast 'i': LongType to IntegerType")) + assert(msg.contains("Cannot safely cast 'i': bigint to int")) msg = intercept[AnalysisException] { sql("insert into t select 1, 2.0") }.getMessage - assert(msg.contains("Cannot safely cast 'd': DecimalType(2,1) to DoubleType")) + assert(msg.contains("Cannot safely cast 'd': decimal(2,1) to double")) msg = intercept[AnalysisException] { sql("insert into t select 1, 2.0D, 3") @@ -660,18 +660,18 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { var msg = intercept[AnalysisException] { sql("insert into t values('a', 'b')") }.getMessage - assert(msg.contains("Cannot safely cast 'i': StringType to IntegerType") && - msg.contains("Cannot safely cast 'd': StringType to DoubleType")) + assert(msg.contains("Cannot safely cast 'i': string to int") && + msg.contains("Cannot safely cast 'd': string to double")) msg = intercept[AnalysisException] { sql("insert into t values(now(), now())") }.getMessage - assert(msg.contains("Cannot safely cast 'i': TimestampType to IntegerType") && - msg.contains("Cannot safely cast 'd': TimestampType to DoubleType")) + assert(msg.contains("Cannot safely cast 'i': timestamp to int") && + msg.contains("Cannot safely cast 'd': timestamp to double")) msg = intercept[AnalysisException] { sql("insert into t values(true, false)") }.getMessage - assert(msg.contains("Cannot safely cast 'i': BooleanType to IntegerType") && - msg.contains("Cannot safely cast 'd': BooleanType to DoubleType")) + assert(msg.contains("Cannot safely cast 'i': boolean to int") && + msg.contains("Cannot safely cast 'd': boolean to double")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/StreamingQueryPageSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/StreamingQueryPageSuite.scala index 2a1e18ab66bb7..640c21c52a146 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/StreamingQueryPageSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/StreamingQueryPageSuite.scala @@ -43,13 +43,11 @@ class StreamingQueryPageSuite extends SharedSparkSession with BeforeAndAfter { var html = renderStreamingQueryPage(request, tab) .toString().toLowerCase(Locale.ROOT) assert(html.contains("active streaming queries (1)")) - assert(html.contains("completed streaming queries (0)")) when(streamQuery.isActive).thenReturn(false) when(streamQuery.exception).thenReturn(None) html = renderStreamingQueryPage(request, tab) .toString().toLowerCase(Locale.ROOT) - assert(html.contains("active streaming queries (0)")) assert(html.contains("completed streaming queries (1)")) assert(html.contains("finished")) @@ -57,7 +55,6 @@ class StreamingQueryPageSuite extends SharedSparkSession with BeforeAndAfter { when(streamQuery.exception).thenReturn(Option("exception in query")) html = renderStreamingQueryPage(request, tab) .toString().toLowerCase(Locale.ROOT) - assert(html.contains("active streaming queries (0)")) assert(html.contains("completed streaming queries (1)")) assert(html.contains("failed")) assert(html.contains("exception in query")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/UISeleniumSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/UISeleniumSuite.scala index fdf4c6634d79f..63b5792ebd515 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/UISeleniumSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/UISeleniumSuite.scala @@ -91,21 +91,23 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B goToUi(spark, "/StreamingQuery") findAll(cssSelector("h3")).map(_.text).toSeq should contain("Streaming Query") - findAll(cssSelector("""#activeQueries-table th""")).map(_.text).toSeq should be { - List("Name", "Status", "Id", "Run ID", "Start Time", "Duration", "Avg Input /sec", - "Avg Process /sec", "Lastest Batch") + + val arrow = 0x25BE.toChar + findAll(cssSelector("""#active-table th""")).map(_.text).toList should be { + List("Name", "Status", "ID", "Run ID", s"Start Time $arrow", "Duration", + "Avg Input /sec", "Avg Process /sec", "Latest Batch") } val activeQueries = - findAll(cssSelector("""#activeQueries-table td""")).map(_.text).toSeq + findAll(cssSelector("""#active-table td""")).map(_.text).toSeq activeQueries should contain(activeQuery.id.toString) activeQueries should contain(activeQuery.runId.toString) - findAll(cssSelector("""#completedQueries-table th""")) - .map(_.text).toSeq should be { - List("Name", "Status", "Id", "Run ID", "Start Time", "Duration", "Avg Input /sec", - "Avg Process /sec", "Lastest Batch", "Error") + findAll(cssSelector("""#completed-table th""")) + .map(_.text).toList should be { + List("Name", "Status", "ID", "Run ID", s"Start Time $arrow", "Duration", + "Avg Input /sec", "Avg Process /sec", "Latest Batch", "Error") } val completedQueries = - findAll(cssSelector("""#completedQueries-table td""")).map(_.text).toSeq + findAll(cssSelector("""#completed-table td""")).map(_.text).toSeq completedQueries should contain(completedQuery.id.toString) completedQueries should contain(completedQuery.runId.toString) completedQueries should contain(failedQuery.id.toString) @@ -113,7 +115,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B // Check the query statistics page val activeQueryLink = - findAll(cssSelector("""#activeQueries-table a""")).flatMap(_.attribute("href")).next + findAll(cssSelector("""#active-table td a""")).flatMap(_.attribute("href")).next go to activeQueryLink findAll(cssSelector("h3")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 9747840ce4032..fe0a8439acc2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -333,7 +333,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with var msg = intercept[AnalysisException] { Seq((1L, 2.0)).toDF("i", "d").write.mode("append").saveAsTable("t") }.getMessage - assert(msg.contains("Cannot safely cast 'i': LongType to IntegerType")) + assert(msg.contains("Cannot safely cast 'i': bigint to int")) // Insert into table successfully. Seq((1, 2.0)).toDF("i", "d").write.mode("append").saveAsTable("t") @@ -354,14 +354,14 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with var msg = intercept[AnalysisException] { Seq(("a", "b")).toDF("i", "d").write.mode("append").saveAsTable("t") }.getMessage - assert(msg.contains("Cannot safely cast 'i': StringType to IntegerType") && - msg.contains("Cannot safely cast 'd': StringType to DoubleType")) + assert(msg.contains("Cannot safely cast 'i': string to int") && + msg.contains("Cannot safely cast 'd': string to double")) msg = intercept[AnalysisException] { Seq((true, false)).toDF("i", "d").write.mode("append").saveAsTable("t") }.getMessage - assert(msg.contains("Cannot safely cast 'i': BooleanType to IntegerType") && - msg.contains("Cannot safely cast 'd': BooleanType to DoubleType")) + assert(msg.contains("Cannot safely cast 'i': boolean to int") && + msg.contains("Cannot safely cast 'd': boolean to double")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceSuite.scala new file mode 100644 index 0000000000000..43cca246cc47c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceSuite.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status.api.v1.sql + +import java.util.Date + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.PrivateMethodTester + +import org.apache.spark.{JobExecutionStatus, SparkFunSuite} +import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SparkPlanGraphCluster, SparkPlanGraphEdge, SparkPlanGraphNode, SQLExecutionUIData, SQLPlanMetric} + +object SqlResourceSuite { + + val SCAN_TEXT = "Scan text" + val FILTER = "Filter" + val WHOLE_STAGE_CODEGEN_1 = "WholeStageCodegen (1)" + val DURATION = "duration" + val NUMBER_OF_OUTPUT_ROWS = "number of output rows" + val METADATA_TIME = "metadata time" + val NUMBER_OF_FILES_READ = "number of files read" + val SIZE_OF_FILES_READ = "size of files read" + val PLAN_DESCRIPTION = "== Physical Plan ==\nCollectLimit (3)\n+- * Filter (2)\n +- Scan text..." + val DESCRIPTION = "csv at MyDataFrames.scala:57" + + val nodeIdAndWSCGIdMap: Map[Long, Option[Long]] = Map(1L -> Some(1L)) + + val filterNode = new SparkPlanGraphNode(1, FILTER, "", + metrics = Seq(SQLPlanMetric(NUMBER_OF_OUTPUT_ROWS, 1, ""))) + val nodes: Seq[SparkPlanGraphNode] = Seq( + new SparkPlanGraphCluster(0, WHOLE_STAGE_CODEGEN_1, "", + nodes = ArrayBuffer(filterNode), + metrics = Seq(SQLPlanMetric(DURATION, 0, ""))), + new SparkPlanGraphNode(2, SCAN_TEXT, "", + metrics = Seq( + SQLPlanMetric(METADATA_TIME, 2, ""), + SQLPlanMetric(NUMBER_OF_FILES_READ, 3, ""), + SQLPlanMetric(NUMBER_OF_OUTPUT_ROWS, 4, ""), + SQLPlanMetric(SIZE_OF_FILES_READ, 5, "")))) + + val nodesWhenCodegenIsOff: Seq[SparkPlanGraphNode] = + SparkPlanGraph(nodes, edges).allNodes.filterNot(_.name == WHOLE_STAGE_CODEGEN_1) + + val edges: Seq[SparkPlanGraphEdge] = + Seq(SparkPlanGraphEdge(3, 2)) + + val metrics: Seq[SQLPlanMetric] = { + Seq(SQLPlanMetric(DURATION, 0, ""), + SQLPlanMetric(NUMBER_OF_OUTPUT_ROWS, 1, ""), + SQLPlanMetric(METADATA_TIME, 2, ""), + SQLPlanMetric(NUMBER_OF_FILES_READ, 3, ""), + SQLPlanMetric(NUMBER_OF_OUTPUT_ROWS, 4, ""), + SQLPlanMetric(SIZE_OF_FILES_READ, 5, "")) + } + + val sqlExecutionUIData: SQLExecutionUIData = { + def getMetricValues() = { + Map[Long, String]( + 0L -> "0 ms", + 1L -> "1", + 2L -> "2 ms", + 3L -> "1", + 4L -> "1", + 5L -> "330.0 B" + ) + } + + new SQLExecutionUIData( + executionId = 0, + description = DESCRIPTION, + details = "", + physicalPlanDescription = PLAN_DESCRIPTION, + metrics = metrics, + submissionTime = 1586768888233L, + completionTime = Some(new Date(1586768888999L)), + jobs = Map[Int, JobExecutionStatus]( + 0 -> JobExecutionStatus.SUCCEEDED, + 1 -> JobExecutionStatus.SUCCEEDED), + stages = Set[Int](), + metricValues = getMetricValues() + ) + } + + private def getNodes(): Seq[Node] = { + val node = Node(0, WHOLE_STAGE_CODEGEN_1, + wholeStageCodegenId = None, metrics = Seq(Metric(DURATION, "0 ms"))) + val node2 = Node(1, FILTER, + wholeStageCodegenId = Some(1), metrics = Seq(Metric(NUMBER_OF_OUTPUT_ROWS, "1"))) + val node3 = Node(2, SCAN_TEXT, wholeStageCodegenId = None, + metrics = Seq(Metric(METADATA_TIME, "2 ms"), + Metric(NUMBER_OF_FILES_READ, "1"), + Metric(NUMBER_OF_OUTPUT_ROWS, "1"), + Metric(SIZE_OF_FILES_READ, "330.0 B"))) + + // reverse order because of supporting execution order by aligning with Spark-UI + Seq(node3, node2, node) + } + + private def getExpectedNodesWhenWholeStageCodegenIsOff(): Seq[Node] = { + val node = Node(1, FILTER, metrics = Seq(Metric(NUMBER_OF_OUTPUT_ROWS, "1"))) + val node2 = Node(2, SCAN_TEXT, + metrics = Seq(Metric(METADATA_TIME, "2 ms"), + Metric(NUMBER_OF_FILES_READ, "1"), + Metric(NUMBER_OF_OUTPUT_ROWS, "1"), + Metric(SIZE_OF_FILES_READ, "330.0 B"))) + + // reverse order because of supporting execution order by aligning with Spark-UI + Seq(node2, node) + } + + private def verifyExpectedExecutionData(executionData: ExecutionData, + nodes: Seq[Node], + edges: Seq[SparkPlanGraphEdge], + planDescription: String): Unit = { + + assert(executionData.id == 0) + assert(executionData.status == "COMPLETED") + assert(executionData.description == DESCRIPTION) + assert(executionData.planDescription == planDescription) + assert(executionData.submissionTime == new Date(1586768888233L)) + assert(executionData.duration == 766L) + assert(executionData.successJobIds == Seq[Int](0, 1)) + assert(executionData.runningJobIds == Seq[Int]()) + assert(executionData.failedJobIds == Seq.empty) + assert(executionData.nodes == nodes) + assert(executionData.edges == edges) + } + +} + +/** + * Sql Resource Public API Unit Tests. + */ +class SqlResourceSuite extends SparkFunSuite with PrivateMethodTester { + + import SqlResourceSuite._ + + val sqlResource = new SqlResource() + val prepareExecutionData = PrivateMethod[ExecutionData]('prepareExecutionData) + + test("Prepare ExecutionData when details = false and planDescription = false") { + val executionData = + sqlResource invokePrivate prepareExecutionData( + sqlExecutionUIData, SparkPlanGraph(Seq.empty, Seq.empty), false, false) + verifyExpectedExecutionData(executionData, edges = Seq.empty, + nodes = Seq.empty, planDescription = "") + } + + test("Prepare ExecutionData when details = true and planDescription = false") { + val executionData = + sqlResource invokePrivate prepareExecutionData( + sqlExecutionUIData, SparkPlanGraph(nodes, edges), true, false) + verifyExpectedExecutionData( + executionData, + nodes = getNodes(), + edges, + planDescription = "") + } + + test("Prepare ExecutionData when details = true and planDescription = true") { + val executionData = + sqlResource invokePrivate prepareExecutionData( + sqlExecutionUIData, SparkPlanGraph(nodes, edges), true, true) + verifyExpectedExecutionData( + executionData, + nodes = getNodes(), + edges = edges, + planDescription = PLAN_DESCRIPTION) + } + + test("Prepare ExecutionData when details = true and planDescription = false and WSCG = off") { + val executionData = + sqlResource invokePrivate prepareExecutionData( + sqlExecutionUIData, SparkPlanGraph(nodesWhenCodegenIsOff, edges), true, false) + verifyExpectedExecutionData( + executionData, + nodes = getExpectedNodesWhenWholeStageCodegenIsOff(), + edges = edges, + planDescription = "") + } + + test("Parse wholeStageCodegenId from nodeName") { + val getWholeStageCodegenId = PrivateMethod[Option[Long]]('getWholeStageCodegenId) + val wholeStageCodegenId = + sqlResource invokePrivate getWholeStageCodegenId(WHOLE_STAGE_CODEGEN_1) + assert(wholeStageCodegenId == Some(1)) + } + +} diff --git a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index a01d5a44da714..b68563956c82c 100644 --- a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.orc -import java.time.LocalDate +import java.time.{Instant, LocalDate} import org.apache.orc.storage.common.`type`.HiveDecimal import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument} @@ -26,7 +26,7 @@ import org.apache.orc.storage.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.orc.storage.serde2.io.HiveDecimalWritable import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.util.DateTimeUtils.{localDateToDays, toJavaDate} +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types._ @@ -167,6 +167,8 @@ private[sql] object OrcFilters extends OrcFiltersBase { new HiveDecimalWritable(HiveDecimal.create(value.asInstanceOf[java.math.BigDecimal])) case _: DateType if value.isInstanceOf[LocalDate] => toJavaDate(localDateToDays(value.asInstanceOf[LocalDate])) + case _: TimestampType if value.isInstanceOf[Instant] => + toJavaTimestamp(instantToMicros(value.asInstanceOf[Instant])) case _ => value } diff --git a/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index a1c325e7bb876..88b4b243b543a 100644 --- a/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -245,29 +245,41 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { } test("filter pushdown - timestamp") { - val timeString = "2015-08-20 14:57:00" - val timestamps = (1 to 4).map { i => - val milliseconds = Timestamp.valueOf(timeString).getTime + i * 3600 - new Timestamp(milliseconds) - } - withOrcDataFrame(timestamps.map(Tuple1(_))) { implicit df => - checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) + val input = Seq( + "1000-01-01 01:02:03", + "1582-10-01 00:11:22", + "1900-01-01 23:59:59", + "2020-05-25 10:11:12").map(Timestamp.valueOf) - checkFilterPredicate($"_1" === timestamps(0), PredicateLeaf.Operator.EQUALS) - checkFilterPredicate($"_1" <=> timestamps(0), PredicateLeaf.Operator.NULL_SAFE_EQUALS) - - checkFilterPredicate($"_1" < timestamps(1), PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate($"_1" > timestamps(2), PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" <= timestamps(0), PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" >= timestamps(3), PredicateLeaf.Operator.LESS_THAN) - - checkFilterPredicate(Literal(timestamps(0)) === $"_1", PredicateLeaf.Operator.EQUALS) - checkFilterPredicate(Literal(timestamps(0)) <=> $"_1", - PredicateLeaf.Operator.NULL_SAFE_EQUALS) - checkFilterPredicate(Literal(timestamps(1)) > $"_1", PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate(Literal(timestamps(2)) < $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(timestamps(0)) >= $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(timestamps(3)) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + withOrcFile(input.map(Tuple1(_))) { path => + Seq(false, true).foreach { java8Api => + withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> java8Api.toString) { + readFile(path) { implicit df => + val timestamps = input.map(Literal(_)) + checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate($"_1" === timestamps(0), PredicateLeaf.Operator.EQUALS) + checkFilterPredicate($"_1" <=> timestamps(0), PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate($"_1" < timestamps(1), PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate($"_1" > timestamps(2), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate($"_1" <= timestamps(0), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate($"_1" >= timestamps(3), PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(timestamps(0)) === $"_1", PredicateLeaf.Operator.EQUALS) + checkFilterPredicate( + Literal(timestamps(0)) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(timestamps(1)) > $"_1", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate( + Literal(timestamps(2)) < $"_1", + PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate( + Literal(timestamps(0)) >= $"_1", + PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(timestamps(3)) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + } + } + } } } diff --git a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 445a52cece1c3..4b642080d25ad 100644 --- a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.orc -import java.time.LocalDate +import java.time.{Instant, LocalDate} import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.ql.io.sarg.{PredicateLeaf, SearchArgument} @@ -26,7 +26,7 @@ import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.util.DateTimeUtils.{localDateToDays, toJavaDate} +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types._ @@ -167,6 +167,8 @@ private[sql] object OrcFilters extends OrcFiltersBase { new HiveDecimalWritable(HiveDecimal.create(value.asInstanceOf[java.math.BigDecimal])) case _: DateType if value.isInstanceOf[LocalDate] => toJavaDate(localDateToDays(value.asInstanceOf[LocalDate])) + case _: TimestampType if value.isInstanceOf[Instant] => + toJavaTimestamp(instantToMicros(value.asInstanceOf[Instant])) case _ => value } diff --git a/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index 815af05beb002..2263179515a5f 100644 --- a/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -246,29 +246,41 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { } test("filter pushdown - timestamp") { - val timeString = "2015-08-20 14:57:00" - val timestamps = (1 to 4).map { i => - val milliseconds = Timestamp.valueOf(timeString).getTime + i * 3600 - new Timestamp(milliseconds) - } - withOrcDataFrame(timestamps.map(Tuple1(_))) { implicit df => - checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) - - checkFilterPredicate($"_1" === timestamps(0), PredicateLeaf.Operator.EQUALS) - checkFilterPredicate($"_1" <=> timestamps(0), PredicateLeaf.Operator.NULL_SAFE_EQUALS) + val input = Seq( + "1000-01-01 01:02:03", + "1582-10-01 00:11:22", + "1900-01-01 23:59:59", + "2020-05-25 10:11:12").map(Timestamp.valueOf) - checkFilterPredicate($"_1" < timestamps(1), PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate($"_1" > timestamps(2), PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" <= timestamps(0), PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" >= timestamps(3), PredicateLeaf.Operator.LESS_THAN) + withOrcFile(input.map(Tuple1(_))) { path => + Seq(false, true).foreach { java8Api => + withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> java8Api.toString) { + readFile(path) { implicit df => + val timestamps = input.map(Literal(_)) + checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) - checkFilterPredicate(Literal(timestamps(0)) === $"_1", PredicateLeaf.Operator.EQUALS) - checkFilterPredicate( - Literal(timestamps(0)) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) - checkFilterPredicate(Literal(timestamps(1)) > $"_1", PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate(Literal(timestamps(2)) < $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(timestamps(0)) >= $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(timestamps(3)) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate($"_1" === timestamps(0), PredicateLeaf.Operator.EQUALS) + checkFilterPredicate($"_1" <=> timestamps(0), PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate($"_1" < timestamps(1), PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate($"_1" > timestamps(2), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate($"_1" <= timestamps(0), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate($"_1" >= timestamps(3), PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(timestamps(0)) === $"_1", PredicateLeaf.Operator.EQUALS) + checkFilterPredicate( + Literal(timestamps(0)) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(timestamps(1)) > $"_1", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate( + Literal(timestamps(2)) < $"_1", + PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate( + Literal(timestamps(0)) >= $"_1", + PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(timestamps(3)) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + } + } + } } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index d14d70f7d3d83..b193c73563ae0 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -44,12 +44,13 @@ import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.{Utils => SparkUtils} private[hive] class SparkExecuteStatementOperation( + val sqlContext: SQLContext, parentSession: HiveSession, statement: String, confOverlay: JMap[String, String], runInBackground: Boolean = true) - (sqlContext: SQLContext, sessionToActivePool: JMap[SessionHandle, String]) extends ExecuteStatementOperation(parentSession, statement, confOverlay, runInBackground) + with SparkOperation with Logging { private var result: DataFrame = _ @@ -62,7 +63,6 @@ private[hive] class SparkExecuteStatementOperation( private var previousFetchStartOffset: Long = 0 private var iter: Iterator[SparkRow] = _ private var dataTypes: Array[DataType] = _ - private var statementId: String = _ private lazy val resultSchema: TableSchema = { if (result == null || result.schema.isEmpty) { @@ -73,13 +73,6 @@ private[hive] class SparkExecuteStatementOperation( } } - override def close(): Unit = { - // RDDs will be cleaned automatically upon garbage collection. - logInfo(s"Close statement with $statementId") - cleanup(OperationState.CLOSED) - HiveThriftServer2.eventManager.onOperationClosed(statementId) - } - def addNonNullColumnValue(from: SparkRow, to: ArrayBuffer[Any], ordinal: Int): Unit = { dataTypes(ordinal) match { case StringType => @@ -100,12 +93,15 @@ private[hive] class SparkExecuteStatementOperation( to += from.getByte(ordinal) case ShortType => to += from.getShort(ordinal) - case DateType => - to += from.getAs[Date](ordinal) - case TimestampType => - to += from.getAs[Timestamp](ordinal) case BinaryType => to += from.getAs[Array[Byte]](ordinal) + // SPARK-31859, SPARK-31861: Date and Timestamp need to be turned to String here to: + // - respect spark.sql.session.timeZone + // - work with spark.sql.datetime.java8API.enabled + // These types have always been sent over the wire as string, converted later. + case _: DateType | _: TimestampType => + val hiveString = HiveResult.toHiveString((from.get(ordinal), dataTypes(ordinal))) + to += hiveString case CalendarIntervalType => to += HiveResult.toHiveString((from.getAs[CalendarInterval](ordinal), CalendarIntervalType)) case _: ArrayType | _: StructType | _: MapType | _: UserDefinedType[_] => @@ -114,7 +110,7 @@ private[hive] class SparkExecuteStatementOperation( } } - def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = withSchedulerPool { + def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = withLocalProperties { log.info(s"Received getNextRowSet request order=${order} and maxRowsL=${maxRowsL} " + s"with ${statementId}") validateDefaultFetchOrientation(order) @@ -193,7 +189,6 @@ private[hive] class SparkExecuteStatementOperation( override def runInternal(): Unit = { setState(OperationState.PENDING) - statementId = UUID.randomUUID().toString logInfo(s"Submitting query '$statement' with $statementId") HiveThriftServer2.eventManager.onStatementStart( statementId, @@ -217,7 +212,9 @@ private[hive] class SparkExecuteStatementOperation( override def run(): Unit = { registerCurrentOperationLog() try { - execute() + withLocalProperties { + execute() + } } catch { case e: HiveSQLException => setOperationException(e) @@ -259,7 +256,7 @@ private[hive] class SparkExecuteStatementOperation( } } - private def execute(): Unit = withSchedulerPool { + private def execute(): Unit = { try { synchronized { if (getStatus.getState.isTerminal) { @@ -282,13 +279,6 @@ private[hive] class SparkExecuteStatementOperation( sqlContext.sparkContext.setJobGroup(statementId, statement) result = sqlContext.sql(statement) logDebug(result.queryExecution.toString()) - result.queryExecution.logical match { - case SetCommand(Some((SQLConf.THRIFTSERVER_POOL.key, Some(value)))) => - sessionToActivePool.put(parentSession.getSessionHandle, value) - logInfo(s"Setting ${SparkContext.SPARK_SCHEDULER_POOL}=$value for future statements " + - "in this session.") - case _ => - } HiveThriftServer2.eventManager.onStatementParsed(statementId, result.queryExecution.toString()) iter = { @@ -346,38 +336,25 @@ private[hive] class SparkExecuteStatementOperation( synchronized { if (!getStatus.getState.isTerminal) { logInfo(s"Cancel query with $statementId") - cleanup(OperationState.CANCELED) + cleanup() + setState(OperationState.CANCELED) HiveThriftServer2.eventManager.onStatementCanceled(statementId) } } } - private def cleanup(state: OperationState): Unit = { - setState(state) + override protected def cleanup(): Unit = { if (runInBackground) { val backgroundHandle = getBackgroundHandle() if (backgroundHandle != null) { backgroundHandle.cancel(true) } } + // RDDs will be cleaned automatically upon garbage collection. if (statementId != null) { sqlContext.sparkContext.cancelJobGroup(statementId) } } - - private def withSchedulerPool[T](body: => T): T = { - val pool = sessionToActivePool.get(parentSession.getSessionHandle) - if (pool != null) { - sqlContext.sparkContext.setLocalProperty(SparkContext.SPARK_SCHEDULER_POOL, pool) - } - try { - body - } finally { - if (pool != null) { - sqlContext.sparkContext.setLocalProperty(SparkContext.SPARK_SCHEDULER_POOL, null) - } - } - } } object SparkExecuteStatementOperation { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetCatalogsOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetCatalogsOperation.scala index 2945cfd200e46..55070e035b944 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetCatalogsOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetCatalogsOperation.scala @@ -36,19 +36,13 @@ import org.apache.spark.util.{Utils => SparkUtils} * @param parentSession a HiveSession from SessionManager */ private[hive] class SparkGetCatalogsOperation( - sqlContext: SQLContext, + val sqlContext: SQLContext, parentSession: HiveSession) - extends GetCatalogsOperation(parentSession) with Logging { - - private var statementId: String = _ - - override def close(): Unit = { - super.close() - HiveThriftServer2.eventManager.onOperationClosed(statementId) - } + extends GetCatalogsOperation(parentSession) + with SparkOperation + with Logging { override def runInternal(): Unit = { - statementId = UUID.randomUUID().toString val logMsg = "Listing catalogs" logInfo(s"$logMsg with $statementId") setState(OperationState.RUNNING) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetColumnsOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetColumnsOperation.scala index ff7cbfeae13be..ca8ad5e6ad134 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetColumnsOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetColumnsOperation.scala @@ -48,26 +48,19 @@ import org.apache.spark.util.{Utils => SparkUtils} * @param columnName column name */ private[hive] class SparkGetColumnsOperation( - sqlContext: SQLContext, + val sqlContext: SQLContext, parentSession: HiveSession, catalogName: String, schemaName: String, tableName: String, columnName: String) extends GetColumnsOperation(parentSession, catalogName, schemaName, tableName, columnName) - with Logging { + with SparkOperation + with Logging { val catalog: SessionCatalog = sqlContext.sessionState.catalog - private var statementId: String = _ - - override def close(): Unit = { - super.close() - HiveThriftServer2.eventManager.onOperationClosed(statementId) - } - override def runInternal(): Unit = { - statementId = UUID.randomUUID().toString // Do not change cmdStr. It's used for Hive auditing and authorization. val cmdStr = s"catalog : $catalogName, schemaPattern : $schemaName, tablePattern : $tableName" val logMsg = s"Listing columns '$cmdStr, columnName : $columnName'" diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetFunctionsOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetFunctionsOperation.scala index d9c12b6ca9e64..f5e647bfd4f38 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetFunctionsOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetFunctionsOperation.scala @@ -43,22 +43,16 @@ import org.apache.spark.util.{Utils => SparkUtils} * @param functionName function name pattern */ private[hive] class SparkGetFunctionsOperation( - sqlContext: SQLContext, + val sqlContext: SQLContext, parentSession: HiveSession, catalogName: String, schemaName: String, functionName: String) - extends GetFunctionsOperation(parentSession, catalogName, schemaName, functionName) with Logging { - - private var statementId: String = _ - - override def close(): Unit = { - super.close() - HiveThriftServer2.eventManager.onOperationClosed(statementId) - } + extends GetFunctionsOperation(parentSession, catalogName, schemaName, functionName) + with SparkOperation + with Logging { override def runInternal(): Unit = { - statementId = UUID.randomUUID().toString // Do not change cmdStr. It's used for Hive auditing and authorization. val cmdStr = s"catalog : $catalogName, schemaPattern : $schemaName" val logMsg = s"Listing functions '$cmdStr, functionName : $functionName'" diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala index db19880d1b99f..74220986fcd34 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala @@ -40,21 +40,15 @@ import org.apache.spark.util.{Utils => SparkUtils} * @param schemaName database name, null or a concrete database name */ private[hive] class SparkGetSchemasOperation( - sqlContext: SQLContext, + val sqlContext: SQLContext, parentSession: HiveSession, catalogName: String, schemaName: String) - extends GetSchemasOperation(parentSession, catalogName, schemaName) with Logging { - - private var statementId: String = _ - - override def close(): Unit = { - super.close() - HiveThriftServer2.eventManager.onOperationClosed(statementId) - } + extends GetSchemasOperation(parentSession, catalogName, schemaName) + with SparkOperation + with Logging { override def runInternal(): Unit = { - statementId = UUID.randomUUID().toString // Do not change cmdStr. It's used for Hive auditing and authorization. val cmdStr = s"catalog : $catalogName, schemaPattern : $schemaName" val logMsg = s"Listing databases '$cmdStr'" diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTableTypesOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTableTypesOperation.scala index b4093e58d3c07..1cf9c3a731af5 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTableTypesOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTableTypesOperation.scala @@ -37,16 +37,11 @@ import org.apache.spark.util.{Utils => SparkUtils} * @param parentSession a HiveSession from SessionManager */ private[hive] class SparkGetTableTypesOperation( - sqlContext: SQLContext, + val sqlContext: SQLContext, parentSession: HiveSession) - extends GetTableTypesOperation(parentSession) with SparkMetadataOperationUtils with Logging { - - private var statementId: String = _ - - override def close(): Unit = { - super.close() - HiveThriftServer2.eventManager.onOperationClosed(statementId) - } + extends GetTableTypesOperation(parentSession) + with SparkOperation + with Logging { override def runInternal(): Unit = { statementId = UUID.randomUUID().toString diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala index 45c6d980aac47..a1d21e2d60c63 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala @@ -46,24 +46,17 @@ import org.apache.spark.util.{Utils => SparkUtils} * @param tableTypes list of allowed table types, e.g. "TABLE", "VIEW" */ private[hive] class SparkGetTablesOperation( - sqlContext: SQLContext, + val sqlContext: SQLContext, parentSession: HiveSession, catalogName: String, schemaName: String, tableName: String, tableTypes: JList[String]) extends GetTablesOperation(parentSession, catalogName, schemaName, tableName, tableTypes) - with SparkMetadataOperationUtils with Logging { - - private var statementId: String = _ - - override def close(): Unit = { - super.close() - HiveThriftServer2.eventManager.onOperationClosed(statementId) - } + with SparkOperation + with Logging { override def runInternal(): Unit = { - statementId = UUID.randomUUID().toString // Do not change cmdStr. It's used for Hive auditing and authorization. val cmdStr = s"catalog : $catalogName, schemaPattern : $schemaName" val tableTypesStr = if (tableTypes == null) "null" else tableTypes.asScala.mkString(",") diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTypeInfoOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTypeInfoOperation.scala index dd5668a93f82d..e38139d60df60 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTypeInfoOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTypeInfoOperation.scala @@ -36,16 +36,11 @@ import org.apache.spark.util.{Utils => SparkUtils} * @param parentSession a HiveSession from SessionManager */ private[hive] class SparkGetTypeInfoOperation( - sqlContext: SQLContext, + val sqlContext: SQLContext, parentSession: HiveSession) - extends GetTypeInfoOperation(parentSession) with Logging { - - private var statementId: String = _ - - override def close(): Unit = { - super.close() - HiveThriftServer2.eventManager.onOperationClosed(statementId) - } + extends GetTypeInfoOperation(parentSession) + with SparkOperation + with Logging { override def runInternal(): Unit = { statementId = UUID.randomUUID().toString diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkOperation.scala new file mode 100644 index 0000000000000..3da568cfa256e --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkOperation.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import org.apache.hive.service.cli.OperationState +import org.apache.hive.service.cli.operation.Operation + +import org.apache.spark.SparkContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.catalyst.catalog.CatalogTableType +import org.apache.spark.sql.catalyst.catalog.CatalogTableType.{EXTERNAL, MANAGED, VIEW} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Utils + +/** + * Utils for Spark operations. + */ +private[hive] trait SparkOperation extends Operation with Logging { + + protected def sqlContext: SQLContext + + protected var statementId = getHandle().getHandleIdentifier().getPublicId().toString() + + protected def cleanup(): Unit = Unit // noop by default + + abstract override def run(): Unit = { + withLocalProperties { + super.run() + } + } + + abstract override def close(): Unit = { + cleanup() + super.close() + logInfo(s"Close statement with $statementId") + HiveThriftServer2.eventManager.onOperationClosed(statementId) + } + + // Set thread local properties for the execution of the operation. + // This method should be applied during the execution of the operation, by all the child threads. + // The original spark context local properties will be restored after the operation. + // + // It is used to: + // - set appropriate SparkSession + // - set scheduler pool for the operation + def withLocalProperties[T](f: => T): T = { + val originalProps = Utils.cloneProperties(sqlContext.sparkContext.getLocalProperties) + val originalSession = SparkSession.getActiveSession + + try { + // Set active SparkSession + SparkSession.setActiveSession(sqlContext.sparkSession) + + // Set scheduler pool + sqlContext.sparkSession.conf.getOption(SQLConf.THRIFTSERVER_POOL.key) match { + case Some(pool) => + sqlContext.sparkContext.setLocalProperty(SparkContext.SPARK_SCHEDULER_POOL, pool) + case None => + } + + // run the body + f + } finally { + // reset local properties, will also reset SPARK_SCHEDULER_POOL + sqlContext.sparkContext.setLocalProperties(originalProps) + + originalSession match { + case Some(session) => SparkSession.setActiveSession(session) + case None => SparkSession.clearActiveSession() + } + } + } + + def tableTypeString(tableType: CatalogTableType): String = tableType match { + case EXTERNAL | MANAGED => "TABLE" + case VIEW => "VIEW" + case t => + throw new IllegalArgumentException(s"Unknown table type is found: $t") + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index bffa24c469601..c7848afd822d5 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -380,10 +380,18 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { ret = rc.getResponseCode if (ret != 0) { - // For analysis exception, only the error is printed out to the console. - rc.getException() match { - case e : AnalysisException => - err.println(s"""Error in query: ${e.getMessage}""") + rc.getException match { + case e: AnalysisException => e.cause match { + case Some(_) if !sessionState.getIsSilent => + err.println( + s"""Error in query: ${e.getMessage} + |${org.apache.hadoop.util.StringUtils.stringifyException(e)} + """.stripMargin) + // For analysis exceptions in silent mode or simple ones that only related to the + // query itself, such as `NoSuchDatabaseException`, only the error is printed out + // to the console. + case _ => err.println(s"""Error in query: ${e.getMessage}""") + } case _ => err.println(rc.getErrorMessage()) } driver.close() @@ -516,7 +524,6 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { var insideComment = false var escape = false var beginIndex = 0 - var endIndex = line.length val ret = new JArrayList[String] for (index <- 0 until line.length) { @@ -544,8 +551,6 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { } else if (hasNext && line.charAt(index + 1) == '-') { // ignore quotes and ; insideComment = true - // ignore eol - endIndex = index } } else if (line.charAt(index) == ';') { if (insideSingleQuote || insideDoubleQuote || insideComment) { @@ -555,8 +560,11 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { ret.add(line.substring(beginIndex, index)) beginIndex = index + 1 } - } else { - // nothing to do + } else if (line.charAt(index) == '\n') { + // with a new line the inline comment should end. + if (!escape) { + insideComment = false + } } // set the escape if (escape) { @@ -565,7 +573,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { escape = true } } - ret.add(line.substring(beginIndex, endIndex)) + ret.add(line.substring(beginIndex)) ret } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index b3171897141c2..e10e7ed1a2769 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -78,7 +78,6 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, sqlContext: val ctx = sparkSqlOperationManager.sessionToContexts.getOrDefault(sessionHandle, sqlContext) ctx.sparkSession.sessionState.catalog.getTempViewNames().foreach(ctx.uncacheTable) super.closeSession(sessionHandle) - sparkSqlOperationManager.sessionToActivePool.remove(sessionHandle) sparkSqlOperationManager.sessionToContexts.remove(sessionHandle) } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index 3396560f43502..bc9c13eb0d4f8 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -38,7 +38,6 @@ private[thriftserver] class SparkSQLOperationManager() val handleToOperation = ReflectionUtils .getSuperField[JMap[OperationHandle, Operation]](this, "handleToOperation") - val sessionToActivePool = new ConcurrentHashMap[SessionHandle, String]() val sessionToContexts = new ConcurrentHashMap[SessionHandle, SQLContext]() override def newExecuteStatementOperation( @@ -51,8 +50,8 @@ private[thriftserver] class SparkSQLOperationManager() s" initialized or had already closed.") val conf = sqlContext.sessionState.conf val runInBackground = async && conf.getConf(HiveUtils.HIVE_THRIFT_SERVER_ASYNC) - val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay, - runInBackground)(sqlContext, sessionToActivePool) + val operation = new SparkExecuteStatementOperation( + sqlContext, parentSession, statement, confOverlay, runInBackground) handleToOperation.put(operation.getHandle, operation) logDebug(s"Created Operation for $statement with session=$parentSession, " + s"runInBackground=$runInBackground") diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/HiveThriftServer2Listener.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/HiveThriftServer2Listener.scala index 6d0a506fa94dc..6b7e5ee611417 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/HiveThriftServer2Listener.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/HiveThriftServer2Listener.scala @@ -25,6 +25,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hive.service.server.HiveServer2 import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.internal.Logging import org.apache.spark.internal.config.Status.LIVE_ENTITY_UPDATE_PERIOD import org.apache.spark.scheduler._ import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2.ExecutionState @@ -38,7 +39,7 @@ private[thriftserver] class HiveThriftServer2Listener( kvstore: ElementTrackingStore, sparkConf: SparkConf, server: Option[HiveServer2], - live: Boolean = true) extends SparkListener { + live: Boolean = true) extends SparkListener with Logging { private val sessionList = new ConcurrentHashMap[String, LiveSessionData]() private val executionList = new ConcurrentHashMap[String, LiveExecutionData]() @@ -131,60 +132,83 @@ private[thriftserver] class HiveThriftServer2Listener( updateLiveStore(session) } - private def onSessionClosed(e: SparkListenerThriftServerSessionClosed): Unit = { - val session = sessionList.get(e.sessionId) - session.finishTimestamp = e.finishTime - updateStoreWithTriggerEnabled(session) - sessionList.remove(e.sessionId) - } + private def onSessionClosed(e: SparkListenerThriftServerSessionClosed): Unit = + Option(sessionList.get(e.sessionId)) match { + case Some(sessionData) => + sessionData.finishTimestamp = e.finishTime + updateStoreWithTriggerEnabled(sessionData) + sessionList.remove(e.sessionId) + case None => logWarning(s"onSessionClosed called with unknown session id: ${e.sessionId}") + } private def onOperationStart(e: SparkListenerThriftServerOperationStart): Unit = { - val info = getOrCreateExecution( + val executionData = getOrCreateExecution( e.id, e.statement, e.sessionId, e.startTime, e.userName) - info.state = ExecutionState.STARTED - executionList.put(e.id, info) - sessionList.get(e.sessionId).totalExecution += 1 - executionList.get(e.id).groupId = e.groupId - updateLiveStore(executionList.get(e.id)) - updateLiveStore(sessionList.get(e.sessionId)) + executionData.state = ExecutionState.STARTED + executionList.put(e.id, executionData) + executionData.groupId = e.groupId + updateLiveStore(executionData) + + Option(sessionList.get(e.sessionId)) match { + case Some(sessionData) => + sessionData.totalExecution += 1 + updateLiveStore(sessionData) + case None => logWarning(s"onOperationStart called with unknown session id: ${e.sessionId}." + + s"Regardless, the operation has been registered.") + } } - private def onOperationParsed(e: SparkListenerThriftServerOperationParsed): Unit = { - executionList.get(e.id).executePlan = e.executionPlan - executionList.get(e.id).state = ExecutionState.COMPILED - updateLiveStore(executionList.get(e.id)) - } + private def onOperationParsed(e: SparkListenerThriftServerOperationParsed): Unit = + Option(executionList.get(e.id)) match { + case Some(executionData) => + executionData.executePlan = e.executionPlan + executionData.state = ExecutionState.COMPILED + updateLiveStore(executionData) + case None => logWarning(s"onOperationParsed called with unknown operation id: ${e.id}") + } - private def onOperationCanceled(e: SparkListenerThriftServerOperationCanceled): Unit = { - executionList.get(e.id).finishTimestamp = e.finishTime - executionList.get(e.id).state = ExecutionState.CANCELED - updateLiveStore(executionList.get(e.id)) - } + private def onOperationCanceled(e: SparkListenerThriftServerOperationCanceled): Unit = + Option(executionList.get(e.id)) match { + case Some(executionData) => + executionData.finishTimestamp = e.finishTime + executionData.state = ExecutionState.CANCELED + updateLiveStore(executionData) + case None => logWarning(s"onOperationCanceled called with unknown operation id: ${e.id}") + } - private def onOperationError(e: SparkListenerThriftServerOperationError): Unit = { - executionList.get(e.id).finishTimestamp = e.finishTime - executionList.get(e.id).detail = e.errorMsg - executionList.get(e.id).state = ExecutionState.FAILED - updateLiveStore(executionList.get(e.id)) - } + private def onOperationError(e: SparkListenerThriftServerOperationError): Unit = + Option(executionList.get(e.id)) match { + case Some(executionData) => + executionData.finishTimestamp = e.finishTime + executionData.detail = e.errorMsg + executionData.state = ExecutionState.FAILED + updateLiveStore(executionData) + case None => logWarning(s"onOperationError called with unknown operation id: ${e.id}") + } - private def onOperationFinished(e: SparkListenerThriftServerOperationFinish): Unit = { - executionList.get(e.id).finishTimestamp = e.finishTime - executionList.get(e.id).state = ExecutionState.FINISHED - updateLiveStore(executionList.get(e.id)) - } + private def onOperationFinished(e: SparkListenerThriftServerOperationFinish): Unit = + Option(executionList.get(e.id)) match { + case Some(executionData) => + executionData.finishTimestamp = e.finishTime + executionData.state = ExecutionState.FINISHED + updateLiveStore(executionData) + case None => logWarning(s"onOperationFinished called with unknown operation id: ${e.id}") + } - private def onOperationClosed(e: SparkListenerThriftServerOperationClosed): Unit = { - executionList.get(e.id).closeTimestamp = e.closeTime - executionList.get(e.id).state = ExecutionState.CLOSED - updateStoreWithTriggerEnabled(executionList.get(e.id)) - executionList.remove(e.id) - } + private def onOperationClosed(e: SparkListenerThriftServerOperationClosed): Unit = + Option(executionList.get(e.id)) match { + case Some(executionData) => + executionData.closeTimestamp = e.closeTime + executionData.state = ExecutionState.CLOSED + updateStoreWithTriggerEnabled(executionData) + executionList.remove(e.id) + case None => logWarning(s"onOperationClosed called with unknown operation id: ${e.id}") + } // Update both live and history stores. Trigger is enabled by default, hence // it will cleanup the entity which exceeds the threshold. diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index 7fb755c292b38..8efbdb30c605c 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -22,8 +22,7 @@ import java.nio.charset.StandardCharsets.UTF_8 import java.util.Calendar import javax.servlet.http.HttpServletRequest -import scala.collection.JavaConverters._ -import scala.xml.{Node, Unparsed} +import scala.xml.Node import org.apache.commons.text.StringEscapeUtils @@ -78,26 +77,8 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" val sqlTableTag = "sqlstat" - val parameterOtherTable = request.getParameterMap().asScala - .filterNot(_._1.startsWith(sqlTableTag)) - .map { case (name, vals) => - name + "=" + vals(0) - } - - val parameterSqlTablePage = request.getParameter(s"$sqlTableTag.page") - val parameterSqlTableSortColumn = request.getParameter(s"$sqlTableTag.sort") - val parameterSqlTableSortDesc = request.getParameter(s"$sqlTableTag.desc") - val parameterSqlPageSize = request.getParameter(s"$sqlTableTag.pageSize") - - val sqlTablePage = Option(parameterSqlTablePage).map(_.toInt).getOrElse(1) - val sqlTableSortColumn = Option(parameterSqlTableSortColumn).map { sortColumn => - UIUtils.decodeURLParameter(sortColumn) - }.getOrElse("Start Time") - val sqlTableSortDesc = Option(parameterSqlTableSortDesc).map(_.toBoolean).getOrElse( - // New executions should be shown above old executions by default. - sqlTableSortColumn == "Start Time" - ) - val sqlTablePageSize = Option(parameterSqlPageSize).map(_.toInt).getOrElse(100) + val sqlTablePage = + Option(request.getParameter(s"$sqlTableTag.page")).map(_.toInt).getOrElse(1) try { Some(new SqlStatsPagedTable( @@ -106,12 +87,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" store.getExecutionList, "sqlserver", UIUtils.prependBaseUri(request, parent.basePath), - parameterOtherTable, - sqlTableTag, - pageSize = sqlTablePageSize, - sortColumn = sqlTableSortColumn, - desc = sqlTableSortDesc - ).table(sqlTablePage)) + sqlTableTag).table(sqlTablePage)) } catch { case e@(_: IllegalArgumentException | _: IndexOutOfBoundsException) => Some(
@@ -146,26 +122,8 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" val sessionTableTag = "sessionstat" - val parameterOtherTable = request.getParameterMap().asScala - .filterNot(_._1.startsWith(sessionTableTag)) - .map { case (name, vals) => - name + "=" + vals(0) - } - - val parameterSessionTablePage = request.getParameter(s"$sessionTableTag.page") - val parameterSessionTableSortColumn = request.getParameter(s"$sessionTableTag.sort") - val parameterSessionTableSortDesc = request.getParameter(s"$sessionTableTag.desc") - val parameterSessionPageSize = request.getParameter(s"$sessionTableTag.pageSize") - - val sessionTablePage = Option(parameterSessionTablePage).map(_.toInt).getOrElse(1) - val sessionTableSortColumn = Option(parameterSessionTableSortColumn).map { sortColumn => - UIUtils.decodeURLParameter(sortColumn) - }.getOrElse("Start Time") - val sessionTableSortDesc = Option(parameterSessionTableSortDesc).map(_.toBoolean).getOrElse( - // New session should be shown above old session by default. - (sessionTableSortColumn == "Start Time") - ) - val sessionTablePageSize = Option(parameterSessionPageSize).map(_.toInt).getOrElse(100) + val sessionTablePage = + Option(request.getParameter(s"$sessionTableTag.page")).map(_.toInt).getOrElse(1) try { Some(new SessionStatsPagedTable( @@ -174,11 +132,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" store.getSessionList, "sqlserver", UIUtils.prependBaseUri(request, parent.basePath), - parameterOtherTable, - sessionTableTag, - pageSize = sessionTablePageSize, - sortColumn = sessionTableSortColumn, - desc = sessionTableSortDesc + sessionTableTag ).table(sessionTablePage)) } catch { case e@(_: IllegalArgumentException | _: IndexOutOfBoundsException) => @@ -216,104 +170,59 @@ private[ui] class SqlStatsPagedTable( data: Seq[ExecutionInfo], subPath: String, basePath: String, - parameterOtherTable: Iterable[String], - sqlStatsTableTag: String, - pageSize: Int, - sortColumn: String, - desc: Boolean) extends PagedTable[SqlStatsTableRow] { + sqlStatsTableTag: String) extends PagedTable[SqlStatsTableRow] { - override val dataSource = new SqlStatsTableDataSource(data, pageSize, sortColumn, desc) + private val (sortColumn, desc, pageSize) = + getTableParameters(request, sqlStatsTableTag, "Start Time") + + private val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) + + private val parameterPath = + s"$basePath/$subPath/?${getParameterOtherTable(request, sqlStatsTableTag)}" - private val parameterPath = s"$basePath/$subPath/?${parameterOtherTable.mkString("&")}" + override val dataSource = new SqlStatsTableDataSource(data, pageSize, sortColumn, desc) override def tableId: String = sqlStatsTableTag override def tableCssClass: String = - "table table-bordered table-sm table-striped " + - "table-head-clickable table-cell-width-limited" + "table table-bordered table-sm table-striped table-head-clickable table-cell-width-limited" override def pageLink(page: Int): String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) parameterPath + s"&$pageNumberFormField=$page" + s"&$sqlStatsTableTag.sort=$encodedSortColumn" + s"&$sqlStatsTableTag.desc=$desc" + - s"&$pageSizeFormField=$pageSize" + s"&$pageSizeFormField=$pageSize" + + s"#$sqlStatsTableTag" } override def pageSizeFormField: String = s"$sqlStatsTableTag.pageSize" override def pageNumberFormField: String = s"$sqlStatsTableTag.page" - override def goButtonFormPath: String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) - s"$parameterPath&$sqlStatsTableTag.sort=$encodedSortColumn&$sqlStatsTableTag.desc=$desc" - } + override def goButtonFormPath: String = + s"$parameterPath&$sqlStatsTableTag.sort=$encodedSortColumn" + + s"&$sqlStatsTableTag.desc=$desc#$sqlStatsTableTag" override def headers: Seq[Node] = { - val sqlTableHeaders = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", - "Close Time", "Execution Time", "Duration", "Statement", "State", "Detail") - - val tooltips = Seq(None, None, None, None, Some(THRIFT_SERVER_FINISH_TIME), - Some(THRIFT_SERVER_CLOSE_TIME), Some(THRIFT_SERVER_EXECUTION), - Some(THRIFT_SERVER_DURATION), None, None, None) - - assert(sqlTableHeaders.length == tooltips.length) - - val headerRow: Seq[Node] = { - sqlTableHeaders.zip(tooltips).map { case (header, tooltip) => - if (header == sortColumn) { - val headerLink = Unparsed( - parameterPath + - s"&$sqlStatsTableTag.sort=${URLEncoder.encode(header, UTF_8.name())}" + - s"&$sqlStatsTableTag.desc=${!desc}" + - s"&$sqlStatsTableTag.pageSize=$pageSize" + - s"#$sqlStatsTableTag") - val arrow = if (desc) "▾" else "▴" // UP or DOWN - - if (tooltip.nonEmpty) { - - - - {header} {Unparsed(arrow)} - - - - } else { - - - {header} {Unparsed(arrow)} - - - } - } else { - val headerLink = Unparsed( - parameterPath + - s"&$sqlStatsTableTag.sort=${URLEncoder.encode(header, UTF_8.name())}" + - s"&$sqlStatsTableTag.pageSize=$pageSize" + - s"#$sqlStatsTableTag") - - if(tooltip.nonEmpty) { - - - - {header} - - - - } else { - - - {header} - - - } - } - } - } - - {headerRow} - + val sqlTableHeadersAndTooltips: Seq[(String, Boolean, Option[String])] = + Seq( + ("User", true, None), + ("JobID", true, None), + ("GroupID", true, None), + ("Start Time", true, None), + ("Finish Time", true, Some(THRIFT_SERVER_FINISH_TIME)), + ("Close Time", true, Some(THRIFT_SERVER_CLOSE_TIME)), + ("Execution Time", true, Some(THRIFT_SERVER_EXECUTION)), + ("Duration", true, Some(THRIFT_SERVER_DURATION)), + ("Statement", true, None), + ("State", true, None), + ("Detail", true, None)) + + isSortColumnValid(sqlTableHeadersAndTooltips, sortColumn) + + headerRow(sqlTableHeadersAndTooltips, desc, pageSize, sortColumn, parameterPath, + sqlStatsTableTag, sqlStatsTableTag) } override def row(sqlStatsTableRow: SqlStatsTableRow): Seq[Node] = { @@ -391,101 +300,55 @@ private[ui] class SessionStatsPagedTable( data: Seq[SessionInfo], subPath: String, basePath: String, - parameterOtherTable: Iterable[String], - sessionStatsTableTag: String, - pageSize: Int, - sortColumn: String, - desc: Boolean) extends PagedTable[SessionInfo] { + sessionStatsTableTag: String) extends PagedTable[SessionInfo] { - override val dataSource = new SessionStatsTableDataSource(data, pageSize, sortColumn, desc) + private val (sortColumn, desc, pageSize) = + getTableParameters(request, sessionStatsTableTag, "Start Time") + + private val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) - private val parameterPath = s"$basePath/$subPath/?${parameterOtherTable.mkString("&")}" + private val parameterPath = + s"$basePath/$subPath/?${getParameterOtherTable(request, sessionStatsTableTag)}" + + override val dataSource = new SessionStatsTableDataSource(data, pageSize, sortColumn, desc) override def tableId: String = sessionStatsTableTag override def tableCssClass: String = - "table table-bordered table-sm table-striped " + - "table-head-clickable table-cell-width-limited" + "table table-bordered table-sm table-striped table-head-clickable table-cell-width-limited" override def pageLink(page: Int): String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) parameterPath + s"&$pageNumberFormField=$page" + s"&$sessionStatsTableTag.sort=$encodedSortColumn" + s"&$sessionStatsTableTag.desc=$desc" + - s"&$pageSizeFormField=$pageSize" + s"&$pageSizeFormField=$pageSize" + + s"#$sessionStatsTableTag" } override def pageSizeFormField: String = s"$sessionStatsTableTag.pageSize" override def pageNumberFormField: String = s"$sessionStatsTableTag.page" - override def goButtonFormPath: String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) - s"$parameterPath&$sessionStatsTableTag.sort=$encodedSortColumn&$sessionStatsTableTag.desc=$desc" - } + override def goButtonFormPath: String = + s"$parameterPath&$sessionStatsTableTag.sort=$encodedSortColumn" + + s"&$sessionStatsTableTag.desc=$desc#$sessionStatsTableTag" override def headers: Seq[Node] = { - val sessionTableHeaders = - Seq("User", "IP", "Session ID", "Start Time", "Finish Time", "Duration", "Total Execute") - - val tooltips = Seq(None, None, None, None, None, Some(THRIFT_SESSION_DURATION), - Some(THRIFT_SESSION_TOTAL_EXECUTE)) - assert(sessionTableHeaders.length == tooltips.length) - val colWidthAttr = s"${100.toDouble / sessionTableHeaders.size}%" - - val headerRow: Seq[Node] = { - sessionTableHeaders.zip(tooltips).map { case (header, tooltip) => - if (header == sortColumn) { - val headerLink = Unparsed( - parameterPath + - s"&$sessionStatsTableTag.sort=${URLEncoder.encode(header, UTF_8.name())}" + - s"&$sessionStatsTableTag.desc=${!desc}" + - s"&$sessionStatsTableTag.pageSize=$pageSize" + - s"#$sessionStatsTableTag") - val arrow = if (desc) "▾" else "▴" // UP or DOWN - - - { - if (tooltip.nonEmpty) { - - {header} {Unparsed(arrow)} - - } else { - - {header} {Unparsed(arrow)} - - } - } - - - - } else { - val headerLink = Unparsed( - parameterPath + - s"&$sessionStatsTableTag.sort=${URLEncoder.encode(header, UTF_8.name())}" + - s"&$sessionStatsTableTag.pageSize=$pageSize" + - s"#$sessionStatsTableTag") - - - - { - if (tooltip.nonEmpty) { - - {header} - - } else { - {header} - } - } - - - } - } - } - - {headerRow} - + val sessionTableHeadersAndTooltips: Seq[(String, Boolean, Option[String])] = + Seq( + ("User", true, None), + ("IP", true, None), + ("Session ID", true, None), + ("Start Time", true, None), + ("Finish Time", true, None), + ("Duration", true, Some(THRIFT_SESSION_DURATION)), + ("Total Execute", true, Some(THRIFT_SESSION_TOTAL_EXECUTE))) + + isSortColumnValid(sessionTableHeadersAndTooltips, sortColumn) + + headerRow(sessionTableHeadersAndTooltips, desc, pageSize, sortColumn, + parameterPath, sessionStatsTableTag, sessionStatsTableTag) } override def row(session: SessionInfo): Seq[Node] = { @@ -503,108 +366,94 @@ private[ui] class SessionStatsPagedTable( } } - private[ui] class SqlStatsTableRow( +private[ui] class SqlStatsTableRow( val jobId: Seq[String], val duration: Long, val executionTime: Long, val executionInfo: ExecutionInfo, val detail: String) - private[ui] class SqlStatsTableDataSource( +private[ui] class SqlStatsTableDataSource( info: Seq[ExecutionInfo], pageSize: Int, sortColumn: String, desc: Boolean) extends PagedDataSource[SqlStatsTableRow](pageSize) { - // Convert ExecutionInfo to SqlStatsTableRow which contains the final contents to show in - // the table so that we can avoid creating duplicate contents during sorting the data - private val data = info.map(sqlStatsTableRow).sorted(ordering(sortColumn, desc)) + // Convert ExecutionInfo to SqlStatsTableRow which contains the final contents to show in + // the table so that we can avoid creating duplicate contents during sorting the data + private val data = info.map(sqlStatsTableRow).sorted(ordering(sortColumn, desc)) - private var _slicedStartTime: Set[Long] = null + override def dataSize: Int = data.size - override def dataSize: Int = data.size + override def sliceData(from: Int, to: Int): Seq[SqlStatsTableRow] = data.slice(from, to) - override def sliceData(from: Int, to: Int): Seq[SqlStatsTableRow] = { - val r = data.slice(from, to) - _slicedStartTime = r.map(_.executionInfo.startTimestamp).toSet - r - } + private def sqlStatsTableRow(executionInfo: ExecutionInfo): SqlStatsTableRow = { + val duration = executionInfo.totalTime(executionInfo.closeTimestamp) + val executionTime = executionInfo.totalTime(executionInfo.finishTimestamp) + val detail = Option(executionInfo.detail).filter(!_.isEmpty) + .getOrElse(executionInfo.executePlan) + val jobId = executionInfo.jobId.toSeq.sorted - private def sqlStatsTableRow(executionInfo: ExecutionInfo): SqlStatsTableRow = { - val duration = executionInfo.totalTime(executionInfo.closeTimestamp) - val executionTime = executionInfo.totalTime(executionInfo.finishTimestamp) - val detail = Option(executionInfo.detail).filter(!_.isEmpty) - .getOrElse(executionInfo.executePlan) - val jobId = executionInfo.jobId.toSeq.sorted - - new SqlStatsTableRow(jobId, duration, executionTime, executionInfo, detail) + new SqlStatsTableRow(jobId, duration, executionTime, executionInfo, detail) + } + /** + * Return Ordering according to sortColumn and desc. + */ + private def ordering(sortColumn: String, desc: Boolean): Ordering[SqlStatsTableRow] = { + val ordering: Ordering[SqlStatsTableRow] = sortColumn match { + case "User" => Ordering.by(_.executionInfo.userName) + case "JobID" => Ordering by (_.jobId.headOption) + case "GroupID" => Ordering.by(_.executionInfo.groupId) + case "Start Time" => Ordering.by(_.executionInfo.startTimestamp) + case "Finish Time" => Ordering.by(_.executionInfo.finishTimestamp) + case "Close Time" => Ordering.by(_.executionInfo.closeTimestamp) + case "Execution Time" => Ordering.by(_.executionTime) + case "Duration" => Ordering.by(_.duration) + case "Statement" => Ordering.by(_.executionInfo.statement) + case "State" => Ordering.by(_.executionInfo.state) + case "Detail" => Ordering.by(_.detail) + case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") } - - /** - * Return Ordering according to sortColumn and desc. - */ - private def ordering(sortColumn: String, desc: Boolean): Ordering[SqlStatsTableRow] = { - val ordering: Ordering[SqlStatsTableRow] = sortColumn match { - case "User" => Ordering.by(_.executionInfo.userName) - case "JobID" => Ordering by (_.jobId.headOption) - case "GroupID" => Ordering.by(_.executionInfo.groupId) - case "Start Time" => Ordering.by(_.executionInfo.startTimestamp) - case "Finish Time" => Ordering.by(_.executionInfo.finishTimestamp) - case "Close Time" => Ordering.by(_.executionInfo.closeTimestamp) - case "Execution Time" => Ordering.by(_.executionTime) - case "Duration" => Ordering.by(_.duration) - case "Statement" => Ordering.by(_.executionInfo.statement) - case "State" => Ordering.by(_.executionInfo.state) - case "Detail" => Ordering.by(_.detail) - case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") - } - if (desc) { - ordering.reverse - } else { - ordering - } + if (desc) { + ordering.reverse + } else { + ordering } - } +} - private[ui] class SessionStatsTableDataSource( +private[ui] class SessionStatsTableDataSource( info: Seq[SessionInfo], pageSize: Int, sortColumn: String, desc: Boolean) extends PagedDataSource[SessionInfo](pageSize) { - // Sorting SessionInfo data - private val data = info.sorted(ordering(sortColumn, desc)) - - private var _slicedStartTime: Set[Long] = null - - override def dataSize: Int = data.size - - override def sliceData(from: Int, to: Int): Seq[SessionInfo] = { - val r = data.slice(from, to) - _slicedStartTime = r.map(_.startTimestamp).toSet - r + // Sorting SessionInfo data + private val data = info.sorted(ordering(sortColumn, desc)) + + override def dataSize: Int = data.size + + override def sliceData(from: Int, to: Int): Seq[SessionInfo] = data.slice(from, to) + + /** + * Return Ordering according to sortColumn and desc. + */ + private def ordering(sortColumn: String, desc: Boolean): Ordering[SessionInfo] = { + val ordering: Ordering[SessionInfo] = sortColumn match { + case "User" => Ordering.by(_.userName) + case "IP" => Ordering.by(_.ip) + case "Session ID" => Ordering.by(_.sessionId) + case "Start Time" => Ordering by (_.startTimestamp) + case "Finish Time" => Ordering.by(_.finishTimestamp) + case "Duration" => Ordering.by(_.totalTime) + case "Total Execute" => Ordering.by(_.totalExecution) + case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") } - - /** - * Return Ordering according to sortColumn and desc. - */ - private def ordering(sortColumn: String, desc: Boolean): Ordering[SessionInfo] = { - val ordering: Ordering[SessionInfo] = sortColumn match { - case "User" => Ordering.by(_.userName) - case "IP" => Ordering.by(_.ip) - case "Session ID" => Ordering.by(_.sessionId) - case "Start Time" => Ordering by (_.startTimestamp) - case "Finish Time" => Ordering.by(_.finishTimestamp) - case "Duration" => Ordering.by(_.totalTime) - case "Total Execute" => Ordering.by(_.totalExecution) - case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") - } - if (desc) { - ordering.reverse - } else { - ordering - } + if (desc) { + ordering.reverse + } else { + ordering } } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index 2d7adf552738c..87165cc8cac45 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.hive.thriftserver.ui import javax.servlet.http.HttpServletRequest -import scala.collection.JavaConverters._ import scala.xml.Node import org.apache.spark.internal.Logging @@ -77,26 +76,8 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) val sqlTableTag = "sqlsessionstat" - val parameterOtherTable = request.getParameterMap().asScala - .filterNot(_._1.startsWith(sqlTableTag)) - .map { case (name, vals) => - name + "=" + vals(0) - } - - val parameterSqlTablePage = request.getParameter(s"$sqlTableTag.page") - val parameterSqlTableSortColumn = request.getParameter(s"$sqlTableTag.sort") - val parameterSqlTableSortDesc = request.getParameter(s"$sqlTableTag.desc") - val parameterSqlPageSize = request.getParameter(s"$sqlTableTag.pageSize") - - val sqlTablePage = Option(parameterSqlTablePage).map(_.toInt).getOrElse(1) - val sqlTableSortColumn = Option(parameterSqlTableSortColumn).map { sortColumn => - UIUtils.decodeURLParameter(sortColumn) - }.getOrElse("Start Time") - val sqlTableSortDesc = Option(parameterSqlTableSortDesc).map(_.toBoolean).getOrElse( - // New executions should be shown above old executions by default. - sqlTableSortColumn == "Start Time" - ) - val sqlTablePageSize = Option(parameterSqlPageSize).map(_.toInt).getOrElse(100) + val sqlTablePage = + Option(request.getParameter(s"$sqlTableTag.page")).map(_.toInt).getOrElse(1) try { Some(new SqlStatsPagedTable( @@ -105,11 +86,7 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) executionList, "sqlserver/session", UIUtils.prependBaseUri(request, parent.basePath), - parameterOtherTable, - sqlTableTag, - pageSize = sqlTablePageSize, - sortColumn = sqlTableSortColumn, - desc = sqlTableSortDesc + sqlTableTag ).table(sqlTablePage)) } catch { case e@(_: IllegalArgumentException | _: IndexOutOfBoundsException) => diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 265e7772a691c..ea1a371151c36 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -39,7 +39,7 @@ import org.apache.spark.util.{ThreadUtils, Utils} /** * A test suite for the `spark-sql` CLI tool. */ -class CliSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfterEach with Logging { +class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { val warehousePath = Utils.createTempDir() val metastorePath = Utils.createTempDir() val scratchDirPath = Utils.createTempDir() @@ -62,12 +62,6 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfterE } } - override def afterEach(): Unit = { - // Only running `runCliWithin` in a single test case will share the same temporary - // Hive metastore - Utils.deleteRecursively(metastorePath) - } - /** * Run a CLI operation and expect all the queries and expected answers to be returned. * @@ -77,6 +71,12 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfterE * is taken as an immediate error condition. That is: if a line containing * with one of these strings is found, fail the test immediately. * The default value is `Seq("Error:")` + * @param maybeWarehouse an option for warehouse path, which will be set via + * `hive.metastore.warehouse.dir`. + * @param useExternalHiveFile whether to load the hive-site.xml from `src/test/noclasspath` or + * not, disabled by default + * @param metastore which path the embedded derby database for metastore locates. Use the the + * global `metastorePath` by default * @param queriesAndExpectedAnswers one or more tuples of query + answer */ def runCliWithin( @@ -84,7 +84,8 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfterE extraArgs: Seq[String] = Seq.empty, errorResponses: Seq[String] = Seq("Error:"), maybeWarehouse: Option[File] = Some(warehousePath), - useExternalHiveFile: Boolean = false)( + useExternalHiveFile: Boolean = false, + metastore: File = metastorePath)( queriesAndExpectedAnswers: (String, String)*): Unit = { // Explicitly adds ENTER for each statement to make sure they are actually entered into the CLI. @@ -116,7 +117,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfterE maybeWarehouse.map(dir => s"--hiveconf ${ConfVars.METASTOREWAREHOUSE}=$dir").getOrElse("") val command = { val cliScript = "../../bin/spark-sql".split("/").mkString(File.separator) - val jdbcUrl = s"jdbc:derby:;databaseName=$metastorePath;create=true" + val jdbcUrl = s"jdbc:derby:;databaseName=$metastore;create=true" s"""$cliScript | --master local | --driver-java-options -Dderby.system.durability=test @@ -202,9 +203,18 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfterE } test("load warehouse dir from hive-site.xml") { - runCliWithin(1.minute, maybeWarehouse = None, useExternalHiveFile = true)( - "desc database default;" -> "hive_one", - "set spark.sql.warehouse.dir;" -> "hive_one") + val metastore = Utils.createTempDir() + metastore.delete() + try { + runCliWithin(1.minute, + maybeWarehouse = None, + useExternalHiveFile = true, + metastore = metastore)( + "desc database default;" -> "hive_one", + "set spark.sql.warehouse.dir;" -> "hive_one") + } finally { + Utils.deleteRecursively(metastore) + } } test("load warehouse dir from --hiveconf") { @@ -218,35 +228,47 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfterE test("load warehouse dir from --conf spark(.hadoop).hive.*") { // override conf from hive-site.xml - runCliWithin( - 2.minute, - extraArgs = Seq("--conf", s"spark.hadoop.${ConfVars.METASTOREWAREHOUSE}=$sparkWareHouseDir"), - maybeWarehouse = None, - useExternalHiveFile = true)( - "desc database default;" -> sparkWareHouseDir.getAbsolutePath, - "create database cliTestDb;" -> "", - "desc database cliTestDb;" -> sparkWareHouseDir.getAbsolutePath, - "set spark.sql.warehouse.dir;" -> sparkWareHouseDir.getAbsolutePath) - - // override conf from --hiveconf too - runCliWithin( - 2.minute, - extraArgs = Seq("--conf", s"spark.${ConfVars.METASTOREWAREHOUSE}=$sparkWareHouseDir"))( - "desc database default;" -> sparkWareHouseDir.getAbsolutePath, - "create database cliTestDb;" -> "", - "desc database cliTestDb;" -> sparkWareHouseDir.getAbsolutePath, - "set spark.sql.warehouse.dir;" -> sparkWareHouseDir.getAbsolutePath) + val metastore = Utils.createTempDir() + metastore.delete() + try { + runCliWithin(2.minute, + extraArgs = + Seq("--conf", s"spark.hadoop.${ConfVars.METASTOREWAREHOUSE}=$sparkWareHouseDir"), + maybeWarehouse = None, + useExternalHiveFile = true, + metastore = metastore)( + "desc database default;" -> sparkWareHouseDir.getAbsolutePath, + "create database cliTestDb;" -> "", + "desc database cliTestDb;" -> sparkWareHouseDir.getAbsolutePath, + "set spark.sql.warehouse.dir;" -> sparkWareHouseDir.getAbsolutePath) + + // override conf from --hiveconf too + runCliWithin(2.minute, + extraArgs = Seq("--conf", s"spark.${ConfVars.METASTOREWAREHOUSE}=$sparkWareHouseDir"), + metastore = metastore)( + "desc database default;" -> sparkWareHouseDir.getAbsolutePath, + "create database cliTestDb;" -> "", + "desc database cliTestDb;" -> sparkWareHouseDir.getAbsolutePath, + "set spark.sql.warehouse.dir;" -> sparkWareHouseDir.getAbsolutePath) + } finally { + Utils.deleteRecursively(metastore) + } } test("load warehouse dir from spark.sql.warehouse.dir") { // spark.sql.warehouse.dir overrides all hive ones - runCliWithin( - 2.minute, - extraArgs = - Seq("--conf", - s"${StaticSQLConf.WAREHOUSE_PATH.key}=${sparkWareHouseDir}1", - "--conf", s"spark.hadoop.${ConfVars.METASTOREWAREHOUSE}=${sparkWareHouseDir}2"))( - "desc database default;" -> sparkWareHouseDir.getAbsolutePath.concat("1")) + val metastore = Utils.createTempDir() + metastore.delete() + try { + runCliWithin(2.minute, + extraArgs = Seq( + "--conf", s"${StaticSQLConf.WAREHOUSE_PATH.key}=${sparkWareHouseDir}1", + "--conf", s"spark.hadoop.${ConfVars.METASTOREWAREHOUSE}=${sparkWareHouseDir}2"), + metastore = metastore)( + "desc database default;" -> sparkWareHouseDir.getAbsolutePath.concat("1")) + } finally { + Utils.deleteRecursively(metastore) + } } test("Simple commands") { @@ -486,18 +508,20 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfterE ) } - test("SPARK-30049 Should not complain for quotes in commented with multi-lines") { + test("SPARK-31102 spark-sql fails to parse when contains comment") { runCliWithin(1.minute)( - """SELECT concat('test', 'comment') -- someone's comment here \\ - | comment continues here with single ' quote \\ - | extra ' \\ - |;""".stripMargin -> "testcomment" + """SELECT concat('test', 'comment'), + | -- someone's comment here + | 2;""".stripMargin -> "testcomment" ) + } + + test("SPARK-30049 Should not complain for quotes in commented with multi-lines") { runCliWithin(1.minute)( - """SELECT concat('test', 'comment') -- someone's comment here \\ - | comment continues here with single ' quote \\ - | extra ' \\ - | ;""".stripMargin -> "testcomment" + """SELECT concat('test', 'comment') -- someone's comment here \ + | comment continues here with single ' quote \ + | extra ' \ + |;""".stripMargin -> "testcomment" ) } @@ -509,4 +533,22 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfterE "SELECT \"legal 'string b\";select 22222 + 1;".stripMargin -> "22223" ) } + + test("AnalysisException with root cause will be printStacktrace") { + // If it is not in silent mode, will print the stacktrace + runCliWithin( + 1.minute, + extraArgs = Seq("--hiveconf", "hive.session.silent=false", + "-e", "select date_sub(date'2011-11-11', '1.2');"), + errorResponses = Seq("NumberFormatException"))( + ("", "Error in query: The second argument of 'date_sub' function needs to be an integer."), + ("", "NumberFormatException: invalid input syntax for type numeric: 1.2")) + // If it is in silent mode, will print the error message only + runCliWithin( + 1.minute, + extraArgs = Seq("--conf", "spark.hive.session.silent=true", + "-e", "select date_sub(date'2011-11-11', '1.2');"), + errorResponses = Seq("AnalysisException"))( + ("", "Error in query: The second argument of 'date_sub' function needs to be an integer.")) + } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveSessionImplSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveSessionImplSuite.scala new file mode 100644 index 0000000000000..05d540d782e31 --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveSessionImplSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hive.service.cli.OperationHandle +import org.apache.hive.service.cli.operation.{GetCatalogsOperation, OperationManager} +import org.apache.hive.service.cli.session.{HiveSessionImpl, SessionManager} +import org.mockito.Mockito.{mock, verify, when} +import org.mockito.invocation.InvocationOnMock + +import org.apache.spark.SparkFunSuite + +class HiveSessionImplSuite extends SparkFunSuite { + private var session: HiveSessionImpl = _ + private var operationManager: OperationManager = _ + + override def beforeAll() { + super.beforeAll() + + session = new HiveSessionImpl( + ThriftserverShimUtils.testedProtocolVersions.head, + "", + "", + new HiveConf(), + "" + ) + val sessionManager = mock(classOf[SessionManager]) + session.setSessionManager(sessionManager) + operationManager = mock(classOf[OperationManager]) + session.setOperationManager(operationManager) + when(operationManager.newGetCatalogsOperation(session)).thenAnswer( + (_: InvocationOnMock) => { + val operation = mock(classOf[GetCatalogsOperation]) + when(operation.getHandle).thenReturn(mock(classOf[OperationHandle])) + operation + } + ) + + session.open(Map.empty[String, String].asJava) + } + + test("SPARK-31387 - session.close() closes all sessions regardless of thrown exceptions") { + val operationHandle1 = session.getCatalogs + val operationHandle2 = session.getCatalogs + + when(operationManager.closeOperation(operationHandle1)) + .thenThrow(classOf[NullPointerException]) + when(operationManager.closeOperation(operationHandle2)) + .thenThrow(classOf[NullPointerException]) + + session.close() + + verify(operationManager).closeOperation(operationHandle1) + verify(operationManager).closeOperation(operationHandle2) + } +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 0cec63460814c..21256ad02c134 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -811,6 +811,61 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } } } + + test("SPARK-31859 Thriftserver works with spark.sql.datetime.java8API.enabled=true") { + withJdbcStatement() { statement => + withJdbcStatement() { st => + st.execute("set spark.sql.datetime.java8API.enabled=true") + val rs = st.executeQuery("select date '2020-05-28', timestamp '2020-05-28 00:00:00'") + rs.next() + assert(rs.getDate(1).toString() == "2020-05-28") + assert(rs.getTimestamp(2).toString() == "2020-05-28 00:00:00.0") + } + } + } + + test("SPARK-31861 Thriftserver respects spark.sql.session.timeZone") { + withJdbcStatement() { statement => + withJdbcStatement() { st => + st.execute("set spark.sql.session.timeZone=+03:15") // different than Thriftserver's JVM tz + val rs = st.executeQuery("select timestamp '2020-05-28 10:00:00'") + rs.next() + // The timestamp as string is the same as the literal + assert(rs.getString(1) == "2020-05-28 10:00:00.0") + // Parsing it to java.sql.Timestamp in the client will always result in a timestamp + // in client default JVM timezone. The string value of the Timestamp will match the literal, + // but if the JDBC application cares about the internal timezone and UTC offset of the + // Timestamp object, it should set spark.sql.session.timeZone to match its client JVM tz. + assert(rs.getTimestamp(1).toString() == "2020-05-28 10:00:00.0") + } + } + } + + test("SPARK-31863 Session conf should persist between Thriftserver worker threads") { + val iter = 20 + withJdbcStatement() { statement => + // date 'now' is resolved during parsing, and relies on SQLConf.get to + // obtain the current set timezone. We exploit this to run this test. + // If the timezones are set correctly to 25 hours apart across threads, + // the dates should reflect this. + + // iterate a few times for the odd chance the same thread is selected + for (_ <- 0 until iter) { + statement.execute("SET spark.sql.session.timeZone=GMT-12") + val firstResult = statement.executeQuery("SELECT date 'now'") + firstResult.next() + val beyondDateLineWest = firstResult.getDate(1) + + statement.execute("SET spark.sql.session.timeZone=GMT+13") + val secondResult = statement.executeQuery("SELECT date 'now'") + secondResult.next() + val dateLineEast = secondResult.getDate(1) + assert( + dateLineEast after beyondDateLineWest, + "SQLConf changes should persist across execution threads") + } + } + } } class SingleSessionSuite extends HiveThriftJdbcTest { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SharedThriftServer.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SharedThriftServer.scala index ce610098156f3..e002bc0117c8b 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SharedThriftServer.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SharedThriftServer.scala @@ -19,29 +19,25 @@ package org.apache.spark.sql.hive.thriftserver import java.sql.{DriverManager, Statement} +import scala.collection.JavaConverters._ import scala.concurrent.duration._ -import scala.util.{Random, Try} +import scala.util.Try import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hive.service.cli.thrift.ThriftCLIService -import org.apache.spark.sql.QueryTest import org.apache.spark.sql.test.SharedSparkSession trait SharedThriftServer extends SharedSparkSession { private var hiveServer2: HiveThriftServer2 = _ + private var serverPort: Int = 0 override def beforeAll(): Unit = { super.beforeAll() - // Chooses a random port between 10000 and 19999 - var listeningPort = 10000 + Random.nextInt(10000) - // Retries up to 3 times with different port numbers if the server fails to start - (1 to 3).foldLeft(Try(startThriftServer(listeningPort, 0))) { case (started, attempt) => - started.orElse { - listeningPort += 1 - Try(startThriftServer(listeningPort, attempt)) - } + (1 to 3).foldLeft(Try(startThriftServer(0))) { case (started, attempt) => + started.orElse(Try(startThriftServer(attempt))) }.recover { case cause: Throwable => throw cause @@ -59,8 +55,7 @@ trait SharedThriftServer extends SharedSparkSession { protected def withJdbcStatement(fs: (Statement => Unit)*): Unit = { val user = System.getProperty("user.name") - - val serverPort = hiveServer2.getHiveConf.get(ConfVars.HIVE_SERVER2_THRIFT_PORT.varname) + require(serverPort != 0, "Failed to bind an actual port for HiveThriftServer2") val connections = fs.map { _ => DriverManager.getConnection(s"jdbc:hive2://localhost:$serverPort", user, "") } val statements = connections.map(_.createStatement()) @@ -73,11 +68,19 @@ trait SharedThriftServer extends SharedSparkSession { } } - private def startThriftServer(port: Int, attempt: Int): Unit = { - logInfo(s"Trying to start HiveThriftServer2: port=$port, attempt=$attempt") + private def startThriftServer(attempt: Int): Unit = { + logInfo(s"Trying to start HiveThriftServer2:, attempt=$attempt") val sqlContext = spark.newSession().sqlContext - sqlContext.setConf(ConfVars.HIVE_SERVER2_THRIFT_PORT.varname, port.toString) + // Set the HIVE_SERVER2_THRIFT_PORT to 0, so it could randomly pick any free port to use. + // It's much more robust than set a random port generated by ourselves ahead + sqlContext.setConf(ConfVars.HIVE_SERVER2_THRIFT_PORT.varname, "0") hiveServer2 = HiveThriftServer2.startWithContext(sqlContext) + hiveServer2.getServices.asScala.foreach { + case t: ThriftCLIService if t.getPortNumber != 0 => + serverPort = t.getPortNumber + logInfo(s"Started HiveThriftServer2: port=$serverPort, attempt=$attempt") + case _ => + } // Wait for thrift server to be ready to serve the query, via executing simple query // till the query succeeds. See SPARK-30345 for more details. diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnvSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnvSuite.scala index ffd1fc48f19fe..f28faea2be868 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnvSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnvSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.hive.thriftserver +import org.apache.commons.io.FileUtils import test.custom.listener.{DummyQueryExecutionListener, DummyStreamingQueryListener} import org.apache.spark.SparkFunSuite @@ -25,10 +26,19 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.hive.HiveUtils.{HIVE_METASTORE_JARS, HIVE_METASTORE_VERSION} import org.apache.spark.sql.hive.test.TestHiveContext import org.apache.spark.sql.internal.StaticSQLConf.{QUERY_EXECUTION_LISTENERS, STREAMING_QUERY_LISTENERS, WAREHOUSE_PATH} +import org.apache.spark.util.Utils class SparkSQLEnvSuite extends SparkFunSuite { test("SPARK-29604 external listeners should be initialized with Spark classloader") { + val metastorePath = Utils.createTempDir("spark_derby") + FileUtils.forceDelete(metastorePath) + + val jdbcUrl = s"jdbc:derby:;databaseName=$metastorePath;create=true" + withSystemProperties( + "javax.jdo.option.ConnectionURL" -> jdbcUrl, + "derby.system.durability" -> "test", + "spark.ui.enabled" -> "false", QUERY_EXECUTION_LISTENERS.key -> classOf[DummyQueryExecutionListener].getCanonicalName, STREAMING_QUERY_LISTENERS.key -> classOf[DummyStreamingQueryListener].getCanonicalName, WAREHOUSE_PATH.key -> TestHiveContext.makeWarehouseDir().toURI.getPath, diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ui/HiveThriftServer2ListenerSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ui/HiveThriftServer2ListenerSuite.scala index 075032fa5d099..9a9f574153a0a 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ui/HiveThriftServer2ListenerSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ui/HiveThriftServer2ListenerSuite.scala @@ -140,6 +140,23 @@ class HiveThriftServer2ListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(listener.noLiveData()) } + test("SPARK-31387 - listener update methods should not throw exception with unknown input") { + val (statusStore: HiveThriftServer2AppStatusStore, listener: HiveThriftServer2Listener) = + createAppStatusStore(true) + + val unknownSession = "unknown_session" + val unknownOperation = "unknown_operation" + listener.onOtherEvent(SparkListenerThriftServerSessionClosed(unknownSession, 0)) + listener.onOtherEvent(SparkListenerThriftServerOperationStart("id", unknownSession, + "stmt", "groupId", 0)) + listener.onOtherEvent(SparkListenerThriftServerOperationParsed(unknownOperation, "query")) + listener.onOtherEvent(SparkListenerThriftServerOperationCanceled(unknownOperation, 0)) + listener.onOtherEvent(SparkListenerThriftServerOperationError(unknownOperation, + "msg", "trace", 0)) + listener.onOtherEvent(SparkListenerThriftServerOperationFinish(unknownOperation, 0)) + listener.onOtherEvent(SparkListenerThriftServerOperationClosed(unknownOperation, 0)) + } + private def createProperties: Properties = { val properties = new Properties() properties.setProperty(SparkContext.SPARK_JOB_GROUP_ID, "groupId") diff --git a/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/ColumnValue.java b/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/ColumnValue.java index a770bea9c2aa6..462b93a0f09fe 100644 --- a/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/ColumnValue.java +++ b/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/ColumnValue.java @@ -123,22 +123,6 @@ private static TColumnValue stringValue(HiveVarchar value) { return TColumnValue.stringVal(tStringValue); } - private static TColumnValue dateValue(Date value) { - TStringValue tStringValue = new TStringValue(); - if (value != null) { - tStringValue.setValue(value.toString()); - } - return new TColumnValue(TColumnValue.stringVal(tStringValue)); - } - - private static TColumnValue timestampValue(Timestamp value) { - TStringValue tStringValue = new TStringValue(); - if (value != null) { - tStringValue.setValue(value.toString()); - } - return TColumnValue.stringVal(tStringValue); - } - private static TColumnValue stringValue(HiveIntervalYearMonth value) { TStringValue tStrValue = new TStringValue(); if (value != null) { @@ -178,9 +162,9 @@ public static TColumnValue toTColumnValue(Type type, Object value) { case VARCHAR_TYPE: return stringValue((HiveVarchar)value); case DATE_TYPE: - return dateValue((Date)value); case TIMESTAMP_TYPE: - return timestampValue((Timestamp)value); + // SPARK-31859, SPARK-31861: converted to string already in SparkExecuteStatementOperation + return stringValue((String)value); case INTERVAL_YEAR_MONTH_TYPE: return stringValue((HiveIntervalYearMonth) value); case INTERVAL_DAY_TIME_TYPE: diff --git a/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/operation/Operation.java b/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/operation/Operation.java index 51bb28748d9e2..4b331423948fa 100644 --- a/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/operation/Operation.java +++ b/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/operation/Operation.java @@ -280,7 +280,10 @@ public void cancel() throws HiveSQLException { throw new UnsupportedOperationException("SQLOperation.cancel()"); } - public abstract void close() throws HiveSQLException; + public void close() throws HiveSQLException { + setState(OperationState.CLOSED); + cleanupOperationLog(); + } public abstract TableSchema getResultSetSchema() throws HiveSQLException; diff --git a/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java b/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java index 745f385e87f78..e3fb54d9f47e9 100644 --- a/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java +++ b/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java @@ -636,7 +636,11 @@ public void close() throws HiveSQLException { acquire(true); // Iterate through the opHandles and close their operations for (OperationHandle opHandle : opHandleSet) { - operationManager.closeOperation(opHandle); + try { + operationManager.closeOperation(opHandle); + } catch (Exception e) { + LOG.warn("Exception is thrown closing operation " + opHandle, e); + } } opHandleSet.clear(); // Cleanup session log directory. @@ -674,11 +678,15 @@ private void cleanupPipeoutFile() { File[] fileAry = new File(lScratchDir).listFiles( (dir, name) -> name.startsWith(sessionID) && name.endsWith(".pipeout")); - for (File file : fileAry) { - try { - FileUtils.forceDelete(file); - } catch (Exception e) { - LOG.error("Failed to cleanup pipeout file: " + file, e); + if (fileAry == null) { + LOG.error("Unable to access pipeout files in " + lScratchDir); + } else { + for (File file : fileAry) { + try { + FileUtils.forceDelete(file); + } catch (Exception e) { + LOG.error("Failed to cleanup pipeout file: " + file, e); + } } } } diff --git a/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftBinaryCLIService.java b/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftBinaryCLIService.java index 21b8bf7de75ce..e1ee503b81209 100644 --- a/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftBinaryCLIService.java +++ b/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftBinaryCLIService.java @@ -76,6 +76,10 @@ public void run() { keyStorePassword, sslVersionBlacklist); } + // In case HIVE_SERVER2_THRIFT_PORT or hive.server2.thrift.port is configured with 0 which + // represents any free port, we should set it to the actual one + portNum = serverSocket.getServerSocket().getLocalPort(); + // Server args int maxMessageSize = hiveConf.getIntVar(HiveConf.ConfVars.HIVE_SERVER2_THRIFT_MAX_MESSAGE_SIZE); int requestTimeout = (int) hiveConf.getTimeVar( diff --git a/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java b/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java index 504e63dbc5e5e..1099a00b67eb7 100644 --- a/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java +++ b/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java @@ -143,6 +143,9 @@ public void run() { // TODO: check defaults: maxTimeout, keepalive, maxBodySize, bodyRecieveDuration, etc. // Finally, start the server httpServer.start(); + // In case HIVE_SERVER2_THRIFT_HTTP_PORT or hive.server2.thrift.http.port is configured with + // 0 which represents any free port, we should set it to the actual one + portNum = connector.getLocalPort(); String msg = "Started " + ThriftHttpCLIService.class.getSimpleName() + " in " + schemeName + " mode on port " + connector.getLocalPort()+ " path=" + httpPath + " with " + minWorkerThreads + "..." + maxWorkerThreads + " worker threads"; diff --git a/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/ColumnValue.java b/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/ColumnValue.java index 53f0465a056d8..85adf55df15e0 100644 --- a/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/ColumnValue.java +++ b/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/ColumnValue.java @@ -124,22 +124,6 @@ private static TColumnValue stringValue(HiveVarchar value) { return TColumnValue.stringVal(tStringValue); } - private static TColumnValue dateValue(Date value) { - TStringValue tStringValue = new TStringValue(); - if (value != null) { - tStringValue.setValue(value.toString()); - } - return new TColumnValue(TColumnValue.stringVal(tStringValue)); - } - - private static TColumnValue timestampValue(Timestamp value) { - TStringValue tStringValue = new TStringValue(); - if (value != null) { - tStringValue.setValue(value.toString()); - } - return TColumnValue.stringVal(tStringValue); - } - private static TColumnValue stringValue(HiveIntervalYearMonth value) { TStringValue tStrValue = new TStringValue(); if (value != null) { @@ -181,9 +165,9 @@ public static TColumnValue toTColumnValue(TypeDescriptor typeDescriptor, Object case VARCHAR_TYPE: return stringValue((HiveVarchar)value); case DATE_TYPE: - return dateValue((Date)value); case TIMESTAMP_TYPE: - return timestampValue((Timestamp)value); + // SPARK-31859, SPARK-31861: converted to string already in SparkExecuteStatementOperation + return stringValue((String)value); case INTERVAL_YEAR_MONTH_TYPE: return stringValue((HiveIntervalYearMonth) value); case INTERVAL_DAY_TIME_TYPE: diff --git a/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/operation/Operation.java b/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/operation/Operation.java index f26c715add987..558c68f85c16b 100644 --- a/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/operation/Operation.java +++ b/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/operation/Operation.java @@ -298,7 +298,10 @@ public void cancel() throws HiveSQLException { throw new UnsupportedOperationException("SQLOperation.cancel()"); } - public abstract void close() throws HiveSQLException; + public void close() throws HiveSQLException { + setState(OperationState.CLOSED); + cleanupOperationLog(); + } public abstract TableSchema getResultSetSchema() throws HiveSQLException; diff --git a/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java b/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java index 14e9c4704c977..1b3e8fe6bfb9d 100644 --- a/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java +++ b/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java @@ -650,7 +650,11 @@ public void close() throws HiveSQLException { acquire(true); // Iterate through the opHandles and close their operations for (OperationHandle opHandle : opHandleSet) { - operationManager.closeOperation(opHandle); + try { + operationManager.closeOperation(opHandle); + } catch (Exception e) { + LOG.warn("Exception is thrown closing operation " + opHandle, e); + } } opHandleSet.clear(); // Cleanup session log directory. @@ -688,11 +692,15 @@ private void cleanupPipeoutFile() { File[] fileAry = new File(lScratchDir).listFiles( (dir, name) -> name.startsWith(sessionID) && name.endsWith(".pipeout")); - for (File file : fileAry) { - try { - FileUtils.forceDelete(file); - } catch (Exception e) { - LOG.error("Failed to cleanup pipeout file: " + file, e); + if (fileAry == null) { + LOG.error("Unable to access pipeout files in " + lScratchDir); + } else { + for (File file : fileAry) { + try { + FileUtils.forceDelete(file); + } catch (Exception e) { + LOG.error("Failed to cleanup pipeout file: " + file, e); + } } } } diff --git a/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftBinaryCLIService.java b/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftBinaryCLIService.java index fc19c65daaf54..a7de9c0f3d0d2 100644 --- a/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftBinaryCLIService.java +++ b/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftBinaryCLIService.java @@ -77,6 +77,10 @@ public void run() { keyStorePassword, sslVersionBlacklist); } + // In case HIVE_SERVER2_THRIFT_PORT or hive.server2.thrift.port is configured with 0 which + // represents any free port, we should set it to the actual one + portNum = serverSocket.getServerSocket().getLocalPort(); + // Server args int maxMessageSize = hiveConf.getIntVar(HiveConf.ConfVars.HIVE_SERVER2_THRIFT_MAX_MESSAGE_SIZE); int requestTimeout = (int) hiveConf.getTimeVar( diff --git a/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java b/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java index 08626e7eb146d..73d5f84476af0 100644 --- a/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java +++ b/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java @@ -144,6 +144,9 @@ public void run() { // TODO: check defaults: maxTimeout, keepalive, maxBodySize, bodyRecieveDuration, etc. // Finally, start the server httpServer.start(); + // In case HIVE_SERVER2_THRIFT_HTTP_PORT or hive.server2.thrift.http.port is configured with + // 0 which represents any free port, we should set it to the actual one + portNum = connector.getLocalPort(); String msg = "Started " + ThriftHttpCLIService.class.getSimpleName() + " in " + schemeName + " mode on port " + portNum + " path=" + httpPath + " with " + minWorkerThreads + "..." + maxWorkerThreads + " worker threads"; diff --git a/sql/hive/benchmarks/InsertIntoHiveTableBenchmark-hive1.2-results.txt b/sql/hive/benchmarks/InsertIntoHiveTableBenchmark-hive1.2-results.txt new file mode 100644 index 0000000000000..85884a1aaf739 --- /dev/null +++ b/sql/hive/benchmarks/InsertIntoHiveTableBenchmark-hive1.2-results.txt @@ -0,0 +1,11 @@ +Java HotSpot(TM) 64-Bit Server VM 1.8.0_251-b08 on Mac OS X 10.15.4 +Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +insert hive table benchmark: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +INSERT INTO DYNAMIC 6812 7043 328 0.0 665204.8 1.0X +INSERT INTO HYBRID 817 852 32 0.0 79783.6 8.3X +INSERT INTO STATIC 231 246 21 0.0 22568.2 29.5X +INSERT OVERWRITE DYNAMIC 25947 26671 1024 0.0 2533910.2 0.3X +INSERT OVERWRITE HYBRID 2846 2884 54 0.0 277908.7 2.4X +INSERT OVERWRITE STATIC 232 247 26 0.0 22659.9 29.4X + diff --git a/sql/hive/benchmarks/InsertIntoHiveTableBenchmark-hive2.3-results.txt b/sql/hive/benchmarks/InsertIntoHiveTableBenchmark-hive2.3-results.txt new file mode 100644 index 0000000000000..ea8e6057ea610 --- /dev/null +++ b/sql/hive/benchmarks/InsertIntoHiveTableBenchmark-hive2.3-results.txt @@ -0,0 +1,11 @@ +Java HotSpot(TM) 64-Bit Server VM 1.8.0_251-b08 on Mac OS X 10.15.4 +Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +insert hive table benchmark: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +INSERT INTO DYNAMIC 4326 4373 66 0.0 422486.0 1.0X +INSERT INTO HYBRID 726 741 21 0.0 70877.2 6.0X +INSERT INTO STATIC 256 270 12 0.0 25015.7 16.9X +INSERT OVERWRITE DYNAMIC 4115 4150 49 0.0 401828.8 1.1X +INSERT OVERWRITE HYBRID 690 699 8 0.0 67370.5 6.3X +INSERT OVERWRITE STATIC 277 283 5 0.0 27097.9 15.6X + diff --git a/sql/hive/benchmarks/InsertIntoHiveTableBenchmark-jdk11-hive2.3-results.txt b/sql/hive/benchmarks/InsertIntoHiveTableBenchmark-jdk11-hive2.3-results.txt new file mode 100644 index 0000000000000..c7a642aad5273 --- /dev/null +++ b/sql/hive/benchmarks/InsertIntoHiveTableBenchmark-jdk11-hive2.3-results.txt @@ -0,0 +1,11 @@ +Java HotSpot(TM) 64-Bit Server VM 11.0.5+10-LTS on Mac OS X 10.15.4 +Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz +insert hive table benchmark: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +INSERT INTO DYNAMIC 5083 5412 466 0.0 496384.5 1.0X +INSERT INTO HYBRID 822 864 43 0.0 80283.6 6.2X +INSERT INTO STATIC 335 342 5 0.0 32694.1 15.2X +INSERT OVERWRITE DYNAMIC 4941 5068 179 0.0 482534.5 1.0X +INSERT OVERWRITE HYBRID 722 745 27 0.0 70502.7 7.0X +INSERT OVERWRITE STATIC 295 314 12 0.0 28846.8 17.2X + diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 29825e5116ef9..db1f6fbd97d90 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.hive.execution import java.io.File -import java.util.{Locale, TimeZone} import org.scalatest.BeforeAndAfter @@ -36,13 +35,10 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private lazy val hiveQueryDir = TestHive.getHiveFile( "ql/src/test/queries/clientpositive".split("/").mkString(File.separator)) - private val originalTimeZone = TimeZone.getDefault - private val originalLocale = Locale.getDefault private val originalColumnBatchSize = TestHive.conf.columnBatchSize private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled private val originalSessionLocalTimeZone = TestHive.conf.sessionLocalTimeZone - private val originalCreateHiveTable = TestHive.conf.createHiveTableByDefaultEnabled def testCases: Seq[(String, File)] = { hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) @@ -51,10 +47,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { override def beforeAll(): Unit = { super.beforeAll() TestHive.setCacheTables(true) - // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) - TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) - // Add Locale setting - Locale.setDefault(Locale.US) // Set a relatively small column batch size for testing purposes TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, 5) // Enable in-memory partition pruning for testing purposes @@ -66,21 +58,16 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Fix session local timezone to America/Los_Angeles for those timezone sensitive tests // (timestamp_*) TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, "America/Los_Angeles") - TestHive.setConf(SQLConf.LEGACY_CREATE_HIVE_TABLE_BY_DEFAULT_ENABLED, true) RuleExecutor.resetMetrics() } override def afterAll(): Unit = { try { TestHive.setCacheTables(false) - TimeZone.setDefault(originalTimeZone) - Locale.setDefault(originalLocale) TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled) TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, originalSessionLocalTimeZone) - TestHive.setConf(SQLConf.LEGACY_CREATE_HIVE_TABLE_BY_DEFAULT_ENABLED, - originalCreateHiveTable) // For debugging dump some statistics about how much time was spent in various optimizer rules logWarning(RuleExecutor.dumpTimeSpent()) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index ed23f65815917..2c0970c85449f 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.hive.execution import java.io.File -import java.util.{Locale, TimeZone} import org.scalatest.BeforeAndAfter @@ -33,17 +32,11 @@ import org.apache.spark.util.Utils * files, every `createQueryTest` calls should explicitly set `reset` to `false`. */ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfter { - private val originalTimeZone = TimeZone.getDefault - private val originalLocale = Locale.getDefault private val testTempDir = Utils.createTempDir() override def beforeAll(): Unit = { super.beforeAll() TestHive.setCacheTables(true) - // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) - TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) - // Add Locale setting - Locale.setDefault(Locale.US) // Create the table used in windowing.q sql("DROP TABLE IF EXISTS part") @@ -103,8 +96,6 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte override def afterAll(): Unit = { try { TestHive.setCacheTables(false) - TimeZone.setDefault(originalTimeZone) - Locale.setDefault(originalLocale) TestHive.reset() } finally { super.afterAll() @@ -747,17 +738,11 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte class HiveWindowFunctionQueryFileSuite extends HiveCompatibilitySuite with BeforeAndAfter { - private val originalTimeZone = TimeZone.getDefault - private val originalLocale = Locale.getDefault private val testTempDir = Utils.createTempDir() override def beforeAll(): Unit = { super.beforeAll() TestHive.setCacheTables(true) - // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) - TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) - // Add Locale setting - Locale.setDefault(Locale.US) // The following settings are used for generating golden files with Hive. // We have to use kryo to correctly let Hive serialize plans with window functions. @@ -772,8 +757,6 @@ class HiveWindowFunctionQueryFileSuite override def afterAll(): Unit = { try { TestHive.setCacheTables(false) - TimeZone.setDefault(originalTimeZone) - Locale.setDefault(originalLocale) TestHive.reset() } finally { super.afterAll() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 8526d86454604..27ba3eca81948 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -22,7 +22,20 @@ package object client { private[hive] sealed abstract class HiveVersion( val fullVersion: String, val extraDeps: Seq[String] = Nil, - val exclusions: Seq[String] = Nil) + val exclusions: Seq[String] = Nil) extends Ordered[HiveVersion] { + override def compare(that: HiveVersion): Int = { + val thisVersionParts = fullVersion.split('.').map(_.toInt) + val thatVersionParts = that.fullVersion.split('.').map(_.toInt) + assert(thisVersionParts.length == thatVersionParts.length) + thisVersionParts.zip(thatVersionParts).foreach { case (l, r) => + val candidate = l - r + if (candidate != 0) { + return candidate + } + } + 0 + } + } // scalastyle:off private[hive] object hive { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 19f439598142e..9f83f2ab96094 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -26,13 +26,15 @@ import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, Row, SparkSession} -import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType, ExternalCatalog, ExternalCatalogUtils} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.CommandUtils +import org.apache.spark.sql.hive.HiveExternalCatalog import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive.client.HiveClientImpl +import org.apache.spark.sql.hive.client.hive._ /** @@ -285,7 +287,21 @@ case class InsertIntoHiveTable( // Newer Hive largely improves insert overwrite performance. As Spark uses older Hive // version and we may not want to catch up new Hive version every time. We delete the // Hive partition first and then load data file into the Hive partition. - if (partitionPath.nonEmpty && overwrite) { + val hiveVersion = externalCatalog.asInstanceOf[ExternalCatalogWithListener] + .unwrapped.asInstanceOf[HiveExternalCatalog] + .client + .version + // SPARK-31684: + // For Hive 2.0.0 and onwards, as https://issues.apache.org/jira/browse/HIVE-11940 + // has been fixed, and there is no performance issue anymore. We should leave the + // overwrite logic to hive to avoid failure in `FileSystem#checkPath` when the table + // and partition locations do not belong to the same `FileSystem` + // TODO(SPARK-31675): For Hive 2.2.0 and earlier, if the table and partition locations + // do not belong together, we will still get the same error thrown by hive encryption + // check. see https://issues.apache.org/jira/browse/HIVE-14380. + // So we still disable for Hive overwrite for Hive 1.x for better performance because + // the partition and table are on the same cluster in most cases. + if (partitionPath.nonEmpty && overwrite && hiveVersion < v2_0) { partitionPath.foreach { path => val fs = path.getFileSystem(hadoopConf) if (fs.exists(path)) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala index 40f7b4e8db7c5..c7183fd7385a6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.execution import java.io._ import java.nio.charset.StandardCharsets import java.util.Properties +import java.util.concurrent.TimeUnit import javax.annotation.Nullable import scala.collection.JavaConverters._ @@ -42,6 +43,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.HiveInspectors import org.apache.spark.sql.hive.HiveShim._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.DataType import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils} @@ -136,6 +138,15 @@ case class ScriptTransformationExec( throw writerThread.exception.get } + // There can be a lag between reader read EOF and the process termination. + // If the script fails to startup, this kind of error may be missed. + // So explicitly waiting for the process termination. + val timeout = conf.getConf(SQLConf.SCRIPT_TRANSFORMATION_EXIT_TIMEOUT) + val exitRes = proc.waitFor(timeout, TimeUnit.SECONDS) + if (!exitRes) { + log.warn(s"Transformation script process exits timeout in $timeout seconds") + } + if (!proc.isAlive) { val exitCode = proc.exitValue() if (exitCode != 0) { @@ -173,7 +184,6 @@ case class ScriptTransformationExec( // Ideally the proc should *not* be alive at this point but // there can be a lag between EOF being written out and the process // being terminated. So explicitly waiting for the process to be done. - proc.waitFor() checkFailureAndPropagate() return false } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/InsertIntoHiveTableBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/InsertIntoHiveTableBenchmark.scala new file mode 100644 index 0000000000000..81eb5e2591f13 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/InsertIntoHiveTableBenchmark.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.hive.HiveUtils +import org.apache.spark.sql.hive.test.TestHive + +/** + * Benchmark to measure hive table write performance. + * To run this benchmark: + * {{{ + * 1. without sbt: bin/spark-submit --class + * --jars ,, + * --packages org.spark-project.hive:hive-exec:1.2.1.spark2 + * + * 2. build/sbt "hive/test:runMain " -Phive-1.2 or + * build/sbt "hive/test:runMain " -Phive-2.3 + * 3. generate result: + * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "hive/test:runMain " + * Results will be written to "benchmarks/InsertIntoHiveTableBenchmark-hive2.3-results.txt". + * 4. -Phive-1.2 does not work for JDK 11 + * }}} + */ +object InsertIntoHiveTableBenchmark extends SqlBasedBenchmark { + + override def getSparkSession: SparkSession = TestHive.sparkSession + + val tempView = "temp" + val numRows = 1024 * 10 + val sql = spark.sql _ + + // scalastyle:off hadoopconfiguration + private val hadoopConf = spark.sparkContext.hadoopConfiguration + // scalastyle:on hadoopconfiguration + hadoopConf.set("hive.exec.dynamic.partition", "true") + hadoopConf.set("hive.exec.dynamic.partition.mode", "nonstrict") + hadoopConf.set("hive.exec.max.dynamic.partitions", numRows.toString) + + def withTable(tableNames: String*)(f: => Unit): Unit = { + tableNames.foreach { name => + sql(s"CREATE TABLE $name(a INT) STORED AS TEXTFILE PARTITIONED BY (b INT, c INT)") + } + try f finally { + tableNames.foreach { name => + spark.sql(s"DROP TABLE IF EXISTS $name") + } + } + } + + def insertOverwriteDynamic(table: String, benchmark: Benchmark): Unit = { + benchmark.addCase("INSERT OVERWRITE DYNAMIC") { _ => + sql(s"INSERT OVERWRITE TABLE $table SELECT CAST(id AS INT) AS a," + + s" CAST(id % 10 AS INT) AS b, CAST(id % 100 AS INT) AS c FROM $tempView") + } + } + + def insertOverwriteHybrid(table: String, benchmark: Benchmark): Unit = { + benchmark.addCase("INSERT OVERWRITE HYBRID") { _ => + sql(s"INSERT OVERWRITE TABLE $table partition(b=1, c) SELECT CAST(id AS INT) AS a," + + s" CAST(id % 10 AS INT) AS c FROM $tempView") + } + } + + def insertOverwriteStatic(table: String, benchmark: Benchmark): Unit = { + benchmark.addCase("INSERT OVERWRITE STATIC") { _ => + sql(s"INSERT OVERWRITE TABLE $table partition(b=1, c=10) SELECT CAST(id AS INT) AS a" + + s" FROM $tempView") + } + } + + def insertIntoDynamic(table: String, benchmark: Benchmark): Unit = { + benchmark.addCase("INSERT INTO DYNAMIC") { _ => + sql(s"INSERT INTO TABLE $table SELECT CAST(id AS INT) AS a," + + s" CAST(id % 10 AS INT) AS b, CAST(id % 100 AS INT) AS c FROM $tempView") + } + } + + def insertIntoHybrid(table: String, benchmark: Benchmark): Unit = { + benchmark.addCase("INSERT INTO HYBRID") { _ => + sql(s"INSERT INTO TABLE $table partition(b=1, c) SELECT CAST(id AS INT) AS a," + + s" CAST(id % 10 AS INT) AS c FROM $tempView") + } + } + + def insertIntoStatic(table: String, benchmark: Benchmark): Unit = { + benchmark.addCase("INSERT INTO STATIC") { _ => + sql(s"INSERT INTO TABLE $table partition(b=1, c=10) SELECT CAST(id AS INT) AS a" + + s" FROM $tempView") + } + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + spark.range(numRows).createOrReplaceTempView(tempView) + + try { + val t1 = "t1" + val t2 = "t2" + val t3 = "t3" + val t4 = "t4" + val t5 = "t5" + val t6 = "t6" + + val benchmark = new Benchmark(s"insert hive table benchmark", numRows, output = output) + + withTable(t1, t2, t3, t4, t5, t6) { + + insertIntoDynamic(t1, benchmark) + insertIntoHybrid(t2, benchmark) + insertIntoStatic(t3, benchmark) + + insertOverwriteDynamic(t4, benchmark) + insertOverwriteHybrid(t5, benchmark) + insertOverwriteStatic(t6, benchmark) + + benchmark.run() + } + } finally { + spark.catalog.dropTempView(tempView) + } + } + + override def suffix: String = if (HiveUtils.isHive23) "-hive2.3" else "-hive1.2" +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 3b5a1247bc09c..8be3d26bfc93a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -242,7 +242,7 @@ object PROCESS_TABLES extends QueryTest with SQLTestUtils { .filter(_ < org.apache.spark.SPARK_VERSION) } catch { // do not throw exception during object initialization. - case NonFatal(_) => Nil + case NonFatal(_) => Seq("2.3.4", "2.4.5") // A temporary fallback to use a specific version } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 5912992694e84..13c48f38e7f78 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.hive import java.util -import java.util.{Locale, TimeZone} import org.apache.hadoop.hive.ql.udf.UDAFPercentile import org.apache.hadoop.hive.serde2.io.DoubleWritable @@ -74,11 +73,6 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { .get()) } - // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) - TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) - // Add Locale setting - Locale.setDefault(Locale.US) - val data = Literal(true) :: Literal(null) :: diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala index 743cdbd6457d7..db8ebcd45f3eb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala @@ -21,7 +21,6 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.execution.adaptive.AdaptiveTestUtils.assertExceptionMessage import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils @@ -100,7 +99,7 @@ class HiveMetadataCacheSuite extends QueryTest with SQLTestUtils with TestHiveSi val e = intercept[SparkException] { sql("select * from test").count() } - assertExceptionMessage(e, "FileNotFoundException") + assert(e.getMessage.contains("FileNotFoundException")) // Test refreshing the cache. spark.catalog.refreshTable("test") @@ -115,7 +114,7 @@ class HiveMetadataCacheSuite extends QueryTest with SQLTestUtils with TestHiveSi val e2 = intercept[SparkException] { sql("select * from test").count() } - assertExceptionMessage(e2, "FileNotFoundException") + assert(e.getMessage.contains("FileNotFoundException")) spark.catalog.refreshByPath(dir.getAbsolutePath) assert(sql("select * from test").count() == 3) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveShowCreateTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveShowCreateTableSuite.scala index 1e31e8b1bf234..cfcf70c0e79f0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveShowCreateTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveShowCreateTableSuite.scala @@ -25,22 +25,6 @@ import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} class HiveShowCreateTableSuite extends ShowCreateTableSuite with TestHiveSingleton { - private var origCreateHiveTableConfig = false - - protected override def beforeAll(): Unit = { - super.beforeAll() - origCreateHiveTableConfig = - spark.conf.get(SQLConf.LEGACY_CREATE_HIVE_TABLE_BY_DEFAULT_ENABLED) - spark.conf.set(SQLConf.LEGACY_CREATE_HIVE_TABLE_BY_DEFAULT_ENABLED.key, true) - } - - protected override def afterAll(): Unit = { - spark.conf.set( - SQLConf.LEGACY_CREATE_HIVE_TABLE_BY_DEFAULT_ENABLED.key, - origCreateHiveTableConfig) - super.afterAll() - } - test("view") { Seq(true, false).foreach { serde => withView("v1") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index d1dd13623650d..8642a5ff16812 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -982,7 +982,7 @@ class VersionsSuite extends SparkFunSuite with Logging { """.stripMargin ) - val errorMsg = "Cannot safely cast 'f0': DecimalType(2,1) to BinaryType" + val errorMsg = "Cannot safely cast 'f0': decimal(2,1) to binary" if (isPartitioned) { val insertStmt = s"INSERT OVERWRITE TABLE $tableName partition (ds='a') SELECT 1.3" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 61e1fefb5b5df..e8548fd62ddc1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -2706,33 +2706,6 @@ class HiveDDLSuite } } - test("SPARK-30098: create table without provider should " + - "use default data source under non-legacy mode") { - val catalog = spark.sessionState.catalog - withSQLConf( - SQLConf.LEGACY_CREATE_HIVE_TABLE_BY_DEFAULT_ENABLED.key -> "false") { - withTable("s") { - val defaultProvider = conf.defaultDataSourceName - sql("CREATE TABLE s(a INT, b INT)") - val table = catalog.getTableMetadata(TableIdentifier("s")) - assert(table.provider === Some(defaultProvider)) - } - } - } - - test("SPARK-30098: create table without provider should " + - "use hive under legacy mode") { - val catalog = spark.sessionState.catalog - withSQLConf( - SQLConf.LEGACY_CREATE_HIVE_TABLE_BY_DEFAULT_ENABLED.key -> "true") { - withTable("s") { - sql("CREATE TABLE s(a INT, b INT)") - val table = catalog.getTableMetadata(TableIdentifier("s")) - assert(table.provider === Some("hive")) - } - } - } - test("SPARK-30785: create table like a partitioned table") { val catalog = spark.sessionState.catalog withTable("sc_part", "ta_part") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 1e89db2bdd01a..63b985fbe4d32 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.net.URI import java.sql.Timestamp -import java.util.{Locale, TimeZone} +import java.util.Locale import scala.util.Try @@ -47,9 +47,6 @@ case class TestData(a: Int, b: String) * included in the hive distribution. */ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAndAfter { - private val originalTimeZone = TimeZone.getDefault - private val originalLocale = Locale.getDefault - import org.apache.spark.sql.hive.test.TestHive.implicits._ private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled @@ -59,10 +56,6 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd override def beforeAll(): Unit = { super.beforeAll() TestHive.setCacheTables(true) - // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) - TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) - // Add Locale setting - Locale.setDefault(Locale.US) // Ensures that cross joins are enabled so that we can test them TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true) } @@ -70,8 +63,6 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd override def afterAll(): Unit = { try { TestHive.setCacheTables(false) - TimeZone.setDefault(originalTimeZone) - Locale.setDefault(originalLocale) sql("DROP TEMPORARY FUNCTION IF EXISTS udtf_count2") TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled) } finally { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala index d2d350221aca0..24b1e3405379c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala @@ -87,8 +87,7 @@ class HiveSerDeSuite extends HiveComparisonTest with PlanTest with BeforeAndAfte SQLConf.withExistingConf(TestHive.conf)(super.withSQLConf(pairs: _*)(f)) test("Test the default fileformat for Hive-serde tables") { - withSQLConf("hive.default.fileformat" -> "orc", - SQLConf.LEGACY_CREATE_HIVE_TABLE_BY_DEFAULT_ENABLED.key -> "true") { + withSQLConf("hive.default.fileformat" -> "orc") { val (desc, exists) = extractTableDesc( "CREATE TABLE IF NOT EXISTS fileformat_test (id int)") assert(exists) @@ -97,8 +96,7 @@ class HiveSerDeSuite extends HiveComparisonTest with PlanTest with BeforeAndAfte assert(desc.storage.serde == Some("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) } - withSQLConf("hive.default.fileformat" -> "parquet", - SQLConf.LEGACY_CREATE_HIVE_TABLE_BY_DEFAULT_ENABLED.key -> "true") { + withSQLConf("hive.default.fileformat" -> "parquet") { val (desc, exists) = extractTableDesc("CREATE TABLE IF NOT EXISTS fileformat_test (id int)") assert(exists) val input = desc.storage.inputFormat diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index 7153d3f03cd57..b97eb869a9e54 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -227,6 +227,42 @@ class ScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with Tes 'e.cast("string")).collect()) } } + + test("SPARK-30973: TRANSFORM should wait for the termination of the script (no serde)") { + assume(TestUtils.testCommandAvailable("/bin/bash")) + + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + val e = intercept[SparkException] { + val plan = + new ScriptTransformationExec( + input = Seq(rowsDf.col("a").expr), + script = "some_non_existent_command", + output = Seq(AttributeReference("a", StringType)()), + child = rowsDf.queryExecution.sparkPlan, + ioschema = noSerdeIOSchema) + SparkPlanTest.executePlan(plan, hiveContext) + } + assert(e.getMessage.contains("Subprocess exited with status")) + assert(uncaughtExceptionHandler.exception.isEmpty) + } + + test("SPARK-30973: TRANSFORM should wait for the termination of the script (with serde)") { + assume(TestUtils.testCommandAvailable("/bin/bash")) + + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + val e = intercept[SparkException] { + val plan = + new ScriptTransformationExec( + input = Seq(rowsDf.col("a").expr), + script = "some_non_existent_command", + output = Seq(AttributeReference("a", StringType)()), + child = rowsDf.queryExecution.sparkPlan, + ioschema = serdeIOSchema) + SparkPlanTest.executePlan(plan, hiveContext) + } + assert(e.getMessage.contains("Subprocess exited with status")) + assert(uncaughtExceptionHandler.exception.isEmpty) + } } private case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala index 4ada5077aec7f..cbea74103343e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy._ import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ @@ -145,40 +146,52 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes val seed = System.nanoTime() withClue(s"Random data generated with the seed: ${seed}") { - val dataGenerator = RandomDataGenerator.forType( - dataType = dataType, - nullable = true, - new Random(seed) - ).getOrElse { - fail(s"Failed to create data generator for schema $dataType") + val java8ApiConfValues = if (dataType == DateType || dataType == TimestampType) { + Seq(false, true) + } else { + Seq(false) + } + java8ApiConfValues.foreach { java8Api => + withSQLConf( + SQLConf.DATETIME_JAVA8API_ENABLED.key -> java8Api.toString, + SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_WRITE.key -> CORRECTED.toString, + SQLConf.LEGACY_AVRO_REBASE_MODE_IN_WRITE.key -> CORRECTED.toString) { + val dataGenerator = RandomDataGenerator.forType( + dataType = dataType, + nullable = true, + new Random(seed) + ).getOrElse { + fail(s"Failed to create data generator for schema $dataType") + } + + // Create a DF for the schema with random data. The index field is used to sort the + // DataFrame. This is a workaround for SPARK-10591. + val schema = new StructType() + .add("index", IntegerType, nullable = false) + .add("col", dataType, nullable = true) + val rdd = + spark.sparkContext.parallelize((1 to 20).map(i => Row(i, dataGenerator()))) + val df = spark.createDataFrame(rdd, schema).orderBy("index").coalesce(1) + + df.write + .mode("overwrite") + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .options(extraOptions) + .save(path) + + val loadedDF = spark + .read + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .schema(df.schema) + .options(extraOptions) + .load(path) + .orderBy("index") + + checkAnswer(loadedDF, df) + } } - - // Create a DF for the schema with random data. The index field is used to sort the - // DataFrame. This is a workaround for SPARK-10591. - val schema = new StructType() - .add("index", IntegerType, nullable = false) - .add("col", dataType, nullable = true) - val rdd = - spark.sparkContext.parallelize((1 to 10).map(i => Row(i, dataGenerator()))) - val df = spark.createDataFrame(rdd, schema).orderBy("index").coalesce(1) - - df.write - .mode("overwrite") - .format(dataSourceName) - .option("dataSchema", df.schema.json) - .options(extraOptions) - .save(path) - - val loadedDF = spark - .read - .format(dataSourceName) - .option("dataSchema", df.schema.json) - .schema(df.schema) - .options(extraOptions) - .load(path) - .orderBy("index") - - checkAnswer(loadedDF, df) } } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 1d6637861511f..4eff464dcdafb 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -293,7 +293,8 @@ class StreamingContextSuite } } - test("stop gracefully") { + // TODO (SPARK-31728): re-enable it + ignore("stop gracefully") { val conf = new SparkConf().setMaster(master).setAppName(appName) conf.set("spark.dummyTimeConfig", "3600s") val sc = new SparkContext(conf)