diff --git a/LICENSE b/LICENSE index 820f14dbdeed0..6f5d9452e800d 100644 --- a/LICENSE +++ b/LICENSE @@ -237,23 +237,24 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (BSD 3 Clause) netlib core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core) (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.2.7 - https://github.com/jpmml/jpmml-model) + (BSD 3 Clause) jmock (org.jmock:jmock-junit4:2.8.4 - http://jmock.org/) (BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/) (BSD License) ANTLR 4.5.2-1 (org.antlr:antlr4:4.5.2-1 - http://wwww.antlr.org/) (BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org) (BSD licence) ANTLR StringTemplate (org.antlr:stringtemplate:3.2.1 - http://www.stringtemplate.org) (BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org) - (BSD) JLine (jline:jline:0.9.94 - http://jline.sourceforge.net) + (BSD) JLine (jline:jline:2.14.3 - https://github.com/jline/jline2) (BSD) ParaNamer Core (com.thoughtworks.paranamer:paranamer:2.3 - http://paranamer.codehaus.org/paranamer) (BSD) ParaNamer Core (com.thoughtworks.paranamer:paranamer:2.6 - http://paranamer.codehaus.org/paranamer) (BSD 3 Clause) Scala (http://www.scala-lang.org/download/#License) (Interpreter classes (all .scala files in repl/src/main/scala except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala), and for SerializableMapWrapper in JavaUtils.scala) - (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.11.8 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.11.8 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.11.8 - http://www.scala-lang.org/) - (BSD-like) Scala Library (org.scala-lang:scala-library:2.11.8 - http://www.scala-lang.org/) - (BSD-like) Scalap (org.scala-lang:scalap:2.11.8 - http://www.scala-lang.org/) + (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.11.12 - http://www.scala-lang.org/) + (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.11.12 - http://www.scala-lang.org/) + (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.11.12 - http://www.scala-lang.org/) + (BSD-like) Scala Library (org.scala-lang:scala-library:2.11.12 - http://www.scala-lang.org/) + (BSD-like) Scalap (org.scala-lang:scalap:2.11.12 - http://www.scala-lang.org/) (BSD-style) scalacheck (org.scalacheck:scalacheck_2.11:1.10.0 - http://www.scalacheck.org) (BSD-style) spire (org.spire-math:spire_2.11:0.7.1 - http://spire-math.org) (BSD-style) spire-macros (org.spire-math:spire-macros_2.11:0.7.1 - http://spire-math.org) diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index a871ab5d448c3..a3f1bcffaea57 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -70,17 +70,18 @@ function build { local BASEDOCKERFILE=${BASEDOCKERFILE:-"$IMG_PATH/spark/Dockerfile"} local PYDOCKERFILE=${PYDOCKERFILE:-"$IMG_PATH/spark/bindings/python/Dockerfile"} - docker build "${BUILD_ARGS[@]}" \ + docker build $NOCACHEARG "${BUILD_ARGS[@]}" \ -t $(image_ref spark) \ -f "$BASEDOCKERFILE" . - docker build "${BINDING_BUILD_ARGS[@]}" \ + docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ -t $(image_ref spark-py) \ -f "$PYDOCKERFILE" . } function push { docker push "$(image_ref spark)" + docker push "$(image_ref spark-py)" } function usage { @@ -99,6 +100,7 @@ Options: -r repo Repository address. -t tag Tag to apply to the built image, or to identify the image to be pushed. -m Use minikube's Docker daemon. + -n Build docker image with --no-cache Using minikube when building images will do so directly into minikube's Docker daemon. There is no need to push the images into minikube in that case, they'll be automatically @@ -127,7 +129,8 @@ REPO= TAG= BASEDOCKERFILE= PYDOCKERFILE= -while getopts f:mr:t: option +NOCACHEARG= +while getopts f:mr:t:n option do case "${option}" in @@ -135,6 +138,7 @@ do p) PYDOCKERFILE=${OPTARG};; r) REPO=${OPTARG};; t) TAG=${OPTARG};; + n) NOCACHEARG="--no-cache";; m) if ! which minikube 1>/dev/null; then error "Cannot find minikube." diff --git a/build/mvn b/build/mvn index efa4f9364ea52..1405983982d4c 100755 --- a/build/mvn +++ b/build/mvn @@ -154,4 +154,4 @@ export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"} echo "Using \`mvn\` from path: $MVN_BIN" 1>&2 # Last, call the `mvn` command as usual -${MVN_BIN} -DzincPort=${ZINC_PORT} "$@" +"${MVN_BIN}" -DzincPort=${ZINC_PORT} "$@" diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java index a5337656cbd84..e7b66a6f33a82 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java @@ -137,30 +137,15 @@ protected void deallocate() { } private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException { - ByteBuffer buffer = buf.nioBuffer(); - int written = (buffer.remaining() <= NIO_BUFFER_LIMIT) ? - target.write(buffer) : writeNioBuffer(target, buffer); + // SPARK-24578: cap the sub-region's size of returned nio buffer to improve the performance + // for the case that the passed-in buffer has too many components. + int length = Math.min(buf.readableBytes(), NIO_BUFFER_LIMIT); + ByteBuffer buffer = buf.nioBuffer(buf.readerIndex(), length); + int written = target.write(buffer); buf.skipBytes(written); return written; } - private int writeNioBuffer( - WritableByteChannel writeCh, - ByteBuffer buf) throws IOException { - int originalLimit = buf.limit(); - int ret = 0; - - try { - int ioSize = Math.min(buf.remaining(), NIO_BUFFER_LIMIT); - buf.limit(buf.position() + ioSize); - ret = writeCh.write(buf); - } finally { - buf.limit(originalLimit); - } - - return ret; - } - @Override public MessageWithHeader touch(Object o) { super.touch(o); diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 5f0045507aaab..9a767dd739b91 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -703,7 +703,7 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff // must be stored in the same memory page. // (8 byte key length) (key) (value) (8 byte pointer to next value) int uaoSize = UnsafeAlignedOffset.getUaoSize(); - final long recordLength = (2 * uaoSize) + klen + vlen + 8; + final long recordLength = (2L * uaoSize) + klen + vlen + 8; if (currentPage == null || currentPage.size() - pageCursor < recordLength) { if (!acquireNewPage(recordLength + uaoSize)) { return false; diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 04c38f12acc78..1632e0c69eef5 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -21,6 +21,7 @@ import java.io.File import java.security.NoSuchAlgorithmException import javax.net.ssl.SSLContext +import org.apache.hadoop.conf.Configuration import org.eclipse.jetty.util.ssl.SslContextFactory import org.apache.spark.internal.Logging @@ -163,11 +164,16 @@ private[spark] object SSLOptions extends Logging { * missing in SparkConf, the corresponding setting is used from the default configuration. * * @param conf Spark configuration object where the settings are collected from + * @param hadoopConf Hadoop configuration to get settings * @param ns the namespace name * @param defaults the default configuration * @return [[org.apache.spark.SSLOptions]] object */ - def parse(conf: SparkConf, ns: String, defaults: Option[SSLOptions] = None): SSLOptions = { + def parse( + conf: SparkConf, + hadoopConf: Configuration, + ns: String, + defaults: Option[SSLOptions] = None): SSLOptions = { val enabled = conf.getBoolean(s"$ns.enabled", defaultValue = defaults.exists(_.enabled)) val port = conf.getWithSubstitution(s"$ns.port").map(_.toInt) @@ -179,9 +185,11 @@ private[spark] object SSLOptions extends Logging { .orElse(defaults.flatMap(_.keyStore)) val keyStorePassword = conf.getWithSubstitution(s"$ns.keyStorePassword") + .orElse(Option(hadoopConf.getPassword(s"$ns.keyStorePassword")).map(new String(_))) .orElse(defaults.flatMap(_.keyStorePassword)) val keyPassword = conf.getWithSubstitution(s"$ns.keyPassword") + .orElse(Option(hadoopConf.getPassword(s"$ns.keyPassword")).map(new String(_))) .orElse(defaults.flatMap(_.keyPassword)) val keyStoreType = conf.getWithSubstitution(s"$ns.keyStoreType") @@ -194,6 +202,7 @@ private[spark] object SSLOptions extends Logging { .orElse(defaults.flatMap(_.trustStore)) val trustStorePassword = conf.getWithSubstitution(s"$ns.trustStorePassword") + .orElse(Option(hadoopConf.getPassword(s"$ns.trustStorePassword")).map(new String(_))) .orElse(defaults.flatMap(_.trustStorePassword)) val trustStoreType = conf.getWithSubstitution(s"$ns.trustStoreType") diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index b87476322573d..3cfafeb951105 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -19,11 +19,11 @@ package org.apache.spark import java.net.{Authenticator, PasswordAuthentication} import java.nio.charset.StandardCharsets.UTF_8 -import javax.net.ssl._ import org.apache.hadoop.io.Text import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.launcher.SparkLauncher @@ -111,11 +111,14 @@ private[spark] class SecurityManager( ) } + private val hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf) // the default SSL configuration - it will be used by all communication layers unless overwritten - private val defaultSSLOptions = SSLOptions.parse(sparkConf, "spark.ssl", defaults = None) + private val defaultSSLOptions = + SSLOptions.parse(sparkConf, hadoopConf, "spark.ssl", defaults = None) def getSSLOptions(module: String): SSLOptions = { - val opts = SSLOptions.parse(sparkConf, s"spark.ssl.$module", Some(defaultSSLOptions)) + val opts = + SSLOptions.parse(sparkConf, hadoopConf, s"spark.ssl.$module", Some(defaultSSLOptions)) logDebug(s"Created SSL options for $module: $opts") opts } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 5e8595603cc90..74bfb5d6d2ea3 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1306,11 +1306,12 @@ class SparkContext(config: SparkConf) extends Logging { /** Build the union of a list of RDDs. */ def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = withScope { - val partitioners = rdds.flatMap(_.partitioner).toSet - if (rdds.forall(_.partitioner.isDefined) && partitioners.size == 1) { - new PartitionerAwareUnionRDD(this, rdds) + val nonEmptyRdds = rdds.filter(!_.partitions.isEmpty) + val partitioners = nonEmptyRdds.flatMap(_.partitioner).toSet + if (nonEmptyRdds.forall(_.partitioner.isDefined) && partitioners.size == 1) { + new PartitionerAwareUnionRDD(this, nonEmptyRdds) } else { - new UnionRDD(this, rdds) + new UnionRDD(this, nonEmptyRdds) } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 41eac10d9b267..ebabedf950e39 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -40,6 +40,7 @@ private[spark] object PythonEvalType { val SQL_SCALAR_PANDAS_UDF = 200 val SQL_GROUPED_MAP_PANDAS_UDF = 201 val SQL_GROUPED_AGG_PANDAS_UDF = 202 + val SQL_WINDOW_AGG_PANDAS_UDF = 203 def toString(pythonEvalType: Int): String = pythonEvalType match { case NON_UDF => "NON_UDF" @@ -47,6 +48,7 @@ private[spark] object PythonEvalType { case SQL_SCALAR_PANDAS_UDF => "SQL_SCALAR_PANDAS_UDF" case SQL_GROUPED_MAP_PANDAS_UDF => "SQL_GROUPED_MAP_PANDAS_UDF" case SQL_GROUPED_AGG_PANDAS_UDF => "SQL_GROUPED_AGG_PANDAS_UDF" + case SQL_WINDOW_AGG_PANDAS_UDF => "SQL_WINDOW_AGG_PANDAS_UDF" } } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index a9a4d5a4ec6a2..56f3f59504a7d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -124,7 +124,7 @@ class HistoryServer( attachHandler(ApiRootResource.getServletHandler(this)) - attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) + addStaticHandler(SparkUI.STATIC_RESOURCE_DIR) val contextHandler = new ServletContextHandler contextHandler.setContextPath(HistoryServer.UI_PATH_PREFIX) @@ -152,7 +152,6 @@ class HistoryServer( assert(serverInfo.isDefined, "HistoryServer must be bound before attaching SparkUIs") handlers.synchronized { ui.getHandlers.foreach(attachHandler) - addFilters(ui.getHandlers, conf) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 35b7ddd46e4db..e87b2240564bd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -43,7 +43,7 @@ class MasterWebUI( val masterPage = new MasterPage(this) attachPage(new ApplicationPage(this)) attachPage(masterPage) - attachHandler(createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static")) + addStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR) attachHandler(createRedirectHandler( "/app/kill", "/", masterPage.handleAppKillRequest, httpMethods = Set("POST"))) attachHandler(createRedirectHandler( diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 58a181128eb4d..a6d13d12fc28d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -225,7 +225,7 @@ private[deploy] class DriverRunner( // check if attempting another run keepTrying = supervise && exitCode != 0 && !killed if (keepTrying) { - if (clock.getTimeMillis() - processStart > successfulRunDuration * 1000) { + if (clock.getTimeMillis() - processStart > successfulRunDuration * 1000L) { waitSeconds = 1 } logInfo(s"Command exited with status $exitCode, re-launching after $waitSeconds s.") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index db696b04384bd..ea67b7434a769 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -47,7 +47,7 @@ class WorkerWebUI( val logPage = new LogPage(this) attachPage(logPage) attachPage(new WorkerPage(this)) - attachHandler(createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE, "/static")) + addStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE) attachHandler(createServletHandler("/log", (request: HttpServletRequest) => logPage.renderLog(request), worker.securityMgr, diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index a54b091a64d50..38a043c85ae33 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -552,4 +552,11 @@ package object config { .timeConf(TimeUnit.SECONDS) .createWithDefaultString("1h") + private[spark] val SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS = + ConfigBuilder("spark.shuffle.minNumPartitionsToHighlyCompress") + .internal() + .doc("Number of partitions to determine if MapStatus should use HighlyCompressedMapStatus") + .intConf + .checkValue(v => v > 0, "The value should be a positive integer.") + .createWithDefault(2000) } diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala index abf39213fa0d2..9ebd0aa301592 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala @@ -76,13 +76,17 @@ object SparkHadoopWriter extends Logging { // Try to write all RDD partitions as a Hadoop OutputFormat. try { val ret = sparkContext.runJob(rdd, (context: TaskContext, iter: Iterator[(K, V)]) => { + // SPARK-24552: Generate a unique "attempt ID" based on the stage and task attempt numbers. + // Assumes that there won't be more than Short.MaxValue attempts, at least not concurrently. + val attemptId = (context.stageAttemptNumber << 16) | context.attemptNumber + executeTask( context = context, config = config, jobTrackerId = jobTrackerId, commitJobId = commitJobId, sparkPartitionId = context.partitionId, - sparkAttemptNumber = context.attemptNumber, + sparkAttemptNumber = attemptId, committer = committer, iterator = iter) }) diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index 764735dc4eae7..db8aff94ea1e1 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -69,9 +69,9 @@ object SparkHadoopMapRedUtil extends Logging { if (shouldCoordinateWithDriver) { val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator - val taskAttemptNumber = TaskContext.get().attemptNumber() - val stageId = TaskContext.get().stageId() - val canCommit = outputCommitCoordinator.canCommit(stageId, splitId, taskAttemptNumber) + val ctx = TaskContext.get() + val canCommit = outputCommitCoordinator.canCommit(ctx.stageId(), ctx.stageAttemptNumber(), + splitId, ctx.attemptNumber()) if (canCommit) { performCommit() @@ -81,7 +81,7 @@ object SparkHadoopMapRedUtil extends Logging { logInfo(message) // We need to abort the task so that the driver can reschedule new attempts, if necessary committer.abortTask(mrTaskContext) - throw new CommitDeniedException(message, stageId, splitId, taskAttemptNumber) + throw new CommitDeniedException(message, ctx.stageId(), splitId, ctx.attemptNumber()) } } else { // Speculation is disabled or a user has chosen to manually bypass the commit coordination diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 13db4985b0b80..ba9dae4ad48ec 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -95,7 +95,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi // the left side of max is >=1 whenever partsScanned >= 2 numPartsToTry = Math.max(1, (1.5 * num * partsScanned / results.size).toInt - partsScanned) - numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) + numPartsToTry = Math.min(numPartsToTry, partsScanned * 4L) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala index 30cf75d43ee09..980fbbe516b91 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala @@ -371,7 +371,7 @@ private[scheduler] class BlacklistTracker ( } -private[scheduler] object BlacklistTracker extends Logging { +private[spark] object BlacklistTracker extends Logging { private val DEFAULT_TIMEOUT = "1h" diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 041eade82d3ca..f74425d73b392 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1171,6 +1171,7 @@ class DAGScheduler( outputCommitCoordinator.taskCompleted( stageId, + task.stageAttemptId, task.partitionId, event.taskInfo.attemptNumber, // this is a task attempt number event.reason) @@ -1330,23 +1331,24 @@ class DAGScheduler( s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + s"(attempt ${failedStage.latestInfo.attemptNumber}) running") } else { + failedStage.fetchFailedAttemptIds.add(task.stageAttemptId) + val shouldAbortStage = + failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts || + disallowStageRetryForTest + // It is likely that we receive multiple FetchFailed for a single stage (because we have // multiple tasks running concurrently on different executors). In that case, it is // possible the fetch failure has already been handled by the scheduler. if (runningStages.contains(failedStage)) { logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + s"due to a fetch failure from $mapStage (${mapStage.name})") - markStageAsFinished(failedStage, Some(failureMessage)) + markStageAsFinished(failedStage, errorMessage = Some(failureMessage), + willRetry = !shouldAbortStage) } else { logDebug(s"Received fetch failure from $task, but its from $failedStage which is no " + s"longer running") } - failedStage.fetchFailedAttemptIds.add(task.stageAttemptId) - val shouldAbortStage = - failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts || - disallowStageRetryForTest - if (shouldAbortStage) { val abortMessage = if (disallowStageRetryForTest) { "Fetch failure will not retry stage due to testing config" @@ -1545,7 +1547,10 @@ class DAGScheduler( /** * Marks a stage as finished and removes it from the list of running stages. */ - private def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = { + private def markStageAsFinished( + stage: Stage, + errorMessage: Option[String] = None, + willRetry: Boolean = false): Unit = { val serviceTime = stage.latestInfo.submissionTime match { case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0) case _ => "Unknown" @@ -1564,7 +1569,9 @@ class DAGScheduler( logInfo(s"$stage (${stage.name}) failed in $serviceTime s due to ${errorMessage.get}") } - outputCommitCoordinator.stageEnd(stage.id) + if (!willRetry) { + outputCommitCoordinator.stageEnd(stage.id) + } listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) runningStages -= stage } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 2ec2f2031aa45..659694dd189ad 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -50,7 +50,9 @@ private[spark] sealed trait MapStatus { private[spark] object MapStatus { def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { - if (uncompressedSizes.length > 2000) { + if (uncompressedSizes.length > Option(SparkEnv.get) + .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS)) + .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get)) { HighlyCompressedMapStatus(loc, uncompressedSizes) } else { new CompressedMapStatus(loc, uncompressedSizes) diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 83d87b548a430..b382d623806e2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -27,7 +27,11 @@ import org.apache.spark.util.{RpcUtils, ThreadUtils} private sealed trait OutputCommitCoordinationMessage extends Serializable private case object StopCoordinator extends OutputCommitCoordinationMessage -private case class AskPermissionToCommitOutput(stage: Int, partition: Int, attemptNumber: Int) +private case class AskPermissionToCommitOutput( + stage: Int, + stageAttempt: Int, + partition: Int, + attemptNumber: Int) /** * Authority that decides whether tasks can commit output to HDFS. Uses a "first committer wins" @@ -45,13 +49,15 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) // Initialized by SparkEnv var coordinatorRef: Option[RpcEndpointRef] = None - private type StageId = Int - private type PartitionId = Int - private type TaskAttemptNumber = Int - private val NO_AUTHORIZED_COMMITTER: TaskAttemptNumber = -1 + // Class used to identify a committer. The task ID for a committer is implicitly defined by + // the partition being processed, but the coordinator needs to keep track of both the stage + // attempt and the task attempt, because in some situations the same task may be running + // concurrently in two different attempts of the same stage. + private case class TaskIdentifier(stageAttempt: Int, taskAttempt: Int) + private case class StageState(numPartitions: Int) { - val authorizedCommitters = Array.fill[TaskAttemptNumber](numPartitions)(NO_AUTHORIZED_COMMITTER) - val failures = mutable.Map[PartitionId, mutable.Set[TaskAttemptNumber]]() + val authorizedCommitters = Array.fill[TaskIdentifier](numPartitions)(null) + val failures = mutable.Map[Int, mutable.Set[TaskIdentifier]]() } /** @@ -64,7 +70,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * * Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance. */ - private val stageStates = mutable.Map[StageId, StageState]() + private val stageStates = mutable.Map[Int, StageState]() /** * Returns whether the OutputCommitCoordinator's internal data structures are all empty. @@ -87,10 +93,11 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * @return true if this task is authorized to commit, false otherwise */ def canCommit( - stage: StageId, - partition: PartitionId, - attemptNumber: TaskAttemptNumber): Boolean = { - val msg = AskPermissionToCommitOutput(stage, partition, attemptNumber) + stage: Int, + stageAttempt: Int, + partition: Int, + attemptNumber: Int): Boolean = { + val msg = AskPermissionToCommitOutput(stage, stageAttempt, partition, attemptNumber) coordinatorRef match { case Some(endpointRef) => ThreadUtils.awaitResult(endpointRef.ask[Boolean](msg), @@ -103,26 +110,35 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) } /** - * Called by the DAGScheduler when a stage starts. + * Called by the DAGScheduler when a stage starts. Initializes the stage's state if it hasn't + * yet been initialized. * * @param stage the stage id. * @param maxPartitionId the maximum partition id that could appear in this stage's tasks (i.e. * the maximum possible value of `context.partitionId`). */ - private[scheduler] def stageStart(stage: StageId, maxPartitionId: Int): Unit = synchronized { - stageStates(stage) = new StageState(maxPartitionId + 1) + private[scheduler] def stageStart(stage: Int, maxPartitionId: Int): Unit = synchronized { + stageStates.get(stage) match { + case Some(state) => + require(state.authorizedCommitters.length == maxPartitionId + 1) + logInfo(s"Reusing state from previous attempt of stage $stage.") + + case _ => + stageStates(stage) = new StageState(maxPartitionId + 1) + } } // Called by DAGScheduler - private[scheduler] def stageEnd(stage: StageId): Unit = synchronized { + private[scheduler] def stageEnd(stage: Int): Unit = synchronized { stageStates.remove(stage) } // Called by DAGScheduler private[scheduler] def taskCompleted( - stage: StageId, - partition: PartitionId, - attemptNumber: TaskAttemptNumber, + stage: Int, + stageAttempt: Int, + partition: Int, + attemptNumber: Int, reason: TaskEndReason): Unit = synchronized { val stageState = stageStates.getOrElse(stage, { logDebug(s"Ignoring task completion for completed stage") @@ -131,16 +147,17 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) reason match { case Success => // The task output has been committed successfully - case denied: TaskCommitDenied => - logInfo(s"Task was denied committing, stage: $stage, partition: $partition, " + - s"attempt: $attemptNumber") - case otherReason => + case _: TaskCommitDenied => + logInfo(s"Task was denied committing, stage: $stage.$stageAttempt, " + + s"partition: $partition, attempt: $attemptNumber") + case _ => // Mark the attempt as failed to blacklist from future commit protocol - stageState.failures.getOrElseUpdate(partition, mutable.Set()) += attemptNumber - if (stageState.authorizedCommitters(partition) == attemptNumber) { + val taskId = TaskIdentifier(stageAttempt, attemptNumber) + stageState.failures.getOrElseUpdate(partition, mutable.Set()) += taskId + if (stageState.authorizedCommitters(partition) == taskId) { logDebug(s"Authorized committer (attemptNumber=$attemptNumber, stage=$stage, " + s"partition=$partition) failed; clearing lock") - stageState.authorizedCommitters(partition) = NO_AUTHORIZED_COMMITTER + stageState.authorizedCommitters(partition) = null } } } @@ -155,47 +172,41 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) // Marked private[scheduler] instead of private so this can be mocked in tests private[scheduler] def handleAskPermissionToCommit( - stage: StageId, - partition: PartitionId, - attemptNumber: TaskAttemptNumber): Boolean = synchronized { + stage: Int, + stageAttempt: Int, + partition: Int, + attemptNumber: Int): Boolean = synchronized { stageStates.get(stage) match { - case Some(state) if attemptFailed(state, partition, attemptNumber) => - logInfo(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage," + - s" partition=$partition as task attempt $attemptNumber has already failed.") + case Some(state) if attemptFailed(state, stageAttempt, partition, attemptNumber) => + logInfo(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " + + s"task attempt $attemptNumber already marked as failed.") false case Some(state) => - state.authorizedCommitters(partition) match { - case NO_AUTHORIZED_COMMITTER => - logDebug(s"Authorizing attemptNumber=$attemptNumber to commit for stage=$stage, " + - s"partition=$partition") - state.authorizedCommitters(partition) = attemptNumber - true - case existingCommitter => - // Coordinator should be idempotent when receiving AskPermissionToCommit. - if (existingCommitter == attemptNumber) { - logWarning(s"Authorizing duplicate request to commit for " + - s"attemptNumber=$attemptNumber to commit for stage=$stage," + - s" partition=$partition; existingCommitter = $existingCommitter." + - s" This can indicate dropped network traffic.") - true - } else { - logDebug(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage, " + - s"partition=$partition; existingCommitter = $existingCommitter") - false - } + val existing = state.authorizedCommitters(partition) + if (existing == null) { + logDebug(s"Commit allowed for stage=$stage.$stageAttempt, partition=$partition, " + + s"task attempt $attemptNumber") + state.authorizedCommitters(partition) = TaskIdentifier(stageAttempt, attemptNumber) + true + } else { + logDebug(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " + + s"already committed by $existing") + false } case None => - logDebug(s"Stage $stage has completed, so not allowing" + - s" attempt number $attemptNumber of partition $partition to commit") + logDebug(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " + + "stage already marked as completed.") false } } private def attemptFailed( stageState: StageState, - partition: PartitionId, - attempt: TaskAttemptNumber): Boolean = synchronized { - stageState.failures.get(partition).exists(_.contains(attempt)) + stageAttempt: Int, + partition: Int, + attempt: Int): Boolean = synchronized { + val failInfo = TaskIdentifier(stageAttempt, attempt) + stageState.failures.get(partition).exists(_.contains(failInfo)) } } @@ -215,9 +226,10 @@ private[spark] object OutputCommitCoordinator { } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case AskPermissionToCommitOutput(stage, partition, attemptNumber) => + case AskPermissionToCommitOutput(stage, stageAttempt, partition, attemptNumber) => context.reply( - outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, attemptNumber)) + outputCommitCoordinator.handleAskPermissionToCommit(stage, stageAttempt, partition, + attemptNumber)) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index d8794e8e551aa..9b90e309d2e04 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -170,8 +170,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp if (executorDataMap.contains(executorId)) { executorRef.send(RegisterExecutorFailed("Duplicate executor ID: " + executorId)) context.reply(true) - } else if (scheduler.nodeBlacklist != null && - scheduler.nodeBlacklist.contains(hostname)) { + } else if (scheduler.nodeBlacklist.contains(hostname)) { // If the cluster manager gives us an executor on a blacklisted node (because it // already started allocating those resources before we informed it of our blacklist, // or if it ignored our blacklist), then we reject that executor immediately. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e0276a4dc4224..df1a4bef616b2 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -291,7 +291,7 @@ private[spark] class BlockManager( case e: Exception if i < MAX_ATTEMPTS => logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}" + s" more times after waiting $SLEEP_TIME_SECS seconds...", e) - Thread.sleep(SLEEP_TIME_SECS * 1000) + Thread.sleep(SLEEP_TIME_SECS * 1000L) case NonFatal(e) => throw new SparkException("Unable to register with external shuffle server due to : " + e.getMessage, e) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index d6a025a6f12da..52a955111231a 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -263,7 +263,7 @@ private[spark] object JettyUtils extends Logging { filters.foreach { case filter : String => if (!filter.isEmpty) { - logInfo("Adding filter: " + filter) + logInfo(s"Adding filter $filter to ${handlers.map(_.getContextPath).mkString(", ")}.") val holder : FilterHolder = new FilterHolder() holder.setClassName(filter) // Get any parameters for each filter @@ -407,7 +407,7 @@ private[spark] object JettyUtils extends Logging { } pool.setMaxThreads(math.max(pool.getMaxThreads, minThreads)) - ServerInfo(server, httpPort, securePort, collection) + ServerInfo(server, httpPort, securePort, conf, collection) } catch { case e: Exception => server.stop() @@ -507,10 +507,12 @@ private[spark] case class ServerInfo( server: Server, boundPort: Int, securePort: Option[Int], + conf: SparkConf, private val rootHandler: ContextHandlerCollection) { - def addHandler(handler: ContextHandler): Unit = { + def addHandler(handler: ServletContextHandler): Unit = { handler.setVirtualHosts(JettyUtils.toVirtualHosts(JettyUtils.SPARK_CONNECTOR_NAME)) + JettyUtils.addFilters(Seq(handler), conf) rootHandler.addHandler(handler) if (!handler.isStarted()) { handler.start() diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index b44ac0ea1febc..d315ef66e0dc0 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -65,7 +65,7 @@ private[spark] class SparkUI private ( attachTab(new StorageTab(this, store)) attachTab(new EnvironmentTab(this, store)) attachTab(new ExecutorsTab(this)) - attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) + addStaticHandler(SparkUI.STATIC_RESOURCE_DIR) attachHandler(createRedirectHandler("/", "/jobs/", basePath = basePath)) attachHandler(ApiRootResource.getServletHandler(this)) diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 8b75f5d8fe1a8..2e43f17e6a8e3 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -60,23 +60,25 @@ private[spark] abstract class WebUI( def getHandlers: Seq[ServletContextHandler] = handlers def getSecurityManager: SecurityManager = securityManager - /** Attach a tab to this UI, along with all of its attached pages. */ - def attachTab(tab: WebUITab) { + /** Attaches a tab to this UI, along with all of its attached pages. */ + def attachTab(tab: WebUITab): Unit = { tab.pages.foreach(attachPage) tabs += tab } - def detachTab(tab: WebUITab) { + /** Detaches a tab from this UI, along with all of its attached pages. */ + def detachTab(tab: WebUITab): Unit = { tab.pages.foreach(detachPage) tabs -= tab } - def detachPage(page: WebUIPage) { + /** Detaches a page from this UI, along with all of its attached handlers. */ + def detachPage(page: WebUIPage): Unit = { pageToHandlers.remove(page).foreach(_.foreach(detachHandler)) } - /** Attach a page to this UI. */ - def attachPage(page: WebUIPage) { + /** Attaches a page to this UI. */ + def attachPage(page: WebUIPage): Unit = { val pagePath = "/" + page.prefix val renderHandler = createServletHandler(pagePath, (request: HttpServletRequest) => page.render(request), securityManager, conf, basePath) @@ -88,41 +90,41 @@ private[spark] abstract class WebUI( handlers += renderHandler } - /** Attach a handler to this UI. */ - def attachHandler(handler: ServletContextHandler) { + /** Attaches a handler to this UI. */ + def attachHandler(handler: ServletContextHandler): Unit = { handlers += handler serverInfo.foreach(_.addHandler(handler)) } - /** Detach a handler from this UI. */ - def detachHandler(handler: ServletContextHandler) { + /** Detaches a handler from this UI. */ + def detachHandler(handler: ServletContextHandler): Unit = { handlers -= handler serverInfo.foreach(_.removeHandler(handler)) } /** - * Add a handler for static content. + * Detaches the content handler at `path` URI. * - * @param resourceBase Root of where to find resources to serve. - * @param path Path in UI where to mount the resources. + * @param path Path in UI to unmount. */ - def addStaticHandler(resourceBase: String, path: String): Unit = { - attachHandler(JettyUtils.createStaticHandler(resourceBase, path)) + def detachHandler(path: String): Unit = { + handlers.find(_.getContextPath() == path).foreach(detachHandler) } /** - * Remove a static content handler. + * Adds a handler for static content. * - * @param path Path in UI to unmount. + * @param resourceBase Root of where to find resources to serve. + * @param path Path in UI where to mount the resources. */ - def removeStaticHandler(path: String): Unit = { - handlers.find(_.getContextPath() == path).foreach(detachHandler) + def addStaticHandler(resourceBase: String, path: String = "/static"): Unit = { + attachHandler(JettyUtils.createStaticHandler(resourceBase, path)) } - /** Initialize all components of the server. */ + /** A hook to initialize components of the UI */ def initialize(): Unit - /** Bind to the HTTP server behind this web interface. */ + /** Binds to the HTTP server behind this web interface. */ def bind(): Unit = { assert(serverInfo.isEmpty, s"Attempted to bind $className more than once!") try { @@ -136,17 +138,17 @@ private[spark] abstract class WebUI( } } - /** Return the url of web interface. Only valid after bind(). */ + /** @return The url of web interface. Only valid after [[bind]]. */ def webUrl: String = s"http://$publicHostName:$boundPort" - /** Return the actual port to which this server is bound. Only valid after bind(). */ + /** @return The actual port to which this server is bound. Only valid after [[bind]]. */ def boundPort: Int = serverInfo.map(_.boundPort).getOrElse(-1) - /** Stop the server behind this web interface. Only valid after bind(). */ + /** Stops the server behind this web interface. Only valid after [[bind]]. */ def stop(): Unit = { assert(serverInfo.isDefined, s"Attempted to stop $className before binding to a server!") - serverInfo.get.stop() + serverInfo.foreach(_.stop()) } } diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 3b469a69437b9..bf618b4afbce0 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -200,10 +200,12 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { } override def toString: String = { + // getClass.getSimpleName can cause Malformed class name error, + // call safer `Utils.getSimpleName` instead if (metadata == null) { - "Un-registered Accumulator: " + getClass.getSimpleName + "Un-registered Accumulator: " + Utils.getSimpleName(getClass) } else { - getClass.getSimpleName + s"(id: $id, name: $name, value: $value)" + Utils.getSimpleName(getClass) + s"(id: $id, name: $name, value: $value)" } } } diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 165a15c73e7ca..0f08a2b0ad895 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -19,13 +19,12 @@ package org.apache.spark.util import java.util.concurrent._ +import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor} -import scala.concurrent.duration.Duration +import scala.concurrent.duration.{Duration, FiniteDuration} import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread} import scala.util.control.NonFatal -import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} - import org.apache.spark.SparkException private[spark] object ThreadUtils { @@ -103,6 +102,22 @@ private[spark] object ThreadUtils { executor } + /** + * Wrapper over ScheduledThreadPoolExecutor. + */ + def newDaemonThreadPoolScheduledExecutor(threadNamePrefix: String, numThreads: Int) + : ScheduledExecutorService = { + val threadFactory = new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat(s"$threadNamePrefix-%d") + .build() + val executor = new ScheduledThreadPoolExecutor(numThreads, threadFactory) + // By default, a cancelled task is not automatically removed from the work queue until its delay + // elapses. We have to enable it manually. + executor.setRemoveOnCancelPolicy(true) + executor + } + /** * Run a piece of code in a new thread and return the result. Exception in the new thread is * thrown in the caller thread with an adjusted stack trace that removes references to this @@ -229,4 +244,14 @@ private[spark] object ThreadUtils { } } // scalastyle:on awaitready + + def shutdown( + executor: ExecutorService, + gracePeriod: Duration = FiniteDuration(30, TimeUnit.SECONDS)): Unit = { + executor.shutdown() + executor.awaitTermination(gracePeriod.toMillis, TimeUnit.MILLISECONDS) + if (!executor.isShutdown) { + executor.shutdownNow() + } + } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index f9191a59c1655..a6fd3637663e8 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -19,6 +19,7 @@ package org.apache.spark.util import java.io._ import java.lang.{Byte => JByte} +import java.lang.InternalError import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo} import java.lang.reflect.InvocationTargetException import java.math.{MathContext, RoundingMode} @@ -30,6 +31,7 @@ import java.nio.file.Files import java.security.SecureRandom import java.util.{Locale, Properties, Random, UUID} import java.util.concurrent._ +import java.util.concurrent.TimeUnit.NANOSECONDS import java.util.concurrent.atomic.AtomicBoolean import java.util.zip.GZIPInputStream @@ -98,7 +100,7 @@ private[spark] object Utils extends Logging { */ val DEFAULT_MAX_TO_STRING_FIELDS = 25 - private def maxNumToStringFields = { + private[spark] def maxNumToStringFields = { if (SparkEnv.get != null) { SparkEnv.get.conf.getInt("spark.debug.maxToStringFields", DEFAULT_MAX_TO_STRING_FIELDS) } else { @@ -433,7 +435,7 @@ private[spark] object Utils extends Logging { new URI("file:///" + rawFileName).getPath.substring(1) } - /** + /** * Download a file or directory to target directory. Supports fetching the file in a variety of * ways, including HTTP, Hadoop-compatible filesystems, and files on a standard filesystem, based * on the URL parameter. Fetching directories is only supported from Hadoop-compatible @@ -506,6 +508,14 @@ private[spark] object Utils extends Logging { targetFile } + /** Records the duration of running `body`. */ + def timeTakenMs[T](body: => T): (T, Long) = { + val startTime = System.nanoTime() + val result = body + val endTime = System.nanoTime() + (result, math.max(NANOSECONDS.toMillis(endTime - startTime), 0)) + } + /** * Download `in` to `tempFile`, then move it to `destFile`. * @@ -1820,7 +1830,7 @@ private[spark] object Utils extends Logging { /** Return the class name of the given object, removing all dollar signs */ def getFormattedClassName(obj: AnyRef): String = { - obj.getClass.getSimpleName.replace("$", "") + getSimpleName(obj.getClass).replace("$", "") } /** @@ -2715,6 +2725,62 @@ private[spark] object Utils extends Logging { HashCodes.fromBytes(secretBytes).toString() } + /** + * Safer than Class obj's getSimpleName which may throw Malformed class name error in scala. + * This method mimicks scalatest's getSimpleNameOfAnObjectsClass. + */ + def getSimpleName(cls: Class[_]): String = { + try { + return cls.getSimpleName + } catch { + case err: InternalError => return stripDollars(stripPackages(cls.getName)) + } + } + + /** + * Remove the packages from full qualified class name + */ + private def stripPackages(fullyQualifiedName: String): String = { + fullyQualifiedName.split("\\.").takeRight(1)(0) + } + + /** + * Remove trailing dollar signs from qualified class name, + * and return the trailing part after the last dollar sign in the middle + */ + private def stripDollars(s: String): String = { + val lastDollarIndex = s.lastIndexOf('$') + if (lastDollarIndex < s.length - 1) { + // The last char is not a dollar sign + if (lastDollarIndex == -1 || !s.contains("$iw")) { + // The name does not have dollar sign or is not an intepreter + // generated class, so we should return the full string + s + } else { + // The class name is intepreter generated, + // return the part after the last dollar sign + // This is the same behavior as getClass.getSimpleName + s.substring(lastDollarIndex + 1) + } + } + else { + // The last char is a dollar sign + // Find last non-dollar char + val lastNonDollarChar = s.reverse.find(_ != '$') + lastNonDollarChar match { + case None => s + case Some(c) => + val lastNonDollarIndex = s.lastIndexOf(c) + if (lastNonDollarIndex == -1) { + s + } else { + // Strip the trailing dollar signs + // Invoke stripDollars again to get the simple name + stripDollars(s.substring(0, lastNonDollarIndex + 1)) + } + } + } + } } private[util] object CallerContext extends Logging { diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index c145532328514..85ffdca436e14 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -129,7 +129,6 @@ public int compare( final UnsafeSorterIterator iter = sorter.getSortedIterator(); int iterLength = 0; long prevPrefix = -1; - Arrays.sort(dataToSort); while (iter.hasNext()) { iter.loadNext(); final String str = diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 88916488c0def..b705556e54b14 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark import java.util.concurrent.{ExecutorService, TimeUnit} -import scala.collection.Map import scala.collection.mutable import scala.concurrent.Future import scala.concurrent.duration._ @@ -73,6 +72,7 @@ class HeartbeatReceiverSuite sc = spy(new SparkContext(conf)) scheduler = mock(classOf[TaskSchedulerImpl]) when(sc.taskScheduler).thenReturn(scheduler) + when(scheduler.nodeBlacklist).thenReturn(Predef.Set[String]()) when(scheduler.sc).thenReturn(sc) heartbeatReceiverClock = new ManualClock heartbeatReceiver = new HeartbeatReceiver(sc, heartbeatReceiverClock) @@ -241,7 +241,7 @@ class HeartbeatReceiverSuite } === Some(true)) } - private def getTrackedExecutors: Map[String, Long] = { + private def getTrackedExecutors: collection.Map[String, Long] = { // We may receive undesired SparkListenerExecutorAdded from LocalSchedulerBackend, // so exclude it from the map. See SPARK-10800. heartbeatReceiver.invokePrivate(_executorLastSeen()). @@ -272,7 +272,7 @@ private class FakeSchedulerBackend( protected override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = { clusterManagerEndpoint.ask[Boolean]( - RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount, Set.empty[String])) + RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount, Set.empty)) } protected override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = { diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala index 8eabc2b3cb958..5dbfc5c10a6f8 100644 --- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -18,8 +18,11 @@ package org.apache.spark import java.io.File +import java.util.UUID import javax.net.ssl.SSLContext +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.security.alias.{CredentialProvider, CredentialProviderFactory} import org.scalatest.BeforeAndAfterAll import org.apache.spark.util.SparkConfWithEnv @@ -40,6 +43,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { .toSet val conf = new SparkConf + val hadoopConf = new Configuration() conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", keyStorePath) conf.set("spark.ssl.keyStorePassword", "password") @@ -49,7 +53,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.ssl.enabledAlgorithms", algorithms.mkString(",")) conf.set("spark.ssl.protocol", "TLSv1.2") - val opts = SSLOptions.parse(conf, "spark.ssl") + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl") assert(opts.enabled === true) assert(opts.trustStore.isDefined === true) @@ -70,6 +74,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath val conf = new SparkConf + val hadoopConf = new Configuration() conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", keyStorePath) conf.set("spark.ssl.keyStorePassword", "password") @@ -80,8 +85,8 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") conf.set("spark.ssl.protocol", "SSLv3") - val defaultOpts = SSLOptions.parse(conf, "spark.ssl", defaults = None) - val opts = SSLOptions.parse(conf, "spark.ssl.ui", defaults = Some(defaultOpts)) + val defaultOpts = SSLOptions.parse(conf, hadoopConf, "spark.ssl", defaults = None) + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl.ui", defaults = Some(defaultOpts)) assert(opts.enabled === true) assert(opts.trustStore.isDefined === true) @@ -103,6 +108,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath val conf = new SparkConf + val hadoopConf = new Configuration() conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.ui.enabled", "false") conf.set("spark.ssl.ui.port", "4242") @@ -117,8 +123,8 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.ssl.ui.enabledAlgorithms", "ABC, DEF") conf.set("spark.ssl.protocol", "SSLv3") - val defaultOpts = SSLOptions.parse(conf, "spark.ssl", defaults = None) - val opts = SSLOptions.parse(conf, "spark.ssl.ui", defaults = Some(defaultOpts)) + val defaultOpts = SSLOptions.parse(conf, hadoopConf, "spark.ssl", defaults = None) + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl.ui", defaults = Some(defaultOpts)) assert(opts.enabled === false) assert(opts.port === Some(4242)) @@ -139,14 +145,71 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { val conf = new SparkConfWithEnv(Map( "ENV1" -> "val1", "ENV2" -> "val2")) + val hadoopConf = new Configuration() conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", "${env:ENV1}") conf.set("spark.ssl.trustStore", "${env:ENV2}") - val opts = SSLOptions.parse(conf, "spark.ssl", defaults = None) + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl", defaults = None) assert(opts.keyStore === Some(new File("val1"))) assert(opts.trustStore === Some(new File("val2"))) } + test("get password from Hadoop credential provider") { + val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath + val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath + + val conf = new SparkConf + val hadoopConf = new Configuration() + val tmpPath = s"localjceks://file${sys.props("java.io.tmpdir")}/test-" + + s"${UUID.randomUUID().toString}.jceks" + val provider = createCredentialProvider(tmpPath, hadoopConf) + + conf.set("spark.ssl.enabled", "true") + conf.set("spark.ssl.keyStore", keyStorePath) + storePassword(provider, "spark.ssl.keyStorePassword", "password") + storePassword(provider, "spark.ssl.keyPassword", "password") + conf.set("spark.ssl.trustStore", trustStorePath) + storePassword(provider, "spark.ssl.trustStorePassword", "password") + conf.set("spark.ssl.enabledAlgorithms", + "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") + conf.set("spark.ssl.protocol", "SSLv3") + + val defaultOpts = SSLOptions.parse(conf, hadoopConf, "spark.ssl", defaults = None) + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl.ui", defaults = Some(defaultOpts)) + + assert(opts.enabled === true) + assert(opts.trustStore.isDefined === true) + assert(opts.trustStore.get.getName === "truststore") + assert(opts.trustStore.get.getAbsolutePath === trustStorePath) + assert(opts.keyStore.isDefined === true) + assert(opts.keyStore.get.getName === "keystore") + assert(opts.keyStore.get.getAbsolutePath === keyStorePath) + assert(opts.trustStorePassword === Some("password")) + assert(opts.keyStorePassword === Some("password")) + assert(opts.keyPassword === Some("password")) + assert(opts.protocol === Some("SSLv3")) + assert(opts.enabledAlgorithms === + Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) + } + + private def createCredentialProvider(tmpPath: String, conf: Configuration): CredentialProvider = { + conf.set(CredentialProviderFactory.CREDENTIAL_PROVIDER_PATH, tmpPath) + + val provider = CredentialProviderFactory.getProviders(conf).get(0) + if (provider == null) { + throw new IllegalStateException(s"Fail to get credential provider with path $tmpPath") + } + + provider + } + + private def storePassword( + provider: CredentialProvider, + passwordKey: String, + password: String): Unit = { + provider.createCredentialEntry(passwordKey, password.toCharArray) + provider.flush() + } } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 191c61250ce21..5148ce05bd918 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -154,6 +154,16 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { } } + test("SPARK-23778: empty RDD in union should not produce a UnionRDD") { + val rddWithPartitioner = sc.parallelize(Seq(1 -> true)).partitionBy(new HashPartitioner(1)) + val emptyRDD = sc.emptyRDD[(Int, Boolean)] + val unionRDD = sc.union(emptyRDD, rddWithPartitioner) + assert(unionRDD.isInstanceOf[PartitionerAwareUnionRDD[_]]) + val unionAllEmptyRDD = sc.union(emptyRDD, emptyRDD) + assert(unionAllEmptyRDD.isInstanceOf[UnionRDD[_]]) + assert(unionAllEmptyRDD.collect().isEmpty) + } + test("partitioner aware union") { def makeRDDWithPartitioner(seq: Seq[Int]): RDD[Int] = { sc.makeRDD(seq, 1) @@ -1047,7 +1057,9 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { private class CyclicalDependencyRDD[T: ClassTag] extends RDD[T](sc, Nil) { private val mutableDependencies: ArrayBuffer[Dependency[_]] = ArrayBuffer.empty override def compute(p: Partition, c: TaskContext): Iterator[T] = Iterator.empty - override def getPartitions: Array[Partition] = Array.empty + override def getPartitions: Array[Partition] = Array(new Partition { + override def index: Int = 0 + }) override def getDependencies: Seq[Dependency[_]] = mutableDependencies def addDependency(dep: Dependency[_]) { mutableDependencies += dep diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index 2155a0f2b6c21..354e6386fa60e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -188,4 +188,32 @@ class MapStatusSuite extends SparkFunSuite { assert(count === 3000) } } + + test("SPARK-24519: HighlyCompressedMapStatus has configurable threshold") { + val conf = new SparkConf() + val env = mock(classOf[SparkEnv]) + doReturn(conf).when(env).conf + SparkEnv.set(env) + val sizes = Array.fill[Long](500)(150L) + // Test default value + val status = MapStatus(null, sizes) + assert(status.isInstanceOf[CompressedMapStatus]) + // Test Non-positive values + for (s <- -1 to 0) { + assertThrows[IllegalArgumentException] { + conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s) + val status = MapStatus(null, sizes) + } + } + // Test positive values + Seq(1, 100, 499, 500, 501).foreach { s => + conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s) + val status = MapStatus(null, sizes) + if(sizes.length > s) { + assert(status.isInstanceOf[HighlyCompressedMapStatus]) + } else { + assert(status.isInstanceOf[CompressedMapStatus]) + } + } + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index 03b1903902491..158c9eb75f2b6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -35,6 +35,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark._ import org.apache.spark.internal.io.{FileCommitProtocol, HadoopMapRedCommitProtocol, SparkHadoopWriterUtils} import org.apache.spark.rdd.{FakeOutputCommitter, RDD} +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.{ThreadUtils, Utils} /** @@ -153,7 +154,7 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { test("Job should not complete if all commits are denied") { // Create a mock OutputCommitCoordinator that denies all attempts to commit doReturn(false).when(outputCommitCoordinator).handleAskPermissionToCommit( - Matchers.any(), Matchers.any(), Matchers.any()) + Matchers.any(), Matchers.any(), Matchers.any(), Matchers.any()) val rdd: RDD[Int] = sc.parallelize(Seq(1), 1) def resultHandler(x: Int, y: Unit): Unit = {} val futureAction: SimpleFutureAction[Unit] = sc.submitJob[Int, Unit, Unit](rdd, @@ -169,45 +170,106 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { test("Only authorized committer failures can clear the authorized committer lock (SPARK-6614)") { val stage: Int = 1 + val stageAttempt: Int = 1 val partition: Int = 2 val authorizedCommitter: Int = 3 val nonAuthorizedCommitter: Int = 100 outputCommitCoordinator.stageStart(stage, maxPartitionId = 2) - assert(outputCommitCoordinator.canCommit(stage, partition, authorizedCommitter)) - assert(!outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter)) + assert(outputCommitCoordinator.canCommit(stage, stageAttempt, partition, authorizedCommitter)) + assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition, + nonAuthorizedCommitter)) // The non-authorized committer fails - outputCommitCoordinator.taskCompleted( - stage, partition, attemptNumber = nonAuthorizedCommitter, reason = TaskKilled("test")) + outputCommitCoordinator.taskCompleted(stage, stageAttempt, partition, + attemptNumber = nonAuthorizedCommitter, reason = TaskKilled("test")) // New tasks should still not be able to commit because the authorized committer has not failed - assert( - !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 1)) + assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition, + nonAuthorizedCommitter + 1)) // The authorized committer now fails, clearing the lock - outputCommitCoordinator.taskCompleted( - stage, partition, attemptNumber = authorizedCommitter, reason = TaskKilled("test")) + outputCommitCoordinator.taskCompleted(stage, stageAttempt, partition, + attemptNumber = authorizedCommitter, reason = TaskKilled("test")) // A new task should now be allowed to become the authorized committer - assert( - outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 2)) + assert(outputCommitCoordinator.canCommit(stage, stageAttempt, partition, + nonAuthorizedCommitter + 2)) // There can only be one authorized committer - assert( - !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 3)) - } - - test("Duplicate calls to canCommit from the authorized committer gets idempotent responses.") { - val rdd = sc.parallelize(Seq(1), 1) - sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).callCanCommitMultipleTimes _, - 0 until rdd.partitions.size) + assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition, + nonAuthorizedCommitter + 3)) } test("SPARK-19631: Do not allow failed attempts to be authorized for committing") { val stage: Int = 1 + val stageAttempt: Int = 1 val partition: Int = 1 val failedAttempt: Int = 0 outputCommitCoordinator.stageStart(stage, maxPartitionId = 1) - outputCommitCoordinator.taskCompleted(stage, partition, attemptNumber = failedAttempt, + outputCommitCoordinator.taskCompleted(stage, stageAttempt, partition, + attemptNumber = failedAttempt, reason = ExecutorLostFailure("0", exitCausedByApp = true, None)) - assert(!outputCommitCoordinator.canCommit(stage, partition, failedAttempt)) - assert(outputCommitCoordinator.canCommit(stage, partition, failedAttempt + 1)) + assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition, failedAttempt)) + assert(outputCommitCoordinator.canCommit(stage, stageAttempt, partition, failedAttempt + 1)) + } + + test("SPARK-24589: Differentiate tasks from different stage attempts") { + var stage = 1 + val taskAttempt = 1 + val partition = 1 + + outputCommitCoordinator.stageStart(stage, maxPartitionId = 1) + assert(outputCommitCoordinator.canCommit(stage, 1, partition, taskAttempt)) + assert(!outputCommitCoordinator.canCommit(stage, 2, partition, taskAttempt)) + + // Fail the task in the first attempt, the task in the second attempt should succeed. + stage += 1 + outputCommitCoordinator.stageStart(stage, maxPartitionId = 1) + outputCommitCoordinator.taskCompleted(stage, 1, partition, taskAttempt, + ExecutorLostFailure("0", exitCausedByApp = true, None)) + assert(!outputCommitCoordinator.canCommit(stage, 1, partition, taskAttempt)) + assert(outputCommitCoordinator.canCommit(stage, 2, partition, taskAttempt)) + + // Commit the 1st attempt, fail the 2nd attempt, make sure 3rd attempt cannot commit, + // then fail the 1st attempt and make sure the 4th one can commit again. + stage += 1 + outputCommitCoordinator.stageStart(stage, maxPartitionId = 1) + assert(outputCommitCoordinator.canCommit(stage, 1, partition, taskAttempt)) + outputCommitCoordinator.taskCompleted(stage, 2, partition, taskAttempt, + ExecutorLostFailure("0", exitCausedByApp = true, None)) + assert(!outputCommitCoordinator.canCommit(stage, 3, partition, taskAttempt)) + outputCommitCoordinator.taskCompleted(stage, 1, partition, taskAttempt, + ExecutorLostFailure("0", exitCausedByApp = true, None)) + assert(outputCommitCoordinator.canCommit(stage, 4, partition, taskAttempt)) + } + + test("SPARK-24589: Make sure stage state is cleaned up") { + // Normal application without stage failures. + sc.parallelize(1 to 100, 100) + .map { i => (i % 10, i) } + .reduceByKey(_ + _) + .collect() + + assert(sc.dagScheduler.outputCommitCoordinator.isEmpty) + + // Force failures in a few tasks so that a stage is retried. Collect the ID of the failing + // stage so that we can check the state of the output committer. + val retriedStage = sc.parallelize(1 to 100, 10) + .map { i => (i % 10, i) } + .reduceByKey { case (_, _) => + val ctx = TaskContext.get() + if (ctx.stageAttemptNumber() == 0) { + throw new FetchFailedException(SparkEnv.get.blockManager.blockManagerId, 1, 1, 1, + new Exception("Failure for test.")) + } else { + ctx.stageId() + } + } + .collect() + .map { case (k, v) => v } + .toSet + + assert(retriedStage.size === 1) + assert(sc.dagScheduler.outputCommitCoordinator.isEmpty) + verify(sc.env.outputCommitCoordinator, times(2)) + .stageStart(Matchers.eq(retriedStage.head), Matchers.any()) + verify(sc.env.outputCommitCoordinator).stageEnd(Matchers.eq(retriedStage.head)) } } @@ -243,16 +305,6 @@ private case class OutputCommitFunctions(tempDirPath: String) { if (ctx.attemptNumber == 0) failingOutputCommitter else successfulOutputCommitter) } - // Receiver should be idempotent for AskPermissionToCommitOutput - def callCanCommitMultipleTimes(iter: Iterator[Int]): Unit = { - val ctx = TaskContext.get() - val canCommit1 = SparkEnv.get.outputCommitCoordinator - .canCommit(ctx.stageId(), ctx.partitionId(), ctx.attemptNumber()) - val canCommit2 = SparkEnv.get.outputCommitCoordinator - .canCommit(ctx.stageId(), ctx.partitionId(), ctx.attemptNumber()) - assert(canCommit1 && canCommit2) - } - private def runCommitWithProvidedCommitter( ctx: TaskContext, iter: Iterator[Int], diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 3b4273184f1e9..418d2f9b88500 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -1168,6 +1168,22 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { Utils.checkAndGetK8sMasterUrl("k8s://foo://host:port") } } + + object MalformedClassObject { + class MalformedClass + } + + test("Safe getSimpleName") { + // getSimpleName on class of MalformedClass will result in error: Malformed class name + // Utils.getSimpleName works + val err = intercept[java.lang.InternalError] { + classOf[MalformedClassObject.MalformedClass].getSimpleName + } + assert(err.getMessage === "Malformed class name") + + assert(Utils.getSimpleName(classOf[MalformedClassObject.MalformedClass]) === + "UtilsSuite$MalformedClassObject$MalformedClass") + } } private class SimpleExtension diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 9552d001a079c..23b24212b4d29 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -106,3 +106,4 @@ spark-warehouse structured-streaming/* kafka-source-initial-offset-version-2.1.0.bin kafka-source-initial-offset-future-version.bin +vote.tmpl diff --git a/dev/create-release/do-release-docker.sh b/dev/create-release/do-release-docker.sh new file mode 100755 index 0000000000000..fa7b73cdb40ec --- /dev/null +++ b/dev/create-release/do-release-docker.sh @@ -0,0 +1,143 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# Creates a Spark release candidate. The script will update versions, tag the branch, +# build Spark binary packages and documentation, and upload maven artifacts to a staging +# repository. There is also a dry run mode where only local builds are performed, and +# nothing is uploaded to the ASF repos. +# +# Run with "-h" for options. +# + +set -e +SELF=$(cd $(dirname $0) && pwd) +. "$SELF/release-util.sh" + +function usage { + local NAME=$(basename $0) + cat < "$GPG_KEY_FILE" + +run_silent "Building spark-rm image with tag $IMGTAG..." "docker-build.log" \ + docker build -t "spark-rm:$IMGTAG" --build-arg UID=$UID "$SELF/spark-rm" + +# Write the release information to a file with environment variables to be used when running the +# image. +ENVFILE="$WORKDIR/env.list" +fcreate_secure "$ENVFILE" + +function cleanup { + rm -f "$ENVFILE" + rm -f "$GPG_KEY_FILE" +} + +trap cleanup EXIT + +cat > $ENVFILE <> $ENVFILE + JAVA_VOL="--volume $JAVA:/opt/spark-java" +fi + +echo "Building $RELEASE_TAG; output will be at $WORKDIR/output" +docker run -ti \ + --env-file "$ENVFILE" \ + --volume "$WORKDIR:/opt/spark-rm" \ + $JAVA_VOL \ + "spark-rm:$IMGTAG" diff --git a/dev/create-release/do-release.sh b/dev/create-release/do-release.sh new file mode 100755 index 0000000000000..f1d4f3ab5ddec --- /dev/null +++ b/dev/create-release/do-release.sh @@ -0,0 +1,81 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +SELF=$(cd $(dirname $0) && pwd) +. "$SELF/release-util.sh" + +while getopts "bn" opt; do + case $opt in + b) GIT_BRANCH=$OPTARG ;; + n) DRY_RUN=1 ;; + ?) error "Invalid option: $OPTARG" ;; + esac +done + +if [ "$RUNNING_IN_DOCKER" = "1" ]; then + # Inside docker, need to import the GPG key stored in the current directory. + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --import "$SELF/gpg.key" + + # We may need to adjust the path since JAVA_HOME may be overridden by the driver script. + if [ -n "$JAVA_HOME" ]; then + export PATH="$JAVA_HOME/bin:$PATH" + else + # JAVA_HOME for the openjdk package. + export JAVA_HOME=/usr + fi +else + # Outside docker, need to ask for information about the release. + get_release_info +fi + +function should_build { + local WHAT=$1 + [ -z "$RELEASE_STEP" ] || [ "$WHAT" = "$RELEASE_STEP" ] +} + +if should_build "tag" && [ $SKIP_TAG = 0 ]; then + run_silent "Creating release tag $RELEASE_TAG..." "tag.log" \ + "$SELF/release-tag.sh" + echo "It may take some time for the tag to be synchronized to github." + echo "Press enter when you've verified that the new tag ($RELEASE_TAG) is available." + read +else + echo "Skipping tag creation for $RELEASE_TAG." +fi + +if should_build "build"; then + run_silent "Building Spark..." "build.log" \ + "$SELF/release-build.sh" package +else + echo "Skipping build step." +fi + +if should_build "docs"; then + run_silent "Building documentation..." "docs.log" \ + "$SELF/release-build.sh" docs +else + echo "Skipping docs step." +fi + +if should_build "publish"; then + run_silent "Publishing release" "publish.log" \ + "$SELF/release-build.sh" publish-release +else + echo "Skipping publish step." +fi diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 5faa3d3260a56..24a62a8f4c7d3 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -17,6 +17,9 @@ # limitations under the License. # +SELF=$(cd $(dirname $0) && pwd) +. "$SELF/release-util.sh" + function exit_with_usage { cat << EOF usage: release-build.sh @@ -87,49 +90,56 @@ NEXUS_ROOT=https://repository.apache.org/service/local/staging NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads BASE_DIR=$(pwd) -MVN="build/mvn --force" - -# Hive-specific profiles for some builds -HIVE_PROFILES="-Phive -Phive-thriftserver" -# Profiles for publishing snapshots and release to Maven Central -PUBLISH_PROFILES="-Pmesos -Pyarn -Pkubernetes -Pflume $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" -# Profiles for building binary releases -BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Pkubernetes -Pflume -Psparkr" -# Scala 2.11 only profiles for some builds -SCALA_2_11_PROFILES="-Pkafka-0-8" -# Scala 2.12 only profiles for some builds -SCALA_2_12_PROFILES="-Pscala-2.12" +init_java +init_maven_sbt rm -rf spark -git clone https://git-wip-us.apache.org/repos/asf/spark.git +git clone "$ASF_REPO" cd spark git checkout $GIT_REF git_hash=`git rev-parse --short HEAD` echo "Checked out Spark git hash $git_hash" if [ -z "$SPARK_VERSION" ]; then - SPARK_VERSION=$($MVN help:evaluate -Dexpression=project.version \ - | grep -v INFO | grep -v WARNING | grep -v Download) + # Run $MVN in a separate command so that 'set -e' does the right thing. + TMP=$(mktemp) + $MVN help:evaluate -Dexpression=project.version > $TMP + SPARK_VERSION=$(cat $TMP | grep -v INFO | grep -v WARNING | grep -v Download) + rm $TMP fi -# Verify we have the right java version set -if [ -z "$JAVA_HOME" ]; then - echo "Please set JAVA_HOME." - exit 1 +# Depending on the version being built, certain extra profiles need to be activated, and +# different versions of Scala are supported. +BASE_PROFILES="-Pmesos -Pyarn" +PUBLISH_SCALA_2_10=0 +SCALA_2_10_PROFILES="-Pscala-2.10" +SCALA_2_11_PROFILES= +SCALA_2_12_PROFILES="-Pscala-2.12" + +if [[ $SPARK_VERSION > "2.3" ]]; then + BASE_PROFILES="$BASE_PROFILES -Pkubernetes -Pflume" + SCALA_2_11_PROFILES="-Pkafka-0-8" +else + PUBLISH_SCALA_2_10=1 fi -java_version=$("${JAVA_HOME}"/bin/javac -version 2>&1 | cut -d " " -f 2) +# Hive-specific profiles for some builds +HIVE_PROFILES="-Phive -Phive-thriftserver" +# Profiles for publishing snapshots and release to Maven Central +PUBLISH_PROFILES="$BASE_PROFILES $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" +# Profiles for building binary releases +BASE_RELEASE_PROFILES="$BASE_PROFILES -Psparkr" if [[ ! $SPARK_VERSION < "2.2." ]]; then - if [[ $java_version < "1.8." ]]; then - echo "Java version $java_version is less than required 1.8 for 2.2+" + if [[ $JAVA_VERSION < "1.8." ]]; then + echo "Java version $JAVA_VERSION is less than required 1.8 for 2.2+" echo "Please set JAVA_HOME correctly." exit 1 fi else - if [[ $java_version > "1.7." ]]; then + if ! [[ $JAVA_VERSION =~ 1\.7\..* ]]; then if [ -z "$JAVA_7_HOME" ]; then - echo "Java version $java_version is higher than required 1.7 for pre-2.2" + echo "Java version $JAVA_VERSION is higher than required 1.7 for pre-2.2" echo "Please set JAVA_HOME correctly." exit 1 else @@ -174,8 +184,9 @@ if [[ "$1" == "package" ]]; then FLAGS=$2 ZINC_PORT=$3 BUILD_PACKAGE=$4 - cp -r spark spark-$SPARK_VERSION-bin-$NAME + echo "Building binary dist $NAME" + cp -r spark spark-$SPARK_VERSION-bin-$NAME cd spark-$SPARK_VERSION-bin-$NAME # TODO There should probably be a flag to make-distribution to allow 2.12 support @@ -244,31 +255,39 @@ if [[ "$1" == "package" ]]; then spark-$SPARK_VERSION-bin-$NAME.tgz.sha512 } - # TODO: Check exit codes of children here: - # http://stackoverflow.com/questions/1570262/shell-get-exit-code-of-background-process - # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds # share the same Zinc server. - make_binary_release "hadoop2.6" "-Phadoop-2.6 $HIVE_PROFILES $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3035" "withr" & - make_binary_release "hadoop2.7" "-Phadoop-2.7 $HIVE_PROFILES $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3036" "withpip" & - make_binary_release "without-hadoop" "-Phadoop-provided $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3038" & - wait - rm -rf spark-$SPARK_VERSION-bin-*/ + if ! make_binary_release "hadoop2.6" "$MVN_EXTRA_OPTS -B -Phadoop-2.6 $HIVE_PROFILES $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3035" "withr"; then + error "Failed to build hadoop2.6 package. Check logs for details." + fi + if ! is_dry_run; then + if ! make_binary_release "hadoop2.7" "$MVN_EXTRA_OPTS -B -Phadoop-2.7 $HIVE_PROFILES $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3036" "withpip"; then + error "Failed to build hadoop2.7 package. Check logs for details." + fi + if ! make_binary_release "without-hadoop" "$MVN_EXTRA_OPTS -B -Phadoop-provided $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3037"; then + error "Failed to build without-hadoop package. Check logs for details." + fi + fi - svn co --depth=empty $RELEASE_STAGING_LOCATION svn-spark - rm -rf "svn-spark/${DEST_DIR_NAME}-bin" - mkdir -p "svn-spark/${DEST_DIR_NAME}-bin" + rm -rf spark-$SPARK_VERSION-bin-*/ - echo "Copying release tarballs" - cp spark-* "svn-spark/${DEST_DIR_NAME}-bin/" - cp pyspark-* "svn-spark/${DEST_DIR_NAME}-bin/" - cp SparkR_* "svn-spark/${DEST_DIR_NAME}-bin/" - svn add "svn-spark/${DEST_DIR_NAME}-bin" + if ! is_dry_run; then + svn co --depth=empty $RELEASE_STAGING_LOCATION svn-spark + rm -rf "svn-spark/${DEST_DIR_NAME}-bin" + mkdir -p "svn-spark/${DEST_DIR_NAME}-bin" + + echo "Copying release tarballs" + cp spark-* "svn-spark/${DEST_DIR_NAME}-bin/" + cp pyspark-* "svn-spark/${DEST_DIR_NAME}-bin/" + cp SparkR_* "svn-spark/${DEST_DIR_NAME}-bin/" + svn add "svn-spark/${DEST_DIR_NAME}-bin" + + cd svn-spark + svn ci --username $ASF_USERNAME --password "$ASF_PASSWORD" -m"Apache Spark $SPARK_PACKAGE_VERSION" + cd .. + rm -rf svn-spark + fi - cd svn-spark - svn ci --username $ASF_USERNAME --password "$ASF_PASSWORD" -m"Apache Spark $SPARK_PACKAGE_VERSION" - cd .. - rm -rf svn-spark exit 0 fi @@ -282,18 +301,22 @@ if [[ "$1" == "docs" ]]; then cd .. cd .. - svn co --depth=empty $RELEASE_STAGING_LOCATION svn-spark - rm -rf "svn-spark/${DEST_DIR_NAME}-docs" - mkdir -p "svn-spark/${DEST_DIR_NAME}-docs" + if ! is_dry_run; then + svn co --depth=empty $RELEASE_STAGING_LOCATION svn-spark + rm -rf "svn-spark/${DEST_DIR_NAME}-docs" + mkdir -p "svn-spark/${DEST_DIR_NAME}-docs" - echo "Copying release documentation" - cp -R "spark/docs/_site" "svn-spark/${DEST_DIR_NAME}-docs/" - svn add "svn-spark/${DEST_DIR_NAME}-docs" + echo "Copying release documentation" + cp -R "spark/docs/_site" "svn-spark/${DEST_DIR_NAME}-docs/" + svn add "svn-spark/${DEST_DIR_NAME}-docs" - cd svn-spark - svn ci --username $ASF_USERNAME --password "$ASF_PASSWORD" -m"Apache Spark $SPARK_PACKAGE_VERSION docs" - cd .. - rm -rf svn-spark + cd svn-spark + svn ci --username $ASF_USERNAME --password "$ASF_PASSWORD" -m"Apache Spark $SPARK_PACKAGE_VERSION docs" + cd .. + rm -rf svn-spark + fi + + mv "spark/docs/_site" docs/ exit 0 fi @@ -341,13 +364,15 @@ if [[ "$1" == "publish-release" ]]; then # Using Nexus API documented here: # https://support.sonatype.com/entries/39720203-Uploading-to-a-Staging-Repository-via-REST-API - echo "Creating Nexus staging repository" - repo_request="Apache Spark $SPARK_VERSION (commit $git_hash)" - out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ - -H "Content-Type:application/xml" -v \ - $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start) - staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/") - echo "Created Nexus staging repository: $staged_repo_id" + if ! is_dry_run; then + echo "Creating Nexus staging repository" + repo_request="Apache Spark $SPARK_VERSION (commit $git_hash)" + out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ + -H "Content-Type:application/xml" -v \ + $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start) + staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/") + echo "Created Nexus staging repository: $staged_repo_id" + fi tmp_repo=$(mktemp -d spark-repo-XXXXX) @@ -356,6 +381,12 @@ if [[ "$1" == "publish-release" ]]; then $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -DskipTests $SCALA_2_11_PROFILES $PUBLISH_PROFILES clean install + if ! is_dry_run && [[ $PUBLISH_SCALA_2_10 = 1 ]]; then + ./dev/change-scala-version.sh 2.10 + $MVN -DzincPort=$((ZINC_PORT + 1)) -Dmaven.repo.local=$tmp_repo -Dscala-2.10 \ + -DskipTests $PUBLISH_PROFILES $SCALA_2_10_PROFILES clean install + fi + #./dev/change-scala-version.sh 2.12 #$MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo \ # -DskipTests $SCALA_2_12_PROFILES §$PUBLISH_PROFILES clean install @@ -386,23 +417,26 @@ if [[ "$1" == "publish-release" ]]; then sha1sum $file | cut -f1 -d' ' > $file.sha1 done - nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id - echo "Uplading files to $nexus_upload" - for file in $(find . -type f) - do - # strip leading ./ - file_short=$(echo $file | sed -e "s/\.\///") - dest_url="$nexus_upload/org/apache/spark/$file_short" - echo " Uploading $file_short" - curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url - done + if ! is_dry_run; then + nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id + echo "Uplading files to $nexus_upload" + for file in $(find . -type f) + do + # strip leading ./ + file_short=$(echo $file | sed -e "s/\.\///") + dest_url="$nexus_upload/org/apache/spark/$file_short" + echo " Uploading $file_short" + curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url + done + + echo "Closing nexus staging repository" + repo_request="$staged_repo_idApache Spark $SPARK_VERSION (commit $git_hash)" + out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ + -H "Content-Type:application/xml" -v \ + $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish) + echo "Closed Nexus staging repository: $staged_repo_id" + fi - echo "Closing nexus staging repository" - repo_request="$staged_repo_idApache Spark $SPARK_VERSION (commit $git_hash)" - out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ - -H "Content-Type:application/xml" -v \ - $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish) - echo "Closed Nexus staging repository: $staged_repo_id" popd rm -rf $tmp_repo cd .. diff --git a/dev/create-release/release-tag.sh b/dev/create-release/release-tag.sh index a05716a5f66bb..628bc0504c9c8 100755 --- a/dev/create-release/release-tag.sh +++ b/dev/create-release/release-tag.sh @@ -17,6 +17,9 @@ # limitations under the License. # +SELF=$(cd $(dirname $0) && pwd) +. "$SELF/release-util.sh" + function exit_with_usage { cat << EOF usage: tag-release.sh @@ -36,6 +39,7 @@ EOF } set -e +set -o pipefail if [[ $@ == *"help"* ]]; then exit_with_usage @@ -54,8 +58,10 @@ for env in ASF_USERNAME ASF_PASSWORD RELEASE_VERSION RELEASE_TAG NEXT_VERSION GI fi done +init_java +init_maven_sbt + ASF_SPARK_REPO="git-wip-us.apache.org/repos/asf/spark.git" -MVN="build/mvn --force" rm -rf spark git clone "https://$ASF_USERNAME:$ASF_PASSWORD@$ASF_SPARK_REPO" -b $GIT_BRANCH @@ -94,9 +100,15 @@ sed -i".tmp7" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$R_NEXT_VERSION" git commit -a -m "Preparing development version $NEXT_VERSION" -# Push changes -git push origin $RELEASE_TAG -git push origin HEAD:$GIT_BRANCH - -cd .. -rm -rf spark +if ! is_dry_run; then + # Push changes + git push origin $RELEASE_TAG + git push origin HEAD:$GIT_BRANCH + + cd .. + rm -rf spark +else + cd .. + mv spark spark.tag + echo "Clone with version changes and tag available as spark.tag in the output directory." +fi diff --git a/dev/create-release/release-util.sh b/dev/create-release/release-util.sh new file mode 100644 index 0000000000000..7426b0d6ca08d --- /dev/null +++ b/dev/create-release/release-util.sh @@ -0,0 +1,228 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +DRY_RUN=${DRY_RUN:-0} +GPG="gpg --no-tty --batch" +ASF_REPO="https://git-wip-us.apache.org/repos/asf/spark.git" +ASF_REPO_WEBUI="https://git-wip-us.apache.org/repos/asf?p=spark.git" + +function error { + echo "$*" + exit 1 +} + +function read_config { + local PROMPT="$1" + local DEFAULT="$2" + local REPLY= + + read -p "$PROMPT [$DEFAULT]: " REPLY + local RETVAL="${REPLY:-$DEFAULT}" + if [ -z "$RETVAL" ]; then + error "$PROMPT is must be provided." + fi + echo "$RETVAL" +} + +function parse_version { + grep -e '.*' | \ + head -n 2 | tail -n 1 | cut -d'>' -f2 | cut -d '<' -f1 +} + +function run_silent { + local BANNER="$1" + local LOG_FILE="$2" + shift 2 + + echo "========================" + echo "= $BANNER" + echo "Command: $@" + echo "Log file: $LOG_FILE" + + "$@" 1>"$LOG_FILE" 2>&1 + + local EC=$? + if [ $EC != 0 ]; then + echo "Command FAILED. Check full logs for details." + tail "$LOG_FILE" + exit $EC + fi +} + +function fcreate_secure { + local FPATH="$1" + rm -f "$FPATH" + touch "$FPATH" + chmod 600 "$FPATH" +} + +function check_for_tag { + curl -s --head --fail "$ASF_REPO_WEBUI;a=commit;h=$1" >/dev/null +} + +function get_release_info { + if [ -z "$GIT_BRANCH" ]; then + # If no branch is specified, found out the latest branch from the repo. + GIT_BRANCH=$(git ls-remote --heads "$ASF_REPO" | + grep -v refs/heads/master | + awk '{print $2}' | + sort -r | + head -n 1 | + cut -d/ -f3) + fi + + export GIT_BRANCH=$(read_config "Branch" "$GIT_BRANCH") + + # Find the current version for the branch. + local VERSION=$(curl -s "$ASF_REPO_WEBUI;a=blob_plain;f=pom.xml;hb=refs/heads/$GIT_BRANCH" | + parse_version) + echo "Current branch version is $VERSION." + + if [[ ! $VERSION =~ .*-SNAPSHOT ]]; then + error "Not a SNAPSHOT version: $VERSION" + fi + + NEXT_VERSION="$VERSION" + RELEASE_VERSION="${VERSION/-SNAPSHOT/}" + SHORT_VERSION=$(echo "$VERSION" | cut -d . -f 1-2) + local REV=$(echo "$VERSION" | cut -d . -f 3) + + # Find out what rc is being prepared. + # - If the current version is "x.y.0", then this is rc1 of the "x.y.0" release. + # - If not, need to check whether the previous version has been already released or not. + # - If it has, then we're building rc1 of the current version. + # - If it has not, we're building the next RC of the previous version. + local RC_COUNT + if [ $REV != 0 ]; then + local PREV_REL_REV=$((REV - 1)) + local PREV_REL_TAG="v${SHORT_VERSION}.${PREV_REL_REV}" + if check_for_tag "$PREV_REL_TAG"; then + RC_COUNT=1 + REV=$((REV + 1)) + NEXT_VERSION="${SHORT_VERSION}.${REV}-SNAPSHOT" + else + RELEASE_VERSION="${SHORT_VERSION}.${PREV_REL_REV}" + RC_COUNT=$(git ls-remote --tags "$ASF_REPO" "v${RELEASE_VERSION}-rc*" | wc -l) + RC_COUNT=$((RC_COUNT + 1)) + fi + else + REV=$((REV + 1)) + NEXT_VERSION="${SHORT_VERSION}.${REV}-SNAPSHOT" + RC_COUNT=1 + fi + + export NEXT_VERSION + export RELEASE_VERSION=$(read_config "Release" "$RELEASE_VERSION") + + RC_COUNT=$(read_config "RC #" "$RC_COUNT") + + # Check if the RC already exists, and if re-creating the RC, skip tag creation. + RELEASE_TAG="v${RELEASE_VERSION}-rc${RC_COUNT}" + SKIP_TAG=0 + if check_for_tag "$RELEASE_TAG"; then + read -p "$RELEASE_TAG already exists. Continue anyway [y/n]? " ANSWER + if [ "$ANSWER" != "y" ]; then + error "Exiting." + fi + SKIP_TAG=1 + fi + + + export RELEASE_TAG + + GIT_REF="$RELEASE_TAG" + if is_dry_run; then + echo "This is a dry run. Please confirm the ref that will be built for testing." + GIT_REF=$(read_config "Ref" "$GIT_REF") + fi + export GIT_REF + export SPARK_PACKAGE_VERSION="$RELEASE_TAG" + + # Gather some user information. + export ASF_USERNAME=$(read_config "ASF user" "$LOGNAME") + + GIT_NAME=$(git config user.name || echo "") + export GIT_NAME=$(read_config "Full name" "$GIT_NAME") + + export GIT_EMAIL="$ASF_USERNAME@apache.org" + export GPG_KEY=$(read_config "GPG key" "$GIT_EMAIL") + + cat <&1 | cut -d " " -f 2) + export JAVA_VERSION +} + +# Initializes MVN_EXTRA_OPTS and SBT_OPTS depending on the JAVA_VERSION in use. Requires init_java. +function init_maven_sbt { + MVN="build/mvn -B" + MVN_EXTRA_OPTS= + SBT_OPTS= + if [[ $JAVA_VERSION < "1.8." ]]; then + # Needed for maven central when using Java 7. + SBT_OPTS="-Dhttps.protocols=TLSv1.1,TLSv1.2" + MVN_EXTRA_OPTS="-Dhttps.protocols=TLSv1.1,TLSv1.2" + MVN="$MVN $MVN_EXTRA_OPTS" + fi + export MVN MVN_EXTRA_OPTS SBT_OPTS +} diff --git a/dev/create-release/spark-rm/Dockerfile b/dev/create-release/spark-rm/Dockerfile new file mode 100644 index 0000000000000..07ce320177f5a --- /dev/null +++ b/dev/create-release/spark-rm/Dockerfile @@ -0,0 +1,87 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Image for building Spark releases. Based on Ubuntu 16.04. +# +# Includes: +# * Java 8 +# * Ivy +# * Python/PyPandoc (2.7.12/3.5.2) +# * R-base/R-base-dev (3.3.2+) +# * Ruby 2.3 build utilities + +FROM ubuntu:16.04 + +# These arguments are just for reuse and not really meant to be customized. +ARG APT_INSTALL="apt-get install --no-install-recommends -y" + +ARG BASE_PIP_PKGS="setuptools wheel virtualenv" +ARG PIP_PKGS="pyopenssl pypandoc numpy pygments sphinx" + +# Install extra needed repos and refresh. +# - CRAN repo +# - Ruby repo (for doc generation) +# +# This is all in a single "RUN" command so that if anything changes, "apt update" is run to fetch +# the most current package versions (instead of potentially using old versions cached by docker). +RUN echo 'deb http://cran.cnr.Berkeley.edu/bin/linux/ubuntu xenial/' >> /etc/apt/sources.list && \ + gpg --keyserver keyserver.ubuntu.com --recv-key E084DAB9 && \ + gpg -a --export E084DAB9 | apt-key add - && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* && \ + apt-get clean && \ + apt-get update && \ + $APT_INSTALL software-properties-common && \ + apt-add-repository -y ppa:brightbox/ruby-ng && \ + apt-get update && \ + # Install openjdk 8. + $APT_INSTALL openjdk-8-jdk && \ + update-alternatives --set java /usr/lib/jvm/java-8-openjdk-amd64/jre/bin/java && \ + # Install build / source control tools + $APT_INSTALL curl wget git maven ivy subversion make gcc lsof libffi-dev \ + pandoc pandoc-citeproc libssl-dev libcurl4-openssl-dev libxml2-dev && \ + ln -s -T /usr/share/java/ivy.jar /usr/share/ant/lib/ivy.jar && \ + curl -sL https://deb.nodesource.com/setup_4.x | bash && \ + $APT_INSTALL nodejs && \ + # Install needed python packages. Use pip for installing packages (for consistency). + $APT_INSTALL libpython2.7-dev libpython3-dev python-pip python3-pip && \ + pip install $BASE_PIP_PKGS && \ + pip install $PIP_PKGS && \ + cd && \ + virtualenv -p python3 p35 && \ + . p35/bin/activate && \ + pip install $BASE_PIP_PKGS && \ + pip install $PIP_PKGS && \ + # Install R packages and dependencies used when building. + # R depends on pandoc*, libssl (which are installed above). + $APT_INSTALL r-base r-base-dev && \ + $APT_INSTALL texlive-latex-base texlive texlive-fonts-extra texinfo qpdf && \ + Rscript -e "install.packages(c('curl', 'xml2', 'httr', 'devtools', 'testthat', 'knitr', 'rmarkdown', 'roxygen2', 'e1071', 'survival'), repos='http://cran.us.r-project.org/')" && \ + Rscript -e "devtools::install_github('jimhester/lintr')" && \ + # Install tools needed to build the documentation. + $APT_INSTALL ruby2.3 ruby2.3-dev && \ + gem install jekyll --no-rdoc --no-ri && \ + gem install jekyll-redirect-from && \ + gem install pygments.rb + +WORKDIR /opt/spark-rm/output + +ARG UID +RUN useradd -m -s /bin/bash -p spark-rm -u $UID spark-rm +USER spark-rm:spark-rm + +ENTRYPOINT [ "/opt/spark-rm/do-release.sh" ] diff --git a/dev/create-release/vote.tmpl b/dev/create-release/vote.tmpl new file mode 100644 index 0000000000000..2ce953c2f7ec4 --- /dev/null +++ b/dev/create-release/vote.tmpl @@ -0,0 +1,65 @@ +Please vote on releasing the following candidate as Apache Spark version {version}. + +The vote is open until {deadline} and passes if a majority +1 PMC votes are cast, with +a minimum of 3 +1 votes. + +[ ] +1 Release this package as Apache Spark {version} +[ ] -1 Do not release this package because ... + +To learn more about Apache Spark, please see http://spark.apache.org/ + +The tag to be voted on is {tag} (commit {tag_commit}): +https://github.com/apache/spark/tree/{tag} + +The release files, including signatures, digests, etc. can be found at: +https://dist.apache.org/repos/dist/dev/spark/{tag}-bin/ + +Signatures used for Spark RCs can be found in this file: +https://dist.apache.org/repos/dist/dev/spark/KEYS + +The staging repository for this release can be found at: +https://repository.apache.org/content/repositories/orgapachespark-{repo_id}/ + +The documentation corresponding to this release can be found at: +https://dist.apache.org/repos/dist/dev/spark/{tag}-docs/ + +The list of bug fixes going into {version} can be found at the following URL: +https://issues.apache.org/jira/projects/SPARK/versions/{jira_version_id} + +FAQ + +========================= +How can I help test this release? +========================= + +If you are a Spark user, you can help us test this release by taking +an existing Spark workload and running on this release candidate, then +reporting any regressions. + +If you're working in PySpark you can set up a virtual env and install +the current RC and see if anything important breaks, in the Java/Scala +you can add the staging repository to your projects resolvers and test +with the RC (make sure to clean up the artifact cache before/after so +you don't end up building with a out of date RC going forward). + +=========================================== +What should happen to JIRA tickets still targeting {version}? +=========================================== + +The current list of open tickets targeted at {version} can be found at: +{open_issues_link} + +Committers should look at those and triage. Extremely important bug +fixes, documentation, and API tweaks that impact compatibility should +be worked on immediately. Everything else please retarget to an +appropriate release. + +================== +But my bug isn't fixed? +================== + +In order to make timely releases, we will typically not hold the +release unless the bug in question is a regression from the previous +release. That being said, if there is something which is a regression +that has not been correctly targeted please ping me or a committer to +help target the issue. \ No newline at end of file diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 723180a14febb..96e9c27210d05 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -122,7 +122,7 @@ jersey-server-2.22.2.jar jets3t-0.9.4.jar jetty-6.1.26.jar jetty-util-6.1.26.jar -jline-2.12.1.jar +jline-2.14.3.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar @@ -172,10 +172,10 @@ parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar py4j-0.10.7.jar pyrolite-4.13.jar -scala-compiler-2.11.8.jar -scala-library-2.11.8.jar -scala-parser-combinators_2.11-1.0.4.jar -scala-reflect-2.11.8.jar +scala-compiler-2.11.12.jar +scala-library-2.11.12.jar +scala-parser-combinators_2.11-1.1.0.jar +scala-reflect-2.11.12.jar scala-xml_2.11-1.0.5.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index ea08a001a1c9b..4a6ee027ec355 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -122,7 +122,7 @@ jersey-server-2.22.2.jar jets3t-0.9.4.jar jetty-6.1.26.jar jetty-util-6.1.26.jar -jline-2.12.1.jar +jline-2.14.3.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar @@ -173,10 +173,10 @@ parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar py4j-0.10.7.jar pyrolite-4.13.jar -scala-compiler-2.11.8.jar -scala-library-2.11.8.jar -scala-parser-combinators_2.11-1.0.4.jar -scala-reflect-2.11.8.jar +scala-compiler-2.11.12.jar +scala-library-2.11.12.jar +scala-parser-combinators_2.11-1.1.0.jar +scala-reflect-2.11.12.jar scala-xml_2.11-1.0.5.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index da874026d7d10..e0b560c8ec71f 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -122,7 +122,7 @@ jersey-server-2.22.2.jar jets3t-0.9.4.jar jetty-webapp-9.3.20.v20170531.jar jetty-xml-9.3.20.v20170531.jar -jline-2.12.1.jar +jline-2.14.3.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar @@ -192,10 +192,10 @@ protobuf-java-2.5.0.jar py4j-0.10.7.jar pyrolite-4.13.jar re2j-1.1.jar -scala-compiler-2.11.8.jar -scala-library-2.11.8.jar -scala-parser-combinators_2.11-1.0.4.jar -scala-reflect-2.11.8.jar +scala-compiler-2.11.12.jar +scala-library-2.11.12.jar +scala-parser-combinators_2.11-1.1.0.jar +scala-reflect-2.11.12.jar scala-xml_2.11-1.0.5.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar diff --git a/dev/run-tests.py b/dev/run-tests.py index 5e8c8590b5c34..cd4590864b7d7 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -357,7 +357,7 @@ def build_spark_unidoc_sbt(hadoop_version): exec_sbt(profiles_and_goals) -def build_spark_assembly_sbt(hadoop_version): +def build_spark_assembly_sbt(hadoop_version, checkstyle=False): # Enable all of the profiles for the build: build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags sbt_goals = ["assembly/package"] @@ -366,6 +366,9 @@ def build_spark_assembly_sbt(hadoop_version): " ".join(profiles_and_goals)) exec_sbt(profiles_and_goals) + if checkstyle: + run_java_style_checks() + # Note that we skip Unidoc build only if Hadoop 2.6 is explicitly set in this SBT build. # Due to a different dependency resolution in SBT & Unidoc by an unknown reason, the # documentation build fails on a specific machine & environment in Jenkins but it was unable @@ -570,11 +573,13 @@ def main(): or f.endswith("scalastyle-config.xml") for f in changed_files): run_scala_style_checks() + should_run_java_style_checks = False if not changed_files or any(f.endswith(".java") or f.endswith("checkstyle.xml") or f.endswith("checkstyle-suppressions.xml") for f in changed_files): - run_java_style_checks() + # Run SBT Checkstyle after the build to prevent a side-effect to the build. + should_run_java_style_checks = True if not changed_files or any(f.endswith("lint-python") or f.endswith("tox.ini") or f.endswith(".py") @@ -603,7 +608,7 @@ def main(): detect_binary_inop_with_mima(hadoop_version) # Since we did not build assembly/package before running dev/mima, we need to # do it here because the tests still rely on it; see SPARK-13294 for details. - build_spark_assembly_sbt(hadoop_version) + build_spark_assembly_sbt(hadoop_version, should_run_java_style_checks) # run the test suites run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags) diff --git a/docs/building-spark.md b/docs/building-spark.md index 0236bb05849ad..c3bcd90ccc78f 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -215,19 +215,23 @@ If you are building Spark for use in a Python environment and you wish to pip in Alternatively, you can also run make-distribution with the --pip option. -## PySpark Tests with Maven +## PySpark Tests with Maven or SBT If you are building PySpark and wish to run the PySpark tests you will need to build Spark with Hive support. ./build/mvn -DskipTests clean package -Phive ./python/run-tests +If you are building PySpark with SBT and wish to run the PySpark tests, you will need to build Spark with Hive support and also build the test components: + + ./build/sbt -Phive clean package + ./build/sbt test:compile + ./python/run-tests + The run-tests script also can be limited to a specific Python version or a specific module ./python/run-tests --python-executables=python --modules=pyspark-sql -**Note:** You can also run Python tests with an sbt build, provided you build Spark with Hive support. - ## Running R Tests To run the SparkR tests you will need to install the [knitr](https://cran.r-project.org/package=knitr), [rmarkdown](https://cran.r-project.org/package=rmarkdown), [testthat](https://cran.r-project.org/package=testthat), [e1071](https://cran.r-project.org/package=e1071) and [survival](https://cran.r-project.org/package=survival) packages first: diff --git a/docs/cloud-integration.md b/docs/cloud-integration.md index ac1c336988930..18e8fe77bbdbe 100644 --- a/docs/cloud-integration.md +++ b/docs/cloud-integration.md @@ -70,7 +70,7 @@ be safely used as the direct destination of work with the normal rename-based co ### Installation With the relevant libraries on the classpath and Spark configured with valid credentials, -objects can be can be read or written by using their URLs as the path to data. +objects can be read or written by using their URLs as the path to data. For example `sparkContext.textFile("s3a://landsat-pds/scene_list.gz")` will create an RDD of the file `scene_list.gz` stored in S3, using the s3a connector. @@ -184,7 +184,8 @@ is no need for a workflow of write-then-rename to ensure that files aren't picke while they are still being written. Applications can write straight to the monitored directory. 1. Streams should only be checkpointed to a store implementing a fast and -atomic `rename()` operation Otherwise the checkpointing may be slow and potentially unreliable. +atomic `rename()` operation. +Otherwise the checkpointing may be slow and potentially unreliable. ## Further Reading diff --git a/docs/configuration.md b/docs/configuration.md index 5588c372d3e42..6aa7878fe614d 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1656,9 +1656,10 @@ Apart from these, the following properties are also available, and may be useful spark.blacklist.killBlacklistedExecutors false - (Experimental) If set to "true", allow Spark to automatically kill, and attempt to re-create, - executors when they are blacklisted. Note that, when an entire node is added to the blacklist, - all of the executors on that node will be killed. + (Experimental) If set to "true", allow Spark to automatically kill the executors + when they are blacklisted on fetch failure or blacklisted for the entire application, + as controlled by spark.blacklist.application.*. Note that, when an entire node is added + to the blacklist, all of the executors on that node will be killed. diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 4dbcbeafbbd9d..575da7205b529 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -411,6 +411,16 @@ To use a custom metrics.properties for the application master and executors, upd name matches both the include and the exclude pattern, this file will be excluded eventually. + + spark.yarn.blacklist.executor.launch.blacklisting.enabled + false + + Flag to enable blacklisting of nodes having YARN resource allocation problems. + The error limit for blacklisting can be configured by + spark.blacklist.application.maxFailedExecutorsPerNode. + + + # Important notes diff --git a/docs/security.md b/docs/security.md index 8c0c66fb5a285..6ef3a808e0471 100644 --- a/docs/security.md +++ b/docs/security.md @@ -177,7 +177,7 @@ ACLs can be configured for either users or groups. Configuration entries accept lists as input, meaning multiple users or groups can be given the desired privileges. This can be used if you run on a shared cluster and have a set of administrators or developers who need to monitor applications they may not have started themselves. A wildcard (`*`) added to specific ACL -means that all users will have the respective pivilege. By default, only the user submitting the +means that all users will have the respective privilege. By default, only the user submitting the application is added to the ACLs. Group membership is established by using a configurable group mapping provider. The mapper is @@ -446,6 +446,27 @@ replaced with one of the above namespaces. +Spark also supports retrieving `${ns}.keyPassword`, `${ns}.keyStorePassword` and `${ns}.trustStorePassword` from +[Hadoop Credential Providers](https://hadoop.apache.org/docs/current/hadoop-project-dist/hadoop-common/CredentialProviderAPI.html). +User could store password into credential file and make it accessible by different components, like: + +``` +hadoop credential create spark.ssl.keyPassword -value password \ + -provider jceks://hdfs@nn1.example.com:9001/user/backup/ssl.jceks +``` + +To configure the location of the credential provider, set the `hadoop.security.credential.provider.path` +config option in the Hadoop configuration used by Spark, like: + +``` + + hadoop.security.credential.provider.path + jceks://hdfs@nn1.example.com:9001/user/backup/ssl.jceks + +``` + +Or via SparkConf "spark.hadoop.hadoop.security.credential.provider.path=jceks://hdfs@nn1.example.com:9001/user/backup/ssl.jceks". + ## Preparing the key stores Key stores can be generated by `keytool` program. The reference documentation for this tool for diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 4d8a738507bd1..196b814420be1 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1752,14 +1752,10 @@ To use `groupBy().apply()`, the user needs to define the following: * A Python function that defines the computation for each group. * A `StructType` object or a string that defines the schema of the output `DataFrame`. -The output schema will be applied to the columns of the returned `pandas.DataFrame` in order by position, -not by name. This means that the columns in the `pandas.DataFrame` must be indexed so that their -position matches the corresponding field in the schema. - -Note that when creating a new `pandas.DataFrame` using a dictionary, the actual position of the column -can differ from the order that it was placed in the dictionary. It is recommended in this case to -explicitly define the column order using the `columns` keyword, e.g. -`pandas.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])`, or alternatively use an `OrderedDict`. +The column labels of the returned `pandas.DataFrame` must either match the field names in the +defined output schema if specified as strings, or match the field data types by position if not +strings, e.g. integer indices. See [pandas.DataFrame](https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.html#pandas.DataFrame) +on how to label columns when constructing a `pandas.DataFrame`. Note that all data for a group will be loaded into memory before the function is applied. This can lead to out of memory exceptons, especially if the group sizes are skewed. The configuration for @@ -1831,6 +1827,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. + - Since Spark 2.4, Spark has enabled non-cascading SQL cache invalidation in addition to the traditional cache invalidation mechanism. The non-cascading cache invalidation mechanism allows users to remove a cache without impacting its dependent caches. This new cache invalidation mechanism is used in scenarios where the data of the cache to be removed is still valid, e.g., calling unpersist() on a Dataset, or dropping a temporary view. This allows users to free up memory and keep the desired caches valid at the same time. - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files. - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala index ae5b5c52d514e..32923dc9f5a6b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala @@ -67,7 +67,7 @@ case class KafkaStreamWriterFactory( override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[InternalRow] = { new KafkaStreamDataWriter(topic, producerParams, schema.toAttributes) } diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index 1e34bb8c73279..d967aa39a4827 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -17,6 +17,7 @@ package org.apache.spark.launcher; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -54,10 +55,12 @@ public static void main(String[] argsArray) throws Exception { String className = args.remove(0); boolean printLaunchCommand = !isEmpty(System.getenv("SPARK_PRINT_LAUNCH_COMMAND")); - AbstractCommandBuilder builder; + Map env = new HashMap<>(); + List cmd; if (className.equals("org.apache.spark.deploy.SparkSubmit")) { try { - builder = new SparkSubmitCommandBuilder(args); + AbstractCommandBuilder builder = new SparkSubmitCommandBuilder(args); + cmd = buildCommand(builder, env, printLaunchCommand); } catch (IllegalArgumentException e) { printLaunchCommand = false; System.err.println("Error: " + e.getMessage()); @@ -76,17 +79,12 @@ public static void main(String[] argsArray) throws Exception { help.add(parser.className); } help.add(parser.USAGE_ERROR); - builder = new SparkSubmitCommandBuilder(help); + AbstractCommandBuilder builder = new SparkSubmitCommandBuilder(help); + cmd = buildCommand(builder, env, printLaunchCommand); } } else { - builder = new SparkClassCommandBuilder(className, args); - } - - Map env = new HashMap<>(); - List cmd = builder.buildCommand(env); - if (printLaunchCommand) { - System.err.println("Spark Command: " + join(" ", cmd)); - System.err.println("========================================"); + AbstractCommandBuilder builder = new SparkClassCommandBuilder(className, args); + cmd = buildCommand(builder, env, printLaunchCommand); } if (isWindows()) { @@ -101,6 +99,22 @@ public static void main(String[] argsArray) throws Exception { } } + /** + * Prepare spark commands with the appropriate command builder. + * If printLaunchCommand is set then the commands will be printed to the stderr. + */ + private static List buildCommand( + AbstractCommandBuilder builder, + Map env, + boolean printLaunchCommand) throws IOException, IllegalArgumentException { + List cmd = builder.buildCommand(env); + if (printLaunchCommand) { + System.err.println("Spark Command: " + join(" ", cmd)); + System.err.println("========================================"); + } + return cmd; + } + /** * Prepare a command line for execution from a Windows batch script. * diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 5cb6457bf5c21..cc65f78b45c30 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -90,7 +90,8 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { final List userArgs; private final List parsedArgs; - private final boolean requiresAppResource; + // Special command means no appResource and no mainClass required + private final boolean isSpecialCommand; private final boolean isExample; /** @@ -105,7 +106,7 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { * spark-submit argument list to be modified after creation. */ SparkSubmitCommandBuilder() { - this.requiresAppResource = true; + this.isSpecialCommand = false; this.isExample = false; this.parsedArgs = new ArrayList<>(); this.userArgs = new ArrayList<>(); @@ -138,25 +139,26 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { case RUN_EXAMPLE: isExample = true; + appResource = SparkLauncher.NO_RESOURCE; submitArgs = args.subList(1, args.size()); } this.isExample = isExample; OptionParser parser = new OptionParser(true); parser.parse(submitArgs); - this.requiresAppResource = parser.requiresAppResource; + this.isSpecialCommand = parser.isSpecialCommand; } else { this.isExample = isExample; - this.requiresAppResource = false; + this.isSpecialCommand = true; } } @Override public List buildCommand(Map env) throws IOException, IllegalArgumentException { - if (PYSPARK_SHELL.equals(appResource) && requiresAppResource) { + if (PYSPARK_SHELL.equals(appResource) && !isSpecialCommand) { return buildPySparkShellCommand(env); - } else if (SPARKR_SHELL.equals(appResource) && requiresAppResource) { + } else if (SPARKR_SHELL.equals(appResource) && !isSpecialCommand) { return buildSparkRCommand(env); } else { return buildSparkSubmitCommand(env); @@ -166,18 +168,18 @@ public List buildCommand(Map env) List buildSparkSubmitArgs() { List args = new ArrayList<>(); OptionParser parser = new OptionParser(false); - final boolean requiresAppResource; + final boolean isSpecialCommand; // If the user args array is not empty, we need to parse it to detect exactly what // the user is trying to run, so that checks below are correct. if (!userArgs.isEmpty()) { parser.parse(userArgs); - requiresAppResource = parser.requiresAppResource; + isSpecialCommand = parser.isSpecialCommand; } else { - requiresAppResource = this.requiresAppResource; + isSpecialCommand = this.isSpecialCommand; } - if (!allowsMixedArguments && requiresAppResource) { + if (!allowsMixedArguments && !isSpecialCommand) { checkArgument(appResource != null, "Missing application resource."); } @@ -229,7 +231,7 @@ List buildSparkSubmitArgs() { args.add(join(",", pyFiles)); } - if (isExample) { + if (isExample && !isSpecialCommand) { checkArgument(mainClass != null, "Missing example class name."); } @@ -421,7 +423,7 @@ private List findExamplesJars() { private class OptionParser extends SparkSubmitOptionParser { - boolean requiresAppResource = true; + boolean isSpecialCommand = false; private final boolean errorOnUnknownArgs; OptionParser(boolean errorOnUnknownArgs) { @@ -470,17 +472,14 @@ protected boolean handle(String opt, String value) { break; case KILL_SUBMISSION: case STATUS: - requiresAppResource = false; + isSpecialCommand = true; parsedArgs.add(opt); parsedArgs.add(value); break; case HELP: case USAGE_ERROR: - requiresAppResource = false; - parsedArgs.add(opt); - break; case VERSION: - requiresAppResource = false; + isSpecialCommand = true; parsedArgs.add(opt); break; default: diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java index 2e050f8413074..b343094b2e7b8 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.launcher; import java.io.File; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -27,7 +28,10 @@ import org.junit.AfterClass; import org.junit.BeforeClass; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; + import static org.junit.Assert.*; public class SparkSubmitCommandBuilderSuite extends BaseSuite { @@ -35,6 +39,9 @@ public class SparkSubmitCommandBuilderSuite extends BaseSuite { private static File dummyPropsFile; private static SparkSubmitOptionParser parser; + @Rule + public ExpectedException expectedException = ExpectedException.none(); + @BeforeClass public static void setUp() throws Exception { dummyPropsFile = File.createTempFile("spark", "properties"); @@ -74,8 +81,11 @@ public void testCliHelpAndNoArg() throws Exception { @Test public void testCliKillAndStatus() throws Exception { - testCLIOpts(parser.STATUS); - testCLIOpts(parser.KILL_SUBMISSION); + List params = Arrays.asList("driver-20160531171222-0000"); + testCLIOpts(null, parser.STATUS, params); + testCLIOpts(null, parser.KILL_SUBMISSION, params); + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.STATUS, params); + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.KILL_SUBMISSION, params); } @Test @@ -190,6 +200,33 @@ public void testSparkRShell() throws Exception { env.get("SPARKR_SUBMIT_ARGS")); } + @Test(expected = IllegalArgumentException.class) + public void testExamplesRunnerNoArg() throws Exception { + List sparkSubmitArgs = Arrays.asList(SparkSubmitCommandBuilder.RUN_EXAMPLE); + Map env = new HashMap<>(); + buildCommand(sparkSubmitArgs, env); + } + + @Test + public void testExamplesRunnerNoMainClass() throws Exception { + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.HELP, null); + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.USAGE_ERROR, null); + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.VERSION, null); + } + + @Test + public void testExamplesRunnerWithMasterNoMainClass() throws Exception { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Missing example class name."); + + List sparkSubmitArgs = Arrays.asList( + SparkSubmitCommandBuilder.RUN_EXAMPLE, + parser.MASTER + "=foo" + ); + Map env = new HashMap<>(); + buildCommand(sparkSubmitArgs, env); + } + @Test public void testExamplesRunner() throws Exception { List sparkSubmitArgs = Arrays.asList( @@ -344,10 +381,17 @@ private List buildCommand(List args, Map env) th return newCommandBuilder(args).buildCommand(env); } - private void testCLIOpts(String opt) throws Exception { - List helpArgs = Arrays.asList(opt, "driver-20160531171222-0000"); + private void testCLIOpts(String appResource, String opt, List params) throws Exception { + List args = new ArrayList<>(); + if (appResource != null) { + args.add(appResource); + } + args.add(opt); + if (params != null) { + args.addAll(params); + } Map env = new HashMap<>(); - List cmd = buildCommand(helpArgs, env); + List cmd = buildCommand(args, env); assertTrue(opt + " should be contained in the final cmd.", cmd.contains(opt)); } diff --git a/licenses/LICENSE-jmock.txt b/licenses/LICENSE-jmock.txt new file mode 100644 index 0000000000000..ed7964fe3d9ef --- /dev/null +++ b/licenses/LICENSE-jmock.txt @@ -0,0 +1,28 @@ +Copyright (c) 2000-2017, jMock.org +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +Redistributions of source code must retain the above copyright notice, +this list of conditions and the following disclaimer. Redistributions +in binary form must reproduce the above copyright notice, this list of +conditions and the following disclaimer in the documentation and/or +other materials provided with the distribution. + +Neither the name of jMock nor the names of its contributors may be +used to endorse or promote products derived from this software without +specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index 3a1c166d46257..11f46eb9e4359 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -30,6 +30,7 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.Param import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset +import org.apache.spark.util.Utils /** * A small wrapper that defines a training session for an estimator, and some methods to log @@ -47,7 +48,9 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( private val id = UUID.randomUUID() private val prefix = { - val className = estimator.getClass.getSimpleName + // estimator.getClass.getSimpleName can cause Malformed class name error, + // call safer `Utils.getSimpleName` instead + val className = Utils.getSimpleName(estimator.getClass) s"$className-${estimator.uid}-${dataset.hashCode()}-$id: " } diff --git a/pom.xml b/pom.xml index 23bbd3b09734e..90e64ff71d229 100644 --- a/pom.xml +++ b/pom.xml @@ -155,7 +155,7 @@ 3.4.1 3.2.2 - 2.11.8 + 2.11.12 2.11 1.9.13 2.6.7 @@ -740,13 +740,13 @@ org.scala-lang.modules scala-parser-combinators_${scala.binary.version} - 1.0.4 + 1.1.0 jline jline - 2.12.1 + 2.14.3 org.scalatest @@ -760,6 +760,12 @@ 1.10.19 test + + org.jmock + jmock-junit4 + test + 2.8.4 + org.scalacheck scalacheck_${scala.binary.version} @@ -2749,7 +2755,7 @@ scala-2.12 - 2.12.4 + 2.12.6 2.12 diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index adc2b6b5d273d..b606f9355e03b 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -57,11 +57,11 @@ object BuildCommons { val optionallyEnabledProjects@Seq(kubernetes, mesos, yarn, streamingFlumeSink, streamingFlume, streamingKafka, sparkGangliaLgpl, streamingKinesisAsl, - dockerIntegrationTests, hadoopCloud) = + dockerIntegrationTests, hadoopCloud, kubernetesIntegrationTests) = Seq("kubernetes", "mesos", "yarn", "streaming-flume-sink", "streaming-flume", "streaming-kafka-0-8", "ganglia-lgpl", "streaming-kinesis-asl", - "docker-integration-tests", "hadoop-cloud").map(ProjectRef(buildLocation, _)) + "docker-integration-tests", "hadoop-cloud", "kubernetes-integration-tests").map(ProjectRef(buildLocation, _)) val assemblyProjects@Seq(networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKafka010Assembly, streamingKinesisAslAssembly) = Seq("network-yarn", "streaming-flume-assembly", "streaming-kafka-0-8-assembly", "streaming-kafka-0-10-assembly", "streaming-kinesis-asl-assembly") diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 0afbe9dc6aa3e..fa2d5e8db716a 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -31,7 +31,7 @@ if sys.version >= '3': xrange = range -from py4j.java_gateway import java_import, JavaGateway, GatewayParameters +from py4j.java_gateway import java_import, JavaGateway, JavaObject, GatewayParameters from pyspark.find_spark_home import _find_spark_home from pyspark.serializers import read_int, write_with_length, UTF8Deserializer @@ -145,3 +145,26 @@ def do_server_auth(conn, auth_secret): if reply != "ok": conn.close() raise Exception("Unexpected reply from iterator server.") + + +def ensure_callback_server_started(gw): + """ + Start callback server if not already started. The callback server is needed if the Java + driver process needs to callback into the Python driver process to execute Python code. + """ + + # getattr will fallback to JVM, so we cannot test by hasattr() + if "_callback_server" not in gw.__dict__ or gw._callback_server is None: + gw.callback_server_parameters.eager_load = True + gw.callback_server_parameters.daemonize = True + gw.callback_server_parameters.daemonize_connections = True + gw.callback_server_parameters.port = 0 + gw.start_callback_server(gw.callback_server_parameters) + cbport = gw._callback_server.server_socket.getsockname()[1] + gw._callback_server.port = cbport + # gateway with real port + gw._python_proxy_port = gw._callback_server.port + # get the GatewayServer object in JVM by ID + jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client) + # update the port of CallbackClient with real port + jgws.resetCallbackClient(jgws.getCallbackClient().getAddress(), gw._python_proxy_port) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 14d9128502ab0..7e7e5822a6b20 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -74,6 +74,7 @@ class PythonEvalType(object): SQL_SCALAR_PANDAS_UDF = 200 SQL_GROUPED_MAP_PANDAS_UDF = 201 SQL_GROUPED_AGG_PANDAS_UDF = 202 + SQL_WINDOW_AGG_PANDAS_UDF = 203 def portable_hash(x): diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 15753f77bd903..4c16b5fc26f3d 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -33,8 +33,9 @@ [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] >>> sc.stop() -PySpark serialize objects in batches; By default, the batch size is chosen based -on the size of objects, also configurable by SparkContext's C{batchSize} parameter: +PySpark serializes objects in batches; by default, the batch size is chosen based +on the size of objects and is also configurable by SparkContext's C{batchSize} +parameter: >>> sc = SparkContext('local', 'test', batchSize=2) >>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) @@ -100,7 +101,7 @@ def load_stream(self, stream): def _load_stream_without_unbatching(self, stream): """ Return an iterator of deserialized batches (iterable) of objects from the input stream. - if the serializer does not operate on batches the default implementation returns an + If the serializer does not operate on batches the default implementation returns an iterator of single element lists. """ return map(lambda x: [x], self.load_stream(stream)) @@ -461,7 +462,7 @@ def dumps(self, obj): return obj -# Hook namedtuple, make it picklable +# Hack namedtuple, make it picklable __cls = {} @@ -525,15 +526,15 @@ def namedtuple(*args, **kwargs): cls = _old_namedtuple(*args, **kwargs) return _hack_namedtuple(cls) - # replace namedtuple with new one + # replace namedtuple with the new one collections.namedtuple.__globals__["_old_namedtuple_kwdefaults"] = _old_namedtuple_kwdefaults collections.namedtuple.__globals__["_old_namedtuple"] = _old_namedtuple collections.namedtuple.__globals__["_hack_namedtuple"] = _hack_namedtuple collections.namedtuple.__code__ = namedtuple.__code__ collections.namedtuple.__hijack = 1 - # hack the cls already generated by namedtuple - # those created in other module can be pickled as normal, + # hack the cls already generated by namedtuple. + # Those created in other modules can be pickled as normal, # so only hack those in __main__ module for n, o in sys.modules["__main__"].__dict__.items(): if (type(o) is type and o.__base__ is tuple @@ -627,7 +628,7 @@ def loads(self, obj): elif _type == b'P': return pickle.loads(obj[1:]) else: - raise ValueError("invalid sevialization type: %s" % _type) + raise ValueError("invalid serialization type: %s" % _type) class CompressedSerializer(FramedSerializer): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1759195c6fcc0..9652d3e79b875 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1819,6 +1819,25 @@ def create_map(*cols): return Column(jc) +@since(2.4) +def map_from_arrays(col1, col2): + """Creates a new map from two arrays. + + :param col1: name of column containing a set of keys. All elements should not be null + :param col2: name of column containing a set of values + + >>> df = spark.createDataFrame([([2, 5], ['a', 'b'])], ['k', 'v']) + >>> df.select(map_from_arrays(df.k, df.v).alias("map")).show() + +----------------+ + | map| + +----------------+ + |[2 -> a, 5 -> b]| + +----------------+ + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.map_from_arrays(_to_java_column(col1), _to_java_column(col2))) + + @since(1.4) def array(*cols): """Creates a new array column. @@ -1980,6 +1999,20 @@ def array_remove(col, element): return Column(sc._jvm.functions.array_remove(_to_java_column(col), element)) +@since(2.4) +def array_distinct(col): + """ + Collection function: removes duplicate values from the array. + :param col: name of column or expression + + >>> df = spark.createDataFrame([([1, 2, 3, 2],), ([4, 5, 5, 4],)], ['data']) + >>> df.select(array_distinct(df.data)).collect() + [Row(array_distinct(data)=[1, 2, 3]), Row(array_distinct(data)=[4, 5])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_distinct(_to_java_column(col))) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. @@ -2149,8 +2182,7 @@ def from_json(col, schema, options={}): [Row(json=Row(a=1))] >>> df.select(from_json(df.value, "a INT").alias("json")).collect() [Row(json=Row(a=1))] - >>> schema = MapType(StringType(), IntegerType()) - >>> df.select(from_json(df.value, schema).alias("json")).collect() + >>> df.select(from_json(df.value, "MAP").alias("json")).collect() [Row(json={u'a': 1})] >>> data = [(1, '''[{"a": 1}]''')] >>> schema = ArrayType(StructType([StructField("a", IntegerType())])) @@ -2380,6 +2412,26 @@ def map_entries(col): return Column(sc._jvm.functions.map_entries(_to_java_column(col))) +@since(2.4) +def map_from_entries(col): + """ + Collection function: Returns a map created from the given array of entries. + + :param col: name of column or expression + + >>> from pyspark.sql.functions import map_from_entries + >>> df = spark.sql("SELECT array(struct(1, 'a'), struct(2, 'b')) as data") + >>> df.select(map_from_entries("data").alias("map")).show() + +----------------+ + | map| + +----------------+ + |[1 -> a, 2 -> b]| + +----------------+ + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.map_from_entries(_to_java_column(col))) + + @ignore_unicode_prefix @since(2.4) def array_repeat(col, count): @@ -2394,6 +2446,23 @@ def array_repeat(col, count): return Column(sc._jvm.functions.array_repeat(_to_java_column(col), count)) +@since(2.4) +def arrays_zip(*cols): + """ + Collection function: Returns a merged array of structs in which the N-th struct contains all + N-th values of input arrays. + + :param cols: columns of arrays to be merged. + + >>> from pyspark.sql.functions import arrays_zip + >>> df = spark.createDataFrame([(([1, 2, 3], [2, 3, 4]))], ['vals1', 'vals2']) + >>> df.select(arrays_zip(df.vals1, df.vals2).alias('zipped')).collect() + [Row(zipped=[Row(vals1=1, vals2=2), Row(vals1=2, vals2=3), Row(vals1=3, vals2=4)])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.arrays_zip(_to_seq(sc, cols, _to_java_column))) + + # ---------------------------- User Defined Function ---------------------------------- class PandasUDFType(object): @@ -2515,9 +2584,10 @@ def pandas_udf(f=None, returnType=None, functionType=None): A grouped map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame` The returnType should be a :class:`StructType` describing the schema of the returned - `pandas.DataFrame`. - The length of the returned `pandas.DataFrame` can be arbitrary and the columns must be - indexed so that their position matches the corresponding field in the schema. + `pandas.DataFrame`. The column labels of the returned `pandas.DataFrame` must either match + the field names in the defined returnType schema if specified as strings, or match the + field data types by position if not strings, e.g. integer indices. + The length of the returned `pandas.DataFrame` can be arbitrary. Grouped map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`. @@ -2580,10 +2650,12 @@ def pandas_udf(f=None, returnType=None, functionType=None): The returned scalar can be either a python primitive type, e.g., `int` or `float` or a numpy data type, e.g., `numpy.int64` or `numpy.float64`. - :class:`ArrayType`, :class:`MapType` and :class:`StructType` are currently not supported as - output types. + :class:`MapType` and :class:`StructType` are currently not supported as output types. - Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` + Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` and + :class:`pyspark.sql.Window` + + This example shows using grouped aggregated UDFs with groupby: >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( @@ -2600,7 +2672,31 @@ def pandas_udf(f=None, returnType=None, functionType=None): | 2| 6.0| +---+-----------+ - .. seealso:: :meth:`pyspark.sql.GroupedData.agg` + This example shows using grouped aggregated UDFs as window functions. Note that only + unbounded window frame is supported at the moment: + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> from pyspark.sql import Window + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) + >>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP + ... def mean_udf(v): + ... return v.mean() + >>> w = Window.partitionBy('id') \\ + ... .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + >>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() # doctest: +SKIP + +---+----+------+ + | id| v|mean_v| + +---+----+------+ + | 1| 1.0| 1.5| + | 1| 2.0| 1.5| + | 2| 3.0| 6.0| + | 2| 5.0| 6.0| + | 2|10.0| 6.0| + +---+----+------+ + + .. seealso:: :meth:`pyspark.sql.GroupedData.agg` and :class:`pyspark.sql.Window` .. note:: The user-defined functions are considered deterministic by default. Due to optimization, duplicate invocations may be eliminated or the function may even be invoked diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index a0e20d39c20da..3efe2adb6e2a4 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -177,7 +177,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None, - encoding=None): + dropFieldIfAllNull=None, encoding=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -246,6 +246,9 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. :param samplingRatio: defines fraction of input JSON objects used for schema inferring. If None is set, it uses the default value, ``1.0``. + :param dropFieldIfAllNull: whether to ignore column of all null values or empty + array/struct during schema inference. If None is set, it + uses the default value, ``false``. >>> df1 = spark.read.json('python/test_support/sql/people.json') >>> df1.dtypes diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index e880dd1ca6d1a..f1ad6b1212ed9 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -567,14 +567,7 @@ def _create_shell_session(): .getOrCreate() else: return SparkSession.builder.getOrCreate() - except py4j.protocol.Py4JError: - if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': - warnings.warn("Fall back to non-hive support because failing to access HiveConf, " - "please make sure you build spark with hive") - - try: - return SparkSession.builder.getOrCreate() - except TypeError: + except (py4j.protocol.Py4JError, TypeError): if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': warnings.warn("Fall back to non-hive support because failing to access HiveConf, " "please make sure you build spark with hive") diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index fae50b3d5d532..8c1fd4af674d7 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -24,12 +24,14 @@ else: intlike = (int, long) +from py4j.java_gateway import java_import + from pyspark import since, keyword_only from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import _to_seq from pyspark.sql.readwriter import OptionUtils, to_str from pyspark.sql.types import * -from pyspark.sql.utils import StreamingQueryException +from pyspark.sql.utils import ForeachBatchFunction, StreamingQueryException __all__ = ["StreamingQuery", "StreamingQueryManager", "DataStreamReader", "DataStreamWriter"] @@ -854,6 +856,197 @@ def trigger(self, processingTime=None, once=None, continuous=None): self._jwrite = self._jwrite.trigger(jTrigger) return self + @since(2.4) + def foreach(self, f): + """ + Sets the output of the streaming query to be processed using the provided writer ``f``. + This is often used to write the output of a streaming query to arbitrary storage systems. + The processing logic can be specified in two ways. + + #. A **function** that takes a row as input. + This is a simple way to express your processing logic. Note that this does + not allow you to deduplicate generated data when failures cause reprocessing of + some input data. That would require you to specify the processing logic in the next + way. + + #. An **object** with a ``process`` method and optional ``open`` and ``close`` methods. + The object can have the following methods. + + * ``open(partition_id, epoch_id)``: *Optional* method that initializes the processing + (for example, open a connection, start a transaction, etc). Additionally, you can + use the `partition_id` and `epoch_id` to deduplicate regenerated data + (discussed later). + + * ``process(row)``: *Non-optional* method that processes each :class:`Row`. + + * ``close(error)``: *Optional* method that finalizes and cleans up (for example, + close connection, commit transaction, etc.) after all rows have been processed. + + The object will be used by Spark in the following way. + + * A single copy of this object is responsible of all the data generated by a + single task in a query. In other words, one instance is responsible for + processing one partition of the data generated in a distributed manner. + + * This object must be serializable because each task will get a fresh + serialized-deserialized copy of the provided object. Hence, it is strongly + recommended that any initialization for writing data (e.g. opening a + connection or starting a transaction) is done after the `open(...)` + method has been called, which signifies that the task is ready to generate data. + + * The lifecycle of the methods are as follows. + + For each partition with ``partition_id``: + + ... For each batch/epoch of streaming data with ``epoch_id``: + + ....... Method ``open(partitionId, epochId)`` is called. + + ....... If ``open(...)`` returns true, for each row in the partition and + batch/epoch, method ``process(row)`` is called. + + ....... Method ``close(errorOrNull)`` is called with error (if any) seen while + processing rows. + + Important points to note: + + * The `partitionId` and `epochId` can be used to deduplicate generated data when + failures cause reprocessing of some input data. This depends on the execution + mode of the query. If the streaming query is being executed in the micro-batch + mode, then every partition represented by a unique tuple (partition_id, epoch_id) + is guaranteed to have the same data. Hence, (partition_id, epoch_id) can be used + to deduplicate and/or transactionally commit data and achieve exactly-once + guarantees. However, if the streaming query is being executed in the continuous + mode, then this guarantee does not hold and therefore should not be used for + deduplication. + + * The ``close()`` method (if exists) will be called if `open()` method exists and + returns successfully (irrespective of the return value), except if the Python + crashes in the middle. + + .. note:: Evolving. + + >>> # Print every row using a function + >>> def print_row(row): + ... print(row) + ... + >>> writer = sdf.writeStream.foreach(print_row) + >>> # Print every row using a object with process() method + >>> class RowPrinter: + ... def open(self, partition_id, epoch_id): + ... print("Opened %d, %d" % (partition_id, epoch_id)) + ... return True + ... def process(self, row): + ... print(row) + ... def close(self, error): + ... print("Closed with error: %s" % str(error)) + ... + >>> writer = sdf.writeStream.foreach(RowPrinter()) + """ + + from pyspark.rdd import _wrap_function + from pyspark.serializers import PickleSerializer, AutoBatchedSerializer + from pyspark.taskcontext import TaskContext + + if callable(f): + # The provided object is a callable function that is supposed to be called on each row. + # Construct a function that takes an iterator and calls the provided function on each + # row. + def func_without_process(_, iterator): + for x in iterator: + f(x) + return iter([]) + + func = func_without_process + + else: + # The provided object is not a callable function. Then it is expected to have a + # 'process(row)' method, and optional 'open(partition_id, epoch_id)' and + # 'close(error)' methods. + + if not hasattr(f, 'process'): + raise Exception("Provided object does not have a 'process' method") + + if not callable(getattr(f, 'process')): + raise Exception("Attribute 'process' in provided object is not callable") + + def doesMethodExist(method_name): + exists = hasattr(f, method_name) + if exists and not callable(getattr(f, method_name)): + raise Exception( + "Attribute '%s' in provided object is not callable" % method_name) + return exists + + open_exists = doesMethodExist('open') + close_exists = doesMethodExist('close') + + def func_with_open_process_close(partition_id, iterator): + epoch_id = TaskContext.get().getLocalProperty('streaming.sql.batchId') + if epoch_id: + epoch_id = int(epoch_id) + else: + raise Exception("Could not get batch id from TaskContext") + + # Check if the data should be processed + should_process = True + if open_exists: + should_process = f.open(partition_id, epoch_id) + + error = None + + try: + if should_process: + for x in iterator: + f.process(x) + except Exception as ex: + error = ex + finally: + if close_exists: + f.close(error) + if error: + raise error + + return iter([]) + + func = func_with_open_process_close + + serializer = AutoBatchedSerializer(PickleSerializer()) + wrapped_func = _wrap_function(self._spark._sc, func, serializer, serializer) + jForeachWriter = \ + self._spark._sc._jvm.org.apache.spark.sql.execution.python.PythonForeachWriter( + wrapped_func, self._df._jdf.schema()) + self._jwrite.foreach(jForeachWriter) + return self + + @since(2.4) + def foreachBatch(self, func): + """ + Sets the output of the streaming query to be processed using the provided + function. This is supported only the in the micro-batch execution modes (that is, when the + trigger is not continuous). In every micro-batch, the provided function will be called in + every micro-batch with (i) the output rows as a DataFrame and (ii) the batch identifier. + The batchId can be used deduplicate and transactionally write the output + (that is, the provided Dataset) to external systems. The output DataFrame is guaranteed + to exactly same for the same batchId (assuming all operations are deterministic in the + query). + + .. note:: Evolving. + + >>> def func(batch_df, batch_id): + ... batch_df.collect() + ... + >>> writer = sdf.writeStream.foreach(func) + """ + + from pyspark.java_gateway import ensure_callback_server_started + gw = self._spark._sc._gateway + java_import(gw.jvm, "org.apache.spark.sql.execution.streaming.sources.*") + + wrapped_func = ForeachBatchFunction(self._spark, func) + gw.jvm.PythonForeachBatchHelper.callForeachBatch(self._jwrite, wrapped_func) + ensure_callback_server_started(gw) + return self + @ignore_unicode_prefix @since(2.0) def start(self, path=None, format=None, outputMode=None, partitionBy=None, queryName=None, diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4a3941de8a6a6..35a0636e5cfc0 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1869,6 +1869,299 @@ def test_query_manager_await_termination(self): q.stop() shutil.rmtree(tmpPath) + class ForeachWriterTester: + + def __init__(self, spark): + self.spark = spark + + def write_open_event(self, partitionId, epochId): + self._write_event( + self.open_events_dir, + {'partition': partitionId, 'epoch': epochId}) + + def write_process_event(self, row): + self._write_event(self.process_events_dir, {'value': 'text'}) + + def write_close_event(self, error): + self._write_event(self.close_events_dir, {'error': str(error)}) + + def write_input_file(self): + self._write_event(self.input_dir, "text") + + def open_events(self): + return self._read_events(self.open_events_dir, 'partition INT, epoch INT') + + def process_events(self): + return self._read_events(self.process_events_dir, 'value STRING') + + def close_events(self): + return self._read_events(self.close_events_dir, 'error STRING') + + def run_streaming_query_on_writer(self, writer, num_files): + self._reset() + try: + sdf = self.spark.readStream.format('text').load(self.input_dir) + sq = sdf.writeStream.foreach(writer).start() + for i in range(num_files): + self.write_input_file() + sq.processAllAvailable() + finally: + self.stop_all() + + def assert_invalid_writer(self, writer, msg=None): + self._reset() + try: + sdf = self.spark.readStream.format('text').load(self.input_dir) + sq = sdf.writeStream.foreach(writer).start() + self.write_input_file() + sq.processAllAvailable() + self.fail("invalid writer %s did not fail the query" % str(writer)) # not expected + except Exception as e: + if msg: + assert msg in str(e), "%s not in %s" % (msg, str(e)) + + finally: + self.stop_all() + + def stop_all(self): + for q in self.spark._wrapped.streams.active: + q.stop() + + def _reset(self): + self.input_dir = tempfile.mkdtemp() + self.open_events_dir = tempfile.mkdtemp() + self.process_events_dir = tempfile.mkdtemp() + self.close_events_dir = tempfile.mkdtemp() + + def _read_events(self, dir, json): + rows = self.spark.read.schema(json).json(dir).collect() + dicts = [row.asDict() for row in rows] + return dicts + + def _write_event(self, dir, event): + import uuid + with open(os.path.join(dir, str(uuid.uuid4())), 'w') as f: + f.write("%s\n" % str(event)) + + def __getstate__(self): + return (self.open_events_dir, self.process_events_dir, self.close_events_dir) + + def __setstate__(self, state): + self.open_events_dir, self.process_events_dir, self.close_events_dir = state + + def test_streaming_foreach_with_simple_function(self): + tester = self.ForeachWriterTester(self.spark) + + def foreach_func(row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(foreach_func, 2) + self.assertEqual(len(tester.process_events()), 2) + + def test_streaming_foreach_with_basic_open_process_close(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partitionId, epochId): + tester.write_open_event(partitionId, epochId) + return True + + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + + open_events = tester.open_events() + self.assertEqual(len(open_events), 2) + self.assertSetEqual(set([e['epoch'] for e in open_events]), {0, 1}) + + self.assertEqual(len(tester.process_events()), 2) + + close_events = tester.close_events() + self.assertEqual(len(close_events), 2) + self.assertSetEqual(set([e['error'] for e in close_events]), {'None'}) + + def test_streaming_foreach_with_open_returning_false(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partition_id, epoch_id): + tester.write_open_event(partition_id, epoch_id) + return False + + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + + self.assertEqual(len(tester.open_events()), 2) + + self.assertEqual(len(tester.process_events()), 0) # no row was processed + + close_events = tester.close_events() + self.assertEqual(len(close_events), 2) + self.assertSetEqual(set([e['error'] for e in close_events]), {'None'}) + + def test_streaming_foreach_without_open_method(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 0) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 2) + + def test_streaming_foreach_without_close_method(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partition_id, epoch_id): + tester.write_open_event(partition_id, epoch_id) + return True + + def process(self, row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 2) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 0) + + def test_streaming_foreach_without_open_and_close_methods(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 0) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 0) + + def test_streaming_foreach_with_process_throwing_error(self): + from pyspark.sql.utils import StreamingQueryException + + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + raise Exception("test error") + + def close(self, error): + tester.write_close_event(error) + + try: + tester.run_streaming_query_on_writer(ForeachWriter(), 1) + self.fail("bad writer did not fail the query") # this is not expected + except StreamingQueryException as e: + # TODO: Verify whether original error message is inside the exception + pass + + self.assertEqual(len(tester.process_events()), 0) # no row was processed + close_events = tester.close_events() + self.assertEqual(len(close_events), 1) + # TODO: Verify whether original error message is inside the exception + + def test_streaming_foreach_with_invalid_writers(self): + + tester = self.ForeachWriterTester(self.spark) + + def func_with_iterator_input(iter): + for x in iter: + print(x) + + tester.assert_invalid_writer(func_with_iterator_input) + + class WriterWithoutProcess: + def open(self, partition): + pass + + tester.assert_invalid_writer(WriterWithoutProcess(), "does not have a 'process'") + + class WriterWithNonCallableProcess(): + process = True + + tester.assert_invalid_writer(WriterWithNonCallableProcess(), + "'process' in provided object is not callable") + + class WriterWithNoParamProcess(): + def process(self): + pass + + tester.assert_invalid_writer(WriterWithNoParamProcess()) + + # Abstract class for tests below + class WithProcess(): + def process(self, row): + pass + + class WriterWithNonCallableOpen(WithProcess): + open = True + + tester.assert_invalid_writer(WriterWithNonCallableOpen(), + "'open' in provided object is not callable") + + class WriterWithNoParamOpen(WithProcess): + def open(self): + pass + + tester.assert_invalid_writer(WriterWithNoParamOpen()) + + class WriterWithNonCallableClose(WithProcess): + close = True + + tester.assert_invalid_writer(WriterWithNonCallableClose(), + "'close' in provided object is not callable") + + def test_streaming_foreachBatch(self): + q = None + collected = dict() + + def collectBatch(batch_df, batch_id): + collected[batch_id] = batch_df.collect() + + try: + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + q = df.writeStream.foreachBatch(collectBatch).start() + q.processAllAvailable() + self.assertTrue(0 in collected) + self.assertTrue(len(collected[0]), 2) + finally: + if q: + q.stop() + + def test_streaming_foreachBatch_propagates_python_errors(self): + from pyspark.sql.utils import StreamingQueryException + + q = None + + def collectBatch(df, id): + raise Exception("this should fail the query") + + try: + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + q = df.writeStream.foreachBatch(collectBatch).start() + q.processAllAvailable() + self.fail("Expected a failure") + except StreamingQueryException as e: + self.assertTrue("this should fail" in str(e)) + finally: + if q: + q.stop() + def test_help_command(self): # Regression test for SPARK-5464 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) @@ -4449,7 +4742,6 @@ def test_vectorized_udf_chained(self): def test_vectorized_udf_wrong_return_type(self): from pyspark.sql.functions import pandas_udf, col - df = self.spark.range(10) with QuietTest(self.sc): with self.assertRaisesRegexp( NotImplementedError, @@ -5034,6 +5326,109 @@ def foo3(key, pdf): expected4 = udf3.func((), pdf) self.assertPandasEqual(expected4, result4) + def test_column_order(self): + from collections import OrderedDict + import pandas as pd + from pyspark.sql.functions import pandas_udf, PandasUDFType + + # Helper function to set column names from a list + def rename_pdf(pdf, names): + pdf.rename(columns={old: new for old, new in + zip(pd_result.columns, names)}, inplace=True) + + df = self.data + grouped_df = df.groupby('id') + grouped_pdf = df.toPandas().groupby('id') + + # Function returns a pdf with required column names, but order could be arbitrary using dict + def change_col_order(pdf): + # Constructing a DataFrame from a dict should result in the same order, + # but use from_items to ensure the pdf column order is different than schema + return pd.DataFrame.from_items([ + ('id', pdf.id), + ('u', pdf.v * 2), + ('v', pdf.v)]) + + ordered_udf = pandas_udf( + change_col_order, + 'id long, v int, u int', + PandasUDFType.GROUPED_MAP + ) + + # The UDF result should assign columns by name from the pdf + result = grouped_df.apply(ordered_udf).sort('id', 'v')\ + .select('id', 'u', 'v').toPandas() + pd_result = grouped_pdf.apply(change_col_order) + expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected, result) + + # Function returns a pdf with positional columns, indexed by range + def range_col_order(pdf): + # Create a DataFrame with positional columns, fix types to long + return pd.DataFrame(list(zip(pdf.id, pdf.v * 3, pdf.v)), dtype='int64') + + range_udf = pandas_udf( + range_col_order, + 'id long, u long, v long', + PandasUDFType.GROUPED_MAP + ) + + # The UDF result uses positional columns from the pdf + result = grouped_df.apply(range_udf).sort('id', 'v') \ + .select('id', 'u', 'v').toPandas() + pd_result = grouped_pdf.apply(range_col_order) + rename_pdf(pd_result, ['id', 'u', 'v']) + expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected, result) + + # Function returns a pdf with columns indexed with integers + def int_index(pdf): + return pd.DataFrame(OrderedDict([(0, pdf.id), (1, pdf.v * 4), (2, pdf.v)])) + + int_index_udf = pandas_udf( + int_index, + 'id long, u int, v int', + PandasUDFType.GROUPED_MAP + ) + + # The UDF result should assign columns by position of integer index + result = grouped_df.apply(int_index_udf).sort('id', 'v') \ + .select('id', 'u', 'v').toPandas() + pd_result = grouped_pdf.apply(int_index) + rename_pdf(pd_result, ['id', 'u', 'v']) + expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected, result) + + @pandas_udf('id long, v int', PandasUDFType.GROUPED_MAP) + def column_name_typo(pdf): + return pd.DataFrame({'iid': pdf.id, 'v': pdf.v}) + + @pandas_udf('id long, v int', PandasUDFType.GROUPED_MAP) + def invalid_positional_types(pdf): + return pd.DataFrame([(u'a', 1.2)]) + + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, "KeyError: 'id'"): + grouped_df.apply(column_name_typo).collect() + with self.assertRaisesRegexp(Exception, "No cast implemented"): + grouped_df.apply(invalid_positional_types).collect() + + def test_positional_assignment_conf(self): + import pandas as pd + from pyspark.sql.functions import pandas_udf, PandasUDFType + + with self.sql_conf({"spark.sql.execution.pandas.groupedMap.assignColumnsByPosition": True}): + + @pandas_udf("a string, b float", PandasUDFType.GROUPED_MAP) + def foo(_): + return pd.DataFrame([('hi', 1)], columns=['x', 'y']) + + df = self.data + result = df.groupBy('id').apply(foo).select('a', 'b').collect() + for r in result: + self.assertEqual(r.a, 'hi') + self.assertEqual(r.b, 1) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, @@ -5454,6 +5849,15 @@ def test_retain_group_columns(self): expected1 = df.groupby(df.id).agg(sum(df.v)) self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + def test_array_type(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + df = self.data + + array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array', PandasUDFType.GROUPED_AGG) + result1 = df.groupby('id').agg(array_udf(df['v']).alias('v2')) + self.assertEquals(result1.first()['v2'], [1.0, 2.0]) + def test_invalid_args(self): from pyspark.sql.functions import mean @@ -5479,6 +5883,235 @@ def test_invalid_args(self): 'mixture.*aggregate function.*group aggregate pandas UDF'): df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() + +@unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) +class WindowPandasUDFTests(ReusedSQLTestCase): + @property + def data(self): + from pyspark.sql.functions import array, explode, col, lit + return self.spark.range(10).toDF('id') \ + .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \ + .withColumn("v", explode(col('vs'))) \ + .drop('vs') \ + .withColumn('w', lit(1.0)) + + @property + def python_plus_one(self): + from pyspark.sql.functions import udf + return udf(lambda v: v + 1, 'double') + + @property + def pandas_scalar_time_two(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + return pandas_udf(lambda v: v * 2, 'double') + + @property + def pandas_agg_mean_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUPED_AGG) + def avg(v): + return v.mean() + return avg + + @property + def pandas_agg_max_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUPED_AGG) + def max(v): + return v.max() + return max + + @property + def pandas_agg_min_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUPED_AGG) + def min(v): + return v.min() + return min + + @property + def unbounded_window(self): + return Window.partitionBy('id') \ + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + + @property + def ordered_window(self): + return Window.partitionBy('id').orderBy('v') + + @property + def unpartitioned_window(self): + return Window.partitionBy() + + def test_simple(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType, percent_rank, mean, max + + df = self.data + w = self.unbounded_window + + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn('mean_v', mean_udf(df['v']).over(w)) + expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) + + result2 = df.select(mean_udf(df['v']).over(w)) + expected2 = df.select(mean(df['v']).over(w)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + def test_multiple_udfs(self): + from pyspark.sql.functions import max, min, mean + + df = self.data + w = self.unbounded_window + + result1 = df.withColumn('mean_v', self.pandas_agg_mean_udf(df['v']).over(w)) \ + .withColumn('max_v', self.pandas_agg_max_udf(df['v']).over(w)) \ + .withColumn('min_w', self.pandas_agg_min_udf(df['w']).over(w)) + + expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) \ + .withColumn('max_v', max(df['v']).over(w)) \ + .withColumn('min_w', min(df['w']).over(w)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_replace_existing(self): + from pyspark.sql.functions import mean + + df = self.data + w = self.unbounded_window + + result1 = df.withColumn('v', self.pandas_agg_mean_udf(df['v']).over(w)) + expected1 = df.withColumn('v', mean(df['v']).over(w)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_mixed_sql(self): + from pyspark.sql.functions import mean + + df = self.data + w = self.unbounded_window + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn('v', mean_udf(df['v'] * 2).over(w) + 1) + expected1 = df.withColumn('v', mean(df['v'] * 2).over(w) + 1) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_mixed_udf(self): + from pyspark.sql.functions import mean + + df = self.data + w = self.unbounded_window + + plus_one = self.python_plus_one + time_two = self.pandas_scalar_time_two + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn( + 'v2', + plus_one(mean_udf(plus_one(df['v'])).over(w))) + expected1 = df.withColumn( + 'v2', + plus_one(mean(plus_one(df['v'])).over(w))) + + result2 = df.withColumn( + 'v2', + time_two(mean_udf(time_two(df['v'])).over(w))) + expected2 = df.withColumn( + 'v2', + time_two(mean(time_two(df['v'])).over(w))) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + def test_without_partitionBy(self): + from pyspark.sql.functions import mean + + df = self.data + w = self.unpartitioned_window + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn('v2', mean_udf(df['v']).over(w)) + expected1 = df.withColumn('v2', mean(df['v']).over(w)) + + result2 = df.select(mean_udf(df['v']).over(w)) + expected2 = df.select(mean(df['v']).over(w)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + def test_mixed_sql_and_udf(self): + from pyspark.sql.functions import max, min, rank, col + + df = self.data + w = self.unbounded_window + ow = self.ordered_window + max_udf = self.pandas_agg_max_udf + min_udf = self.pandas_agg_min_udf + + result1 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min_udf(df['v']).over(w)) + expected1 = df.withColumn('v_diff', max(df['v']).over(w) - min(df['v']).over(w)) + + # Test mixing sql window function and window udf in the same expression + result2 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min(df['v']).over(w)) + expected2 = expected1 + + # Test chaining sql aggregate function and udf + result3 = df.withColumn('max_v', max_udf(df['v']).over(w)) \ + .withColumn('min_v', min(df['v']).over(w)) \ + .withColumn('v_diff', col('max_v') - col('min_v')) \ + .drop('max_v', 'min_v') + expected3 = expected1 + + # Test mixing sql window function and udf + result4 = df.withColumn('max_v', max_udf(df['v']).over(w)) \ + .withColumn('rank', rank().over(ow)) + expected4 = df.withColumn('max_v', max(df['v']).over(w)) \ + .withColumn('rank', rank().over(ow)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) + + def test_array_type(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + df = self.data + w = self.unbounded_window + + array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array', PandasUDFType.GROUPED_AGG) + result1 = df.withColumn('v2', array_udf(df['v']).over(w)) + self.assertEquals(result1.first()['v2'], [1.0, 2.0]) + + def test_invalid_args(self): + from pyspark.sql.functions import mean, pandas_udf, PandasUDFType + + df = self.data + w = self.unbounded_window + ow = self.ordered_window + mean_udf = self.pandas_agg_mean_udf + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + '.*not supported within a window function'): + foo_udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP) + df.withColumn('v2', foo_udf(df['v']).over(w)) + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + '.*Only unbounded window frame is supported.*'): + df.withColumn('mean_v', mean_udf(df['v']).over(ow)) + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 45363f089a73d..bb9ce02c4b60f 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -150,3 +150,26 @@ def require_minimum_pyarrow_version(): if LooseVersion(pyarrow.__version__) < LooseVersion(minimum_pyarrow_version): raise ImportError("PyArrow >= %s must be installed; however, " "your version was %s." % (minimum_pyarrow_version, pyarrow.__version__)) + + +class ForeachBatchFunction(object): + """ + This is the Python implementation of Java interface 'ForeachBatchFunction'. This wraps + the user-defined 'foreachBatch' function such that it can be called from the JVM when + the query is active. + """ + + def __init__(self, sql_ctx, func): + self.sql_ctx = sql_ctx + self.func = func + + def call(self, jdf, batch_id): + from pyspark.sql.dataframe import DataFrame + try: + self.func(DataFrame(jdf, self.sql_ctx), batch_id) + except Exception as e: + self.error = e + raise e + + class Java: + implements = ['org.apache.spark.sql.execution.streaming.sources.PythonForeachBatchFunction'] diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index dd924ef89868e..a4515828d180c 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -79,22 +79,8 @@ def _ensure_initialized(cls): java_import(gw.jvm, "org.apache.spark.streaming.api.java.*") java_import(gw.jvm, "org.apache.spark.streaming.api.python.*") - # start callback server - # getattr will fallback to JVM, so we cannot test by hasattr() - if "_callback_server" not in gw.__dict__ or gw._callback_server is None: - gw.callback_server_parameters.eager_load = True - gw.callback_server_parameters.daemonize = True - gw.callback_server_parameters.daemonize_connections = True - gw.callback_server_parameters.port = 0 - gw.start_callback_server(gw.callback_server_parameters) - cbport = gw._callback_server.server_socket.getsockname()[1] - gw._callback_server.port = cbport - # gateway with real port - gw._python_proxy_port = gw._callback_server.port - # get the GatewayServer object in JVM by ID - jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client) - # update the port of CallbackClient with real port - jgws.resetCallbackClient(jgws.getCallbackClient().getAddress(), gw._python_proxy_port) + from pyspark.java_gateway import ensure_callback_server_started + ensure_callback_server_started(gw) # register serializer for TransformFunction # it happens before creating SparkContext when loading from checkpointing diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 18b2f251dc9fd..a4c5fb1db8b37 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -581,9 +581,9 @@ def test_get_local_property(self): self.sc.setLocalProperty(key, value) try: rdd = self.sc.parallelize(range(1), 1) - prop1 = rdd.map(lambda x: TaskContext.get().getLocalProperty(key)).collect()[0] + prop1 = rdd.map(lambda _: TaskContext.get().getLocalProperty(key)).collect()[0] self.assertEqual(prop1, value) - prop2 = rdd.map(lambda x: TaskContext.get().getLocalProperty("otherkey")).collect()[0] + prop2 = rdd.map(lambda _: TaskContext.get().getLocalProperty("otherkey")).collect()[0] self.assertTrue(prop2 is None) finally: self.sc.setLocalProperty(key, None) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index a30d6bf523a50..eaaae2b14e107 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -38,6 +38,9 @@ from pyspark.util import _get_argspec, fail_on_stopiteration from pyspark import shuffle +if sys.version >= '3': + basestring = str + pickleSer = PickleSerializer() utf8_deserializer = UTF8Deserializer() @@ -92,7 +95,10 @@ def verify_result_length(*a): return lambda *a: (verify_result_length(*a), arrow_return_type) -def wrap_grouped_map_pandas_udf(f, return_type, argspec): +def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf): + assign_cols_by_pos = runner_conf.get( + "spark.sql.execution.pandas.groupedMap.assignColumnsByPosition", False) + def wrapped(key_series, value_series): import pandas as pd @@ -110,9 +116,13 @@ def wrapped(key_series, value_series): "Number of columns of the returned pandas.DataFrame " "doesn't match specified schema. " "Expected: {} Actual: {}".format(len(return_type), len(result.columns))) - arrow_return_types = (to_arrow_type(field.dataType) for field in return_type) - return [(result[result.columns[i]], arrow_type) - for i, arrow_type in enumerate(arrow_return_types)] + + # Assign result columns by schema name if user labeled with strings, else use position + if not assign_cols_by_pos and any(isinstance(name, basestring) for name in result.columns): + return [(result[field.name], to_arrow_type(field.dataType)) for field in return_type] + else: + return [(result[result.columns[i]], to_arrow_type(field.dataType)) + for i, field in enumerate(return_type)] return wrapped @@ -128,7 +138,22 @@ def wrapped(*series): return lambda *a: (wrapped(*a), arrow_return_type) -def read_single_udf(pickleSer, infile, eval_type): +def wrap_window_agg_pandas_udf(f, return_type): + # This is similar to grouped_agg_pandas_udf, the only difference + # is that window_agg_pandas_udf needs to repeat the return value + # to match window length, where grouped_agg_pandas_udf just returns + # the scalar value. + arrow_return_type = to_arrow_type(return_type) + + def wrapped(*series): + import pandas as pd + result = f(*series) + return pd.Series([result]).repeat(len(series[0])) + + return lambda *a: (wrapped(*a), arrow_return_type) + + +def read_single_udf(pickleSer, infile, eval_type, runner_conf): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] row_func = None @@ -148,9 +173,11 @@ def read_single_udf(pickleSer, infile, eval_type): return arg_offsets, wrap_scalar_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: argspec = _get_argspec(row_func) # signature was lost when wrapping it - return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) + return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec, runner_conf) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type) + elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF: + return arg_offsets, wrap_window_agg_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_BATCHED_UDF: return arg_offsets, wrap_udf(func, return_type) else: @@ -158,6 +185,26 @@ def read_single_udf(pickleSer, infile, eval_type): def read_udfs(pickleSer, infile, eval_type): + runner_conf = {} + + if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, + PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF): + + # Load conf used for pandas_udf evaluation + num_conf = read_int(infile) + for i in range(num_conf): + k = utf8_deserializer.loads(infile) + v = utf8_deserializer.loads(infile) + runner_conf[k] = v + + # NOTE: if timezone is set here, that implies respectSessionTimeZone is True + timezone = runner_conf.get("spark.sql.session.timeZone", None) + ser = ArrowStreamPandasSerializer(timezone) + else: + ser = BatchedSerializer(PickleSerializer(), 100) + num_udfs = read_int(infile) udfs = {} call_udf = [] @@ -172,7 +219,7 @@ def read_udfs(pickleSer, infile, eval_type): # See FlatMapGroupsInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes - arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type) + arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf) udfs['f'] = udf split_offset = arg_offsets[0] + 1 arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]] @@ -184,7 +231,7 @@ def read_udfs(pickleSer, infile, eval_type): # In the special case of a single UDF this will return a single result rather # than a tuple of results; this is the format that the JVM side expects. for i in range(num_udfs): - arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type) + arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf) udfs['f%d' % i] = udf args = ["a[%d]" % o for o in arg_offsets] call_udf.append("f%d(%s)" % (i, ", ".join(args))) @@ -193,14 +240,6 @@ def read_udfs(pickleSer, infile, eval_type): mapper = eval(mapper_str, udfs) func = lambda _, it: map(mapper, it) - if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, - PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): - timezone = utf8_deserializer.loads(infile) - ser = ArrowStreamPandasSerializer(timezone) - else: - ser = BatchedSerializer(PickleSerializer(), 100) - # profiling is not supported for UDF return func, None, ser, ser diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala index e69441a475e9a..a44051b351e19 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -36,7 +36,7 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) def this() = this(None, new JPrintWriter(Console.out, true)) override def createInterpreter(): Unit = { - intp = new SparkILoopInterpreter(settings, out) + intp = new SparkILoopInterpreter(settings, out, initializeSpark) } val initializationCommands: Seq[String] = Seq( @@ -73,11 +73,15 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) "import org.apache.spark.sql.functions._" ) - def initializeSpark() { - intp.beQuietDuring { - savingReplayStack { // remove the commands from session history. - initializationCommands.foreach(processLine) + def initializeSpark(): Unit = { + if (!intp.reporter.hasErrors) { + // `savingReplayStack` removes the commands from session history. + savingReplayStack { + initializationCommands.foreach(intp quietRun _) } + } else { + throw new RuntimeException(s"Scala $versionString interpreter encountered " + + "errors during initialization") } } @@ -101,16 +105,6 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) /** Available commands */ override def commands: List[LoopCommand] = standardCommands - /** - * We override `loadFiles` because we need to initialize Spark *before* the REPL - * sees any files, so that the Spark context is visible in those files. This is a bit of a - * hack, but there isn't another hook available to us at this point. - */ - override def loadFiles(settings: Settings): Unit = { - initializeSpark() - super.loadFiles(settings) - } - override def resetCommand(line: String): Unit = { super.resetCommand(line) initializeSpark() diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala index e736607a9a6b9..4e63816402a10 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala @@ -21,8 +21,22 @@ import scala.collection.mutable import scala.tools.nsc.Settings import scala.tools.nsc.interpreter._ -class SparkILoopInterpreter(settings: Settings, out: JPrintWriter) extends IMain(settings, out) { - self => +class SparkILoopInterpreter(settings: Settings, out: JPrintWriter, initializeSpark: () => Unit) + extends IMain(settings, out) { self => + + /** + * We override `initializeSynchronous` to initialize Spark *after* `intp` is properly initialized + * and *before* the REPL sees any files in the private `loadInitFiles` functions, so that + * the Spark context is visible in those files. + * + * This is a bit of a hack, but there isn't another hook available to us at this point. + * + * See the discussion in Scala community https://github.com/scala/bug/issues/10913 for detail. + */ + override def initializeSynchronous(): Unit = { + super.initializeSynchronous() + initializeSpark() + } override lazy val memberHandlers = new { val intp: self.type = self diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index a62f271273465..a6dd47a6b7d95 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -77,6 +77,12 @@ + + com.squareup.okhttp3 + okhttp + 3.8.1 + + org.mockito mockito-core @@ -84,9 +90,9 @@ - com.squareup.okhttp3 - okhttp - 3.8.1 + org.jmock + jmock-junit4 + test diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 590deaa72e7ee..bf33179ae3dab 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -176,6 +176,24 @@ private[spark] object Config extends Logging { .checkValue(interval => interval > 0, s"Logging interval must be a positive time value.") .createWithDefaultString("1s") + val KUBERNETES_EXECUTOR_API_POLLING_INTERVAL = + ConfigBuilder("spark.kubernetes.executor.apiPollingInterval") + .doc("Interval between polls against the Kubernetes API server to inspect the " + + "state of executors.") + .timeConf(TimeUnit.MILLISECONDS) + .checkValue(interval => interval > 0, s"API server polling interval must be a" + + " positive time value.") + .createWithDefaultString("30s") + + val KUBERNETES_EXECUTOR_EVENT_PROCESSING_INTERVAL = + ConfigBuilder("spark.kubernetes.executor.eventProcessingInterval") + .doc("Interval between successive inspection of executor events sent from the" + + " Kubernetes API.") + .timeConf(TimeUnit.MILLISECONDS) + .checkValue(interval => interval > 0, s"Event processing interval must be a positive" + + " time value.") + .createWithDefaultString("1s") + val MEMORY_OVERHEAD_FACTOR = ConfigBuilder("spark.kubernetes.memoryOverheadFactor") .doc("This sets the Memory Overhead Factor that will allocate memory to non-JVM jobs " + @@ -193,7 +211,6 @@ private[spark] object Config extends Logging { "Ensure that major Python version is either Python2 or Python3") .createWithDefault("2") - val KUBERNETES_AUTH_SUBMISSION_CONF_PREFIX = "spark.kubernetes.authenticate.submission" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala new file mode 100644 index 0000000000000..83daddf714489 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod + +sealed trait ExecutorPodState { + def pod: Pod +} + +case class PodRunning(pod: Pod) extends ExecutorPodState + +case class PodPending(pod: Pod) extends ExecutorPodState + +sealed trait FinalPodState extends ExecutorPodState + +case class PodSucceeded(pod: Pod) extends FinalPodState + +case class PodFailed(pod: Pod) extends FinalPodState + +case class PodDeleted(pod: Pod) extends FinalPodState + +case class PodUnknown(pod: Pod) extends ExecutorPodState diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala new file mode 100644 index 0000000000000..5a143ad3600fd --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.atomic.{AtomicInteger, AtomicLong} + +import io.fabric8.kubernetes.api.model.PodBuilder +import io.fabric8.kubernetes.client.KubernetesClient +import scala.collection.mutable + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.KubernetesConf +import org.apache.spark.internal.Logging +import org.apache.spark.util.{Clock, Utils} + +private[spark] class ExecutorPodsAllocator( + conf: SparkConf, + executorBuilder: KubernetesExecutorBuilder, + kubernetesClient: KubernetesClient, + snapshotsStore: ExecutorPodsSnapshotsStore, + clock: Clock) extends Logging { + + private val EXECUTOR_ID_COUNTER = new AtomicLong(0L) + + private val totalExpectedExecutors = new AtomicInteger(0) + + private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE) + + private val podAllocationDelay = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY) + + private val podCreationTimeout = math.max(podAllocationDelay * 5, 60000) + + private val kubernetesDriverPodName = conf + .get(KUBERNETES_DRIVER_POD_NAME) + .getOrElse(throw new SparkException("Must specify the driver pod name")) + + private val driverPod = kubernetesClient.pods() + .withName(kubernetesDriverPodName) + .get() + + // Executor IDs that have been requested from Kubernetes but have not been detected in any + // snapshot yet. Mapped to the timestamp when they were created. + private val newlyCreatedExecutors = mutable.Map.empty[Long, Long] + + def start(applicationId: String): Unit = { + snapshotsStore.addSubscriber(podAllocationDelay) { + onNewSnapshots(applicationId, _) + } + } + + def setTotalExpectedExecutors(total: Int): Unit = totalExpectedExecutors.set(total) + + private def onNewSnapshots(applicationId: String, snapshots: Seq[ExecutorPodsSnapshot]): Unit = { + newlyCreatedExecutors --= snapshots.flatMap(_.executorPods.keys) + // For all executors we've created against the API but have not seen in a snapshot + // yet - check the current time. If the current time has exceeded some threshold, + // assume that the pod was either never created (the API server never properly + // handled the creation request), or the API server created the pod but we missed + // both the creation and deletion events. In either case, delete the missing pod + // if possible, and mark such a pod to be rescheduled below. + newlyCreatedExecutors.foreach { case (execId, timeCreated) => + val currentTime = clock.getTimeMillis() + if (currentTime - timeCreated > podCreationTimeout) { + logWarning(s"Executor with id $execId was not detected in the Kubernetes" + + s" cluster after $podCreationTimeout milliseconds despite the fact that a" + + " previous allocation attempt tried to create it. The executor may have been" + + " deleted but the application missed the deletion event.") + Utils.tryLogNonFatalError { + kubernetesClient + .pods() + .withLabel(SPARK_EXECUTOR_ID_LABEL, execId.toString) + .delete() + } + newlyCreatedExecutors -= execId + } else { + logDebug(s"Executor with id $execId was not found in the Kubernetes cluster since it" + + s" was created ${currentTime - timeCreated} milliseconds ago.") + } + } + + if (snapshots.nonEmpty) { + // Only need to examine the cluster as of the latest snapshot, the "current" state, to see if + // we need to allocate more executors or not. + val latestSnapshot = snapshots.last + val currentRunningExecutors = latestSnapshot.executorPods.values.count { + case PodRunning(_) => true + case _ => false + } + val currentPendingExecutors = latestSnapshot.executorPods.values.count { + case PodPending(_) => true + case _ => false + } + val currentTotalExpectedExecutors = totalExpectedExecutors.get + logDebug(s"Currently have $currentRunningExecutors running executors and" + + s" $currentPendingExecutors pending executors. $newlyCreatedExecutors executors" + + s" have been requested but are pending appearance in the cluster.") + if (newlyCreatedExecutors.isEmpty + && currentPendingExecutors == 0 + && currentRunningExecutors < currentTotalExpectedExecutors) { + val numExecutorsToAllocate = math.min( + currentTotalExpectedExecutors - currentRunningExecutors, podAllocationSize) + logInfo(s"Going to request $numExecutorsToAllocate executors from Kubernetes.") + for ( _ <- 0 until numExecutorsToAllocate) { + val newExecutorId = EXECUTOR_ID_COUNTER.incrementAndGet() + val executorConf = KubernetesConf.createExecutorConf( + conf, + newExecutorId.toString, + applicationId, + driverPod) + val executorPod = executorBuilder.buildFromFeatures(executorConf) + val podWithAttachedContainer = new PodBuilder(executorPod.pod) + .editOrNewSpec() + .addToContainers(executorPod.container) + .endSpec() + .build() + kubernetesClient.pods().create(podWithAttachedContainer) + newlyCreatedExecutors(newExecutorId) = clock.getTimeMillis() + logDebug(s"Requested executor with id $newExecutorId from Kubernetes.") + } + } else if (currentRunningExecutors >= currentTotalExpectedExecutors) { + // TODO handle edge cases if we end up with more running executors than expected. + logDebug("Current number of running executors is equal to the number of requested" + + " executors. Not scaling up further.") + } else if (newlyCreatedExecutors.nonEmpty || currentPendingExecutors != 0) { + logDebug(s"Still waiting for ${newlyCreatedExecutors.size + currentPendingExecutors}" + + s" executors to begin running before requesting for more executors. # of executors in" + + s" pending status in the cluster: $currentPendingExecutors. # of executors that we have" + + s" created but we have not observed as being present in the cluster yet:" + + s" ${newlyCreatedExecutors.size}.") + } + } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala new file mode 100644 index 0000000000000..b28d93990313e --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import com.google.common.cache.Cache +import io.fabric8.kubernetes.api.model.Pod +import io.fabric8.kubernetes.client.KubernetesClient +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.ExecutorExited +import org.apache.spark.util.Utils + +private[spark] class ExecutorPodsLifecycleManager( + conf: SparkConf, + executorBuilder: KubernetesExecutorBuilder, + kubernetesClient: KubernetesClient, + snapshotsStore: ExecutorPodsSnapshotsStore, + // Use a best-effort to track which executors have been removed already. It's not generally + // job-breaking if we remove executors more than once but it's ideal if we make an attempt + // to avoid doing so. Expire cache entries so that this data structure doesn't grow beyond + // bounds. + removedExecutorsCache: Cache[java.lang.Long, java.lang.Long]) extends Logging { + + import ExecutorPodsLifecycleManager._ + + private val eventProcessingInterval = conf.get(KUBERNETES_EXECUTOR_EVENT_PROCESSING_INTERVAL) + + def start(schedulerBackend: KubernetesClusterSchedulerBackend): Unit = { + snapshotsStore.addSubscriber(eventProcessingInterval) { + onNewSnapshots(schedulerBackend, _) + } + } + + private def onNewSnapshots( + schedulerBackend: KubernetesClusterSchedulerBackend, + snapshots: Seq[ExecutorPodsSnapshot]): Unit = { + val execIdsRemovedInThisRound = mutable.HashSet.empty[Long] + snapshots.foreach { snapshot => + snapshot.executorPods.foreach { case (execId, state) => + state match { + case deleted@PodDeleted(_) => + logDebug(s"Snapshot reported deleted executor with id $execId," + + s" pod name ${state.pod.getMetadata.getName}") + removeExecutorFromSpark(schedulerBackend, deleted, execId) + execIdsRemovedInThisRound += execId + case failed@PodFailed(_) => + logDebug(s"Snapshot reported failed executor with id $execId," + + s" pod name ${state.pod.getMetadata.getName}") + onFinalNonDeletedState(failed, execId, schedulerBackend, execIdsRemovedInThisRound) + case succeeded@PodSucceeded(_) => + logDebug(s"Snapshot reported succeeded executor with id $execId," + + s" pod name ${state.pod.getMetadata.getName}. Note that succeeded executors are" + + s" unusual unless Spark specifically informed the executor to exit.") + onFinalNonDeletedState(succeeded, execId, schedulerBackend, execIdsRemovedInThisRound) + case _ => + } + } + } + + // Reconcile the case where Spark claims to know about an executor but the corresponding pod + // is missing from the cluster. This would occur if we miss a deletion event and the pod + // transitions immediately from running io absent. We only need to check against the latest + // snapshot for this, and we don't do this for executors in the deleted executors cache or + // that we just removed in this round. + if (snapshots.nonEmpty) { + val latestSnapshot = snapshots.last + (schedulerBackend.getExecutorIds().map(_.toLong).toSet + -- latestSnapshot.executorPods.keySet + -- execIdsRemovedInThisRound).foreach { missingExecutorId => + if (removedExecutorsCache.getIfPresent(missingExecutorId) == null) { + val exitReasonMessage = s"The executor with ID $missingExecutorId was not found in the" + + s" cluster but we didn't get a reason why. Marking the executor as failed. The" + + s" executor may have been deleted but the driver missed the deletion event." + logDebug(exitReasonMessage) + val exitReason = ExecutorExited( + UNKNOWN_EXIT_CODE, + exitCausedByApp = false, + exitReasonMessage) + schedulerBackend.doRemoveExecutor(missingExecutorId.toString, exitReason) + execIdsRemovedInThisRound += missingExecutorId + } + } + } + logDebug(s"Removed executors with ids ${execIdsRemovedInThisRound.mkString(",")}" + + s" from Spark that were either found to be deleted or non-existent in the cluster.") + } + + private def onFinalNonDeletedState( + podState: FinalPodState, + execId: Long, + schedulerBackend: KubernetesClusterSchedulerBackend, + execIdsRemovedInRound: mutable.Set[Long]): Unit = { + removeExecutorFromK8s(podState.pod) + removeExecutorFromSpark(schedulerBackend, podState, execId) + execIdsRemovedInRound += execId + } + + private def removeExecutorFromK8s(updatedPod: Pod): Unit = { + // If deletion failed on a previous try, we can try again if resync informs us the pod + // is still around. + // Delete as best attempt - duplicate deletes will throw an exception but the end state + // of getting rid of the pod is what matters. + Utils.tryLogNonFatalError { + kubernetesClient + .pods() + .withName(updatedPod.getMetadata.getName) + .delete() + } + } + + private def removeExecutorFromSpark( + schedulerBackend: KubernetesClusterSchedulerBackend, + podState: FinalPodState, + execId: Long): Unit = { + if (removedExecutorsCache.getIfPresent(execId) == null) { + removedExecutorsCache.put(execId, execId) + val exitReason = findExitReason(podState, execId) + schedulerBackend.doRemoveExecutor(execId.toString, exitReason) + } + } + + private def findExitReason(podState: FinalPodState, execId: Long): ExecutorExited = { + val exitCode = findExitCode(podState) + val (exitCausedByApp, exitMessage) = podState match { + case PodDeleted(_) => + (false, s"The executor with id $execId was deleted by a user or the framework.") + case _ => + val msg = exitReasonMessage(podState, execId, exitCode) + (true, msg) + } + ExecutorExited(exitCode, exitCausedByApp, exitMessage) + } + + private def exitReasonMessage(podState: FinalPodState, execId: Long, exitCode: Int) = { + val pod = podState.pod + s""" + |The executor with id $execId exited with exit code $exitCode. + |The API gave the following brief reason: ${pod.getStatus.getReason} + |The API gave the following message: ${pod.getStatus.getMessage} + |The API gave the following container statuses: + | + |${pod.getStatus.getContainerStatuses.asScala.map(_.toString).mkString("\n===\n")} + """.stripMargin + } + + private def findExitCode(podState: FinalPodState): Int = { + podState.pod.getStatus.getContainerStatuses.asScala.find { containerStatus => + containerStatus.getState.getTerminated != null + }.map { terminatedContainer => + terminatedContainer.getState.getTerminated.getExitCode.toInt + }.getOrElse(UNKNOWN_EXIT_CODE) + } +} + +private object ExecutorPodsLifecycleManager { + val UNKNOWN_EXIT_CODE = -1 +} + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala new file mode 100644 index 0000000000000..e77e604d00e0f --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.{Future, ScheduledExecutorService, TimeUnit} + +import io.fabric8.kubernetes.client.KubernetesClient +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.Logging +import org.apache.spark.util.ThreadUtils + +private[spark] class ExecutorPodsPollingSnapshotSource( + conf: SparkConf, + kubernetesClient: KubernetesClient, + snapshotsStore: ExecutorPodsSnapshotsStore, + pollingExecutor: ScheduledExecutorService) extends Logging { + + private val pollingInterval = conf.get(KUBERNETES_EXECUTOR_API_POLLING_INTERVAL) + + private var pollingFuture: Future[_] = _ + + def start(applicationId: String): Unit = { + require(pollingFuture == null, "Cannot start polling more than once.") + logDebug(s"Starting to check for executor pod state every $pollingInterval ms.") + pollingFuture = pollingExecutor.scheduleWithFixedDelay( + new PollRunnable(applicationId), pollingInterval, pollingInterval, TimeUnit.MILLISECONDS) + } + + def stop(): Unit = { + if (pollingFuture != null) { + pollingFuture.cancel(true) + pollingFuture = null + } + ThreadUtils.shutdown(pollingExecutor) + } + + private class PollRunnable(applicationId: String) extends Runnable { + override def run(): Unit = { + logDebug(s"Resynchronizing full executor pod state from Kubernetes.") + snapshotsStore.replaceSnapshot(kubernetesClient + .pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId) + .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .list() + .getItems + .asScala) + } + } + +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala new file mode 100644 index 0000000000000..26be918043412 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod + +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.Logging + +/** + * An immutable view of the current executor pods that are running in the cluster. + */ +private[spark] case class ExecutorPodsSnapshot(executorPods: Map[Long, ExecutorPodState]) { + + import ExecutorPodsSnapshot._ + + def withUpdate(updatedPod: Pod): ExecutorPodsSnapshot = { + val newExecutorPods = executorPods ++ toStatesByExecutorId(Seq(updatedPod)) + new ExecutorPodsSnapshot(newExecutorPods) + } +} + +object ExecutorPodsSnapshot extends Logging { + + def apply(executorPods: Seq[Pod]): ExecutorPodsSnapshot = { + ExecutorPodsSnapshot(toStatesByExecutorId(executorPods)) + } + + def apply(): ExecutorPodsSnapshot = ExecutorPodsSnapshot(Map.empty[Long, ExecutorPodState]) + + private def toStatesByExecutorId(executorPods: Seq[Pod]): Map[Long, ExecutorPodState] = { + executorPods.map { pod => + (pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL).toLong, toState(pod)) + }.toMap + } + + private def toState(pod: Pod): ExecutorPodState = { + if (isDeleted(pod)) { + PodDeleted(pod) + } else { + val phase = pod.getStatus.getPhase.toLowerCase + phase match { + case "pending" => + PodPending(pod) + case "running" => + PodRunning(pod) + case "failed" => + PodFailed(pod) + case "succeeded" => + PodSucceeded(pod) + case _ => + logWarning(s"Received unknown phase $phase for executor pod with name" + + s" ${pod.getMetadata.getName} in namespace ${pod.getMetadata.getNamespace}") + PodUnknown(pod) + } + } + } + + private def isDeleted(pod: Pod): Boolean = pod.getMetadata.getDeletionTimestamp != null +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStore.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStore.scala new file mode 100644 index 0000000000000..dd264332cf9e8 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStore.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod + +private[spark] trait ExecutorPodsSnapshotsStore { + + def addSubscriber + (processBatchIntervalMillis: Long) + (onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit) + + def stop(): Unit + + def updatePod(updatedPod: Pod): Unit + + def replaceSnapshot(newSnapshot: Seq[Pod]): Unit +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreImpl.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreImpl.scala new file mode 100644 index 0000000000000..5583b4617eeb2 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreImpl.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent._ + +import io.fabric8.kubernetes.api.model.Pod +import javax.annotation.concurrent.GuardedBy +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.util.{ThreadUtils, Utils} + +/** + * Controls the propagation of the Spark application's executor pods state to subscribers that + * react to that state. + *
+ * Roughly follows a producer-consumer model. Producers report states of executor pods, and these + * states are then published to consumers that can perform any actions in response to these states. + *
+ * Producers push updates in one of two ways. An incremental update sent by updatePod() represents + * a known new state of a single executor pod. A full sync sent by replaceSnapshot() indicates that + * the passed pods are all of the most up to date states of all executor pods for the application. + * The combination of the states of all executor pods for the application is collectively known as + * a snapshot. The store keeps track of the most up to date snapshot, and applies updates to that + * most recent snapshot - either by incrementally updating the snapshot with a single new pod state, + * or by replacing the snapshot entirely on a full sync. + *
+ * Consumers, or subscribers, register that they want to be informed about all snapshots of the + * executor pods. Every time the store replaces its most up to date snapshot from either an + * incremental update or a full sync, the most recent snapshot after the update is posted to the + * subscriber's buffer. Subscribers receive blocks of snapshots produced by the producers in + * time-windowed chunks. Each subscriber can choose to receive their snapshot chunks at different + * time intervals. + */ +private[spark] class ExecutorPodsSnapshotsStoreImpl(subscribersExecutor: ScheduledExecutorService) + extends ExecutorPodsSnapshotsStore { + + private val SNAPSHOT_LOCK = new Object() + + private val subscribers = mutable.Buffer.empty[SnapshotsSubscriber] + private val pollingTasks = mutable.Buffer.empty[Future[_]] + + @GuardedBy("SNAPSHOT_LOCK") + private var currentSnapshot = ExecutorPodsSnapshot() + + override def addSubscriber( + processBatchIntervalMillis: Long) + (onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit): Unit = { + val newSubscriber = SnapshotsSubscriber( + new LinkedBlockingQueue[ExecutorPodsSnapshot](), onNewSnapshots) + SNAPSHOT_LOCK.synchronized { + newSubscriber.snapshotsBuffer.add(currentSnapshot) + } + subscribers += newSubscriber + pollingTasks += subscribersExecutor.scheduleWithFixedDelay( + toRunnable(() => callSubscriber(newSubscriber)), + 0L, + processBatchIntervalMillis, + TimeUnit.MILLISECONDS) + } + + override def stop(): Unit = { + pollingTasks.foreach(_.cancel(true)) + ThreadUtils.shutdown(subscribersExecutor) + } + + override def updatePod(updatedPod: Pod): Unit = SNAPSHOT_LOCK.synchronized { + currentSnapshot = currentSnapshot.withUpdate(updatedPod) + addCurrentSnapshotToSubscribers() + } + + override def replaceSnapshot(newSnapshot: Seq[Pod]): Unit = SNAPSHOT_LOCK.synchronized { + currentSnapshot = ExecutorPodsSnapshot(newSnapshot) + addCurrentSnapshotToSubscribers() + } + + private def addCurrentSnapshotToSubscribers(): Unit = { + subscribers.foreach { subscriber => + subscriber.snapshotsBuffer.add(currentSnapshot) + } + } + + private def callSubscriber(subscriber: SnapshotsSubscriber): Unit = { + Utils.tryLogNonFatalError { + val currentSnapshots = mutable.Buffer.empty[ExecutorPodsSnapshot].asJava + subscriber.snapshotsBuffer.drainTo(currentSnapshots) + subscriber.onNewSnapshots(currentSnapshots.asScala) + } + } + + private def toRunnable[T](runnable: () => Unit): Runnable = new Runnable { + override def run(): Unit = runnable() + } + + private case class SnapshotsSubscriber( + snapshotsBuffer: BlockingQueue[ExecutorPodsSnapshot], + onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit) +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala new file mode 100644 index 0000000000000..a6749a644e00c --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.io.Closeable + +import io.fabric8.kubernetes.api.model.Pod +import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action + +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +private[spark] class ExecutorPodsWatchSnapshotSource( + snapshotsStore: ExecutorPodsSnapshotsStore, + kubernetesClient: KubernetesClient) extends Logging { + + private var watchConnection: Closeable = _ + + def start(applicationId: String): Unit = { + require(watchConnection == null, "Cannot start the watcher twice.") + logDebug(s"Starting watch for pods with labels $SPARK_APP_ID_LABEL=$applicationId," + + s" $SPARK_ROLE_LABEL=$SPARK_POD_EXECUTOR_ROLE.") + watchConnection = kubernetesClient.pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId) + .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .watch(new ExecutorPodsWatcher()) + } + + def stop(): Unit = { + if (watchConnection != null) { + Utils.tryLogNonFatalError { + watchConnection.close() + } + watchConnection = null + } + } + + private class ExecutorPodsWatcher extends Watcher[Pod] { + override def eventReceived(action: Action, pod: Pod): Unit = { + val podName = pod.getMetadata.getName + logDebug(s"Received executor pod update for pod named $podName, action $action") + snapshotsStore.updatePod(pod) + } + + override def onClose(e: KubernetesClientException): Unit = { + logWarning("Kubernetes client has been closed (this is expected if the application is" + + " shutting down.)", e) + } + } + +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index 0ea80dfbc0d97..c6e931a38405f 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -17,7 +17,9 @@ package org.apache.spark.scheduler.cluster.k8s import java.io.File +import java.util.concurrent.TimeUnit +import com.google.common.cache.CacheBuilder import io.fabric8.kubernetes.client.Config import org.apache.spark.{SparkContext, SparkException} @@ -26,7 +28,7 @@ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} -import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.{SystemClock, ThreadUtils} private[spark] class KubernetesClusterManager extends ExternalClusterManager with Logging { @@ -56,17 +58,45 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) - val allocatorExecutor = ThreadUtils - .newDaemonSingleThreadScheduledExecutor("kubernetes-pod-allocator") val requestExecutorsService = ThreadUtils.newDaemonCachedThreadPool( "kubernetes-executor-requests") + + val subscribersExecutor = ThreadUtils + .newDaemonThreadPoolScheduledExecutor( + "kubernetes-executor-snapshots-subscribers", 2) + val snapshotsStore = new ExecutorPodsSnapshotsStoreImpl(subscribersExecutor) + val removedExecutorsCache = CacheBuilder.newBuilder() + .expireAfterWrite(3, TimeUnit.MINUTES) + .build[java.lang.Long, java.lang.Long]() + val executorPodsLifecycleEventHandler = new ExecutorPodsLifecycleManager( + sc.conf, + new KubernetesExecutorBuilder(), + kubernetesClient, + snapshotsStore, + removedExecutorsCache) + + val executorPodsAllocator = new ExecutorPodsAllocator( + sc.conf, new KubernetesExecutorBuilder(), kubernetesClient, snapshotsStore, new SystemClock()) + + val podsWatchEventSource = new ExecutorPodsWatchSnapshotSource( + snapshotsStore, + kubernetesClient) + + val eventsPollingExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor( + "kubernetes-executor-pod-polling-sync") + val podsPollingEventSource = new ExecutorPodsPollingSnapshotSource( + sc.conf, kubernetesClient, snapshotsStore, eventsPollingExecutor) + new KubernetesClusterSchedulerBackend( scheduler.asInstanceOf[TaskSchedulerImpl], sc.env.rpcEnv, - new KubernetesExecutorBuilder, kubernetesClient, - allocatorExecutor, - requestExecutorsService) + requestExecutorsService, + snapshotsStore, + executorPodsAllocator, + executorPodsLifecycleEventHandler, + podsWatchEventSource, + podsPollingEventSource) } override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala index d86664c81071b..fa6dc2c479bbf 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -16,60 +16,32 @@ */ package org.apache.spark.scheduler.cluster.k8s -import java.io.Closeable -import java.net.InetAddress -import java.util.concurrent.{ConcurrentHashMap, ExecutorService, ScheduledExecutorService, TimeUnit} -import java.util.concurrent.atomic.{AtomicInteger, AtomicLong, AtomicReference} -import javax.annotation.concurrent.GuardedBy +import java.util.concurrent.ExecutorService -import io.fabric8.kubernetes.api.model._ -import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watcher} -import io.fabric8.kubernetes.client.Watcher.Action -import scala.collection.JavaConverters._ -import scala.collection.mutable +import io.fabric8.kubernetes.client.KubernetesClient import scala.concurrent.{ExecutionContext, Future} -import org.apache.spark.SparkException -import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.KubernetesConf -import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress, RpcEnv} -import org.apache.spark.scheduler.{ExecutorExited, SlaveLost, TaskSchedulerImpl} +import org.apache.spark.rpc.{RpcAddress, RpcEnv} +import org.apache.spark.scheduler.{ExecutorLossReason, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SchedulerBackendUtils} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} private[spark] class KubernetesClusterSchedulerBackend( scheduler: TaskSchedulerImpl, rpcEnv: RpcEnv, - executorBuilder: KubernetesExecutorBuilder, kubernetesClient: KubernetesClient, - allocatorExecutor: ScheduledExecutorService, - requestExecutorsService: ExecutorService) + requestExecutorsService: ExecutorService, + snapshotsStore: ExecutorPodsSnapshotsStore, + podAllocator: ExecutorPodsAllocator, + lifecycleEventHandler: ExecutorPodsLifecycleManager, + watchEvents: ExecutorPodsWatchSnapshotSource, + pollEvents: ExecutorPodsPollingSnapshotSource) extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) { - import KubernetesClusterSchedulerBackend._ - - private val EXECUTOR_ID_COUNTER = new AtomicLong(0L) - private val RUNNING_EXECUTOR_PODS_LOCK = new Object - @GuardedBy("RUNNING_EXECUTOR_PODS_LOCK") - private val runningExecutorsToPods = new mutable.HashMap[String, Pod] - private val executorPodsByIPs = new ConcurrentHashMap[String, Pod]() - private val podsWithKnownExitReasons = new ConcurrentHashMap[String, ExecutorExited]() - private val disconnectedPodsByExecutorIdPendingRemoval = new ConcurrentHashMap[String, Pod]() - - private val kubernetesNamespace = conf.get(KUBERNETES_NAMESPACE) - - private val kubernetesDriverPodName = conf - .get(KUBERNETES_DRIVER_POD_NAME) - .getOrElse(throw new SparkException("Must specify the driver pod name")) private implicit val requestExecutorContext = ExecutionContext.fromExecutorService( requestExecutorsService) - private val driverPod = kubernetesClient.pods() - .inNamespace(kubernetesNamespace) - .withName(kubernetesDriverPodName) - .get() - protected override val minRegisteredRatio = if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { 0.8 @@ -77,372 +49,93 @@ private[spark] class KubernetesClusterSchedulerBackend( super.minRegisteredRatio } - private val executorWatchResource = new AtomicReference[Closeable] - private val totalExpectedExecutors = new AtomicInteger(0) - - private val driverUrl = RpcEndpointAddress( - conf.get("spark.driver.host"), - conf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString - private val initialExecutors = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf) - private val podAllocationInterval = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY) - - private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE) - - private val executorLostReasonCheckMaxAttempts = conf.get( - KUBERNETES_EXECUTOR_LOST_REASON_CHECK_MAX_ATTEMPTS) - - private val allocatorRunnable = new Runnable { - - // Maintains a map of executor id to count of checks performed to learn the loss reason - // for an executor. - private val executorReasonCheckAttemptCounts = new mutable.HashMap[String, Int] - - override def run(): Unit = { - handleDisconnectedExecutors() - - val executorsToAllocate = mutable.Map[String, Pod]() - val currentTotalRegisteredExecutors = totalRegisteredExecutors.get - val currentTotalExpectedExecutors = totalExpectedExecutors.get - val currentNodeToLocalTaskCount = getNodesWithLocalTaskCounts() - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - if (currentTotalRegisteredExecutors < runningExecutorsToPods.size) { - logDebug("Waiting for pending executors before scaling") - } else if (currentTotalExpectedExecutors <= runningExecutorsToPods.size) { - logDebug("Maximum allowed executor limit reached. Not scaling up further.") - } else { - for (_ <- 0 until math.min( - currentTotalExpectedExecutors - runningExecutorsToPods.size, podAllocationSize)) { - val executorId = EXECUTOR_ID_COUNTER.incrementAndGet().toString - val executorConf = KubernetesConf.createExecutorConf( - conf, - executorId, - applicationId(), - driverPod) - val executorPod = executorBuilder.buildFromFeatures(executorConf) - val podWithAttachedContainer = new PodBuilder(executorPod.pod) - .editOrNewSpec() - .addToContainers(executorPod.container) - .endSpec() - .build() - - executorsToAllocate(executorId) = podWithAttachedContainer - logInfo( - s"Requesting a new executor, total executors is now ${runningExecutorsToPods.size}") - } - } - } - - val allocatedExecutors = executorsToAllocate.mapValues { pod => - Utils.tryLog { - kubernetesClient.pods().create(pod) - } - } - - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - allocatedExecutors.map { - case (executorId, attemptedAllocatedExecutor) => - attemptedAllocatedExecutor.map { successfullyAllocatedExecutor => - runningExecutorsToPods.put(executorId, successfullyAllocatedExecutor) - } - } - } - } - - def handleDisconnectedExecutors(): Unit = { - // For each disconnected executor, synchronize with the loss reasons that may have been found - // by the executor pod watcher. If the loss reason was discovered by the watcher, - // inform the parent class with removeExecutor. - disconnectedPodsByExecutorIdPendingRemoval.asScala.foreach { - case (executorId, executorPod) => - val knownExitReason = Option(podsWithKnownExitReasons.remove( - executorPod.getMetadata.getName)) - knownExitReason.fold { - removeExecutorOrIncrementLossReasonCheckCount(executorId) - } { executorExited => - logWarning(s"Removing executor $executorId with loss reason " + executorExited.message) - removeExecutor(executorId, executorExited) - // We don't delete the pod running the executor that has an exit condition caused by - // the application from the Kubernetes API server. This allows users to debug later on - // through commands such as "kubectl logs " and - // "kubectl describe pod ". Note that exited containers have terminated and - // therefore won't take CPU and memory resources. - // Otherwise, the executor pod is marked to be deleted from the API server. - if (executorExited.exitCausedByApp) { - logInfo(s"Executor $executorId exited because of the application.") - deleteExecutorFromDataStructures(executorId) - } else { - logInfo(s"Executor $executorId failed because of a framework error.") - deleteExecutorFromClusterAndDataStructures(executorId) - } - } - } - } - - def removeExecutorOrIncrementLossReasonCheckCount(executorId: String): Unit = { - val reasonCheckCount = executorReasonCheckAttemptCounts.getOrElse(executorId, 0) - if (reasonCheckCount >= executorLostReasonCheckMaxAttempts) { - removeExecutor(executorId, SlaveLost("Executor lost for unknown reasons.")) - deleteExecutorFromClusterAndDataStructures(executorId) - } else { - executorReasonCheckAttemptCounts.put(executorId, reasonCheckCount + 1) - } - } - - def deleteExecutorFromClusterAndDataStructures(executorId: String): Unit = { - deleteExecutorFromDataStructures(executorId).foreach { pod => - kubernetesClient.pods().delete(pod) - } - } - - def deleteExecutorFromDataStructures(executorId: String): Option[Pod] = { - disconnectedPodsByExecutorIdPendingRemoval.remove(executorId) - executorReasonCheckAttemptCounts -= executorId - podsWithKnownExitReasons.remove(executorId) - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - runningExecutorsToPods.remove(executorId).orElse { - logWarning(s"Unable to remove pod for unknown executor $executorId") - None - } - } - } - } - - override def sufficientResourcesRegistered(): Boolean = { - totalRegisteredExecutors.get() >= initialExecutors * minRegisteredRatio + // Allow removeExecutor to be accessible by ExecutorPodsLifecycleEventHandler + private[k8s] def doRemoveExecutor(executorId: String, reason: ExecutorLossReason): Unit = { + removeExecutor(executorId, reason) } override def start(): Unit = { super.start() - executorWatchResource.set( - kubernetesClient - .pods() - .withLabel(SPARK_APP_ID_LABEL, applicationId()) - .watch(new ExecutorPodsWatcher())) - - allocatorExecutor.scheduleWithFixedDelay( - allocatorRunnable, 0L, podAllocationInterval, TimeUnit.MILLISECONDS) - if (!Utils.isDynamicAllocationEnabled(conf)) { - doRequestTotalExecutors(initialExecutors) + podAllocator.setTotalExpectedExecutors(initialExecutors) } + lifecycleEventHandler.start(this) + podAllocator.start(applicationId()) + watchEvents.start(applicationId()) + pollEvents.start(applicationId()) } override def stop(): Unit = { - // stop allocation of new resources and caches. - allocatorExecutor.shutdown() - allocatorExecutor.awaitTermination(30, TimeUnit.SECONDS) - - // send stop message to executors so they shut down cleanly super.stop() - try { - val resource = executorWatchResource.getAndSet(null) - if (resource != null) { - resource.close() - } - } catch { - case e: Throwable => logWarning("Failed to close the executor pod watcher", e) + Utils.tryLogNonFatalError { + snapshotsStore.stop() } - // then delete the executor pods Utils.tryLogNonFatalError { - deleteExecutorPodsOnStop() - executorPodsByIPs.clear() + watchEvents.stop() } + Utils.tryLogNonFatalError { - logInfo("Closing kubernetes client") - kubernetesClient.close() + pollEvents.stop() } - } - /** - * @return A map of K8s cluster nodes to the number of tasks that could benefit from data - * locality if an executor launches on the cluster node. - */ - private def getNodesWithLocalTaskCounts() : Map[String, Int] = { - val nodeToLocalTaskCount = synchronized { - mutable.Map[String, Int]() ++ hostToLocalTaskCount + Utils.tryLogNonFatalError { + kubernetesClient.pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId()) + .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .delete() } - for (pod <- executorPodsByIPs.values().asScala) { - // Remove cluster nodes that are running our executors already. - // TODO: This prefers spreading out executors across nodes. In case users want - // consolidating executors on fewer nodes, introduce a flag. See the spark.deploy.spreadOut - // flag that Spark standalone has: https://spark.apache.org/docs/latest/spark-standalone.html - nodeToLocalTaskCount.remove(pod.getSpec.getNodeName).nonEmpty || - nodeToLocalTaskCount.remove(pod.getStatus.getHostIP).nonEmpty || - nodeToLocalTaskCount.remove( - InetAddress.getByName(pod.getStatus.getHostIP).getCanonicalHostName).nonEmpty + Utils.tryLogNonFatalError { + ThreadUtils.shutdown(requestExecutorsService) + } + + Utils.tryLogNonFatalError { + kubernetesClient.close() } - nodeToLocalTaskCount.toMap[String, Int] } override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = Future[Boolean] { - totalExpectedExecutors.set(requestedTotal) + // TODO when we support dynamic allocation, the pod allocator should be told to process the + // current snapshot in order to decrease/increase the number of executors accordingly. + podAllocator.setTotalExpectedExecutors(requestedTotal) true } - override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future[Boolean] { - val podsToDelete = RUNNING_EXECUTOR_PODS_LOCK.synchronized { - executorIds.flatMap { executorId => - runningExecutorsToPods.remove(executorId) match { - case Some(pod) => - disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) - Some(pod) - - case None => - logWarning(s"Unable to remove pod for unknown executor $executorId") - None - } - } - } - - kubernetesClient.pods().delete(podsToDelete: _*) - true + override def sufficientResourcesRegistered(): Boolean = { + totalRegisteredExecutors.get() >= initialExecutors * minRegisteredRatio } - private def deleteExecutorPodsOnStop(): Unit = { - val executorPodsToDelete = RUNNING_EXECUTOR_PODS_LOCK.synchronized { - val runningExecutorPodsCopy = Seq(runningExecutorsToPods.values.toSeq: _*) - runningExecutorsToPods.clear() - runningExecutorPodsCopy - } - kubernetesClient.pods().delete(executorPodsToDelete: _*) + override def getExecutorIds(): Seq[String] = synchronized { + super.getExecutorIds() } - private class ExecutorPodsWatcher extends Watcher[Pod] { - - private val DEFAULT_CONTAINER_FAILURE_EXIT_STATUS = -1 - - override def eventReceived(action: Action, pod: Pod): Unit = { - val podName = pod.getMetadata.getName - val podIP = pod.getStatus.getPodIP - - action match { - case Action.MODIFIED if (pod.getStatus.getPhase == "Running" - && pod.getMetadata.getDeletionTimestamp == null) => - val clusterNodeName = pod.getSpec.getNodeName - logInfo(s"Executor pod $podName ready, launched at $clusterNodeName as IP $podIP.") - executorPodsByIPs.put(podIP, pod) - - case Action.DELETED | Action.ERROR => - val executorId = getExecutorId(pod) - logDebug(s"Executor pod $podName at IP $podIP was at $action.") - if (podIP != null) { - executorPodsByIPs.remove(podIP) - } - - val executorExitReason = if (action == Action.ERROR) { - logWarning(s"Received error event of executor pod $podName. Reason: " + - pod.getStatus.getReason) - executorExitReasonOnError(pod) - } else if (action == Action.DELETED) { - logWarning(s"Received delete event of executor pod $podName. Reason: " + - pod.getStatus.getReason) - executorExitReasonOnDelete(pod) - } else { - throw new IllegalStateException( - s"Unknown action that should only be DELETED or ERROR: $action") - } - podsWithKnownExitReasons.put(pod.getMetadata.getName, executorExitReason) - - if (!disconnectedPodsByExecutorIdPendingRemoval.containsKey(executorId)) { - log.warn(s"Executor with id $executorId was not marked as disconnected, but the " + - s"watch received an event of type $action for this executor. The executor may " + - "have failed to start in the first place and never registered with the driver.") - } - disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) - - case _ => logDebug(s"Received event of executor pod $podName: " + action) - } - } - - override def onClose(cause: KubernetesClientException): Unit = { - logDebug("Executor pod watch closed.", cause) - } - - private def getExecutorExitStatus(pod: Pod): Int = { - val containerStatuses = pod.getStatus.getContainerStatuses - if (!containerStatuses.isEmpty) { - // we assume the first container represents the pod status. This assumption may not hold - // true in the future. Revisit this if side-car containers start running inside executor - // pods. - getExecutorExitStatus(containerStatuses.get(0)) - } else DEFAULT_CONTAINER_FAILURE_EXIT_STATUS - } - - private def getExecutorExitStatus(containerStatus: ContainerStatus): Int = { - Option(containerStatus.getState).map { containerState => - Option(containerState.getTerminated).map { containerStateTerminated => - containerStateTerminated.getExitCode.intValue() - }.getOrElse(UNKNOWN_EXIT_CODE) - }.getOrElse(UNKNOWN_EXIT_CODE) - } - - private def isPodAlreadyReleased(pod: Pod): Boolean = { - val executorId = pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - !runningExecutorsToPods.contains(executorId) - } - } - - private def executorExitReasonOnError(pod: Pod): ExecutorExited = { - val containerExitStatus = getExecutorExitStatus(pod) - // container was probably actively killed by the driver. - if (isPodAlreadyReleased(pod)) { - ExecutorExited(containerExitStatus, exitCausedByApp = false, - s"Container in pod ${pod.getMetadata.getName} exited from explicit termination " + - "request.") - } else { - val containerExitReason = s"Pod ${pod.getMetadata.getName}'s executor container " + - s"exited with exit status code $containerExitStatus." - ExecutorExited(containerExitStatus, exitCausedByApp = true, containerExitReason) - } - } - - private def executorExitReasonOnDelete(pod: Pod): ExecutorExited = { - val exitMessage = if (isPodAlreadyReleased(pod)) { - s"Container in pod ${pod.getMetadata.getName} exited from explicit termination request." - } else { - s"Pod ${pod.getMetadata.getName} deleted or lost." - } - ExecutorExited(getExecutorExitStatus(pod), exitCausedByApp = false, exitMessage) - } - - private def getExecutorId(pod: Pod): String = { - val executorId = pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) - require(executorId != null, "Unexpected pod metadata; expected all executor pods " + - s"to have label $SPARK_EXECUTOR_ID_LABEL.") - executorId - } + override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future[Boolean] { + kubernetesClient.pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId()) + .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .withLabelIn(SPARK_EXECUTOR_ID_LABEL, executorIds: _*) + .delete() + // Don't do anything else - let event handling from the Kubernetes API do the Spark changes } override def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = { new KubernetesDriverEndpoint(rpcEnv, properties) } - private class KubernetesDriverEndpoint( - rpcEnv: RpcEnv, - sparkProperties: Seq[(String, String)]) + private class KubernetesDriverEndpoint(rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends DriverEndpoint(rpcEnv, sparkProperties) { override def onDisconnected(rpcAddress: RpcAddress): Unit = { - addressToExecutorId.get(rpcAddress).foreach { executorId => - if (disableExecutor(executorId)) { - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - runningExecutorsToPods.get(executorId).foreach { pod => - disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) - } - } - } - } + // Don't do anything besides disabling the executor - allow the Kubernetes API events to + // drive the rest of the lifecycle decisions + // TODO what if we disconnect from a networking issue? Probably want to mark the executor + // to be deleted eventually. + addressToExecutorId.get(rpcAddress).foreach(disableExecutor) } } -} -private object KubernetesClusterSchedulerBackend { - private val UNKNOWN_EXIT_CODE = -1 } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/Fabric8Aliases.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/Fabric8Aliases.scala new file mode 100644 index 0000000000000..527fc6b0d8f87 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/Fabric8Aliases.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import io.fabric8.kubernetes.api.model.{DoneablePod, HasMetadata, Pod, PodList} +import io.fabric8.kubernetes.client.{Watch, Watcher} +import io.fabric8.kubernetes.client.dsl.{FilterWatchListDeletable, MixedOperation, NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable, PodResource} + +object Fabric8Aliases { + type PODS = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] + type LABELED_PODS = FilterWatchListDeletable[ + Pod, PodList, java.lang.Boolean, Watch, Watcher[Pod]] + type SINGLE_POD = PodResource[Pod, DoneablePod] + type RESOURCE_LIST = NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable[ + HasMetadata, Boolean] +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index a8a8218c621ea..d045d9ae89c07 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -27,6 +27,7 @@ import org.scalatest.mockito.MockitoSugar._ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.Fabric8Aliases._ class ClientSuite extends SparkFunSuite with BeforeAndAfter { @@ -103,15 +104,11 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { .build() } - private type ResourceList = NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable[ - HasMetadata, Boolean] - private type Pods = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] - @Mock private var kubernetesClient: KubernetesClient = _ @Mock - private var podOperations: Pods = _ + private var podOperations: PODS = _ @Mock private var namedPods: PodResource[Pod, DoneablePod] = _ @@ -123,7 +120,7 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { private var driverBuilder: KubernetesDriverBuilder = _ @Mock - private var resourceList: ResourceList = _ + private var resourceList: RESOURCE_LIST = _ private var kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf] = _ diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/DeterministicExecutorPodsSnapshotsStore.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/DeterministicExecutorPodsSnapshotsStore.scala new file mode 100644 index 0000000000000..f7721e6fd6388 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/DeterministicExecutorPodsSnapshotsStore.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod +import scala.collection.mutable + +class DeterministicExecutorPodsSnapshotsStore extends ExecutorPodsSnapshotsStore { + + private val snapshotsBuffer = mutable.Buffer.empty[ExecutorPodsSnapshot] + private val subscribers = mutable.Buffer.empty[Seq[ExecutorPodsSnapshot] => Unit] + + private var currentSnapshot = ExecutorPodsSnapshot() + + override def addSubscriber + (processBatchIntervalMillis: Long) + (onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit): Unit = { + subscribers += onNewSnapshots + } + + override def stop(): Unit = {} + + def notifySubscribers(): Unit = { + subscribers.foreach(_(snapshotsBuffer)) + snapshotsBuffer.clear() + } + + override def updatePod(updatedPod: Pod): Unit = { + currentSnapshot = currentSnapshot.withUpdate(updatedPod) + snapshotsBuffer += currentSnapshot + } + + override def replaceSnapshot(newSnapshot: Seq[Pod]): Unit = { + currentSnapshot = ExecutorPodsSnapshot(newSnapshot) + snapshotsBuffer += currentSnapshot + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala new file mode 100644 index 0000000000000..c6b667ed85e8c --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, Pod, PodBuilder} + +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.SparkPod + +object ExecutorLifecycleTestUtils { + + val TEST_SPARK_APP_ID = "spark-app-id" + + def failedExecutorWithoutDeletion(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("failed") + .addNewContainerStatus() + .withName("spark-executor") + .withImage("k8s-spark") + .withNewState() + .withNewTerminated() + .withMessage("Failed") + .withExitCode(1) + .endTerminated() + .endState() + .endContainerStatus() + .addNewContainerStatus() + .withName("spark-executor-sidecar") + .withImage("k8s-spark-sidecar") + .withNewState() + .withNewTerminated() + .withMessage("Failed") + .withExitCode(1) + .endTerminated() + .endState() + .endContainerStatus() + .withMessage("Executor failed.") + .withReason("Executor failed because of a thrown error.") + .endStatus() + .build() + } + + def pendingExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("pending") + .endStatus() + .build() + } + + def runningExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("running") + .endStatus() + .build() + } + + def succeededExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("succeeded") + .endStatus() + .build() + } + + def deletedExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewMetadata() + .withNewDeletionTimestamp("523012521") + .endMetadata() + .build() + } + + def unknownExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("unknown") + .endStatus() + .build() + } + + def podWithAttachedContainerForId(executorId: Long): Pod = { + val sparkPod = executorPodWithId(executorId) + val podWithAttachedContainer = new PodBuilder(sparkPod.pod) + .editOrNewSpec() + .addToContainers(sparkPod.container) + .endSpec() + .build() + podWithAttachedContainer + } + + def executorPodWithId(executorId: Long): SparkPod = { + val pod = new PodBuilder() + .withNewMetadata() + .withName(s"spark-executor-$executorId") + .addToLabels(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID) + .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .addToLabels(SPARK_EXECUTOR_ID_LABEL, executorId.toString) + .endMetadata() + .build() + val container = new ContainerBuilder() + .withName("spark-executor") + .withImage("k8s-spark") + .build() + SparkPod(pod, container) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala new file mode 100644 index 0000000000000..0c19f5946b75f --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.{DoneablePod, Pod, PodBuilder} +import io.fabric8.kubernetes.client.KubernetesClient +import io.fabric8.kubernetes.client.dsl.PodResource +import org.mockito.{ArgumentMatcher, Matchers, Mock, MockitoAnnotations} +import org.mockito.Matchers.any +import org.mockito.Mockito.{never, times, verify, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ +import org.apache.spark.util.ManualClock + +class ExecutorPodsAllocatorSuite extends SparkFunSuite with BeforeAndAfter { + + private val driverPodName = "driver" + + private val driverPod = new PodBuilder() + .withNewMetadata() + .withName(driverPodName) + .addToLabels(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID) + .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_DRIVER_ROLE) + .withUid("driver-pod-uid") + .endMetadata() + .build() + + private val conf = new SparkConf().set(KUBERNETES_DRIVER_POD_NAME, driverPodName) + + private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE) + private val podAllocationDelay = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY) + private val podCreationTimeout = math.max(podAllocationDelay * 5, 60000L) + + private var waitForExecutorPodsClock: ManualClock = _ + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var labeledPods: LABELED_PODS = _ + + @Mock + private var driverPodOperations: PodResource[Pod, DoneablePod] = _ + + @Mock + private var executorBuilder: KubernetesExecutorBuilder = _ + + private var snapshotsStore: DeterministicExecutorPodsSnapshotsStore = _ + + private var podsAllocatorUnderTest: ExecutorPodsAllocator = _ + + before { + MockitoAnnotations.initMocks(this) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withName(driverPodName)).thenReturn(driverPodOperations) + when(driverPodOperations.get).thenReturn(driverPod) + when(executorBuilder.buildFromFeatures(kubernetesConfWithCorrectFields())) + .thenAnswer(executorPodAnswer()) + snapshotsStore = new DeterministicExecutorPodsSnapshotsStore() + waitForExecutorPodsClock = new ManualClock(0L) + podsAllocatorUnderTest = new ExecutorPodsAllocator( + conf, executorBuilder, kubernetesClient, snapshotsStore, waitForExecutorPodsClock) + podsAllocatorUnderTest.start(TEST_SPARK_APP_ID) + } + + test("Initially request executors in batches. Do not request another batch if the" + + " first has not finished.") { + podsAllocatorUnderTest.setTotalExpectedExecutors(podAllocationSize + 1) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + for (nextId <- 1 to podAllocationSize) { + verify(podOperations).create(podWithAttachedContainerForId(nextId)) + } + verify(podOperations, never()).create(podWithAttachedContainerForId(podAllocationSize + 1)) + } + + test("Request executors in batches. Allow another batch to be requested if" + + " all pending executors start running.") { + podsAllocatorUnderTest.setTotalExpectedExecutors(podAllocationSize + 1) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + for (execId <- 1 until podAllocationSize) { + snapshotsStore.updatePod(runningExecutor(execId)) + } + snapshotsStore.notifySubscribers() + verify(podOperations, never()).create(podWithAttachedContainerForId(podAllocationSize + 1)) + snapshotsStore.updatePod(runningExecutor(podAllocationSize)) + snapshotsStore.notifySubscribers() + verify(podOperations).create(podWithAttachedContainerForId(podAllocationSize + 1)) + snapshotsStore.updatePod(runningExecutor(podAllocationSize)) + snapshotsStore.notifySubscribers() + verify(podOperations, times(podAllocationSize + 1)).create(any(classOf[Pod])) + } + + test("When a current batch reaches error states immediately, re-request" + + " them on the next batch.") { + podsAllocatorUnderTest.setTotalExpectedExecutors(podAllocationSize) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + for (execId <- 1 until podAllocationSize) { + snapshotsStore.updatePod(runningExecutor(execId)) + } + val failedPod = failedExecutorWithoutDeletion(podAllocationSize) + snapshotsStore.updatePod(failedPod) + snapshotsStore.notifySubscribers() + verify(podOperations).create(podWithAttachedContainerForId(podAllocationSize + 1)) + } + + test("When an executor is requested but the API does not report it in a reasonable time, retry" + + " requesting that executor.") { + podsAllocatorUnderTest.setTotalExpectedExecutors(1) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + waitForExecutorPodsClock.setTime(podCreationTimeout + 1) + when(podOperations.withLabel(SPARK_EXECUTOR_ID_LABEL, "1")).thenReturn(labeledPods) + snapshotsStore.notifySubscribers() + verify(labeledPods).delete() + verify(podOperations).create(podWithAttachedContainerForId(2)) + } + + private def executorPodAnswer(): Answer[SparkPod] = { + new Answer[SparkPod] { + override def answer(invocation: InvocationOnMock): SparkPod = { + val k8sConf = invocation.getArgumentAt( + 0, classOf[KubernetesConf[KubernetesExecutorSpecificConf]]) + executorPodWithId(k8sConf.roleSpecificConf.executorId.toInt) + } + } + } + + private def kubernetesConfWithCorrectFields(): KubernetesConf[KubernetesExecutorSpecificConf] = + Matchers.argThat(new ArgumentMatcher[KubernetesConf[KubernetesExecutorSpecificConf]] { + override def matches(argument: scala.Any): Boolean = { + if (!argument.isInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]]) { + false + } else { + val k8sConf = argument.asInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] + val executorSpecificConf = k8sConf.roleSpecificConf + val expectedK8sConf = KubernetesConf.createExecutorConf( + conf, + executorSpecificConf.executorId, + TEST_SPARK_APP_ID, + driverPod) + k8sConf.sparkConf.getAll.toMap == conf.getAll.toMap && + // Since KubernetesConf.createExecutorConf clones the SparkConf object, force + // deep equality comparison for the SparkConf object and use object equality + // comparison on all other fields. + k8sConf.copy(sparkConf = conf) == expectedK8sConf.copy(sparkConf = conf) + } + } + }) + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala new file mode 100644 index 0000000000000..562ace9f49d4d --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import com.google.common.cache.CacheBuilder +import io.fabric8.kubernetes.api.model.{DoneablePod, Pod} +import io.fabric8.kubernetes.client.KubernetesClient +import io.fabric8.kubernetes.client.dsl.PodResource +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Matchers.any +import org.mockito.Mockito.{mock, times, verify, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfter +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.scheduler.ExecutorExited +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ + +class ExecutorPodsLifecycleManagerSuite extends SparkFunSuite with BeforeAndAfter { + + private var namedExecutorPods: mutable.Map[String, PodResource[Pod, DoneablePod]] = _ + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var executorBuilder: KubernetesExecutorBuilder = _ + + @Mock + private var schedulerBackend: KubernetesClusterSchedulerBackend = _ + + private var snapshotsStore: DeterministicExecutorPodsSnapshotsStore = _ + private var eventHandlerUnderTest: ExecutorPodsLifecycleManager = _ + + before { + MockitoAnnotations.initMocks(this) + val removedExecutorsCache = CacheBuilder.newBuilder().build[java.lang.Long, java.lang.Long] + snapshotsStore = new DeterministicExecutorPodsSnapshotsStore() + namedExecutorPods = mutable.Map.empty[String, PodResource[Pod, DoneablePod]] + when(schedulerBackend.getExecutorIds()).thenReturn(Seq.empty[String]) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withName(any(classOf[String]))).thenAnswer(namedPodsAnswer()) + eventHandlerUnderTest = new ExecutorPodsLifecycleManager( + new SparkConf(), + executorBuilder, + kubernetesClient, + snapshotsStore, + removedExecutorsCache) + eventHandlerUnderTest.start(schedulerBackend) + } + + test("When an executor reaches error states immediately, remove from the scheduler backend.") { + val failedPod = failedExecutorWithoutDeletion(1) + snapshotsStore.updatePod(failedPod) + snapshotsStore.notifySubscribers() + val msg = exitReasonMessage(1, failedPod) + val expectedLossReason = ExecutorExited(1, exitCausedByApp = true, msg) + verify(schedulerBackend).doRemoveExecutor("1", expectedLossReason) + verify(namedExecutorPods(failedPod.getMetadata.getName)).delete() + } + + test("Don't remove executors twice from Spark but remove from K8s repeatedly.") { + val failedPod = failedExecutorWithoutDeletion(1) + snapshotsStore.updatePod(failedPod) + snapshotsStore.updatePod(failedPod) + snapshotsStore.notifySubscribers() + val msg = exitReasonMessage(1, failedPod) + val expectedLossReason = ExecutorExited(1, exitCausedByApp = true, msg) + verify(schedulerBackend, times(1)).doRemoveExecutor("1", expectedLossReason) + verify(namedExecutorPods(failedPod.getMetadata.getName), times(2)).delete() + } + + test("When the scheduler backend lists executor ids that aren't present in the cluster," + + " remove those executors from Spark.") { + when(schedulerBackend.getExecutorIds()).thenReturn(Seq("1")) + val msg = s"The executor with ID 1 was not found in the cluster but we didn't" + + s" get a reason why. Marking the executor as failed. The executor may have been" + + s" deleted but the driver missed the deletion event." + val expectedLossReason = ExecutorExited(-1, exitCausedByApp = false, msg) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + verify(schedulerBackend).doRemoveExecutor("1", expectedLossReason) + } + + private def exitReasonMessage(failedExecutorId: Int, failedPod: Pod): String = { + s""" + |The executor with id $failedExecutorId exited with exit code 1. + |The API gave the following brief reason: ${failedPod.getStatus.getReason} + |The API gave the following message: ${failedPod.getStatus.getMessage} + |The API gave the following container statuses: + | + |${failedPod.getStatus.getContainerStatuses.asScala.map(_.toString).mkString("\n===\n")} + """.stripMargin + } + + private def namedPodsAnswer(): Answer[PodResource[Pod, DoneablePod]] = { + new Answer[PodResource[Pod, DoneablePod]] { + override def answer(invocation: InvocationOnMock): PodResource[Pod, DoneablePod] = { + val podName = invocation.getArgumentAt(0, classOf[String]) + namedExecutorPods.getOrElseUpdate( + podName, mock(classOf[PodResource[Pod, DoneablePod]])) + } + } + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSourceSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSourceSuite.scala new file mode 100644 index 0000000000000..1b26d6af296a5 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSourceSuite.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.TimeUnit + +import io.fabric8.kubernetes.api.model.PodListBuilder +import io.fabric8.kubernetes.client.KubernetesClient +import org.jmock.lib.concurrent.DeterministicScheduler +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Mockito.{verify, when} +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ + +class ExecutorPodsPollingSnapshotSourceSuite extends SparkFunSuite with BeforeAndAfter { + + private val sparkConf = new SparkConf + + private val pollingInterval = sparkConf.get(KUBERNETES_EXECUTOR_API_POLLING_INTERVAL) + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var appIdLabeledPods: LABELED_PODS = _ + + @Mock + private var executorRoleLabeledPods: LABELED_PODS = _ + + @Mock + private var eventQueue: ExecutorPodsSnapshotsStore = _ + + private var pollingExecutor: DeterministicScheduler = _ + private var pollingSourceUnderTest: ExecutorPodsPollingSnapshotSource = _ + + before { + MockitoAnnotations.initMocks(this) + pollingExecutor = new DeterministicScheduler() + pollingSourceUnderTest = new ExecutorPodsPollingSnapshotSource( + sparkConf, + kubernetesClient, + eventQueue, + pollingExecutor) + pollingSourceUnderTest.start(TEST_SPARK_APP_ID) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)) + .thenReturn(appIdLabeledPods) + when(appIdLabeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)) + .thenReturn(executorRoleLabeledPods) + } + + test("Items returned by the API should be pushed to the event queue") { + when(executorRoleLabeledPods.list()) + .thenReturn(new PodListBuilder() + .addToItems( + runningExecutor(1), + runningExecutor(2)) + .build()) + pollingExecutor.tick(pollingInterval, TimeUnit.MILLISECONDS) + verify(eventQueue).replaceSnapshot(Seq(runningExecutor(1), runningExecutor(2))) + + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotSuite.scala new file mode 100644 index 0000000000000..70e19c904eddb --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import org.apache.spark.SparkFunSuite +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ + +class ExecutorPodsSnapshotSuite extends SparkFunSuite { + + test("States are interpreted correctly from pod metadata.") { + val pods = Seq( + pendingExecutor(0), + runningExecutor(1), + succeededExecutor(2), + failedExecutorWithoutDeletion(3), + deletedExecutor(4), + unknownExecutor(5)) + val snapshot = ExecutorPodsSnapshot(pods) + assert(snapshot.executorPods === + Map( + 0L -> PodPending(pods(0)), + 1L -> PodRunning(pods(1)), + 2L -> PodSucceeded(pods(2)), + 3L -> PodFailed(pods(3)), + 4L -> PodDeleted(pods(4)), + 5L -> PodUnknown(pods(5)))) + } + + test("Updates add new pods for non-matching ids and edit existing pods for matching ids") { + val originalPods = Seq( + pendingExecutor(0), + runningExecutor(1)) + val originalSnapshot = ExecutorPodsSnapshot(originalPods) + val snapshotWithUpdatedPod = originalSnapshot.withUpdate(succeededExecutor(1)) + assert(snapshotWithUpdatedPod.executorPods === + Map( + 0L -> PodPending(originalPods(0)), + 1L -> PodSucceeded(succeededExecutor(1)))) + val snapshotWithNewPod = snapshotWithUpdatedPod.withUpdate(pendingExecutor(2)) + assert(snapshotWithNewPod.executorPods === + Map( + 0L -> PodPending(originalPods(0)), + 1L -> PodSucceeded(succeededExecutor(1)), + 2L -> PodPending(pendingExecutor(2)))) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreSuite.scala new file mode 100644 index 0000000000000..cf54b3c4eb329 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicReference + +import io.fabric8.kubernetes.api.model.{Pod, PodBuilder} +import org.jmock.lib.concurrent.DeterministicScheduler +import org.scalatest.BeforeAndAfter +import scala.collection.mutable + +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.k8s.Constants._ + +class ExecutorPodsSnapshotsStoreSuite extends SparkFunSuite with BeforeAndAfter { + + private var eventBufferScheduler: DeterministicScheduler = _ + private var eventQueueUnderTest: ExecutorPodsSnapshotsStoreImpl = _ + + before { + eventBufferScheduler = new DeterministicScheduler() + eventQueueUnderTest = new ExecutorPodsSnapshotsStoreImpl(eventBufferScheduler) + } + + test("Subscribers get notified of events periodically.") { + val receivedSnapshots1 = mutable.Buffer.empty[ExecutorPodsSnapshot] + val receivedSnapshots2 = mutable.Buffer.empty[ExecutorPodsSnapshot] + eventQueueUnderTest.addSubscriber(1000) { + receivedSnapshots1 ++= _ + } + eventQueueUnderTest.addSubscriber(2000) { + receivedSnapshots2 ++= _ + } + + eventBufferScheduler.runUntilIdle() + assert(receivedSnapshots1 === Seq(ExecutorPodsSnapshot())) + assert(receivedSnapshots2 === Seq(ExecutorPodsSnapshot())) + + pushPodWithIndex(1) + // Force time to move forward so that the buffer is emitted, scheduling the + // processing task on the subscription executor... + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + // ... then actually execute the subscribers. + + assert(receivedSnapshots1 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + assert(receivedSnapshots2 === Seq(ExecutorPodsSnapshot())) + + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + + // Don't repeat snapshots + assert(receivedSnapshots1 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + assert(receivedSnapshots2 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + pushPodWithIndex(2) + pushPodWithIndex(3) + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + + assert(receivedSnapshots1 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))), + ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2))), + ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2), podWithIndex(3))))) + assert(receivedSnapshots2 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + assert(receivedSnapshots1 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))), + ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2))), + ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2), podWithIndex(3))))) + assert(receivedSnapshots1 === receivedSnapshots2) + } + + test("Even without sending events, initially receive an empty buffer.") { + val receivedInitialSnapshot = new AtomicReference[Seq[ExecutorPodsSnapshot]](null) + eventQueueUnderTest.addSubscriber(1000) { + receivedInitialSnapshot.set + } + assert(receivedInitialSnapshot.get == null) + eventBufferScheduler.runUntilIdle() + assert(receivedInitialSnapshot.get === Seq(ExecutorPodsSnapshot())) + } + + test("Replacing the snapshot passes the new snapshot to subscribers.") { + val receivedSnapshots = mutable.Buffer.empty[ExecutorPodsSnapshot] + eventQueueUnderTest.addSubscriber(1000) { + receivedSnapshots ++= _ + } + eventQueueUnderTest.updatePod(podWithIndex(1)) + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + assert(receivedSnapshots === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + eventQueueUnderTest.replaceSnapshot(Seq(podWithIndex(2))) + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + assert(receivedSnapshots === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))), + ExecutorPodsSnapshot(Seq(podWithIndex(2))))) + } + + private def pushPodWithIndex(index: Int): Unit = + eventQueueUnderTest.updatePod(podWithIndex(index)) + + private def podWithIndex(index: Int): Pod = + new PodBuilder() + .editOrNewMetadata() + .withName(s"pod-$index") + .addToLabels(SPARK_EXECUTOR_ID_LABEL, index.toString) + .endMetadata() + .editOrNewStatus() + .withPhase("running") + .endStatus() + .build() +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSourceSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSourceSuite.scala new file mode 100644 index 0000000000000..ac1968b4ff810 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSourceSuite.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod +import io.fabric8.kubernetes.client.{KubernetesClient, Watch, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action +import org.mockito.{ArgumentCaptor, Mock, MockitoAnnotations} +import org.mockito.Mockito.{verify, when} +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ + +class ExecutorPodsWatchSnapshotSourceSuite extends SparkFunSuite with BeforeAndAfter { + + @Mock + private var eventQueue: ExecutorPodsSnapshotsStore = _ + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var appIdLabeledPods: LABELED_PODS = _ + + @Mock + private var executorRoleLabeledPods: LABELED_PODS = _ + + @Mock + private var watchConnection: Watch = _ + + private var watch: ArgumentCaptor[Watcher[Pod]] = _ + + private var watchSourceUnderTest: ExecutorPodsWatchSnapshotSource = _ + + before { + MockitoAnnotations.initMocks(this) + watch = ArgumentCaptor.forClass(classOf[Watcher[Pod]]) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)) + .thenReturn(appIdLabeledPods) + when(appIdLabeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)) + .thenReturn(executorRoleLabeledPods) + when(executorRoleLabeledPods.watch(watch.capture())).thenReturn(watchConnection) + watchSourceUnderTest = new ExecutorPodsWatchSnapshotSource( + eventQueue, kubernetesClient) + watchSourceUnderTest.start(TEST_SPARK_APP_ID) + } + + test("Watch events should be pushed to the snapshots store as snapshot updates.") { + watch.getValue.eventReceived(Action.ADDED, runningExecutor(1)) + watch.getValue.eventReceived(Action.MODIFIED, runningExecutor(2)) + verify(eventQueue).updatePod(runningExecutor(1)) + verify(eventQueue).updatePod(runningExecutor(2)) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala index 96065e83f069c..52e7a12dbaf06 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala @@ -16,85 +16,36 @@ */ package org.apache.spark.scheduler.cluster.k8s -import java.util.concurrent.{ExecutorService, ScheduledExecutorService, TimeUnit} - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, DoneablePod, Pod, PodBuilder, PodList} -import io.fabric8.kubernetes.client.{KubernetesClient, Watch, Watcher} -import io.fabric8.kubernetes.client.Watcher.Action -import io.fabric8.kubernetes.client.dsl.{FilterWatchListDeletable, MixedOperation, NonNamespaceOperation, PodResource} -import org.hamcrest.{BaseMatcher, Description, Matcher} -import org.mockito.{AdditionalAnswers, ArgumentCaptor, Matchers, Mock, MockitoAnnotations} -import org.mockito.Matchers.{any, eq => mockitoEq} -import org.mockito.Mockito.{doNothing, never, times, verify, when} +import io.fabric8.kubernetes.client.KubernetesClient +import org.jmock.lib.concurrent.DeterministicScheduler +import org.mockito.{ArgumentCaptor, Mock, MockitoAnnotations} +import org.mockito.Matchers.{eq => mockitoEq} +import org.mockito.Mockito.{never, verify, when} import org.scalatest.BeforeAndAfter -import org.scalatest.mockito.MockitoSugar._ -import scala.collection.JavaConverters._ -import scala.concurrent.Future import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.rpc._ -import org.apache.spark.scheduler.{ExecutorExited, LiveListenerBus, SlaveLost, TaskSchedulerImpl} -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{RegisterExecutor, RemoveExecutor} +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} +import org.apache.spark.scheduler.{ExecutorKilled, TaskSchedulerImpl} +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.ThreadUtils +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils.TEST_SPARK_APP_ID class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAndAfter { - private val APP_ID = "test-spark-app" - private val DRIVER_POD_NAME = "spark-driver-pod" - private val NAMESPACE = "test-namespace" - private val SPARK_DRIVER_HOST = "localhost" - private val SPARK_DRIVER_PORT = 7077 - private val POD_ALLOCATION_INTERVAL = "1m" - private val FIRST_EXECUTOR_POD = new PodBuilder() - .withNewMetadata() - .withName("pod1") - .endMetadata() - .withNewSpec() - .withNodeName("node1") - .endSpec() - .withNewStatus() - .withHostIP("192.168.99.100") - .endStatus() - .build() - private val SECOND_EXECUTOR_POD = new PodBuilder() - .withNewMetadata() - .withName("pod2") - .endMetadata() - .withNewSpec() - .withNodeName("node2") - .endSpec() - .withNewStatus() - .withHostIP("192.168.99.101") - .endStatus() - .build() - - private type PODS = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] - private type LABELED_PODS = FilterWatchListDeletable[ - Pod, PodList, java.lang.Boolean, Watch, Watcher[Pod]] - private type IN_NAMESPACE_PODS = NonNamespaceOperation[ - Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] - - @Mock - private var sparkContext: SparkContext = _ - - @Mock - private var listenerBus: LiveListenerBus = _ - - @Mock - private var taskSchedulerImpl: TaskSchedulerImpl = _ + private val requestExecutorsService = new DeterministicScheduler() + private val sparkConf = new SparkConf(false) + .set("spark.executor.instances", "3") @Mock - private var allocatorExecutor: ScheduledExecutorService = _ + private var sc: SparkContext = _ @Mock - private var requestExecutorsService: ExecutorService = _ + private var rpcEnv: RpcEnv = _ @Mock - private var executorBuilder: KubernetesExecutorBuilder = _ + private var driverEndpointRef: RpcEndpointRef = _ @Mock private var kubernetesClient: KubernetesClient = _ @@ -103,347 +54,97 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn private var podOperations: PODS = _ @Mock - private var podsWithLabelOperations: LABELED_PODS = _ + private var labeledPods: LABELED_PODS = _ @Mock - private var podsInNamespace: IN_NAMESPACE_PODS = _ + private var taskScheduler: TaskSchedulerImpl = _ @Mock - private var podsWithDriverName: PodResource[Pod, DoneablePod] = _ + private var eventQueue: ExecutorPodsSnapshotsStore = _ @Mock - private var rpcEnv: RpcEnv = _ + private var podAllocator: ExecutorPodsAllocator = _ @Mock - private var driverEndpointRef: RpcEndpointRef = _ + private var lifecycleEventHandler: ExecutorPodsLifecycleManager = _ @Mock - private var executorPodsWatch: Watch = _ + private var watchEvents: ExecutorPodsWatchSnapshotSource = _ @Mock - private var successFuture: Future[Boolean] = _ + private var pollEvents: ExecutorPodsPollingSnapshotSource = _ - private var sparkConf: SparkConf = _ - private var executorPodsWatcherArgument: ArgumentCaptor[Watcher[Pod]] = _ - private var allocatorRunnable: ArgumentCaptor[Runnable] = _ - private var requestExecutorRunnable: ArgumentCaptor[Runnable] = _ private var driverEndpoint: ArgumentCaptor[RpcEndpoint] = _ - - private val driverPod = new PodBuilder() - .withNewMetadata() - .withName(DRIVER_POD_NAME) - .addToLabels(SPARK_APP_ID_LABEL, APP_ID) - .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_DRIVER_ROLE) - .endMetadata() - .build() + private var schedulerBackendUnderTest: KubernetesClusterSchedulerBackend = _ before { MockitoAnnotations.initMocks(this) - sparkConf = new SparkConf() - .set(KUBERNETES_DRIVER_POD_NAME, DRIVER_POD_NAME) - .set(KUBERNETES_NAMESPACE, NAMESPACE) - .set("spark.driver.host", SPARK_DRIVER_HOST) - .set("spark.driver.port", SPARK_DRIVER_PORT.toString) - .set(KUBERNETES_ALLOCATION_BATCH_DELAY.key, POD_ALLOCATION_INTERVAL) - executorPodsWatcherArgument = ArgumentCaptor.forClass(classOf[Watcher[Pod]]) - allocatorRunnable = ArgumentCaptor.forClass(classOf[Runnable]) - requestExecutorRunnable = ArgumentCaptor.forClass(classOf[Runnable]) + when(taskScheduler.sc).thenReturn(sc) + when(sc.conf).thenReturn(sparkConf) driverEndpoint = ArgumentCaptor.forClass(classOf[RpcEndpoint]) - when(sparkContext.conf).thenReturn(sparkConf) - when(sparkContext.listenerBus).thenReturn(listenerBus) - when(taskSchedulerImpl.sc).thenReturn(sparkContext) - when(kubernetesClient.pods()).thenReturn(podOperations) - when(podOperations.withLabel(SPARK_APP_ID_LABEL, APP_ID)).thenReturn(podsWithLabelOperations) - when(podsWithLabelOperations.watch(executorPodsWatcherArgument.capture())) - .thenReturn(executorPodsWatch) - when(podOperations.inNamespace(NAMESPACE)).thenReturn(podsInNamespace) - when(podsInNamespace.withName(DRIVER_POD_NAME)).thenReturn(podsWithDriverName) - when(podsWithDriverName.get()).thenReturn(driverPod) - when(allocatorExecutor.scheduleWithFixedDelay( - allocatorRunnable.capture(), - mockitoEq(0L), - mockitoEq(TimeUnit.MINUTES.toMillis(1)), - mockitoEq(TimeUnit.MILLISECONDS))).thenReturn(null) - // Creating Futures in Scala backed by a Java executor service resolves to running - // ExecutorService#execute (as opposed to submit) - doNothing().when(requestExecutorsService).execute(requestExecutorRunnable.capture()) when(rpcEnv.setupEndpoint( mockitoEq(CoarseGrainedSchedulerBackend.ENDPOINT_NAME), driverEndpoint.capture())) .thenReturn(driverEndpointRef) - - // Used by the CoarseGrainedSchedulerBackend when making RPC calls. - when(driverEndpointRef.ask[Boolean] - (any(classOf[Any])) - (any())).thenReturn(successFuture) - when(successFuture.failed).thenReturn(Future[Throwable] { - // emulate behavior of the Future.failed method. - throw new NoSuchElementException() - }(ThreadUtils.sameThread)) - } - - test("Basic lifecycle expectations when starting and stopping the scheduler.") { - val scheduler = newSchedulerBackend() - scheduler.start() - assert(executorPodsWatcherArgument.getValue != null) - assert(allocatorRunnable.getValue != null) - scheduler.stop() - verify(executorPodsWatch).close() - } - - test("Static allocation should request executors upon first allocator run.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) - val scheduler = newSchedulerBackend() - scheduler.start() - requestExecutorRunnable.getValue.run() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - allocatorRunnable.getValue.run() - verify(podOperations).create(firstResolvedPod) - verify(podOperations).create(secondResolvedPod) - } - - test("Killing executors deletes the executor pods") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) - val scheduler = newSchedulerBackend() - scheduler.start() - requestExecutorRunnable.getValue.run() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - when(podOperations.create(any(classOf[Pod]))) - .thenAnswer(AdditionalAnswers.returnsFirstArg()) - allocatorRunnable.getValue.run() - scheduler.doKillExecutors(Seq("2")) - requestExecutorRunnable.getAllValues.asScala.last.run() - verify(podOperations).delete(secondResolvedPod) - verify(podOperations, never()).delete(firstResolvedPod) - } - - test("Executors should be requested in batches.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) - val scheduler = newSchedulerBackend() - scheduler.start() - requestExecutorRunnable.getValue.run() - when(podOperations.create(any(classOf[Pod]))) - .thenAnswer(AdditionalAnswers.returnsFirstArg()) - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - allocatorRunnable.getValue.run() - verify(podOperations).create(firstResolvedPod) - verify(podOperations, never()).create(secondResolvedPod) - val registerFirstExecutorMessage = RegisterExecutor( - "1", mock[RpcEndpointRef], "localhost", 1, Map.empty[String, String]) - when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) - driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) - .apply(registerFirstExecutorMessage) - allocatorRunnable.getValue.run() - verify(podOperations).create(secondResolvedPod) - } - - test("Scaled down executors should be cleaned up") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val scheduler = newSchedulerBackend() - scheduler.start() - - // The scheduler backend spins up one executor pod. - requestExecutorRunnable.getValue.run() - when(podOperations.create(any(classOf[Pod]))) - .thenAnswer(AdditionalAnswers.returnsFirstArg()) - val resolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - allocatorRunnable.getValue.run() - val executorEndpointRef = mock[RpcEndpointRef] - when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) - val registerFirstExecutorMessage = RegisterExecutor( - "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) - when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) - driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) - .apply(registerFirstExecutorMessage) - - // Request that there are 0 executors and trigger deletion from driver. - scheduler.doRequestTotalExecutors(0) - requestExecutorRunnable.getAllValues.asScala.last.run() - scheduler.doKillExecutors(Seq("1")) - requestExecutorRunnable.getAllValues.asScala.last.run() - verify(podOperations, times(1)).delete(resolvedPod) - driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) - - val exitedPod = exitPod(resolvedPod, 0) - executorPodsWatcherArgument.getValue.eventReceived(Action.DELETED, exitedPod) - allocatorRunnable.getValue.run() - - // No more deletion attempts of the executors. - // This is graceful termination and should not be detected as a failure. - verify(podOperations, times(1)).delete(resolvedPod) - verify(driverEndpointRef, times(1)).send( - RemoveExecutor("1", ExecutorExited( - 0, - exitCausedByApp = false, - s"Container in pod ${exitedPod.getMetadata.getName} exited from" + - s" explicit termination request."))) - } - - test("Executors that fail should not be deleted.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - - val scheduler = newSchedulerBackend() - scheduler.start() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getValue.run() - val executorEndpointRef = mock[RpcEndpointRef] - when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) - val registerFirstExecutorMessage = RegisterExecutor( - "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) - when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) - driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) - .apply(registerFirstExecutorMessage) - driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) - executorPodsWatcherArgument.getValue.eventReceived( - Action.ERROR, exitPod(firstResolvedPod, 1)) - - // A replacement executor should be created but the error pod should persist. - val replacementPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - scheduler.doRequestTotalExecutors(1) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getAllValues.asScala.last.run() - verify(podOperations, never()).delete(firstResolvedPod) - verify(driverEndpointRef).send( - RemoveExecutor("1", ExecutorExited( - 1, - exitCausedByApp = true, - s"Pod ${FIRST_EXECUTOR_POD.getMetadata.getName}'s executor container exited with" + - " exit status code 1."))) - } - - test("Executors disconnected due to unknown reasons are deleted and replaced.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val executorLostReasonCheckMaxAttempts = sparkConf.get( - KUBERNETES_EXECUTOR_LOST_REASON_CHECK_MAX_ATTEMPTS) - - val scheduler = newSchedulerBackend() - scheduler.start() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getValue.run() - val executorEndpointRef = mock[RpcEndpointRef] - when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) - val registerFirstExecutorMessage = RegisterExecutor( - "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) - when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) - driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) - .apply(registerFirstExecutorMessage) - - driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) - 1 to executorLostReasonCheckMaxAttempts foreach { _ => - allocatorRunnable.getValue.run() - verify(podOperations, never()).delete(FIRST_EXECUTOR_POD) + when(kubernetesClient.pods()).thenReturn(podOperations) + schedulerBackendUnderTest = new KubernetesClusterSchedulerBackend( + taskScheduler, + rpcEnv, + kubernetesClient, + requestExecutorsService, + eventQueue, + podAllocator, + lifecycleEventHandler, + watchEvents, + pollEvents) { + override def applicationId(): String = TEST_SPARK_APP_ID } - - val recreatedResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - allocatorRunnable.getValue.run() - verify(podOperations).delete(firstResolvedPod) - verify(driverEndpointRef).send( - RemoveExecutor("1", SlaveLost("Executor lost for unknown reasons."))) } - test("Executors that fail to start on the Kubernetes API call rebuild in the next batch.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val scheduler = newSchedulerBackend() - scheduler.start() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - when(podOperations.create(firstResolvedPod)) - .thenThrow(new RuntimeException("test")) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getValue.run() - verify(podOperations, times(1)).create(firstResolvedPod) - val recreatedResolvedPod = expectPodCreationWithId(2, FIRST_EXECUTOR_POD) - allocatorRunnable.getValue.run() - verify(podOperations).create(recreatedResolvedPod) + test("Start all components") { + schedulerBackendUnderTest.start() + verify(podAllocator).setTotalExpectedExecutors(3) + verify(podAllocator).start(TEST_SPARK_APP_ID) + verify(lifecycleEventHandler).start(schedulerBackendUnderTest) + verify(watchEvents).start(TEST_SPARK_APP_ID) + verify(pollEvents).start(TEST_SPARK_APP_ID) } - test("Executors that are initially created but the watch notices them fail are rebuilt" + - " in the next batch.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val scheduler = newSchedulerBackend() - scheduler.start() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - when(podOperations.create(FIRST_EXECUTOR_POD)).thenAnswer(AdditionalAnswers.returnsFirstArg()) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getValue.run() - verify(podOperations, times(1)).create(firstResolvedPod) - executorPodsWatcherArgument.getValue.eventReceived(Action.ERROR, firstResolvedPod) - val recreatedResolvedPod = expectPodCreationWithId(2, FIRST_EXECUTOR_POD) - allocatorRunnable.getValue.run() - verify(podOperations).create(recreatedResolvedPod) + test("Stop all components") { + when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)).thenReturn(labeledPods) + when(labeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)).thenReturn(labeledPods) + schedulerBackendUnderTest.stop() + verify(eventQueue).stop() + verify(watchEvents).stop() + verify(pollEvents).stop() + verify(labeledPods).delete() + verify(kubernetesClient).close() } - private def newSchedulerBackend(): KubernetesClusterSchedulerBackend = { - new KubernetesClusterSchedulerBackend( - taskSchedulerImpl, - rpcEnv, - executorBuilder, - kubernetesClient, - allocatorExecutor, - requestExecutorsService) { - - override def applicationId(): String = APP_ID - } + test("Remove executor") { + schedulerBackendUnderTest.start() + schedulerBackendUnderTest.doRemoveExecutor( + "1", ExecutorKilled) + verify(driverEndpointRef).send(RemoveExecutor("1", ExecutorKilled)) } - private def exitPod(basePod: Pod, exitCode: Int): Pod = { - new PodBuilder(basePod) - .editStatus() - .addNewContainerStatus() - .withNewState() - .withNewTerminated() - .withExitCode(exitCode) - .endTerminated() - .endState() - .endContainerStatus() - .endStatus() - .build() + test("Kill executors") { + schedulerBackendUnderTest.start() + when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)).thenReturn(labeledPods) + when(labeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)).thenReturn(labeledPods) + when(labeledPods.withLabelIn(SPARK_EXECUTOR_ID_LABEL, "1", "2")).thenReturn(labeledPods) + schedulerBackendUnderTest.doKillExecutors(Seq("1", "2")) + verify(labeledPods, never()).delete() + requestExecutorsService.runNextPendingCommand() + verify(labeledPods).delete() } - private def expectPodCreationWithId(executorId: Int, expectedPod: Pod): Pod = { - val resolvedPod = new PodBuilder(expectedPod) - .editMetadata() - .addToLabels(SPARK_EXECUTOR_ID_LABEL, executorId.toString) - .endMetadata() - .build() - val resolvedContainer = new ContainerBuilder().build() - when(executorBuilder.buildFromFeatures(Matchers.argThat( - new BaseMatcher[KubernetesConf[KubernetesExecutorSpecificConf]] { - override def matches(argument: scala.Any) - : Boolean = { - argument.isInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] && - argument.asInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] - .roleSpecificConf.executorId == executorId.toString - } - - override def describeTo(description: Description): Unit = {} - }))).thenReturn(SparkPod(resolvedPod, resolvedContainer)) - new PodBuilder(resolvedPod) - .editSpec() - .addToContainers(resolvedContainer) - .endSpec() - .build() + test("Request total executors") { + schedulerBackendUnderTest.start() + schedulerBackendUnderTest.doRequestTotalExecutors(5) + verify(podAllocator).setTotalExpectedExecutors(3) + verify(podAllocator, never()).setTotalExpectedExecutors(5) + requestExecutorsService.runNextPendingCommand() + verify(podAllocator).setTotalExpectedExecutors(5) } + } diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index acdb4b1f09e0a..2f4e115e84ecd 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -37,11 +37,17 @@ if [ -z "$uidentry" ] ; then fi SPARK_K8S_CMD="$1" -if [ -z "$SPARK_K8S_CMD" ]; then - echo "No command to execute has been provided." 1>&2 - exit 1 -fi -shift 1 +case "$SPARK_K8S_CMD" in + driver | driver-py | executor) + shift 1 + ;; + "") + ;; + *) + echo "Non-spark-on-k8s command provided, proceeding in pass-through mode..." + exec /sbin/tini -s -- "$@" + ;; +esac SPARK_CLASSPATH="$SPARK_CLASSPATH:${SPARK_HOME}/jars/*" env | grep SPARK_JAVA_OPT_ | sort -t_ -k4 -n | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt @@ -92,7 +98,6 @@ case "$SPARK_K8S_CMD" in "$@" $PYSPARK_PRIMARY $PYSPARK_ARGS ) ;; - executor) CMD=( ${JAVA_HOME}/bin/java diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala index 604978967d6db..15bbe60d6c8fb 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala @@ -40,7 +40,7 @@ private[spark] class MesosClusterUI( override def initialize() { attachPage(new MesosClusterPage(this)) attachPage(new DriverPage(this)) - attachHandler(createStaticHandler(MesosClusterUI.STATIC_RESOURCE_DIR, "/static")) + addStaticHandler(MesosClusterUI.STATIC_RESOURCE_DIR) } } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index f4bd1ee9da6f7..b790c7cd27794 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -789,6 +789,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.nodeBlacklist).thenReturn(Set[String]()) when(taskScheduler.sc).thenReturn(sc) externalShuffleClient = mock[MesosExternalShuffleClient] diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 3d6ee50b070a3..ecc576910db9e 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -515,6 +515,10 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_MAX_EXECUTOR_FAILURES, s"Max number of executor failures ($maxNumExecutorFailures) reached") + } else if (allocator.isAllNodeBlacklisted) { + finish(FinalApplicationStatus.FAILED, + ApplicationMaster.EXIT_MAX_EXECUTOR_FAILURES, + "Due to executor failures all available nodes are blacklisted") } else { logDebug("Sending progress") allocator.allocateResources() diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index ebee3d431744d..fae054e0eea00 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -24,7 +24,7 @@ import java.util.regex.Pattern import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.util.control.NonFatal import org.apache.hadoop.yarn.api.records._ @@ -66,7 +66,8 @@ private[yarn] class YarnAllocator( appAttemptId: ApplicationAttemptId, securityMgr: SecurityManager, localResources: Map[String, LocalResource], - resolver: SparkRackResolver) + resolver: SparkRackResolver, + clock: Clock = new SystemClock) extends Logging { import YarnAllocator._ @@ -102,18 +103,14 @@ private[yarn] class YarnAllocator( private var executorIdCounter: Int = driverRef.askSync[Int](RetrieveLastAllocatedExecutorId) - // Queue to store the timestamp of failed executors - private val failedExecutorsTimeStamps = new Queue[Long]() + private[spark] val failureTracker = new FailureTracker(sparkConf, clock) - private var clock: Clock = new SystemClock - - private val executorFailuresValidityInterval = - sparkConf.get(EXECUTOR_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS).getOrElse(-1L) + private val allocatorBlacklistTracker = + new YarnAllocatorBlacklistTracker(sparkConf, amClient, failureTracker) @volatile private var targetNumExecutors = SchedulerBackendUtils.getInitialTargetExecutorNumber(sparkConf) - private var currentNodeBlacklist = Set.empty[String] // Executor loss reason requests that are pending - maps from executor ID for inquiry to a // list of requesters that should be responded to once we find out why the given executor @@ -149,7 +146,6 @@ private[yarn] class YarnAllocator( private val labelExpression = sparkConf.get(EXECUTOR_NODE_LABEL_EXPRESSION) - // A map to store preferred hostname and possible task numbers running on it. private var hostToLocalTaskCounts: Map[String, Int] = Map.empty @@ -160,26 +156,11 @@ private[yarn] class YarnAllocator( private[yarn] val containerPlacementStrategy = new LocalityPreferredContainerPlacementStrategy(sparkConf, conf, resource, resolver) - /** - * Use a different clock for YarnAllocator. This is mainly used for testing. - */ - def setClock(newClock: Clock): Unit = { - clock = newClock - } - def getNumExecutorsRunning: Int = runningExecutors.size() - def getNumExecutorsFailed: Int = synchronized { - val endTime = clock.getTimeMillis() + def getNumExecutorsFailed: Int = failureTracker.numFailedExecutors - while (executorFailuresValidityInterval > 0 - && failedExecutorsTimeStamps.nonEmpty - && failedExecutorsTimeStamps.head < endTime - executorFailuresValidityInterval) { - failedExecutorsTimeStamps.dequeue() - } - - failedExecutorsTimeStamps.size - } + def isAllNodeBlacklisted: Boolean = allocatorBlacklistTracker.isAllNodeBlacklisted /** * A sequence of pending container requests that have not yet been fulfilled. @@ -204,9 +185,8 @@ private[yarn] class YarnAllocator( * @param localityAwareTasks number of locality aware tasks to be used as container placement hint * @param hostToLocalTaskCount a map of preferred hostname to possible task counts to be used as * container placement hint. - * @param nodeBlacklist a set of blacklisted nodes, which is passed in to avoid allocating new - * containers on them. It will be used to update the application master's - * blacklist. + * @param nodeBlacklist blacklisted nodes, which is passed in to avoid allocating new containers + * on them. It will be used to update the application master's blacklist. * @return Whether the new requested total is different than the old value. */ def requestTotalExecutorsWithPreferredLocalities( @@ -220,19 +200,7 @@ private[yarn] class YarnAllocator( if (requestedTotal != targetNumExecutors) { logInfo(s"Driver requested a total number of $requestedTotal executor(s).") targetNumExecutors = requestedTotal - - // Update blacklist infomation to YARN ResouceManager for this application, - // in order to avoid allocating new Containers on the problematic nodes. - val blacklistAdditions = nodeBlacklist -- currentNodeBlacklist - val blacklistRemovals = currentNodeBlacklist -- nodeBlacklist - if (blacklistAdditions.nonEmpty) { - logInfo(s"adding nodes to YARN application master's blacklist: $blacklistAdditions") - } - if (blacklistRemovals.nonEmpty) { - logInfo(s"removing nodes from YARN application master's blacklist: $blacklistRemovals") - } - amClient.updateBlacklist(blacklistAdditions.toList.asJava, blacklistRemovals.toList.asJava) - currentNodeBlacklist = nodeBlacklist + allocatorBlacklistTracker.setSchedulerBlacklistedNodes(nodeBlacklist) true } else { false @@ -268,6 +236,7 @@ private[yarn] class YarnAllocator( val allocateResponse = amClient.allocate(progressIndicator) val allocatedContainers = allocateResponse.getAllocatedContainers() + allocatorBlacklistTracker.setNumClusterNodes(allocateResponse.getNumClusterNodes) if (allocatedContainers.size > 0) { logDebug(("Allocated containers: %d. Current executor count: %d. " + @@ -602,8 +571,9 @@ private[yarn] class YarnAllocator( completedContainer.getDiagnostics, PMEM_EXCEEDED_PATTERN)) case _ => - // Enqueue the timestamp of failed executor - failedExecutorsTimeStamps.enqueue(clock.getTimeMillis()) + // all the failures which not covered above, like: + // disk failure, kill by app master or resource manager, ... + allocatorBlacklistTracker.handleResourceAllocationFailure(hostOpt) (true, "Container marked as failed: " + containerId + onHostStr + ". Exit status: " + completedContainer.getExitStatus + ". Diagnostics: " + completedContainer.getDiagnostics) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala new file mode 100644 index 0000000000000..1b48a0ee7ad32 --- /dev/null +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.yarn + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.HashMap + +import org.apache.hadoop.yarn.client.api.AMRMClient +import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.yarn.config._ +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ +import org.apache.spark.scheduler.BlacklistTracker +import org.apache.spark.util.{Clock, SystemClock, Utils} + +/** + * YarnAllocatorBlacklistTracker is responsible for tracking the blacklisted nodes + * and synchronizing the node list to YARN. + * + * Blacklisted nodes are coming from two different sources: + * + *
    + *
  • from the scheduler as task level blacklisted nodes + *
  • from this class (tracked here) as YARN resource allocation problems + *
+ * + * The reason to realize this logic here (and not in the driver) is to avoid possible delays + * between synchronizing the blacklisted nodes with YARN and resource allocations. + */ +private[spark] class YarnAllocatorBlacklistTracker( + sparkConf: SparkConf, + amClient: AMRMClient[ContainerRequest], + failureTracker: FailureTracker) + extends Logging { + + private val blacklistTimeoutMillis = BlacklistTracker.getBlacklistTimeout(sparkConf) + + private val launchBlacklistEnabled = sparkConf.get(YARN_EXECUTOR_LAUNCH_BLACKLIST_ENABLED) + + private val maxFailuresPerHost = sparkConf.get(MAX_FAILED_EXEC_PER_NODE) + + private val allocatorBlacklist = new HashMap[String, Long]() + + private var currentBlacklistedYarnNodes = Set.empty[String] + + private var schedulerBlacklist = Set.empty[String] + + private var numClusterNodes = Int.MaxValue + + def setNumClusterNodes(numClusterNodes: Int): Unit = { + this.numClusterNodes = numClusterNodes + } + + def handleResourceAllocationFailure(hostOpt: Option[String]): Unit = { + hostOpt match { + case Some(hostname) if launchBlacklistEnabled => + // failures on an already blacklisted nodes are not even tracked. + // otherwise, such failures could shutdown the application + // as resource requests are asynchronous + // and a late failure response could exceed MAX_EXECUTOR_FAILURES + if (!schedulerBlacklist.contains(hostname) && + !allocatorBlacklist.contains(hostname)) { + failureTracker.registerFailureOnHost(hostname) + updateAllocationBlacklistedNodes(hostname) + } + case _ => + failureTracker.registerExecutorFailure() + } + } + + private def updateAllocationBlacklistedNodes(hostname: String): Unit = { + val failuresOnHost = failureTracker.numFailuresOnHost(hostname) + if (failuresOnHost > maxFailuresPerHost) { + logInfo(s"blacklisting $hostname as YARN allocation failed $failuresOnHost times") + allocatorBlacklist.put( + hostname, + failureTracker.clock.getTimeMillis() + blacklistTimeoutMillis) + refreshBlacklistedNodes() + } + } + + def setSchedulerBlacklistedNodes(schedulerBlacklistedNodesWithExpiry: Set[String]): Unit = { + this.schedulerBlacklist = schedulerBlacklistedNodesWithExpiry + refreshBlacklistedNodes() + } + + def isAllNodeBlacklisted: Boolean = currentBlacklistedYarnNodes.size >= numClusterNodes + + private def refreshBlacklistedNodes(): Unit = { + removeExpiredYarnBlacklistedNodes() + val allBlacklistedNodes = schedulerBlacklist ++ allocatorBlacklist.keySet + synchronizeBlacklistedNodeWithYarn(allBlacklistedNodes) + } + + private def synchronizeBlacklistedNodeWithYarn(nodesToBlacklist: Set[String]): Unit = { + // Update blacklist information to YARN ResourceManager for this application, + // in order to avoid allocating new Containers on the problematic nodes. + val additions = (nodesToBlacklist -- currentBlacklistedYarnNodes).toList.sorted + val removals = (currentBlacklistedYarnNodes -- nodesToBlacklist).toList.sorted + if (additions.nonEmpty) { + logInfo(s"adding nodes to YARN application master's blacklist: $additions") + } + if (removals.nonEmpty) { + logInfo(s"removing nodes from YARN application master's blacklist: $removals") + } + amClient.updateBlacklist(additions.asJava, removals.asJava) + currentBlacklistedYarnNodes = nodesToBlacklist + } + + private def removeExpiredYarnBlacklistedNodes(): Unit = { + val now = failureTracker.clock.getTimeMillis() + allocatorBlacklist.retain { (_, expiryTime) => expiryTime > now } + } +} + +/** + * FailureTracker is responsible for tracking executor failures both for each host separately + * and for all hosts altogether. + */ +private[spark] class FailureTracker( + sparkConf: SparkConf, + val clock: Clock = new SystemClock) extends Logging { + + private val executorFailuresValidityInterval = + sparkConf.get(config.EXECUTOR_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS).getOrElse(-1L) + + // Queue to store the timestamp of failed executors for each host + private val failedExecutorsTimeStampsPerHost = mutable.Map[String, mutable.Queue[Long]]() + + private val failedExecutorsTimeStamps = new mutable.Queue[Long]() + + private def updateAndCountFailures(failedExecutorsWithTimeStamps: mutable.Queue[Long]): Int = { + val endTime = clock.getTimeMillis() + while (executorFailuresValidityInterval > 0 && + failedExecutorsWithTimeStamps.nonEmpty && + failedExecutorsWithTimeStamps.head < endTime - executorFailuresValidityInterval) { + failedExecutorsWithTimeStamps.dequeue() + } + failedExecutorsWithTimeStamps.size + } + + def numFailedExecutors: Int = synchronized { + updateAndCountFailures(failedExecutorsTimeStamps) + } + + def registerFailureOnHost(hostname: String): Unit = synchronized { + val timeMillis = clock.getTimeMillis() + failedExecutorsTimeStamps.enqueue(timeMillis) + val failedExecutorsOnHost = + failedExecutorsTimeStampsPerHost.getOrElse(hostname, { + val failureOnHost = mutable.Queue[Long]() + failedExecutorsTimeStampsPerHost.put(hostname, failureOnHost) + failureOnHost + }) + failedExecutorsOnHost.enqueue(timeMillis) + } + + def registerExecutorFailure(): Unit = synchronized { + val timeMillis = clock.getTimeMillis() + failedExecutorsTimeStamps.enqueue(timeMillis) + } + + def numFailuresOnHost(hostname: String): Int = { + failedExecutorsTimeStampsPerHost.get(hostname).map { failedExecutorsOnHost => + updateAndCountFailures(failedExecutorsOnHost) + }.getOrElse(0) + } + +} + diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index 1a99b3bd57672..129084a86597a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -328,4 +328,10 @@ package object config { CACHED_FILES_TYPES, CACHED_CONF_ARCHIVE) + /* YARN allocator-level blacklisting related config entries. */ + private[spark] val YARN_EXECUTOR_LAUNCH_BLACKLIST_ENABLED = + ConfigBuilder("spark.yarn.blacklist.executor.launch.blacklisting.enabled") + .booleanConf + .createWithDefault(false) + } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/FailureTrackerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/FailureTrackerSuite.scala new file mode 100644 index 0000000000000..4f77b9c99dd25 --- /dev/null +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/FailureTrackerSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.yarn + +import org.scalatest.Matchers + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.util.ManualClock + +class FailureTrackerSuite extends SparkFunSuite with Matchers { + + override def beforeAll(): Unit = { + super.beforeAll() + } + + test("failures expire if validity interval is set") { + val sparkConf = new SparkConf() + sparkConf.set(config.EXECUTOR_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS, 100L) + + val clock = new ManualClock() + val failureTracker = new FailureTracker(sparkConf, clock) + + clock.setTime(0) + failureTracker.registerFailureOnHost("host1") + failureTracker.numFailuresOnHost("host1") should be (1) + failureTracker.numFailedExecutors should be (1) + + clock.setTime(10) + failureTracker.registerFailureOnHost("host2") + failureTracker.numFailuresOnHost("host2") should be (1) + failureTracker.numFailedExecutors should be (2) + + clock.setTime(20) + failureTracker.registerFailureOnHost("host1") + failureTracker.numFailuresOnHost("host1") should be (2) + failureTracker.numFailedExecutors should be (3) + + clock.setTime(30) + failureTracker.registerFailureOnHost("host2") + failureTracker.numFailuresOnHost("host2") should be (2) + failureTracker.numFailedExecutors should be (4) + + clock.setTime(101) + failureTracker.numFailuresOnHost("host1") should be (1) + failureTracker.numFailedExecutors should be (3) + + clock.setTime(231) + failureTracker.numFailuresOnHost("host1") should be (0) + failureTracker.numFailuresOnHost("host2") should be (0) + failureTracker.numFailedExecutors should be (0) + } + + + test("failures never expire if validity interval is not set (-1)") { + val sparkConf = new SparkConf() + + val clock = new ManualClock() + val failureTracker = new FailureTracker(sparkConf, clock) + + clock.setTime(0) + failureTracker.registerFailureOnHost("host1") + failureTracker.numFailuresOnHost("host1") should be (1) + failureTracker.numFailedExecutors should be (1) + + clock.setTime(10) + failureTracker.registerFailureOnHost("host2") + failureTracker.numFailuresOnHost("host2") should be (1) + failureTracker.numFailedExecutors should be (2) + + clock.setTime(20) + failureTracker.registerFailureOnHost("host1") + failureTracker.numFailuresOnHost("host1") should be (2) + failureTracker.numFailedExecutors should be (3) + + clock.setTime(30) + failureTracker.registerFailureOnHost("host2") + failureTracker.numFailuresOnHost("host2") should be (2) + failureTracker.numFailedExecutors should be (4) + + clock.setTime(1000) + failureTracker.numFailuresOnHost("host1") should be (2) + failureTracker.numFailuresOnHost("host2") should be (2) + failureTracker.numFailedExecutors should be (4) + } + +} diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTrackerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTrackerSuite.scala new file mode 100644 index 0000000000000..aeac68e6ed330 --- /dev/null +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTrackerSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.yarn + +import java.util.Arrays +import java.util.Collections + +import org.apache.hadoop.yarn.client.api.AMRMClient +import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest +import org.mockito.Mockito._ +import org.scalatest.{BeforeAndAfterEach, Matchers} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.yarn.config.YARN_EXECUTOR_LAUNCH_BLACKLIST_ENABLED +import org.apache.spark.internal.config.{BLACKLIST_TIMEOUT_CONF, MAX_FAILED_EXEC_PER_NODE} +import org.apache.spark.util.ManualClock + +class YarnAllocatorBlacklistTrackerSuite extends SparkFunSuite with Matchers + with BeforeAndAfterEach { + + val BLACKLIST_TIMEOUT = 100L + val MAX_FAILED_EXEC_PER_NODE_VALUE = 2 + + var amClientMock: AMRMClient[ContainerRequest] = _ + var yarnBlacklistTracker: YarnAllocatorBlacklistTracker = _ + var failureTracker: FailureTracker = _ + var clock: ManualClock = _ + + override def beforeEach(): Unit = { + val sparkConf = new SparkConf() + sparkConf.set(BLACKLIST_TIMEOUT_CONF, BLACKLIST_TIMEOUT) + sparkConf.set(YARN_EXECUTOR_LAUNCH_BLACKLIST_ENABLED, true) + sparkConf.set(MAX_FAILED_EXEC_PER_NODE, MAX_FAILED_EXEC_PER_NODE_VALUE) + clock = new ManualClock() + + amClientMock = mock(classOf[AMRMClient[ContainerRequest]]) + failureTracker = new FailureTracker(sparkConf, clock) + yarnBlacklistTracker = + new YarnAllocatorBlacklistTracker(sparkConf, amClientMock, failureTracker) + yarnBlacklistTracker.setNumClusterNodes(4) + super.beforeEach() + } + + test("expiring its own blacklisted nodes") { + (1 to MAX_FAILED_EXEC_PER_NODE_VALUE).foreach { + _ => { + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host")) + // host should not be blacklisted at these failures as MAX_FAILED_EXEC_PER_NODE is 2 + verify(amClientMock, never()) + .updateBlacklist(Arrays.asList("host"), Collections.emptyList()) + } + } + + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host")) + // the third failure on the host triggers the blacklisting + verify(amClientMock).updateBlacklist(Arrays.asList("host"), Collections.emptyList()) + + clock.advance(BLACKLIST_TIMEOUT) + + // trigger synchronisation of blacklisted nodes with YARN + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set()) + verify(amClientMock).updateBlacklist(Collections.emptyList(), Arrays.asList("host")) + } + + test("not handling the expiry of scheduler blacklisted nodes") { + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host1", "host2")) + verify(amClientMock) + .updateBlacklist(Arrays.asList("host1", "host2"), Collections.emptyList()) + + // advance timer more then host1, host2 expiry time + clock.advance(200L) + + // expired blacklisted nodes (simulating a resource request) + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host1", "host2")) + // no change is communicated to YARN regarding the blacklisting + verify(amClientMock).updateBlacklist(Collections.emptyList(), Collections.emptyList()) + } + + test("combining scheduler and allocation blacklist") { + (1 to MAX_FAILED_EXEC_PER_NODE_VALUE).foreach { + _ => { + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host1")) + // host1 should not be blacklisted at these failures as MAX_FAILED_EXEC_PER_NODE is 2 + verify(amClientMock, never()) + .updateBlacklist(Arrays.asList("host1"), Collections.emptyList()) + } + } + + // as this is the third failure on host1 the node will be blacklisted + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host1")) + verify(amClientMock) + .updateBlacklist(Arrays.asList("host1"), Collections.emptyList()) + + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host2", "host3")) + verify(amClientMock) + .updateBlacklist(Arrays.asList("host2", "host3"), Collections.emptyList()) + + clock.advance(10L) + + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host3", "host4")) + verify(amClientMock) + .updateBlacklist(Arrays.asList("host4"), Arrays.asList("host2")) + } + + test("blacklist all available nodes") { + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host1", "host2", "host3")) + verify(amClientMock) + .updateBlacklist(Arrays.asList("host1", "host2", "host3"), Collections.emptyList()) + + clock.advance(60L) + (1 to MAX_FAILED_EXEC_PER_NODE_VALUE).foreach { + _ => { + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host4")) + // host4 should not be blacklisted at these failures as MAX_FAILED_EXEC_PER_NODE is 2 + verify(amClientMock, never()) + .updateBlacklist(Arrays.asList("host4"), Collections.emptyList()) + } + } + + // the third failure on the host triggers the blacklisting + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host4")) + + verify(amClientMock).updateBlacklist(Arrays.asList("host4"), Collections.emptyList()) + assert(yarnBlacklistTracker.isAllNodeBlacklisted === true) + } +} diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 525abb6f2b350..3f783baed110d 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -59,6 +59,8 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter var rmClient: AMRMClient[ContainerRequest] = _ + var clock: ManualClock = _ + var containerNum = 0 override def beforeEach() { @@ -66,6 +68,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter rmClient = AMRMClient.createAMRMClient() rmClient.init(conf) rmClient.start() + clock = new ManualClock() } override def afterEach() { @@ -101,7 +104,8 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter appAttemptId, new SecurityManager(sparkConf), Map(), - new MockResolver()) + new MockResolver(), + clock) } def createContainer(host: String): Container = { @@ -332,10 +336,14 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.requestTotalExecutorsWithPreferredLocalities(1, 0, Map(), Set("hostA")) verify(mockAmClient).updateBlacklist(Seq("hostA").asJava, Seq[String]().asJava) - handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map(), Set("hostA", "hostB")) + val blacklistedNodes = Set( + "hostA", + "hostB" + ) + handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map(), blacklistedNodes) verify(mockAmClient).updateBlacklist(Seq("hostB").asJava, Seq[String]().asJava) - handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map(), Set()) + handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map(), Set.empty) verify(mockAmClient).updateBlacklist(Seq[String]().asJava, Seq("hostA", "hostB").asJava) } @@ -353,8 +361,6 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter test("window based failure executor counting") { sparkConf.set("spark.yarn.executor.failuresValidityInterval", "100s") val handler = createAllocator(4) - val clock = new ManualClock(0L) - handler.setClock(clock) handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index d5d934bc91cab..4dd2b7365652a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -83,7 +83,7 @@ public static long calculateSizeOfUnderlyingByteArray(long numFields, int elemen private long elementOffset; private long getElementOffset(int ordinal, int elementSize) { - return elementOffset + ordinal * elementSize; + return elementOffset + ordinal * (long)elementSize; } public Object getBaseObject() { return baseObject; } @@ -414,7 +414,7 @@ public byte[] toByteArray() { public short[] toShortArray() { short[] values = new short[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.SHORT_ARRAY_OFFSET, numElements * 2); + baseObject, elementOffset, values, Platform.SHORT_ARRAY_OFFSET, numElements * 2L); return values; } @@ -422,7 +422,7 @@ public short[] toShortArray() { public int[] toIntArray() { int[] values = new int[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.INT_ARRAY_OFFSET, numElements * 4); + baseObject, elementOffset, values, Platform.INT_ARRAY_OFFSET, numElements * 4L); return values; } @@ -430,7 +430,7 @@ public int[] toIntArray() { public long[] toLongArray() { long[] values = new long[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.LONG_ARRAY_OFFSET, numElements * 8); + baseObject, elementOffset, values, Platform.LONG_ARRAY_OFFSET, numElements * 8L); return values; } @@ -438,7 +438,7 @@ public long[] toLongArray() { public float[] toFloatArray() { float[] values = new float[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.FLOAT_ARRAY_OFFSET, numElements * 4); + baseObject, elementOffset, values, Platform.FLOAT_ARRAY_OFFSET, numElements * 4L); return values; } @@ -446,14 +446,14 @@ public float[] toFloatArray() { public double[] toDoubleArray() { double[] values = new double[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.DOUBLE_ARRAY_OFFSET, numElements * 8); + baseObject, elementOffset, values, Platform.DOUBLE_ARRAY_OFFSET, numElements * 8L); return values; } private static UnsafeArrayData fromPrimitiveArray( Object arr, int offset, int length, int elementSize) { final long headerInBytes = calculateHeaderPortionInBytes(length); - final long valueRegionInBytes = elementSize * length; + final long valueRegionInBytes = (long)elementSize * length; final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8; if (totalSizeInLongs > Integer.MAX_VALUE / 8) { throw new UnsupportedOperationException("Cannot convert this array to unsafe format as " + diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java index 905e6820ce6e2..c823de4810f2b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java @@ -41,7 +41,7 @@ public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueB @Override public UnsafeRow appendRow(Object kbase, long koff, int klen, Object vbase, long voff, int vlen) { - final long recordLength = 8 + klen + vlen + 8; + final long recordLength = 8L + klen + vlen + 8; // if run out of max supported rows or page size, return null if (numRows >= capacity || page == null || page.size() - pageCursor < recordLength) { return null; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java index d224332d8a6c9..023ec139652c5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java @@ -21,6 +21,9 @@ import java.io.Reader; import javax.xml.namespace.QName; +import javax.xml.parsers.DocumentBuilder; +import javax.xml.parsers.DocumentBuilderFactory; +import javax.xml.parsers.ParserConfigurationException; import javax.xml.xpath.XPath; import javax.xml.xpath.XPathConstants; import javax.xml.xpath.XPathExpression; @@ -37,9 +40,15 @@ * This is based on Hive's UDFXPathUtil implementation. */ public class UDFXPathUtil { + public static final String SAX_FEATURE_PREFIX = "http://xml.org/sax/features/"; + public static final String EXTERNAL_GENERAL_ENTITIES_FEATURE = "external-general-entities"; + public static final String EXTERNAL_PARAMETER_ENTITIES_FEATURE = "external-parameter-entities"; + private DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance(); + private DocumentBuilder builder = null; private XPath xpath = XPathFactory.newInstance().newXPath(); private ReusableStringReader reader = new ReusableStringReader(); private InputSource inputSource = new InputSource(reader); + private XPathExpression expression = null; private String oldPath = null; @@ -65,14 +74,31 @@ public Object eval(String xml, String path, QName qname) throws XPathExpressionE return null; } + if (builder == null){ + try { + initializeDocumentBuilderFactory(); + builder = dbf.newDocumentBuilder(); + } catch (ParserConfigurationException e) { + throw new RuntimeException( + "Error instantiating DocumentBuilder, cannot build xml parser", e); + } + } + reader.set(xml); try { - return expression.evaluate(inputSource, qname); + return expression.evaluate(builder.parse(inputSource), qname); } catch (XPathExpressionException e) { throw new RuntimeException("Invalid XML document: " + e.getMessage() + "\n" + xml, e); + } catch (Exception e) { + throw new RuntimeException("Error loading expression '" + oldPath + "'", e); } } + private void initializeDocumentBuilderFactory() throws ParserConfigurationException { + dbf.setFeature(SAX_FEATURE_PREFIX + EXTERNAL_GENERAL_ENTITIES_FEATURE, false); + dbf.setFeature(SAX_FEATURE_PREFIX + EXTERNAL_PARAMETER_ENTITIES_FEATURE, false); + } + public Boolean evalBoolean(String xml, String path) throws XPathExpressionException { return (Boolean) eval(xml, path, XPathConstants.BOOLEAN); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index 0b95a8821b05a..b47ec0b72c638 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -132,7 +132,7 @@ object Encoders { * - primitive types: boolean, int, double, etc. * - boxed types: Boolean, Integer, Double, etc. * - String - * - java.math.BigDecimal + * - java.math.BigDecimal, java.math.BigInteger * - time related: java.sql.Date, java.sql.Timestamp * - collection types: only array and java.util.List currently, map support is in progress * - nested java bean. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 9e9105a157abe..93df73ab1eaf6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -286,6 +286,7 @@ object CatalystTypeConverters { override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match { case str: String => UTF8String.fromString(str) case utf8: UTF8String => utf8 + case chr: Char => UTF8String.fromString(chr.toString) case other => throw new IllegalArgumentException( s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + s"cannot be converted to the string type") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f9947d1fa6c78..e187133d03b17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1739,9 +1739,10 @@ class Analyzer( * 1. For a list of [[Expression]]s (a projectList or an aggregateExpressions), partitions * it two lists of [[Expression]]s, one for all [[WindowExpression]]s and another for * all regular expressions. - * 2. For all [[WindowExpression]]s, groups them based on their [[WindowSpecDefinition]]s. - * 3. For every distinct [[WindowSpecDefinition]], creates a [[Window]] operator and inserts - * it into the plan tree. + * 2. For all [[WindowExpression]]s, groups them based on their [[WindowSpecDefinition]]s + * and [[WindowFunctionType]]s. + * 3. For every distinct [[WindowSpecDefinition]] and [[WindowFunctionType]], creates a + * [[Window]] operator and inserts it into the plan tree. */ object ExtractWindowExpressions extends Rule[LogicalPlan] { private def hasWindowFunction(exprs: Seq[Expression]): Boolean = @@ -1901,7 +1902,7 @@ class Analyzer( s"Please file a bug report with this error message, stack trace, and the query.") } else { val spec = distinctWindowSpec.head - (spec.partitionSpec, spec.orderSpec) + (spec.partitionSpec, spec.orderSpec, WindowFunctionType.functionType(expr)) } }.toSeq @@ -1909,7 +1910,7 @@ class Analyzer( // setting this to the child of the next Window operator. val windowOps = groupedWindowExpressions.foldLeft(child) { - case (last, ((partitionSpec, orderSpec), windowExpressions)) => + case (last, ((partitionSpec, orderSpec, _), windowExpressions)) => Window(windowExpressions, partitionSpec, orderSpec, last) } @@ -1922,6 +1923,9 @@ class Analyzer( // "Aggregate with Having clause" will be triggered. def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case Filter(condition, _) if hasWindowFunction(condition) => + failAnalysis("It is not allowed to use window functions inside WHERE and HAVING clauses") + // Aggregate with Having clause. This rule works with an unresolved Aggregate because // a resolved Aggregate will not have Window Functions. case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 90bda2a72ad82..af256b98b34f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ @@ -112,12 +113,19 @@ trait CheckAnalysis extends PredicateHelper { failAnalysis("An offset window function can only be evaluated in an ordered " + s"row-based window frame with a single offset: $w") + case _ @ WindowExpression(_: PythonUDF, + WindowSpecDefinition(_, _, frame: SpecifiedWindowFrame)) + if !frame.isUnbounded => + failAnalysis("Only unbounded window frame is supported with Pandas UDFs.") + case w @ WindowExpression(e, s) => // Only allow window functions with an aggregate expression or an offset window - // function. + // function or a Pandas window UDF. e match { case _: AggregateExpression | _: OffsetWindowFunction | _: AggregateWindowFunction => w + case f: PythonUDF if PythonUDF.isWindowPandasUDF(f) => + w case _ => failAnalysis(s"Expression '$e' not supported within a window function.") } @@ -154,7 +162,7 @@ trait CheckAnalysis extends PredicateHelper { case Aggregate(groupingExprs, aggregateExprs, child) => def isAggregateExpression(expr: Expression) = { - expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupAggPandasUDF(expr) + expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr) } def checkValidAggregateExpression(expr: Expression): Unit = expr match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 49fb35b083580..8abc616c1a3f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -417,12 +417,15 @@ object FunctionRegistry { expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), expression[ElementAt]("element_at"), + expression[MapFromArrays]("map_from_arrays"), expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), expression[MapEntries]("map_entries"), + expression[MapFromEntries]("map_from_entries"), expression[Size]("size"), expression[Slice]("slice"), expression[Size]("cardinality"), + expression[ArraysZip]("arrays_zip"), expression[SortArray]("sort_array"), expression[ArrayMin]("array_min"), expression[ArrayMax]("array_max"), @@ -431,6 +434,7 @@ object FunctionRegistry { expression[Flatten]("flatten"), expression[ArrayRepeat]("array_repeat"), expression[ArrayRemove]("array_remove"), + expression[ArrayDistinct]("array_distinct"), CreateStruct.registryEntry, // mask functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index b2817b0538a7f..637923928a7da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -536,6 +536,14 @@ object TypeCoercion { case None => c } + case aj @ ArrayJoin(arr, d, nr) if !ArrayType(StringType).acceptsType(arr.dataType) && + ArrayType.acceptsType(arr.dataType) => + val containsNull = arr.dataType.asInstanceOf[ArrayType].containsNull + ImplicitTypeCasts.implicitCast(arr, ArrayType(StringType, containsNull)) match { + case Some(castedArr) => ArrayJoin(castedArr, d, nr) + case None => aj + } + case m @ CreateMap(children) if m.keys.length == m.values.length && (!haveSameType(m.keys) || !haveSameType(m.values)) => val newKeys = if (haveSameType(m.keys)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index efc2882f0a3d3..cbea3c017a265 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -128,7 +128,7 @@ object ExpressionEncoder { case b: BoundReference if b == originalInputObject => newInputObject }) - if (enc.flat) { + val serializerExpr = if (enc.flat) { newSerializer.head } else { // For non-flat encoder, the input object is not top level anymore after being combined to @@ -146,6 +146,7 @@ object ExpressionEncoder { Invoke(Literal.fromObject(None), "equals", BooleanType, newInputObject :: Nil)) If(nullCheck, Literal.create(null, struct.dataType), struct) } + Alias(serializerExpr, s"_${index + 1}")() } val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index efd664dde725a..6530b176968f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -34,10 +34,14 @@ object PythonUDF { e.isInstanceOf[PythonUDF] && SCALAR_TYPES.contains(e.asInstanceOf[PythonUDF].evalType) } - def isGroupAggPandasUDF(e: Expression): Boolean = { + def isGroupedAggPandasUDF(e: Expression): Boolean = { e.isInstanceOf[PythonUDF] && e.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF } + + // This is currently same as GroupedAggPandasUDF, but we might support new types in the future, + // e.g, N -> N transform. + def isWindowPandasUDF(e: Expression): Boolean = isGroupedAggPandasUDF(e) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 66315e5906253..4cc0968911cb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -819,6 +819,36 @@ class CodegenContext { } } + /** + * Generates code to do null safe execution when accessing properties of complex + * ArrayData elements. + * + * @param nullElements used to decide whether the ArrayData might contain null or not. + * @param isNull a variable indicating whether the result will be evaluated to null or not. + * @param arrayData a variable name representing the ArrayData. + * @param execute the code that should be executed only if the ArrayData doesn't contain + * any null. + */ + def nullArrayElementsSaveExec( + nullElements: Boolean, + isNull: String, + arrayData: String)( + execute: String): String = { + val i = freshName("idx") + if (nullElements) { + s""" + |for (int $i = 0; !$isNull && $i < $arrayData.numElements(); $i++) { + | $isNull |= $arrayData.isNullAt($i); + |} + |if (!$isNull) { + | $execute + |} + """.stripMargin + } else { + execute + } + } + /** * Splits the generated code of expressions into multiple functions, because function has * 64kb code size limit in JVM. If the class to which the function would be inlined would grow diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 176995affe701..58612f65c1a53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -25,12 +25,13 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.{ByteArray, UTF8String} +import org.apache.spark.util.collection.OpenHashSet /** * Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit @@ -128,6 +129,172 @@ case class MapKeys(child: Expression) override def prettyName: String = "map_keys" } +@ExpressionDescription( + usage = """ + _FUNC_(a1, a2, ...) - Returns a merged array of structs in which the N-th struct contains all + N-th values of input arrays. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(2, 3, 4)); + [[1, 2], [2, 3], [3, 4]] + > SELECT _FUNC_(array(1, 2), array(2, 3), array(3, 4)); + [[1, 2, 3], [2, 3, 4]] + """, + since = "2.4.0") +case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.length)(ArrayType) + + override def dataType: DataType = ArrayType(mountSchema) + + override def nullable: Boolean = children.exists(_.nullable) + + private lazy val arrayTypes = children.map(_.dataType.asInstanceOf[ArrayType]) + + private lazy val arrayElementTypes = arrayTypes.map(_.elementType) + + @transient private lazy val mountSchema: StructType = { + val fields = children.zip(arrayElementTypes).zipWithIndex.map { + case ((expr: NamedExpression, elementType), _) => + StructField(expr.name, elementType, nullable = true) + case ((_, elementType), idx) => + StructField(idx.toString, elementType, nullable = true) + } + StructType(fields) + } + + @transient lazy val numberOfArrays: Int = children.length + + @transient lazy val genericArrayData = classOf[GenericArrayData].getName + + def emptyInputGenCode(ev: ExprCode): ExprCode = { + ev.copy(code""" + |${CodeGenerator.javaType(dataType)} ${ev.value} = new $genericArrayData(new Object[0]); + |boolean ${ev.isNull} = false; + """.stripMargin) + } + + def nonEmptyInputGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val genericInternalRow = classOf[GenericInternalRow].getName + val arrVals = ctx.freshName("arrVals") + val biggestCardinality = ctx.freshName("biggestCardinality") + + val currentRow = ctx.freshName("currentRow") + val j = ctx.freshName("j") + val i = ctx.freshName("i") + val args = ctx.freshName("args") + + val evals = children.map(_.genCode(ctx)) + val getValuesAndCardinalities = evals.zipWithIndex.map { case (eval, index) => + s""" + |if ($biggestCardinality != -1) { + | ${eval.code} + | if (!${eval.isNull}) { + | $arrVals[$index] = ${eval.value}; + | $biggestCardinality = Math.max($biggestCardinality, ${eval.value}.numElements()); + | } else { + | $biggestCardinality = -1; + | } + |} + """.stripMargin + } + + val splittedGetValuesAndCardinalities = ctx.splitExpressionsWithCurrentInputs( + expressions = getValuesAndCardinalities, + funcName = "getValuesAndCardinalities", + returnType = "int", + makeSplitFunction = body => + s""" + |$body + |return $biggestCardinality; + """.stripMargin, + foldFunctions = _.map(funcCall => s"$biggestCardinality = $funcCall;").mkString("\n"), + extraArguments = + ("ArrayData[]", arrVals) :: + ("int", biggestCardinality) :: Nil) + + val getValueForType = arrayElementTypes.zipWithIndex.map { case (eleType, idx) => + val g = CodeGenerator.getValue(s"$arrVals[$idx]", eleType, i) + s""" + |if ($i < $arrVals[$idx].numElements() && !$arrVals[$idx].isNullAt($i)) { + | $currentRow[$idx] = $g; + |} else { + | $currentRow[$idx] = null; + |} + """.stripMargin + } + + val getValueForTypeSplitted = ctx.splitExpressions( + expressions = getValueForType, + funcName = "extractValue", + arguments = + ("int", i) :: + ("Object[]", currentRow) :: + ("ArrayData[]", arrVals) :: Nil) + + val initVariables = s""" + |ArrayData[] $arrVals = new ArrayData[$numberOfArrays]; + |int $biggestCardinality = 0; + |${CodeGenerator.javaType(dataType)} ${ev.value} = null; + """.stripMargin + + ev.copy(code""" + |$initVariables + |$splittedGetValuesAndCardinalities + |boolean ${ev.isNull} = $biggestCardinality == -1; + |if (!${ev.isNull}) { + | Object[] $args = new Object[$biggestCardinality]; + | for (int $i = 0; $i < $biggestCardinality; $i ++) { + | Object[] $currentRow = new Object[$numberOfArrays]; + | $getValueForTypeSplitted + | $args[$i] = new $genericInternalRow($currentRow); + | } + | ${ev.value} = new $genericArrayData($args); + |} + """.stripMargin) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + if (numberOfArrays == 0) { + emptyInputGenCode(ev) + } else { + nonEmptyInputGenCode(ctx, ev) + } + } + + override def eval(input: InternalRow): Any = { + val inputArrays = children.map(_.eval(input).asInstanceOf[ArrayData]) + if (inputArrays.contains(null)) { + null + } else { + val biggestCardinality = if (inputArrays.isEmpty) { + 0 + } else { + inputArrays.map(_.numElements()).max + } + + val result = new Array[InternalRow](biggestCardinality) + val zippedArrs: Seq[(ArrayData, Int)] = inputArrays.zipWithIndex + + for (i <- 0 until biggestCardinality) { + val currentLayer: Seq[Object] = zippedArrs.map { case (arr, index) => + if (i < arr.numElements() && !arr.isNullAt(i)) { + arr.get(i, arrayElementTypes(index)) + } else { + null + } + } + + result(i) = InternalRow.apply(currentLayer: _*) + } + new GenericArrayData(result) + } + } + + override def prettyName: String = "arrays_zip" +} + /** * Returns an unordered array containing the values of the map. */ @@ -308,6 +475,223 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp override def prettyName: String = "map_entries" } +/** + * Returns a map created from the given array of entries. + */ +@ExpressionDescription( + usage = "_FUNC_(arrayOfEntries) - Returns a map created from the given array of entries.", + examples = """ + Examples: + > SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b'))); + {1:"a",2:"b"} + """, + since = "2.4.0") +case class MapFromEntries(child: Expression) extends UnaryExpression { + + @transient + private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = child.dataType match { + case ArrayType( + StructType(Array( + StructField(_, keyType, keyNullable, _), + StructField(_, valueType, valueNullable, _))), + containsNull) => Some((MapType(keyType, valueType, valueNullable), keyNullable, containsNull)) + case _ => None + } + + private def nullEntries: Boolean = dataTypeDetails.get._3 + + override def nullable: Boolean = child.nullable || nullEntries + + override def dataType: MapType = dataTypeDetails.get._1 + + override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match { + case Some(_) => TypeCheckResult.TypeCheckSuccess + case None => TypeCheckResult.TypeCheckFailure(s"'${child.sql}' is of " + + s"${child.dataType.simpleString} type. $prettyName accepts only arrays of pair structs.") + } + + override protected def nullSafeEval(input: Any): Any = { + val arrayData = input.asInstanceOf[ArrayData] + val numEntries = arrayData.numElements() + var i = 0 + if(nullEntries) { + while (i < numEntries) { + if (arrayData.isNullAt(i)) return null + i += 1 + } + } + val keyArray = new Array[AnyRef](numEntries) + val valueArray = new Array[AnyRef](numEntries) + i = 0 + while (i < numEntries) { + val entry = arrayData.getStruct(i, 2) + val key = entry.get(0, dataType.keyType) + if (key == null) { + throw new RuntimeException("The first field from a struct (key) can't be null.") + } + keyArray.update(i, key) + val value = entry.get(1, dataType.valueType) + valueArray.update(i, value) + i += 1 + } + ArrayBasedMapData(keyArray, valueArray) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + val numEntries = ctx.freshName("numEntries") + val isKeyPrimitive = CodeGenerator.isPrimitiveType(dataType.keyType) + val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType) + val code = if (isKeyPrimitive && isValuePrimitive) { + genCodeForPrimitiveElements(ctx, c, ev.value, numEntries) + } else { + genCodeForAnyElements(ctx, c, ev.value, numEntries) + } + ctx.nullArrayElementsSaveExec(nullEntries, ev.isNull, c) { + s""" + |final int $numEntries = $c.numElements(); + |$code + """.stripMargin + } + }) + } + + private def genCodeForAssignmentLoop( + ctx: CodegenContext, + childVariable: String, + mapData: String, + numEntries: String, + keyAssignment: (String, String) => String, + valueAssignment: (String, String) => String): String = { + val entry = ctx.freshName("entry") + val i = ctx.freshName("idx") + + val nullKeyCheck = if (dataTypeDetails.get._2) { + s""" + |if ($entry.isNullAt(0)) { + | throw new RuntimeException("The first field from a struct (key) can't be null."); + |} + """.stripMargin + } else { + "" + } + + s""" + |for (int $i = 0; $i < $numEntries; $i++) { + | InternalRow $entry = $childVariable.getStruct($i, 2); + | $nullKeyCheck + | ${keyAssignment(CodeGenerator.getValue(entry, dataType.keyType, "0"), i)} + | ${valueAssignment(entry, i)} + |} + """.stripMargin + } + + private def genCodeForPrimitiveElements( + ctx: CodegenContext, + childVariable: String, + mapData: String, + numEntries: String): String = { + val byteArraySize = ctx.freshName("byteArraySize") + val keySectionSize = ctx.freshName("keySectionSize") + val valueSectionSize = ctx.freshName("valueSectionSize") + val data = ctx.freshName("byteArray") + val unsafeMapData = ctx.freshName("unsafeMapData") + val keyArrayData = ctx.freshName("keyArrayData") + val valueArrayData = ctx.freshName("valueArrayData") + + val baseOffset = Platform.BYTE_ARRAY_OFFSET + val keySize = dataType.keyType.defaultSize + val valueSize = dataType.valueType.defaultSize + val kByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $keySize)" + val vByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $valueSize)" + val keyTypeName = CodeGenerator.primitiveTypeName(dataType.keyType) + val valueTypeName = CodeGenerator.primitiveTypeName(dataType.valueType) + + val keyAssignment = (key: String, idx: String) => s"$keyArrayData.set$keyTypeName($idx, $key);" + val valueAssignment = (entry: String, idx: String) => { + val value = CodeGenerator.getValue(entry, dataType.valueType, "1") + val valueNullUnsafeAssignment = s"$valueArrayData.set$valueTypeName($idx, $value);" + if (dataType.valueContainsNull) { + s""" + |if ($entry.isNullAt(1)) { + | $valueArrayData.setNullAt($idx); + |} else { + | $valueNullUnsafeAssignment + |} + """.stripMargin + } else { + valueNullUnsafeAssignment + } + } + val assignmentLoop = genCodeForAssignmentLoop( + ctx, + childVariable, + mapData, + numEntries, + keyAssignment, + valueAssignment + ) + + s""" + |final long $keySectionSize = $kByteSize; + |final long $valueSectionSize = $vByteSize; + |final long $byteArraySize = 8 + $keySectionSize + $valueSectionSize; + |if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | ${genCodeForAnyElements(ctx, childVariable, mapData, numEntries)} + |} else { + | final byte[] $data = new byte[(int)$byteArraySize]; + | UnsafeMapData $unsafeMapData = new UnsafeMapData(); + | Platform.putLong($data, $baseOffset, $keySectionSize); + | Platform.putLong($data, ${baseOffset + 8}, $numEntries); + | Platform.putLong($data, ${baseOffset + 8} + $keySectionSize, $numEntries); + | $unsafeMapData.pointTo($data, $baseOffset, (int)$byteArraySize); + | ArrayData $keyArrayData = $unsafeMapData.keyArray(); + | ArrayData $valueArrayData = $unsafeMapData.valueArray(); + | $assignmentLoop + | $mapData = $unsafeMapData; + |} + """.stripMargin + } + + private def genCodeForAnyElements( + ctx: CodegenContext, + childVariable: String, + mapData: String, + numEntries: String): String = { + val keys = ctx.freshName("keys") + val values = ctx.freshName("values") + val mapDataClass = classOf[ArrayBasedMapData].getName() + + val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType) + val valueAssignment = (entry: String, idx: String) => { + val value = CodeGenerator.getValue(entry, dataType.valueType, "1") + if (dataType.valueContainsNull && isValuePrimitive) { + s"$values[$idx] = $entry.isNullAt(1) ? null : (Object)$value;" + } else { + s"$values[$idx] = $value;" + } + } + val keyAssignment = (key: String, idx: String) => s"$keys[$idx] = $key;" + val assignmentLoop = genCodeForAssignmentLoop( + ctx, + childVariable, + mapData, + numEntries, + keyAssignment, + valueAssignment) + + s""" + |final Object[] $keys = new Object[$numEntries]; + |final Object[] $values = new Object[$numEntries]; + |$assignmentLoop + |$mapData = $mapDataClass.apply($keys, $values); + """.stripMargin + } + + override def prettyName: String = "map_from_entries" +} + + /** * Common base class for [[SortArray]] and [[ArraySort]]. */ @@ -1237,6 +1621,7 @@ case class ArrayJoin( override def dataType: DataType = StringType + override def prettyName: String = "array_join" } /** @@ -1823,24 +2208,10 @@ case class Flatten(child: Expression) extends UnaryExpression { } else { genCodeForFlattenOfNonPrimitiveElements(ctx, c, ev.value) } - if (childDataType.containsNull) nullElementsProtection(ev, c, code) else code + ctx.nullArrayElementsSaveExec(childDataType.containsNull, ev.isNull, c)(code) }) } - private def nullElementsProtection( - ev: ExprCode, - childVariableName: String, - coreLogic: String): String = { - s""" - |for (int z = 0; !${ev.isNull} && z < $childVariableName.numElements(); z++) { - | ${ev.isNull} |= $childVariableName.isNullAt(z); - |} - |if (!${ev.isNull}) { - | $coreLogic - |} - """.stripMargin - } - private def genCodeForNumberOfElements( ctx: CodegenContext, childVariableName: String) : (String, String) = { @@ -2189,3 +2560,281 @@ case class ArrayRemove(left: Expression, right: Expression) override def prettyName: String = "array_remove" } + +/** + * Removes duplicate values from the array. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Removes duplicate values from the array.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3, null, 3)); + [1,2,3,null] + """, since = "2.4.0") +case class ArrayDistinct(child: Expression) + extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + override def dataType: DataType = child.dataType + + @transient lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(elementType) + + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(elementType, s"function $prettyName") + } + } + + @transient private lazy val elementTypeSupportEquals = elementType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } + + override def nullSafeEval(array: Any): Any = { + val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) + if (elementTypeSupportEquals) { + new GenericArrayData(data.distinct.asInstanceOf[Array[Any]]) + } else { + var foundNullElement = false + var pos = 0 + for (i <- 0 until data.length) { + if (data(i) == null) { + if (!foundNullElement) { + foundNullElement = true + pos = pos + 1 + } + } else { + var j = 0 + var done = false + while (j <= i && !done) { + if (data(j) != null && ordering.equiv(data(j), data(i))) { + done = true + } + j = j + 1 + } + if (i == j - 1) { + pos = pos + 1 + } + } + } + new GenericArrayData(data.slice(0, pos)) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (array) => { + val i = ctx.freshName("i") + val j = ctx.freshName("j") + val sizeOfDistinctArray = ctx.freshName("sizeOfDistinctArray") + val getValue1 = CodeGenerator.getValue(array, elementType, i) + val getValue2 = CodeGenerator.getValue(array, elementType, j) + val foundNullElement = ctx.freshName("foundNullElement") + val openHashSet = classOf[OpenHashSet[_]].getName + val hs = ctx.freshName("hs") + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()" + if (elementTypeSupportEquals) { + s""" + |int $sizeOfDistinctArray = 0; + |boolean $foundNullElement = false; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $array.numElements(); $i ++) { + | if ($array.isNullAt($i)) { + | $foundNullElement = true; + | } else { + | $hs.add($getValue1); + | } + |} + |$sizeOfDistinctArray = $hs.size() + ($foundNullElement ? 1 : 0); + |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)} + """.stripMargin + } else { + s""" + |int $sizeOfDistinctArray = 0; + |boolean $foundNullElement = false; + |for (int $i = 0; $i < $array.numElements(); $i ++) { + | if ($array.isNullAt($i)) { + | if (!($foundNullElement)) { + | $sizeOfDistinctArray = $sizeOfDistinctArray + 1; + | $foundNullElement = true; + | } + | } else { + | int $j; + | for ($j = 0; $j < $i; $j ++) { + | if (!$array.isNullAt($j) && ${ctx.genEqual(elementType, getValue1, getValue2)}) { + | break; + | } + | } + | if ($i == $j) { + | $sizeOfDistinctArray = $sizeOfDistinctArray + 1; + | } + | } + |} + | + |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)} + """.stripMargin + } + }) + } + + private def setNull( + isPrimitive: Boolean, + foundNullElement: String, + distinctArray: String, + pos: String): String = { + val setNullValue = + if (!isPrimitive) { + s"$distinctArray[$pos] = null"; + } else { + s"$distinctArray.setNullAt($pos)"; + } + + s""" + |if (!($foundNullElement)) { + | $setNullValue; + | $pos = $pos + 1; + | $foundNullElement = true; + |} + """.stripMargin + } + + private def setNotNullValue(isPrimitive: Boolean, + distinctArray: String, + pos: String, + getValue1: String, + primitiveValueTypeName: String): String = { + if (!isPrimitive) { + s"$distinctArray[$pos] = $getValue1"; + } else { + s"$distinctArray.set$primitiveValueTypeName($pos, $getValue1)"; + } + } + + private def setValueForFastEval( + isPrimitive: Boolean, + hs: String, + distinctArray: String, + pos: String, + getValue1: String, + primitiveValueTypeName: String): String = { + val setValue = setNotNullValue(isPrimitive, + distinctArray, pos, getValue1, primitiveValueTypeName) + s""" + |if (!($hs.contains($getValue1))) { + | $hs.add($getValue1); + | $setValue; + | $pos = $pos + 1; + |} + """.stripMargin + } + + private def setValueForBruteForceEval( + isPrimitive: Boolean, + i: String, + j: String, + inputArray: String, + distinctArray: String, + pos: String, + getValue1: String, + isEqual: String, + primitiveValueTypeName: String): String = { + val setValue = setNotNullValue(isPrimitive, + distinctArray, pos, getValue1, primitiveValueTypeName) + s""" + |int $j; + |for ($j = 0; $j < $i; $j ++) { + | if (!$inputArray.isNullAt($j) && $isEqual) { + | break; + | } + |} + |if ($i == $j) { + | $setValue; + | $pos = $pos + 1; + |} + """.stripMargin + } + + def genCodeForResult( + ctx: CodegenContext, + ev: ExprCode, + inputArray: String, + size: String): String = { + val distinctArray = ctx.freshName("distinctArray") + val i = ctx.freshName("i") + val j = ctx.freshName("j") + val pos = ctx.freshName("pos") + val getValue1 = CodeGenerator.getValue(inputArray, elementType, i) + val getValue2 = CodeGenerator.getValue(inputArray, elementType, j) + val isEqual = ctx.genEqual(elementType, getValue1, getValue2) + val foundNullElement = ctx.freshName("foundNullElement") + val hs = ctx.freshName("hs") + val openHashSet = classOf[OpenHashSet[_]].getName + if (!CodeGenerator.isPrimitiveType(elementType)) { + val arrayClass = classOf[GenericArrayData].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()" + val setNullForNonPrimitive = + setNull(false, foundNullElement, distinctArray, pos) + if (elementTypeSupportEquals) { + val setValueForFast = setValueForFastEval(false, hs, distinctArray, pos, getValue1, "") + s""" + |int $pos = 0; + |Object[] $distinctArray = new Object[$size]; + |boolean $foundNullElement = false; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $setNullForNonPrimitive; + | } else { + | $setValueForFast; + | } + |} + |${ev.value} = new $arrayClass($distinctArray); + """.stripMargin + } else { + val setValueForBruteForce = setValueForBruteForceEval( + false, i, j, inputArray, distinctArray, pos, getValue1, isEqual, "") + s""" + |int $pos = 0; + |Object[] $distinctArray = new Object[$size]; + |boolean $foundNullElement = false; + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $setNullForNonPrimitive; + | } else { + | $setValueForBruteForce; + | } + |} + |${ev.value} = new $arrayClass($distinctArray); + """.stripMargin + } + } else { + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + val setNullForPrimitive = setNull(true, foundNullElement, distinctArray, pos) + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$primitiveValueTypeName()" + val setValueForFast = + setValueForFastEval(true, hs, distinctArray, pos, getValue1, primitiveValueTypeName) + s""" + |${ctx.createUnsafeArray(distinctArray, size, elementType, s" $prettyName failed.")} + |int $pos = 0; + |boolean $foundNullElement = false; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $setNullForPrimitive; + | } else { + | $setValueForFast; + | } + |} + |${ev.value} = $distinctArray; + """.stripMargin + } + } + + override def prettyName: String = "array_distinct" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index a9867aaeb0cfe..0a5f8a907b50a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils} +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods @@ -236,6 +236,76 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override def prettyName: String = "map" } +/** + * Returns a catalyst Map containing the two arrays in children expressions as keys and values. + */ +@ExpressionDescription( + usage = """ + _FUNC_(keys, values) - Creates a map with a pair of the given key/value arrays. All elements + in keys should not be null""", + examples = """ + Examples: + > SELECT _FUNC_([1.0, 3.0], ['2', '4']); + {1.0:"2",3.0:"4"} + """, since = "2.4.0") +case class MapFromArrays(left: Expression, right: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) + + override def dataType: DataType = { + MapType( + keyType = left.dataType.asInstanceOf[ArrayType].elementType, + valueType = right.dataType.asInstanceOf[ArrayType].elementType, + valueContainsNull = right.dataType.asInstanceOf[ArrayType].containsNull) + } + + override def nullSafeEval(keyArray: Any, valueArray: Any): Any = { + val keyArrayData = keyArray.asInstanceOf[ArrayData] + val valueArrayData = valueArray.asInstanceOf[ArrayData] + if (keyArrayData.numElements != valueArrayData.numElements) { + throw new RuntimeException("The given two arrays should have the same length") + } + val leftArrayType = left.dataType.asInstanceOf[ArrayType] + if (leftArrayType.containsNull) { + var i = 0 + while (i < keyArrayData.numElements) { + if (keyArrayData.isNullAt(i)) { + throw new RuntimeException("Cannot use null as map key!") + } + i += 1 + } + } + new ArrayBasedMapData(keyArrayData.copy(), valueArrayData.copy()) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (keyArrayData, valueArrayData) => { + val arrayBasedMapData = classOf[ArrayBasedMapData].getName + val leftArrayType = left.dataType.asInstanceOf[ArrayType] + val keyArrayElemNullCheck = if (!leftArrayType.containsNull) "" else { + val i = ctx.freshName("i") + s""" + |for (int $i = 0; $i < $keyArrayData.numElements(); $i++) { + | if ($keyArrayData.isNullAt($i)) { + | throw new RuntimeException("Cannot use null as map key!"); + | } + |} + """.stripMargin + } + s""" + |if ($keyArrayData.numElements() != $valueArrayData.numElements()) { + | throw new RuntimeException("The given two arrays should have the same length"); + |} + |$keyArrayElemNullCheck + |${ev.value} = new $arrayBasedMapData($keyArrayData.copy(), $valueArrayData.copy()); + """.stripMargin + }) + } + + override def prettyName: String = "map_from_arrays" +} + /** * An expression representing a not yet available attribute name. This expression is unevaluable * and as its name suggests it is a temporary place holder until we're able to determine the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 04a4eb0ffc032..f6d74f5b74c8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.json._ -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, FailFastMode, GenericArrayData, MapData} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -747,8 +746,8 @@ case class StructsToJson( object JsonExprUtils { - def validateSchemaLiteral(exp: Expression): StructType = exp match { - case Literal(s, StringType) => CatalystSqlParser.parseTableSchema(s.toString) + def validateSchemaLiteral(exp: Expression): DataType = exp match { + case Literal(s, StringType) => DataType.fromDDL(s.toString) case e => throw new AnalysisException(s"Expected a string literal instead of $e") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 246025b82d59e..0cc2a332f2c30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -57,6 +57,7 @@ object Literal { case b: Byte => Literal(b, ByteType) case s: Short => Literal(s, ShortType) case s: String => Literal(UTF8String.fromString(s), StringType) + case c: Char => Literal(UTF8String.fromString(c.toString), StringType) case b: Boolean => Literal(b, BooleanType) case d: BigDecimal => Literal(Decimal(d), DecimalType.fromBigDecimal(d)) case d: JavaBigDecimal => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 9fe2fb2b95e4d..f957aaa96e98c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -21,7 +21,7 @@ import java.util.Locale import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} -import org.apache.spark.sql.catalyst.expressions.aggregate.{DeclarativeAggregate, NoOp} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, DeclarativeAggregate, NoOp} import org.apache.spark.sql.types._ /** @@ -297,6 +297,37 @@ trait WindowFunction extends Expression { def frame: WindowFrame = UnspecifiedFrame } +/** + * Case objects that describe whether a window function is a SQL window function or a Python + * user-defined window function. + */ +sealed trait WindowFunctionType + +object WindowFunctionType { + case object SQL extends WindowFunctionType + case object Python extends WindowFunctionType + + def functionType(windowExpression: NamedExpression): WindowFunctionType = { + val t = windowExpression.collectFirst { + case _: WindowFunction | _: AggregateFunction => SQL + case udf: PythonUDF if PythonUDF.isWindowPandasUDF(udf) => Python + } + + // Normally a window expression would either have a SQL window function, a SQL + // aggregate function or a python window UDF. However, sometimes the optimizer will replace + // the window function if the value of the window function can be predetermined. + // For example, for query: + // + // select count(NULL) over () from values 1.0, 2.0, 3.0 T(a) + // + // The window function will be replaced by expression literal(0) + // To handle this case, if a window expression doesn't have a regular window function, we + // consider its type to be SQL as literal(0) is also a SQL expression. + t.getOrElse(SQL) + } +} + + /** * An offset window function is a window function that returns the value of the input column offset * by a number of rows within the partition. For instance: an OffsetWindowfunction for value x with diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 2ff12acb2946f..47eeb70e00427 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -73,6 +73,9 @@ private[sql] class JSONOptions( val columnNameOfCorruptRecord = parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) + // Whether to ignore column of all null values or empty array/struct during schema inference + val dropFieldIfAllNull = parameters.get("dropFieldIfAllNull").map(_.toBoolean).getOrElse(false) + val timeZone: TimeZone = DateTimeUtils.getTimeZone( parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) @@ -94,32 +97,16 @@ private[sql] class JSONOptions( sep } + protected def checkedEncoding(enc: String): String = enc + /** * Standard encoding (charset) name. For example UTF-8, UTF-16LE and UTF-32BE. - * If the encoding is not specified (None), it will be detected automatically - * when the multiLine option is set to `true`. + * If the encoding is not specified (None) in read, it will be detected automatically + * when the multiLine option is set to `true`. If encoding is not specified in write, + * UTF-8 is used by default. */ val encoding: Option[String] = parameters.get("encoding") - .orElse(parameters.get("charset")).map { enc => - // The following encodings are not supported in per-line mode (multiline is false) - // because they cause some problems in reading files with BOM which is supposed to - // present in the files with such encodings. After splitting input files by lines, - // only the first lines will have the BOM which leads to impossibility for reading - // the rest lines. Besides of that, the lineSep option must have the BOM in such - // encodings which can never present between lines. - val blacklist = Seq(Charset.forName("UTF-16"), Charset.forName("UTF-32")) - val isBlacklisted = blacklist.contains(Charset.forName(enc)) - require(multiLine || !isBlacklisted, - s"""The $enc encoding in the blacklist is not allowed when multiLine is disabled. - |Blacklist: ${blacklist.mkString(", ")}""".stripMargin) - - val isLineSepRequired = - multiLine || Charset.forName(enc) == StandardCharsets.UTF_8 || lineSeparator.nonEmpty - - require(isLineSepRequired, s"The lineSep option must be specified for the $enc encoding") - - enc - } + .orElse(parameters.get("charset")).map(checkedEncoding) val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep => lineSep.getBytes(encoding.getOrElse("UTF-8")) @@ -138,3 +125,46 @@ private[sql] class JSONOptions( factory.configure(JsonParser.Feature.ALLOW_UNQUOTED_CONTROL_CHARS, allowUnquotedControlChars) } } + +private[sql] class JSONOptionsInRead( + @transient override val parameters: CaseInsensitiveMap[String], + defaultTimeZoneId: String, + defaultColumnNameOfCorruptRecord: String) + extends JSONOptions(parameters, defaultTimeZoneId, defaultColumnNameOfCorruptRecord) { + + def this( + parameters: Map[String, String], + defaultTimeZoneId: String, + defaultColumnNameOfCorruptRecord: String = "") = { + this( + CaseInsensitiveMap(parameters), + defaultTimeZoneId, + defaultColumnNameOfCorruptRecord) + } + + protected override def checkedEncoding(enc: String): String = { + val isBlacklisted = JSONOptionsInRead.blacklist.contains(Charset.forName(enc)) + require(multiLine || !isBlacklisted, + s"""The ${enc} encoding must not be included in the blacklist when multiLine is disabled: + |Blacklist: ${JSONOptionsInRead.blacklist.mkString(", ")}""".stripMargin) + + val isLineSepRequired = + multiLine || Charset.forName(enc) == StandardCharsets.UTF_8 || lineSeparator.nonEmpty + require(isLineSepRequired, s"The lineSep option must be specified for the $enc encoding") + + enc + } +} + +private[sql] object JSONOptionsInRead { + // The following encodings are not supported in per-line mode (multiline is false) + // because they cause some problems in reading files with BOM which is supposed to + // present in the files with such encodings. After splitting input files by lines, + // only the first lines will have the BOM which leads to impossibility for reading + // the rest lines. Besides of that, the lineSep option must have the BOM in such + // encodings which can never present between lines. + val blacklist = Seq( + Charset.forName("UTF-16"), + Charset.forName("UTF-32") + ) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index bfa61116a6658..aa992def1ce6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -621,12 +621,15 @@ object CollapseRepartition extends Rule[LogicalPlan] { /** * Collapse Adjacent Window Expression. * - If the partition specs and order specs are the same and the window expression are - * independent, collapse into the parent. + * independent and are of the same window function type, collapse into the parent. */ object CollapseWindow extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case w1 @ Window(we1, ps1, os1, w2 @ Window(we2, ps2, os2, grandChild)) - if ps1 == ps2 && os1 == os2 && w1.references.intersect(w2.windowOutputSet).isEmpty => + if ps1 == ps2 && os1 == os2 && w1.references.intersect(w2.windowOutputSet).isEmpty && + // This assumes Window contains the same type of window expressions. This is ensured + // by ExtractWindowFunctions. + WindowFunctionType.functionType(we1.head) == WindowFunctionType.functionType(we2.head) => w1.copy(windowExpressions = we2 ++ we1, child = grandChild) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 626f905707191..84be677e438a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.planning import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ @@ -215,7 +216,7 @@ object PhysicalAggregation { case agg: AggregateExpression if !equivalentAggregateExpressions.addExpr(agg) => agg case udf: PythonUDF - if PythonUDF.isGroupAggPandasUDF(udf) && + if PythonUDF.isGroupedAggPandasUDF(udf) && !equivalentAggregateExpressions.addExpr(udf) => udf } } @@ -245,7 +246,7 @@ object PhysicalAggregation { equivalentAggregateExpressions.getEquivalentExprs(ae).headOption .getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute // Similar to AggregateExpression - case ue: PythonUDF if PythonUDF.isGroupAggPandasUDF(ue) => + case ue: PythonUDF if PythonUDF.isGroupedAggPandasUDF(ue) => equivalentAggregateExpressions.getEquivalentExprs(ue).headOption .getOrElse(ue).asInstanceOf[PythonUDF].resultAttribute case expression => @@ -268,3 +269,40 @@ object PhysicalAggregation { case _ => None } } + +/** + * An extractor used when planning physical execution of a window. This extractor outputs + * the window function type of the logical window. + * + * The input logical window must contain same type of window functions, which is ensured by + * the rule ExtractWindowExpressions in the analyzer. + */ +object PhysicalWindow { + // windowFunctionType, windowExpression, partitionSpec, orderSpec, child + private type ReturnType = + (WindowFunctionType, Seq[NamedExpression], Seq[Expression], Seq[SortOrder], LogicalPlan) + + def unapply(a: Any): Option[ReturnType] = a match { + case expr @ logical.Window(windowExpressions, partitionSpec, orderSpec, child) => + + // The window expression should not be empty here, otherwise it's a bug. + if (windowExpressions.isEmpty) { + throw new AnalysisException(s"Window expression is empty in $expr") + } + + val windowFunctionType = windowExpressions.map(WindowFunctionType.functionType) + .reduceLeft { (t1: WindowFunctionType, t2: WindowFunctionType) => + if (t1 != t2) { + // We shouldn't have different window function type here, otherwise it's a bug. + throw new AnalysisException( + s"Found different window function type in $windowExpressions") + } else { + t1 + } + } + + Some((windowFunctionType, windowExpressions, partitionSpec, orderSpec, child)) + + case _ => None + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 64cb8c726772f..e431c9523a9da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -119,6 +119,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT case Some(value) => Some(recursiveTransform(value)) case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs + case stream: Stream[_] => stream.map(recursiveTransform).force case seq: Traversable[_] => seq.map(recursiveTransform) case other: AnyRef => other case null => null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 4d9a9925fe3ff..cc1a5e835d9cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -99,16 +99,19 @@ case class ClusteredDistribution( * This is a strictly stronger guarantee than [[ClusteredDistribution]]. Given a tuple and the * number of partitions, this distribution strictly requires which partition the tuple should be in. */ -case class HashClusteredDistribution(expressions: Seq[Expression]) extends Distribution { +case class HashClusteredDistribution( + expressions: Seq[Expression], + requiredNumPartitions: Option[Int] = None) extends Distribution { require( expressions != Nil, - "The expressions for hash of a HashPartitionedDistribution should not be Nil. " + + "The expressions for hash of a HashClusteredDistribution should not be Nil. " + "An AllTuples should be used to represent a distribution that only has " + "a single partition.") - override def requiredNumPartitions: Option[Int] = None - override def createPartitioning(numPartitions: Int): Partitioning = { + assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions, + s"This HashClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " + + s"the actual number of partitions is $numPartitions.") HashPartitioning(expressions, numPartitions) } } @@ -163,11 +166,22 @@ trait Partitioning { * i.e. the current dataset does not need to be re-partitioned for the `required` * Distribution (it is possible that tuples within a partition need to be reorganized). * + * A [[Partitioning]] can never satisfy a [[Distribution]] if its `numPartitions` does't match + * [[Distribution.requiredNumPartitions]]. + */ + final def satisfies(required: Distribution): Boolean = { + required.requiredNumPartitions.forall(_ == numPartitions) && satisfies0(required) + } + + /** + * The actual method that defines whether this [[Partitioning]] can satisfy the given + * [[Distribution]], after the `numPartitions` check. + * * By default a [[Partitioning]] can satisfy [[UnspecifiedDistribution]], and [[AllTuples]] if - * the [[Partitioning]] only have one partition. Implementations can overwrite this method with - * special logic. + * the [[Partitioning]] only have one partition. Implementations can also overwrite this method + * with special logic. */ - def satisfies(required: Distribution): Boolean = required match { + protected def satisfies0(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true case AllTuples => numPartitions == 1 case _ => false @@ -186,9 +200,8 @@ case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning case object SinglePartition extends Partitioning { val numPartitions = 1 - override def satisfies(required: Distribution): Boolean = required match { + override def satisfies0(required: Distribution): Boolean = required match { case _: BroadcastDistribution => false - case ClusteredDistribution(_, Some(requiredNumPartitions)) => requiredNumPartitions == 1 case _ => true } } @@ -205,16 +218,15 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - override def satisfies(required: Distribution): Boolean = { - super.satisfies(required) || { + override def satisfies0(required: Distribution): Boolean = { + super.satisfies0(required) || { required match { case h: HashClusteredDistribution => expressions.length == h.expressions.length && expressions.zip(h.expressions).forall { case (l, r) => l.semanticEquals(r) } - case ClusteredDistribution(requiredClustering, requiredNumPartitions) => - expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) && - (requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions) + case ClusteredDistribution(requiredClustering, _) => + expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) case _ => false } } @@ -246,15 +258,14 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - override def satisfies(required: Distribution): Boolean = { - super.satisfies(required) || { + override def satisfies0(required: Distribution): Boolean = { + super.satisfies0(required) || { required match { case OrderedDistribution(requiredOrdering) => val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) - case ClusteredDistribution(requiredClustering, requiredNumPartitions) => - ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) && - (requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions) + case ClusteredDistribution(requiredClustering, _) => + ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) case _ => false } } @@ -295,7 +306,7 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) * Returns true if any `partitioning` of this collection satisfies the given * [[Distribution]]. */ - override def satisfies(required: Distribution): Boolean = + override def satisfies0(required: Distribution): Boolean = partitionings.exists(_.satisfies(required)) override def toString: String = { @@ -310,7 +321,7 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning { override val numPartitions: Int = 1 - override def satisfies(required: Distribution): Boolean = required match { + override def satisfies0(required: Distribution): Boolean = required match { case BroadcastDistribution(m) if m == mode => true case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 9c7d47f99ee10..becfa8d982213 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -199,44 +199,33 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { var changed = false val remainingNewChildren = newChildren.toBuffer val remainingOldChildren = children.toBuffer + def mapTreeNode(node: TreeNode[_]): TreeNode[_] = { + val newChild = remainingNewChildren.remove(0) + val oldChild = remainingOldChildren.remove(0) + if (newChild fastEquals oldChild) { + oldChild + } else { + changed = true + newChild + } + } + def mapChild(child: Any): Any = child match { + case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg) + case nonChild: AnyRef => nonChild + case null => null + } val newArgs = mapProductIterator { case s: StructType => s // Don't convert struct types to some other type of Seq[StructField] // Handle Seq[TreeNode] in TreeNode parameters. - case s: Seq[_] => s.map { - case arg: TreeNode[_] if containsChild(arg) => - val newChild = remainingNewChildren.remove(0) - val oldChild = remainingOldChildren.remove(0) - if (newChild fastEquals oldChild) { - oldChild - } else { - changed = true - newChild - } - case nonChild: AnyRef => nonChild - case null => null - } - case m: Map[_, _] => m.mapValues { - case arg: TreeNode[_] if containsChild(arg) => - val newChild = remainingNewChildren.remove(0) - val oldChild = remainingOldChildren.remove(0) - if (newChild fastEquals oldChild) { - oldChild - } else { - changed = true - newChild - } - case nonChild: AnyRef => nonChild - case null => null - }.view.force // `mapValues` is lazy and we need to force it to materialize - case arg: TreeNode[_] if containsChild(arg) => - val newChild = remainingNewChildren.remove(0) - val oldChild = remainingOldChildren.remove(0) - if (newChild fastEquals oldChild) { - oldChild - } else { - changed = true - newChild - } + case s: Stream[_] => + // Stream is lazy so we need to force materialization + s.map(mapChild).force + case s: Seq[_] => + s.map(mapChild) + case m: Map[_, _] => + // `mapValues` is lazy and we need to force it to materialize + m.mapValues(mapChild).view.force + case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg) case nonChild: AnyRef => nonChild case null => null } @@ -301,6 +290,37 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { def mapChildren(f: BaseType => BaseType): BaseType = { if (children.nonEmpty) { var changed = false + def mapChild(child: Any): Any = child match { + case arg: TreeNode[_] if containsChild(arg) => + val newChild = f(arg.asInstanceOf[BaseType]) + if (!(newChild fastEquals arg)) { + changed = true + newChild + } else { + arg + } + case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) => + val newChild1 = if (containsChild(arg1)) { + f(arg1.asInstanceOf[BaseType]) + } else { + arg1.asInstanceOf[BaseType] + } + + val newChild2 = if (containsChild(arg2)) { + f(arg2.asInstanceOf[BaseType]) + } else { + arg2.asInstanceOf[BaseType] + } + + if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) { + changed = true + (newChild1, newChild2) + } else { + tuple + } + case other => other + } + val newArgs = mapProductIterator { case arg: TreeNode[_] if containsChild(arg) => val newChild = f(arg.asInstanceOf[BaseType]) @@ -330,36 +350,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case other => other }.view.force // `mapValues` is lazy and we need to force it to materialize case d: DataType => d // Avoid unpacking Structs - case args: Traversable[_] => args.map { - case arg: TreeNode[_] if containsChild(arg) => - val newChild = f(arg.asInstanceOf[BaseType]) - if (!(newChild fastEquals arg)) { - changed = true - newChild - } else { - arg - } - case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) => - val newChild1 = if (containsChild(arg1)) { - f(arg1.asInstanceOf[BaseType]) - } else { - arg1.asInstanceOf[BaseType] - } - - val newChild2 = if (containsChild(arg2)) { - f(arg2.asInstanceOf[BaseType]) - } else { - arg2.asInstanceOf[BaseType] - } - - if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) { - changed = true - (newChild1, newChild2) - } else { - tuple - } - case other => other - } + case args: Stream[_] => args.map(mapChild).force // Force materialization on stream + case args: Traversable[_] => args.map(mapChild) case nonChild: AnyRef => nonChild case null => null } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 8d2320d8a6ed7..e768416f257c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1161,6 +1161,16 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_POSITION = + buildConf("spark.sql.execution.pandas.groupedMap.assignColumnsByPosition") + .internal() + .doc("When true, a grouped map Pandas UDF will assign columns from the returned " + + "Pandas DataFrame based on position, regardless of column label type. When false, " + + "columns will be looked up by name if labeled with a string and fallback to use " + + "position if not. This configuration will be deprecated in future releases.") + .booleanConf + .createWithDefault(false) + val REPLACE_EXCEPT_WITH_FILTER = buildConf("spark.sql.optimizer.replaceExceptWithFilter") .internal() .doc("When true, the apply function of the rule verifies whether the right node of the" + @@ -1647,6 +1657,9 @@ class SQLConf extends Serializable with Logging { def pandasRespectSessionTimeZone: Boolean = getConf(PANDAS_RESPECT_SESSION_LOCAL_TIMEZONE) + def pandasGroupedMapAssignColumnssByPosition: Boolean = + getConf(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_POSITION) + def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index fe0ad39c29025..382ef28f49a7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -96,6 +96,14 @@ object StaticSQLConf { .toSequence .createOptional + val STREAMING_QUERY_LISTENERS = buildStaticConf("spark.sql.streaming.streamingQueryListeners") + .doc("List of class names implementing StreamingQueryListener that will be automatically " + + "added to newly created sessions. The classes should have either a no-arg constructor, " + + "or a constructor that expects a SparkConf argument.") + .stringConf + .toSequence + .createOptional + val UI_RETAINED_EXECUTIONS = buildStaticConf("spark.sql.ui.retainedExecutions") .doc("Number of executions to retain in the Spark UI.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 0bef11659fc9e..fd40741cfb5f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.types import java.util.Locale +import scala.util.control.NonFatal + import org.json4s._ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ @@ -26,6 +28,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils @@ -110,6 +113,14 @@ abstract class DataType extends AbstractDataType { @InterfaceStability.Stable object DataType { + def fromDDL(ddl: String): DataType = { + try { + CatalystSqlParser.parseDataType(ddl) + } catch { + case NonFatal(_) => CatalystSqlParser.parseTableSchema(ddl) + } + } + def fromJson(json: String): DataType = parseDataType(parse(json)) private val nonDecimalNameToType = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala index f99af9b84d959..89452ee05cff3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class CatalystTypeConvertersSuite extends SparkFunSuite { @@ -139,4 +140,11 @@ class CatalystTypeConvertersSuite extends SparkFunSuite { assert(exception.getMessage.contains("The value (0.1) of the type " + "(java.lang.Double) cannot be converted to the string type")) } + + test("SPARK-24571: convert Char to String") { + val chr: Char = 'X' + val converter = CatalystTypeConverters.createToCatalystConverter(StringType) + val expected = UTF8String.fromString("X") + assert(converter(chr) === expected) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index b47b8adfe5d55..39228102682b9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -41,34 +41,127 @@ class DistributionSuite extends SparkFunSuite { } } - test("HashPartitioning (with nullSafe = true) is the output partitioning") { - // Cases which do not need an exchange between two data properties. + test("UnspecifiedDistribution and AllTuples") { + // except `BroadcastPartitioning`, all other partitioning can satisfy UnspecifiedDistribution checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10), + UnknownPartitioning(-1), UnspecifiedDistribution, true) checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10), - ClusteredDistribution(Seq('a, 'b, 'c)), + RoundRobinPartitioning(10), + UnspecifiedDistribution, true) checkSatisfied( - HashPartitioning(Seq('b, 'c), 10), - ClusteredDistribution(Seq('a, 'b, 'c)), + SinglePartition, + UnspecifiedDistribution, + true) + + checkSatisfied( + HashPartitioning(Seq('a), 10), + UnspecifiedDistribution, + true) + + checkSatisfied( + RangePartitioning(Seq('a.asc), 10), + UnspecifiedDistribution, + true) + + checkSatisfied( + BroadcastPartitioning(IdentityBroadcastMode), + UnspecifiedDistribution, + false) + + // except `BroadcastPartitioning`, all other partitioning can satisfy AllTuples if they have + // only one partition. + checkSatisfied( + UnknownPartitioning(1), + AllTuples, + true) + + checkSatisfied( + UnknownPartitioning(10), + AllTuples, + false) + + checkSatisfied( + RoundRobinPartitioning(1), + AllTuples, + true) + + checkSatisfied( + RoundRobinPartitioning(10), + AllTuples, + false) + + checkSatisfied( + SinglePartition, + AllTuples, + true) + + checkSatisfied( + HashPartitioning(Seq('a), 1), + AllTuples, true) + checkSatisfied( + HashPartitioning(Seq('a), 10), + AllTuples, + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc), 1), + AllTuples, + true) + + checkSatisfied( + RangePartitioning(Seq('a.asc), 10), + AllTuples, + false) + + checkSatisfied( + BroadcastPartitioning(IdentityBroadcastMode), + AllTuples, + false) + } + + test("SinglePartition is the output partitioning") { + // SinglePartition can satisfy all the distributions except `BroadcastDistribution` checkSatisfied( SinglePartition, ClusteredDistribution(Seq('a, 'b, 'c)), true) + checkSatisfied( + SinglePartition, + HashClusteredDistribution(Seq('a, 'b, 'c)), + true) + checkSatisfied( SinglePartition, OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), true) - // Cases which need an exchange between two data properties. + checkSatisfied( + SinglePartition, + BroadcastDistribution(IdentityBroadcastMode), + false) + } + + test("HashPartitioning is the output partitioning") { + // HashPartitioning can satisfy ClusteredDistribution iff its hash expressions are a subset of + // the required clustering expressions. + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c)), + true) + + checkSatisfied( + HashPartitioning(Seq('b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c)), + true) + checkSatisfied( HashPartitioning(Seq('a, 'b, 'c), 10), ClusteredDistribution(Seq('b, 'c)), @@ -79,37 +172,43 @@ class DistributionSuite extends SparkFunSuite { ClusteredDistribution(Seq('d, 'e)), false) + // HashPartitioning can satisfy HashClusteredDistribution iff its hash expressions are exactly + // same with the required hash clustering expressions. checkSatisfied( HashPartitioning(Seq('a, 'b, 'c), 10), - AllTuples, + HashClusteredDistribution(Seq('a, 'b, 'c)), + true) + + checkSatisfied( + HashPartitioning(Seq('c, 'b, 'a), 10), + HashClusteredDistribution(Seq('a, 'b, 'c)), false) + checkSatisfied( + HashPartitioning(Seq('a, 'b), 10), + HashClusteredDistribution(Seq('a, 'b, 'c)), + false) + + // HashPartitioning cannot satisfy OrderedDistribution checkSatisfied( HashPartitioning(Seq('a, 'b, 'c), 10), OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), false) checkSatisfied( - HashPartitioning(Seq('b, 'c), 10), + HashPartitioning(Seq('a, 'b, 'c), 1), OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), - false) + false) // TODO: this can be relaxed. - // TODO: We should check functional dependencies - /* checkSatisfied( - ClusteredDistribution(Seq('b)), - ClusteredDistribution(Seq('b + 1)), - true) - */ + HashPartitioning(Seq('b, 'c), 10), + OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), + false) } test("RangePartitioning is the output partitioning") { - // Cases which do not need an exchange between two data properties. - checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - UnspecifiedDistribution, - true) - + // RangePartitioning can satisfy OrderedDistribution iff its ordering is a prefix + // of the required ordering, or the required ordering is a prefix of its ordering. checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), @@ -125,6 +224,27 @@ class DistributionSuite extends SparkFunSuite { OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc, 'd.desc)), true) + // TODO: We can have an optimization to first sort the dataset + // by a.asc and then sort b, and c in a partition. This optimization + // should tradeoff the benefit of a less number of Exchange operators + // and the parallelism. + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('a.asc, 'b.desc, 'c.asc)), + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('b.asc, 'a.asc)), + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('a.asc, 'b.asc, 'd.desc)), + false) + + // RangePartitioning can satisfy ClusteredDistribution iff its ordering expressions are a subset + // of the required clustering expressions. checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), ClusteredDistribution(Seq('a, 'b, 'c)), @@ -140,34 +260,47 @@ class DistributionSuite extends SparkFunSuite { ClusteredDistribution(Seq('b, 'c, 'a, 'd)), true) - // Cases which need an exchange between two data properties. - // TODO: We can have an optimization to first sort the dataset - // by a.asc and then sort b, and c in a partition. This optimization - // should tradeoff the benefit of a less number of Exchange operators - // and the parallelism. checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - OrderedDistribution(Seq('a.asc, 'b.desc, 'c.asc)), + ClusteredDistribution(Seq('a, 'b)), false) checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - OrderedDistribution(Seq('b.asc, 'a.asc)), + ClusteredDistribution(Seq('c, 'd)), false) + // RangePartitioning cannot satisfy HashClusteredDistribution checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('a, 'b)), + HashClusteredDistribution(Seq('a, 'b, 'c)), false) + } + test("Partitioning.numPartitions must match Distribution.requiredNumPartitions to satisfy it") { checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('c, 'd)), + SinglePartition, + ClusteredDistribution(Seq('a, 'b, 'c), Some(10)), + false) + + checkSatisfied( + SinglePartition, + HashClusteredDistribution(Seq('a, 'b, 'c), Some(10)), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c), Some(5)), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + HashClusteredDistribution(Seq('a, 'b, 'c), Some(5)), false) checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - AllTuples, + ClusteredDistribution(Seq('a, 'b, 'c), Some(5)), false) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index f8ad624ce0e3d..5b8cf5128fe21 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ @@ -79,6 +80,57 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapEntries(ms2), null) } + test("MapFromEntries") { + def arrayType(keyType: DataType, valueType: DataType) : DataType = { + ArrayType( + StructType(Seq( + StructField("a", keyType), + StructField("b", valueType))), + true) + } + def r(values: Any*): InternalRow = create_row(values: _*) + + // Primitive-type keys and values + val aiType = arrayType(IntegerType, IntegerType) + val ai0 = Literal.create(Seq(r(1, 10), r(2, 20), r(3, 20)), aiType) + val ai1 = Literal.create(Seq(r(1, null), r(2, 20), r(3, null)), aiType) + val ai2 = Literal.create(Seq.empty, aiType) + val ai3 = Literal.create(null, aiType) + val ai4 = Literal.create(Seq(r(1, 10), r(1, 20)), aiType) + val ai5 = Literal.create(Seq(r(1, 10), r(null, 20)), aiType) + val ai6 = Literal.create(Seq(null, r(2, 20), null), aiType) + + checkEvaluation(MapFromEntries(ai0), Map(1 -> 10, 2 -> 20, 3 -> 20)) + checkEvaluation(MapFromEntries(ai1), Map(1 -> null, 2 -> 20, 3 -> null)) + checkEvaluation(MapFromEntries(ai2), Map.empty) + checkEvaluation(MapFromEntries(ai3), null) + checkEvaluation(MapKeys(MapFromEntries(ai4)), Seq(1, 1)) + checkExceptionInExpression[RuntimeException]( + MapFromEntries(ai5), + "The first field from a struct (key) can't be null.") + checkEvaluation(MapFromEntries(ai6), null) + + // Non-primitive-type keys and values + val asType = arrayType(StringType, StringType) + val as0 = Literal.create(Seq(r("a", "aa"), r("b", "bb"), r("c", "bb")), asType) + val as1 = Literal.create(Seq(r("a", null), r("b", "bb"), r("c", null)), asType) + val as2 = Literal.create(Seq.empty, asType) + val as3 = Literal.create(null, asType) + val as4 = Literal.create(Seq(r("a", "aa"), r("a", "bb")), asType) + val as5 = Literal.create(Seq(r("a", "aa"), r(null, "bb")), asType) + val as6 = Literal.create(Seq(null, r("b", "bb"), null), asType) + + checkEvaluation(MapFromEntries(as0), Map("a" -> "aa", "b" -> "bb", "c" -> "bb")) + checkEvaluation(MapFromEntries(as1), Map("a" -> null, "b" -> "bb", "c" -> null)) + checkEvaluation(MapFromEntries(as2), Map.empty) + checkEvaluation(MapFromEntries(as3), null) + checkEvaluation(MapKeys(MapFromEntries(as4)), Seq("a", "a")) + checkExceptionInExpression[RuntimeException]( + MapFromEntries(as5), + "The first field from a struct (key) can't be null.") + checkEvaluation(MapFromEntries(as6), null) + } + test("Sort Array") { val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) @@ -315,6 +367,91 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper Some(Literal.create(null, StringType))), null) } + test("ArraysZip") { + val literals = Seq( + Literal.create(Seq(9001, 9002, 9003, null), ArrayType(IntegerType)), + Literal.create(Seq(null, 1L, null, 4L, 11L), ArrayType(LongType)), + Literal.create(Seq(-1, -3, 900, null), ArrayType(IntegerType)), + Literal.create(Seq("a", null, "c"), ArrayType(StringType)), + Literal.create(Seq(null, false, true), ArrayType(BooleanType)), + Literal.create(Seq(1.1, null, 1.3, null), ArrayType(DoubleType)), + Literal.create(Seq(), ArrayType(NullType)), + Literal.create(Seq(null), ArrayType(NullType)), + Literal.create(Seq(192.toByte), ArrayType(ByteType)), + Literal.create( + Seq(Seq(1, 2, 3), null, Seq(4, 5), Seq(1, null, 3)), ArrayType(ArrayType(IntegerType))), + Literal.create(Seq(Array[Byte](1.toByte, 5.toByte)), ArrayType(BinaryType)) + ) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(1))), + List(Row(9001, null), Row(9002, 1L), Row(9003, null), Row(null, 4L), Row(null, 11L))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(2))), + List(Row(9001, -1), Row(9002, -3), Row(9003, 900), Row(null, null))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(3))), + List(Row(9001, "a"), Row(9002, null), Row(9003, "c"), Row(null, null))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(4))), + List(Row(9001, null), Row(9002, false), Row(9003, true), Row(null, null))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(5))), + List(Row(9001, 1.1), Row(9002, null), Row(9003, 1.3), Row(null, null))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(6))), + List(Row(9001, null), Row(9002, null), Row(9003, null), Row(null, null))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(7))), + List(Row(9001, null), Row(9002, null), Row(9003, null), Row(null, null))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(1), literals(2), literals(3))), + List( + Row(9001, null, -1, "a"), + Row(9002, 1L, -3, null), + Row(9003, null, 900, "c"), + Row(null, 4L, null, null), + Row(null, 11L, null, null))) + + checkEvaluation(ArraysZip(Seq(literals(4), literals(5), literals(6), literals(7), literals(8))), + List( + Row(null, 1.1, null, null, 192.toByte), + Row(false, null, null, null, null), + Row(true, 1.3, null, null, null), + Row(null, null, null, null, null))) + + checkEvaluation(ArraysZip(Seq(literals(9), literals(0))), + List( + Row(List(1, 2, 3), 9001), + Row(null, 9002), + Row(List(4, 5), 9003), + Row(List(1, null, 3), null))) + + checkEvaluation(ArraysZip(Seq(literals(7), literals(10))), + List(Row(null, Array[Byte](1.toByte, 5.toByte)))) + + val longLiteral = + Literal.create((0 to 1000).toSeq, ArrayType(IntegerType)) + + checkEvaluation(ArraysZip(Seq(literals(0), longLiteral)), + List(Row(9001, 0), Row(9002, 1), Row(9003, 2)) ++ + (3 to 1000).map { Row(null, _) }.toList) + + val manyLiterals = (0 to 1000).map { _ => + Literal.create(Seq(1), ArrayType(IntegerType)) + }.toSeq + + val numbers = List( + Row(Seq(9001) ++ (0 to 1000).map { _ => 1 }.toSeq: _*), + Row(Seq(9002) ++ (0 to 1000).map { _ => null }.toSeq: _*), + Row(Seq(9003) ++ (0 to 1000).map { _ => null }.toSeq: _*), + Row(Seq(null) ++ (0 to 1000).map { _ => null }.toSeq: _*)) + checkEvaluation(ArraysZip(Seq(literals(0)) ++ manyLiterals), + List(numbers(0), numbers(1), numbers(2), numbers(3))) + + checkEvaluation(ArraysZip(Seq(literals(0), Literal.create(null, ArrayType(IntegerType)))), null) + checkEvaluation(ArraysZip(Seq()), List()) + } + test("Array Min") { checkEvaluation(ArrayMin(Literal.create(Seq(-11, 10, 2), ArrayType(IntegerType))), -11) checkEvaluation( @@ -680,4 +817,49 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayRemove(c1, dataToRemove2), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1))) checkEvaluation(ArrayRemove(c2, dataToRemove2), Seq[Seq[Int]](null, Seq[Int](2, 1))) } + + test("Array Distinct") { + val a0 = Literal.create(Seq(2, 1, 2, 3, 4, 4, 5), ArrayType(IntegerType)) + val a1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) + val a2 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType)) + val a3 = Literal.create(Seq("b", null, "a", null, "a", null), ArrayType(StringType)) + val a4 = Literal.create(Seq(null, null, null), ArrayType(NullType)) + val a5 = Literal.create(Seq(true, false, false, true), ArrayType(BooleanType)) + val a6 = Literal.create(Seq(1.123, 0.1234, 1.121, 1.123, 1.1230, 1.121, 0.1234), + ArrayType(DoubleType)) + val a7 = Literal.create(Seq(1.123f, 0.1234f, 1.121f, 1.123f, 1.1230f, 1.121f, 0.1234f), + ArrayType(FloatType)) + + checkEvaluation(new ArrayDistinct(a0), Seq(2, 1, 3, 4, 5)) + checkEvaluation(new ArrayDistinct(a1), Seq.empty[Integer]) + checkEvaluation(new ArrayDistinct(a2), Seq("b", "a", "c")) + checkEvaluation(new ArrayDistinct(a3), Seq("b", null, "a")) + checkEvaluation(new ArrayDistinct(a4), Seq(null)) + checkEvaluation(new ArrayDistinct(a5), Seq(true, false)) + checkEvaluation(new ArrayDistinct(a6), Seq(1.123, 0.1234, 1.121)) + checkEvaluation(new ArrayDistinct(a7), Seq(1.123f, 0.1234f, 1.121f)) + + // complex data types + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2), + Array[Byte](1, 2), Array[Byte](5, 6)), ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), null, Array[Byte](1, 2), + null, Array[Byte](5, 6), null), ArrayType(BinaryType)) + + checkEvaluation(ArrayDistinct(b0), Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2))) + checkEvaluation(ArrayDistinct(b1), Seq[Array[Byte]](Array[Byte](2, 1), null)) + checkEvaluation(ArrayDistinct(b2), Seq[Array[Byte]](Array[Byte](5, 6), null, + Array[Byte](1, 2))) + + val c0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](1, 2), + Seq[Int](3, 4), Seq[Int](1, 2)), ArrayType(ArrayType(IntegerType))) + val c1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + val c2 = Literal.create(Seq[Seq[Int]](null, Seq[Int](2, 1), null, null, Seq[Int](2, 1), null), + ArrayType(ArrayType(IntegerType))) + checkEvaluation(ArrayDistinct(c0), Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4))) + checkEvaluation(ArrayDistinct(c1), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1))) + checkEvaluation(ArrayDistinct(c2), Seq[Seq[Int]](null, Seq[Int](2, 1))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index b4138ce366b3a..726193b411737 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -186,6 +186,50 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("MapFromArrays") { + def createMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = { + // catalyst map is order-sensitive, so we create ListMap here to preserve the elements order. + scala.collection.immutable.ListMap(keys.zip(values): _*) + } + + val intSeq = Seq(5, 10, 15, 20, 25) + val longSeq = intSeq.map(_.toLong) + val strSeq = intSeq.map(_.toString) + val integerSeq = Seq[java.lang.Integer](5, 10, 15, 20, 25) + val intWithNullSeq = Seq[java.lang.Integer](5, 10, null, 20, 25) + val longWithNullSeq = intSeq.map(java.lang.Long.valueOf(_)) + + val intArray = Literal.create(intSeq, ArrayType(IntegerType, false)) + val longArray = Literal.create(longSeq, ArrayType(LongType, false)) + val strArray = Literal.create(strSeq, ArrayType(StringType, false)) + + val integerArray = Literal.create(integerSeq, ArrayType(IntegerType, true)) + val intWithNullArray = Literal.create(intWithNullSeq, ArrayType(IntegerType, true)) + val longWithNullArray = Literal.create(longWithNullSeq, ArrayType(LongType, true)) + + val nullArray = Literal.create(null, ArrayType(StringType, false)) + + checkEvaluation(MapFromArrays(intArray, longArray), createMap(intSeq, longSeq)) + checkEvaluation(MapFromArrays(intArray, strArray), createMap(intSeq, strSeq)) + checkEvaluation(MapFromArrays(integerArray, strArray), createMap(integerSeq, strSeq)) + + checkEvaluation( + MapFromArrays(strArray, intWithNullArray), createMap(strSeq, intWithNullSeq)) + checkEvaluation( + MapFromArrays(strArray, longWithNullArray), createMap(strSeq, longWithNullSeq)) + checkEvaluation( + MapFromArrays(strArray, longWithNullArray), createMap(strSeq, longWithNullSeq)) + checkEvaluation(MapFromArrays(nullArray, nullArray), null) + + intercept[RuntimeException] { + checkEvaluation(MapFromArrays(intWithNullArray, strArray), null) + } + intercept[RuntimeException] { + checkEvaluation( + MapFromArrays(intArray, Literal.create(Seq(1), ArrayType(IntegerType))), null) + } + } + test("CreateStruct") { val row = create_row(1, 2, 3) val c1 = 'a.int.at(0) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index a9e0eb0e377a6..86f80fe66d28b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -219,4 +219,11 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkUnsupportedTypeInLiteral(Map("key1" -> 1, "key2" -> 2)) checkUnsupportedTypeInLiteral(("mike", 29, 1.0)) } + + test("SPARK-24571: char literals") { + checkEvaluation(Literal('X'), "X") + checkEvaluation(Literal.create('0'), "0") + checkEvaluation(Literal('\u0000'), "\u0000") + checkEvaluation(Literal.create('\n'), "\n") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala index c4cde7091154b..0fec15bc42c17 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala @@ -77,6 +77,27 @@ class UDFXPathUtilSuite extends SparkFunSuite { assert(ret == "foo") } + test("embedFailure") { + import org.apache.commons.io.FileUtils + import java.io.File + val secretValue = String.valueOf(Math.random) + val tempFile = File.createTempFile("verifyembed", ".tmp") + tempFile.deleteOnExit() + val fname = tempFile.getAbsolutePath + + FileUtils.writeStringToFile(tempFile, secretValue) + + val xml = + s""" + | + |]> + |&embed; + """.stripMargin + val evaled = new UDFXPathUtil().evalString(xml, "/foo") + assert(evaled.isEmpty) + } + test("number eval") { var ret = util.evalNumber("truefalseb3c1-77", "a/c[2]") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala index bfa18a0919e45..c6f6d3abb860c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala @@ -40,8 +40,9 @@ class XPathExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { // Test error message for invalid XML document val e1 = intercept[RuntimeException] { testExpr("/a>", "a", null.asInstanceOf[T]) } - assert(e1.getCause.getMessage.contains("Invalid XML document") && - e1.getCause.getMessage.contains("/a>")) + assert(e1.getCause.getCause.getMessage.contains( + "XML document structures must start and end within the same entity.")) + assert(e1.getMessage.contains("/a>")) // Test error message for invalid xpath val e2 = intercept[RuntimeException] { testExpr("", "!#$", null.asInstanceOf[T]) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index 14041747fd20e..bf569cb869428 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Coalesce, Literal, NamedExpression} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.IntegerType @@ -101,4 +101,22 @@ class LogicalPlanSuite extends SparkFunSuite { assert(TestBinaryRelation(relation, incrementalRelation).isStreaming === true) assert(TestBinaryRelation(incrementalRelation, incrementalRelation).isStreaming) } + + test("transformExpressions works with a Stream") { + val id1 = NamedExpression.newExprId + val id2 = NamedExpression.newExprId + val plan = Project(Stream( + Alias(Literal(1), "a")(exprId = id1), + Alias(Literal(2), "b")(exprId = id2)), + OneRowRelation()) + val result = plan.transformExpressions { + case Literal(v: Int, IntegerType) if v != 1 => + Literal(v + 1, IntegerType) + } + val expected = Project(Stream( + Alias(Literal(1), "a")(exprId = id1), + Alias(Literal(3), "b")(exprId = id2)), + OneRowRelation()) + assert(result.sameResult(expected)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 84d0ba7bef642..b7092f4c42d4c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -29,14 +29,14 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource, JarResource} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.dsl.expressions.DslString import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans.{LeftOuter, NaturalJoin} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Union} import org.apache.spark.sql.catalyst.plans.physical.{IdentityBroadcastMode, RoundRobinPartitioning, SinglePartition} -import org.apache.spark.sql.types.{BooleanType, DoubleType, FloatType, IntegerType, Metadata, NullType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel case class Dummy(optKey: Option[Expression]) extends Expression with CodegenFallback { @@ -574,4 +574,25 @@ class TreeNodeSuite extends SparkFunSuite { val right = JsonMethods.parse(rightJson) assert(left == right) } + + test("transform works on stream of children") { + val before = Coalesce(Stream(Literal(1), Literal(2))) + // Note it is a bit tricky to exhibit the broken behavior. Basically we want to create the + // situation in which the TreeNode.mapChildren function's change detection is not triggered. A + // stream's first element is typically materialized, so in order to not trip the TreeNode change + // detection logic, we should not change the first element in the sequence. + val result = before.transform { + case Literal(v: Int, IntegerType) if v != 1 => + Literal(v + 1, IntegerType) + } + val expected = Coalesce(Stream(Literal(1), Literal(3))) + assert(result === expected) + } + + test("withNewChildren on stream of children") { + val before = Coalesce(Stream(Literal(1), Literal(2))) + val result = before.withNewChildren(Stream(Literal(1), Literal(3))) + val expected = Coalesce(Stream(Literal(1), Literal(3))) + assert(result === expected) + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 4733f36174f42..6fdadde628551 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -216,12 +216,12 @@ protected UTF8String getBytesAsUTF8String(int rowId, int count) { @Override public void putShort(int rowId, short value) { - Platform.putShort(null, data + 2 * rowId, value); + Platform.putShort(null, data + 2L * rowId, value); } @Override public void putShorts(int rowId, int count, short value) { - long offset = data + 2 * rowId; + long offset = data + 2L * rowId; for (int i = 0; i < count; ++i, offset += 2) { Platform.putShort(null, offset, value); } @@ -229,20 +229,20 @@ public void putShorts(int rowId, int count, short value) { @Override public void putShorts(int rowId, int count, short[] src, int srcIndex) { - Platform.copyMemory(src, Platform.SHORT_ARRAY_OFFSET + srcIndex * 2, - null, data + 2 * rowId, count * 2); + Platform.copyMemory(src, Platform.SHORT_ARRAY_OFFSET + srcIndex * 2L, + null, data + 2L * rowId, count * 2L); } @Override public void putShorts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 2, count * 2); + null, data + rowId * 2L, count * 2L); } @Override public short getShort(int rowId) { if (dictionary == null) { - return Platform.getShort(null, data + 2 * rowId); + return Platform.getShort(null, data + 2L * rowId); } else { return (short) dictionary.decodeToInt(dictionaryIds.getDictId(rowId)); } @@ -252,7 +252,7 @@ public short getShort(int rowId) { public short[] getShorts(int rowId, int count) { assert(dictionary == null); short[] array = new short[count]; - Platform.copyMemory(null, data + rowId * 2, array, Platform.SHORT_ARRAY_OFFSET, count * 2); + Platform.copyMemory(null, data + rowId * 2L, array, Platform.SHORT_ARRAY_OFFSET, count * 2L); return array; } @@ -262,12 +262,12 @@ public short[] getShorts(int rowId, int count) { @Override public void putInt(int rowId, int value) { - Platform.putInt(null, data + 4 * rowId, value); + Platform.putInt(null, data + 4L * rowId, value); } @Override public void putInts(int rowId, int count, int value) { - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4) { Platform.putInt(null, offset, value); } @@ -275,24 +275,24 @@ public void putInts(int rowId, int count, int value) { @Override public void putInts(int rowId, int count, int[] src, int srcIndex) { - Platform.copyMemory(src, Platform.INT_ARRAY_OFFSET + srcIndex * 4, - null, data + 4 * rowId, count * 4); + Platform.copyMemory(src, Platform.INT_ARRAY_OFFSET + srcIndex * 4L, + null, data + 4L * rowId, count * 4L); } @Override public void putInts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 4, count * 4); + null, data + rowId * 4L, count * 4L); } @Override public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET, - null, data + 4 * rowId, count * 4); + null, data + 4L * rowId, count * 4L); } else { int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4, srcOffset += 4) { Platform.putInt(null, offset, java.lang.Integer.reverseBytes(Platform.getInt(src, srcOffset))); @@ -303,7 +303,7 @@ public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) @Override public int getInt(int rowId) { if (dictionary == null) { - return Platform.getInt(null, data + 4 * rowId); + return Platform.getInt(null, data + 4L * rowId); } else { return dictionary.decodeToInt(dictionaryIds.getDictId(rowId)); } @@ -313,7 +313,7 @@ public int getInt(int rowId) { public int[] getInts(int rowId, int count) { assert(dictionary == null); int[] array = new int[count]; - Platform.copyMemory(null, data + rowId * 4, array, Platform.INT_ARRAY_OFFSET, count * 4); + Platform.copyMemory(null, data + rowId * 4L, array, Platform.INT_ARRAY_OFFSET, count * 4L); return array; } @@ -325,7 +325,7 @@ public int[] getInts(int rowId, int count) { public int getDictId(int rowId) { assert(dictionary == null) : "A ColumnVector dictionary should not have a dictionary for itself."; - return Platform.getInt(null, data + 4 * rowId); + return Platform.getInt(null, data + 4L * rowId); } // @@ -334,12 +334,12 @@ public int getDictId(int rowId) { @Override public void putLong(int rowId, long value) { - Platform.putLong(null, data + 8 * rowId, value); + Platform.putLong(null, data + 8L * rowId, value); } @Override public void putLongs(int rowId, int count, long value) { - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8) { Platform.putLong(null, offset, value); } @@ -347,24 +347,24 @@ public void putLongs(int rowId, int count, long value) { @Override public void putLongs(int rowId, int count, long[] src, int srcIndex) { - Platform.copyMemory(src, Platform.LONG_ARRAY_OFFSET + srcIndex * 8, - null, data + 8 * rowId, count * 8); + Platform.copyMemory(src, Platform.LONG_ARRAY_OFFSET + srcIndex * 8L, + null, data + 8L * rowId, count * 8L); } @Override public void putLongs(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 8, count * 8); + null, data + rowId * 8L, count * 8L); } @Override public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET, - null, data + 8 * rowId, count * 8); + null, data + 8L * rowId, count * 8L); } else { int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8, srcOffset += 8) { Platform.putLong(null, offset, java.lang.Long.reverseBytes(Platform.getLong(src, srcOffset))); @@ -375,7 +375,7 @@ public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) @Override public long getLong(int rowId) { if (dictionary == null) { - return Platform.getLong(null, data + 8 * rowId); + return Platform.getLong(null, data + 8L * rowId); } else { return dictionary.decodeToLong(dictionaryIds.getDictId(rowId)); } @@ -385,7 +385,7 @@ public long getLong(int rowId) { public long[] getLongs(int rowId, int count) { assert(dictionary == null); long[] array = new long[count]; - Platform.copyMemory(null, data + rowId * 8, array, Platform.LONG_ARRAY_OFFSET, count * 8); + Platform.copyMemory(null, data + rowId * 8L, array, Platform.LONG_ARRAY_OFFSET, count * 8L); return array; } @@ -395,12 +395,12 @@ public long[] getLongs(int rowId, int count) { @Override public void putFloat(int rowId, float value) { - Platform.putFloat(null, data + rowId * 4, value); + Platform.putFloat(null, data + rowId * 4L, value); } @Override public void putFloats(int rowId, int count, float value) { - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4) { Platform.putFloat(null, offset, value); } @@ -408,18 +408,18 @@ public void putFloats(int rowId, int count, float value) { @Override public void putFloats(int rowId, int count, float[] src, int srcIndex) { - Platform.copyMemory(src, Platform.FLOAT_ARRAY_OFFSET + srcIndex * 4, - null, data + 4 * rowId, count * 4); + Platform.copyMemory(src, Platform.FLOAT_ARRAY_OFFSET + srcIndex * 4L, + null, data + 4L * rowId, count * 4L); } @Override public void putFloats(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 4, count * 4); + null, data + rowId * 4L, count * 4L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4) { Platform.putFloat(null, offset, bb.getFloat(srcIndex + (4 * i))); } @@ -429,7 +429,7 @@ public void putFloats(int rowId, int count, byte[] src, int srcIndex) { @Override public float getFloat(int rowId) { if (dictionary == null) { - return Platform.getFloat(null, data + rowId * 4); + return Platform.getFloat(null, data + rowId * 4L); } else { return dictionary.decodeToFloat(dictionaryIds.getDictId(rowId)); } @@ -439,7 +439,7 @@ public float getFloat(int rowId) { public float[] getFloats(int rowId, int count) { assert(dictionary == null); float[] array = new float[count]; - Platform.copyMemory(null, data + rowId * 4, array, Platform.FLOAT_ARRAY_OFFSET, count * 4); + Platform.copyMemory(null, data + rowId * 4L, array, Platform.FLOAT_ARRAY_OFFSET, count * 4L); return array; } @@ -450,12 +450,12 @@ public float[] getFloats(int rowId, int count) { @Override public void putDouble(int rowId, double value) { - Platform.putDouble(null, data + rowId * 8, value); + Platform.putDouble(null, data + rowId * 8L, value); } @Override public void putDoubles(int rowId, int count, double value) { - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8) { Platform.putDouble(null, offset, value); } @@ -463,18 +463,18 @@ public void putDoubles(int rowId, int count, double value) { @Override public void putDoubles(int rowId, int count, double[] src, int srcIndex) { - Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex * 8, - null, data + 8 * rowId, count * 8); + Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex * 8L, + null, data + 8L * rowId, count * 8L); } @Override public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 8, count * 8); + null, data + rowId * 8L, count * 8L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8) { Platform.putDouble(null, offset, bb.getDouble(srcIndex + (8 * i))); } @@ -484,7 +484,7 @@ public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { @Override public double getDouble(int rowId) { if (dictionary == null) { - return Platform.getDouble(null, data + rowId * 8); + return Platform.getDouble(null, data + rowId * 8L); } else { return dictionary.decodeToDouble(dictionaryIds.getDictId(rowId)); } @@ -494,7 +494,7 @@ public double getDouble(int rowId) { public double[] getDoubles(int rowId, int count) { assert(dictionary == null); double[] array = new double[count]; - Platform.copyMemory(null, data + rowId * 8, array, Platform.DOUBLE_ARRAY_OFFSET, count * 8); + Platform.copyMemory(null, data + rowId * 8L, array, Platform.DOUBLE_ARRAY_OFFSET, count * 8L); return array; } @@ -504,26 +504,26 @@ public double[] getDoubles(int rowId, int count) { @Override public void putArray(int rowId, int offset, int length) { assert(offset >= 0 && offset + length <= childColumns[0].capacity); - Platform.putInt(null, lengthData + 4 * rowId, length); - Platform.putInt(null, offsetData + 4 * rowId, offset); + Platform.putInt(null, lengthData + 4L * rowId, length); + Platform.putInt(null, offsetData + 4L * rowId, offset); } @Override public int getArrayLength(int rowId) { - return Platform.getInt(null, lengthData + 4 * rowId); + return Platform.getInt(null, lengthData + 4L * rowId); } @Override public int getArrayOffset(int rowId) { - return Platform.getInt(null, offsetData + 4 * rowId); + return Platform.getInt(null, offsetData + 4L * rowId); } // APIs dealing with ByteArrays @Override public int putByteArray(int rowId, byte[] value, int offset, int length) { int result = arrayData().appendBytes(length, value, offset); - Platform.putInt(null, lengthData + 4 * rowId, length); - Platform.putInt(null, offsetData + 4 * rowId, result); + Platform.putInt(null, lengthData + 4L * rowId, length); + Platform.putInt(null, offsetData + 4L * rowId, result); return result; } @@ -533,19 +533,19 @@ protected void reserveInternal(int newCapacity) { int oldCapacity = (nulls == 0L) ? 0 : capacity; if (isArray() || type instanceof MapType) { this.lengthData = - Platform.reallocateMemory(lengthData, oldCapacity * 4, newCapacity * 4); + Platform.reallocateMemory(lengthData, oldCapacity * 4L, newCapacity * 4L); this.offsetData = - Platform.reallocateMemory(offsetData, oldCapacity * 4, newCapacity * 4); + Platform.reallocateMemory(offsetData, oldCapacity * 4L, newCapacity * 4L); } else if (type instanceof ByteType || type instanceof BooleanType) { this.data = Platform.reallocateMemory(data, oldCapacity, newCapacity); } else if (type instanceof ShortType) { - this.data = Platform.reallocateMemory(data, oldCapacity * 2, newCapacity * 2); + this.data = Platform.reallocateMemory(data, oldCapacity * 2L, newCapacity * 2L); } else if (type instanceof IntegerType || type instanceof FloatType || type instanceof DateType || DecimalType.is32BitDecimalType(type)) { - this.data = Platform.reallocateMemory(data, oldCapacity * 4, newCapacity * 4); + this.data = Platform.reallocateMemory(data, oldCapacity * 4L, newCapacity * 4L); } else if (type instanceof LongType || type instanceof DoubleType || DecimalType.is64BitDecimalType(type) || type instanceof TimestampType) { - this.data = Platform.reallocateMemory(data, oldCapacity * 8, newCapacity * 8); + this.data = Platform.reallocateMemory(data, oldCapacity * 8L, newCapacity * 8L); } else if (childColumns != null) { // Nothing to store. } else { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 23dcc104e67c4..577eab6ed14c8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -231,7 +231,7 @@ public void putShorts(int rowId, int count, short[] src, int srcIndex) { @Override public void putShorts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, shortData, - Platform.SHORT_ARRAY_OFFSET + rowId * 2, count * 2); + Platform.SHORT_ARRAY_OFFSET + rowId * 2L, count * 2L); } @Override @@ -276,7 +276,7 @@ public void putInts(int rowId, int count, int[] src, int srcIndex) { @Override public void putInts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, intData, - Platform.INT_ARRAY_OFFSET + rowId * 4, count * 4); + Platform.INT_ARRAY_OFFSET + rowId * 4L, count * 4L); } @Override @@ -342,7 +342,7 @@ public void putLongs(int rowId, int count, long[] src, int srcIndex) { @Override public void putLongs(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, longData, - Platform.LONG_ARRAY_OFFSET + rowId * 8, count * 8); + Platform.LONG_ARRAY_OFFSET + rowId * 8L, count * 8L); } @Override @@ -394,7 +394,7 @@ public void putFloats(int rowId, int count, float[] src, int srcIndex) { public void putFloats(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, floatData, - Platform.DOUBLE_ARRAY_OFFSET + rowId * 4, count * 4); + Platform.DOUBLE_ARRAY_OFFSET + rowId * 4L, count * 4L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); for (int i = 0; i < count; ++i) { @@ -443,7 +443,7 @@ public void putDoubles(int rowId, int count, double[] src, int srcIndex) { public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, doubleData, - Platform.DOUBLE_ARRAY_OFFSET + rowId * 8, count * 8); + Platform.DOUBLE_ARRAY_OFFSET + rowId * 8L, count * 8L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); for (int i = 0; i < count; ++i) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java index 11bb13fd3b211..926396414816c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java @@ -22,6 +22,10 @@ /** * A mix in interface for {@link DataSourceReader}. Data source readers can implement this * interface to report statistics to Spark. + * + * Statistics are reported to the optimizer before any operator is pushed to the DataSourceReader. + * Implementations that return more accurate statistics based on pushed operators will not improve + * query performance until the planner can push operators before getting stats. */ @InterfaceStability.Evolving public interface SupportsReportStatistics extends DataSourceReader { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index 0030a9f05dba7..7eedc85a5d6f3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -64,8 +64,8 @@ public interface DataSourceWriter { DataWriterFactory createWriterFactory(); /** - * Returns whether Spark should use the commit coordinator to ensure that at most one attempt for - * each task commits. + * Returns whether Spark should use the commit coordinator to ensure that at most one task for + * each partition commits. * * @return true if commit coordinator should be used, false otherwise. */ @@ -90,9 +90,9 @@ default void onDataWriterCommit(WriterCommitMessage message) {} * is undefined and @{@link #abort(WriterCommitMessage[])} may not be able to deal with it. * * Note that speculative execution may cause multiple tasks to run for a partition. By default, - * Spark uses the commit coordinator to allow at most one attempt to commit. Implementations can + * Spark uses the commit coordinator to allow at most one task to commit. Implementations can * disable this behavior by overriding {@link #useCommitCoordinator()}. If disabled, multiple - * attempts may have committed successfully and one successful commit message per task will be + * tasks may have committed successfully and one successful commit message per task will be * passed to this commit method. The remaining commit messages are ignored by Spark. */ void commit(WriterCommitMessage[] messages); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index 39bf458298862..1626c0013e4e7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data writer returned by {@link DataWriterFactory#createDataWriter(int, int, long)} and is + * A data writer returned by {@link DataWriterFactory#createDataWriter(int, long, long)} and is * responsible for writing data for an input RDD partition. * * One Spark task has one exclusive data writer, so there is no thread-safe concern. @@ -39,14 +39,14 @@ * {@link DataSourceWriter#commit(WriterCommitMessage[])} with commit messages from other data * writers. If this data writer fails(one record fails to write or {@link #commit()} fails), an * exception will be sent to the driver side, and Spark may retry this writing task a few times. - * In each retry, {@link DataWriterFactory#createDataWriter(int, int, long)} will receive a - * different `attemptNumber`. Spark will call {@link DataSourceWriter#abort(WriterCommitMessage[])} + * In each retry, {@link DataWriterFactory#createDataWriter(int, long, long)} will receive a + * different `taskId`. Spark will call {@link DataSourceWriter#abort(WriterCommitMessage[])} * when the configured number of retries is exhausted. * * Besides the retry mechanism, Spark may launch speculative tasks if the existing writing task * takes too long to finish. Different from retried tasks, which are launched one by one after the * previous one fails, speculative tasks are running simultaneously. It's possible that one input - * RDD partition has multiple data writers with different `attemptNumber` running at the same time, + * RDD partition has multiple data writers with different `taskId` running at the same time, * and data sources should guarantee that these data writers don't conflict and can work together. * Implementations can coordinate with driver during {@link #commit()} to make sure only one of * these data writers can commit successfully. Or implementations can allow all of them to commit diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index 7527bcc0c4027..0932ff8f8f8a7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -42,15 +42,12 @@ public interface DataWriterFactory extends Serializable { * Usually Spark processes many RDD partitions at the same time, * implementations should use the partition id to distinguish writers for * different partitions. - * @param attemptNumber Spark may launch multiple tasks with the same task id. For example, a task - * failed, Spark launches a new task wth the same task id but different - * attempt number. Or a task is too slow, Spark launches new tasks wth the - * same task id but different attempt number, which means there are multiple - * tasks with the same task id running at the same time. Implementations can - * use this attempt number to distinguish writers of different task attempts. + * @param taskId A unique identifier for a task that is performing the write of the partition + * data. Spark may run multiple tasks for the same partition (due to speculation + * or task failures, for example). * @param epochId A monotonically increasing id for streaming queries that are split in to * discrete periods of execution. For non-streaming queries, * this ID will always be 0. */ - DataWriter createDataWriter(int partitionId, int attemptNumber, long epochId); + DataWriter createDataWriter(int partitionId, long taskId, long epochId); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index de6be5f76e15a..ec9352a7fa055 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -381,6 +381,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * that should be used for parsing. *
  • `samplingRatio` (default is 1.0): defines fraction of input JSON objects used * for schema inferring.
  • + *
  • `dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or + * empty array/struct during schema inference.
  • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f5526104690d2..57f1e173211af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2964,6 +2964,7 @@ class Dataset[T] private[sql]( /** * Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk. + * This will not un-persist any cached data that is built upon this Dataset. * * @param blocking Whether to block until all blocks are deleted. * @@ -2971,12 +2972,13 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def unpersist(blocking: Boolean): this.type = { - sparkSession.sharedState.cacheManager.uncacheQuery(this, blocking) + sparkSession.sharedState.cacheManager.uncacheQuery(this, cascade = false, blocking) this } /** * Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk. + * This will not un-persist any cached data that is built upon this Dataset. * * @group basic * @since 1.6.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala index 86e02e98c01f3..b21c50af18433 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala @@ -20,10 +20,48 @@ package org.apache.spark.sql import org.apache.spark.annotation.InterfaceStability /** - * A class to consume data generated by a `StreamingQuery`. Typically this is used to send the - * generated data to external systems. Each partition will use a new deserialized instance, so you - * usually should do all the initialization (e.g. opening a connection or initiating a transaction) - * in the `open` method. + * The abstract class for writing custom logic to process data generated by a query. + * This is often used to write the output of a streaming query to arbitrary storage systems. + * Any implementation of this base class will be used by Spark in the following way. + * + *
      + *
    • A single instance of this class is responsible of all the data generated by a single task + * in a query. In other words, one instance is responsible for processing one partition of the + * data generated in a distributed manner. + * + *
    • Any implementation of this class must be serializable because each task will get a fresh + * serialized-deserialized copy of the provided object. Hence, it is strongly recommended that + * any initialization for writing data (e.g. opening a connection or starting a transaction) + * is done after the `open(...)` method has been called, which signifies that the task is + * ready to generate data. + * + *
    • The lifecycle of the methods are as follows. + * + *
      + *   For each partition with `partitionId`:
      + *       For each batch/epoch of streaming data (if its streaming query) with `epochId`:
      + *           Method `open(partitionId, epochId)` is called.
      + *           If `open` returns true:
      + *                For each row in the partition and batch/epoch, method `process(row)` is called.
      + *           Method `close(errorOrNull)` is called with error (if any) seen while processing rows.
      + *   
      + * + *
    + * + * Important points to note: + *
      + *
    • The `partitionId` and `epochId` can be used to deduplicate generated data when failures + * cause reprocessing of some input data. This depends on the execution mode of the query. If + * the streaming query is being executed in the micro-batch mode, then every partition + * represented by a unique tuple (partitionId, epochId) is guaranteed to have the same data. + * Hence, (partitionId, epochId) can be used to deduplicate and/or transactionally commit data + * and achieve exactly-once guarantees. However, if the streaming query is being executed in the + * continuous mode, then this guarantee does not hold and therefore should not be used for + * deduplication. + * + *
    • The `close()` method will be called if `open()` method returns successfully (irrespective + * of the return value), except if the JVM crashes in the middle. + *
    * * Scala example: * {{{ @@ -63,6 +101,7 @@ import org.apache.spark.annotation.InterfaceStability * } * }); * }}} + * * @since 2.0.0 */ @InterfaceStability.Evolving @@ -71,23 +110,18 @@ abstract class ForeachWriter[T] extends Serializable { // TODO: Move this to org.apache.spark.sql.util or consolidate this with batch API. /** - * Called when starting to process one partition of new data in the executor. The `version` is - * for data deduplication when there are failures. When recovering from a failure, some data may - * be generated multiple times but they will always have the same version. - * - * If this method finds using the `partitionId` and `version` that this partition has already been - * processed, it can return `false` to skip the further data processing. However, `close` still - * will be called for cleaning up resources. + * Called when starting to process one partition of new data in the executor. See the class + * docs for more information on how to use the `partitionId` and `epochId`. * * @param partitionId the partition id. - * @param version a unique id for data deduplication. + * @param epochId a unique id for data deduplication. * @return `true` if the corresponding partition and version id should be processed. `false` * indicates the partition should be skipped. */ - def open(partitionId: Long, version: Long): Boolean + def open(partitionId: Long, epochId: Long): Boolean /** - * Called to process the data in the executor side. This method will be called only when `open` + * Called to process the data in the executor side. This method will be called only if `open` * returns `true`. */ def process(value: T): Unit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 93bf91e56f1bd..39d9a95ca4710 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.expressions.SubqueryExpression -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ResolvedHint} +import org.apache.spark.sql.catalyst.plans.logical.{AnalysisBarrier, LogicalPlan, ResolvedHint} import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.storage.StorageLevel @@ -97,7 +97,7 @@ class CacheManager extends Logging { val inMemoryRelation = InMemoryRelation( sparkSession.sessionState.conf.useCompression, sparkSession.sessionState.conf.columnBatchSize, storageLevel, - sparkSession.sessionState.executePlan(planToCache).executedPlan, + sparkSession.sessionState.executePlan(AnalysisBarrier(planToCache)).executedPlan, tableName, planToCache) cachedData.add(CachedData(planToCache, inMemoryRelation)) @@ -105,24 +105,50 @@ class CacheManager extends Logging { } /** - * Un-cache all the cache entries that refer to the given plan. + * Un-cache the given plan or all the cache entries that refer to the given plan. + * @param query The [[Dataset]] to be un-cached. + * @param cascade If true, un-cache all the cache entries that refer to the given + * [[Dataset]]; otherwise un-cache the given [[Dataset]] only. + * @param blocking Whether to block until all blocks are deleted. */ - def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Unit = writeLock { - uncacheQuery(query.sparkSession, query.logicalPlan, blocking) + def uncacheQuery( + query: Dataset[_], + cascade: Boolean, + blocking: Boolean = true): Unit = writeLock { + uncacheQuery(query.sparkSession, query.logicalPlan, cascade, blocking) } /** - * Un-cache all the cache entries that refer to the given plan. + * Un-cache the given plan or all the cache entries that refer to the given plan. + * @param spark The Spark session. + * @param plan The plan to be un-cached. + * @param cascade If true, un-cache all the cache entries that refer to the given + * plan; otherwise un-cache the given plan only. + * @param blocking Whether to block until all blocks are deleted. */ - def uncacheQuery(spark: SparkSession, plan: LogicalPlan, blocking: Boolean): Unit = writeLock { + def uncacheQuery( + spark: SparkSession, + plan: LogicalPlan, + cascade: Boolean, + blocking: Boolean): Unit = writeLock { + val shouldRemove: LogicalPlan => Boolean = + if (cascade) { + _.find(_.sameResult(plan)).isDefined + } else { + _.sameResult(plan) + } val it = cachedData.iterator() while (it.hasNext) { val cd = it.next() - if (cd.plan.find(_.sameResult(plan)).isDefined) { + if (shouldRemove(cd.plan)) { cd.cachedRepresentation.cacheBuilder.clearCache(blocking) it.remove() } } + // Re-compile dependent cached queries after removing the cached query. + if (!cascade) { + recacheByCondition(spark, _.find(_.sameResult(plan)).isDefined, clearCache = false) + } } /** @@ -132,20 +158,24 @@ class CacheManager extends Logging { recacheByCondition(spark, _.find(_.sameResult(plan)).isDefined) } - private def recacheByCondition(spark: SparkSession, condition: LogicalPlan => Boolean): Unit = { + private def recacheByCondition( + spark: SparkSession, + condition: LogicalPlan => Boolean, + clearCache: Boolean = true): Unit = { val it = cachedData.iterator() val needToRecache = scala.collection.mutable.ArrayBuffer.empty[CachedData] while (it.hasNext) { val cd = it.next() if (condition(cd.plan)) { - cd.cachedRepresentation.cacheBuilder.clearCache() + if (clearCache) { + cd.cachedRepresentation.cacheBuilder.clearCache() + } // Remove the cache entry before we create a new one, so that we can have a different // physical plan. it.remove() - val plan = spark.sessionState.executePlan(cd.plan).executedPlan + val plan = spark.sessionState.executePlan(AnalysisBarrier(cd.plan)).executedPlan val newCache = InMemoryRelation( - cacheBuilder = cd.cachedRepresentation - .cacheBuilder.copy(cachedPlan = plan)(_cachedColumnBuffers = null), + cacheBuilder = cd.cachedRepresentation.cacheBuilder.withCachedPlan(plan), logicalPlan = cd.plan) needToRecache += cd.copy(cachedRepresentation = newCache) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 1c8e4050978dc..00ff4c8ac310b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -21,7 +21,6 @@ import org.apache.spark.sql.ExperimentalMethods import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions -import org.apache.spark.sql.execution.datasources.v2.PushDownOperatorsToDataSource import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate class SparkOptimizer( @@ -32,8 +31,7 @@ class SparkOptimizer( override def batches: Seq[Batch] = (preOptimizationBatches ++ super.batches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ - Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+ - Batch("Push down operators to data source scan", Once, PushDownOperatorsToDataSource)) ++ + Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++ postHocOptimizationBatches :+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 74048871f8d42..75f5ec0e253df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -41,6 +41,7 @@ class SparkPlanner( DataSourceStrategy(conf) :: SpecialLimits :: Aggregation :: + Window :: JoinSelection :: InMemoryScans :: BasicOperators :: Nil) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index be34387f6a874..07a6fcae83b70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -327,7 +327,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case PhysicalAggregation( namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) => - if (aggregateExpressions.exists(PythonUDF.isGroupAggPandasUDF)) { + if (aggregateExpressions.exists(PythonUDF.isGroupedAggPandasUDF)) { throw new AnalysisException( "Streaming aggregation doesn't support group aggregate pandas UDF") } @@ -428,6 +428,22 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + object Window extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalWindow( + WindowFunctionType.SQL, windowExprs, partitionSpec, orderSpec, child) => + execution.window.WindowExec( + windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil + + case PhysicalWindow( + WindowFunctionType.Python, windowExprs, partitionSpec, orderSpec, child) => + execution.python.WindowInPandasExec( + windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil + + case _ => Nil + } + } + protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1) object InMemoryScans extends Strategy { @@ -478,7 +494,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case d: DataWritingCommand => DataWritingCommandExec(d, planLater(d.query)) :: Nil @@ -548,8 +563,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.FilterExec(f.typedCondition(f.deserializer), planLater(f.child)) :: Nil case e @ logical.Expand(_, _, child) => execution.ExpandExec(e.projections, e.output, planLater(child)) :: Nil - case logical.Window(windowExprs, partitionSpec, orderSpec, child) => - execution.window.WindowExec(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil case logical.Sample(lb, ub, withReplacement, seed, child) => execution.SampleExec(lb, ub, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data, _) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index aab8cc50b9526..6d44890704f49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils object TypedAggregateExpression { def apply[BUF : Encoder, OUT : Encoder]( @@ -109,7 +110,9 @@ trait TypedAggregateExpression extends AggregateFunction { s"$nodeName($input)" } - override def nodeName: String = aggregator.getClass.getSimpleName.stripSuffix("$") + // aggregator.getClass.getSimpleName can cause Malformed class name error, + // call safer `Utils.getSimpleName` instead + override def nodeName: String = Utils.getSimpleName(aggregator.getClass).stripSuffix("$"); } // TODO: merge these 2 implementations once we refactor the `AggregateFunction` interface. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala index 6ad11bda84bf6..93c8127681b3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala @@ -23,6 +23,7 @@ import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit} import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ object ArrowUtils { @@ -120,4 +121,19 @@ object ArrowUtils { StructField(field.getName, dt, field.isNullable) }) } + + /** Return Map with conf settings to be used in ArrowPythonRunner */ + def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = { + val timeZoneConf = if (conf.pandasRespectSessionTimeZone) { + Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> conf.sessionLocalTimeZone) + } else { + Nil + } + val pandasColsByPosition = if (conf.pandasGroupedMapAssignColumnssByPosition) { + Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_POSITION.key -> "true") + } else { + Nil + } + Map(timeZoneConf ++ pandasColsByPosition: _*) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index da35a4734e65a..7c8faec53a828 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -74,6 +74,16 @@ case class CachedRDDBuilder( } } + def withCachedPlan(cachedPlan: SparkPlan): CachedRDDBuilder = { + new CachedRDDBuilder( + useCompression, + batchSize, + storageLevel, + cachedPlan = cachedPlan, + tableName + )(_cachedColumnBuffers) + } + private def buildBuffers(): RDD[CachedBatch] = { val output = cachedPlan.output val cached = cachedPlan.execute().mapPartitionsInternal { rowIterator => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 0b4dd76c7d860..997cf92449c68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.vectorized._ import org.apache.spark.sql.types._ @@ -169,8 +169,8 @@ case class InMemoryTableScanExec( // But the cached version could alias output, so we need to replace output. override def outputPartitioning: Partitioning = { relation.cachedPlan.outputPartitioning match { - case h: HashPartitioning => updateAttribute(h).asInstanceOf[HashPartitioning] - case _ => relation.cachedPlan.outputPartitioning + case e: Expression => updateAttribute(e).asInstanceOf[Partitioning] + case other => other } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index bf4d96fa18d0d..04bf8c6dd917f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -189,8 +189,9 @@ case class DropTableCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog + val isTempView = catalog.isTemporaryTable(tableName) - if (!catalog.isTemporaryTable(tableName) && catalog.tableExists(tableName)) { + if (!isTempView && catalog.tableExists(tableName)) { // If the command DROP VIEW is to drop a table or DROP TABLE is to drop a view // issue an exception. catalog.getTableMetadata(tableName).tableType match { @@ -204,9 +205,10 @@ case class DropTableCommand( } } - if (catalog.isTemporaryTable(tableName) || catalog.tableExists(tableName)) { + if (isTempView || catalog.tableExists(tableName)) { try { - sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName)) + sparkSession.sharedState.cacheManager.uncacheQuery( + sparkSession.table(tableName), cascade = !isTempView) } catch { case NonFatal(e) => log.warn(e.toString, e) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 44749190c79eb..ec3961f84bd8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -493,7 +493,7 @@ case class TruncateTableCommand( spark.sessionState.refreshTable(tableName.unquotedString) // Also try to drop the contents of the table from the columnar cache try { - spark.sharedState.cacheManager.uncacheQuery(spark.table(table.identifier)) + spark.sharedState.cacheManager.uncacheQuery(spark.table(table.identifier), cascade = true) } catch { case NonFatal(e) => log.warn(s"Exception when attempting to uncache table $tableIdentWithDB", e) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala index a813829d50cb1..80d7608a22891 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala @@ -38,9 +38,8 @@ case class InsertIntoDataSourceCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] val data = Dataset.ofRows(sparkSession, query) - // Apply the schema of the existing table to the new data. - val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema) - relation.insert(df, overwrite) + // Data has been casted to the target relation's schema by the PreprocessTableInsertion rule. + relation.insert(data, overwrite) // Re-cache all cached plans(including this relation itself, if it's cached) that refer to this // data source relation. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index b23e5a7722004..b84543ccd7869 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -22,10 +22,12 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.Partition import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils /** * Instructions on how to partition the table among workers. @@ -48,10 +50,17 @@ private[sql] object JDBCRelation extends Logging { * Null value predicate is added to the first partition where clause to include * the rows with null value for the partitions column. * + * @param schema resolved schema of a JDBC table * @param partitioning partition information to generate the where clause for each partition + * @param resolver function used to determine if two identifiers are equal + * @param jdbcOptions JDBC options that contains url * @return an array of partitions with where clause for each partition */ - def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = { + def columnPartition( + schema: StructType, + partitioning: JDBCPartitioningInfo, + resolver: Resolver, + jdbcOptions: JDBCOptions): Array[Partition] = { if (partitioning == null || partitioning.numPartitions <= 1 || partitioning.lowerBound == partitioning.upperBound) { return Array[Partition](JDBCPartition(null, 0)) @@ -78,7 +87,10 @@ private[sql] object JDBCRelation extends Logging { // Overflow and silliness can happen if you subtract then divide. // Here we get a little roundoff, but that's (hopefully) OK. val stride: Long = upperBound / numPartitions - lowerBound / numPartitions - val column = partitioning.column + + val column = verifyAndGetNormalizedColumnName( + schema, partitioning.column, resolver, jdbcOptions) + var i: Int = 0 var currentValue: Long = lowerBound val ans = new ArrayBuffer[Partition]() @@ -99,10 +111,57 @@ private[sql] object JDBCRelation extends Logging { } ans.toArray } + + // Verify column name based on the JDBC resolved schema + private def verifyAndGetNormalizedColumnName( + schema: StructType, + columnName: String, + resolver: Resolver, + jdbcOptions: JDBCOptions): String = { + val dialect = JdbcDialects.get(jdbcOptions.url) + schema.map(_.name).find { fieldName => + resolver(fieldName, columnName) || + resolver(dialect.quoteIdentifier(fieldName), columnName) + }.map(dialect.quoteIdentifier).getOrElse { + throw new AnalysisException(s"User-defined partition column $columnName not " + + s"found in the JDBC relation: ${schema.simpleString(Utils.maxNumToStringFields)}") + } + } + + /** + * Takes a (schema, table) specification and returns the table's Catalyst schema. + * If `customSchema` defined in the JDBC options, replaces the schema's dataType with the + * custom schema's type. + * + * @param resolver function used to determine if two identifiers are equal + * @param jdbcOptions JDBC options that contains url, table and other information. + * @return resolved Catalyst schema of a JDBC table + */ + def getSchema(resolver: Resolver, jdbcOptions: JDBCOptions): StructType = { + val tableSchema = JDBCRDD.resolveTable(jdbcOptions) + jdbcOptions.customSchema match { + case Some(customSchema) => JdbcUtils.getCustomSchema( + tableSchema, customSchema, resolver) + case None => tableSchema + } + } + + /** + * Resolves a Catalyst schema of a JDBC table and returns [[JDBCRelation]] with the schema. + */ + def apply( + parts: Array[Partition], + jdbcOptions: JDBCOptions)( + sparkSession: SparkSession): JDBCRelation = { + val schema = JDBCRelation.getSchema(sparkSession.sessionState.conf.resolver, jdbcOptions) + JDBCRelation(schema, parts, jdbcOptions)(sparkSession) + } } private[sql] case class JDBCRelation( - parts: Array[Partition], jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession) + override val schema: StructType, + parts: Array[Partition], + jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession) extends BaseRelation with PrunedFilteredScan with InsertableRelation { @@ -111,15 +170,6 @@ private[sql] case class JDBCRelation( override val needConversion: Boolean = false - override val schema: StructType = { - val tableSchema = JDBCRDD.resolveTable(jdbcOptions) - jdbcOptions.customSchema match { - case Some(customSchema) => JdbcUtils.getCustomSchema( - tableSchema, customSchema, sparkSession.sessionState.conf.resolver) - case None => tableSchema - } - } - // Check if JDBCRDD.compileFilter can accept input filters override def unhandledFilters(filters: Array[Filter]): Array[Filter] = { filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions.url)).isEmpty) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index f8c5677ea0f2a..2b488bb7121dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -48,8 +48,10 @@ class JdbcRelationProvider extends CreatableRelationProvider JDBCPartitioningInfo( partitionColumn.get, lowerBound.get, upperBound.get, numPartitions.get) } - val parts = JDBCRelation.columnPartition(partitionInfo) - JDBCRelation(parts, jdbcOptions)(sqlContext.sparkSession) + val resolver = sqlContext.conf.resolver + val schema = JDBCRelation.getSchema(resolver, jdbcOptions) + val parts = JDBCRelation.columnPartition(schema, partitionInfo, resolver, jdbcOptions) + JDBCRelation(schema, parts, jdbcOptions)(sqlContext.sparkSession) } override def createRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 3b04510d29695..e9a0b383b5f49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions, JSONOptionsInRead} import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ @@ -40,7 +40,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], path: Path): Boolean = { - val parsedOptions = new JSONOptions( + val parsedOptions = new JSONOptionsInRead( options, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) @@ -52,7 +52,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - val parsedOptions = new JSONOptions( + val parsedOptions = new JSONOptionsInRead( options, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) @@ -99,7 +99,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - val parsedOptions = new JSONOptions( + val parsedOptions = new JSONOptionsInRead( options, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) @@ -158,6 +158,11 @@ private[json] class JsonOutputWriter( case None => StandardCharsets.UTF_8 } + if (JSONOptionsInRead.blacklist.contains(encoding)) { + logWarning(s"The JSON file ($path) was written in the encoding ${encoding.displayName()}" + + " which can be read back by Spark only if multiLine is enabled.") + } + private val writer = CodecStreams.createOutputStreamWriter( context, new Path(path), encoding) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index e7eed95a560a3..f6edc7bfb3750 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -75,7 +75,7 @@ private[sql] object JsonInferSchema { // active SparkSession and `SQLConf.get` may point to the wrong configs. val rootType = mergedTypesFromPartitions.toLocalIterator.fold(StructType(Nil))(typeMerger) - canonicalizeType(rootType) match { + canonicalizeType(rootType, configOptions) match { case Some(st: StructType) => st case _ => // canonicalizeType erases all empty structs, including the only one we want to keep @@ -181,33 +181,33 @@ private[sql] object JsonInferSchema { } /** - * Convert NullType to StringType and remove StructTypes with no fields + * Recursively canonicalizes inferred types, e.g., removes StructTypes with no fields, + * drops NullTypes or converts them to StringType based on provided options. */ - private def canonicalizeType(tpe: DataType): Option[DataType] = tpe match { - case at @ ArrayType(elementType, _) => - for { - canonicalType <- canonicalizeType(elementType) - } yield { - at.copy(canonicalType) - } + private def canonicalizeType(tpe: DataType, options: JSONOptions): Option[DataType] = tpe match { + case at: ArrayType => + canonicalizeType(at.elementType, options) + .map(t => at.copy(elementType = t)) case StructType(fields) => - val canonicalFields: Array[StructField] = for { - field <- fields - if field.name.length > 0 - canonicalType <- canonicalizeType(field.dataType) - } yield { - field.copy(dataType = canonicalType) + val canonicalFields = fields.filter(_.name.nonEmpty).flatMap { f => + canonicalizeType(f.dataType, options) + .map(t => f.copy(dataType = t)) } - - if (canonicalFields.length > 0) { - Some(StructType(canonicalFields)) + // SPARK-8093: empty structs should be deleted + if (canonicalFields.isEmpty) { + None } else { - // per SPARK-8093: empty structs should be deleted + Some(StructType(canonicalFields)) + } + + case NullType => + if (options.dropFieldIfAllNull) { None + } else { + Some(StringType) } - case NullType => Some(StringType) case other => Some(other) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala index 017a6737161a6..33079d5912506 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala @@ -30,8 +30,8 @@ class DataSourcePartitioning( override val numPartitions: Int = partitioning.numPartitions() - override def satisfies(required: physical.Distribution): Boolean = { - super.satisfies(required) || { + override def satisfies0(required: physical.Distribution): Boolean = { + super.satisfies0(required) || { required match { case d: physical.ClusteredDistribution if isCandidate(d.clustering) => val attrs = d.clustering.map(_.asInstanceOf[Attribute]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 90fb5a14c9fc9..7613eb210c659 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -22,79 +22,36 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} -import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.sources.{DataSourceRegister, Filter} +import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema} -import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns, SupportsReportStatistics} +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsReportStatistics} import org.apache.spark.sql.types.StructType +/** + * A logical plan representing a data source v2 scan. + * + * @param source An instance of a [[DataSourceV2]] implementation. + * @param options The options for this scan. Used to create fresh [[DataSourceReader]]. + * @param userSpecifiedSchema The user-specified schema for this scan. Used to create fresh + * [[DataSourceReader]]. + */ case class DataSourceV2Relation( source: DataSourceV2, + output: Seq[AttributeReference], options: Map[String, String], - projection: Seq[AttributeReference], - filters: Option[Seq[Expression]] = None, - userSpecifiedSchema: Option[StructType] = None) + userSpecifiedSchema: Option[StructType]) extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat { import DataSourceV2Relation._ - override def simpleString: String = "RelationV2 " + metadataString - - override lazy val schema: StructType = reader.readSchema() - - override lazy val output: Seq[AttributeReference] = { - // use the projection attributes to avoid assigning new ids. fields that are not projected - // will be assigned new ids, which is okay because they are not projected. - val attrMap = projection.map(a => a.name -> a).toMap - schema.map(f => attrMap.getOrElse(f.name, - AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())) - } + override def pushedFilters: Seq[Expression] = Seq.empty - private lazy val v2Options: DataSourceOptions = makeV2Options(options) - - // postScanFilters: filters that need to be evaluated after the scan. - // pushedFilters: filters that will be pushed down and evaluated in the underlying data sources. - // Note: postScanFilters and pushedFilters can overlap, e.g. the parquet row group filter. - lazy val ( - reader: DataSourceReader, - postScanFilters: Seq[Expression], - pushedFilters: Seq[Expression]) = { - val newReader = userSpecifiedSchema match { - case Some(s) => - source.asReadSupportWithSchema.createReader(s, v2Options) - case _ => - source.asReadSupport.createReader(v2Options) - } - - DataSourceV2Relation.pushRequiredColumns(newReader, projection.toStructType) - - val (postScanFilters, pushedFilters) = filters match { - case Some(filterSeq) => - DataSourceV2Relation.pushFilters(newReader, filterSeq) - case _ => - (Nil, Nil) - } - logInfo(s"Post-Scan Filters: ${postScanFilters.mkString(",")}") - logInfo(s"Pushed Filters: ${pushedFilters.mkString(", ")}") - - (newReader, postScanFilters, pushedFilters) - } - - override def doCanonicalize(): LogicalPlan = { - val c = super.doCanonicalize().asInstanceOf[DataSourceV2Relation] - - // override output with canonicalized output to avoid attempting to configure a reader - val canonicalOutput: Seq[AttributeReference] = this.output - .map(a => QueryPlan.normalizeExprId(a, projection)) + override def simpleString: String = "RelationV2 " + metadataString - new DataSourceV2Relation(c.source, c.options, c.projection) { - override lazy val output: Seq[AttributeReference] = canonicalOutput - } - } + def newReader(): DataSourceReader = source.createReader(options, userSpecifiedSchema) - override def computeStats(): Statistics = reader match { + override def computeStats(): Statistics = newReader match { case r: SupportsReportStatistics => Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) case _ => @@ -102,9 +59,7 @@ case class DataSourceV2Relation( } override def newInstance(): DataSourceV2Relation = { - // projection is used to maintain id assignment. - // if projection is not set, use output so the copy is not equal to the original - copy(projection = projection.map(_.newInstance())) + copy(output = output.map(_.newInstance())) } } @@ -184,77 +139,26 @@ object DataSourceV2Relation { source.getClass.getSimpleName } } - } - - private def makeV2Options(options: Map[String, String]): DataSourceOptions = { - new DataSourceOptions(options.asJava) - } - private def schema( - source: DataSourceV2, - v2Options: DataSourceOptions, - userSchema: Option[StructType]): StructType = { - val reader = userSchema match { - case Some(s) => - source.asReadSupportWithSchema.createReader(s, v2Options) - case _ => - source.asReadSupport.createReader(v2Options) + def createReader( + options: Map[String, String], + userSpecifiedSchema: Option[StructType]): DataSourceReader = { + val v2Options = new DataSourceOptions(options.asJava) + userSpecifiedSchema match { + case Some(s) => + asReadSupportWithSchema.createReader(s, v2Options) + case _ => + asReadSupport.createReader(v2Options) + } } - reader.readSchema() } def create( source: DataSourceV2, options: Map[String, String], - filters: Option[Seq[Expression]] = None, - userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = { - val projection = schema(source, makeV2Options(options), userSpecifiedSchema).toAttributes - DataSourceV2Relation(source, options, projection, filters, userSpecifiedSchema) - } - - private def pushRequiredColumns(reader: DataSourceReader, struct: StructType): Unit = { - reader match { - case projectionSupport: SupportsPushDownRequiredColumns => - projectionSupport.pruneColumns(struct) - case _ => - } - } - - private def pushFilters( - reader: DataSourceReader, - filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { - reader match { - case r: SupportsPushDownCatalystFilters => - val postScanFilters = r.pushCatalystFilters(filters.toArray) - val pushedFilters = r.pushedCatalystFilters() - (postScanFilters, pushedFilters) - - case r: SupportsPushDownFilters => - // A map from translated data source filters to original catalyst filter expressions. - val translatedFilterToExpr = scala.collection.mutable.HashMap.empty[Filter, Expression] - // Catalyst filter expression that can't be translated to data source filters. - val untranslatableExprs = scala.collection.mutable.ArrayBuffer.empty[Expression] - - for (filterExpr <- filters) { - val translated = DataSourceStrategy.translateFilter(filterExpr) - if (translated.isDefined) { - translatedFilterToExpr(translated.get) = filterExpr - } else { - untranslatableExprs += filterExpr - } - } - - // Data source filters that need to be evaluated again after scanning. which means - // the data source cannot guarantee the rows returned can pass these filters. - // As a result we must return it so Spark can plan an extra filter operator. - val postScanFilters = - r.pushFilters(translatedFilterToExpr.keys.toArray).map(translatedFilterToExpr) - // The filters which are marked as pushed to this data source - val pushedFilters = r.pushedFilters().map(translatedFilterToExpr) - - (untranslatableExprs ++ postScanFilters, pushedFilters) - - case _ => (filters, Nil) - } + userSpecifiedSchema: Option[StructType]): DataSourceV2Relation = { + val reader = source.createReader(options, userSpecifiedSchema) + DataSourceV2Relation( + source, reader.readSchema().toAttributes, options, userSpecifiedSchema) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 1b7c639f10f98..182aa2906cf1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -17,15 +17,120 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.Strategy +import scala.collection.mutable + +import org.apache.spark.sql.{sources, Strategy} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression} +import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} +import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns} object DataSourceV2Strategy extends Strategy { + + /** + * Pushes down filters to the data source reader + * + * @return pushed filter and post-scan filters. + */ + private def pushFilters( + reader: DataSourceReader, + filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + reader match { + case r: SupportsPushDownCatalystFilters => + val postScanFilters = r.pushCatalystFilters(filters.toArray) + val pushedFilters = r.pushedCatalystFilters() + (pushedFilters, postScanFilters) + + case r: SupportsPushDownFilters => + // A map from translated data source filters to original catalyst filter expressions. + val translatedFilterToExpr = mutable.HashMap.empty[sources.Filter, Expression] + // Catalyst filter expression that can't be translated to data source filters. + val untranslatableExprs = mutable.ArrayBuffer.empty[Expression] + + for (filterExpr <- filters) { + val translated = DataSourceStrategy.translateFilter(filterExpr) + if (translated.isDefined) { + translatedFilterToExpr(translated.get) = filterExpr + } else { + untranslatableExprs += filterExpr + } + } + + // Data source filters that need to be evaluated again after scanning. which means + // the data source cannot guarantee the rows returned can pass these filters. + // As a result we must return it so Spark can plan an extra filter operator. + val postScanFilters = r.pushFilters(translatedFilterToExpr.keys.toArray) + .map(translatedFilterToExpr) + // The filters which are marked as pushed to this data source + val pushedFilters = r.pushedFilters().map(translatedFilterToExpr) + (pushedFilters, untranslatableExprs ++ postScanFilters) + + case _ => (Nil, filters) + } + } + + /** + * Applies column pruning to the data source, w.r.t. the references of the given expressions. + * + * @return new output attributes after column pruning. + */ + // TODO: nested column pruning. + private def pruneColumns( + reader: DataSourceReader, + relation: DataSourceV2Relation, + exprs: Seq[Expression]): Seq[AttributeReference] = { + reader match { + case r: SupportsPushDownRequiredColumns => + val requiredColumns = AttributeSet(exprs.flatMap(_.references)) + val neededOutput = relation.output.filter(requiredColumns.contains) + if (neededOutput != relation.output) { + r.pruneColumns(neededOutput.toStructType) + val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap + r.readSchema().toAttributes.map { + // We have to keep the attribute id during transformation. + a => a.withExprId(nameToAttr(a.name).exprId) + } + } else { + relation.output + } + + case _ => relation.output + } + } + + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case r: DataSourceV2Relation => - DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil + case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => + val reader = relation.newReader() + // `pushedFilters` will be pushed down and evaluated in the underlying data sources. + // `postScanFilters` need to be evaluated after the scan. + // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. + val (pushedFilters, postScanFilters) = pushFilters(reader, filters) + val output = pruneColumns(reader, relation, project ++ postScanFilters) + logInfo( + s""" + |Pushing operators to ${relation.source.getClass} + |Pushed Filters: ${pushedFilters.mkString(", ")} + |Post-Scan Filters: ${postScanFilters.mkString(",")} + |Output: ${output.mkString(", ")} + """.stripMargin) + + val scan = DataSourceV2ScanExec( + output, relation.source, relation.options, pushedFilters, reader) + + val filterCondition = postScanFilters.reduceLeftOption(And) + val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) + + val withProjection = if (withFilter.output != project) { + ProjectExec(project, withFilter) + } else { + withFilter + } + + withProjection :: Nil case r: StreamingDataSourceV2Relation => DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala index 693e67dcd108e..97e6c6d702acb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala @@ -53,7 +53,9 @@ trait DataSourceV2StringFormat { private def sourceName: String = source match { case registered: DataSourceRegister => registered.shortName() - case _ => source.getClass.getSimpleName.stripSuffix("$") + // source.getClass.getSimpleName can cause Malformed class name error, + // call safer `Utils.getSimpleName` instead + case _ => Utils.getSimpleName(source.getClass) } def metadataString: String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala deleted file mode 100644 index e894f8afd6762..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet} -import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} -import org.apache.spark.sql.catalyst.rules.Rule - -object PushDownOperatorsToDataSource extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan match { - // PhysicalOperation guarantees that filters are deterministic; no need to check - case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => - assert(relation.filters.isEmpty, "data source v2 should do push down only once.") - - val projectAttrs = project.map(_.toAttribute) - val projectSet = AttributeSet(project.flatMap(_.references)) - val filterSet = AttributeSet(filters.flatMap(_.references)) - - val projection = if (filterSet.subsetOf(projectSet) && - AttributeSet(projectAttrs) == projectSet) { - // When the required projection contains all of the filter columns and column pruning alone - // can produce the required projection, push the required projection. - // A final projection may still be needed if the data source produces a different column - // order or if it cannot prune all of the nested columns. - projectAttrs - } else { - // When there are filter columns not already in the required projection or when the required - // projection is more complicated than column pruning, base column pruning on the set of - // all columns needed by both. - (projectSet ++ filterSet).toSeq - } - - val newRelation = relation.copy( - projection = projection.asInstanceOf[Seq[AttributeReference]], - filters = Some(filters)) - - // Add a Filter for any filters that need to be evaluated after scan. - val postScanFilterCond = newRelation.postScanFilters.reduceLeftOption(And) - val filtered = postScanFilterCond.map(Filter(_, newRelation)).getOrElse(newRelation) - - // Add a Project to ensure the output matches the required projection - if (newRelation.output != projectAttrs) { - Project(project, filtered) - } else { - filtered - } - - case other => other.mapChildren(apply) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index ea4bda327f36f..b1148c0f62f7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -29,10 +29,8 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.streaming.{MicroBatchExecution, StreamExecution} -import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions} +import org.apache.spark.sql.execution.streaming.MicroBatchExecution import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -109,10 +107,12 @@ object DataWritingSparkTask extends Logging { iter: Iterator[InternalRow], useCommitCoordinator: Boolean): WriterCommitMessage = { val stageId = context.stageId() + val stageAttempt = context.stageAttemptNumber() val partId = context.partitionId() + val taskId = context.taskAttemptId() val attemptId = context.attemptNumber() val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).getOrElse("0") - val dataWriter = writeTask.createDataWriter(partId, attemptId, epochId.toLong) + val dataWriter = writeTask.createDataWriter(partId, taskId, epochId.toLong) // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { @@ -122,12 +122,14 @@ object DataWritingSparkTask extends Logging { val msg = if (useCommitCoordinator) { val coordinator = SparkEnv.get.outputCommitCoordinator - val commitAuthorized = coordinator.canCommit(context.stageId(), partId, attemptId) + val commitAuthorized = coordinator.canCommit(stageId, stageAttempt, partId, attemptId) if (commitAuthorized) { - logInfo(s"Writer for stage $stageId, task $partId.$attemptId is authorized to commit.") + logInfo(s"Commit authorized for partition $partId (task $taskId, attempt $attemptId" + + s"stage $stageId.$stageAttempt)") dataWriter.commit() } else { - val message = s"Stage $stageId, task $partId.$attemptId: driver did not authorize commit" + val message = s"Commit denied for partition $partId (task $taskId, attempt $attemptId" + + s"stage $stageId.$stageAttempt)" logInfo(message) // throwing CommitDeniedException will trigger the catch block for abort throw new CommitDeniedException(message, stageId, partId, attemptId) @@ -138,15 +140,18 @@ object DataWritingSparkTask extends Logging { dataWriter.commit() } - logInfo(s"Writer for stage $stageId, task $partId.$attemptId committed.") + logInfo(s"Committed partition $partId (task $taskId, attempt $attemptId" + + s"stage $stageId.$stageAttempt)") msg })(catchBlock = { // If there is an error, abort this writer - logError(s"Writer for stage $stageId, task $partId.$attemptId is aborting.") + logError(s"Aborting commit for partition $partId (task $taskId, attempt $attemptId" + + s"stage $stageId.$stageAttempt)") dataWriter.abort() - logError(s"Writer for stage $stageId, task $partId.$attemptId aborted.") + logError(s"Aborted commit for partition $partId (task $taskId, attempt $attemptId" + + s"stage $stageId.$stageAttempt)") }) } } @@ -157,10 +162,10 @@ class InternalRowDataWriterFactory( override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[InternalRow] = { new InternalRowDataWriter( - rowWriterFactory.createDataWriter(partitionId, attemptNumber, epochId), + rowWriterFactory.createDataWriter(partitionId, taskId, epochId), RowEncoder.apply(schema).resolveAndBind()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index e3d28388c5470..ad95879d86f42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.exchange +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions._ @@ -227,9 +228,16 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { val leftKeysBuffer = ArrayBuffer[Expression]() val rightKeysBuffer = ArrayBuffer[Expression]() + val pickedIndexes = mutable.Set[Int]() + val keysAndIndexes = currentOrderOfKeys.zipWithIndex expectedOrderOfKeys.foreach(expression => { - val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression)) + val index = keysAndIndexes.find { case (e, idx) => + // As we may have the same key used many times, we need to filter out its occurrence we + // have already used. + e.semanticEquals(expression) && !pickedIndexes.contains(idx) + }.map(_._2).get + pickedIndexes += index leftKeysBuffer.append(leftKeys(index)) rightKeysBuffer.append(rightKeys(index)) }) @@ -270,7 +278,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { * partitioning of the join nodes' children. */ private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = { - plan.transformUp { + plan match { case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => val (reorderedLeftKeys, reorderedRightKeys) = @@ -288,6 +296,8 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { val (reorderedLeftKeys, reorderedRightKeys) = reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right) + + case other => other } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala index 09f79a2de0ba0..1a5b7599bb7d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -24,7 +24,7 @@ import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SortOrder} -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode} import org.apache.spark.sql.internal.SQLConf @@ -70,7 +70,7 @@ case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchan } override def outputPartitioning: Partitioning = child.outputPartitioning match { - case h: HashPartitioning => h.copy(expressions = h.expressions.map(updateAttr)) + case e: Expression => updateAttr(e).asInstanceOf[Partitioning] case other => other } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 77b907870d678..b4f0ae1eb1a18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.metric import java.text.NumberFormat import java.util.Locale +import java.util.concurrent.atomic.LongAdder import org.apache.spark.SparkContext import org.apache.spark.scheduler.AccumulableInfo @@ -32,40 +33,45 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils} * on the driver side must be explicitly posted using [[SQLMetrics.postDriverMetricUpdates()]]. */ class SQLMetric(val metricType: String, initValue: Long = 0L) extends AccumulatorV2[Long, Long] { + // This is a workaround for SPARK-11013. // We may use -1 as initial value of the accumulator, if the accumulator is valid, we will // update it at the end of task and the value will be at least 0. Then we can filter out the -1 // values before calculate max, min, etc. - private[this] var _value = initValue - private var _zeroValue = initValue + private[this] val _value = new LongAdder + private val _zeroValue = initValue + _value.add(initValue) override def copy(): SQLMetric = { - val newAcc = new SQLMetric(metricType, _value) - newAcc._zeroValue = initValue + val newAcc = new SQLMetric(metricType, initValue) + newAcc.add(_value.sum()) newAcc } - override def reset(): Unit = _value = _zeroValue + override def reset(): Unit = this.set(_zeroValue) override def merge(other: AccumulatorV2[Long, Long]): Unit = other match { - case o: SQLMetric => _value += o.value + case o: SQLMetric => _value.add(o.value) case _ => throw new UnsupportedOperationException( s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") } - override def isZero(): Boolean = _value == _zeroValue + override def isZero(): Boolean = _value.sum() == _zeroValue - override def add(v: Long): Unit = _value += v + override def add(v: Long): Unit = _value.add(v) // We can set a double value to `SQLMetric` which stores only long value, if it is // average metrics. def set(v: Double): Unit = SQLMetrics.setDoubleForAverageMetrics(this, v) - def set(v: Long): Unit = _value = v + def set(v: Long): Unit = { + _value.reset() + _value.add(v) + } - def +=(v: Long): Unit = _value += v + def +=(v: Long): Unit = _value.add(v) - override def value: Long = _value + override def value: Long = _value.sum() // Provide special identifier as metadata so we can tell that this is a `SQLMetric` later override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = { @@ -153,7 +159,7 @@ object SQLMetrics { Seq.fill(3)(0L) } else { val sorted = validValues.sorted - Seq(sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) + Seq(sorted.head, sorted(validValues.length / 2), sorted(validValues.length - 1)) } metric.map(v => numberFormat.format(v.toDouble / baseForAvgMetric)) } @@ -173,7 +179,8 @@ object SQLMetrics { Seq.fill(4)(0L) } else { val sorted = validValues.sorted - Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) + Seq(sorted.sum, sorted.head, sorted(validValues.length / 2), + sorted(validValues.length - 1)) } metric.map(strFormat) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 8e01e8e56a5bd..d00f6f042d6e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.util.Utils @@ -81,7 +82,7 @@ case class AggregateInPandasExec( val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) val sessionLocalTimeZone = conf.sessionLocalTimeZone - val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip @@ -135,10 +136,14 @@ case class AggregateInPandasExec( } val columnarBatchIter = new ArrowPythonRunner( - pyFuncs, bufferSize, reuseWorker, - PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, argOffsets, aggInputSchema, - sessionLocalTimeZone, pandasRespectSessionTimeZone) - .compute(projectedRowIter, context.partitionId(), context) + pyFuncs, + bufferSize, + reuseWorker, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, + argOffsets, + aggInputSchema, + sessionLocalTimeZone, + pythonRunnerConf).compute(projectedRowIter, context.partitionId(), context) val joinedAttributes = groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index c4de214679ae4..0bc21c0986e69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -24,6 +24,7 @@ import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.StructType /** @@ -63,7 +64,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi private val batchSize = conf.arrowMaxRecordsPerBatch private val sessionLocalTimeZone = conf.sessionLocalTimeZone - private val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) protected override def evaluate( funcs: Seq[ChainedPythonFunctions], @@ -80,10 +81,14 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) val columnarBatchIter = new ArrowPythonRunner( - funcs, bufferSize, reuseWorker, - PythonEvalType.SQL_SCALAR_PANDAS_UDF, argOffsets, schema, - sessionLocalTimeZone, pandasRespectSessionTimeZone) - .compute(batchIter, context.partitionId(), context) + funcs, + bufferSize, + reuseWorker, + PythonEvalType.SQL_SCALAR_PANDAS_UDF, + argOffsets, + schema, + sessionLocalTimeZone, + pythonRunnerConf).compute(batchIter, context.partitionId(), context) new Iterator[InternalRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 01e19bddbfb66..ca665652f204d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -45,7 +45,7 @@ class ArrowPythonRunner( argOffsets: Array[Array[Int]], schema: StructType, timeZoneId: String, - respectTimeZone: Boolean) + conf: Map[String, String]) extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( funcs, bufferSize, reuseWorker, evalType, argOffsets) { @@ -58,12 +58,15 @@ class ArrowPythonRunner( new WriterThread(env, worker, inputIterator, partitionIndex, context) { protected override def writeCommand(dataOut: DataOutputStream): Unit = { - PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) - if (respectTimeZone) { - PythonRDD.writeUTF(timeZoneId, dataOut) - } else { - dataOut.writeInt(SpecialLengths.NULL) + + // Write config for the worker as a number of key -> value pairs of strings + dataOut.writeInt(conf.size) + for ((k, v) <- conf) { + PythonRDD.writeUTF(k, dataOut) + PythonRDD.writeUTF(v, dataOut) } + + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 9d56f48249982..1e096100f7f43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -39,7 +39,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { */ private def belongAggregate(e: Expression, agg: Aggregate): Boolean = { e.isInstanceOf[AggregateExpression] || - PythonUDF.isGroupAggPandasUDF(e) || + PythonUDF.isGroupedAggPandasUDF(e) || agg.groupingExpressions.exists(_.semanticEquals(e)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 513e174c7733e..f5a563baf52df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.StructType /** @@ -77,7 +78,7 @@ case class FlatMapGroupsInPandasExec( val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) val sessionLocalTimeZone = conf.sessionLocalTimeZone - val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) // Deduplicate the grouping attributes. // If a grouping attribute also appears in data attributes, then we don't need to send the @@ -139,10 +140,14 @@ case class FlatMapGroupsInPandasExec( val context = TaskContext.get() val columnarBatchIter = new ArrowPythonRunner( - chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, dedupSchema, - sessionLocalTimeZone, pandasRespectSessionTimeZone) - .compute(grouped, context.partitionId(), context) + chainedFunc, + bufferSize, + reuseWorker, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + argOffsets, + dedupSchema, + sessionLocalTimeZone, + pythonRunnerConf).compute(grouped, context.partitionId(), context) columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala new file mode 100644 index 0000000000000..a58773122922f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.File +import java.util.concurrent.TimeUnit +import java.util.concurrent.locks.ReentrantLock + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python._ +import org.apache.spark.internal.Logging +import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.sql.ForeachWriter +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.{NextIterator, Utils} + +class PythonForeachWriter(func: PythonFunction, schema: StructType) + extends ForeachWriter[UnsafeRow] { + + private lazy val context = TaskContext.get() + private lazy val buffer = new PythonForeachWriter.UnsafeRowBuffer( + context.taskMemoryManager, new File(Utils.getLocalDir(SparkEnv.get.conf)), schema.fields.length) + private lazy val inputRowIterator = buffer.iterator + + private lazy val inputByteIterator = { + EvaluatePython.registerPicklers() + val objIterator = inputRowIterator.map { row => EvaluatePython.toJava(row, schema) } + new SerDeUtil.AutoBatchedPickler(objIterator) + } + + private lazy val pythonRunner = { + val conf = SparkEnv.get.conf + val bufferSize = conf.getInt("spark.buffer.size", 65536) + val reuseWorker = conf.getBoolean("spark.python.worker.reuse", true) + PythonRunner(func, bufferSize, reuseWorker) + } + + private lazy val outputIterator = + pythonRunner.compute(inputByteIterator, context.partitionId(), context) + + override def open(partitionId: Long, version: Long): Boolean = { + outputIterator // initialize everything + TaskContext.get.addTaskCompletionListener { _ => buffer.close() } + true + } + + override def process(value: UnsafeRow): Unit = { + buffer.add(value) + } + + override def close(errorOrNull: Throwable): Unit = { + buffer.allRowsAdded() + if (outputIterator.hasNext) outputIterator.next() // to throw python exception if there was one + } +} + +object PythonForeachWriter { + + /** + * A buffer that is designed for the sole purpose of buffering UnsafeRows in PythonForeachWriter. + * It is designed to be used with only 1 writer thread (i.e. JVM task thread) and only 1 reader + * thread (i.e. PythonRunner writing thread that reads from the buffer and writes to the Python + * worker stdin). Adds to the buffer are non-blocking, and reads through the buffer's iterator + * are blocking, that is, it blocks until new data is available or all data has been added. + * + * Internally, it uses a [[HybridRowQueue]] to buffer the rows in a practically unlimited queue + * across memory and local disk. However, HybridRowQueue is designed to be used only with + * EvalPythonExec where the reader is always behind the the writer, that is, the reader does not + * try to read n+1 rows if the writer has only written n rows at any point of time. This + * assumption is not true for PythonForeachWriter where rows may be added at a different rate as + * they are consumed by the python worker. Hence, to maintain the invariant of the reader being + * behind the writer while using HybridRowQueue, the buffer does the following + * - Keeps a count of the rows in the HybridRowQueue + * - Blocks the buffer's consuming iterator when the count is 0 so that the reader does not + * try to read more rows than what has been written. + * + * The implementation of the blocking iterator (ReentrantLock, Condition, etc.) has been borrowed + * from that of ArrayBlockingQueue. + */ + class UnsafeRowBuffer(taskMemoryManager: TaskMemoryManager, tempDir: File, numFields: Int) + extends Logging { + private val queue = HybridRowQueue(taskMemoryManager, tempDir, numFields) + private val lock = new ReentrantLock() + private val unblockRemove = lock.newCondition() + + // All of these are guarded by `lock` + private var count = 0L + private var allAdded = false + private var exception: Throwable = null + + val iterator = new NextIterator[UnsafeRow] { + override protected def getNext(): UnsafeRow = { + val row = remove() + if (row == null) finished = true + row + } + override protected def close(): Unit = { } + } + + def add(row: UnsafeRow): Unit = withLock { + assert(queue.add(row), s"Failed to add row to HybridRowQueue while sending data to Python" + + s"[count = $count, allAdded = $allAdded, exception = $exception]") + count += 1 + unblockRemove.signal() + logTrace(s"Added $row, $count left") + } + + private def remove(): UnsafeRow = withLock { + while (count == 0 && !allAdded && exception == null) { + unblockRemove.await(100, TimeUnit.MILLISECONDS) + } + + // If there was any error in the adding thread, then rethrow it in the removing thread + if (exception != null) throw exception + + if (count > 0) { + val row = queue.remove() + assert(row != null, "HybridRowQueue.remove() returned null " + + s"[count = $count, allAdded = $allAdded, exception = $exception]") + count -= 1 + logTrace(s"Removed $row, $count left") + row + } else { + null + } + } + + def allRowsAdded(): Unit = withLock { + allAdded = true + unblockRemove.signal() + } + + def close(): Unit = { queue.close() } + + private def withLock[T](f: => T): T = { + lock.lockInterruptibly() + try { f } catch { + case e: Throwable => + if (exception == null) exception = e + throw e + } finally { lock.unlock() } + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala new file mode 100644 index 0000000000000..628029b13a6c3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.File + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.arrow.ArrowUtils +import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.util.Utils + +case class WindowInPandasExec( + windowExpression: Seq[NamedExpression], + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + child: SparkPlan) extends UnaryExecNode { + + override def output: Seq[Attribute] = + child.output ++ windowExpression.map(_.toAttribute) + + override def requiredChildDistribution: Seq[Distribution] = { + if (partitionSpec.isEmpty) { + // Only show warning when the number of bytes is larger than 100 MB? + logWarning("No Partition Defined for Window operation! Moving all data to a single " + + "partition, this can cause serious performance degradation.") + AllTuples :: Nil + } else { + ClusteredDistribution(partitionSpec) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec) + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputPartitioning: Partitioning = child.outputPartitioning + + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { + udf.children match { + case Seq(u: PythonUDF) => + val (chained, children) = collectFunctions(u) + (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) + case children => + // There should not be any other UDFs, or the children can't be evaluated directly. + assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) + (ChainedPythonFunctions(Seq(udf.func)), udf.children) + } + } + + /** + * Create the resulting projection. + * + * This method uses Code Generation. It can only be used on the executor side. + * + * @param expressions unbound ordered function expressions. + * @return the final resulting projection. + */ + private[this] def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = { + val references = expressions.zipWithIndex.map { case (e, i) => + // Results of window expressions will be on the right side of child's output + BoundReference(child.output.size + i, e.dataType, e.nullable) + } + val unboundToRefMap = expressions.zip(references).toMap + val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) + UnsafeProjection.create( + child.output ++ patchedWindowExpression, + child.output) + } + + protected override def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute() + + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) + val sessionLocalTimeZone = conf.sessionLocalTimeZone + val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + + // Extract window expressions and window functions + val expressions = windowExpression.flatMap(_.collect { case e: WindowExpression => e }) + + val udfExpressions = expressions.map(_.windowFunction.asInstanceOf[PythonUDF]) + + val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip + + // Filter child output attributes down to only those that are UDF inputs. + // Also eliminate duplicate UDF inputs. + val allInputs = new ArrayBuffer[Expression] + val dataTypes = new ArrayBuffer[DataType] + val argOffsets = inputs.map { input => + input.map { e => + if (allInputs.exists(_.semanticEquals(e))) { + allInputs.indexWhere(_.semanticEquals(e)) + } else { + allInputs += e + dataTypes += e.dataType + allInputs.length - 1 + } + }.toArray + }.toArray + + // Schema of input rows to the python runner + val windowInputSchema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => + StructField(s"_$i", dt) + }) + + inputRDD.mapPartitionsInternal { iter => + val context = TaskContext.get() + + val grouped = if (partitionSpec.isEmpty) { + // Use an empty unsafe row as a place holder for the grouping key + Iterator((new UnsafeRow(), iter)) + } else { + GroupedIterator(iter, partitionSpec, child.output) + } + + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = HybridRowQueue(context.taskMemoryManager(), + new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length) + context.addTaskCompletionListener { _ => + queue.close() + } + + val inputProj = UnsafeProjection.create(allInputs, child.output) + val pythonInput = grouped.map { case (_, rows) => + rows.map { row => + queue.add(row.asInstanceOf[UnsafeRow]) + inputProj(row) + } + } + + val windowFunctionResult = new ArrowPythonRunner( + pyFuncs, + bufferSize, + reuseWorker, + PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, + argOffsets, + windowInputSchema, + sessionLocalTimeZone, + pythonRunnerConf).compute(pythonInput, context.partitionId(), context) + + val joined = new JoinedRow + val resultProj = createResultProjection(expressions) + + windowFunctionResult.flatMap(_.rowIterator.asScala).map { windowOutput => + val leftRow = queue.remove() + val joinedRow = joined(leftRow, windowOutput) + resultProj(joinedRow) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index afa664eb76525..50cf971e4ec3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -167,8 +167,8 @@ case class StreamingSymmetricHashJoinExec( val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length) override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) :: - ClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil + HashClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) :: + HashClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil override def output: Seq[Attribute] = joinType match { case _: InnerLike => left.output ++ right.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala index ef5f0da1e7cc2..76f3f5baa8d56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala @@ -56,7 +56,7 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor val dataIterator = prev.compute(split, context) dataWriter = writeTask.createDataWriter( context.partitionId(), - context.attemptNumber(), + context.taskAttemptId(), EpochTracker.getCurrentEpoch.get) while (dataIterator.hasNext) { dataWriter.write(dataIterator.next()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index 801b28b751bee..cf6572d3de1f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -34,8 +34,10 @@ case class ContinuousShuffleReadPartition( // Initialized only on the executor, and only once even as we call compute() multiple times. lazy val (reader: ContinuousShuffleReader, endpoint) = { val env = SparkEnv.get.rpcEnv - val receiver = new UnsafeRowReceiver(queueSize, numShuffleWriters, epochIntervalMs, env) - val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", receiver) + val receiver = new RPCContinuousShuffleReader( + queueSize, numShuffleWriters, epochIntervalMs, env) + val endpoint = env.setupEndpoint(s"RPCContinuousShuffleReader-${UUID.randomUUID()}", receiver) + TaskContext.get().addTaskCompletionListener { ctx => env.stop(endpoint) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala new file mode 100644 index 0000000000000..47b1f78b24505 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +/** + * Trait for writing to a continuous processing shuffle. + */ +trait ContinuousShuffleWriter { + def write(epoch: Iterator[UnsafeRow]): Unit +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala similarity index 86% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala index d81f552d56626..834e84675c7d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala @@ -20,26 +20,24 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean -import scala.collection.mutable - import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.util.NextIterator /** - * Messages for the UnsafeRowReceiver endpoint. Either an incoming row or an epoch marker. + * Messages for the RPCContinuousShuffleReader endpoint. Either an incoming row or an epoch marker. * * Each message comes tagged with writerId, identifying which writer the message is coming * from. The receiver will only begin the next epoch once all writers have sent an epoch * marker ending the current epoch. */ -private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable { +private[shuffle] sealed trait RPCContinuousShuffleMessage extends Serializable { def writerId: Int } private[shuffle] case class ReceiverRow(writerId: Int, row: UnsafeRow) - extends UnsafeRowReceiverMessage -private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends UnsafeRowReceiverMessage + extends RPCContinuousShuffleMessage +private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends RPCContinuousShuffleMessage /** * RPC endpoint for receiving rows into a continuous processing shuffle task. Continuous shuffle @@ -48,7 +46,7 @@ private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends UnsafeRow * TODO: Support multiple source tasks. We need to output a single epoch marker once all * source tasks have sent one. */ -private[shuffle] class UnsafeRowReceiver( +private[shuffle] class RPCContinuousShuffleReader( queueSize: Int, numShuffleWriters: Int, epochIntervalMs: Long, @@ -57,7 +55,7 @@ private[shuffle] class UnsafeRowReceiver( // Note that this queue will be drained from the main task thread and populated in the RPC // response thread. private val queues = Array.fill(numShuffleWriters) { - new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize) + new ArrayBlockingQueue[RPCContinuousShuffleMessage](queueSize) } // Exposed for testing to determine if the endpoint gets stopped on task end. @@ -68,7 +66,9 @@ private[shuffle] class UnsafeRowReceiver( } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case r: UnsafeRowReceiverMessage => + case r: RPCContinuousShuffleMessage => + // Note that this will block a thread the shared RPC handler pool! + // The TCP based shuffle handler (SPARK-24541) will avoid this problem. queues(r.writerId).put(r) context.reply(()) } @@ -79,10 +79,10 @@ private[shuffle] class UnsafeRowReceiver( private val writerEpochMarkersReceived = Array.fill(numShuffleWriters)(false) private val executor = Executors.newFixedThreadPool(numShuffleWriters) - private val completion = new ExecutorCompletionService[UnsafeRowReceiverMessage](executor) + private val completion = new ExecutorCompletionService[RPCContinuousShuffleMessage](executor) - private def completionTask(writerId: Int) = new Callable[UnsafeRowReceiverMessage] { - override def call(): UnsafeRowReceiverMessage = queues(writerId).take() + private def completionTask(writerId: Int) = new Callable[RPCContinuousShuffleMessage] { + override def call(): RPCContinuousShuffleMessage = queues(writerId).take() } // Initialize by submitting tasks to read the first row from each writer. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala new file mode 100644 index 0000000000000..1c6f3ddb395e6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import scala.concurrent.Future +import scala.concurrent.duration.Duration + +import org.apache.spark.Partitioner +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.util.ThreadUtils + +/** + * A [[ContinuousShuffleWriter]] sending data to [[RPCContinuousShuffleReader]] instances. + * + * @param writerId The partition ID of this writer. + * @param outputPartitioner The partitioner on the reader side of the shuffle. + * @param endpoints The [[RPCContinuousShuffleReader]] endpoints to write to. Indexed by + * partition ID within outputPartitioner. + */ +class RPCContinuousShuffleWriter( + writerId: Int, + outputPartitioner: Partitioner, + endpoints: Array[RpcEndpointRef]) extends ContinuousShuffleWriter { + + if (outputPartitioner.numPartitions != 1) { + throw new IllegalArgumentException("multiple readers not yet supported") + } + + if (outputPartitioner.numPartitions != endpoints.length) { + throw new IllegalArgumentException(s"partitioner size ${outputPartitioner.numPartitions} did " + + s"not match endpoint count ${endpoints.length}") + } + + def write(epoch: Iterator[UnsafeRow]): Unit = { + while (epoch.hasNext) { + val row = epoch.next() + endpoints(outputPartitioner.getPartition(row)).askSync[Unit](ReceiverRow(writerId, row)) + } + + val futures = endpoints.map(_.ask[Unit](ReceiverEpochMarker(writerId))).toSeq + implicit val ec = ThreadUtils.sameThread + ThreadUtils.awaitResult(Future.sequence(futures), Duration.Inf) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index b137f98045c5a..7fa13c4aa2c01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode @@ -221,19 +222,60 @@ class MemoryStreamInputPartition(records: Array[UnsafeRow]) } /** A common trait for MemorySinks with methods used for testing */ -trait MemorySinkBase extends BaseStreamingSink { +trait MemorySinkBase extends BaseStreamingSink with Logging { def allData: Seq[Row] def latestBatchData: Seq[Row] def dataSinceBatch(sinceBatchId: Long): Seq[Row] def latestBatchId: Option[Long] + + /** + * Truncates the given rows to return at most maxRows rows. + * @param rows The data that may need to be truncated. + * @param batchLimit Number of rows to keep in this batch; the rest will be truncated + * @param sinkLimit Total number of rows kept in this sink, for logging purposes. + * @param batchId The ID of the batch that sent these rows, for logging purposes. + * @return Truncated rows. + */ + protected def truncateRowsIfNeeded( + rows: Array[Row], + batchLimit: Int, + sinkLimit: Int, + batchId: Long): Array[Row] = { + if (rows.length > batchLimit && batchLimit >= 0) { + logWarning(s"Truncating batch $batchId to $batchLimit rows because of sink limit $sinkLimit") + rows.take(batchLimit) + } else { + rows + } + } +} + +/** + * Companion object to MemorySinkBase. + */ +object MemorySinkBase { + val MAX_MEMORY_SINK_ROWS = "maxRows" + val MAX_MEMORY_SINK_ROWS_DEFAULT = -1 + + /** + * Gets the max number of rows a MemorySink should store. This number is based on the memory + * sink row limit option if it is set. If not, we use a large value so that data truncates + * rather than causing out of memory errors. + * @param options Options for writing from which we get the max rows option + * @return The maximum number of rows a memorySink should store. + */ + def getMemorySinkCapacity(options: DataSourceOptions): Int = { + val maxRows = options.getInt(MAX_MEMORY_SINK_ROWS, MAX_MEMORY_SINK_ROWS_DEFAULT) + if (maxRows >= 0) maxRows else Int.MaxValue - 10 + } } /** * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink - with MemorySinkBase with Logging { +class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSourceOptions) + extends Sink with MemorySinkBase with Logging { private case class AddedData(batchId: Long, data: Array[Row]) @@ -241,6 +283,12 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink @GuardedBy("this") private val batches = new ArrayBuffer[AddedData]() + /** The number of rows in this MemorySink. */ + private var numRows = 0 + + /** The capacity in rows of this sink. */ + val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options) + /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { batches.flatMap(_.data) @@ -273,14 +321,23 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink logDebug(s"Committing batch $batchId to $this") outputMode match { case Append | Update => - val rows = AddedData(batchId, data.collect()) - synchronized { batches += rows } + var rowsToAdd = data.collect() + synchronized { + rowsToAdd = + truncateRowsIfNeeded(rowsToAdd, sinkCapacity - numRows, sinkCapacity, batchId) + val rows = AddedData(batchId, rowsToAdd) + batches += rows + numRows += rowsToAdd.length + } case Complete => - val rows = AddedData(batchId, data.collect()) + var rowsToAdd = data.collect() synchronized { + rowsToAdd = truncateRowsIfNeeded(rowsToAdd, sinkCapacity, sinkCapacity, batchId) + val rows = AddedData(batchId, rowsToAdd) batches.clear() batches += rows + numRows = rowsToAdd.length } case _ => @@ -294,6 +351,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink def clear(): Unit = synchronized { batches.clear() + numRows = 0 } override def toString(): String = "MemorySink" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala new file mode 100644 index 0000000000000..03c567c58d46a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.api.python.PythonException +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.execution.streaming.Sink +import org.apache.spark.sql.streaming.DataStreamWriter + +class ForeachBatchSink[T](batchWriter: (Dataset[T], Long) => Unit, encoder: ExpressionEncoder[T]) + extends Sink { + + override def addBatch(batchId: Long, data: DataFrame): Unit = { + val resolvedEncoder = encoder.resolveAndBind( + data.logicalPlan.output, + data.sparkSession.sessionState.analyzer) + val rdd = data.queryExecution.toRdd.map[T](resolvedEncoder.fromRow)(encoder.clsTag) + val ds = data.sparkSession.createDataset(rdd)(encoder) + batchWriter(ds, batchId) + } + + override def toString(): String = "ForeachBatchSink" +} + + +/** + * Interface that is meant to be extended by Python classes via Py4J. + * Py4J allows Python classes to implement Java interfaces so that the JVM can call back + * Python objects. In this case, this allows the user-defined Python `foreachBatch` function + * to be called from JVM when the query is active. + * */ +trait PythonForeachBatchFunction { + /** Call the Python implementation of this function */ + def call(batchDF: DataFrame, batchId: Long): Unit +} + +object PythonForeachBatchHelper { + def callForeachBatch(dsw: DataStreamWriter[Row], pythonFunc: PythonForeachBatchFunction): Unit = { + dsw.foreachBatch(pythonFunc.call _) + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala index df5d69d57e36f..bc9b6d93ce7d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.sql.{Encoder, ForeachWriter, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.python.PythonForeachWriter import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter @@ -31,9 +33,14 @@ import org.apache.spark.sql.types.StructType * [[ForeachWriter]]. * * @param writer The [[ForeachWriter]] to process all data. + * @param converter An object to convert internal rows to target type T. Either it can be + * a [[ExpressionEncoder]] or a direct converter function. * @tparam T The expected type of the sink. */ -case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends StreamWriteSupport { +case class ForeachWriterProvider[T]( + writer: ForeachWriter[T], + converter: Either[ExpressionEncoder[T], InternalRow => T]) extends StreamWriteSupport { + override def createStreamWriter( queryId: String, schema: StructType, @@ -44,10 +51,16 @@ case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends S override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { - val encoder = encoderFor[T].resolveAndBind( - schema.toAttributes, - SparkSession.getActiveSession.get.sessionState.analyzer) - ForeachWriterFactory(writer, encoder) + val rowConverter: InternalRow => T = converter match { + case Left(enc) => + val boundEnc = enc.resolveAndBind( + schema.toAttributes, + SparkSession.getActiveSession.get.sessionState.analyzer) + boundEnc.fromRow + case Right(func) => + func + } + ForeachWriterFactory(writer, rowConverter) } override def toString: String = "ForeachSink" @@ -55,29 +68,44 @@ case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends S } } -case class ForeachWriterFactory[T: Encoder]( +object ForeachWriterProvider { + def apply[T]( + writer: ForeachWriter[T], + encoder: ExpressionEncoder[T]): ForeachWriterProvider[_] = { + writer match { + case pythonWriter: PythonForeachWriter => + new ForeachWriterProvider[UnsafeRow]( + pythonWriter, Right((x: InternalRow) => x.asInstanceOf[UnsafeRow])) + case _ => + new ForeachWriterProvider[T](writer, Left(encoder)) + } + } +} + +case class ForeachWriterFactory[T]( writer: ForeachWriter[T], - encoder: ExpressionEncoder[T]) + rowConverter: InternalRow => T) extends DataWriterFactory[InternalRow] { override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): ForeachDataWriter[T] = { - new ForeachDataWriter(writer, encoder, partitionId, epochId) + new ForeachDataWriter(writer, rowConverter, partitionId, epochId) } } /** * A [[DataWriter]] which writes data in this partition to a [[ForeachWriter]]. + * * @param writer The [[ForeachWriter]] to process all data. - * @param encoder An encoder which can convert [[InternalRow]] to the required type [[T]] + * @param rowConverter A function which can convert [[InternalRow]] to the required type [[T]] * @param partitionId * @param epochId * @tparam T The type expected by the writer. */ -class ForeachDataWriter[T : Encoder]( +class ForeachDataWriter[T]( writer: ForeachWriter[T], - encoder: ExpressionEncoder[T], + rowConverter: InternalRow => T, partitionId: Int, epochId: Long) extends DataWriter[InternalRow] { @@ -89,7 +117,7 @@ class ForeachDataWriter[T : Encoder]( if (!opened) return try { - writer.process(encoder.fromRow(record)) + writer.process(rowConverter(record)) } catch { case t: Throwable => writer.close(t) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala index e07355aa37dba..b501d90c81f06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, Dat case object PackedRowWriterFactory extends DataWriterFactory[Row] { override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[Row] = { new PackedRowDataWriter() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala index fbff8db987110..b393c48baee8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -202,7 +202,7 @@ class RateStreamMicroBatchInputPartitionReader( rangeEnd: Long, localStartTimeMs: Long, relativeMsPerValue: Double) extends InputPartitionReader[Row] { - private var count = 0 + private var count: Long = 0 override def next(): Boolean = { rangeStart + partitionId + numPartitions * count < rangeEnd diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 468313bfe8c3c..29f8cca476722 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -46,7 +46,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB schema: StructType, mode: OutputMode, options: DataSourceOptions): StreamWriter = { - new MemoryStreamWriter(this, mode) + new MemoryStreamWriter(this, mode, options) } private case class AddedData(batchId: Long, data: Array[Row]) @@ -55,6 +55,9 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB @GuardedBy("this") private val batches = new ArrayBuffer[AddedData]() + /** The number of rows in this MemorySink. */ + private var numRows = 0 + /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { batches.flatMap(_.data) @@ -81,7 +84,11 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB }.mkString("\n") } - def write(batchId: Long, outputMode: OutputMode, newRows: Array[Row]): Unit = { + def write( + batchId: Long, + outputMode: OutputMode, + newRows: Array[Row], + sinkCapacity: Int): Unit = { val notCommitted = synchronized { latestBatchId.isEmpty || batchId > latestBatchId.get } @@ -89,14 +96,21 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB logDebug(s"Committing batch $batchId to $this") outputMode match { case Append | Update => - val rows = AddedData(batchId, newRows) - synchronized { batches += rows } + synchronized { + val rowsToAdd = + truncateRowsIfNeeded(newRows, sinkCapacity - numRows, sinkCapacity, batchId) + val rows = AddedData(batchId, rowsToAdd) + batches += rows + numRows += rowsToAdd.length + } case Complete => - val rows = AddedData(batchId, newRows) synchronized { + val rowsToAdd = truncateRowsIfNeeded(newRows, sinkCapacity, sinkCapacity, batchId) + val rows = AddedData(batchId, rowsToAdd) batches.clear() batches += rows + numRows = rowsToAdd.length } case _ => @@ -110,6 +124,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB def clear(): Unit = synchronized { batches.clear() + numRows = 0 } override def toString(): String = "MemorySinkV2" @@ -117,16 +132,22 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} -class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) +class MemoryWriter( + sink: MemorySinkV2, + batchId: Long, + outputMode: OutputMode, + options: DataSourceOptions) extends DataSourceWriter with Logging { + val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options) + override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) def commit(messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { case message: MemoryWriterCommitMessage => message.data } - sink.write(batchId, outputMode, newRows) + sink.write(batchId, outputMode, newRows, sinkCapacity) } override def abort(messages: Array[WriterCommitMessage]): Unit = { @@ -134,16 +155,21 @@ class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) } } -class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode) +class MemoryStreamWriter( + val sink: MemorySinkV2, + outputMode: OutputMode, + options: DataSourceOptions) extends StreamWriter { + val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options) + override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { case message: MemoryWriterCommitMessage => message.data } - sink.write(epochId, outputMode, newRows) + sink.write(epochId, outputMode, newRows, sinkCapacity) } override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { @@ -154,7 +180,7 @@ class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode) case class MemoryWriterFactory(outputMode: OutputMode) extends DataWriterFactory[Row] { override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[Row] = { new MemoryDataWriter(partitionId, outputMode) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala index 8240e06d4ab72..91e3b7179c34a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala @@ -22,6 +22,7 @@ import java.net.Socket import java.sql.Timestamp import java.text.SimpleDateFormat import java.util.{Calendar, List => JList, Locale, Optional} +import java.util.concurrent.atomic.AtomicBoolean import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ @@ -76,7 +77,7 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR @GuardedBy("this") private var lastOffsetCommitted: LongOffset = LongOffset(-1L) - initialize() + private val initialized: AtomicBoolean = new AtomicBoolean(false) /** This method is only used for unit test */ private[sources] def getCurrentOffset(): LongOffset = synchronized { @@ -149,6 +150,10 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR // Internal buffer only holds the batches after lastOffsetCommitted val rawList = synchronized { + if (initialized.compareAndSet(false, true)) { + initialize() + } + val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1 val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1 batches.slice(sliceStart, sliceEnd) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index df722b953228b..118c82aa75e68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -18,12 +18,10 @@ package org.apache.spark.sql.execution.streaming.state import java.io._ -import java.nio.channels.ClosedChannelException import java.util.Locale import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.util.Random import scala.util.control.NonFatal import com.google.common.io.ByteStreams @@ -280,38 +278,49 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit if (loadedCurrentVersionMap.isDefined) { return loadedCurrentVersionMap.get } - val snapshotCurrentVersionMap = readSnapshotFile(version) - if (snapshotCurrentVersionMap.isDefined) { - synchronized { loadedMaps.put(version, snapshotCurrentVersionMap.get) } - return snapshotCurrentVersionMap.get - } - // Find the most recent map before this version that we can. - // [SPARK-22305] This must be done iteratively to avoid stack overflow. - var lastAvailableVersion = version - var lastAvailableMap: Option[MapType] = None - while (lastAvailableMap.isEmpty) { - lastAvailableVersion -= 1 + logWarning(s"The state for version $version doesn't exist in loadedMaps. " + + "Reading snapshot file and delta files if needed..." + + "Note that this is normal for the first batch of starting query.") - if (lastAvailableVersion <= 0) { - // Use an empty map for versions 0 or less. - lastAvailableMap = Some(new MapType) - } else { - lastAvailableMap = - synchronized { loadedMaps.get(lastAvailableVersion) } - .orElse(readSnapshotFile(lastAvailableVersion)) + val (result, elapsedMs) = Utils.timeTakenMs { + val snapshotCurrentVersionMap = readSnapshotFile(version) + if (snapshotCurrentVersionMap.isDefined) { + synchronized { loadedMaps.put(version, snapshotCurrentVersionMap.get) } + return snapshotCurrentVersionMap.get + } + + // Find the most recent map before this version that we can. + // [SPARK-22305] This must be done iteratively to avoid stack overflow. + var lastAvailableVersion = version + var lastAvailableMap: Option[MapType] = None + while (lastAvailableMap.isEmpty) { + lastAvailableVersion -= 1 + + if (lastAvailableVersion <= 0) { + // Use an empty map for versions 0 or less. + lastAvailableMap = Some(new MapType) + } else { + lastAvailableMap = + synchronized { loadedMaps.get(lastAvailableVersion) } + .orElse(readSnapshotFile(lastAvailableVersion)) + } + } + + // Load all the deltas from the version after the last available one up to the target version. + // The last available version is the one with a full snapshot, so it doesn't need deltas. + val resultMap = new MapType(lastAvailableMap.get) + for (deltaVersion <- lastAvailableVersion + 1 to version) { + updateFromDeltaFile(deltaVersion, resultMap) } - } - // Load all the deltas from the version after the last available one up to the target version. - // The last available version is the one with a full snapshot, so it doesn't need deltas. - val resultMap = new MapType(lastAvailableMap.get) - for (deltaVersion <- lastAvailableVersion + 1 to version) { - updateFromDeltaFile(deltaVersion, resultMap) + synchronized { loadedMaps.put(version, resultMap) } + resultMap } - synchronized { loadedMaps.put(version, resultMap) } - resultMap + logDebug(s"Loading state for $version takes $elapsedMs ms.") + + result } private def writeUpdateToDeltaFile( @@ -490,7 +499,9 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit /** Perform a snapshot of the store to allow delta files to be consolidated */ private def doSnapshot(): Unit = { try { - val files = fetchFiles() + val (files, e1) = Utils.timeTakenMs(fetchFiles()) + logDebug(s"fetchFiles() took $e1 ms.") + if (files.nonEmpty) { val lastVersion = files.last.version val deltaFilesForLastVersion = @@ -498,7 +509,8 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit synchronized { loadedMaps.get(lastVersion) } match { case Some(map) => if (deltaFilesForLastVersion.size > storeConf.minDeltasForSnapshot) { - writeSnapshotFile(lastVersion, map) + val (_, e2) = Utils.timeTakenMs(writeSnapshotFile(lastVersion, map)) + logDebug(s"writeSnapshotFile() took $e2 ms.") } case None => // The last map is not loaded, probably some other instance is in charge @@ -517,7 +529,9 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit */ private[state] def cleanup(): Unit = { try { - val files = fetchFiles() + val (files, e1) = Utils.timeTakenMs(fetchFiles()) + logDebug(s"fetchFiles() took $e1 ms.") + if (files.nonEmpty) { val earliestVersionToRetain = files.last.version - storeConf.minVersionsToRetain if (earliestVersionToRetain > 0) { @@ -527,9 +541,12 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit mapsToRemove.foreach(loadedMaps.remove) } val filesToDelete = files.filter(_.version < earliestFileToRetain.version) - filesToDelete.foreach { f => - fm.delete(f.path) + val (_, e2) = Utils.timeTakenMs { + filesToDelete.foreach { f => + fm.delete(f.path) + } } + logDebug(s"deleting files took $e2 ms.") logInfo(s"Deleted files older than ${earliestFileToRetain.version} for $this: " + filesToDelete.mkString(", ")) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 1691a6320a526..6759fb42b4052 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{OutputMode, StateOperatorProgress} import org.apache.spark.sql.types._ -import org.apache.spark.util.{CompletionIterator, NextIterator} +import org.apache.spark.util.{CompletionIterator, NextIterator, Utils} /** Used to identify the state store for a given operator. */ @@ -97,12 +97,7 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan => } /** Records the duration of running `body` for the next query progress update. */ - protected def timeTakenMs(body: => Unit): Long = { - val startTime = System.nanoTime() - val result = body - val endTime = System.nanoTime() - math.max(NANOSECONDS.toMillis(endTime - startTime), 0) - } + protected def timeTakenMs(body: => Unit): Long = Utils.timeTakenMs(body)._2 /** * Set the SQL metrics related to the state store. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index a2aae9a708ff3..40c40e7083d1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1070,6 +1070,17 @@ object functions { @scala.annotation.varargs def map(cols: Column*): Column = withExpr { CreateMap(cols.map(_.expr)) } + /** + * Creates a new map column. The array in the first column is used for keys. The array in the + * second column is used for values. All elements in the array for key should not be null. + * + * @group normal_funcs + * @since 2.4 + */ + def map_from_arrays(keys: Column, values: Column): Column = withExpr { + MapFromArrays(keys.expr, values.expr) + } + /** * Marks a DataFrame as small enough for use in broadcast joins. * @@ -3082,7 +3093,7 @@ object functions { * @since 1.5.0 */ def array_contains(column: Column, value: Any): Column = withExpr { - ArrayContains(column.expr, Literal(value)) + ArrayContains(column.expr, lit(value).expr) } /** @@ -3146,7 +3157,7 @@ object functions { * @since 2.4.0 */ def array_position(column: Column, value: Any): Column = withExpr { - ArrayPosition(column.expr, Literal(value)) + ArrayPosition(column.expr, lit(value).expr) } /** @@ -3157,7 +3168,7 @@ object functions { * @since 2.4.0 */ def element_at(column: Column, value: Any): Column = withExpr { - ElementAt(column.expr, Literal(value)) + ElementAt(column.expr, lit(value).expr) } /** @@ -3175,9 +3186,16 @@ object functions { * @since 2.4.0 */ def array_remove(column: Column, element: Any): Column = withExpr { - ArrayRemove(column.expr, Literal(element)) + ArrayRemove(column.expr, lit(element).expr) } + /** + * Removes duplicate values from the array. + * @group collection_funcs + * @since 2.4.0 + */ + def array_distinct(e: Column): Column = withExpr { ArrayDistinct(e.expr) } + /** * Creates a new row for each element in the given array or map column. * @@ -3358,7 +3376,7 @@ object functions { val dataType = try { DataType.fromJson(schema) } catch { - case NonFatal(_) => StructType.fromDDL(schema) + case NonFatal(_) => DataType.fromDDL(schema) } from_json(e, dataType, options) } @@ -3508,6 +3526,22 @@ object functions { */ def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) } + /** + * Returns a map created from the given array of entries. + * @group collection_funcs + * @since 2.4.0 + */ + def map_from_entries(e: Column): Column = withExpr { MapFromEntries(e.expr) } + + /** + * Returns a merged array of structs in which the N-th struct contains all N-th values of input + * arrays. + * @group collection_funcs + * @since 2.4.0 + */ + @scala.annotation.varargs + def arrays_zip(e: Column*): Column = withExpr { ArraysZip(e.map(_.expr)) } + ////////////////////////////////////////////////////////////////////////////////////////////// // Mask functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 6ae307bce10c8..4698e8ab13ce3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -364,7 +364,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def dropTempView(viewName: String): Boolean = { sparkSession.sessionState.catalog.getTempView(viewName).exists { viewDef => - sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true) + sparkSession.sharedState.cacheManager.uncacheQuery( + sparkSession, viewDef, cascade = false, blocking = true) sessionCatalog.dropTempView(viewName) } } @@ -379,7 +380,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def dropGlobalTempView(viewName: String): Boolean = { sparkSession.sessionState.catalog.getGlobalTempView(viewName).exists { viewDef => - sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true) + sparkSession.sharedState.cacheManager.uncacheQuery( + sparkSession, viewDef, cascade = false, blocking = true) sessionCatalog.dropGlobalTempView(viewName) } } @@ -438,7 +440,9 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 2.0.0 */ override def uncacheTable(tableName: String): Unit = { - sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName)) + val tableIdent = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName) + val cascade = !sessionCatalog.isTemporaryTable(tableIdent) + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName), cascade) } /** @@ -490,7 +494,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { // cached version and make the new version cached lazily. if (isCached(table)) { // Uncache the logicalPlan. - sparkSession.sharedState.cacheManager.uncacheQuery(table, blocking = true) + sparkSession.sharedState.cacheManager.uncacheQuery(table, cascade = true, blocking = true) // Cache it again. sparkSession.sharedState.cacheManager.cacheQuery(table, Some(tableIdent.table)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index ae93965bc50ed..ef8dc3a325a33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -270,6 +270,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * per file *
  • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator * that should be used for parsing.
  • + *
  • `dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or + * empty array/struct during schema inference.
  • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index effc1471e8e12..926c0b69a03fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -21,15 +21,16 @@ import java.util.Locale import scala.collection.JavaConverters._ -import org.apache.spark.annotation.InterfaceStability -import org.apache.spark.sql.{AnalysisException, Dataset, ForeachWriter} +import org.apache.spark.annotation.{InterfaceStability, Since} +import org.apache.spark.api.java.function.VoidFunction2 +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger -import org.apache.spark.sql.execution.streaming.sources.{ForeachWriterProvider, MemoryPlanV2, MemorySinkV2} -import org.apache.spark.sql.sources.v2.StreamWriteSupport +import org.apache.spark.sql.execution.streaming.sources._ +import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -249,7 +250,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { val r = Dataset.ofRows(df.sparkSession, new MemoryPlanV2(s, df.schema.toAttributes)) (s, r) case _ => - val s = new MemorySink(df.schema, outputMode) + val s = new MemorySink(df.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) val r = Dataset.ofRows(df.sparkSession, new MemoryPlan(s)) (s, r) } @@ -269,7 +270,22 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { query } else if (source == "foreach") { assertNotPartitioned("foreach") - val sink = new ForeachWriterProvider[T](foreachWriter)(ds.exprEnc) + val sink = ForeachWriterProvider[T](foreachWriter, ds.exprEnc) + df.sparkSession.sessionState.streamingQueryManager.startQuery( + extraOptions.get("queryName"), + extraOptions.get("checkpointLocation"), + df, + extraOptions.toMap, + sink, + outputMode, + useTempCheckpointLocation = true, + trigger = trigger) + } else if (source == "foreachBatch") { + assertNotPartitioned("foreachBatch") + if (trigger.isInstanceOf[ContinuousTrigger]) { + throw new AnalysisException("'foreachBatch' is not supported with continuous trigger") + } + val sink = new ForeachBatchSink[T](foreachBatchWriter, ds.exprEnc) df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), @@ -307,49 +323,9 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { } /** - * Starts the execution of the streaming query, which will continually send results to the given - * `ForeachWriter` as new data arrives. The `ForeachWriter` can be used to send the data - * generated by the `DataFrame`/`Dataset` to an external system. - * - * Scala example: - * {{{ - * datasetOfString.writeStream.foreach(new ForeachWriter[String] { - * - * def open(partitionId: Long, version: Long): Boolean = { - * // open connection - * } - * - * def process(record: String) = { - * // write string to connection - * } - * - * def close(errorOrNull: Throwable): Unit = { - * // close the connection - * } - * }).start() - * }}} - * - * Java example: - * {{{ - * datasetOfString.writeStream().foreach(new ForeachWriter() { - * - * @Override - * public boolean open(long partitionId, long version) { - * // open connection - * } - * - * @Override - * public void process(String value) { - * // write string to connection - * } - * - * @Override - * public void close(Throwable errorOrNull) { - * // close the connection - * } - * }).start(); - * }}} - * + * Sets the output of the streaming query to be processed using the provided writer object. + * object. See [[org.apache.spark.sql.ForeachWriter]] for more details on the lifecycle and + * semantics. * @since 2.0.0 */ def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = { @@ -362,6 +338,45 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { this } + /** + * :: Experimental :: + * + * (Scala-specific) Sets the output of the streaming query to be processed using the provided + * function. This is supported only the in the micro-batch execution modes (that is, when the + * trigger is not continuous). In every micro-batch, the provided function will be called in + * every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier. + * The batchId can be used deduplicate and transactionally write the output + * (that is, the provided Dataset) to external systems. The output Dataset is guaranteed + * to exactly same for the same batchId (assuming all operations are deterministic in the query). + * + * @since 2.4.0 + */ + @InterfaceStability.Evolving + def foreachBatch(function: (Dataset[T], Long) => Unit): DataStreamWriter[T] = { + this.source = "foreachBatch" + if (function == null) throw new IllegalArgumentException("foreachBatch function cannot be null") + this.foreachBatchWriter = function + this + } + + /** + * :: Experimental :: + * + * (Java-specific) Sets the output of the streaming query to be processed using the provided + * function. This is supported only the in the micro-batch execution modes (that is, when the + * trigger is not continuous). In every micro-batch, the provided function will be called in + * every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier. + * The batchId can be used deduplicate and transactionally write the output + * (that is, the provided Dataset) to external systems. The output Dataset is guaranteed + * to exactly same for the same batchId (assuming all operations are deterministic in the query). + * + * @since 2.4.0 + */ + @InterfaceStability.Evolving + def foreachBatch(function: VoidFunction2[Dataset[T], Long]): DataStreamWriter[T] = { + foreachBatch((batchDs: Dataset[T], batchId: Long) => function.call(batchDs, batchId)) + } + private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols => cols.map(normalize(_, "Partition")) } @@ -398,5 +413,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { private var foreachWriter: ForeachWriter[T] = null + private var foreachBatchWriter: (Dataset[T], Long) => Unit = null + private var partitioningColumns: Option[Seq[String]] = None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 97da2b1325f58..25bb05212d66f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -24,6 +24,7 @@ import scala.collection.mutable import org.apache.hadoop.fs.Path +import org.apache.spark.SparkException import org.apache.spark.annotation.InterfaceStability import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession} @@ -32,6 +33,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger} import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.StaticSQLConf.STREAMING_QUERY_LISTENERS import org.apache.spark.sql.sources.v2.StreamWriteSupport import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -55,6 +57,19 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo @GuardedBy("awaitTerminationLock") private var lastTerminatedQuery: StreamingQuery = null + try { + sparkSession.sparkContext.conf.get(STREAMING_QUERY_LISTENERS).foreach { classNames => + Utils.loadExtensions(classOf[StreamingQueryListener], classNames, + sparkSession.sparkContext.conf).foreach(listener => { + addListener(listener) + logInfo(s"Registered listener ${listener.getClass.getName}") + }) + } + } catch { + case e: Exception => + throw new SparkException("Exception when registering StreamingQueryListener", e) + } + /** * Returns a list of active queries associated with this SQLContext * diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index c132cab1b38cf..2c695fc58fd8c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -34,6 +34,7 @@ import org.junit.*; import org.junit.rules.ExpectedException; +import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.*; import org.apache.spark.sql.*; @@ -336,6 +337,23 @@ public void testTupleEncoder() { Assert.assertEquals(data5, ds5.collectAsList()); } + @Test + public void testTupleEncoderSchema() { + Encoder>> encoder = + Encoders.tuple(Encoders.STRING(), Encoders.tuple(Encoders.STRING(), Encoders.STRING())); + List>> data = Arrays.asList(tuple2("1", tuple2("a", "b")), + tuple2("2", tuple2("c", "d"))); + Dataset ds1 = spark.createDataset(data, encoder).toDF("value1", "value2"); + + JavaPairRDD> pairRDD = jsc.parallelizePairs(data); + Dataset ds2 = spark.createDataset(JavaPairRDD.toRDD(pairRDD), encoder) + .toDF("value1", "value2"); + + Assert.assertEquals(ds1.schema(), ds2.schema()); + Assert.assertEquals(ds1.select(expr("value2._1")).collectAsList(), + ds2.select(expr("value2._1")).collectAsList()); + } + @Test public void testNestedTupleEncoder() { // test ((int, string), string) diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql index fea069eac4d48..dc15d13cd1dd3 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql @@ -31,3 +31,7 @@ CREATE TEMPORARY VIEW jsonTable(jsonField, a) AS SELECT * FROM VALUES ('{"a": 1, SELECT json_tuple(jsonField, 'b', CAST(NULL AS STRING), a) FROM jsonTable; -- Clean up DROP VIEW IF EXISTS jsonTable; + +-- from_json - complex types +select from_json('{"a":1, "b":2}', 'map'); +select from_json('{"a":1, "b":"2"}', 'struct'); diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/arrayJoin.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/arrayJoin.sql new file mode 100644 index 0000000000000..99729c007b104 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/arrayJoin.sql @@ -0,0 +1,11 @@ +SELECT array_join(array(true, false), ', '); +SELECT array_join(array(2Y, 1Y), ', '); +SELECT array_join(array(2S, 1S), ', '); +SELECT array_join(array(2, 1), ', '); +SELECT array_join(array(2L, 1L), ', '); +SELECT array_join(array(9223372036854775809, 9223372036854775808), ', '); +SELECT array_join(array(2.0D, 1.0D), ', '); +SELECT array_join(array(float(2.0), float(1.0)), ', '); +SELECT array_join(array(date '2016-03-14', date '2016-03-13'), ', '); +SELECT array_join(array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), ', '); +SELECT array_join(array('a', 'b'), ', '); diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index 14a69128ffb41..2b3288dc5a137 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 26 +-- Number of queries: 28 -- !query 0 @@ -258,3 +258,19 @@ DROP VIEW IF EXISTS jsonTable struct<> -- !query 25 output + + +-- !query 26 +select from_json('{"a":1, "b":2}', 'map') +-- !query 26 schema +struct> +-- !query 26 output +{"a":1,"b":2} + + +-- !query 27 +select from_json('{"a":1, "b":"2"}', 'struct') +-- !query 27 schema +struct> +-- !query 27 output +{"a":1,"b":"2"} diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/arrayJoin.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/arrayJoin.sql.out new file mode 100644 index 0000000000000..b23a62dacef7c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/arrayJoin.sql.out @@ -0,0 +1,90 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 11 + + +-- !query 0 +SELECT array_join(array(true, false), ', ') +-- !query 0 schema +struct +-- !query 0 output +true, false + + +-- !query 1 +SELECT array_join(array(2Y, 1Y), ', ') +-- !query 1 schema +struct +-- !query 1 output +2, 1 + + +-- !query 2 +SELECT array_join(array(2S, 1S), ', ') +-- !query 2 schema +struct +-- !query 2 output +2, 1 + + +-- !query 3 +SELECT array_join(array(2, 1), ', ') +-- !query 3 schema +struct +-- !query 3 output +2, 1 + + +-- !query 4 +SELECT array_join(array(2L, 1L), ', ') +-- !query 4 schema +struct +-- !query 4 output +2, 1 + + +-- !query 5 +SELECT array_join(array(9223372036854775809, 9223372036854775808), ', ') +-- !query 5 schema +struct +-- !query 5 output +9223372036854775809, 9223372036854775808 + + +-- !query 6 +SELECT array_join(array(2.0D, 1.0D), ', ') +-- !query 6 schema +struct +-- !query 6 output +2.0, 1.0 + + +-- !query 7 +SELECT array_join(array(float(2.0), float(1.0)), ', ') +-- !query 7 schema +struct +-- !query 7 output +2.0, 1.0 + + +-- !query 8 +SELECT array_join(array(date '2016-03-14', date '2016-03-13'), ', ') +-- !query 8 schema +struct +-- !query 8 output +2016-03-14, 2016-03-13 + + +-- !query 9 +SELECT array_join(array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), ', ') +-- !query 9 schema +struct +-- !query 9 output +2016-11-15 20:54:00, 2016-11-12 20:54:00 + + +-- !query 10 +SELECT array_join(array('a', 'b'), ', ') +-- !query 10 schema +struct +-- !query 10 output +a, b diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 81b7e18773f81..60c73df88896b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec -import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.storage.{RDDBlockId, StorageLevel} @@ -83,25 +82,6 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext }.sum } - test("withColumn doesn't invalidate cached dataframe") { - var evalCount = 0 - val myUDF = udf((x: String) => { evalCount += 1; "result" }) - val df = Seq(("test", 1)).toDF("s", "i").select(myUDF($"s")) - df.cache() - - df.collect() - assert(evalCount === 1) - - df.collect() - assert(evalCount === 1) - - val df2 = df.withColumn("newColumn", lit(1)) - df2.collect() - - // We should not reevaluate the cached dataframe - assert(evalCount === 1) - } - test("cache temp table") { withTempView("tempTable") { testData.select('key).createOrReplaceTempView("tempTable") @@ -820,4 +800,69 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext } assert(cachedData.collect === Seq(1001)) } + + test("SPARK-24596 Non-cascading Cache Invalidation - uncache temporary view") { + withTempView("t1", "t2") { + sql("CACHE TABLE t1 AS SELECT * FROM testData WHERE key > 1") + sql("CACHE TABLE t2 as SELECT * FROM t1 WHERE value > 1") + + assert(spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) + sql("UNCACHE TABLE t1") + assert(!spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) + } + } + + test("SPARK-24596 Non-cascading Cache Invalidation - drop temporary view") { + withTempView("t1", "t2") { + sql("CACHE TABLE t1 AS SELECT * FROM testData WHERE key > 1") + sql("CACHE TABLE t2 as SELECT * FROM t1 WHERE value > 1") + + assert(spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) + sql("DROP VIEW t1") + assert(spark.catalog.isCached("t2")) + } + } + + test("SPARK-24596 Non-cascading Cache Invalidation - drop persistent view") { + withTable("t") { + spark.range(1, 10).toDF("key").withColumn("value", 'key * 2) + .write.format("json").saveAsTable("t") + withView("t1") { + withTempView("t2") { + sql("CREATE VIEW t1 AS SELECT * FROM t WHERE key > 1") + + sql("CACHE TABLE t1") + sql("CACHE TABLE t2 AS SELECT * FROM t1 WHERE value > 1") + + assert(spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) + sql("DROP VIEW t1") + assert(!spark.catalog.isCached("t2")) + } + } + } + } + + test("SPARK-24596 Non-cascading Cache Invalidation - uncache table") { + withTable("t") { + spark.range(1, 10).toDF("key").withColumn("value", 'key * 2) + .write.format("json").saveAsTable("t") + withTempView("t1", "t2") { + sql("CACHE TABLE t") + sql("CACHE TABLE t1 AS SELECT * FROM t WHERE key > 1") + sql("CACHE TABLE t2 AS SELECT * FROM t1 WHERE value > 1") + + assert(spark.catalog.isCached("t")) + assert(spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) + sql("UNCACHE TABLE t") + assert(!spark.catalog.isCached("t")) + assert(!spark.catalog.isCached("t1")) + assert(!spark.catalog.isCached("t2")) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 59119bbbd8a2c..5d6a6c0832c96 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -62,6 +62,36 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(row.getMap[Int, String](0) === Map(2 -> "a")) } + test("map with arrays") { + val df1 = Seq((Seq(1, 2), Seq("a", "b"))).toDF("k", "v") + val expectedType = MapType(IntegerType, StringType, valueContainsNull = true) + val row = df1.select(map_from_arrays($"k", $"v")).first() + assert(row.schema(0).dataType === expectedType) + assert(row.getMap[Int, String](0) === Map(1 -> "a", 2 -> "b")) + checkAnswer(df1.select(map_from_arrays($"k", $"v")), Seq(Row(Map(1 -> "a", 2 -> "b")))) + + val df2 = Seq((Seq(1, 2), Seq(null, "b"))).toDF("k", "v") + checkAnswer(df2.select(map_from_arrays($"k", $"v")), Seq(Row(Map(1 -> null, 2 -> "b")))) + + val df3 = Seq((null, null)).toDF("k", "v") + checkAnswer(df3.select(map_from_arrays($"k", $"v")), Seq(Row(null))) + + val df4 = Seq((1, "a")).toDF("k", "v") + intercept[AnalysisException] { + df4.select(map_from_arrays($"k", $"v")) + } + + val df5 = Seq((Seq("a", null), Seq(1, 2))).toDF("k", "v") + intercept[RuntimeException] { + df5.select(map_from_arrays($"k", $"v")).collect + } + + val df6 = Seq((Seq(1, 2), Seq("a"))).toDF("k", "v") + intercept[RuntimeException] { + df6.select(map_from_arrays($"k", $"v")).collect + } + } + test("struct with column name") { val df = Seq((1, "str")).toDF("a", "b") val row = df.select(struct("a", "b")).first() @@ -479,6 +509,64 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("dataframe arrays_zip function") { + val df1 = Seq((Seq(9001, 9002, 9003), Seq(4, 5, 6))).toDF("val1", "val2") + val df2 = Seq((Seq("a", "b"), Seq(true, false), Seq(10, 11))).toDF("val1", "val2", "val3") + val df3 = Seq((Seq("a", "b"), Seq(4, 5, 6))).toDF("val1", "val2") + val df4 = Seq((Seq("a", "b", null), Seq(4L))).toDF("val1", "val2") + val df5 = Seq((Seq(-1), Seq(null), Seq(), Seq(null, null))).toDF("val1", "val2", "val3", "val4") + val df6 = Seq((Seq(192.toByte, 256.toByte), Seq(1.1), Seq(), Seq(null, null))) + .toDF("v1", "v2", "v3", "v4") + val df7 = Seq((Seq(Seq(1, 2, 3), Seq(4, 5)), Seq(1.1, 2.2))).toDF("v1", "v2") + val df8 = Seq((Seq(Array[Byte](1.toByte, 5.toByte)), Seq(null))).toDF("v1", "v2") + + val expectedValue1 = Row(Seq(Row(9001, 4), Row(9002, 5), Row(9003, 6))) + checkAnswer(df1.select(arrays_zip($"val1", $"val2")), expectedValue1) + checkAnswer(df1.selectExpr("arrays_zip(val1, val2)"), expectedValue1) + + val expectedValue2 = Row(Seq(Row("a", true, 10), Row("b", false, 11))) + checkAnswer(df2.select(arrays_zip($"val1", $"val2", $"val3")), expectedValue2) + checkAnswer(df2.selectExpr("arrays_zip(val1, val2, val3)"), expectedValue2) + + val expectedValue3 = Row(Seq(Row("a", 4), Row("b", 5), Row(null, 6))) + checkAnswer(df3.select(arrays_zip($"val1", $"val2")), expectedValue3) + checkAnswer(df3.selectExpr("arrays_zip(val1, val2)"), expectedValue3) + + val expectedValue4 = Row(Seq(Row("a", 4L), Row("b", null), Row(null, null))) + checkAnswer(df4.select(arrays_zip($"val1", $"val2")), expectedValue4) + checkAnswer(df4.selectExpr("arrays_zip(val1, val2)"), expectedValue4) + + val expectedValue5 = Row(Seq(Row(-1, null, null, null), Row(null, null, null, null))) + checkAnswer(df5.select(arrays_zip($"val1", $"val2", $"val3", $"val4")), expectedValue5) + checkAnswer(df5.selectExpr("arrays_zip(val1, val2, val3, val4)"), expectedValue5) + + val expectedValue6 = Row(Seq( + Row(192.toByte, 1.1, null, null), Row(256.toByte, null, null, null))) + checkAnswer(df6.select(arrays_zip($"v1", $"v2", $"v3", $"v4")), expectedValue6) + checkAnswer(df6.selectExpr("arrays_zip(v1, v2, v3, v4)"), expectedValue6) + + val expectedValue7 = Row(Seq( + Row(Seq(1, 2, 3), 1.1), Row(Seq(4, 5), 2.2))) + checkAnswer(df7.select(arrays_zip($"v1", $"v2")), expectedValue7) + checkAnswer(df7.selectExpr("arrays_zip(v1, v2)"), expectedValue7) + + val expectedValue8 = Row(Seq( + Row(Array[Byte](1.toByte, 5.toByte), null))) + checkAnswer(df8.select(arrays_zip($"v1", $"v2")), expectedValue8) + checkAnswer(df8.selectExpr("arrays_zip(v1, v2)"), expectedValue8) + } + + test("SPARK-24633: arrays_zip splits input processing correctly") { + Seq("true", "false").foreach { wholestageCodegenEnabled => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> wholestageCodegenEnabled) { + val df = spark.range(1) + val exprs = (0 to 5).map(x => array($"id" + lit(x))) + checkAnswer(df.select(arrays_zip(exprs: _*)), + Row(Seq(Row(0, 1, 2, 3, 4, 5)))) + } + } + } + test("map size function") { val df = Seq( (Map[Int, Int](1 -> 1, 2 -> 2), "x"), @@ -556,11 +644,61 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(sdf.filter(dummyFilter('m)).select(map_entries('m)), sExpected) } + test("map_from_entries function") { + def dummyFilter(c: Column): Column = c.isNull || c.isNotNull + val oneRowDF = Seq(3215).toDF("i") + + // Test cases with primitive-type keys and values + val idf = Seq( + Seq((1, 10), (2, 20), (3, 10)), + Seq((1, 10), null, (2, 20)), + Seq.empty, + null + ).toDF("a") + val iExpected = Seq( + Row(Map(1 -> 10, 2 -> 20, 3 -> 10)), + Row(null), + Row(Map.empty), + Row(null)) + + checkAnswer(idf.select(map_from_entries('a)), iExpected) + checkAnswer(idf.selectExpr("map_from_entries(a)"), iExpected) + checkAnswer(idf.filter(dummyFilter('a)).select(map_from_entries('a)), iExpected) + checkAnswer( + oneRowDF.selectExpr("map_from_entries(array(struct(1, null), struct(2, null)))"), + Seq(Row(Map(1 -> null, 2 -> null))) + ) + checkAnswer( + oneRowDF.filter(dummyFilter('i)) + .selectExpr("map_from_entries(array(struct(1, null), struct(2, null)))"), + Seq(Row(Map(1 -> null, 2 -> null))) + ) + + // Test cases with non-primitive-type keys and values + val sdf = Seq( + Seq(("a", "aa"), ("b", "bb"), ("c", "aa")), + Seq(("a", "aa"), null, ("b", "bb")), + Seq(("a", null), ("b", null)), + Seq.empty, + null + ).toDF("a") + val sExpected = Seq( + Row(Map("a" -> "aa", "b" -> "bb", "c" -> "aa")), + Row(null), + Row(Map("a" -> null, "b" -> null)), + Row(Map.empty), + Row(null)) + + checkAnswer(sdf.select(map_from_entries('a)), sExpected) + checkAnswer(sdf.selectExpr("map_from_entries(a)"), sExpected) + checkAnswer(sdf.filter(dummyFilter('a)).select(map_from_entries('a)), sExpected) + } + test("array contains function") { val df = Seq( - (Seq[Int](1, 2), "x"), - (Seq[Int](), "x") - ).toDF("a", "b") + (Seq[Int](1, 2), "x", 1), + (Seq[Int](), "x", 1) + ).toDF("a", "b", "c") // Simple test cases checkAnswer( @@ -571,6 +709,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("array_contains(a, 1)"), Seq(Row(true), Row(false)) ) + checkAnswer( + df.select(array_contains(df("a"), df("c"))), + Seq(Row(true), Row(false)) + ) + checkAnswer( + df.selectExpr("array_contains(a, c)"), + Seq(Row(true), Row(false)) + ) // In hive, this errors because null has no type information intercept[AnalysisException] { @@ -659,6 +805,23 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer( df.selectExpr("array_join(x, delimiter, 'NULL')"), Seq(Row("a,b"), Row("a,NULL,b"), Row(""))) + + val idf = Seq(Seq(1, 2, 3)).toDF("x") + + checkAnswer( + idf.select(array_join(idf("x"), ", ")), + Seq(Row("1, 2, 3")) + ) + checkAnswer( + idf.selectExpr("array_join(x, ', ')"), + Seq(Row("1, 2, 3")) + ) + intercept[AnalysisException] { + idf.selectExpr("array_join(x, 1)") + } + intercept[AnalysisException] { + idf.selectExpr("array_join(x, ', ', 1)") + } } test("array_min function") { @@ -785,9 +948,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("array position function") { val df = Seq( - (Seq[Int](1, 2), "x"), - (Seq[Int](), "x") - ).toDF("a", "b") + (Seq[Int](1, 2), "x", 1), + (Seq[Int](), "x", 1) + ).toDF("a", "b", "c") checkAnswer( df.select(array_position(df("a"), 1)), @@ -797,7 +960,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("array_position(a, 1)"), Seq(Row(1L), Row(0L)) ) - + checkAnswer( + df.selectExpr("array_position(a, c)"), + Seq(Row(1L), Row(0L)) + ) + checkAnswer( + df.select(array_position(df("a"), df("c"))), + Seq(Row(1L), Row(0L)) + ) checkAnswer( df.select(array_position(df("a"), null)), Seq(Row(null), Row(null)) @@ -824,10 +994,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("element_at function") { val df = Seq( - (Seq[String]("1", "2", "3")), - (Seq[String](null, "")), - (Seq[String]()) - ).toDF("a") + (Seq[String]("1", "2", "3"), 1), + (Seq[String](null, ""), -1), + (Seq[String](), 2) + ).toDF("a", "b") intercept[Exception] { checkAnswer( @@ -845,6 +1015,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.select(element_at(df("a"), 4)), Seq(Row(null), Row(null), Row(null)) ) + checkAnswer( + df.select(element_at(df("a"), df("b"))), + Seq(Row("1"), Row(""), Row(null)) + ) + checkAnswer( + df.selectExpr("element_at(a, b)"), + Seq(Row("1"), Row(""), Row(null)) + ) checkAnswer( df.select(element_at(df("a"), 1)), @@ -1112,10 +1290,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("array remove") { val df = Seq( - (Array[Int](2, 1, 2, 3), Array("a", "b", "c", "a"), Array("", "")), - (Array.empty[Int], Array.empty[String], Array.empty[String]), - (null, null, null) - ).toDF("a", "b", "c") + (Array[Int](2, 1, 2, 3), Array("a", "b", "c", "a"), Array("", ""), 2), + (Array.empty[Int], Array.empty[String], Array.empty[String], 2), + (null, null, null, 2) + ).toDF("a", "b", "c", "d") checkAnswer( df.select(array_remove($"a", 2), array_remove($"b", "a"), array_remove($"c", "")), Seq( @@ -1124,6 +1302,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null, null, null)) ) + checkAnswer( + df.select(array_remove($"a", $"d")), + Seq( + Row(Seq(1, 3)), + Row(Seq.empty[Int]), + Row(null)) + ) + + checkAnswer( + df.selectExpr("array_remove(a, d)"), + Seq( + Row(Seq(1, 3)), + Row(Seq.empty[Int]), + Row(null)) + ) + checkAnswer( df.selectExpr("array_remove(a, 2)", "array_remove(b, \"a\")", "array_remove(c, \"\")"), @@ -1139,6 +1333,28 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(e.message.contains("argument 1 requires array type, however, '`_1`' is of string type")) } + test("array_distinct functions") { + val df = Seq( + (Array[Int](2, 1, 3, 4, 3, 5), Array("b", "c", "a", "c", "b", "", "")), + (Array.empty[Int], Array.empty[String]), + (null, null) + ).toDF("a", "b") + checkAnswer( + df.select(array_distinct($"a"), array_distinct($"b")), + Seq( + Row(Seq(2, 1, 3, 4, 5), Seq("b", "c", "a", "")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) + checkAnswer( + df.selectExpr("array_distinct(a)", "array_distinct(b)"), + Seq( + Row(Seq(2, 1, 3, 4, 5), Seq("b", "c", "a", "")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 3ea398aad7375..97a843978f0bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql -import java.sql.{Date, Timestamp} - -import scala.collection.mutable +import org.scalatest.Matchers.the import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} @@ -27,7 +25,6 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval /** * Window function testing for DataFrame API. @@ -624,4 +621,41 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-24575: Window functions inside WHERE and HAVING clauses") { + def checkAnalysisError(df: => DataFrame): Unit = { + val thrownException = the [AnalysisException] thrownBy { + df.queryExecution.analyzed + } + assert(thrownException.message.contains("window functions inside WHERE and HAVING clauses")) + } + + checkAnalysisError(testData2.select('a).where(rank().over(Window.orderBy('b)) === 1)) + checkAnalysisError(testData2.where('b === 2 && rank().over(Window.orderBy('b)) === 1)) + checkAnalysisError( + testData2.groupBy('a) + .agg(avg('b).as("avgb")) + .where('a > 'avgb && rank().over(Window.orderBy('a)) === 1)) + checkAnalysisError( + testData2.groupBy('a) + .agg(max('b).as("maxb"), sum('b).as("sumb")) + .where(rank().over(Window.orderBy('a)) === 1)) + checkAnalysisError( + testData2.groupBy('a) + .agg(max('b).as("maxb"), sum('b).as("sumb")) + .where('sumb === 5 && rank().over(Window.orderBy('a)) === 1)) + + checkAnalysisError(sql("SELECT a FROM testData2 WHERE RANK() OVER(ORDER BY b) = 1")) + checkAnalysisError(sql("SELECT * FROM testData2 WHERE b = 2 AND RANK() OVER(ORDER BY b) = 1")) + checkAnalysisError( + sql("SELECT * FROM testData2 GROUP BY a HAVING a > AVG(b) AND RANK() OVER(ORDER BY a) = 1")) + checkAnalysisError( + sql("SELECT a, MAX(b), SUM(b) FROM testData2 GROUP BY a HAVING RANK() OVER(ORDER BY a) = 1")) + checkAnalysisError( + sql( + s"""SELECT a, MAX(b) + |FROM testData2 + |GROUP BY a + |HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index e0561ee2797a5..5c6a021d5b767 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -17,14 +17,28 @@ package org.apache.spark.sql +import org.scalatest.concurrent.TimeLimits +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.storage.StorageLevel -class DatasetCacheSuite extends QueryTest with SharedSQLContext { +class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits { import testImplicits._ + /** + * Asserts that a cached [[Dataset]] will be built using the given number of other cached results. + */ + private def assertCacheDependency(df: DataFrame, numOfCachesDependedUpon: Int = 1): Unit = { + val plan = df.queryExecution.withCachedData + assert(plan.isInstanceOf[InMemoryRelation]) + val internalPlan = plan.asInstanceOf[InMemoryRelation].cacheBuilder.cachedPlan + assert(internalPlan.find(_.isInstanceOf[InMemoryTableScanExec]).size == numOfCachesDependedUpon) + } + test("get storage level") { val ds1 = Seq("1", "2").toDS().as("a") val ds2 = Seq(2, 3).toDS().as("b") @@ -96,4 +110,100 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { agged.unpersist() assert(agged.storageLevel == StorageLevel.NONE, "The Dataset agged should not be cached.") } + + test("persist and then withColumn") { + val df = Seq(("test", 1)).toDF("s", "i") + val df2 = df.withColumn("newColumn", lit(1)) + + df.cache() + assertCached(df) + assertCached(df2) + + df.count() + assertCached(df2) + + df.unpersist() + assert(df.storageLevel == StorageLevel.NONE) + } + + test("cache UDF result correctly") { + val expensiveUDF = udf({x: Int => Thread.sleep(5000); x}) + val df = spark.range(0, 10).toDF("a").withColumn("b", expensiveUDF($"a")) + val df2 = df.agg(sum(df("b"))) + + df.cache() + df.count() + assertCached(df2) + + // udf has been evaluated during caching, and thus should not be re-evaluated here + failAfter(3 seconds) { + df2.collect() + } + + df.unpersist() + assert(df.storageLevel == StorageLevel.NONE) + } + + test("SPARK-24613 Cache with UDF could not be matched with subsequent dependent caches") { + val udf1 = udf({x: Int => x + 1}) + val df = spark.range(0, 10).toDF("a").withColumn("b", udf1($"a")) + val df2 = df.agg(sum(df("b"))) + + df.cache() + df.count() + df2.cache() + + assertCacheDependency(df2) + } + + test("SPARK-24596 Non-cascading Cache Invalidation") { + val df = Seq(("a", 1), ("b", 2)).toDF("s", "i") + val df2 = df.filter('i > 1) + val df3 = df.filter('i < 2) + + df2.cache() + df.cache() + df.count() + df3.cache() + + df.unpersist() + + // df un-cached; df2 and df3's cache plan re-compiled + assert(df.storageLevel == StorageLevel.NONE) + assertCacheDependency(df2, 0) + assertCacheDependency(df3, 0) + } + + test("SPARK-24596 Non-cascading Cache Invalidation - verify cached data reuse") { + val expensiveUDF = udf({ x: Int => Thread.sleep(5000); x }) + val df = spark.range(0, 5).toDF("a") + val df1 = df.withColumn("b", expensiveUDF($"a")) + val df2 = df1.groupBy('a).agg(sum('b)) + val df3 = df.agg(sum('a)) + + df1.cache() + df2.cache() + df2.collect() + df3.cache() + + assertCacheDependency(df2) + + df1.unpersist(blocking = true) + + // df1 un-cached; df2's cache plan re-compiled + assert(df1.storageLevel == StorageLevel.NONE) + assertCacheDependency(df1.groupBy('a).agg(sum('b)), 0) + + val df4 = df1.groupBy('a).agg(sum('b)).agg(sum("sum(b)")) + assertCached(df4) + // reuse loaded cache + failAfter(3 seconds) { + checkDataset(df4, Row(10)) + } + + val df5 = df.agg(sum('a)).filter($"sum(a)" > 1) + assertCached(df5) + // first time use, load cache + checkDataset(df5, Row(10)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index d477d78dc14e3..2d20c50584c03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1466,6 +1466,27 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds = Seq[(Option[Int], Option[Int])]((Some(1), None)).toDS() intercept[NullPointerException](ds.as[(Int, Int)].collect()) } + + test("SPARK-24548: Dataset with tuple encoders should have correct schema") { + val encoder = Encoders.tuple(newStringEncoder, + Encoders.tuple(newStringEncoder, newStringEncoder)) + + val data = Seq(("a", ("1", "2")), ("b", ("3", "4"))) + val rdd = sparkContext.parallelize(data) + + val ds1 = spark.createDataset(rdd) + val ds2 = spark.createDataset(rdd)(encoder) + assert(ds1.schema == ds2.schema) + checkDataset(ds1.select("_2._2"), ds2.select("_2._2").collect(): _*) + } + + test("SPARK-24571: filtering of string values by char literal") { + val df = Seq("Amsterdam", "San Francisco", "X").toDF("city") + checkAnswer(df.where('city === 'X'), Seq(Row("X"))) + checkAnswer( + df.where($"city".contains(new java.lang.Character('A'))), + Seq(Row("Amsterdam"))) + } } case class TestDataUnion(x: Int, y: Int, z: Int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala deleted file mode 100644 index c6dd7dadc9d93..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala +++ /dev/null @@ -1,243 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.io.File - -import scala.util.{Random, Try} - -import org.apache.spark.SparkConf -import org.apache.spark.sql.functions.monotonically_increasing_id -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.{Benchmark, Utils} - - -/** - * Benchmark to measure read performance with Filter pushdown. - */ -object FilterPushdownBenchmark { - val conf = new SparkConf() - conf.set("orc.compression", "snappy") - conf.set("spark.sql.parquet.compression.codec", "snappy") - - private val spark = SparkSession.builder() - .master("local[1]") - .appName("FilterPushdownBenchmark") - .config(conf) - .getOrCreate() - - def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - - def withTempTable(tableNames: String*)(f: => Unit): Unit = { - try f finally tableNames.foreach(spark.catalog.dropTempView) - } - - def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) - (keys, values).zipped.foreach(spark.conf.set) - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => spark.conf.set(key, value) - case (key, None) => spark.conf.unset(key) - } - } - } - - private def prepareTable(dir: File, numRows: Int, width: Int): Unit = { - import spark.implicits._ - val selectExpr = (1 to width).map(i => s"CAST(value AS STRING) c$i") - val df = spark.range(numRows).map(_ => Random.nextLong).selectExpr(selectExpr: _*) - .withColumn("id", monotonically_increasing_id()) - - val dirORC = dir.getCanonicalPath + "/orc" - val dirParquet = dir.getCanonicalPath + "/parquet" - - df.write.mode("overwrite").orc(dirORC) - df.write.mode("overwrite").parquet(dirParquet) - - spark.read.orc(dirORC).createOrReplaceTempView("orcTable") - spark.read.parquet(dirParquet).createOrReplaceTempView("parquetTable") - } - - def filterPushDownBenchmark( - values: Int, - title: String, - whereExpr: String, - selectExpr: String = "*"): Unit = { - val benchmark = new Benchmark(title, values, minNumIters = 5) - - Seq(false, true).foreach { pushDownEnabled => - val name = s"Parquet Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" - benchmark.addCase(name) { _ => - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { - spark.sql(s"SELECT $selectExpr FROM parquetTable WHERE $whereExpr").collect() - } - } - } - - Seq(false, true).foreach { pushDownEnabled => - val name = s"Native ORC Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" - benchmark.addCase(name) { _ => - withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { - spark.sql(s"SELECT $selectExpr FROM orcTable WHERE $whereExpr").collect() - } - } - } - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 - Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz - - Select 0 row (id IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7882 / 7957 2.0 501.1 1.0X - Parquet Vectorized (Pushdown) 55 / 60 285.2 3.5 142.9X - Native ORC Vectorized 5592 / 5627 2.8 355.5 1.4X - Native ORC Vectorized (Pushdown) 66 / 70 237.2 4.2 118.9X - - Select 0 row (7864320 < id < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7884 / 7909 2.0 501.2 1.0X - Parquet Vectorized (Pushdown) 739 / 752 21.3 47.0 10.7X - Native ORC Vectorized 5614 / 5646 2.8 356.9 1.4X - Native ORC Vectorized (Pushdown) 81 / 83 195.2 5.1 97.8X - - Select 1 row (id = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7905 / 8027 2.0 502.6 1.0X - Parquet Vectorized (Pushdown) 740 / 766 21.2 47.1 10.7X - Native ORC Vectorized 5684 / 5738 2.8 361.4 1.4X - Native ORC Vectorized (Pushdown) 78 / 81 202.4 4.9 101.7X - - Select 1 row (id <=> 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7928 / 7993 2.0 504.1 1.0X - Parquet Vectorized (Pushdown) 747 / 772 21.0 47.5 10.6X - Native ORC Vectorized 5728 / 5753 2.7 364.2 1.4X - Native ORC Vectorized (Pushdown) 76 / 78 207.9 4.8 104.8X - - Select 1 row (7864320 <= id <= 7864320):Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7939 / 8021 2.0 504.8 1.0X - Parquet Vectorized (Pushdown) 746 / 770 21.1 47.4 10.6X - Native ORC Vectorized 5690 / 5734 2.8 361.7 1.4X - Native ORC Vectorized (Pushdown) 76 / 79 206.7 4.8 104.3X - - Select 1 row (7864319 < id < 7864321): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7972 / 8019 2.0 506.9 1.0X - Parquet Vectorized (Pushdown) 742 / 764 21.2 47.2 10.7X - Native ORC Vectorized 5704 / 5743 2.8 362.6 1.4X - Native ORC Vectorized (Pushdown) 76 / 78 207.9 4.8 105.4X - - Select 10% rows (id < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 8733 / 8808 1.8 555.2 1.0X - Parquet Vectorized (Pushdown) 2213 / 2267 7.1 140.7 3.9X - Native ORC Vectorized 6420 / 6463 2.4 408.2 1.4X - Native ORC Vectorized (Pushdown) 1313 / 1331 12.0 83.5 6.7X - - Select 50% rows (id < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 11518 / 11591 1.4 732.3 1.0X - Parquet Vectorized (Pushdown) 7962 / 7991 2.0 506.2 1.4X - Native ORC Vectorized 8927 / 8985 1.8 567.6 1.3X - Native ORC Vectorized (Pushdown) 6102 / 6160 2.6 387.9 1.9X - - Select 90% rows (id < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 14255 / 14389 1.1 906.3 1.0X - Parquet Vectorized (Pushdown) 13564 / 13594 1.2 862.4 1.1X - Native ORC Vectorized 11442 / 11608 1.4 727.5 1.2X - Native ORC Vectorized (Pushdown) 10991 / 11029 1.4 698.8 1.3X - - Select all rows (id IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 14917 / 14938 1.1 948.4 1.0X - Parquet Vectorized (Pushdown) 14910 / 14964 1.1 948.0 1.0X - Native ORC Vectorized 11986 / 12069 1.3 762.0 1.2X - Native ORC Vectorized (Pushdown) 12037 / 12123 1.3 765.3 1.2X - - Select all rows (id > -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 14951 / 14976 1.1 950.6 1.0X - Parquet Vectorized (Pushdown) 14934 / 15016 1.1 949.5 1.0X - Native ORC Vectorized 12000 / 12156 1.3 763.0 1.2X - Native ORC Vectorized (Pushdown) 12079 / 12113 1.3 767.9 1.2X - - Select all rows (id != -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 14930 / 14972 1.1 949.3 1.0X - Parquet Vectorized (Pushdown) 15015 / 15047 1.0 954.6 1.0X - Native ORC Vectorized 12090 / 12259 1.3 768.7 1.2X - Native ORC Vectorized (Pushdown) 12021 / 12096 1.3 764.2 1.2X - */ - benchmark.run() - } - - def main(args: Array[String]): Unit = { - val numRows = 1024 * 1024 * 15 - val width = 5 - val mid = numRows / 2 - - withTempPath { dir => - withTempTable("orcTable", "patquetTable") { - prepareTable(dir, numRows, width) - - Seq("id IS NULL", s"$mid < id AND id < $mid").foreach { whereExpr => - val title = s"Select 0 row ($whereExpr)".replace("id AND id", "id") - filterPushDownBenchmark(numRows, title, whereExpr) - } - - Seq( - s"id = $mid", - s"id <=> $mid", - s"$mid <= id AND id <= $mid", - s"${mid - 1} < id AND id < ${mid + 1}" - ).foreach { whereExpr => - val title = s"Select 1 row ($whereExpr)".replace("id AND id", "id") - filterPushDownBenchmark(numRows, title, whereExpr) - } - - val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(id)") - - Seq(10, 50, 90).foreach { percent => - filterPushDownBenchmark( - numRows, - s"Select $percent% rows (id < ${numRows * percent / 100})", - s"id < ${numRows * percent / 100}", - selectExpr - ) - } - - Seq("id IS NOT NULL", "id > -1", "id != -1").foreach { whereExpr => - filterPushDownBenchmark( - numRows, - s"Select all rows ($whereExpr)", - whereExpr, - selectExpr) - } - } - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 8fa747465cb1a..44767dfc92497 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -882,4 +882,15 @@ class JoinSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Row(3, 8, 7, 2) :: Row(3, 8, 4, 2) :: Nil) } } + + test("SPARK-24495: Join may return wrong result when having duplicated equal-join keys") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", + SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df1 = spark.range(0, 100, 1, 2) + val df2 = spark.range(100).select($"id".as("b1"), (- $"id").as("b2")) + val res = df1.join(df2, $"id" === $"b1" && $"id" === $"b2").select($"b1", $"b2", $"id") + checkAnswer(res, Row(0, 0, 0)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 055e1fc5640f3..7bf17cbcd9c97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -354,8 +354,8 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { test("SPARK-24027: from_json - map>") { val in = Seq("""{"a": {"b": 1}}""").toDS() - val schema = MapType(StringType, MapType(StringType, IntegerType)) - val out = in.select(from_json($"value", schema)) + val schema = "map>" + val out = in.select(from_json($"value", schema, Map.empty[String, String])) checkAnswer(out, Row(Map("a" -> Map("b" -> 1)))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 98a50fbd52b4d..d254345e8fa54 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -18,13 +18,13 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{execution, Row} +import org.apache.spark.sql.{execution, DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, Repartition, Sort} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, Repartition, Sort, Union} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ @@ -679,6 +679,90 @@ class PlannerSuite extends SharedSQLContext { } assert(rangeExecInZeroPartition.head.outputPartitioning == UnknownPartitioning(0)) } + + test("SPARK-24495: EnsureRequirements can return wrong plan when reusing the same key in join") { + val plan1 = DummySparkPlan(outputOrdering = Seq(orderingA), + outputPartitioning = HashPartitioning(exprA :: exprA :: Nil, 5)) + val plan2 = DummySparkPlan(outputOrdering = Seq(orderingB), + outputPartitioning = HashPartitioning(exprB :: Nil, 5)) + val smjExec = SortMergeJoinExec( + exprA :: exprA :: Nil, exprB :: exprC :: Nil, Inner, None, plan1, plan2) + + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec) + outputPlan match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, _, _) => + assert(leftKeys == Seq(exprA, exprA)) + assert(rightKeys == Seq(exprB, exprC)) + case _ => fail() + } + } + + test("SPARK-24500: create union with stream of children") { + val df = Union(Stream( + Range(1, 1, 1, 1), + Range(1, 2, 1, 1))) + df.queryExecution.executedPlan.execute() + } + + test("SPARK-24556: always rewrite output partitioning in ReusedExchangeExec " + + "and InMemoryTableScanExec") { + def checkOutputPartitioningRewrite( + plans: Seq[SparkPlan], + expectedPartitioningClass: Class[_]): Unit = { + assert(plans.size == 1) + val plan = plans.head + val partitioning = plan.outputPartitioning + assert(partitioning.getClass == expectedPartitioningClass) + val partitionedAttrs = partitioning.asInstanceOf[Expression].references + assert(partitionedAttrs.subsetOf(plan.outputSet)) + } + + def checkReusedExchangeOutputPartitioningRewrite( + df: DataFrame, + expectedPartitioningClass: Class[_]): Unit = { + val reusedExchange = df.queryExecution.executedPlan.collect { + case r: ReusedExchangeExec => r + } + checkOutputPartitioningRewrite(reusedExchange, expectedPartitioningClass) + } + + def checkInMemoryTableScanOutputPartitioningRewrite( + df: DataFrame, + expectedPartitioningClass: Class[_]): Unit = { + val inMemoryScan = df.queryExecution.executedPlan.collect { + case m: InMemoryTableScanExec => m + } + checkOutputPartitioningRewrite(inMemoryScan, expectedPartitioningClass) + } + + // ReusedExchange is HashPartitioning + val df1 = Seq(1 -> "a").toDF("i", "j").repartition($"i") + val df2 = Seq(1 -> "a").toDF("i", "j").repartition($"i") + checkReusedExchangeOutputPartitioningRewrite(df1.union(df2), classOf[HashPartitioning]) + + // ReusedExchange is RangePartitioning + val df3 = Seq(1 -> "a").toDF("i", "j").orderBy($"i") + val df4 = Seq(1 -> "a").toDF("i", "j").orderBy($"i") + checkReusedExchangeOutputPartitioningRewrite(df3.union(df4), classOf[RangePartitioning]) + + // InMemoryTableScan is HashPartitioning + Seq(1 -> "a").toDF("i", "j").repartition($"i").persist() + checkInMemoryTableScanOutputPartitioningRewrite( + Seq(1 -> "a").toDF("i", "j").repartition($"i"), classOf[HashPartitioning]) + + // InMemoryTableScan is RangePartitioning + spark.range(1, 100, 1, 10).toDF().persist() + checkInMemoryTableScanOutputPartitioningRewrite( + spark.range(1, 100, 1, 10).toDF(), classOf[RangePartitioning]) + + // InMemoryTableScan is PartitioningCollection + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + Seq(1 -> "a").toDF("i", "j").join(Seq(1 -> "a").toDF("m", "n"), $"i" === $"m").persist() + checkInMemoryTableScanOutputPartitioningRewrite( + Seq(1 -> "a").toDF("i", "j").join(Seq(1 -> "a").toDF("m", "n"), $"i" === $"m"), + classOf[PartitioningCollection]) + } + } } // Used for unit-testing EnsureRequirements diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala new file mode 100644 index 0000000000000..6d7c7de9a856e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala @@ -0,0 +1,442 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import java.io.File + +import scala.util.{Random, Try} + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.functions.monotonically_increasing_id +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.{Benchmark, Utils} + + +/** + * Benchmark to measure read performance with Filter pushdown. + * To run this: + * spark-submit --class + */ +object FilterPushdownBenchmark { + val conf = new SparkConf() + .setAppName("FilterPushdownBenchmark") + // Since `spark.master` always exists, overrides this value + .set("spark.master", "local[1]") + .setIfMissing("spark.driver.memory", "3g") + .setIfMissing("spark.executor.memory", "3g") + .setIfMissing("spark.ui.enabled", "false") + .setIfMissing("orc.compression", "snappy") + .setIfMissing("spark.sql.parquet.compression.codec", "snappy") + + private val spark = SparkSession.builder().config(conf).getOrCreate() + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) + } + } + } + + private def prepareTable( + dir: File, numRows: Int, width: Int, useStringForValue: Boolean): Unit = { + import spark.implicits._ + val selectExpr = (1 to width).map(i => s"CAST(value AS STRING) c$i") + val valueCol = if (useStringForValue) { + monotonically_increasing_id().cast("string") + } else { + monotonically_increasing_id() + } + val df = spark.range(numRows).map(_ => Random.nextLong).selectExpr(selectExpr: _*) + .withColumn("value", valueCol) + .sort("value") + + saveAsOrcTable(df, dir.getCanonicalPath + "/orc") + saveAsParquetTable(df, dir.getCanonicalPath + "/parquet") + } + + private def prepareStringDictTable( + dir: File, numRows: Int, numDistinctValues: Int, width: Int): Unit = { + val selectExpr = (0 to width).map { + case 0 => s"CAST(id % $numDistinctValues AS STRING) AS value" + case i => s"CAST(rand() AS STRING) c$i" + } + val df = spark.range(numRows).selectExpr(selectExpr: _*).sort("value") + + saveAsOrcTable(df, dir.getCanonicalPath + "/orc") + saveAsParquetTable(df, dir.getCanonicalPath + "/parquet") + } + + private def saveAsOrcTable(df: DataFrame, dir: String): Unit = { + // To always turn on dictionary encoding, we set 1.0 at the threshold (the default is 0.8) + df.write.mode("overwrite").option("orc.dictionary.key.threshold", 1.0).orc(dir) + spark.read.orc(dir).createOrReplaceTempView("orcTable") + } + + private def saveAsParquetTable(df: DataFrame, dir: String): Unit = { + df.write.mode("overwrite").parquet(dir) + spark.read.parquet(dir).createOrReplaceTempView("parquetTable") + } + + def filterPushDownBenchmark( + values: Int, + title: String, + whereExpr: String, + selectExpr: String = "*"): Unit = { + val benchmark = new Benchmark(title, values, minNumIters = 5) + + Seq(false, true).foreach { pushDownEnabled => + val name = s"Parquet Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" + benchmark.addCase(name) { _ => + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { + spark.sql(s"SELECT $selectExpr FROM parquetTable WHERE $whereExpr").collect() + } + } + } + + Seq(false, true).foreach { pushDownEnabled => + val name = s"Native ORC Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" + benchmark.addCase(name) { _ => + withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { + spark.sql(s"SELECT $selectExpr FROM orcTable WHERE $whereExpr").collect() + } + } + } + + /* + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz + Select 0 string row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 9201 / 9300 1.7 585.0 1.0X + Parquet Vectorized (Pushdown) 89 / 105 176.3 5.7 103.1X + Native ORC Vectorized 8886 / 8898 1.8 564.9 1.0X + Native ORC Vectorized (Pushdown) 110 / 128 143.4 7.0 83.9X + + + Select 0 string row + ('7864320' < value < '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 9336 / 9357 1.7 593.6 1.0X + Parquet Vectorized (Pushdown) 927 / 937 17.0 58.9 10.1X + Native ORC Vectorized 9026 / 9041 1.7 573.9 1.0X + Native ORC Vectorized (Pushdown) 257 / 272 61.1 16.4 36.3X + + + Select 1 string row (value = '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 9209 / 9223 1.7 585.5 1.0X + Parquet Vectorized (Pushdown) 908 / 925 17.3 57.7 10.1X + Native ORC Vectorized 8878 / 8904 1.8 564.4 1.0X + Native ORC Vectorized (Pushdown) 248 / 261 63.4 15.8 37.1X + + + Select 1 string row + (value <=> '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 9194 / 9216 1.7 584.5 1.0X + Parquet Vectorized (Pushdown) 899 / 908 17.5 57.2 10.2X + Native ORC Vectorized 8934 / 8962 1.8 568.0 1.0X + Native ORC Vectorized (Pushdown) 249 / 254 63.3 15.8 37.0X + + + Select 1 string row + ('7864320' <= value <= '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 9332 / 9351 1.7 593.3 1.0X + Parquet Vectorized (Pushdown) 915 / 934 17.2 58.2 10.2X + Native ORC Vectorized 9049 / 9057 1.7 575.3 1.0X + Native ORC Vectorized (Pushdown) 248 / 258 63.5 15.8 37.7X + + + Select all string rows + (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 20478 / 20497 0.8 1301.9 1.0X + Parquet Vectorized (Pushdown) 20461 / 20550 0.8 1300.9 1.0X + Native ORC Vectorized 27464 / 27482 0.6 1746.1 0.7X + Native ORC Vectorized (Pushdown) 27454 / 27488 0.6 1745.5 0.7X + + + Select 0 int row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8489 / 8519 1.9 539.7 1.0X + Parquet Vectorized (Pushdown) 64 / 69 246.1 4.1 132.8X + Native ORC Vectorized 8064 / 8099 2.0 512.7 1.1X + Native ORC Vectorized (Pushdown) 88 / 94 178.6 5.6 96.4X + + + Select 0 int row + (7864320 < value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8494 / 8514 1.9 540.0 1.0X + Parquet Vectorized (Pushdown) 835 / 840 18.8 53.1 10.2X + Native ORC Vectorized 8090 / 8106 1.9 514.4 1.0X + Native ORC Vectorized (Pushdown) 249 / 257 63.2 15.8 34.1X + + + Select 1 int row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8552 / 8560 1.8 543.7 1.0X + Parquet Vectorized (Pushdown) 837 / 841 18.8 53.2 10.2X + Native ORC Vectorized 8178 / 8188 1.9 519.9 1.0X + Native ORC Vectorized (Pushdown) 249 / 258 63.2 15.8 34.4X + + + Select 1 int row (value <=> 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8562 / 8580 1.8 544.3 1.0X + Parquet Vectorized (Pushdown) 833 / 836 18.9 53.0 10.3X + Native ORC Vectorized 8164 / 8185 1.9 519.0 1.0X + Native ORC Vectorized (Pushdown) 245 / 254 64.3 15.6 35.0X + + + Select 1 int row + (7864320 <= value <= 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8540 / 8555 1.8 542.9 1.0X + Parquet Vectorized (Pushdown) 837 / 839 18.8 53.2 10.2X + Native ORC Vectorized 8182 / 8231 1.9 520.2 1.0X + Native ORC Vectorized (Pushdown) 250 / 259 62.9 15.9 34.1X + + + Select 1 int row + (7864319 < value < 7864321): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8535 / 8555 1.8 542.6 1.0X + Parquet Vectorized (Pushdown) 835 / 841 18.8 53.1 10.2X + Native ORC Vectorized 8159 / 8179 1.9 518.8 1.0X + Native ORC Vectorized (Pushdown) 244 / 250 64.5 15.5 35.0X + + + Select 10% int rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 9609 / 9634 1.6 610.9 1.0X + Parquet Vectorized (Pushdown) 2663 / 2672 5.9 169.3 3.6X + Native ORC Vectorized 9824 / 9850 1.6 624.6 1.0X + Native ORC Vectorized (Pushdown) 2717 / 2722 5.8 172.7 3.5X + + + Select 50% int rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 13592 / 13613 1.2 864.2 1.0X + Parquet Vectorized (Pushdown) 9720 / 9738 1.6 618.0 1.4X + Native ORC Vectorized 16366 / 16397 1.0 1040.5 0.8X + Native ORC Vectorized (Pushdown) 12437 / 12459 1.3 790.7 1.1X + + + Select 90% int rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 17580 / 17617 0.9 1117.7 1.0X + Parquet Vectorized (Pushdown) 16803 / 16827 0.9 1068.3 1.0X + Native ORC Vectorized 24169 / 24187 0.7 1536.6 0.7X + Native ORC Vectorized (Pushdown) 22147 / 22341 0.7 1408.1 0.8X + + + Select all int rows (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 18461 / 18491 0.9 1173.7 1.0X + Parquet Vectorized (Pushdown) 18466 / 18530 0.9 1174.1 1.0X + Native ORC Vectorized 24231 / 24270 0.6 1540.6 0.8X + Native ORC Vectorized (Pushdown) 24207 / 24304 0.6 1539.0 0.8X + + + Select all int rows (value > -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 18414 / 18453 0.9 1170.7 1.0X + Parquet Vectorized (Pushdown) 18435 / 18464 0.9 1172.1 1.0X + Native ORC Vectorized 24430 / 24454 0.6 1553.2 0.8X + Native ORC Vectorized (Pushdown) 24410 / 24465 0.6 1552.0 0.8X + + + Select all int rows (value != -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 18446 / 18457 0.9 1172.8 1.0X + Parquet Vectorized (Pushdown) 18428 / 18440 0.9 1171.6 1.0X + Native ORC Vectorized 24414 / 24450 0.6 1552.2 0.8X + Native ORC Vectorized (Pushdown) 24385 / 24472 0.6 1550.4 0.8X + + + Select 0 distinct string row + (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8322 / 8352 1.9 529.1 1.0X + Parquet Vectorized (Pushdown) 53 / 57 296.3 3.4 156.7X + Native ORC Vectorized 7903 / 7953 2.0 502.4 1.1X + Native ORC Vectorized (Pushdown) 80 / 82 197.2 5.1 104.3X + + + Select 0 distinct string row + ('100' < value < '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8712 / 8743 1.8 553.9 1.0X + Parquet Vectorized (Pushdown) 995 / 1030 15.8 63.3 8.8X + Native ORC Vectorized 8345 / 8362 1.9 530.6 1.0X + Native ORC Vectorized (Pushdown) 84 / 87 187.6 5.3 103.9X + + + Select 1 distinct string row + (value = '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8574 / 8610 1.8 545.1 1.0X + Parquet Vectorized (Pushdown) 1127 / 1135 14.0 71.6 7.6X + Native ORC Vectorized 8163 / 8181 1.9 519.0 1.1X + Native ORC Vectorized (Pushdown) 426 / 433 36.9 27.1 20.1X + + + Select 1 distinct string row + (value <=> '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8549 / 8568 1.8 543.5 1.0X + Parquet Vectorized (Pushdown) 1124 / 1131 14.0 71.4 7.6X + Native ORC Vectorized 8163 / 8210 1.9 519.0 1.0X + Native ORC Vectorized (Pushdown) 426 / 436 36.9 27.1 20.1X + + + Select 1 distinct string row + ('100' <= value <= '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8889 / 8896 1.8 565.2 1.0X + Parquet Vectorized (Pushdown) 1161 / 1168 13.6 73.8 7.7X + Native ORC Vectorized 8519 / 8554 1.8 541.6 1.0X + Native ORC Vectorized (Pushdown) 430 / 437 36.6 27.3 20.7X + + + Select all distinct string rows + (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 20433 / 20533 0.8 1299.1 1.0X + Parquet Vectorized (Pushdown) 20433 / 20456 0.8 1299.1 1.0X + Native ORC Vectorized 25435 / 25513 0.6 1617.1 0.8X + Native ORC Vectorized (Pushdown) 25435 / 25507 0.6 1617.1 0.8X + */ + + benchmark.run() + } + + private def runIntBenchmark(numRows: Int, width: Int, mid: Int): Unit = { + Seq("value IS NULL", s"$mid < value AND value < $mid").foreach { whereExpr => + val title = s"Select 0 int row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + Seq( + s"value = $mid", + s"value <=> $mid", + s"$mid <= value AND value <= $mid", + s"${mid - 1} < value AND value < ${mid + 1}" + ).foreach { whereExpr => + val title = s"Select 1 int row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% int rows (value < ${numRows * percent / 100})", + s"value < ${numRows * percent / 100}", + selectExpr + ) + } + + Seq("value IS NOT NULL", "value > -1", "value != -1").foreach { whereExpr => + filterPushDownBenchmark( + numRows, + s"Select all int rows ($whereExpr)", + whereExpr, + selectExpr) + } + } + + private def runStringBenchmark( + numRows: Int, width: Int, searchValue: Int, colType: String): Unit = { + Seq("value IS NULL", s"'$searchValue' < value AND value < '$searchValue'") + .foreach { whereExpr => + val title = s"Select 0 $colType row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + Seq( + s"value = '$searchValue'", + s"value <=> '$searchValue'", + s"'$searchValue' <= value AND value <= '$searchValue'" + ).foreach { whereExpr => + val title = s"Select 1 $colType row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + + Seq("value IS NOT NULL").foreach { whereExpr => + filterPushDownBenchmark( + numRows, + s"Select all $colType rows ($whereExpr)", + whereExpr, + selectExpr) + } + } + + def main(args: Array[String]): Unit = { + val numRows = 1024 * 1024 * 15 + val width = 5 + + // Pushdown for many distinct value case + withTempPath { dir => + val mid = numRows / 2 + + withTempTable("orcTable", "patquetTable") { + Seq(true, false).foreach { useStringForValue => + prepareTable(dir, numRows, width, useStringForValue) + if (useStringForValue) { + runStringBenchmark(numRows, width, mid, "string") + } else { + runIntBenchmark(numRows, width, mid) + } + } + } + } + + // Pushdown for few distinct value case (use dictionary encoding) + withTempPath { dir => + val numDistinctValues = 200 + val mid = numDistinctValues / 2 + + withTempTable("orcTable", "patquetTable") { + prepareStringDictTable(dir, numRows, numDistinctValues, width) + runStringBenchmark(numRows, width, mid, "distinct string") + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 4b3921c61a000..897424daca0cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.datasources.json -import java.io.{File, FileOutputStream, StringWriter} -import java.nio.charset.{StandardCharsets, UnsupportedCharsetException} -import java.nio.file.{Files, Paths, StandardOpenOption} +import java.io._ +import java.nio.charset.{Charset, StandardCharsets, UnsupportedCharsetException} +import java.nio.file.Files import java.sql.{Date, Timestamp} import java.util.Locale @@ -2262,7 +2262,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { withTempPath { path => val df = spark.createDataset(Seq(("Dog", 42))) df.write - .options(Map("encoding" -> encoding, "lineSep" -> "\n")) + .options(Map("encoding" -> encoding)) .json(path.getCanonicalPath) checkEncoding( @@ -2286,16 +2286,22 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-23723: wrong output encoding") { val encoding = "UTF-128" - val exception = intercept[UnsupportedCharsetException] { + val exception = intercept[SparkException] { withTempPath { path => val df = spark.createDataset(Seq((0))) df.write - .options(Map("encoding" -> encoding, "lineSep" -> "\n")) + .options(Map("encoding" -> encoding)) .json(path.getCanonicalPath) } } - assert(exception.getMessage == encoding) + val baos = new ByteArrayOutputStream() + val ps = new PrintStream(baos, true, "UTF-8") + exception.printStackTrace(ps) + ps.flush() + + assert(baos.toString.contains( + "java.nio.charset.UnsupportedCharsetException: UTF-128")) } test("SPARK-23723: read back json in UTF-16LE") { @@ -2316,18 +2322,17 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-23723: write json in UTF-16/32 with multiline off") { Seq("UTF-16", "UTF-32").foreach { encoding => withTempPath { path => - val ds = spark.createDataset(Seq( - ("a", 1), ("b", 2), ("c", 3)) - ).repartition(2) - val e = intercept[IllegalArgumentException] { - ds.write - .option("encoding", encoding) - .option("multiline", "false") - .format("json").mode("overwrite") - .save(path.getCanonicalPath) - }.getMessage - assert(e.contains( - s"$encoding encoding in the blacklist is not allowed when multiLine is disabled")) + val ds = spark.createDataset(Seq(("a", 1))).repartition(1) + ds.write + .option("encoding", encoding) + .option("multiline", false) + .json(path.getCanonicalPath) + val jsonFiles = path.listFiles().filter(_.getName.endsWith("json")) + jsonFiles.foreach { jsonFile => + val readback = Files.readAllBytes(jsonFile.toPath) + val expected = ("""{"_1":"a","_2":1}""" + "\n").getBytes(Charset.forName(encoding)) + assert(readback === expected) + } } } } @@ -2427,4 +2432,66 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { spark.read.option("mode", "PERMISSIVE").option("encoding", "UTF-8").json(Seq(badJson).toDS()), Row(badJson)) } + + test("SPARK-23772 ignore column of all null values or empty array during schema inference") { + withTempPath { tempDir => + val path = tempDir.getAbsolutePath + + // primitive types + Seq( + """{"a":null, "b":1, "c":3.0}""", + """{"a":null, "b":null, "c":"string"}""", + """{"a":null, "b":null, "c":null}""") + .toDS().write.text(path) + var df = spark.read.format("json") + .option("dropFieldIfAllNull", true) + .load(path) + var expectedSchema = new StructType() + .add("b", LongType).add("c", StringType) + assert(df.schema === expectedSchema) + checkAnswer(df, Row(1, "3.0") :: Row(null, "string") :: Row(null, null) :: Nil) + + // arrays + Seq( + """{"a":[2, 1], "b":[null, null], "c":null, "d":[[], [null]], "e":[[], null, [[]]]}""", + """{"a":[null], "b":[null], "c":[], "d":[null, []], "e":null}""", + """{"a":null, "b":null, "c":[], "d":null, "e":[null, [], null]}""") + .toDS().write.mode("overwrite").text(path) + df = spark.read.format("json") + .option("dropFieldIfAllNull", true) + .load(path) + expectedSchema = new StructType() + .add("a", ArrayType(LongType)) + assert(df.schema === expectedSchema) + checkAnswer(df, Row(Array(2, 1)) :: Row(Array(null)) :: Row(null) :: Nil) + + // structs + Seq( + """{"a":{"a1": 1, "a2":"string"}, "b":{}}""", + """{"a":{"a1": 2, "a2":null}, "b":{"b1":[null]}}""", + """{"a":null, "b":null}""") + .toDS().write.mode("overwrite").text(path) + df = spark.read.format("json") + .option("dropFieldIfAllNull", true) + .load(path) + expectedSchema = new StructType() + .add("a", StructType(StructField("a1", LongType) :: StructField("a2", StringType) + :: Nil)) + assert(df.schema === expectedSchema) + checkAnswer(df, Row(Row(1, "string")) :: Row(Row(2, null)) :: Row(null) :: Nil) + } + } + + test("SPARK-24190: restrictions for JSONOptions in read") { + for (encoding <- Set("UTF-16", "UTF-32")) { + val exception = intercept[IllegalArgumentException] { + spark.read + .option("encoding", encoding) + .option("multiLine", false) + .json(testFile("test-data/utf16LE.json")) + .count() + } + assert(exception.getMessage.contains("encoding must not be included in the blacklist")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index a3a3f3851e21c..8263c9c81c49e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.execution.metric import java.io.File +import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future} import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.execution.ui.SQLAppStatusStore import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -504,4 +504,38 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared test("writing data out metrics with dynamic partition: parquet") { testMetricsDynamicPartition("parquet", "parquet", "t1") } + + test("writing metrics from single thread") { + val nAdds = 10 + val acc = new SQLMetric("test", -10) + assert(acc.isZero()) + acc.set(0) + for (i <- 1 to nAdds) acc.add(1) + assert(!acc.isZero()) + assert(nAdds === acc.value) + acc.reset() + assert(acc.isZero()) + } + + test("writing metrics from multiple threads") { + implicit val ec: ExecutionContextExecutor = ExecutionContext.global + val nFutures = 1000 + val nAdds = 100 + val acc = new SQLMetric("test", -10) + assert(acc.isZero() === true) + acc.set(0) + val l = for ( i <- 1 to nFutures ) yield { + Future { + for (j <- 1 to nAdds) acc.add(1) + i + } + } + for (futures <- Future.sequence(l)) { + assert(nFutures === futures.length) + assert(!acc.isZero()) + assert(nFutures * nAdds === acc.value) + acc.reset() + assert(acc.isZero()) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala new file mode 100644 index 0000000000000..07e6034770127 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.concurrent.Eventually +import org.scalatest.time.SpanSugar._ + +import org.apache.spark._ +import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.execution.python.PythonForeachWriter.UnsafeRowBuffer +import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.util.Utils + +class PythonForeachWriterSuite extends SparkFunSuite with Eventually { + + testWithBuffer("UnsafeRowBuffer: iterator blocks when no data is available") { b => + b.assertIteratorBlocked() + + b.add(Seq(1)) + b.assertOutput(Seq(1)) + b.assertIteratorBlocked() + + b.add(2 to 100) + b.assertOutput(1 to 100) + b.assertIteratorBlocked() + } + + testWithBuffer("UnsafeRowBuffer: iterator unblocks when all data added") { b => + b.assertIteratorBlocked() + b.add(Seq(1)) + b.assertIteratorBlocked() + + b.allAdded() + b.assertThreadTerminated() + b.assertOutput(Seq(1)) + } + + testWithBuffer( + "UnsafeRowBuffer: handles more data than memory", + memBytes = 5, + sleepPerRowReadMs = 1) { b => + + b.assertIteratorBlocked() + b.add(1 to 2000) + b.assertOutput(1 to 2000) + } + + def testWithBuffer( + name: String, + memBytes: Long = 4 << 10, + sleepPerRowReadMs: Int = 0 + )(f: BufferTester => Unit): Unit = { + + test(name) { + var tester: BufferTester = null + try { + tester = new BufferTester(memBytes, sleepPerRowReadMs) + f(tester) + } finally { + if (tester == null) tester.close() + } + } + } + + + class BufferTester(memBytes: Long, sleepPerRowReadMs: Int) { + private val buffer = { + val mem = new TestMemoryManager(new SparkConf()) + mem.limit(memBytes) + val taskM = new TaskMemoryManager(mem, 0) + new UnsafeRowBuffer(taskM, Utils.createTempDir(), 1) + } + private val iterator = buffer.iterator + private val outputBuffer = new ArrayBuffer[Int] + private val testTimeout = timeout(20.seconds) + private val intProj = UnsafeProjection.create(Array[DataType](IntegerType)) + private val thread = new Thread() { + override def run(): Unit = { + while (iterator.hasNext) { + outputBuffer.synchronized { + outputBuffer += iterator.next().getInt(0) + } + Thread.sleep(sleepPerRowReadMs) + } + } + } + thread.start() + + def add(ints: Seq[Int]): Unit = { + ints.foreach { i => buffer.add(intProj.apply(new GenericInternalRow(Array[Any](i)))) } + } + + def allAdded(): Unit = { buffer.allRowsAdded() } + + def assertOutput(expectedOutput: Seq[Int]): Unit = { + eventually(testTimeout) { + val output = outputBuffer.synchronized { outputBuffer.toArray }.toSeq + assert(output == expectedOutput) + } + } + + def assertIteratorBlocked(): Unit = { + import Thread.State._ + eventually(testTimeout) { + assert(thread.isAlive) + assert(thread.getState == TIMED_WAITING || thread.getState == WAITING) + } + } + + def assertThreadTerminated(): Unit = { + eventually(testTimeout) { assert(!thread.isAlive) } + } + + def close(): Unit = { + thread.interrupt() + thread.join() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index 3bc36ce55d902..b2fd6ba27ebb8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.execution.streaming +import scala.collection.JavaConverters._ import scala.language.implicitConversions import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.{OutputMode, StreamTest} import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -36,7 +38,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("directly add data in Append output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Append) + val sink = new MemorySink(schema, OutputMode.Append, DataSourceOptions.empty()) // Before adding data, check output assert(sink.latestBatchId === None) @@ -68,9 +70,35 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { checkAnswer(sink.allData, 1 to 9) } + test("directly add data in Append output mode with row limit") { + implicit val schema = new StructType().add(new StructField("value", IntegerType)) + + var optionsMap = new scala.collection.mutable.HashMap[String, String] + optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString()) + var options = new DataSourceOptions(optionsMap.toMap.asJava) + val sink = new MemorySink(schema, OutputMode.Append, options) + + // Before adding data, check output + assert(sink.latestBatchId === None) + checkAnswer(sink.latestBatchData, Seq.empty) + checkAnswer(sink.allData, Seq.empty) + + // Add batch 0 and check outputs + sink.addBatch(0, 1 to 3) + assert(sink.latestBatchId === Some(0)) + checkAnswer(sink.latestBatchData, 1 to 3) + checkAnswer(sink.allData, 1 to 3) + + // Add batch 1 and check outputs + sink.addBatch(1, 4 to 6) + assert(sink.latestBatchId === Some(1)) + checkAnswer(sink.latestBatchData, 4 to 5) + checkAnswer(sink.allData, 1 to 5) // new data should not go over the limit + } + test("directly add data in Update output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Update) + val sink = new MemorySink(schema, OutputMode.Update, DataSourceOptions.empty()) // Before adding data, check output assert(sink.latestBatchId === None) @@ -104,7 +132,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("directly add data in Complete output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Complete) + val sink = new MemorySink(schema, OutputMode.Complete, DataSourceOptions.empty()) // Before adding data, check output assert(sink.latestBatchId === None) @@ -136,6 +164,32 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { checkAnswer(sink.allData, 7 to 9) } + test("directly add data in Complete output mode with row limit") { + implicit val schema = new StructType().add(new StructField("value", IntegerType)) + + var optionsMap = new scala.collection.mutable.HashMap[String, String] + optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString()) + var options = new DataSourceOptions(optionsMap.toMap.asJava) + val sink = new MemorySink(schema, OutputMode.Complete, options) + + // Before adding data, check output + assert(sink.latestBatchId === None) + checkAnswer(sink.latestBatchData, Seq.empty) + checkAnswer(sink.allData, Seq.empty) + + // Add batch 0 and check outputs + sink.addBatch(0, 1 to 3) + assert(sink.latestBatchId === Some(0)) + checkAnswer(sink.latestBatchData, 1 to 3) + checkAnswer(sink.allData, 1 to 3) + + // Add batch 1 and check outputs + sink.addBatch(1, 4 to 10) + assert(sink.latestBatchId === Some(1)) + checkAnswer(sink.latestBatchData, 4 to 8) + checkAnswer(sink.allData, 4 to 8) // new data should replace old data + } + test("registering as a table in Append output mode") { val input = MemoryStream[Int] @@ -211,7 +265,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("MemoryPlan statistics") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Append) + val sink = new MemorySink(schema, OutputMode.Append, DataSourceOptions.empty()) val plan = new MemoryPlan(sink) // Before adding data, check output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index 9be22d94b5654..e539510e15755 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -17,11 +17,16 @@ package org.apache.spark.sql.execution.streaming +import scala.collection.JavaConverters._ + import org.scalatest.BeforeAndAfter import org.apache.spark.sql.Row import org.apache.spark.sql.execution.streaming.sources._ +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.{OutputMode, StreamTest} +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.StructType class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("data writer") { @@ -40,7 +45,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("continuous writer") { val sink = new MemorySinkV2 - val writer = new MemoryStreamWriter(sink, OutputMode.Append()) + val writer = new MemoryStreamWriter(sink, OutputMode.Append(), DataSourceOptions.empty()) writer.commit(0, Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), @@ -62,7 +67,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("microbatch writer") { val sink = new MemorySinkV2 - new MemoryWriter(sink, 0, OutputMode.Append()).commit( + new MemoryWriter(sink, 0, OutputMode.Append(), DataSourceOptions.empty()).commit( Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), @@ -70,7 +75,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { )) assert(sink.latestBatchId.contains(0)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - new MemoryWriter(sink, 19, OutputMode.Append()).commit( + new MemoryWriter(sink, 19, OutputMode.Append(), DataSourceOptions.empty()).commit( Array( MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), MemoryWriterCommitMessage(0, Seq(Row(33))) @@ -80,4 +85,73 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33)) } + + test("continuous writer with row limit") { + val sink = new MemorySinkV2 + val optionsMap = new scala.collection.mutable.HashMap[String, String] + optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 7.toString()) + val options = new DataSourceOptions(optionsMap.toMap.asJava) + val appendWriter = new MemoryStreamWriter(sink, OutputMode.Append(), options) + appendWriter.commit(0, Array( + MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), + MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), + MemoryWriterCommitMessage(2, Seq(Row(6), Row(7))))) + assert(sink.latestBatchId.contains(0)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) + appendWriter.commit(19, Array( + MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), + MemoryWriterCommitMessage(0, Seq(Row(33))))) + assert(sink.latestBatchId.contains(19)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11)) + + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11)) + + val completeWriter = new MemoryStreamWriter(sink, OutputMode.Complete(), options) + completeWriter.commit(20, Array( + MemoryWriterCommitMessage(4, Seq(Row(11), Row(22))), + MemoryWriterCommitMessage(5, Seq(Row(33))))) + assert(sink.latestBatchId.contains(20)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33)) + completeWriter.commit(21, Array( + MemoryWriterCommitMessage(0, Seq(Row(1), Row(2), Row(3))), + MemoryWriterCommitMessage(1, Seq(Row(4), Row(5), Row(6))), + MemoryWriterCommitMessage(2, Seq(Row(7), Row(8), Row(9))))) + assert(sink.latestBatchId.contains(21)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5, 6, 7)) + + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5, 6, 7)) + } + + test("microbatch writer with row limit") { + val sink = new MemorySinkV2 + val optionsMap = new scala.collection.mutable.HashMap[String, String] + optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString()) + val options = new DataSourceOptions(optionsMap.toMap.asJava) + + new MemoryWriter(sink, 25, OutputMode.Append(), options).commit(Array( + MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), + MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))))) + assert(sink.latestBatchId.contains(25)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4)) + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4)) + new MemoryWriter(sink, 26, OutputMode.Append(), options).commit(Array( + MemoryWriterCommitMessage(2, Seq(Row(5), Row(6))), + MemoryWriterCommitMessage(3, Seq(Row(7), Row(8))))) + assert(sink.latestBatchId.contains(26)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(5)) + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5)) + + new MemoryWriter(sink, 27, OutputMode.Complete(), options).commit(Array( + MemoryWriterCommitMessage(4, Seq(Row(9), Row(10))), + MemoryWriterCommitMessage(5, Seq(Row(11), Row(12))))) + assert(sink.latestBatchId.contains(27)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(9, 10, 11, 12)) + assert(sink.allData.map(_.getInt(0)).sorted == Seq(9, 10, 11, 12)) + new MemoryWriter(sink, 28, OutputMode.Complete(), options).commit(Array( + MemoryWriterCommitMessage(4, Seq(Row(13), Row(14), Row(15))), + MemoryWriterCommitMessage(5, Seq(Row(16), Row(17), Row(18))))) + assert(sink.latestBatchId.contains(28)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(13, 14, 15, 16, 17)) + assert(sink.allData.map(_.getInt(0)).sorted == Seq(13, 14, 15, 16, 17)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala new file mode 100644 index 0000000000000..a4233e15e4ffd --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import scala.collection.mutable + +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming._ + +case class KV(key: Int, value: Long) + +class ForeachBatchSinkSuite extends StreamTest { + import testImplicits._ + + test("foreachBatch with non-stateful query") { + val mem = MemoryStream[Int] + val ds = mem.toDS.map(_ + 1) + + val tester = new ForeachBatchTester[Int](mem) + val writer = (ds: Dataset[Int], batchId: Long) => tester.record(batchId, ds.map(_ + 1)) + + import tester._ + testWriter(ds, writer)( + check(in = 1, 2, 3)(out = 3, 4, 5), // out = in + 2 (i.e. 1 in query, 1 in writer) + check(in = 5, 6, 7)(out = 7, 8, 9)) + } + + test("foreachBatch with stateful query in update mode") { + val mem = MemoryStream[Int] + val ds = mem.toDF() + .select($"value" % 2 as "key") + .groupBy("key") + .agg(count("*") as "value") + .toDF.as[KV] + + val tester = new ForeachBatchTester[KV](mem) + val writer = (batchDS: Dataset[KV], batchId: Long) => tester.record(batchId, batchDS) + + import tester._ + testWriter(ds, writer, outputMode = OutputMode.Update)( + check(in = 0)(out = (0, 1L)), + check(in = 1)(out = (1, 1L)), + check(in = 2, 3)(out = (0, 2L), (1, 2L))) + } + + test("foreachBatch with stateful query in complete mode") { + val mem = MemoryStream[Int] + val ds = mem.toDF() + .select($"value" % 2 as "key") + .groupBy("key") + .agg(count("*") as "value") + .toDF.as[KV] + + val tester = new ForeachBatchTester[KV](mem) + val writer = (batchDS: Dataset[KV], batchId: Long) => tester.record(batchId, batchDS) + + import tester._ + testWriter(ds, writer, outputMode = OutputMode.Complete)( + check(in = 0)(out = (0, 1L)), + check(in = 1)(out = (0, 1L), (1, 1L)), + check(in = 2)(out = (0, 2L), (1, 1L))) + } + + test("foreachBatchSink does not affect metric generation") { + val mem = MemoryStream[Int] + val ds = mem.toDS.map(_ + 1) + + val tester = new ForeachBatchTester[Int](mem) + val writer = (ds: Dataset[Int], batchId: Long) => tester.record(batchId, ds.map(_ + 1)) + + import tester._ + testWriter(ds, writer)( + check(in = 1, 2, 3)(out = 3, 4, 5), + checkMetrics) + } + + test("throws errors in invalid situations") { + val ds = MemoryStream[Int].toDS + val ex1 = intercept[IllegalArgumentException] { + ds.writeStream.foreachBatch(null.asInstanceOf[(Dataset[Int], Long) => Unit]).start() + } + assert(ex1.getMessage.contains("foreachBatch function cannot be null")) + val ex2 = intercept[AnalysisException] { + ds.writeStream.foreachBatch((_, _) => {}).trigger(Trigger.Continuous("1 second")).start() + } + assert(ex2.getMessage.contains("'foreachBatch' is not supported with continuous trigger")) + val ex3 = intercept[AnalysisException] { + ds.writeStream.foreachBatch((_, _) => {}).partitionBy("value").start() + } + assert(ex3.getMessage.contains("'foreachBatch' does not support partitioning")) + } + + // ============== Helper classes and methods ================= + + private class ForeachBatchTester[T: Encoder](memoryStream: MemoryStream[Int]) { + trait Test + private case class Check(in: Seq[Int], out: Seq[T]) extends Test + private case object CheckMetrics extends Test + + private val recordedOutput = new mutable.HashMap[Long, Seq[T]] + + def testWriter( + ds: Dataset[T], + outputBatchWriter: (Dataset[T], Long) => Unit, + outputMode: OutputMode = OutputMode.Append())(tests: Test*): Unit = { + try { + var expectedBatchId = -1 + val query = ds.writeStream.outputMode(outputMode).foreachBatch(outputBatchWriter).start() + + tests.foreach { + case Check(in, out) => + expectedBatchId += 1 + memoryStream.addData(in) + query.processAllAvailable() + assert(recordedOutput.contains(expectedBatchId)) + val ds: Dataset[T] = spark.createDataset[T](recordedOutput(expectedBatchId)) + checkDataset[T](ds, out: _*) + case CheckMetrics => + assert(query.recentProgress.exists(_.numInputRows > 0)) + } + } finally { + sqlContext.streams.active.foreach(_.stop()) + } + } + + def check(in: Int*)(out: T*): Test = Check(in, out) + def checkMetrics: Test = CheckMetrics + def record(batchId: Long, ds: Dataset[T]): Unit = recordedOutput.put(batchId, ds.collect()) + implicit def conv(x: (Int, Long)): KV = KV(x._1, x._2) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala index a15a980bb92fd..52e8386f6b1fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.execution.streaming.sources -import java.io.IOException -import java.net.InetSocketAddress +import java.net.{InetSocketAddress, SocketException} import java.nio.ByteBuffer import java.nio.channels.ServerSocketChannel import java.sql.Timestamp @@ -33,9 +32,10 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} -import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} @@ -101,7 +101,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread = new ServerThread() serverThread.start() - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + withSQLConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED.key -> "false") { val ref = spark import ref.implicits._ @@ -130,7 +130,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread = new ServerThread() serverThread.start() - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + withSQLConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED.key -> "false") { val socket = spark .readStream .format("socket") @@ -216,20 +216,11 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before "socket source does not support a user-specified schema")) } - test("no server up") { - val provider = new TextSocketSourceProvider - val parameters = Map("host" -> "localhost", "port" -> "0") - intercept[IOException] { - batchReader = provider.createMicroBatchReader( - Optional.empty(), "", new DataSourceOptions(parameters.asJava)) - } - } - test("input row metrics") { serverThread = new ServerThread() serverThread.start() - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + withSQLConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED.key -> "false") { val ref = spark import ref.implicits._ @@ -256,6 +247,66 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before } } + test("verify ServerThread only accepts the first connection") { + serverThread = new ServerThread() + serverThread.start() + + withSQLConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED.key -> "false") { + val ref = spark + import ref.implicits._ + + val socket = spark + .readStream + .format("socket") + .options(Map("host" -> "localhost", "port" -> serverThread.port.toString)) + .load() + .as[String] + + assert(socket.schema === StructType(StructField("value", StringType) :: Nil)) + + testStream(socket)( + StartStream(), + AddSocketData("hello"), + CheckAnswer("hello"), + AddSocketData("world"), + CheckLastBatch("world"), + CheckAnswer("hello", "world"), + StopStream + ) + + // we are trying to connect to the server once again which should fail + try { + val socket2 = spark + .readStream + .format("socket") + .options(Map("host" -> "localhost", "port" -> serverThread.port.toString)) + .load() + .as[String] + + testStream(socket2)( + StartStream(), + AddSocketData("hello"), + CheckAnswer("hello"), + AddSocketData("world"), + CheckLastBatch("world"), + CheckAnswer("hello", "world"), + StopStream + ) + + fail("StreamingQueryException is expected!") + } catch { + case e: StreamingQueryException if e.cause.isInstanceOf[SocketException] => // pass + } + } + } + + /** + * This class tries to mimic the behavior of netcat, so that we can ensure + * TextSocketStream supports netcat, which only accepts the first connection + * and exits the process when the first connection is closed. + * + * Please refer SPARK-24466 for more details. + */ private class ServerThread extends Thread with Logging { private val serverSocketChannel = ServerSocketChannel.open() serverSocketChannel.bind(new InetSocketAddress(0)) @@ -265,36 +316,24 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before override def run(): Unit = { try { + val clientSocketChannel = serverSocketChannel.accept() + + // Close server socket channel immediately to mimic the behavior that + // only first connection will be made and deny any further connections + // Note that the first client socket channel will be available + serverSocketChannel.close() + + clientSocketChannel.configureBlocking(false) + clientSocketChannel.socket().setTcpNoDelay(true) + while (true) { - val clientSocketChannel = serverSocketChannel.accept() - clientSocketChannel.configureBlocking(false) - clientSocketChannel.socket().setTcpNoDelay(true) - - // Check whether remote client is closed but still send data to this closed socket. - // This happens in DataStreamReader where a source will be created to get the schema. - var remoteIsClosed = false - var cnt = 0 - while (cnt < 3 && !remoteIsClosed) { - if (clientSocketChannel.read(ByteBuffer.allocate(1)) != -1) { - cnt += 1 - Thread.sleep(100) - } else { - remoteIsClosed = true - } - } - - if (remoteIsClosed) { - logInfo(s"remote client ${clientSocketChannel.socket()} is closed") - } else { - while (true) { - val line = messageQueue.take() + "\n" - clientSocketChannel.write(ByteBuffer.wrap(line.getBytes("UTF-8"))) - } - } + val line = messageQueue.take() + "\n" + clientSocketChannel.write(ByteBuffer.wrap(line.getBytes("UTF-8"))) } } catch { case e: InterruptedException => } finally { + // no harm to call close() again... serverSocketChannel.close() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index bc2aca65e803f..6ea61f02a8206 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -31,8 +31,9 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation, JdbcUtils} +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRDD, JDBCRelation, JdbcUtils} import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -238,6 +239,11 @@ class JDBCSuite extends SparkFunSuite |OPTIONS (url '$url', dbtable 'TEST."mixedCaseCols"', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) + conn.prepareStatement("CREATE TABLE test.partition (THEID INTEGER, `THE ID` INTEGER) " + + "AS SELECT 1, 1") + .executeUpdate() + conn.commit() + // Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types. } @@ -1206,4 +1212,47 @@ class JDBCSuite extends SparkFunSuite }.getMessage assert(errMsg.contains("Statement was canceled or the session timed out")) } + + test("SPARK-24327 verify and normalize a partition column based on a JDBC resolved schema") { + def testJdbcParitionColumn(partColName: String, expectedColumnName: String): Unit = { + val df = spark.read.format("jdbc") + .option("url", urlWithUserAndPass) + .option("dbtable", "TEST.PARTITION") + .option("partitionColumn", partColName) + .option("lowerBound", 1) + .option("upperBound", 4) + .option("numPartitions", 3) + .load() + + val quotedPrtColName = testH2Dialect.quoteIdentifier(expectedColumnName) + df.logicalPlan match { + case LogicalRelation(JDBCRelation(_, parts, _), _, _, _) => + val whereClauses = parts.map(_.asInstanceOf[JDBCPartition].whereClause).toSet + assert(whereClauses === Set( + s"$quotedPrtColName < 2 or $quotedPrtColName is null", + s"$quotedPrtColName >= 2 AND $quotedPrtColName < 3", + s"$quotedPrtColName >= 3")) + } + } + + testJdbcParitionColumn("THEID", "THEID") + testJdbcParitionColumn("\"THEID\"", "THEID") + withSQLConf("spark.sql.caseSensitive" -> "false") { + testJdbcParitionColumn("ThEiD", "THEID") + } + testJdbcParitionColumn("THE ID", "THE ID") + + def testIncorrectJdbcPartitionColumn(partColName: String): Unit = { + val errMsg = intercept[AnalysisException] { + testJdbcParitionColumn(partColName, "THEID") + }.getMessage + assert(errMsg.contains(s"User-defined partition column $partColName not found " + + "in the JDBC relation:")) + } + + testIncorrectJdbcPartitionColumn("NoExistingColumn") + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + testIncorrectJdbcPartitionColumn(testH2Dialect.quoteIdentifier("ThEiD")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index fef01c860db6e..438d5d8176b8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -20,12 +20,36 @@ package org.apache.spark.sql.sources import java.io.File import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +class SimpleInsertSource extends SchemaRelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: StructType): BaseRelation = { + SimpleInsert(schema)(sqlContext.sparkSession) + } +} + +case class SimpleInsert(userSpecifiedSchema: StructType)(@transient val sparkSession: SparkSession) + extends BaseRelation with InsertableRelation { + + override def sqlContext: SQLContext = sparkSession.sqlContext + + override def schema: StructType = userSpecifiedSchema + + override def insert(input: DataFrame, overwrite: Boolean): Unit = { + input.collect + } +} + class InsertSuite extends DataSourceTest with SharedSQLContext { import testImplicits._ @@ -520,4 +544,29 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { } } } + + test("SPARK-24583 Wrong schema type in InsertIntoDataSourceCommand") { + withTable("test_table") { + val schema = new StructType() + .add("i", LongType, false) + .add("s", StringType, false) + val newTable = CatalogTable( + identifier = TableIdentifier("test_table", None), + tableType = CatalogTableType.EXTERNAL, + storage = CatalogStorageFormat( + locationUri = None, + inputFormat = None, + outputFormat = None, + serde = None, + compressed = false, + properties = Map.empty), + schema = schema, + provider = Some(classOf[SimpleInsertSource].getName)) + + spark.sessionState.catalog.createTable(newTable, false) + + sql("INSERT INTO TABLE test_table SELECT 1, 'a'") + sql("INSERT INTO TABLE test_table SELECT 2, null") + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 505a3f3465c02..e96cd4500458d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -323,21 +323,22 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } test("SPARK-23315: get output from canonicalized data source v2 related plans") { - def checkCanonicalizedOutput(df: DataFrame, numOutput: Int): Unit = { + def checkCanonicalizedOutput( + df: DataFrame, logicalNumOutput: Int, physicalNumOutput: Int): Unit = { val logical = df.queryExecution.optimizedPlan.collect { case d: DataSourceV2Relation => d }.head - assert(logical.canonicalized.output.length == numOutput) + assert(logical.canonicalized.output.length == logicalNumOutput) val physical = df.queryExecution.executedPlan.collect { case d: DataSourceV2ScanExec => d }.head - assert(physical.canonicalized.output.length == numOutput) + assert(physical.canonicalized.output.length == physicalNumOutput) } val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() - checkCanonicalizedOutput(df, 2) - checkCanonicalizedOutput(df.select('i), 1) + checkCanonicalizedOutput(df, 2, 2) + checkCanonicalizedOutput(df.select('i), 2, 1) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index 694bb3b95b0f0..1334cf71ae988 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -209,10 +209,10 @@ class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: Serializable override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[Row] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) - val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") + val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId") val fs = filePath.getFileSystem(conf.value) new SimpleCSVDataWriter(fs, filePath) } @@ -245,10 +245,10 @@ class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: Seriali override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[InternalRow] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) - val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") + val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId") val fs = filePath.getFileSystem(conf.value) new InternalRowCSVDataWriter(fs, filePath) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 4c3fd58cb2e45..e41b4534ed51d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -45,6 +45,7 @@ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -337,7 +338,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be var currentStream: StreamExecution = null var lastStream: StreamExecution = null val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for - val sink = if (useV2Sink) new MemorySinkV2 else new MemorySink(stream.schema, outputMode) + val sink = if (useV2Sink) new MemorySinkV2 + else new MemorySink(stream.schema, outputMode, DataSourceOptions.empty()) val resetConfValues = mutable.Map[String, Option[String]]() val defaultCheckpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 1f62357e6d09e..c5cc8df4356a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -404,6 +404,20 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(input3, 5, 10), CheckNewAnswer((5, 10, 5, 15, 5, 25))) } + + test("streaming join should require HashClusteredDistribution from children") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val df1 = input1.toDF.select('value as 'a, 'value * 2 as 'b) + val df2 = input2.toDF.select('value as 'a, 'value * 2 as 'b).repartition('b) + val joined = df1.join(df2, Seq("a", "b")).select('a) + + testStream(joined)( + AddData(input1, 1.to(1000): _*), + AddData(input2, 1.to(1000): _*), + CheckAnswer(1.to(1000): _*)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala new file mode 100644 index 0000000000000..1aaf8a9aa2d55 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import scala.language.reflectiveCalls + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkConf +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.streaming.StreamingQueryListener._ + + +class StreamingQueryListenersConfSuite extends StreamTest with BeforeAndAfter { + + import testImplicits._ + + + override protected def sparkConf: SparkConf = + super.sparkConf.set("spark.sql.streaming.streamingQueryListeners", + "org.apache.spark.sql.streaming.TestListener") + + test("test if the configured query lister is loaded") { + testStream(MemoryStream[Int].toDS)( + StartStream(), + StopStream + ) + + assert(TestListener.queryStartedEvent != null) + assert(TestListener.queryTerminatedEvent != null) + } + +} + +object TestListener { + @volatile var queryStartedEvent: QueryStartedEvent = null + @volatile var queryTerminatedEvent: QueryTerminatedEvent = null +} + +class TestListener(sparkConf: SparkConf) extends StreamingQueryListener { + + override def onQueryStarted(event: QueryStartedEvent): Unit = { + TestListener.queryStartedEvent = event + } + + override def onQueryProgress(event: QueryProgressEvent): Unit = {} + + override def onQueryTerminated(event: QueryTerminatedEvent): Unit = { + TestListener.queryTerminatedEvent = event + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala similarity index 65% rename from sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala index 2e4d607a403ca..a8e3611b585cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala @@ -17,29 +17,14 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle -import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.{HashPartitioner, Partition, TaskContext, TaskContextImpl} import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types.{DataType, IntegerType, StringType} import org.apache.spark.unsafe.types.UTF8String -class ContinuousShuffleReadSuite extends StreamTest { - - private def unsafeRow(value: Int) = { - UnsafeProjection.create(Array(IntegerType : DataType))( - new GenericInternalRow(Array(value: Any))) - } - - private def unsafeRow(value: String) = { - UnsafeProjection.create(Array(StringType : DataType))( - new GenericInternalRow(Array(UTF8String.fromString(value): Any))) - } - - private def send(endpoint: RpcEndpointRef, messages: UnsafeRowReceiverMessage*) = { - messages.foreach(endpoint.askSync[Unit](_)) - } - +class ContinuousShuffleSuite extends StreamTest { // In this unit test, we emulate that we're in the task thread where // ContinuousShuffleReadRDD.compute() will be evaluated. This requires a task context // thread local to be set. @@ -58,39 +43,29 @@ class ContinuousShuffleReadSuite extends StreamTest { super.afterEach() } - test("receiver stopped with row last") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - send( - endpoint, - ReceiverEpochMarker(0), - ReceiverRow(0, unsafeRow(111)) - ) + private implicit def unsafeRow(value: Int) = { + UnsafeProjection.create(Array(IntegerType : DataType))( + new GenericInternalRow(Array(value: Any))) + } - ctx.markTaskCompleted(None) - val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader - eventually(timeout(streamingTimeout)) { - assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get()) - } + private def unsafeRow(value: String) = { + UnsafeProjection.create(Array(StringType : DataType))( + new GenericInternalRow(Array(UTF8String.fromString(value): Any))) } - test("receiver stopped with marker last") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - send( - endpoint, - ReceiverRow(0, unsafeRow(111)), - ReceiverEpochMarker(0) - ) + private def send(endpoint: RpcEndpointRef, messages: RPCContinuousShuffleMessage*) = { + messages.foreach(endpoint.askSync[Unit](_)) + } - ctx.markTaskCompleted(None) - val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader - eventually(timeout(streamingTimeout)) { - assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get()) - } + private def readRDDEndpoint(rdd: ContinuousShuffleReadRDD) = { + rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint } - test("one epoch") { + private def readEpoch(rdd: ContinuousShuffleReadRDD) = { + rdd.compute(rdd.partitions(0), ctx).toSeq.map(_.getInt(0)) + } + + test("reader - one epoch") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint send( @@ -105,7 +80,7 @@ class ContinuousShuffleReadSuite extends StreamTest { assert(iter.toSeq.map(_.getInt(0)) == Seq(111, 222, 333)) } - test("multiple epochs") { + test("reader - multiple epochs") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint send( @@ -124,7 +99,7 @@ class ContinuousShuffleReadSuite extends StreamTest { assert(secondEpoch.toSeq.map(_.getInt(0)) == Seq(222, 333)) } - test("empty epochs") { + test("reader - empty epochs") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint @@ -148,7 +123,7 @@ class ContinuousShuffleReadSuite extends StreamTest { assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) } - test("multiple partitions") { + test("reader - multiple partitions") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 5) // Send all data before processing to ensure there's no crossover. for (p <- rdd.partitions) { @@ -169,7 +144,7 @@ class ContinuousShuffleReadSuite extends StreamTest { } } - test("blocks waiting for new rows") { + test("reader - blocks waiting for new rows") { val rdd = new ContinuousShuffleReadRDD( sparkContext, numPartitions = 1, epochIntervalMs = Long.MaxValue) val epoch = rdd.compute(rdd.partitions(0), ctx) @@ -195,7 +170,7 @@ class ContinuousShuffleReadSuite extends StreamTest { } } - test("multiple writers") { + test("reader - multiple writers") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint send( @@ -213,7 +188,7 @@ class ContinuousShuffleReadSuite extends StreamTest { Set("writer0-row0", "writer1-row0", "writer2-row0")) } - test("epoch only ends when all writers send markers") { + test("reader - epoch only ends when all writers send markers") { val rdd = new ContinuousShuffleReadRDD( sparkContext, numPartitions = 1, numShuffleWriters = 3, epochIntervalMs = Long.MaxValue) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint @@ -233,6 +208,7 @@ class ContinuousShuffleReadSuite extends StreamTest { // After checking the right rows, block until we get an epoch marker indicating there's no next. // (Also fail the assertion if for some reason we get a row.) + val readEpochMarkerThread = new Thread { override def run(): Unit = { assert(!epoch.hasNext) @@ -251,10 +227,10 @@ class ContinuousShuffleReadSuite extends StreamTest { } // Join to pick up assertion failures. - readEpochMarkerThread.join() + readEpochMarkerThread.join(streamingTimeout.toMillis) } - test("writer epochs non aligned") { + test("reader - writer epochs non aligned") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint // We send multiple epochs for 0, then multiple for 1, then multiple for 2. The receiver should @@ -288,4 +264,153 @@ class ContinuousShuffleReadSuite extends StreamTest { val thirdEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet assert(thirdEpoch == Set("writer1-row1", "writer2-row0")) } + + test("one epoch") { + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + + writer.write(Iterator(1, 2, 3)) + + assert(readEpoch(reader) == Seq(1, 2, 3)) + } + + test("multiple epochs") { + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + + writer.write(Iterator(1, 2, 3)) + writer.write(Iterator(4, 5, 6)) + + assert(readEpoch(reader) == Seq(1, 2, 3)) + assert(readEpoch(reader) == Seq(4, 5, 6)) + } + + test("empty epochs") { + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + + writer.write(Iterator()) + writer.write(Iterator(1, 2)) + writer.write(Iterator()) + writer.write(Iterator()) + writer.write(Iterator(3, 4)) + writer.write(Iterator()) + + assert(readEpoch(reader) == Seq()) + assert(readEpoch(reader) == Seq(1, 2)) + assert(readEpoch(reader) == Seq()) + assert(readEpoch(reader) == Seq()) + assert(readEpoch(reader) == Seq(3, 4)) + assert(readEpoch(reader) == Seq()) + } + + test("blocks waiting for writer") { + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + + val readerEpoch = reader.compute(reader.partitions(0), ctx) + + val readRowThread = new Thread { + override def run(): Unit = { + assert(readerEpoch.toSeq.map(_.getInt(0)) == Seq(1)) + } + } + readRowThread.start() + + eventually(timeout(streamingTimeout)) { + assert(readRowThread.getState == Thread.State.TIMED_WAITING) + } + + // Once we write the epoch the thread should stop waiting and succeed. + writer.write(Iterator(1)) + readRowThread.join(streamingTimeout.toMillis) + } + + test("multiple writer partitions") { + val numWriterPartitions = 3 + + val reader = new ContinuousShuffleReadRDD( + sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions) + val writers = (0 until 3).map { idx => + new RPCContinuousShuffleWriter(idx, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + } + + writers(0).write(Iterator(1, 4, 7)) + writers(1).write(Iterator(2, 5)) + writers(2).write(Iterator(3, 6)) + + writers(0).write(Iterator(4, 7, 10)) + writers(1).write(Iterator(5, 8)) + writers(2).write(Iterator(6, 9)) + + // Since there are multiple asynchronous writers, the original row sequencing is not guaranteed. + // The epochs should be deterministically preserved, however. + assert(readEpoch(reader).toSet == Seq(1, 2, 3, 4, 5, 6, 7).toSet) + assert(readEpoch(reader).toSet == Seq(4, 5, 6, 7, 8, 9, 10).toSet) + } + + test("reader epoch only ends when all writer partitions write it") { + val numWriterPartitions = 3 + + val reader = new ContinuousShuffleReadRDD( + sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions) + val writers = (0 until 3).map { idx => + new RPCContinuousShuffleWriter(idx, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + } + + writers(1).write(Iterator()) + writers(2).write(Iterator()) + + val readerEpoch = reader.compute(reader.partitions(0), ctx) + + val readEpochMarkerThread = new Thread { + override def run(): Unit = { + assert(!readerEpoch.hasNext) + } + } + + readEpochMarkerThread.start() + eventually(timeout(streamingTimeout)) { + assert(readEpochMarkerThread.getState == Thread.State.TIMED_WAITING) + } + + writers(0).write(Iterator()) + readEpochMarkerThread.join(streamingTimeout.toMillis) + } + + test("receiver stopped with row last") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverEpochMarker(0), + ReceiverRow(0, unsafeRow(111)) + ) + + ctx.markTaskCompleted(None) + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader + eventually(timeout(streamingTimeout)) { + assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get()) + } + } + + test("receiver stopped with marker last") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(0, unsafeRow(111)), + ReceiverEpochMarker(0) + ) + + ctx.markTaskCompleted(None) + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader + eventually(timeout(streamingTimeout)) { + assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get()) + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 130e258e78ca2..8620f3f6d99fb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -342,7 +342,7 @@ private[client] class Shim_v0_12 extends Shim with Logging { } override def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long = { - conf.getIntVar(HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY) * 1000 + conf.getIntVar(HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY) * 1000L } override def loadPartition( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 6f904c937348d..514921875f1f9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -195,7 +195,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.0.2", "2.1.2", "2.2.1", "2.3.0") + val testingVersions = Seq("2.0.2", "2.1.2", "2.2.1", "2.3.1") protected var spark: SparkSession = _ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala index 9d1b82a6341b1..25e71258b9369 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala @@ -49,7 +49,7 @@ private[spark] class StreamingTab(val ssc: StreamingContext) def detach() { getSparkUI(ssc).detachTab(this) - getSparkUI(ssc).removeStaticHandler("/static/streaming") + getSparkUI(ssc).detachHandler("/static/streaming") } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index ab7c8558321c8..2e8599026ea1d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -222,7 +222,7 @@ private[streaming] class FileBasedWriteAheadLog( pastLogs += LogInfo(currentLogWriterStartTime, currentLogWriterStopTime, _) } currentLogWriterStartTime = currentTime - currentLogWriterStopTime = currentTime + (rollingIntervalSecs * 1000) + currentLogWriterStopTime = currentTime + (rollingIntervalSecs * 1000L) val newLogPath = new Path(logDirectory, timeToLogFile(currentLogWriterStartTime, currentLogWriterStopTime)) currentLogPath = Some(newLogPath.toString)